Skip to content

Commit a385c47

Browse files
TMP: Throwaway test commit
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
1 parent bbb001d commit a385c47

File tree

4 files changed

+209
-14
lines changed

4 files changed

+209
-14
lines changed

tests/jax/test_fused_attn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,6 @@ def _setup_inputs(self):
438438
self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
439439
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
440440
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
441-
breakpoint()
442441

443442
key = jax.random.PRNGKey(0)
444443
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
@@ -663,7 +662,6 @@ def generate_random_segment_ids(
663662
self.cp_reorder_fn(self.segment_pos_kv),
664663
),
665664
)
666-
breakpoint()
667665
case _:
668666
raise ValueError(f"Unknown {self.seq_desc_format=}")
669667
else:
@@ -728,7 +726,6 @@ def to_dp_shardings(x):
728726

729727
self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
730728

731-
#jax.debug.breakpoint()
732729
if self.bias_shape == BiasShape._1HSS:
733730
self.bias_pspec = PartitionSpec(
734731
None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None

tests/jax/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,6 @@ def assert_allclose(
14651465
mismatch_counter = 0
14661466
has_nonzero = jnp.any(actual != 0)
14671467
print(f"has_nonzero: {has_nonzero}")
1468-
breakpoint()
14691468
with np.printoptions(threshold=sys.maxsize):
14701469
mismatch_mask = ~np.isclose(actual, desired, **tols) # True means mismatch
14711470
diff_indices = np.argwhere(mismatch_mask)

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
506506
else
507507
{
508508
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 0, actual_b, /*does not matter for single row*/actual_b);
509+
cudaDeviceSynchronize();
509510
}
510511
}
511512
if (devActualSeqlenQ) {
@@ -521,6 +522,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
521522
}
522523
else {
523524
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devActualSeqlenQ), 1, 0, actual_b, /*does not matter for single row*/actual_b);
525+
cudaDeviceSynchronize();
524526
}
525527
}
526528
if(devPtrCuSeqlensKV) {
@@ -538,6 +540,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
538540
}
539541
else {
540542
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b);
543+
cudaDeviceSynchronize();
541544
}
542545
}
543546
if(devActualSeqlenKV) {
@@ -556,6 +559,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
556559
else
557560
{
558561
print_tensor_elements_2<<<1, 1, 0, stream>>>(static_cast<int32_t *>(devActualSeqlenKV), 1, 0, actual_b, /*does not matter for single row*/ actual_b);
562+
cudaDeviceSynchronize();
559563
}
560564
}
561565
}
@@ -597,6 +601,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
597601
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
598602
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
599603
devOffsetsV, devOffsetsO, devOffsetsS);
604+
cudaDeviceSynchronize();
600605
NVTE_CHECK_CUDA(cudaGetLastError());
601606
if (print_tensors) {
602607
if (devPtrSeqOffsetsQ) {
@@ -614,6 +619,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
614619
print_tensor_elements_2<<<1, 1, 0, stream>>>(
615620
static_cast<int32_t *>(devPtrSeqOffsetsQ), 1, 0, actual_b,
616621
/*does not matter for single row*/ actual_b);
622+
cudaDeviceSynchronize();
617623
}
618624
}
619625
if (devOffsetsQ) {
@@ -631,6 +637,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
631637
print_tensor_elements_2<<<1, 1, 0, stream>>>(
632638
static_cast<int64_t *>(devOffsetsQ), 1, 0, actual_b,
633639
/*does not matter for single row*/ actual_b);
640+
cudaDeviceSynchronize();
634641
}
635642
}
636643
if (devPtrSeqOffsetsKV) {
@@ -648,6 +655,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
648655
print_tensor_elements_2<<<1, 1, 0, stream>>>(
649656
static_cast<int32_t *>(devPtrSeqOffsetsKV), 1, 0, actual_b,
650657
/*does not matter for single row*/ actual_b);
658+
cudaDeviceSynchronize();
651659
}
652660
}
653661
if (devOffsetsK) {
@@ -665,6 +673,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
665673
print_tensor_elements_2<<<1, 1, 0, stream>>>(
666674
static_cast<int64_t *>(devOffsetsK), 1, 0, actual_b,
667675
/*does not matter for single row*/ actual_b);
676+
cudaDeviceSynchronize();
668677
}
669678
}
670679
}

0 commit comments

Comments
 (0)