-
Notifications
You must be signed in to change notification settings - Fork 541
FA num splits option #2357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FA num splits option #2357
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR exposes the num_splits parameter for FlashAttention v2 and v3 backends, allowing users to control memory optimization during attention computation.
Key Changes:
- Added optional
num_splitsparameter toDotProductAttention.forward()method - Passes
num_splitsto both FlashAttention v2 and v3 backend implementations when provided - Parameter is conditionally added to kwargs only when not
None
Areas for Improvement:
- Missing parameter documentation in the docstring
- No version compatibility check for flash-attn (unlike other optional parameters like
window_sizeanddeterministic) - No tests demonstrating the new functionality
Confidence Score: 4/5
- This PR is safe to merge with minor documentation improvements recommended
- The implementation correctly follows the existing pattern for optional parameters in FlashAttention backends. The changes are minimal and well-scoped. However, the score is not 5 due to: (1) missing parameter documentation, (2) lack of version compatibility checks that other optional parameters have, and (3) no accompanying tests. These are quality-of-life improvements rather than critical issues.
- No files require special attention - the implementation is straightforward and follows existing patterns
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/backends.py | 4/5 | Added num_splits parameter to FlashAttention forward method and passes it to both FA v2 and FA v3 backends when provided |
| transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py | 4/5 | Added num_splits parameter to DotProductAttention forward signature and forwards it to FlashAttention backend |
Sequence Diagram
sequenceDiagram
participant User
participant DotProductAttention
participant FlashAttention
participant FlashAttn_v2
participant FlashAttn_v3
User->>DotProductAttention: forward(query, key, value, num_splits=...)
Note over DotProductAttention: New optional parameter
DotProductAttention->>FlashAttention: forward(..., num_splits=num_splits)
Note over FlashAttention: Check backend version
alt Flash Attention v2
FlashAttention->>FlashAttention: if num_splits is not None
FlashAttention->>FlashAttention: fa_optional_forward_kwargs["num_splits"] = num_splits
FlashAttention->>FlashAttn_v2: flash_attn_func(..., **fa_optional_forward_kwargs)
FlashAttn_v2-->>FlashAttention: output
else Flash Attention v3
FlashAttention->>FlashAttention: if num_splits is not None
FlashAttention->>FlashAttention: fa_3_optional_forward_kwargs["num_splits"] = num_splits
FlashAttention->>FlashAttn_v3: flash_attn_func(..., **fa_3_optional_forward_kwargs)
FlashAttn_v3-->>FlashAttention: output
end
FlashAttention-->>DotProductAttention: output
DotProductAttention-->>User: output
2 files reviewed, 3 comments
| inference_params: Optional[InferenceParams] = None, | ||
| pad_between_seqs: Optional[bool] = None, | ||
| fp8_output: Optional[bool] = False, | ||
| num_splits: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Missing documentation for the num_splits parameter in the docstring. Add a parameter description explaining what this controls and when it should be used (e.g., "Controls the number of splits for FlashAttention computation. Used for memory optimization.")
| if num_splits is not None: | ||
| # Forward optional split control to flash-attn if available | ||
| fa_optional_forward_kwargs["num_splits"] = num_splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Verify that the flash-attn version supports num_splits parameter. Unlike other optional parameters (e.g., window_size has fa_utils.v2_3_plus check, deterministic has fa_utils.v2_4_1_plus check), this parameter is added without a version guard.
| if num_splits is not None: | ||
| # Forward optional split control to flash-attn v3 if supported | ||
| fa_3_optional_forward_kwargs["num_splits"] = num_splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Same as FA v2: verify flash-attn v3 supports num_splits to avoid potential TypeError
Description
I want to be able to control num splits in FA3. This exposes this argument
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: