From 29f47ebe5b688e1b3bb1a929d9b9f581bdb36a7a Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Sun, 29 Dec 2024 06:02:46 -0800 Subject: [PATCH] fix copy v matrix from global to shared. --- benchmarks/cpp/flashattention/copy.cuh | 46 ++++++++++++++++++-- benchmarks/cpp/flashattention/cutlass_fa.cuh | 10 ++++- benchmarks/cpp/flashattention/main.cu | 4 +- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index d4d3b0c..c6b0df8 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -195,15 +195,55 @@ class G2SCopyV { } DEVICE void prologue() { - // Pipeline the copy operation. +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + + cute::cp_async_fence(); + gV.data() = gV.data() + gV_stride; + sV.data() = sV.data() + sV_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; } DEVICE void body() { - // Pipeline the copy operation. +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + + cute::cp_async_fence(); + + gV.data() = gV.data() + gV_stride; + sV.data() = sV.data() + sV_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; } DEVICE void epilogue() { - // Pipeline the copy operation. +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + cute::cp_async_fence(); } private: diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index b6bd8c3..25cf1f8 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -189,6 +189,7 @@ __global__ void __launch_bounds__(Nthreads) printf("acc0: \n"); print(acc0), print("\n"); } + // scores = dot(q, k) auto scores = make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); @@ -197,10 +198,10 @@ __global__ void __launch_bounds__(Nthreads) auto scores_max = make_fragment_like(m_new); - // Compute row max. + // scores_max = reduce_max(scores, axis=1) reduce_max<4, true>(scores, scores_max); - // Compute new max vector. + // Compute new partial max value. for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { m_new(ax0) = max(m_new(ax0), scores_max(ax0)); } @@ -236,6 +237,7 @@ __global__ void __launch_bounds__(Nthreads) // TODO: Understand the following code. auto frag = convert_type(scores); auto rP = make_tensor(make_rmem_ptr(&frag), scores.layout()); + // Why convert the layout? auto rP_Aregs = make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); @@ -252,6 +254,10 @@ __global__ void __launch_bounds__(Nthreads) cp_async_wait_flash<0>(); __syncthreads(); + if (n < split_n - 1) { + // Update the pointer of K. + } + s2r_pipeline_v.epilogue(rP_Aregs); } diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu index 640419a..788608b 100644 --- a/benchmarks/cpp/flashattention/main.cu +++ b/benchmarks/cpp/flashattention/main.cu @@ -35,8 +35,8 @@ void run(bool check = true) { static constexpr int kWarpPerRow = 4; static constexpr int kWarpPerCol = 1; static constexpr int kThreads = 128; - static constexpr int kStagesQK = 2; - static constexpr int kStagesV = 2; + static constexpr int kStagesQK = 1; + static constexpr int kStagesV = 1; static_assert(kK == kTK, "The current implementation requires kTK == K for now.");