@@ -1582,7 +1582,7 @@ def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_se
15821582 # Get the indices for segment changes (these are the offsets)
15831583 max_size = q_segment_pos .shape [- 1 ]
15841584 seq_offsets = jax .vmap (
1585- lambda scm_row : jnp .where (scm_row , size = max_segments_per_seq + 1 , fill_value = - 1 )[0 ]
1585+ lambda scm_row : jnp .where (scm_row , size = max_segments_per_seq , fill_value = - 1 )[0 ]
15861586 )(segment_changes_masked )
15871587 return seq_offsets
15881588
@@ -1695,7 +1695,7 @@ def kv_seqoffsets_for_striped_for_rank(
16951695
16961696 # Get segment change indices for rank
16971697 segment_changes_indices = jax .vmap (
1698- lambda sc_row : jnp .where (sc_row , size = max_segments_per_seq + 1 , fill_value = - 1 )[0 ]
1698+ lambda sc_row : jnp .where (sc_row , size = max_segments_per_seq , fill_value = - 1 )[0 ]
16991699 )(segment_changes_first_true_masked )
17001700 # Get segment ids associated with the segment_changes_indices for rank
17011701 segment_ids = jax .vmap (
@@ -1719,7 +1719,7 @@ def kv_seqoffsets_for_striped_for_rank(
17191719 )
17201720 # Get segment change indices for AG
17211721 segment_changes_ag_indices = jax .vmap (
1722- lambda scag_row : jnp .where (scag_row , size = max_segments_per_seq + 1 , fill_value = - 1 )[0 ]
1722+ lambda scag_row : jnp .where (scag_row , size = max_segments_per_seq , fill_value = - 1 )[0 ]
17231723 )(segment_changes_ag_first_true_masked )
17241724
17251725 # Use the segment ids picked per rank to get the offsets from the AG indices
0 commit comments