4848
4949namespace transformer_engine {
5050namespace fused_attn {
51- template <typename T>
52- __global__ void print_tensor_elements_2 (const T *const data, const size_t rows,
53- const size_t start_cols, const size_t end_cols,
54- const size_t cols) {
55- if ((threadIdx .x == 0 ) && (threadIdx .y == 0 ) && (blockIdx .x == 0 ) && (blockIdx .y == 0 )) {
56- for (size_t i = 0 ; i < rows; ++i) {
57- for (size_t j = start_cols; j < end_cols; ++j) {
58- const size_t idx = i * cols + j;
59- printf (" %8f " , static_cast <float >(data[idx]));
60- }
61- printf (" \n " );
62- }
63- }
64- }
65-
6651void fused_attn_arbitrary_seqlen_fwd_impl (
6752 int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
6853 int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
@@ -474,10 +459,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
474459 if (is_bias) {
475460 variant_pack[bias] = devPtrBias;
476461 }
477- // KL test code
478- bool print_tensors = true ;
479- // For the thd_regular case, the actual_b = 18
480- bool print_tensors_custom_mask = actual_b >= 300 ? true : false ;
481462
482463 if (is_padding) {
483464 constexpr size_t nthreads_per_block = 128 ;
@@ -489,79 +470,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
489470 static_cast <const int32_t *>(devPtrCuSeqlensKV), static_cast <int32_t *>(devActualSeqlenQ),
490471 static_cast <int32_t *>(devActualSeqlenKV));
491472 NVTE_CHECK_CUDA (cudaGetLastError ());
492- // std::cout << "print_tensors: " << print_tensors <<
493- // "print_tensors_custom_mask: "
494- // << print_tensors_custom_mask << std::endl;
495- if (print_tensors) {
496- if (devPtrCuSeqlensQ) {
497- if (print_tensors_custom_mask) {
498- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
499- static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 0 , 8 ,
500- /* does not matter for single row*/ actual_b);
501- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
502- static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 1024 , 1032 ,
503- /* does not matter for single row*/ actual_b);
504- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
505- static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 8184 , 8192 ,
506- /* does not matter for single row*/ actual_b);
507- } else {
508- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
509- static_cast <int32_t *>(devPtrCuSeqlensQ), 1 , 0 , actual_b,
510- /* does not matter for single row*/ actual_b);
511- }
512- }
513- if (devActualSeqlenQ) {
514- if (print_tensors_custom_mask) {
515- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
516- static_cast <int32_t *>(devActualSeqlenQ), 1 , 0 , 8 ,
517- /* does not matter for single row*/ actual_b);
518- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
519- static_cast <int32_t *>(devActualSeqlenQ), 1 , 1024 , 1032 ,
520- /* does not matter for single row*/ actual_b);
521- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
522- static_cast <int32_t *>(devActualSeqlenQ), 1 , 8184 , 8192 ,
523- /* does not matter for single row*/ actual_b);
524- } else {
525- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
526- static_cast <int32_t *>(devActualSeqlenQ), 1 , 0 , actual_b,
527- /* does not matter for single row*/ actual_b);
528- }
529- }
530- if (devPtrCuSeqlensKV) {
531- if (print_tensors_custom_mask) {
532- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
533- static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 0 , 8 ,
534- /* does not matter for single row*/ actual_b);
535- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
536- static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 1024 , 1032 ,
537- /* does not matter for single row*/ actual_b);
538- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
539- static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 8184 , 8192 ,
540- /* does not matter for single row*/ actual_b);
541- } else {
542- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
543- static_cast <int32_t *>(devPtrCuSeqlensKV), 1 , 0 , actual_b,
544- /* does not matter for single row*/ actual_b);
545- }
546- }
547- if (devActualSeqlenKV) {
548- if (print_tensors_custom_mask) {
549- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
550- static_cast <int32_t *>(devActualSeqlenKV), 1 , 0 , 8 ,
551- /* does not matter for single row*/ actual_b);
552- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
553- static_cast <int32_t *>(devActualSeqlenKV), 1 , 1024 , 1032 ,
554- /* does not matter for single row*/ actual_b);
555- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
556- static_cast <int32_t *>(devActualSeqlenKV), 1 , 8184 , 8192 ,
557- /* does not matter for single row*/ actual_b);
558- } else {
559- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
560- static_cast <int32_t *>(devActualSeqlenKV), 1 , 0 , actual_b,
561- /* does not matter for single row*/ actual_b);
562- }
563- }
564- }
565473 variant_pack[seq_q] = devActualSeqlenQ;
566474 variant_pack[seq_kv] = devActualSeqlenKV;
567475 }
@@ -601,76 +509,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
601509 static_cast <int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
602510 devOffsetsV, devOffsetsO, devOffsetsS);
603511 NVTE_CHECK_CUDA (cudaGetLastError ());
604- if (print_tensors) {
605- if (devPtrSeqOffsetsQ) {
606- if (print_tensors_custom_mask) {
607- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
608- static_cast <int32_t *>(devPtrSeqOffsetsQ), 1 , 0 , 8 ,
609- /* does not matter for single row*/ actual_b);
610- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
611- static_cast <int32_t *>(devPtrSeqOffsetsQ), 1 , 1024 , 1032 ,
612- /* does not matter for single row*/ actual_b);
613- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
614- static_cast <int32_t *>(devPtrSeqOffsetsQ), 1 , 8184 , 8192 ,
615- /* does not matter for single row*/ actual_b);
616- } else {
617- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
618- static_cast <int32_t *>(devPtrSeqOffsetsQ), 1 , 0 , actual_b,
619- /* does not matter for single row*/ actual_b);
620- }
621- }
622- if (devOffsetsQ) {
623- if (print_tensors_custom_mask) {
624- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
625- static_cast <int64_t *>(devOffsetsQ), 1 , 0 , 8 ,
626- /* does not matter for single row*/ actual_b);
627- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
628- static_cast <int64_t *>(devOffsetsQ), 1 , 1024 , 1032 ,
629- /* does not matter for single row*/ actual_b);
630- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
631- static_cast <int64_t *>(devOffsetsQ), 1 , 8184 , 8192 ,
632- /* does not matter for single row*/ actual_b);
633- } else {
634- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
635- static_cast <int64_t *>(devOffsetsQ), 1 , 0 , actual_b,
636- /* does not matter for single row*/ actual_b);
637- }
638- }
639- if (devPtrSeqOffsetsKV) {
640- if (print_tensors_custom_mask) {
641- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
642- static_cast <int32_t *>(devPtrSeqOffsetsKV), 1 , 0 , 8 ,
643- /* does not matter for single row*/ actual_b);
644- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
645- static_cast <int32_t *>(devPtrSeqOffsetsKV), 1 , 1024 , 1032 ,
646- /* does not matter for single row*/ actual_b);
647- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
648- static_cast <int32_t *>(devPtrSeqOffsetsKV), 1 , 8184 , 8192 ,
649- /* does not matter for single row*/ actual_b);
650- } else {
651- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
652- static_cast <int32_t *>(devPtrSeqOffsetsKV), 1 , 0 , actual_b,
653- /* does not matter for single row*/ actual_b);
654- }
655- }
656- if (devOffsetsK) {
657- if (print_tensors_custom_mask) {
658- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
659- static_cast <int64_t *>(devOffsetsK), 1 , 0 , 8 ,
660- /* does not matter for single row*/ actual_b);
661- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
662- static_cast <int64_t *>(devOffsetsK), 1 , 1024 , 1032 ,
663- /* does not matter for single row*/ actual_b);
664- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
665- static_cast <int64_t *>(devOffsetsK), 1 , 8184 , 8192 ,
666- /* does not matter for single row*/ actual_b);
667- } else {
668- print_tensor_elements_2<<<1 , 1 , 0 , stream>>> (
669- static_cast <int64_t *>(devOffsetsK), 1 , 0 , actual_b,
670- /* does not matter for single row*/ actual_b);
671- }
672- }
673- }
674512 if (is_ragged_q) {
675513 variant_pack[offset_q] = devOffsetsQ;
676514 variant_pack[offset_o] = devOffsetsO;
0 commit comments