Muon Optimizer #2601
Replies: 9 comments 2 replies
-
| @caojiaolong I was pondering doing it, but it's in torch now https://github.com/pytorch/pytorch/blob/main/torch/optim/_muon.py ... so that lowers priority for adding to timm unless there are any modes/features not in the torch version. EDIT: I will add the mapping to timm to allow using the torch.optim.Muon w/ the timm factory. One thing with the original, I had some q re compatibility with convnet weight layouts as the dim handling isn't always friendly with convs and transpose convs. Think some of that might have been fixed, I have to try out the torch version, haven't had a chance. | 
Beta Was this translation helpful? Give feedback.
-
| I could never get Shampoo and SOAP working well (finding decent hparams for vision tasks). Though another related optimizer Kron works quite well. | 
Beta Was this translation helpful? Give feedback.
-
| @caojiaolong looking more closely, surprisingly the pytorch impl of Muon only supports 2D parameters... sooo guess that's not going to work. I figured they'd generalize it better if it was going to be included there... | 
Beta Was this translation helpful? Give feedback.
-
| @caojiaolong pushed an impl I had sitting around closer to completion and created a PR (#2596), still testing some things. It's based on Keller's like many of the others, but I added/integrated a few additional options, sped up the NS iteration a bit without getting too crazy. I should work with convnets and hybrid vit-cnn decently and has two modes for that, flatten like Keller's and another option that treats the spatial kernel dims as a batch dim for the NS iterations. | 
Beta Was this translation helpful? Give feedback.
-
| I tweaked the heuristics to assign params to Muon vs AdamW/NAdamW updates. Seemed to improve behaviour with convnets like EfficientNets / MobileNets that have lots of depthwise convs. I think I'm ready to merge the initial version, but if anyone following wants to try it, feedback welcome..... | 
Beta Was this translation helpful? Give feedback.
-
| Hello @rwightman , Thank you! I have tried the Muon optimizer on a private self-supervised learning task, but observed a performance gap (frankly, worse) with Adamw, probably task-specific. I'll try more tasks to see how Muon performs. Thanks for your quick response on the demand of Muon~ | 
Beta Was this translation helpful? Give feedback.
-
| @rwightman Thank you for your excellent work. Training details can be found in this W&B report. export OMP_NUM_THREADS=2
export MKL_NUM_THREADS=2
# export HF_DATASETS_IN_MEMORY_MAX_SIZE=50240000
MODEL=convnext_tiny # drop-path 0.1, 0.1, 0.15, 0.2 for T, S, B, L
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=31929 bash distributed_train.sh 8 \
    `# Dataset parameters` \
    --dataset hfds//ImageNet_arrow_rgbdpa_new/ \
    --data-dir /.cache/huggingface/datasets/ \
    --train-split train \
    --val-split validation \
    --input-img-mode RGB \
    --input-key rgb \
    $() \
    `# Model parameters` \
    --model $MODEL \
    --num-classes 1000 \
    --input-size 3 224 224 \
    --mean 0.485 0.456 0.406 \
    --std 0.229 0.224 0.225 \
    --batch-size 256 \
    --grad-accum-steps 2 \
    $() \
    `# Scripting / codegen` \
    `#--torchcompile inductor` \
    $() \
    `# Device & distributed` \
    --amp \
    $() \
    `# Optimizer parameters` \
    --opt muon \
    --weight-decay 0.05 \
    --opt-kwargs "fallback_list=['head.*', 'stem.*']" \
    $() \
    `# Learning rate schedule parameters` \
    --sched-on-updates \
    --lr-base 0.004 \
    --lr-base-size 4096 \
    --warmup-lr 1e-6 \
    --epochs 300 \
    --warmup-epochs 20 \
    $() \
    `# Augmentation & regularization parameters` \
    --aa rand-m9-mstd0.5-inc1 \
    --reprob 0.25 \
    --remode pixel \
    --cutmix 1.0 \
    --mixup 0.8 \
    --drop-path 0.1 \
    $() \
    `# Model Exponential Moving Average` \
    --model-ema \
    --model-ema-decay 0.9999 \
    --model-ema-warmup \
    $() \
    `# Misc` \
    --seed 42 \
    --log-interval 1 \
    --workers 8 \
    --pin-mem \
    --output output/train \
    --experiment convnext_tiny_rep_nmuon_rgb \
    --use-multi-epochs-loader \
    --log-wandb \
    --wandb-project RGBDPretrainI’m now also experimenting with nMuon, and will share the results once it’s done. | 
Beta Was this translation helpful? Give feedback.
-
| @caojiaolong @sjiang95 so far, this mirrors my experience ... early convergence is notably faster with Muon but over a typical training schedule AdamW or NAdamW ends up winning by a bit. In larger scale LLM training, I feel you're often in a more 'undertrained' state where this convergence behaviour is highly beneficial, vs doing hundreds of epochs for smaller vision models on smaller data. It does seem like LR could/should be pushed a bit higher for Muon so trying that right now. I've been leaving the stem as Muon but the final head projection as AdamW. I don't think there's anything inherintely wrong with my Muon impl or Muon generally. But as usual, it's hard to beat good ol AdamW. Even with a theoretically better optimizer, you have to run through a whole search of LR, weight decay, schedule to be 'fair' and match familiar hparams on AdamW. | 
Beta Was this translation helpful? Give feedback.
-
| I think a better vision test would be pretraining a much larger ViT model, but I don't currently have free hardware resources to give that a spin. I'm going to convert this to a discussion for ongoing updates... | 
Beta Was this translation helpful? Give feedback.

Uh oh!
There was an error while loading. Please reload this page.
-
The recently introduced [Muon Optimizer](https://github.com/KellerJordan/Muon) has shown promising results, claiming to outperform commonly used optimizers such as AdamW, Shampoo, and SOAP ([reference benchmarks](https://github.com/KellerJordan/modded-nanogpt/tree/master/records/102924_Optimizers)).
Given its potential advantages, it would be valuable to have Muon integrated into the
timmlibrary for broader experimentation and adoption.Is there any plan or interest in adding support for Muon Optimizer to
timm?Beta Was this translation helpful? Give feedback.
All reactions