A PyTorch implementation of the Muon optimizer - MomentUm Orthogonalized by Newton-schulz.
⚠️ Important Notice: This codebase is largely generated by AI and is currently a work in progress. While the implementation follows the Muon optimizer specification, it has not been extensively tested in production environments. Use with caution and thorough testing for critical applications.
Muon is an optimizer that combines standard SGD-momentum with an orthogonalization post-processing step. Each 2D parameter's update is replaced with the nearest orthogonal matrix using efficient Newton-Schulz iteration. This approach provides better training stability and convergence for deep neural networks.
- Orthogonalized Momentum Updates: Replaces parameter updates with their nearest orthogonal matrix
- Efficient Newton-Schulz Iteration: Fast orthogonalization using optimized quintic iteration
- Distributed Training Support: Full support for multi-GPU and multi-node training
- Hybrid Optimization: Combine Muon with AdamW for different parameter types
- Comprehensive Features: Error handling, type hints, and logging
- Stable bfloat16 Computation: Optimized for GPU training with mixed precision
# Install dependencies
pip install torch>=2.7.1 numpy rich
# For benchmarking and visualization
pip install matplotlib plotly seaborn dash dash-bootstrap-components
# Development install
pip install -e .
import torch
import torch.nn as nn
from muon_optimizer import SingleDeviceMuon
# Create a simple model
model = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
# Initialize optimizer
optimizer = SingleDeviceMuon(model.parameters(), lr=0.02, momentum=0.95)
# Training loop
for batch_idx, (data, target) in enumerate(dataloader):
optimizer.zero_grad()
output = model(data)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
optimizer.step()
For best results, use Muon for matrix weights and AdamW for other parameters:
from muon_optimizer import SingleDeviceMuonWithAuxAdam, create_muon_param_groups
# Create parameter groups automatically
param_groups = create_muon_param_groups(
model,
muon_lr=0.02,
adam_lr=3e-4,
muon_momentum=0.95,
weight_decay=0.01
)
# Initialize hybrid optimizer
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
from muon_optimizer import SingleDeviceMuonWithAuxAdam
# Separate parameters by dimension
matrix_params = [p for p in model.parameters() if p.ndim >= 2]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
param_groups = [
# Muon for matrix parameters
{
"params": matrix_params,
"lr": 0.02,
"momentum": 0.95,
"weight_decay": 0.01,
"ns_steps": 5,
"use_muon": True
},
# AdamW for scalar parameters (biases)
{
"params": scalar_params,
"lr": 3e-4,
"betas": (0.9, 0.95),
"eps": 1e-10,
"weight_decay": 0.01,
"use_muon": False
}
]
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
python minimalist_quadratic_optimization.py
This example demonstrates Muon optimization on simple 2D quadratic functions with visualization.
python mnist_optimizer_benchmark.py
This comprehensive benchmark compares Muon against SGD and Adam optimizers on MNIST classification using a CNN.
Single-device Muon optimizer for non-distributed training.
Parameters:
params
: Iterable of parameters to optimizelr
: Learning rate (default: 0.02)weight_decay
: Weight decay coefficient (default: 0)momentum
: Momentum coefficient (default: 0.95)ns_steps
: Newton-Schulz iteration steps (default: 5)
Distributed Muon optimizer for multi-GPU training. Requires torch.distributed
initialization.
Hybrid optimizer combining Muon and AdamW for different parameter groups.
Parameters:
param_groups
: List of parameter group dictionaries withuse_muon
boolean flag
Distributed version of the hybrid optimizer.
Automatically create parameter groups from a PyTorch model.
Parameters:
model
: PyTorch model to extract parameters frommuon_lr
: Learning rate for Muon parameters (default: 0.02)adam_lr
: Learning rate for AdamW parameters (default: 3e-4)muon_momentum
: Momentum for Muon parameters (default: 0.95)muon_ns_steps
: Newton-Schulz steps for Muon (default: 5)adam_betas
: Beta parameters for AdamW (default: (0.9, 0.95))weight_decay
: Weight decay coefficient (default: 0)eps
: Epsilon for AdamW numerical stability (default: 1e-10)
Compute orthogonalization using Newton-Schulz iteration.
Perform a single Muon update step with orthogonalization.
Perform a single Adam update step.
- Use Muon for: 2D+ matrix parameters (linear layers, conv weights)
- Use AdamW for: 1D parameters (biases), embeddings, output layers
- Muon parameters: Start with
lr=0.02
, adjust based on convergence - AdamW parameters: Use standard rates like
3e-4
- Muon requires momentum buffers (similar to SGD with momentum)
- Orthogonalization uses bfloat16 for GPU efficiency
- Parameters are automatically sorted by size for distributed efficiency
# Run all tests
python -m pytest muon_optimizer_test.py -v
# Run specific test class
python -m pytest muon_optimizer_test.py::TestMuonOptimizer -v
# Run with coverage
python -m pytest muon_optimizer_test.py --cov=muon_optimizer
# Format code
black muon_optimizer.py
# Sort imports
isort muon_optimizer.py
# Check linting
flake8 muon_optimizer.py
# Type checking
mypy muon_optimizer.py
- Distributed Training: Uses
torch.distributed.all_gather()
for parameter synchronization - Numerical Stability: All orthogonalization performed in bfloat16 with 1e-7 epsilon
- Quintic Newton-Schulz: Optimized coefficients (3.4445, -4.7750, 2.0315) for convergence
- Automatic Reshaping: 4D conv weights reshaped to 2D for orthogonalization
- Parameter Sorting: Automatically sorted by size for efficient distributed processing
- Requires PyTorch >= 2.7.1 and Python >= 3.13
- Orthogonalization only applies to 2D+ parameters
- Distributed mode requires proper
torch.distributed
initialization - Performance not extensively benchmarked on large-scale models
If you use this implementation in your research, please cite the original Muon paper:
@misc{muon_optimizer,
title={Muon Optimizer: MomentUm Orthogonalized by Newton-schulz},
author={Keller Jordan and contributors},
year={2024},
url={https://kellerjordan.github.io/posts/muon/}
}
MIT License - see LICENSE file for details.
- Original Muon implementation by Keller Jordan
- Contributions from @scottjmaddox, @YouJiacheng, @jxbz, @leloykun
- PyTorch team for the framework