Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
MXLinearConfig as TorchAOMXLinearConfig,
)

mx_job_config: TorchAOMXLinearConfig = job_config.quantize.linear.mx
mx_job_config = job_config.quantize.linear.mx
config = TorchAOMXLinearConfig.from_recipe_name(mx_job_config.recipe_name)
config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
Expand Down
30 changes: 30 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,36 @@ class MXGroupedMM:
- "mxfp8_wgrad_with_hp": Use MXFP8 for forward output and dgrad, but keep wgrad in high-precision.
This can be used to trade-off some performance for improved accuracy. For some smaller expert shapes,
it is also better for performance.

This recipe also reduces communication volume by doing certain all-to-all operations in MXFP8.
Specifically, instead of doing MXFP8 quantization right before the MXFP8 grouped GEMMs for the
routed experts, we do the quantization earlier, before the preceding all-to-all dispatch, and
stay in MXFP8 through the token shuffle.

This speeds up the all-to-all by sending fewer bytes over the network. Similarly, in the backward
pass, we quantize before the all-to-all combine, stay in MXFP8 through token unpermutation, and then
do the MXFP8 grouped GEMM. The diagram below visualzes this:

Forward flow:
1. a2a dispatch forward: high-precision input, MXFP8 output
2. permute forward: MXFP8 input, MXFP8 output
3. MXFP8 grouped GEMM: MXFP8 input, high-precision output
4. unpermute forward: high-precision input and output
5. a2a combine forward: high-precision input and output

Backward flow:
1. a2a combine backward: high-precision input, MXFP8 output
2. unpermute backward: MXFP8 input, MXFP8 output
3. MXFP8 grouped GEMM: MXFP8 input, high-precision output
4. permute backward: high-precision input and output
5. a2a dispatch backward: high-precision input and output

Limitations:
- Not compatible with DeepEP

MXFP8 requires installation of torchao nightly build for CUDA 12.8+: https://github.com/pytorch/ao


Example: --quantize.grouped_mm.mx.recipe_name="mxfp8"
"""

Expand Down
184 changes: 158 additions & 26 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,32 +107,13 @@ def _token_dispatch(
ep_degree = device_mesh.shape[0]
num_local_experts = num_tokens_per_expert.shape[0] // ep_degree

# generate the input splits and output splits for all-to-all
with torch.no_grad():
num_tokens_per_expert_group = all_to_all_single(
num_tokens_per_expert,
None,
None,
group=device_mesh.get_group(),
)
# Need to wait explicitly because it is used by a triton kernel later
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
input_splits = (
num_tokens_per_expert.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
self.input_splits = input_splits.tolist()
self.output_splits = output_splits.tolist()
# first all-to-all to calculate output splits from input splits.
# note: this will incur a d2h sync
(
self.input_splits,
self.output_splits,
num_tokens_per_expert_group,
) = get_a2a_splits(num_tokens_per_expert, device_mesh, ep_degree)

# perform all-to-all
routed_input = all_to_all_single_autograd(
Expand Down Expand Up @@ -192,6 +173,111 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
)


class MXFP8ExpertParallel(ExpertParallel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's related to mxfp8 only, let's put this in quantization/mx.py?

def __init__(self):
super().__init__()
self.input_splits = None
self.output_splits = None
self.input_shape = None
self.permuted_indices = None
try:
from torchao.prototype.moe_training.ep import (
a2a_combine_hp_fwd_mxfp8_bwd,
a2a_dispatch_mxfp8_fwd_hp_bwd,
permute_mxfp8_fwd_hp_bwd,
unpermute_hp_fwd_mxfp8_bwd,
)

self.a2a_dispatch_mxfp8_fwd_hp_bwd = a2a_dispatch_mxfp8_fwd_hp_bwd
self.permute_mxfp8_fwd_hp_bwd = permute_mxfp8_fwd_hp_bwd
self.unpermute_hp_fwd_mxfp8_bwd = unpermute_hp_fwd_mxfp8_bwd
self.a2a_combine_hp_fwd_mxfp8_bwd = a2a_combine_hp_fwd_mxfp8_bwd
except ImportError as e:
raise ImportError(
"MXFP8 expert parallel ops are not available."
"Please install torchao nightly build for CUDA 12.8+: "
"https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#-installation."
) from e

def _token_dispatch(
self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh
) -> tuple[Tensor, Tensor]:
# annotate module input placements/sharding with input_layouts
routed_input, num_tokens_per_expert = inputs
ep_degree = device_mesh.shape[0]
num_local_experts = num_tokens_per_expert.shape[0] // ep_degree

# first all-to-all to calculate output splits from input splits.
# note: this will incur a d2h sync
(
self.input_splits,
self.output_splits,
num_tokens_per_expert_group,
) = get_a2a_splits(num_tokens_per_expert, device_mesh, ep_degree)

# perform all-to-all
# TODO: set use_mxfp8=self.use_mxfp8_a2a_dispatch_fwd when the option is available in torchao
routed_input = self.a2a_dispatch_mxfp8_fwd_hp_bwd(
routed_input,
output_splits=self.output_splits,
input_splits=self.input_splits,
group_name=device_mesh.get_group().group_name,
)

# NOTE: After this all-to-all, the routed input is put on proper EP rank.
# However, the num_tokens_per_expert_group is not of the final target format
# [#tokens for local expert 0, #tokens for local expert 1, ...]
# Rather, it is of the format
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
# We need to perform another shuffle to get the correct layout, via the _permute function
# below, which also does padding to make sure the number of tokens each expert gets locally
# is a multiple of TOKEN_GROUP_ALIGN_SIZE_M.
# Note that this will create side effects when wrapping the for-loop implementation
# of GroupedExperts, as it does not need padding.

# TODO: set use_mxfp8=self.use_mxfp8_a2a_dispatch_fwd when the option is available in torchao
(
self.input_shape,
routed_input,
self.permuted_indices,
num_tokens_per_expert_group,
_,
) = self.permute_mxfp8_fwd_hp_bwd(
routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts
)

return routed_input, num_tokens_per_expert_group

def _token_combine(
self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh
) -> Tensor:
# TODO: set use_mxfp8=self.use_mxfp8_a2a_combine_bwd when the option is available in torchao
routed_output = self.unpermute_hp_fwd_mxfp8_bwd(
routed_output, self.permuted_indices, self.input_shape
)

# TODO: set use_mxfp8=self.use_mxfp8_a2a_combine_bwd when the option is available in torchao
routed_output = self.a2a_combine_hp_fwd_mxfp8_bwd(
routed_output,
output_splits=self.input_splits, # swap input/output splits to reverse all-to-all dispatch
input_splits=self.output_splits,
group_name=device_mesh.get_group().group_name,
)
return routed_output

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partition_fn=self._partition_fn,
# pyrefly: ignore [bad-argument-type]
input_fn=self._token_dispatch,
# pyrefly: ignore [bad-argument-type]
output_fn=self._token_combine,
)


# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
class ExpertTensorParallel(ExpertParallel):
def _token_dispatch(self, mod, inputs, device_mesh):
Expand Down Expand Up @@ -382,3 +468,49 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
input_fn=self._token_dispatch, # pyrefly: ignore [bad-argument-type]
output_fn=self._token_combine, # pyrefly: ignore [bad-argument-type]
)


def get_a2a_splits(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now think it's better to make _get_a2a_splits a static method of ExpertParallel, so that subclasses of it can inherit.

num_tokens_per_expert: torch.Tensor,
device_mesh: DeviceMesh,
ep_degree: int,
) -> tuple[list[int], list[int], torch.Tensor]:
"""
Get the input and output splits for all-to-all comms in expert parallelism.

Note: this incurs a device-to-host synchronization.

Args:
num_tokens_per_expert: Tensor of shape (num_experts,)
device_mesh: Device mesh for expert parallelism
ep_degree: Expert parallelism degree
Returns:
input_splits: list of shape (ep_degree,)
output_splits: list of shape (ep_degree,)
num_tokens_per_expert_group: Tensor of shape (num_experts,)
"""

with torch.no_grad():
num_tokens_per_expert_group = all_to_all_single(
num_tokens_per_expert,
None,
None,
group=device_mesh.get_group(),
)
# Need to wait explicitly because it is used by a triton kernel later
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
input_splits = (
num_tokens_per_expert.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
return input_splits.tolist(), output_splits.tolist(), num_tokens_per_expert_group
5 changes: 5 additions & 0 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def parallelize_deepseekv3(
else:
use_deepep = False

use_mxfp8_a2a = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have DeepEP backend now enabled with this config https://github.com/pytorch/torchtitan/blob/main/torchtitan/config/job_config.py#L474
I think we should just add another backend mxfp8, and pass it to apply_moe_ep_tp, instead of using both use_deepep and use_mxfp8_a2a.

"quantize.grouped_mm.mx" in job_config.model.converters
and job_config.quantize.grouped_mm.mx.recipe_name == "mxfp8_wgrad_with_hp"
)
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims)

Expand All @@ -131,6 +135,7 @@ def parallelize_deepseekv3(
ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]),
dual_pipe_v=dual_pipe_v,
use_deepep=use_deepep,
use_mxfp8_a2a=use_mxfp8_a2a,
)

if parallel_dims.cp_enabled:
Expand Down
18 changes: 17 additions & 1 deletion torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DeepEPExpertParallel,
ExpertParallel,
ExpertTensorParallel,
MXFP8ExpertParallel,
ReordererSequenceParallel,
TensorParallel,
)
Expand Down Expand Up @@ -134,6 +135,10 @@ def parallelize_llama(
else:
use_deepep = False

use_mxfp8_a2a = (
"quantize.grouped_mm.mx" in job_config.model.converters
and job_config.quantize.grouped_mm.mx.recipe_name == "mxfp8_wgrad_with_hp"
)
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims)

Expand All @@ -145,6 +150,7 @@ def parallelize_llama(
ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]),
dual_pipe_v=dual_pipe_v,
use_deepep=use_deepep,
use_mxfp8_a2a=use_mxfp8_a2a,
)

attn_type = getattr(model.model_args, "attn_type", "sdpa")
Expand Down Expand Up @@ -505,8 +511,13 @@ def apply_moe_ep_tp(
ep_etp_mesh: DeviceMesh | None,
dual_pipe_v: bool = False,
use_deepep: bool = False,
use_mxfp8_a2a: bool = False,
):
assert ep_mesh is not None or tp_mesh is not None
assert (use_deepep and not use_mxfp8_a2a) or (not use_deepep and use_mxfp8_a2a), (
'DeepEP and MXFP8 all-to-all (part of quantize.grouped_mm.mx.recipe_name="mxfp8_wgrad_with_hp") are not compatible. '
"Please choose one of them."
)

# pyrefly: ignore [not-callable]
for transformer_block in model.layers.values():
Expand Down Expand Up @@ -574,7 +585,12 @@ def apply_moe_ep_tp(
logger.info("Applying DeepEP to MoE layer")
else:
# input / output sharding on the batch / tokens dim
experts_plan = ExpertParallel()
if use_mxfp8_a2a:
logger.info("Applying MXFP8 Expert Parallelism to MoE")
experts_plan = MXFP8ExpertParallel()
else:
logger.info("Applying Expert Parallelism to MoE layer")
experts_plan = ExpertParallel()
else:
experts_mesh = ep_etp_mesh
experts_plan = ExpertTensorParallel()
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,18 @@ def parallelize_qwen3(
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims)

use_mxfp8_a2a = (
"quantize.grouped_mm.mx" in job_config.model.converters
and job_config.quantize.grouped_mm.mx.recipe_name == "mxfp8_wgrad_with_hp"
)
apply_moe_ep_tp(
model,
tp_mesh=parallel_dims.get_optional_mesh("tp"),
ep_mesh=parallel_dims.get_optional_mesh("ep"),
etp_mesh=parallel_dims.get_optional_mesh("etp"),
ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]),
dual_pipe_v=dual_pipe_v,
use_mxfp8_a2a=use_mxfp8_a2a,
)

if parallel_dims.cp_enabled:
Expand Down
Loading