diff --git a/.vscode/settings.json b/.vscode/settings.json index 60a735a..e654000 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,7 +3,10 @@ "array": "cpp", "string": "cpp", "string_view": "cpp", - "span": "cpp" + "span": "cpp", + "bitset": "cpp", + "initializer_list": "cpp", + "utility": "cpp" }, "gotoSymbolStack.currentStackPosition": 0, "gotoSymbolStack.maxStackPosition": 0, diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index abe4e85..238dbba 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -5,6 +5,7 @@ #include "cuda_utils.cuh" #include +#include namespace benchmarks { namespace cutlass_wrapper { @@ -31,6 +32,10 @@ class G2SCopyQK { cur_iter(0), num_stage(num_stage) {} + DEVICE auto get_sQ() { return sQ; } + + DEVICE auto get_sK() { return sK; } + // For debugging purpose. DEVICE void print_gQ() { if (thread0()) { @@ -209,6 +214,45 @@ class G2SCopyV { int num_stage; }; +template +class S2RPipelineQK { + public: + DEVICE S2RPipelineQK(SQTensor& sQ, RQTensor& rQ, SKTensor& sK, RKTensor& rK, + RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k, + TiledMma tiled_mma, int sQ_stride, int sK_stride, + int num_stage = 2) + : sQ(sQ), + rQ(rQ), + sK(sK), + rK(rK), + acc(acc), + copy_q(copy_q), + copy_k(copy_k), + tiled_mma(tiled_mma), + sQ_stride(sQ_stride), + sK_stride(sK_stride), + num_stage(num_stage), + cur_iter(0), + cur_iter_sq(0) {} + + private: + SQTensor& sQ; + RQTensor& rQ; + SKTensor& sK; + RKTensor& rK; + RAccTensor& acc; + TiledCopyQ copy_q; + TiledCopyK copy_k; + TiledMma tiled_mma; + int sQ_stride; + int sK_stride; + int num_stage; + int cur_iter; + int cur_iter_sq; +}; + } // namespace detail template -// class G2SCopyTile { -// public: -// G2SCopyTile(G2STiledCopy& copy_v, GTensor& gV, STensor& sV, int -// gV_stride, -// int sV_stride, int num_stage = 2) -// : copy_v(copy_v), -// gV(gV), -// sV(sV), -// gV_stride(gV_stride), -// sV_stride(sV_stride), -// num_stage(num_stage), -// cur_iter(0) {} - -// inline __device__ void prologue() { -// #pragma unroll -// for (int m = 0; m < size<1>(gV); ++m) { -// #pragma unroll -// for (int k = 0; k < size<2>(gV); ++k) { -// cute::copy(copy_v, gV(_, m, k), sV(_, m, k)); -// } -// } - -// // Copies using `cp.async` are separted into "commit groups" using -// // `cp_async_fence()`. -// cute::cp_async_fence(); -// gV.data() = gV.data() + gV_stride; -// sV.data() = sV.data() + sV_stride; - -// // Circlically read SMEM Buffer -// if ((cur_iter + 1) % num_stage == 0) { -// sV.data() = sV.data() - sV_stride * num_stage; -// } - -// cur_iter++; -// } - -// inline __device__ void body() { -// #pragma unroll -// for (int m = 0; m < size<1>(gV); ++m) { -// #pragma unroll -// for (int k = 0; k < size<2>(gV); ++k) { -// cute::copy(copy_v, gV(_, m, k), sV(_, m, k)); -// } -// } - -// cute::cp_async_fence(); -// gV.data() = gV.data() + gV_stride; -// sV.data() = sV.data() + sV_stride; - -// if ((cur_iter + 1) % num_stage == 0) { -// sV.data() = sV.data() - sV_stride * num_stage; -// } - -// cur_iter++; -// } - -// inline __device__ void epilogue() { -// #pragma unroll -// for (int m = 0; m < size<1>(gV); ++m) { -// #pragma unroll -// for (int k = 0; k < size<2>(gV); ++k) { -// cute::copy(copy_v, gV(_, m, k), sV(_, m, k)); -// } -// } - -// cute::cp_async_fence(); -// } - -// private: -// G2STiledCopy& copy_v; -// GTensor& gV; -// STensor& sV; -// int gV_stride; -// int sV_stride; -// int num_stage; -// int cur_iter; -// } - -// template -// class S2RCopyTileQK { -// public: -// S2RCopyTileQK(TiledCopyQ& copy_q, TiledCopyK& copy_k, STensorQ& sQ, -// RTensorQ& rQ, STensorK& sK, RTensorK& rK, RTensorAcc& acc, -// TiledMMA tiled_mma, int sQ_stride, int rQ_stride, -// int sK_stride, int rK_stride, int num_stage = 2) -// : copy_q(copy_q), -// copy_k(copy_k), -// sQ(sQ), -// rQ(rQ), -// sK(sK), -// rK(rK), -// sQ_stride(sQ_stride), -// rQ_stride(rQ_stride), -// sK_stride(sK_stride), -// rK_stride(rK_stride), -// num_stage(num_stage), -// cur_iter(0), -// cur_iter_sq(0) {} - -// inline __device__ void prologue() { -// cur_iter = 0; -// cute::copy(copy_q, sQ(_, _, _0{}), rQ(_, _, _0{})); -// cute::copy(copy_k, sK(_, _, _0{}), rK(_, _, _0{})); - -// // Software pipelining Technique. -// #pragma unroll -// for (int i = 0; i < size<2>(rK); ++i) { -// if (i < size<2>(rK) - 1) { -// cute::copy(copy_q, sQ(_, _, i + 1), rQ(_, _, i + 1)); -// cute::copy(copy_k, sK(_, _, i + 1), rK(_, _, i + 1)); -// } - -// cute::gemm(tiled_mma, rQ(_, _, i), rK(_, _, i), acc); -// } - -// sQ.data() = sQ.data() + sQ_stride; -// sK.data() = sK.data() + sK_stride; -// cur_iter++; -// } - -// inline __device__ void body() { -// cute::copy(copy_q, sQ(_, _, _0{}), rQ(_, _, _0{})); -// cute::copy(copy_k, sK(_, _, _0{}), rK(_, _, _0{})); - -// // Software pipelining Technique. -// // Loading from SMEM to RMEM is handled by LSU(Load/Store Unit), -// while -// // computation is handled by a computational unit(e.g., tensor -// cores). -// #pragma unroll -// for (int i = 0; i < size<2>(rK); ++i) { -// if (i < size<2>(rK) - 1) { -// cute::copy(copy_q, sQ(_, _, i + 1), rQ(_, _, i + 1)); -// cute::copy(copy_k, sK(_, _, i + 1), rK(_, _, i + 1)); -// } - -// cute::gemm(tiled_mma, rQ(_, _, i), rK(_, _, i), acc); -// } - -// sQ.data() = sQ.data() + sQ_stride; -// sK.data() = sK.data() + sK_stride; - -// if ((cur_iter + 1) % num_stage == 0) { -// sK.data() = sK.data() - sK_stride * num_stage; -// } - -// cur_iter++; -// cur_iter_sq++; -// } - -// inline __device__ void epilogue() { -// cute::copy(copy_q, sQ(_, _, _0{}), rQ(_, _, _0{})); -// cute::copy(copy_k, sK(_, _, _0{}), rK(_, _, _0{})); - -// // Software pipelining Technique. -// #pragma unroll -// for (int i = 0; i < size<2>(rK); ++i) { -// if (i < size<2>(rK) - 1) { -// cute::copy(copy_q, sQ(_, _, i + 1), rQ(_, _, i + 1)); -// cute::copy(copy_k, sK(_, _, i + 1), rK(_, _, i + 1)); -// } - -// cute::gemm(tiled_mma, rQ(_, _, i), rK(_, _, i), acc); -// } - -// sQ.data() = sQ.data() - sQ_stride * cur_iter_sq; -// sK.data() = sK.data() + sK_stride; - -// if ((cur_iter + 1) % num_stage == 0) { -// sK.data() = sK.data() - sK_stride * num_stage; -// } - -// cur_iter++; -// cur_iter_sq = 0; -// } - -// private: -// TiledCopyQ& copy_q; -// TiledCopyK& copy_k; -// STensorQ& sQ; -// RTensorQ& rQ; -// STensorK& sK; -// RTensorK& rK; -// RTensorAcc& acc; -// TiledMMA tiled_mma; -// int sQ_stride; -// int rQ_stride; -// int sK_stride; -// int rK_stride; -// int num_stage; -// int cur_iter; -// int cur_iter_sq; -// } +template +DEVICE auto make_s2r_qk(SQTensor sQ, SKTensor sK, RegAcc acc, int sQ_stride, + int sK_stride, SmemCopyAtom copy_atom = SmemCopyAtom{}, + TiledMma tiled_mma = TiledMma{}) { + int tid = threadIdx.x; + + auto thr_mma = tiled_mma.get_thread_slice(tid); + + auto s2r_copy_q = make_tiled_copy_A(copy_atom, tiled_mma); + auto s2r_copy_k = make_tiled_copy_B(copy_atom, tiled_mma); + auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid); + auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid); + + auto rQ_org = thr_mma.partition_fragment_A(sQ); + auto rK_org = thr_mma.partition_fragment_B(sK); + + auto rQ = s2r_thr_copy_q.retile_D(rQ_org); + auto rK = s2r_thr_copy_k.retile_D(rK_org); + // auto rAcc = get_acc(rQ), size<1>(rK)>(tiled_mma); + + if (thread0()) { + printf("thr_mma: \n"); + print(thr_mma), print("\n"); + printf("s2r_copy_q: \n"); + print(s2r_copy_q), print("\n"); + printf("rQ_org: \n"); + print(rQ_org), print("\n"); + printf("rQ: \n"); + print(rQ), print("\n"); + } + + detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ, sK, rK, acc, s2r_copy_q, + s2r_copy_k, tiled_mma, sQ_stride, + sK_stride); +} } // namespace cutlass_wrapper } // namespace benchmarks \ No newline at end of file diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index a7fb342..316f3fb 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -55,6 +55,8 @@ struct FATraits : public Base { using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemCopyAtom = Copy_Atom; + static constexpr int kWarps = kThreads / 32; // Declare MMA Operation: [16, 8, 16] * [1, 2, 1] -> [16, 16, 16] @@ -132,9 +134,17 @@ __global__ void __launch_bounds__(Nthreads) typename KeTraits::SmemLayoutV, typename KeTraits::TiledCopyG2S>(V, sV_ptr, kN, kTN); +#ifdef DEBUG g2s_copy_qk.print_gQ(); g2s_copy_v.print_gV(); - // g2s_copy_qk.print_gQ_data(0); + g2s_copy_qk.print_gQ_data(0); +#endif + + auto sQ = g2s_copy_qk.get_sQ(); + auto sK = g2s_copy_qk.get_sK(); + auto acc0 = get_acc(mma); + + make_s2r_qk(sQ, sK, acc0, kTK, kTK, typename KeTraits::SmemCopyAtom{}, mma); // Issue global to shared memory copy before the main loop. g2s_copy_qk.prologue(); @@ -153,7 +163,7 @@ __global__ void __launch_bounds__(Nthreads) cp_async_wait_flash<0>(); __syncthreads(); - g2s_copy_qk.print_sQ_data(0); + // g2s_copy_qk.print_sQ_data(0); g2s_copy_v.prologue(); } }