From 793589e89ac4920610cf25756239794d7bb584a6 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 29 Jan 2026 15:11:43 -0800 Subject: [PATCH 1/4] [DeepSeek-V3] Add 16B model config for testing Add deepseek_v3_16b.toml config for local testing with 4 GPUs. --- torchtitan/models/deepseek_v3/__init__.py | 2 +- .../models/deepseek_v3/train_configs/deepseek_v3_16b.toml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 31e450eb04..c335409a72 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -80,7 +80,7 @@ dim=2048, inter_dim=10944, moe_inter_dim=1408, - n_layers=27, + n_layers=8, n_dense_layers=1, n_heads=16, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 00ec53310e..d6c66bbae7 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -4,7 +4,7 @@ description = "DeepSeek-V3 16B model training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false @@ -38,7 +38,7 @@ min_lr_factor = 0.1 local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 20 dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -49,7 +49,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 8 +expert_parallel_degree = 2 expert_tensor_parallel_degree = 1 [checkpoint] From 3bf7e27e24f329c1b9b1b8881dd43d9d1f4175dd Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 6 Feb 2026 17:17:27 -0800 Subject: [PATCH 2/4] [FSDP2] Enable per-param mesh FSDP2 for MoE --- torchtitan/models/deepseek_v3/__init__.py | 2 +- .../train_configs/deepseek_v3_16b.toml | 6 +- torchtitan/models/llama4/infra/parallelize.py | 103 +++++++++--------- 3 files changed, 58 insertions(+), 53 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index c335409a72..31e450eb04 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -80,7 +80,7 @@ dim=2048, inter_dim=10944, moe_inter_dim=1408, - n_layers=8, + n_layers=27, n_dense_layers=1, n_heads=16, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index d6c66bbae7..00ec53310e 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -4,7 +4,7 @@ description = "DeepSeek-V3 16B model training" print_config = false [profiling] -enable_profiling = true +enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false @@ -38,7 +38,7 @@ min_lr_factor = 0.1 local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 20 +steps = 1000 dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -49,7 +49,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 2 +expert_parallel_degree = 8 expert_tensor_parallel_degree = 1 [checkpoint] diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 5e23b81ab1..e3902a358b 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -13,6 +13,10 @@ ) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.fsdp._fully_shard._fsdp_common import ( + FSDPMeshInfo, + ShardPlacementResult, +) from torch.distributed.tensor import Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -379,40 +383,61 @@ def apply_fsdp( # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): - # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping - # - the router and the shared experts are sharded together with the TransformerBlock - # - the routed experts are sharded with the remaining edp_mesh + # NOTE: When EP is enabled, In an MoE layer, we use shard_placement_fn to apply + # different FSDP mesh and shard placement to different parameters: + # - the routed experts are sharded on edp_mesh (Shard(0) or Shard(1) depending on num_experts) + # - other params (router, shared experts, attention, norms) are sharded with Shard(0) on dp_mesh if transformer_block.moe_enabled and ep_degree > 1: - fsdp_mod_ep_config = fsdp_config.copy() - fsdp_mod_ep_config["mesh"] = edp_mesh - - # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). - # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding - # causes inefficiency, so we choose to do FSDP sharding on dim-1. - # Even when EP is not used, we may still want to shard the experts - # on non-0 dim. For now it may not be worth the complexity to support - # shard_placement_fn on the outer TransformerBlock-level FSDP. - _experts_shard_placement_fn = None assert edp_mesh is not None assert hasattr(transformer_block, "moe") - if ( - edp_mesh["efsdp"].size() * ep_degree - > transformer_block.moe.experts.num_experts - ): - _experts_shard_placement_fn = lambda param: Shard(1) + + # Create FSDPMeshInfo for different parameter groups + edp_mesh_info = FSDPMeshInfo(mesh=edp_mesh, shard_mesh_dim=0) + # dp_mesh_info for non-expert parameters (Shard on dim 0) + dp_mesh_info = FSDPMeshInfo(mesh=dp_mesh, shard_mesh_dim=0) + + # Collect expert parameter references for identification + expert_params = set(transformer_block.moe.experts.parameters()) + + # Determine expert shard placement dynamically: + # - When efsdp_size * ep_degree <= num_experts: use Shard(0), each rank owns complete experts + # - When efsdp_size * ep_degree > num_experts: use Shard(1), must shard within expert weights + num_experts = transformer_block.moe.experts.num_experts + if edp_mesh["efsdp"].size() * ep_degree > num_experts: + expert_shard_placement = Shard(1) + else: + expert_shard_placement = Shard(0) + + def _shard_placement_fn( + param: nn.Parameter, + _expert_params: set = expert_params, + _expert_placement: Shard = expert_shard_placement, + _edp_mesh_info: FSDPMeshInfo = edp_mesh_info, + _dp_mesh_info: FSDPMeshInfo = dp_mesh_info, + ) -> ShardPlacementResult: + if param in _expert_params: + # Expert parameters: use dynamic placement on edp_mesh + return ShardPlacementResult( + placement=_expert_placement, mesh_info=_edp_mesh_info + ) + else: + # Non-expert parameters: use Shard(0) on dp_mesh + return ShardPlacementResult( + placement=Shard(0), mesh_info=_dp_mesh_info + ) fully_shard( - transformer_block.moe.experts, - **fsdp_mod_ep_config, + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + shard_placement_fn=_shard_placement_fn, + ) + else: + fully_shard( + transformer_block, + **fsdp_config, reshard_after_forward=reshard_after_forward, - shard_placement_fn=_experts_shard_placement_fn, ) - - fully_shard( - transformer_block, - **fsdp_config, - reshard_after_forward=reshard_after_forward, - ) # As an optimization, do not reshard_after_forward the last layers by default # since FSDP would prefetch them immediately after the forward pass @@ -449,17 +474,7 @@ def apply_fsdp( ): if next_transformer_block is not None: # pyrefly: ignore [missing-attribute] - if next_transformer_block.moe_enabled: - # pyrefly: ignore [missing-attribute] - transformer_block.set_modules_to_forward_prefetch( - # pyrefly: ignore [missing-attribute] - [next_transformer_block, next_transformer_block.moe.experts] - ) - else: - # pyrefly: ignore [missing-attribute] - transformer_block.set_modules_to_forward_prefetch( - [next_transformer_block] - ) + transformer_block.set_modules_to_forward_prefetch([next_transformer_block]) elif model.norm is not None and model.output is not None: # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( @@ -481,17 +496,7 @@ def apply_fsdp( ): if prev_transformer_block is not None: # pyrefly: ignore [missing-attribute] - if prev_transformer_block.moe_enabled: - # pyrefly: ignore [missing-attribute] - transformer_block.set_modules_to_backward_prefetch( - # pyrefly: ignore [missing-attribute] - [prev_transformer_block, prev_transformer_block.moe.experts] - ) - else: - # pyrefly: ignore [missing-attribute] - transformer_block.set_modules_to_backward_prefetch( - [prev_transformer_block] - ) + transformer_block.set_modules_to_backward_prefetch([prev_transformer_block]) elif model.tok_embeddings is not None: # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings]) From cc03b3c8b49d4244947ab93c2b5ac6946d4f58e6 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 6 Feb 2026 17:52:53 -0800 Subject: [PATCH 3/4] [Llama4] Simplify torch.compile for MoE with per-param mesh FSDP2 With per-param mesh FSDP2, we no longer apply fully_shard on GroupedExperts separately. This eliminates the graph break from FSDP hooks on experts, so we can compile each whole transformer block instead of the previous per-submodule workaround. --- torchtitan/models/llama4/infra/parallelize.py | 64 +++---------------- 1 file changed, 9 insertions(+), 55 deletions(-) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index e3902a358b..00940528d5 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -8,9 +8,6 @@ import torch import torch.nn as nn -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointWrapper, -) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy from torch.distributed.fsdp._fully_shard._fsdp_common import ( @@ -603,60 +600,17 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE # but it is experimental. torch._dynamo.config.capture_scalar_outputs = True + + # With per-param mesh FSDP2, we no longer apply fully_shard on GroupedExperts separately. + # This means there's no graph break from FSDP hooks on experts, so we can compile the + # whole transformer block for both MoE and non-MoE layers. # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): - if transformer_block.moe_enabled: - # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break - # So we must weave compile wrappers around those FSDP hooks to - # prevent AC from falling back the whole graph to eager. - # TODO: Fix Compile(AC(graph break)) - - if isinstance(transformer_block, CheckpointWrapper): - # TODO: Make CheckpointWrapper a transparent wrapper - # unwrap so that .named_children() works - block = transformer_block._checkpoint_wrapped_module - else: - block = transformer_block - - for attr_name, submod in block.named_children(): - assert getattr(block, attr_name) == getattr( - transformer_block, attr_name - ) - - if isinstance(submod, moe_module.MoE): - # avoid graph breaking on the GroupedExperts' FSDP hooks - # by wrapping each submod's forward instead of their __call__ - moe = submod - for attr_name, submod in moe.named_children(): - if attr_name == "experts": - # NOTE: We don't compile token dispatch and token combine due to an issue on B200: - # https://github.com/pytorch/torchtitan/issues/1940 - continue - setattr( - moe, - attr_name, - torch.compile( - submod, backend=compile_config.backend, fullgraph=True - ), - ) - else: - setattr( - block, - attr_name, - torch.compile( - submod, backend=compile_config.backend, fullgraph=True - ), - ) - - else: - # If it's not a MoE layer, there is no FSDP(GroupedExperts) - # So we can compile the whole block - transformer_block = torch.compile( - transformer_block, - backend=compile_config.backend, - fullgraph=True, - ) - + transformer_block = torch.compile( + transformer_block, + backend=compile_config.backend, + fullgraph=True, + ) # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) From 19f229d795678d554efa5611bee1f00c2dc5d297 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 6 Feb 2026 19:18:30 -0800 Subject: [PATCH 4/4] reduce layer and enable model compile Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/models/deepseek_v3/__init__.py | 2 +- .../models/deepseek_v3/train_configs/deepseek_v3_16b.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 31e450eb04..4e26a62b8a 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -80,7 +80,7 @@ dim=2048, inter_dim=10944, moe_inter_dim=1408, - n_layers=27, + n_layers=7, n_dense_layers=1, n_heads=16, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 00ec53310e..d24a5aa972 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -66,7 +66,7 @@ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac ba [compile] enable=true -components = ["loss"] # ["model", "loss"] +components = ["model", "loss"] [quantize.linear.float8] enable_fsdp_float8_all_gather = false