Skip to content

Commit 9fc5a74

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 565a1e7 commit 9fc5a74

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1676,7 +1676,7 @@ def kv_seqoffsets_for_striped_for_rank(
16761676
max_segments_per_seq,
16771677
):
16781678
"""Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
1679-
AG kv seg ids and seg pos."""
1679+
AG kv seg ids and seg pos."""
16801680
# Calculate the segment pos change mask
16811681
segment_changes_first_true = jnp.concatenate(
16821682
[

0 commit comments

Comments
 (0)