Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Nov 6, 2025

Description

JAX calls nvte_fused_attn_fwd_kvpacked(), nvte_fused_attn_fwd_qkvpacked() or nvte_fused_attn_fwd(). First two will be deprecated by #2287, so this PR changes the jax extension code to use only last one.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL changed the title [JAX] Make all jax attention calls to use non-packed common calls [JAX] Make all jax attention calls use non-packed common calls Nov 6, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactors JAX attention extension to use only nvte_fused_attn_fwd/bwd instead of the deprecated packed variants (nvte_fused_attn_fwd_qkvpacked and nvte_fused_attn_fwd_kvpacked). The PR moves pointer arithmetic from the common API layer into the JAX extension code.

Key changes:

  • Unified all three layout types (QKV packed, KV packed, separate) to call single nvte_fused_attn_fwd/bwd API
  • Added pointer arithmetic in JAX extension to extract K and V pointers from packed tensors
  • Removed unused tensor shape definitions and layout-specific branching in workspace size calculations
  • Updated gradient zeroing logic in backward pass to correctly handle packed tensor memory layouts

Critical issue found:

  • Lines 287 and 517: Stride calculation for KV-packed layout uses qk_head_dim but should use v_head_dim since KV packed tensors have shape [batch*seqlen, 2, num_gqa_groups, v_head_dim]

Confidence Score: 1/5

  • This PR contains critical pointer arithmetic bugs that will cause memory corruption or incorrect results
  • Score of 1 (critical issues) due to incorrect stride calculation in KV-packed layout at lines 287 and 517. Using qk_head_dim instead of v_head_dim will cause V pointer to point to wrong memory location when these dimensions differ, leading to incorrect attention computation or potential memory access violations
  • transformer_engine/jax/csrc/extensions/attention.cpp - lines 287 and 517 must be fixed before merge

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/csrc/extensions/attention.cpp 1/5 Refactors JAX attention to use only nvte_fused_attn_fwd/bwd, but has critical pointer arithmetic bug in KV-packed layout stride calculation using qk_head_dim instead of v_head_dim

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Extension
    participant Old as Old API (deprecated)
    participant New as nvte_fused_attn_fwd
    participant Kernel as Attention Kernel

    Note over JAX,Kernel: Before PR #2358
    JAX->>Old: nvte_fused_attn_fwd_qkvpacked(QKV)
    Old->>Kernel: Extract Q,K,V pointers internally
    Kernel-->>Old: Result
    Old-->>JAX: Output

    JAX->>Old: nvte_fused_attn_fwd_kvpacked(Q, KV)
    Old->>Kernel: Extract K,V pointers internally
    Kernel-->>Old: Result
    Old-->>JAX: Output

    Note over JAX,Kernel: After PR #2358
    JAX->>JAX: Extract Q,K,V pointers from packed data
    JAX->>New: nvte_fused_attn_fwd(Q, K, V)
    New->>Kernel: Forward separate pointers
    Kernel-->>New: Result
    New-->>JAX: Output

    Note over JAX: Pointer arithmetic handles:<br/>- NVTE_3HD: QKV packed<br/>- NVTE_HD_2HD: KV packed<br/>- NVTE_HD_HD_HD: All separate
Loading

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

// Python passes: q=query, k=packed_kv, v=dummy
// Extract V pointer from the packed k data
// Note: For KV packed, qk_head_dim must equal v_head_dim
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: stride calculation uses qk_head_dim but KV packed tensor has shape [batch*seqlen, 2, num_gqa_groups, v_head_dim] - should use v_head_dim

Suggested change
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
size_t stride = (typeToSize(dtype) * num_gqa_groups * v_head_dim);

cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream);
// Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
// Note: For KV packed, qk_head_dim must equal v_head_dim
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: stride calculation uses qk_head_dim but KV packed tensor has shape [batch*seqlen, 2, num_gqa_groups, v_head_dim] - should use v_head_dim

Suggested change
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
size_t stride = (typeToSize(dtype) * num_gqa_groups * v_head_dim);

pggPL and others added 2 commits November 6, 2025 23:01
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactors JAX attention code to use unified nvte_fused_attn_fwd() and nvte_fused_attn_bwd() APIs instead of deprecated packed variants (nvte_fused_attn_fwd_kvpacked() and nvte_fused_attn_fwd_qkvpacked()). The pointer arithmetic for extracting K/V pointers from packed tensors is now handled in the JAX layer.

Key Changes:

  • Removed conditional calls to packed-specific attention APIs
  • Added pointer calculation logic to extract K/V pointers from packed QKV/KV tensors based on layout
  • Unified all attention calls to use the single nvte_fused_attn_fwd/bwd API
  • Updated workspace size calculation functions similarly

Issue Found:

  • For KV-packed layout (NVTE_HD_2HD), stride calculation uses qk_head_dim but should use v_head_dim to match the actual tensor shape [batch*seqlen, 2, num_gqa_groups, v_head_dim]. While enforced equal by runtime check, using v_head_dim is semantically correct.

Confidence Score: 4/5

  • Safe to merge after fixing stride calculation to use v_head_dim instead of qk_head_dim for KV-packed layout
  • The refactoring is well-structured and aligns with the goal of deprecating packed-specific APIs. However, the stride calculation issue (using qk_head_dim instead of v_head_dim) in the KV-packed layout needs to be fixed for semantic correctness, even though runtime checks enforce equality. The logic is sound otherwise, with proper handling of different layouts and appropriate memory clearing for ragged sequences.
  • transformer_engine/jax/csrc/extensions/attention.cpp - Fix stride calculation on lines 290 and 523 to use v_head_dim instead of qk_head_dim

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/csrc/extensions/attention.cpp 4/5 Refactors JAX attention to use unified nvte_fused_attn_fwd/bwd API; has stride calculation issue using qk_head_dim instead of v_head_dim for KV-packed layout

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Python Layer
    participant FwdImpl as FusedAttnForwardImpl
    participant LayoutCheck as Layout Detection
    participant PtrCalc as Pointer Calculation
    participant API as nvte_fused_attn_fwd

    JAX->>FwdImpl: Call with q, k, v pointers
    FwdImpl->>LayoutCheck: Check layout_group (NVTE_3HD/HD_2HD/HD_HD_HD)
    
    alt NVTE_3HD (QKV packed)
        LayoutCheck->>PtrCalc: Extract K, V from packed Q
        PtrCalc->>PtrCalc: k_ptr = q + stride<br/>v_ptr = q + 2*stride<br/>stride = typeSize * attn_heads * qk_head_dim
    else NVTE_HD_2HD (KV packed)
        LayoutCheck->>PtrCalc: Extract V from packed K
        PtrCalc->>PtrCalc: v_ptr = k + stride<br/>stride = typeSize * num_gqa_groups * qk_head_dim
    else NVTE_HD_HD_HD (separate)
        LayoutCheck->>PtrCalc: Use pointers as-is
    end
    
    PtrCalc->>API: Call with separate q_ptr, k_ptr, v_ptr
    API-->>FwdImpl: Return results
    FwdImpl-->>JAX: Return output
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator Author

pggPL commented Nov 6, 2025

/te-ci jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant