Skip to content
18 changes: 18 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,24 @@ class Parallelism:
Note that this is still an experimental feature.
"""

expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default"
"""
All-to-all implementation to use for the token dispatch step in expert parallelism.
- "default": Directly uses all_to_all_single with inputs/outputs in original precision.
- "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8,
using all_to_all_single on the quantized data and scales, then dequantizing
the outputs back to original precision.
"""

expert_parallel_a2a_combine_impl: Literal["default", "mxfp8"] = "default"
"""
All-to-all implementation to use for the token combine step in expert parallelism.
- "default": Directly uses all_to_all_single with inputs/outputs in original precision.
- "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8,
using all_to_all_single on the quantized data and scales, then dequantizing
the outputs back to original precision.
"""


@dataclass
class Checkpoint:
Expand Down
34 changes: 29 additions & 5 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,40 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:


class ExpertParallel(ParallelStyle):
def __init__(self):
"""
ExpertParallel is a parallel style for MoE, where each experts
are distributed across ranks along a given axis of the device mesh.

Args:
a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
"""

def __init__(
self, a2a_dispatch_impl: str = "default", a2a_combine_impl: str = "default"
):
super().__init__()
self.input_splits = None
self.output_splits = None
self.a2a_dispatch_func = self._get_a2a_func(a2a_dispatch_impl)
self.a2a_combine_func = self._get_a2a_func(a2a_combine_impl)

def _get_a2a_func(self, a2a_impl: str):
if a2a_impl == "default":
return all_to_all_single_autograd
elif a2a_impl == "mxfp8":
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
to_mxfp8_a2a_dequant,
)

return to_mxfp8_a2a_dequant
else:
raise ValueError(f"Unknown a2a_impl: {a2a_impl}")

# performing all-to-all dispatch on the input
def _token_dispatch(self, mod, inputs, device_mesh):
# annotate module input placements/sharding with input_layouts
routed_input, num_tokens_per_expert = inputs
ep_size = device_mesh.shape[0]
ep_size = device_mesh.size(0)

# generate the input splits and output splits for all-to-all
with torch.no_grad():
Expand Down Expand Up @@ -119,8 +143,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
self.input_splits = input_splits.tolist()
self.output_splits = output_splits.tolist()

# perform all-to-all
routed_input = all_to_all_single_autograd(
routed_input = self.a2a_dispatch_func(
routed_input,
self.output_splits,
self.input_splits,
Expand Down Expand Up @@ -148,7 +171,8 @@ def _partition_fn(name, mod, device_mesh):

# performing all-to-all combine on the output
def _token_combine(self, mod, routed_output, device_mesh):
routed_output = all_to_all_single_autograd(
# For a2a combine, input splits and output splits are opposite of a2a dispatch.
routed_output = self.a2a_combine_func(
routed_output,
self.input_splits,
self.output_splits,
Expand Down
16 changes: 14 additions & 2 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def parallelize_llama(
else None
),
etp_enabled=parallel_dims.etp_enabled,
a2a_dispatch_impl=job_config.parallelism.expert_parallel_a2a_dispatch_impl,
a2a_combine_impl=job_config.parallelism.expert_parallel_a2a_combine_impl,
)

model_compile_enabled = (
Expand Down Expand Up @@ -438,7 +440,11 @@ def apply_moe_ep_tp(
ep_mesh: DeviceMesh | None,
ep_tp_mesh: DeviceMesh | None,
etp_enabled: bool,
a2a_dispatch_impl: str = "default",
a2a_combine_impl: str = "default",
):
logger.info(f"Using all-to-all dispatch: {a2a_dispatch_impl}")
logger.info(f"Using all-to-all combine: {a2a_combine_impl}")
for transformer_block in model.layers.values():
if not transformer_block.moe_enabled:
continue
Expand Down Expand Up @@ -487,13 +493,19 @@ def apply_moe_ep_tp(
elif tp_mesh is None:
experts_mesh = ep_mesh
# input / output sharding on the batch / tokens dim
experts_plan = ExpertParallel()
experts_plan = ExpertParallel(
a2a_dispatch_impl=a2a_dispatch_impl,
a2a_combine_impl=a2a_combine_impl,
)
elif etp_enabled:
experts_mesh = ep_tp_mesh
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
else:
experts_mesh = ep_mesh
experts_plan = ExpertParallel()
experts_plan = ExpertParallel(
a2a_dispatch_impl=a2a_dispatch_impl,
a2a_combine_impl=a2a_combine_impl,
)

parallelize_module(
module=transformer_block.moe.experts,
Expand Down