Summary
Request to investigate integrating FlashOptim into Megatron Core's optimizer infrastructure. FlashOptim reduces AdamW memory from 16 to 7 bytes/parameter (5 with gradient release) by combining ULP-bounded master weight splitting with companded 8-bit optimizer state quantization — with no measurable quality loss and no wall-clock overhead.
Motivation
Optimizer states dominate training memory for large models. FlashOptim addresses this with two complementary techniques:
Master Weight Splitting: Splits FP32 weights into BF16 + INT8 error correction. Key insight: reconstruction error is bounded by the unit of least precision (ULP), so the exponent bits in the correction term are redundant. Achieves ~24-bit effective precision with bitwise-perfect reconstruction in 99.92% of values.
Companded State Quantization: Applies nonlinear companding before INT8 quantization of optimizer states — 2x/(1+|x|) for momentum (signed), √x for variance (unsigned). This is critical: linear quantization of Adam states causes divergence due to heavy-tailed variance distributions.
Results:
| Task |
Model |
Reference |
FlashOptim |
| ImageNet |
ResNet-50 |
77.01% |
77.16% |
| Pretraining |
GPT-2 124M |
3.263 loss |
3.265 loss |
| Finetuning |
Llama-3.1-8B |
75.09% GSM8k |
74.98% GSM8k |
All within measurement variance. Optimizer step is actually 8% faster on Llama-3.1-8B (fused Triton kernels). Supports SGD, AdamW, and Lion.
Memory Savings (bytes/parameter)
|
AdamW |
FlashAdamW |
FlashAdamW + grad release |
| Total |
16 |
7 |
5 |
Achieved via BF16+INT8 master weight splitting and companded INT8 optimizer states.
Requested Feature
Investigate adding FlashAdamW as a supported optimizer in Megatron-Core's distributed optimizer, including compatibility with existing distributed checkpointing. FlashOptim >= 0.1.3 provides native DTensor support for PyTorch DCP/FSDP2 integration.
References
Summary
Request to investigate integrating FlashOptim into Megatron Core's optimizer infrastructure. FlashOptim reduces AdamW memory from 16 to 7 bytes/parameter (5 with gradient release) by combining ULP-bounded master weight splitting with companded 8-bit optimizer state quantization — with no measurable quality loss and no wall-clock overhead.
Motivation
Optimizer states dominate training memory for large models. FlashOptim addresses this with two complementary techniques:
Master Weight Splitting: Splits FP32 weights into BF16 + INT8 error correction. Key insight: reconstruction error is bounded by the unit of least precision (ULP), so the exponent bits in the correction term are redundant. Achieves ~24-bit effective precision with bitwise-perfect reconstruction in 99.92% of values.
Companded State Quantization: Applies nonlinear companding before INT8 quantization of optimizer states —
2x/(1+|x|)for momentum (signed),√xfor variance (unsigned). This is critical: linear quantization of Adam states causes divergence due to heavy-tailed variance distributions.Results:
All within measurement variance. Optimizer step is actually 8% faster on Llama-3.1-8B (fused Triton kernels). Supports SGD, AdamW, and Lion.
Memory Savings (bytes/parameter)
Achieved via BF16+INT8 master weight splitting and companded INT8 optimizer states.
Requested Feature
Investigate adding
FlashAdamWas a supported optimizer in Megatron-Core's distributed optimizer, including compatibility with existing distributed checkpointing. FlashOptim >= 0.1.3 provides native DTensor support for PyTorch DCP/FSDP2 integration.References
FlashAdamWwith FSDP2/DCP support