Skip to content

Commit ab81a30

Browse files
nit: Apply suggestions from code review
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Fix type on fused attn tests Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
1 parent c5e0d6f commit ab81a30

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

tests/jax/test_fused_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def to_dp_shardings(x):
776776
self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
777777
self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)
778778

779-
def _test_forward(self):
779+
def test_forward(self):
780780
"""
781781
Test forward with JITted primitive and unJITted reference
782782
"""
@@ -1150,7 +1150,7 @@ class TestFusedAttn:
11501150
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
11511151
],
11521152
)
1153-
def test_forward(
1153+
def _test_forward(
11541154
b,
11551155
s_q,
11561156
s_kv,

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,7 @@ def check_supported(self):
13061306
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
13071307
)
13081308
# TODO: For now do not all CP + AG + THD + Striped with NO_MASK
1309+
# TODO: For now do not allow CP + AG + THD + Striped with NO_MASK
13091310
if self.config.attn_mask_type is AttnMaskType.NO_MASK and self.config.qkv_layout.is_thd():
13101311
raise ValueError(f"{header} only supports CAUSAL_MASK for THD types")
13111312

@@ -1380,7 +1381,7 @@ def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig:
13801381
)
13811382

13821383
def all_gather_kv(self, k, v):
1383-
"""Performs aa all-gather of k and v over context parallel ranks."""
1384+
"""Performs an all-gather of k and v over context parallel ranks."""
13841385

13851386
def ag(x):
13861387
x = lax_paral_op(
@@ -1402,7 +1403,7 @@ def ag(x):
14021403
return k, v # fall through
14031404

14041405
def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos):
1405-
"""Performs aa all-gather of kv segment ids and kv segment pos over context parallel ranks."""
1406+
"""Performs an all-gather of kv segment ids and kv segment pos over context parallel ranks."""
14061407
kv_segment_ids = lax_paral_op(
14071408
kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
14081409
)

0 commit comments

Comments
 (0)