Skip to content

Conversation

@weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Feb 7, 2026

after enabling per-param mesh FSDP2 #2281 , we still have a torch.compile error for per-layer compile

command to repro: CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2 --profiling.enable_profiling --profiling.profile_freq=10 --training.steps=20

    File "/data/users/weif/dry_replicate/pytorch/torch/nn/modules/module.py", line 1885, in _call_impl
      return inner()
             ^^^^^^^
    File "/data/users/weif/dry_replicate/pytorch/torch/nn/modules/module.py", line 1833, in inner
      result = forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/weif/dry_replicate/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 501, in forward
      h = layer(h, self.freqs_cis, attention_masks, positions)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/weif/dry_replicate/pytorch/torch/_dynamo/eval_frame.py", line 471, in __call__
      return super().__call__(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/weif/dry_replicate/pytorch/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/weif/dry_replicate/pytorch/torch/nn/modules/module.py", line 1885, in _call_impl
      return inner()
             ^^^^^^^
    File "/data/users/weif/dry_replicate/pytorch/torch/nn/modules/module.py", line 1833, in inner
      result = forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/weif/dry_replicate/pytorch/torch/_dynamo/eval_frame.py", line 1025, in compile_wrapper
      raise e.with_traceback(None) from e.__cause__  # User compiler error
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  torch._dynamo.exc.Unsupported: HOP: Unsafe side effect
    Higher Order Operator: torch.utils.checkpoint.checkpoint
    Explanation: Mutating a variable from outside the scope of this HOP is not supported.
    Hint: If the HOP is activation checkpointing (torch.utils.checkpoint.checkpoint), this points to a side effect in forward method. Eager activation checkpointing replays that side-effect while recomputing the forward in the backward. If you are ok with side-effect not replayed in the backward, try setting `torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True`
  
    Developer debug context: Attempted to mutate UserDefinedObjectVariable(ExpertParallel)
  
   For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0067.html
  
  from user code:
     File "/data/users/weif/dry_replicate/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 169, in forward
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/data/users/weif/dry_replicate/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 392, in forward
      x = x + self.moe(self.ffn_norm(x))
    File "/data/users/weif/dry_replicate/torchtitan/torchtitan/models/moe/moe.py", line 522, in forward
      routed_output = self.experts(routed_input, num_tokens_per_expert)
    File "/data/users/weif/dry_replicate/pytorch/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/weif/dry_replicate/pytorch/torch/nn/modules/module.py", line 1882, in _call_impl
      return inner()
    File "/data/users/weif/dry_replicate/pytorch/torch/nn/modules/module.py", line 1822, in inner
      args_result = hook(self, args)
    File "/data/users/weif/dry_replicate/pytorch/torch/distributed/tensor/_api.py", line 1101, in <lambda>
      lambda mod, inputs: input_fn(mod, inputs, device_mesh)
    File "/data/users/weif/dry_replicate/torchtitan/torchtitan/distributed/expert_parallel.py", line 134, in _token_dispatch
      self.input_splits = input_splits.tolist()
  
  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Add deepseek_v3_16b.toml config for local testing with 4 GPUs.
With per-param mesh FSDP2, we no longer apply fully_shard on
GroupedExperts separately. This eliminates the graph break from
FSDP hooks on experts, so we can compile each whole transformer
block instead of the previous per-submodule workaround.
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 7, 2026
@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 7, 2026

cc @xmfan @soulitzer

@soulitzer
Copy link
Contributor

We can bypass the side effect check with torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True. Setting this flag will apply everywhere, and not just here, but its probably fine because I doubt that there are any side effects that we actually want replayed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants