diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index 60a0e13..36051df 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -76,7 +76,7 @@ class G2SCopyQK { } } - DEVICE void next_K_slice(int gK_slice, int gK_stride) { + DEVICE void update_tile_K(int gK_slice, int gK_stride) { gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride; } diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 803f33e..c07b18f 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -101,6 +101,7 @@ __global__ void __launch_bounds__(Nthreads) fa_kernel(const Element* dQ, const Element* dK, const Element* dV, Element* dO) { constexpr float softmax_scale = 1.250000e-01f; + const bool load_q_once = (kTK == kK); extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; auto* buf = reinterpret_cast(buf_); @@ -132,12 +133,6 @@ __global__ void __launch_bounds__(Nthreads) typename KeTraits::SmemLayoutV, typename KeTraits::TiledCopyG2S>(V, sV_ptr, kN, kTN); -#ifdef DEBUG - g2s_copy_qk.print_gQ(); - g2s_copy_v.print_gV(); - g2s_copy_qk.print_gQ_data(0); -#endif - auto acc0 = get_acc(mma); auto acco = get_acc(mma); @@ -168,9 +163,9 @@ __global__ void __launch_bounds__(Nthreads) int split_n = kN / kTN; for (int n = 0; n < split_n; ++n) { - int split_k = kK / kTK - 1; + int slice_k = kK / kTK - 1; // Pipeline - for (int k = 0; k < split_k; ++k) { + for (int k = 0; k < slice_k; ++k) { // Barrier to ensure all data are loaded into shared memory. cp_async_wait_flash<0>(); __syncthreads(); @@ -184,11 +179,6 @@ __global__ void __launch_bounds__(Nthreads) g2s_copy_v.prologue(); s2r_pipeline_qk.epilogue(); - // Print acc0 data. - if (thread0()) { - printf("acc0: \n"); - print(acc0), print("\n"); - } // scores = dot(q, k) auto scores = make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); @@ -241,9 +231,13 @@ __global__ void __launch_bounds__(Nthreads) auto rP_Aregs = make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); - // Load V into register and issue MMA. - int split_n = kN / kTN - 1; - for (int n = 0; n < split_n; ++n) { + /** + * In FractalTensor, the `kTN` dimension is split again. To simplify the + * current implementation of rhe pipeline flashattention, the `tile_n` + * is hardcoded to 0 at this point. + */ + const int tile_n = 0; + for (int tile_ = 0; tile_ < tile_n; ++tile_) { // Barrier to ensure all data are loaded into shared memory. cp_async_wait_flash<0>(); __syncthreads(); @@ -255,10 +249,32 @@ __global__ void __launch_bounds__(Nthreads) __syncthreads(); if (n < split_n - 1) { - // Update the pointer of K. - g2s_copy_qk.next_K_slice(kTN, kK); - // TODO(KuangjuX): Assume load q once. - g2s_copy_qk.prologue_K(); + /** + * Update K tile because the entire K Block will be processed in a + * single SM Block. + * + * For example, In `TileFusion`: + * ```cpp + * for (int n = 0; n < GIteratorV::sc0; ++n) { + * load_sv(gVs(n), sV); + * for (int k = 0; k < GIteratorQ::sc1; ++k) { + * load_sq(gQs(k), sQ); + * load_sk(gKs(k, n), sK); + * } + * } + * ``` + */ + g2s_copy_qk.update_tile_K(kTN, kK); + /** + * `load_q_once` means that at this point `kK == kTK`, and the Q is + * loaded into shared memory in blocks only once. In this case, we + * only need to update the pointer of K and do not need to update + * the pointer for Q, because the blocking along the k dimension + * will not be executed, thus the Q is always reloaded. + */ + if (load_q_once) { + g2s_copy_qk.prologue_K(); + } } s2r_pipeline_v.epilogue(rP_Aregs);