4949namespace transformer_engine {
5050namespace fused_attn {
5151template <typename T>
52- __global__ void print_tensor_elements_2 (const T *const data, const size_t rows, const size_t start_cols, const size_t end_cols, const size_t cols) {
52+ __global__ void print_tensor_elements_2 (const T *const data, const size_t rows,
53+ const size_t start_cols, const size_t end_cols,
54+ const size_t cols) {
5355 if ((threadIdx .x == 0 ) && (threadIdx .y == 0 ) && (blockIdx .x == 0 ) && (blockIdx .y == 0 )) {
5456 for (size_t i = 0 ; i < rows; ++i) {
5557 for (size_t j = start_cols; j < end_cols; ++j) {
@@ -487,47 +489,47 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
487489 static_cast <const int32_t *>(devPtrCuSeqlensKV), static_cast <int32_t *>(devActualSeqlenQ),
488490 static_cast <int32_t *>(devActualSeqlenKV));
489491 NVTE_CHECK_CUDA (cudaGetLastError ());
490- std::cout << " print_tensors: " << print_tensors <<
491- " print_tensors_custom_mask: "
492- << print_tensors_custom_mask << std::endl;
493- if (print_tensors)
494- {
495- if (devPtrCuSeqlensQ) {
496- if (print_tensors_custom_mask)
497- {
498- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 0 , 8 , /* does not matter for single row*/ actual_b);
499- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devPtrCuSeqlensQ), 1 ,
500- 1024 , 1032 ,
501- /* does not matter for single row*/ actual_b);
502- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devPtrCuSeqlensQ), 1 ,
503- 8184 , 8192 ,
504- /* does not matter for single row*/ actual_b);
505- }
506- else
507- {
508- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 0 , actual_b, /* does not matter for single row*/ actual_b);
492+ std::cout << " print_tensors: " << print_tensors
493+ << " print_tensors_custom_mask: " << print_tensors_custom_mask << std::endl;
494+ if (print_tensors) {
495+ if (devPtrCuSeqlensQ) {
496+ if (print_tensors_custom_mask) {
497+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
498+ static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 0 , 8 ,
499+ /* does not matter for single row*/ actual_b);
500+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
501+ static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 1024 , 1032 ,
502+ /* does not matter for single row*/ actual_b);
503+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
504+ static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 8184 , 8192 ,
505+ /* does not matter for single row*/ actual_b);
506+ } else {
507+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
508+ static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 0 , actual_b,
509+ /* does not matter for single row*/ actual_b);
509510 cudaDeviceSynchronize ();
510511 }
511512 }
512513 if (devActualSeqlenQ) {
513- if (print_tensors_custom_mask)
514- {
515- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devActualSeqlenQ), 1 , 0 , 8 , /* does not matter for single row*/ actual_b);
516- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devActualSeqlenQ), 1 ,
517- 1024 , 1032 ,
518- /* does not matter for single row*/ actual_b);
519- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devActualSeqlenQ), 1 ,
520- 8184 , 8192 ,
521- /* does not matter for single row*/ actual_b);
522- }
523- else {
524- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devActualSeqlenQ), 1 , 0 , actual_b, /* does not matter for single row*/ actual_b);
514+ if (print_tensors_custom_mask) {
515+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
516+ static_cast <int32_t *>(devActualSeqlenQ), 1 , 0 , 8 ,
517+ /* does not matter for single row*/ actual_b);
518+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
519+ static_cast <int32_t *>(devActualSeqlenQ), 1 , 1024 , 1032 ,
520+ /* does not matter for single row*/ actual_b);
521+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
522+ static_cast <int32_t *>(devActualSeqlenQ), 1 , 8184 , 8192 ,
523+ /* does not matter for single row*/ actual_b);
524+ } else {
525+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
526+ static_cast <int32_t *>(devActualSeqlenQ), 1 , 0 , actual_b,
527+ /* does not matter for single row*/ actual_b);
525528 cudaDeviceSynchronize ();
526529 }
527530 }
528- if (devPtrCuSeqlensKV) {
529- if (print_tensors_custom_mask)
530- {
531+ if (devPtrCuSeqlensKV) {
532+ if (print_tensors_custom_mask) {
531533 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
532534 static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 0 , 8 ,
533535 /* does not matter for single row*/ actual_b);
@@ -537,15 +539,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
537539 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
538540 static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 8184 , 8192 ,
539541 /* does not matter for single row*/ actual_b);
540- }
541- else {
542- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 0 , actual_b, /* does not matter for single row*/ actual_b);
542+ } else {
543+ print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
544+ static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 0 , actual_b,
545+ /* does not matter for single row*/ actual_b);
543546 cudaDeviceSynchronize ();
544547 }
545548 }
546- if (devActualSeqlenKV) {
547- if (print_tensors_custom_mask)
548- {
549+ if (devActualSeqlenKV) {
550+ if (print_tensors_custom_mask) {
549551 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
550552 static_cast <int32_t *>(devActualSeqlenKV), 1 , 0 , 8 ,
551553 /* does not matter for single row*/ actual_b);
@@ -555,10 +557,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
555557 print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
556558 static_cast <int32_t *>(devActualSeqlenKV), 1 , 8184 , 8192 ,
557559 /* does not matter for single row*/ actual_b);
558- }
559- else
560- {
561- print_tensor_elements_2 <<< 1 , 1 , 0 , stream>>> ( static_cast < int32_t *>(devActualSeqlenKV), 1 , 0 , actual_b, /* does not matter for single row*/ actual_b);
560+ } else {
561+ print_tensor_elements_2 <<< 1 , 1 , 0 , stream>>> (
562+ static_cast < int32_t *>(devActualSeqlenKV), 1 , 0 , actual_b,
563+ /* does not matter for single row*/ actual_b);
562564 cudaDeviceSynchronize ();
563565 }
564566 }
@@ -677,18 +679,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
677679 }
678680 }
679681 }
680- if (is_ragged_q) {
681- variant_pack[offset_q] = devOffsetsQ;
682- variant_pack[offset_o] = devOffsetsO;
683- }
684- if (is_ragged_kv) {
685- variant_pack[offset_k] = devOffsetsK;
686- variant_pack[offset_v] = devOffsetsV;
687- }
688- if (is_ragged_q && cudnn_runtime_version >= 90600 ) {
689- variant_pack[offset_stats] = devOffsetsS;
690- }
682+ if (is_ragged_q) {
683+ variant_pack[offset_q] = devOffsetsQ;
684+ variant_pack[offset_o] = devOffsetsO;
685+ }
686+ if (is_ragged_kv) {
687+ variant_pack[offset_k] = devOffsetsK;
688+ variant_pack[offset_v] = devOffsetsV;
689+ }
690+ if (is_ragged_q && cudnn_runtime_version >= 90600 ) {
691+ variant_pack[offset_stats] = devOffsetsS;
691692 }
693+ }
692694
693695 if (is_dropout) {
694696 variant_pack[dropout_seed] = devPtrDropoutSeed;
0 commit comments