From 3341d2c68618fca6fea5b0d3ba6dc5fd81900bb5 Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Mon, 6 Jan 2025 01:23:46 +0000 Subject: [PATCH 1/7] Add some comments. --- benchmarks/cpp/flashattention/cutlass_fa.cuh | 25 +++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 6cae18b..0114965 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -23,6 +23,9 @@ template , Int>, Stride, _1>>; using GmemLayoutK = Layout, Int>, Stride, _1>>; @@ -137,6 +140,22 @@ __global__ void __launch_bounds__(Nthreads) auto acc0 = get_acc(mma); auto acco = get_acc(mma); + if (thread0()) { + printf("acc0 size<0>: %d, size<1>: %d\n", (int)size<0>(acc0), + (int)size<1>(acc0)); + printf("acco size<0>: %d, size<1>: %d\n", (int)size<0>(acco), + (int)size<1>(acco)); + } + + /** + * In TileFusion, we use + * ```cpp + * using RegVec = RegTile>; + * ``` + * We need to store the reduce results for both the top row and the bottom + * row simultaneously. + */ + auto m_new = make_tensor(Shape(acc0)>>{}); auto lse_new = make_fragment_like(m_new); @@ -165,6 +184,8 @@ __global__ void __launch_bounds__(Nthreads) int split_n = kN / kTN; for (int n = 0; n < split_n; ++n) { clear(acc0); + + // When `load_q_once` is true, the folling code is not executed. int slice_k = kK / kTK - 1; for (int k = 0; k < slice_k; ++k) { // Barrier to ensure all data are loaded into shared memory. @@ -178,6 +199,8 @@ __global__ void __launch_bounds__(Nthreads) cp_async_wait_flash<0>(); __syncthreads(); g2s_copy_v.prologue(); + // When `load_q_once` is true, `g2s_copy_qk.prologue()` is executed only + // once, and `s2r_pipeline_qk.epilogue()` is executed once as well. s2r_pipeline_qk.epilogue(); // scores = dot(q, k) @@ -200,7 +223,7 @@ __global__ void __launch_bounds__(Nthreads) auto acco_rowcol = make_tensor(acco.data(), convert_layout_scores(acco.layout())); - // Renormalizatio for the previous block. + // Renormalization for the previous block. for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) { float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); lse_new(ax0) = lse_new(ax0) * scale; From 7000e7a84ae965f7b0ebba2c0d1792ac9e5bf4af Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Mon, 6 Jan 2025 05:54:20 +0000 Subject: [PATCH 2/7] Normalize the attention block. --- benchmarks/cpp/flashattention/convert.cuh | 3 ++ benchmarks/cpp/flashattention/cutlass_fa.cuh | 37 +++++++++++++++----- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/benchmarks/cpp/flashattention/convert.cuh b/benchmarks/cpp/flashattention/convert.cuh index 2480eee..703513d 100644 --- a/benchmarks/cpp/flashattention/convert.cuh +++ b/benchmarks/cpp/flashattention/convert.cuh @@ -48,6 +48,9 @@ DEVICE auto convert_layout_C_Aregs() { get<1>(l), get<1>(get<2>(l))); } +/** + * @brief Convert a 3d register tensor into a 2d register tensor. + */ template DEVICE auto convert_layout_scores(LayoutType layout_s) { using namespace cute; diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 0114965..8220f40 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -141,10 +141,10 @@ __global__ void __launch_bounds__(Nthreads) auto acco = get_acc(mma); if (thread0()) { - printf("acc0 size<0>: %d, size<1>: %d\n", (int)size<0>(acc0), - (int)size<1>(acc0)); - printf("acco size<0>: %d, size<1>: %d\n", (int)size<0>(acco), - (int)size<1>(acco)); + printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", + (int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0)); + printf("acco size<0>: %d, size<1>: %d, size<2>: %d\n", + (int)size<0>(acco), (int)size<1>(acco), (int)size<2>(acco)); } /** @@ -220,15 +220,25 @@ __global__ void __launch_bounds__(Nthreads) m_new(ax0) = max(m_new(ax0), scores_max(ax0)); } - auto acco_rowcol = + // Currently, `acco` stores the results from the previous iteration's + // computation. + auto previous_attn_block = make_tensor(acco.data(), convert_layout_scores(acco.layout())); + if (thread0()) { + printf("scores size<0>: %d, size<1>: %d\n", (int)size<0>(scores), + (int)size<1>(scores)); + printf("previous_attn_block size<0>: %d, size<1>: %d\n", + (int)size<0>(previous_attn_block), + (int)size<1>(previous_attn_block)); + } + // Renormalization for the previous block. - for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) { + for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) { float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); lse_new(ax0) = lse_new(ax0) * scale; - for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) { - acco_rowcol(ax0, ax1) *= scale; + for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) { + previous_attn_block(ax0, ax1) *= scale; } } @@ -303,6 +313,17 @@ __global__ void __launch_bounds__(Nthreads) s2r_pipeline_v.epilogue(rP_Aregs); } + // Normalize the attention block. + auto attn_block = + make_tensor(acco.data(), convert_layout_scores(acco.layout())); + for (int ax0 = 0; ax0 < size<0>(attn_block); ++ax0) { + float scale = 1 / lse_new(ax0); + lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); + for (int ax1 = 0; ax1 < size<1>(attn_block); ++ax1) { + attn_block(ax0, ax1) *= scale; + } + } + // Store O from registers to shared memory and then to global memory. store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco, typename KeTraits::StoreR2SCopyAtom{}, mma); From f69a4eccb327afcc433ac5413df8616611b2ddad Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Mon, 6 Jan 2025 06:40:19 +0000 Subject: [PATCH 3/7] Add template testcase. --- benchmarks/cpp/flashattention/cutlass_fa.cuh | 1 - benchmarks/cpp/flashattention/main.cu | 30 ++++++++------------ 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 8220f40..80fcda1 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -23,7 +23,6 @@ template void run(bool check = true) { using InType = cutlass::half_t; using AccType = cutlass::half_t; using OutType = cutlass::half_t; - static constexpr int kM = 64; - static constexpr int kN = 64; - static constexpr int kK = 128; - static constexpr int kP = 128; - - static constexpr int kTM = 64; - static constexpr int kTN = 64; - static constexpr int kTK = 128; - static constexpr int kTP = 128; - + // Currently `kBatch` is fixed to 1. static constexpr int kBatch = 1; - - static constexpr int kWarpPerRow = 1; - static constexpr int kWarpPerCol = 1; static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32; - static constexpr int kStagesQK = 1; - static constexpr int kStagesV = 1; - static_assert(kK == kTK, - "The current implementation requires kTK == K for now."); + // static_assert(kK == kTK, + // "The current implementation requires kTK == K for now."); static_assert(kP == kTP, "The current implementation requires kTP == P for now."); @@ -125,4 +114,9 @@ void run(bool check = true) { cudaDeviceSynchronize(); } -int main() { run(); } +int main() { + // + run<64, 64, 128, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); + run<64, 64, 256, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); +} From 5b072ee4ecc4cae81347f5c7d9b62ae7560ccaea Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Mon, 6 Jan 2025 08:09:11 +0000 Subject: [PATCH 4/7] fix some code. --- benchmarks/cpp/flashattention/cutlass_fa.cuh | 29 ++++++++++++++++---- benchmarks/cpp/flashattention/main.cu | 2 +- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 80fcda1..ee7b568 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -95,7 +95,9 @@ template (previous_attn_block); ++ax0) { + // Compute `acc_o_scale = exp(m_i - m_ij)` float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); lse_new(ax0) = lse_new(ax0) * scale; + // Compute `acc_o = acc_o_scale * acc_o` for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) { previous_attn_block(ax0, ax1) *= scale; } } for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { - float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); - lse_new(ax0) = lse_new(ax0) * m_scaled; + // Compute `p = exp(qk - m_ij)` + float m_scaled = m_new(ax0) * softmax_scale; for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { scores(ax0, ax1) = exp(scores(ax0, ax1) * softmax_scale - m_scaled); } } + // Compute `l_ij = sum(p)`. auto scores_sum = make_fragment_like(lse_new); reduce_sum<4>(scores, scores_sum); + // Compute `l_i_new = exp(lse_i - m_ij) + l_ij`. for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); } @@ -309,17 +315,28 @@ __global__ void __launch_bounds__(Nthreads) } } + // Compute `acc_o = acc_o + dot(p, v)` s2r_pipeline_v.epilogue(rP_Aregs); + + // Compute `lse_i = m_ij + log(l_i_new)`. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); + } } // Normalize the attention block. auto attn_block = make_tensor(acco.data(), convert_layout_scores(acco.layout())); for (int ax0 = 0; ax0 < size<0>(attn_block); ++ax0) { - float scale = 1 / lse_new(ax0); - lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); + // TODO(KuangjuX): fix the following code? -> `o_scale = exp(m_i - + // lse_i)`. + + // float scale = 1 / lse_new(ax0); + float o_scale = exp(m_new(ax0) - lse_new(ax0)); + // TODO(KuangjuX): Move this code into loop? + // lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); for (int ax1 = 0; ax1 < size<1>(attn_block); ++ax1) { - attn_block(ax0, ax1) *= scale; + attn_block(ax0, ax1) *= o_scale; } } diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu index 5368ec4..314e673 100644 --- a/benchmarks/cpp/flashattention/main.cu +++ b/benchmarks/cpp/flashattention/main.cu @@ -118,5 +118,5 @@ int main() { // run<64, 64, 128, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); - run<64, 64, 256, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); + // run<64, 64, 256, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); } From 338051a6e8c4168df27ede69e3b5788fdc6872fa Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Mon, 6 Jan 2025 08:27:07 +0000 Subject: [PATCH 5/7] Add case for kK != kTK. --- benchmarks/cpp/flashattention/copy.cuh | 11 +++++++++++ benchmarks/cpp/flashattention/cutlass_fa.cuh | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index ced7ae9..05ada5b 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -47,6 +47,17 @@ class G2SCopyQK { gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride; } + /** + * @brief Reset the pointer of the global K tensor. + * + * The current function is called when `load_q_once` is true, i.e., when + * kTK == kK. In this case, the pointer of Q needs to be restored to the + * starting position. + * + * @param stride The stride in K dimension. + */ + DEVICE void reset_tile_Q(int stride) { sQ.data() = sQ.data() + (-stride); } + /** * @brief Preload the K matrix. When `load_q_once` is true, the Q matrix * only needs to be loaded once and does not require repeated loading, while diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index ee7b568..a6eb856 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -312,6 +312,13 @@ __global__ void __launch_bounds__(Nthreads) */ if (load_q_once) { g2s_copy_qk.prologue_K(); + } else { + /** + * In this case, we need to reset thr pointer of Q to the + * starting position and simultaneously preload the Q and K. + */ + g2s_copy_qk.reset_tile_Q(kK); + g2s_copy_qk.prologue(); } } From 50900bca7102333647b4c88d3004bfb9981bdab0 Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Mon, 6 Jan 2025 08:34:15 +0000 Subject: [PATCH 6/7] fix codespell error. --- benchmarks/cpp/flashattention/cutlass_fa.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index a6eb856..2b43db2 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -186,7 +186,7 @@ __global__ void __launch_bounds__(Nthreads) for (int n = 0; n < split_n; ++n) { clear(acc0); - // When `load_q_once` is true, the folling code is not executed. + // When `load_q_once` is true, the following code is not executed. int slice_k = kK / kTK - 1; for (int k = 0; k < slice_k; ++k) { // Barrier to ensure all data are loaded into shared memory. From 81d70b317186463f13143edf25e27d8971cf0ebd Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Fri, 10 Jan 2025 09:20:37 +0000 Subject: [PATCH 7/7] Add debug flag --- benchmarks/cpp/flashattention/copy.cuh | 30 ++++++++++++++++++++ benchmarks/cpp/flashattention/cutlass_fa.cuh | 4 +++ benchmarks/cpp/flashattention/main.cu | 3 +- 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index 05ada5b..8909d0b 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -491,6 +491,10 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, TiledCopy tiled_copy; + // if (thread0()) { + // print_latex(tiled_copy); + // } + auto loader = tiled_copy.get_thread_slice(tid); auto gQs = loader.partition_S(gQ); @@ -501,10 +505,12 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, int sQ_stride = size(sQ); int sK_stride = size(sK); +#ifdef DEBUG if (thread0()) { printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n", gQ_stride, sQ_stride, gK_stride, sK_stride); } +#endif detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride, sQ_stride, gK_stride, sK_stride); @@ -529,9 +535,11 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) { int sV_stride = size(sV); +#ifdef DEBUG if (thread0()) { printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride); } +#endif detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride); @@ -556,6 +564,15 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid); auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid); +#ifdef DEBUG + if (thread0()) { + printf("sQ_Layout: "); + print(sQ_layout), print('\n'); + printf("s2r_copy_q: "); + print(s2r_copy_q), print('\n'); + } +#endif + auto sQ = s2r_thr_copy_q.partition_S(sQ_); auto sK = s2r_thr_copy_k.partition_S(sK_); @@ -567,6 +584,19 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma); auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma); +#ifdef DEBUG + if (thread0()) { + printf("sQ_: "); + print(sQ_), print('\n'); + printf("sQ: "); + print(sQ), print('\n'); + printf("rQ_copy: "); + print(rQ_copy), print('\n'); + printf("rQ_mma: "); + print(rQ_mma), print('\n'); + } +#endif + int sQ_stride = size(sQ_); int sK_stride = size(sK_); diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 2b43db2..532d390 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -141,12 +141,14 @@ __global__ void __launch_bounds__(Nthreads) auto acc0 = get_acc(mma); auto acco = get_acc(mma); +#ifdef DEBUG if (thread0()) { printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", (int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0)); printf("acco size<0>: %d, size<1>: %d, size<2>: %d\n", (int)size<0>(acco), (int)size<1>(acco), (int)size<2>(acco)); } +#endif /** * In TileFusion, we use @@ -226,6 +228,7 @@ __global__ void __launch_bounds__(Nthreads) auto previous_attn_block = make_tensor(acco.data(), convert_layout_scores(acco.layout())); +#ifdef DEBUG if (thread0()) { printf("scores size<0>: %d, size<1>: %d\n", (int)size<0>(scores), (int)size<1>(scores)); @@ -233,6 +236,7 @@ __global__ void __launch_bounds__(Nthreads) (int)size<0>(previous_attn_block), (int)size<1>(previous_attn_block)); } +#endif // Renormalization for the previous block. for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) { diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu index 314e673..8b1cf16 100644 --- a/benchmarks/cpp/flashattention/main.cu +++ b/benchmarks/cpp/flashattention/main.cu @@ -89,7 +89,8 @@ void run(bool check = true) { dim3 grid(block_x, block_y, block_z); dim3 block(kThreads, 1, 1); - int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP); + int shm_input = + (kTM * kTK * kStagesQK + kTK * kTN * kStagesQK + kTN * kTP * kStagesV); int shm_output = kTM * kTP; int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) : shm_input * sizeof(InType);