@@ -1305,8 +1305,7 @@ def check_supported(self):
13051305 f"{ header } only supports masking types: "
13061306 f" { ',' .join (map (str , allowed_masks ))} got: { self .config .attn_mask_type } "
13071307 )
1308- # TODO: For now do not all CP + AG + THD + Striped with NO_MASK
1309- # TODO: For now do not allow CP + AG + THD + Striped with NO_MASK
1308+ # Do not allow CP + AG + THD + Striped with NO_MASK
13101309 if self .config .attn_mask_type is AttnMaskType .NO_MASK and self .config .qkv_layout .is_thd ():
13111310 raise ValueError (f"{ header } only supports CAUSAL_MASK for THD types" )
13121311
@@ -1339,6 +1338,7 @@ def get_adjusted_mask(self):
13391338 return self .config .attn_mask_type
13401339
13411340 def get_adjusted_max_segments_per_seq (self , max_seqlen , cp_size ):
1341+ """Converts the max segments per seq for context parallelism AG + THD."""
13421342 # Estimating adjusted max segments per seq
13431343 return (
13441344 max_seqlen // (self .config .stripe_size * cp_size )
@@ -1504,8 +1504,7 @@ def pad(x, npad):
15041504
15051505 return dk , dv # fall through
15061506
1507- # Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos
1508- # For e.g. below are the sharded post AG q seg ids and pos for a given rank:
1507+ # Below are the sharded post AG q seg ids and pos for a given rank:
15091508 # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
15101509 # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
15111510 # max_segments_per_seq = 7
@@ -1515,6 +1514,7 @@ def pad(x, npad):
15151514 # seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]
15161515 # seqlens_all_pad_neg = [[ 4, 4, 4, -1, -1, -1, -1]]
15171516 def q_seqlens_for_striped_for_rank (self , q_segment_ids , q_segment_pos , max_segments_per_seq ):
1517+ """Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos"""
15181518 # Create mask for non-zero seg ids and get the non-zero indices associated with the same
15191519 non_zero_mask = q_segment_ids != 0
15201520 max_size = q_segment_ids .shape [- 1 ]
@@ -1542,9 +1542,8 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme
15421542 segment_changes = jnp .concatenate (
15431543 [
15441544 first_is_segment , # First valid element starts a segment
1545- (valid_segment_ids [..., 1 :] != valid_segment_ids [..., :- 1 ])
1546- | (valid_segment_pos [..., 1 :] != valid_segment_pos [..., :- 1 ] + 1 ),
1547- (valid_segment_pos [..., 1 :] != valid_segment_pos [..., :- 1 ] + 1 ),
1545+ (valid_segment_ids [..., 1 :] != valid_segment_ids [..., :- 1 ]) |
1546+ (valid_segment_pos [..., 1 :] != valid_segment_pos [..., :- 1 ] + 1 )
15481547 ],
15491548 axis = - 1 ,
15501549 )
@@ -1558,8 +1557,7 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme
15581557 seqlens_all_pad_neg = jnp .where (seqlens_all == 0 , - 1 , seqlens_all )
15591558 return seqlens_all_pad_neg
15601559
1561- # Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos
1562- # For e.g. below are the sharded post AG q seg ids and pos for a given rank:
1560+ # Below are the sharded post AG q seg ids and pos for a given rank:
15631561 # q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
15641562 # q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
15651563 # max_segments_per_seq = 7
@@ -1568,6 +1566,7 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme
15681566 # segment_changes_masked = [[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]]
15691567 # seq_offsets = [[ 0, 8, 12, -1, -1, -1, -1, -1]]
15701568 def q_seqoffsets_for_striped_for_rank (self , q_segment_ids , q_segment_pos , max_segments_per_seq ):
1569+ """Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos"""
15711570 segment_changes = jnp .concatenate (
15721571 [
15731572 jnp .full (
@@ -1580,14 +1579,12 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se
15801579 # Remove any padded region segment changes
15811580 segment_changes_masked = jnp .where (q_segment_ids != 0 , segment_changes , False )
15821581 # Get the indices for segment changes (these are the offsets)
1583- max_size = q_segment_pos .shape [- 1 ]
15841582 seq_offsets = jax .vmap (
15851583 lambda scm_row : jnp .where (scm_row , size = max_segments_per_seq , fill_value = - 1 )[0 ]
15861584 )(segment_changes_masked )
15871585 return seq_offsets
15881586
1589- # Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos
1590- # For e.g. below are the sharded post AG q seg ids and pos for a given rank:
1587+ # Below are the sharded post AG q seg ids and pos for a given rank:
15911588 # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
15921589 # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
15931590 # max_segments_per_seq = 7
@@ -1597,6 +1594,7 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se
15971594 # segment_changes = [[False, False, False, True, False, False, False, True, False, False, False, True, True, True, True, False]]
15981595 # selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]]
15991596 def kv_seqlens_for_striped_for_rank (self , kv_segment_ids , kv_segment_pos , max_segments_per_seq ):
1597+ """Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos"""
16001598 # Create mask for non-zero seg ids and get the non-zero indices associated with the same
16011599 non_zero_mask = kv_segment_ids != 0
16021600 max_size = kv_segment_ids .shape [- 1 ]
@@ -1614,7 +1612,6 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se
16141612 non_zero_indices >= 0 , jnp .take_along_axis (kv_segment_pos , clipped_indices , axis = - 1 ), 0
16151613 )
16161614 actual_valid = valid_segment_ids != 0
1617- first_is_segment = actual_valid [..., 0 :1 ]
16181615
16191616 # Detect segment breaks (only for non-zero segments)
16201617 segment_changes = jnp .concatenate (
@@ -1643,9 +1640,7 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se
16431640 )
16441641 return selected_values
16451642
1646- # Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
1647- # AG kv seg ids and seg pos.
1648- # For e.g. below are the sharded post AG q seg ids and pos for a given rank:
1643+ # Below are the sharded post AG q seg ids and pos for a given rank:
16491644 # kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
16501645 # kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
16511646 # kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -1679,6 +1674,8 @@ def kv_seqoffsets_for_striped_for_rank(
16791674 kv_segment_ids_ag ,
16801675 max_segments_per_seq ,
16811676 ):
1677+ """Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
1678+ AG kv seg ids and seg pos."""
16821679 # Calculate the segment pos change mask
16831680 segment_changes_first_true = jnp .concatenate (
16841681 [
@@ -2082,7 +2079,7 @@ def impl(
20822079 # Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos,
20832080 # _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in.
20842081 def _cross_attn (
2085- idx , q , k , v , bias , softmax_offset , kv_segment_ids_ag , kv_segment_pos_ag , seed
2082+ q , k , v , bias , softmax_offset , kv_segment_ids_ag , kv_segment_pos_ag , seed
20862083 ):
20872084 # Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
20882085 # Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
@@ -2143,7 +2140,6 @@ def _cross_attn(
21432140 functions = [
21442141 partial (
21452142 _cross_attn ,
2146- idx ,
21472143 q ,
21482144 k_ag ,
21492145 v_ag ,
@@ -2226,7 +2222,6 @@ def impl(
22262222
22272223 # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
22282224 def _cross_attn_bwd (
2229- idx ,
22302225 q ,
22312226 k ,
22322227 v ,
@@ -2306,7 +2301,6 @@ def _cross_attn_bwd(
23062301 functions = [
23072302 partial (
23082303 _cross_attn_bwd ,
2309- idx ,
23102304 q ,
23112305 k_ag ,
23122306 v_ag ,
0 commit comments