Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ def parallelize_deepseekv3(
)

if model_compile_enabled:
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)
apply_compile(
model,
job_config.compile,
parallel_dims.ep_enabled,
parallel_dims.fsdp_enabled,
)

dp_mesh: DeviceMesh | None = None
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
Expand Down
107 changes: 68 additions & 39 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ def parallelize_llama(

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)
apply_compile(
model,
job_config.compile,
parallel_dims.ep_enabled,
parallel_dims.fsdp_enabled,
)

if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
# dp_mesh is the mesh for FSDP/HSDP
Expand Down Expand Up @@ -590,7 +595,55 @@ def apply_moe_ep_tp(
)


def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: bool):
def _is_checkpoint_wrapped(module: nn.Module) -> bool:
return isinstance(module, CheckpointWrapper)


def _apply_compile_all_except_moe_grouped_experts(
transformer_block: nn.Module,
compile_config: CompileConfig,
):

if isinstance(transformer_block, CheckpointWrapper):
# TODO: Make CheckpointWrapper a transparent wrapper
# unwrap so that .named_children() works
block = transformer_block._checkpoint_wrapped_module
else:
block = transformer_block

for attr_name, submod in block.named_children():
assert getattr(block, attr_name) == getattr(transformer_block, attr_name)

if isinstance(submod, moe_module.MoE):
# avoid graph breaking on the GroupedExperts' FSDP hooks
# by wrapping each submod's forward instead of their __call__
moe = submod
for attr_name, submod in moe.named_children():
if attr_name == "experts":
# NOTE: We don't compile token dispatch and token combine due to an issue on B200:
# https://github.com/pytorch/torchtitan/issues/1940
continue
setattr(
moe,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)
else:
setattr(
block,
attr_name,
torch.compile(submod, backend=compile_config.backend, fullgraph=True),
)


def apply_compile(
model: nn.Module,
compile_config: CompileConfig,
ep_enabled: bool,
fsdp_enabled: bool,
):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
Expand All @@ -600,51 +653,27 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b
torch._dynamo.config.capture_scalar_outputs = True
# pyrefly: ignore [missing-attribute]
for layer_id, transformer_block in model.layers.named_children():
if transformer_block.moe_enabled:

is_checkpoint_wrapped = _is_checkpoint_wrapped(transformer_block)
if (
transformer_block.moe_enabled
and is_checkpoint_wrapped
and fsdp_enabled
and ep_enabled
):
# If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break
# So we must weave compile wrappers around those FSDP hooks to
# prevent AC from falling back the whole graph to eager.
# TODO: Fix Compile(AC(graph break))

if isinstance(transformer_block, CheckpointWrapper):
# TODO: Make CheckpointWrapper a transparent wrapper
# unwrap so that .named_children() works
block = transformer_block._checkpoint_wrapped_module
else:
block = transformer_block

for attr_name, submod in block.named_children():
assert getattr(block, attr_name) == getattr(
transformer_block, attr_name
)

if isinstance(submod, moe_module.MoE):
# avoid graph breaking on the GroupedExperts' FSDP hooks
# by wrapping each submod's forward instead of their __call__
moe = submod
for attr_name, submod in moe.named_children():
if attr_name == "experts":
# NOTE: We don't compile token dispatch and token combine due to an issue on B200:
# https://github.com/pytorch/torchtitan/issues/1940
continue
setattr(
moe,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)
else:
setattr(
block,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)
_apply_compile_all_except_moe_grouped_experts(
transformer_block, compile_config
)

else:
# If it's not a MoE layer, there is no FSDP(GroupedExperts)
# If EP is not enabled, there is no FSDP(GroupedExperts) either
# If FSDP is not enabled, there is no FSDP(GroupedExperts) either
# So we can compile the whole block
transformer_block = torch.compile(
transformer_block,
Expand Down
7 changes: 6 additions & 1 deletion torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ def parallelize_qwen3(

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)
apply_compile(
model,
job_config.compile,
parallel_dims.ep_enabled,
parallel_dims.fsdp_enabled,
)

if parallel_dims.fsdp_enabled:
# apply FSDP or HSDP, potentially with Context Parallel
Expand Down