-
Notifications
You must be signed in to change notification settings - Fork 565
[JAX] Add CP + THD + AG + Striped>1 + SWA support #2379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
def57db
e888025
1fc957a
6ade6a4
e8067c6
a33301a
4438298
3da94e5
a51b7d9
fcee4f4
5dfdaf5
f6fb305
104b51e
2494084
c6e5966
298ee6b
1fa57b4
7f205d5
3115064
0391f41
94af413
a788a1d
60191eb
d29f59a
b01340a
c5921af
3bb1d5a
78fec5b
69dad1a
2166305
c5e0d6f
ab81a30
51440db
5e014af
9b5280b
8841f5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -327,7 +327,7 @@ def test_cross_attn( | |
| ] | ||
|
|
||
| DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [ | ||
| # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. | ||
| # Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes. | ||
| pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), | ||
| pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: update pytest param id to include |
||
| ] | ||
|
|
@@ -351,12 +351,14 @@ def impl_test_context_parallel_attn( | |
| use_shardy, | ||
| use_scan_ring=False, | ||
| window_size=None, | ||
| stripe_size=0, | ||
| num_segments_per_seq=0, | ||
| ): | ||
| if qkv_layout.is_thd(): | ||
| if cp_strategy == CPStrategy.ALL_GATHER: | ||
| pytest.skip("THD doesn't support all gather context parallelism.") | ||
| if not load_balanced and cp_strategy == CPStrategy.RING: | ||
| pytest.skip("THD + ring doesn't support unbalanced context parallelism.") | ||
| if not load_balanced and ( | ||
| cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER | ||
| ): | ||
| pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. today I learned |
||
|
|
||
| assert not use_scan_ring or cp_strategy == CPStrategy.RING | ||
|
|
||
|
|
@@ -382,7 +384,6 @@ def impl_test_context_parallel_attn( | |
| data_shape = batch, seqlen, num_head, hidden | ||
|
|
||
| num_kv_heads = num_head // kv_groups | ||
|
|
||
| runner = FusedAttnRunner( | ||
| batch, | ||
| seqlen, | ||
|
|
@@ -407,6 +408,8 @@ def impl_test_context_parallel_attn( | |
| mesh_resource=mesh_resource, | ||
| cp_strategy=cp_strategy, | ||
| cp_load_balanced=load_balanced, | ||
| stripe_size=stripe_size, | ||
| num_segments_per_seq=num_segments_per_seq, | ||
| ) | ||
|
|
||
| def check_has_backend_for_mask(mask_type): | ||
|
|
@@ -457,7 +460,7 @@ def check_has_backend_for_mask(mask_type): | |
| @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) | ||
| @pytest.mark.parametrize( | ||
| "qkv_layout, attn_mask_type", | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we skipping a mask here?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this because the new test below is explicitly testing the index=-1 mask case? If so, can we remove this last mask from the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will revert this back as it makes no functional difference
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. However, the logic to skip this for THD data layout in impl_test_context_parallel_attn() has changed and so just reverting the change will not be enough to skip this for THD data types, so I will add an additional check in test_context_parallel_allgather_attn_shardy() and test_context_parallel_allgather_attn() to skip for THD layouts
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In either case it is just a different way to put things - the older change was kipping THD layouts by filtering them via DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1] but no we will pass DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS and instead filter within test_context_parallel_allgather_attn_shardy() and test_context_parallel_allgather_attn() explicitly |
||
| ) | ||
| def test_context_parallel_allgather_attn_shardy( | ||
| self, | ||
|
|
@@ -486,6 +489,72 @@ def test_context_parallel_allgather_attn_shardy( | |
| use_shardy=True, | ||
| ) | ||
|
|
||
| @pytest_parametrize_wrapper( | ||
| "device_count,mesh_shape,mesh_axes,mesh_resource", | ||
| generate_context_parallel_configs_for_attn(), | ||
| ) | ||
| @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) | ||
| @pytest.mark.parametrize("kv_groups", [1, 8]) | ||
| @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) | ||
| @pytest.mark.parametrize( | ||
| "qkv_layout, attn_mask_type", | ||
| [DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1]], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "load_balanced", | ||
| [pytest.param(True, id="BALANCED")], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "stripe_size", | ||
| [pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't have to be in this PR since I'm guessing this applies to more tests than updated here, but TE/JAX has a pytest util called |
||
| ) | ||
| @pytest.mark.parametrize( | ||
| "window_size", | ||
| [ | ||
| pytest.param((-1, -1), id="window_size(-1, -1)"), | ||
| pytest.param((5, 0), id="window_size(5, 0)"), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "num_segments_per_seq", | ||
| [pytest.param(2, id="SEG-2"), pytest.param(11, id="SEG-11")], | ||
| ) | ||
| def test_context_parallel_allgather_striped_attn( | ||
| self, | ||
| device_count, | ||
| mesh_shape, | ||
| mesh_axes, | ||
| mesh_resource, | ||
| data_shape, | ||
| kv_groups, | ||
| attn_mask_type, | ||
| dtype, | ||
| qkv_layout, | ||
| load_balanced, | ||
| window_size, | ||
| stripe_size, | ||
| num_segments_per_seq, | ||
| ): | ||
| if window_size != (-1, -1) and not qkv_layout.is_thd(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So BSHD with window_size = (-1, -1) is supported? Does that mean window_size (-1, -1) means don't do any striping?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will update this check to : I believe my parametrization of the inputs: [DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1]], did not trigger a BSHD mask anyways so this check never got triggerred. |
||
| pytest.skip("Sliding window attention is only supported for THD layout") | ||
| self.impl_test_context_parallel_attn( | ||
| device_count, | ||
| mesh_shape, | ||
| mesh_axes, | ||
| mesh_resource, | ||
| data_shape, | ||
| kv_groups, | ||
| attn_mask_type, | ||
| dtype, | ||
| qkv_layout, | ||
| load_balanced, | ||
| CPStrategy.ALL_GATHER, | ||
| use_shardy=False, | ||
| window_size=window_size, | ||
| stripe_size=stripe_size, | ||
| num_segments_per_seq=num_segments_per_seq, | ||
| ) | ||
|
|
||
| @pytest_parametrize_wrapper( | ||
| "device_count,mesh_shape,mesh_axes,mesh_resource", | ||
| generate_context_parallel_configs_for_attn(), | ||
|
|
@@ -495,7 +564,7 @@ def test_context_parallel_allgather_attn_shardy( | |
| @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) | ||
| @pytest.mark.parametrize( | ||
| "qkv_layout, attn_mask_type", | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above with this [:-1] slicing of the masks. If this array of masks [:-1] is for one feature and [-1:] is for this new CP striped attention feature, let's split them into two constant lists.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will revert this back as it makes no functional difference
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/NVIDIA/TransformerEngine/pull/2379/files#r2566226953 |
||
| ) | ||
| @pytest.mark.parametrize( | ||
| "load_balanced", | ||
|
|
@@ -538,7 +607,7 @@ def test_context_parallel_allgather_attn( | |
| @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) | ||
| @pytest.mark.parametrize( | ||
| "qkv_layout, attn_mask_type", | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "load_balanced", | ||
|
|
@@ -602,7 +671,7 @@ def test_context_parallel_ring_attn( | |
| @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) | ||
| @pytest.mark.parametrize( | ||
| "qkv_layout, attn_mask_type", | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, | ||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], | ||
| ) | ||
| def test_context_parallel_ring_attn_shardy( | ||
| self, | ||
|
|
@@ -639,31 +708,39 @@ def test_context_parallel_ring_attn_shardy( | |
| "L2": [[4, 32, 12, 32], [1, 16, 1, 1]], | ||
| } | ||
|
|
||
| REORDER_STRATEGY = [ | ||
| pytest.param(ReorderStrategy.DualChunkSwap, None, id="DualChunkSwap"), | ||
| pytest.param(ReorderStrategy.Striped, 1, id="Striped-1"), | ||
| pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"), | ||
| ] | ||
|
|
||
|
|
||
| class TestReorderCausalLoadBalancing: | ||
| @pytest.mark.parametrize("cp_size", [2, 4, 8]) | ||
| @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES) | ||
| @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) | ||
| @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD]) | ||
| @pytest.mark.parametrize( | ||
| "reorder_strategy", | ||
| [ | ||
| pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"), | ||
| pytest.param(ReorderStrategy.Striped, id="Striped"), | ||
| ], | ||
| "reorder_strategy, stripe_size", | ||
| REORDER_STRATEGY, | ||
| ) | ||
| def test(self, cp_size, shape, qkv_format, reorder_strategy): | ||
| def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_size): | ||
| tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) | ||
| seq_dim = 1 | ||
| if qkv_format == QKVFormat.SBHD: | ||
| tensor = tensor.swapaxes(0, 1) | ||
| seq_dim = 0 | ||
|
|
||
| if reorder_strategy == ReorderStrategy.Striped: | ||
| seq_lens = shape[seq_dim] | ||
| if seq_lens < (cp_size * stripe_size): | ||
| pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_size=}") | ||
|
|
||
| ref = tensor.copy() | ||
|
|
||
| reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3]) | ||
| inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3]) | ||
| reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) | ||
| inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) | ||
|
|
||
| reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim) | ||
| inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim) | ||
| reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_size) | ||
| inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_size) | ||
|
|
||
| assert jnp.array_equal(inversed, ref) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,6 +230,7 @@ def make_mask( | |
| @jax.jit | ||
| def get_seqlens_and_offsets(segment_ids): | ||
| batch, max_seqlen = segment_ids.shape | ||
| # TODO: should this be max_seqlen + 1 ? | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reminder about TODO
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am removing this for now as it is not related to this PR. |
||
| bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) | ||
| seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) | ||
| seqlens = seqlens_with_zero[..., 1:] | ||
|
|
@@ -352,6 +353,8 @@ class FusedAttnRunner: | |
| bias_shape: BiasShape | ||
| window_size: Tuple[int, int] | ||
| seq_desc_format: SeqDescFormat | ||
| stripe_size: int = 0 | ||
| num_segments_per_seq: int = 0 | ||
|
|
||
| # Specifies sharding resources for distributed tests | ||
| number_of_devices: int = 1 | ||
|
|
@@ -577,7 +580,9 @@ def generate_random_segment_ids( | |
| return segment_ids, segment_pos, segment_pad | ||
|
|
||
| if self.qkv_layout.is_thd(): | ||
| self.num_segments_per_seq = 2 | ||
| # If using default num segments of 0, set to 2 | ||
| if self.num_segments_per_seq == 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the default overridden here? Can we make the default
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I will make this change. It was in my TODO but missed it. |
||
| self.num_segments_per_seq = 2 | ||
| self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( | ||
| self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 | ||
| ) | ||
|
|
@@ -635,12 +640,14 @@ def generate_random_segment_ids( | |
| strategy=reorder_strategy, | ||
| cp_size=self.cp_size, | ||
| seq_dim=seq_dim, | ||
| stripe_size=self.stripe_size, | ||
| ) | ||
| self.cp_inverse_reorder_fn = partial( | ||
| inverse_reorder_causal_load_balancing, | ||
| strategy=reorder_strategy, | ||
| cp_size=self.cp_size, | ||
| seq_dim=seq_dim, | ||
| stripe_size=self.stripe_size, | ||
| ) | ||
| else: | ||
| # no-ops for non cp or non load balanced | ||
|
|
@@ -771,7 +778,7 @@ def to_dp_shardings(x): | |
|
|
||
| def test_forward(self): | ||
| """ | ||
| Test forward without JIT | ||
| Test forward with JITted primitive and unJITted reference | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason the reference JAX impl is unjitted? It should be equivalent JIT'd and could speed up our tests. Ack, this is unrelated to this PR's focus, just mentioning it so we can discuss if this could be improved in a separate PR
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsure tbh. I was expecting it to be JItted as well but when I noticed it wasn't I thought it should at least be explicit in the doc string to address later. Agree with you |
||
| """ | ||
| self._setup_inputs() | ||
|
|
||
|
|
@@ -801,6 +808,7 @@ def test_forward(self): | |
| "window_size": self.window_size, | ||
| "context_parallel_strategy": self.cp_strategy, | ||
| "context_parallel_causal_load_balanced": self.cp_load_balanced, | ||
| "stripe_size": self.stripe_size, | ||
| } | ||
|
|
||
| customcall_fused_dpa_jit = jit( | ||
|
|
@@ -896,6 +904,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): | |
| "window_size": self.window_size, | ||
| "context_parallel_strategy": self.cp_strategy, | ||
| "context_parallel_causal_load_balanced": self.cp_load_balanced, | ||
| "stripe_size": self.stripe_size, | ||
| } | ||
|
|
||
| # We can compute dBias only for the [1, h, s, s] layout | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -386,23 +386,27 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): | |
| return batch, q_max_seqlen, kv_max_seqlen | ||
|
|
||
|
|
||
| def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int): | ||
| def reorder_causal_load_balancing( | ||
| tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int = 1 | ||
| ): | ||
| """Reorders a tensor for load balancing the compute of causal attention.""" | ||
| if strategy == ReorderStrategy.DualChunkSwap: | ||
| return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) | ||
| if strategy == ReorderStrategy.Striped: | ||
| return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False) | ||
| # stripe_size > 1 is only supported for CP+THD+AG+Striped | ||
| return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False, stripe_size) | ||
| raise ValueError(f"Unsupported {strategy=}") | ||
|
|
||
|
|
||
| def inverse_reorder_causal_load_balancing( | ||
| tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int | ||
| tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int = 1 | ||
| ): | ||
| """Inverse operation of `reorder_causal_load_balancing`.""" | ||
| if strategy == ReorderStrategy.DualChunkSwap: | ||
| return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True) | ||
| if strategy == ReorderStrategy.Striped: | ||
| return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True) | ||
| # stripe_size > 1 is only supported for CP+THD+AG+Striped | ||
| return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True, stripe_size) | ||
| raise ValueError(f"Unsupported {strategy=}") | ||
|
|
||
|
|
||
|
|
@@ -988,7 +992,7 @@ def fused_attn_thd( | |
| return output | ||
|
|
||
|
|
||
| @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) | ||
| @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) | ||
| def _fused_attn( | ||
| qkv: Tuple[jnp.ndarray, ...], | ||
| bias: Optional[jnp.ndarray], | ||
|
|
@@ -1008,6 +1012,7 @@ def _fused_attn( | |
| context_parallel_causal_load_balanced: bool, | ||
| context_parallel_axis: str, | ||
| context_checkpoint_name: str = "context", | ||
| stripe_size: int = 0, | ||
| ): | ||
| output, _ = _fused_attn_fwd_rule( | ||
| qkv, | ||
|
|
@@ -1028,6 +1033,7 @@ def _fused_attn( | |
| context_parallel_causal_load_balanced, | ||
| context_parallel_axis, | ||
| context_checkpoint_name=context_checkpoint_name, | ||
| stripe_size=stripe_size, | ||
| ) | ||
| return output | ||
|
|
||
|
|
@@ -1051,6 +1057,7 @@ def _fused_attn_fwd_rule( | |
| context_parallel_causal_load_balanced, | ||
| context_parallel_axis, | ||
| context_checkpoint_name, | ||
| stripe_size, | ||
| ): | ||
| output, softmax_aux, rng_state = tex.fused_attn_fwd( | ||
| qkv, | ||
|
|
@@ -1070,6 +1077,7 @@ def _fused_attn_fwd_rule( | |
| context_parallel_strategy=context_parallel_strategy, | ||
| context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, | ||
| context_parallel_axis=context_parallel_axis, | ||
| stripe_size=stripe_size, | ||
| ) | ||
| output = checkpoint_name(output, context_checkpoint_name) | ||
| softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) | ||
|
|
@@ -1099,6 +1107,7 @@ def _fused_attn_bwd_rule( | |
| context_parallel_causal_load_balanced, | ||
| context_parallel_axis, | ||
| context_checkpoint_name, | ||
| stripe_size, | ||
| ctx, | ||
| dz, | ||
| ): | ||
|
|
@@ -1133,6 +1142,7 @@ def _fused_attn_bwd_rule( | |
| context_parallel_strategy=context_parallel_strategy, | ||
| context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, | ||
| context_parallel_axis=context_parallel_axis, | ||
| stripe_size=stripe_size, | ||
| ) | ||
| if attn_bias_type == AttnBiasType.NO_BIAS: | ||
| grad_bias = None | ||
|
|
@@ -1169,6 +1179,7 @@ def fused_attn( | |
| context_parallel_axis: str = "", | ||
| context_checkpoint_name: str = "context", | ||
| softmax_offset: Optional[jnp.ndarray] = None, | ||
| stripe_size: int = 0, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the default be 1 or 0? |
||
| ): | ||
| """ | ||
| Perform cuDNN fused attention. | ||
|
|
@@ -1206,6 +1217,10 @@ def fused_attn( | |
| softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape | ||
| [1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX. | ||
| If provided, this parameter will receive gradients during backpropagation. | ||
| stripe_size (int): | ||
| Indicates the striping size to be used when using ReorderStrategy.Striped. | ||
| Currently, a stripe_size > 1 is only allowed for CP + THD + Striped + AG | ||
| 0 indicates no striping strategy | ||
| Returns: | ||
| (jnp.ndarray): The output tensor from the fused attention. | ||
|
|
||
|
|
@@ -1283,5 +1298,6 @@ def fused_attn( | |
| context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, | ||
| context_parallel_axis=context_parallel_axis, | ||
| context_checkpoint_name=context_checkpoint_name, | ||
| stripe_size=stripe_size, | ||
| ) | ||
| return output | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for measuring the test runtime! What is our total L1 test time at currently? We were intermittently failing previously but reduced attention tests by ~15mins iirc. This is now increasing by 11mins, so I'm concerned we may hit timeouts again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think with the merging of this PR we will be around the 80-85 mins mark for the L1 tests, so I do not expect to hit timeouts (closer to 120 mins)
I believe after your changes to reduce the L1 test timing we had stopped hitting the limit anyways. My reduction to attention tests was more of an additional effort and we reduced total time to closer to 70 mins if I remember right.
Now with sink attention and this PR, we should have a total increase of about ~15 mins so I expect to hit ~85 mins. Nonethless, I will report the findings from the last CI pipeline I run.
I do not think we are alarmingly close but a clean up in the future would only help.