Skip to content

phi9t/muon_optimizer

Repository files navigation

Muon Optimizer

CI/CD Pipeline Python 3.11+ Code style: black License: MIT

A PyTorch implementation of the Muon optimizer - MomentUm Orthogonalized by Newton-schulz.

⚠️ Important Notice: This codebase is largely generated by AI and is currently a work in progress. While the implementation follows the Muon optimizer specification, it has not been extensively tested in production environments. Use with caution and thorough testing for critical applications.

Overview

Muon is an optimizer that combines standard SGD-momentum with an orthogonalization post-processing step. Each 2D parameter's update is replaced with the nearest orthogonal matrix using efficient Newton-Schulz iteration. This approach provides better training stability and convergence for deep neural networks.

Key Features

  • Orthogonalized Momentum Updates: Replaces parameter updates with their nearest orthogonal matrix
  • Efficient Newton-Schulz Iteration: Fast orthogonalization using optimized quintic iteration
  • Distributed Training Support: Full support for multi-GPU and multi-node training
  • Hybrid Optimization: Combine Muon with AdamW for different parameter types
  • Comprehensive Features: Error handling, type hints, and logging
  • Stable bfloat16 Computation: Optimized for GPU training with mixed precision

Installation

# Install dependencies
pip install torch>=2.7.1 numpy rich

# For benchmarking and visualization
pip install matplotlib plotly seaborn dash dash-bootstrap-components

# Development install
pip install -e .

Quick Start

Basic Usage

import torch
import torch.nn as nn
from muon_optimizer import SingleDeviceMuon

# Create a simple model
model = nn.Sequential(
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Linear(512, 10)
)

# Initialize optimizer
optimizer = SingleDeviceMuon(model.parameters(), lr=0.02, momentum=0.95)

# Training loop
for batch_idx, (data, target) in enumerate(dataloader):
    optimizer.zero_grad()
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    loss.backward()
    optimizer.step()

Hybrid Optimization

For best results, use Muon for matrix weights and AdamW for other parameters:

from muon_optimizer import SingleDeviceMuonWithAuxAdam, create_muon_param_groups

# Create parameter groups automatically
param_groups = create_muon_param_groups(
    model,
    muon_lr=0.02,
    adam_lr=3e-4,
    muon_momentum=0.95,
    weight_decay=0.01
)

# Initialize hybrid optimizer
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)

Manual Parameter Grouping

from muon_optimizer import SingleDeviceMuonWithAuxAdam

# Separate parameters by dimension
matrix_params = [p for p in model.parameters() if p.ndim >= 2]
scalar_params = [p for p in model.parameters() if p.ndim < 2]

param_groups = [
    # Muon for matrix parameters
    {
        "params": matrix_params,
        "lr": 0.02,
        "momentum": 0.95,
        "weight_decay": 0.01,
        "ns_steps": 5,
        "use_muon": True
    },
    # AdamW for scalar parameters (biases)
    {
        "params": scalar_params,
        "lr": 3e-4,
        "betas": (0.9, 0.95),
        "eps": 1e-10,
        "weight_decay": 0.01,
        "use_muon": False
    }
]

optimizer = SingleDeviceMuonWithAuxAdam(param_groups)

Examples

Run the Quadratic Optimization Demo

python minimalist_quadratic_optimization.py

This example demonstrates Muon optimization on simple 2D quadratic functions with visualization.

Run the MNIST Benchmark

python mnist_optimizer_benchmark.py

This comprehensive benchmark compares Muon against SGD and Adam optimizers on MNIST classification using a CNN.

API Reference

Core Classes

SingleDeviceMuon(params, lr=0.02, weight_decay=0, momentum=0.95, ns_steps=5)

Single-device Muon optimizer for non-distributed training.

Parameters:

  • params: Iterable of parameters to optimize
  • lr: Learning rate (default: 0.02)
  • weight_decay: Weight decay coefficient (default: 0)
  • momentum: Momentum coefficient (default: 0.95)
  • ns_steps: Newton-Schulz iteration steps (default: 5)

Muon(params, lr=0.02, weight_decay=0, momentum=0.95, ns_steps=5)

Distributed Muon optimizer for multi-GPU training. Requires torch.distributed initialization.

SingleDeviceMuonWithAuxAdam(param_groups)

Hybrid optimizer combining Muon and AdamW for different parameter groups.

Parameters:

  • param_groups: List of parameter group dictionaries with use_muon boolean flag

MuonWithAuxAdam(param_groups)

Distributed version of the hybrid optimizer.

Utility Functions

create_muon_param_groups(model, muon_lr=0.02, adam_lr=3e-4, ...)

Automatically create parameter groups from a PyTorch model.

Parameters:

  • model: PyTorch model to extract parameters from
  • muon_lr: Learning rate for Muon parameters (default: 0.02)
  • adam_lr: Learning rate for AdamW parameters (default: 3e-4)
  • muon_momentum: Momentum for Muon parameters (default: 0.95)
  • muon_ns_steps: Newton-Schulz steps for Muon (default: 5)
  • adam_betas: Beta parameters for AdamW (default: (0.9, 0.95))
  • weight_decay: Weight decay coefficient (default: 0)
  • eps: Epsilon for AdamW numerical stability (default: 1e-10)

Core Functions

zeropower_via_newtonschulz5(G, steps)

Compute orthogonalization using Newton-Schulz iteration.

muon_update(grad, momentum, beta, ns_steps, nesterov)

Perform a single Muon update step with orthogonalization.

adam_update(grad, buf1, buf2, step, betas, eps)

Perform a single Adam update step.

Best Practices

Parameter Selection

  • Use Muon for: 2D+ matrix parameters (linear layers, conv weights)
  • Use AdamW for: 1D parameters (biases), embeddings, output layers

Learning Rate Guidelines

  • Muon parameters: Start with lr=0.02, adjust based on convergence
  • AdamW parameters: Use standard rates like 3e-4

Memory Considerations

  • Muon requires momentum buffers (similar to SGD with momentum)
  • Orthogonalization uses bfloat16 for GPU efficiency
  • Parameters are automatically sorted by size for distributed efficiency

Testing

# Run all tests
python -m pytest muon_optimizer_test.py -v

# Run specific test class
python -m pytest muon_optimizer_test.py::TestMuonOptimizer -v

# Run with coverage
python -m pytest muon_optimizer_test.py --cov=muon_optimizer

Development

Code Quality Tools

# Format code
black muon_optimizer.py

# Sort imports
isort muon_optimizer.py

# Check linting
flake8 muon_optimizer.py

# Type checking
mypy muon_optimizer.py

Architecture Notes

  • Distributed Training: Uses torch.distributed.all_gather() for parameter synchronization
  • Numerical Stability: All orthogonalization performed in bfloat16 with 1e-7 epsilon
  • Quintic Newton-Schulz: Optimized coefficients (3.4445, -4.7750, 2.0315) for convergence
  • Automatic Reshaping: 4D conv weights reshaped to 2D for orthogonalization
  • Parameter Sorting: Automatically sorted by size for efficient distributed processing

Known Limitations

  • Requires PyTorch >= 2.7.1 and Python >= 3.13
  • Orthogonalization only applies to 2D+ parameters
  • Distributed mode requires proper torch.distributed initialization
  • Performance not extensively benchmarked on large-scale models

Citation

If you use this implementation in your research, please cite the original Muon paper:

@misc{muon_optimizer,
  title={Muon Optimizer: MomentUm Orthogonalized by Newton-schulz},
  author={Keller Jordan and contributors},
  year={2024},
  url={https://kellerjordan.github.io/posts/muon/}
}

License

MIT License - see LICENSE file for details.

Acknowledgments

  • Original Muon implementation by Keller Jordan
  • Contributions from @scottjmaddox, @YouJiacheng, @jxbz, @leloykun
  • PyTorch team for the framework

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages