Skip to content

Commit b54fb3a

Browse files
Clean up test code in TE common
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
1 parent 82fb8df commit b54fb3a

File tree

2 files changed

+0
-170
lines changed

2 files changed

+0
-170
lines changed

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 0 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,6 @@
4848

4949
namespace transformer_engine {
5050
namespace fused_attn {
51-
template <typename T>
52-
__global__ void print_tensor_elements_2(const T *const data, const size_t rows,
53-
const size_t start_cols, const size_t end_cols,
54-
const size_t cols) {
55-
if ((threadIdx.x == 0) && (threadIdx.y == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
56-
for (size_t i = 0; i < rows; ++i) {
57-
for (size_t j = start_cols; j < end_cols; ++j) {
58-
const size_t idx = i * cols + j;
59-
printf("%8f ", static_cast<float>(data[idx]));
60-
}
61-
printf("\n");
62-
}
63-
}
64-
}
65-
6651
void fused_attn_arbitrary_seqlen_fwd_impl(
6752
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
6853
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
@@ -474,10 +459,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
474459
if (is_bias) {
475460
variant_pack[bias] = devPtrBias;
476461
}
477-
//KL test code
478-
bool print_tensors = true;
479-
// For the thd_regular case, the actual_b = 18
480-
bool print_tensors_custom_mask = actual_b >= 300 ? true : false;
481462

482463
if (is_padding) {
483464
constexpr size_t nthreads_per_block = 128;
@@ -489,79 +470,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
489470
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
490471
static_cast<int32_t *>(devActualSeqlenKV));
491472
NVTE_CHECK_CUDA(cudaGetLastError());
492-
//std::cout << "print_tensors: " << print_tensors <<
493-
// "print_tensors_custom_mask: "
494-
// << print_tensors_custom_mask << std::endl;
495-
if (print_tensors) {
496-
if (devPtrCuSeqlensQ) {
497-
if (print_tensors_custom_mask) {
498-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
499-
static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 0, 8,
500-
/*does not matter for single row*/ actual_b);
501-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
502-
static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 1024, 1032,
503-
/*does not matter for single row*/ actual_b);
504-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
505-
static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 8184, 8192,
506-
/*does not matter for single row*/ actual_b);
507-
} else {
508-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
509-
static_cast<int32_t *>(devPtrCuSeqlensQ), 1, 0, actual_b,
510-
/*does not matter for single row*/ actual_b);
511-
}
512-
}
513-
if (devActualSeqlenQ) {
514-
if (print_tensors_custom_mask) {
515-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
516-
static_cast<int32_t *>(devActualSeqlenQ), 1, 0, 8,
517-
/*does not matter for single row*/ actual_b);
518-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
519-
static_cast<int32_t *>(devActualSeqlenQ), 1, 1024, 1032,
520-
/*does not matter for single row*/ actual_b);
521-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
522-
static_cast<int32_t *>(devActualSeqlenQ), 1, 8184, 8192,
523-
/*does not matter for single row*/ actual_b);
524-
} else {
525-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
526-
static_cast<int32_t *>(devActualSeqlenQ), 1, 0, actual_b,
527-
/*does not matter for single row*/ actual_b);
528-
}
529-
}
530-
if (devPtrCuSeqlensKV) {
531-
if (print_tensors_custom_mask) {
532-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
533-
static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 0, 8,
534-
/*does not matter for single row*/ actual_b);
535-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
536-
static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 1024, 1032,
537-
/*does not matter for single row*/ actual_b);
538-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
539-
static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 8184, 8192,
540-
/*does not matter for single row*/ actual_b);
541-
} else {
542-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
543-
static_cast<int32_t *>(devPtrCuSeqlensKV), 1, 0, actual_b,
544-
/*does not matter for single row*/ actual_b);
545-
}
546-
}
547-
if (devActualSeqlenKV) {
548-
if (print_tensors_custom_mask) {
549-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
550-
static_cast<int32_t *>(devActualSeqlenKV), 1, 0, 8,
551-
/*does not matter for single row*/ actual_b);
552-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
553-
static_cast<int32_t *>(devActualSeqlenKV), 1, 1024, 1032,
554-
/*does not matter for single row*/ actual_b);
555-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
556-
static_cast<int32_t *>(devActualSeqlenKV), 1, 8184, 8192,
557-
/*does not matter for single row*/ actual_b);
558-
} else {
559-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
560-
static_cast<int32_t *>(devActualSeqlenKV), 1, 0, actual_b,
561-
/*does not matter for single row*/ actual_b);
562-
}
563-
}
564-
}
565473
variant_pack[seq_q] = devActualSeqlenQ;
566474
variant_pack[seq_kv] = devActualSeqlenKV;
567475
}
@@ -601,76 +509,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
601509
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
602510
devOffsetsV, devOffsetsO, devOffsetsS);
603511
NVTE_CHECK_CUDA(cudaGetLastError());
604-
if (print_tensors) {
605-
if (devPtrSeqOffsetsQ) {
606-
if (print_tensors_custom_mask) {
607-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
608-
static_cast<int32_t *>(devPtrSeqOffsetsQ), 1, 0, 8,
609-
/*does not matter for single row*/ actual_b);
610-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
611-
static_cast<int32_t *>(devPtrSeqOffsetsQ), 1, 1024, 1032,
612-
/*does not matter for single row*/ actual_b);
613-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
614-
static_cast<int32_t *>(devPtrSeqOffsetsQ), 1, 8184, 8192,
615-
/*does not matter for single row*/ actual_b);
616-
} else {
617-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
618-
static_cast<int32_t *>(devPtrSeqOffsetsQ), 1, 0, actual_b,
619-
/*does not matter for single row*/ actual_b);
620-
}
621-
}
622-
if (devOffsetsQ) {
623-
if (print_tensors_custom_mask) {
624-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
625-
static_cast<int64_t *>(devOffsetsQ), 1, 0, 8,
626-
/*does not matter for single row*/ actual_b);
627-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
628-
static_cast<int64_t *>(devOffsetsQ), 1, 1024, 1032,
629-
/*does not matter for single row*/ actual_b);
630-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
631-
static_cast<int64_t *>(devOffsetsQ), 1, 8184, 8192,
632-
/*does not matter for single row*/ actual_b);
633-
} else {
634-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
635-
static_cast<int64_t *>(devOffsetsQ), 1, 0, actual_b,
636-
/*does not matter for single row*/ actual_b);
637-
}
638-
}
639-
if (devPtrSeqOffsetsKV) {
640-
if (print_tensors_custom_mask) {
641-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
642-
static_cast<int32_t *>(devPtrSeqOffsetsKV), 1, 0, 8,
643-
/*does not matter for single row*/ actual_b);
644-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
645-
static_cast<int32_t *>(devPtrSeqOffsetsKV), 1, 1024, 1032,
646-
/*does not matter for single row*/ actual_b);
647-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
648-
static_cast<int32_t *>(devPtrSeqOffsetsKV), 1, 8184, 8192,
649-
/*does not matter for single row*/ actual_b);
650-
} else {
651-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
652-
static_cast<int32_t *>(devPtrSeqOffsetsKV), 1, 0, actual_b,
653-
/*does not matter for single row*/ actual_b);
654-
}
655-
}
656-
if (devOffsetsK) {
657-
if (print_tensors_custom_mask) {
658-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
659-
static_cast<int64_t *>(devOffsetsK), 1, 0, 8,
660-
/*does not matter for single row*/ actual_b);
661-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
662-
static_cast<int64_t *>(devOffsetsK), 1, 1024, 1032,
663-
/*does not matter for single row*/ actual_b);
664-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
665-
static_cast<int64_t *>(devOffsetsK), 1, 8184, 8192,
666-
/*does not matter for single row*/ actual_b);
667-
} else {
668-
print_tensor_elements_2<<<1, 1, 0, stream>>>(
669-
static_cast<int64_t *>(devOffsetsK), 1, 0, actual_b,
670-
/*does not matter for single row*/ actual_b);
671-
}
672-
}
673-
}
674512
if (is_ragged_q) {
675513
variant_pack[offset_q] = devOffsetsQ;
676514
variant_pack[offset_o] = devOffsetsO;

transformer_engine/common/fused_attn/utils.cu

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -428,14 +428,6 @@ __device__ void cu_seqlens_padded_to_offsets_impl(
428428
OFFSETS_T *offsets_v, OFFSETS_T *offsets_o, OFFSETS_T *offsets_s) {
429429
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
430430
auto cu_seqlens_id = min(tid, actual_b);
431-
if (tid == 0) {
432-
printf("actual_b: %lld \n", (long long int)actual_b);
433-
printf("max_b: %lld \n", (long long int)max_b);
434-
printf("h: %lld \n", (long long int)h);
435-
printf("hg: %lld \n", (long long int)hg);
436-
printf("d_qk: %lld \n", (long long int)d_qk);
437-
printf("d_v: %lld \n", (long long int)d_v);
438-
}
439431
if (tid <= max_b) {
440432
if (offsets_s != nullptr) {
441433
offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id];

0 commit comments

Comments
 (0)