diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 9325a3cf4e..5836ffd1a7 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -639,13 +639,20 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b # by wrapping each submod's forward instead of their __call__ moe = submod for attr_name, submod in moe.named_children(): - setattr( - moe, - attr_name, - torch.compile( - submod, backend=compile_config.backend, fullgraph=True - ), - ) + # temp workaround: compile everything except GroupedExperts. + # Context: https://github.com/pytorch/torchtitan/pull/2250#discussion_r2713296483 + # TL;DR is this avoids a tensor metadata mismatch issue between forward output and + # backward() input (upstream grad). + if not isinstance(submod, moe_module.GroupedExperts): + setattr( + moe, + attr_name, + torch.compile( + submod, + backend=compile_config.backend, + fullgraph=True, + ), + ) else: setattr( block, @@ -667,35 +674,6 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) - # Patch some globals only once (apply_compile is called multiple times for PP setup) - already_patched = ( - "_run_experts_grouped_mm_dynamic" - in moe_module._run_experts_grouped_mm.__qualname__ - ) - if not already_patched: - moe_module._run_experts_grouped_mm = torch.compile( - moe_module._run_experts_grouped_mm, - backend=compile_config.backend, - fullgraph=True, - ) - - if ep_enabled: - compiled_fn = moe_module._run_experts_grouped_mm - - # keep function logic in sync with `already_patched` above - def _run_experts_grouped_mm_dynamic( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, - ) -> torch.Tensor: - # dynamic number of tokens in expert parallel - torch._dynamo.mark_dynamic(x, 0) - return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) - - moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic - # NOTE: We don't compile for loop code path due to an issue with unbacked symints: # https://github.com/pytorch/pytorch/issues/166460