Skip to content

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Nov 13, 2025

Description

Add support for Striped>1 Reordering in TE JAX
Add support for CP + THD + AG + Striped>1 + SWA in TE JAX

Below is an example of what the pattern would deconstruct into after load balancing and striping

Unbalanced matrix with 4 segments (64*64)
Screenshot 2025-11-25 at 10 02 30 PM

Balanced matrix with stripe_size=4, cp_size=4 (each rank works on 16*64 pattern) and SW
Screenshot 2025-11-25 at 10 05 31 PM

CP0: THD with 5 segments using PBRCM and SW
CP1: THD with 5 segments using PBRCM and SW
CP2: THD with 4 segments using PBRCM and SW
CP3: THD with 5 segments using PBRCM and SW

Testing

Below are the timings for the minimal new test cases added

  • Added 128 L1 tests ( 104 passed, 24 skipped) for CP + THD + AG + Striped>1 + SWA
    NVTE_JAX_UNITTEST_LEVEL="L1" NVTE_JAX_TEST_TIMING=1 pytest -k "test_context_parallel_allgather_striped_attn" tests/jax/test_distributed_fused_attn.py
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_context_parallel_allgather_striped_attn                 | 128x |  719.65s | avg:   5.62s
================================================================================
TOTAL RUNTIME                                                |      |  719.65s |
================================================================================
  • Added 64 L2 tests (48 passed, 16 skipped) for CP + THD + AG + Striped>1 + SWA
    NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_JAX_TEST_TIMING=1 pytest -k "test_context_parallel_allgather_striped_attn" tests/jax/test_distributed_fused_attn.py
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_context_parallel_allgather_striped_attn                 |  64x |  123.81s | avg:   1.93s
================================================================================
TOTAL RUNTIME                                                |      |  123.81s |
================================================================================
  • Added 12 L1 tests for Reordering with Striped>1
    NVTE_JAX_TEST_TIMING=1 NVTE_JAX_UNITTEST_LEVEL="L1" pytest -k "TestReorderCausalLoadBalancing" tests/jax/test_distributed_fused_attn.py
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  27x |    3.32s | avg:   0.12s
================================================================================
TOTAL RUNTIME                                                |      |    3.32s |
================================================================================
  • I ran an additional testing sweep with randomized num_segments_per_seq = [2,3,4,9,11,14] and randomized seed = [12, 30, 42] and max_seqlens = [64, 1024, 2048] with all-to-all mapping between these three set of of params to ensure robustness and I saw no failures on B200x8

Changes

  • Added load balancing/reordering logice for stripe_size > 1
  • For the CP+THD+AG+Striped>1+SWA primitive, the different types of deconstructed patterns per CP could not be fully expressed correctly by passing sharded seg ids and pos to get_seqlens_and_offsets() and hence helper functions had to be created for constructing the seqlens and offsets directly from the sharded seg ids and seg pos

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

TODO

  • Optimize and compare this primitive perf to CP+THD+RingP2P+Striped+SWA for different data layouts

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Nov 13, 2025
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feature/striped-height-cp-thd branch from 547bf11 to 8af3492 Compare November 21, 2025 06:33
KshitijLakhani and others added 6 commits November 23, 2025 01:22
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
… to reordering methods

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
…. Fix the incorrect shape in striping inverser reordering

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
…priate mask checks for AG+THD+CP and pick BRCM to be executed per rank. Add Fused Attn Primitive for CP + THD +AG + Striping. Add a method to reorder and all gather segment ids and offsets for kv

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feature/striped-height-cp-thd branch 7 times, most recently from 0799446 to b54fb3a Compare November 25, 2025 20:20
@KshitijLakhani KshitijLakhani marked this pull request as ready for review November 25, 2025 23:35
@KshitijLakhani KshitijLakhani changed the title CP + THD + AG + Striped {JAX] Add CP + THD + AG + Striped>1 + SWA support Nov 25, 2025
@KshitijLakhani KshitijLakhani changed the title {JAX] Add CP + THD + AG + Striped>1 + SWA support [JAX] Add CP + THD + AG + Striped>1 + SWA support Nov 25, 2025
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 25, 2025

Greptile Overview

Greptile Summary

Adds support for striped load balancing with stripe_size > 1 for Context Parallel (CP) + Tensor-Hidden-Dimension (THD) + AllGather (AG) + Sliding Window Attention (SWA) configurations in JAX.

Key Changes:

  • Extended reorder_causal_striped() to support configurable stripe_size parameter for fine-grained load balancing patterns
  • Implemented new primitives FusedAttnCPStripedWithAllGatherFwdPrimitive and FusedAttnCPStripedWithAllGatherBwdPrimitive for striped attention with CP
  • Added helper functions to compute seqlens/offsets from segment IDs and positions: q_seqlens_for_striped_for_rank(), q_seqoffsets_for_striped_for_rank(), kv_seqlens_for_striped_for_rank(), kv_seqoffsets_for_striped_for_rank()
  • Updated API to thread stripe_size parameter through fused_attn() and related functions
  • Added 192 new test cases covering various stripe sizes, window sizes, and segment configurations

Testing:

  • 128 L1 tests and 64 L2 tests for CP+THD+AG+Striped+SWA configurations
  • Additional tests for reordering with stripe_size variations (1, 4)
  • Tests validate correctness across different stripe sizes (64, 128), window sizes, and segment counts (2, 11)

The implementation enables more flexible load balancing strategies for distributed attention computation, particularly beneficial for handling variable-length sequences with multiple segments per sequence.

Confidence Score: 4/5

  • This PR is safe to merge with moderate risk - the implementation is well-tested with 192 new test cases, but the complex segment ID/position manipulation logic requires careful validation
  • Score reflects solid test coverage (192 new tests with passing results), clear implementation of the striped load balancing logic, and well-documented helper functions. However, the complexity of the segment manipulation functions (q_seqlens_for_striped_for_rank, etc.) and the intricate reshaping/swapping operations in reorder_causal_striped introduce moderate risk that could surface in edge cases not covered by tests
  • transformer_engine/jax/cpp_extensions/attention.py - particularly the helper functions computing seqlens/offsets from segment IDs (lines 1516-1726), as these involve complex array manipulation logic

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/cpp_extensions/attention.py 4/5 Added stripe_size > 1 support for CP+THD+AG with extensive helper functions for computing seqlens and offsets from segment IDs/positions. Implementation includes load balancing logic and new primitives FusedAttnCPStripedWithAllGatherFwdPrimitive/BwdPrimitive
transformer_engine/jax/attention.py 5/5 Added stripe_size parameter to fused_attn API and reorder functions to enable stripe_size > 1 for CP+THD+AG patterns
tests/jax/test_distributed_fused_attn.py 5/5 Added 192 new test cases for CP+THD+AG+Striped>1+SWA (128 L1, 64 L2) and extended reorder tests to support stripe_size parameter and THD format

Sequence Diagram

sequenceDiagram
    participant User
    participant fused_attn
    participant FusedAttnCPStripedFwd
    participant Helper as _FusedAttnCPWithAllGatherHelper
    participant FusedAttnFwd as FusedAttnFwdPrimitive
    participant cuDNN

    User->>fused_attn: Call with stripe_size>1, THD layout, CP enabled
    fused_attn->>FusedAttnCPStripedFwd: Forward pass with sharded Q, K, V, segment_ids, segment_pos
    
    FusedAttnCPStripedFwd->>Helper: check_supported()
    Helper-->>FusedAttnCPStripedFwd: Validate THD+Striped config
    
    FusedAttnCPStripedFwd->>Helper: all_gather_kv(k, v)
    Helper->>Helper: lax.all_gather on CP axis
    Helper->>Helper: reorder_causal_striped(x, cp_size, seq_dim, True, stripe_size)
    Helper-->>FusedAttnCPStripedFwd: k_ag, v_ag
    
    FusedAttnCPStripedFwd->>Helper: all_gather_segment_ids_and_pos()
    Helper->>Helper: lax.all_gather + reorder_causal_striped
    Helper-->>FusedAttnCPStripedFwd: kv_segment_ids_ag, kv_segment_pos_ag
    
    FusedAttnCPStripedFwd->>Helper: q_seqlens_for_striped_for_rank()
    Helper->>Helper: Extract seqlens from sharded q_segment_ids/pos
    Helper-->>FusedAttnCPStripedFwd: q_seqlens_for_rank
    
    FusedAttnCPStripedFwd->>Helper: q_seqoffsets_for_striped_for_rank()
    Helper-->>FusedAttnCPStripedFwd: q_seq_offsets_for_rank
    
    FusedAttnCPStripedFwd->>Helper: kv_seqlens_for_striped_for_rank()
    Helper-->>FusedAttnCPStripedFwd: kv_seqlens_for_rank
    
    FusedAttnCPStripedFwd->>Helper: kv_seqoffsets_for_striped_for_rank()
    Helper-->>FusedAttnCPStripedFwd: kv_seq_offsets_for_rank
    
    FusedAttnCPStripedFwd->>FusedAttnFwd: impl(q, k_ag, v_ag, seqlens, offsets)
    FusedAttnFwd->>cuDNN: Execute fused attention kernel
    cuDNN-->>FusedAttnFwd: output, softmax_aux, rng_state
    FusedAttnFwd-->>FusedAttnCPStripedFwd: Attention output
    
    FusedAttnCPStripedFwd-->>fused_attn: output
    fused_attn-->>User: Final output

    Note over User,cuDNN: Backward pass follows similar pattern with<br/>reduce_scatter for gradients
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

KshitijLakhani commented Nov 26, 2025

/te-ci jax L0 L1 L2

Pipeline: 39181941 passes all tests (except lint test - to be fixed in a separate commit)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

KshitijLakhani and others added 12 commits November 26, 2025 06:44
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

Clean up test code in TE common

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

Clean up debug statements

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

Fix type on fused attn tests

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
…b and not b+1 needed by cuDNN

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feature/striped-height-cp-thd branch from 578aaa5 to 51440db Compare November 26, 2025 06:46
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

first_is_segment, # First valid element starts a segment
(valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1])
| (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
(valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Duplicate concatenation condition causing shape mismatch. Line 1547 repeats the same condition as line 1546, which will result in segment_changes having incorrect dimensions (31 elements instead of expected 16). This will cause runtime errors when the array is used downstream.

KshitijLakhani and others added 2 commits November 26, 2025 07:29
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

Fix incorrect greptile change

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feature/striped-height-cp-thd branch from 9fc5a74 to 9b5280b Compare November 26, 2025 07:30
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

Pipeline: 39192121 passes all L0, L1 and L2 tests

# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
# Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update pytest param id to include CPx2

if not load_balanced and (
cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER
):
pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

today I learned {var=} does var={var} in f-strings. useful!

@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we skipping a mask here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because the new test below is explicitly testing the index=-1 mask case? If so, can we remove this last mask from the DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS and create a new list of masks for striped below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will revert this back as it makes no functional difference
The THD data layout (which we are trying to explicitly skip with this change) is anyways being skipped as part of the tests, so it isn't fully needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, the logic to skip this for THD data layout in impl_test_context_parallel_attn() has changed and so just reverting the change will not be enough to skip this for THD data types, so I will add an additional check in test_context_parallel_allgather_attn_shardy() and test_context_parallel_allgather_attn() to skip for THD layouts

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In either case it is just a different way to put things - the older change was kipping THD layouts by filtering them via DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1] but no we will pass DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS and instead filter within test_context_parallel_allgather_attn_shardy() and test_context_parallel_allgather_attn() explicitly

)
@pytest.mark.parametrize(
"stripe_size",
[pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be in this PR since I'm guessing this applies to more tests than updated here, but TE/JAX has a pytest util called pytest_parametrize_wrapper that automatically converts common types into string representations like stripe_64 so you don't need to list these manually

stripe_size,
num_segments_per_seq,
):
if window_size != (-1, -1) and not qkv_layout.is_thd():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So BSHD with window_size = (-1, -1) is supported? Does that mean window_size (-1, -1) means don't do any striping?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update this check to : if window_size != (-1, -1) and not qkv_layout.is_thd():
Any BSHD layouts should be skipped
Only THD with or without SWA should be allowed
Thanks for pointing it out

I believe my parametrization of the inputs: [DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1]], did not trigger a BSHD mask anyways so this check never got triggerred.

@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above with this [:-1] slicing of the masks. If this array of masks [:-1] is for one feature and [-1:] is for this new CP striped attention feature, let's split them into two constant lists.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will revert this back as it makes no functional difference
The THD data layout (which we are trying to explicitly skip with this change) is anyways being skipped as part of the tests, so it isn't fully needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/NVIDIA/TransformerEngine/pull/2379/files#r2566226953
All the comments in the earlier thread apply here

@jax.jit
def get_seqlens_and_offsets(segment_ids):
batch, max_seqlen = segment_ids.shape
# TODO: should this be max_seqlen + 1 ?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reminder about TODO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am removing this for now as it is not related to this PR.
Rather a general note for me - I'll take a look at this outside the scope of this PR

if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
# If using default num segments of 0, set to 2
if self.num_segments_per_seq == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the default overridden here? Can we make the default None instead of 0 to indicate it's not yet populated?

Copy link
Collaborator Author

@KshitijLakhani KshitijLakhani Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will make this change. It was in my TODO but missed it.
Better that ways

def test_forward(self):
"""
Test forward without JIT
Test forward with JITted primitive and unJITted reference
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the reference JAX impl is unjitted? It should be equivalent JIT'd and could speed up our tests. Ack, this is unrelated to this PR's focus, just mentioning it so we can discuss if this could be improved in a separate PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure tbh. I was expecting it to be JItted as well but when I noticed it wasn't I thought it should at least be explicit in the doc string to address later. Agree with you

QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL"
),
]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for measuring the test runtime! What is our total L1 test time at currently? We were intermittently failing previously but reduced attention tests by ~15mins iirc. This is now increasing by 11mins, so I'm concerned we may hit timeouts again.

================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_context_parallel_allgather_striped_attn                 | 128x |  719.65s | avg:   5.62s
================================================================================
TOTAL RUNTIME                                                |      |  719.65s |
================================================================================

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with the merging of this PR we will be around the 80-85 mins mark for the L1 tests, so I do not expect to hit timeouts (closer to 120 mins)

I believe after your changes to reduce the L1 test timing we had stopped hitting the limit anyways. My reduction to attention tests was more of an additional effort and we reduced total time to closer to 70 mins if I remember right.

Now with sink attention and this PR, we should have a total increase of about ~15 mins so I expect to hit ~85 mins. Nonethless, I will report the findings from the last CI pipeline I run.

I do not think we are alarmingly close but a clean up in the future would only help.

)
# Do not allow CP + AG + THD + Striped with NO_MASK
if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd():
raise ValueError(f"{header} only supports CAUSAL_MASK for THD types")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We also support PADDING_CAUSAL_MASK for THD in addition to CAUSUAL_MASK, based on lines 1301-1302, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support PADDING_MASK as well, right?

Copy link
Collaborator Author

@KshitijLakhani KshitijLakhani Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should change to :

if self.config.attn_mask_type is not AttnMaskType.PADDING_CAUSAL_MASK and self.config.qkv_layout.is_thd():
raise ValueError(f"{header} only supports PADDING_CAUSAL_MASK for THD types")

This function is in the AG helper class which was originally only supporting BSHD + AG + Dual Chunk.
I believe the AG BSHD supports allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]

This PR only adds support for AttnMaskType.PADDING_CAUSAL_MASK as it would be the larger case but extending support to AttnMaskType.CAUSAL_MASK in the future should not be too hard

To summarize, the CP + AG has two supported cases:

  1. BSHD + AG + Dual Chunk + (no mask + causal)
  2. THD + AG + Striped>1 + (padding causal) - this PR

@cyanguwa
Copy link
Collaborator

cyanguwa commented Nov 26, 2025

I think it's Striped>=1 that we support, right? We can probably mention that >=128 is recommended as well somewhere, for performance reasons.

Did you check the CP path in the CP vs non-CP comparison in the unit tests really uses the cuDNN backend? Just trying to make sure it's not using the unfused path.

context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
stripe_size: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the default be 1 or 0?

@KshitijLakhani
Copy link
Collaborator Author

I think it's Striped>=1 that we support, right? We can probably mention that >=128 is recommended as well somewhere, for performance reasons.

Sure let me mention that in one of the comments

Did you check the CP path in the CP vs non-CP comparison in the unit tests really uses the cuDNN backend? Just trying to make sure it's not using the unfused path.

Yes it does use the fused cudnn backend for sure.
The entire debugging process relied on checking seqlens and offsets being passes to the fused cuDNN API in TE common

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants