From 7a10f6fbfeeae24adc7b8b5fd6cede3bc667426b Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Sat, 21 Dec 2024 17:36:28 -0800 Subject: [PATCH] Add G2SCopyV class. --- benchmarks/cpp/flashattention/copy.cuh | 68 ++++++++++++++++++-- benchmarks/cpp/flashattention/cutlass_fa.cuh | 8 ++- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index f153761..ea4793a 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include "cuda_utils.cuh" #include @@ -11,13 +12,14 @@ namespace cutlass_wrapper { using namespace cute; namespace detail { + template class G2SCopyQK { public: - __device__ G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK, - TiledCopy tiled_copy, int gQ_stride, int sQ_stride, - int gK_stride, int sK_stride, int num_stage = 2) + DEVICE G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK, + TiledCopy tiled_copy, int gQ_stride, int sQ_stride, + int gK_stride, int sK_stride, int num_stage = 2) : gQ(gQ), sQ(sQ), gK(gK), @@ -29,13 +31,14 @@ class G2SCopyQK { cur_iter(0), num_stage(num_stage) {} - inline __device__ void print_q() { + // For debugging purpose. + DEVICE void print_gQ() { if (thread0()) { print(gQ), print("\n"); } } - inline __device__ void prologue() { + DEVICE void prologue() { // Pipeline the copy operation. #pragma unroll for (int m = 0; m < size<1>(gQ); ++m) { @@ -74,14 +77,44 @@ class G2SCopyQK { SQTensor& sQ; GKTensor& gK; SKTensor& sK; - TiledCopy tiled_copy; - int gQ_stride; int sQ_stride; int gK_stride; int sK_stride; + int cur_iter; + int num_stage; +}; + +template +class G2SCopyV { + public: + DEVICE G2SCopyV(GVTensor& gV, SVTensor& sV, TiledCopy tiled_copy, + int gV_stride, int sV_stride, int num_stage = 2) + : gV(gV), + sV(sV), + gV_stride(gV_stride), + sV_stride(sV_stride), + cur_iter(0), + num_stage(num_stage) {} + + // For debugging purpose. + DEVICE void print_gV() { + if (thread0()) { + print(gV), print("\n"); + } + } + + DEVICE void prologue() { + // Pipeline the copy operation. + } + private: + GVTensor& gV; + SVTensor& sV; + TiledCopy tiled_copy; + int gV_stride; + int sV_stride; int cur_iter; int num_stage; }; @@ -117,6 +150,27 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, return copy_qk; } +template +DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride, + int sV_stride) { + int tid = threadIdx.x; + + auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{}); + auto sV = make_tensor(make_smem_ptr(sV_ptr), SharedVLayout{}); + + TiledCopy tiled_copy; + + auto loader = tiled_copy.get_thread_slice(tid); + + auto gVs = loader.partition_S(gV); + auto sVs = loader.partition_D(sV); + + detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride); + + return copy_v; +} + // template // class G2SCopyTile { // public: diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 4d5bfb9..2e37057 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -130,7 +130,13 @@ __global__ void __launch_bounds__(Nthreads) typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kK, kTK, kK, kTK); - g2s_copy_qk.print_q(); + auto g2s_copy_v = + make_g2s_v(V, sV_ptr, kN, kTN); + + g2s_copy_qk.print_gQ(); + g2s_copy_v.print_gV(); } } // namespace cutlass_wrapper