diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 5e23b81ab1..e3902a358b 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -13,6 +13,10 @@ ) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.fsdp._fully_shard._fsdp_common import ( + FSDPMeshInfo, + ShardPlacementResult, +) from torch.distributed.tensor import Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -379,40 +383,61 @@ def apply_fsdp( # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): - # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping - # - the router and the shared experts are sharded together with the TransformerBlock - # - the routed experts are sharded with the remaining edp_mesh + # NOTE: When EP is enabled, In an MoE layer, we use shard_placement_fn to apply + # different FSDP mesh and shard placement to different parameters: + # - the routed experts are sharded on edp_mesh (Shard(0) or Shard(1) depending on num_experts) + # - other params (router, shared experts, attention, norms) are sharded with Shard(0) on dp_mesh if transformer_block.moe_enabled and ep_degree > 1: - fsdp_mod_ep_config = fsdp_config.copy() - fsdp_mod_ep_config["mesh"] = edp_mesh - - # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). - # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding - # causes inefficiency, so we choose to do FSDP sharding on dim-1. - # Even when EP is not used, we may still want to shard the experts - # on non-0 dim. For now it may not be worth the complexity to support - # shard_placement_fn on the outer TransformerBlock-level FSDP. - _experts_shard_placement_fn = None assert edp_mesh is not None assert hasattr(transformer_block, "moe") - if ( - edp_mesh["efsdp"].size() * ep_degree - > transformer_block.moe.experts.num_experts - ): - _experts_shard_placement_fn = lambda param: Shard(1) + + # Create FSDPMeshInfo for different parameter groups + edp_mesh_info = FSDPMeshInfo(mesh=edp_mesh, shard_mesh_dim=0) + # dp_mesh_info for non-expert parameters (Shard on dim 0) + dp_mesh_info = FSDPMeshInfo(mesh=dp_mesh, shard_mesh_dim=0) + + # Collect expert parameter references for identification + expert_params = set(transformer_block.moe.experts.parameters()) + + # Determine expert shard placement dynamically: + # - When efsdp_size * ep_degree <= num_experts: use Shard(0), each rank owns complete experts + # - When efsdp_size * ep_degree > num_experts: use Shard(1), must shard within expert weights + num_experts = transformer_block.moe.experts.num_experts + if edp_mesh["efsdp"].size() * ep_degree > num_experts: + expert_shard_placement = Shard(1) + else: + expert_shard_placement = Shard(0) + + def _shard_placement_fn( + param: nn.Parameter, + _expert_params: set = expert_params, + _expert_placement: Shard = expert_shard_placement, + _edp_mesh_info: FSDPMeshInfo = edp_mesh_info, + _dp_mesh_info: FSDPMeshInfo = dp_mesh_info, + ) -> ShardPlacementResult: + if param in _expert_params: + # Expert parameters: use dynamic placement on edp_mesh + return ShardPlacementResult( + placement=_expert_placement, mesh_info=_edp_mesh_info + ) + else: + # Non-expert parameters: use Shard(0) on dp_mesh + return ShardPlacementResult( + placement=Shard(0), mesh_info=_dp_mesh_info + ) fully_shard( - transformer_block.moe.experts, - **fsdp_mod_ep_config, + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + shard_placement_fn=_shard_placement_fn, + ) + else: + fully_shard( + transformer_block, + **fsdp_config, reshard_after_forward=reshard_after_forward, - shard_placement_fn=_experts_shard_placement_fn, ) - - fully_shard( - transformer_block, - **fsdp_config, - reshard_after_forward=reshard_after_forward, - ) # As an optimization, do not reshard_after_forward the last layers by default # since FSDP would prefetch them immediately after the forward pass @@ -449,17 +474,7 @@ def apply_fsdp( ): if next_transformer_block is not None: # pyrefly: ignore [missing-attribute] - if next_transformer_block.moe_enabled: - # pyrefly: ignore [missing-attribute] - transformer_block.set_modules_to_forward_prefetch( - # pyrefly: ignore [missing-attribute] - [next_transformer_block, next_transformer_block.moe.experts] - ) - else: - # pyrefly: ignore [missing-attribute] - transformer_block.set_modules_to_forward_prefetch( - [next_transformer_block] - ) + transformer_block.set_modules_to_forward_prefetch([next_transformer_block]) elif model.norm is not None and model.output is not None: # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( @@ -481,17 +496,7 @@ def apply_fsdp( ): if prev_transformer_block is not None: # pyrefly: ignore [missing-attribute] - if prev_transformer_block.moe_enabled: - # pyrefly: ignore [missing-attribute] - transformer_block.set_modules_to_backward_prefetch( - # pyrefly: ignore [missing-attribute] - [prev_transformer_block, prev_transformer_block.moe.experts] - ) - else: - # pyrefly: ignore [missing-attribute] - transformer_block.set_modules_to_backward_prefetch( - [prev_transformer_block] - ) + transformer_block.set_modules_to_backward_prefetch([prev_transformer_block]) elif model.tok_embeddings is not None: # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings])