Skip to content

Commit 5baa713

Browse files
authored
Merge pull request #321 from kozistr/feature/orthogonalize
[Feature] Implement `OrthoGrad` optimizer
2 parents 61fbbd7 + 141e01e commit 5baa713

File tree

12 files changed

+111
-21
lines changed

12 files changed

+111
-21
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **87 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **89 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -195,7 +195,9 @@ get_supported_optimizers(['adam*', 'ranger*'])
195195
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
196196
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
197197
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |
198-
| Grams | *Grams: Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
198+
| Grams | *Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
199+
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
200+
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
199201

200202
## Supported LR Scheduler
201203

@@ -371,6 +373,10 @@ Correcting the norm of a gradient in each iteration based on the adaptive traini
371373

372374
Updates only occur when the proposed update direction aligns with the current gradient.
373375

376+
### Adam-ATAN2
377+
378+
Adam-atan2 is a new numerically stable, scale-invariant version of Adam that eliminates the epsilon hyperparameter.
379+
374380
## Frequently asked questions
375381

376382
[here](docs/qa.md)

docs/changelogs/v3.3.3.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
* Implement `Grams` optimizer. (#317, #318)
66
* [Grams: Gradient Descent with Adaptive Momentum Scaling](https://arxiv.org/abs/2412.17107)
7-
* Support `stable_adamw` variant for `ADOPT` and `AdEMAMix` optimizer. (#320)
7+
* Support `stable_adamw` variant for `ADOPT` and `AdEMAMix` optimizer. (#321)
88
* `optimizer = ADOPT(model.parameters(), ..., stable_adamw=True)`
9-
* Implement an experimental optimizer `Ranger25` (not tested). (#320)
9+
* Implement an experimental optimizer `Ranger25` (not tested). (#321)
1010
* mixing `ADOPT + AdEMAMix + StableAdamW + Cautious + RAdam` optimizers.
11+
* Implement `OrthoGrad` optimizer. (#321)
12+
* [Grokking at the Edge of Numerical Stability](https://arxiv.org/abs/2501.04697)
13+
* Support `Adam-Atan2` feature for `Prodigy` optimizer when `eps` is None. (#321)
14+
* [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/abs/2407.05872)

docs/index.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **87 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **89 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -195,7 +195,9 @@ get_supported_optimizers(['adam*', 'ranger*'])
195195
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
196196
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
197197
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |
198-
| Grams | *Grams: Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
198+
| Grams | *Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
199+
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
200+
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
199201

200202
## Supported LR Scheduler
201203

@@ -371,6 +373,10 @@ Correcting the norm of a gradient in each iteration based on the adaptive traini
371373

372374
Updates only occur when the proposed update direction aligns with the current gradient.
373375

376+
### Adam-ATAN2
377+
378+
Adam-atan2 is a new numerically stable, scale-invariant version of Adam that eliminates the epsilon hyperparameter.
379+
374380
## Frequently asked questions
375381

376382
[here](docs/qa.md)

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@
256256
:docstring:
257257
:members:
258258

259+
::: pytorch_optimizer.OrthoGrad
260+
:docstring:
261+
:members:
262+
259263
::: pytorch_optimizer.PAdam
260264
:docstring:
261265
:members:

pyproject.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.3.2"
3+
version = "3.3.3"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <kozistr@gmail.com>"]
@@ -16,11 +16,11 @@ keywords = [
1616
"Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD",
1717
"DAdaptLion", "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Grams", "Gravity", "GrokFast", "GSAM",
1818
"Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
19-
"NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad",
20-
"SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW",
21-
"SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE",
22-
"BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky",
23-
"LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
19+
"NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
20+
"Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo",
21+
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC",
22+
"WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered",
23+
"Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2424
]
2525
classifiers = [
2626
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,15 @@
116116
Muon,
117117
Nero,
118118
NovoGrad,
119+
OrthoGrad,
119120
PAdam,
120121
PCGrad,
121122
Prodigy,
122123
QHAdam,
123124
RAdam,
124125
Ranger,
125126
Ranger21,
127+
Ranger25,
126128
RotoGrad,
127129
SafeFP16Optimizer,
128130
ScalableShampoo,

pytorch_optimizer/optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from pytorch_optimizer.optimizer.muon import Muon
6464
from pytorch_optimizer.optimizer.nero import Nero
6565
from pytorch_optimizer.optimizer.novograd import NovoGrad
66+
from pytorch_optimizer.optimizer.orthograd import OrthoGrad
6667
from pytorch_optimizer.optimizer.padam import PAdam
6768
from pytorch_optimizer.optimizer.pcgrad import PCGrad
6869
from pytorch_optimizer.optimizer.pid import PID
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
3+
from pytorch_optimizer.base.optimizer import BaseOptimizer
4+
from pytorch_optimizer.base.types import CLOSURE, LOSS, OPTIMIZER, PARAMETERS
5+
6+
7+
class OrthoGrad(BaseOptimizer):
8+
r"""Grokking at the Edge of Numerical Stability.
9+
10+
A wrapper optimizer that projects gradients to be orthogonal to the current parameters before performing an update.
11+
12+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
13+
:param optimizer: OPTIMIZER. base optimizer.
14+
"""
15+
16+
def __init__(self, params: PARAMETERS, optimizer: OPTIMIZER = torch.optim.AdamW, **kwargs):
17+
self.eps: float = 1e-30
18+
19+
super().__init__(params, {})
20+
self.base_optimizer = optimizer(self.param_groups, **kwargs)
21+
22+
def __str__(self) -> str:
23+
return 'OrthoGrad'
24+
25+
@torch.no_grad()
26+
def reset(self):
27+
pass
28+
29+
@torch.no_grad()
30+
def orthogonalize_gradients(self, params) -> None:
31+
for p in params:
32+
if p.grad is None:
33+
continue
34+
35+
w = p.view(-1)
36+
g = p.grad.view(-1)
37+
38+
proj = torch.dot(w, g).div_(torch.dot(w, w).add_(self.eps))
39+
g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj)
40+
g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(self.eps)))
41+
42+
p.grad.copy_(g_ortho_scaled.view_as(p.grad))
43+
44+
@torch.no_grad()
45+
def step(self, closure: CLOSURE = None) -> LOSS:
46+
for group in self.param_groups:
47+
self.orthogonalize_gradients(group['params'])
48+
return self.base_optimizer.step(closure)

pytorch_optimizer/optimizer/prodigy.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Prodigy(BaseOptimizer):
1616
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1717
:param lr: float. learning rate.
1818
:param betas: BETAS. betas.
19-
:param beta3: float. coefficients for computing the Prodidy step-size using running averages. If set to None,
19+
:param beta3: float. coefficients for computing the Prodigy step-size using running averages. If set to None,
2020
uses the value of square root of beta2.
2121
:param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
2222
:param d_coef: float. Coefficient in the expression for the estimate of d.
@@ -26,7 +26,8 @@ class Prodigy(BaseOptimizer):
2626
:param fixed_decay: bool. fix weight decay.
2727
:param bias_correction: bool. turn on Adam's bias correction.
2828
:param safeguard_warmup: bool. remove lr from the denominator of D estimate to avoid issues during warm-up stage.
29-
:param eps: float. term added to the denominator to improve numerical stability.
29+
:param eps: float. term added to the denominator to improve numerical stability. when eps is None, use atan2 rather
30+
than epsilon and division for parameter updates.
3031
"""
3132

3233
def __init__(
@@ -43,7 +44,7 @@ def __init__(
4344
fixed_decay: bool = False,
4445
bias_correction: bool = False,
4546
safeguard_warmup: bool = False,
46-
eps: float = 1e-8,
47+
eps: Optional[float] = 1e-8,
4748
**kwargs,
4849
):
4950
self.validate_learning_rate(lr)
@@ -172,8 +173,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
172173

173174
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
174175

175-
de_nom = exp_avg_sq.sqrt().add_(d * group['eps'])
176-
177176
self.apply_weight_decay(
178177
p,
179178
p.grad,
@@ -183,6 +182,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
183182
fixed_decay=group['fixed_decay'],
184183
)
185184

186-
p.addcdiv_(exp_avg, de_nom, value=-d_lr)
185+
de_nom = exp_avg_sq.sqrt()
186+
187+
if group['eps'] is not None:
188+
de_nom.add_(d * group['eps'])
189+
p.addcdiv_(exp_avg, de_nom, value=-d_lr)
190+
else:
191+
update = exp_avg.clone().atan2_(de_nom)
192+
p.add_(update, alpha=-d_lr)
187193

188194
return loss

tests/constants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
Tiger,
8585
Yogi,
8686
)
87-
from tests.utils import build_lookahead
87+
from tests.utils import build_lookahead, build_orthograd
8888

8989
DECOUPLE_FLAGS: List[bool] = [True, False]
9090
ADAPTIVE_FLAGS: List[bool] = [True, False]
@@ -115,6 +115,7 @@
115115
'radam',
116116
'ranger',
117117
'ranger21',
118+
'ranger25',
118119
'pnm',
119120
'adapnm',
120121
'adan',
@@ -180,6 +181,7 @@
180181

181182
OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
182183
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
184+
(build_orthograd, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
183185
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
184186
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 5),
185187
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 5),
@@ -441,6 +443,7 @@
441443
(SWATS, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
442444
(SWATS, {'lr': 5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 5),
443445
(Prodigy, {'lr': 5e1, 'beta3': None, 'weight_decay': 1e-3}, 10),
446+
(Prodigy, {'lr': 5e0, 'beta3': None, 'weight_decay': 1e-3, 'eps': None}, 15),
444447
(Prodigy, {'lr': 5e1, 'beta3': 0.999, 'weight_decay': 1e-3}, 10),
445448
(Prodigy, {'lr': 1e1, 'beta3': 0.999, 'weight_decay': 1e-3, 'bias_correction': True}, 15),
446449
(Prodigy, {'lr': 1e0, 'beta3': 0.999, 'weight_decay': 1e-3, 'safeguard_warmup': True}, 15),
@@ -545,6 +548,7 @@
545548
(SGDSaI, {'lr': 1e0, 'momentum': 0.0}, 15),
546549
(Grams, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
547550
(Ranger25, {'lr': 1e-1}, 25),
551+
(Ranger25, {'lr': 1e-1, 't_alpha_beta3': 5}, 25),
548552
]
549553
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
550554
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

0 commit comments

Comments
 (0)