Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
def57db
Add generic stripe_height support for load balancing
KshitijLakhani Oct 15, 2025
e888025
Fix imports in test for deprecated jax.experimental.pjit
KshitijLakhani Oct 15, 2025
1fc957a
Add test case for stripe_height greater than 1. Add stripe_height arg…
Oct 16, 2025
6ade6a4
Add Striped 1 and 4 test cases. Refactor the Load Balancing test case…
Oct 16, 2025
e8067c6
Modify test code for CP + AG + THD + stripe height greater than 1
Oct 23, 2025
a33301a
Add stripe_height arg to fused attn and fused attn fwd API. Add appro…
Oct 23, 2025
4438298
TMP: Throwaway testing commit
Oct 23, 2025
3da94e5
Add comments in primitive registration process
Nov 5, 2025
a51b7d9
TMP: Throwaway test commit
KshitijLakhani Nov 13, 2025
fcee4f4
Undoing incorrect rebase/merge leftovers
Nov 21, 2025
5dfdaf5
TMP: Throwaway test commits
Nov 21, 2025
f6fb305
Add support for calculating q and kv seqlens and offsets per rank for…
Nov 21, 2025
104b51e
Augment jax primitive register code comments
Nov 21, 2025
2494084
Fix the array sizes and padding values returned for seqlens and offse…
KshitijLakhani Nov 23, 2025
c6e5966
Add support in new primitive for softmax_offset related changes. Put …
Nov 24, 2025
298ee6b
Add new set of helper functions for seqlens and seqoffsets fo AG+THD+…
KshitijLakhani Nov 24, 2025
1fa57b4
Add backward primitive for CP+THD+AG+Striped>1
KshitijLakhani Nov 25, 2025
7f205d5
Modify tests for backward primitive for CP+THD+AG+Striped>1
KshitijLakhani Nov 25, 2025
3115064
Move stripe_height along with other static args in fused_attn_bwd rul…
Nov 25, 2025
0391f41
Code clean up: remove older version for calculating seqlens and offse…
pre-commit-ci[bot] Nov 25, 2025
94af413
Add test for CP+THD+AG+Striped>1
KshitijLakhani Nov 25, 2025
a788a1d
Fix missing var
KshitijLakhani Nov 25, 2025
60191eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2025
d29f59a
Add SWA tests for AG+Striped>1+CP+THD+SWA
Nov 25, 2025
b01340a
Restoring test code
Nov 25, 2025
c5921af
Remove assert preventing SWA code path in CP+AG+Striped primitive
Nov 25, 2025
3bb1d5a
Parametrize num_segments_per_seq in tests
KshitijLakhani Nov 25, 2025
78fec5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2025
69dad1a
Clean up test code
KshitijLakhani Nov 25, 2025
2166305
Rename stripe_height to stripe_size
KshitijLakhani Nov 25, 2025
c5e0d6f
Code clean up and add additional comments
KshitijLakhani Nov 25, 2025
ab81a30
nit: Apply suggestions from code review
KshitijLakhani Nov 25, 2025
51440db
Fix seqoffsets length to be passed onto FusedAttn primitive as it is …
Nov 26, 2025
5e014af
Remove commented code
KshitijLakhani Nov 26, 2025
9b5280b
Fix linting issues
KshitijLakhani Nov 26, 2025
8841f5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 98 additions & 21 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_cross_attn(
]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for measuring the test runtime! What is our total L1 test time at currently? We were intermittently failing previously but reduced attention tests by ~15mins iirc. This is now increasing by 11mins, so I'm concerned we may hit timeouts again.

================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_context_parallel_allgather_striped_attn                 | 128x |  719.65s | avg:   5.62s
================================================================================
TOTAL RUNTIME                                                |      |  719.65s |
================================================================================

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with the merging of this PR we will be around the 80-85 mins mark for the L1 tests, so I do not expect to hit timeouts (closer to 120 mins)

I believe after your changes to reduce the L1 test timing we had stopped hitting the limit anyways. My reduction to attention tests was more of an additional effort and we reduced total time to closer to 70 mins if I remember right.

Now with sink attention and this PR, we should have a total increase of about ~15 mins so I expect to hit ~85 mins. Nonethless, I will report the findings from the last CI pipeline I run.

I do not think we are alarmingly close but a clean up in the future would only help.

DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
# Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update pytest param id to include CPx2

]
Expand All @@ -351,12 +351,14 @@ def impl_test_context_parallel_attn(
use_shardy,
use_scan_ring=False,
window_size=None,
stripe_size=0,
num_segments_per_seq=0,
):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
pytest.skip("THD doesn't support all gather context parallelism.")
if not load_balanced and cp_strategy == CPStrategy.RING:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
if not load_balanced and (
cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER
):
pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

today I learned {var=} does var={var} in f-strings. useful!


assert not use_scan_ring or cp_strategy == CPStrategy.RING

Expand All @@ -382,7 +384,6 @@ def impl_test_context_parallel_attn(
data_shape = batch, seqlen, num_head, hidden

num_kv_heads = num_head // kv_groups

runner = FusedAttnRunner(
batch,
seqlen,
Expand All @@ -407,6 +408,8 @@ def impl_test_context_parallel_attn(
mesh_resource=mesh_resource,
cp_strategy=cp_strategy,
cp_load_balanced=load_balanced,
stripe_size=stripe_size,
num_segments_per_seq=num_segments_per_seq,
)

def check_has_backend_for_mask(mask_type):
Expand Down Expand Up @@ -457,7 +460,7 @@ def check_has_backend_for_mask(mask_type):
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we skipping a mask here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because the new test below is explicitly testing the index=-1 mask case? If so, can we remove this last mask from the DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS and create a new list of masks for striped below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will revert this back as it makes no functional difference
The THD data layout (which we are trying to explicitly skip with this change) is anyways being skipped as part of the tests, so it isn't fully needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, the logic to skip this for THD data layout in impl_test_context_parallel_attn() has changed and so just reverting the change will not be enough to skip this for THD data types, so I will add an additional check in test_context_parallel_allgather_attn_shardy() and test_context_parallel_allgather_attn() to skip for THD layouts

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In either case it is just a different way to put things - the older change was kipping THD layouts by filtering them via DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1] but no we will pass DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS and instead filter within test_context_parallel_allgather_attn_shardy() and test_context_parallel_allgather_attn() explicitly

)
def test_context_parallel_allgather_attn_shardy(
self,
Expand Down Expand Up @@ -486,6 +489,72 @@ def test_context_parallel_allgather_attn_shardy(
use_shardy=True,
)

@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
[DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1]],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED")],
)
@pytest.mark.parametrize(
"stripe_size",
[pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be in this PR since I'm guessing this applies to more tests than updated here, but TE/JAX has a pytest util called pytest_parametrize_wrapper that automatically converts common types into string representations like stripe_64 so you don't need to list these manually

)
@pytest.mark.parametrize(
"window_size",
[
pytest.param((-1, -1), id="window_size(-1, -1)"),
pytest.param((5, 0), id="window_size(5, 0)"),
],
)
@pytest.mark.parametrize(
"num_segments_per_seq",
[pytest.param(2, id="SEG-2"), pytest.param(11, id="SEG-11")],
)
def test_context_parallel_allgather_striped_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
window_size,
stripe_size,
num_segments_per_seq,
):
if window_size != (-1, -1) and not qkv_layout.is_thd():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So BSHD with window_size = (-1, -1) is supported? Does that mean window_size (-1, -1) means don't do any striping?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update this check to : if window_size != (-1, -1) and not qkv_layout.is_thd():
Any BSHD layouts should be skipped
Only THD with or without SWA should be allowed
Thanks for pointing it out

I believe my parametrization of the inputs: [DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1]], did not trigger a BSHD mask anyways so this check never got triggerred.

pytest.skip("Sliding window attention is only supported for THD layout")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
use_shardy=False,
window_size=window_size,
stripe_size=stripe_size,
num_segments_per_seq=num_segments_per_seq,
)

@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
Expand All @@ -495,7 +564,7 @@ def test_context_parallel_allgather_attn_shardy(
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above with this [:-1] slicing of the masks. If this array of masks [:-1] is for one feature and [-1:] is for this new CP striped attention feature, let's split them into two constant lists.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will revert this back as it makes no functional difference
The THD data layout (which we are trying to explicitly skip with this change) is anyways being skipped as part of the tests, so it isn't fully needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/NVIDIA/TransformerEngine/pull/2379/files#r2566226953
All the comments in the earlier thread apply here

)
@pytest.mark.parametrize(
"load_balanced",
Expand Down Expand Up @@ -538,7 +607,7 @@ def test_context_parallel_allgather_attn(
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1],
)
@pytest.mark.parametrize(
"load_balanced",
Expand Down Expand Up @@ -602,7 +671,7 @@ def test_context_parallel_ring_attn(
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1],
)
def test_context_parallel_ring_attn_shardy(
self,
Expand Down Expand Up @@ -639,31 +708,39 @@ def test_context_parallel_ring_attn_shardy(
"L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}

REORDER_STRATEGY = [
pytest.param(ReorderStrategy.DualChunkSwap, None, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, 1, id="Striped-1"),
pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"),
]


class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD])
@pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
"reorder_strategy, stripe_size",
REORDER_STRATEGY,
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_size):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
seq_dim = 0

if reorder_strategy == ReorderStrategy.Striped:
seq_lens = shape[seq_dim]
if seq_lens < (cp_size * stripe_size):
pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_size=}")

