Skip to content

Commit

Permalink
Add Store O Implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Dec 28, 2024
1 parent 1bd3c5c commit 7771636
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 26 deletions.
124 changes: 107 additions & 17 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,12 @@ class S2RPipelineQK {

template <typename SVTensor, typename RVMmaView, typename RVCopyView,
typename RegAcc, typename TiledCopy, typename TiledMma>
class S2RPipelineQK {
class S2RPipelineV {
public:
DEVICE S2RPipelineQK(SVTensor& sV, RVMmaView& rV_mma_view,
RVCopyView& rV_copy_view, RegAcc& acc,
TiledCopy tiled_copy, TiledMma, tiled_mma,
int sV_stride, int num_stage = 2)
DEVICE S2RPipelineV(SVTensor& sV, RVMmaView& rV_mma_view,
RVCopyView& rV_copy_view, RegAcc& acc,
TiledCopy tiled_copy, TiledMma tiled_mma, int sV_stride,
int num_stage = 2)
: sV(sV),
rV_mma_view(rV_mma_view),
rV_copy_view(rV_copy_view),
Expand All @@ -366,6 +366,8 @@ class S2RPipelineQK {
cute::copy(tiled_copy, sV(_, _, i + 1),
rV_copy_view(_, _, i + 1));
}
// TODO: Why do we need to use value(_, _, cur_iter *
// size<2>(rV_mma_view) + i)?
cute::gemm(tiled_mma,
value(_, _, cur_iter * size<2>(rV_mma_view) + i),
rV_mma_view(_, _, i), acc);
Expand All @@ -375,6 +377,55 @@ class S2RPipelineQK {
cur_iter++;
}

template <typename RegValue>
DEVICE void body(RegValue& value) {
cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{}));

#pragma unroll
for (int i = 0; i < size<2>(rV_mma_view); ++i) {
if (i < size<2>(rV_mma_view) - 1) {
cute::copy(tiled_copy, sV(_, _, i + 1),
rV_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma,
value(_, _, cur_iter * size<2>(rV_mma_view) + i),
rV_mma_view(_, _, i), acc);
}

sV.data() = sV.data() + sV_stride;
if ((cur_iter + 1) % num_stage == 0) {
sV.data() = sV.data() + (-sV_stride * num_stage);
}

cur_iter++;
cur_iter_sv++;
}

template <typename RegValue>
DEVICE void epilogue(RegValue& value) {
cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{}));

#pragma unroll
for (int i = 0; i < size<2>(rV_mma_view); ++i) {
if (i < size<2>(rV_mma_view) - 1) {
cute::copy(tiled_copy, sV(_, _, i + 1),
rV_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma,
value(_, _, cur_iter * size<2>(rV_mma_view) + i),
rV_mma_view(_, _, i), acc);
}

sV.data() = sV.data() + (-sV_stride * cur_iter_sv);

if ((cur_iter + 1) % num_stage == 0) {
sV.data() = sV.data() + (-sV_stride * num_stage);
}

cur_iter++;
cur_iter_sv = 0;
}

private:
SVTensor& sV;
RVMmaView& rV_mma_view;
Expand All @@ -386,7 +437,7 @@ class S2RPipelineQK {
int num_stage;
int cur_iter;
int cur_iter_sv;
}
};

} // namespace detail

Expand All @@ -402,10 +453,8 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
auto sQ = make_tensor(make_smem_ptr(sQ_ptr), SharedQLayout{});

if (thread0()) {
printf("gQ: \n");
print(gQ), print("\n");
printf("size<0>(gQ): %d, size<1>(gQ): %d\n", (int)size<0>(gQ),
(int)size<1>(gQ));
printf("sQ: \n");
print(sQ), print("\n");
}

auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{});
Expand Down Expand Up @@ -484,10 +533,11 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
return s2r_pipeline_qk;
}

template <typename Element, typename SVLayout, typename SmemCopyAtom,
typename TiledMma>
DEVICE void make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, int sV_stride,
SmemCopyAtom copy_atom, TiledMma tiled_mma) {
template <typename Element, typename SVLayout, typename RegAcc,
typename SmemCopyAtom, typename TiledMma>
DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc,
int sV_stride, SmemCopyAtom copy_atom,
TiledMma tiled_mma) {
int tid = threadIdx.x;

auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout);
Expand All @@ -501,10 +551,50 @@ DEVICE void make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, int sV_stride,

auto rV_mma = thr_mma.partition_fragment_B(sV_);
auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma);

detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v,
tiled_mma, sV_stride);

return s2r_pipeline_v;
}

} // namespace cutlass_wrapper
} // namespace benchmarks
template <typename Element, typename SOLayout, typename RegO,
typename SmemCopyAtom, typename TiledMma>
DEVICE auto store_r2s_o(Element* sO_ptr, SOLayout sO_layout, RegO& o,
SmemCopyAtom copy_atom, TiledMma tiled_mma) {
auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout);

auto r2s_copy_o = make_tiled_copy_C(copy_atom, tiled_mma);
auto r2s_thr_copy_o = r2s_copy_o.get_thread_slice(threadIdx.x);

auto sOs = r2s_thr_copy_o.partition_D(sO);
auto rO_copy_view = r2s_thr_copy_o.retile_S(o);

cute::copy(r2s_copy_o, rO_copy_view, sOs);
}

template <typename Element, typename GOLayout, typename SOLayout,
typename TiledCopy>
DEVICE auto store_s2g_o(Element* gO_ptr, const Element* sO_ptr,
GOLayout gO_layout, SOLayout sO_layout,
TiledCopy tiled_copy) {
auto gO = make_tensor(make_gmem_ptr(gO_ptr), gO_layout);
auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout);

auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);

auto gO_partition = thr_copy.partition_D(gO);
auto sO_partition = thr_copy.partition_S(sO);

#pragma unroll
for (int m = 0; m < size<1>(gO_partition); ++m) {
#pragma unroll
for (int n = 0; n < size<2>(gO_partition); ++n) {
cute::copy(tiled_copy, sO_partition(_, m, n),
gO_partition(_, m, n));
}
}
}

} // namespace cutlass_wrapper
} // namespace benchmarks
} // namespace benchmarks
48 changes: 45 additions & 3 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

namespace benchmarks {
namespace cutlass_wrapper {

using namespace cute;

//
Expand Down Expand Up @@ -113,7 +114,7 @@ __global__ void __launch_bounds__(Nthreads)
Element* sQ_ptr = reinterpret_cast<Element*>(buf);
Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK;
Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK;
// Element* sO_ptr = sQ_ptr;
Element* sO_ptr = sQ_ptr;

typename KeTraits::TiledMma mma;
typename KeTraits::TiledCopyG2S tiled_copy_g2s;
Expand Down Expand Up @@ -154,6 +155,10 @@ __global__ void __launch_bounds__(Nthreads)
typename KeTraits::SmemCopyAtom{}, mma);
s2r_pipeline_qk.print_rQ();

auto s2r_pipeline_v =
make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, kTN,
typename KeTraits::SmemCopyAtom{}, mma);

// Issue global to shared memory copy before the main loop.
g2s_copy_qk.prologue();

Expand Down Expand Up @@ -187,10 +192,10 @@ __global__ void __launch_bounds__(Nthreads)
auto scores =
make_tensor(acc0.data(), convert_layout_scores(acc0.layout()));

Tensor m_old = make_fragment_like(m_new);
auto m_old = make_fragment_like(m_new);
copy(m_new, m_old);

Tensor scores_max = make_fragment_like(m_new);
auto scores_max = make_fragment_like(m_new);

// Compute row max.
reduce_max<4, true>(scores, scores_max);
Expand All @@ -212,15 +217,52 @@ __global__ void __launch_bounds__(Nthreads)
}
}

for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) {
float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale);
lse_new(ax0) = lse_new(ax0) * m_scaled;
for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) {
scores(ax0, ax1) =
exp(scores(ax0, ax1) * softmax_scale - m_scaled);
}
}

auto scores_sum = make_fragment_like(lse_new);
reduce_sum<4>(scores, scores_sum);

for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) {
lse_new(ax0) = lse_new(ax0) + scores_sum(ax0);
}

// TODO: Understand the following code.
auto frag = convert_type<Element>(scores);
auto rP = make_tensor(make_rmem_ptr<Element>(&frag), scores.layout());
auto rP_Aregs =
make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout()));

// Load V into register and issue MMA.
int split_n = kN / kTN - 1;
for (int n = 0; n < split_n; ++n) {
// Barrier to ensure all data are loaded into shared memory.
cp_async_wait_flash<0>();
__syncthreads();
g2s_copy_v.body();
s2r_pipeline_v.body(rP_Aregs);
}

cp_async_wait_flash<0>();
__syncthreads();

s2r_pipeline_v.epilogue(rP_Aregs);
}

auto acco_f16 = convert_type<Element>(acco);

store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco_f16,
typename KeTraits::SmemCopyAtom{}, mma);
__syncthreads();

store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{},
typename KeTraits::SmemLayoutO{}, tiled_copy_g2s);
}

} // namespace cutlass_wrapper
Expand Down
6 changes: 0 additions & 6 deletions benchmarks/utils/cpp/cutlass/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ DEVICE void copy_tile_g2s(const Element* src_data, Element* dst_data,

auto src = loader.partition_S(gtile);
auto dst = loader.partition_D(stile);

#pragma unroll
for (int i = 0; i < int(size<1>(src)); ++i)
#pragma unroll
for (int j = 0; j < int(size<2>(src)); ++j)
cute::copy(tiled_copy, src(_, i, j), dst(_, i, j));
}

// Copy a tensor from shared memory to global memory
Expand Down

0 comments on commit 7771636

Please sign in to comment.