Skip to content

Commit 9b5280b

Browse files
Fix linting issues
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Fix incorrect greptile change Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
1 parent 5e014af commit 9b5280b

File tree

1 file changed

+14
-20
lines changed

1 file changed

+14
-20
lines changed

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,8 +1305,7 @@ def check_supported(self):
13051305
f"{header} only supports masking types: "
13061306
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
13071307
)
1308-
# TODO: For now do not all CP + AG + THD + Striped with NO_MASK
1309-
# TODO: For now do not allow CP + AG + THD + Striped with NO_MASK
1308+
# Do not allow CP + AG + THD + Striped with NO_MASK
13101309
if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd():
13111310
raise ValueError(f"{header} only supports CAUSAL_MASK for THD types")
13121311

@@ -1339,6 +1338,7 @@ def get_adjusted_mask(self):
13391338
return self.config.attn_mask_type
13401339

13411340
def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size):
1341+
"""Converts the max segments per seq for context parallelism AG + THD."""
13421342
# Estimating adjusted max segments per seq
13431343
return (
13441344
max_seqlen // (self.config.stripe_size * cp_size)
@@ -1504,8 +1504,7 @@ def pad(x, npad):
15041504

15051505
return dk, dv # fall through
15061506

1507-
# Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos
1508-
# For e.g. below are the sharded post AG q seg ids and pos for a given rank:
1507+
# Below are the sharded post AG q seg ids and pos for a given rank:
15091508
# q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
15101509
# q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
15111510
# max_segments_per_seq = 7
@@ -1515,6 +1514,7 @@ def pad(x, npad):
15151514
# seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]
15161515
# seqlens_all_pad_neg = [[ 4, 4, 4, -1, -1, -1, -1]]
15171516
def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq):
1517+
"""Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos"""
15181518
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
15191519
non_zero_mask = q_segment_ids != 0
15201520
max_size = q_segment_ids.shape[-1]
@@ -1542,9 +1542,8 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme
15421542
segment_changes = jnp.concatenate(
15431543
[
15441544
first_is_segment, # First valid element starts a segment
1545-
(valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1])
1546-
| (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
1547-
(valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
1545+
(valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1]) |
1546+
(valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1)
15481547
],
15491548
axis=-1,
15501549
)
@@ -1558,8 +1557,7 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme
15581557
seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)
15591558
return seqlens_all_pad_neg
15601559

1561-
# Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos
1562-
# For e.g. below are the sharded post AG q seg ids and pos for a given rank:
1560+
# Below are the sharded post AG q seg ids and pos for a given rank:
15631561
# q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
15641562
# q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
15651563
# max_segments_per_seq = 7
@@ -1568,6 +1566,7 @@ def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segme
15681566
# segment_changes_masked = [[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]]
15691567
# seq_offsets = [[ 0, 8, 12, -1, -1, -1, -1, -1]]
15701568
def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq):
1569+
"""Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos"""
15711570
segment_changes = jnp.concatenate(
15721571
[
15731572
jnp.full(
@@ -1580,14 +1579,12 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se
15801579
# Remove any padded region segment changes
15811580
segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False)
15821581
# Get the indices for segment changes (these are the offsets)
1583-
max_size = q_segment_pos.shape[-1]
15841582
seq_offsets = jax.vmap(
15851583
lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq, fill_value=-1)[0]
15861584
)(segment_changes_masked)
15871585
return seq_offsets
15881586

1589-
# Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos
1590-
# For e.g. below are the sharded post AG q seg ids and pos for a given rank:
1587+
# Below are the sharded post AG q seg ids and pos for a given rank:
15911588
# kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
15921589
# kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
15931590
# max_segments_per_seq = 7
@@ -1597,6 +1594,7 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se
15971594
# segment_changes = [[False, False, False, True, False, False, False, True, False, False, False, True, True, True, True, False]]
15981595
# selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]]
15991596
def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq):
1597+
"""Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos"""
16001598
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
16011599
non_zero_mask = kv_segment_ids != 0
16021600
max_size = kv_segment_ids.shape[-1]
@@ -1614,7 +1612,6 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se
16141612
non_zero_indices >= 0, jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), 0
16151613
)
16161614
actual_valid = valid_segment_ids != 0
1617-
first_is_segment = actual_valid[..., 0:1]
16181615

16191616
# Detect segment breaks (only for non-zero segments)
16201617
segment_changes = jnp.concatenate(
@@ -1643,9 +1640,7 @@ def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_se
16431640
)
16441641
return selected_values
16451642

1646-
# Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
1647-
# AG kv seg ids and seg pos.
1648-
# For e.g. below are the sharded post AG q seg ids and pos for a given rank:
1643+
# Below are the sharded post AG q seg ids and pos for a given rank:
16491644
# kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
16501645
# kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
16511646
# kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -1679,6 +1674,8 @@ def kv_seqoffsets_for_striped_for_rank(
16791674
kv_segment_ids_ag,
16801675
max_segments_per_seq,
16811676
):
1677+
"""Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
1678+
AG kv seg ids and seg pos."""
16821679
# Calculate the segment pos change mask
16831680
segment_changes_first_true = jnp.concatenate(
16841681
[
@@ -2082,7 +2079,7 @@ def impl(
20822079
# Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos,
20832080
# _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in.
20842081
def _cross_attn(
2085-
idx, q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed
2082+
q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed
20862083
):
20872084
# Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
20882085
# Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
@@ -2143,7 +2140,6 @@ def _cross_attn(
21432140
functions = [
21442141
partial(
21452142
_cross_attn,
2146-
idx,
21472143
q,
21482144
k_ag,
21492145
v_ag,
@@ -2226,7 +2222,6 @@ def impl(
22262222

22272223
# See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
22282224
def _cross_attn_bwd(
2229-
idx,
22302225
q,
22312226
k,
22322227
v,
@@ -2306,7 +2301,6 @@ def _cross_attn_bwd(
23062301
functions = [
23072302
partial(
23082303
_cross_attn_bwd,
2309-
idx,
23102304
q,
23112305
k_ag,
23122306
v_ag,

0 commit comments

Comments
 (0)