Skip to content

Commit b9357f8

Browse files
committed
add smoke test for kl shampoo and fix bug
1 parent ad5204b commit b9357f8

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

emerging_optimizers/soap/soap.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,14 +503,18 @@ def update_kronecker_factors_kl_shampoo(
503503

504504
# Scale the gradient matrix by the approximate eigenvalues and the eigenbasis
505505
# G@Q_R@λ_R^(−1)@Q_R.T@G.T/dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG)
506+
updates = []
506507
for idx, (kronecker_factor, eigenbasis) in enumerate(zip(kronecker_factor_list, eigenbasis_list, strict=True)):
507508
approx_eigvals = utils.eig.conjugate(kronecker_factor, eigenbasis, diag=True)
508509
scale_factor = 1 / grad.shape[idx] * approx_eigvals.clamp_min(eps) ** eigval_exp
509510

510511
correction = (eigenbasis * scale_factor[None, :]) @ eigenbasis.T
511512

512-
maybe_transpose_grad = grad.T if idx == 0 else grad
513-
update = utils.eig.conjugate(correction, maybe_transpose_grad)
513+
maybe_transpose_grad = grad.T if idx == 1 else grad
514+
updates.append(utils.eig.conjugate(correction, maybe_transpose_grad))
515+
516+
# Note that updates caculated in previous loop are in reverse order of the kronecker factor list they apply to
517+
for kronecker_factor, update in zip(kronecker_factor_list, updates[::-1], strict=True):
514518
kronecker_factor.lerp_(update, 1 - shampoo_beta)
515519

516520

tests/test_soap.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def test_10steps_smoke(self):
4747
optimizer.step()
4848
param.grad = None
4949

50+
def test_with_kl_shampoo_10steps_smoke(self):
51+
param = torch.randn(5, 3, requires_grad=True, device="cuda")
52+
optimizer = soap.SOAP(
53+
[param],
54+
**self.default_config,
55+
use_kl_shampoo=True,
56+
)
57+
58+
for _ in range(10):
59+
param.grad = torch.randn_like(param)
60+
optimizer.step()
61+
param.grad = None
62+
5063

5164
if __name__ == "__main__":
5265
absltest.main()

0 commit comments

Comments
 (0)