Skip to content

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Sep 12, 2025

Similar to #1696, but this PR uses parallel_module similar to TP/SP.

This PR also requires pytorch/pytorch#162542

Similar to #1696, but this PR uses parallel_module similar to TP/SP.

This PR also requires pytorch/pytorch#162542
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 12, 2025
device_mesh=world_mesh["cp"],
parallelize_plan=_ContextParallel(
seq_dim=2,
attention_type=_ContextParallel.AttentionType.FLEX,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this only work for FlexAttention?
Is there a plan to consolidate SDPA and FlexAttention in terms of how CP is applied?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will work for both SDPA and Flex. We just need to pass in a different type based on what attention is used.

if parallel_dims.cp_enabled:
for block in model.layers.values():
parallelize_module(
module=block.attention.sdpa.attention_fn_wrapper,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC for FlexAttn we need this wrapper because of block mask has to be obtained inside FlexAttention class before calling the wrapper. For SDPA it seems unnecessary? It is already a very thin wrapper.

If the concern is code branching, the code is going to branch couple of lines below anyway, so I think it's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not just about the unification. That wrapper must have the exact function signatures as scaled_dot_product_attention. Our ScaledDotProductAttention doesn't meet this requirement. More importantly, we don't want this wrapper to be broken when the core library changes the function signature of scaled_dot_product_attention or flex_attention. So the best UX is to always ask users to wrap the APIs with forward being def forward(*args, **kwargs) -> Any. So TorchTitan should also follow this rule.

parallelize_module(
module=block.attention.sdpa.attention_fn_wrapper,
device_mesh=world_mesh["cp"],
parallelize_plan=_ContextParallel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after this change, we only need to specify context parallel plan for attention module here, and CP of other modules is still handled by the context manager,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check the discussion in pytorch/pytorch#162542. It's definitely good to remove the context manager, but that may also have some implications to how users should write the model, like the wrapper in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants