-
Notifications
You must be signed in to change notification settings - Fork 547
[JAX] Add support for sink attention in JAX #2225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/te-ci jax |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/te-ci jax L1 |
for more information, see https://pre-commit.ci
|
/te-ci jax L1 |
|
/te-ci jax L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds support for sink attention (attention with learnable softmax variants) to JAX, following the PyTorch implementation from PR #2148. The implementation introduces three softmax types: VANILLA (standard), OFF_BY_ONE (adds 1 to denominator), and LEARNABLE (uses learnable per-head offset parameters).
Key Changes:
- Adds
AttnSoftmaxTypeenum with three variants: VANILLA_SOFTMAX, OFF_BY_ONE_SOFTMAX, LEARNABLE_SOFTMAX - Threads
softmax_typeandsoftmax_offsetparameters throughout the attention pipeline from Flax modules through JAX primitives to C++/cuDNN backends - Renames
SoftmaxTypetoSoftmaxFusionto distinguish kernel fusion strategies from sink attention variants - Updates all attention primitives (standard, context parallel with all-gather, ring attention) to handle softmax_offset
- Implements proper gradient computation for learnable softmax parameters with appropriate all-reduce operations
- Adds comprehensive test coverage including forward/backward passes and distributed scenarios
Implementation Details:
- For LEARNABLE_SOFTMAX, creates a learnable parameter of shape
(1, num_heads, 1, 1)with proper sharding by head dimension - OFF_BY_ONE_SOFTMAX is handled by setting
softmax_offset=1.0 - Context parallel paths (ring attention, all-gather) return dummy gradients for softmax_offset as they don't support learnable variants
- The C++ layer properly packs softmax_offset tensors into the cuDNN tensor pack for both forward and backward passes
- Refactored C++ code consolidates multiple layout-specific calls into unified
nvte_fused_attn_fwd/bwdcalls
Testing:
Test runtime increased from ~3061s to ~3314s (+8%) due to additional test cases covering the three softmax variants
Confidence Score: 4/5
- This PR is generally safe to merge with one potential mathematical issue to verify
- The implementation is comprehensive and well-structured with proper gradient handling, sharding, and test coverage. However, there's one potential mathematical correctness issue in
jax_general_softmax(transformer_engine/jax/cpp_extensions/softmax.py:834-853) where the offset handling needs verification - specifically whether callers pass the raw learnable parameter or its exponential for LEARNABLE_SOFTMAX - transformer_engine/jax/cpp_extensions/softmax.py - verify that the
jax_general_softmaxfunction receives the correct pre-exponentiated offset values for LEARNABLE_SOFTMAX
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/attention.py | 4/5 | Adds AttnSoftmaxType enum with support for VANILLA, OFF_BY_ONE, and LEARNABLE softmax variants. Updates function signatures to accept softmax_type and softmax_offset parameters. Implementation looks solid with proper enum handling and parameter validation. |
| transformer_engine/jax/cpp_extensions/attention.py | 4/5 | Extensive changes to support softmax_offset throughout the attention primitives. Properly handles forward/backward passes, sharding, and context parallelism. The gradient handling for softmax_offset includes proper all-reduce for learnable softmax and dummy returns for CP paths that don't use it. |
| transformer_engine/jax/csrc/extensions/attention.cpp | 4/5 | C++ implementation updated to thread softmax_type and softmax_offset through the call chain. Adds proper tensor pack handling for softmax_offset in both forward and backward passes. The refactoring consolidates multiple layout-specific calls into a unified nvte_fused_attn_fwd call. |
| transformer_engine/jax/cpp_extensions/softmax.py | 3/5 | Adds jax_general_softmax for sink attention support and updates softmax functions to accept softmax_offset. The implementation adds offset to the denominator, but needs verification that callers pass the correct pre-exponentiated values for LEARNABLE_SOFTMAX. |
| transformer_engine/jax/flax/transformer.py | 4/5 | Adds softmax_type parameter to both fused and unfused attention implementations. Creates learnable softmax_offset parameter for LEARNABLE_SOFTMAX with proper sharding. Handles OFF_BY_ONE internally. PRE_SCALE_BIAS handling sets bias to None after adding to prevent double-addition. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Flax as Flax Transformer
participant Attn as JAX Attention
participant Prim as Attention Primitive
participant CPP as C++ Extension
participant cuDNN as cuDNN Backend
User->>Flax: Call attention with softmax_type
alt LEARNABLE_SOFTMAX
Flax->>Flax: Initialize learnable softmax_offset param
else OFF_BY_ONE_SOFTMAX
Flax->>Flax: Set softmax_offset = 1.0
else VANILLA_SOFTMAX
Flax->>Flax: Set softmax_offset = empty
end
Flax->>Attn: fused_attn(qkv, bias, softmax_offset, ...)
Attn->>Attn: Apply sharding constraints to softmax_offset
Attn->>Prim: FusedAttnFwdPrimitive.bind(q, k, v, bias, softmax_offset, ...)
Prim->>CPP: FusedAttnForwardImpl(q, k, v, bias, softmax_offset, ...)
CPP->>CPP: PrepareFusedAttnForwardAuxTensors (adds softmax_offset to tensor pack)
CPP->>cuDNN: nvte_fused_attn_fwd(tensor_pack with softmax_offset)
cuDNN-->>CPP: output, softmax_aux
CPP-->>Prim: output, softmax_aux, rng_state
Prim-->>Attn: output
Attn-->>Flax: output
Note over User,cuDNN: Backward Pass
User->>Flax: Gradient computation
Flax->>Attn: fused_attn backward
Attn->>Prim: FusedAttnBwdPrimitive.bind(...)
Prim->>CPP: FusedAttnBackwardImpl(..., softmax_offset)
CPP->>cuDNN: nvte_fused_attn_bwd(...)
cuDNN-->>CPP: dq, dk, dv, dbias, dsoftmax_offset
CPP-->>Prim: dq, dk, dv, dbias, dsoftmax_offset
alt LEARNABLE_SOFTMAX
Prim->>Prim: all_reduce dsoftmax_offset across DP/FSDP
end
Prim-->>Attn: gradients
Attn-->>Flax: gradients
Flax-->>User: Updated parameters
15 files reviewed, 1 comment
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds support for sink attention to JAX, implementing three softmax variants: VANILLA, OFF_BY_ONE (zero sink), and LEARNABLE (learnable sink). The implementation follows the pattern established in PR #2148 for PyTorch.
Key Changes
- New
AttnSoftmaxTypeenum: Defines three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE) softmax_offsetparameter: Added throughout the attention pipeline to support sink attention- For LEARNABLE_SOFTMAX: a learnable parameter
[1, num_heads, 1, 1] - For OFF_BY_ONE_SOFTMAX: treated as an implicit offset of 0
- For LEARNABLE_SOFTMAX: a learnable parameter
- Renamed
SoftmaxType→SoftmaxFusion: Distinguishes fusion strategy from softmax variant - Backend support: Updated C++ extensions and cuDNN integration to handle new softmax types
- Comprehensive tests: Added test coverage for all three softmax types
Implementation Quality
Strengths:
- Well-structured changes following existing code patterns
- Comprehensive test coverage with reference implementations
- Proper gradient handling for learnable parameters
- Clean separation between fusion strategy and softmax type
Critical Issue Found:
- Bug in
transformer_engine/jax/flax/module.py:198: OFF_BY_ONE_SOFTMAX incorrectly usessoftmax_offset = 1.0instead ofsoftmax_offset = 0.0. This will produce incorrect attention weights when logits exceed 1.0. The test reference implementation correctly uses a zero logit, but the optimized path has the wrong value.
Confidence Score: 3/5
- This PR should not be merged without fixing the OFF_BY_ONE_SOFTMAX bug
- The implementation is well-structured and comprehensive, but contains a critical logical error in the OFF_BY_ONE_SOFTMAX implementation that will cause incorrect attention computation. The bug is in
module.py:198wheresoftmax_offset = 1.0should besoftmax_offset = 0.0. This discrepancy between the test reference implementation (which correctly uses zero) and the optimized path means tests may pass while the production code produces wrong results in certain cases. transformer_engine/jax/flax/module.py- Fix OFF_BY_ONE_SOFTMAX offset value from 1.0 to 0.0
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/attention.py | 3/5 | Added AttnSoftmaxType enum and softmax_offset parameter throughout attention pipeline. Implementation looks correct. |
| transformer_engine/jax/cpp_extensions/softmax.py | 2/5 | Added jax_general_softmax with offset support. Implementation is mathematically sound but usage in module.py has a bug. |
| transformer_engine/jax/flax/module.py | 2/5 | Updated Softmax module with sink attention support. Critical bug: OFF_BY_ONE_SOFTMAX uses offset=1.0 instead of offset=0.0. |
| transformer_engine/jax/flax/transformer.py | 4/5 | Added softmax_type parameter to attention modules with learnable parameter initialization for LEARNABLE_SOFTMAX. Implementation looks correct. |
| tests/jax/test_fused_attn.py | 5/5 | Comprehensive test coverage added for all three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE). Reference implementation matches expected behavior. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant DPA as DotProductAttention
participant Fused as _FusedDotProductAttention
participant Attn as fused_attn
participant Prim as FusedAttnFwdPrimitive
participant CPP as C++ Backend
participant cuDNN as cuDNN Kernel
Note over User,cuDNN: Sink Attention Flow (JAX)
User->>DPA: __call__(query, key, value, softmax_type='off_by_one')
DPA->>Fused: forward with softmax_type
alt softmax_type == LEARNABLE_SOFTMAX
Fused->>Fused: Initialize learnable param<br/>softmax_offset [1, h, 1, 1]
else softmax_type == OFF_BY_ONE_SOFTMAX
Note over Fused: No offset param needed<br/>(handled by backend)
end
Fused->>Attn: fused_attn(qkv, bias, softmax_offset,<br/>sequence_descriptor, softmax_type)
Attn->>Prim: FusedAttnFwdPrimitive.bind(q, k, v,<br/>bias, softmax_offset, ...)
Prim->>CPP: FusedAttnForwardFFI(q, k, v, bias,<br/>softmax_offset, softmax_type)
CPP->>CPP: Prepare tensor pack:<br/>- softmax_aux<br/>- rng_state<br/>- bias (if provided)<br/>- softmax_offset (if provided)
CPP->>cuDNN: nvte_fused_attn_fwd(q, k, v, bias,<br/>softmax_offset, softmax_type, ...)
alt softmax_type == VANILLA_SOFTMAX
Note over cuDNN: S = exp(QK^T) / sum(exp(QK^T))
else softmax_type == OFF_BY_ONE_SOFTMAX
Note over cuDNN: S = exp(QK^T) / (1 + sum(exp(QK^T)))
else softmax_type == LEARNABLE_SOFTMAX
Note over cuDNN: S = exp(QK^T) / (exp(alpha) + sum(exp(QK^T)))
end
cuDNN-->>CPP: output, softmax_aux, rng_state
CPP-->>Prim: output, softmax_aux, rng_state
Prim-->>Attn: output
Attn-->>Fused: attention output
Fused-->>DPA: attention output
DPA-->>User: final output
Note over User,cuDNN: Backward Pass (if training)
User->>DPA: grad(output)
DPA->>Attn: backward
Attn->>Prim: FusedAttnBwdPrimitive
Prim->>CPP: FusedAttnBackwardFFI
CPP->>cuDNN: nvte_fused_attn_bwd
alt softmax_type == LEARNABLE_SOFTMAX
Note over cuDNN: Compute grad_softmax_offset
cuDNN-->>Prim: grad_q, grad_k, grad_v,<br/>grad_bias, grad_softmax_offset
else
cuDNN-->>Prim: grad_q, grad_k, grad_v,<br/>grad_bias, None
end
Prim-->>User: gradients
15 files reviewed, 2 comments
|
/te-ci jax |
KshitijLakhani
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me mostly. Thanks Pawel for adding this onto the JAX side, as well !
I've left a few comments and question in the files.
Leaving a few more here :
- Would be good to add info / example tensor where a specific case would be helpful / useful to use in the comments
Consider all negative scores: [-10.0, -12.0, -11.0, -15.0, -13.0, -14.0]
This represents a scenario where the query has no good match with any key—ideally, the model should attend to nothing.
Vanilla Softmax:
Weights: [0.6337, 0.0858, 0.2331, 0.0043, 0.0315, 0.0116]
Notes: Even though all scores are terrible, the model is forced to give 63% attention to the "least bad" position
Off-by-One Softmax:
* Weights: [0.000045, 0.000006, 0.000017, 0.0000003, 0.000002, 0.0000008]
* Sum: 0.000072 (nearly zero as compared to 1 for Vanilla)
* Notes: All attention weights are tiny—the model can effectively express "I don't want to attend to anything"
- Thanks for consolidating nvte_fused_attn_fwd() calls in
csrc/extensions/attention.cppbut I would strongly suggest to decouple that into a separate PR
| # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it) | ||
| dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset) | ||
| return dq, dk, dv, global_dbias, dummy_dsoftmax_offset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only for arity - the idea is that the P2P helper will raise an exception in the "checking configs" stage itself so we won't really need to worry about this part being executed, right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is executed if and only if softmax_offset has size 0.
| softmax_out = softmax_with_extra[..., :-1].astype(dtype) | ||
| case AttnSoftmaxType.LEARNABLE_SOFTMAX: | ||
| # Append learnable offset logit, apply standard softmax, then remove last column | ||
| learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quick question : Remind me again, does this match the logits shape ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# logits with shape (b, h_kv, num_groups, s_q, s_kv)
comment on line 89
|
Also, @pggPL the test times do not record the additional test you add in test_layer right ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds comprehensive sink attention support to JAX by introducing three softmax types: vanilla, off-by-one, and learnable. The implementation follows the pattern established in PR #2148 for PyTorch.
Key Changes:
- Added
AttnSoftmaxTypeenum with three variants (VANILLA, OFF_BY_ONE, LEARNABLE) - Introduced
softmax_typeparameter throughout the attention pipeline (transformer layers → attention modules → fused/unfused attention → softmax) - Implemented learnable
softmax_offsetparameter for LEARNABLE_SOFTMAX, properly initialized and sharded across attention heads - Extended JAX fallback path with
jax_general_softmaxfunction that correctly implements sink attention mathematics - Updated C++ extensions and cuDNN backend integration to handle softmax type and offset tensors
- Added comprehensive test coverage for all three softmax types with proper reference implementations
- Refactored
SoftmaxType→SoftmaxFusionto avoid naming collision withAttnSoftmaxType
Implementation Details:
- OFF_BY_ONE_SOFTMAX: Sets offset=0.0, which adds exp(0-x_max)=exp(-x_max) to the denominator, contributing +1 after normalization
- LEARNABLE_SOFTMAX: Uses a learnable parameter
alphaof shape (1, num_heads, 1, 1) as the offset logit - Both fused (cuDNN) and unfused (JAX native) paths are supported
- Proper gradient flow for learnable offset parameter in backward pass
- Shape validation and dtype checks ensure offset tensor is (1, H, 1, 1) float32
Confidence Score: 5/5
- This PR is safe to merge with high confidence
- The implementation is thorough, well-tested, and correctly follows established patterns from the PyTorch implementation. The mathematics of sink attention is properly implemented in both test reference and actual code. Previous concerns about offset handling have been correctly addressed (offset=0.0 for OFF_BY_ONE). Comprehensive test coverage includes all softmax types with proper reference implementations. The changes are non-breaking and extend existing functionality.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/attention.py | 5/5 | Added AttnSoftmaxType enum and integrated softmax_offset parameter throughout the fused attention pipeline. Changes properly thread the parameter through forward and backward passes. |
| transformer_engine/jax/flax/transformer.py | 5/5 | Added softmax_type parameter to attention modules with proper initialization of learnable softmax_offset parameters. Correctly passes parameters through both unfused and fused attention paths. |
| transformer_engine/jax/flax/module.py | 5/5 | Updated Softmax module to support sink attention by adding softmax_type parameter and handling softmax_offset. Correctly sets offset=0.0 for OFF_BY_ONE_SOFTMAX. |
| transformer_engine/jax/cpp_extensions/softmax.py | 5/5 | Added jax_general_softmax function implementing sink attention math. Updated all softmax functions to accept optional softmax_offset parameter. Implementation correctly handles offset as logit value. |
| transformer_engine/jax/cpp_extensions/attention.py | 5/5 | Extended fused attention primitives to support softmax_type and softmax_offset. Properly validates shapes and dtypes, threads parameters through FFI calls. |
| transformer_engine/jax/csrc/extensions/attention.cpp | 5/5 | Updated C++ implementation to accept softmax_type parameter and handle softmax_offset tensor in aux tensor packs. Properly integrated with cuDNN backend. |
| tests/jax/test_fused_attn.py | 5/5 | Added comprehensive tests for all three softmax types (vanilla, off-by-one, learnable). Reference implementation correctly appends zero/learnable logit and applies standard softmax. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant MHA as MultiHeadAttention
participant DPA as DotProductAttention
participant Fused as _FusedDotProductAttention
participant Unfused as _UnfusedDotProductAttention
participant Softmax as Softmax Module
participant FusedAttn as fused_attn (JAX)
participant CPP as C++ Extensions
participant cuDNN as cuDNN Backend
User->>MHA: forward(q, k, v, softmax_type='off_by_one')
MHA->>MHA: Initialize softmax_offset param if LEARNABLE
MHA->>DPA: __call__(q, k, v, softmax_type)
alt Fused Attention Path
DPA->>Fused: __call__(q, k, v, softmax_type)
Fused->>Fused: Create softmax_offset param if LEARNABLE
Fused->>FusedAttn: fused_attn(qkv, softmax_offset, softmax_type)
FusedAttn->>CPP: fused_attn_fwd(softmax_type, softmax_offset)
CPP->>cuDNN: nvte_fused_attn_fwd(softmax_type, softmax_offset)
cuDNN-->>CPP: output
CPP-->>FusedAttn: output
FusedAttn-->>Fused: output
else Unfused Attention Path
DPA->>Unfused: __call__(q, k, v, softmax_type)
Unfused->>Unfused: Create softmax_offset param if LEARNABLE
Unfused->>Unfused: Compute Q·K^T attention logits
Unfused->>Softmax: __call__(logits, softmax_offset, softmax_type)
alt OFF_BY_ONE_SOFTMAX
Softmax->>Softmax: Set offset = 0.0
Softmax->>Softmax: jax_general_softmax(logits, offset=0.0)
Note over Softmax: Adds exp(0-x_max) to denominator = +1
else LEARNABLE_SOFTMAX
Softmax->>Softmax: Use learnable offset param
Softmax->>Softmax: jax_general_softmax(logits, offset=alpha)
Note over Softmax: Adds exp(alpha-x_max) to denominator
else VANILLA_SOFTMAX
Softmax->>Softmax: jax.nn.softmax(logits)
end
Softmax-->>Unfused: attention_weights
Unfused->>Unfused: attention_weights · V
Unfused-->>DPA: output
end
DPA-->>MHA: output
MHA-->>User: output
15 files reviewed, no comments
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds sink attention support to JAX, completing the implementation started in PR #2148 for PyTorch. The changes introduce three softmax variants: VANILLA_SOFTMAX (standard), OFF_BY_ONE_SOFTMAX (adds +1 to denominator), and LEARNABLE_SOFTMAX (adds learnable offset parameter).
Key Implementation Details
- Softmax offset handling: OFF_BY_ONE_SOFTMAX uses
offset=0.0which correctly addsexp(0-x_max)to the denominator, resulting in +1 after normalization - Unfused path:
_UnfusedDotProductAttentioninitializes a learnablesoftmax_offsetparameter with shape (1, h, 1, 1) for LEARNABLE_SOFTMAX - Softmax module: Properly detects that fused kernels don't support sink attention (line 45 in softmax.py returns False) and falls back to JAX implementation
- JAX implementation:
jax_general_softmaxcorrectly handles max normalization by including offset in max calculation and addingexp(offset - x_max)to denominator - Context parallel: Runtime validation ensures sink attention is not used with context parallelism (check_supported method validates
softmax_type == VANILLA_SOFTMAX) - Gradient flow: Backward pass correctly returns
Noneforgrad_softmax_offsetwhen not using LEARNABLE_SOFTMAX
Test Coverage
Comprehensive test coverage added across:
test_fused_attn.py: Tests all three softmax types with correct reference implementationstest_distributed_fused_attn.py: Distributed tests with proper constraint enforcement- Naming cleanup: Renamed
SoftmaxType→SoftmaxFusionTypeto distinguish fromAttnSoftmaxType
Confidence Score: 5/5
- This PR is safe to merge with high confidence - implementation is mathematically correct, well-tested, and properly handles edge cases
- Score of 5 reflects: (1) Correct mathematical implementation of sink attention verified through code analysis, (2) Comprehensive test coverage including reference implementations, (3) Proper constraint validation preventing misuse with context parallelism, (4) Correct gradient handling in backward pass, (5) Clean fallback to JAX implementation when fused kernels unavailable, (6) All previous review comments were already addressed correctly
- No files require special attention - all implementations are correct and well-tested
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/flax/transformer.py | 5/5 | Added sink attention support to _UnfusedDotProductAttention with proper learnable parameter initialization and sharding |
| transformer_engine/jax/flax/module.py | 5/5 | Added softmax_offset handling in Softmax module with proper fallback to JAX implementation when fused kernels don't support sink attention |
| transformer_engine/jax/cpp_extensions/softmax.py | 5/5 | Implemented jax_general_softmax with offset support for sink attention, correctly handling max normalization and denominator adjustment |
| transformer_engine/jax/cpp_extensions/attention.py | 5/5 | Added softmax_offset parameter threading through attention primitives with proper validation that context parallel doesn't support sink attention |
| transformer_engine/jax/attention.py | 5/5 | Added softmax_offset parameter to fused_attn API with proper gradient handling in backward pass |
| tests/jax/test_fused_attn.py | 5/5 | Added comprehensive tests for all three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE) with correct reference implementations |
Sequence Diagram
sequenceDiagram
participant User
participant DotProductAttention
participant UnfusedDPA as _UnfusedDotProductAttention
participant Softmax as Softmax Module
participant SoftmaxKernel as is_softmax_kernel_available
participant JAXSoftmax as jax_general_softmax
participant FusedAttn as fused_attn
User->>DotProductAttention: call with softmax_type
alt Unfused Path
DotProductAttention->>UnfusedDPA: forward(q, k, v, ...)
alt LEARNABLE_SOFTMAX
UnfusedDPA->>UnfusedDPA: init learnable param (1, h, 1, 1)
else OFF_BY_ONE_SOFTMAX
UnfusedDPA->>UnfusedDPA: set softmax_offset = 0.0
end
UnfusedDPA->>UnfusedDPA: compute logits (QK^T)
UnfusedDPA->>Softmax: __call__(logits, mask, bias, softmax_offset)
Softmax->>SoftmaxKernel: check if kernel available
SoftmaxKernel-->>Softmax: False (sink attention not supported)
Softmax->>JAXSoftmax: jax_scaled_masked_softmax(logits, mask, scale, offset)
JAXSoftmax->>JAXSoftmax: compute x_max = max(x_max, offset)
JAXSoftmax->>JAXSoftmax: denominator += exp(offset - x_max)
JAXSoftmax-->>Softmax: softmax probs
Softmax-->>UnfusedDPA: attention weights
UnfusedDPA->>UnfusedDPA: compute output (weights @ V)
UnfusedDPA-->>User: attention output
else Fused Path
DotProductAttention->>FusedAttn: fused_attn(qkv, bias, softmax_offset, ...)
alt softmax_type != VANILLA
FusedAttn->>FusedAttn: validate CP not enabled
end
FusedAttn->>FusedAttn: forward with softmax_offset
FusedAttn->>FusedAttn: backward computes grad_softmax_offset
alt LEARNABLE_SOFTMAX
FusedAttn-->>User: output + gradients for softmax_offset
else
FusedAttn-->>User: output (grad_softmax_offset = None)
end
end
12 files reviewed, no comments
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds sink attention support to JAX, completing the implementation started in PR #2148 for PyTorch. It implements two types of sink attention: OFF_BY_ONE_SOFTMAX (adds +1 to denominator) and LEARNABLE_SOFTMAX (adds learnable parameter to denominator).
Key Changes:
- Adds
AttnSoftmaxTypeenum with VANILLA, OFF_BY_ONE, and LEARNABLE variants - Implements
jax_general_softmaxfunction that correctly handles softmax with offset by computingmax(x_max, offset)and addingexp(offset - x_max)to denominator - Threads
softmax_typeandsoftmax_offsetparameters through entire attention stack (MultiHeadAttention → DotProductAttention → fused/unfused implementations) - For LEARNABLE_SOFTMAX, creates a learnable parameter with shape
(1, num_heads, 1, 1)and proper sharding - Updates C++ bindings to pass softmax_type and softmax_offset to cuDNN backend
- Fixes critical bug from initial implementation where
softmax_offset = 0.0is correctly set for OFF_BY_ONE_SOFTMAX (was1.0in earlier version) - Adds comprehensive tests validating correctness against reference implementation
- Properly handles gradients for learnable offset in backward pass
Architecture:
- Both fused (cuDNN) and unfused (JAX) attention paths fully support all three softmax types
- Renames
SoftmaxTypetoSoftmaxFusionTypeto avoid confusion withAttnSoftmaxType - Maintains backward compatibility with default
softmax_type='vanilla'
Confidence Score: 5/5
- This PR is safe to merge with high confidence
- The implementation is mathematically correct, well-tested, and follows established patterns from the PyTorch implementation. The critical offset calculation bug identified in previous reviews has been fixed (offset=0.0 for OFF_BY_ONE). The code properly validates input shapes/dtypes, handles gradients correctly for learnable parameters, and includes comprehensive test coverage. All changes are additive with backward compatibility maintained.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/flax/module.py | 5/5 | Adds softmax_type parameter to Softmax module, correctly sets offset=0.0 for OFF_BY_ONE_SOFTMAX (fixes previous issue), properly validates LEARNABLE_SOFTMAX requires softmax_offset parameter |
| transformer_engine/jax/cpp_extensions/softmax.py | 5/5 | Implements jax_general_softmax with offset support, correctly computes max(x_max, offset) and adds exp(offset - x_max) to denominator, properly handles vanilla softmax fallback when offset is not needed |
| transformer_engine/jax/attention.py | 5/5 | Adds AttnSoftmaxType enum and threads softmax_offset through fused attention API, properly handles gradients for LEARNABLE_SOFTMAX, updates custom VJP implementation correctly |
| transformer_engine/jax/cpp_extensions/attention.py | 5/5 | Updates FusedAttnHelper and primitives to support softmax_type and softmax_offset, validates shapes and dtypes correctly, properly handles gradient computation for learnable offset |
| transformer_engine/jax/flax/transformer.py | 5/5 | Adds softmax_type parameter to attention modules, creates learnable softmax_offset parameter with proper sharding, correctly threads parameter through both fused and unfused attention paths |
| transformer_engine/jax/csrc/extensions/attention.cpp | 5/5 | Updates C++ bindings to accept softmax_type and softmax_offset parameters, properly adds offset to auxiliary tensor packs for forward and backward passes |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant MHA as MultiHeadAttention
participant DPA as DotProductAttention
participant Fused as _FusedDotProductAttention
participant Unfused as _UnfusedDotProductAttention
participant Softmax as Softmax Module
participant FusedAttn as fused_attn
participant Primitive as FusedAttnFwdPrimitive
participant CPP as C++ Extension
participant cuDNN as cuDNN Backend
User->>MHA: forward(q, k, v, softmax_type)
MHA->>MHA: Initialize softmax_offset param<br/>if softmax_type == LEARNABLE
MHA->>DPA: forward(q, k, v, softmax_type)
alt Fused Attention Path
DPA->>Fused: forward(q, k, v)
Fused->>Fused: Initialize softmax_offset<br/>if LEARNABLE_SOFTMAX
Fused->>FusedAttn: fused_attn(qkv, bias, softmax_offset,<br/>softmax_type=...)
FusedAttn->>Primitive: bind(qkv, bias, softmax_offset, ...)
Primitive->>CPP: fused_attn_fwd_ffi(softmax_type, softmax_offset)
CPP->>CPP: PrepareFusedAttnForwardAuxTensors<br/>(adds softmax_offset to tensor pack)
CPP->>cuDNN: nvte_fused_attn_fwd(..., softmax_offset, ...)
cuDNN-->>CPP: output, softmax_aux
CPP-->>Primitive: output, softmax_aux, rng_state
Primitive-->>FusedAttn: output
FusedAttn-->>Fused: output
Fused-->>DPA: output
else Unfused Attention Path
DPA->>Unfused: forward(q, k, v)
Unfused->>Unfused: Initialize softmax_offset<br/>if LEARNABLE_SOFTMAX
Unfused->>Unfused: Compute logits = scale * Q @ K^T
Unfused->>Softmax: forward(logits, softmax_offset,<br/>softmax_type=...)
alt OFF_BY_ONE_SOFTMAX
Softmax->>Softmax: Set offset = 0.0
Softmax->>Softmax: jax_general_softmax(logits, offset=0.0)
Note over Softmax: Adds exp(0 - x_max) to denominator<br/>resulting in +1 offset
else LEARNABLE_SOFTMAX
Softmax->>Softmax: Use learnable offset param
Softmax->>Softmax: jax_general_softmax(logits, offset=alpha)
Note over Softmax: Adds exp(alpha - x_max) to denominator
else VANILLA_SOFTMAX
Softmax->>Softmax: Standard jax.nn.softmax
end
Softmax-->>Unfused: attn_weights
Unfused->>Unfused: output = attn_weights @ V
Unfused-->>DPA: output
end
DPA-->>MHA: output
MHA-->>User: output
Note over User,cuDNN: Backward pass mirrors forward,<br/>computing gradients for softmax_offset<br/>when softmax_type == LEARNABLE_SOFTMAX
16 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully implements sink attention support for JAX, following the PyTorch implementation from PR #2148. The changes add three softmax types: VANILLA_SOFTMAX (standard), OFF_BY_ONE_SOFTMAX (adds +1 to denominator), and LEARNABLE_SOFTMAX (learnable sink parameter).
Key changes:
- Introduced
AttnSoftmaxTypeenum with three variants (VANILLA, OFF_BY_ONE, LEARNABLE) - Separated concerns between
SoftmaxFusionType(kernel fusion strategy) andAttnSoftmaxType(mathematical softmax variant) - Added
softmax_offsetparameter throughout the attention stack (module → attention API → primitives) - For OFF_BY_ONE_SOFTMAX:
softmax_offset=0.0(logit value, not denominator contribution) - For LEARNABLE_SOFTMAX: learnable parameter with shape
(1, num_heads, 1, 1)and proper sharding viaHEAD_AXES - Context parallel explicitly doesn't support sink attention (forced to VANILLA_SOFTMAX)
- Comprehensive test coverage added across all attention backends
Implementation correctness:
The softmax_offset=0.0 for OFF_BY_ONE_SOFTMAX is mathematically correct. The offset represents the logit value (not the denominator contribution). In jax_general_softmax, this computes exp(0 - x_max) which after normalization contributes +1 to the denominator, matching the reference implementation in tests that appends jnp.zeros to logits.
Confidence Score: 5/5
- This PR is safe to merge with minimal risk - the implementation is mathematically sound, well-tested, and follows established patterns from the PyTorch implementation.
- Score of 5 reflects: (1) Comprehensive test coverage with 36+ new test cases across multiple backends and configurations, (2) Clean separation of
SoftmaxFusionTypevsAttnSoftmaxTypewith no breaking changes, (3) Correct mathematical implementation verified against reference code, (4) Proper handling of edge cases (context parallel limitation documented), (5) Consistent API design following PyTorch precedent, and (6) All validation checks for shapes/dtypes in place. - No files require special attention. The previous comments flagging
softmax_offset=0.0as incorrect were erroneous - the implementation is mathematically correct.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/flax/module.py | 5/5 | Added AttnSoftmaxType support to Softmax module with proper handling of softmax_offset parameter. Implementation correctly sets softmax_offset=0.0 for OFF_BY_ONE_SOFTMAX (previous comments claiming this was wrong are incorrect). |
| transformer_engine/jax/attention.py | 5/5 | Added AttnSoftmaxType enum and integrated softmax_offset parameter throughout fused attention API. Includes proper from_str conversion method and documentation. |
| transformer_engine/jax/cpp_extensions/attention.py | 5/5 | Extended fused attention primitives to support softmax_type and softmax_offset in both forward and backward passes. Proper validation of offset shapes and dtypes added. |
| tests/jax/test_distributed_softmax.py | 5/5 | Updated test suite to distinguish between SoftmaxFusionType (SCALED/SCALED_MASKED/SCALED_UPPER_TRIANG_MASKED) and AttnSoftmaxType (VANILLA/OFF_BY_ONE/LEARNABLE). Clean separation of concerns. |
Sequence Diagram
sequenceDiagram
participant User
participant DotProductAttention
participant UnfusedDPA as _UnfusedDotProductAttention
participant FusedDPA as _FusedDotProductAttention
participant Softmax
participant FusedAttn as fused_attn
participant Primitives as FusedAttnFwd/BwdPrimitive
User->>DotProductAttention: __call__(query, key, value, ...)
DotProductAttention->>DotProductAttention: Convert softmax_type string to AttnSoftmaxType enum
DotProductAttention->>DotProductAttention: Check if fused attention available
alt Fused Attention Available
DotProductAttention->>FusedDPA: __call__(query, key, value, ...)
FusedDPA->>FusedDPA: Initialize softmax_offset param if LEARNABLE_SOFTMAX
FusedDPA->>FusedAttn: fused_attn(qkv, bias, softmax_offset, ...)
FusedAttn->>Primitives: Forward/Backward with softmax_type & offset
Primitives-->>FusedAttn: Attention output
FusedAttn-->>FusedDPA: Output
FusedDPA-->>DotProductAttention: Output
else Unfused Attention
DotProductAttention->>UnfusedDPA: __call__(query, key, value, ...)
UnfusedDPA->>UnfusedDPA: Initialize softmax_offset param if LEARNABLE_SOFTMAX
UnfusedDPA->>UnfusedDPA: Compute Q @ K^T (attention logits)
UnfusedDPA->>UnfusedDPA: Convert attn_mask_type to SoftmaxFusionType
UnfusedDPA->>Softmax: __call__(logits, mask, bias, softmax_offset)
Softmax->>Softmax: Set softmax_offset=0.0 if OFF_BY_ONE_SOFTMAX
Softmax->>Softmax: Apply softmax with offset (jax_general_softmax)
Softmax-->>UnfusedDPA: Attention weights
UnfusedDPA->>UnfusedDPA: Compute attn_weights @ V
UnfusedDPA-->>DotProductAttention: Output
end
DotProductAttention-->>User: Attention output
5 files reviewed, no comments
Greptile OverviewGreptile SummaryThis PR adds sink attention support to JAX, implementing three softmax variants:
The implementation threads
The implementation is mathematically sound and follows established patterns from PyTorch sink attention support (PR #2148). Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User Code
participant MHA as MultiHeadAttention
participant DPA as DotProductAttention
participant Softmax as Softmax Module
participant FusedAttn as FusedAttn (cuDNN)
participant JAXSoftmax as jax_general_softmax
User->>MHA: call with softmax_type
Note over MHA: softmax_type: 'vanilla'/'off_by_one'/'learnable'
alt Learnable Softmax
MHA->>MHA: Create learnable param<br/>softmax_offset [1,h,1,1]
end
MHA->>DPA: Forward with softmax_type
alt Unfused Path
DPA->>Softmax: logits, softmax_type, offset
Note over Softmax: OFF_BY_ONE: offset=0.0<br/>LEARNABLE: offset=learned param
Softmax->>JAXSoftmax: jax_general_softmax(logits, offset)
Note over JAXSoftmax: x_max = max(x_max, offset)<br/>denom += exp(offset - x_max)
JAXSoftmax-->>Softmax: softmax output
Softmax-->>DPA: attention weights
else Fused Path (cuDNN)
DPA->>FusedAttn: Q, K, V, softmax_type, offset
Note over FusedAttn: Passes softmax_type and offset<br/>to cuDNN backend
FusedAttn->>FusedAttn: cuDNN fused kernel<br/>with sink attention
FusedAttn-->>DPA: attention output
end
DPA-->>MHA: attention output
Note over User,JAXSoftmax: Backward Pass
alt Learnable Softmax
MHA->>MHA: Accumulate grad_softmax_offset<br/>across DP/FSDP shards
MHA->>MHA: Update learnable param
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, no comments
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, no comments
|
/te-ci jax L1 |
phu0ngng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks
Description
PR #2148 added support for sink attention to common and PyTorch. This PR adds support for JAX.
Fixes #2070
Type of change
Checklist: