[mxfp8 MoE training] Support mxfp8 all to all in expert parallel #1765
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
NOTE: requires torchao changes pytorch/ao#3103 (wait for land)
Summary
"mxfp8"
a2a impl option which uses torchao mxfp8mxfp8_sync_all_to_all_v
, which provides a ~1.17x speedup over the default impl in microbenchmarks: [mxfp8 moe training] mxfp8 a2a with d2h sync ao#3103.mxfp8_sync_all_to_all_v
has the exact same API as functional collectiveall_to_all_single_autograd
, and is differentiable, so it can be used as a drop-in replacement for the default a2a impl.mxfp8_sync_all_to_all_v
simply quantizes the inputs, all_to_all_single on the e4m3 data and e8m0 data, then dequantizes the outputs. There's quantization/dequantization overhead, but net benefit of 1.17x speedup due to lower network bandwidth usage.Additional context
30% of llama4 model profiled runtime is all2all comms
47% avg runtime devoted to MoE comms in profiled OSS models