@@ -506,6 +506,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
506506 else
507507 {
508508 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 0 , actual_b, /* does not matter for single row*/ actual_b);
509+ cudaDeviceSynchronize ();
509510 }
510511 }
511512 if (devActualSeqlenQ) {
@@ -521,6 +522,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
521522 }
522523 else {
523524 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devActualSeqlenQ), 1 , 0 , actual_b, /* does not matter for single row*/ actual_b);
525+ cudaDeviceSynchronize ();
524526 }
525527 }
526528 if (devPtrCuSeqlensKV) {
@@ -538,6 +540,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
538540 }
539541 else {
540542 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 0 , actual_b, /* does not matter for single row*/ actual_b);
543+ cudaDeviceSynchronize ();
541544 }
542545 }
543546 if (devActualSeqlenKV) {
@@ -556,6 +559,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
556559 else
557560 {
558561 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devActualSeqlenKV), 1 , 0 , actual_b, /* does not matter for single row*/ actual_b);
562+ cudaDeviceSynchronize ();
559563 }
560564 }
561565 }
@@ -597,6 +601,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
597601 layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast <int32_t *>(devPtrSeqOffsetsQ),
598602 static_cast <int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
599603 devOffsetsV, devOffsetsO, devOffsetsS);
604+ cudaDeviceSynchronize ();
600605 NVTE_CHECK_CUDA (cudaGetLastError ());
601606 if (print_tensors) {
602607 if (devPtrSeqOffsetsQ) {
@@ -614,6 +619,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
614619 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
615620 static_cast <int32_t *>(devPtrSeqOffsetsQ), 1 , 0 , actual_b,
616621 /* does not matter for single row*/ actual_b);
622+ cudaDeviceSynchronize ();
617623 }
618624 }
619625 if (devOffsetsQ) {
@@ -631,6 +637,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
631637 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
632638 static_cast <int64_t *>(devOffsetsQ), 1 , 0 , actual_b,
633639 /* does not matter for single row*/ actual_b);
640+ cudaDeviceSynchronize ();
634641 }
635642 }
636643 if (devPtrSeqOffsetsKV) {
@@ -648,6 +655,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
648655 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
649656 static_cast <int32_t *>(devPtrSeqOffsetsKV), 1 , 0 , actual_b,
650657 /* does not matter for single row*/ actual_b);
658+ cudaDeviceSynchronize ();
651659 }
652660 }
653661 if (devOffsetsK) {
@@ -665,6 +673,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
665673 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
666674 static_cast <int64_t *>(devOffsetsK), 1 , 0 , actual_b,
667675 /* does not matter for single row*/ actual_b);
676+ cudaDeviceSynchronize ();
668677 }
669678 }
670679 }
0 commit comments