From def57db1b6aa651a6003c4333f9a95b6a515f9fe Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 15 Oct 2025 23:48:25 +0000 Subject: [PATCH 01/36] Add generic stripe_height support for load balancing Signed-off-by: Kshitij Lakhani --- .../jax/cpp_extensions/attention.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index f0778bfd29..d259771786 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1234,31 +1234,31 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig return combined.reshape(ori_tensor_shape) -def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool): +def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_height:int = 1): """Reorders a tensor for load balancing with striped pattern""" origin_shape = tensor.shape - if origin_shape[seq_dim] % cp_size != 0: + if origin_shape[seq_dim] % (cp_size*stripe_height) != 0: raise ValueError( - "Expected origin_shape[seq_dim] is multiple of cp_size but got" - f" {origin_shape[seq_dim]=} and {cp_size=}" + "Expected origin_shape[seq_dim] is multiple of cp_size*stripe_height but got" + f" {origin_shape[seq_dim]=}, {cp_size=}, {stripe_height=}, {cp_size*stripe_height=}" ) if not is_inverse: new_shape = [ *origin_shape[:seq_dim], - *[origin_shape[seq_dim] // cp_size, cp_size], + *[origin_shape[seq_dim] // (cp_size*stripe_height), cp_size, stripe_height], *origin_shape[seq_dim + 1 :], ] else: new_shape = [ *origin_shape[:seq_dim], - *[cp_size, origin_shape[seq_dim] // cp_size], + *[stripe_height, cp_size, origin_shape[seq_dim] // (cp_size*stripe_height)], *origin_shape[seq_dim + 1 :], ] - chunked_tensor = tensor.reshape(new_shape) - reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1) - return reordered_chunked_tensor.reshape(origin_shape) + striped_tensor = tensor.reshape(new_shape) + reordered_striped_tensor = jnp.swapaxes(striped_tensor, seq_dim, seq_dim + 1) + return reordered_striped_tensor.reshape(origin_shape) @dataclass(frozen=True) From e88802584fe9be59a239d4b2b10078fedb865854 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 15 Oct 2025 21:32:00 +0000 Subject: [PATCH 02/36] Fix imports in test for deprecated jax.experimental.pjit Signed-off-by: Kshitij Lakhani --- tests/jax/distributed_test_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 137fa480dd..170e13e054 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -154,9 +154,13 @@ def compare_ops( grad_args = tuple(range(len(inputs))) target_grad_func = jax.value_and_grad(target_func, argnums=grad_args) +<<<<<<< HEAD target_jitter = jax.jit( target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings ) +======= + target_jitter = jax.jit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) +>>>>>>> 57b57292 (Fix imports in test for deprecated jax.experimental.pjit) target_fwd, target_grads = target_jitter(*inputs, **kwargs) target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text() From 1fc957a68eb884074a5d722d018ff54568422677 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 16 Oct 2025 14:36:34 -0700 Subject: [PATCH 03/36] Add test case for stripe_height greater than 1. Add stripe_height arg to reordering methods Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_distributed_fused_attn.py | 17 ++++++++++++----- transformer_engine/jax/attention.py | 8 ++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 5372018ae8..dfc960109a 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -651,7 +651,14 @@ class TestReorderCausalLoadBalancing: pytest.param(ReorderStrategy.Striped, id="Striped"), ], ) - def test(self, cp_size, shape, qkv_format, reorder_strategy): + @pytest.mark.parametrize( + "stripe_height", + [ + pytest.param(1, id="stripe-1"), + pytest.param(4, id="stripe-4"), + ], + ) + def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_height): tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) seq_dim = 1 if qkv_format == QKVFormat.SBHD: @@ -660,10 +667,10 @@ def test(self, cp_size, shape, qkv_format, reorder_strategy): ref = tensor.copy() - reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3]) - inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3]) + reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) + inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) - reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim) - inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim) + reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_height) + inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_height) assert jnp.array_equal(inversed, ref) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 0a32be9679..b37d2d09be 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -386,23 +386,23 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen -def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int): +def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1): """Reorders a tensor for load balancing the compute of causal attention.""" if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) if strategy == ReorderStrategy.Striped: - return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False) + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False, stripe_height) raise ValueError(f"Unsupported {strategy=}") def inverse_reorder_causal_load_balancing( - tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1 ): """Inverse operation of `reorder_causal_load_balancing`.""" if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True) if strategy == ReorderStrategy.Striped: - return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True) + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True, stripe_height) raise ValueError(f"Unsupported {strategy=}") From 6ade6a461abf05a41a4b46ecbc783c4c4bc2a455 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 16 Oct 2025 16:59:29 -0700 Subject: [PATCH 04/36] Add Striped 1 and 4 test cases. Refactor the Load Balancing test case. Fix the incorrect shape in striping inverser reordering Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_distributed_fused_attn.py | 25 ++++++++++--------- .../jax/cpp_extensions/attention.py | 2 +- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index dfc960109a..63d1ca26d3 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -639,24 +639,19 @@ def test_context_parallel_ring_attn_shardy( "L2": [[4, 32, 12, 32], [1, 16, 1, 1]], } +REORDER_STRATEGY = [ + pytest.param(ReorderStrategy.DualChunkSwap, None, id="DualChunkSwap"), + pytest.param(ReorderStrategy.Striped, 1, id="Striped-1"), + pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"), +] class TestReorderCausalLoadBalancing: @pytest.mark.parametrize("cp_size", [2, 4, 8]) @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) @pytest.mark.parametrize( - "reorder_strategy", - [ - pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"), - pytest.param(ReorderStrategy.Striped, id="Striped"), - ], - ) - @pytest.mark.parametrize( - "stripe_height", - [ - pytest.param(1, id="stripe-1"), - pytest.param(4, id="stripe-4"), - ], + "reorder_strategy, stripe_height", + REORDER_STRATEGY, ) def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_height): tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) @@ -665,6 +660,12 @@ def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_height): tensor = tensor.swapaxes(0, 1) seq_dim = 0 + if reorder_strategy == ReorderStrategy.Striped: + seq_lens = shape[seq_dim] + if seq_lens < (cp_size*stripe_height): + pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_height=}") + + ref = tensor.copy() reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index d259771786..c682d41533 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1252,7 +1252,7 @@ def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool, else: new_shape = [ *origin_shape[:seq_dim], - *[stripe_height, cp_size, origin_shape[seq_dim] // (cp_size*stripe_height)], + *[cp_size, origin_shape[seq_dim] // (cp_size*stripe_height), stripe_height], *origin_shape[seq_dim + 1 :], ] From e8067c66d9bcdaf9624377a50593986ad158d1f9 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 23 Oct 2025 12:31:47 -0700 Subject: [PATCH 05/36] Modify test code for CP + AG + THD + stripe height greater than 1 Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_distributed_fused_attn.py | 11 ++++++++--- tests/jax/test_fused_attn.py | 7 ++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 63d1ca26d3..10a1dd2db7 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -327,7 +327,8 @@ def test_cross_attn( ] DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [ - # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. + # Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes. + #TODO: Change the id to CPx2 pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), ] @@ -353,8 +354,8 @@ def impl_test_context_parallel_attn( window_size=None, ): if qkv_layout.is_thd(): - if cp_strategy == CPStrategy.ALL_GATHER: - pytest.skip("THD doesn't support all gather context parallelism.") + # if cp_strategy == CPStrategy.ALL_GATHER: + # pytest.skip("THD doesn't support all gather context parallelism.") if not load_balanced and cp_strategy == CPStrategy.RING: pytest.skip("THD + ring doesn't support unbalanced context parallelism.") @@ -383,6 +384,9 @@ def impl_test_context_parallel_attn( num_kv_heads = num_head // kv_groups + # KL code For AG case only + stripe_height = 4 if qkv_layout.is_thd() and cp_strategy == CPStrategy.ALL_GATHER else 0 + runner = FusedAttnRunner( batch, seqlen, @@ -407,6 +411,7 @@ def impl_test_context_parallel_attn( mesh_resource=mesh_resource, cp_strategy=cp_strategy, cp_load_balanced=load_balanced, + stripe_height=stripe_height ) def check_has_backend_for_mask(mask_type): diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f4caaef165..fafde75b33 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -352,6 +352,7 @@ class FusedAttnRunner: bias_shape: BiasShape window_size: Tuple[int, int] seq_desc_format: SeqDescFormat + stripe_height: int = 0 # Specifies sharding resources for distributed tests number_of_devices: int = 1 @@ -635,12 +636,14 @@ def generate_random_segment_ids( strategy=reorder_strategy, cp_size=self.cp_size, seq_dim=seq_dim, + stripe_height=self.stripe_height, ) self.cp_inverse_reorder_fn = partial( inverse_reorder_causal_load_balancing, strategy=reorder_strategy, cp_size=self.cp_size, seq_dim=seq_dim, + stripe_height=self.stripe_height, ) else: # no-ops for non cp or non load balanced @@ -771,7 +774,7 @@ def to_dp_shardings(x): def test_forward(self): """ - Test forward without JIT + Test forward with JITted primitive and unJITted reference """ self._setup_inputs() @@ -801,6 +804,7 @@ def test_forward(self): "window_size": self.window_size, "context_parallel_strategy": self.cp_strategy, "context_parallel_causal_load_balanced": self.cp_load_balanced, + "stripe_height": self.stripe_height, } customcall_fused_dpa_jit = jit( @@ -896,6 +900,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): "window_size": self.window_size, "context_parallel_strategy": self.cp_strategy, "context_parallel_causal_load_balanced": self.cp_load_balanced, + #"stripe_height": self.stripe_height, } # We can compute dBias only for the [1, h, s, s] layout From a33301aafe85dff765a53dcbf042be2567b8fa02 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 23 Oct 2025 12:40:30 -0700 Subject: [PATCH 06/36] Add stripe_height arg to fused attn and fused attn fwd API. Add appropriate mask checks for AG+THD+CP and pick BRCM to be executed per rank. Add Fused Attn Primitive for CP + THD +AG + Striping. Add a method to reorder and all gather segment ids and offsets for kv Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 10 ++ .../jax/cpp_extensions/attention.py | 163 ++++++++++++++++-- 2 files changed, 163 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index b37d2d09be..74f50bcae4 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1008,6 +1008,7 @@ def _fused_attn( context_parallel_causal_load_balanced: bool, context_parallel_axis: str, context_checkpoint_name: str = "context", + stripe_height: int = 0, ): output, _ = _fused_attn_fwd_rule( qkv, @@ -1028,6 +1029,7 @@ def _fused_attn( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name=context_checkpoint_name, + stripe_height=stripe_height, ) return output @@ -1051,6 +1053,7 @@ def _fused_attn_fwd_rule( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name, + stripe_height, ): output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, @@ -1070,6 +1073,7 @@ def _fused_attn_fwd_rule( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, + stripe_height=stripe_height ) output = checkpoint_name(output, context_checkpoint_name) softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) @@ -1169,6 +1173,7 @@ def fused_attn( context_parallel_axis: str = "", context_checkpoint_name: str = "context", softmax_offset: Optional[jnp.ndarray] = None, + stripe_height: int = 0, ): """ Perform cuDNN fused attention. @@ -1206,6 +1211,10 @@ def fused_attn( softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape [1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX. If provided, this parameter will receive gradients during backpropagation. + stripe_height (int): + Indicates the striping height to be used when using ReorderStrategy.Striped. + Currently, a stripe_height > 1 is only allowed for CP + THD + Striped + AG + 0 indicates no striping strategy Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -1283,5 +1292,6 @@ def fused_attn( context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, context_checkpoint_name=context_checkpoint_name, + stripe_height=stripe_height, ) return output diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index c682d41533..8ddccde763 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -73,6 +73,7 @@ "context_parallel_load_balanced", "cp_axis", "cp_striped_window_size", + "stripe_height", ], ) @dataclass(frozen=True) @@ -93,6 +94,7 @@ class _FusedAttnConfig: context_parallel_load_balanced: bool cp_axis: str cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA + stripe_height: int # Only for CP + Striped. For, Ring P2P , stripe_height=1 only. @dataclass(frozen=True) @@ -527,7 +529,7 @@ def impl( segment_ids=(_q_segment_ids, _kv_segment_ids), segment_pos=(_q_segment_pos, _kv_segment_pos), ) - + breakpoint() (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( sequence_descriptor.get_seqlens_and_offsets( config.attn_mask_type, @@ -536,6 +538,7 @@ def impl( config.max_segments_per_seq, ) ) + #jax.debug.breakpoint() if config.qkv_layout.is_thd(): @@ -1272,26 +1275,40 @@ def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused attention" - allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] + allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] if self.config.qkv_layout not in allowed_layouts: raise ValueError( f"{header} only supports layouts:" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) - + + if (not self.config.qkv_layout.is_thd() and self.config.stripe_height != 0) or (self.config.qkv_layout.is_thd() and self.config.stripe_height == 0): + raise ValueError( + f"{header} only supports Dual Chunk load balancing with BSHD layouts and Striped load balancing with THD layouts" + ) + if self.config.attn_bias_type != AttnBiasType.NO_BIAS: raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") - + + #TODO: Should AttnMaskType.PADDING_CAUSAL_MASK be allowed for CP + AG + THD + Striped ? + #TODO: Should Should AttnMaskType.NO_MASK be allowed for CP + AG + THD + Striped ? allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] + if self.config.qkv_layout.is_thd(): + allowed_masks.append(AttnMaskType.PADDING_CAUSAL_MASK) if self.config.attn_mask_type not in allowed_masks: raise ValueError( f"{header} only supports masking types: " f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" ) + #TODO: For now do not all CP + AG + THD + Striped with NO_MASK + if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd(): + raise ValueError( + f"{header} only supports CAUSAL_MASK for THD types" + ) - if self.config.max_segments_per_seq != 1: + if self.config.max_segments_per_seq != 1 and (not self.config.qkv_layout.is_thd): raise ValueError( - f"{header} only supports max_segments_per_seq == 1 got:" + f"{header} only supports max_segments_per_seq == 1 for BSHD layouts, got:" f" {self.config.max_segments_per_seq}" ) @@ -1305,12 +1322,15 @@ def check_supported(self): def get_adjusted_mask(self): """Converts the mask for context parallelism.""" - if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK: + if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK and not self.config.qkv_layout.is_thd(): # BSHD only ? return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK + if self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK and self.config.qkv_layout.is_thd(): # THD only ? + return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + #TODO: Should the max_segments_per_seq be different ? return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=self.get_adjusted_mask(), @@ -1324,6 +1344,7 @@ def get_step_config(self) -> _FusedAttnConfig: context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, + stripe_height=self.config.stripe_height, ) def all_gather_kv(self, k, v): @@ -1335,7 +1356,10 @@ def ag(x): ) if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True) + if self.config.qkv_layout.is_thd(): + x = reorder_causal_striped(x, cp_size, 1, True, self.config.stripe_height) + else: + x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True) return x if self.config.qkv_layout.is_kvpacked(): @@ -1344,6 +1368,26 @@ def ag(x): return ag(k), ag(v) return k, v # fall through + + def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos): + """Performs a all-gather of k and v over context parallel ranks.""" + + #TODO: Is the axis chosen right ? + kv_segment_ids = lax_paral_op( + kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True + ) + kv_segment_pos = lax_paral_op( + kv_segment_pos, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True + ) + #jax.debug.breakpoint() + if self.config.context_parallel_load_balanced: + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + if self.config.qkv_layout.is_thd(): + kv_segment_ids_ag = reorder_causal_striped(kv_segment_ids, cp_size, 1, True, self.config.stripe_height) + kv_segment_pos_ag = reorder_causal_striped(kv_segment_pos, cp_size, 1, True, self.config.stripe_height) + return kv_segment_ids_ag, kv_segment_pos_ag + #TODO: Is the dual chunk case needed ? + return kv_segment_ids, kv_segment_pos # fall through def reduce_scatter_dkv(self, dk, dv): """Performs a reduce-scatter of dk and dv over context parallel ranks.""" @@ -1454,6 +1498,7 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[5] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + jax.debug.breakpoint() def impl( q, @@ -1473,6 +1518,7 @@ def impl( ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + breakpoint() # cuDNN does not support right-aligned masking with dynamic sequence length padding. # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch @@ -1492,6 +1538,7 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): ) results = [] + breakpoint() for sub_idx in range(2): if config.attn_mask_type == AttnMaskType.NO_MASK: k_unmasked, v_unmasked = k, v # full kv used for unmasked @@ -1501,7 +1548,7 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): q_seqlen_for_step = q_seqlen / (cp_size * 2) num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks - + breakpoint() output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_split[sub_idx], k_unmasked, @@ -1721,6 +1768,92 @@ def _cross_attn_bwd( register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) +class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Attention Forward with Context Parallelism and Striped Load Balancing Primitive + + This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + #jax.debug.breakpoint() + + def impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + #jax.debug.breakpoint() + + # cuDNN does not support right-aligned masking with dynamic sequence length padding. + # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch + # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor + # meeting the expectation of the SPMD model. + # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding + # mask/sequence length tensor to avoid this unrolled loop. + def _cross_attn(idx, q, k, v, bias, kv_segment_ids_ag, kv_segment_pos_ag, seed): + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( + q, + k, #ag + v, #ag + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + kv_segment_ids_ag, + _q_segment_pos, + kv_segment_pos_ag, + config=helper.get_step_config(), + ) + return output, softmax_aux, rng_state + + k_ag, v_ag = helper.all_gather_kv(k, v) + _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(_kv_segment_ids, _kv_segment_pos) + + functions = [ + partial(_cross_attn, idx, q, k_ag, v_ag, bias, _kv_segment_ids_ag, _kv_segment_pos_ag, seed) + for idx in range(cp_size) + ] + + return lax.switch(cp_rank, functions) + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPStripedWithAllGatherFwdPrimitive) @dataclass(frozen=True) class _FusedAttnCPWithP2PHelper: @@ -1811,6 +1944,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, + stripe_height=self.config.stripe_height, ) def stack_kv(self, k, v): @@ -2693,6 +2827,7 @@ def fused_attn_fwd( context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", + stripe_height: int = 0, ) -> jnp.ndarray: """ Perform the forward pass of with cuDNN fused attention implementations. @@ -2731,6 +2866,7 @@ def fused_attn_fwd( context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + stripe_height (int): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -2796,12 +2932,16 @@ def fused_attn_fwd( context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, + stripe_height=stripe_height, ) primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive + if qkv_layout.is_thd(): + primitive = FusedAttnCPStripedWithAllGatherFwdPrimitive.outer_primitive + else: + primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive case CPStrategy.RING: # We must use stripe attention for THD-RING if qkv_layout.is_thd(): @@ -2818,6 +2958,7 @@ def fused_attn_fwd( *seq_desc_flatten, config=fused_config, ) + breakpoint() rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None)) return (output, softmax_aux, rng_state) @@ -2941,6 +3082,7 @@ def fused_attn_bwd( attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" + #TODO: stripe_height hardcoded for now as bwd is not being tests fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, @@ -2954,6 +3096,7 @@ def fused_attn_bwd( context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, + stripe_height=0, ) primitive = None From 44382986af1bbe006f8885aa71223d41bd388ce2 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 23 Oct 2025 12:41:05 -0700 Subject: [PATCH 07/36] TMP: Throwaway testing commit Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_distributed_fused_attn.py | 8 +- tests/jax/test_fused_attn.py | 35 +++- tests/jax/utils.py | 23 ++- .../fused_attn_f16_arbitrary_seqlen.cu | 179 +++++++++++++++++- transformer_engine/common/fused_attn/utils.cu | 8 + transformer_engine/jax/attention.py | 16 +- .../jax/cpp_extensions/attention.py | 2 +- 7 files changed, 247 insertions(+), 24 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 10a1dd2db7..9fa327d777 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -331,6 +331,8 @@ def test_cross_attn( #TODO: Change the id to CPx2 pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), + # KL test code + pytest.param([2, 8, 16, 64], id="2-8xCP-16-64"), ] @@ -451,7 +453,9 @@ def check_has_backend_for_mask(mask_type): if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") - runner.test_backward() + #KL code + #runner.test_backward() + runner.test_forward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] @pytest_parametrize_wrapper( @@ -653,7 +657,7 @@ def test_context_parallel_ring_attn_shardy( class TestReorderCausalLoadBalancing: @pytest.mark.parametrize("cp_size", [2, 4, 8]) @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES) - @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) + @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD]) @pytest.mark.parametrize( "reorder_strategy, stripe_height", REORDER_STRATEGY, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index fafde75b33..d738b46853 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -464,6 +464,7 @@ def _setup_inputs(self): self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1) self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1) + breakpoint() key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6) @@ -490,9 +491,25 @@ def _setup_inputs(self): else: pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") - self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0) - self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0) - self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) + # self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0) + # self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0) + # self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) + # KL test code + q_np = np.zeros(q_shape, self.dtype) + k_np = np.zeros(k_shape, self.dtype) + token_numbers_q = range(self.max_seqlen_q) + token_numbers_k = range(self.max_seqlen_kv) + for batch_idx in range(q_shape[0]): + for token_idx in token_numbers_q: + q_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * (token_idx + 1) + for token_idx in token_numbers_k: + k_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * np.sqrt(self.head_dim_qk) + v_np = np.ones(v_shape, self.dtype) + # Set cols at multiples + v_np[0,::4, 0, :] = np.arange(v_np.shape[3]) + self.q = jnp.array(q_np) + self.k = jnp.array(k_np) + self.v = jnp.array(v_np) if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.bias_shape == BiasShape._1HSS: @@ -558,6 +575,8 @@ def generate_random_segment_ids( min_segment_size = 1 if min_segment_len is not None: min_segment_size = min_segment_len[i][seg_id] + #KL test code + min_segment_size = 4 segment_size = rng.integers(min_segment_size, max_segment_size + 1) if current_pos + segment_size > sequence_length: break @@ -613,6 +632,8 @@ def generate_random_segment_ids( ) self.segment_pos_q = self.segment_pos_kv = None self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None + print(f"self.segment_ids_q: {self.segment_ids_q}, \n self.segment_pos_q: {self.segment_pos_q}, \n self.pad_q: {self.pad_q}, \n self.seqlens_q: {self.seqlens_q}, \n self.offsets_q: { self.offsets_q} \n") + print(f"self.segment_ids_kv: {self.segment_ids_kv}, \n self.segment_pos_kv: {self.segment_pos_kv}, \n self.pad_kv: {self.pad_kv}, \n self.seqlens_kv: {self.seqlens_kv}, \n self.offsets_kv: { self.offsets_kv} \n") # For reference code self.mask = make_mask( @@ -623,6 +644,10 @@ def generate_random_segment_ids( self.attn_mask_type, self.window_size, ) + # KL tet code + import sys + with np.printoptions(threshold=sys.maxsize): + print(f"self.mask: \n {self.mask}") if self.cp_size > 1 and self.cp_load_balanced: if self.qkv_layout.is_thd(): @@ -671,6 +696,7 @@ def generate_random_segment_ids( self.cp_reorder_fn(self.segment_pos_kv), ), ) + breakpoint() case _: raise ValueError(f"Unknown {self.seq_desc_format=}") else: @@ -735,6 +761,7 @@ def to_dp_shardings(x): self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor) + #jax.debug.breakpoint() if self.bias_shape == BiasShape._1HSS: self.bias_pspec = PartitionSpec( None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None @@ -1146,7 +1173,7 @@ class TestFusedAttn: pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"), ], ) - def _test_forward( + def test_forward( b, s_q, s_kv, diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 7194e387c7..571da23a76 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1507,7 +1507,28 @@ def assert_allclose( actual = actual.astype(jnp.float32) if not isinstance(desired, float): desired = desired.astype(jnp.float32) - + # KL test code + import sys + mismatch_counter = 0 + has_nonzero = jnp.any(actual != 0) + print(f"has_nonzero: {has_nonzero}") + breakpoint() + with np.printoptions(threshold=sys.maxsize): + mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch + diff_indices = np.argwhere(mismatch_mask) + for idx in diff_indices: + idx_tuple = tuple(idx) + mismatch_counter += 1 + if mismatch_counter < 1024: + print(f"Index {idx_tuple}: a={actual[idx_tuple]}, d={desired[idx_tuple]}") + # Batch 0 and head 0 + # for seq_idx in range(actual.shape[1]): + # #print("Mismatch at positions:\n", np.argwhere(mismatch_mask[0,:,0,:])) # Pick indices where mask is True + # for d_idx in range(actual.shape[3]): + # # print mismatches + # #if mismatch_mask[0][seq_idx][0][d_idx] == True: + # print(f"seq_idx: {seq_idx}, d_idx: {d_idx}, A: {actual[0][seq_idx][0][d_idx]}, D: {desired[0][seq_idx][0][d_idx]}") + print(f"mismatch_counter: {mismatch_counter}") # Check if tensors are close np.testing.assert_allclose(actual, desired, **tols, **kwargs) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 14468b543a..1fbdee492a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -48,6 +48,19 @@ namespace transformer_engine { namespace fused_attn { +template +__global__ void print_tensor_elements_2(const T *const data, const size_t rows, const size_t start_cols, const size_t end_cols, const size_t cols) { + if ((threadIdx.x == 0) && (threadIdx.y == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + for (size_t i = 0; i < rows; ++i) { + for (size_t j = start_cols; j < end_cols; ++j) { + const size_t idx = i * cols + j; + printf("%8f ", static_cast(data[idx])); + } + printf("\n"); + } + } +} + void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, @@ -459,6 +472,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bias) { variant_pack[bias] = devPtrBias; } + //KL test code + bool print_tensors = true; + // For the thd_regular case, the actual_b = 18 + bool print_tensors_custom_mask = actual_b >= 300 ? true : false; if (is_padding) { constexpr size_t nthreads_per_block = 128; @@ -470,6 +487,78 @@ void fused_attn_arbitrary_seqlen_fwd_impl( static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); NVTE_CHECK_CUDA(cudaGetLastError()); + std::cout << "print_tensors: " << print_tensors << + "print_tensors_custom_mask: " + << print_tensors_custom_mask << std::endl; + if (print_tensors) + { + if(devPtrCuSeqlensQ) { + if(print_tensors_custom_mask) + { + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, 0, 8, /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, + 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, + 8184, 8192, + /*does not matter for single row*/ actual_b); + } + else + { + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, 0, actual_b, /*does not matter for single row*/actual_b); + } + } + if (devActualSeqlenQ) { + if (print_tensors_custom_mask) + { + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, 0, 8, /*does not matter for single row*/actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, + 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, + 8184, 8192, + /*does not matter for single row*/ actual_b); + } + else { + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, 0, actual_b, /*does not matter for single row*/actual_b); + } + } + if(devPtrCuSeqlensKV) { + if(print_tensors_custom_mask) + { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrCuSeqlensKV), 1, 0, 8, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrCuSeqlensKV), 1, 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrCuSeqlensKV), 1, 8184, 8192, + /*does not matter for single row*/ actual_b); + } + else { + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + } + } + if(devActualSeqlenKV) { + if (print_tensors_custom_mask) + { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devActualSeqlenKV), 1, 0, 8, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devActualSeqlenKV), 1, 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devActualSeqlenKV), 1, 8184, 8192, + /*does not matter for single row*/ actual_b); + } + else + { + print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + } + } + } variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -509,18 +598,88 @@ void fused_attn_arbitrary_seqlen_fwd_impl( static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); NVTE_CHECK_CUDA(cudaGetLastError()); - if (is_ragged_q) { - variant_pack[offset_q] = devOffsetsQ; - variant_pack[offset_o] = devOffsetsO; - } - if (is_ragged_kv) { - variant_pack[offset_k] = devOffsetsK; - variant_pack[offset_v] = devOffsetsV; + if (print_tensors) { + if (devPtrSeqOffsetsQ) { + if (print_tensors_custom_mask) { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrSeqOffsetsQ), 1, 0, 8, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrSeqOffsetsQ), 1, 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrSeqOffsetsQ), 1, 8184, 8192, + /*does not matter for single row*/ actual_b); + } else { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrSeqOffsetsQ), 1, 0, actual_b, + /*does not matter for single row*/ actual_b); + } + } + if (devOffsetsQ) { + if (print_tensors_custom_mask) { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devOffsetsQ), 1, 0, 8, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devOffsetsQ), 1, 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devOffsetsQ), 1, 8184, 8192, + /*does not matter for single row*/ actual_b); + } else { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devOffsetsQ), 1, 0, actual_b, + /*does not matter for single row*/ actual_b); + } + } + if (devPtrSeqOffsetsKV) { + if (print_tensors_custom_mask) { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrSeqOffsetsKV), 1, 0, 8, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrSeqOffsetsKV), 1, 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrSeqOffsetsKV), 1, 8184, 8192, + /*does not matter for single row*/ actual_b); + } else { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrSeqOffsetsKV), 1, 0, actual_b, + /*does not matter for single row*/ actual_b); + } + } + if (devOffsetsK) { + if (print_tensors_custom_mask) { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devOffsetsK), 1, 0, 8, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devOffsetsK), 1, 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devOffsetsK), 1, 8184, 8192, + /*does not matter for single row*/ actual_b); + } else { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devOffsetsK), 1, 0, actual_b, + /*does not matter for single row*/ actual_b); + } + } } - if (is_ragged_q && cudnn_runtime_version >= 90600) { - variant_pack[offset_stats] = devOffsetsS; + if (is_ragged_q) { + variant_pack[offset_q] = devOffsetsQ; + variant_pack[offset_o] = devOffsetsO; + } + if (is_ragged_kv) { + variant_pack[offset_k] = devOffsetsK; + variant_pack[offset_v] = devOffsetsV; + } + if (is_ragged_q && cudnn_runtime_version >= 90600) { + variant_pack[offset_stats] = devOffsetsS; + } } - } if (is_dropout) { variant_pack[dropout_seed] = devPtrDropoutSeed; diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index df1eae0dd7..f635014fca 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -428,6 +428,14 @@ __device__ void cu_seqlens_padded_to_offsets_impl( OFFSETS_T *offsets_v, OFFSETS_T *offsets_o, OFFSETS_T *offsets_s) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; auto cu_seqlens_id = min(tid, actual_b); + if (tid == 0) { + printf("actual_b: %lld \n", (long long int)actual_b); + printf("max_b: %lld \n", (long long int)max_b); + printf("h: %lld \n", (long long int)h); + printf("hg: %lld \n", (long long int)hg); + printf("d_qk: %lld \n", (long long int)d_qk); + printf("d_v: %lld \n", (long long int)d_v); + } if (tid <= max_b) { if (offsets_s != nullptr) { offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 74f50bcae4..ed2d11cad6 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -536,12 +536,13 @@ def _segment_ids_pos_to_seqlens_offsets( # It does not need to involve SW for this mask's creation # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - if (attn_mask_type.is_causal() and window_size is None) or ( - window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - ): - return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( - segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq - ) + #TODO: Un comment the fast path + # if (attn_mask_type.is_causal() and window_size is None) or ( + # window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + # ): + # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( + # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq + # ) # (1 = attend, 0 = masked) segment_mask = make_attention_mask( @@ -554,6 +555,7 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids_kv, lambda x, y: jnp.equal(x, y) * x, ) + #jax.debug.breakpoint() # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied attn_mask = segment_mask if attn_mask_type.is_bottom_right(): @@ -600,6 +602,7 @@ def _segment_ids_pos_to_seqlens_offsets( q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( attn_mask_with_id, max_segments_per_seq ) + #jax.debug.breakpoint() return q_seqlen, kv_seqlen, q_offset, kv_offset @@ -679,6 +682,7 @@ def get_seqlens_and_offsets( window_size, max_segments_per_seq, ) + #jax.debug.breakpoint() else: q_seqlens, kv_seqlens = _segment_ids_to_seqlens( q_segment_ids, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 8ddccde763..d36425c11a 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2948,7 +2948,7 @@ def fused_attn_fwd( primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive else: primitive = FusedRingAttnFwdPrimitive.outer_primitive - + print(f"qkv_for_primitive: \n {qkv_for_primitive}") seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) output, softmax_aux, rng_state = primitive.bind( *qkv_for_primitive, From 3da94e538e7aee3ffb0ca796dd22efb39daefee2 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Wed, 5 Nov 2025 14:10:37 -0800 Subject: [PATCH 08/36] Add comments in primitive registration process Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/cpp_extensions/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 556b587191..ee62dc9128 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -176,6 +176,9 @@ def shardy_sharding_rule(*args): def register_primitive(cls, outer_only=False): """ Register a JAX primitive and add it to the internal registry. + Inner primitive - single device, no sharding awareness, eager mode fallback + Outer primitive - multi device, sharding aware, partition() distributes work, + used when there's a dev mesh context """ _primitive_registry[cls.__name__] = cls @@ -190,15 +193,17 @@ def name_of_wrapper_p(): inner_p = core.Primitive(cls.name) dispatch.prim_requires_devices_during_lowering.add(inner_p) inner_p.multiple_results = cls.multiple_results + # Define eager execution implementation inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_abstract_eval(cls.abstract) mlir.register_lowering(inner_p, cls.lowering, platform="cuda") cls.inner_primitive = inner_p + # Create the outer primitive for distributed execution outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results - outer_p.def_impl(cls.outer_impl) + # Define the eager execution implementation outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) From a51b7d941cb7429289d8bf71b409254f6a001749 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 13 Nov 2025 11:48:58 -0800 Subject: [PATCH 09/36] TMP: Throwaway test commit Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 3 - tests/jax/utils.py | 1 - .../fused_attn_f16_arbitrary_seqlen.cu | 9 + .../jax/cpp_extensions/attention.py | 210 +++++++++++++++++- 4 files changed, 209 insertions(+), 14 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index d738b46853..082522377c 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -464,7 +464,6 @@ def _setup_inputs(self): self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1) self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1) - breakpoint() key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6) @@ -696,7 +695,6 @@ def generate_random_segment_ids( self.cp_reorder_fn(self.segment_pos_kv), ), ) - breakpoint() case _: raise ValueError(f"Unknown {self.seq_desc_format=}") else: @@ -761,7 +759,6 @@ def to_dp_shardings(x): self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor) - #jax.debug.breakpoint() if self.bias_shape == BiasShape._1HSS: self.bias_pspec = PartitionSpec( None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 571da23a76..4e37048890 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1512,7 +1512,6 @@ def assert_allclose( mismatch_counter = 0 has_nonzero = jnp.any(actual != 0) print(f"has_nonzero: {has_nonzero}") - breakpoint() with np.printoptions(threshold=sys.maxsize): mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch diff_indices = np.argwhere(mismatch_mask) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1fbdee492a..69e97c7555 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -506,6 +506,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( else { print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, 0, actual_b, /*does not matter for single row*/actual_b); + cudaDeviceSynchronize(); } } if (devActualSeqlenQ) { @@ -521,6 +522,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } else { print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, 0, actual_b, /*does not matter for single row*/actual_b); + cudaDeviceSynchronize(); } } if(devPtrCuSeqlensKV) { @@ -538,6 +540,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } else { print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + cudaDeviceSynchronize(); } } if(devActualSeqlenKV) { @@ -556,6 +559,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( else { print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + cudaDeviceSynchronize(); } } } @@ -597,6 +601,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); + cudaDeviceSynchronize(); NVTE_CHECK_CUDA(cudaGetLastError()); if (print_tensors) { if (devPtrSeqOffsetsQ) { @@ -614,6 +619,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devPtrSeqOffsetsQ), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + cudaDeviceSynchronize(); } } if (devOffsetsQ) { @@ -631,6 +637,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devOffsetsQ), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + cudaDeviceSynchronize(); } } if (devPtrSeqOffsetsKV) { @@ -648,6 +655,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devPtrSeqOffsetsKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + cudaDeviceSynchronize(); } } if (devOffsetsK) { @@ -665,6 +673,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devOffsetsK), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + cudaDeviceSynchronize(); } } } diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index d36425c11a..df6a074b22 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -521,6 +521,37 @@ def impl( _kv_segment_pos, config: _FusedAttnConfig, ): + DEBUG = True #os.environ.get("TE_DEBUG_STRIPED_ATTN", "0") == "1" + # if DEBUG: + # jax.debug.print("FusedAttnFwdPrimitive.impl CALLED") + # jax.debug.print("Config: qkv_layout={}, attn_mask_type={}", + # str(config.qkv_layout), str(config.attn_mask_type)) + # jax.debug.print("Input shapes:") + # jax.debug.print(" q={}, k={}, v={}", q.shape, k.shape, v.shape) + # jax.debug.print(" q_seqlen={}, kv_seqlen={}", q_seqlen.shape, kv_seqlen.shape) + + # def print_impl_inputs(q_val, k_val, v_val, q_seq, kv_seq, q_off, k_off): + # print(f"\n~~~ FusedAttnFwdPrimitive.impl INPUTS ~~~") + # print(f"Q: shape={q_val.shape}, mean={q_val.mean():.6f}, std={q_val.std():.6f}") + # print(f" First 5: {q_val.flatten()[:5]}") + + # print(f"K: shape={k_val.shape}, mean={k_val.mean():.6f}, std={k_val.std():.6f}") + # print(f" First 5: {k_val.flatten()[:5]}") + + # print(f"V: shape={v_val.shape}, mean={v_val.mean():.6f}, std={v_val.std():.6f}") + # print(f" First 5: {v_val.flatten()[:5]}") + + # print(f"\nSequence info:") + # print(f" q_seqlen: {q_seq}") + # print(f" kv_seqlen: {kv_seq}") + # print(f" q_seq_offsets: {q_off}") + # print(f" k_seq_offsets: {k_off}") + # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") + + # jax.debug.callback( + # print_impl_inputs, + # q, k, v, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets + # ) assert FusedAttnFwdPrimitive.inner_primitive is not None sequence_descriptor = SequenceDescriptor( @@ -529,7 +560,6 @@ def impl( segment_ids=(_q_segment_ids, _kv_segment_ids), segment_pos=(_q_segment_pos, _kv_segment_pos), ) - breakpoint() (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( sequence_descriptor.get_seqlens_and_offsets( config.attn_mask_type, @@ -538,9 +568,23 @@ def impl( config.max_segments_per_seq, ) ) - #jax.debug.breakpoint() - + # if DEBUG: + # jax.debug.print("After sequence_descriptor processing:") + # jax.debug.print(" q_seqlen={}, kv_seqlen={}", q_seqlen.shape, kv_seqlen.shape) + + # def print_seq_descriptor(q_seq, kv_seq, q_off, k_off): + # print(f"\n~~~ SEQUENCE DESCRIPTOR OUTPUTS ~~~") + # print(f"q_seqlen (processed): {q_seq}") + # print(f"kv_seqlen (processed): {kv_seq}") + # print(f"q_seq_offsets (processed): {q_off}") + # print(f"k_seq_offsets (processed): {k_off}") + # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") + + # jax.debug.callback(print_seq_descriptor, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets) + #jax.debug.print("Hello FA impl") if config.qkv_layout.is_thd(): + # if DEBUG: + # jax.debug.print("Processing THD layout...") def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -564,6 +608,10 @@ def convert_to_2d(offsets, batch, max_seqlen): assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}" kv_batch = q_batch = batch[0] + # if DEBUG: + # jax.debug.print(" batch={}, q_max_seqlen={}, kv_max_seqlen={}", + # q_batch, q_max_seqlen, kv_max_seqlen) + # Gather valid q_seqlen, which is greater than 0 # cuDNN version < 9.3.0: # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] @@ -592,9 +640,36 @@ def convert_to_2d(offsets, batch, max_seqlen): k_seq_offsets = _fix_len_take( k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen ) + # if DEBUG: + # def print_thd_processing(q_seq, kv_seq, q_off, k_off): + # print(f"\n~~~ AFTER THD PROCESSING ~~~") + # print(f"q_seqlen (fixed): {q_seq}") + # print(f"kv_seqlen (fixed): {kv_seq}") + # print(f"q_seq_offsets (2d): {q_off}") + # print(f"k_seq_offsets (2d): {k_off}") + # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") + + # jax.debug.callback(print_thd_processing, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) + # jax.debug.print(f"q_seqlen: {q_seqlen}, kv_seqlen: {kv_seqlen}") + # if DEBUG: + # jax.debug.print("Generated cumulative sequence lengths:") + # jax.debug.print(" q_cu_seqlen={}", q_cu_seqlen.shape) + # jax.debug.print(" kv_cu_seqlen={}", kv_cu_seqlen.shape) + + # def print_cu_seqlen(q_cu, kv_cu): + # print(f"\n~~~ CUMULATIVE SEQLENS ~~~") + # print(f"q_cu_seqlen: {q_cu}") + # print(f"kv_cu_seqlen: {kv_cu}") + # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~\n") + + # jax.debug.callback(print_cu_seqlen, q_cu_seqlen, kv_cu_seqlen) + + # if DEBUG: + # jax.debug.print("Calling inner_primitive.bind...") + output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( q, @@ -1498,7 +1573,7 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[5] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - jax.debug.breakpoint() + #jax.debug.breakpoint() def impl( q, @@ -1518,7 +1593,8 @@ def impl( ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) - breakpoint() + # jax.debug.print("Test CP DC AG") + #breakpoint() # cuDNN does not support right-aligned masking with dynamic sequence length padding. # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch @@ -1529,6 +1605,8 @@ def impl( def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): kv_max_seqlen = k.shape[1] kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) + #jax.debug.print("Test AG") + #jax.debug.print(f"kv_max_seqlen: {kv_max_seqlen}") assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" q_split = jnp.split(q, 2, axis=1) @@ -1538,7 +1616,7 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): ) results = [] - breakpoint() + #breakpoint() for sub_idx in range(2): if config.attn_mask_type == AttnMaskType.NO_MASK: k_unmasked, v_unmasked = k, v # full kv used for unmasked @@ -1548,7 +1626,7 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): q_seqlen_for_step = q_seqlen / (cp_size * 2) num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks - breakpoint() + #breakpoint() output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_split[sub_idx], k_unmasked, @@ -1777,6 +1855,15 @@ class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): @staticmethod def partition(config, mesh, arg_infos, result_infos): + DEBUG = True #os.environ.get("TE_DEBUG_STRIPED_ATTN", "0") == "1" + if DEBUG: + print(f"STRIPED PARTITION CALLED (Compilation Phase)") + print(f"Mesh: {mesh}") + print(f"CP axis: {config.cp_axis}, size: {get_mesh_axis_size(config.cp_axis, mesh)}") + print(f"window_size: {config.window_size}, context_parallel_load_balanced: {config.context_parallel_load_balanced}, stripe_height: {config.stripe_height}") + print(f"Arg shapes: {[info.shape for info in arg_infos]}") + print(f"QKV layout: {config.qkv_layout}") + print(f"Attention mask type: {config.attn_mask_type}") # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 if not is_context_parallel: @@ -1794,7 +1881,6 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[4] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - #jax.debug.breakpoint() def impl( q, @@ -1813,7 +1899,39 @@ def impl( ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) - #jax.debug.breakpoint() + # jax.debug.print("Test CP striped AG") + # if DEBUG: + # jax.debug.print("STRIPED IMPL CALLED (Execution Phase)") + # jax.debug.print("cp_size={}, cp_rank={}", cp_size, cp_rank) + + # # Print input shapes + # jax.debug.print("INPUT SHAPES:") + # jax.debug.print(" q.shape={}, k.shape={}, v.shape={}", q.shape, k.shape, v.shape) + # jax.debug.print(" bias.shape={}", bias.shape if bias is not None else "None") + # jax.debug.print(" q_seqlen.shape={}, kv_seqlen.shape={}", q_seqlen.shape, kv_seqlen.shape) + + # Print actual input values + # def print_inputs(q_val, k_val, v_val, q_seq_val, kv_seq_val, rank): + # print(f"\n--- STRIPED INPUTS (Rank {rank}) ---") + # print(f"Q: shape={q_val.shape}, dtype={q_val.dtype}") + # print(f" mean={q_val.mean():.6f}, std={q_val.std():.6f}") + # print(f" min={q_val.min():.6f}, max={q_val.max():.6f}") + # print(f" First 10 values: {q_val.flatten()[:10]}") + + # print(f"\nK: shape={k_val.shape}, dtype={k_val.dtype}") + # print(f" mean={k_val.mean():.6f}, std={k_val.std():.6f}") + # print(f" First 10 values: {k_val.flatten()[:10]}") + + # print(f"\nV: shape={v_val.shape}, dtype={v_val.dtype}") + # print(f" mean={v_val.mean():.6f}, std={v_val.std():.6f}") + # print(f" First 10 values: {v_val.flatten()[:10]}") + + # print(f"\nSequence lengths:") + # print(f" q_seqlen: {q_seq_val}") + # print(f" kv_seqlen: {kv_seq_val}") + # print(f"--------------------------------------\n") + + # jax.debug.callback(print_inputs, q, k, v, q_seqlen, kv_seqlen, cp_rank) # cuDNN does not support right-aligned masking with dynamic sequence length padding. # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch @@ -1822,6 +1940,35 @@ def impl( # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding # mask/sequence length tensor to avoid this unrolled loop. def _cross_attn(idx, q, k, v, bias, kv_segment_ids_ag, kv_segment_pos_ag, seed): + # if DEBUG: + # jax.debug.print("\n--- _cross_attn called: idx={} ---", idx) + # jax.debug.print(" q.shape={}, k.shape={}, v.shape={}", q.shape, k.shape, v.shape) + + # def print_cross_attn_inputs(q_val, k_val, v_val, kv_seg_ids, kv_seg_pos, idx_val): + # print(f"\n--- CROSS ATTN INPUTS (idx={idx_val}) ---") + # print(f"Q (local): shape={q_val.shape}") + # print(f" mean={q_val.mean():.6f}, std={q_val.std():.6f}") + # print(f" First 5 values: {q_val.flatten()[:5]}") + + # print(f"\nK (all-gathered): shape={k_val.shape}") + # print(f" mean={k_val.mean():.6f}, std={k_val.std():.6f}") + # print(f" First 5 values: {k_val.flatten()[:5]}") + + # print(f"\nV (all-gathered): shape={v_val.shape}") + # print(f" mean={v_val.mean():.6f}, std={v_val.std():.6f}") + # print(f" First 5 values: {v_val.flatten()[:5]}") + + # print(f"\nKV segment IDs: shape={kv_seg_ids.shape}") + # print(f" {kv_seg_ids.flatten()[:20]}") + + # print(f"\nKV segment positions: shape={kv_seg_pos.shape}") + # print(f" {kv_seg_pos.flatten()[:20]}") + # print(f"--------------------------------------------\n") + + # jax.debug.callback( + # print_cross_attn_inputs, + # q, k, v, kv_segment_ids_ag, kv_segment_pos_ag, idx + # ) output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q, k, #ag @@ -1838,11 +1985,55 @@ def _cross_attn(idx, q, k, v, bias, kv_segment_ids_ag, kv_segment_pos_ag, seed): kv_segment_pos_ag, config=helper.get_step_config(), ) + # if DEBUG: + # jax.debug.print(" Output from FusedAttnFwdPrimitive.impl:") + # jax.debug.print(" output.shape={}", output.shape) + # jax.debug.print(" softmax_aux.shape={}", softmax_aux.shape) + + # def print_cross_attn_outputs(out, softmax, idx_val): + # print(f"\n--- CROSS ATTN OUTPUTS (idx={idx_val}) ---") + # print(f"Output: shape={out.shape}") + # print(f" mean={out.mean():.6f}, std={out.std():.6f}") + # print(f" min={out.min():.6f}, max={out.max():.6f}") + # print(f" First 10 values: {out.flatten()[:10]}") + + # print(f"\nSoftmax aux: shape={softmax.shape}") + # print(f" mean={softmax.mean():.6f}") + # print(f" First 10 values: {softmax.flatten()[:10]}") + # print(f"--------------------------------------------\n") + return output, softmax_aux, rng_state k_ag, v_ag = helper.all_gather_kv(k, v) _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(_kv_segment_ids, _kv_segment_pos) + # if DEBUG: + # jax.debug.print("After all-gather:") + # jax.debug.print(" k_ag.shape={}, v_ag.shape={}", k_ag.shape, v_ag.shape) + # jax.debug.print(" kv_segment_ids_ag.shape={}, kv_segment_pos_ag.shape={}", + # _kv_segment_ids_ag.shape, _kv_segment_pos_ag.shape) + + # def print_all_gathered(k_gathered, v_gathered, seg_ids, seg_pos, rank): + # print(f"\n--- ALL-GATHERED DATA (Rank {rank}) ---") + # print(f"K (all-gathered): shape={k_gathered.shape}") + # print(f" mean={k_gathered.mean():.6f}, std={k_gathered.std():.6f}") + # print(f" First 5 values: {k_gathered.flatten()[:5]}") + + # print(f"\nV (all-gathered): shape={v_gathered.shape}") + # print(f" mean={v_gathered.mean():.6f}, std={v_gathered.std():.6f}") + # print(f" First 5 values: {v_gathered.flatten()[:5]}") + + # print(f"\nKV segment IDs (all-gathered): shape={seg_ids.shape}") + # print(f" {seg_ids.flatten()[:30]}") + + # print(f"\nKV segment pos (all-gathered): shape={seg_pos.shape}") + # print(f" {seg_pos.flatten()[:30]}") + # print(f"----------------------------------------\n") + + # jax.debug.callback( + # print_all_gathered, + # k_ag, v_ag, _kv_segment_ids_ag, _kv_segment_pos_ag, cp_rank + # ) functions = [ partial(_cross_attn, idx, q, k_ag, v_ag, bias, _kv_segment_ids_ag, _kv_segment_pos_ag, seed) for idx in range(cp_size) @@ -2958,7 +3149,6 @@ def fused_attn_fwd( *seq_desc_flatten, config=fused_config, ) - breakpoint() rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None)) return (output, softmax_aux, rng_state) From fcee4f440ebff61f1d00f6cb431a0d10effd1638 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 20 Nov 2025 22:24:32 -0800 Subject: [PATCH 10/36] Undoing incorrect rebase/merge leftovers Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/distributed_test_base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 170e13e054..137fa480dd 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -154,13 +154,9 @@ def compare_ops( grad_args = tuple(range(len(inputs))) target_grad_func = jax.value_and_grad(target_func, argnums=grad_args) -<<<<<<< HEAD target_jitter = jax.jit( target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings ) -======= - target_jitter = jax.jit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) ->>>>>>> 57b57292 (Fix imports in test for deprecated jax.experimental.pjit) target_fwd, target_grads = target_jitter(*inputs, **kwargs) target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text() From 5dfdaf59a9e13bf73da985587485f087f19ba524 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 20 Nov 2025 22:25:10 -0800 Subject: [PATCH 11/36] TMP: Throwaway test commits Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_fused_attn.py | 8 +- tests/jax/utils.py | 5 +- .../fused_attn_f16_arbitrary_seqlen.cu | 15 +- .../jax/cpp_extensions/attention.py | 154 ++++-------------- 4 files changed, 46 insertions(+), 136 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 082522377c..3e551ee934 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -598,7 +598,7 @@ def generate_random_segment_ids( if self.qkv_layout.is_thd(): self.num_segments_per_seq = 2 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( - self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 + self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=12 ) self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) # TODO(rewang): record only self attention and find the reason of cross attention @@ -644,9 +644,9 @@ def generate_random_segment_ids( self.window_size, ) # KL tet code - import sys - with np.printoptions(threshold=sys.maxsize): - print(f"self.mask: \n {self.mask}") + # import sys + # with np.printoptions(threshold=sys.maxsize): + # print(f"self.mask: \n {self.mask}") if self.cp_size > 1 and self.cp_load_balanced: if self.qkv_layout.is_thd(): diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 4e37048890..6a45ab437a 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1515,11 +1515,14 @@ def assert_allclose( with np.printoptions(threshold=sys.maxsize): mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch diff_indices = np.argwhere(mismatch_mask) + seq_set = set() for idx in diff_indices: idx_tuple = tuple(idx) - mismatch_counter += 1 + seq_set.add(idx_tuple[1]) + mismatch_counter += 1 if mismatch_counter < 1024: print(f"Index {idx_tuple}: a={actual[idx_tuple]}, d={desired[idx_tuple]}") + print(f"{sorted(seq_set)}") # Batch 0 and head 0 # for seq_idx in range(actual.shape[1]): # #print("Mismatch at positions:\n", np.argwhere(mismatch_mask[0,:,0,:])) # Pick indices where mask is True diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 69e97c7555..5056942d1a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -487,9 +487,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); NVTE_CHECK_CUDA(cudaGetLastError()); - std::cout << "print_tensors: " << print_tensors << - "print_tensors_custom_mask: " - << print_tensors_custom_mask << std::endl; + //std::cout << "print_tensors: " << print_tensors << + // "print_tensors_custom_mask: " + // << print_tensors_custom_mask << std::endl; if (print_tensors) { if(devPtrCuSeqlensQ) { @@ -506,7 +506,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( else { print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, 0, actual_b, /*does not matter for single row*/actual_b); - cudaDeviceSynchronize(); } } if (devActualSeqlenQ) { @@ -522,7 +521,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } else { print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, 0, actual_b, /*does not matter for single row*/actual_b); - cudaDeviceSynchronize(); } } if(devPtrCuSeqlensKV) { @@ -540,7 +538,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } else { print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); - cudaDeviceSynchronize(); } } if(devActualSeqlenKV) { @@ -559,7 +556,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( else { print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); - cudaDeviceSynchronize(); } } } @@ -601,7 +597,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); - cudaDeviceSynchronize(); NVTE_CHECK_CUDA(cudaGetLastError()); if (print_tensors) { if (devPtrSeqOffsetsQ) { @@ -619,7 +614,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devPtrSeqOffsetsQ), 1, 0, actual_b, /*does not matter for single row*/ actual_b); - cudaDeviceSynchronize(); } } if (devOffsetsQ) { @@ -637,7 +631,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devOffsetsQ), 1, 0, actual_b, /*does not matter for single row*/ actual_b); - cudaDeviceSynchronize(); } } if (devPtrSeqOffsetsKV) { @@ -655,7 +648,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devPtrSeqOffsetsKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); - cudaDeviceSynchronize(); } } if (devOffsetsK) { @@ -673,7 +665,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devOffsetsK), 1, 0, actual_b, /*does not matter for single row*/ actual_b); - cudaDeviceSynchronize(); } } } diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index df6a074b22..ffa12c0bd6 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -581,7 +581,7 @@ def impl( # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") # jax.debug.callback(print_seq_descriptor, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets) - #jax.debug.print("Hello FA impl") + # jax.debug.print("Hello FA impl") if config.qkv_layout.is_thd(): # if DEBUG: # jax.debug.print("Processing THD layout...") @@ -1593,7 +1593,7 @@ def impl( ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) - # jax.debug.print("Test CP DC AG") + #jax.debug.print("Test CP DC AG") - Gives a seg fault #breakpoint() # cuDNN does not support right-aligned masking with dynamic sequence length padding. @@ -1605,7 +1605,7 @@ def impl( def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): kv_max_seqlen = k.shape[1] kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) - #jax.debug.print("Test AG") + # jax.debug.print("Test cross attn ag") - Gives a seg fault #jax.debug.print(f"kv_max_seqlen: {kv_max_seqlen}") assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" @@ -1881,6 +1881,10 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[4] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + if DEBUG: + print(f"STRIPED PARTITION CALLED (Compilation Phase)") + print(f"Arg shardings: {[arg_i.sharding for arg_i in arg_infos]}") + print(f"Out shardings: {[out_i for out_i in out_shardings]}") def impl( q, @@ -1900,38 +1904,6 @@ def impl( cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) # jax.debug.print("Test CP striped AG") - # if DEBUG: - # jax.debug.print("STRIPED IMPL CALLED (Execution Phase)") - # jax.debug.print("cp_size={}, cp_rank={}", cp_size, cp_rank) - - # # Print input shapes - # jax.debug.print("INPUT SHAPES:") - # jax.debug.print(" q.shape={}, k.shape={}, v.shape={}", q.shape, k.shape, v.shape) - # jax.debug.print(" bias.shape={}", bias.shape if bias is not None else "None") - # jax.debug.print(" q_seqlen.shape={}, kv_seqlen.shape={}", q_seqlen.shape, kv_seqlen.shape) - - # Print actual input values - # def print_inputs(q_val, k_val, v_val, q_seq_val, kv_seq_val, rank): - # print(f"\n--- STRIPED INPUTS (Rank {rank}) ---") - # print(f"Q: shape={q_val.shape}, dtype={q_val.dtype}") - # print(f" mean={q_val.mean():.6f}, std={q_val.std():.6f}") - # print(f" min={q_val.min():.6f}, max={q_val.max():.6f}") - # print(f" First 10 values: {q_val.flatten()[:10]}") - - # print(f"\nK: shape={k_val.shape}, dtype={k_val.dtype}") - # print(f" mean={k_val.mean():.6f}, std={k_val.std():.6f}") - # print(f" First 10 values: {k_val.flatten()[:10]}") - - # print(f"\nV: shape={v_val.shape}, dtype={v_val.dtype}") - # print(f" mean={v_val.mean():.6f}, std={v_val.std():.6f}") - # print(f" First 10 values: {v_val.flatten()[:10]}") - - # print(f"\nSequence lengths:") - # print(f" q_seqlen: {q_seq_val}") - # print(f" kv_seqlen: {kv_seq_val}") - # print(f"--------------------------------------\n") - - # jax.debug.callback(print_inputs, q, k, v, q_seqlen, kv_seqlen, cp_rank) # cuDNN does not support right-aligned masking with dynamic sequence length padding. # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch @@ -1939,101 +1911,45 @@ def impl( # meeting the expectation of the SPMD model. # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding # mask/sequence length tensor to avoid this unrolled loop. + + # Each rank receives the ag k and v along with the ag kv seg ids and kv seg offsets + # Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos, + # _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in. def _cross_attn(idx, q, k, v, bias, kv_segment_ids_ag, kv_segment_pos_ag, seed): - # if DEBUG: - # jax.debug.print("\n--- _cross_attn called: idx={} ---", idx) - # jax.debug.print(" q.shape={}, k.shape={}, v.shape={}", q.shape, k.shape, v.shape) - - # def print_cross_attn_inputs(q_val, k_val, v_val, kv_seg_ids, kv_seg_pos, idx_val): - # print(f"\n--- CROSS ATTN INPUTS (idx={idx_val}) ---") - # print(f"Q (local): shape={q_val.shape}") - # print(f" mean={q_val.mean():.6f}, std={q_val.std():.6f}") - # print(f" First 5 values: {q_val.flatten()[:5]}") - - # print(f"\nK (all-gathered): shape={k_val.shape}") - # print(f" mean={k_val.mean():.6f}, std={k_val.std():.6f}") - # print(f" First 5 values: {k_val.flatten()[:5]}") - - # print(f"\nV (all-gathered): shape={v_val.shape}") - # print(f" mean={v_val.mean():.6f}, std={v_val.std():.6f}") - # print(f" First 5 values: {v_val.flatten()[:5]}") - - # print(f"\nKV segment IDs: shape={kv_seg_ids.shape}") - # print(f" {kv_seg_ids.flatten()[:20]}") - - # print(f"\nKV segment positions: shape={kv_seg_pos.shape}") - # print(f" {kv_seg_pos.flatten()[:20]}") - # print(f"--------------------------------------------\n") - - # jax.debug.callback( - # print_cross_attn_inputs, - # q, k, v, kv_segment_ids_ag, kv_segment_pos_ag, idx - # ) + # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive + # Do not forget to unset the segment_ids and segment_pos so that the seqlens_from_segment_ids_pos() function + # does not go down that route but instead just picks the seqlens and offsets passed onto it + + kv_max_seqlen = k.shape[1] + # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq + adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(max_seqlen=kv_max_seqlen, cp_size=cp_size) + q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq) + q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + kv_num_segments_for_rank, kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(kv_segment_ids=_kv_segment_ids, kv_segment_pos=_kv_segment_pos, max_segments_per_seq=adjusted_max_segments_per_seq) + kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(kv_segment_pos=_kv_segment_pos, kv_segment_ids=_kv_segment_ids, kv_segment_pos_ag=kv_segment_pos_ag, kv_num_segments=kv_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + #kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( - q, + q, #sharded for rank k, #ag v, #ag bias, seed, - q_seqlen, - kv_seqlen, - q_seq_offsets, - k_seq_offsets, - _q_segment_ids, - kv_segment_ids_ag, - _q_segment_pos, - kv_segment_pos_ag, - config=helper.get_step_config(), + q_seqlens_for_rank, + kv_seqlens_for_rank, + q_seq_offsets_for_rank, + kv_seq_offsets_for_rank, + q_seqlen, #Should be empty ids but using placeholder + kv_seqlen, #Should be empty poss but using placeholder + q_seq_offsets, #Should be empty ids but using placeholder + k_seq_offsets, #Should be empty pos but using placeholder + config=helper.get_step_config_for_striped(max_seqlen=kv_max_seqlen, cp_size=cp_size), ) - # if DEBUG: - # jax.debug.print(" Output from FusedAttnFwdPrimitive.impl:") - # jax.debug.print(" output.shape={}", output.shape) - # jax.debug.print(" softmax_aux.shape={}", softmax_aux.shape) - - # def print_cross_attn_outputs(out, softmax, idx_val): - # print(f"\n--- CROSS ATTN OUTPUTS (idx={idx_val}) ---") - # print(f"Output: shape={out.shape}") - # print(f" mean={out.mean():.6f}, std={out.std():.6f}") - # print(f" min={out.min():.6f}, max={out.max():.6f}") - # print(f" First 10 values: {out.flatten()[:10]}") - - # print(f"\nSoftmax aux: shape={softmax.shape}") - # print(f" mean={softmax.mean():.6f}") - # print(f" First 10 values: {softmax.flatten()[:10]}") - # print(f"--------------------------------------------\n") - return output, softmax_aux, rng_state k_ag, v_ag = helper.all_gather_kv(k, v) + # Only the pos is needed for kv offsets calculation _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(_kv_segment_ids, _kv_segment_pos) - - # if DEBUG: - # jax.debug.print("After all-gather:") - # jax.debug.print(" k_ag.shape={}, v_ag.shape={}", k_ag.shape, v_ag.shape) - # jax.debug.print(" kv_segment_ids_ag.shape={}, kv_segment_pos_ag.shape={}", - # _kv_segment_ids_ag.shape, _kv_segment_pos_ag.shape) - - # def print_all_gathered(k_gathered, v_gathered, seg_ids, seg_pos, rank): - # print(f"\n--- ALL-GATHERED DATA (Rank {rank}) ---") - # print(f"K (all-gathered): shape={k_gathered.shape}") - # print(f" mean={k_gathered.mean():.6f}, std={k_gathered.std():.6f}") - # print(f" First 5 values: {k_gathered.flatten()[:5]}") - - # print(f"\nV (all-gathered): shape={v_gathered.shape}") - # print(f" mean={v_gathered.mean():.6f}, std={v_gathered.std():.6f}") - # print(f" First 5 values: {v_gathered.flatten()[:5]}") - - # print(f"\nKV segment IDs (all-gathered): shape={seg_ids.shape}") - # print(f" {seg_ids.flatten()[:30]}") - - # print(f"\nKV segment pos (all-gathered): shape={seg_pos.shape}") - # print(f" {seg_pos.flatten()[:30]}") - # print(f"----------------------------------------\n") - - # jax.debug.callback( - # print_all_gathered, - # k_ag, v_ag, _kv_segment_ids_ag, _kv_segment_pos_ag, cp_rank - # ) functions = [ partial(_cross_attn, idx, q, k_ag, v_ag, bias, _kv_segment_ids_ag, _kv_segment_pos_ag, seed) for idx in range(cp_size) From f6fb305d92050e785ba5dc3e9ccbce55766a8278 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 20 Nov 2025 22:30:21 -0800 Subject: [PATCH 12/36] Add support for calculating q and kv seqlens and offsets per rank for CP+THD+AG+SW+Striped>1 primitive Signed-off-by: Kshitij Janardan Lakhani --- .../jax/cpp_extensions/attention.py | 241 ++++++++++++++++++ 1 file changed, 241 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index ffa12c0bd6..49610150b1 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1402,6 +1402,10 @@ def get_adjusted_mask(self): if self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK and self.config.qkv_layout.is_thd(): # THD only ? return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type + + def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): + # Estimating + return (max_seqlen // (self.config.stripe_height*cp_size)) + self.config.max_segments_per_seq def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" @@ -1421,6 +1425,24 @@ def get_step_config(self) -> _FusedAttnConfig: cp_striped_window_size=None, stripe_height=self.config.stripe_height, ) + + def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: + """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + #TODO: Should the max_segments_per_seq be different ? + return _FusedAttnConfig( + attn_bias_type=self.config.attn_bias_type, + attn_mask_type=self.get_adjusted_mask(), + qkv_layout=self.config.qkv_layout, + scaling_factor=self.config.scaling_factor, + dropout_probability=self.config.dropout_probability, + is_training=self.config.is_training, + max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size), + window_size=self.config.window_size, + context_parallel_load_balanced=self.config.context_parallel_load_balanced, + cp_axis=self.config.cp_axis, + cp_striped_window_size=None, + stripe_height=self.config.stripe_height, + ) def all_gather_kv(self, k, v): """Performs a all-gather of k and v over context parallel ranks.""" @@ -1542,6 +1564,225 @@ def pad(x, npad): return pad(dk, npad), pad(dv, npad) return dk, dv # fall through + + #TODO: max_segments_per_seq - might need some modifications for per rank compute as it won't be the same as the FA packed representation - maybe (max_segments_per_seq_new = seqlens/stripe_height + max_segments_per_seq) + #TODO: Do take a look at other implementations to check if flattening required ? + #QUESTION: Do take a look at other implementations to check if flattening required ? + def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): + q_segment_ids_flat = q_segment_ids.reshape(-1) + q_segment_pos_flat = q_segment_pos.reshape(-1) + + # Create mask for non-zero segment IDs + non_zero_mask = q_segment_ids_flat != 0 + # Calculate indices from mask + max_size = q_segment_ids_flat.shape[0] + # Non zero segment id indices followed by padding of -1 at the end to get static size + non_zero_indices = jnp.where( + non_zero_mask, + size=max_size, + fill_value=-1 + )[0] + # print(f"{non_zero_indices=}") + # Pick non zero seg ids and seg pos + valid_segment_ids = jnp.where(non_zero_indices >= 0, q_segment_ids_flat[non_zero_indices], 0) + valid_segment_pos = jnp.where(non_zero_indices >= 0, q_segment_pos_flat[non_zero_indices], 0) + # print(f"{valid_segment_ids=}, {valid_segment_pos=}") + + # Create mask for actual valid entries (not padding) + # All Trues in the beginning for valid segment ids followed by padding of False + actual_valid_mask = valid_segment_ids != 0 + # print(f"{actual_valid_mask=}") + + # Detect segment changes, accounting for padding + # First element is True only if it's actually valid + first_is_segment = actual_valid_mask[0:1] + + segment_changes = jnp.concatenate([ + first_is_segment, # First valid element starts a segment + ((valid_segment_ids[1:] != valid_segment_ids[:-1]) | # Segment ID changed + (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1)) & # Position not consecutive + actual_valid_mask[1:] # Only consider actually valid positions + ]) + + # Create new segment IDs + new_segment_ids = jnp.cumsum(segment_changes) + # print(f"{new_segment_ids=}") + + # Can't use len() on traced values - use jnp.max instead + max_new_segments_per_seq = jnp.max(jnp.where(actual_valid_mask, new_segment_ids, 0)) + # print(f"{max_new_segments_per_seq=}") + + # Use bincount with a safe length + # Add 1 to handle 0-indexing, and ensure it's at least max_segments_per_seq + seqlens_all = jnp.bincount( + jnp.where(actual_valid_mask, new_segment_ids, 0).astype(jnp.int32), + length=max_segments_per_seq + )[1:] + # print(f"{seqlens_all=}") + + # Pad 0 at start prior to cumsum + seqlens_padded = jnp.concatenate([jnp.array([0]), seqlens_all]) + # print(f"{seqlens_padded=}") + cum_seqlens_padded = jnp.cumsum(seqlens_padded) # TODO:Momentarily comment off + #print(f"{cum_seqlens_padded=}") + + return max_new_segments_per_seq, seqlens_padded + + #QUESTION: Do take a look at other implementations to check if flattening required ? + def q_seqoffsets_for_striped_for_rank(self, q_segment_pos, q_num_segments, max_segments_per_seq): + q_segment_pos_flat = q_segment_pos.reshape(-1) + # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 + # e.g. 01230000124567 + segment_changes = jnp.concatenate([ + jnp.array([True]), # First valid element starts a segment + (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed + ]) + #print(f"{segment_changes=}") + max_size = q_segment_pos_flat.shape[0] + seq_offsets_2 = jnp.argwhere(segment_changes, size=max_size, fill_value=-1).flatten() + + # Create index array (static shape) + seq_offsets_2_indices = jnp.arange(seq_offsets_2.shape[0]) + # Create a mask (False: do not clip to the edge element, True: clip to edge element) + mask = seq_offsets_2_indices >= q_num_segments + # Get fill value dynamically by calculating the edge index + edge_index = jnp.clip(q_num_segments - 1, 0, seq_offsets_2.shape[0] - 1) + fill_value = seq_offsets_2[edge_index] + + seq_offsets = jnp.where(mask, fill_value, seq_offsets_2) + + return seq_offsets[:max_segments_per_seq] + + + # Per rank! + # Use full reordered kv seg id and offset(not) + # Look in every stripe_height section of kv_segment_ids + # i) if same as previous section segment id then add to seqlens counter or ii) if not same as previous section segment id then start seqlens counter + # monotonic constraint automatically applies the stripe_height constraint + #QUESTION: Do take a look at other implementations to check if flattening required ? + def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): + kv_segment_ids_flat = kv_segment_ids.reshape(-1) + kv_segment_pos_flat = kv_segment_pos.reshape(-1) + #print(f"{kv_segment_ids_flat=}, {kv_segment_pos_flat=}") + + # Create mask for non-zero segment IDs + non_zero_mask = kv_segment_ids_flat != 0 + #print(f"{non_zero_mask=}") + + # Filter to only non-zero segments + max_size = kv_segment_ids_flat.shape[0] + non_zero_indices = jnp.where( + non_zero_mask, + size=max_size, + fill_value=-1 + )[0] + valid_segment_ids = kv_segment_ids_flat[non_zero_indices] + valid_segment_pos = kv_segment_pos_flat[non_zero_indices] + actual_valid = valid_segment_ids != 0 + #print(f"{valid_segment_ids=}, {valid_segment_pos=}") + + # Detect segment breaks (only for non-zero segments) + segment_changes = jnp.concatenate([ + ((valid_segment_ids[1:] != valid_segment_ids[:-1]) & actual_valid[1:])| # Segment ID changed and not non zero + (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1), # Position not consecutive + jnp.array([True]), # Last valid element ends a segment + ]) + # Use the indices from segment_changes to pick out the offset value (which in turn will be the seq length for that segment) + segment_changes_valid = jnp.where(segment_changes & actual_valid, size=max_size, fill_value=-1)[0] + #print(f"{segment_changes_valid=}") + # Remove any + safe_indices = jnp.maximum(segment_changes_valid, 0) + #print(f"{safe_indices=}") + selected_values = jnp.where(safe_indices !=0, valid_segment_pos[safe_indices] + 1, 0) + seqlens = jnp.concatenate([jnp.array([0]), jnp.where(segment_changes_valid >= 0, selected_values, 0)[:-1]]) + seqlens_cumsum_padded = jnp.cumsum(seqlens) + #print(f"{result=}") + return jnp.count_nonzero(selected_values).astype(int), seqlens[:max_segments_per_seq] + + #QUESTION: Do take a look at other implementations to check if flattening required ? + def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_num_segments, max_segments_per_seq): + # Calculate the segment pos change mask + kv_segment_pos_flat = kv_segment_pos.reshape(-1) + kv_segment_ids_flat = kv_segment_ids.reshape(-1) + kv_segment_pos_ag_flat = kv_segment_pos_ag.reshape(-1) + #print(f"{kv_segment_pos_flat=}, {kv_segment_ids_flat=}") + # segment_changes=Array([ True, False, False, False, True, True, False, False, True, + # False, False, False, True, True, True, True], dtype=bool) + # segment_changes_first_false = jnp.concatenate([ + # jnp.array([False]), # Assume valid element starts a segment + # (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed + # ]) + segment_changes_first_true = jnp.concatenate([ + jnp.array([True]), # Assume valid element starts a segment + (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed + ]) + #segment_changes = jnp.where(kv_segment_ids_flat[0]==1, segment_changes_first_true, segment_changes_first_true) + #print(f"{segment_changes_first_true=}") + + # Get segment change indices for rank + #print(f"{jnp.size(segment_changes_first_true)=}") + segment_changes_indices = jnp.argwhere(segment_changes_first_true, size=jnp.size(segment_changes_first_true), fill_value=-1).flatten() + #print(f"{segment_changes_indices=}") + # Get segment ids associated with the segment_changes_indices for rank + segment_ids = kv_segment_ids_flat[segment_changes_indices] + #print(f"{segment_ids=}") + + # Get segment change indices for AG + segment_changes_ag_first_true = jnp.concatenate([ + jnp.array([True]), # Assume valid element starts a segment + (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1) # Segment pos changed + ]) + #print(f"{segment_changes_ag_first_true=}") + # Get segment change indices for AG + #print(f"{jnp.size(segment_changes_ag_first_true)=}") + segment_changes_ag_indices = jnp.argwhere(segment_changes_ag_first_true, size=jnp.size(segment_changes_ag_first_true), fill_value=-1).flatten() + #print(f"{segment_changes_ag_indices=}") + + # Use the segment ids picked per rank to get the offsets from the AG indices + seq_offsets = jnp.where(segment_ids !=0, segment_changes_ag_indices[segment_ids-1], 0) + #print(f"{seq_offsets=}") + indices = jnp.arange(0, seq_offsets.size) + #print(f"{indices=}") + #print(f"{kv_num_segments=}") + arr = jnp.ones_like(seq_offsets) * seq_offsets[kv_num_segments-1] + #print(f"{arr=}") + seq_offsets_truncated = jnp.where(indices >= kv_num_segments, arr, seq_offsets) + #print(f"{seq_offsets_truncated=}") + return seq_offsets_truncated[:max_segments_per_seq] + #TODO: Come up with better names/organization for the new four functions + # def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_num_segments, max_segments_per_seq): + # # Calculate the segment pos change mask + # kv_segment_pos_flat = kv_segment_pos.reshape(-1) + # segment_changes = jnp.concatenate([ + # jnp.array([True]), # First valid element starts a segment + # (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed + # ]) + # #print(f"{segment_changes=}") + + # # Calculate the offsets for the ag array + # kv_segment_pos_ag_flat = kv_segment_pos_ag.reshape(-1) + # #print(f"{kv_segment_pos_ag_flat=}") + # segment_changes_ag = jnp.concatenate([ + # (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1), # Segment pos changed + # jnp.array([False]) + # ]) + # # segment_changes_ag = (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1) # Segment pos changed + # #print(f"{segment_changes_ag=}") + # segment_offsets_ag = jnp.concatenate([jnp.array([0]), kv_segment_pos_ag_flat[segment_changes_ag] + 1]) # First valid element starts a segment + # #print(f"{segment_offsets_ag=}") + # segment_offsets_ag_cumsum = jnp.cumsum(segment_offsets_ag) + # #print(f"{segment_offsets_ag_cumsum=}") + + # # Use the segment_changes mask to find the segment ids where segments changes and then use the segment ids to pick the + # # offset value from the segment_offsets_ag_cumsum + # segment_change_ids = kv_segment_ids[segment_changes] - 1 + # #print(f"{segment_change_ids=}") + # seq_offsets = segment_offsets_ag_cumsum[segment_change_ids] + # seq_offsets_truncate = seq_offsets[:kv_num_segments] + # pad_width = jnp.maximum(0, max_segments_per_seq - seq_offsets_truncate[0].size) + # seq_offsets_padded = jnp.pad(seq_offsets_truncate, (0, pad_width), mode='edge') + # #print(f"{seq_offsets_truncate=}") + # return seq_offsets_padded class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): From 104b51e2ba6313069755e9ac6bda7b437ed74c35 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 20 Nov 2025 22:31:42 -0800 Subject: [PATCH 13/36] Augment jax primitive register code comments Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/cpp_extensions/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index ee62dc9128..5f717df13a 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -193,7 +193,7 @@ def name_of_wrapper_p(): inner_p = core.Primitive(cls.name) dispatch.prim_requires_devices_during_lowering.add(inner_p) inner_p.multiple_results = cls.multiple_results - # Define eager execution implementation + # Define eager execution implementation (by invoking it's MLIR lowering) inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_abstract_eval(cls.abstract) mlir.register_lowering(inner_p, cls.lowering, platform="cuda") From 24940847ff7c55f0178d994574d0d7a12ce5fde1 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Sat, 22 Nov 2025 17:17:08 -0800 Subject: [PATCH 14/36] Fix the array sizes and padding values returned for seqlens and offsets to fit what the fused attn primitive non cp computation Signed-off-by: Kshitij Lakhani --- .../jax/cpp_extensions/attention.py | 100 +++++++++++------- 1 file changed, 63 insertions(+), 37 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 49610150b1..43bf7eac7d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1619,17 +1619,19 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme length=max_segments_per_seq )[1:] # print(f"{seqlens_all=}") + seqlens_all_pad_neg = jnp.where(seqlens_all==0, -1, seqlens_all) # Pad 0 at start prior to cumsum - seqlens_padded = jnp.concatenate([jnp.array([0]), seqlens_all]) + #seqlens_padded = jnp.concatenate([jnp.array([0]), seqlens_all]) # print(f"{seqlens_padded=}") - cum_seqlens_padded = jnp.cumsum(seqlens_padded) # TODO:Momentarily comment off + #cum_seqlens_padded = jnp.cumsum(seqlens_padded) # TODO:Momentarily comment off #print(f"{cum_seqlens_padded=}") - return max_new_segments_per_seq, seqlens_padded + return max_new_segments_per_seq, seqlens_all_pad_neg #QUESTION: Do take a look at other implementations to check if flattening required ? - def q_seqoffsets_for_striped_for_rank(self, q_segment_pos, q_num_segments, max_segments_per_seq): + # TODO: q_num_segments not needed + def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq): q_segment_pos_flat = q_segment_pos.reshape(-1) # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 # e.g. 01230000124567 @@ -1638,20 +1640,38 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_pos, q_num_segments, max_s (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed ]) #print(f"{segment_changes=}") + # Remove any padded region segment changes + segment_changes_masked = jnp.where(q_segment_ids!=0, segment_changes, False) + #print(f"{segment_changes_masked=}") + # Get the indices for segment changes (these are the offsets) max_size = q_segment_pos_flat.shape[0] - seq_offsets_2 = jnp.argwhere(segment_changes, size=max_size, fill_value=-1).flatten() + seq_offsets_2 = jnp.argwhere(segment_changes_masked, size=max_segments_per_seq, fill_value=-1).flatten() + #print(f"{seq_offsets_2=}") + #seq_offsets = jnp.where(seq_offsets_2 !=-1, seq_offsets_2, seq_offsets_2[q_num_segments]) + return seq_offsets_2 + #print(f"{seq_offsets=}") + # q_segment_pos_flat = q_segment_pos.reshape(-1) + # # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 + # # e.g. 01230000124567 + # segment_changes = jnp.concatenate([ + # jnp.array([True]), # First valid element starts a segment + # (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed + # ]) + # #print(f"{segment_changes=}") + # max_size = q_segment_pos_flat.shape[0] + # seq_offsets_2 = jnp.argwhere(segment_changes, size=max_size, fill_value=-1).flatten() - # Create index array (static shape) - seq_offsets_2_indices = jnp.arange(seq_offsets_2.shape[0]) - # Create a mask (False: do not clip to the edge element, True: clip to edge element) - mask = seq_offsets_2_indices >= q_num_segments - # Get fill value dynamically by calculating the edge index - edge_index = jnp.clip(q_num_segments - 1, 0, seq_offsets_2.shape[0] - 1) - fill_value = seq_offsets_2[edge_index] + # # Create index array (static shape) + # seq_offsets_2_indices = jnp.arange(seq_offsets_2.shape[0]) + # # Create a mask (False: do not clip to the edge element, True: clip to edge element) + # mask = seq_offsets_2_indices >= q_num_segments + # # Get fill value dynamically by calculating the edge index + # edge_index = jnp.clip(q_num_segments - 1, 0, seq_offsets_2.shape[0] - 1) + # fill_value = seq_offsets_2[edge_index] - seq_offsets = jnp.where(mask, fill_value, seq_offsets_2) + # seq_offsets = jnp.where(mask, fill_value, seq_offsets_2) - return seq_offsets[:max_segments_per_seq] + # return seq_offsets[:max_segments_per_seq] # Per rank! @@ -1676,8 +1696,8 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se size=max_size, fill_value=-1 )[0] - valid_segment_ids = kv_segment_ids_flat[non_zero_indices] - valid_segment_pos = kv_segment_pos_flat[non_zero_indices] + valid_segment_ids = jnp.where(non_zero_indices >= 0, kv_segment_ids_flat[non_zero_indices], 0) + valid_segment_pos = jnp.where(non_zero_indices >= 0, kv_segment_pos_flat[non_zero_indices], 0) actual_valid = valid_segment_ids != 0 #print(f"{valid_segment_ids=}, {valid_segment_pos=}") @@ -1685,26 +1705,28 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se segment_changes = jnp.concatenate([ ((valid_segment_ids[1:] != valid_segment_ids[:-1]) & actual_valid[1:])| # Segment ID changed and not non zero (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1), # Position not consecutive - jnp.array([True]), # Last valid element ends a segment + jnp.array([actual_valid[-1]]) # Last valid element ends a segment ]) # Use the indices from segment_changes to pick out the offset value (which in turn will be the seq length for that segment) - segment_changes_valid = jnp.where(segment_changes & actual_valid, size=max_size, fill_value=-1)[0] + segment_changes_valid = jnp.where(segment_changes & actual_valid, size=max_segments_per_seq, fill_value=-1)[0] #print(f"{segment_changes_valid=}") # Remove any safe_indices = jnp.maximum(segment_changes_valid, 0) #print(f"{safe_indices=}") - selected_values = jnp.where(safe_indices !=0, valid_segment_pos[safe_indices] + 1, 0) - seqlens = jnp.concatenate([jnp.array([0]), jnp.where(segment_changes_valid >= 0, selected_values, 0)[:-1]]) - seqlens_cumsum_padded = jnp.cumsum(seqlens) + selected_values = jnp.where(safe_indices !=0, valid_segment_pos[safe_indices] + 1, -1) + # seqlens = jnp.concatenate([jnp.array([0]), jnp.where(segment_changes_valid >= 0, selected_values, 0)[:-1]]) + # seqlens_cumsum_padded = jnp.cumsum(seqlens) #print(f"{result=}") - return jnp.count_nonzero(selected_values).astype(int), seqlens[:max_segments_per_seq] + return jnp.count_nonzero(selected_values).astype(int), selected_values #QUESTION: Do take a look at other implementations to check if flattening required ? - def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_num_segments, max_segments_per_seq): + # TODO: kv_num_segments not needed + def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_segment_ids_ag, kv_num_segments, max_segments_per_seq): # Calculate the segment pos change mask kv_segment_pos_flat = kv_segment_pos.reshape(-1) kv_segment_ids_flat = kv_segment_ids.reshape(-1) kv_segment_pos_ag_flat = kv_segment_pos_ag.reshape(-1) + kv_segment_ids_ag_flat = kv_segment_ids_ag.reshape(-1) #print(f"{kv_segment_pos_flat=}, {kv_segment_ids_flat=}") # segment_changes=Array([ True, False, False, False, True, True, False, False, True, # False, False, False, True, True, True, True], dtype=bool) @@ -1716,15 +1738,16 @@ def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_ jnp.array([True]), # Assume valid element starts a segment (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed ]) + segment_changes_first_true_masked = jnp.where(kv_segment_ids_flat!=0, segment_changes_first_true, False) #segment_changes = jnp.where(kv_segment_ids_flat[0]==1, segment_changes_first_true, segment_changes_first_true) #print(f"{segment_changes_first_true=}") # Get segment change indices for rank #print(f"{jnp.size(segment_changes_first_true)=}") - segment_changes_indices = jnp.argwhere(segment_changes_first_true, size=jnp.size(segment_changes_first_true), fill_value=-1).flatten() + segment_changes_indices = jnp.argwhere(segment_changes_first_true_masked, size=max_segments_per_seq, fill_value=-1).flatten() #print(f"{segment_changes_indices=}") # Get segment ids associated with the segment_changes_indices for rank - segment_ids = kv_segment_ids_flat[segment_changes_indices] + segment_ids = jnp.where(segment_changes_indices >= 0, kv_segment_ids_flat[segment_changes_indices], -1) #print(f"{segment_ids=}") # Get segment change indices for AG @@ -1732,23 +1755,26 @@ def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_ jnp.array([True]), # Assume valid element starts a segment (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1) # Segment pos changed ]) + segment_changes_ag_first_true_masked = jnp.where(kv_segment_ids_ag_flat!=0, segment_changes_ag_first_true, False) #print(f"{segment_changes_ag_first_true=}") # Get segment change indices for AG #print(f"{jnp.size(segment_changes_ag_first_true)=}") - segment_changes_ag_indices = jnp.argwhere(segment_changes_ag_first_true, size=jnp.size(segment_changes_ag_first_true), fill_value=-1).flatten() + segment_changes_ag_indices = jnp.argwhere(segment_changes_ag_first_true_masked, size=jnp.size(segment_changes_ag_first_true_masked), fill_value=-1).flatten() #print(f"{segment_changes_ag_indices=}") + # Use the segment ids picked per rank to get the offsets from the AG indices - seq_offsets = jnp.where(segment_ids !=0, segment_changes_ag_indices[segment_ids-1], 0) - #print(f"{seq_offsets=}") - indices = jnp.arange(0, seq_offsets.size) - #print(f"{indices=}") - #print(f"{kv_num_segments=}") - arr = jnp.ones_like(seq_offsets) * seq_offsets[kv_num_segments-1] - #print(f"{arr=}") - seq_offsets_truncated = jnp.where(indices >= kv_num_segments, arr, seq_offsets) + seq_offsets = jnp.where(segment_ids !=0, segment_changes_ag_indices[segment_ids-1], -1) + return seq_offsets + # #print(f"{seq_offsets=}") + # indices = jnp.arange(0, seq_offsets.size) + # #print(f"{indices=}") + # #print(f"{kv_num_segments=}") + # arr = jnp.ones_like(seq_offsets) * seq_offsets[kv_num_segments-1] + # #print(f"{arr=}") + # seq_offsets_truncated = jnp.where(indices >= kv_num_segments, arr, seq_offsets) #print(f"{seq_offsets_truncated=}") - return seq_offsets_truncated[:max_segments_per_seq] + # return seq_offsets_truncated[:max_segments_per_seq] #TODO: Come up with better names/organization for the new four functions # def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_num_segments, max_segments_per_seq): # # Calculate the segment pos change mask @@ -2165,9 +2191,9 @@ def _cross_attn(idx, q, k, v, bias, kv_segment_ids_ag, kv_segment_pos_ag, seed): # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(max_seqlen=kv_max_seqlen, cp_size=cp_size) q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq) - q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(q_segment_ids=_q_segment_ids ,q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) kv_num_segments_for_rank, kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(kv_segment_ids=_kv_segment_ids, kv_segment_pos=_kv_segment_pos, max_segments_per_seq=adjusted_max_segments_per_seq) - kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(kv_segment_pos=_kv_segment_pos, kv_segment_ids=_kv_segment_ids, kv_segment_pos_ag=kv_segment_pos_ag, kv_num_segments=kv_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(kv_segment_pos=_kv_segment_pos, kv_segment_ids=_kv_segment_ids, kv_segment_pos_ag=kv_segment_pos_ag, kv_segment_ids_ag=kv_segment_ids_ag,kv_num_segments=kv_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) #kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( From c6e596637efa3df9b659e11b445b6538e80394eb Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Mon, 24 Nov 2025 12:25:34 -0800 Subject: [PATCH 15/36] Add support in new primitive for softmax_offset related changes. Put in missing primitive registering line in again. Increase the seqoffsets arrays lengths by 1 Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 2 +- transformer_engine/jax/cpp_extensions/attention.py | 13 ++++++++----- transformer_engine/jax/cpp_extensions/base.py | 1 + 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index ed2d11cad6..7d76d05872 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -992,7 +992,7 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 43bf7eac7d..a861be9731 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1432,6 +1432,7 @@ def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=self.get_adjusted_mask(), + softmax_type=self.config.softmax_type, qkv_layout=self.config.qkv_layout, scaling_factor=self.config.scaling_factor, dropout_probability=self.config.dropout_probability, @@ -1645,7 +1646,7 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_ #print(f"{segment_changes_masked=}") # Get the indices for segment changes (these are the offsets) max_size = q_segment_pos_flat.shape[0] - seq_offsets_2 = jnp.argwhere(segment_changes_masked, size=max_segments_per_seq, fill_value=-1).flatten() + seq_offsets_2 = jnp.argwhere(segment_changes_masked, size=max_segments_per_seq+1, fill_value=-1).flatten() #print(f"{seq_offsets_2=}") #seq_offsets = jnp.where(seq_offsets_2 !=-1, seq_offsets_2, seq_offsets_2[q_num_segments]) return seq_offsets_2 @@ -1744,7 +1745,7 @@ def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_ # Get segment change indices for rank #print(f"{jnp.size(segment_changes_first_true)=}") - segment_changes_indices = jnp.argwhere(segment_changes_first_true_masked, size=max_segments_per_seq, fill_value=-1).flatten() + segment_changes_indices = jnp.argwhere(segment_changes_first_true_masked, size=max_segments_per_seq+1, fill_value=-1).flatten() #print(f"{segment_changes_indices=}") # Get segment ids associated with the segment_changes_indices for rank segment_ids = jnp.where(segment_changes_indices >= 0, kv_segment_ids_flat[segment_changes_indices], -1) @@ -2145,7 +2146,7 @@ def partition(config, mesh, arg_infos, result_infos): mesh, PartitionSpec(get_all_mesh_axes(), None) ) arg_shardings = [arg_i.sharding for arg_i in arg_infos] - arg_shardings[4] = seed_sharding + arg_shardings[5] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) if DEBUG: @@ -2158,6 +2159,7 @@ def impl( k, v, bias, + softmax_offset, seed, q_seqlen, kv_seqlen, @@ -2182,7 +2184,7 @@ def impl( # Each rank receives the ag k and v along with the ag kv seg ids and kv seg offsets # Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos, # _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in. - def _cross_attn(idx, q, k, v, bias, kv_segment_ids_ag, kv_segment_pos_ag, seed): + def _cross_attn(idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed): # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive # Do not forget to unset the segment_ids and segment_pos so that the seqlens_from_segment_ids_pos() function # does not go down that route but instead just picks the seqlens and offsets passed onto it @@ -2201,6 +2203,7 @@ def _cross_attn(idx, q, k, v, bias, kv_segment_ids_ag, kv_segment_pos_ag, seed): k, #ag v, #ag bias, + softmax_offset, seed, q_seqlens_for_rank, kv_seqlens_for_rank, @@ -2218,7 +2221,7 @@ def _cross_attn(idx, q, k, v, bias, kv_segment_ids_ag, kv_segment_pos_ag, seed): # Only the pos is needed for kv offsets calculation _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(_kv_segment_ids, _kv_segment_pos) functions = [ - partial(_cross_attn, idx, q, k_ag, v_ag, bias, _kv_segment_ids_ag, _kv_segment_pos_ag, seed) + partial(_cross_attn, idx, q, k_ag, v_ag, bias, softmax_offset, _kv_segment_ids_ag, _kv_segment_pos_ag, seed) for idx in range(cp_size) ] diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 5f717df13a..61deab5b80 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -204,6 +204,7 @@ def name_of_wrapper_p(): dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results # Define the eager execution implementation + outer_p.def_impl(cls.outer_impl) outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) From 298ee6b294d2d878dd3d38f65c3bb6b9a30e910e Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Mon, 24 Nov 2025 20:34:11 +0000 Subject: [PATCH 16/36] Add new set of helper functions for seqlens and seqoffsets fo AG+THD+CP+Stripe>1 which accounts for batching and seq offsets size b+1 Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 1 + .../jax/cpp_extensions/attention.py | 503 +++++++++++------- 2 files changed, 323 insertions(+), 181 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 3e551ee934..3cf0fed066 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -230,6 +230,7 @@ def make_mask( @jax.jit def get_seqlens_and_offsets(segment_ids): batch, max_seqlen = segment_ids.shape + #TODO: should this be max_seqlen + 1 ? bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) seqlens = seqlens_with_zero[..., 1:] diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index a861be9731..3090f13842 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1568,111 +1568,207 @@ def pad(x, npad): #TODO: max_segments_per_seq - might need some modifications for per rank compute as it won't be the same as the FA packed representation - maybe (max_segments_per_seq_new = seqlens/stripe_height + max_segments_per_seq) #TODO: Do take a look at other implementations to check if flattening required ? - #QUESTION: Do take a look at other implementations to check if flattening required ? def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): - q_segment_ids_flat = q_segment_ids.reshape(-1) - q_segment_pos_flat = q_segment_pos.reshape(-1) + #q_segment_ids_flat = q_segment_ids.reshape(-1) + #q_segment_pos_flat = q_segment_pos.reshape(-1) + + # Create mask for non-zero segment IDs + non_zero_mask = q_segment_ids != 0 + print(f"{non_zero_mask=}") + # Calculate indices from mask + max_size = q_segment_ids.shape[-1] + # Get non-zero indices for each row (need to vmap underlying jnp.nonzero calls made by jnp.where) + non_zero_indices = jax.vmap( + lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] + )(non_zero_mask) + #print(f"{non_zero_indices=}") + + # Pick non zero seg ids and seg pos using take_along_axis + # Clip -1 to 0 for safe indexing + clipped_indices = jnp.clip(non_zero_indices, 0, None) + valid_segment_ids = jnp.where( + non_zero_indices >= 0, + jnp.take_along_axis(q_segment_ids, clipped_indices, axis=-1), + 0 + ) + valid_segment_pos = jnp.where( + non_zero_indices >= 0, + jnp.take_along_axis(q_segment_pos, clipped_indices, axis=-1), + 0 + ) + #print(f"{valid_segment_ids=},\n {valid_segment_pos=}") + + # Create mask for actual valid entries (not padding) + actual_valid = valid_segment_ids != 0 + #print(f"{actual_valid=}") + + # Detect segment changes, accounting for padding + # First element is True only if it's actually valid + first_is_segment = actual_valid[..., 0:1] + # Detect segment breaks in the valid tokens only (not full seq) + # Padding will always be true as the segment change condition is being applied + # on the valid segments (which have padding at the end so they'll always trigger True) + segment_changes = jnp.concatenate([ + first_is_segment, # First valid element starts a segment + (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) | + #((valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) & actual_valid[..., 1:]) + (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) + ], axis=-1) + # segment_changes = jnp.concatenate([ + # first_is_segment, # First valid element starts a segment + # ((valid_segment_ids[...,1:] != valid_segment_ids[...,:-1]) | # Segment ID changed + # (valid_segment_pos[...,1:] != valid_segment_pos[...,:-1] + 1)) & # Position not consecutive + # actual_valid_mask[...,1:] # Only consider actually valid positions + # ], axis=-1) + #print(f"{segment_changes=}") + # segment_changes_masked = jnp.where(q_segment_ids!=0, segment_changes, False) + # print(f"{segment_changes_masked=}") + + # Create new segment IDs using only valid indices (basically use the non zero indices to index into segment_changes_mask and then do a cumsum) + #new_segment_ids_pre = jax.vmap(lambda nzi_row, scm_row: jnp.where(nzi_row>=0, scm_row[nzi_row], False))(non_zero_indices, segment_changes) + #print(f"{new_segment_ids_pre=}") + new_segment_ids = jnp.cumsum(segment_changes, axis=-1) + #print(f"{new_segment_ids=}") + + # Can't use len() on traced values - use jnp.max instead + # max_new_segments_per_seq = jnp.max(jnp.where(actual_valid, new_segment_ids, 0)) + # print(f"{max_new_segments_per_seq=}") + #max_new_segments_per_seq = 0 #placeholder to be removed later on + + # Use bincount with a safe length + # Add 1 to handle 0-indexing, and ensure it's at least max_segments_per_seq + seqlens_pre = jax.vmap(lambda av_row, nsi_row: jnp.where(av_row, nsi_row, 0).astype(jnp.int32))(actual_valid, new_segment_ids) + #print(f"{seqlens_pre=}") + #print(f"{seqlens_pre.shape=}") + seqlens_all = jax.vmap(lambda sp_row : jnp.bincount( + sp_row, + length=max_segments_per_seq+1 + )[1:])(seqlens_pre) + #print(f"{seqlens_all=}") + seqlens_all_pad_neg = jnp.where(seqlens_all==0, -1, seqlens_all) + #print(f"{seqlens_all_pad_neg=}") + + return max_new_segments_per_seq, seqlens_all_pad_neg + #QUESTION: Do take a look at other implementations to check if flattening required ? + # def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): + # q_segment_ids_flat = q_segment_ids.reshape(-1) + # q_segment_pos_flat = q_segment_pos.reshape(-1) - # Create mask for non-zero segment IDs - non_zero_mask = q_segment_ids_flat != 0 - # Calculate indices from mask - max_size = q_segment_ids_flat.shape[0] - # Non zero segment id indices followed by padding of -1 at the end to get static size - non_zero_indices = jnp.where( - non_zero_mask, - size=max_size, - fill_value=-1 - )[0] - # print(f"{non_zero_indices=}") - # Pick non zero seg ids and seg pos - valid_segment_ids = jnp.where(non_zero_indices >= 0, q_segment_ids_flat[non_zero_indices], 0) - valid_segment_pos = jnp.where(non_zero_indices >= 0, q_segment_pos_flat[non_zero_indices], 0) - # print(f"{valid_segment_ids=}, {valid_segment_pos=}") + # # Create mask for non-zero segment IDs + # non_zero_mask = q_segment_ids_flat != 0 + # # Calculate indices from mask + # max_size = q_segment_ids_flat.shape[0] + # # Non zero segment id indices followed by padding of -1 at the end to get static size + # non_zero_indices = jnp.where( + # non_zero_mask, + # size=max_size, + # fill_value=-1 + # )[0] + # # print(f"{non_zero_indices=}") + # # Pick non zero seg ids and seg pos + # valid_segment_ids = jnp.where(non_zero_indices >= 0, q_segment_ids_flat[non_zero_indices], 0) + # valid_segment_pos = jnp.where(non_zero_indices >= 0, q_segment_pos_flat[non_zero_indices], 0) + # # print(f"{valid_segment_ids=}, {valid_segment_pos=}") - # Create mask for actual valid entries (not padding) - # All Trues in the beginning for valid segment ids followed by padding of False - actual_valid_mask = valid_segment_ids != 0 - # print(f"{actual_valid_mask=}") + # # Create mask for actual valid entries (not padding) + # # All Trues in the beginning for valid segment ids followed by padding of False + # actual_valid_mask = valid_segment_ids != 0 + # # print(f"{actual_valid_mask=}") - # Detect segment changes, accounting for padding - # First element is True only if it's actually valid - first_is_segment = actual_valid_mask[0:1] + # # Detect segment changes, accounting for padding + # # First element is True only if it's actually valid + # first_is_segment = actual_valid_mask[0:1] - segment_changes = jnp.concatenate([ - first_is_segment, # First valid element starts a segment - ((valid_segment_ids[1:] != valid_segment_ids[:-1]) | # Segment ID changed - (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1)) & # Position not consecutive - actual_valid_mask[1:] # Only consider actually valid positions - ]) + # segment_changes = jnp.concatenate([ + # first_is_segment, # First valid element starts a segment + # ((valid_segment_ids[1:] != valid_segment_ids[:-1]) | # Segment ID changed + # (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1)) & # Position not consecutive + # actual_valid_mask[1:] # Only consider actually valid positions + # ]) - # Create new segment IDs - new_segment_ids = jnp.cumsum(segment_changes) - # print(f"{new_segment_ids=}") + # # Create new segment IDs + # new_segment_ids = jnp.cumsum(segment_changes) + # # print(f"{new_segment_ids=}") - # Can't use len() on traced values - use jnp.max instead - max_new_segments_per_seq = jnp.max(jnp.where(actual_valid_mask, new_segment_ids, 0)) - # print(f"{max_new_segments_per_seq=}") + # # Can't use len() on traced values - use jnp.max instead + # max_new_segments_per_seq = jnp.max(jnp.where(actual_valid_mask, new_segment_ids, 0)) + # # print(f"{max_new_segments_per_seq=}") - # Use bincount with a safe length - # Add 1 to handle 0-indexing, and ensure it's at least max_segments_per_seq - seqlens_all = jnp.bincount( - jnp.where(actual_valid_mask, new_segment_ids, 0).astype(jnp.int32), - length=max_segments_per_seq - )[1:] - # print(f"{seqlens_all=}") - seqlens_all_pad_neg = jnp.where(seqlens_all==0, -1, seqlens_all) + # # Use bincount with a safe length + # # Add 1 to handle 0-indexing, and ensure it's at least max_segments_per_seq + # seqlens_all = jnp.bincount( + # jnp.where(actual_valid_mask, new_segment_ids, 0).astype(jnp.int32), + # length=max_segments_per_seq + # )[1:] + # # print(f"{seqlens_all=}") + # seqlens_all_pad_neg = jnp.where(seqlens_all==0, -1, seqlens_all) - # Pad 0 at start prior to cumsum - #seqlens_padded = jnp.concatenate([jnp.array([0]), seqlens_all]) - # print(f"{seqlens_padded=}") - #cum_seqlens_padded = jnp.cumsum(seqlens_padded) # TODO:Momentarily comment off - #print(f"{cum_seqlens_padded=}") + # # Pad 0 at start prior to cumsum + # #seqlens_padded = jnp.concatenate([jnp.array([0]), seqlens_all]) + # # print(f"{seqlens_padded=}") + # #cum_seqlens_padded = jnp.cumsum(seqlens_padded) # TODO:Momentarily comment off + # #print(f"{cum_seqlens_padded=}") - return max_new_segments_per_seq, seqlens_all_pad_neg + # return max_new_segments_per_seq, seqlens_all_pad_neg #QUESTION: Do take a look at other implementations to check if flattening required ? - # TODO: q_num_segments not needed def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq): - q_segment_pos_flat = q_segment_pos.reshape(-1) # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 # e.g. 01230000124567 segment_changes = jnp.concatenate([ - jnp.array([True]), # First valid element starts a segment - (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed - ]) - #print(f"{segment_changes=}") + jnp.full((q_segment_pos.shape[0], 1), True, dtype=bool), # First valid element starts a segment + (q_segment_pos[...,1:] != q_segment_pos[...,:-1] + 1) # Segment pos changed + ], axis=-1) # Remove any padded region segment changes segment_changes_masked = jnp.where(q_segment_ids!=0, segment_changes, False) - #print(f"{segment_changes_masked=}") # Get the indices for segment changes (these are the offsets) - max_size = q_segment_pos_flat.shape[0] - seq_offsets_2 = jnp.argwhere(segment_changes_masked, size=max_segments_per_seq+1, fill_value=-1).flatten() - #print(f"{seq_offsets_2=}") - #seq_offsets = jnp.where(seq_offsets_2 !=-1, seq_offsets_2, seq_offsets_2[q_num_segments]) + max_size = q_segment_pos.shape[-1] + #seq_offsets_2 = jnp.argwhere(segment_changes_masked, size=max_segments_per_seq+1, fill_value=-1).flatten() + seq_offsets_2 = jax.vmap(lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_masked) return seq_offsets_2 - #print(f"{seq_offsets=}") - # q_segment_pos_flat = q_segment_pos.reshape(-1) - # # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 - # # e.g. 01230000124567 - # segment_changes = jnp.concatenate([ - # jnp.array([True]), # First valid element starts a segment - # (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed - # ]) - # #print(f"{segment_changes=}") - # max_size = q_segment_pos_flat.shape[0] - # seq_offsets_2 = jnp.argwhere(segment_changes, size=max_size, fill_value=-1).flatten() + + # # TODO: q_num_segments not needed + # def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq): + # q_segment_pos_flat = q_segment_pos.reshape(-1) + # # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 + # # e.g. 01230000124567 + # segment_changes = jnp.concatenate([ + # jnp.array([True]), # First valid element starts a segment + # (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed + # ]) + # #print(f"{segment_changes=}") + # # Remove any padded region segment changes + # segment_changes_masked = jnp.where(q_segment_ids!=0, segment_changes, False) + # #print(f"{segment_changes_masked=}") + # # Get the indices for segment changes (these are the offsets) + # max_size = q_segment_pos_flat.shape[0] + # seq_offsets_2 = jnp.argwhere(segment_changes_masked, size=max_segments_per_seq, fill_value=-1).flatten() + # #print(f"{seq_offsets_2=}") + # #seq_offsets = jnp.where(seq_offsets_2 !=-1, seq_offsets_2, seq_offsets_2[q_num_segments]) + # return seq_offsets_2 + # #print(f"{seq_offsets=}") + # # q_segment_pos_flat = q_segment_pos.reshape(-1) + # # # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 + # # # e.g. 01230000124567 + # # segment_changes = jnp.concatenate([ + # # jnp.array([True]), # First valid element starts a segment + # # (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed + # # ]) + # # #print(f"{segment_changes=}") + # # max_size = q_segment_pos_flat.shape[0] + # # seq_offsets_2 = jnp.argwhere(segment_changes, size=max_size, fill_value=-1).flatten() - # # Create index array (static shape) - # seq_offsets_2_indices = jnp.arange(seq_offsets_2.shape[0]) - # # Create a mask (False: do not clip to the edge element, True: clip to edge element) - # mask = seq_offsets_2_indices >= q_num_segments - # # Get fill value dynamically by calculating the edge index - # edge_index = jnp.clip(q_num_segments - 1, 0, seq_offsets_2.shape[0] - 1) - # fill_value = seq_offsets_2[edge_index] + # # # Create index array (static shape) + # # seq_offsets_2_indices = jnp.arange(seq_offsets_2.shape[0]) + # # # Create a mask (False: do not clip to the edge element, True: clip to edge element) + # # mask = seq_offsets_2_indices >= q_num_segments + # # # Get fill value dynamically by calculating the edge index + # # edge_index = jnp.clip(q_num_segments - 1, 0, seq_offsets_2.shape[0] - 1) + # # fill_value = seq_offsets_2[edge_index] - # seq_offsets = jnp.where(mask, fill_value, seq_offsets_2) + # # seq_offsets = jnp.where(mask, fill_value, seq_offsets_2) - # return seq_offsets[:max_segments_per_seq] + # # return seq_offsets[:max_segments_per_seq] # Per rank! @@ -1681,92 +1777,171 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_ # i) if same as previous section segment id then add to seqlens counter or ii) if not same as previous section segment id then start seqlens counter # monotonic constraint automatically applies the stripe_height constraint #QUESTION: Do take a look at other implementations to check if flattening required ? + # def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): + # kv_segment_ids_flat = kv_segment_ids.reshape(-1) + # kv_segment_pos_flat = kv_segment_pos.reshape(-1) + # #print(f"{kv_segment_ids_flat=}, {kv_segment_pos_flat=}") + + # # Create mask for non-zero segment IDs + # non_zero_mask = kv_segment_ids_flat != 0 + # #print(f"{non_zero_mask=}") + + # # Filter to only non-zero segments + # max_size = kv_segment_ids_flat.shape[0] + # non_zero_indices = jnp.where( + # non_zero_mask, + # size=max_size, + # fill_value=-1 + # )[0] + # valid_segment_ids = jnp.where(non_zero_indices >= 0, kv_segment_ids_flat[non_zero_indices], 0) + # valid_segment_pos = jnp.where(non_zero_indices >= 0, kv_segment_pos_flat[non_zero_indices], 0) + # actual_valid = valid_segment_ids != 0 + # #print(f"{valid_segment_ids=}, {valid_segment_pos=}") + + # # Detect segment breaks (only for non-zero segments) + # segment_changes = jnp.concatenate([ + # ((valid_segment_ids[1:] != valid_segment_ids[:-1]) & actual_valid[1:])| # Segment ID changed and not non zero + # (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1), # Position not consecutive + # jnp.array([actual_valid[-1]]) # Last valid element ends a segment + # ]) + # # Use the indices from segment_changes to pick out the offset value (which in turn will be the seq length for that segment) + # segment_changes_valid = jnp.where(segment_changes & actual_valid, size=max_segments_per_seq, fill_value=-1)[0] + # #print(f"{segment_changes_valid=}") + # # Remove any + # safe_indices = jnp.maximum(segment_changes_valid, 0) + # #print(f"{safe_indices=}") + # selected_values = jnp.where(safe_indices !=0, valid_segment_pos[safe_indices] + 1, -1) + # # seqlens = jnp.concatenate([jnp.array([0]), jnp.where(segment_changes_valid >= 0, selected_values, 0)[:-1]]) + # # seqlens_cumsum_padded = jnp.cumsum(seqlens) + # #print(f"{result=}") + # return jnp.count_nonzero(selected_values).astype(int), selected_values def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): - kv_segment_ids_flat = kv_segment_ids.reshape(-1) - kv_segment_pos_flat = kv_segment_pos.reshape(-1) - #print(f"{kv_segment_ids_flat=}, {kv_segment_pos_flat=}") - # Create mask for non-zero segment IDs - non_zero_mask = kv_segment_ids_flat != 0 - #print(f"{non_zero_mask=}") - + non_zero_mask = kv_segment_ids != 0 # Filter to only non-zero segments - max_size = kv_segment_ids_flat.shape[0] - non_zero_indices = jnp.where( - non_zero_mask, - size=max_size, - fill_value=-1 - )[0] - valid_segment_ids = jnp.where(non_zero_indices >= 0, kv_segment_ids_flat[non_zero_indices], 0) - valid_segment_pos = jnp.where(non_zero_indices >= 0, kv_segment_pos_flat[non_zero_indices], 0) + max_size = kv_segment_ids.shape[-1] + # Get non-zero indices for each row (need to vmap underlying jnp.nonzero calls made by jnp.where) + non_zero_indices = jax.vmap( + lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] + )(non_zero_mask) + + # Pick non zero seg ids and seg pos using take_along_axis + # Clip -1 to 0 for safe indexing + clipped_indices = jnp.clip(non_zero_indices, 0, None) + valid_segment_ids = jnp.where( + non_zero_indices >= 0, + jnp.take_along_axis(kv_segment_ids, clipped_indices, axis=-1), + 0 + ) + valid_segment_pos = jnp.where( + non_zero_indices >= 0, + jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), + 0 + ) actual_valid = valid_segment_ids != 0 - #print(f"{valid_segment_ids=}, {valid_segment_pos=}") + # Detect segment changes, accounting for padding + # First element is True only if it's actually valid + first_is_segment = actual_valid[..., 0:1] # Detect segment breaks (only for non-zero segments) segment_changes = jnp.concatenate([ - ((valid_segment_ids[1:] != valid_segment_ids[:-1]) & actual_valid[1:])| # Segment ID changed and not non zero - (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1), # Position not consecutive - jnp.array([actual_valid[-1]]) # Last valid element ends a segment - ]) - # Use the indices from segment_changes to pick out the offset value (which in turn will be the seq length for that segment) - segment_changes_valid = jnp.where(segment_changes & actual_valid, size=max_segments_per_seq, fill_value=-1)[0] - #print(f"{segment_changes_valid=}") - # Remove any + ((valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) & actual_valid[..., 1:]) | + (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), + actual_valid[..., -1:] + ], axis=-1) + + # Get the indices for segment changes - apply vmap per row + segment_changes_valid = jax.vmap( + lambda sc_row, av_row: jnp.where(sc_row & av_row, size=max_segments_per_seq, fill_value=-1)[0] + )(segment_changes, actual_valid) + # Safe indices safe_indices = jnp.maximum(segment_changes_valid, 0) - #print(f"{safe_indices=}") - selected_values = jnp.where(safe_indices !=0, valid_segment_pos[safe_indices] + 1, -1) - # seqlens = jnp.concatenate([jnp.array([0]), jnp.where(segment_changes_valid >= 0, selected_values, 0)[:-1]]) - # seqlens_cumsum_padded = jnp.cumsum(seqlens) - #print(f"{result=}") - return jnp.count_nonzero(selected_values).astype(int), selected_values - - #QUESTION: Do take a look at other implementations to check if flattening required ? - # TODO: kv_num_segments not needed + # Select values using take_along_axis per row + selected_values = jnp.where( + segment_changes_valid >= 0, + jnp.take_along_axis(valid_segment_pos, safe_indices, axis=-1) + 1, + -1 + ) + # Count non-zero per row or total + num_segments = jnp.count_nonzero(selected_values > 0, axis=-1).astype(int) # Per row + return num_segments, selected_values + def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_segment_ids_ag, kv_num_segments, max_segments_per_seq): # Calculate the segment pos change mask - kv_segment_pos_flat = kv_segment_pos.reshape(-1) - kv_segment_ids_flat = kv_segment_ids.reshape(-1) - kv_segment_pos_ag_flat = kv_segment_pos_ag.reshape(-1) - kv_segment_ids_ag_flat = kv_segment_ids_ag.reshape(-1) - #print(f"{kv_segment_pos_flat=}, {kv_segment_ids_flat=}") - # segment_changes=Array([ True, False, False, False, True, True, False, False, True, - # False, False, False, True, True, True, True], dtype=bool) - # segment_changes_first_false = jnp.concatenate([ - # jnp.array([False]), # Assume valid element starts a segment - # (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed - # ]) segment_changes_first_true = jnp.concatenate([ - jnp.array([True]), # Assume valid element starts a segment - (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed - ]) - segment_changes_first_true_masked = jnp.where(kv_segment_ids_flat!=0, segment_changes_first_true, False) - #segment_changes = jnp.where(kv_segment_ids_flat[0]==1, segment_changes_first_true, segment_changes_first_true) - #print(f"{segment_changes_first_true=}") + jnp.full((kv_segment_pos.shape[0], 1), True, dtype=bool), # Assume valid element starts a segment + (kv_segment_pos[...,1:] != kv_segment_pos[...,:-1] + 1) # Segment pos changed + ], axis=-1) + segment_changes_first_true_masked = jnp.where(kv_segment_ids!=0, segment_changes_first_true, False) # Get segment change indices for rank - #print(f"{jnp.size(segment_changes_first_true)=}") - segment_changes_indices = jnp.argwhere(segment_changes_first_true_masked, size=max_segments_per_seq+1, fill_value=-1).flatten() - #print(f"{segment_changes_indices=}") + #segment_changes_indices = jnp.argwhere(segment_changes_first_true_masked, size=max_segments_per_seq+1, fill_value=-1).flatten() + segment_changes_indices = jax.vmap(lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_first_true_masked) # Get segment ids associated with the segment_changes_indices for rank - segment_ids = jnp.where(segment_changes_indices >= 0, kv_segment_ids_flat[segment_changes_indices], -1) - #print(f"{segment_ids=}") + #segment_ids = jnp.where(segment_changes_indices >= 0, kv_segment_ids_flat[segment_changes_indices], -1) + segment_ids = jax.vmap(lambda sci_row, ksi_row: jnp.where(sci_row>=0, ksi_row[sci_row], -1))(segment_changes_indices, kv_segment_ids) # Get segment change indices for AG segment_changes_ag_first_true = jnp.concatenate([ - jnp.array([True]), # Assume valid element starts a segment - (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1) # Segment pos changed - ]) - segment_changes_ag_first_true_masked = jnp.where(kv_segment_ids_ag_flat!=0, segment_changes_ag_first_true, False) - #print(f"{segment_changes_ag_first_true=}") + jnp.full((kv_segment_pos.shape[0], 1), True, dtype=bool), # Assume valid element starts a segment + (kv_segment_pos_ag[...,1:] != kv_segment_pos_ag[...,:-1] + 1) # Segment pos changed + ], axis=-1) + segment_changes_ag_first_true_masked = jnp.where(kv_segment_ids_ag!=0, segment_changes_ag_first_true, False) # Get segment change indices for AG - #print(f"{jnp.size(segment_changes_ag_first_true)=}") - segment_changes_ag_indices = jnp.argwhere(segment_changes_ag_first_true_masked, size=jnp.size(segment_changes_ag_first_true_masked), fill_value=-1).flatten() - #print(f"{segment_changes_ag_indices=}") - + #segment_changes_ag_indices = jnp.argwhere(segment_changes_ag_first_true_masked, size=jnp.size(segment_changes_ag_first_true), fill_value=-1).flatten() + segment_changes_ag_indices = jax.vmap(lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_ag_first_true_masked) # Use the segment ids picked per rank to get the offsets from the AG indices - seq_offsets = jnp.where(segment_ids !=0, segment_changes_ag_indices[segment_ids-1], -1) + seq_offsets = jax.vmap(lambda si_row, sca_row: jnp.where(si_row>0, sca_row[si_row-1], -1))(segment_ids, segment_changes_ag_indices) return seq_offsets + + #QUESTION: Do take a look at other implementations to check if flattening required ? + # TODO: kv_num_segments not needed + # def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_segment_ids_ag, kv_num_segments, max_segments_per_seq): + # # Calculate the segment pos change mask + # kv_segment_pos_flat = kv_segment_pos.reshape(-1) + # kv_segment_ids_flat = kv_segment_ids.reshape(-1) + # kv_segment_pos_ag_flat = kv_segment_pos_ag.reshape(-1) + # kv_segment_ids_ag_flat = kv_segment_ids_ag.reshape(-1) + # #print(f"{kv_segment_pos_flat=}, {kv_segment_ids_flat=}") + # # segment_changes=Array([ True, False, False, False, True, True, False, False, True, + # # False, False, False, True, True, True, True], dtype=bool) + # # segment_changes_first_false = jnp.concatenate([ + # # jnp.array([False]), # Assume valid element starts a segment + # # (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed + # # ]) + # segment_changes_first_true = jnp.concatenate([ + # jnp.array([True]), # Assume valid element starts a segment + # (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed + # ]) + # segment_changes_first_true_masked = jnp.where(kv_segment_ids_flat!=0, segment_changes_first_true, False) + # #segment_changes = jnp.where(kv_segment_ids_flat[0]==1, segment_changes_first_true, segment_changes_first_true) + # #print(f"{segment_changes_first_true=}") + + # # Get segment change indices for rank + # #print(f"{jnp.size(segment_changes_first_true)=}") + # segment_changes_indices = jnp.argwhere(segment_changes_first_true_masked, size=max_segments_per_seq, fill_value=-1).flatten() + # #print(f"{segment_changes_indices=}") + # # Get segment ids associated with the segment_changes_indices for rank + # segment_ids = jnp.where(segment_changes_indices >= 0, kv_segment_ids_flat[segment_changes_indices], -1) + # #print(f"{segment_ids=}") + + # # Get segment change indices for AG + # segment_changes_ag_first_true = jnp.concatenate([ + # jnp.array([True]), # Assume valid element starts a segment + # (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1) # Segment pos changed + # ]) + # segment_changes_ag_first_true_masked = jnp.where(kv_segment_ids_ag_flat!=0, segment_changes_ag_first_true, False) + # #print(f"{segment_changes_ag_first_true=}") + # # Get segment change indices for AG + # #print(f"{jnp.size(segment_changes_ag_first_true)=}") + # segment_changes_ag_indices = jnp.argwhere(segment_changes_ag_first_true_masked, size=jnp.size(segment_changes_ag_first_true_masked), fill_value=-1).flatten() + # #print(f"{segment_changes_ag_indices=}") + + + # # Use the segment ids picked per rank to get the offsets from the AG indices + # seq_offsets = jnp.where(segment_ids !=0, segment_changes_ag_indices[segment_ids-1], -1) + # return seq_offsets # #print(f"{seq_offsets=}") # indices = jnp.arange(0, seq_offsets.size) # #print(f"{indices=}") @@ -1776,40 +1951,6 @@ def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_ # seq_offsets_truncated = jnp.where(indices >= kv_num_segments, arr, seq_offsets) #print(f"{seq_offsets_truncated=}") # return seq_offsets_truncated[:max_segments_per_seq] - #TODO: Come up with better names/organization for the new four functions - # def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_num_segments, max_segments_per_seq): - # # Calculate the segment pos change mask - # kv_segment_pos_flat = kv_segment_pos.reshape(-1) - # segment_changes = jnp.concatenate([ - # jnp.array([True]), # First valid element starts a segment - # (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed - # ]) - # #print(f"{segment_changes=}") - - # # Calculate the offsets for the ag array - # kv_segment_pos_ag_flat = kv_segment_pos_ag.reshape(-1) - # #print(f"{kv_segment_pos_ag_flat=}") - # segment_changes_ag = jnp.concatenate([ - # (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1), # Segment pos changed - # jnp.array([False]) - # ]) - # # segment_changes_ag = (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1) # Segment pos changed - # #print(f"{segment_changes_ag=}") - # segment_offsets_ag = jnp.concatenate([jnp.array([0]), kv_segment_pos_ag_flat[segment_changes_ag] + 1]) # First valid element starts a segment - # #print(f"{segment_offsets_ag=}") - # segment_offsets_ag_cumsum = jnp.cumsum(segment_offsets_ag) - # #print(f"{segment_offsets_ag_cumsum=}") - - # # Use the segment_changes mask to find the segment ids where segments changes and then use the segment ids to pick the - # # offset value from the segment_offsets_ag_cumsum - # segment_change_ids = kv_segment_ids[segment_changes] - 1 - # #print(f"{segment_change_ids=}") - # seq_offsets = segment_offsets_ag_cumsum[segment_change_ids] - # seq_offsets_truncate = seq_offsets[:kv_num_segments] - # pad_width = jnp.maximum(0, max_segments_per_seq - seq_offsets_truncate[0].size) - # seq_offsets_padded = jnp.pad(seq_offsets_truncate, (0, pad_width), mode='edge') - # #print(f"{seq_offsets_truncate=}") - # return seq_offsets_padded class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): From 1fa57b433821ccc1bde0053906d7b14fcb3ccf7d Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 25 Nov 2025 06:16:02 +0000 Subject: [PATCH 17/36] Add backward primitive for CP+THD+AG+Striped>1 Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 2 + .../jax/cpp_extensions/attention.py | 185 +++++++++++++++++- 2 files changed, 182 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 7d76d05872..a2d9756580 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1109,6 +1109,7 @@ def _fused_attn_bwd_rule( context_checkpoint_name, ctx, dz, + stripe_height ): del context_checkpoint_name ( @@ -1141,6 +1142,7 @@ def _fused_attn_bwd_rule( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, + stripe_height=stripe_height ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 3090f13842..1ddeb41841 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1493,7 +1493,10 @@ def reduce_scatter_dkv(self, dk, dv): def rs(x): if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False) + if self.config.qkv_layout.is_thd(): + x = reorder_causal_striped(x, cp_size, 1, False, self.config.stripe_height) + else: + x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False) return lax_paral_op( x, @@ -2359,7 +2362,6 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segmen return output, softmax_aux, rng_state k_ag, v_ag = helper.all_gather_kv(k, v) - # Only the pos is needed for kv offsets calculation _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(_kv_segment_ids, _kv_segment_pos) functions = [ partial(_cross_attn, idx, q, k_ag, v_ag, bias, softmax_offset, _kv_segment_ids_ag, _kv_segment_pos_ag, seed) @@ -2373,6 +2375,175 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segmen register_primitive(FusedAttnCPStripedWithAllGatherFwdPrimitive) +class FusedAttnCPStripedWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Attention Backward with Context Parallelism and Striped Load Balancing Primitive. + + This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks. + The gradients are subsequently reduce-scattered back to each context parallel rank. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + # Ensure we can support this configuration with context parallelism. + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + #TODO: Confirm the deletion + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + softmax_offset_spec = get_padded_spec(arg_infos[4]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = ( + dq_sharding, + dk_sharding, + dv_sharding, + dbias_sharding, + dsoftmax_offset_sharding, + ) + + def impl( + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. + def _cross_attn_bwd( + idx, + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive + # Do not forget to unset the segment_ids and segment_pos so that the seqlens_from_segment_ids_pos() function + # does not go down that route but instead just picks the seqlens and offsets passed onto it + + kv_max_seqlen = k.shape[1] + # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq + adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(max_seqlen=kv_max_seqlen, cp_size=cp_size) + q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq) + q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(q_segment_ids=_q_segment_ids ,q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + kv_num_segments_for_rank, kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(kv_segment_ids=_kv_segment_ids, kv_segment_pos=_kv_segment_pos, max_segments_per_seq=adjusted_max_segments_per_seq) + kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(kv_segment_pos=_kv_segment_pos, kv_segment_ids=_kv_segment_ids, kv_segment_pos_ag=kv_segment_pos_ag, kv_segment_ids_ag=kv_segment_ids_ag,kv_num_segments=kv_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + #kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + + dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl( + q, #sharded for rank + k, #ag + v, #ag + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlens_for_rank, + kv_seqlens_for_rank, + q_seq_offsets_for_rank, + kv_seq_offsets_for_rank, + q_seqlen, #Should be empty ids but using placeholder + kv_seqlen, #Should be empty poss but using placeholder + q_seq_offsets, #Should be empty ids but using placeholder + k_seq_offsets, #Should be empty pos but using placeholder + config=helper.get_step_config_for_striped(max_seqlen=kv_max_seqlen, cp_size=cp_size), + ) + + # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. + # if config.attn_mask_type != AttnMaskType.NO_MASK: + # pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx] + # dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length) + + # results.append((dq_local, dk_local, dv_local, dbias_local)) + + # dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1) + # dk_local_pad = results[0][1] + results[1][1] + # dv_local_pad = results[0][2] + results[1][2] + # return dq_local, dk_local_pad, dv_local_pad, results[1][3] + return dq_local, dk_local, dv_local, dbias_local + + k_ag, v_ag = helper.all_gather_kv(k, v) + _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(_kv_segment_ids, _kv_segment_pos) + + functions = [ + partial( + _cross_attn_bwd, + idx, + q, + k_ag, + v_ag, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + _q_segment_ids, + _kv_segment_ids_ag, + _q_segment_pos, + _kv_segment_pos_ag, + ) + for idx in range(cp_size) + ] + + dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) + dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) + + # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it) + dummy_dsoftmax_offset = jnp.empty_like(softmax_offset) + return dq, dk, dv, dbias, dummy_dsoftmax_offset + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPStripedWithAllGatherBwdPrimitive) + @dataclass(frozen=True) class _FusedAttnCPWithP2PHelper: """Helper class to assist with running the P2P ring strategy for CP attention.""" @@ -3501,6 +3672,7 @@ def fused_attn_bwd( context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", + stripe_height: int = 0, ): """ Perform the backward pass of the cuDNN fused attention implementations. @@ -3540,6 +3712,7 @@ def fused_attn_bwd( context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + stripe_height (int): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing Returns: Tuple[jnp.ndarray, ...], jnp.ndarray: - The first tuple contains the gradients with respect to the input `qkv` tensors in the @@ -3599,7 +3772,6 @@ def fused_attn_bwd( attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" - #TODO: stripe_height hardcoded for now as bwd is not being tests fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, @@ -3613,13 +3785,16 @@ def fused_attn_bwd( context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, - stripe_height=0, + stripe_height=stripe_height, ) primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive + if qkv_layout.is_thd(): + primitive = FusedAttnCPStripedWithAllGatherBwdPrimitive.outer_primitive + else: + primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive case CPStrategy.RING: if qkv_layout.is_thd(): primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive From 7f205d5325e818dd52abd8d94638ae402323838a Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 25 Nov 2025 06:16:59 +0000 Subject: [PATCH 18/36] Modify tests for backward primitive for CP+THD+AG+Striped>1 Signed-off-by: Kshitij Lakhani --- tests/jax/test_distributed_fused_attn.py | 4 ++-- tests/jax/test_fused_attn.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 9fa327d777..fdb68e6667 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -454,8 +454,8 @@ def check_has_backend_for_mask(mask_type): pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") #KL code - #runner.test_backward() - runner.test_forward() + runner.test_backward() + #runner.test_forward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] @pytest_parametrize_wrapper( diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 3cf0fed066..07aeca69a2 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -925,7 +925,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): "window_size": self.window_size, "context_parallel_strategy": self.cp_strategy, "context_parallel_causal_load_balanced": self.cp_load_balanced, - #"stripe_height": self.stripe_height, + "stripe_height": self.stripe_height, } # We can compute dBias only for the [1, h, s, s] layout From 3115064659e168f2b7680283952544a24d26d04e Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Mon, 24 Nov 2025 22:57:05 -0800 Subject: [PATCH 19/36] Move stripe_height along with other static args in fused_attn_bwd rule. Fix typo in CP+AG+TH+Striped>1 primitive Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 2 +- transformer_engine/jax/cpp_extensions/attention.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index a2d9756580..ea0e225815 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1107,9 +1107,9 @@ def _fused_attn_bwd_rule( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name, + stripe_height, ctx, dz, - stripe_height ): del context_checkpoint_name ( diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 1ddeb41841..05ab31fe2e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1636,7 +1636,7 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme # Can't use len() on traced values - use jnp.max instead # max_new_segments_per_seq = jnp.max(jnp.where(actual_valid, new_segment_ids, 0)) # print(f"{max_new_segments_per_seq=}") - #max_new_segments_per_seq = 0 #placeholder to be removed later on + max_new_segments_per_seq = 0 #placeholder to be removed later on # Use bincount with a safe length # Add 1 to handle 0-indexing, and ensure it's at least max_segments_per_seq @@ -2455,9 +2455,9 @@ def _cross_attn_bwd( q_seqlen, kv_seqlen, _q_segment_ids, - _kv_segment_ids, + kv_segment_ids_ag, _q_segment_pos, - _kv_segment_pos, + kv_segment_pos_ag, ): # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive # Do not forget to unset the segment_ids and segment_pos so that the seqlens_from_segment_ids_pos() function From 0391f41a2e09c74159bcd0413e2372e81ad6b724 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 06:58:09 +0000 Subject: [PATCH 20/36] Code clean up: remove older version for calculating seqlens and offsets for CP+AG+THD+striped>1 primitive Signed-off-by: Kshitij Lakhani --- tests/jax/test_distributed_fused_attn.py | 12 +- tests/jax/test_fused_attn.py | 26 +- tests/jax/utils.py | 5 +- .../fused_attn_f16_arbitrary_seqlen.cu | 111 +-- transformer_engine/jax/attention.py | 20 +- .../jax/cpp_extensions/attention.py | 641 +++++++----------- 6 files changed, 350 insertions(+), 465 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index fdb68e6667..e0d60e06c6 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -328,7 +328,7 @@ def test_cross_attn( DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [ # Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes. - #TODO: Change the id to CPx2 + # TODO: Change the id to CPx2 pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), # KL test code @@ -413,7 +413,7 @@ def impl_test_context_parallel_attn( mesh_resource=mesh_resource, cp_strategy=cp_strategy, cp_load_balanced=load_balanced, - stripe_height=stripe_height + stripe_height=stripe_height, ) def check_has_backend_for_mask(mask_type): @@ -453,9 +453,9 @@ def check_has_backend_for_mask(mask_type): if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") - #KL code + # KL code runner.test_backward() - #runner.test_forward() + # runner.test_forward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] @pytest_parametrize_wrapper( @@ -654,6 +654,7 @@ def test_context_parallel_ring_attn_shardy( pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"), ] + class TestReorderCausalLoadBalancing: @pytest.mark.parametrize("cp_size", [2, 4, 8]) @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES) @@ -671,10 +672,9 @@ def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_height): if reorder_strategy == ReorderStrategy.Striped: seq_lens = shape[seq_dim] - if seq_lens < (cp_size*stripe_height): + if seq_lens < (cp_size * stripe_height): pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_height=}") - ref = tensor.copy() reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 07aeca69a2..6bdbd16d95 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -230,7 +230,7 @@ def make_mask( @jax.jit def get_seqlens_and_offsets(segment_ids): batch, max_seqlen = segment_ids.shape - #TODO: should this be max_seqlen + 1 ? + # TODO: should this be max_seqlen + 1 ? bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) seqlens = seqlens_with_zero[..., 1:] @@ -501,12 +501,16 @@ def _setup_inputs(self): token_numbers_k = range(self.max_seqlen_kv) for batch_idx in range(q_shape[0]): for token_idx in token_numbers_q: - q_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * (token_idx + 1) + q_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * ( + token_idx + 1 + ) for token_idx in token_numbers_k: - k_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * np.sqrt(self.head_dim_qk) + k_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * np.sqrt( + self.head_dim_qk + ) v_np = np.ones(v_shape, self.dtype) # Set cols at multiples - v_np[0,::4, 0, :] = np.arange(v_np.shape[3]) + v_np[0, ::4, 0, :] = np.arange(v_np.shape[3]) self.q = jnp.array(q_np) self.k = jnp.array(k_np) self.v = jnp.array(v_np) @@ -575,7 +579,7 @@ def generate_random_segment_ids( min_segment_size = 1 if min_segment_len is not None: min_segment_size = min_segment_len[i][seg_id] - #KL test code + # KL test code min_segment_size = 4 segment_size = rng.integers(min_segment_size, max_segment_size + 1) if current_pos + segment_size > sequence_length: @@ -632,8 +636,16 @@ def generate_random_segment_ids( ) self.segment_pos_q = self.segment_pos_kv = None self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None - print(f"self.segment_ids_q: {self.segment_ids_q}, \n self.segment_pos_q: {self.segment_pos_q}, \n self.pad_q: {self.pad_q}, \n self.seqlens_q: {self.seqlens_q}, \n self.offsets_q: { self.offsets_q} \n") - print(f"self.segment_ids_kv: {self.segment_ids_kv}, \n self.segment_pos_kv: {self.segment_pos_kv}, \n self.pad_kv: {self.pad_kv}, \n self.seqlens_kv: {self.seqlens_kv}, \n self.offsets_kv: { self.offsets_kv} \n") + print( + f"self.segment_ids_q: {self.segment_ids_q}, \n self.segment_pos_q:" + f" {self.segment_pos_q}, \n self.pad_q: {self.pad_q}, \n self.seqlens_q:" + f" {self.seqlens_q}, \n self.offsets_q: { self.offsets_q} \n" + ) + print( + f"self.segment_ids_kv: {self.segment_ids_kv}, \n self.segment_pos_kv:" + f" {self.segment_pos_kv}, \n self.pad_kv: {self.pad_kv}, \n self.seqlens_kv:" + f" {self.seqlens_kv}, \n self.offsets_kv: { self.offsets_kv} \n" + ) # For reference code self.mask = make_mask( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 6a45ab437a..b5170eacc3 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1509,17 +1509,18 @@ def assert_allclose( desired = desired.astype(jnp.float32) # KL test code import sys + mismatch_counter = 0 has_nonzero = jnp.any(actual != 0) print(f"has_nonzero: {has_nonzero}") with np.printoptions(threshold=sys.maxsize): - mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch + mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch diff_indices = np.argwhere(mismatch_mask) seq_set = set() for idx in diff_indices: idx_tuple = tuple(idx) seq_set.add(idx_tuple[1]) - mismatch_counter += 1 + mismatch_counter += 1 if mismatch_counter < 1024: print(f"Index {idx_tuple}: a={actual[idx_tuple]}, d={desired[idx_tuple]}") print(f"{sorted(seq_set)}") diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 5056942d1a..bd0702125b 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -49,7 +49,9 @@ namespace transformer_engine { namespace fused_attn { template -__global__ void print_tensor_elements_2(const T *const data, const size_t rows, const size_t start_cols, const size_t end_cols, const size_t cols) { +__global__ void print_tensor_elements_2(const T *const data, const size_t rows, + const size_t start_cols, const size_t end_cols, + const size_t cols) { if ((threadIdx.x == 0) && (threadIdx.y == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { for (size_t i = 0; i < rows; ++i) { for (size_t j = start_cols; j < end_cols; ++j) { @@ -487,45 +489,46 @@ void fused_attn_arbitrary_seqlen_fwd_impl( static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); NVTE_CHECK_CUDA(cudaGetLastError()); - //std::cout << "print_tensors: " << print_tensors << + //std::cout << "print_tensors: " << print_tensors << // "print_tensors_custom_mask: " // << print_tensors_custom_mask << std::endl; - if (print_tensors) - { - if(devPtrCuSeqlensQ) { - if(print_tensors_custom_mask) - { - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, 0, 8, /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, - 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, - 8184, 8192, - /*does not matter for single row*/ actual_b); - } - else - { - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensQ), 1, 0, actual_b, /*does not matter for single row*/actual_b); + if (print_tensors) { + if (devPtrCuSeqlensQ) { + if (print_tensors_custom_mask) { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrCuSeqlensQ), 1, 0, 8, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrCuSeqlensQ), 1, 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrCuSeqlensQ), 1, 8184, 8192, + /*does not matter for single row*/ actual_b); + } else { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrCuSeqlensQ), 1, 0, actual_b, + /*does not matter for single row*/ actual_b); } } if (devActualSeqlenQ) { - if (print_tensors_custom_mask) - { - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, 0, 8, /*does not matter for single row*/actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, - 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, - 8184, 8192, - /*does not matter for single row*/ actual_b); - } - else { - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenQ), 1, 0, actual_b, /*does not matter for single row*/actual_b); + if (print_tensors_custom_mask) { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devActualSeqlenQ), 1, 0, 8, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devActualSeqlenQ), 1, 1024, 1032, + /*does not matter for single row*/ actual_b); + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devActualSeqlenQ), 1, 8184, 8192, + /*does not matter for single row*/ actual_b); + } else { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devActualSeqlenQ), 1, 0, actual_b, + /*does not matter for single row*/ actual_b); } } - if(devPtrCuSeqlensKV) { - if(print_tensors_custom_mask) - { + if (devPtrCuSeqlensKV) { + if (print_tensors_custom_mask) { print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devPtrCuSeqlensKV), 1, 0, 8, /*does not matter for single row*/ actual_b); @@ -535,14 +538,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devPtrCuSeqlensKV), 1, 8184, 8192, /*does not matter for single row*/ actual_b); - } - else { - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devPtrCuSeqlensKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + } else { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devPtrCuSeqlensKV), 1, 0, actual_b, + /*does not matter for single row*/ actual_b); } } - if(devActualSeqlenKV) { - if (print_tensors_custom_mask) - { + if (devActualSeqlenKV) { + if (print_tensors_custom_mask) { print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devActualSeqlenKV), 1, 0, 8, /*does not matter for single row*/ actual_b); @@ -552,10 +555,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( print_tensor_elements_2<<<1, 1, 0, stream>>>( static_cast(devActualSeqlenKV), 1, 8184, 8192, /*does not matter for single row*/ actual_b); - } - else - { - print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast(devActualSeqlenKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b); + } else { + print_tensor_elements_2<<<1, 1, 0, stream>>>( + static_cast(devActualSeqlenKV), 1, 0, actual_b, + /*does not matter for single row*/ actual_b); } } } @@ -668,18 +671,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } } } - if (is_ragged_q) { - variant_pack[offset_q] = devOffsetsQ; - variant_pack[offset_o] = devOffsetsO; - } - if (is_ragged_kv) { - variant_pack[offset_k] = devOffsetsK; - variant_pack[offset_v] = devOffsetsV; - } - if (is_ragged_q && cudnn_runtime_version >= 90600) { - variant_pack[offset_stats] = devOffsetsS; - } + if (is_ragged_q) { + variant_pack[offset_q] = devOffsetsQ; + variant_pack[offset_o] = devOffsetsO; + } + if (is_ragged_kv) { + variant_pack[offset_k] = devOffsetsK; + variant_pack[offset_v] = devOffsetsV; + } + if (is_ragged_q && cudnn_runtime_version >= 90600) { + variant_pack[offset_stats] = devOffsetsS; } + } if (is_dropout) { variant_pack[dropout_seed] = devPtrDropoutSeed; diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index ea0e225815..ac5c155baf 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -386,7 +386,9 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen -def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1): +def reorder_causal_load_balancing( + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1 +): """Reorders a tensor for load balancing the compute of causal attention.""" if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) @@ -396,7 +398,7 @@ def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: in def inverse_reorder_causal_load_balancing( - tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1 + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1 ): """Inverse operation of `reorder_causal_load_balancing`.""" if strategy == ReorderStrategy.DualChunkSwap: @@ -536,7 +538,7 @@ def _segment_ids_pos_to_seqlens_offsets( # It does not need to involve SW for this mask's creation # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - #TODO: Un comment the fast path + # TODO: Un comment the fast path # if (attn_mask_type.is_causal() and window_size is None) or ( # window_size == (-1, -1) and not attn_mask_type.is_bottom_right() # ): @@ -555,7 +557,7 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids_kv, lambda x, y: jnp.equal(x, y) * x, ) - #jax.debug.breakpoint() + # jax.debug.breakpoint() # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied attn_mask = segment_mask if attn_mask_type.is_bottom_right(): @@ -602,7 +604,7 @@ def _segment_ids_pos_to_seqlens_offsets( q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( attn_mask_with_id, max_segments_per_seq ) - #jax.debug.breakpoint() + # jax.debug.breakpoint() return q_seqlen, kv_seqlen, q_offset, kv_offset @@ -682,7 +684,7 @@ def get_seqlens_and_offsets( window_size, max_segments_per_seq, ) - #jax.debug.breakpoint() + # jax.debug.breakpoint() else: q_seqlens, kv_seqlens = _segment_ids_to_seqlens( q_segment_ids, @@ -1077,7 +1079,7 @@ def _fused_attn_fwd_rule( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, - stripe_height=stripe_height + stripe_height=stripe_height, ) output = checkpoint_name(output, context_checkpoint_name) softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) @@ -1142,7 +1144,7 @@ def _fused_attn_bwd_rule( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, - stripe_height=stripe_height + stripe_height=stripe_height, ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -1217,7 +1219,7 @@ def fused_attn( softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape [1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX. If provided, this parameter will receive gradients during backpropagation. - stripe_height (int): + stripe_height (int): Indicates the striping height to be used when using ReorderStrategy.Striped. Currently, a stripe_height > 1 is only allowed for CP + THD + Striped + AG 0 indicates no striping strategy diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 05ab31fe2e..67a4a1be99 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -94,7 +94,7 @@ class _FusedAttnConfig: context_parallel_load_balanced: bool cp_axis: str cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA - stripe_height: int # Only for CP + Striped. For, Ring P2P , stripe_height=1 only. + stripe_height: int # Only for CP + Striped. For, Ring P2P , stripe_height=1 only. @dataclass(frozen=True) @@ -521,35 +521,35 @@ def impl( _kv_segment_pos, config: _FusedAttnConfig, ): - DEBUG = True #os.environ.get("TE_DEBUG_STRIPED_ATTN", "0") == "1" + DEBUG = True # os.environ.get("TE_DEBUG_STRIPED_ATTN", "0") == "1" # if DEBUG: # jax.debug.print("FusedAttnFwdPrimitive.impl CALLED") - # jax.debug.print("Config: qkv_layout={}, attn_mask_type={}", + # jax.debug.print("Config: qkv_layout={}, attn_mask_type={}", # str(config.qkv_layout), str(config.attn_mask_type)) # jax.debug.print("Input shapes:") # jax.debug.print(" q={}, k={}, v={}", q.shape, k.shape, v.shape) # jax.debug.print(" q_seqlen={}, kv_seqlen={}", q_seqlen.shape, kv_seqlen.shape) - + # def print_impl_inputs(q_val, k_val, v_val, q_seq, kv_seq, q_off, k_off): # print(f"\n~~~ FusedAttnFwdPrimitive.impl INPUTS ~~~") # print(f"Q: shape={q_val.shape}, mean={q_val.mean():.6f}, std={q_val.std():.6f}") # print(f" First 5: {q_val.flatten()[:5]}") - + # print(f"K: shape={k_val.shape}, mean={k_val.mean():.6f}, std={k_val.std():.6f}") # print(f" First 5: {k_val.flatten()[:5]}") - + # print(f"V: shape={v_val.shape}, mean={v_val.mean():.6f}, std={v_val.std():.6f}") # print(f" First 5: {v_val.flatten()[:5]}") - + # print(f"\nSequence info:") # print(f" q_seqlen: {q_seq}") # print(f" kv_seqlen: {kv_seq}") # print(f" q_seq_offsets: {q_off}") # print(f" k_seq_offsets: {k_off}") # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") - + # jax.debug.callback( - # print_impl_inputs, + # print_impl_inputs, # q, k, v, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets # ) assert FusedAttnFwdPrimitive.inner_primitive is not None @@ -571,7 +571,7 @@ def impl( # if DEBUG: # jax.debug.print("After sequence_descriptor processing:") # jax.debug.print(" q_seqlen={}, kv_seqlen={}", q_seqlen.shape, kv_seqlen.shape) - + # def print_seq_descriptor(q_seq, kv_seq, q_off, k_off): # print(f"\n~~~ SEQUENCE DESCRIPTOR OUTPUTS ~~~") # print(f"q_seqlen (processed): {q_seq}") @@ -579,7 +579,7 @@ def impl( # print(f"q_seq_offsets (processed): {q_off}") # print(f"k_seq_offsets (processed): {k_off}") # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") - + # jax.debug.callback(print_seq_descriptor, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets) # jax.debug.print("Hello FA impl") if config.qkv_layout.is_thd(): @@ -609,7 +609,7 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_batch = q_batch = batch[0] # if DEBUG: - # jax.debug.print(" batch={}, q_max_seqlen={}, kv_max_seqlen={}", + # jax.debug.print(" batch={}, q_max_seqlen={}, kv_max_seqlen={}", # q_batch, q_max_seqlen, kv_max_seqlen) # Gather valid q_seqlen, which is greater than 0 @@ -648,7 +648,7 @@ def convert_to_2d(offsets, batch, max_seqlen): # print(f"q_seq_offsets (2d): {q_off}") # print(f"k_seq_offsets (2d): {k_off}") # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") - + # jax.debug.callback(print_thd_processing, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) @@ -658,19 +658,18 @@ def convert_to_2d(offsets, batch, max_seqlen): # jax.debug.print("Generated cumulative sequence lengths:") # jax.debug.print(" q_cu_seqlen={}", q_cu_seqlen.shape) # jax.debug.print(" kv_cu_seqlen={}", kv_cu_seqlen.shape) - + # def print_cu_seqlen(q_cu, kv_cu): # print(f"\n~~~ CUMULATIVE SEQLENS ~~~") # print(f"q_cu_seqlen: {q_cu}") # print(f"kv_cu_seqlen: {kv_cu}") # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~\n") - + # jax.debug.callback(print_cu_seqlen, q_cu_seqlen, kv_cu_seqlen) # if DEBUG: # jax.debug.print("Calling inner_primitive.bind...") - output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( q, k, @@ -1312,10 +1311,12 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig return combined.reshape(ori_tensor_shape) -def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_height:int = 1): +def reorder_causal_striped( + tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_height: int = 1 +): """Reorders a tensor for load balancing with striped pattern""" origin_shape = tensor.shape - if origin_shape[seq_dim] % (cp_size*stripe_height) != 0: + if origin_shape[seq_dim] % (cp_size * stripe_height) != 0: raise ValueError( "Expected origin_shape[seq_dim] is multiple of cp_size*stripe_height but got" f" {origin_shape[seq_dim]=}, {cp_size=}, {stripe_height=}, {cp_size*stripe_height=}" @@ -1324,13 +1325,13 @@ def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool, if not is_inverse: new_shape = [ *origin_shape[:seq_dim], - *[origin_shape[seq_dim] // (cp_size*stripe_height), cp_size, stripe_height], + *[origin_shape[seq_dim] // (cp_size * stripe_height), cp_size, stripe_height], *origin_shape[seq_dim + 1 :], ] else: new_shape = [ *origin_shape[:seq_dim], - *[cp_size, origin_shape[seq_dim] // (cp_size*stripe_height), stripe_height], + *[cp_size, origin_shape[seq_dim] // (cp_size * stripe_height), stripe_height], *origin_shape[seq_dim + 1 :], ] @@ -1350,23 +1351,31 @@ def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused attention" - allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] + allowed_layouts = [ + QKVLayout.BSHD_BS2HD, + QKVLayout.BSHD_BSHD_BSHD, + QKVLayout.THD_T2HD, + QKVLayout.THD_THD_THD, + ] if self.config.qkv_layout not in allowed_layouts: raise ValueError( f"{header} only supports layouts:" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) - - if (not self.config.qkv_layout.is_thd() and self.config.stripe_height != 0) or (self.config.qkv_layout.is_thd() and self.config.stripe_height == 0): + + if (not self.config.qkv_layout.is_thd() and self.config.stripe_height != 0) or ( + self.config.qkv_layout.is_thd() and self.config.stripe_height == 0 + ): raise ValueError( - f"{header} only supports Dual Chunk load balancing with BSHD layouts and Striped load balancing with THD layouts" + f"{header} only supports Dual Chunk load balancing with BSHD layouts and Striped" + " load balancing with THD layouts" ) - + if self.config.attn_bias_type != AttnBiasType.NO_BIAS: raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") - - #TODO: Should AttnMaskType.PADDING_CAUSAL_MASK be allowed for CP + AG + THD + Striped ? - #TODO: Should Should AttnMaskType.NO_MASK be allowed for CP + AG + THD + Striped ? + + # TODO: Should AttnMaskType.PADDING_CAUSAL_MASK be allowed for CP + AG + THD + Striped ? + # TODO: Should Should AttnMaskType.NO_MASK be allowed for CP + AG + THD + Striped ? allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] if self.config.qkv_layout.is_thd(): allowed_masks.append(AttnMaskType.PADDING_CAUSAL_MASK) @@ -1375,11 +1384,9 @@ def check_supported(self): f"{header} only supports masking types: " f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" ) - #TODO: For now do not all CP + AG + THD + Striped with NO_MASK + # TODO: For now do not all CP + AG + THD + Striped with NO_MASK if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd(): - raise ValueError( - f"{header} only supports CAUSAL_MASK for THD types" - ) + raise ValueError(f"{header} only supports CAUSAL_MASK for THD types") if self.config.max_segments_per_seq != 1 and (not self.config.qkv_layout.is_thd): raise ValueError( @@ -1397,19 +1404,27 @@ def check_supported(self): def get_adjusted_mask(self): """Converts the mask for context parallelism.""" - if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK and not self.config.qkv_layout.is_thd(): # BSHD only ? + if ( + self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK + and not self.config.qkv_layout.is_thd() + ): # BSHD only ? return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK - if self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK and self.config.qkv_layout.is_thd(): # THD only ? + if ( + self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK + and self.config.qkv_layout.is_thd() + ): # THD only ? return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type - + def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): - # Estimating - return (max_seqlen // (self.config.stripe_height*cp_size)) + self.config.max_segments_per_seq + # Estimating + return ( + max_seqlen // (self.config.stripe_height * cp_size) + ) + self.config.max_segments_per_seq def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" - #TODO: Should the max_segments_per_seq be different ? + # TODO: Should the max_segments_per_seq be different ? return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=self.get_adjusted_mask(), @@ -1425,10 +1440,10 @@ def get_step_config(self) -> _FusedAttnConfig: cp_striped_window_size=None, stripe_height=self.config.stripe_height, ) - + def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" - #TODO: Should the max_segments_per_seq be different ? + # TODO: Should the max_segments_per_seq be different ? return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=self.get_adjusted_mask(), @@ -1466,26 +1481,30 @@ def ag(x): return ag(k), ag(v) return k, v # fall through - + def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos): """Performs a all-gather of k and v over context parallel ranks.""" - #TODO: Is the axis chosen right ? + # TODO: Is the axis chosen right ? kv_segment_ids = lax_paral_op( kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True ) kv_segment_pos = lax_paral_op( kv_segment_pos, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True ) - #jax.debug.breakpoint() + # jax.debug.breakpoint() if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) if self.config.qkv_layout.is_thd(): - kv_segment_ids_ag = reorder_causal_striped(kv_segment_ids, cp_size, 1, True, self.config.stripe_height) - kv_segment_pos_ag = reorder_causal_striped(kv_segment_pos, cp_size, 1, True, self.config.stripe_height) + kv_segment_ids_ag = reorder_causal_striped( + kv_segment_ids, cp_size, 1, True, self.config.stripe_height + ) + kv_segment_pos_ag = reorder_causal_striped( + kv_segment_pos, cp_size, 1, True, self.config.stripe_height + ) return kv_segment_ids_ag, kv_segment_pos_ag - #TODO: Is the dual chunk case needed ? - return kv_segment_ids, kv_segment_pos # fall through + # TODO: Is the dual chunk case needed ? + return kv_segment_ids, kv_segment_pos # fall through def reduce_scatter_dkv(self, dk, dv): """Performs a reduce-scatter of dk and dv over context parallel ranks.""" @@ -1569,255 +1588,62 @@ def pad(x, npad): return dk, dv # fall through - #TODO: max_segments_per_seq - might need some modifications for per rank compute as it won't be the same as the FA packed representation - maybe (max_segments_per_seq_new = seqlens/stripe_height + max_segments_per_seq) - #TODO: Do take a look at other implementations to check if flattening required ? def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): - #q_segment_ids_flat = q_segment_ids.reshape(-1) - #q_segment_pos_flat = q_segment_pos.reshape(-1) - - # Create mask for non-zero segment IDs + # Create mask for non-zero segment IDs non_zero_mask = q_segment_ids != 0 - print(f"{non_zero_mask=}") - # Calculate indices from mask + # Calculate indices from mask max_size = q_segment_ids.shape[-1] - # Get non-zero indices for each row (need to vmap underlying jnp.nonzero calls made by jnp.where) + # Get non-zero indices for each row (need to vmap underlying jnp.nonzero calls made by jnp.where) non_zero_indices = jax.vmap( lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] )(non_zero_mask) - #print(f"{non_zero_indices=}") - # Pick non zero seg ids and seg pos using take_along_axis - # Clip -1 to 0 for safe indexing + # Pick non zero seg ids and seg pos using take_along_axis + # Clip -1 to 0 for safe indexing clipped_indices = jnp.clip(non_zero_indices, 0, None) valid_segment_ids = jnp.where( - non_zero_indices >= 0, - jnp.take_along_axis(q_segment_ids, clipped_indices, axis=-1), - 0 + non_zero_indices >= 0, jnp.take_along_axis(q_segment_ids, clipped_indices, axis=-1), 0 ) valid_segment_pos = jnp.where( - non_zero_indices >= 0, - jnp.take_along_axis(q_segment_pos, clipped_indices, axis=-1), - 0 + non_zero_indices >= 0, jnp.take_along_axis(q_segment_pos, clipped_indices, axis=-1), 0 ) - #print(f"{valid_segment_ids=},\n {valid_segment_pos=}") - # Create mask for actual valid entries (not padding) + # Create mask for actual valid entries (not padding) actual_valid = valid_segment_ids != 0 - #print(f"{actual_valid=}") - # Detect segment changes, accounting for padding - # First element is True only if it's actually valid + # Detect segment changes, accounting for padding + # First element is True only if it's actually valid first_is_segment = actual_valid[..., 0:1] - # Detect segment breaks in the valid tokens only (not full seq) - # Padding will always be true as the segment change condition is being applied - # on the valid segments (which have padding at the end so they'll always trigger True) + # Detect segment breaks in the valid tokens only (not full seq) + # Padding will always be true as the segment change condition is being applied + # on the valid segments (which have padding at the end so they'll always trigger True) segment_changes = jnp.concatenate([ first_is_segment, # First valid element starts a segment (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) | #((valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) & actual_valid[..., 1:]) (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) ], axis=-1) - # segment_changes = jnp.concatenate([ - # first_is_segment, # First valid element starts a segment - # ((valid_segment_ids[...,1:] != valid_segment_ids[...,:-1]) | # Segment ID changed - # (valid_segment_pos[...,1:] != valid_segment_pos[...,:-1] + 1)) & # Position not consecutive - # actual_valid_mask[...,1:] # Only consider actually valid positions - # ], axis=-1) - #print(f"{segment_changes=}") - # segment_changes_masked = jnp.where(q_segment_ids!=0, segment_changes, False) - # print(f"{segment_changes_masked=}") - - # Create new segment IDs using only valid indices (basically use the non zero indices to index into segment_changes_mask and then do a cumsum) - #new_segment_ids_pre = jax.vmap(lambda nzi_row, scm_row: jnp.where(nzi_row>=0, scm_row[nzi_row], False))(non_zero_indices, segment_changes) - #print(f"{new_segment_ids_pre=}") new_segment_ids = jnp.cumsum(segment_changes, axis=-1) - #print(f"{new_segment_ids=}") - - # Can't use len() on traced values - use jnp.max instead - # max_new_segments_per_seq = jnp.max(jnp.where(actual_valid, new_segment_ids, 0)) - # print(f"{max_new_segments_per_seq=}") - max_new_segments_per_seq = 0 #placeholder to be removed later on - - # Use bincount with a safe length - # Add 1 to handle 0-indexing, and ensure it's at least max_segments_per_seq seqlens_pre = jax.vmap(lambda av_row, nsi_row: jnp.where(av_row, nsi_row, 0).astype(jnp.int32))(actual_valid, new_segment_ids) - #print(f"{seqlens_pre=}") - #print(f"{seqlens_pre.shape=}") seqlens_all = jax.vmap(lambda sp_row : jnp.bincount( sp_row, length=max_segments_per_seq+1 )[1:])(seqlens_pre) - #print(f"{seqlens_all=}") seqlens_all_pad_neg = jnp.where(seqlens_all==0, -1, seqlens_all) - #print(f"{seqlens_all_pad_neg=}") - return max_new_segments_per_seq, seqlens_all_pad_neg - #QUESTION: Do take a look at other implementations to check if flattening required ? - # def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): - # q_segment_ids_flat = q_segment_ids.reshape(-1) - # q_segment_pos_flat = q_segment_pos.reshape(-1) - - # # Create mask for non-zero segment IDs - # non_zero_mask = q_segment_ids_flat != 0 - # # Calculate indices from mask - # max_size = q_segment_ids_flat.shape[0] - # # Non zero segment id indices followed by padding of -1 at the end to get static size - # non_zero_indices = jnp.where( - # non_zero_mask, - # size=max_size, - # fill_value=-1 - # )[0] - # # print(f"{non_zero_indices=}") - # # Pick non zero seg ids and seg pos - # valid_segment_ids = jnp.where(non_zero_indices >= 0, q_segment_ids_flat[non_zero_indices], 0) - # valid_segment_pos = jnp.where(non_zero_indices >= 0, q_segment_pos_flat[non_zero_indices], 0) - # # print(f"{valid_segment_ids=}, {valid_segment_pos=}") - - # # Create mask for actual valid entries (not padding) - # # All Trues in the beginning for valid segment ids followed by padding of False - # actual_valid_mask = valid_segment_ids != 0 - # # print(f"{actual_valid_mask=}") - - # # Detect segment changes, accounting for padding - # # First element is True only if it's actually valid - # first_is_segment = actual_valid_mask[0:1] - - # segment_changes = jnp.concatenate([ - # first_is_segment, # First valid element starts a segment - # ((valid_segment_ids[1:] != valid_segment_ids[:-1]) | # Segment ID changed - # (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1)) & # Position not consecutive - # actual_valid_mask[1:] # Only consider actually valid positions - # ]) - - # # Create new segment IDs - # new_segment_ids = jnp.cumsum(segment_changes) - # # print(f"{new_segment_ids=}") - - # # Can't use len() on traced values - use jnp.max instead - # max_new_segments_per_seq = jnp.max(jnp.where(actual_valid_mask, new_segment_ids, 0)) - # # print(f"{max_new_segments_per_seq=}") - - # # Use bincount with a safe length - # # Add 1 to handle 0-indexing, and ensure it's at least max_segments_per_seq - # seqlens_all = jnp.bincount( - # jnp.where(actual_valid_mask, new_segment_ids, 0).astype(jnp.int32), - # length=max_segments_per_seq - # )[1:] - # # print(f"{seqlens_all=}") - # seqlens_all_pad_neg = jnp.where(seqlens_all==0, -1, seqlens_all) - - # # Pad 0 at start prior to cumsum - # #seqlens_padded = jnp.concatenate([jnp.array([0]), seqlens_all]) - # # print(f"{seqlens_padded=}") - # #cum_seqlens_padded = jnp.cumsum(seqlens_padded) # TODO:Momentarily comment off - # #print(f"{cum_seqlens_padded=}") - - # return max_new_segments_per_seq, seqlens_all_pad_neg - - #QUESTION: Do take a look at other implementations to check if flattening required ? + def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq): - # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 - # e.g. 01230000124567 segment_changes = jnp.concatenate([ jnp.full((q_segment_pos.shape[0], 1), True, dtype=bool), # First valid element starts a segment (q_segment_pos[...,1:] != q_segment_pos[...,:-1] + 1) # Segment pos changed ], axis=-1) # Remove any padded region segment changes - segment_changes_masked = jnp.where(q_segment_ids!=0, segment_changes, False) + segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False) # Get the indices for segment changes (these are the offsets) max_size = q_segment_pos.shape[-1] - #seq_offsets_2 = jnp.argwhere(segment_changes_masked, size=max_segments_per_seq+1, fill_value=-1).flatten() seq_offsets_2 = jax.vmap(lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_masked) return seq_offsets_2 - # # TODO: q_num_segments not needed - # def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq): - # q_segment_pos_flat = q_segment_pos.reshape(-1) - # # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 - # # e.g. 01230000124567 - # segment_changes = jnp.concatenate([ - # jnp.array([True]), # First valid element starts a segment - # (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed - # ]) - # #print(f"{segment_changes=}") - # # Remove any padded region segment changes - # segment_changes_masked = jnp.where(q_segment_ids!=0, segment_changes, False) - # #print(f"{segment_changes_masked=}") - # # Get the indices for segment changes (these are the offsets) - # max_size = q_segment_pos_flat.shape[0] - # seq_offsets_2 = jnp.argwhere(segment_changes_masked, size=max_segments_per_seq, fill_value=-1).flatten() - # #print(f"{seq_offsets_2=}") - # #seq_offsets = jnp.where(seq_offsets_2 !=-1, seq_offsets_2, seq_offsets_2[q_num_segments]) - # return seq_offsets_2 - # #print(f"{seq_offsets=}") - # # q_segment_pos_flat = q_segment_pos.reshape(-1) - # # # QUESTION: Will this logic be affected if end padding stripes (i.e. seg pos =0) are present in between seg pos !=0 - # # # e.g. 01230000124567 - # # segment_changes = jnp.concatenate([ - # # jnp.array([True]), # First valid element starts a segment - # # (q_segment_pos_flat[1:] != q_segment_pos_flat[:-1] + 1) # Segment pos changed - # # ]) - # # #print(f"{segment_changes=}") - # # max_size = q_segment_pos_flat.shape[0] - # # seq_offsets_2 = jnp.argwhere(segment_changes, size=max_size, fill_value=-1).flatten() - - # # # Create index array (static shape) - # # seq_offsets_2_indices = jnp.arange(seq_offsets_2.shape[0]) - # # # Create a mask (False: do not clip to the edge element, True: clip to edge element) - # # mask = seq_offsets_2_indices >= q_num_segments - # # # Get fill value dynamically by calculating the edge index - # # edge_index = jnp.clip(q_num_segments - 1, 0, seq_offsets_2.shape[0] - 1) - # # fill_value = seq_offsets_2[edge_index] - - # # seq_offsets = jnp.where(mask, fill_value, seq_offsets_2) - - # # return seq_offsets[:max_segments_per_seq] - - - # Per rank! - # Use full reordered kv seg id and offset(not) - # Look in every stripe_height section of kv_segment_ids - # i) if same as previous section segment id then add to seqlens counter or ii) if not same as previous section segment id then start seqlens counter - # monotonic constraint automatically applies the stripe_height constraint - #QUESTION: Do take a look at other implementations to check if flattening required ? - # def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): - # kv_segment_ids_flat = kv_segment_ids.reshape(-1) - # kv_segment_pos_flat = kv_segment_pos.reshape(-1) - # #print(f"{kv_segment_ids_flat=}, {kv_segment_pos_flat=}") - - # # Create mask for non-zero segment IDs - # non_zero_mask = kv_segment_ids_flat != 0 - # #print(f"{non_zero_mask=}") - - # # Filter to only non-zero segments - # max_size = kv_segment_ids_flat.shape[0] - # non_zero_indices = jnp.where( - # non_zero_mask, - # size=max_size, - # fill_value=-1 - # )[0] - # valid_segment_ids = jnp.where(non_zero_indices >= 0, kv_segment_ids_flat[non_zero_indices], 0) - # valid_segment_pos = jnp.where(non_zero_indices >= 0, kv_segment_pos_flat[non_zero_indices], 0) - # actual_valid = valid_segment_ids != 0 - # #print(f"{valid_segment_ids=}, {valid_segment_pos=}") - - # # Detect segment breaks (only for non-zero segments) - # segment_changes = jnp.concatenate([ - # ((valid_segment_ids[1:] != valid_segment_ids[:-1]) & actual_valid[1:])| # Segment ID changed and not non zero - # (valid_segment_pos[1:] != valid_segment_pos[:-1] + 1), # Position not consecutive - # jnp.array([actual_valid[-1]]) # Last valid element ends a segment - # ]) - # # Use the indices from segment_changes to pick out the offset value (which in turn will be the seq length for that segment) - # segment_changes_valid = jnp.where(segment_changes & actual_valid, size=max_segments_per_seq, fill_value=-1)[0] - # #print(f"{segment_changes_valid=}") - # # Remove any - # safe_indices = jnp.maximum(segment_changes_valid, 0) - # #print(f"{safe_indices=}") - # selected_values = jnp.where(safe_indices !=0, valid_segment_pos[safe_indices] + 1, -1) - # # seqlens = jnp.concatenate([jnp.array([0]), jnp.where(segment_changes_valid >= 0, selected_values, 0)[:-1]]) - # # seqlens_cumsum_padded = jnp.cumsum(seqlens) - # #print(f"{result=}") - # return jnp.count_nonzero(selected_values).astype(int), selected_values def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): # Create mask for non-zero segment IDs non_zero_mask = kv_segment_ids != 0 @@ -1832,14 +1658,10 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se # Clip -1 to 0 for safe indexing clipped_indices = jnp.clip(non_zero_indices, 0, None) valid_segment_ids = jnp.where( - non_zero_indices >= 0, - jnp.take_along_axis(kv_segment_ids, clipped_indices, axis=-1), - 0 + non_zero_indices >= 0, jnp.take_along_axis(kv_segment_ids, clipped_indices, axis=-1), 0 ) valid_segment_pos = jnp.where( - non_zero_indices >= 0, - jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), - 0 + non_zero_indices >= 0, jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), 0 ) actual_valid = valid_segment_ids != 0 @@ -1847,15 +1669,23 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se # First element is True only if it's actually valid first_is_segment = actual_valid[..., 0:1] # Detect segment breaks (only for non-zero segments) - segment_changes = jnp.concatenate([ - ((valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) & actual_valid[..., 1:]) | - (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), - actual_valid[..., -1:] - ], axis=-1) + segment_changes = jnp.concatenate( + [ + ( + (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) + & actual_valid[..., 1:] + ) + | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), + actual_valid[..., -1:], + ], + axis=-1, + ) # Get the indices for segment changes - apply vmap per row segment_changes_valid = jax.vmap( - lambda sc_row, av_row: jnp.where(sc_row & av_row, size=max_segments_per_seq, fill_value=-1)[0] + lambda sc_row, av_row: jnp.where( + sc_row & av_row, size=max_segments_per_seq, fill_value=-1 + )[0] )(segment_changes, actual_valid) # Safe indices safe_indices = jnp.maximum(segment_changes_valid, 0) @@ -1863,98 +1693,64 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se selected_values = jnp.where( segment_changes_valid >= 0, jnp.take_along_axis(valid_segment_pos, safe_indices, axis=-1) + 1, - -1 + -1, ) # Count non-zero per row or total num_segments = jnp.count_nonzero(selected_values > 0, axis=-1).astype(int) # Per row return num_segments, selected_values - def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_segment_ids_ag, kv_num_segments, max_segments_per_seq): + def kv_seqoffsets_for_striped_for_rank( + self, + kv_segment_pos, + kv_segment_ids, + kv_segment_pos_ag, + kv_segment_ids_ag, + kv_num_segments, + max_segments_per_seq, + ): # Calculate the segment pos change mask - segment_changes_first_true = jnp.concatenate([ - jnp.full((kv_segment_pos.shape[0], 1), True, dtype=bool), # Assume valid element starts a segment - (kv_segment_pos[...,1:] != kv_segment_pos[...,:-1] + 1) # Segment pos changed - ], axis=-1) - segment_changes_first_true_masked = jnp.where(kv_segment_ids!=0, segment_changes_first_true, False) + segment_changes_first_true = jnp.concatenate( + [ + jnp.full( + (kv_segment_pos.shape[0], 1), True, dtype=bool + ), # Assume valid element starts a segment + (kv_segment_pos[..., 1:] != kv_segment_pos[..., :-1] + 1), # Segment pos changed + ], + axis=-1, + ) + segment_changes_first_true_masked = jnp.where( + kv_segment_ids != 0, segment_changes_first_true, False + ) # Get segment change indices for rank - #segment_changes_indices = jnp.argwhere(segment_changes_first_true_masked, size=max_segments_per_seq+1, fill_value=-1).flatten() segment_changes_indices = jax.vmap(lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_first_true_masked) # Get segment ids associated with the segment_changes_indices for rank - #segment_ids = jnp.where(segment_changes_indices >= 0, kv_segment_ids_flat[segment_changes_indices], -1) segment_ids = jax.vmap(lambda sci_row, ksi_row: jnp.where(sci_row>=0, ksi_row[sci_row], -1))(segment_changes_indices, kv_segment_ids) # Get segment change indices for AG - segment_changes_ag_first_true = jnp.concatenate([ - jnp.full((kv_segment_pos.shape[0], 1), True, dtype=bool), # Assume valid element starts a segment - (kv_segment_pos_ag[...,1:] != kv_segment_pos_ag[...,:-1] + 1) # Segment pos changed - ], axis=-1) - segment_changes_ag_first_true_masked = jnp.where(kv_segment_ids_ag!=0, segment_changes_ag_first_true, False) + segment_changes_ag_first_true = jnp.concatenate( + [ + jnp.full( + (kv_segment_pos.shape[0], 1), True, dtype=bool + ), # Assume valid element starts a segment + ( + kv_segment_pos_ag[..., 1:] != kv_segment_pos_ag[..., :-1] + 1 + ), # Segment pos changed + ], + axis=-1, + ) + segment_changes_ag_first_true_masked = jnp.where( + kv_segment_ids_ag != 0, segment_changes_ag_first_true, False + ) # Get segment change indices for AG - #segment_changes_ag_indices = jnp.argwhere(segment_changes_ag_first_true_masked, size=jnp.size(segment_changes_ag_first_true), fill_value=-1).flatten() segment_changes_ag_indices = jax.vmap(lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_ag_first_true_masked) # Use the segment ids picked per rank to get the offsets from the AG indices - seq_offsets = jax.vmap(lambda si_row, sca_row: jnp.where(si_row>0, sca_row[si_row-1], -1))(segment_ids, segment_changes_ag_indices) + seq_offsets = jax.vmap( + lambda si_row, sca_row: jnp.where(si_row > 0, sca_row[si_row - 1], -1) + )(segment_ids, segment_changes_ag_indices) return seq_offsets - #QUESTION: Do take a look at other implementations to check if flattening required ? - # TODO: kv_num_segments not needed - # def kv_seqoffsets_for_striped_for_rank(self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_segment_ids_ag, kv_num_segments, max_segments_per_seq): - # # Calculate the segment pos change mask - # kv_segment_pos_flat = kv_segment_pos.reshape(-1) - # kv_segment_ids_flat = kv_segment_ids.reshape(-1) - # kv_segment_pos_ag_flat = kv_segment_pos_ag.reshape(-1) - # kv_segment_ids_ag_flat = kv_segment_ids_ag.reshape(-1) - # #print(f"{kv_segment_pos_flat=}, {kv_segment_ids_flat=}") - # # segment_changes=Array([ True, False, False, False, True, True, False, False, True, - # # False, False, False, True, True, True, True], dtype=bool) - # # segment_changes_first_false = jnp.concatenate([ - # # jnp.array([False]), # Assume valid element starts a segment - # # (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed - # # ]) - # segment_changes_first_true = jnp.concatenate([ - # jnp.array([True]), # Assume valid element starts a segment - # (kv_segment_pos_flat[1:] != kv_segment_pos_flat[:-1] + 1) # Segment pos changed - # ]) - # segment_changes_first_true_masked = jnp.where(kv_segment_ids_flat!=0, segment_changes_first_true, False) - # #segment_changes = jnp.where(kv_segment_ids_flat[0]==1, segment_changes_first_true, segment_changes_first_true) - # #print(f"{segment_changes_first_true=}") - - # # Get segment change indices for rank - # #print(f"{jnp.size(segment_changes_first_true)=}") - # segment_changes_indices = jnp.argwhere(segment_changes_first_true_masked, size=max_segments_per_seq, fill_value=-1).flatten() - # #print(f"{segment_changes_indices=}") - # # Get segment ids associated with the segment_changes_indices for rank - # segment_ids = jnp.where(segment_changes_indices >= 0, kv_segment_ids_flat[segment_changes_indices], -1) - # #print(f"{segment_ids=}") - - # # Get segment change indices for AG - # segment_changes_ag_first_true = jnp.concatenate([ - # jnp.array([True]), # Assume valid element starts a segment - # (kv_segment_pos_ag_flat[1:] != kv_segment_pos_ag_flat[:-1] + 1) # Segment pos changed - # ]) - # segment_changes_ag_first_true_masked = jnp.where(kv_segment_ids_ag_flat!=0, segment_changes_ag_first_true, False) - # #print(f"{segment_changes_ag_first_true=}") - # # Get segment change indices for AG - # #print(f"{jnp.size(segment_changes_ag_first_true)=}") - # segment_changes_ag_indices = jnp.argwhere(segment_changes_ag_first_true_masked, size=jnp.size(segment_changes_ag_first_true_masked), fill_value=-1).flatten() - # #print(f"{segment_changes_ag_indices=}") - - - # # Use the segment ids picked per rank to get the offsets from the AG indices - # seq_offsets = jnp.where(segment_ids !=0, segment_changes_ag_indices[segment_ids-1], -1) - # return seq_offsets - # #print(f"{seq_offsets=}") - # indices = jnp.arange(0, seq_offsets.size) - # #print(f"{indices=}") - # #print(f"{kv_num_segments=}") - # arr = jnp.ones_like(seq_offsets) * seq_offsets[kv_num_segments-1] - # #print(f"{arr=}") - # seq_offsets_truncated = jnp.where(indices >= kv_num_segments, arr, seq_offsets) - #print(f"{seq_offsets_truncated=}") - # return seq_offsets_truncated[:max_segments_per_seq] - class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): """ @@ -1985,7 +1781,7 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[5] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - #jax.debug.breakpoint() + # jax.debug.breakpoint() def impl( q, @@ -2005,8 +1801,8 @@ def impl( ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) - #jax.debug.print("Test CP DC AG") - Gives a seg fault - #breakpoint() + # jax.debug.print("Test CP DC AG") - Gives a seg fault + # breakpoint() # cuDNN does not support right-aligned masking with dynamic sequence length padding. # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch @@ -2018,7 +1814,7 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): kv_max_seqlen = k.shape[1] kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) # jax.debug.print("Test cross attn ag") - Gives a seg fault - #jax.debug.print(f"kv_max_seqlen: {kv_max_seqlen}") + # jax.debug.print(f"kv_max_seqlen: {kv_max_seqlen}") assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" q_split = jnp.split(q, 2, axis=1) @@ -2028,7 +1824,7 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): ) results = [] - #breakpoint() + # breakpoint() for sub_idx in range(2): if config.attn_mask_type == AttnMaskType.NO_MASK: k_unmasked, v_unmasked = k, v # full kv used for unmasked @@ -2038,7 +1834,7 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): q_seqlen_for_step = q_seqlen / (cp_size * 2) num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks - #breakpoint() + # breakpoint() output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_split[sub_idx], k_unmasked, @@ -2258,6 +2054,7 @@ def _cross_attn_bwd( register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) + class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): """ Fused Attention Forward with Context Parallelism and Striped Load Balancing Primitive @@ -2267,12 +2064,15 @@ class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): @staticmethod def partition(config, mesh, arg_infos, result_infos): - DEBUG = True #os.environ.get("TE_DEBUG_STRIPED_ATTN", "0") == "1" + DEBUG = True # os.environ.get("TE_DEBUG_STRIPED_ATTN", "0") == "1" if DEBUG: print(f"STRIPED PARTITION CALLED (Compilation Phase)") print(f"Mesh: {mesh}") print(f"CP axis: {config.cp_axis}, size: {get_mesh_axis_size(config.cp_axis, mesh)}") - print(f"window_size: {config.window_size}, context_parallel_load_balanced: {config.context_parallel_load_balanced}, stripe_height: {config.stripe_height}") + print( + f"window_size: {config.window_size}, context_parallel_load_balanced:" + f" {config.context_parallel_load_balanced}, stripe_height: {config.stripe_height}" + ) print(f"Arg shapes: {[info.shape for info in arg_infos]}") print(f"QKV layout: {config.qkv_layout}") print(f"Attention mask type: {config.attn_mask_type}") @@ -2326,26 +2126,50 @@ def impl( # mask/sequence length tensor to avoid this unrolled loop. # Each rank receives the ag k and v along with the ag kv seg ids and kv seg offsets - # Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos, + # Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos, # _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in. - def _cross_attn(idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed): + def _cross_attn( + idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed + ): # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive # Do not forget to unset the segment_ids and segment_pos so that the seqlens_from_segment_ids_pos() function # does not go down that route but instead just picks the seqlens and offsets passed onto it - + kv_max_seqlen = k.shape[1] # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq - adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(max_seqlen=kv_max_seqlen, cp_size=cp_size) - q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq) - q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(q_segment_ids=_q_segment_ids ,q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) - kv_num_segments_for_rank, kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(kv_segment_ids=_kv_segment_ids, kv_segment_pos=_kv_segment_pos, max_segments_per_seq=adjusted_max_segments_per_seq) - kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(kv_segment_pos=_kv_segment_pos, kv_segment_ids=_kv_segment_ids, kv_segment_pos_ag=kv_segment_pos_ag, kv_segment_ids_ag=kv_segment_ids_ag,kv_num_segments=kv_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) - #kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) - + adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq( + max_seqlen=kv_max_seqlen, cp_size=cp_size + ) + q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank( + _q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq + ) + q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank( + q_segment_ids=_q_segment_ids, + q_segment_pos=_q_segment_pos, + q_num_segments=q_num_segments_for_rank, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + kv_num_segments_for_rank, kv_seqlens_for_rank = ( + helper.kv_seqlens_for_striped_for_rank( + kv_segment_ids=_kv_segment_ids, + kv_segment_pos=_kv_segment_pos, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + ) + kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank( + kv_segment_pos=_kv_segment_pos, + kv_segment_ids=_kv_segment_ids, + kv_segment_pos_ag=kv_segment_pos_ag, + kv_segment_ids_ag=kv_segment_ids_ag, + kv_num_segments=kv_num_segments_for_rank, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + # kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( - q, #sharded for rank - k, #ag - v, #ag + q, # sharded for rank + k, # ag + v, # ag bias, softmax_offset, seed, @@ -2353,18 +2177,33 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segmen kv_seqlens_for_rank, q_seq_offsets_for_rank, kv_seq_offsets_for_rank, - q_seqlen, #Should be empty ids but using placeholder - kv_seqlen, #Should be empty poss but using placeholder - q_seq_offsets, #Should be empty ids but using placeholder - k_seq_offsets, #Should be empty pos but using placeholder - config=helper.get_step_config_for_striped(max_seqlen=kv_max_seqlen, cp_size=cp_size), + q_seqlen, # Should be empty ids but using placeholder + kv_seqlen, # Should be empty poss but using placeholder + q_seq_offsets, # Should be empty ids but using placeholder + k_seq_offsets, # Should be empty pos but using placeholder + config=helper.get_step_config_for_striped( + max_seqlen=kv_max_seqlen, cp_size=cp_size + ), ) return output, softmax_aux, rng_state k_ag, v_ag = helper.all_gather_kv(k, v) - _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(_kv_segment_ids, _kv_segment_pos) + _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos( + _kv_segment_ids, _kv_segment_pos + ) functions = [ - partial(_cross_attn, idx, q, k_ag, v_ag, bias, softmax_offset, _kv_segment_ids_ag, _kv_segment_pos_ag, seed) + partial( + _cross_attn, + idx, + q, + k_ag, + v_ag, + bias, + softmax_offset, + _kv_segment_ids_ag, + _kv_segment_pos_ag, + seed, + ) for idx in range(cp_size) ] @@ -2375,6 +2214,7 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segmen register_primitive(FusedAttnCPStripedWithAllGatherFwdPrimitive) + class FusedAttnCPStripedWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): """ Fused Attention Backward with Context Parallelism and Striped Load Balancing Primitive. @@ -2397,7 +2237,7 @@ def partition(config, mesh, arg_infos, result_infos): helper = _FusedAttnCPWithAllGatherHelper(mesh, config) helper.check_supported() - #TODO: Confirm the deletion + # TODO: Confirm the deletion del result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) @@ -2462,20 +2302,42 @@ def _cross_attn_bwd( # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive # Do not forget to unset the segment_ids and segment_pos so that the seqlens_from_segment_ids_pos() function # does not go down that route but instead just picks the seqlens and offsets passed onto it - + kv_max_seqlen = k.shape[1] # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq - adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(max_seqlen=kv_max_seqlen, cp_size=cp_size) - q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq) - q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(q_segment_ids=_q_segment_ids ,q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) - kv_num_segments_for_rank, kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(kv_segment_ids=_kv_segment_ids, kv_segment_pos=_kv_segment_pos, max_segments_per_seq=adjusted_max_segments_per_seq) - kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(kv_segment_pos=_kv_segment_pos, kv_segment_ids=_kv_segment_ids, kv_segment_pos_ag=kv_segment_pos_ag, kv_segment_ids_ag=kv_segment_ids_ag,kv_num_segments=kv_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) - #kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) - + adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq( + max_seqlen=kv_max_seqlen, cp_size=cp_size + ) + q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank( + _q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq + ) + q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank( + q_segment_ids=_q_segment_ids, + q_segment_pos=_q_segment_pos, + q_num_segments=q_num_segments_for_rank, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + kv_num_segments_for_rank, kv_seqlens_for_rank = ( + helper.kv_seqlens_for_striped_for_rank( + kv_segment_ids=_kv_segment_ids, + kv_segment_pos=_kv_segment_pos, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + ) + kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank( + kv_segment_pos=_kv_segment_pos, + kv_segment_ids=_kv_segment_ids, + kv_segment_pos_ag=kv_segment_pos_ag, + kv_segment_ids_ag=kv_segment_ids_ag, + kv_num_segments=kv_num_segments_for_rank, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + # kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) + dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl( - q, #sharded for rank - k, #ag - v, #ag + q, # sharded for rank + k, # ag + v, # ag bias, softmax_offset, softmax_aux, @@ -2486,11 +2348,13 @@ def _cross_attn_bwd( kv_seqlens_for_rank, q_seq_offsets_for_rank, kv_seq_offsets_for_rank, - q_seqlen, #Should be empty ids but using placeholder - kv_seqlen, #Should be empty poss but using placeholder - q_seq_offsets, #Should be empty ids but using placeholder - k_seq_offsets, #Should be empty pos but using placeholder - config=helper.get_step_config_for_striped(max_seqlen=kv_max_seqlen, cp_size=cp_size), + q_seqlen, # Should be empty ids but using placeholder + kv_seqlen, # Should be empty poss but using placeholder + q_seq_offsets, # Should be empty ids but using placeholder + k_seq_offsets, # Should be empty pos but using placeholder + config=helper.get_step_config_for_striped( + max_seqlen=kv_max_seqlen, cp_size=cp_size + ), ) # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. @@ -2504,10 +2368,12 @@ def _cross_attn_bwd( # dk_local_pad = results[0][1] + results[1][1] # dv_local_pad = results[0][2] + results[1][2] # return dq_local, dk_local_pad, dv_local_pad, results[1][3] - return dq_local, dk_local, dv_local, dbias_local + return dq_local, dk_local, dv_local, dbias_local k_ag, v_ag = helper.all_gather_kv(k, v) - _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(_kv_segment_ids, _kv_segment_pos) + _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos( + _kv_segment_ids, _kv_segment_pos + ) functions = [ partial( @@ -2544,6 +2410,7 @@ def _cross_attn_bwd( register_primitive(FusedAttnCPStripedWithAllGatherBwdPrimitive) + @dataclass(frozen=True) class _FusedAttnCPWithP2PHelper: """Helper class to assist with running the P2P ring strategy for CP attention.""" From 94af41377087835f998a7d4019cbf2f2610ca6a0 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 25 Nov 2025 07:51:06 +0000 Subject: [PATCH 21/36] Add test for CP+THD+AG+Striped>1 Signed-off-by: Kshitij Lakhani --- tests/jax/test_distributed_fused_attn.py | 66 +++++++++++++++++++++--- 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index e0d60e06c6..c61ce74ca3 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -354,12 +354,13 @@ def impl_test_context_parallel_attn( use_shardy, use_scan_ring=False, window_size=None, + stripe_height=0, ): if qkv_layout.is_thd(): # if cp_strategy == CPStrategy.ALL_GATHER: # pytest.skip("THD doesn't support all gather context parallelism.") - if not load_balanced and cp_strategy == CPStrategy.RING: - pytest.skip("THD + ring doesn't support unbalanced context parallelism.") + if not load_balanced and (cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER): + pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.") assert not use_scan_ring or cp_strategy == CPStrategy.RING @@ -387,7 +388,7 @@ def impl_test_context_parallel_attn( num_kv_heads = num_head // kv_groups # KL code For AG case only - stripe_height = 4 if qkv_layout.is_thd() and cp_strategy == CPStrategy.ALL_GATHER else 0 + #stripe_height = 4 if qkv_layout.is_thd() and cp_strategy == CPStrategy.ALL_GATHER else 0 runner = FusedAttnRunner( batch, @@ -440,6 +441,7 @@ def check_has_backend_for_mask(mask_type): # and exception if the step backend is not supported. This was a deliberate API # decision to keep the CP size or flag out of the function. has_backend = check_has_backend_for_mask(attn_mask_type) + #TODO: For PADDING_CAUSAL_MASK ? if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK: has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK) @@ -455,7 +457,6 @@ def check_has_backend_for_mask(mask_type): # KL code runner.test_backward() - # runner.test_forward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] @pytest_parametrize_wrapper( @@ -466,7 +467,7 @@ def check_has_backend_for_mask(mask_type): @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], ) def test_context_parallel_allgather_attn_shardy( self, @@ -495,6 +496,55 @@ def test_context_parallel_allgather_attn_shardy( use_shardy=True, ) + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) + @pytest.mark.parametrize("kv_groups", [1, 8]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1], + ) + @pytest.mark.parametrize( + "load_balanced", + [pytest.param(True, id="BALANCED")], + ) + @pytest.mark.parametrize( + "stripe_height", + [pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")], + ) + def test_context_parallel_allgather_striped_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + stripe_height, + ): + self.impl_test_context_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + CPStrategy.ALL_GATHER, + use_shardy=False, + stripe_height=stripe_height, + ) + @pytest_parametrize_wrapper( "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs_for_attn(), @@ -504,7 +554,7 @@ def test_context_parallel_allgather_attn_shardy( @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], ) @pytest.mark.parametrize( "load_balanced", @@ -547,7 +597,7 @@ def test_context_parallel_allgather_attn( @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], ) @pytest.mark.parametrize( "load_balanced", @@ -611,7 +661,7 @@ def test_context_parallel_ring_attn( @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], ) def test_context_parallel_ring_attn_shardy( self, From a788a1d6db9a4a8d3b7804944c497c9fe2da4895 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 25 Nov 2025 07:53:06 +0000 Subject: [PATCH 22/36] Fix missing var Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/cpp_extensions/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 67a4a1be99..7cae665b19 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1630,6 +1630,7 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme length=max_segments_per_seq+1 )[1:])(seqlens_pre) seqlens_all_pad_neg = jnp.where(seqlens_all==0, -1, seqlens_all) + max_new_segments_per_seq = 0 #TODO: Remove return max_new_segments_per_seq, seqlens_all_pad_neg def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq): From 60191eb1ead6a52a797ecb13760fd80340d14887 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 07:54:09 +0000 Subject: [PATCH 23/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_fused_attn.py | 8 ++- .../jax/cpp_extensions/attention.py | 69 ++++++++++++------- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index c61ce74ca3..46d3fcf2e3 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -359,7 +359,9 @@ def impl_test_context_parallel_attn( if qkv_layout.is_thd(): # if cp_strategy == CPStrategy.ALL_GATHER: # pytest.skip("THD doesn't support all gather context parallelism.") - if not load_balanced and (cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER): + if not load_balanced and ( + cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER + ): pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.") assert not use_scan_ring or cp_strategy == CPStrategy.RING @@ -388,7 +390,7 @@ def impl_test_context_parallel_attn( num_kv_heads = num_head // kv_groups # KL code For AG case only - #stripe_height = 4 if qkv_layout.is_thd() and cp_strategy == CPStrategy.ALL_GATHER else 0 + # stripe_height = 4 if qkv_layout.is_thd() and cp_strategy == CPStrategy.ALL_GATHER else 0 runner = FusedAttnRunner( batch, @@ -441,7 +443,7 @@ def check_has_backend_for_mask(mask_type): # and exception if the step backend is not supported. This was a deliberate API # decision to keep the CP size or flag out of the function. has_backend = check_has_backend_for_mask(attn_mask_type) - #TODO: For PADDING_CAUSAL_MASK ? + # TODO: For PADDING_CAUSAL_MASK ? if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK: has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 7cae665b19..4e1fe5e08e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1587,7 +1587,7 @@ def pad(x, npad): return pad(dk, npad), pad(dv, npad) return dk, dv # fall through - + def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): # Create mask for non-zero segment IDs non_zero_mask = q_segment_ids != 0 @@ -1616,33 +1616,46 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme first_is_segment = actual_valid[..., 0:1] # Detect segment breaks in the valid tokens only (not full seq) # Padding will always be true as the segment change condition is being applied - # on the valid segments (which have padding at the end so they'll always trigger True) - segment_changes = jnp.concatenate([ - first_is_segment, # First valid element starts a segment - (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) | - #((valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) & actual_valid[..., 1:]) - (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) - ], axis=-1) + # on the valid segments (which have padding at the end so they'll always trigger True) + segment_changes = jnp.concatenate( + [ + first_is_segment, # First valid element starts a segment + (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) | + # ((valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) & actual_valid[..., 1:]) + (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), + ], + axis=-1, + ) new_segment_ids = jnp.cumsum(segment_changes, axis=-1) - seqlens_pre = jax.vmap(lambda av_row, nsi_row: jnp.where(av_row, nsi_row, 0).astype(jnp.int32))(actual_valid, new_segment_ids) - seqlens_all = jax.vmap(lambda sp_row : jnp.bincount( - sp_row, - length=max_segments_per_seq+1 - )[1:])(seqlens_pre) - seqlens_all_pad_neg = jnp.where(seqlens_all==0, -1, seqlens_all) - max_new_segments_per_seq = 0 #TODO: Remove + seqlens_pre = jax.vmap( + lambda av_row, nsi_row: jnp.where(av_row, nsi_row, 0).astype(jnp.int32) + )(actual_valid, new_segment_ids) + seqlens_all = jax.vmap( + lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] + )(seqlens_pre) + seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) + max_new_segments_per_seq = 0 # TODO: Remove return max_new_segments_per_seq, seqlens_all_pad_neg - def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq): - segment_changes = jnp.concatenate([ - jnp.full((q_segment_pos.shape[0], 1), True, dtype=bool), # First valid element starts a segment - (q_segment_pos[...,1:] != q_segment_pos[...,:-1] + 1) # Segment pos changed - ], axis=-1) + def q_seqoffsets_for_striped_for_rank( + self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq + ): + segment_changes = jnp.concatenate( + [ + jnp.full( + (q_segment_pos.shape[0], 1), True, dtype=bool + ), # First valid element starts a segment + (q_segment_pos[..., 1:] != q_segment_pos[..., :-1] + 1), # Segment pos changed + ], + axis=-1, + ) # Remove any padded region segment changes segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False) # Get the indices for segment changes (these are the offsets) max_size = q_segment_pos.shape[-1] - seq_offsets_2 = jax.vmap(lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_masked) + seq_offsets_2 = jax.vmap( + lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + )(segment_changes_masked) return seq_offsets_2 def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): @@ -1724,9 +1737,13 @@ def kv_seqoffsets_for_striped_for_rank( ) # Get segment change indices for rank - segment_changes_indices = jax.vmap(lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_first_true_masked) + segment_changes_indices = jax.vmap( + lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + )(segment_changes_first_true_masked) # Get segment ids associated with the segment_changes_indices for rank - segment_ids = jax.vmap(lambda sci_row, ksi_row: jnp.where(sci_row>=0, ksi_row[sci_row], -1))(segment_changes_indices, kv_segment_ids) + segment_ids = jax.vmap( + lambda sci_row, ksi_row: jnp.where(sci_row >= 0, ksi_row[sci_row], -1) + )(segment_changes_indices, kv_segment_ids) # Get segment change indices for AG segment_changes_ag_first_true = jnp.concatenate( @@ -1744,14 +1761,16 @@ def kv_seqoffsets_for_striped_for_rank( kv_segment_ids_ag != 0, segment_changes_ag_first_true, False ) # Get segment change indices for AG - segment_changes_ag_indices = jax.vmap(lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq+1, fill_value=-1)[0])(segment_changes_ag_first_true_masked) + segment_changes_ag_indices = jax.vmap( + lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + )(segment_changes_ag_first_true_masked) # Use the segment ids picked per rank to get the offsets from the AG indices seq_offsets = jax.vmap( lambda si_row, sca_row: jnp.where(si_row > 0, sca_row[si_row - 1], -1) )(segment_ids, segment_changes_ag_indices) return seq_offsets - + class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): """ From d29f59a7a1e1ff544a2843772ff6f33d4c08f043 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Tue, 25 Nov 2025 11:25:28 -0800 Subject: [PATCH 24/36] Add SWA tests for AG+Striped>1+CP+THD+SWA Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_distributed_fused_attn.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 46d3fcf2e3..a81106be1d 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -507,7 +507,7 @@ def test_context_parallel_allgather_attn_shardy( @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1], + [DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1]], ) @pytest.mark.parametrize( "load_balanced", @@ -517,6 +517,13 @@ def test_context_parallel_allgather_attn_shardy( "stripe_height", [pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")], ) + @pytest.mark.parametrize( + "window_size", + [ + pytest.param((-1, -1), id="window_size(-1, -1)"), + pytest.param((5, 0), id="window_size(5, 0)"), + ], + ) def test_context_parallel_allgather_striped_attn( self, device_count, @@ -529,8 +536,11 @@ def test_context_parallel_allgather_striped_attn( dtype, qkv_layout, load_balanced, + window_size, stripe_height, ): + if window_size != (-1, -1) and not qkv_layout.is_thd(): + pytest.skip("Sliding window attention is only supported for THD layout") self.impl_test_context_parallel_attn( device_count, mesh_shape, @@ -544,6 +554,7 @@ def test_context_parallel_allgather_striped_attn( load_balanced, CPStrategy.ALL_GATHER, use_shardy=False, + window_size=window_size, stripe_height=stripe_height, ) From b01340a738d6426895522d19b26aafba4ce16eff Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Tue, 25 Nov 2025 11:32:10 -0800 Subject: [PATCH 25/36] Restoring test code Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_fused_attn.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 6bdbd16d95..5e53560aa8 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -491,29 +491,9 @@ def _setup_inputs(self): else: pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") - # self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0) - # self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0) - # self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) - # KL test code - q_np = np.zeros(q_shape, self.dtype) - k_np = np.zeros(k_shape, self.dtype) - token_numbers_q = range(self.max_seqlen_q) - token_numbers_k = range(self.max_seqlen_kv) - for batch_idx in range(q_shape[0]): - for token_idx in token_numbers_q: - q_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * ( - token_idx + 1 - ) - for token_idx in token_numbers_k: - k_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * np.sqrt( - self.head_dim_qk - ) - v_np = np.ones(v_shape, self.dtype) - # Set cols at multiples - v_np[0, ::4, 0, :] = np.arange(v_np.shape[3]) - self.q = jnp.array(q_np) - self.k = jnp.array(k_np) - self.v = jnp.array(v_np) + self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0) + self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0) + self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.bias_shape == BiasShape._1HSS: @@ -579,8 +559,6 @@ def generate_random_segment_ids( min_segment_size = 1 if min_segment_len is not None: min_segment_size = min_segment_len[i][seg_id] - # KL test code - min_segment_size = 4 segment_size = rng.integers(min_segment_size, max_segment_size + 1) if current_pos + segment_size > sequence_length: break @@ -603,7 +581,7 @@ def generate_random_segment_ids( if self.qkv_layout.is_thd(): self.num_segments_per_seq = 2 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( - self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=12 + self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) # TODO(rewang): record only self attention and find the reason of cross attention From c5921af9d1d6a3c7129b209c5b571e850d383bc6 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Tue, 25 Nov 2025 11:34:33 -0800 Subject: [PATCH 26/36] Remove assert preventing SWA code path in CP+AG+Striped primitive Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/cpp_extensions/attention.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 4e1fe5e08e..7d7ad1fb08 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2247,9 +2247,6 @@ class FusedAttnCPStripedWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): def partition(config, mesh, arg_infos, result_infos): # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 - assert ( - not is_context_parallel or config.window_size[0] == -1 - ), "Sliding window attention is not supported when context parallelism is enabled" if not is_context_parallel: return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) From 3bb1d5ac3e1d4a3da0eb87649e5487fed3a242b1 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 25 Nov 2025 19:43:34 +0000 Subject: [PATCH 27/36] Parametrize num_segments_per_seq in tests Signed-off-by: Kshitij Lakhani --- tests/jax/test_distributed_fused_attn.py | 8 ++++++++ tests/jax/test_fused_attn.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index a81106be1d..207a426ed6 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -355,6 +355,7 @@ def impl_test_context_parallel_attn( use_scan_ring=False, window_size=None, stripe_height=0, + num_segments_per_seq=0, ): if qkv_layout.is_thd(): # if cp_strategy == CPStrategy.ALL_GATHER: @@ -417,6 +418,7 @@ def impl_test_context_parallel_attn( cp_strategy=cp_strategy, cp_load_balanced=load_balanced, stripe_height=stripe_height, + num_segments_per_seq=num_segments_per_seq, ) def check_has_backend_for_mask(mask_type): @@ -524,6 +526,10 @@ def test_context_parallel_allgather_attn_shardy( pytest.param((5, 0), id="window_size(5, 0)"), ], ) + @pytest.mark.parametrize( + "num_segments_per_seq", + [pytest.param(2, id="SEG-2"), pytest.param(11, id="SEG-11")], + ) def test_context_parallel_allgather_striped_attn( self, device_count, @@ -538,6 +544,7 @@ def test_context_parallel_allgather_striped_attn( load_balanced, window_size, stripe_height, + num_segments_per_seq, ): if window_size != (-1, -1) and not qkv_layout.is_thd(): pytest.skip("Sliding window attention is only supported for THD layout") @@ -556,6 +563,7 @@ def test_context_parallel_allgather_striped_attn( use_shardy=False, window_size=window_size, stripe_height=stripe_height, + num_segments_per_seq=num_segments_per_seq, ) @pytest_parametrize_wrapper( diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 5e53560aa8..b068e7b69c 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -354,6 +354,7 @@ class FusedAttnRunner: window_size: Tuple[int, int] seq_desc_format: SeqDescFormat stripe_height: int = 0 + num_segments_per_seq: int = 0 # Specifies sharding resources for distributed tests number_of_devices: int = 1 @@ -579,7 +580,9 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - self.num_segments_per_seq = 2 + # If using default num segments of 0, set to 2 + if self.num_segments_per_seq==0: + self.num_segments_per_seq = 2 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) From 78fec5b369ab12633e6e9e7b38a4a9afbe3ed337 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 19:48:12 +0000 Subject: [PATCH 28/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index b068e7b69c..d27401714f 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -580,9 +580,9 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - # If using default num segments of 0, set to 2 - if self.num_segments_per_seq==0: - self.num_segments_per_seq = 2 + # If using default num segments of 0, set to 2 + if self.num_segments_per_seq == 0: + self.num_segments_per_seq = 2 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) From 69dad1aaba9aa947f09bee26c91e9e0383ed298a Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 25 Nov 2025 20:08:37 +0000 Subject: [PATCH 29/36] Clean up test code Signed-off-by: Kshitij Lakhani Clean up test code in TE common Signed-off-by: Kshitij Lakhani Clean up debug statements Signed-off-by: Kshitij Lakhani --- tests/jax/test_distributed_fused_attn.py | 10 -- tests/jax/test_fused_attn.py | 16 +- tests/jax/utils.py | 25 --- .../fused_attn_f16_arbitrary_seqlen.cu | 162 ------------------ transformer_engine/common/fused_attn/utils.cu | 8 - transformer_engine/jax/attention.py | 3 - .../jax/cpp_extensions/attention.py | 104 ----------- 7 files changed, 1 insertion(+), 327 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 207a426ed6..7b6ca67706 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -328,11 +328,8 @@ def test_cross_attn( DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [ # Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes. - # TODO: Change the id to CPx2 pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), - # KL test code - pytest.param([2, 8, 16, 64], id="2-8xCP-16-64"), ] @@ -358,8 +355,6 @@ def impl_test_context_parallel_attn( num_segments_per_seq=0, ): if qkv_layout.is_thd(): - # if cp_strategy == CPStrategy.ALL_GATHER: - # pytest.skip("THD doesn't support all gather context parallelism.") if not load_balanced and ( cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER ): @@ -389,10 +384,6 @@ def impl_test_context_parallel_attn( data_shape = batch, seqlen, num_head, hidden num_kv_heads = num_head // kv_groups - - # KL code For AG case only - # stripe_height = 4 if qkv_layout.is_thd() and cp_strategy == CPStrategy.ALL_GATHER else 0 - runner = FusedAttnRunner( batch, seqlen, @@ -459,7 +450,6 @@ def check_has_backend_for_mask(mask_type): if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") - # KL code runner.test_backward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index d27401714f..cba860380c 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -617,16 +617,6 @@ def generate_random_segment_ids( ) self.segment_pos_q = self.segment_pos_kv = None self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None - print( - f"self.segment_ids_q: {self.segment_ids_q}, \n self.segment_pos_q:" - f" {self.segment_pos_q}, \n self.pad_q: {self.pad_q}, \n self.seqlens_q:" - f" {self.seqlens_q}, \n self.offsets_q: { self.offsets_q} \n" - ) - print( - f"self.segment_ids_kv: {self.segment_ids_kv}, \n self.segment_pos_kv:" - f" {self.segment_pos_kv}, \n self.pad_kv: {self.pad_kv}, \n self.seqlens_kv:" - f" {self.seqlens_kv}, \n self.offsets_kv: { self.offsets_kv} \n" - ) # For reference code self.mask = make_mask( @@ -637,10 +627,6 @@ def generate_random_segment_ids( self.attn_mask_type, self.window_size, ) - # KL tet code - # import sys - # with np.printoptions(threshold=sys.maxsize): - # print(f"self.mask: \n {self.mask}") if self.cp_size > 1 and self.cp_load_balanced: if self.qkv_layout.is_thd(): @@ -790,7 +776,7 @@ def to_dp_shardings(x): self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None) self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec) - def test_forward(self): + def _test_forward(self): """ Test forward with JITted primitive and unJITted reference """ diff --git a/tests/jax/utils.py b/tests/jax/utils.py index b5170eacc3..cafb31aa85 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1507,31 +1507,6 @@ def assert_allclose( actual = actual.astype(jnp.float32) if not isinstance(desired, float): desired = desired.astype(jnp.float32) - # KL test code - import sys - - mismatch_counter = 0 - has_nonzero = jnp.any(actual != 0) - print(f"has_nonzero: {has_nonzero}") - with np.printoptions(threshold=sys.maxsize): - mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch - diff_indices = np.argwhere(mismatch_mask) - seq_set = set() - for idx in diff_indices: - idx_tuple = tuple(idx) - seq_set.add(idx_tuple[1]) - mismatch_counter += 1 - if mismatch_counter < 1024: - print(f"Index {idx_tuple}: a={actual[idx_tuple]}, d={desired[idx_tuple]}") - print(f"{sorted(seq_set)}") - # Batch 0 and head 0 - # for seq_idx in range(actual.shape[1]): - # #print("Mismatch at positions:\n", np.argwhere(mismatch_mask[0,:,0,:])) # Pick indices where mask is True - # for d_idx in range(actual.shape[3]): - # # print mismatches - # #if mismatch_mask[0][seq_idx][0][d_idx] == True: - # print(f"seq_idx: {seq_idx}, d_idx: {d_idx}, A: {actual[0][seq_idx][0][d_idx]}, D: {desired[0][seq_idx][0][d_idx]}") - print(f"mismatch_counter: {mismatch_counter}") # Check if tensors are close np.testing.assert_allclose(actual, desired, **tols, **kwargs) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index bd0702125b..14468b543a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -48,21 +48,6 @@ namespace transformer_engine { namespace fused_attn { -template -__global__ void print_tensor_elements_2(const T *const data, const size_t rows, - const size_t start_cols, const size_t end_cols, - const size_t cols) { - if ((threadIdx.x == 0) && (threadIdx.y == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - for (size_t i = 0; i < rows; ++i) { - for (size_t j = start_cols; j < end_cols; ++j) { - const size_t idx = i * cols + j; - printf("%8f ", static_cast(data[idx])); - } - printf("\n"); - } - } -} - void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, @@ -474,10 +459,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bias) { variant_pack[bias] = devPtrBias; } - //KL test code - bool print_tensors = true; - // For the thd_regular case, the actual_b = 18 - bool print_tensors_custom_mask = actual_b >= 300 ? true : false; if (is_padding) { constexpr size_t nthreads_per_block = 128; @@ -489,79 +470,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); NVTE_CHECK_CUDA(cudaGetLastError()); - //std::cout << "print_tensors: " << print_tensors << - // "print_tensors_custom_mask: " - // << print_tensors_custom_mask << std::endl; - if (print_tensors) { - if (devPtrCuSeqlensQ) { - if (print_tensors_custom_mask) { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrCuSeqlensQ), 1, 0, 8, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrCuSeqlensQ), 1, 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrCuSeqlensQ), 1, 8184, 8192, - /*does not matter for single row*/ actual_b); - } else { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrCuSeqlensQ), 1, 0, actual_b, - /*does not matter for single row*/ actual_b); - } - } - if (devActualSeqlenQ) { - if (print_tensors_custom_mask) { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devActualSeqlenQ), 1, 0, 8, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devActualSeqlenQ), 1, 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devActualSeqlenQ), 1, 8184, 8192, - /*does not matter for single row*/ actual_b); - } else { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devActualSeqlenQ), 1, 0, actual_b, - /*does not matter for single row*/ actual_b); - } - } - if (devPtrCuSeqlensKV) { - if (print_tensors_custom_mask) { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrCuSeqlensKV), 1, 0, 8, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrCuSeqlensKV), 1, 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrCuSeqlensKV), 1, 8184, 8192, - /*does not matter for single row*/ actual_b); - } else { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrCuSeqlensKV), 1, 0, actual_b, - /*does not matter for single row*/ actual_b); - } - } - if (devActualSeqlenKV) { - if (print_tensors_custom_mask) { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devActualSeqlenKV), 1, 0, 8, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devActualSeqlenKV), 1, 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devActualSeqlenKV), 1, 8184, 8192, - /*does not matter for single row*/ actual_b); - } else { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devActualSeqlenKV), 1, 0, actual_b, - /*does not matter for single row*/ actual_b); - } - } - } variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -601,76 +509,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); NVTE_CHECK_CUDA(cudaGetLastError()); - if (print_tensors) { - if (devPtrSeqOffsetsQ) { - if (print_tensors_custom_mask) { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrSeqOffsetsQ), 1, 0, 8, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrSeqOffsetsQ), 1, 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrSeqOffsetsQ), 1, 8184, 8192, - /*does not matter for single row*/ actual_b); - } else { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrSeqOffsetsQ), 1, 0, actual_b, - /*does not matter for single row*/ actual_b); - } - } - if (devOffsetsQ) { - if (print_tensors_custom_mask) { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devOffsetsQ), 1, 0, 8, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devOffsetsQ), 1, 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devOffsetsQ), 1, 8184, 8192, - /*does not matter for single row*/ actual_b); - } else { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devOffsetsQ), 1, 0, actual_b, - /*does not matter for single row*/ actual_b); - } - } - if (devPtrSeqOffsetsKV) { - if (print_tensors_custom_mask) { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrSeqOffsetsKV), 1, 0, 8, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrSeqOffsetsKV), 1, 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrSeqOffsetsKV), 1, 8184, 8192, - /*does not matter for single row*/ actual_b); - } else { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devPtrSeqOffsetsKV), 1, 0, actual_b, - /*does not matter for single row*/ actual_b); - } - } - if (devOffsetsK) { - if (print_tensors_custom_mask) { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devOffsetsK), 1, 0, 8, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devOffsetsK), 1, 1024, 1032, - /*does not matter for single row*/ actual_b); - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devOffsetsK), 1, 8184, 8192, - /*does not matter for single row*/ actual_b); - } else { - print_tensor_elements_2<<<1, 1, 0, stream>>>( - static_cast(devOffsetsK), 1, 0, actual_b, - /*does not matter for single row*/ actual_b); - } - } - } if (is_ragged_q) { variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_o] = devOffsetsO; diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index f635014fca..df1eae0dd7 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -428,14 +428,6 @@ __device__ void cu_seqlens_padded_to_offsets_impl( OFFSETS_T *offsets_v, OFFSETS_T *offsets_o, OFFSETS_T *offsets_s) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; auto cu_seqlens_id = min(tid, actual_b); - if (tid == 0) { - printf("actual_b: %lld \n", (long long int)actual_b); - printf("max_b: %lld \n", (long long int)max_b); - printf("h: %lld \n", (long long int)h); - printf("hg: %lld \n", (long long int)hg); - printf("d_qk: %lld \n", (long long int)d_qk); - printf("d_v: %lld \n", (long long int)d_v); - } if (tid <= max_b) { if (offsets_s != nullptr) { offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index ac5c155baf..55712ec8e6 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -557,7 +557,6 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids_kv, lambda x, y: jnp.equal(x, y) * x, ) - # jax.debug.breakpoint() # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied attn_mask = segment_mask if attn_mask_type.is_bottom_right(): @@ -604,7 +603,6 @@ def _segment_ids_pos_to_seqlens_offsets( q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( attn_mask_with_id, max_segments_per_seq ) - # jax.debug.breakpoint() return q_seqlen, kv_seqlen, q_offset, kv_offset @@ -684,7 +682,6 @@ def get_seqlens_and_offsets( window_size, max_segments_per_seq, ) - # jax.debug.breakpoint() else: q_seqlens, kv_seqlens = _segment_ids_to_seqlens( q_segment_ids, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 7d7ad1fb08..9ab58611a0 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -521,37 +521,6 @@ def impl( _kv_segment_pos, config: _FusedAttnConfig, ): - DEBUG = True # os.environ.get("TE_DEBUG_STRIPED_ATTN", "0") == "1" - # if DEBUG: - # jax.debug.print("FusedAttnFwdPrimitive.impl CALLED") - # jax.debug.print("Config: qkv_layout={}, attn_mask_type={}", - # str(config.qkv_layout), str(config.attn_mask_type)) - # jax.debug.print("Input shapes:") - # jax.debug.print(" q={}, k={}, v={}", q.shape, k.shape, v.shape) - # jax.debug.print(" q_seqlen={}, kv_seqlen={}", q_seqlen.shape, kv_seqlen.shape) - - # def print_impl_inputs(q_val, k_val, v_val, q_seq, kv_seq, q_off, k_off): - # print(f"\n~~~ FusedAttnFwdPrimitive.impl INPUTS ~~~") - # print(f"Q: shape={q_val.shape}, mean={q_val.mean():.6f}, std={q_val.std():.6f}") - # print(f" First 5: {q_val.flatten()[:5]}") - - # print(f"K: shape={k_val.shape}, mean={k_val.mean():.6f}, std={k_val.std():.6f}") - # print(f" First 5: {k_val.flatten()[:5]}") - - # print(f"V: shape={v_val.shape}, mean={v_val.mean():.6f}, std={v_val.std():.6f}") - # print(f" First 5: {v_val.flatten()[:5]}") - - # print(f"\nSequence info:") - # print(f" q_seqlen: {q_seq}") - # print(f" kv_seqlen: {kv_seq}") - # print(f" q_seq_offsets: {q_off}") - # print(f" k_seq_offsets: {k_off}") - # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") - - # jax.debug.callback( - # print_impl_inputs, - # q, k, v, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets - # ) assert FusedAttnFwdPrimitive.inner_primitive is not None sequence_descriptor = SequenceDescriptor( @@ -568,23 +537,7 @@ def impl( config.max_segments_per_seq, ) ) - # if DEBUG: - # jax.debug.print("After sequence_descriptor processing:") - # jax.debug.print(" q_seqlen={}, kv_seqlen={}", q_seqlen.shape, kv_seqlen.shape) - - # def print_seq_descriptor(q_seq, kv_seq, q_off, k_off): - # print(f"\n~~~ SEQUENCE DESCRIPTOR OUTPUTS ~~~") - # print(f"q_seqlen (processed): {q_seq}") - # print(f"kv_seqlen (processed): {kv_seq}") - # print(f"q_seq_offsets (processed): {q_off}") - # print(f"k_seq_offsets (processed): {k_off}") - # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") - - # jax.debug.callback(print_seq_descriptor, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets) - # jax.debug.print("Hello FA impl") if config.qkv_layout.is_thd(): - # if DEBUG: - # jax.debug.print("Processing THD layout...") def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -608,10 +561,6 @@ def convert_to_2d(offsets, batch, max_seqlen): assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}" kv_batch = q_batch = batch[0] - # if DEBUG: - # jax.debug.print(" batch={}, q_max_seqlen={}, kv_max_seqlen={}", - # q_batch, q_max_seqlen, kv_max_seqlen) - # Gather valid q_seqlen, which is greater than 0 # cuDNN version < 9.3.0: # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] @@ -640,35 +589,9 @@ def convert_to_2d(offsets, batch, max_seqlen): k_seq_offsets = _fix_len_take( k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen ) - # if DEBUG: - # def print_thd_processing(q_seq, kv_seq, q_off, k_off): - # print(f"\n~~~ AFTER THD PROCESSING ~~~") - # print(f"q_seqlen (fixed): {q_seq}") - # print(f"kv_seqlen (fixed): {kv_seq}") - # print(f"q_seq_offsets (2d): {q_off}") - # print(f"k_seq_offsets (2d): {k_off}") - # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") - - # jax.debug.callback(print_thd_processing, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) - # jax.debug.print(f"q_seqlen: {q_seqlen}, kv_seqlen: {kv_seqlen}") - # if DEBUG: - # jax.debug.print("Generated cumulative sequence lengths:") - # jax.debug.print(" q_cu_seqlen={}", q_cu_seqlen.shape) - # jax.debug.print(" kv_cu_seqlen={}", kv_cu_seqlen.shape) - - # def print_cu_seqlen(q_cu, kv_cu): - # print(f"\n~~~ CUMULATIVE SEQLENS ~~~") - # print(f"q_cu_seqlen: {q_cu}") - # print(f"kv_cu_seqlen: {kv_cu}") - # print(f"~~~~~~~~~~~~~~~~~~~~~~~~~~\n") - - # jax.debug.callback(print_cu_seqlen, q_cu_seqlen, kv_cu_seqlen) - - # if DEBUG: - # jax.debug.print("Calling inner_primitive.bind...") output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( q, @@ -1485,14 +1408,12 @@ def ag(x): def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos): """Performs a all-gather of k and v over context parallel ranks.""" - # TODO: Is the axis chosen right ? kv_segment_ids = lax_paral_op( kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True ) kv_segment_pos = lax_paral_op( kv_segment_pos, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True ) - # jax.debug.breakpoint() if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) if self.config.qkv_layout.is_thd(): @@ -1801,7 +1722,6 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[5] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - # jax.debug.breakpoint() def impl( q, @@ -1821,8 +1741,6 @@ def impl( ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) - # jax.debug.print("Test CP DC AG") - Gives a seg fault - # breakpoint() # cuDNN does not support right-aligned masking with dynamic sequence length padding. # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch @@ -1833,8 +1751,6 @@ def impl( def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): kv_max_seqlen = k.shape[1] kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) - # jax.debug.print("Test cross attn ag") - Gives a seg fault - # jax.debug.print(f"kv_max_seqlen: {kv_max_seqlen}") assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" q_split = jnp.split(q, 2, axis=1) @@ -1844,7 +1760,6 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): ) results = [] - # breakpoint() for sub_idx in range(2): if config.attn_mask_type == AttnMaskType.NO_MASK: k_unmasked, v_unmasked = k, v # full kv used for unmasked @@ -1854,7 +1769,6 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): q_seqlen_for_step = q_seqlen / (cp_size * 2) num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks - # breakpoint() output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_split[sub_idx], k_unmasked, @@ -2084,18 +1998,6 @@ class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): @staticmethod def partition(config, mesh, arg_infos, result_infos): - DEBUG = True # os.environ.get("TE_DEBUG_STRIPED_ATTN", "0") == "1" - if DEBUG: - print(f"STRIPED PARTITION CALLED (Compilation Phase)") - print(f"Mesh: {mesh}") - print(f"CP axis: {config.cp_axis}, size: {get_mesh_axis_size(config.cp_axis, mesh)}") - print( - f"window_size: {config.window_size}, context_parallel_load_balanced:" - f" {config.context_parallel_load_balanced}, stripe_height: {config.stripe_height}" - ) - print(f"Arg shapes: {[info.shape for info in arg_infos]}") - print(f"QKV layout: {config.qkv_layout}") - print(f"Attention mask type: {config.attn_mask_type}") # Call base implementation for non-context parallel mesh to avoid unecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 if not is_context_parallel: @@ -2113,10 +2015,6 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[5] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - if DEBUG: - print(f"STRIPED PARTITION CALLED (Compilation Phase)") - print(f"Arg shardings: {[arg_i.sharding for arg_i in arg_infos]}") - print(f"Out shardings: {[out_i for out_i in out_shardings]}") def impl( q, @@ -2136,7 +2034,6 @@ def impl( ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) - # jax.debug.print("Test CP striped AG") # cuDNN does not support right-aligned masking with dynamic sequence length padding. # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch @@ -3521,7 +3418,6 @@ def fused_attn_fwd( primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive else: primitive = FusedRingAttnFwdPrimitive.outer_primitive - print(f"qkv_for_primitive: \n {qkv_for_primitive}") seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) output, softmax_aux, rng_state = primitive.bind( *qkv_for_primitive, From 21663051d8f1e936b3ec5bbc14427c270a867d44 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 25 Nov 2025 20:54:30 +0000 Subject: [PATCH 30/36] Rename stripe_height to stripe_size Signed-off-by: Kshitij Lakhani --- tests/jax/test_distributed_fused_attn.py | 22 ++++----- tests/jax/test_fused_attn.py | 10 ++-- transformer_engine/jax/attention.py | 28 +++++------ .../jax/cpp_extensions/attention.py | 48 +++++++++---------- 4 files changed, 54 insertions(+), 54 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 7b6ca67706..70d9a3f464 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -351,7 +351,7 @@ def impl_test_context_parallel_attn( use_shardy, use_scan_ring=False, window_size=None, - stripe_height=0, + stripe_size=0, num_segments_per_seq=0, ): if qkv_layout.is_thd(): @@ -408,7 +408,7 @@ def impl_test_context_parallel_attn( mesh_resource=mesh_resource, cp_strategy=cp_strategy, cp_load_balanced=load_balanced, - stripe_height=stripe_height, + stripe_size=stripe_size, num_segments_per_seq=num_segments_per_seq, ) @@ -506,7 +506,7 @@ def test_context_parallel_allgather_attn_shardy( [pytest.param(True, id="BALANCED")], ) @pytest.mark.parametrize( - "stripe_height", + "stripe_size", [pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")], ) @pytest.mark.parametrize( @@ -533,7 +533,7 @@ def test_context_parallel_allgather_striped_attn( qkv_layout, load_balanced, window_size, - stripe_height, + stripe_size, num_segments_per_seq, ): if window_size != (-1, -1) and not qkv_layout.is_thd(): @@ -552,7 +552,7 @@ def test_context_parallel_allgather_striped_attn( CPStrategy.ALL_GATHER, use_shardy=False, window_size=window_size, - stripe_height=stripe_height, + stripe_size=stripe_size, num_segments_per_seq=num_segments_per_seq, ) @@ -721,10 +721,10 @@ class TestReorderCausalLoadBalancing: @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD]) @pytest.mark.parametrize( - "reorder_strategy, stripe_height", + "reorder_strategy, stripe_size", REORDER_STRATEGY, ) - def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_height): + def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_size): tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) seq_dim = 1 if qkv_format == QKVFormat.SBHD: @@ -733,15 +733,15 @@ def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_height): if reorder_strategy == ReorderStrategy.Striped: seq_lens = shape[seq_dim] - if seq_lens < (cp_size * stripe_height): - pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_height=}") + if seq_lens < (cp_size * stripe_size): + pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_size=}") ref = tensor.copy() reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) - reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_height) - inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_height) + reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_size) + inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_size) assert jnp.array_equal(inversed, ref) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index cba860380c..4deb2a1856 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -353,7 +353,7 @@ class FusedAttnRunner: bias_shape: BiasShape window_size: Tuple[int, int] seq_desc_format: SeqDescFormat - stripe_height: int = 0 + stripe_size: int = 0 num_segments_per_seq: int = 0 # Specifies sharding resources for distributed tests @@ -640,14 +640,14 @@ def generate_random_segment_ids( strategy=reorder_strategy, cp_size=self.cp_size, seq_dim=seq_dim, - stripe_height=self.stripe_height, + stripe_size=self.stripe_size, ) self.cp_inverse_reorder_fn = partial( inverse_reorder_causal_load_balancing, strategy=reorder_strategy, cp_size=self.cp_size, seq_dim=seq_dim, - stripe_height=self.stripe_height, + stripe_size=self.stripe_size, ) else: # no-ops for non cp or non load balanced @@ -808,7 +808,7 @@ def _test_forward(self): "window_size": self.window_size, "context_parallel_strategy": self.cp_strategy, "context_parallel_causal_load_balanced": self.cp_load_balanced, - "stripe_height": self.stripe_height, + "stripe_size": self.stripe_size, } customcall_fused_dpa_jit = jit( @@ -904,7 +904,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): "window_size": self.window_size, "context_parallel_strategy": self.cp_strategy, "context_parallel_causal_load_balanced": self.cp_load_balanced, - "stripe_height": self.stripe_height, + "stripe_size": self.stripe_size, } # We can compute dBias only for the [1, h, s, s] layout diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 55712ec8e6..95256f2cff 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -387,24 +387,24 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): def reorder_causal_load_balancing( - tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1 + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int = 1 ): """Reorders a tensor for load balancing the compute of causal attention.""" if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) if strategy == ReorderStrategy.Striped: - return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False, stripe_height) + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False, stripe_size) raise ValueError(f"Unsupported {strategy=}") def inverse_reorder_causal_load_balancing( - tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1 + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int = 1 ): """Inverse operation of `reorder_causal_load_balancing`.""" if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True) if strategy == ReorderStrategy.Striped: - return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True, stripe_height) + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True, stripe_size) raise ValueError(f"Unsupported {strategy=}") @@ -1011,7 +1011,7 @@ def _fused_attn( context_parallel_causal_load_balanced: bool, context_parallel_axis: str, context_checkpoint_name: str = "context", - stripe_height: int = 0, + stripe_size: int = 0, ): output, _ = _fused_attn_fwd_rule( qkv, @@ -1032,7 +1032,7 @@ def _fused_attn( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name=context_checkpoint_name, - stripe_height=stripe_height, + stripe_size=stripe_size, ) return output @@ -1056,7 +1056,7 @@ def _fused_attn_fwd_rule( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name, - stripe_height, + stripe_size, ): output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, @@ -1076,7 +1076,7 @@ def _fused_attn_fwd_rule( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, - stripe_height=stripe_height, + stripe_size=stripe_size, ) output = checkpoint_name(output, context_checkpoint_name) softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) @@ -1106,7 +1106,7 @@ def _fused_attn_bwd_rule( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name, - stripe_height, + stripe_size, ctx, dz, ): @@ -1141,7 +1141,7 @@ def _fused_attn_bwd_rule( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, - stripe_height=stripe_height, + stripe_size=stripe_size, ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -1178,7 +1178,7 @@ def fused_attn( context_parallel_axis: str = "", context_checkpoint_name: str = "context", softmax_offset: Optional[jnp.ndarray] = None, - stripe_height: int = 0, + stripe_size: int = 0, ): """ Perform cuDNN fused attention. @@ -1216,9 +1216,9 @@ def fused_attn( softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape [1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX. If provided, this parameter will receive gradients during backpropagation. - stripe_height (int): + stripe_size (int): Indicates the striping height to be used when using ReorderStrategy.Striped. - Currently, a stripe_height > 1 is only allowed for CP + THD + Striped + AG + Currently, a stripe_size > 1 is only allowed for CP + THD + Striped + AG 0 indicates no striping strategy Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -1297,6 +1297,6 @@ def fused_attn( context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, context_checkpoint_name=context_checkpoint_name, - stripe_height=stripe_height, + stripe_size=stripe_size, ) return output diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 9ab58611a0..1c7d1ae2f2 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -73,7 +73,7 @@ "context_parallel_load_balanced", "cp_axis", "cp_striped_window_size", - "stripe_height", + "stripe_size", ], ) @dataclass(frozen=True) @@ -94,7 +94,7 @@ class _FusedAttnConfig: context_parallel_load_balanced: bool cp_axis: str cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA - stripe_height: int # Only for CP + Striped. For, Ring P2P , stripe_height=1 only. + stripe_size: int # Only for CP + Striped. For, Ring P2P , stripe_size=1 only. @dataclass(frozen=True) @@ -1235,26 +1235,26 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig def reorder_causal_striped( - tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_height: int = 1 + tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_size: int = 1 ): """Reorders a tensor for load balancing with striped pattern""" origin_shape = tensor.shape - if origin_shape[seq_dim] % (cp_size * stripe_height) != 0: + if origin_shape[seq_dim] % (cp_size * stripe_size) != 0: raise ValueError( - "Expected origin_shape[seq_dim] is multiple of cp_size*stripe_height but got" - f" {origin_shape[seq_dim]=}, {cp_size=}, {stripe_height=}, {cp_size*stripe_height=}" + "Expected origin_shape[seq_dim] is multiple of cp_size*stripe_size but got" + f" {origin_shape[seq_dim]=}, {cp_size=}, {stripe_size=}, {cp_size*stripe_size=}" ) if not is_inverse: new_shape = [ *origin_shape[:seq_dim], - *[origin_shape[seq_dim] // (cp_size * stripe_height), cp_size, stripe_height], + *[origin_shape[seq_dim] // (cp_size * stripe_size), cp_size, stripe_size], *origin_shape[seq_dim + 1 :], ] else: new_shape = [ *origin_shape[:seq_dim], - *[cp_size, origin_shape[seq_dim] // (cp_size * stripe_height), stripe_height], + *[cp_size, origin_shape[seq_dim] // (cp_size * stripe_size), stripe_size], *origin_shape[seq_dim + 1 :], ] @@ -1286,8 +1286,8 @@ def check_supported(self): f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) - if (not self.config.qkv_layout.is_thd() and self.config.stripe_height != 0) or ( - self.config.qkv_layout.is_thd() and self.config.stripe_height == 0 + if (not self.config.qkv_layout.is_thd() and self.config.stripe_size != 0) or ( + self.config.qkv_layout.is_thd() and self.config.stripe_size == 0 ): raise ValueError( f"{header} only supports Dual Chunk load balancing with BSHD layouts and Striped" @@ -1342,7 +1342,7 @@ def get_adjusted_mask(self): def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): # Estimating return ( - max_seqlen // (self.config.stripe_height * cp_size) + max_seqlen // (self.config.stripe_size * cp_size) ) + self.config.max_segments_per_seq def get_step_config(self) -> _FusedAttnConfig: @@ -1361,7 +1361,7 @@ def get_step_config(self) -> _FusedAttnConfig: context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, - stripe_height=self.config.stripe_height, + stripe_size=self.config.stripe_size, ) def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: @@ -1380,7 +1380,7 @@ def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, - stripe_height=self.config.stripe_height, + stripe_size=self.config.stripe_size, ) def all_gather_kv(self, k, v): @@ -1393,7 +1393,7 @@ def ag(x): if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) if self.config.qkv_layout.is_thd(): - x = reorder_causal_striped(x, cp_size, 1, True, self.config.stripe_height) + x = reorder_causal_striped(x, cp_size, 1, True, self.config.stripe_size) else: x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True) return x @@ -1418,10 +1418,10 @@ def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos): cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) if self.config.qkv_layout.is_thd(): kv_segment_ids_ag = reorder_causal_striped( - kv_segment_ids, cp_size, 1, True, self.config.stripe_height + kv_segment_ids, cp_size, 1, True, self.config.stripe_size ) kv_segment_pos_ag = reorder_causal_striped( - kv_segment_pos, cp_size, 1, True, self.config.stripe_height + kv_segment_pos, cp_size, 1, True, self.config.stripe_size ) return kv_segment_ids_ag, kv_segment_pos_ag # TODO: Is the dual chunk case needed ? @@ -1434,7 +1434,7 @@ def rs(x): if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) if self.config.qkv_layout.is_thd(): - x = reorder_causal_striped(x, cp_size, 1, False, self.config.stripe_height) + x = reorder_causal_striped(x, cp_size, 1, False, self.config.stripe_size) else: x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False) @@ -2414,7 +2414,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, - stripe_height=self.config.stripe_height, + stripe_size=self.config.stripe_size, ) def stack_kv(self, k, v): @@ -3297,7 +3297,7 @@ def fused_attn_fwd( context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", - stripe_height: int = 0, + stripe_size: int = 0, ) -> jnp.ndarray: """ Perform the forward pass of with cuDNN fused attention implementations. @@ -3336,7 +3336,7 @@ def fused_attn_fwd( context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. - stripe_height (int): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing + stripe_size (int): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -3402,7 +3402,7 @@ def fused_attn_fwd( context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, - stripe_height=stripe_height, + stripe_size=stripe_size, ) primitive = None @@ -3452,7 +3452,7 @@ def fused_attn_bwd( context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", - stripe_height: int = 0, + stripe_size: int = 0, ): """ Perform the backward pass of the cuDNN fused attention implementations. @@ -3492,7 +3492,7 @@ def fused_attn_bwd( context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. - stripe_height (int): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing + stripe_size (int): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing Returns: Tuple[jnp.ndarray, ...], jnp.ndarray: - The first tuple contains the gradients with respect to the input `qkv` tensors in the @@ -3565,7 +3565,7 @@ def fused_attn_bwd( context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, - stripe_height=stripe_height, + stripe_size=stripe_size, ) primitive = None From c5e0d6f5a257646fff89041db05c0628f45812cf Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 25 Nov 2025 23:35:13 +0000 Subject: [PATCH 31/36] Code clean up and add additional comments Signed-off-by: Kshitij Lakhani --- tests/jax/test_distributed_fused_attn.py | 1 - tests/jax/utils.py | 1 + transformer_engine/jax/attention.py | 17 +- .../jax/cpp_extensions/attention.py | 195 ++++++++++-------- 4 files changed, 115 insertions(+), 99 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 70d9a3f464..222275f775 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -436,7 +436,6 @@ def check_has_backend_for_mask(mask_type): # and exception if the step backend is not supported. This was a deliberate API # decision to keep the CP size or flag out of the function. has_backend = check_has_backend_for_mask(attn_mask_type) - # TODO: For PADDING_CAUSAL_MASK ? if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK: has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index cafb31aa85..7194e387c7 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1507,6 +1507,7 @@ def assert_allclose( actual = actual.astype(jnp.float32) if not isinstance(desired, float): desired = desired.astype(jnp.float32) + # Check if tensors are close np.testing.assert_allclose(actual, desired, **tols, **kwargs) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 95256f2cff..bc1b25cd82 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -393,6 +393,7 @@ def reorder_causal_load_balancing( if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) if strategy == ReorderStrategy.Striped: + # stripe_size > 1 is only supported for CP+THD+AG+Striped return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False, stripe_size) raise ValueError(f"Unsupported {strategy=}") @@ -404,6 +405,7 @@ def inverse_reorder_causal_load_balancing( if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True) if strategy == ReorderStrategy.Striped: + # stripe_size > 1 is only supported for CP+THD+AG+Striped return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True, stripe_size) raise ValueError(f"Unsupported {strategy=}") @@ -538,13 +540,12 @@ def _segment_ids_pos_to_seqlens_offsets( # It does not need to involve SW for this mask's creation # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - # TODO: Un comment the fast path - # if (attn_mask_type.is_causal() and window_size is None) or ( - # window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - # ): - # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( - # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq - # ) + if (attn_mask_type.is_causal() and window_size is None) or ( + window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + ): + return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( + segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq + ) # (1 = attend, 0 = masked) segment_mask = make_attention_mask( @@ -1217,7 +1218,7 @@ def fused_attn( [1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX. If provided, this parameter will receive gradients during backpropagation. stripe_size (int): - Indicates the striping height to be used when using ReorderStrategy.Striped. + Indicates the striping size to be used when using ReorderStrategy.Striped. Currently, a stripe_size > 1 is only allowed for CP + THD + Striped + AG 0 indicates no striping strategy Returns: diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 1c7d1ae2f2..a3d4eb2b59 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1297,8 +1297,6 @@ def check_supported(self): if self.config.attn_bias_type != AttnBiasType.NO_BIAS: raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") - # TODO: Should AttnMaskType.PADDING_CAUSAL_MASK be allowed for CP + AG + THD + Striped ? - # TODO: Should Should AttnMaskType.NO_MASK be allowed for CP + AG + THD + Striped ? allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] if self.config.qkv_layout.is_thd(): allowed_masks.append(AttnMaskType.PADDING_CAUSAL_MASK) @@ -1330,24 +1328,23 @@ def get_adjusted_mask(self): if ( self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK and not self.config.qkv_layout.is_thd() - ): # BSHD only ? + ): # BSHD AG case only return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK if ( self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK and self.config.qkv_layout.is_thd() - ): # THD only ? + ): # THD AG case only return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): - # Estimating + # Estimating adjusted max segments per seq return ( max_seqlen // (self.config.stripe_size * cp_size) ) + self.config.max_segments_per_seq def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" - # TODO: Should the max_segments_per_seq be different ? return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=self.get_adjusted_mask(), @@ -1365,8 +1362,7 @@ def get_step_config(self) -> _FusedAttnConfig: ) def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: - """Returns a _FusedAttnConfig for single CP step call to fused attention.""" - # TODO: Should the max_segments_per_seq be different ? + """Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention.""" return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=self.get_adjusted_mask(), @@ -1384,7 +1380,7 @@ def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: ) def all_gather_kv(self, k, v): - """Performs a all-gather of k and v over context parallel ranks.""" + """Performs aa all-gather of k and v over context parallel ranks.""" def ag(x): x = lax_paral_op( @@ -1406,8 +1402,7 @@ def ag(x): return k, v # fall through def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos): - """Performs a all-gather of k and v over context parallel ranks.""" - + """Performs aa all-gather of kv segment ids and kv segment pos over context parallel ranks.""" kv_segment_ids = lax_paral_op( kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True ) @@ -1424,7 +1419,6 @@ def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos): kv_segment_pos, cp_size, 1, True, self.config.stripe_size ) return kv_segment_ids_ag, kv_segment_pos_ag - # TODO: Is the dual chunk case needed ? return kv_segment_ids, kv_segment_pos # fall through def reduce_scatter_dkv(self, dk, dv): @@ -1509,17 +1503,25 @@ def pad(x, npad): return dk, dv # fall through + # Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos + # For e.g. below are the sharded post AG q seg ids and pos for a given rank: + # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] + # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] + # max_segments_per_seq = 7 + # Below are some intermediate representations: + # non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]] + # segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, True, True, True]] + # seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]] + # seqlens_all_pad_neg = [[ 4, 4, 4, -1, -1, -1, -1]] def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): - # Create mask for non-zero segment IDs + # Create mask for non-zero seg ids and get the non-zero indices associated with the same non_zero_mask = q_segment_ids != 0 - # Calculate indices from mask max_size = q_segment_ids.shape[-1] - # Get non-zero indices for each row (need to vmap underlying jnp.nonzero calls made by jnp.where) non_zero_indices = jax.vmap( lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] )(non_zero_mask) - # Pick non zero seg ids and seg pos using take_along_axis + # Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos # Clip -1 to 0 for safe indexing clipped_indices = jnp.clip(non_zero_indices, 0, None) valid_segment_ids = jnp.where( @@ -1528,13 +1530,11 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme valid_segment_pos = jnp.where( non_zero_indices >= 0, jnp.take_along_axis(q_segment_pos, clipped_indices, axis=-1), 0 ) - - # Create mask for actual valid entries (not padding) + # Create a mask for actual valid entries (not padding) actual_valid = valid_segment_ids != 0 - - # Detect segment changes, accounting for padding # First element is True only if it's actually valid first_is_segment = actual_valid[..., 0:1] + # Detect segment breaks in the valid tokens only (not full seq) # Padding will always be true as the segment change condition is being applied # on the valid segments (which have padding at the end so they'll always trigger True) @@ -1555,12 +1555,18 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] )(seqlens_pre) seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) - max_new_segments_per_seq = 0 # TODO: Remove - return max_new_segments_per_seq, seqlens_all_pad_neg - - def q_seqoffsets_for_striped_for_rank( - self, q_segment_ids, q_segment_pos, q_num_segments, max_segments_per_seq - ): + return seqlens_all_pad_neg + + # Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos + # For e.g. below are the sharded post AG q seg ids and pos for a given rank: + # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] + # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] + # max_segments_per_seq = 7 + # Below are some intermediate representations: + # segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, False, False, False]] + # segment_changes_masked = [[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]] + # seq_offsets = [[ 0, 8, 12, -1, -1, -1, -1, -1]] + def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): segment_changes = jnp.concatenate( [ jnp.full( @@ -1574,17 +1580,25 @@ def q_seqoffsets_for_striped_for_rank( segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False) # Get the indices for segment changes (these are the offsets) max_size = q_segment_pos.shape[-1] - seq_offsets_2 = jax.vmap( + seq_offsets = jax.vmap( lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] )(segment_changes_masked) - return seq_offsets_2 + return seq_offsets + # Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos + # For e.g. below are the sharded post AG q seg ids and pos for a given rank: + # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] + # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] + # max_segments_per_seq = 7 + # Below are some intermediate representations: + # non_zero_mask = [[ True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True]] + # non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]] + # segment_changes = [[False, False, False, True, False, False, False, True, False, False, False, True, True, True, True, False]] + # selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]] def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): - # Create mask for non-zero segment IDs + # Create mask for non-zero seg ids and get the non-zero indices associated with the same non_zero_mask = kv_segment_ids != 0 - # Filter to only non-zero segments max_size = kv_segment_ids.shape[-1] - # Get non-zero indices for each row (need to vmap underlying jnp.nonzero calls made by jnp.where) non_zero_indices = jax.vmap( lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] )(non_zero_mask) @@ -1599,10 +1613,8 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se non_zero_indices >= 0, jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), 0 ) actual_valid = valid_segment_ids != 0 - - # Detect segment changes, accounting for padding - # First element is True only if it's actually valid first_is_segment = actual_valid[..., 0:1] + # Detect segment breaks (only for non-zero segments) segment_changes = jnp.concatenate( [ @@ -1615,14 +1627,12 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se ], axis=-1, ) - - # Get the indices for segment changes - apply vmap per row + # Get the indices for segment changes segment_changes_valid = jax.vmap( lambda sc_row, av_row: jnp.where( sc_row & av_row, size=max_segments_per_seq, fill_value=-1 )[0] )(segment_changes, actual_valid) - # Safe indices safe_indices = jnp.maximum(segment_changes_valid, 0) # Select values using take_along_axis per row selected_values = jnp.where( @@ -1630,17 +1640,42 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se jnp.take_along_axis(valid_segment_pos, safe_indices, axis=-1) + 1, -1, ) - # Count non-zero per row or total - num_segments = jnp.count_nonzero(selected_values > 0, axis=-1).astype(int) # Per row - return num_segments, selected_values - + return selected_values + + # Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos, + # AG kv seg ids and seg pos. + # For e.g. below are the sharded post AG q seg ids and pos for a given rank: + # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] + # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] + # kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + # 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + # kv_segment_pos_ag = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + # 18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + # 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + # max_segments_per_seq = 7 + # Below are some intermediate representations: + # segment_changes_first_true_masked = [[ True, False, False, False, False, False, False, False, True, + # False, False, False, True, False, False, False]] + # segment_changes_indices = [[ 0, 8, 12, -1, -1, -1, -1, -1, -1]] + # segment_ids = [[ 1, 2, 2, -1, -1, -1, -1, -1, -1]] + # segment_changes_ag_first_true_masked = [[ True, False, False, False, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False, False, False, True, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False] + # segment_changes_ag_indices = [[ 0, 21, -1, -1, -1, -1, -1, -1, -1]] + # seq_offsets = [[ 0, 21, 21, -1, -1, -1, -1, -1, -1]] def kv_seqoffsets_for_striped_for_rank( self, kv_segment_pos, kv_segment_ids, kv_segment_pos_ag, kv_segment_ids_ag, - kv_num_segments, max_segments_per_seq, ): # Calculate the segment pos change mask @@ -1648,7 +1683,7 @@ def kv_seqoffsets_for_striped_for_rank( [ jnp.full( (kv_segment_pos.shape[0], 1), True, dtype=bool - ), # Assume valid element starts a segment + ), # Assume valid element starts a segment and mask afterwards (kv_segment_pos[..., 1:] != kv_segment_pos[..., :-1] + 1), # Segment pos changed ], axis=-1, @@ -1671,7 +1706,7 @@ def kv_seqoffsets_for_striped_for_rank( [ jnp.full( (kv_segment_pos.shape[0], 1), True, dtype=bool - ), # Assume valid element starts a segment + ), # Assume valid element starts a segment and mask afterwards ( kv_segment_pos_ag[..., 1:] != kv_segment_pos_ag[..., :-1] + 1 ), # Segment pos changed @@ -2049,39 +2084,34 @@ def _cross_attn( idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed ): # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive - # Do not forget to unset the segment_ids and segment_pos so that the seqlens_from_segment_ids_pos() function - # does not go down that route but instead just picks the seqlens and offsets passed onto it + # Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos() + # does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it kv_max_seqlen = k.shape[1] # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq( max_seqlen=kv_max_seqlen, cp_size=cp_size ) - q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank( + q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank( _q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq ) q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank( q_segment_ids=_q_segment_ids, q_segment_pos=_q_segment_pos, - q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq, ) - kv_num_segments_for_rank, kv_seqlens_for_rank = ( - helper.kv_seqlens_for_striped_for_rank( - kv_segment_ids=_kv_segment_ids, - kv_segment_pos=_kv_segment_pos, - max_segments_per_seq=adjusted_max_segments_per_seq, - ) + kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank( + kv_segment_ids=_kv_segment_ids, + kv_segment_pos=_kv_segment_pos, + max_segments_per_seq=adjusted_max_segments_per_seq, ) kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank( kv_segment_pos=_kv_segment_pos, kv_segment_ids=_kv_segment_ids, kv_segment_pos_ag=kv_segment_pos_ag, kv_segment_ids_ag=kv_segment_ids_ag, - kv_num_segments=kv_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq, ) - # kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q, # sharded for rank @@ -2094,16 +2124,17 @@ def _cross_attn( kv_seqlens_for_rank, q_seq_offsets_for_rank, kv_seq_offsets_for_rank, - q_seqlen, # Should be empty ids but using placeholder - kv_seqlen, # Should be empty poss but using placeholder - q_seq_offsets, # Should be empty ids but using placeholder - k_seq_offsets, # Should be empty pos but using placeholder + q_seqlen, # q seg ids should be empty ids so just passing q seqlens (empty) instead + kv_seqlen, # kv seg ids should be empty ids so just passing kv seqlens (empty) instead + q_seq_offsets, # q seg pos should be empty pos so just passing q seqoffsets (empty) instead + k_seq_offsets, # kv seg pos should be empty pos so just passing kv seqoffsets (empty) instead config=helper.get_step_config_for_striped( max_seqlen=kv_max_seqlen, cp_size=cp_size ), ) return output, softmax_aux, rng_state + # AG the k, v, kv_segment_ids and kv_segment_pos k_ag, v_ag = helper.all_gather_kv(k, v) _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos( _kv_segment_ids, _kv_segment_pos @@ -2123,7 +2154,6 @@ def _cross_attn( ) for idx in range(cp_size) ] - return lax.switch(cp_rank, functions) return mesh, impl, out_shardings, arg_shardings @@ -2151,7 +2181,6 @@ def partition(config, mesh, arg_infos, result_infos): helper = _FusedAttnCPWithAllGatherHelper(mesh, config) helper.check_supported() - # TODO: Confirm the deletion del result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) @@ -2214,39 +2243,34 @@ def _cross_attn_bwd( kv_segment_pos_ag, ): # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive - # Do not forget to unset the segment_ids and segment_pos so that the seqlens_from_segment_ids_pos() function - # does not go down that route but instead just picks the seqlens and offsets passed onto it + # Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos() + # does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it kv_max_seqlen = k.shape[1] # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq( max_seqlen=kv_max_seqlen, cp_size=cp_size ) - q_num_segments_for_rank, q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank( + q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank( _q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq ) q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank( q_segment_ids=_q_segment_ids, q_segment_pos=_q_segment_pos, - q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq, ) - kv_num_segments_for_rank, kv_seqlens_for_rank = ( - helper.kv_seqlens_for_striped_for_rank( - kv_segment_ids=_kv_segment_ids, - kv_segment_pos=_kv_segment_pos, - max_segments_per_seq=adjusted_max_segments_per_seq, - ) + kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank( + kv_segment_ids=_kv_segment_ids, + kv_segment_pos=_kv_segment_pos, + max_segments_per_seq=adjusted_max_segments_per_seq, ) kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank( kv_segment_pos=_kv_segment_pos, kv_segment_ids=_kv_segment_ids, kv_segment_pos_ag=kv_segment_pos_ag, kv_segment_ids_ag=kv_segment_ids_ag, - kv_num_segments=kv_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq, ) - # kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(q_segment_pos=_q_segment_pos, q_num_segments=q_num_segments_for_rank, max_segments_per_seq=adjusted_max_segments_per_seq) dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl( q, # sharded for rank @@ -2262,28 +2286,17 @@ def _cross_attn_bwd( kv_seqlens_for_rank, q_seq_offsets_for_rank, kv_seq_offsets_for_rank, - q_seqlen, # Should be empty ids but using placeholder - kv_seqlen, # Should be empty poss but using placeholder - q_seq_offsets, # Should be empty ids but using placeholder - k_seq_offsets, # Should be empty pos but using placeholder + q_seqlen, # q seg ids should be empty ids so just passing q seqlens (empty) instead + kv_seqlen, # kv seg ids should be empty ids so just passing kv seqlens (empty) instead + q_seq_offsets, # q seg pos should be empty pos so just passing q seqoffsets (empty) instead + k_seq_offsets, # kv seg pos should be empty pos so just passing kv seqoffsets (empty) instead config=helper.get_step_config_for_striped( max_seqlen=kv_max_seqlen, cp_size=cp_size ), ) - - # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. - # if config.attn_mask_type != AttnMaskType.NO_MASK: - # pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx] - # dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length) - - # results.append((dq_local, dk_local, dv_local, dbias_local)) - - # dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1) - # dk_local_pad = results[0][1] + results[1][1] - # dv_local_pad = results[0][2] + results[1][2] - # return dq_local, dk_local_pad, dv_local_pad, results[1][3] return dq_local, dk_local, dv_local, dbias_local + # AG the k, v, kv_segment_ids and kv_segment_pos k_ag, v_ag = helper.all_gather_kv(k, v) _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos( _kv_segment_ids, _kv_segment_pos @@ -2313,6 +2326,7 @@ def _cross_attn_bwd( ] dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) + # RS the dk and dv dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it) @@ -3418,6 +3432,7 @@ def fused_attn_fwd( primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive else: primitive = FusedRingAttnFwdPrimitive.outer_primitive + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) output, softmax_aux, rng_state = primitive.bind( *qkv_for_primitive, From ab81a3074c990784b4b410052cd54e665db6f94c Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:50:20 -0800 Subject: [PATCH 32/36] nit: Apply suggestions from code review Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Fix type on fused attn tests Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_fused_attn.py | 4 ++-- transformer_engine/jax/cpp_extensions/attention.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 4deb2a1856..1c97662238 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -776,7 +776,7 @@ def to_dp_shardings(x): self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None) self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec) - def _test_forward(self): + def test_forward(self): """ Test forward with JITted primitive and unJITted reference """ @@ -1150,7 +1150,7 @@ class TestFusedAttn: pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"), ], ) - def test_forward( + def _test_forward( b, s_q, s_kv, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index a3d4eb2b59..c154cb98a9 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1306,6 +1306,7 @@ def check_supported(self): f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" ) # TODO: For now do not all CP + AG + THD + Striped with NO_MASK + # TODO: For now do not allow CP + AG + THD + Striped with NO_MASK if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd(): raise ValueError(f"{header} only supports CAUSAL_MASK for THD types") @@ -1380,7 +1381,7 @@ def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: ) def all_gather_kv(self, k, v): - """Performs aa all-gather of k and v over context parallel ranks.""" + """Performs an all-gather of k and v over context parallel ranks.""" def ag(x): x = lax_paral_op( @@ -1402,7 +1403,7 @@ def ag(x): return k, v # fall through def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos): - """Performs aa all-gather of kv segment ids and kv segment pos over context parallel ranks.""" + """Performs an all-gather of kv segment ids and kv segment pos over context parallel ranks.""" kv_segment_ids = lax_paral_op( kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True ) From 51440db402074972dc4fbadc5cf48be39b1e15ea Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Tue, 25 Nov 2025 21:52:02 -0800 Subject: [PATCH 33/36] Fix seqoffsets length to be passed onto FusedAttn primitive as it is b and not b+1 needed by cuDNN Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/cpp_extensions/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index c154cb98a9..db98e72871 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1582,7 +1582,7 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se # Get the indices for segment changes (these are the offsets) max_size = q_segment_pos.shape[-1] seq_offsets = jax.vmap( - lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq, fill_value=-1)[0] )(segment_changes_masked) return seq_offsets @@ -1695,7 +1695,7 @@ def kv_seqoffsets_for_striped_for_rank( # Get segment change indices for rank segment_changes_indices = jax.vmap( - lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq, fill_value=-1)[0] )(segment_changes_first_true_masked) # Get segment ids associated with the segment_changes_indices for rank segment_ids = jax.vmap( @@ -1719,7 +1719,7 @@ def kv_seqoffsets_for_striped_for_rank( ) # Get segment change indices for AG segment_changes_ag_indices = jax.vmap( - lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq, fill_value=-1)[0] )(segment_changes_ag_first_true_masked) # Use the segment ids picked per rank to get the offsets from the AG indices From 5e014af48ce7773157abeff9fcc7ea0a3f3e47ef Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Tue, 25 Nov 2025 22:48:12 -0800 Subject: [PATCH 34/36] Remove commented code Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> --- transformer_engine/jax/cpp_extensions/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index db98e72871..1bce04ccc8 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1542,8 +1542,8 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme segment_changes = jnp.concatenate( [ first_is_segment, # First valid element starts a segment - (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) | - # ((valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) & actual_valid[..., 1:]) + (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) + | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), ], axis=-1, From 9b5280bc9a42c278455a75fe972ec3f5a6aff387 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 26 Nov 2025 07:20:45 +0000 Subject: [PATCH 35/36] Fix linting issues Signed-off-by: Kshitij Lakhani Fix incorrect greptile change Signed-off-by: Kshitij Lakhani --- .../jax/cpp_extensions/attention.py | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 1bce04ccc8..8f42318586 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1305,8 +1305,7 @@ def check_supported(self): f"{header} only supports masking types: " f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" ) - # TODO: For now do not all CP + AG + THD + Striped with NO_MASK - # TODO: For now do not allow CP + AG + THD + Striped with NO_MASK + # Do not allow CP + AG + THD + Striped with NO_MASK if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd(): raise ValueError(f"{header} only supports CAUSAL_MASK for THD types") @@ -1339,6 +1338,7 @@ def get_adjusted_mask(self): return self.config.attn_mask_type def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): + """Converts the max segments per seq for context parallelism AG + THD.""" # Estimating adjusted max segments per seq return ( max_seqlen // (self.config.stripe_size * cp_size) @@ -1504,8 +1504,7 @@ def pad(x, npad): return dk, dv # fall through - # Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos - # For e.g. below are the sharded post AG q seg ids and pos for a given rank: + # Below are the sharded post AG q seg ids and pos for a given rank: # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] # max_segments_per_seq = 7 @@ -1515,6 +1514,7 @@ def pad(x, npad): # seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]] # seqlens_all_pad_neg = [[ 4, 4, 4, -1, -1, -1, -1]] def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): + """Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos""" # Create mask for non-zero seg ids and get the non-zero indices associated with the same non_zero_mask = q_segment_ids != 0 max_size = q_segment_ids.shape[-1] @@ -1542,9 +1542,8 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme segment_changes = jnp.concatenate( [ first_is_segment, # First valid element starts a segment - (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) - | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), - (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), + (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) | + (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) ], axis=-1, ) @@ -1558,8 +1557,7 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) return seqlens_all_pad_neg - # Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos - # For e.g. below are the sharded post AG q seg ids and pos for a given rank: + # Below are the sharded post AG q seg ids and pos for a given rank: # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] # max_segments_per_seq = 7 @@ -1568,6 +1566,7 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme # segment_changes_masked = [[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]] # seq_offsets = [[ 0, 8, 12, -1, -1, -1, -1, -1]] def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): + """Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos""" segment_changes = jnp.concatenate( [ jnp.full( @@ -1580,14 +1579,12 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se # Remove any padded region segment changes segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False) # Get the indices for segment changes (these are the offsets) - max_size = q_segment_pos.shape[-1] seq_offsets = jax.vmap( lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq, fill_value=-1)[0] )(segment_changes_masked) return seq_offsets - # Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos - # For e.g. below are the sharded post AG q seg ids and pos for a given rank: + # Below are the sharded post AG q seg ids and pos for a given rank: # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] # max_segments_per_seq = 7 @@ -1597,6 +1594,7 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se # segment_changes = [[False, False, False, True, False, False, False, True, False, False, False, True, True, True, True, False]] # selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]] def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): + """Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos""" # Create mask for non-zero seg ids and get the non-zero indices associated with the same non_zero_mask = kv_segment_ids != 0 max_size = kv_segment_ids.shape[-1] @@ -1614,7 +1612,6 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se non_zero_indices >= 0, jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), 0 ) actual_valid = valid_segment_ids != 0 - first_is_segment = actual_valid[..., 0:1] # Detect segment breaks (only for non-zero segments) segment_changes = jnp.concatenate( @@ -1643,9 +1640,7 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se ) return selected_values - # Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos, - # AG kv seg ids and seg pos. - # For e.g. below are the sharded post AG q seg ids and pos for a given rank: + # Below are the sharded post AG q seg ids and pos for a given rank: # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] # kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -1679,6 +1674,8 @@ def kv_seqoffsets_for_striped_for_rank( kv_segment_ids_ag, max_segments_per_seq, ): + """Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos, + AG kv seg ids and seg pos.""" # Calculate the segment pos change mask segment_changes_first_true = jnp.concatenate( [ @@ -2082,7 +2079,7 @@ def impl( # Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos, # _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in. def _cross_attn( - idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed + q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed ): # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive # Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos() @@ -2143,7 +2140,6 @@ def _cross_attn( functions = [ partial( _cross_attn, - idx, q, k_ag, v_ag, @@ -2226,7 +2222,6 @@ def impl( # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. def _cross_attn_bwd( - idx, q, k, v, @@ -2306,7 +2301,6 @@ def _cross_attn_bwd( functions = [ partial( _cross_attn_bwd, - idx, q, k_ag, v_ag, From 8841f5b283c010d8b9cde16186cc32809b7d3eba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Nov 2025 07:31:09 +0000 Subject: [PATCH 36/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 8f42318586..ebf1c84a9d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1542,8 +1542,8 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme segment_changes = jnp.concatenate( [ first_is_segment, # First valid element starts a segment - (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) | - (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1) + (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) + | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), ], axis=-1, )