Skip to content

Commit

Permalink
Merge pull request #27 from kozistr/feature/diffrgrad-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement DiffRGrad optimizer
  • Loading branch information
kozistr authored Sep 23, 2021
2 parents 3c952d0 + 9b3c3d2 commit bc22b00
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytorch_optimizer.agc import agc
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
from pytorch_optimizer.diffgrad import DiffGrad
from pytorch_optimizer.diffrgrad import DiffRGrad
from pytorch_optimizer.gc import centralize_gradient
from pytorch_optimizer.lookahead import Lookahead
from pytorch_optimizer.madgrad import MADGRAD
Expand All @@ -15,4 +16,4 @@
from pytorch_optimizer.sam import SAM
from pytorch_optimizer.sgdp import SGDP

__VERSION__ = '0.0.8'
__VERSION__ = '0.0.9'
172 changes: 172 additions & 0 deletions pytorch_optimizer/diffrgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE


class DiffRGrad(Optimizer):
"""
Reference 1 : https://github.com/shivram1987/diffGrad
Reference 2 : https://github.com/LiyuanLucasLiu/RAdam
Reference 3 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/diffgrad/diff_rgrad.py
Example :
from pytorch_optimizer import DiffRGrad
...
model = YourModel()
optimizer = DiffRGrad(model.parameters())
...
for input, output in data:
optimizer.zero_grad()
loss = loss_function(output, model(input))
loss.backward()
optimizer.step()
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.999),
weight_decay: float = 0.0,
n_sma_threshold: int = 5,
degenerated_to_sgd: bool = True,
eps: float = 1e-8,
):
"""Blend RAdam with DiffGrad
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param weight_decay: float. weight decay (L2 penalty)
:param n_sma_threshold: int. (recommended is 5)
:param degenerated_to_sgd: float.
:param eps: float. term added to the denominator to improve numerical stability
"""
self.lr = lr
self.betas = betas
self.weight_decay = weight_decay
self.n_sma_threshold = n_sma_threshold
self.degenerated_to_sgd = degenerated_to_sgd
self.eps = eps

self.check_valid_parameters()

if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
param['buffer'] = [[None, None, None] for _ in range(10)]

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
buffer=[[None, None, None] for _ in range(10)],
)

super().__init__(params, defaults)

def check_valid_parameters(self):
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
if not 0.0 <= self.betas[1] < 1.0:
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')

def __setstate__(self, state: STATE):
super().__setstate__(state)

def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue

grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('diffGrad does not support sparse gradients')

p_data_fp32 = p.data.float()
state = self.state[p]

if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
state['previous_grad'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
state['previous_grad'] = state['previous_grad'].type_as(p_data_fp32)

exp_avg, exp_avg_sq, previous_grad = (
state['exp_avg'],
state['exp_avg_sq'],
state['previous_grad'],
)
beta1, beta2 = group['betas']

state['step'] += 1

exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

# compute diffGrad coefficient (dfc)
diff = abs(previous_grad - grad)
dfc = 1.0 / (1.0 + torch.exp(-diff))

state['previous_grad'] = grad.clone()

buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
n_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
n_sma_max = 2.0 / (1.0 - beta2) - 1.0
n_sma = n_sma_max - 2.0 * state['step'] * beta2_t / (1.0 - beta2_t)
buffered[1] = n_sma

if n_sma >= self.n_sma_threshold:
step_size = math.sqrt(
(1 - beta2_t)
* (n_sma - 4)
/ (n_sma_max - 4)
* (n_sma - 2)
/ n_sma
* n_sma_max
/ (n_sma_max - 2)
) / (1.0 - beta1 ** state['step'])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state['step'])
else:
step_size = -1
buffered[2] = step_size

if n_sma >= self.n_sma_threshold:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

denom = exp_avg_sq.sqrt().add_(group['eps'])

# update momentum with dfc
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg * dfc.float(), denom)
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p.data.copy_(p_data_fp32)

return loss
4 changes: 2 additions & 2 deletions pytorch_optimizer/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ def __init__(
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
n_sma_threshold: int = 5,
degenerated_to_sgd: bool = False,
eps: float = 1e-8,
):
"""
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param eps: float. term added to the denominator to improve numerical stability
:param weight_decay: float. weight decay (L2 penalty)
:param n_sma_threshold: int. (recommended is 5)
:param degenerated_to_sgd: float.
:param eps: float. term added to the denominator to improve numerical stability
"""
self.lr = lr
self.betas = betas
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def read_version() -> str:
'sam',
'asam',
'diffgrad',
'diffrgrad',
]
)

Expand Down

0 comments on commit bc22b00

Please sign in to comment.