Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jan 17, 2026

Stacked PRs:


Summary

torchao MXFP8 MoE training code now supports a new recipe: wgrad_with_hp (described below). This PR update the torchtitan integration to allow users to use. it

Recipes:
- "mxfp8": Use MXFP8 for all 3 grouped GEMMs in the forward and backward pass (output, dgrad, wgrad).
- "mxfp8_wgrad_with_hp": Use MXFP8 for forward output and dgrad, but keep wgrad in high-precision.
This can be used to trade-off some performance for improved accuracy.

Note: I plan to do some benchmarking to provide more concrete guidance to users on what expert shapes will result in better TPS using wgrad_with_hp

Tests

mxfp8 recipe:

CONFIG_FILE=/home/dev/torchtitan/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml ./run_train.sh --metrics.log_freq=10 \--training.steps=200  \--parallelism.data_parallel_shard_degree=8 \--parallelism.expert_parallel_degree=8 \--parallelism.tensor_parallel_degree=1 \                                                  
--parallelism.expert_tensor_parallel_degree=1 \
--profiling.enable_profiling --profiling.profile_freq=30 \
--training.seq_len=8192 \
--activation_checkpoint.mode=none \
--model.print_after_conversion \
--training.local_batch_size=12 \
--model.converters="quantize.grouped_mm.mx,quantize.linear.mx" \
--quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="cuda" \
--quantize.linear.mx.filter_fqns="output,router.gate,wk,wv" \
--quantize.grouped_mm.mx.fqns="experts" --quantize.grouped_mm.mx.recipe_name="mxfp8_wgrad" \
--compile.enable --debug.moe_force_load_balance

mxfp8_wgrad_with_hp recipe:

CONFIG_FILE=/home/dev/torchtitan/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml ./run_train.sh --metrics.log_freq=10 \--training.steps=200  \--parallelism.data_parallel_shard_degree=8 \--parallelism.expert_parallel_degree=8 \--parallelism.tensor_parallel_degree=1 \                                                  
--parallelism.expert_tensor_parallel_degree=1 \
--profiling.enable_profiling --profiling.profile_freq=30 \
--training.seq_len=8192 \
--activation_checkpoint.mode=none \
--model.print_after_conversion \
--training.local_batch_size=12 \
--model.converters="quantize.grouped_mm.mx,quantize.linear.mx" \
--quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="cuda" \
--quantize.linear.mx.filter_fqns="output,router.gate,wk,wv" \
--quantize.grouped_mm.mx.fqns="experts" --quantize.grouped_mm.mx.recipe_name="mxfp8_wgrad_with_hp" \
--compile.enable --debug.moe_force_load_balance

danielvegamyhre added a commit that referenced this pull request Jan 17, 2026
stack-info: PR: #2249, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from df304f1 to a59fe09 Compare January 17, 2026 20:37
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 17, 2026
@danielvegamyhre
Copy link
Contributor Author

cc @tianyu-l for review. pre-commit passes locally but fails in CI, not sure why yet. I did a fresh pip install of requirements-dev.txt

@danielvegamyhre danielvegamyhre marked this pull request as draft January 17, 2026 23:51
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 17, 2026 23:51
@danielvegamyhre danielvegamyhre marked this pull request as draft January 18, 2026 23:21
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 18, 2026 23:21
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love to see some requests / tech reports / experiments / scientific studies / guidance on such new features -- with all these flexible options, it's not clear to users what works best.

@danielvegamyhre danielvegamyhre marked this pull request as draft January 20, 2026 20:40
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 20, 2026 20:40
@danielvegamyhre danielvegamyhre marked this pull request as draft January 20, 2026 22:49
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 20, 2026 22:50
@danielvegamyhre danielvegamyhre marked this pull request as draft January 21, 2026 04:22
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 21, 2026 04:22
@danielvegamyhre danielvegamyhre marked this pull request as draft January 21, 2026 17:09
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 21, 2026 17:09
@danielvegamyhre danielvegamyhre marked this pull request as draft January 21, 2026 21:51
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 21, 2026 21:51
stack-info: PR: #2249, branch: danielvegamyhre/stack/3
@danielvegamyhre danielvegamyhre marked this pull request as draft February 3, 2026 20:03
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from a59fe09 to ac87d5c Compare February 3, 2026 20:03
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 3, 2026 20:03
@danielvegamyhre danielvegamyhre marked this pull request as draft February 4, 2026 18:49
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 4, 2026 18:50
@danielvegamyhre danielvegamyhre marked this pull request as draft February 4, 2026 18:56
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 4, 2026 18:56
@danielvegamyhre danielvegamyhre marked this pull request as draft February 4, 2026 19:15
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 4, 2026 19:16
@danielvegamyhre
Copy link
Contributor Author

Would love to see some requests / tech reports / experiments / scientific studies / guidance on such new features -- with all these flexible options, it's not clear to users what works best.

We are publishing docs and an e2e tutorial on torchao docsite on this, which we can link to in the torchtitan MXFP8 docs, if that works? If so, I can add a new PR on top of this stack with documentation updates.

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 5, 2026

We are publishing docs and an e2e tutorial on torchao docsite on this

would love to read this tutorial

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Feb 5, 2026

We are publishing docs and an e2e tutorial on torchao docsite on this

would love to read this tutorial

Here is the tutorial preview: https://docs-preview.pytorch.org/pytorch/ao/3752/mxfp8_expert_parallel.html

the torchtitan config params will need to be updated since i changed the MXFP8ExpertParallel to be applied automatically for mxfp8_wgrad_with_hp recipe, but otherwise it is good to go i think, let me know if you have any thoughts

@danielvegamyhre
Copy link
Contributor Author

test failures are unrelated

@danielvegamyhre danielvegamyhre marked this pull request as draft February 6, 2026 22:59
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 6, 2026 23:00
@dataclass
class MXGroupedMM:
recipe_name: Literal["mxfp8"] = "mxfp8"
recipe_name: Literal["mxfp8", "mxfp8_wgrad_with_hp"] = "mxfp8"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tutorial

The mxfp8_wgrad_with_hp recipe is required for MoE training with expert parallelism.

  • why it's only required for EP?
  • why here default is not wgrad with hp?
  • why mxfp8 is an option if hp wgrad is "required"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just saw #2250 (comment)

Does it make sense that we don't give user control, just do

  • mxfp8 when DeepEP is used, or no EP is used
  • mxfp8_wgrad_with_hp when EP > 1 but DeepEP is not used

When EP is not enabled, which one should user use?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When EP is not enabled, which one should user use?

  • For performance, it depends on the expert shapes, batch size, and sequence length. For directional guidance on which recipe is better in a given context, I'm thinking about working on some tables like we have for float8 (see screenshot below). There has been positive feedback from users on this.
  • For accuracy / improved step quality, wgrad_with_hp compute weight gradients in bf16, so this recipe can be more of a net benefit, not in terms of TPS but in terms of "time to target validation loss" or "time to some eval score threshold." Luca also found perf benefit for certain smaller shapes, with fp8_rowwise_with_gw_hp (Hopper recipe). These have to evaluated through experimentation though.
Screenshot 2026-02-09 at 4 54 08 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants