Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

self.recipe_name = job_config.quantize.grouped_mm.mx.recipe_name
self.enabled = True
logger.info("MXFP8 MoE training enabled")
logger.info(f"MXFP8 MoE training enabled with recipe: {self.recipe_name}")

def convert(self, model: nn.Module):
"""
Expand All @@ -154,7 +154,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return True
return False

config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
config = MoETrainingConfig(scaling_type=MoEScalingType(self.recipe_name))
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
logger.info(
f"Converted MoE layers matching FQNS {self.moe_fqns} "
Expand Down
9 changes: 7 additions & 2 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,10 +784,15 @@ class MXLinear:

@dataclass
class MXGroupedMM:
recipe_name: Literal["mxfp8"] = "mxfp8"
recipe_name: Literal["mxfp8", "mxfp8_wgrad_with_hp"] = "mxfp8"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tutorial

The mxfp8_wgrad_with_hp recipe is required for MoE training with expert parallelism.

  • why it's only required for EP?
  • why here default is not wgrad with hp?
  • why mxfp8 is an option if hp wgrad is "required"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just saw #2250 (comment)

Does it make sense that we don't give user control, just do

  • mxfp8 when DeepEP is used, or no EP is used
  • mxfp8_wgrad_with_hp when EP > 1 but DeepEP is not used

When EP is not enabled, which one should user use?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When EP is not enabled, which one should user use?

  • For performance, it depends on the expert shapes, batch size, and sequence length. For directional guidance on which recipe is better in a given context, I'm thinking about working on some tables like we have for float8 (see screenshot below). There has been positive feedback from users on this.
  • For accuracy / improved step quality, wgrad_with_hp compute weight gradients in bf16, so this recipe can be more of a net benefit, not in terms of TPS but in terms of "time to target validation loss" or "time to some eval score threshold." Luca also found perf benefit for certain smaller shapes, with fp8_rowwise_with_gw_hp (Hopper recipe). These have to evaluated through experimentation though.
Screenshot 2026-02-09 at 4 54 08 PM

"""
Quantization recipe name for grouped GEMMs. Options: ["mxfp8"]
Quantization recipe name for grouped GEMMs. Options: ["mxfp8", "mxfp8_wgrad_with_hp"]

Recipes:
- "mxfp8": Use MXFP8 for all 3 grouped GEMMs in the forward and backward pass (output, dgrad, wgrad).
- "mxfp8_wgrad_with_hp": Use MXFP8 for forward output and dgrad, but keep wgrad in high-precision.
This can be used to trade-off some performance for improved accuracy. For some smaller expert shapes,
it is also better for performance.
Example: --quantize.grouped_mm.mx.recipe_name="mxfp8"
"""

Expand Down
Loading