From 73becabaed1abee12708bd0a7aa206caa6f93b78 Mon Sep 17 00:00:00 2001 From: Eric Schreiber Date: Wed, 4 Feb 2026 15:04:40 +0000 Subject: [PATCH] Clarify cases when manual graph breaks are needed. --- .../models/deepseek_v3/infra/parallelize.py | 7 +- torchtitan/models/llama4/infra/parallelize.py | 107 +++++++++++------- torchtitan/models/qwen3/infra/parallelize.py | 7 +- 3 files changed, 80 insertions(+), 41 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 19d9f946d2..5f7364b3bd 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -156,7 +156,12 @@ def parallelize_deepseekv3( ) if model_compile_enabled: - apply_compile(model, job_config.compile, parallel_dims.ep_enabled) + apply_compile( + model, + job_config.compile, + parallel_dims.ep_enabled, + parallel_dims.fsdp_enabled, + ) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 5e23b81ab1..bc51e87f58 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -176,7 +176,12 @@ def parallelize_llama( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile, parallel_dims.ep_enabled) + apply_compile( + model, + job_config.compile, + parallel_dims.ep_enabled, + parallel_dims.fsdp_enabled, + ) if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # dp_mesh is the mesh for FSDP/HSDP @@ -590,7 +595,55 @@ def apply_moe_ep_tp( ) -def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: bool): +def _is_checkpoint_wrapped(module: nn.Module) -> bool: + return isinstance(module, CheckpointWrapper) + + +def _apply_compile_all_except_moe_grouped_experts( + transformer_block: nn.Module, + compile_config: CompileConfig, +): + + if isinstance(transformer_block, CheckpointWrapper): + # TODO: Make CheckpointWrapper a transparent wrapper + # unwrap so that .named_children() works + block = transformer_block._checkpoint_wrapped_module + else: + block = transformer_block + + for attr_name, submod in block.named_children(): + assert getattr(block, attr_name) == getattr(transformer_block, attr_name) + + if isinstance(submod, moe_module.MoE): + # avoid graph breaking on the GroupedExperts' FSDP hooks + # by wrapping each submod's forward instead of their __call__ + moe = submod + for attr_name, submod in moe.named_children(): + if attr_name == "experts": + # NOTE: We don't compile token dispatch and token combine due to an issue on B200: + # https://github.com/pytorch/torchtitan/issues/1940 + continue + setattr( + moe, + attr_name, + torch.compile( + submod, backend=compile_config.backend, fullgraph=True + ), + ) + else: + setattr( + block, + attr_name, + torch.compile(submod, backend=compile_config.backend, fullgraph=True), + ) + + +def apply_compile( + model: nn.Module, + compile_config: CompileConfig, + ep_enabled: bool, + fsdp_enabled: bool, +): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). @@ -600,51 +653,27 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b torch._dynamo.config.capture_scalar_outputs = True # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): - if transformer_block.moe_enabled: + + is_checkpoint_wrapped = _is_checkpoint_wrapped(transformer_block) + if ( + transformer_block.moe_enabled + and is_checkpoint_wrapped + and fsdp_enabled + and ep_enabled + ): # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break # So we must weave compile wrappers around those FSDP hooks to # prevent AC from falling back the whole graph to eager. # TODO: Fix Compile(AC(graph break)) - if isinstance(transformer_block, CheckpointWrapper): - # TODO: Make CheckpointWrapper a transparent wrapper - # unwrap so that .named_children() works - block = transformer_block._checkpoint_wrapped_module - else: - block = transformer_block - - for attr_name, submod in block.named_children(): - assert getattr(block, attr_name) == getattr( - transformer_block, attr_name - ) - - if isinstance(submod, moe_module.MoE): - # avoid graph breaking on the GroupedExperts' FSDP hooks - # by wrapping each submod's forward instead of their __call__ - moe = submod - for attr_name, submod in moe.named_children(): - if attr_name == "experts": - # NOTE: We don't compile token dispatch and token combine due to an issue on B200: - # https://github.com/pytorch/torchtitan/issues/1940 - continue - setattr( - moe, - attr_name, - torch.compile( - submod, backend=compile_config.backend, fullgraph=True - ), - ) - else: - setattr( - block, - attr_name, - torch.compile( - submod, backend=compile_config.backend, fullgraph=True - ), - ) + _apply_compile_all_except_moe_grouped_experts( + transformer_block, compile_config + ) else: # If it's not a MoE layer, there is no FSDP(GroupedExperts) + # If EP is not enabled, there is no FSDP(GroupedExperts) either + # If FSDP is not enabled, there is no FSDP(GroupedExperts) either # So we can compile the whole block transformer_block = torch.compile( transformer_block, diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 4837dbc68e..f26a2c06ee 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -137,7 +137,12 @@ def parallelize_qwen3( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile, parallel_dims.ep_enabled) + apply_compile( + model, + job_config.compile, + parallel_dims.ep_enabled, + parallel_dims.fsdp_enabled, + ) if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel