-
Notifications
You must be signed in to change notification settings - Fork 700
[mxfp8 moe training] support wgrad_with_hp recipe #2249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #2249, branch: danielvegamyhre/stack/3
df304f1 to
a59fe09
Compare
|
cc @tianyu-l for review. |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would love to see some requests / tech reports / experiments / scientific studies / guidance on such new features -- with all these flexible options, it's not clear to users what works best.
stack-info: PR: #2249, branch: danielvegamyhre/stack/3
a59fe09 to
ac87d5c
Compare
We are publishing docs and an e2e tutorial on torchao docsite on this, which we can link to in the torchtitan MXFP8 docs, if that works? If so, I can add a new PR on top of this stack with documentation updates. |
would love to read this tutorial |
Here is the tutorial preview: https://docs-preview.pytorch.org/pytorch/ao/3752/mxfp8_expert_parallel.html the torchtitan config params will need to be updated since i changed the MXFP8ExpertParallel to be applied automatically for mxfp8_wgrad_with_hp recipe, but otherwise it is good to go i think, let me know if you have any thoughts |
|
test failures are unrelated |
| @dataclass | ||
| class MXGroupedMM: | ||
| recipe_name: Literal["mxfp8"] = "mxfp8" | ||
| recipe_name: Literal["mxfp8", "mxfp8_wgrad_with_hp"] = "mxfp8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the tutorial
The mxfp8_wgrad_with_hp recipe is required for MoE training with expert parallelism.
- why it's only required for EP?
- why here default is not wgrad with hp?
- why mxfp8 is an option if hp wgrad is "required"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just saw #2250 (comment)
Does it make sense that we don't give user control, just do
- mxfp8 when DeepEP is used, or no EP is used
- mxfp8_wgrad_with_hp when EP > 1 but DeepEP is not used
When EP is not enabled, which one should user use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When EP is not enabled, which one should user use?
- For performance, it depends on the expert shapes, batch size, and sequence length. For directional guidance on which recipe is better in a given context, I'm thinking about working on some tables like we have for float8 (see screenshot below). There has been positive feedback from users on this.
- For accuracy / improved step quality, wgrad_with_hp compute weight gradients in bf16, so this recipe can be more of a net benefit, not in terms of TPS but in terms of "time to target validation loss" or "time to some eval score threshold." Luca also found perf benefit for certain smaller shapes, with
fp8_rowwise_with_gw_hp(Hopper recipe). These have to evaluated through experimentation though.
Stacked PRs:
Summary
torchao MXFP8 MoE training code now supports a new recipe:
wgrad_with_hp(described below). This PR update the torchtitan integration to allow users to use. itRecipes:
- "mxfp8": Use MXFP8 for all 3 grouped GEMMs in the forward and backward pass (output, dgrad, wgrad).
- "mxfp8_wgrad_with_hp": Use MXFP8 for forward output and dgrad, but keep wgrad in high-precision.
This can be used to trade-off some performance for improved accuracy.
Note: I plan to do some benchmarking to provide more concrete guidance to users on what expert shapes will result in better TPS using wgrad_with_hp
Tests
mxfp8 recipe:
mxfp8_wgrad_with_hp recipe: