Skip to content

Commit

Permalink
add muon to contrib/
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Nov 2, 2024
1 parent d4592a6 commit a45f7a0
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Experimental features and algorithms that don't meet the
MomoState
momo_adam
MomoAdamState
muon
MuonState
prodigy
ProdigyState
sam
Expand Down Expand Up @@ -82,6 +84,12 @@ Momo
.. autofunction:: momo_adam
.. autoclass:: MomoAdamState

Muon
~~~~
.. autofunction:: muon
.. autofunction:: scale_by_muon
.. autoclass:: MuonState

Prodigy
~~~~~~~
.. autofunction:: prodigy
Expand Down
3 changes: 3 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from optax.contrib._momo import momo_adam
from optax.contrib._momo import MomoAdamState
from optax.contrib._momo import MomoState
from optax.contrib._muon import muon
from optax.contrib._muon import scale_by_muon
from optax.contrib._muon import MuonState
from optax.contrib._privacy import differentially_private_aggregate
from optax.contrib._privacy import DifferentiallyPrivateAggregateState
from optax.contrib._privacy import dpsgd
Expand Down
1 change: 1 addition & 0 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
dict(opt_name='dowg', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='muon', opt_kwargs=dict(learning_rate=1e-3)),
dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)),
dict(
opt_name='schedule_free_sgd',
Expand Down
126 changes: 126 additions & 0 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Any, List, NamedTuple, Optional, Tuple, Union

import chex
import jax
import jax.numpy as jnp

from optax import tree_utils as otu
from optax._src import base
from optax._src import combine
from optax._src import numerics
from optax._src import transform
from optax._src import utils


class MuonState(NamedTuple):
"""State for the Adam algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates


def scale_by_muon(
newton_schulz_coeffs: Union[Tuple[float, float, float], List[Tuple[float, float, float]]] = (3.4445, -4.7750, 2.0315),
newton_schulz_steps: Optional[int] = 5,
mumentum: float = 0.95,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = True,
) -> base.GradientTransformation:
r"""Rescale updates according to the Muon algorithm.
Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize
the momentum accumulated by the optimizer. Mathematically, it does steepest descent
under the Schatten-p norm, for some large p. With p=infty, it is equivalent to
Shampoo without accumulation, or steepest descent under the Spectral norm.
References:
Jordan, `Overview of mini-batch gradient descent
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
newton_schulz_coeffs: Coefficients for the Newton-schulz method.
newton_schulz_steps: Number of Newton-schulz iterations.
mumentum: Exponential decay rate to track the first moment of past gradients.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
Returns:
A `GradientTransformation` object.
"""
muon_coeffs = jnp.asarray(
newton_schulz_coeffs
if isinstance(newton_schulz_coeffs, list)
else [newton_schulz_coeffs] * newton_schulz_steps
)
muon_iterator = lambda x, abc:(abc[0]*x + abc[1]*(x@x.T)@x + abc[2]*(x@x.T)@(x@x.T)@x, 0)
mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
return MuonState(count=jnp.zeros([], jnp.int32), mu=mu)

def update_fn(updates, state, params=None):
del params
mu = otu.tree_update_moment(updates, state.mu, mumentum, 1)
count_inc = numerics.safe_int32_increment(state.count)
if nesterov:
mu_hat = jax.tree.map(
lambda m, g: mumentum * m + (1 - mumentum) * g,
otu.tree_bias_correction(
mu, mumentum, numerics.safe_int32_increment(count_inc)
),
otu.tree_bias_correction(updates, mumentum, count_inc),
)
else:
mu_hat = otu.tree_bias_correction(mu, mumentum, count_inc)
updates = jax.tree.map(lambda x: x / jnp.linalg.norm(x, ord='fro'), mu_hat)
updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs)
mu = otu.tree_cast(mu, mu_dtype)
return updates, MuonState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)


def muon(
learning_rate: base.ScalarOrSchedule,
newton_schulz_coeffs: Tuple[float, float, float] | List[Tuple[float, float, float]] = (3.4445, -4.7750, 2.0315),
newton_schulz_steps: Optional[int] = 5,
mumentum: float = 0.95,
mu_dtype: Optional[Any] = None,
*,
nesterov: bool = True,
) -> base.GradientTransformation:
r"""Muon: Momentum Orthogonalized by Newton-schulz
Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize
the momentum accumulated by the optimizer. Mathematically, it does steepest descent
under the Schatten-p norm, for some large p. With p=infty, it is equivalent to
Shampoo without accumulation, or steepest descent under the Spectral norm.
References:
Jordan, `Overview of mini-batch gradient descent
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
newton_schulz_coeffs: Coefficients for the Newton-schulz method.
newton_schulz_steps: Number of Newton-schulz iterations.
mumentum: Exponential decay rate to track the first moment of past gradients.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(
scale_by_muon(
newton_schulz_coeffs,
newton_schulz_steps,
mumentum,
mu_dtype,
nesterov=nesterov,
),
transform.scale_by_learning_rate(learning_rate),
)

0 comments on commit a45f7a0

Please sign in to comment.