Skip to content

Commit

Permalink
Update copy plan.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Dec 22, 2024
1 parent 7a10f6f commit 5fe4fba
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 7 deletions.
94 changes: 92 additions & 2 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,36 @@ class G2SCopyQK {
}
}

DEVICE void print_gQ_data(int tid) {
if (threadIdx.x == tid) {
printf("gQ data(%d): \n", tid);
for (int i = 0; i < size<0>(gQ); ++i) {
for (int j = 0; j < size<1>(gQ); ++j) {
for (int k = 0; k < size<2>(gQ); ++k) {
print(gQ(i, j, k)), print(" ");
}
print("\n");
}
print("\n");
}
}
}

DEVICE void print_sQ_data(int tid) {
if (threadIdx.x == tid) {
printf("sQ data(%d): \n", tid);
for (int i = 0; i < size<0>(sQ); ++i) {
for (int j = 0; j < size<1>(sQ); ++j) {
for (int k = 0; k < size<2>(sQ); ++k) {
print(sQ(i, j, k)), print(" ");
}
print("\n");
}
print("\n");
}
}
}

DEVICE void prologue() {
// Pipeline the copy operation.
#pragma unroll
Expand Down Expand Up @@ -65,13 +95,65 @@ class G2SCopyQK {

// Circlically read SMEM Buffer
if ((cur_iter + 1) % num_stage == 0) {
sQ.data() = sQ.data() - sQ_stride * num_stage;
sK.data() = sK.data() - sK_stride * num_stage;
sQ.data() = sQ.data() + (-sQ_stride * num_stage);
sK.data() = sK.data() + (-sK_stride * num_stage);
}

cur_iter++;
}

DEVICE void body() {
#pragma unroll
for (int m = 0; m < size<1>(gQ); ++m) {
#pragma unroll
for (int k = 0; k < size<2>(gQ); ++k) {
cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k));
}
}

#pragma unroll
for (int m = 0; m < size<1>(gK); ++m) {
#pragma unroll
for (int k = 0; k < size<2>(gK); ++k) {
cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k));
}
}

cute::cp_async_fence();

gQ.data() = gQ.data() + gQ_stride;
sQ.data() = sQ.data() + sQ_stride;
gK.data() = gK.data() + gK_stride;
sK.data() = sK.data() + sK_stride;

if ((cur_iter + 1) % num_stage == 0) {
sQ.data() = sQ.data() + (-sQ_stride * num_stage);
sK.data() = sK.data() + (-sK_stride * num_stage);
}

cur_iter++;
}

DEVICE void epilogue() {
#pragma unroll
for (int m = 0; m < size<1>(gQ); ++m) {
#pragma unroll
for (int k = 0; k < size<2>(gQ); ++k) {
cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k));
}
}

#pragma unroll
for (int m = 0; m < size<1>(gK); ++m) {
#pragma unroll
for (int k = 0; k < size<2>(gK); ++k) {
cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k));
}
}

cute::cp_async_fence();
}

private:
GQTensor& gQ;
SQTensor& sQ;
Expand Down Expand Up @@ -109,6 +191,14 @@ class G2SCopyV {
// Pipeline the copy operation.
}

DEVICE void body() {
// Pipeline the copy operation.
}

DEVICE void epilogue() {
// Pipeline the copy operation.
}

private:
GVTensor& gV;
SVTensor& sV;
Expand Down
29 changes: 24 additions & 5 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
#include "cutlass/copy.cuh"
#include "cutlass/traits_base.cuh"

// #include <cute/tensor.hpp>

// template <const int kM, const int kN, const int kK, const int kP>
// using FAShape = TileShape<kM, kN, kK, kP>;

namespace benchmarks {
namespace cutlass_wrapper {
using namespace cute;
Expand Down Expand Up @@ -124,19 +119,43 @@ __global__ void __launch_bounds__(Nthreads)
typename KeTraits::TiledMma mma;
typename KeTraits::TiledCopyG2S tiled_copy_g2s;

// Build the copy plan for QK from global memory to shared memory.
auto g2s_copy_qk = make_g2s_qk<
Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ,
typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK,
typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kK, kTK, kK,
kTK);

// Build the copy plan for V from global memory to shared memory.
auto g2s_copy_v =
make_g2s_v<Element, typename KeTraits::GmemLayoutV,
typename KeTraits::SmemLayoutV,
typename KeTraits::TiledCopyG2S>(V, sV_ptr, kN, kTN);

g2s_copy_qk.print_gQ();
g2s_copy_v.print_gV();
// g2s_copy_qk.print_gQ_data(0);

// Issue global to shared memory copy before the main loop.
g2s_copy_qk.prologue();

for (int n = 0; n < kN; n += kTN) {
int split_k = kK / kTK - 1;

// Pipeline
for (int k = 0; k < split_k; ++k) {
// Barrier to ensure all data are loaded into shared memory.
cp_async_wait_flash<0>();
__syncthreads();
g2s_copy_qk.body();
// Load data from shared memory into register and issue MMA.
}

cp_async_wait_flash<0>();
__syncthreads();
g2s_copy_qk.print_sQ_data(0);
g2s_copy_v.prologue();
}
}

} // namespace cutlass_wrapper
Expand Down
7 changes: 7 additions & 0 deletions benchmarks/utils/cpp/cutlass/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ DEVICE void __copy_async() {
wait_group<0>();
}

template <int N>
DEVICE void cp_async_wait_flash() {
#if defined(CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif
}

// Copy a 2d data tile from global memory to shared memory
template <typename Element, typename SrcLayout, typename DstLayout,
typename TiledCopy>
Expand Down

0 comments on commit 5fe4fba

Please sign in to comment.