Skip to content

Commit 547bf11

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a385c47 commit 547bf11

File tree

6 files changed

+227
-175
lines changed

6 files changed

+227
-175
lines changed

tests/jax/test_distributed_fused_attn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_cross_attn(
280280

281281
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
282282
# Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes.
283-
#TODO: Change the id to CPx2
283+
# TODO: Change the id to CPx2
284284
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
285285
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
286286
# KL test code
@@ -362,7 +362,7 @@ def impl_test_context_parallel_attn(
362362
mesh_resource=mesh_resource,
363363
cp_strategy=cp_strategy,
364364
cp_load_balanced=load_balanced,
365-
stripe_height=stripe_height
365+
stripe_height=stripe_height,
366366
)
367367

368368
def check_has_backend_for_mask(mask_type):
@@ -401,8 +401,8 @@ def check_has_backend_for_mask(mask_type):
401401
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
402402
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
403403

404-
#KL code
405-
#runner.test_backward()
404+
# KL code
405+
# runner.test_backward()
406406
runner.test_forward()
407407
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
408408

@@ -602,6 +602,7 @@ def test_context_parallel_ring_attn_shardy(
602602
pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"),
603603
]
604604

605+
605606
class TestReorderCausalLoadBalancing:
606607
@pytest.mark.parametrize("cp_size", [2, 4, 8])
607608
@pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
@@ -619,10 +620,9 @@ def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_height):
619620

620621
if reorder_strategy == ReorderStrategy.Striped:
621622
seq_lens = shape[seq_dim]
622-
if seq_lens < (cp_size*stripe_height):
623+
if seq_lens < (cp_size * stripe_height):
623624
pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_height=}")
624625

625-
626626
ref = tensor.copy()
627627

628628
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])

tests/jax/test_fused_attn.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -474,12 +474,16 @@ def _setup_inputs(self):
474474
token_numbers_k = range(self.max_seqlen_kv)
475475
for batch_idx in range(q_shape[0]):
476476
for token_idx in token_numbers_q:
477-
q_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * (token_idx + 1)
477+
q_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * (
478+
token_idx + 1
479+
)
478480
for token_idx in token_numbers_k:
479-
k_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * np.sqrt(self.head_dim_qk)
481+
k_np[batch_idx][token_idx][0] = np.ones(self.head_dim_qk, self.dtype) * np.sqrt(
482+
self.head_dim_qk
483+
)
480484
v_np = np.ones(v_shape, self.dtype)
481485
# Set cols at multiples
482-
v_np[0,::4, 0, :] = np.arange(v_np.shape[3])
486+
v_np[0, ::4, 0, :] = np.arange(v_np.shape[3])
483487
self.q = jnp.array(q_np)
484488
self.k = jnp.array(k_np)
485489
self.v = jnp.array(v_np)
@@ -541,7 +545,7 @@ def generate_random_segment_ids(
541545
min_segment_size = 1
542546
if min_segment_len is not None:
543547
min_segment_size = min_segment_len[i][seg_id]
544-
#KL test code
548+
# KL test code
545549
min_segment_size = 4
546550
segment_size = rng.integers(min_segment_size, max_segment_size + 1)
547551
if current_pos + segment_size > sequence_length:
@@ -598,8 +602,16 @@ def generate_random_segment_ids(
598602
)
599603
self.segment_pos_q = self.segment_pos_kv = None
600604
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
601-
print(f"self.segment_ids_q: {self.segment_ids_q}, \n self.segment_pos_q: {self.segment_pos_q}, \n self.pad_q: {self.pad_q}, \n self.seqlens_q: {self.seqlens_q}, \n self.offsets_q: { self.offsets_q} \n")
602-
print(f"self.segment_ids_kv: {self.segment_ids_kv}, \n self.segment_pos_kv: {self.segment_pos_kv}, \n self.pad_kv: {self.pad_kv}, \n self.seqlens_kv: {self.seqlens_kv}, \n self.offsets_kv: { self.offsets_kv} \n")
605+
print(
606+
f"self.segment_ids_q: {self.segment_ids_q}, \n self.segment_pos_q:"
607+
f" {self.segment_pos_q}, \n self.pad_q: {self.pad_q}, \n self.seqlens_q:"
608+
f" {self.seqlens_q}, \n self.offsets_q: { self.offsets_q} \n"
609+
)
610+
print(
611+
f"self.segment_ids_kv: {self.segment_ids_kv}, \n self.segment_pos_kv:"
612+
f" {self.segment_pos_kv}, \n self.pad_kv: {self.pad_kv}, \n self.seqlens_kv:"
613+
f" {self.seqlens_kv}, \n self.offsets_kv: { self.offsets_kv} \n"
614+
)
603615

604616
# For reference code
605617
self.mask = make_mask(
@@ -612,6 +624,7 @@ def generate_random_segment_ids(
612624
)
613625
# KL tet code
614626
import sys
627+
615628
with np.printoptions(threshold=sys.maxsize):
616629
print(f"self.mask: \n {self.mask}")
617630

@@ -876,7 +889,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
876889
"window_size": self.window_size,
877890
"context_parallel_strategy": self.cp_strategy,
878891
"context_parallel_causal_load_balanced": self.cp_load_balanced,
879-
#"stripe_height": self.stripe_height,
892+
# "stripe_height": self.stripe_height,
880893
}
881894

882895
# We can compute dBias only for the [1, h, s, s] layout

tests/jax/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1462,11 +1462,12 @@ def assert_allclose(
14621462
desired = desired.astype(jnp.float32)
14631463
# KL test code
14641464
import sys
1465+
14651466
mismatch_counter = 0
14661467
has_nonzero = jnp.any(actual != 0)
14671468
print(f"has_nonzero: {has_nonzero}")
14681469
with np.printoptions(threshold=sys.maxsize):
1469-
mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch
1470+
mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch
14701471
diff_indices = np.argwhere(mismatch_mask)
14711472
for idx in diff_indices:
14721473
idx_tuple = tuple(idx)

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949
namespace transformer_engine {
5050
namespace fused_attn {
5151
template <typename T>
52-
__global__ void print_tensor_elements_2(const T *const data, const size_t rows, const size_t start_cols, const size_t end_cols, const size_t cols) {
52+
__global__ void print_tensor_elements_2(const T *const data, const size_t rows,
53+
const size_t start_cols, const size_t end_cols,
54+
const size_t cols) {
5355
if ((threadIdx.x == 0) && (threadIdx.y == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
5456
for (size_t i = 0; i < rows; ++i) {
5557
for (size_t j = start_cols; j < end_cols; ++j) {
@@ -487,47 +489,47 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
487489
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
488490
static_cast<int32_t *>(devActualSeqlenKV));
489491
NVTE_CHECK_CUDA(cudaGetLastError());
490-
std::cout << "print_tensors: " << print_tensors <<
491-
"print_tensors_custom_mask: "
492-
<< print_tensors_custom_mask << std::endl;
493-
if (print_tensors)
494-
{
495-
if(devPtrCuSeqlensQ) {
496-
if(print_tensors_custom_mask)
497-
{
498-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 0, 8, /*does not matter for single row*/ actual_b);
499-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devPtrCuSeqlensQ), 1,
500-
1024, 1032,
501-
/*does not matter for single row*/ actual_b);
502-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devPtrCuSeqlensQ), 1,
503-
8184, 8192,
504-
/*does not matter for single row*/ actual_b);
505-
}
506-
else
507-
{
508-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 0, actual_b, /*does not matter for single row*/actual_b);
492+
std::cout << "print_tensors: " << print_tensors
493+
<< "print_tensors_custom_mask: " << print_tensors_custom_mask << std::endl;
494+
if (print_tensors) {
495+
if (devPtrCuSeqlensQ) {
496+
if (print_tensors_custom_mask) {
497+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
498+
static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 0, 8,
499+
/*does not matter for single row*/ actual_b);
500+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
501+
static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 1024, 1032,
502+
/*does not matter for single row*/ actual_b);
503+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
504+
static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 8184, 8192,
505+
/*does not matter for single row*/ actual_b);
506+
} else {
507+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
508+
static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 0, actual_b,
509+
/*does not matter for single row*/ actual_b);
509510
cudaDeviceSynchronize();
510511
}
511512
}
512513
if (devActualSeqlenQ) {
513-
if (print_tensors_custom_mask)
514-
{
515-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devActualSeqlenQ), 1, 0, 8, /*does not matter for single row*/actual_b);
516-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devActualSeqlenQ), 1,
517-
1024, 1032,
518-
/*does not matter for single row*/ actual_b);
519-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devActualSeqlenQ), 1,
520-
8184, 8192,
521-
/*does not matter for single row*/ actual_b);
522-
}
523-
else {
524-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devActualSeqlenQ), 1, 0, actual_b, /*does not matter for single row*/actual_b);
514+
if (print_tensors_custom_mask) {
515+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
516+
static_cast<int32_t *>(devActualSeqlenQ), 1, 0, 8,
517+
/*does not matter for single row*/ actual_b);
518+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
519+
static_cast<int32_t *>(devActualSeqlenQ), 1, 1024, 1032,
520+
/*does not matter for single row*/ actual_b);
521+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
522+
static_cast<int32_t *>(devActualSeqlenQ), 1, 8184, 8192,
523+
/*does not matter for single row*/ actual_b);
524+
} else {
525+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
526+
static_cast<int32_t *>(devActualSeqlenQ), 1, 0, actual_b,
527+
/*does not matter for single row*/ actual_b);
525528
cudaDeviceSynchronize();
526529
}
527530
}
528-
if(devPtrCuSeqlensKV) {
529-
if(print_tensors_custom_mask)
530-
{
531+
if (devPtrCuSeqlensKV) {
532+
if (print_tensors_custom_mask) {
531533
print_tensor_elements_2<<<1, 1, 0, stream>>>(
532534
static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 0, 8,
533535
/*does not matter for single row*/ actual_b);
@@ -537,15 +539,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
537539
print_tensor_elements_2<<<1, 1, 0, stream>>>(
538540
static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 8184, 8192,
539541
/*does not matter for single row*/ actual_b);
540-
}
541-
else {
542-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b);
542+
} else {
543+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
544+
static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 0, actual_b,
545+
/*does not matter for single row*/ actual_b);
543546
cudaDeviceSynchronize();
544547
}
545548
}
546-
if(devActualSeqlenKV) {
547-
if (print_tensors_custom_mask)
548-
{
549+
if (devActualSeqlenKV) {
550+
if (print_tensors_custom_mask) {
549551
print_tensor_elements_2<<<1, 1, 0, stream>>>(
550552
static_cast<int32_t *>(devActualSeqlenKV), 1, 0, 8,
551553
/*does not matter for single row*/ actual_b);
@@ -555,10 +557,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
555557
print_tensor_elements_2<<<1, 1, 0, stream>>>(
556558
static_cast<int32_t *>(devActualSeqlenKV), 1, 8184, 8192,
557559
/*does not matter for single row*/ actual_b);
558-
}
559-
else
560-
{
561-
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devActualSeqlenKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b);
560+
} else {
561+
print_tensor_elements_2<<<1, 1, 0, stream>>>(
562+
static_cast<int32_t *>(devActualSeqlenKV), 1, 0, actual_b,
563+
/*does not matter for single row*/ actual_b);
562564
cudaDeviceSynchronize();
563565
}
564566
}
@@ -677,18 +679,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
677679
}
678680
}
679681
}
680-
if (is_ragged_q) {
681-
variant_pack[offset_q] = devOffsetsQ;
682-
variant_pack[offset_o] = devOffsetsO;
683-
}
684-
if (is_ragged_kv) {
685-
variant_pack[offset_k] = devOffsetsK;
686-
variant_pack[offset_v] = devOffsetsV;
687-
}
688-
if (is_ragged_q && cudnn_runtime_version >= 90600) {
689-
variant_pack[offset_stats] = devOffsetsS;
690-
}
682+
if (is_ragged_q) {
683+
variant_pack[offset_q] = devOffsetsQ;
684+
variant_pack[offset_o] = devOffsetsO;
685+
}
686+
if (is_ragged_kv) {
687+
variant_pack[offset_k] = devOffsetsK;
688+
variant_pack[offset_v] = devOffsetsV;
689+
}
690+
if (is_ragged_q && cudnn_runtime_version >= 90600) {
691+
variant_pack[offset_stats] = devOffsetsS;
691692
}
693+
}
692694

693695
if (is_dropout) {
694696
variant_pack[dropout_seed] = devPtrDropoutSeed;

transformer_engine/jax/attention.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,9 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
353353
return batch, q_max_seqlen, kv_max_seqlen
354354

355355

356-
def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1):
356+
def reorder_causal_load_balancing(
357+
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1
358+
):
357359
"""Reorders a tensor for load balancing the compute of causal attention."""
358360
if strategy == ReorderStrategy.DualChunkSwap:
359361
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
@@ -363,7 +365,7 @@ def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: in
363365

364366

365367
def inverse_reorder_causal_load_balancing(
366-
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1
368+
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_height: int = 1
367369
):
368370
"""Inverse operation of `reorder_causal_load_balancing`."""
369371
if strategy == ReorderStrategy.DualChunkSwap:
@@ -498,7 +500,7 @@ def _segment_ids_pos_to_seqlens_offsets(
498500
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
499501
# examine only O(Q+KV) elements.
500502
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
501-
#TODO: Un comment the fast path
503+
# TODO: Un comment the fast path
502504
# if (attn_mask_type.is_causal() and window_size is None) or (
503505
# window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
504506
# ):
@@ -517,7 +519,7 @@ def _segment_ids_pos_to_seqlens_offsets(
517519
segment_ids_kv,
518520
lambda x, y: jnp.equal(x, y) * x,
519521
)
520-
#jax.debug.breakpoint()
522+
# jax.debug.breakpoint()
521523
# TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied
522524
attn_mask = segment_mask
523525
if attn_mask_type.is_bottom_right():
@@ -579,7 +581,7 @@ def _segment_ids_pos_to_seqlens_offsets(
579581
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
580582
attn_mask_with_id, max_segments_per_seq
581583
)
582-
#jax.debug.breakpoint()
584+
# jax.debug.breakpoint()
583585
return q_seqlen, kv_seqlen, q_offset, kv_offset
584586

585587

@@ -659,7 +661,7 @@ def get_seqlens_and_offsets(
659661
window_size,
660662
max_segments_per_seq,
661663
)
662-
#jax.debug.breakpoint()
664+
# jax.debug.breakpoint()
663665
else:
664666
q_seqlens, kv_seqlens = _segment_ids_to_seqlens(
665667
q_segment_ids,
@@ -1038,7 +1040,7 @@ def _fused_attn_fwd_rule(
10381040
context_parallel_strategy=context_parallel_strategy,
10391041
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
10401042
context_parallel_axis=context_parallel_axis,
1041-
stripe_height=stripe_height
1043+
stripe_height=stripe_height,
10421044
)
10431045
output = checkpoint_name(output, context_checkpoint_name)
10441046
softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
@@ -1162,7 +1164,7 @@ def fused_attn(
11621164
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
11631165
context_parallel_axis (str): The name of the context parallel axis.
11641166
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
1165-
stripe_height (int):
1167+
stripe_height (int):
11661168
Indicates the striping height to be used when using ReorderStrategy.Striped.
11671169
Currently, a stripe_height > 1 is only allowed for CP + THD + Striped + AG
11681170
0 indicates no striping strategy

0 commit comments

Comments
 (0)