Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Oct 1, 2025

Description

PR #2148 added support for sink attention to common and PyTorch. This PR adds support for JAX.

Fixes #2070

BEFORE
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    1.97s | avg:   0.16s
test_autocast_with_mesh_resource                             |   1x |    0.00s | avg:   0.00s
test_context_parallel_allgather_attn                         | 160x |  612.61s | avg:   3.83s
test_context_parallel_allgather_attn_shardy                  |  20x |   90.95s | avg:   4.55s
test_context_parallel_ring_attn                              | 640x | 1042.37s | avg:   1.63s
test_context_parallel_ring_attn_shardy                       |  20x |   37.74s | avg:   1.89s
test_cross_attn                                              |   6x |   31.82s | avg:   5.30s
test_distributed_gemm                                        |   6x |    6.10s | avg:   1.02s
test_layernorm                                               | 144x |   81.39s | avg:   0.57s
test_layernorm_mlp_grad                                      | 240x |  301.51s | avg:   1.26s
test_layernorm_mlp_grad_shardy                               | 240x |  293.58s | avg:   1.22s
test_layernorm_mlp_layer                                     |  48x |   21.58s | avg:   0.45s
test_layernorm_mlp_layer_fp8                                 | 192x |   81.58s | avg:   0.42s
test_layernorm_mlp_layer_fp8_shardy                          | 192x |   91.23s | avg:   0.48s
test_layernorm_mlp_layer_shardy                              |  48x |   25.98s | avg:   0.54s
test_rmsnorm                                                 |  72x |   29.43s | avg:   0.41s
test_self_attn                                               |  18x |   89.75s | avg:   4.99s
test_self_attn_shardy                                        |   6x |   17.32s | avg:   2.89s
test_softmax                                                 | 288x |  185.44s | avg:   0.64s
test_softmax_gspmd                                           |  24x |   13.07s | avg:   0.54s
test_te_distributed_dense_grad                               |   6x |    5.12s | avg:   0.85s
================================================================================
TOTAL RUNTIME                                                |      | 3060.56s |
================================================================================

AFTER
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    2.20s | avg:   0.18s
test_autocast_with_mesh_resource                             |   1x |    0.00s | avg:   0.00s
test_context_parallel_allgather_attn                         | 160x |  587.44s | avg:   3.67s
test_context_parallel_allgather_attn_shardy                  |  20x |   87.95s | avg:   4.40s
test_context_parallel_ring_attn                              | 640x | 1037.16s | avg:   1.62s
test_context_parallel_ring_attn_shardy                       |  20x |   41.83s | avg:   2.09s
test_cross_attn                                              |  18x |   89.76s | avg:   4.99s
test_distributed_gemm                                        |   6x |    5.74s | avg:   0.96s
test_layernorm                                               | 144x |   83.85s | avg:   0.58s
test_layernorm_mlp_grad                                      | 240x |  301.73s | avg:   1.26s
test_layernorm_mlp_grad_shardy                               | 240x |  309.08s | avg:   1.29s
test_layernorm_mlp_layer                                     |  48x |   24.98s | avg:   0.52s
test_layernorm_mlp_layer_fp8                                 | 192x |   89.17s | avg:   0.46s
test_layernorm_mlp_layer_fp8_shardy                          | 192x |   92.58s | avg:   0.48s
test_layernorm_mlp_layer_shardy                              |  48x |   26.29s | avg:   0.55s
test_rmsnorm                                                 |  72x |   29.52s | avg:   0.41s
test_self_attn                                               |  54x |  259.63s | avg:   4.81s
test_self_attn_shardy                                        |  18x |   43.51s | avg:   2.42s
test_softmax                                                 | 288x |  183.87s | avg:   0.64s
test_softmax_gspmd                                           |  24x |   12.72s | avg:   0.53s
test_te_distributed_dense_grad                               |   6x |    4.74s | avg:   0.79s
================================================================================
TOTAL RUNTIME                                                |      | 3313.74s |
================================================================================

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 4 commits October 1, 2025 14:45
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pre-commit-ci bot and others added 5 commits October 2, 2025 15:54
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 6, 2025

/te-ci jax

pggPL and others added 2 commits October 7, 2025 14:30
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@phu0ngng phu0ngng self-requested a review October 8, 2025 18:33
pggPL and others added 7 commits October 14, 2025 14:41
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 14, 2025

/te-ci jax L1

@pggPL
Copy link
Collaborator Author

pggPL commented Oct 15, 2025

/te-ci jax L1

@pggPL
Copy link
Collaborator Author

pggPL commented Oct 15, 2025

/te-ci jax L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds support for sink attention (attention with learnable softmax variants) to JAX, following the PyTorch implementation from PR #2148. The implementation introduces three softmax types: VANILLA (standard), OFF_BY_ONE (adds 1 to denominator), and LEARNABLE (uses learnable per-head offset parameters).

Key Changes:

  • Adds AttnSoftmaxType enum with three variants: VANILLA_SOFTMAX, OFF_BY_ONE_SOFTMAX, LEARNABLE_SOFTMAX
  • Threads softmax_type and softmax_offset parameters throughout the attention pipeline from Flax modules through JAX primitives to C++/cuDNN backends
  • Renames SoftmaxType to SoftmaxFusion to distinguish kernel fusion strategies from sink attention variants
  • Updates all attention primitives (standard, context parallel with all-gather, ring attention) to handle softmax_offset
  • Implements proper gradient computation for learnable softmax parameters with appropriate all-reduce operations
  • Adds comprehensive test coverage including forward/backward passes and distributed scenarios

Implementation Details:

  • For LEARNABLE_SOFTMAX, creates a learnable parameter of shape (1, num_heads, 1, 1) with proper sharding by head dimension
  • OFF_BY_ONE_SOFTMAX is handled by setting softmax_offset=1.0
  • Context parallel paths (ring attention, all-gather) return dummy gradients for softmax_offset as they don't support learnable variants
  • The C++ layer properly packs softmax_offset tensors into the cuDNN tensor pack for both forward and backward passes
  • Refactored C++ code consolidates multiple layout-specific calls into unified nvte_fused_attn_fwd/bwd calls

Testing:
Test runtime increased from ~3061s to ~3314s (+8%) due to additional test cases covering the three softmax variants

Confidence Score: 4/5

  • This PR is generally safe to merge with one potential mathematical issue to verify
  • The implementation is comprehensive and well-structured with proper gradient handling, sharding, and test coverage. However, there's one potential mathematical correctness issue in jax_general_softmax (transformer_engine/jax/cpp_extensions/softmax.py:834-853) where the offset handling needs verification - specifically whether callers pass the raw learnable parameter or its exponential for LEARNABLE_SOFTMAX
  • transformer_engine/jax/cpp_extensions/softmax.py - verify that the jax_general_softmax function receives the correct pre-exponentiated offset values for LEARNABLE_SOFTMAX

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/attention.py 4/5 Adds AttnSoftmaxType enum with support for VANILLA, OFF_BY_ONE, and LEARNABLE softmax variants. Updates function signatures to accept softmax_type and softmax_offset parameters. Implementation looks solid with proper enum handling and parameter validation.
transformer_engine/jax/cpp_extensions/attention.py 4/5 Extensive changes to support softmax_offset throughout the attention primitives. Properly handles forward/backward passes, sharding, and context parallelism. The gradient handling for softmax_offset includes proper all-reduce for learnable softmax and dummy returns for CP paths that don't use it.
transformer_engine/jax/csrc/extensions/attention.cpp 4/5 C++ implementation updated to thread softmax_type and softmax_offset through the call chain. Adds proper tensor pack handling for softmax_offset in both forward and backward passes. The refactoring consolidates multiple layout-specific calls into a unified nvte_fused_attn_fwd call.
transformer_engine/jax/cpp_extensions/softmax.py 3/5 Adds jax_general_softmax for sink attention support and updates softmax functions to accept softmax_offset. The implementation adds offset to the denominator, but needs verification that callers pass the correct pre-exponentiated values for LEARNABLE_SOFTMAX.
transformer_engine/jax/flax/transformer.py 4/5 Adds softmax_type parameter to both fused and unfused attention implementations. Creates learnable softmax_offset parameter for LEARNABLE_SOFTMAX with proper sharding. Handles OFF_BY_ONE internally. PRE_SCALE_BIAS handling sets bias to None after adding to prevent double-addition.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Flax as Flax Transformer
    participant Attn as JAX Attention
    participant Prim as Attention Primitive
    participant CPP as C++ Extension
    participant cuDNN as cuDNN Backend

    User->>Flax: Call attention with softmax_type
    
    alt LEARNABLE_SOFTMAX
        Flax->>Flax: Initialize learnable softmax_offset param
    else OFF_BY_ONE_SOFTMAX
        Flax->>Flax: Set softmax_offset = 1.0
    else VANILLA_SOFTMAX
        Flax->>Flax: Set softmax_offset = empty
    end
    
    Flax->>Attn: fused_attn(qkv, bias, softmax_offset, ...)
    Attn->>Attn: Apply sharding constraints to softmax_offset
    Attn->>Prim: FusedAttnFwdPrimitive.bind(q, k, v, bias, softmax_offset, ...)
    
    Prim->>CPP: FusedAttnForwardImpl(q, k, v, bias, softmax_offset, ...)
    CPP->>CPP: PrepareFusedAttnForwardAuxTensors (adds softmax_offset to tensor pack)
    CPP->>cuDNN: nvte_fused_attn_fwd(tensor_pack with softmax_offset)
    cuDNN-->>CPP: output, softmax_aux
    CPP-->>Prim: output, softmax_aux, rng_state
    Prim-->>Attn: output
    Attn-->>Flax: output
    
    Note over User,cuDNN: Backward Pass
    
    User->>Flax: Gradient computation
    Flax->>Attn: fused_attn backward
    Attn->>Prim: FusedAttnBwdPrimitive.bind(...)
    Prim->>CPP: FusedAttnBackwardImpl(..., softmax_offset)
    CPP->>cuDNN: nvte_fused_attn_bwd(...)
    cuDNN-->>CPP: dq, dk, dv, dbias, dsoftmax_offset
    CPP-->>Prim: dq, dk, dv, dbias, dsoftmax_offset
    
    alt LEARNABLE_SOFTMAX
        Prim->>Prim: all_reduce dsoftmax_offset across DP/FSDP
    end
    
    Prim-->>Attn: gradients
    Attn-->>Flax: gradients
    Flax-->>User: Updated parameters
Loading

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

pggPL and others added 2 commits November 5, 2025 14:31
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds support for sink attention to JAX, implementing three softmax variants: VANILLA, OFF_BY_ONE (zero sink), and LEARNABLE (learnable sink). The implementation follows the pattern established in PR #2148 for PyTorch.

Key Changes

  • New AttnSoftmaxType enum: Defines three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE)
  • softmax_offset parameter: Added throughout the attention pipeline to support sink attention
    • For LEARNABLE_SOFTMAX: a learnable parameter [1, num_heads, 1, 1]
    • For OFF_BY_ONE_SOFTMAX: treated as an implicit offset of 0
  • Renamed SoftmaxTypeSoftmaxFusion: Distinguishes fusion strategy from softmax variant
  • Backend support: Updated C++ extensions and cuDNN integration to handle new softmax types
  • Comprehensive tests: Added test coverage for all three softmax types

Implementation Quality

Strengths:

  • Well-structured changes following existing code patterns
  • Comprehensive test coverage with reference implementations
  • Proper gradient handling for learnable parameters
  • Clean separation between fusion strategy and softmax type

Critical Issue Found:

  • Bug in transformer_engine/jax/flax/module.py:198: OFF_BY_ONE_SOFTMAX incorrectly uses softmax_offset = 1.0 instead of softmax_offset = 0.0. This will produce incorrect attention weights when logits exceed 1.0. The test reference implementation correctly uses a zero logit, but the optimized path has the wrong value.

Confidence Score: 3/5

  • This PR should not be merged without fixing the OFF_BY_ONE_SOFTMAX bug
  • The implementation is well-structured and comprehensive, but contains a critical logical error in the OFF_BY_ONE_SOFTMAX implementation that will cause incorrect attention computation. The bug is in module.py:198 where softmax_offset = 1.0 should be softmax_offset = 0.0. This discrepancy between the test reference implementation (which correctly uses zero) and the optimized path means tests may pass while the production code produces wrong results in certain cases.
  • transformer_engine/jax/flax/module.py - Fix OFF_BY_ONE_SOFTMAX offset value from 1.0 to 0.0

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/attention.py 3/5 Added AttnSoftmaxType enum and softmax_offset parameter throughout attention pipeline. Implementation looks correct.
transformer_engine/jax/cpp_extensions/softmax.py 2/5 Added jax_general_softmax with offset support. Implementation is mathematically sound but usage in module.py has a bug.
transformer_engine/jax/flax/module.py 2/5 Updated Softmax module with sink attention support. Critical bug: OFF_BY_ONE_SOFTMAX uses offset=1.0 instead of offset=0.0.
transformer_engine/jax/flax/transformer.py 4/5 Added softmax_type parameter to attention modules with learnable parameter initialization for LEARNABLE_SOFTMAX. Implementation looks correct.
tests/jax/test_fused_attn.py 5/5 Comprehensive test coverage added for all three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE). Reference implementation matches expected behavior.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant DPA as DotProductAttention
    participant Fused as _FusedDotProductAttention
    participant Attn as fused_attn
    participant Prim as FusedAttnFwdPrimitive
    participant CPP as C++ Backend
    participant cuDNN as cuDNN Kernel

    Note over User,cuDNN: Sink Attention Flow (JAX)
    
    User->>DPA: __call__(query, key, value, softmax_type='off_by_one')
    DPA->>Fused: forward with softmax_type
    
    alt softmax_type == LEARNABLE_SOFTMAX
        Fused->>Fused: Initialize learnable param<br/>softmax_offset [1, h, 1, 1]
    else softmax_type == OFF_BY_ONE_SOFTMAX
        Note over Fused: No offset param needed<br/>(handled by backend)
    end
    
    Fused->>Attn: fused_attn(qkv, bias, softmax_offset,<br/>sequence_descriptor, softmax_type)
    Attn->>Prim: FusedAttnFwdPrimitive.bind(q, k, v,<br/>bias, softmax_offset, ...)
    
    Prim->>CPP: FusedAttnForwardFFI(q, k, v, bias,<br/>softmax_offset, softmax_type)
    
    CPP->>CPP: Prepare tensor pack:<br/>- softmax_aux<br/>- rng_state<br/>- bias (if provided)<br/>- softmax_offset (if provided)
    
    CPP->>cuDNN: nvte_fused_attn_fwd(q, k, v, bias,<br/>softmax_offset, softmax_type, ...)
    
    alt softmax_type == VANILLA_SOFTMAX
        Note over cuDNN: S = exp(QK^T) / sum(exp(QK^T))
    else softmax_type == OFF_BY_ONE_SOFTMAX
        Note over cuDNN: S = exp(QK^T) / (1 + sum(exp(QK^T)))
    else softmax_type == LEARNABLE_SOFTMAX
        Note over cuDNN: S = exp(QK^T) / (exp(alpha) + sum(exp(QK^T)))
    end
    
    cuDNN-->>CPP: output, softmax_aux, rng_state
    CPP-->>Prim: output, softmax_aux, rng_state
    Prim-->>Attn: output
    Attn-->>Fused: attention output
    Fused-->>DPA: attention output
    DPA-->>User: final output
    
    Note over User,cuDNN: Backward Pass (if training)
    User->>DPA: grad(output)
    DPA->>Attn: backward
    Attn->>Prim: FusedAttnBwdPrimitive
    Prim->>CPP: FusedAttnBackwardFFI
    CPP->>cuDNN: nvte_fused_attn_bwd
    
    alt softmax_type == LEARNABLE_SOFTMAX
        Note over cuDNN: Compute grad_softmax_offset
        cuDNN-->>Prim: grad_q, grad_k, grad_v,<br/>grad_bias, grad_softmax_offset
    else
        cuDNN-->>Prim: grad_q, grad_k, grad_v,<br/>grad_bias, None
    end
    
    Prim-->>User: gradients
Loading

15 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Nov 5, 2025

/te-ci jax

Copy link
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me mostly. Thanks Pawel for adding this onto the JAX side, as well !
I've left a few comments and question in the files.

Leaving a few more here :

  1. Would be good to add info / example tensor where a specific case would be helpful / useful to use in the comments
Consider all negative scores: [-10.0, -12.0, -11.0, -15.0, -13.0, -14.0]
This represents a scenario where the query has no good match with any key—ideally, the model should attend to nothing.
Vanilla Softmax:

