Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 26, 2025

NOTE: requires torchao changes pytorch/ao#3103 (wait for land)

Summary

  • Make EP a2a impl configurable and add "mxfp8" a2a impl option which uses torchao mxfp8 mxfp8_sync_all_to_all_v, which provides a ~1.17x speedup over the default impl in microbenchmarks: [mxfp8 moe training] mxfp8 a2a with d2h sync ao#3103.
  • torchao mxfp8_sync_all_to_all_v has the exact same API as functional collective all_to_all_single_autograd, and is differentiable, so it can be used as a drop-in replacement for the default a2a impl.
  • torchao mxfp8_sync_all_to_all_v simply quantizes the inputs, all_to_all_single on the e4m3 data and e8m0 data, then dequantizes the outputs. There's quantization/dequantization overhead, but net benefit of 1.17x speedup due to lower network bandwidth usage.
  • Note: both default and mxfp8 impls require the d2h sync to get input_splits/output_splits on the host for the a2a call.
    • I also explored a no-sync/on-device implementation using Triton + Symmetric memory, and got it working e2e in a torchtitan PoC: [mxfp8 moe training] mxfp8 a2a working e2e in torchtitan llama4 training; improve tests + bench scripts ao#3088
    • I found that this design of preallocating over-allocated symmetric memory buffers for exchange of variable token numbers (to avoid syncs required for exact allocation, while risking either crash or token dropping if overflow factor heuristic is wrong), is fundamentally in conflict with the torchtitan MoE design of doing a d2h sync to safely do exact allocation. Extracting out the variable size outputs from the padded buffers causes d2h sync (causing perf to regress below baseline), and we can't avoid this since otherwise downstream ops will break due to shape mismatches - the whole model basically would need to be designed assuming the static padded shapes.
    • Therefore, we choose to integrate this more straight-forward impl that is natively compatible with non-experimental titan MoE design

Additional context

  • MoE performance literature has shown ~47% average runtime for flagship OSS MoE models (Qwen2, Phi3.5, Mixtra8x7b) is due to exposed MoE comms.
  • Torchtitan Llama4 debug model with EP=4, ~30% of MoE training with EP is a2a comms, most of that exposed (see trace screenshot), which directionally corroborates this.
  • We can optimize this via (1) quantizing the comms to minimize data sent over NVLink/IB, (2) avoid d2h sync that can occur in implementations which move a2a output splits from device->host to compute exact preallocation necessary for incoming tokens, and (3) finer grained overlapping techniques.

30% of llama4 model profiled runtime is all2all comms

  • FSDP=4, EP=4, dim=5120, num_experts=16, seq_len=8192, local_batch_size=8
Screenshot 2025-09-29 at 3 08 47 PM

47% avg runtime devoted to MoE comms in profiled OSS models

Screenshot 2025-09-29 at 3 11 00 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 26, 2025
@danielvegamyhre danielvegamyhre force-pushed the mx-a2a branch 2 times, most recently from 4527f8b to bba9c6a Compare September 27, 2025 00:13
@danielvegamyhre danielvegamyhre force-pushed the mx-a2a branch 2 times, most recently from fde6de2 to a48e631 Compare September 29, 2025 22:27
@danielvegamyhre danielvegamyhre changed the title [WIP] Support mxfp8 on device all_to_all_v in expert parallel Support mxfp8 on device all_to_all_v in expert parallel Sep 29, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 29, 2025 23:04
@danielvegamyhre danielvegamyhre changed the title Support mxfp8 on device all_to_all_v in expert parallel Support mxfp8 all to all in expert parallel Sep 30, 2025
@danielvegamyhre danielvegamyhre changed the title Support mxfp8 all to all in expert parallel [mxfp8 MoE training] Support mxfp8 all to all in expert parallel Sep 30, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 30, 2025 21:17
@danielvegamyhre danielvegamyhre marked this pull request as draft September 30, 2025 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant