From ac87d5c920cfb6c85d21bd95716ab8d061a6e5b0 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 17 Jan 2026 20:36:55 +0000 Subject: [PATCH] [mxfp8 moe training] support wgrad_with_hp recipe stack-info: PR: https://github.com/pytorch/torchtitan/pull/2249, branch: danielvegamyhre/stack/3 --- torchtitan/components/quantization/mx.py | 4 ++-- torchtitan/config/job_config.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 3bdd250c15..16b83381bd 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -132,7 +132,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.recipe_name = job_config.quantize.grouped_mm.mx.recipe_name self.enabled = True - logger.info("MXFP8 MoE training enabled") + logger.info(f"MXFP8 MoE training enabled with recipe: {self.recipe_name}") def convert(self, model: nn.Module): """ @@ -154,7 +154,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return True return False - config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8) + config = MoETrainingConfig(scaling_type=MoEScalingType(self.recipe_name)) quantize_(model, config=config, filter_fn=moe_module_filter_fn) logger.info( f"Converted MoE layers matching FQNS {self.moe_fqns} " diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index b3a24c7847..4fdffb78cf 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -784,10 +784,15 @@ class MXLinear: @dataclass class MXGroupedMM: - recipe_name: Literal["mxfp8"] = "mxfp8" + recipe_name: Literal["mxfp8", "mxfp8_wgrad_with_hp"] = "mxfp8" """ - Quantization recipe name for grouped GEMMs. Options: ["mxfp8"] + Quantization recipe name for grouped GEMMs. Options: ["mxfp8", "mxfp8_wgrad_with_hp"] + Recipes: + - "mxfp8": Use MXFP8 for all 3 grouped GEMMs in the forward and backward pass (output, dgrad, wgrad). + - "mxfp8_wgrad_with_hp": Use MXFP8 for forward output and dgrad, but keep wgrad in high-precision. + This can be used to trade-off some performance for improved accuracy. For some smaller expert shapes, + it is also better for performance. Example: --quantize.grouped_mm.mx.recipe_name="mxfp8" """