Weights: [0.6337, 0.0858, 0.2331, 0.0043, 0.0315, 0.0116]
Notes: Even though all scores are terrible, the model is forced to give 63% attention to the "least bad" position

Off-by-One Softmax:
* Weights: [0.000045, 0.000006, 0.000017, 0.0000003, 0.000002, 0.0000008]
* Sum: 0.000072 (nearly zero as compared to 1 for Vanilla)
* Notes: All attention weights are tiny—the model can effectively express "I don't want to attend to anything"
  1. Thanks for consolidating nvte_fused_attn_fwd() calls in csrc/extensions/attention.cpp but I would strongly suggest to decouple that into a separate PR

Comment on lines +2637 to +2639
# Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only for arity - the idea is that the P2P helper will raise an exception in the "checking configs" stage itself so we won't really need to worry about this part being executed, right ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is executed if and only if softmax_offset has size 0.

softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case AttnSoftmaxType.LEARNABLE_SOFTMAX:
# Append learnable offset logit, apply standard softmax, then remove last column
learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question : Remind me again, does this match the logits shape ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# logits with shape (b, h_kv, num_groups, s_q, s_kv)
comment on line 89

@KshitijLakhani
Copy link
Collaborator

Also, @pggPL the test times do not record the additional test you add in test_layer right ?
Asking differently, could you confirm if the test timings you posted runs all the tests that you are adding ?
Thanks !

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds comprehensive sink attention support to JAX by introducing three softmax types: vanilla, off-by-one, and learnable. The implementation follows the pattern established in PR #2148 for PyTorch.

Key Changes:

  • Added AttnSoftmaxType enum with three variants (VANILLA, OFF_BY_ONE, LEARNABLE)
  • Introduced softmax_type parameter throughout the attention pipeline (transformer layers → attention modules → fused/unfused attention → softmax)
  • Implemented learnable softmax_offset parameter for LEARNABLE_SOFTMAX, properly initialized and sharded across attention heads
  • Extended JAX fallback path with jax_general_softmax function that correctly implements sink attention mathematics
  • Updated C++ extensions and cuDNN backend integration to handle softmax type and offset tensors
  • Added comprehensive test coverage for all three softmax types with proper reference implementations
  • Refactored SoftmaxTypeSoftmaxFusion to avoid naming collision with AttnSoftmaxType

Implementation Details:

  • OFF_BY_ONE_SOFTMAX: Sets offset=0.0, which adds exp(0-x_max)=exp(-x_max) to the denominator, contributing +1 after normalization
  • LEARNABLE_SOFTMAX: Uses a learnable parameter alpha of shape (1, num_heads, 1, 1) as the offset logit
  • Both fused (cuDNN) and unfused (JAX native) paths are supported
  • Proper gradient flow for learnable offset parameter in backward pass
  • Shape validation and dtype checks ensure offset tensor is (1, H, 1, 1) float32

Confidence Score: 5/5

  • This PR is safe to merge with high confidence
  • The implementation is thorough, well-tested, and correctly follows established patterns from the PyTorch implementation. The mathematics of sink attention is properly implemented in both test reference and actual code. Previous concerns about offset handling have been correctly addressed (offset=0.0 for OFF_BY_ONE). Comprehensive test coverage includes all softmax types with proper reference implementations. The changes are non-breaking and extend existing functionality.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/attention.py 5/5 Added AttnSoftmaxType enum and integrated softmax_offset parameter throughout the fused attention pipeline. Changes properly thread the parameter through forward and backward passes.
transformer_engine/jax/flax/transformer.py 5/5 Added softmax_type parameter to attention modules with proper initialization of learnable softmax_offset parameters. Correctly passes parameters through both unfused and fused attention paths.
transformer_engine/jax/flax/module.py 5/5 Updated Softmax module to support sink attention by adding softmax_type parameter and handling softmax_offset. Correctly sets offset=0.0 for OFF_BY_ONE_SOFTMAX.
transformer_engine/jax/cpp_extensions/softmax.py 5/5 Added jax_general_softmax function implementing sink attention math. Updated all softmax functions to accept optional softmax_offset parameter. Implementation correctly handles offset as logit value.
transformer_engine/jax/cpp_extensions/attention.py 5/5 Extended fused attention primitives to support softmax_type and softmax_offset. Properly validates shapes and dtypes, threads parameters through FFI calls.
transformer_engine/jax/csrc/extensions/attention.cpp 5/5 Updated C++ implementation to accept softmax_type parameter and handle softmax_offset tensor in aux tensor packs. Properly integrated with cuDNN backend.
tests/jax/test_fused_attn.py 5/5 Added comprehensive tests for all three softmax types (vanilla, off-by-one, learnable). Reference implementation correctly appends zero/learnable logit and applies standard softmax.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant MHA as MultiHeadAttention
    participant DPA as DotProductAttention
    participant Fused as _FusedDotProductAttention
    participant Unfused as _UnfusedDotProductAttention
    participant Softmax as Softmax Module
    participant FusedAttn as fused_attn (JAX)
    participant CPP as C++ Extensions
    participant cuDNN as cuDNN Backend

    User->>MHA: forward(q, k, v, softmax_type='off_by_one')
    MHA->>MHA: Initialize softmax_offset param if LEARNABLE
    MHA->>DPA: __call__(q, k, v, softmax_type)
    
    alt Fused Attention Path
        DPA->>Fused: __call__(q, k, v, softmax_type)
        Fused->>Fused: Create softmax_offset param if LEARNABLE
        Fused->>FusedAttn: fused_attn(qkv, softmax_offset, softmax_type)
        FusedAttn->>CPP: fused_attn_fwd(softmax_type, softmax_offset)
        CPP->>cuDNN: nvte_fused_attn_fwd(softmax_type, softmax_offset)
        cuDNN-->>CPP: output
        CPP-->>FusedAttn: output
        FusedAttn-->>Fused: output
    else Unfused Attention Path
        DPA->>Unfused: __call__(q, k, v, softmax_type)
        Unfused->>Unfused: Create softmax_offset param if LEARNABLE
        Unfused->>Unfused: Compute Q·K^T attention logits
        Unfused->>Softmax: __call__(logits, softmax_offset, softmax_type)
        
        alt OFF_BY_ONE_SOFTMAX
            Softmax->>Softmax: Set offset = 0.0
            Softmax->>Softmax: jax_general_softmax(logits, offset=0.0)
            Note over Softmax: Adds exp(0-x_max) to denominator = +1
        else LEARNABLE_SOFTMAX
            Softmax->>Softmax: Use learnable offset param
            Softmax->>Softmax: jax_general_softmax(logits, offset=alpha)
            Note over Softmax: Adds exp(alpha-x_max) to denominator
        else VANILLA_SOFTMAX
            Softmax->>Softmax: jax.nn.softmax(logits)
        end
        
        Softmax-->>Unfused: attention_weights
        Unfused->>Unfused: attention_weights · V
        Unfused-->>DPA: output
    end
    
    DPA-->>MHA: output
    MHA-->>User: output
