-
Notifications
You must be signed in to change notification settings - Fork 549
[CP][RFC] Enable FlexCP for llama3 with parallelize_module #1707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Similar to #1696, but this PR uses parallel_module similar to TP/SP. This PR also requires pytorch/pytorch#162542
device_mesh=world_mesh["cp"], | ||
parallelize_plan=_ContextParallel( | ||
seq_dim=2, | ||
attention_type=_ContextParallel.AttentionType.FLEX, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this only work for FlexAttention?
Is there a plan to consolidate SDPA and FlexAttention in terms of how CP is applied?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will work for both SDPA and Flex. We just need to pass in a different type based on what attention is used.
if parallel_dims.cp_enabled: | ||
for block in model.layers.values(): | ||
parallelize_module( | ||
module=block.attention.sdpa.attention_fn_wrapper, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC for FlexAttn we need this wrapper because of block mask has to be obtained inside FlexAttention
class before calling the wrapper. For SDPA it seems unnecessary? It is already a very thin wrapper.
If the concern is code branching, the code is going to branch couple of lines below anyway, so I think it's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not just about the unification. That wrapper must have the exact function signatures as scaled_dot_product_attention. Our ScaledDotProductAttention doesn't meet this requirement. More importantly, we don't want this wrapper to be broken when the core library changes the function signature of scaled_dot_product_attention or flex_attention. So the best UX is to always ask users to wrap the APIs with forward being def forward(*args, **kwargs) -> Any
. So TorchTitan should also follow this rule.
parallelize_module( | ||
module=block.attention.sdpa.attention_fn_wrapper, | ||
device_mesh=world_mesh["cp"], | ||
parallelize_plan=_ContextParallel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So after this change, we only need to specify context parallel plan for attention module here, and CP of other modules is still handled by the context manager,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check the discussion in pytorch/pytorch#162542. It's definitely good to remove the context manager, but that may also have some implications to how users should write the model, like the wrapper in this PR.
Similar to #1696, but this PR uses parallel_module similar to TP/SP.
This PR also requires pytorch/pytorch#162542