-
Notifications
You must be signed in to change notification settings - Fork 700
[mxfp8 moe training] temp workaround: don't compile GroupedExperts #2268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: danielvegamyhre/stack/5
Are you sure you want to change the base?
Conversation
stack-info: PR: #2268, branch: danielvegamyhre/stack/6
f7bd318 to
baa9d08
Compare
| ), | ||
| ) | ||
| # temp workaround: compile everything except GroupedExperts | ||
| if not isinstance(submod, moe_module.GroupedExperts): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's also another hardcoded call to compile _run_experts_grouped_mm here that you'd need to skip (I'm doing it in my patch in https://gist.github.com/bdhirsh/970a671b84c35cc95a76f33657ca4d69)
stack-info: PR: #2268, branch: danielvegamyhre/stack/6
baa9d08 to
35fa4d4
Compare
torchtitan/models/llama4/__init__.py
Outdated
| "17bx16e": TransformerModelArgs( | ||
| dim=5120, | ||
| n_layers=48, | ||
| n_layers=6, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert
| @@ -687,14 +689,8 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b | |||
| in moe_module._run_experts_grouped_mm.__qualname__ | |||
| ) | |||
| if not already_patched: | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this logic since we are not compiling GroupedExperts any more?
| submod, backend=compile_config.backend, fullgraph=True | ||
| ), | ||
| ) | ||
| # temp workaround: compile everything except GroupedExperts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment on the issue this code is working around
|
@tianyu-l thanks for the early feedback. the top 3 PRs in this stack aren't quite ready - once scale testing is complete and i've confirmed the mxfp8 a2a feature works as expected (perf, numerics) then i'll polish it up, add PR descriptions, and let you know when it's ready. So far so good, just waiting for a long training run to finish to verify training stability and identical convergence to bf16. Should be ready early next week. |
stack-info: PR: #2268, branch: danielvegamyhre/stack/6
35fa4d4 to
d31d45d
Compare
stack-info: PR: #2268, branch: danielvegamyhre/stack/6
d31d45d to
61d426f
Compare
stack-info: PR: #2268, branch: danielvegamyhre/stack/6
61d426f to
44bfc34
Compare
stack-info: PR: #2268, branch: danielvegamyhre/stack/6
44bfc34 to
cda7d8d
Compare
stack-info: PR: #2268, branch: danielvegamyhre/stack/6
cda7d8d to
d920414
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have to land this, if #2281 is going to be landed soon
Stacked PRs:
[mxfp8 moe training] temp workaround: don't compile GroupedExperts
See thread (#2250 (comment)) for context.
TL;DR is this avoids a tensor metadata mismatch issue between forward output and backward() input (upstream grad).
Tests