Skip to content

Commit 81d70b3

Browse files
committed
Add debug flag
1 parent 50900bc commit 81d70b3

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

benchmarks/cpp/flashattention/copy.cuh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,10 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
491491

492492
TiledCopy tiled_copy;
493493

494+
// if (thread0()) {
495+
// print_latex(tiled_copy);
496+
// }
497+
494498
auto loader = tiled_copy.get_thread_slice(tid);
495499

496500
auto gQs = loader.partition_S(gQ);
@@ -501,10 +505,12 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
501505
int sQ_stride = size(sQ);
502506
int sK_stride = size(sK);
503507

508+
#ifdef DEBUG
504509
if (thread0()) {
505510
printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n",
506511
gQ_stride, sQ_stride, gK_stride, sK_stride);
507512
}
513+
#endif
508514

509515
detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride,
510516
sQ_stride, gK_stride, sK_stride);
@@ -529,9 +535,11 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) {
529535

530536
int sV_stride = size(sV);
531537

538+
#ifdef DEBUG
532539
if (thread0()) {
533540
printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride);
534541
}
542+
#endif
535543

536544
detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride);
537545

@@ -556,6 +564,15 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
556564
auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid);
557565
auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid);
558566

567+
#ifdef DEBUG
568+
if (thread0()) {
569+
printf("sQ_Layout: ");
570+
print(sQ_layout), print('\n');
571+
printf("s2r_copy_q: ");
572+
print(s2r_copy_q), print('\n');
573+
}
574+
#endif
575+
559576
auto sQ = s2r_thr_copy_q.partition_S(sQ_);
560577
auto sK = s2r_thr_copy_k.partition_S(sK_);
561578

@@ -567,6 +584,19 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
567584
auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma);
568585
auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma);
569586

587+
#ifdef DEBUG
588+
if (thread0()) {
589+
printf("sQ_: ");
590+
print(sQ_), print('\n');
591+
printf("sQ: ");
592+
print(sQ), print('\n');
593+
printf("rQ_copy: ");
594+
print(rQ_copy), print('\n');
595+
printf("rQ_mma: ");
596+
print(rQ_mma), print('\n');
597+
}
598+
#endif
599+
570600
int sQ_stride = size(sQ_);
571601
int sK_stride = size(sK_);
572602

benchmarks/cpp/flashattention/cutlass_fa.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,14 @@ __global__ void __launch_bounds__(Nthreads)
141141
auto acc0 = get_acc<kTM, kTN>(mma);
142142
auto acco = get_acc<kTM, kTP>(mma);
143143

144+
#ifdef DEBUG
144145
if (thread0()) {
145146
printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n",
146147
(int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0));
147148
printf("acco size<0>: %d, size<1>: %d, size<2>: %d\n",
148149
(int)size<0>(acco), (int)size<1>(acco), (int)size<2>(acco));
149150
}
151+
#endif
150152

151153
/**
152154
* In TileFusion, we use
@@ -226,13 +228,15 @@ __global__ void __launch_bounds__(Nthreads)
226228
auto previous_attn_block =
227229
make_tensor(acco.data(), convert_layout_scores(acco.layout()));
228230

231+
#ifdef DEBUG
229232
if (thread0()) {
230233
printf("scores size<0>: %d, size<1>: %d\n", (int)size<0>(scores),
231234
(int)size<1>(scores));
232235
printf("previous_attn_block size<0>: %d, size<1>: %d\n",
233236
(int)size<0>(previous_attn_block),
234237
(int)size<1>(previous_attn_block));
235238
}
239+
#endif
236240

237241
// Renormalization for the previous block.
238242
for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) {

benchmarks/cpp/flashattention/main.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ void run(bool check = true) {
8989
dim3 grid(block_x, block_y, block_z);
9090
dim3 block(kThreads, 1, 1);
9191

92-
int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP);
92+
int shm_input =
93+
(kTM * kTK * kStagesQK + kTK * kTN * kStagesQK + kTN * kTP * kStagesV);
9394
int shm_output = kTM * kTP;
9495
int shm_size = shm_input < shm_output ? shm_output * sizeof(InType)
9596
: shm_input * sizeof(InType);

0 commit comments

Comments
 (0)