diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 5372018ae8..222275f775 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -327,7 +327,7 @@ def test_cross_attn( ] DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [ - # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. + # Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes. pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), ] @@ -351,12 +351,14 @@ def impl_test_context_parallel_attn( use_shardy, use_scan_ring=False, window_size=None, + stripe_size=0, + num_segments_per_seq=0, ): if qkv_layout.is_thd(): - if cp_strategy == CPStrategy.ALL_GATHER: - pytest.skip("THD doesn't support all gather context parallelism.") - if not load_balanced and cp_strategy == CPStrategy.RING: - pytest.skip("THD + ring doesn't support unbalanced context parallelism.") + if not load_balanced and ( + cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER + ): + pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.") assert not use_scan_ring or cp_strategy == CPStrategy.RING @@ -382,7 +384,6 @@ def impl_test_context_parallel_attn( data_shape = batch, seqlen, num_head, hidden num_kv_heads = num_head // kv_groups - runner = FusedAttnRunner( batch, seqlen, @@ -407,6 +408,8 @@ def impl_test_context_parallel_attn( mesh_resource=mesh_resource, cp_strategy=cp_strategy, cp_load_balanced=load_balanced, + stripe_size=stripe_size, + num_segments_per_seq=num_segments_per_seq, ) def check_has_backend_for_mask(mask_type): @@ -457,7 +460,7 @@ def check_has_backend_for_mask(mask_type): @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], ) def test_context_parallel_allgather_attn_shardy( self, @@ -486,6 +489,72 @@ def test_context_parallel_allgather_attn_shardy( use_shardy=True, ) + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) + @pytest.mark.parametrize("kv_groups", [1, 8]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + [DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[-1]], + ) + @pytest.mark.parametrize( + "load_balanced", + [pytest.param(True, id="BALANCED")], + ) + @pytest.mark.parametrize( + "stripe_size", + [pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")], + ) + @pytest.mark.parametrize( + "window_size", + [ + pytest.param((-1, -1), id="window_size(-1, -1)"), + pytest.param((5, 0), id="window_size(5, 0)"), + ], + ) + @pytest.mark.parametrize( + "num_segments_per_seq", + [pytest.param(2, id="SEG-2"), pytest.param(11, id="SEG-11")], + ) + def test_context_parallel_allgather_striped_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + window_size, + stripe_size, + num_segments_per_seq, + ): + if window_size != (-1, -1) and not qkv_layout.is_thd(): + pytest.skip("Sliding window attention is only supported for THD layout") + self.impl_test_context_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + CPStrategy.ALL_GATHER, + use_shardy=False, + window_size=window_size, + stripe_size=stripe_size, + num_segments_per_seq=num_segments_per_seq, + ) + @pytest_parametrize_wrapper( "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs_for_attn(), @@ -495,7 +564,7 @@ def test_context_parallel_allgather_attn_shardy( @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], ) @pytest.mark.parametrize( "load_balanced", @@ -538,7 +607,7 @@ def test_context_parallel_allgather_attn( @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], ) @pytest.mark.parametrize( "load_balanced", @@ -602,7 +671,7 @@ def test_context_parallel_ring_attn( @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS[:-1], ) def test_context_parallel_ring_attn_shardy( self, @@ -639,31 +708,39 @@ def test_context_parallel_ring_attn_shardy( "L2": [[4, 32, 12, 32], [1, 16, 1, 1]], } +REORDER_STRATEGY = [ + pytest.param(ReorderStrategy.DualChunkSwap, None, id="DualChunkSwap"), + pytest.param(ReorderStrategy.Striped, 1, id="Striped-1"), + pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"), +] + class TestReorderCausalLoadBalancing: @pytest.mark.parametrize("cp_size", [2, 4, 8]) @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES) - @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) + @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD]) @pytest.mark.parametrize( - "reorder_strategy", - [ - pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"), - pytest.param(ReorderStrategy.Striped, id="Striped"), - ], + "reorder_strategy, stripe_size", + REORDER_STRATEGY, ) - def test(self, cp_size, shape, qkv_format, reorder_strategy): + def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_size): tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) seq_dim = 1 if qkv_format == QKVFormat.SBHD: tensor = tensor.swapaxes(0, 1) seq_dim = 0 + if reorder_strategy == ReorderStrategy.Striped: + seq_lens = shape[seq_dim] + if seq_lens < (cp_size * stripe_size): + pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_size=}") + ref = tensor.copy() - reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3]) - inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3]) + reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) + inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4]) - reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim) - inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim) + reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_size) + inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_size) assert jnp.array_equal(inversed, ref) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f4caaef165..1c97662238 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -230,6 +230,7 @@ def make_mask( @jax.jit def get_seqlens_and_offsets(segment_ids): batch, max_seqlen = segment_ids.shape + # TODO: should this be max_seqlen + 1 ? bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) seqlens = seqlens_with_zero[..., 1:] @@ -352,6 +353,8 @@ class FusedAttnRunner: bias_shape: BiasShape window_size: Tuple[int, int] seq_desc_format: SeqDescFormat + stripe_size: int = 0 + num_segments_per_seq: int = 0 # Specifies sharding resources for distributed tests number_of_devices: int = 1 @@ -577,7 +580,9 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - self.num_segments_per_seq = 2 + # If using default num segments of 0, set to 2 + if self.num_segments_per_seq == 0: + self.num_segments_per_seq = 2 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) @@ -635,12 +640,14 @@ def generate_random_segment_ids( strategy=reorder_strategy, cp_size=self.cp_size, seq_dim=seq_dim, + stripe_size=self.stripe_size, ) self.cp_inverse_reorder_fn = partial( inverse_reorder_causal_load_balancing, strategy=reorder_strategy, cp_size=self.cp_size, seq_dim=seq_dim, + stripe_size=self.stripe_size, ) else: # no-ops for non cp or non load balanced @@ -771,7 +778,7 @@ def to_dp_shardings(x): def test_forward(self): """ - Test forward without JIT + Test forward with JITted primitive and unJITted reference """ self._setup_inputs() @@ -801,6 +808,7 @@ def test_forward(self): "window_size": self.window_size, "context_parallel_strategy": self.cp_strategy, "context_parallel_causal_load_balanced": self.cp_load_balanced, + "stripe_size": self.stripe_size, } customcall_fused_dpa_jit = jit( @@ -896,6 +904,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): "window_size": self.window_size, "context_parallel_strategy": self.cp_strategy, "context_parallel_causal_load_balanced": self.cp_load_balanced, + "stripe_size": self.stripe_size, } # We can compute dBias only for the [1, h, s, s] layout diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 0a32be9679..bc1b25cd82 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -386,23 +386,27 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen -def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int): +def reorder_causal_load_balancing( + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int = 1 +): """Reorders a tensor for load balancing the compute of causal attention.""" if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) if strategy == ReorderStrategy.Striped: - return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False) + # stripe_size > 1 is only supported for CP+THD+AG+Striped + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False, stripe_size) raise ValueError(f"Unsupported {strategy=}") def inverse_reorder_causal_load_balancing( - tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int = 1 ): """Inverse operation of `reorder_causal_load_balancing`.""" if strategy == ReorderStrategy.DualChunkSwap: return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True) if strategy == ReorderStrategy.Striped: - return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True) + # stripe_size > 1 is only supported for CP+THD+AG+Striped + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True, stripe_size) raise ValueError(f"Unsupported {strategy=}") @@ -988,7 +992,7 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -1008,6 +1012,7 @@ def _fused_attn( context_parallel_causal_load_balanced: bool, context_parallel_axis: str, context_checkpoint_name: str = "context", + stripe_size: int = 0, ): output, _ = _fused_attn_fwd_rule( qkv, @@ -1028,6 +1033,7 @@ def _fused_attn( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name=context_checkpoint_name, + stripe_size=stripe_size, ) return output @@ -1051,6 +1057,7 @@ def _fused_attn_fwd_rule( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name, + stripe_size, ): output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, @@ -1070,6 +1077,7 @@ def _fused_attn_fwd_rule( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, + stripe_size=stripe_size, ) output = checkpoint_name(output, context_checkpoint_name) softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) @@ -1099,6 +1107,7 @@ def _fused_attn_bwd_rule( context_parallel_causal_load_balanced, context_parallel_axis, context_checkpoint_name, + stripe_size, ctx, dz, ): @@ -1133,6 +1142,7 @@ def _fused_attn_bwd_rule( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, + stripe_size=stripe_size, ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -1169,6 +1179,7 @@ def fused_attn( context_parallel_axis: str = "", context_checkpoint_name: str = "context", softmax_offset: Optional[jnp.ndarray] = None, + stripe_size: int = 0, ): """ Perform cuDNN fused attention. @@ -1206,6 +1217,10 @@ def fused_attn( softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape [1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX. If provided, this parameter will receive gradients during backpropagation. + stripe_size (int): + Indicates the striping size to be used when using ReorderStrategy.Striped. + Currently, a stripe_size > 1 is only allowed for CP + THD + Striped + AG + 0 indicates no striping strategy Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -1283,5 +1298,6 @@ def fused_attn( context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, context_checkpoint_name=context_checkpoint_name, + stripe_size=stripe_size, ) return output diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index f0778bfd29..ebf1c84a9d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -73,6 +73,7 @@ "context_parallel_load_balanced", "cp_axis", "cp_striped_window_size", + "stripe_size", ], ) @dataclass(frozen=True) @@ -93,6 +94,7 @@ class _FusedAttnConfig: context_parallel_load_balanced: bool cp_axis: str cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA + stripe_size: int # Only for CP + Striped. For, Ring P2P , stripe_size=1 only. @dataclass(frozen=True) @@ -527,7 +529,6 @@ def impl( segment_ids=(_q_segment_ids, _kv_segment_ids), segment_pos=(_q_segment_pos, _kv_segment_pos), ) - (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( sequence_descriptor.get_seqlens_and_offsets( config.attn_mask_type, @@ -536,7 +537,6 @@ def impl( config.max_segments_per_seq, ) ) - if config.qkv_layout.is_thd(): def _fix_len_take(x, condition, fill_value=-1): @@ -1234,31 +1234,33 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig return combined.reshape(ori_tensor_shape) -def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool): +def reorder_causal_striped( + tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_size: int = 1 +): """Reorders a tensor for load balancing with striped pattern""" origin_shape = tensor.shape - if origin_shape[seq_dim] % cp_size != 0: + if origin_shape[seq_dim] % (cp_size * stripe_size) != 0: raise ValueError( - "Expected origin_shape[seq_dim] is multiple of cp_size but got" - f" {origin_shape[seq_dim]=} and {cp_size=}" + "Expected origin_shape[seq_dim] is multiple of cp_size*stripe_size but got" + f" {origin_shape[seq_dim]=}, {cp_size=}, {stripe_size=}, {cp_size*stripe_size=}" ) if not is_inverse: new_shape = [ *origin_shape[:seq_dim], - *[origin_shape[seq_dim] // cp_size, cp_size], + *[origin_shape[seq_dim] // (cp_size * stripe_size), cp_size, stripe_size], *origin_shape[seq_dim + 1 :], ] else: new_shape = [ *origin_shape[:seq_dim], - *[cp_size, origin_shape[seq_dim] // cp_size], + *[cp_size, origin_shape[seq_dim] // (cp_size * stripe_size), stripe_size], *origin_shape[seq_dim + 1 :], ] - chunked_tensor = tensor.reshape(new_shape) - reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1) - return reordered_chunked_tensor.reshape(origin_shape) + striped_tensor = tensor.reshape(new_shape) + reordered_striped_tensor = jnp.swapaxes(striped_tensor, seq_dim, seq_dim + 1) + return reordered_striped_tensor.reshape(origin_shape) @dataclass(frozen=True) @@ -1272,26 +1274,44 @@ def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused attention" - allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] + allowed_layouts = [ + QKVLayout.BSHD_BS2HD, + QKVLayout.BSHD_BSHD_BSHD, + QKVLayout.THD_T2HD, + QKVLayout.THD_THD_THD, + ] if self.config.qkv_layout not in allowed_layouts: raise ValueError( f"{header} only supports layouts:" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) + if (not self.config.qkv_layout.is_thd() and self.config.stripe_size != 0) or ( + self.config.qkv_layout.is_thd() and self.config.stripe_size == 0 + ): + raise ValueError( + f"{header} only supports Dual Chunk load balancing with BSHD layouts and Striped" + " load balancing with THD layouts" + ) + if self.config.attn_bias_type != AttnBiasType.NO_BIAS: raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] + if self.config.qkv_layout.is_thd(): + allowed_masks.append(AttnMaskType.PADDING_CAUSAL_MASK) if self.config.attn_mask_type not in allowed_masks: raise ValueError( f"{header} only supports masking types: " f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" ) + # Do not allow CP + AG + THD + Striped with NO_MASK + if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd(): + raise ValueError(f"{header} only supports CAUSAL_MASK for THD types") - if self.config.max_segments_per_seq != 1: + if self.config.max_segments_per_seq != 1 and (not self.config.qkv_layout.is_thd): raise ValueError( - f"{header} only supports max_segments_per_seq == 1 got:" + f"{header} only supports max_segments_per_seq == 1 for BSHD layouts, got:" f" {self.config.max_segments_per_seq}" ) @@ -1305,10 +1325,25 @@ def check_supported(self): def get_adjusted_mask(self): """Converts the mask for context parallelism.""" - if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK: + if ( + self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK + and not self.config.qkv_layout.is_thd() + ): # BSHD AG case only return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK + if ( + self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK + and self.config.qkv_layout.is_thd() + ): # THD AG case only + return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type + def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): + """Converts the max segments per seq for context parallelism AG + THD.""" + # Estimating adjusted max segments per seq + return ( + max_seqlen // (self.config.stripe_size * cp_size) + ) + self.config.max_segments_per_seq + def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" return _FusedAttnConfig( @@ -1324,10 +1359,29 @@ def get_step_config(self) -> _FusedAttnConfig: context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, + stripe_size=self.config.stripe_size, + ) + + def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: + """Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention.""" + return _FusedAttnConfig( + attn_bias_type=self.config.attn_bias_type, + attn_mask_type=self.get_adjusted_mask(), + softmax_type=self.config.softmax_type, + qkv_layout=self.config.qkv_layout, + scaling_factor=self.config.scaling_factor, + dropout_probability=self.config.dropout_probability, + is_training=self.config.is_training, + max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size), + window_size=self.config.window_size, + context_parallel_load_balanced=self.config.context_parallel_load_balanced, + cp_axis=self.config.cp_axis, + cp_striped_window_size=None, + stripe_size=self.config.stripe_size, ) def all_gather_kv(self, k, v): - """Performs a all-gather of k and v over context parallel ranks.""" + """Performs an all-gather of k and v over context parallel ranks.""" def ag(x): x = lax_paral_op( @@ -1335,7 +1389,10 @@ def ag(x): ) if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True) + if self.config.qkv_layout.is_thd(): + x = reorder_causal_striped(x, cp_size, 1, True, self.config.stripe_size) + else: + x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True) return x if self.config.qkv_layout.is_kvpacked(): @@ -1345,13 +1402,36 @@ def ag(x): return k, v # fall through + def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos): + """Performs an all-gather of kv segment ids and kv segment pos over context parallel ranks.""" + kv_segment_ids = lax_paral_op( + kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True + ) + kv_segment_pos = lax_paral_op( + kv_segment_pos, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True + ) + if self.config.context_parallel_load_balanced: + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + if self.config.qkv_layout.is_thd(): + kv_segment_ids_ag = reorder_causal_striped( + kv_segment_ids, cp_size, 1, True, self.config.stripe_size + ) + kv_segment_pos_ag = reorder_causal_striped( + kv_segment_pos, cp_size, 1, True, self.config.stripe_size + ) + return kv_segment_ids_ag, kv_segment_pos_ag + return kv_segment_ids, kv_segment_pos # fall through + def reduce_scatter_dkv(self, dk, dv): """Performs a reduce-scatter of dk and dv over context parallel ranks.""" def rs(x): if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False) + if self.config.qkv_layout.is_thd(): + x = reorder_causal_striped(x, cp_size, 1, False, self.config.stripe_size) + else: + x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False) return lax_paral_op( x, @@ -1424,6 +1504,227 @@ def pad(x, npad): return dk, dv # fall through + # Below are the sharded post AG q seg ids and pos for a given rank: + # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] + # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] + # max_segments_per_seq = 7 + # Below are some intermediate representations: + # non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]] + # segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, True, True, True]] + # seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]] + # seqlens_all_pad_neg = [[ 4, 4, 4, -1, -1, -1, -1]] + def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): + """Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos""" + # Create mask for non-zero seg ids and get the non-zero indices associated with the same + non_zero_mask = q_segment_ids != 0 + max_size = q_segment_ids.shape[-1] + non_zero_indices = jax.vmap( + lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] + )(non_zero_mask) + + # Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos + # Clip -1 to 0 for safe indexing + clipped_indices = jnp.clip(non_zero_indices, 0, None) + valid_segment_ids = jnp.where( + non_zero_indices >= 0, jnp.take_along_axis(q_segment_ids, clipped_indices, axis=-1), 0 + ) + valid_segment_pos = jnp.where( + non_zero_indices >= 0, jnp.take_along_axis(q_segment_pos, clipped_indices, axis=-1), 0 + ) + # Create a mask for actual valid entries (not padding) + actual_valid = valid_segment_ids != 0 + # First element is True only if it's actually valid + first_is_segment = actual_valid[..., 0:1] + + # Detect segment breaks in the valid tokens only (not full seq) + # Padding will always be true as the segment change condition is being applied + # on the valid segments (which have padding at the end so they'll always trigger True) + segment_changes = jnp.concatenate( + [ + first_is_segment, # First valid element starts a segment + (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) + | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), + ], + axis=-1, + ) + new_segment_ids = jnp.cumsum(segment_changes, axis=-1) + seqlens_pre = jax.vmap( + lambda av_row, nsi_row: jnp.where(av_row, nsi_row, 0).astype(jnp.int32) + )(actual_valid, new_segment_ids) + seqlens_all = jax.vmap( + lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] + )(seqlens_pre) + seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) + return seqlens_all_pad_neg + + # Below are the sharded post AG q seg ids and pos for a given rank: + # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] + # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] + # max_segments_per_seq = 7 + # Below are some intermediate representations: + # segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, False, False, False]] + # segment_changes_masked = [[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]] + # seq_offsets = [[ 0, 8, 12, -1, -1, -1, -1, -1]] + def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq): + """Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos""" + segment_changes = jnp.concatenate( + [ + jnp.full( + (q_segment_pos.shape[0], 1), True, dtype=bool + ), # First valid element starts a segment + (q_segment_pos[..., 1:] != q_segment_pos[..., :-1] + 1), # Segment pos changed + ], + axis=-1, + ) + # Remove any padded region segment changes + segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False) + # Get the indices for segment changes (these are the offsets) + seq_offsets = jax.vmap( + lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq, fill_value=-1)[0] + )(segment_changes_masked) + return seq_offsets + + # Below are the sharded post AG q seg ids and pos for a given rank: + # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] + # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] + # max_segments_per_seq = 7 + # Below are some intermediate representations: + # non_zero_mask = [[ True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True]] + # non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]] + # segment_changes = [[False, False, False, True, False, False, False, True, False, False, False, True, True, True, True, False]] + # selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]] + def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq): + """Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos""" + # Create mask for non-zero seg ids and get the non-zero indices associated with the same + non_zero_mask = kv_segment_ids != 0 + max_size = kv_segment_ids.shape[-1] + non_zero_indices = jax.vmap( + lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] + )(non_zero_mask) + + # Pick non zero seg ids and seg pos using take_along_axis + # Clip -1 to 0 for safe indexing + clipped_indices = jnp.clip(non_zero_indices, 0, None) + valid_segment_ids = jnp.where( + non_zero_indices >= 0, jnp.take_along_axis(kv_segment_ids, clipped_indices, axis=-1), 0 + ) + valid_segment_pos = jnp.where( + non_zero_indices >= 0, jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), 0 + ) + actual_valid = valid_segment_ids != 0 + + # Detect segment breaks (only for non-zero segments) + segment_changes = jnp.concatenate( + [ + ( + (valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) + & actual_valid[..., 1:] + ) + | (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1), + actual_valid[..., -1:], + ], + axis=-1, + ) + # Get the indices for segment changes + segment_changes_valid = jax.vmap( + lambda sc_row, av_row: jnp.where( + sc_row & av_row, size=max_segments_per_seq, fill_value=-1 + )[0] + )(segment_changes, actual_valid) + safe_indices = jnp.maximum(segment_changes_valid, 0) + # Select values using take_along_axis per row + selected_values = jnp.where( + segment_changes_valid >= 0, + jnp.take_along_axis(valid_segment_pos, safe_indices, axis=-1) + 1, + -1, + ) + return selected_values + + # Below are the sharded post AG q seg ids and pos for a given rank: + # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]] + # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]] + # kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + # 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + # kv_segment_pos_ag = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + # 18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + # 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + # max_segments_per_seq = 7 + # Below are some intermediate representations: + # segment_changes_first_true_masked = [[ True, False, False, False, False, False, False, False, True, + # False, False, False, True, False, False, False]] + # segment_changes_indices = [[ 0, 8, 12, -1, -1, -1, -1, -1, -1]] + # segment_ids = [[ 1, 2, 2, -1, -1, -1, -1, -1, -1]] + # segment_changes_ag_first_true_masked = [[ True, False, False, False, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False, False, False, True, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False, False, False, False, False, False, False, False, False, + # False] + # segment_changes_ag_indices = [[ 0, 21, -1, -1, -1, -1, -1, -1, -1]] + # seq_offsets = [[ 0, 21, 21, -1, -1, -1, -1, -1, -1]] + def kv_seqoffsets_for_striped_for_rank( + self, + kv_segment_pos, + kv_segment_ids, + kv_segment_pos_ag, + kv_segment_ids_ag, + max_segments_per_seq, + ): + """Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos, + AG kv seg ids and seg pos.""" + # Calculate the segment pos change mask + segment_changes_first_true = jnp.concatenate( + [ + jnp.full( + (kv_segment_pos.shape[0], 1), True, dtype=bool + ), # Assume valid element starts a segment and mask afterwards + (kv_segment_pos[..., 1:] != kv_segment_pos[..., :-1] + 1), # Segment pos changed + ], + axis=-1, + ) + segment_changes_first_true_masked = jnp.where( + kv_segment_ids != 0, segment_changes_first_true, False + ) + + # Get segment change indices for rank + segment_changes_indices = jax.vmap( + lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq, fill_value=-1)[0] + )(segment_changes_first_true_masked) + # Get segment ids associated with the segment_changes_indices for rank + segment_ids = jax.vmap( + lambda sci_row, ksi_row: jnp.where(sci_row >= 0, ksi_row[sci_row], -1) + )(segment_changes_indices, kv_segment_ids) + + # Get segment change indices for AG + segment_changes_ag_first_true = jnp.concatenate( + [ + jnp.full( + (kv_segment_pos.shape[0], 1), True, dtype=bool + ), # Assume valid element starts a segment and mask afterwards + ( + kv_segment_pos_ag[..., 1:] != kv_segment_pos_ag[..., :-1] + 1 + ), # Segment pos changed + ], + axis=-1, + ) + segment_changes_ag_first_true_masked = jnp.where( + kv_segment_ids_ag != 0, segment_changes_ag_first_true, False + ) + # Get segment change indices for AG + segment_changes_ag_indices = jax.vmap( + lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq, fill_value=-1)[0] + )(segment_changes_ag_first_true_masked) + + # Use the segment ids picked per rank to get the offsets from the AG indices + seq_offsets = jax.vmap( + lambda si_row, sca_row: jnp.where(si_row > 0, sca_row[si_row - 1], -1) + )(segment_ids, segment_changes_ag_indices) + return seq_offsets + class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): """ @@ -1501,7 +1802,6 @@ def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed): q_seqlen_for_step = q_seqlen / (cp_size * 2) num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks - output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_split[sub_idx], k_unmasked, @@ -1722,6 +2022,318 @@ def _cross_attn_bwd( register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) +class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Attention Forward with Context Parallelism and Striped Load Balancing Primitive + + This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[5] = seed_sharding + arg_shardings = tuple(arg_shardings) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def impl( + q, + k, + v, + bias, + softmax_offset, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # cuDNN does not support right-aligned masking with dynamic sequence length padding. + # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch + # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor + # meeting the expectation of the SPMD model. + # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding + # mask/sequence length tensor to avoid this unrolled loop. + + # Each rank receives the ag k and v along with the ag kv seg ids and kv seg offsets + # Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos, + # _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in. + def _cross_attn( + q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed + ): + # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive + # Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos() + # does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it + + kv_max_seqlen = k.shape[1] + # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq + adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq( + max_seqlen=kv_max_seqlen, cp_size=cp_size + ) + q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank( + _q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq + ) + q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank( + q_segment_ids=_q_segment_ids, + q_segment_pos=_q_segment_pos, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank( + kv_segment_ids=_kv_segment_ids, + kv_segment_pos=_kv_segment_pos, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank( + kv_segment_pos=_kv_segment_pos, + kv_segment_ids=_kv_segment_ids, + kv_segment_pos_ag=kv_segment_pos_ag, + kv_segment_ids_ag=kv_segment_ids_ag, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( + q, # sharded for rank + k, # ag + v, # ag + bias, + softmax_offset, + seed, + q_seqlens_for_rank, + kv_seqlens_for_rank, + q_seq_offsets_for_rank, + kv_seq_offsets_for_rank, + q_seqlen, # q seg ids should be empty ids so just passing q seqlens (empty) instead + kv_seqlen, # kv seg ids should be empty ids so just passing kv seqlens (empty) instead + q_seq_offsets, # q seg pos should be empty pos so just passing q seqoffsets (empty) instead + k_seq_offsets, # kv seg pos should be empty pos so just passing kv seqoffsets (empty) instead + config=helper.get_step_config_for_striped( + max_seqlen=kv_max_seqlen, cp_size=cp_size + ), + ) + return output, softmax_aux, rng_state + + # AG the k, v, kv_segment_ids and kv_segment_pos + k_ag, v_ag = helper.all_gather_kv(k, v) + _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos( + _kv_segment_ids, _kv_segment_pos + ) + functions = [ + partial( + _cross_attn, + q, + k_ag, + v_ag, + bias, + softmax_offset, + _kv_segment_ids_ag, + _kv_segment_pos_ag, + seed, + ) + for idx in range(cp_size) + ] + return lax.switch(cp_rank, functions) + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPStripedWithAllGatherFwdPrimitive) + + +class FusedAttnCPStripedWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Attention Backward with Context Parallelism and Striped Load Balancing Primitive. + + This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks. + The gradients are subsequently reduce-scattered back to each context parallel rank. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + # Ensure we can support this configuration with context parallelism. + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + softmax_offset_spec = get_padded_spec(arg_infos[4]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = ( + dq_sharding, + dk_sharding, + dv_sharding, + dbias_sharding, + dsoftmax_offset_sharding, + ) + + def impl( + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. + def _cross_attn_bwd( + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + _q_segment_ids, + kv_segment_ids_ag, + _q_segment_pos, + kv_segment_pos_ag, + ): + # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive + # Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos() + # does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it + + kv_max_seqlen = k.shape[1] + # Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq + adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq( + max_seqlen=kv_max_seqlen, cp_size=cp_size + ) + q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank( + _q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq + ) + q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank( + q_segment_ids=_q_segment_ids, + q_segment_pos=_q_segment_pos, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank( + kv_segment_ids=_kv_segment_ids, + kv_segment_pos=_kv_segment_pos, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank( + kv_segment_pos=_kv_segment_pos, + kv_segment_ids=_kv_segment_ids, + kv_segment_pos_ag=kv_segment_pos_ag, + kv_segment_ids_ag=kv_segment_ids_ag, + max_segments_per_seq=adjusted_max_segments_per_seq, + ) + + dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl( + q, # sharded for rank + k, # ag + v, # ag + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlens_for_rank, + kv_seqlens_for_rank, + q_seq_offsets_for_rank, + kv_seq_offsets_for_rank, + q_seqlen, # q seg ids should be empty ids so just passing q seqlens (empty) instead + kv_seqlen, # kv seg ids should be empty ids so just passing kv seqlens (empty) instead + q_seq_offsets, # q seg pos should be empty pos so just passing q seqoffsets (empty) instead + k_seq_offsets, # kv seg pos should be empty pos so just passing kv seqoffsets (empty) instead + config=helper.get_step_config_for_striped( + max_seqlen=kv_max_seqlen, cp_size=cp_size + ), + ) + return dq_local, dk_local, dv_local, dbias_local + + # AG the k, v, kv_segment_ids and kv_segment_pos + k_ag, v_ag = helper.all_gather_kv(k, v) + _kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos( + _kv_segment_ids, _kv_segment_pos + ) + + functions = [ + partial( + _cross_attn_bwd, + q, + k_ag, + v_ag, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + _q_segment_ids, + _kv_segment_ids_ag, + _q_segment_pos, + _kv_segment_pos_ag, + ) + for idx in range(cp_size) + ] + + dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) + # RS the dk and dv + dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) + + # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it) + dummy_dsoftmax_offset = jnp.empty_like(softmax_offset) + return dq, dk, dv, dbias, dummy_dsoftmax_offset + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPStripedWithAllGatherBwdPrimitive) + + @dataclass(frozen=True) class _FusedAttnCPWithP2PHelper: """Helper class to assist with running the P2P ring strategy for CP attention.""" @@ -1811,6 +2423,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, + stripe_size=self.config.stripe_size, ) def stack_kv(self, k, v): @@ -2693,6 +3306,7 @@ def fused_attn_fwd( context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", + stripe_size: int = 0, ) -> jnp.ndarray: """ Perform the forward pass of with cuDNN fused attention implementations. @@ -2731,6 +3345,7 @@ def fused_attn_fwd( context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + stripe_size (int): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -2796,12 +3411,16 @@ def fused_attn_fwd( context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, + stripe_size=stripe_size, ) primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive + if qkv_layout.is_thd(): + primitive = FusedAttnCPStripedWithAllGatherFwdPrimitive.outer_primitive + else: + primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive case CPStrategy.RING: # We must use stripe attention for THD-RING if qkv_layout.is_thd(): @@ -2843,6 +3462,7 @@ def fused_attn_bwd( context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", + stripe_size: int = 0, ): """ Perform the backward pass of the cuDNN fused attention implementations. @@ -2882,6 +3502,7 @@ def fused_attn_bwd( context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + stripe_size (int): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing Returns: Tuple[jnp.ndarray, ...], jnp.ndarray: - The first tuple contains the gradients with respect to the input `qkv` tensors in the @@ -2954,12 +3575,16 @@ def fused_attn_bwd( context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, + stripe_size=stripe_size, ) primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive + if qkv_layout.is_thd(): + primitive = FusedAttnCPStripedWithAllGatherBwdPrimitive.outer_primitive + else: + primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive case CPStrategy.RING: if qkv_layout.is_thd(): primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 556b587191..61deab5b80 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -176,6 +176,9 @@ def shardy_sharding_rule(*args): def register_primitive(cls, outer_only=False): """ Register a JAX primitive and add it to the internal registry. + Inner primitive - single device, no sharding awareness, eager mode fallback + Outer primitive - multi device, sharding aware, partition() distributes work, + used when there's a dev mesh context """ _primitive_registry[cls.__name__] = cls @@ -190,14 +193,17 @@ def name_of_wrapper_p(): inner_p = core.Primitive(cls.name) dispatch.prim_requires_devices_during_lowering.add(inner_p) inner_p.multiple_results = cls.multiple_results + # Define eager execution implementation (by invoking it's MLIR lowering) inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_abstract_eval(cls.abstract) mlir.register_lowering(inner_p, cls.lowering, platform="cuda") cls.inner_primitive = inner_p + # Create the outer primitive for distributed execution outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results + # Define the eager execution implementation outer_p.def_impl(cls.outer_impl) outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher