Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 14 additions & 36 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,13 +639,20 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b
# by wrapping each submod's forward instead of their __call__
moe = submod
for attr_name, submod in moe.named_children():
setattr(
moe,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)
# temp workaround: compile everything except GroupedExperts.
# Context: https://github.com/pytorch/torchtitan/pull/2250#discussion_r2713296483
# TL;DR is this avoids a tensor metadata mismatch issue between forward output and
# backward() input (upstream grad).
if not isinstance(submod, moe_module.GroupedExperts):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's also another hardcoded call to compile _run_experts_grouped_mm here that you'd need to skip (I'm doing it in my patch in https://gist.github.com/bdhirsh/970a671b84c35cc95a76f33657ca4d69)

setattr(
moe,
attr_name,
torch.compile(
submod,
backend=compile_config.backend,
fullgraph=True,
),
)
else:
setattr(
block,
Expand All @@ -667,35 +674,6 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b
# pyrefly: ignore [missing-attribute]
model.layers.register_module(layer_id, transformer_block)

# Patch some globals only once (apply_compile is called multiple times for PP setup)
already_patched = (
"_run_experts_grouped_mm_dynamic"
in moe_module._run_experts_grouped_mm.__qualname__
)
if not already_patched:
moe_module._run_experts_grouped_mm = torch.compile(
moe_module._run_experts_grouped_mm,
backend=compile_config.backend,
fullgraph=True,
)

if ep_enabled:
compiled_fn = moe_module._run_experts_grouped_mm

# keep function logic in sync with `already_patched` above
def _run_experts_grouped_mm_dynamic(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# dynamic number of tokens in expert parallel
torch._dynamo.mark_dynamic(x, 0)
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)

moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic

# NOTE: We don't compile for loop code path due to an issue with unbacked symints:
# https://github.com/pytorch/pytorch/issues/166460

Expand Down
Loading