Skip to content

Commit d657060

Browse files
authored
Merge pull request #54 from kozistr/refactor/optimizers
[Refactor] Optimizers
2 parents 6db0d49 + 619c169 commit d657060

16 files changed

+52
-62
lines changed

lint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_configuration() -> Namespace:
1414
parser.add_argument(
1515
'-t',
1616
'--threshold',
17-
default=9.9,
17+
default=9.95,
1818
type=float,
1919
)
2020

pytorch_optimizer/adabelief.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,6 @@ def validate_parameters(self):
8282
self.validate_weight_decay(self.weight_decay)
8383
self.validate_epsilon(self.eps)
8484

85-
def __setstate__(self, state: STATE):
86-
super().__setstate__(state)
87-
for group in self.param_groups:
88-
group.setdefault('amsgrad', False)
89-
group.setdefault('adamd_debias_term', False)
90-
9185
@torch.no_grad()
9286
def reset(self):
9387
for group in self.param_groups:
@@ -152,11 +146,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
152146
grad_residual = grad - exp_avg
153147
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2)
154148

149+
exp_avg_var = exp_avg_var.add_(group['eps'])
155150
if group['amsgrad']:
156-
max_exp_avg_var = torch.max(state['max_exp_avg_var'], exp_avg_var.add_(group['eps']))
157-
de_nom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
158-
else:
159-
de_nom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
151+
exp_avg_var = torch.max(state['max_exp_avg_var'], exp_avg_var)
152+
153+
de_nom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
160154

161155
if not self.rectify:
162156
step_size = group['lr']

pytorch_optimizer/adabound.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,6 @@ def validate_parameters(self):
8080
self.validate_weight_decay(self.weight_decay)
8181
self.validate_epsilon(self.eps)
8282

83-
def __setstate__(self, state: STATE):
84-
super().__setstate__(state)
85-
for group in self.param_groups:
86-
group.setdefault('amsbound', False)
87-
8883
@torch.no_grad()
8984
def reset(self):
9085
for group in self.param_groups:
@@ -140,10 +135,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
140135
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
141136

142137
if group['amsbound']:
143-
max_exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
144-
de_nom = max_exp_avg_sq.sqrt().add_(group['eps'])
145-
else:
146-
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
138+
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
139+
140+
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
147141

148142
bias_correction1 = 1.0 - beta1 ** state['step']
149143
bias_correction2 = 1.0 - beta2 ** state['step']

pytorch_optimizer/diffgrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.optim.optimizer import Optimizer
55

66
from pytorch_optimizer.base_optimizer import BaseOptimizer
7-
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
7+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
88

99

1010
class DiffGrad(Optimizer, BaseOptimizer):

pytorch_optimizer/diffrgrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.optim.optimizer import Optimizer
55

66
from pytorch_optimizer.base_optimizer import BaseOptimizer
7-
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
7+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
88

99

1010
class DiffRGrad(Optimizer, BaseOptimizer):

pytorch_optimizer/lars.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8989
grad = grad.add(p, alpha=g['weight_decay'])
9090
param_norm = torch.norm(p)
9191
update_norm = torch.norm(grad)
92-
one = torch.ones_like(param_norm)
92+
one = torch.ones_like(param_norm, device=param_norm.device)
9393

9494
q = torch.where(
9595
param_norm > 0.0,
@@ -100,7 +100,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
100100

101101
param_state = self.state[p]
102102
if 'mu' not in param_state:
103-
param_state['mu'] = torch.zeros_like(p)
103+
param_state['mu'] = torch.zeros_like(p, device=p.device)
104104

105105
mu = param_state['mu']
106106
mu.mul_(g['momentum']).add_(grad)

pytorch_optimizer/pcgrad.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ def validate_parameters(self):
3939

4040
@torch.no_grad()
4141
def reset(self):
42-
pass
42+
self.zero_grad()
4343

4444
def zero_grad(self):
4545
return self.optimizer.zero_grad(set_to_none=True)
4646

4747
def step(self):
4848
return self.optimizer.step()
4949

50-
def set_grad(self, grads):
50+
def set_grad(self, grads: List[torch.Tensor]):
5151
idx: int = 0
5252
for group in self.optimizer.param_groups:
5353
for p in group['params']:
@@ -74,7 +74,7 @@ def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tenso
7474
def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
7575
"""pack the gradient of the parameters of the network for each objective
7676
:param objectives: Iterable[nn.Module]. a list of objectives
77-
:return:
77+
:return: torch.Tensor. packed gradients
7878
"""
7979
grads, shapes, has_grads = [], [], []
8080
for objective in objectives:
@@ -89,27 +89,29 @@ def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List
8989

9090
return grads, shapes, has_grads
9191

92-
def project_conflicting(self, grads, has_grads) -> torch.Tensor:
92+
def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
9393
"""project conflicting
9494
:param grads: a list of the gradient of the parameters
9595
:param has_grads: a list of mask represent whether the parameter has gradient
96-
:return:
96+
:return: torch.Tensor. merged gradients
9797
"""
98-
shared = torch.stack(has_grads).prod(0).bool()
98+
shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()
9999

100-
pc_grad = deepcopy(grads)
100+
pc_grad: List[torch.Tensor] = deepcopy(grads)
101101
for g_i in pc_grad:
102102
random.shuffle(grads)
103103
for g_j in grads:
104-
g_i_g_j = torch.dot(g_i, g_j)
104+
g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
105105
if g_i_g_j < 0:
106106
g_i -= g_i_g_j * g_j / (g_j.norm() ** 2)
107107

108-
merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
108+
merged_grad: torch.Tensor = torch.zeros_like(grads[0], device=grads[0].device)
109+
110+
shared_pc_gradients: torch.Tensor = torch.stack([g[shared] for g in pc_grad])
109111
if self.reduction == 'mean':
110-
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
112+
merged_grad[shared] = shared_pc_gradients.mean(dim=0)
111113
else:
112-
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)
114+
merged_grad[shared] = shared_pc_gradients.sum(dim=0)
113115

114116
merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)
115117

