From 3eed106383d20cf3cacd2c30ecff308e79352a40 Mon Sep 17 00:00:00 2001 From: KuangjuX <18630816527@163.com> Date: Thu, 2 Jan 2025 11:44:58 +0000 Subject: [PATCH] Add comments and fix some bugs. --- benchmarks/cpp/flashattention/copy.cuh | 34 ++++++++---- benchmarks/cpp/flashattention/cutlass_fa.cuh | 58 +++++++++----------- benchmarks/cpp/flashattention/main.cu | 4 +- 3 files changed, 50 insertions(+), 46 deletions(-) diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index 36051df..64409af 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -512,18 +512,12 @@ template inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, const Element* gK_ptr, Element* sK_ptr, - int gQ_stride, int sQ_stride, int gK_stride, - int sK_stride) { + int gQ_stride, int gK_stride) { int tid = threadIdx.x; auto gQ = make_tensor(make_gmem_ptr(gQ_ptr), GlobalQLayout{}); auto sQ = make_tensor(make_smem_ptr(sQ_ptr), SharedQLayout{}); - if (thread0()) { - printf("sQ: \n"); - print(sQ), print("\n"); - } - auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{}); auto sK = make_tensor(make_smem_ptr(sK_ptr), SharedKLayout{}); @@ -536,6 +530,14 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, auto sQs = loader.partition_D(sQ); auto sKs = loader.partition_D(sK); + int sQ_stride = size(sQ); + int sK_stride = size(sK); + + if (thread0()) { + printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n", + gQ_stride, sQ_stride, gK_stride, sK_stride); + } + detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride, sQ_stride, gK_stride, sK_stride); @@ -544,8 +546,7 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, template -DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride, - int sV_stride) { +DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) { int tid = threadIdx.x; auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{}); @@ -558,6 +559,12 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride, auto gVs = loader.partition_S(gV); auto sVs = loader.partition_D(sV); + int sV_stride = size(sV); + + if (thread0()) { + printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride); + } + detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride); return copy_v; @@ -567,7 +574,6 @@ template DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, SQLayout sQ_layout, SKLayout sK_layout, RegAcc acc, - int sQ_stride, int sK_stride, SmemCopyAtom copy_atom = SmemCopyAtom{}, TiledMma tiled_mma = TiledMma{}) { int tid = threadIdx.x; @@ -593,6 +599,9 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma); auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma); + int sQ_stride = size(sQ_); + int sK_stride = size(sK_); + detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma, rK_copy, acc, s2r_copy_q, s2r_copy_k, tiled_mma, sQ_stride, sK_stride); @@ -603,8 +612,7 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, template DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc, - int sV_stride, SmemCopyAtom copy_atom, - TiledMma tiled_mma) { + SmemCopyAtom copy_atom, TiledMma tiled_mma) { int tid = threadIdx.x; auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout); @@ -619,6 +627,8 @@ DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc, auto rV_mma = thr_mma.partition_fragment_B(sV_); auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma); + int sV_stride = size(sV_); + detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v, tiled_mma, sV_stride); diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index c07b18f..c0a6d6b 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -15,20 +15,12 @@ namespace cutlass_wrapper { using namespace cute; -// -/// @brief -/// @tparam Element_ template > struct FATraits : public Base { - // Q: [kM, kK] --> [length, hidden_qk] - // K: [kN, kK] --> [length, hidden_qk] - // V: [kP, kN] --> [length, hidden_v] - // O: [kM, kP] --> [length, hidden_v] - // assert(kM == kN) using Element = Element_; // Declare global to shared memory copy layout. @@ -59,13 +51,6 @@ struct FATraits : public Base { static constexpr int kWarps = kThreads / 32; - // Declare MMA Operation: [16, 8, 16] * [1, 2, 1] -> [16, 16, 16] - // Legacy code - // using TiledMma = - // TiledMMA, - // Layout, _1, _1>>, Layout>>; - using TiledMma = TiledMMA, Layout, Int, _1>>, @@ -124,14 +109,19 @@ __global__ void __launch_bounds__(Nthreads) auto g2s_copy_qk = make_g2s_qk< Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ, typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK, - typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kK, kTK, kK, - kTK); + typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kTK, kTK); // Build the copy plan for V from global memory to shared memory. + /** + * The size of the V matrix is [kN, kP], and the size processed in a single + * SM Block is [kN, kTP]. When split along the N dimension, the size is + * [kTN, kTP]. Therefore, the stride for global memory should be set to kTN + * * kP. + */ auto g2s_copy_v = make_g2s_v(V, sV_ptr, kN, kTN); + typename KeTraits::TiledCopyG2S>(V, sV_ptr, kTN * kP); auto acc0 = get_acc(mma); auto acco = get_acc(mma); @@ -139,19 +129,13 @@ __global__ void __launch_bounds__(Nthreads) auto m_new = make_tensor(Shape(acc0)>>{}); auto lse_new = make_fragment_like(m_new); - if (thread0()) { - printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", - (int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0)); - } - auto s2r_pipeline_qk = make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{}, - typename KeTraits::SmemLayoutK{}, acc0, kTK, kTK, + typename KeTraits::SmemLayoutK{}, acc0, typename KeTraits::SmemCopyAtom{}, mma); - s2r_pipeline_qk.print_rQ(); auto s2r_pipeline_v = - make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, kTN, + make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, typename KeTraits::SmemCopyAtom{}, mma); // Issue global to shared memory copy before the main loop. @@ -161,10 +145,16 @@ __global__ void __launch_bounds__(Nthreads) fill(m_new, -INFINITY); clear(acco); + /** + * Flash Attention performs two-level tiling for each SM Block, splitting + * along the N dimension and the K dimension. The Q matrix is split along + * the K dimension, the V matrix is split along the N dimension, and the K + * matrix is split along both dimensions simultaneously. + */ int split_n = kN / kTN; for (int n = 0; n < split_n; ++n) { + clear(acc0); int slice_k = kK / kTK - 1; - // Pipeline for (int k = 0; k < slice_k; ++k) { // Barrier to ensure all data are loaded into shared memory. cp_async_wait_flash<0>(); @@ -280,14 +270,18 @@ __global__ void __launch_bounds__(Nthreads) s2r_pipeline_v.epilogue(rP_Aregs); } - auto acco_f16 = convert_type(acco); + // Store O from registers to shared memory and then to global memory. + // auto frag2 = convert_type(acco); + // auto acco_f16 = make_tensor(make_rmem_ptr(&frag2), + // acco.layout()); - store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco_f16, - typename KeTraits::SmemCopyAtom{}, mma); - __syncthreads(); + // store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco_f16, + // typename KeTraits::SmemCopyAtom{}, mma); + // __syncthreads(); store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{}, - typename KeTraits::SmemLayoutO{}, tiled_copy_g2s); + typename KeTraits::SmemLayoutO{}, + typename KeTraits::TiledCopyS2G{}); } } // namespace cutlass_wrapper diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu index 788608b..4567831 100644 --- a/benchmarks/cpp/flashattention/main.cu +++ b/benchmarks/cpp/flashattention/main.cu @@ -32,9 +32,9 @@ void run(bool check = true) { static constexpr int kBatch = 1; - static constexpr int kWarpPerRow = 4; + static constexpr int kWarpPerRow = 1; static constexpr int kWarpPerCol = 1; - static constexpr int kThreads = 128; + static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32; static constexpr int kStagesQK = 1; static constexpr int kStagesV = 1;