-
Notifications
You must be signed in to change notification settings - Fork 700
[mxfp8 moe training] mxfp8 all to all #2250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: danielvegamyhre/stack/3
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -107,32 +107,13 @@ def _token_dispatch( | |
| ep_degree = device_mesh.shape[0] | ||
| num_local_experts = num_tokens_per_expert.shape[0] // ep_degree | ||
|
|
||
| # generate the input splits and output splits for all-to-all | ||
| with torch.no_grad(): | ||
| num_tokens_per_expert_group = all_to_all_single( | ||
| num_tokens_per_expert, | ||
| None, | ||
| None, | ||
| group=device_mesh.get_group(), | ||
| ) | ||
| # Need to wait explicitly because it is used by a triton kernel later | ||
| # which doesn't realize that AsyncCollectiveTensor needs unwrapping | ||
| num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( | ||
| num_tokens_per_expert_group | ||
| ) | ||
| input_splits = ( | ||
| num_tokens_per_expert.view(ep_degree, -1) | ||
| .sum(dim=1) | ||
| .to(torch.device("cpu"), non_blocking=True) | ||
| ) | ||
| # NOTE: this would incur a device-to-host sync | ||
| output_splits = ( | ||
| num_tokens_per_expert_group.view(ep_degree, -1) | ||
| .sum(dim=1) | ||
| .to(torch.device("cpu"), non_blocking=False) | ||
| ) | ||
| self.input_splits = input_splits.tolist() | ||
| self.output_splits = output_splits.tolist() | ||
| # first all-to-all to calculate output splits from input splits. | ||
| # note: this will incur a d2h sync | ||
| ( | ||
| self.input_splits, | ||
| self.output_splits, | ||
| num_tokens_per_expert_group, | ||
| ) = get_a2a_splits(num_tokens_per_expert, device_mesh, ep_degree) | ||
|
|
||
| # perform all-to-all | ||
| routed_input = all_to_all_single_autograd( | ||
|
|
@@ -192,6 +173,111 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | |
| ) | ||
|
|
||
|
|
||
| class MXFP8ExpertParallel(ExpertParallel): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.input_splits = None | ||
| self.output_splits = None | ||
| self.input_shape = None | ||
| self.permuted_indices = None | ||
| try: | ||
| from torchao.prototype.moe_training.ep import ( | ||
| a2a_combine_hp_fwd_mxfp8_bwd, | ||
| a2a_dispatch_mxfp8_fwd_hp_bwd, | ||
| permute_mxfp8_fwd_hp_bwd, | ||
| unpermute_hp_fwd_mxfp8_bwd, | ||
| ) | ||
|
|
||
| self.a2a_dispatch_mxfp8_fwd_hp_bwd = a2a_dispatch_mxfp8_fwd_hp_bwd | ||
| self.permute_mxfp8_fwd_hp_bwd = permute_mxfp8_fwd_hp_bwd | ||
| self.unpermute_hp_fwd_mxfp8_bwd = unpermute_hp_fwd_mxfp8_bwd | ||
| self.a2a_combine_hp_fwd_mxfp8_bwd = a2a_combine_hp_fwd_mxfp8_bwd | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "MXFP8 expert parallel ops are not available." | ||
| "Please install torchao nightly build for CUDA 12.8+: " | ||
| "https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#-installation." | ||
| ) from e | ||
|
|
||
| def _token_dispatch( | ||
| self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh | ||
| ) -> tuple[Tensor, Tensor]: | ||
| # annotate module input placements/sharding with input_layouts | ||
| routed_input, num_tokens_per_expert = inputs | ||
| ep_degree = device_mesh.shape[0] | ||
| num_local_experts = num_tokens_per_expert.shape[0] // ep_degree | ||
|
|
||
| # first all-to-all to calculate output splits from input splits. | ||
| # note: this will incur a d2h sync | ||
| ( | ||
| self.input_splits, | ||
| self.output_splits, | ||
| num_tokens_per_expert_group, | ||
| ) = get_a2a_splits(num_tokens_per_expert, device_mesh, ep_degree) | ||
|
|
||
| # perform all-to-all | ||
| # TODO: set use_mxfp8=self.use_mxfp8_a2a_dispatch_fwd when the option is available in torchao | ||
| routed_input = self.a2a_dispatch_mxfp8_fwd_hp_bwd( | ||
| routed_input, | ||
| output_splits=self.output_splits, | ||
| input_splits=self.input_splits, | ||
| group_name=device_mesh.get_group().group_name, | ||
| ) | ||
|
|
||
| # NOTE: After this all-to-all, the routed input is put on proper EP rank. | ||
| # However, the num_tokens_per_expert_group is not of the final target format | ||
| # [#tokens for local expert 0, #tokens for local expert 1, ...] | ||
| # Rather, it is of the format | ||
| # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., | ||
| # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] | ||
| # We need to perform another shuffle to get the correct layout, via the _permute function | ||
| # below, which also does padding to make sure the number of tokens each expert gets locally | ||
| # is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. | ||
| # Note that this will create side effects when wrapping the for-loop implementation | ||
| # of GroupedExperts, as it does not need padding. | ||
|
|
||
| # TODO: set use_mxfp8=self.use_mxfp8_a2a_dispatch_fwd when the option is available in torchao | ||
| ( | ||
| self.input_shape, | ||
| routed_input, | ||
| self.permuted_indices, | ||
| num_tokens_per_expert_group, | ||
| _, | ||
| ) = self.permute_mxfp8_fwd_hp_bwd( | ||
| routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts | ||
| ) | ||
|
|
||
| return routed_input, num_tokens_per_expert_group | ||
|
|
||
| def _token_combine( | ||
| self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh | ||
| ) -> Tensor: | ||
| # TODO: set use_mxfp8=self.use_mxfp8_a2a_combine_bwd when the option is available in torchao | ||
| routed_output = self.unpermute_hp_fwd_mxfp8_bwd( | ||
| routed_output, self.permuted_indices, self.input_shape | ||
| ) | ||
|
|
||
| # TODO: set use_mxfp8=self.use_mxfp8_a2a_combine_bwd when the option is available in torchao | ||
| routed_output = self.a2a_combine_hp_fwd_mxfp8_bwd( | ||
| routed_output, | ||
| output_splits=self.input_splits, # swap input/output splits to reverse all-to-all dispatch | ||
| input_splits=self.output_splits, | ||
| group_name=device_mesh.get_group().group_name, | ||
| ) | ||
| return routed_output | ||
|
|
||
| def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
| return distribute_module( | ||
| module, | ||
| device_mesh, | ||
| partition_fn=self._partition_fn, | ||
| # pyrefly: ignore [bad-argument-type] | ||
| input_fn=self._token_dispatch, | ||
| # pyrefly: ignore [bad-argument-type] | ||
| output_fn=self._token_combine, | ||
| ) | ||
|
|
||
|
|
||
| # This class is for dp2ep with TP (without TP we can just use ExpertParallel) | ||
| class ExpertTensorParallel(ExpertParallel): | ||
| def _token_dispatch(self, mod, inputs, device_mesh): | ||
|
|
@@ -382,3 +468,49 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | |
| input_fn=self._token_dispatch, # pyrefly: ignore [bad-argument-type] | ||
| output_fn=self._token_combine, # pyrefly: ignore [bad-argument-type] | ||
| ) | ||
|
|
||
|
|
||
| def get_a2a_splits( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I now think it's better to make |
||
| num_tokens_per_expert: torch.Tensor, | ||
| device_mesh: DeviceMesh, | ||
| ep_degree: int, | ||
| ) -> tuple[list[int], list[int], torch.Tensor]: | ||
| """ | ||
| Get the input and output splits for all-to-all comms in expert parallelism. | ||
|
|
||
| Note: this incurs a device-to-host synchronization. | ||
|
|
||
| Args: | ||
| num_tokens_per_expert: Tensor of shape (num_experts,) | ||
| device_mesh: Device mesh for expert parallelism | ||
| ep_degree: Expert parallelism degree | ||
| Returns: | ||
| input_splits: list of shape (ep_degree,) | ||
| output_splits: list of shape (ep_degree,) | ||
| num_tokens_per_expert_group: Tensor of shape (num_experts,) | ||
| """ | ||
|
|
||
| with torch.no_grad(): | ||
| num_tokens_per_expert_group = all_to_all_single( | ||
| num_tokens_per_expert, | ||
| None, | ||
| None, | ||
| group=device_mesh.get_group(), | ||
| ) | ||
| # Need to wait explicitly because it is used by a triton kernel later | ||
| # which doesn't realize that AsyncCollectiveTensor needs unwrapping | ||
| num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( | ||
| num_tokens_per_expert_group | ||
| ) | ||
| input_splits = ( | ||
| num_tokens_per_expert.view(ep_degree, -1) | ||
| .sum(dim=1) | ||
| .to(torch.device("cpu"), non_blocking=True) | ||
| ) | ||
| # NOTE: this would incur a device-to-host sync | ||
| output_splits = ( | ||
| num_tokens_per_expert_group.view(ep_degree, -1) | ||
| .sum(dim=1) | ||
| .to(torch.device("cpu"), non_blocking=False) | ||
| ) | ||
| return input_splits.tolist(), output_splits.tolist(), num_tokens_per_expert_group | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,6 +120,10 @@ def parallelize_deepseekv3( | |
| else: | ||
| use_deepep = False | ||
|
|
||
| use_mxfp8_a2a = ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have DeepEP backend now enabled with this config https://github.com/pytorch/torchtitan/blob/main/torchtitan/config/job_config.py#L474 |
||
| "quantize.grouped_mm.mx" in job_config.model.converters | ||
| and job_config.quantize.grouped_mm.mx.recipe_name == "mxfp8_wgrad_with_hp" | ||
| ) | ||
| if parallel_dims.tp_enabled or parallel_dims.ep_enabled: | ||
| dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) | ||
|
|
||
|
|
@@ -131,6 +135,7 @@ def parallelize_deepseekv3( | |
| ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), | ||
| dual_pipe_v=dual_pipe_v, | ||
| use_deepep=use_deepep, | ||
| use_mxfp8_a2a=use_mxfp8_a2a, | ||
| ) | ||
|
|
||
| if parallel_dims.cp_enabled: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's related to mxfp8 only, let's put this in quantization/mx.py?