@@ -121,7 +123,7 @@ def pc_backward(self, objectives: Iterable[nn.Module]):
121123
:return:
122124
"""
123125
grads, shapes, has_grads = self.pack_grad(objectives)
126+
124127
pc_grad = self.project_conflicting(grads, has_grads)
125128
pc_grad = un_flatten_grad(pc_grad, shapes[0])
126-
127129
self.set_grad(pc_grad)

pytorch_optimizer/radam.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from typing import Dict
32

43
import torch
54
from torch.optim.optimizer import Optimizer
@@ -153,14 +152,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
153152
step_size = -1
154153
buffered[2] = step_size
155154

155+
if group['weight_decay'] != 0 and (n_sma >= self.n_sma_threshold or step_size > 0):
156+
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
157+
156158
if n_sma >= self.n_sma_threshold:
157-
if group['weight_decay'] != 0:
158-
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
159159
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
160160
p_fp32.addcdiv_(exp_avg, de_nom, value=-step_size * group['lr'])
161161
elif step_size > 0:
162-
if group['weight_decay'] != 0:
163-
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
164162
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
165163

166164
if p.dtype in (torch.float16, torch.bfloat16):

pytorch_optimizer/ralamb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_gradient_norm(self) -> float:
9494
if p.grad is None:
9595
continue
9696

97-
norm_sq += torch.linalg.norm(p.grad).item() ** 2
97+
norm_sq += torch.linalg.norm(p.grad).cpu().numpy() ** 2
9898

9999
norm = math.sqrt(norm_sq)
100100

@@ -147,7 +147,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
147147
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
148148

149149
state['step'] += 1
150-
buffered = group['buffer'][int(state['step'] % 10)]
150+
buffered = group['buffer'][state['step'] % 10]
151151

152152
bias_correction1 = 1.0 - beta1 ** state['step']
153153

pytorch_optimizer/ranger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from typing import Dict
32

43
import torch
54
from torch.optim.optimizer import Optimizer

pytorch_optimizer/ranger21.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ def reset(self):
148148

149149
@staticmethod
150150
def build_warm_up_iterations(total_iterations: int, beta2: float, warm_up_pct: float = 0.22) -> int:
151-
beta_warm_up_iterations: int = math.ceil(2.0 / (1.0 - beta2)) # default un-tuned linear warmup
152-
beta_pct: float = beta_warm_up_iterations / total_iterations
151+
warm_up_iterations: int = math.ceil(2.0 / (1.0 - beta2)) # default un-tuned linear warmup
152+
beta_pct: float = warm_up_iterations / total_iterations
153153
if beta_pct > 0.45:
154154
return int(warm_up_pct * total_iterations)
155-
return beta_warm_up_iterations
155+
return warm_up_iterations
156156

157157
@staticmethod
158158
def build_warm_down_iterations(total_iterations: int, warm_down_pct: float = 0.72) -> int:
@@ -187,13 +187,6 @@ def get_warm_down(self, lr: float, iteration: int) -> float:
187187

188188
return new_lr
189189

190-
@staticmethod
191-
def get_state_values(group, state: STATE):
192-
beta1, beta2 = group['betas']
193-
mean_avg = state['mean_avg']
194-
variance_avg = state['variance_avg']
195-
return beta1, beta2, mean_avg, variance_avg
196-
197190
@torch.no_grad()
198191
def step(self, closure: CLOSURE = None) -> LOSS:
199192
loss: LOSS = None

pytorch_optimizer/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
1+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
22

33
import torch
44

pytorch_optimizer/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor:
4343

4444
def un_flatten_grad(grads: torch.Tensor, shapes: List[int]) -> List[torch.Tensor]:
4545
idx: int = 0
46-
un_flatten_grad: List[torch.Tensor] = []
46+
un_flatten_grads: List[torch.Tensor] = []
4747
for shape in shapes:
4848
length = np.prod(shape)
49-
un_flatten_grad.append(grads[idx : idx + length].view(shape).clone())
49+
un_flatten_grads.append(grads[idx : idx + length].view(shape).clone())
5050
idx += length
51-
return un_flatten_grad
51+
return un_flatten_grads
5252

5353

5454
def channel_view(x: torch.Tensor) -> torch.Tensor:

pytorch_optimizer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__VERSION__ = '0.4.1'
1+
__VERSION__ = '0.4.2'

tests/test_optimizer_parameters.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from torch import nn
55

6-
from pytorch_optimizer import SAM, AdamP, Lookahead, PCGrad, SafeFP16Optimizer, load_optimizers
6+
from pytorch_optimizer import SAM, AdamP, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizers
77
from tests.utils import Example
88

99
OPTIMIZER_NAMES: List[str] = [
@@ -173,3 +173,9 @@ def test_safe_fp16_methods():
173173
optimizer.set_lr(lr=5e-1)
174174

175175
assert optimizer.loss_scale == 2.0 ** (15 - 1)
176+
177+
178+
def test_ranger21_methods():
179+
assert Ranger21.build_warm_up_iterations(1000, 0.999) == 220
180+
assert Ranger21.build_warm_up_iterations(4500, 0.999) == 2000
181+
assert Ranger21.build_warm_down_iterations(1000) == 280

tests/test_optimizers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'fixed_decay': True}, 200),
4242
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'rectify': False}, 200),
4343
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
44+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'fixed_decay': True}, 200),
45+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'weight_decouple': False}, 200),
4446
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'amsbound': True}, 200),
4547
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
4648
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 200),
@@ -61,6 +63,7 @@
6163
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
6264
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 200),
6365
(SGDP, {'lr': 2e-1, 'weight_decay': 1e-3}, 500),
66+
(SGDP, {'lr': 2e-1, 'weight_decay': 1e-3, 'nesterov': True}, 500),
6467
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
6568
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 500}, 500),
6669
]
@@ -248,8 +251,9 @@ def test_adamd_optimizers(optimizer_adamd_config):
248251
assert tensor_to_numpy(init_loss) > 2.0 * tensor_to_numpy(loss)
249252

250253

254+
@pytest.mark.parametrize('reduction', ('mean', 'sum'))
251255
@pytest.mark.parametrize('optimizer_pc_grad_config', OPTIMIZERS, ids=ids)
252-
def test_pc_grad_optimizers(optimizer_pc_grad_config):
256+
def test_pc_grad_optimizers(reduction, optimizer_pc_grad_config):
253257
torch.manual_seed(42)
254258

255259
x_data, y_data = make_dataset()
@@ -259,7 +263,7 @@ def test_pc_grad_optimizers(optimizer_pc_grad_config):
259263
loss_fn_2: nn.Module = nn.L1Loss()
260264

261265
optimizer_class, config, iterations = optimizer_pc_grad_config
262-
optimizer = PCGrad(optimizer_class(model.parameters(), **config))
266+
optimizer = PCGrad(optimizer_class(model.parameters(), **config), reduction=reduction)
263267

264268
if optimizer_class.__name__ == 'RaLamb' and 'pre_norm' in config:
265269
return True

0 commit comments

Comments
 (0)