-
Notifications
You must be signed in to change notification settings - Fork 540
[JAX] Make all jax attention calls use non-packed common calls #2358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactors JAX attention extension to use only nvte_fused_attn_fwd/bwd instead of the deprecated packed variants (nvte_fused_attn_fwd_qkvpacked and nvte_fused_attn_fwd_kvpacked). The PR moves pointer arithmetic from the common API layer into the JAX extension code.
Key changes:
- Unified all three layout types (QKV packed, KV packed, separate) to call single
nvte_fused_attn_fwd/bwdAPI - Added pointer arithmetic in JAX extension to extract K and V pointers from packed tensors
- Removed unused tensor shape definitions and layout-specific branching in workspace size calculations
- Updated gradient zeroing logic in backward pass to correctly handle packed tensor memory layouts
Critical issue found:
- Lines 287 and 517: Stride calculation for KV-packed layout uses
qk_head_dimbut should usev_head_dimsince KV packed tensors have shape[batch*seqlen, 2, num_gqa_groups, v_head_dim]
Confidence Score: 1/5
- This PR contains critical pointer arithmetic bugs that will cause memory corruption or incorrect results
- Score of 1 (critical issues) due to incorrect stride calculation in KV-packed layout at lines 287 and 517. Using
qk_head_diminstead ofv_head_dimwill cause V pointer to point to wrong memory location when these dimensions differ, leading to incorrect attention computation or potential memory access violations - transformer_engine/jax/csrc/extensions/attention.cpp - lines 287 and 517 must be fixed before merge
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/csrc/extensions/attention.cpp | 1/5 | Refactors JAX attention to use only nvte_fused_attn_fwd/bwd, but has critical pointer arithmetic bug in KV-packed layout stride calculation using qk_head_dim instead of v_head_dim |
Sequence Diagram
sequenceDiagram
participant JAX as JAX Extension
participant Old as Old API (deprecated)
participant New as nvte_fused_attn_fwd
participant Kernel as Attention Kernel
Note over JAX,Kernel: Before PR #2358
JAX->>Old: nvte_fused_attn_fwd_qkvpacked(QKV)
Old->>Kernel: Extract Q,K,V pointers internally
Kernel-->>Old: Result
Old-->>JAX: Output
JAX->>Old: nvte_fused_attn_fwd_kvpacked(Q, KV)
Old->>Kernel: Extract K,V pointers internally
Kernel-->>Old: Result
Old-->>JAX: Output
Note over JAX,Kernel: After PR #2358
JAX->>JAX: Extract Q,K,V pointers from packed data
JAX->>New: nvte_fused_attn_fwd(Q, K, V)
New->>Kernel: Forward separate pointers
Kernel-->>New: Result
New-->>JAX: Output
Note over JAX: Pointer arithmetic handles:<br/>- NVTE_3HD: QKV packed<br/>- NVTE_HD_2HD: KV packed<br/>- NVTE_HD_HD_HD: All separate
1 file reviewed, 2 comments
| // Python passes: q=query, k=packed_kv, v=dummy | ||
| // Extract V pointer from the packed k data | ||
| // Note: For KV packed, qk_head_dim must equal v_head_dim | ||
| size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: stride calculation uses qk_head_dim but KV packed tensor has shape [batch*seqlen, 2, num_gqa_groups, v_head_dim] - should use v_head_dim
| size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); | |
| size_t stride = (typeToSize(dtype) * num_gqa_groups * v_head_dim); |
| cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); | ||
| // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim] | ||
| // Note: For KV packed, qk_head_dim must equal v_head_dim | ||
| size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: stride calculation uses qk_head_dim but KV packed tensor has shape [batch*seqlen, 2, num_gqa_groups, v_head_dim] - should use v_head_dim
| size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); | |
| size_t stride = (typeToSize(dtype) * num_gqa_groups * v_head_dim); |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactors JAX attention code to use unified nvte_fused_attn_fwd() and nvte_fused_attn_bwd() APIs instead of deprecated packed variants (nvte_fused_attn_fwd_kvpacked() and nvte_fused_attn_fwd_qkvpacked()). The pointer arithmetic for extracting K/V pointers from packed tensors is now handled in the JAX layer.
Key Changes:
- Removed conditional calls to packed-specific attention APIs
- Added pointer calculation logic to extract K/V pointers from packed QKV/KV tensors based on layout
- Unified all attention calls to use the single
nvte_fused_attn_fwd/bwdAPI - Updated workspace size calculation functions similarly
Issue Found:
- For KV-packed layout (NVTE_HD_2HD), stride calculation uses
qk_head_dimbut should usev_head_dimto match the actual tensor shape[batch*seqlen, 2, num_gqa_groups, v_head_dim]. While enforced equal by runtime check, usingv_head_dimis semantically correct.
Confidence Score: 4/5
- Safe to merge after fixing stride calculation to use
v_head_diminstead ofqk_head_dimfor KV-packed layout - The refactoring is well-structured and aligns with the goal of deprecating packed-specific APIs. However, the stride calculation issue (using
qk_head_diminstead ofv_head_dim) in the KV-packed layout needs to be fixed for semantic correctness, even though runtime checks enforce equality. The logic is sound otherwise, with proper handling of different layouts and appropriate memory clearing for ragged sequences. - transformer_engine/jax/csrc/extensions/attention.cpp - Fix stride calculation on lines 290 and 523 to use
v_head_diminstead ofqk_head_dim
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/csrc/extensions/attention.cpp | 4/5 | Refactors JAX attention to use unified nvte_fused_attn_fwd/bwd API; has stride calculation issue using qk_head_dim instead of v_head_dim for KV-packed layout |
Sequence Diagram
sequenceDiagram
participant JAX as JAX Python Layer
participant FwdImpl as FusedAttnForwardImpl
participant LayoutCheck as Layout Detection
participant PtrCalc as Pointer Calculation
participant API as nvte_fused_attn_fwd
JAX->>FwdImpl: Call with q, k, v pointers
FwdImpl->>LayoutCheck: Check layout_group (NVTE_3HD/HD_2HD/HD_HD_HD)
alt NVTE_3HD (QKV packed)
LayoutCheck->>PtrCalc: Extract K, V from packed Q
PtrCalc->>PtrCalc: k_ptr = q + stride<br/>v_ptr = q + 2*stride<br/>stride = typeSize * attn_heads * qk_head_dim
else NVTE_HD_2HD (KV packed)
LayoutCheck->>PtrCalc: Extract V from packed K
PtrCalc->>PtrCalc: v_ptr = k + stride<br/>stride = typeSize * num_gqa_groups * qk_head_dim
else NVTE_HD_HD_HD (separate)
LayoutCheck->>PtrCalc: Use pointers as-is
end
PtrCalc->>API: Call with separate q_ptr, k_ptr, v_ptr
API-->>FwdImpl: Return results
FwdImpl-->>JAX: Return output
1 file reviewed, no comments
|
/te-ci jax |
Description
JAX calls
nvte_fused_attn_fwd_kvpacked(),nvte_fused_attn_fwd_qkvpacked()ornvte_fused_attn_fwd(). First two will be deprecated by #2287, so this PR changes the jax extension code to use only last one.Type of change
Checklist: