Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 54 additions & 49 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
from torch.distributed.fsdp._fully_shard._fsdp_common import (
FSDPMeshInfo,
ShardPlacementResult,
)
from torch.distributed.tensor import Partial, Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -379,40 +383,61 @@ def apply_fsdp(

# pyrefly: ignore [missing-attribute]
for layer_id, transformer_block in model.layers.items():
# NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping
# - the router and the shared experts are sharded together with the TransformerBlock
# - the routed experts are sharded with the remaining edp_mesh
# NOTE: When EP is enabled, In an MoE layer, we use shard_placement_fn to apply
# different FSDP mesh and shard placement to different parameters:
# - the routed experts are sharded on edp_mesh (Shard(0) or Shard(1) depending on num_experts)
# - other params (router, shared experts, attention, norms) are sharded with Shard(0) on dp_mesh
if transformer_block.moe_enabled and ep_degree > 1:
fsdp_mod_ep_config = fsdp_config.copy()
fsdp_mod_ep_config["mesh"] = edp_mesh

# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
# Even when EP is not used, we may still want to shard the experts
# on non-0 dim. For now it may not be worth the complexity to support
# shard_placement_fn on the outer TransformerBlock-level FSDP.
_experts_shard_placement_fn = None
assert edp_mesh is not None
assert hasattr(transformer_block, "moe")
if (
edp_mesh["efsdp"].size() * ep_degree
> transformer_block.moe.experts.num_experts
):
_experts_shard_placement_fn = lambda param: Shard(1)

# Create FSDPMeshInfo for different parameter groups
edp_mesh_info = FSDPMeshInfo(mesh=edp_mesh, shard_mesh_dim=0)
# dp_mesh_info for non-expert parameters (Shard on dim 0)
dp_mesh_info = FSDPMeshInfo(mesh=dp_mesh, shard_mesh_dim=0)

# Collect expert parameter references for identification
expert_params = set(transformer_block.moe.experts.parameters())

# Determine expert shard placement dynamically:
# - When efsdp_size * ep_degree <= num_experts: use Shard(0), each rank owns complete experts
# - When efsdp_size * ep_degree > num_experts: use Shard(1), must shard within expert weights
num_experts = transformer_block.moe.experts.num_experts
if edp_mesh["efsdp"].size() * ep_degree > num_experts:
expert_shard_placement = Shard(1)
else:
expert_shard_placement = Shard(0)

def _shard_placement_fn(
param: nn.Parameter,
_expert_params: set = expert_params,
_expert_placement: Shard = expert_shard_placement,
_edp_mesh_info: FSDPMeshInfo = edp_mesh_info,
_dp_mesh_info: FSDPMeshInfo = dp_mesh_info,
) -> ShardPlacementResult:
if param in _expert_params:
# Expert parameters: use dynamic placement on edp_mesh
return ShardPlacementResult(
placement=_expert_placement, mesh_info=_edp_mesh_info
)
else:
# Non-expert parameters: use Shard(0) on dp_mesh
return ShardPlacementResult(
placement=Shard(0), mesh_info=_dp_mesh_info
)

fully_shard(
transformer_block.moe.experts,
**fsdp_mod_ep_config,
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_shard_placement_fn,
)
else:
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_experts_shard_placement_fn,
)

fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)

# As an optimization, do not reshard_after_forward the last layers by default
# since FSDP would prefetch them immediately after the forward pass
Expand Down Expand Up @@ -449,17 +474,7 @@ def apply_fsdp(
):
if next_transformer_block is not None:
# pyrefly: ignore [missing-attribute]
if next_transformer_block.moe_enabled:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_forward_prefetch(
# pyrefly: ignore [missing-attribute]
[next_transformer_block, next_transformer_block.moe.experts]
)
else:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_forward_prefetch(
[next_transformer_block]
)
transformer_block.set_modules_to_forward_prefetch([next_transformer_block])
elif model.norm is not None and model.output is not None:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_forward_prefetch(
Expand All @@ -481,17 +496,7 @@ def apply_fsdp(
):
if prev_transformer_block is not None:
# pyrefly: ignore [missing-attribute]
if prev_transformer_block.moe_enabled:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_backward_prefetch(
# pyrefly: ignore [missing-attribute]
[prev_transformer_block, prev_transformer_block.moe.experts]
)
else:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_backward_prefetch(
[prev_transformer_block]
)
transformer_block.set_modules_to_backward_prefetch([prev_transformer_block])
elif model.tok_embeddings is not None:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings])
Expand Down
Loading