diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index ea4793a..abe4e85 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -38,6 +38,36 @@ class G2SCopyQK { } } + DEVICE void print_gQ_data(int tid) { + if (threadIdx.x == tid) { + printf("gQ data(%d): \n", tid); + for (int i = 0; i < size<0>(gQ); ++i) { + for (int j = 0; j < size<1>(gQ); ++j) { + for (int k = 0; k < size<2>(gQ); ++k) { + print(gQ(i, j, k)), print(" "); + } + print("\n"); + } + print("\n"); + } + } + } + + DEVICE void print_sQ_data(int tid) { + if (threadIdx.x == tid) { + printf("sQ data(%d): \n", tid); + for (int i = 0; i < size<0>(sQ); ++i) { + for (int j = 0; j < size<1>(sQ); ++j) { + for (int k = 0; k < size<2>(sQ); ++k) { + print(sQ(i, j, k)), print(" "); + } + print("\n"); + } + print("\n"); + } + } + } + DEVICE void prologue() { // Pipeline the copy operation. #pragma unroll @@ -65,13 +95,65 @@ class G2SCopyQK { // Circlically read SMEM Buffer if ((cur_iter + 1) % num_stage == 0) { - sQ.data() = sQ.data() - sQ_stride * num_stage; - sK.data() = sK.data() - sK_stride * num_stage; + sQ.data() = sQ.data() + (-sQ_stride * num_stage); + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void body() { +#pragma unroll + for (int m = 0; m < size<1>(gQ); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } + +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + + gQ.data() = gQ.data() + gQ_stride; + sQ.data() = sQ.data() + sQ_stride; + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sQ.data() = sQ.data() + (-sQ_stride * num_stage); + sK.data() = sK.data() + (-sK_stride * num_stage); } cur_iter++; } + DEVICE void epilogue() { +#pragma unroll + for (int m = 0; m < size<1>(gQ); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } + +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + } + private: GQTensor& gQ; SQTensor& sQ; @@ -109,6 +191,14 @@ class G2SCopyV { // Pipeline the copy operation. } + DEVICE void body() { + // Pipeline the copy operation. + } + + DEVICE void epilogue() { + // Pipeline the copy operation. + } + private: GVTensor& gV; SVTensor& sV; diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 2e37057..a7fb342 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -8,11 +8,6 @@ #include "cutlass/copy.cuh" #include "cutlass/traits_base.cuh" -// #include - -// template -// using FAShape = TileShape; - namespace benchmarks { namespace cutlass_wrapper { using namespace cute; @@ -124,12 +119,14 @@ __global__ void __launch_bounds__(Nthreads) typename KeTraits::TiledMma mma; typename KeTraits::TiledCopyG2S tiled_copy_g2s; + // Build the copy plan for QK from global memory to shared memory. auto g2s_copy_qk = make_g2s_qk< Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ, typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK, typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kK, kTK, kK, kTK); + // Build the copy plan for V from global memory to shared memory. auto g2s_copy_v = make_g2s_v(); + __syncthreads(); + g2s_copy_qk.body(); + // Load data from shared memory into register and issue MMA. + } + + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_qk.print_sQ_data(0); + g2s_copy_v.prologue(); + } } } // namespace cutlass_wrapper diff --git a/benchmarks/utils/cpp/cutlass/copy.cuh b/benchmarks/utils/cpp/cutlass/copy.cuh index b3b105b..08a4c1d 100644 --- a/benchmarks/utils/cpp/cutlass/copy.cuh +++ b/benchmarks/utils/cpp/cutlass/copy.cuh @@ -32,6 +32,13 @@ DEVICE void __copy_async() { wait_group<0>(); } +template +DEVICE void cp_async_wait_flash() { +#if defined(CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + // Copy a 2d data tile from global memory to shared memory template