We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 565a1e7 commit 9fc5a74Copy full SHA for 9fc5a74
transformer_engine/jax/cpp_extensions/attention.py
@@ -1676,7 +1676,7 @@ def kv_seqoffsets_for_striped_for_rank(
1676
max_segments_per_seq,
1677
):
1678
"""Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
1679
- AG kv seg ids and seg pos."""
+ AG kv seg ids and seg pos."""
1680
# Calculate the segment pos change mask
1681
segment_changes_first_true = jnp.concatenate(
1682
[
0 commit comments