Skip to content

Commit

Permalink
Merge pull request #25 from kozistr/feature/diffgrad-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement DiffGrad optimizer
  • Loading branch information
kozistr authored Sep 23, 2021
2 parents 4735dce + cc9d942 commit 5113c54
Show file tree
Hide file tree
Showing 19 changed files with 368 additions and 179 deletions.
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ disable=
fixme,
import-outside-toplevel,
consider-using-enumerate,

duplicate-code,
too-many-branches,
too-many-statements,

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
91 changes: 67 additions & 24 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pytorch-optimizer
| |workflow| |Documentation Status| |PyPI version| |PyPi download| |black|
| Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas.
| Most of the implementations are based on the original paper, but I added some tweaks.
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
Documentation
Expand Down Expand Up @@ -53,6 +54,8 @@ Supported Optimizers
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | `github <https://github.com/clovaai/AdamP>`__ | `https://arxiv.org/abs/2006.08217 <https://arxiv.org/abs/2006.08217>`__ |
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
| diffGrad | *An Optimization Method for Convolutional Neural Networks* | `github <https://github.com/shivram1987/diffGrad>`__ | `https://arxiv.org/abs/1909.11015v3 <https://arxiv.org/abs/1909.11015v3>`__ |
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
| MADGRAD | *A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic* | `github <https://github.com/facebookresearch/madgrad>`__ | `https://arxiv.org/abs/2101.11075 <https://arxiv.org/abs/2101.11075>`__ |
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
| RAdam | *On the Variance of the Adaptive Learning Rate and Beyond* | `github <https://github.com/LiyuanLucasLiu/RAdam>`__ | `https://arxiv.org/abs/1908.03265 <https://arxiv.org/abs/1908.03265>`__ |
Expand All @@ -70,42 +73,51 @@ of the ideas are applied in ``Ranger21`` optimizer.

Also, most of the captures are taken from ``Ranger21`` paper.

Adaptive Gradient Clipping (AGC)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+------------------------------------------+-------------------------------------+--------------------------------------------+
| `Adaptive Gradient Clipping`_ | `Gradient Centralization`_ | `Softplus Transformation`_ |
+------------------------------------------+-------------------------------------+--------------------------------------------+
| `Gradient Normalization`_ | `Norm Loss`_ | `Positive-Negative Momentum`_ |
+------------------------------------------+-------------------------------------+--------------------------------------------+
| `Linear learning rate warmup`_ | `Stable weight decay`_ | `Explore-exploit learning rate schedule`_ |
+------------------------------------------+-------------------------------------+--------------------------------------------+
| `Lookahead`_ | `Chebyshev learning rate schedule`_ | `(Adaptive) Sharpness-Aware Minimization`_ |
+------------------------------------------+-------------------------------------+--------------------------------------------+
| `On the Convergence of Adam and Beyond`_ | | |
+------------------------------------------+-------------------------------------+--------------------------------------------+

Adaptive Gradient Clipping
--------------------------

| This idea originally proposed in ``NFNet (Normalized-Free Network)`` paper.
| AGC (Adaptive Gradient Clipping) clips gradients based on the ``unit-wise ratio of gradient norms to parameter norms``.
| ``AGC (Adaptive Gradient Clipping)`` clips gradients based on the ``unit-wise ratio of gradient norms to parameter norms``.
- code :
`github <https://github.com/deepmind/deepmind-research/tree/master/nfnets>`__
- code : `github <https://github.com/deepmind/deepmind-research/tree/master/nfnets>`__
- paper : `arXiv <https://arxiv.org/abs/2102.06171>`__

Gradient Centralization (GC)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Gradient Centralization
-----------------------

+-----------------------------------------------------------------------------------------------------------------+
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/gradient_centralization.png |
+-----------------------------------------------------------------------------------------------------------------+

Gradient Centralization (GC) operates directly on gradients by
centralizing the gradient to have zero mean.
``Gradient Centralization (GC)`` operates directly on gradients by centralizing the gradient to have zero mean.

- code :
`github <https://github.com/Yonghongwei/Gradient-Centralization>`__
- code : `github <https://github.com/Yonghongwei/Gradient-Centralization>`__
- paper : `arXiv <https://arxiv.org/abs/2004.01461>`__

Softplus Transformation
~~~~~~~~~~~~~~~~~~~~~~~
-----------------------

By running the final variance denom through the softplus function, it lifts extremely tiny values to keep them viable.

- paper : `arXiv <https://arxiv.org/abs/1908.00700>`__

Gradient Normalization
~~~~~~~~~~~~~~~~~~~~~~
----------------------

Norm Loss
~~~~~~~~~
---------

+---------------------------------------------------------------------------------------------------+
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/norm_loss.png |
Expand All @@ -114,7 +126,7 @@ Norm Loss
- paper : `arXiv <https://arxiv.org/abs/2103.06583>`__

Positive-Negative Momentum
~~~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------

+--------------------------------------------------------------------------------------------------------------------+
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/positive_negative_momentum.png |
Expand All @@ -123,8 +135,8 @@ Positive-Negative Momentum
- code : `github <https://github.com/zeke-xie/Positive-Negative-Momentum>`__
- paper : `arXiv <https://arxiv.org/abs/2103.17182>`__

Linear learning-rate warm-up
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Linear learning rate warmup
---------------------------

+----------------------------------------------------------------------------------------------------------+
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/linear_lr_warmup.png |
Expand All @@ -133,7 +145,7 @@ Linear learning-rate warm-up
- paper : `arXiv <https://arxiv.org/abs/1910.04209>`__

Stable weight decay
~~~~~~~~~~~~~~~~~~~
-------------------

+-------------------------------------------------------------------------------------------------------------+
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/stable_weight_decay.png |
Expand All @@ -142,8 +154,8 @@ Stable weight decay
- code : `github <https://github.com/zeke-xie/stable-weight-decay-regularization>`__
- paper : `arXiv <https://arxiv.org/abs/2011.11152>`__

Explore-exploit learning-rate schedule
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Explore-exploit learning rate schedule
--------------------------------------

+---------------------------------------------------------------------------------------------------------------------+
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/explore_exploit_lr_schedule.png |
Expand All @@ -153,7 +165,7 @@ Explore-exploit learning-rate schedule
- paper : `arXiv <https://arxiv.org/abs/2003.03977>`__

Lookahead
~~~~~~~~~
---------

| ``k`` steps forward, 1 step back. ``Lookahead`` consisting of keeping an exponential moving average of the weights that is
| updated and substituted to the current weights every ``k_{lookahead}`` steps (5 by default).
Expand All @@ -162,14 +174,14 @@ Lookahead
- paper : `arXiv <https://arxiv.org/abs/1907.08610v2>`__

Chebyshev learning rate schedule
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------------

Acceleration via Fractal Learning Rate Schedules

- paper : `arXiv <https://arxiv.org/abs/2103.01338v1>`__

(Adaptive) Sharpness-Aware Minimization (A/SAM)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(Adaptive) Sharpness-Aware Minimization
---------------------------------------

| Sharpness-Aware Minimization (SAM) simultaneously minimizes loss value and loss sharpness.
| In particular, it seeks parameters that lie in neighborhoods having uniformly low loss.
Expand All @@ -178,6 +190,11 @@ Acceleration via Fractal Learning Rate Schedules
- ASAM paper : `paper <https://arxiv.org/abs/2102.11600>`__
- A/SAM code : `github <https://github.com/davda54/sam>`__

On the Convergence of Adam and Beyond
-------------------------------------

- paper : `paper <https://openreview.net/forum?id=ryQu7f-RZ>`__

Citations
---------

Expand Down Expand Up @@ -387,6 +404,32 @@ Adaptive Sharpness-Aware Minimization
year={2021}
}

diffGrad

::

@article{dubey2019diffgrad,
title={diffgrad: An optimization method for convolutional neural networks},
author={Dubey, Shiv Ram and Chakraborty, Soumendu and Roy, Swalpa Kumar and Mukherjee, Snehasis and Singh, Satish Kumar and Chaudhuri, Bidyut Baran},
journal={IEEE transactions on neural networks and learning systems},
volume={31},
number={11},
pages={4500--4511},
year={2019},
publisher={IEEE}
}

On the Convergence of Adam and Beyond

::

@article{reddi2019convergence,
title={On the convergence of adam and beyond},
author={Reddi, Sashank J and Kale, Satyen and Kumar, Sanjiv},
journal={arXiv preprint arXiv:1904.09237},
year={2019}
}

Author
------

Expand Down
4 changes: 3 additions & 1 deletion pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# pylint: disable=unused-import
from pytorch_optimizer.adabelief import AdaBelief
from pytorch_optimizer.adabound import AdaBound
from pytorch_optimizer.adahessian import AdaHessian
from pytorch_optimizer.adamp import AdamP
from pytorch_optimizer.agc import agc
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
from pytorch_optimizer.diffgrad import DiffGrad
from pytorch_optimizer.gc import centralize_gradient
from pytorch_optimizer.lookahead import Lookahead
from pytorch_optimizer.madgrad import MADGRAD
Expand All @@ -13,4 +15,4 @@
from pytorch_optimizer.sam import SAM
from pytorch_optimizer.sgdp import SGDP

__VERSION__ = '0.0.7'
__VERSION__ = '0.0.8'
15 changes: 10 additions & 5 deletions pytorch_optimizer/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,22 @@ def __init__(
degenerated_to_sgd: bool = True,
):
"""AdaBelief optimizer
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
:param params: PARAMS. iterable of parameters to optimize
or dicts defining parameter groups
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param eps: float. term added to the denominator to improve numerical stability
:param betas: BETAS. coefficients used for computing running averages
of gradient and the squared hessian trace
:param eps: float. term added to the denominator
to improve numerical stability
:param weight_decay: float. weight decay (L2 penalty)
:param n_sma_threshold: (recommended is 5)
:param amsgrad: bool. whether to use the AMSBound variant
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
:param weight_decouple: bool. the optimizer uses decoupled weight decay
as in AdamW
:param fixed_decay: bool.
:param rectify: bool. perform the rectified update similar to RAdam
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
:param degenerated_to_sgd: bool. perform SGD update
when variance of gradient is high
"""
self.lr = lr
self.betas = betas
Expand Down
20 changes: 12 additions & 8 deletions pytorch_optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class AdaBound(Optimizer):
"""
Reference : https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py
Reference : https://github.com/Luolc/AdaBound
Example :
from pytorch_optimizer import AdaBound
...
Expand Down Expand Up @@ -43,14 +43,18 @@ def __init__(
amsbound: bool = False,
):
"""AdaBound optimizer
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
:param params: PARAMS. iterable of parameters to optimize
or dicts defining parameter groups
:param lr: float. learning rate
:param final_lr: float. final learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param betas: BETAS. coefficients used for computing running averages
of gradient and the squared hessian trace
:param gamma: float. convergence speed of the bound functions
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
:param weight_decouple: bool. the optimizer uses decoupled weight decay
as in AdamW
:param fixed_decay: bool.
:param eps: float. term added to the denominator to improve numerical stability
:param eps: float. term added to the denominator
to improve numerical stability
:param weight_decay: float. weight decay (L2 penalty)
:param amsbound: bool. whether to use the AMSBound variant
"""
Expand All @@ -75,11 +79,11 @@ def __init__(
self.base_lrs = [group['lr'] for group in self.param_groups]

def check_valid_parameters(self):
if 0.0 > self.lr:
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if 0.0 > self.eps:
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')
if 0.0 > self.weight_decay:
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
Expand Down
38 changes: 24 additions & 14 deletions pytorch_optimizer/adahessian.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict, Iterable

import torch
from torch.optim import Optimizer

Expand All @@ -12,7 +14,7 @@

class AdaHessian(Optimizer):
"""
Reference : https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
Reference : https://github.com/davda54/ada-hessian
Example :
from pytorch_optimizer import AdaHessian
...
Expand Down Expand Up @@ -40,15 +42,21 @@ def __init__(
seed: int = 2147483647,
):
"""
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
:param params: PARAMS. iterable of parameters to optimize
or dicts defining parameter groups
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param eps: float. term added to the denominator to improve numerical stability
:param betas: BETAS. coefficients used for computing running averages
of gradient and the squared hessian trace
:param eps: float. term added to the denominator
to improve numerical stability
:param weight_decay: float. weight decay (L2 penalty)
:param hessian_power: float. exponent of the hessian trace
:param update_each: int. compute the hessian trace approximation only after *this* number of steps
:param n_samples: int. how many times to sample `z` for the approximation of the hessian trace
:param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper.
:param update_each: int. compute the hessian trace approximation
only after *this* number of steps
:param n_samples: int. how many times to sample `z`
for the approximation of the hessian trace
:param average_conv_kernel: bool. average out the hessian traces
of convolutional kernels as in the paper.
:param seed: int.
"""
self.lr = lr
Expand All @@ -63,8 +71,8 @@ def __init__(

self.check_valid_parameters()

# use a separate generator that deterministically generates the same `z`s across all GPUs
# in case of distributed training
# use a separate generator that deterministically generates
# the same `z`s across all GPUs in case of distributed training
self.generator: torch.Generator = torch.Generator().manual_seed(
self.seed
)
Expand All @@ -83,11 +91,11 @@ def __init__(
self.state[p]['hessian_step'] = 0

def check_valid_parameters(self):
if 0.0 > self.lr:
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if 0.0 > self.eps:
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')
if 0.0 > self.weight_decay:
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
Expand All @@ -96,7 +104,7 @@ def check_valid_parameters(self):
if not 0.0 <= self.hessian_power < 1.0:
raise ValueError(f'Invalid hessian_power : {self.hessian_power}')

def get_params(self):
def get_params(self) -> Iterable[Dict]:
"""Gets all parameters in all param_groups with gradients"""
return (
p
Expand All @@ -116,7 +124,9 @@ def zero_hessian(self):

@torch.no_grad()
def set_hessian(self):
"""Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter"""
"""Computes the Hutchinson approximation of the hessian trace
and accumulates it for each trainable parameter
"""
params = []
for p in filter(
lambda param: param.grad is not None, self.get_params()
Expand Down
Loading

0 comments on commit 5113c54

Please sign in to comment.