Loading

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

pggPL added 2 commits November 7, 2025 12:02
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds sink attention support to JAX, completing the implementation started in PR #2148 for PyTorch. The changes introduce three softmax variants: VANILLA_SOFTMAX (standard), OFF_BY_ONE_SOFTMAX (adds +1 to denominator), and LEARNABLE_SOFTMAX (adds learnable offset parameter).

Key Implementation Details

  • Softmax offset handling: OFF_BY_ONE_SOFTMAX uses offset=0.0 which correctly adds exp(0-x_max) to the denominator, resulting in +1 after normalization
  • Unfused path: _UnfusedDotProductAttention initializes a learnable softmax_offset parameter with shape (1, h, 1, 1) for LEARNABLE_SOFTMAX
  • Softmax module: Properly detects that fused kernels don't support sink attention (line 45 in softmax.py returns False) and falls back to JAX implementation
  • JAX implementation: jax_general_softmax correctly handles max normalization by including offset in max calculation and adding exp(offset - x_max) to denominator
  • Context parallel: Runtime validation ensures sink attention is not used with context parallelism (check_supported method validates softmax_type == VANILLA_SOFTMAX)
  • Gradient flow: Backward pass correctly returns None for grad_softmax_offset when not using LEARNABLE_SOFTMAX

Test Coverage

Comprehensive test coverage added across:

  • test_fused_attn.py: Tests all three softmax types with correct reference implementations
  • test_distributed_fused_attn.py: Distributed tests with proper constraint enforcement
  • Naming cleanup: Renamed SoftmaxTypeSoftmaxFusionType to distinguish from AttnSoftmaxType

Confidence Score: 5/5

  • This PR is safe to merge with high confidence - implementation is mathematically correct, well-tested, and properly handles edge cases
  • Score of 5 reflects: (1) Correct mathematical implementation of sink attention verified through code analysis, (2) Comprehensive test coverage including reference implementations, (3) Proper constraint validation preventing misuse with context parallelism, (4) Correct gradient handling in backward pass, (5) Clean fallback to JAX implementation when fused kernels unavailable, (6) All previous review comments were already addressed correctly
  • No files require special attention - all implementations are correct and well-tested

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/flax/transformer.py 5/5 Added sink attention support to _UnfusedDotProductAttention with proper learnable parameter initialization and sharding
transformer_engine/jax/flax/module.py 5/5 Added softmax_offset handling in Softmax module with proper fallback to JAX implementation when fused kernels don't support sink attention
transformer_engine/jax/cpp_extensions/softmax.py 5/5 Implemented jax_general_softmax with offset support for sink attention, correctly handling max normalization and denominator adjustment
transformer_engine/jax/cpp_extensions/attention.py 5/5 Added softmax_offset parameter threading through attention primitives with proper validation that context parallel doesn't support sink attention
transformer_engine/jax/attention.py 5/5 Added softmax_offset parameter to fused_attn API with proper gradient handling in backward pass
tests/jax/test_fused_attn.py 5/5 Added comprehensive tests for all three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE) with correct reference implementations

Sequence Diagram

sequenceDiagram
    participant User
    participant DotProductAttention
    participant UnfusedDPA as _UnfusedDotProductAttention
    participant Softmax as Softmax Module
    participant SoftmaxKernel as is_softmax_kernel_available
    participant JAXSoftmax as jax_general_softmax
    participant FusedAttn as fused_attn

    User->>DotProductAttention: call with softmax_type
    
    alt Unfused Path
        DotProductAttention->>UnfusedDPA: forward(q, k, v, ...)
        
        alt LEARNABLE_SOFTMAX
            UnfusedDPA->>UnfusedDPA: init learnable param (1, h, 1, 1)
        else OFF_BY_ONE_SOFTMAX
            UnfusedDPA->>UnfusedDPA: set softmax_offset = 0.0
        end
        
        UnfusedDPA->>UnfusedDPA: compute logits (QK^T)
        UnfusedDPA->>Softmax: __call__(logits, mask, bias, softmax_offset)
        
        Softmax->>SoftmaxKernel: check if kernel available
        SoftmaxKernel-->>Softmax: False (sink attention not supported)
        
        Softmax->>JAXSoftmax: jax_scaled_masked_softmax(logits, mask, scale, offset)
        JAXSoftmax->>JAXSoftmax: compute x_max = max(x_max, offset)
        JAXSoftmax->>JAXSoftmax: denominator += exp(offset - x_max)
        JAXSoftmax-->>Softmax: softmax probs
        
        Softmax-->>UnfusedDPA: attention weights
        UnfusedDPA->>UnfusedDPA: compute output (weights @ V)
        UnfusedDPA-->>User: attention output
    else Fused Path
        DotProductAttention->>FusedAttn: fused_attn(qkv, bias, softmax_offset, ...)
        
        alt softmax_type != VANILLA
            FusedAttn->>FusedAttn: validate CP not enabled
        end
        
        FusedAttn->>FusedAttn: forward with softmax_offset
        FusedAttn->>FusedAttn: backward computes grad_softmax_offset
        
        alt LEARNABLE_SOFTMAX
            FusedAttn-->>User: output + gradients for softmax_offset
        else
            FusedAttn-->>User: output (grad_softmax_offset = None)
        end
    end
Loading

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds sink attention support to JAX, completing the implementation started in PR #2148 for PyTorch. It implements two types of sink attention: OFF_BY_ONE_SOFTMAX (adds +1 to denominator) and LEARNABLE_SOFTMAX (adds learnable parameter to denominator).

Key Changes:

  • Adds AttnSoftmaxType enum with VANILLA, OFF_BY_ONE, and LEARNABLE variants
  • Implements jax_general_softmax function that correctly handles softmax with offset by computing max(x_max, offset) and adding exp(offset - x_max) to denominator
  • Threads softmax_type and softmax_offset parameters through entire attention stack (MultiHeadAttention → DotProductAttention → fused/unfused implementations)
  • For LEARNABLE_SOFTMAX, creates a learnable parameter with shape (1, num_heads, 1, 1) and proper sharding
  • Updates C++ bindings to pass softmax_type and softmax_offset to cuDNN backend
  • Fixes critical bug from initial implementation where softmax_offset = 0.0 is correctly set for OFF_BY_ONE_SOFTMAX (was 1.0 in earlier version)
  • Adds comprehensive tests validating correctness against reference implementation
  • Properly handles gradients for learnable offset in backward pass

Architecture:

  • Both fused (cuDNN) and unfused (JAX) attention paths fully support all three softmax types
  • Renames SoftmaxType to SoftmaxFusionType to avoid confusion with AttnSoftmaxType
  • Maintains backward compatibility with default softmax_type='vanilla'

Confidence Score: 5/5

  • This PR is safe to merge with high confidence
  • The implementation is mathematically correct, well-tested, and follows established patterns from the PyTorch implementation. The critical offset calculation bug identified in previous reviews has been fixed (offset=0.0 for OFF_BY_ONE). The code properly validates input shapes/dtypes, handles gradients correctly for learnable parameters, and includes comprehensive test coverage. All changes are additive with backward compatibility maintained.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/flax/module.py 5/5 Adds softmax_type parameter to Softmax module, correctly sets offset=0.0 for OFF_BY_ONE_SOFTMAX (fixes previous issue), properly validates LEARNABLE_SOFTMAX requires softmax_offset parameter
transformer_engine/jax/cpp_extensions/softmax.py 5/5 Implements jax_general_softmax with offset support, correctly computes max(x_max, offset) and adds exp(offset - x_max) to denominator, properly handles vanilla softmax fallback when offset is not needed
transformer_engine/jax/attention.py 5/5 Adds AttnSoftmaxType enum and threads softmax_offset through fused attention API, properly handles gradients for LEARNABLE_SOFTMAX, updates custom VJP implementation correctly
transformer_engine/jax/cpp_extensions/attention.py 5/5 Updates FusedAttnHelper and primitives to support softmax_type and softmax_offset, validates shapes and dtypes correctly, properly handles gradient computation for learnable offset
transformer_engine/jax/flax/transformer.py 5/5 Adds softmax_type parameter to attention modules, creates learnable softmax_offset parameter with proper sharding, correctly threads parameter through both fused and unfused attention paths
transformer_engine/jax/csrc/extensions/attention.cpp 5/5 Updates C++ bindings to accept softmax_type and softmax_offset parameters, properly adds offset to auxiliary tensor packs for forward and backward passes

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant MHA as MultiHeadAttention
    participant DPA as DotProductAttention
    participant Fused as _FusedDotProductAttention
    participant Unfused as _UnfusedDotProductAttention
    participant Softmax as Softmax Module
    participant FusedAttn as fused_attn
    participant Primitive as FusedAttnFwdPrimitive
    participant CPP as C++ Extension
    participant cuDNN as cuDNN Backend

    User->>MHA: forward(q, k, v, softmax_type)
    MHA->>MHA: Initialize softmax_offset param<br/>if softmax_type == LEARNABLE
    MHA->>DPA: forward(q, k, v, softmax_type)
    
    alt Fused Attention Path
        DPA->>Fused: forward(q, k, v)
        Fused->>Fused: Initialize softmax_offset<br/>if LEARNABLE_SOFTMAX
        Fused->>FusedAttn: fused_attn(qkv, bias, softmax_offset,<br/>softmax_type=...)
        FusedAttn->>Primitive: bind(qkv, bias, softmax_offset, ...)
        Primitive->>CPP: fused_attn_fwd_ffi(softmax_type, softmax_offset)
        CPP->>CPP: PrepareFusedAttnForwardAuxTensors<br/>(adds softmax_offset to tensor pack)
        CPP->>cuDNN: nvte_fused_attn_fwd(..., softmax_offset, ...)
        cuDNN-->>CPP: output, softmax_aux
        CPP-->>Primitive: output, softmax_aux, rng_state
        Primitive-->>FusedAttn: output
        FusedAttn-->>Fused: output
        Fused-->>DPA: output
    else Unfused Attention Path
        DPA->>Unfused: forward(q, k, v)
        Unfused->>Unfused: Initialize softmax_offset<br/>if LEARNABLE_SOFTMAX
        Unfused->>Unfused: Compute logits = scale * Q @ K^T
        Unfused->>Softmax: forward(logits, softmax_offset,<br/>softmax_type=...)
        
        alt OFF_BY_ONE_SOFTMAX
            Softmax->>Softmax: Set offset = 0.0
            Softmax->>Softmax: jax_general_softmax(logits, offset=0.0)
            Note over Softmax: Adds exp(0 - x_max) to denominator<br/>resulting in +1 offset
        else LEARNABLE_SOFTMAX
            Softmax->>Softmax: Use learnable offset param
            Softmax->>Softmax: jax_general_softmax(logits, offset=alpha)
            Note over Softmax: Adds exp(alpha - x_max) to denominator
        else VANILLA_SOFTMAX
            Softmax->>Softmax: Standard jax.nn.softmax
        end
        
        Softmax-->>Unfused: attn_weights
        Unfused->>Unfused: output = attn_weights @ V
        Unfused-->>DPA: output
    end
    
    DPA-->>MHA: output
    MHA-->>User: output
    
    Note over User,cuDNN: Backward pass mirrors forward,<br/>computing gradients for softmax_offset<br/>when softmax_type == LEARNABLE_SOFTMAX
Loading

16 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR successfully implements sink attention support for JAX, following the PyTorch implementation from PR #2148. The changes add three softmax types: VANILLA_SOFTMAX (standard), OFF_BY_ONE_SOFTMAX (adds +1 to denominator), and LEARNABLE_SOFTMAX (learnable sink parameter).

Key changes:

  • Introduced AttnSoftmaxType enum with three variants (VANILLA, OFF_BY_ONE, LEARNABLE)
  • Separated concerns between SoftmaxFusionType (kernel fusion strategy) and AttnSoftmaxType (mathematical softmax variant)
  • Added softmax_offset parameter throughout the attention stack (module → attention API → primitives)
  • For OFF_BY_ONE_SOFTMAX: softmax_offset=0.0 (logit value, not denominator contribution)
  • For LEARNABLE_SOFTMAX: learnable parameter with shape (1, num_heads, 1, 1) and proper sharding via HEAD_AXES
  • Context parallel explicitly doesn't support sink attention (forced to VANILLA_SOFTMAX)
  • Comprehensive test coverage added across all attention backends

Implementation correctness:
The softmax_offset=0.0 for OFF_BY_ONE_SOFTMAX is mathematically correct. The offset represents the logit value (not the denominator contribution). In jax_general_softmax, this computes exp(0 - x_max) which after normalization contributes +1 to the denominator, matching the reference implementation in tests that appends jnp.zeros to logits.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - the implementation is mathematically sound, well-tested, and follows established patterns from the PyTorch implementation.
  • Score of 5 reflects: (1) Comprehensive test coverage with 36+ new test cases across multiple backends and configurations, (2) Clean separation of SoftmaxFusionType vs AttnSoftmaxType with no breaking changes, (3) Correct mathematical implementation verified against reference code, (4) Proper handling of edge cases (context parallel limitation documented), (5) Consistent API design following PyTorch precedent, and (6) All validation checks for shapes/dtypes in place.
  • No files require special attention. The previous comments flagging softmax_offset=0.0 as incorrect were erroneous - the implementation is mathematically correct.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/flax/module.py 5/5 Added AttnSoftmaxType support to Softmax module with proper handling of softmax_offset parameter. Implementation correctly sets softmax_offset=0.0 for OFF_BY_ONE_SOFTMAX (previous comments claiming this was wrong are incorrect).
transformer_engine/jax/attention.py 5/5 Added AttnSoftmaxType enum and integrated softmax_offset parameter throughout fused attention API. Includes proper from_str conversion method and documentation.
transformer_engine/jax/cpp_extensions/attention.py 5/5 Extended fused attention primitives to support softmax_type and softmax_offset in both forward and backward passes. Proper validation of offset shapes and dtypes added.
tests/jax/test_distributed_softmax.py 5/5 Updated test suite to distinguish between SoftmaxFusionType (SCALED/SCALED_MASKED/SCALED_UPPER_TRIANG_MASKED) and AttnSoftmaxType (VANILLA/OFF_BY_ONE/LEARNABLE). Clean separation of concerns.

Sequence Diagram

sequenceDiagram
    participant User
    participant DotProductAttention
    participant UnfusedDPA as _UnfusedDotProductAttention
    participant FusedDPA as _FusedDotProductAttention
    participant Softmax
    participant FusedAttn as fused_attn
    participant Primitives as FusedAttnFwd/BwdPrimitive

    User->>DotProductAttention: __call__(query, key, value, ...)
    DotProductAttention->>DotProductAttention: Convert softmax_type string to AttnSoftmaxType enum
    DotProductAttention->>DotProductAttention: Check if fused attention available
    
    alt Fused Attention Available
        DotProductAttention->>FusedDPA: __call__(query, key, value, ...)
        FusedDPA->>FusedDPA: Initialize softmax_offset param if LEARNABLE_SOFTMAX
        FusedDPA->>FusedAttn: fused_attn(qkv, bias, softmax_offset, ...)
        FusedAttn->>Primitives: Forward/Backward with softmax_type & offset
        Primitives-->>FusedAttn: Attention output
        FusedAttn-->>FusedDPA: Output
        FusedDPA-->>DotProductAttention: Output
    else Unfused Attention
        DotProductAttention->>UnfusedDPA: __call__(query, key, value, ...)
        UnfusedDPA->>UnfusedDPA: Initialize softmax_offset param if LEARNABLE_SOFTMAX
        UnfusedDPA->>UnfusedDPA: Compute Q @ K^T (attention logits)
        UnfusedDPA->>UnfusedDPA: Convert attn_mask_type to SoftmaxFusionType
        UnfusedDPA->>Softmax: __call__(logits, mask, bias, softmax_offset)
        Softmax->>Softmax: Set softmax_offset=0.0 if OFF_BY_ONE_SOFTMAX
        Softmax->>Softmax: Apply softmax with offset (jax_general_softmax)
        Softmax-->>UnfusedDPA: Attention weights
        UnfusedDPA->>UnfusedDPA: Compute attn_weights @ V
        UnfusedDPA-->>DotProductAttention: Output
    end
    
    DotProductAttention-->>User: Attention output
Loading

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 14, 2025

Greptile Overview

Greptile Summary

This PR adds sink attention support to JAX, implementing three softmax variants:

  • VANILLA_SOFTMAX: Standard softmax (baseline)
  • OFF_BY_ONE_SOFTMAX: Adds +1 to denominator by treating offset logit as 0.0
  • LEARNABLE_SOFTMAX: Adds learnable parameter to denominator, one per attention head

The implementation threads softmax_type and softmax_offset through the entire attention pipeline from high-level Flax modules down to C++ cuDNN bindings. Key architectural decisions:

  • Renamed SoftmaxTypeSoftmaxFusionType to distinguish fusion strategy from attention softmax variant
  • For OFF_BY_ONE, sets softmax_offset = 0.0 (mathematically equivalent to appending a zero logit)
  • For LEARNABLE, creates a learnable parameter with shape [1, num_heads, 1, 1] with proper sharding annotations
  • Implements jax_general_softmax for unfused path: includes offset in max computation and adds exp(offset - x_max) to denominator
  • Gradients for learnable offset are accumulated across data parallel/FSDP shards in backward pass
  • Comprehensive test coverage with reference implementations validates correctness

The implementation is mathematically sound and follows established patterns from PyTorch sink attention support (PR #2148).

Confidence Score: 5/5

  • Safe to merge - well-tested feature addition with comprehensive coverage
  • Implementation is mathematically correct, follows established patterns from PyTorch version, includes comprehensive test coverage across all softmax types, properly handles distributed training with gradient accumulation, and passes all existing tests
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/attention.py 5/5 Added AttnSoftmaxType enum and softmax_offset parameter threading throughout fused attention pipeline. Clean implementation with proper gradient handling.
transformer_engine/jax/flax/transformer.py 5/5 Added softmax_type parameter to attention modules, creates learnable softmax_offset parameter for LEARNABLE_SOFTMAX, sets offset=0.0 for OFF_BY_ONE_SOFTMAX.
transformer_engine/jax/flax/module.py 5/5 Updated Softmax module to distinguish softmax_fusion_type from softmax_type, correctly handles OFF_BY_ONE (offset=0.0) and LEARNABLE variants.
transformer_engine/jax/cpp_extensions/softmax.py 5/5 Added jax_general_softmax function with offset support. Mathematically correct: includes offset in max computation and adds exp(offset-x_max) to denominator.
transformer_engine/jax/cpp_extensions/attention.py 5/5 Added softmax_type to config and primitives, threads softmax_offset through forward/backward passes with proper sharding and gradient accumulation.
transformer_engine/jax/csrc/extensions/attention.cpp 5/5 Updated C++ bindings to accept softmax_type parameter and softmax_offset tensor, properly packs them into aux tensors for cuDNN backend.
tests/jax/test_fused_attn.py 5/5 Comprehensive test coverage for all three softmax types with reference implementation. Tests verify correctness of OFF_BY_ONE (appends zero) and LEARNABLE (appends learned value).

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant MHA as MultiHeadAttention
    participant DPA as DotProductAttention
    participant Softmax as Softmax Module
    participant FusedAttn as FusedAttn (cuDNN)
    participant JAXSoftmax as jax_general_softmax
    
    User->>MHA: call with softmax_type
    Note over MHA: softmax_type: 'vanilla'/'off_by_one'/'learnable'
    
    alt Learnable Softmax
        MHA->>MHA: Create learnable param<br/>softmax_offset [1,h,1,1]
    end
    
    MHA->>DPA: Forward with softmax_type
    
    alt Unfused Path
        DPA->>Softmax: logits, softmax_type, offset
        Note over Softmax: OFF_BY_ONE: offset=0.0<br/>LEARNABLE: offset=learned param
        Softmax->>JAXSoftmax: jax_general_softmax(logits, offset)
        Note over JAXSoftmax: x_max = max(x_max, offset)<br/>denom += exp(offset - x_max)
        JAXSoftmax-->>Softmax: softmax output
        Softmax-->>DPA: attention weights
    else Fused Path (cuDNN)
        DPA->>FusedAttn: Q, K, V, softmax_type, offset
        Note over FusedAttn: Passes softmax_type and offset<br/>to cuDNN backend
        FusedAttn->>FusedAttn: cuDNN fused kernel<br/>with sink attention
        FusedAttn-->>DPA: attention output
    end
    
    DPA-->>MHA: attention output
    
    Note over User,JAXSoftmax: Backward Pass
    
    alt Learnable Softmax
        MHA->>MHA: Accumulate grad_softmax_offset<br/>across DP/FSDP shards
        MHA->>MHA: Update learnable param
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Nov 14, 2025

/te-ci jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator Author

pggPL commented Nov 14, 2025

/te-ci jax L1

Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@pggPL pggPL merged commit 15cefbc into NVIDIA:main Nov 18, 2025
23 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add attention sink to flash attention

5 participants