ref = tensor.copy()

reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])

reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_size)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_size)

assert jnp.array_equal(inversed, ref)
13 changes: 11 additions & 2 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def make_mask(
@jax.jit
def get_seqlens_and_offsets(segment_ids):
batch, max_seqlen = segment_ids.shape
# TODO: should this be max_seqlen + 1 ?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reminder about TODO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am removing this for now as it is not related to this PR.
Rather a general note for me - I'll take a look at this outside the scope of this PR

bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
seqlens = seqlens_with_zero[..., 1:]
Expand Down Expand Up @@ -352,6 +353,8 @@ class FusedAttnRunner:
bias_shape: BiasShape
window_size: Tuple[int, int]
seq_desc_format: SeqDescFormat
stripe_size: int = 0
num_segments_per_seq: int = 0

# Specifies sharding resources for distributed tests
number_of_devices: int = 1
Expand Down Expand Up @@ -577,7 +580,9 @@ def generate_random_segment_ids(
return segment_ids, segment_pos, segment_pad

if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
# If using default num segments of 0, set to 2
if self.num_segments_per_seq == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the default overridden here? Can we make the default None instead of 0 to indicate it's not yet populated?

Copy link
Collaborator Author

@KshitijLakhani KshitijLakhani Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will make this change. It was in my TODO but missed it.
Better that ways

self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
Expand Down Expand Up @@ -635,12 +640,14 @@ def generate_random_segment_ids(
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
stripe_size=self.stripe_size,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
stripe_size=self.stripe_size,
)
else:
# no-ops for non cp or non load balanced
Expand Down Expand Up @@ -771,7 +778,7 @@ def to_dp_shardings(x):

def test_forward(self):
"""
Test forward without JIT
Test forward with JITted primitive and unJITted reference
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the reference JAX impl is unjitted? It should be equivalent JIT'd and could speed up our tests. Ack, this is unrelated to this PR's focus, just mentioning it so we can discuss if this could be improved in a separate PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure tbh. I was expecting it to be JItted as well but when I noticed it wasn't I thought it should at least be explicit in the doc string to address later. Agree with you

"""
self._setup_inputs()

Expand Down Expand Up @@ -801,6 +808,7 @@ def test_forward(self):
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
"stripe_size": self.stripe_size,
}

customcall_fused_dpa_jit = jit(
Expand Down Expand Up @@ -896,6 +904,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
"stripe_size": self.stripe_size,
}

# We can compute dBias only for the [1, h, s, s] layout
Expand Down
26 changes: 21 additions & 5 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,23 +386,27 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen


def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int):
def reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int = 1
):
"""Reorders a tensor for load balancing the compute of causal attention."""
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False)
# stripe_size > 1 is only supported for CP+THD+AG+Striped
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False, stripe_size)
raise ValueError(f"Unsupported {strategy=}")


def inverse_reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int = 1
):
"""Inverse operation of `reorder_causal_load_balancing`."""
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True)
# stripe_size > 1 is only supported for CP+THD+AG+Striped
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True, stripe_size)
raise ValueError(f"Unsupported {strategy=}")


Expand Down Expand Up @@ -988,7 +992,7 @@ def fused_attn_thd(
return output


@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
Expand All @@ -1008,6 +1012,7 @@ def _fused_attn(
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
context_checkpoint_name: str = "context",
stripe_size: int = 0,
):
output, _ = _fused_attn_fwd_rule(
qkv,
Expand All @@ -1028,6 +1033,7 @@ def _fused_attn(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
stripe_size=stripe_size,
)
return output

Expand All @@ -1051,6 +1057,7 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
stripe_size,
):
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
Expand All @@ -1070,6 +1077,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
stripe_size=stripe_size,
)
output = checkpoint_name(output, context_checkpoint_name)
softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
Expand Down Expand Up @@ -1099,6 +1107,7 @@ def _fused_attn_bwd_rule(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
stripe_size,
ctx,
dz,
):
Expand Down Expand Up @@ -1133,6 +1142,7 @@ def _fused_attn_bwd_rule(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
stripe_size=stripe_size,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
Expand Down Expand Up @@ -1169,6 +1179,7 @@ def fused_attn(
context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
stripe_size: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the default be 1 or 0?

):
"""
Perform cuDNN fused attention.
Expand Down Expand Up @@ -1206,6 +1217,10 @@ def fused_attn(
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
stripe_size (int):
Indicates the striping size to be used when using ReorderStrategy.Striped.
Currently, a stripe_size > 1 is only allowed for CP + THD + Striped + AG
0 indicates no striping strategy
Returns:
(jnp.ndarray): The output tensor from the fused attention.

Expand Down Expand Up @@ -1283,5 +1298,6 @@ def fused_attn(
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
stripe_size=stripe_size,
)
return output
Loading