Skip to content

Commit

Permalink
Add Build S2RPipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Dec 23, 2024
1 parent 5fe4fba commit f3a8d2c
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 198 deletions.
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
"array": "cpp",
"string": "cpp",
"string_view": "cpp",
"span": "cpp"
"span": "cpp",
"bitset": "cpp",
"initializer_list": "cpp",
"utility": "cpp"
},
"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
Expand Down
275 changes: 80 additions & 195 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "cuda_utils.cuh"

#include <cute/tensor.hpp>
#include <cutlass/numeric_conversion.h>

namespace benchmarks {
namespace cutlass_wrapper {
Expand All @@ -31,6 +32,10 @@ class G2SCopyQK {
cur_iter(0),
num_stage(num_stage) {}

DEVICE auto get_sQ() { return sQ; }

DEVICE auto get_sK() { return sK; }

// For debugging purpose.
DEVICE void print_gQ() {
if (thread0()) {
Expand Down Expand Up @@ -209,6 +214,45 @@ class G2SCopyV {
int num_stage;
};

template <typename SQTensor, typename RQTensor, typename SKTensor,
typename RKTensor, typename RAccTensor, typename TiledCopyQ,
typename TiledCopyK, typename TiledMma>
class S2RPipelineQK {
public:
DEVICE S2RPipelineQK(SQTensor& sQ, RQTensor& rQ, SKTensor& sK, RKTensor& rK,
RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k,
TiledMma tiled_mma, int sQ_stride, int sK_stride,
int num_stage = 2)
: sQ(sQ),
rQ(rQ),
sK(sK),
rK(rK),
acc(acc),
copy_q(copy_q),
copy_k(copy_k),
tiled_mma(tiled_mma),
sQ_stride(sQ_stride),
sK_stride(sK_stride),
num_stage(num_stage),
cur_iter(0),
cur_iter_sq(0) {}

private:
SQTensor& sQ;
RQTensor& rQ;
SKTensor& sK;
RKTensor& rK;
RAccTensor& acc;
TiledCopyQ copy_q;
TiledCopyK copy_k;
TiledMma tiled_mma;
int sQ_stride;
int sK_stride;
int num_stage;
int cur_iter;
int cur_iter_sq;
};

} // namespace detail

template <typename Element, typename GlobalQLayout, typename SharedQLayout,
Expand Down Expand Up @@ -261,201 +305,42 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride,
return copy_v;
}

// template <class G2STiledCopy, class GTensor, class STensor>
// class G2SCopyTile {
// public:
// G2SCopyTile(G2STiledCopy& copy_v, GTensor& gV, STensor& sV, int
// gV_stride,
// int sV_stride, int num_stage = 2)
// : copy_v(copy_v),
// gV(gV),
// sV(sV),
// gV_stride(gV_stride),
// sV_stride(sV_stride),
// num_stage(num_stage),
// cur_iter(0) {}

// inline __device__ void prologue() {
// #pragma unroll
// for (int m = 0; m < size<1>(gV); ++m) {
// #pragma unroll
// for (int k = 0; k < size<2>(gV); ++k) {
// cute::copy(copy_v, gV(_, m, k), sV(_, m, k));
// }
// }

// // Copies using `cp.async` are separted into "commit groups" using
// // `cp_async_fence()`.
// cute::cp_async_fence();
// gV.data() = gV.data() + gV_stride;
// sV.data() = sV.data() + sV_stride;

// // Circlically read SMEM Buffer
// if ((cur_iter + 1) % num_stage == 0) {
// sV.data() = sV.data() - sV_stride * num_stage;
// }

// cur_iter++;
// }

// inline __device__ void body() {
// #pragma unroll
// for (int m = 0; m < size<1>(gV); ++m) {
// #pragma unroll
// for (int k = 0; k < size<2>(gV); ++k) {
// cute::copy(copy_v, gV(_, m, k), sV(_, m, k));
// }
// }

// cute::cp_async_fence();
// gV.data() = gV.data() + gV_stride;
// sV.data() = sV.data() + sV_stride;

// if ((cur_iter + 1) % num_stage == 0) {
// sV.data() = sV.data() - sV_stride * num_stage;
// }

// cur_iter++;
// }

// inline __device__ void epilogue() {
// #pragma unroll
// for (int m = 0; m < size<1>(gV); ++m) {
// #pragma unroll
// for (int k = 0; k < size<2>(gV); ++k) {
// cute::copy(copy_v, gV(_, m, k), sV(_, m, k));
// }
// }

// cute::cp_async_fence();
// }

// private:
// G2STiledCopy& copy_v;
// GTensor& gV;
// STensor& sV;
// int gV_stride;
// int sV_stride;
// int num_stage;
// int cur_iter;
// }

// template <class TiledCopyQ, class TiledCopyK, class STensorQ, class RTensorQ,
// class STensorK, class RTensorK, class RTensorAcc, class TiledMMA>
// class S2RCopyTileQK {
// public:
// S2RCopyTileQK(TiledCopyQ& copy_q, TiledCopyK& copy_k, STensorQ& sQ,
// RTensorQ& rQ, STensorK& sK, RTensorK& rK, RTensorAcc& acc,
// TiledMMA tiled_mma, int sQ_stride, int rQ_stride,
// int sK_stride, int rK_stride, int num_stage = 2)
// : copy_q(copy_q),
// copy_k(copy_k),
// sQ(sQ),
// rQ(rQ),
// sK(sK),
// rK(rK),
// sQ_stride(sQ_stride),
// rQ_stride(rQ_stride),
// sK_stride(sK_stride),
// rK_stride(rK_stride),
// num_stage(num_stage),
// cur_iter(0),
// cur_iter_sq(0) {}

// inline __device__ void prologue() {
// cur_iter = 0;
// cute::copy(copy_q, sQ(_, _, _0{}), rQ(_, _, _0{}));
// cute::copy(copy_k, sK(_, _, _0{}), rK(_, _, _0{}));

// // Software pipelining Technique.
// #pragma unroll
// for (int i = 0; i < size<2>(rK); ++i) {
// if (i < size<2>(rK) - 1) {
// cute::copy(copy_q, sQ(_, _, i + 1), rQ(_, _, i + 1));
// cute::copy(copy_k, sK(_, _, i + 1), rK(_, _, i + 1));
// }

// cute::gemm(tiled_mma, rQ(_, _, i), rK(_, _, i), acc);
// }

// sQ.data() = sQ.data() + sQ_stride;
// sK.data() = sK.data() + sK_stride;
// cur_iter++;
// }

// inline __device__ void body() {
// cute::copy(copy_q, sQ(_, _, _0{}), rQ(_, _, _0{}));
// cute::copy(copy_k, sK(_, _, _0{}), rK(_, _, _0{}));

// // Software pipelining Technique.
// // Loading from SMEM to RMEM is handled by LSU(Load/Store Unit),
// while
// // computation is handled by a computational unit(e.g., tensor
// cores).
// #pragma unroll
// for (int i = 0; i < size<2>(rK); ++i) {
// if (i < size<2>(rK) - 1) {
// cute::copy(copy_q, sQ(_, _, i + 1), rQ(_, _, i + 1));
// cute::copy(copy_k, sK(_, _, i + 1), rK(_, _, i + 1));
// }

// cute::gemm(tiled_mma, rQ(_, _, i), rK(_, _, i), acc);
// }

// sQ.data() = sQ.data() + sQ_stride;
// sK.data() = sK.data() + sK_stride;

// if ((cur_iter + 1) % num_stage == 0) {
// sK.data() = sK.data() - sK_stride * num_stage;
// }

// cur_iter++;
// cur_iter_sq++;
// }

// inline __device__ void epilogue() {
// cute::copy(copy_q, sQ(_, _, _0{}), rQ(_, _, _0{}));
// cute::copy(copy_k, sK(_, _, _0{}), rK(_, _, _0{}));

// // Software pipelining Technique.
// #pragma unroll
// for (int i = 0; i < size<2>(rK); ++i) {
// if (i < size<2>(rK) - 1) {
// cute::copy(copy_q, sQ(_, _, i + 1), rQ(_, _, i + 1));
// cute::copy(copy_k, sK(_, _, i + 1), rK(_, _, i + 1));
// }

// cute::gemm(tiled_mma, rQ(_, _, i), rK(_, _, i), acc);
// }

// sQ.data() = sQ.data() - sQ_stride * cur_iter_sq;
// sK.data() = sK.data() + sK_stride;

// if ((cur_iter + 1) % num_stage == 0) {
// sK.data() = sK.data() - sK_stride * num_stage;
// }

// cur_iter++;
// cur_iter_sq = 0;
// }

// private:
// TiledCopyQ& copy_q;
// TiledCopyK& copy_k;
// STensorQ& sQ;
// RTensorQ& rQ;
// STensorK& sK;
// RTensorK& rK;
// RTensorAcc& acc;
// TiledMMA tiled_mma;
// int sQ_stride;
// int rQ_stride;
// int sK_stride;
// int rK_stride;
// int num_stage;
// int cur_iter;
// int cur_iter_sq;
// }
template <typename SQTensor, typename SKTensor, typename RegAcc,
typename SmemCopyAtom, typename TiledMma>
DEVICE auto make_s2r_qk(SQTensor sQ, SKTensor sK, RegAcc acc, int sQ_stride,
int sK_stride, SmemCopyAtom copy_atom = SmemCopyAtom{},
TiledMma tiled_mma = TiledMma{}) {
int tid = threadIdx.x;

auto thr_mma = tiled_mma.get_thread_slice(tid);

auto s2r_copy_q = make_tiled_copy_A(copy_atom, tiled_mma);
auto s2r_copy_k = make_tiled_copy_B(copy_atom, tiled_mma);
auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid);
auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid);

auto rQ_org = thr_mma.partition_fragment_A(sQ);
auto rK_org = thr_mma.partition_fragment_B(sK);

auto rQ = s2r_thr_copy_q.retile_D(rQ_org);
auto rK = s2r_thr_copy_k.retile_D(rK_org);
// auto rAcc = get_acc<size<0>(rQ), size<1>(rK)>(tiled_mma);

if (thread0()) {
printf("thr_mma: \n");
print(thr_mma), print("\n");
printf("s2r_copy_q: \n");
print(s2r_copy_q), print("\n");
printf("rQ_org: \n");
print(rQ_org), print("\n");
printf("rQ: \n");
print(rQ), print("\n");
}

detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ, sK, rK, acc, s2r_copy_q,
s2r_copy_k, tiled_mma, sQ_stride,
sK_stride);
}

} // namespace cutlass_wrapper
} // namespace benchmarks
14 changes: 12 additions & 2 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ struct FATraits : public Base {
using SmemLayoutO =
decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTM>, Int<kTP>>{}));

using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>;

static constexpr int kWarps = kThreads / 32;

// Declare MMA Operation: [16, 8, 16] * [1, 2, 1] -> [16, 16, 16]
Expand Down Expand Up @@ -132,9 +134,17 @@ __global__ void __launch_bounds__(Nthreads)
typename KeTraits::SmemLayoutV,
typename KeTraits::TiledCopyG2S>(V, sV_ptr, kN, kTN);

#ifdef DEBUG
g2s_copy_qk.print_gQ();
g2s_copy_v.print_gV();
// g2s_copy_qk.print_gQ_data(0);
g2s_copy_qk.print_gQ_data(0);
#endif

auto sQ = g2s_copy_qk.get_sQ();
auto sK = g2s_copy_qk.get_sK();
auto acc0 = get_acc<kTM, kTN>(mma);

make_s2r_qk(sQ, sK, acc0, kTK, kTK, typename KeTraits::SmemCopyAtom{}, mma);

// Issue global to shared memory copy before the main loop.
g2s_copy_qk.prologue();
Expand All @@ -153,7 +163,7 @@ __global__ void __launch_bounds__(Nthreads)

cp_async_wait_flash<0>();
__syncthreads();
g2s_copy_qk.print_sQ_data(0);
// g2s_copy_qk.print_sQ_data(0);
g2s_copy_v.prologue();
}
}
Expand Down

0 comments on commit f3a8d2c

Please sign in to comment.