From 54d3b05859df33c09dbe597cc808a7a107e76aea Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 1 Jan 2026 18:12:13 +0000 Subject: [PATCH] [mxfp8 moe training] mxfp8 all to all stack-info: PR: https://github.com/pytorch/torchtitan/pull/2250, branch: danielvegamyhre/stack/4 --- torchtitan/components/quantization/mx.py | 2 +- torchtitan/config/job_config.py | 30 +++ torchtitan/distributed/expert_parallel.py | 184 +++++++++++++++--- .../models/deepseek_v3/infra/parallelize.py | 5 + torchtitan/models/llama4/infra/parallelize.py | 18 +- torchtitan/models/qwen3/infra/parallelize.py | 5 + 6 files changed, 216 insertions(+), 28 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 16b83381bd..cda64a281d 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -57,7 +57,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): MXLinearConfig as TorchAOMXLinearConfig, ) - mx_job_config: TorchAOMXLinearConfig = job_config.quantize.linear.mx + mx_job_config = job_config.quantize.linear.mx config = TorchAOMXLinearConfig.from_recipe_name(mx_job_config.recipe_name) config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[ mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 4fdffb78cf..53bda77754 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -793,6 +793,36 @@ class MXGroupedMM: - "mxfp8_wgrad_with_hp": Use MXFP8 for forward output and dgrad, but keep wgrad in high-precision. This can be used to trade-off some performance for improved accuracy. For some smaller expert shapes, it is also better for performance. + + This recipe also reduces communication volume by doing certain all-to-all operations in MXFP8. + Specifically, instead of doing MXFP8 quantization right before the MXFP8 grouped GEMMs for the + routed experts, we do the quantization earlier, before the preceding all-to-all dispatch, and + stay in MXFP8 through the token shuffle. + + This speeds up the all-to-all by sending fewer bytes over the network. Similarly, in the backward + pass, we quantize before the all-to-all combine, stay in MXFP8 through token unpermutation, and then + do the MXFP8 grouped GEMM. The diagram below visualzes this: + + Forward flow: + 1. a2a dispatch forward: high-precision input, MXFP8 output + 2. permute forward: MXFP8 input, MXFP8 output + 3. MXFP8 grouped GEMM: MXFP8 input, high-precision output + 4. unpermute forward: high-precision input and output + 5. a2a combine forward: high-precision input and output + + Backward flow: + 1. a2a combine backward: high-precision input, MXFP8 output + 2. unpermute backward: MXFP8 input, MXFP8 output + 3. MXFP8 grouped GEMM: MXFP8 input, high-precision output + 4. permute backward: high-precision input and output + 5. a2a dispatch backward: high-precision input and output + + Limitations: + - Not compatible with DeepEP + + MXFP8 requires installation of torchao nightly build for CUDA 12.8+: https://github.com/pytorch/ao + + Example: --quantize.grouped_mm.mx.recipe_name="mxfp8" """ diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 8ee53e754e..05343b2899 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -107,32 +107,13 @@ def _token_dispatch( ep_degree = device_mesh.shape[0] num_local_experts = num_tokens_per_expert.shape[0] // ep_degree - # generate the input splits and output splits for all-to-all - with torch.no_grad(): - num_tokens_per_expert_group = all_to_all_single( - num_tokens_per_expert, - None, - None, - group=device_mesh.get_group(), - ) - # Need to wait explicitly because it is used by a triton kernel later - # which doesn't realize that AsyncCollectiveTensor needs unwrapping - num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( - num_tokens_per_expert_group - ) - input_splits = ( - num_tokens_per_expert.view(ep_degree, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=True) - ) - # NOTE: this would incur a device-to-host sync - output_splits = ( - num_tokens_per_expert_group.view(ep_degree, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=False) - ) - self.input_splits = input_splits.tolist() - self.output_splits = output_splits.tolist() + # first all-to-all to calculate output splits from input splits. + # note: this will incur a d2h sync + ( + self.input_splits, + self.output_splits, + num_tokens_per_expert_group, + ) = get_a2a_splits(num_tokens_per_expert, device_mesh, ep_degree) # perform all-to-all routed_input = all_to_all_single_autograd( @@ -192,6 +173,111 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) +class MXFP8ExpertParallel(ExpertParallel): + def __init__(self): + super().__init__() + self.input_splits = None + self.output_splits = None + self.input_shape = None + self.permuted_indices = None + try: + from torchao.prototype.moe_training.ep import ( + a2a_combine_hp_fwd_mxfp8_bwd, + a2a_dispatch_mxfp8_fwd_hp_bwd, + permute_mxfp8_fwd_hp_bwd, + unpermute_hp_fwd_mxfp8_bwd, + ) + + self.a2a_dispatch_mxfp8_fwd_hp_bwd = a2a_dispatch_mxfp8_fwd_hp_bwd + self.permute_mxfp8_fwd_hp_bwd = permute_mxfp8_fwd_hp_bwd + self.unpermute_hp_fwd_mxfp8_bwd = unpermute_hp_fwd_mxfp8_bwd + self.a2a_combine_hp_fwd_mxfp8_bwd = a2a_combine_hp_fwd_mxfp8_bwd + except ImportError as e: + raise ImportError( + "MXFP8 expert parallel ops are not available." + "Please install torchao nightly build for CUDA 12.8+: " + "https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#-installation." + ) from e + + def _token_dispatch( + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh + ) -> tuple[Tensor, Tensor]: + # annotate module input placements/sharding with input_layouts + routed_input, num_tokens_per_expert = inputs + ep_degree = device_mesh.shape[0] + num_local_experts = num_tokens_per_expert.shape[0] // ep_degree + + # first all-to-all to calculate output splits from input splits. + # note: this will incur a d2h sync + ( + self.input_splits, + self.output_splits, + num_tokens_per_expert_group, + ) = get_a2a_splits(num_tokens_per_expert, device_mesh, ep_degree) + + # perform all-to-all + # TODO: set use_mxfp8=self.use_mxfp8_a2a_dispatch_fwd when the option is available in torchao + routed_input = self.a2a_dispatch_mxfp8_fwd_hp_bwd( + routed_input, + output_splits=self.output_splits, + input_splits=self.input_splits, + group_name=device_mesh.get_group().group_name, + ) + + # NOTE: After this all-to-all, the routed input is put on proper EP rank. + # However, the num_tokens_per_expert_group is not of the final target format + # [#tokens for local expert 0, #tokens for local expert 1, ...] + # Rather, it is of the format + # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., + # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] + # We need to perform another shuffle to get the correct layout, via the _permute function + # below, which also does padding to make sure the number of tokens each expert gets locally + # is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. + # Note that this will create side effects when wrapping the for-loop implementation + # of GroupedExperts, as it does not need padding. + + # TODO: set use_mxfp8=self.use_mxfp8_a2a_dispatch_fwd when the option is available in torchao + ( + self.input_shape, + routed_input, + self.permuted_indices, + num_tokens_per_expert_group, + _, + ) = self.permute_mxfp8_fwd_hp_bwd( + routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts + ) + + return routed_input, num_tokens_per_expert_group + + def _token_combine( + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh + ) -> Tensor: + # TODO: set use_mxfp8=self.use_mxfp8_a2a_combine_bwd when the option is available in torchao + routed_output = self.unpermute_hp_fwd_mxfp8_bwd( + routed_output, self.permuted_indices, self.input_shape + ) + + # TODO: set use_mxfp8=self.use_mxfp8_a2a_combine_bwd when the option is available in torchao + routed_output = self.a2a_combine_hp_fwd_mxfp8_bwd( + routed_output, + output_splits=self.input_splits, # swap input/output splits to reverse all-to-all dispatch + input_splits=self.output_splits, + group_name=device_mesh.get_group().group_name, + ) + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] + input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] + output_fn=self._token_combine, + ) + + # This class is for dp2ep with TP (without TP we can just use ExpertParallel) class ExpertTensorParallel(ExpertParallel): def _token_dispatch(self, mod, inputs, device_mesh): @@ -382,3 +468,49 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: input_fn=self._token_dispatch, # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, # pyrefly: ignore [bad-argument-type] ) + + +def get_a2a_splits( + num_tokens_per_expert: torch.Tensor, + device_mesh: DeviceMesh, + ep_degree: int, +) -> tuple[list[int], list[int], torch.Tensor]: + """ + Get the input and output splits for all-to-all comms in expert parallelism. + + Note: this incurs a device-to-host synchronization. + + Args: + num_tokens_per_expert: Tensor of shape (num_experts,) + device_mesh: Device mesh for expert parallelism + ep_degree: Expert parallelism degree + Returns: + input_splits: list of shape (ep_degree,) + output_splits: list of shape (ep_degree,) + num_tokens_per_expert_group: Tensor of shape (num_experts,) + """ + + with torch.no_grad(): + num_tokens_per_expert_group = all_to_all_single( + num_tokens_per_expert, + None, + None, + group=device_mesh.get_group(), + ) + # Need to wait explicitly because it is used by a triton kernel later + # which doesn't realize that AsyncCollectiveTensor needs unwrapping + num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( + num_tokens_per_expert_group + ) + input_splits = ( + num_tokens_per_expert.view(ep_degree, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) + ) + # NOTE: this would incur a device-to-host sync + output_splits = ( + num_tokens_per_expert_group.view(ep_degree, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=False) + ) + return input_splits.tolist(), output_splits.tolist(), num_tokens_per_expert_group diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 19d9f946d2..39c7e2bea8 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -120,6 +120,10 @@ def parallelize_deepseekv3( else: use_deepep = False + use_mxfp8_a2a = ( + "quantize.grouped_mm.mx" in job_config.model.converters + and job_config.quantize.grouped_mm.mx.recipe_name == "mxfp8_wgrad_with_hp" + ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) @@ -131,6 +135,7 @@ def parallelize_deepseekv3( ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), dual_pipe_v=dual_pipe_v, use_deepep=use_deepep, + use_mxfp8_a2a=use_mxfp8_a2a, ) if parallel_dims.cp_enabled: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 5e23b81ab1..8324080389 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -36,6 +36,7 @@ DeepEPExpertParallel, ExpertParallel, ExpertTensorParallel, + MXFP8ExpertParallel, ReordererSequenceParallel, TensorParallel, ) @@ -134,6 +135,10 @@ def parallelize_llama( else: use_deepep = False + use_mxfp8_a2a = ( + "quantize.grouped_mm.mx" in job_config.model.converters + and job_config.quantize.grouped_mm.mx.recipe_name == "mxfp8_wgrad_with_hp" + ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) @@ -145,6 +150,7 @@ def parallelize_llama( ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), dual_pipe_v=dual_pipe_v, use_deepep=use_deepep, + use_mxfp8_a2a=use_mxfp8_a2a, ) attn_type = getattr(model.model_args, "attn_type", "sdpa") @@ -505,8 +511,13 @@ def apply_moe_ep_tp( ep_etp_mesh: DeviceMesh | None, dual_pipe_v: bool = False, use_deepep: bool = False, + use_mxfp8_a2a: bool = False, ): assert ep_mesh is not None or tp_mesh is not None + assert (use_deepep and not use_mxfp8_a2a) or (not use_deepep and use_mxfp8_a2a), ( + 'DeepEP and MXFP8 all-to-all (part of quantize.grouped_mm.mx.recipe_name="mxfp8_wgrad_with_hp") are not compatible. ' + "Please choose one of them." + ) # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): @@ -574,7 +585,12 @@ def apply_moe_ep_tp( logger.info("Applying DeepEP to MoE layer") else: # input / output sharding on the batch / tokens dim - experts_plan = ExpertParallel() + if use_mxfp8_a2a: + logger.info("Applying MXFP8 Expert Parallelism to MoE") + experts_plan = MXFP8ExpertParallel() + else: + logger.info("Applying Expert Parallelism to MoE layer") + experts_plan = ExpertParallel() else: experts_mesh = ep_etp_mesh experts_plan = ExpertTensorParallel() diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 4837dbc68e..f2ebab369f 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -108,6 +108,10 @@ def parallelize_qwen3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + use_mxfp8_a2a = ( + "quantize.grouped_mm.mx" in job_config.model.converters + and job_config.quantize.grouped_mm.mx.recipe_name == "mxfp8_wgrad_with_hp" + ) apply_moe_ep_tp( model, tp_mesh=parallel_dims.get_optional_mesh("tp"), @@ -115,6 +119,7 @@ def parallelize_qwen3( etp_mesh=parallel_dims.get_optional_mesh("etp"), ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), dual_pipe_v=dual_pipe_v, + use_mxfp8_a2a=use_mxfp8_a2a, ) if parallel_dims.cp_enabled: