Skip to content

Commit 51440db

Browse files
Kshitij  Janardan LakhaniKshitijLakhani
authored andcommitted
Fix seqoffsets length to be passed onto FusedAttn primitive as it is b and not b+1 needed by cuDNN
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
1 parent ab81a30 commit 51440db

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,7 +1582,7 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se
15821582
# Get the indices for segment changes (these are the offsets)
15831583
max_size = q_segment_pos.shape[-1]
15841584
seq_offsets = jax.vmap(
1585-
lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0]
1585+
lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq, fill_value=-1)[0]
15861586
)(segment_changes_masked)
15871587
return seq_offsets
15881588

@@ -1695,7 +1695,7 @@ def kv_seqoffsets_for_striped_for_rank(
16951695

16961696
# Get segment change indices for rank
16971697
segment_changes_indices = jax.vmap(
1698-
lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq + 1, fill_value=-1)[0]
1698+
lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq, fill_value=-1)[0]
16991699
)(segment_changes_first_true_masked)
17001700
# Get segment ids associated with the segment_changes_indices for rank
17011701
segment_ids = jax.vmap(
@@ -1719,7 +1719,7 @@ def kv_seqoffsets_for_striped_for_rank(
17191719
)
17201720
# Get segment change indices for AG
17211721
segment_changes_ag_indices = jax.vmap(
1722-
lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq + 1, fill_value=-1)[0]
1722+
lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq, fill_value=-1)[0]
17231723
)(segment_changes_ag_first_true_masked)
17241724

17251725
# Use the segment ids picked per rank to get the offsets from the AG indices

0 commit comments

Comments
 (0)