-
Notifications
You must be signed in to change notification settings - Fork 565
[JAX] Add CP + THD + AG + Striped>1 + SWA support #2379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Add CP + THD + AG + Striped>1 + SWA support #2379
Conversation
547bf11 to
8af3492
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
… to reordering methods Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
…. Fix the incorrect shape in striping inverser reordering Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
…priate mask checks for AG+THD+CP and pick BRCM to be executed per rank. Add Fused Attn Primitive for CP + THD +AG + Striping. Add a method to reorder and all gather segment ids and offsets for kv Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
0799446 to
b54fb3a
Compare
Greptile OverviewGreptile SummaryAdds support for striped load balancing with Key Changes:
Testing:
The implementation enables more flexible load balancing strategies for distributed attention computation, particularly beneficial for handling variable-length sequences with multiple segments per sequence. Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant fused_attn
participant FusedAttnCPStripedFwd
participant Helper as _FusedAttnCPWithAllGatherHelper
participant FusedAttnFwd as FusedAttnFwdPrimitive
participant cuDNN
User->>fused_attn: Call with stripe_size>1, THD layout, CP enabled
fused_attn->>FusedAttnCPStripedFwd: Forward pass with sharded Q, K, V, segment_ids, segment_pos
FusedAttnCPStripedFwd->>Helper: check_supported()
Helper-->>FusedAttnCPStripedFwd: Validate THD+Striped config
FusedAttnCPStripedFwd->>Helper: all_gather_kv(k, v)
Helper->>Helper: lax.all_gather on CP axis
Helper->>Helper: reorder_causal_striped(x, cp_size, seq_dim, True, stripe_size)
Helper-->>FusedAttnCPStripedFwd: k_ag, v_ag
FusedAttnCPStripedFwd->>Helper: all_gather_segment_ids_and_pos()
Helper->>Helper: lax.all_gather + reorder_causal_striped
Helper-->>FusedAttnCPStripedFwd: kv_segment_ids_ag, kv_segment_pos_ag
FusedAttnCPStripedFwd->>Helper: q_seqlens_for_striped_for_rank()
Helper->>Helper: Extract seqlens from sharded q_segment_ids/pos
Helper-->>FusedAttnCPStripedFwd: q_seqlens_for_rank
FusedAttnCPStripedFwd->>Helper: q_seqoffsets_for_striped_for_rank()
Helper-->>FusedAttnCPStripedFwd: q_seq_offsets_for_rank
FusedAttnCPStripedFwd->>Helper: kv_seqlens_for_striped_for_rank()
Helper-->>FusedAttnCPStripedFwd: kv_seqlens_for_rank
FusedAttnCPStripedFwd->>Helper: kv_seqoffsets_for_striped_for_rank()
Helper-->>FusedAttnCPStripedFwd: kv_seq_offsets_for_rank
FusedAttnCPStripedFwd->>FusedAttnFwd: impl(q, k_ag, v_ag, seqlens, offsets)
FusedAttnFwd->>cuDNN: Execute fused attention kernel
cuDNN-->>FusedAttnFwd: output, softmax_aux, rng_state
FusedAttnFwd-->>FusedAttnCPStripedFwd: Attention output
FusedAttnCPStripedFwd-->>fused_attn: output
fused_attn-->>User: Final output
Note over User,cuDNN: Backward pass follows similar pattern with<br/>reduce_scatter for gradients
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 3 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
|
/te-ci jax L0 L1 L2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Pipeline: 39181941 passes all tests (except lint test - to be fixed in a separate commit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Clean up test code in TE common Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Clean up debug statements Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Fix type on fused attn tests Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
…b and not b+1 needed by cuDNN Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
578aaa5 to
51440db
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| first_is_segment, # First valid element starts a segment | ||
| (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) | ||
| | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), | ||
| (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Duplicate concatenation condition causing shape mismatch. Line 1547 repeats the same condition as line 1546, which will result in segment_changes having incorrect dimensions (31 elements instead of expected 16). This will cause runtime errors when the array is used downstream.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Fix incorrect greptile change Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
9fc5a74 to
9b5280b
Compare
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 L2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Pipeline: 39192121 passes all L0, L1 and L2 tests |
| # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. | ||
| # Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes. | ||
| pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), | ||
| pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update pytest param id to include CPx2
| if not load_balanced and ( | ||
| cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER | ||
| ): | ||
| pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
today I learned {var=} does var={var} in f-strings. useful!
| @pytest.mark.parametrize( | ||
| "qkv_layout, attn_mask_type", | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we skipping a mask here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because the new test below is explicitly testing the index=-1 mask case? If so, can we remove this last mask from the DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS and create a new list of masks for striped below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will revert this back as it makes no functional difference
The THD data layout (which we are trying to explicitly skip with this change) is anyways being skipped as part of the tests, so it isn't fully needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, the logic to skip this for THD data layout in impl_test_context_parallel_attn() has changed and so just reverting the change will not be enough to skip this for THD data types, so I will add an additional check in test_context_parallel_allgather_attn_shardy() and test_context_parallel_allgather_attn() to skip for THD layouts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In either case it is just a different way to put things - the older change was kipping THD layouts by filtering them via DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1] but no we will pass DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS and instead filter within test_context_parallel_allgather_attn_shardy() and test_context_parallel_allgather_attn() explicitly
| ) | ||
| @pytest.mark.parametrize( | ||
| "stripe_size", | ||
| [pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't have to be in this PR since I'm guessing this applies to more tests than updated here, but TE/JAX has a pytest util called pytest_parametrize_wrapper that automatically converts common types into string representations like stripe_64 so you don't need to list these manually
| stripe_size, | ||
| num_segments_per_seq, | ||
| ): | ||
| if window_size != (-1, -1) and not qkv_layout.is_thd(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So BSHD with window_size = (-1, -1) is supported? Does that mean window_size (-1, -1) means don't do any striping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update this check to : if window_size != (-1, -1) and not qkv_layout.is_thd():
Any BSHD layouts should be skipped
Only THD with or without SWA should be allowed
Thanks for pointing it out
I believe my parametrization of the inputs: [DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1]], did not trigger a BSHD mask anyways so this check never got triggerred.
| @pytest.mark.parametrize( | ||
| "qkv_layout, attn_mask_type", | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above with this [:-1] slicing of the masks. If this array of masks [:-1] is for one feature and [-1:] is for this new CP striped attention feature, let's split them into two constant lists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will revert this back as it makes no functional difference
The THD data layout (which we are trying to explicitly skip with this change) is anyways being skipped as part of the tests, so it isn't fully needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/NVIDIA/TransformerEngine/pull/2379/files#r2566226953
All the comments in the earlier thread apply here
| @jax.jit | ||
| def get_seqlens_and_offsets(segment_ids): | ||
| batch, max_seqlen = segment_ids.shape | ||
| # TODO: should this be max_seqlen + 1 ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reminder about TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am removing this for now as it is not related to this PR.
Rather a general note for me - I'll take a look at this outside the scope of this PR
| if self.qkv_layout.is_thd(): | ||
| self.num_segments_per_seq = 2 | ||
| # If using default num segments of 0, set to 2 | ||
| if self.num_segments_per_seq == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the default overridden here? Can we make the default None instead of 0 to indicate it's not yet populated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I will make this change. It was in my TODO but missed it.
Better that ways
| def test_forward(self): | ||
| """ | ||
| Test forward without JIT | ||
| Test forward with JITted primitive and unJITted reference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason the reference JAX impl is unjitted? It should be equivalent JIT'd and could speed up our tests. Ack, this is unrelated to this PR's focus, just mentioning it so we can discuss if this could be improved in a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsure tbh. I was expecting it to be JItted as well but when I noticed it wasn't I thought it should at least be explicit in the doc string to address later. Agree with you
| QKVLayout.THD_THD_THD, AttnMaskType.PADDING_CAUSAL_MASK, id="THD_SEPARATE-PADDING_CAUSAL" | ||
| ), | ||
| ] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for measuring the test runtime! What is our total L1 test time at currently? We were intermittently failing previously but reduced attention tests by ~15mins iirc. This is now increasing by 11mins, so I'm concerned we may hit timeouts again.
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_context_parallel_allgather_striped_attn | 128x | 719.65s | avg: 5.62s
================================================================================
TOTAL RUNTIME | | 719.65s |
================================================================================
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think with the merging of this PR we will be around the 80-85 mins mark for the L1 tests, so I do not expect to hit timeouts (closer to 120 mins)
I believe after your changes to reduce the L1 test timing we had stopped hitting the limit anyways. My reduction to attention tests was more of an additional effort and we reduced total time to closer to 70 mins if I remember right.
Now with sink attention and this PR, we should have a total increase of about ~15 mins so I expect to hit ~85 mins. Nonethless, I will report the findings from the last CI pipeline I run.
I do not think we are alarmingly close but a clean up in the future would only help.
| ) | ||
| # Do not allow CP + AG + THD + Striped with NO_MASK | ||
| if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd(): | ||
| raise ValueError(f"{header} only supports CAUSAL_MASK for THD types") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We also support PADDING_CAUSAL_MASK for THD in addition to CAUSUAL_MASK, based on lines 1301-1302, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support PADDING_MASK as well, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should change to :
if self.config.attn_mask_type is not AttnMaskType.PADDING_CAUSAL_MASK and self.config.qkv_layout.is_thd():
raise ValueError(f"{header} only supports PADDING_CAUSAL_MASK for THD types")
This function is in the AG helper class which was originally only supporting BSHD + AG + Dual Chunk.
I believe the AG BSHD supports allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
This PR only adds support for AttnMaskType.PADDING_CAUSAL_MASK as it would be the larger case but extending support to AttnMaskType.CAUSAL_MASK in the future should not be too hard
To summarize, the CP + AG has two supported cases:
- BSHD + AG + Dual Chunk + (no mask + causal)
- THD + AG + Striped>1 + (padding causal) - this PR
|
I think it's Did you check the CP path in the CP vs non-CP comparison in the unit tests really uses the cuDNN backend? Just trying to make sure it's not using the unfused path. |
| context_parallel_axis: str = "", | ||
| context_checkpoint_name: str = "context", | ||
| softmax_offset: Optional[jnp.ndarray] = None, | ||
| stripe_size: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the default be 1 or 0?
Sure let me mention that in one of the comments
Yes it does use the fused cudnn backend for sure. |
Description
Add support for Striped>1 Reordering in TE JAX
Add support for CP + THD + AG + Striped>1 + SWA in TE JAX
Below is an example of what the pattern would deconstruct into after load balancing and striping
Unbalanced matrix with 4 segments (64*64)

Balanced matrix with stripe_size=4, cp_size=4 (each rank works on 16*64 pattern) and SW

CP0: THD with 5 segments using PBRCM and SW
CP1: THD with 5 segments using PBRCM and SW
CP2: THD with 4 segments using PBRCM and SW
CP3: THD with 5 segments using PBRCM and SW
Testing
Below are the timings for the minimal new test cases added
NVTE_JAX_UNITTEST_LEVEL="L1" NVTE_JAX_TEST_TIMING=1 pytest -k "test_context_parallel_allgather_striped_attn" tests/jax/test_distributed_fused_attn.pyNVTE_JAX_UNITTEST_LEVEL="L2" NVTE_JAX_TEST_TIMING=1 pytest -k "test_context_parallel_allgather_striped_attn" tests/jax/test_distributed_fused_attn.pyNVTE_JAX_TEST_TIMING=1 NVTE_JAX_UNITTEST_LEVEL="L1" pytest -k "TestReorderCausalLoadBalancing" tests/jax/test_distributed_fused_attn.pyChanges
Type of change
TODO
Checklist: