diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index fdee8656..dacc57c5 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -25,6 +25,8 @@ Experimental features and algorithms that don't meet the MomoState momo_adam MomoAdamState + muon + MuonState prodigy ProdigyState sam @@ -82,6 +84,12 @@ Momo .. autofunction:: momo_adam .. autoclass:: MomoAdamState +Muon +~~~~ +.. autofunction:: muon +.. autofunction:: scale_by_muon +.. autoclass:: MuonState + Prodigy ~~~~~~~ .. autofunction:: prodigy diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index a310cc23..93c84e49 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -35,6 +35,9 @@ from optax.contrib._momo import momo_adam from optax.contrib._momo import MomoAdamState from optax.contrib._momo import MomoState +from optax.contrib._muon import muon +from optax.contrib._muon import scale_by_muon +from optax.contrib._muon import MuonState from optax.contrib._privacy import differentially_private_aggregate from optax.contrib._privacy import DifferentiallyPrivateAggregateState from optax.contrib._privacy import dpsgd diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index b2011848..81f379d1 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -44,6 +44,7 @@ dict(opt_name='dowg', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='muon', opt_kwargs=dict(learning_rate=1e-3)), dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)), dict( opt_name='schedule_free_sgd', diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py new file mode 100644 index 00000000..c78b1e61 --- /dev/null +++ b/optax/contrib/_muon.py @@ -0,0 +1,126 @@ +from typing import Any, List, NamedTuple, Optional, Tuple, Union + +import chex +import jax +import jax.numpy as jnp + +from optax import tree_utils as otu +from optax._src import base +from optax._src import combine +from optax._src import numerics +from optax._src import transform +from optax._src import utils + + +class MuonState(NamedTuple): + """State for the Adam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + + +def scale_by_muon( + newton_schulz_coeffs: Union[Tuple[float, float, float], List[Tuple[float, float, float]]] = (3.4445, -4.7750, 2.0315), + newton_schulz_steps: Optional[int] = 5, + mumentum: float = 0.95, + mu_dtype: Optional[chex.ArrayDType] = None, + *, + nesterov: bool = True, +) -> base.GradientTransformation: + r"""Rescale updates according to the Muon algorithm. + + Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize + the momentum accumulated by the optimizer. Mathematically, it does steepest descent + under the Schatten-p norm, for some large p. With p=infty, it is equivalent to + Shampoo without accumulation, or steepest descent under the Spectral norm. + + References: + Jordan, `Overview of mini-batch gradient descent + https://github.com/KellerJordan/modded-nanogpt`_, 2024 + + Args: + learning_rate: A global scaling factor, either fixed or evolving along + iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + newton_schulz_coeffs: Coefficients for the Newton-schulz method. + newton_schulz_steps: Number of Newton-schulz iterations. + mumentum: Exponential decay rate to track the first moment of past gradients. + mu_dtype: Data type of the momentum accumulator. + nesterov: Whether to use Nesterov momentum. + + Returns: + A `GradientTransformation` object. + """ + muon_coeffs = jnp.asarray( + newton_schulz_coeffs + if isinstance(newton_schulz_coeffs, list) + else [newton_schulz_coeffs] * newton_schulz_steps + ) + muon_iterator = lambda x, abc:(abc[0]*x + abc[1]*(x@x.T)@x + abc[2]*(x@x.T)@(x@x.T)@x, 0) + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment + return MuonState(count=jnp.zeros([], jnp.int32), mu=mu) + + def update_fn(updates, state, params=None): + del params + mu = otu.tree_update_moment(updates, state.mu, mumentum, 1) + count_inc = numerics.safe_int32_increment(state.count) + if nesterov: + mu_hat = jax.tree.map( + lambda m, g: mumentum * m + (1 - mumentum) * g, + otu.tree_bias_correction( + mu, mumentum, numerics.safe_int32_increment(count_inc) + ), + otu.tree_bias_correction(updates, mumentum, count_inc), + ) + else: + mu_hat = otu.tree_bias_correction(mu, mumentum, count_inc) + updates = jax.tree.map(lambda x: x / jnp.linalg.norm(x, ord='fro'), mu_hat) + updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs) + mu = otu.tree_cast(mu, mu_dtype) + return updates, MuonState(count=count_inc, mu=mu) + return base.GradientTransformation(init_fn, update_fn) + + +def muon( + learning_rate: base.ScalarOrSchedule, + newton_schulz_coeffs: Tuple[float, float, float] | List[Tuple[float, float, float]] = (3.4445, -4.7750, 2.0315), + newton_schulz_steps: Optional[int] = 5, + mumentum: float = 0.95, + mu_dtype: Optional[Any] = None, + *, + nesterov: bool = True, +) -> base.GradientTransformation: + r"""Muon: Momentum Orthogonalized by Newton-schulz + + Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize + the momentum accumulated by the optimizer. Mathematically, it does steepest descent + under the Schatten-p norm, for some large p. With p=infty, it is equivalent to + Shampoo without accumulation, or steepest descent under the Spectral norm. + + References: + Jordan, `Overview of mini-batch gradient descent + https://github.com/KellerJordan/modded-nanogpt`_, 2024 + + Args: + learning_rate: A global scaling factor, either fixed or evolving along + iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + newton_schulz_coeffs: Coefficients for the Newton-schulz method. + newton_schulz_steps: Number of Newton-schulz iterations. + mumentum: Exponential decay rate to track the first moment of past gradients. + mu_dtype: Data type of the momentum accumulator. + nesterov: Whether to use Nesterov momentum. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + scale_by_muon( + newton_schulz_coeffs, + newton_schulz_steps, + mumentum, + mu_dtype, + nesterov=nesterov, + ), + transform.scale_by_learning_rate(learning_rate), + )