Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
dim=2048,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=27,
n_layers=7,
n_dense_layers=1,
n_heads=16,
moe_args=MoEArgs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac ba

[compile]
enable=true
components = ["loss"] # ["model", "loss"]
components = ["model", "loss"]

[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
Expand Down
167 changes: 63 additions & 104 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
from torch.distributed.fsdp._fully_shard._fsdp_common import (
FSDPMeshInfo,
ShardPlacementResult,
)
from torch.distributed.tensor import Partial, Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -379,40 +380,61 @@ def apply_fsdp(

# pyrefly: ignore [missing-attribute]
for layer_id, transformer_block in model.layers.items():
# NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping
# - the router and the shared experts are sharded together with the TransformerBlock
# - the routed experts are sharded with the remaining edp_mesh
# NOTE: When EP is enabled, In an MoE layer, we use shard_placement_fn to apply
# different FSDP mesh and shard placement to different parameters:
# - the routed experts are sharded on edp_mesh (Shard(0) or Shard(1) depending on num_experts)
# - other params (router, shared experts, attention, norms) are sharded with Shard(0) on dp_mesh
if transformer_block.moe_enabled and ep_degree > 1:
fsdp_mod_ep_config = fsdp_config.copy()
fsdp_mod_ep_config["mesh"] = edp_mesh

# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
# Even when EP is not used, we may still want to shard the experts
# on non-0 dim. For now it may not be worth the complexity to support
# shard_placement_fn on the outer TransformerBlock-level FSDP.
_experts_shard_placement_fn = None
assert edp_mesh is not None
assert hasattr(transformer_block, "moe")
if (
edp_mesh["efsdp"].size() * ep_degree
> transformer_block.moe.experts.num_experts
):
_experts_shard_placement_fn = lambda param: Shard(1)

# Create FSDPMeshInfo for different parameter groups
edp_mesh_info = FSDPMeshInfo(mesh=edp_mesh, shard_mesh_dim=0)
# dp_mesh_info for non-expert parameters (Shard on dim 0)
dp_mesh_info = FSDPMeshInfo(mesh=dp_mesh, shard_mesh_dim=0)

# Collect expert parameter references for identification
expert_params = set(transformer_block.moe.experts.parameters())

# Determine expert shard placement dynamically:
# - When efsdp_size * ep_degree <= num_experts: use Shard(0), each rank owns complete experts
# - When efsdp_size * ep_degree > num_experts: use Shard(1), must shard within expert weights
num_experts = transformer_block.moe.experts.num_experts
if edp_mesh["efsdp"].size() * ep_degree > num_experts:
expert_shard_placement = Shard(1)
else:
expert_shard_placement = Shard(0)

def _shard_placement_fn(
param: nn.Parameter,
_expert_params: set = expert_params,
_expert_placement: Shard = expert_shard_placement,
_edp_mesh_info: FSDPMeshInfo = edp_mesh_info,
_dp_mesh_info: FSDPMeshInfo = dp_mesh_info,
) -> ShardPlacementResult:
if param in _expert_params:
# Expert parameters: use dynamic placement on edp_mesh
return ShardPlacementResult(
placement=_expert_placement, mesh_info=_edp_mesh_info
)
else:
# Non-expert parameters: use Shard(0) on dp_mesh
return ShardPlacementResult(
placement=Shard(0), mesh_info=_dp_mesh_info
)

fully_shard(
transformer_block.moe.experts,
**fsdp_mod_ep_config,
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_shard_placement_fn,
)
else:
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_experts_shard_placement_fn,
)

fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)

# As an optimization, do not reshard_after_forward the last layers by default
# since FSDP would prefetch them immediately after the forward pass
Expand Down Expand Up @@ -449,17 +471,7 @@ def apply_fsdp(
):
if next_transformer_block is not None:
# pyrefly: ignore [missing-attribute]
if next_transformer_block.moe_enabled:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_forward_prefetch(
# pyrefly: ignore [missing-attribute]
[next_transformer_block, next_transformer_block.moe.experts]
)
else:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_forward_prefetch(
[next_transformer_block]
)
transformer_block.set_modules_to_forward_prefetch([next_transformer_block])
elif model.norm is not None and model.output is not None:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_forward_prefetch(
Expand All @@ -481,17 +493,7 @@ def apply_fsdp(
):
if prev_transformer_block is not None:
# pyrefly: ignore [missing-attribute]
if prev_transformer_block.moe_enabled:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_backward_prefetch(
# pyrefly: ignore [missing-attribute]
[prev_transformer_block, prev_transformer_block.moe.experts]
)
else:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_backward_prefetch(
[prev_transformer_block]
)
transformer_block.set_modules_to_backward_prefetch([prev_transformer_block])
elif model.tok_embeddings is not None:
# pyrefly: ignore [missing-attribute]
transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings])
Expand Down Expand Up @@ -598,60 +600,17 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b
# NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE
# but it is experimental.
torch._dynamo.config.capture_scalar_outputs = True

# With per-param mesh FSDP2, we no longer apply fully_shard on GroupedExperts separately.
# This means there's no graph break from FSDP hooks on experts, so we can compile the
# whole transformer block for both MoE and non-MoE layers.
# pyrefly: ignore [missing-attribute]
for layer_id, transformer_block in model.layers.named_children():
if transformer_block.moe_enabled:
# If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break
# So we must weave compile wrappers around those FSDP hooks to
# prevent AC from falling back the whole graph to eager.
# TODO: Fix Compile(AC(graph break))

if isinstance(transformer_block, CheckpointWrapper):
# TODO: Make CheckpointWrapper a transparent wrapper
# unwrap so that .named_children() works
block = transformer_block._checkpoint_wrapped_module
else:
block = transformer_block

for attr_name, submod in block.named_children():
assert getattr(block, attr_name) == getattr(
transformer_block, attr_name
)

if isinstance(submod, moe_module.MoE):
# avoid graph breaking on the GroupedExperts' FSDP hooks
# by wrapping each submod's forward instead of their __call__
moe = submod
for attr_name, submod in moe.named_children():
if attr_name == "experts":
# NOTE: We don't compile token dispatch and token combine due to an issue on B200:
# https://github.com/pytorch/torchtitan/issues/1940
continue
setattr(
moe,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)
else:
setattr(
block,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)

else:
# If it's not a MoE layer, there is no FSDP(GroupedExperts)
# So we can compile the whole block
transformer_block = torch.compile(
transformer_block,
backend=compile_config.backend,
fullgraph=True,
)

transformer_block = torch.compile(
transformer_block,
backend=compile_config.backend,
fullgraph=True,
)
# pyrefly: ignore [missing-attribute]
model.layers.register_module(layer_id, transformer_block)

Expand Down
Loading