Skip to content

Commit

Permalink
Add G2SCopyV class.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Dec 22, 2024
1 parent 8a81507 commit 7a10f6f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
68 changes: 61 additions & 7 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#pragma once
#include "cuda_utils.cuh"

#include <cute/tensor.hpp>

Expand All @@ -11,13 +12,14 @@ namespace cutlass_wrapper {
using namespace cute;

namespace detail {

template <typename GQTensor, typename SQTensor, typename GKTensor,
typename SKTensor, typename TiledCopy>
class G2SCopyQK {
public:
__device__ G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK,
TiledCopy tiled_copy, int gQ_stride, int sQ_stride,
int gK_stride, int sK_stride, int num_stage = 2)
DEVICE G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK,
TiledCopy tiled_copy, int gQ_stride, int sQ_stride,
int gK_stride, int sK_stride, int num_stage = 2)
: gQ(gQ),
sQ(sQ),
gK(gK),
Expand All @@ -29,13 +31,14 @@ class G2SCopyQK {
cur_iter(0),
num_stage(num_stage) {}

inline __device__ void print_q() {
// For debugging purpose.
DEVICE void print_gQ() {
if (thread0()) {
print(gQ), print("\n");
}
}

inline __device__ void prologue() {
DEVICE void prologue() {
// Pipeline the copy operation.
#pragma unroll
for (int m = 0; m < size<1>(gQ); ++m) {
Expand Down Expand Up @@ -74,14 +77,44 @@ class G2SCopyQK {
SQTensor& sQ;
GKTensor& gK;
SKTensor& sK;

TiledCopy tiled_copy;

int gQ_stride;
int sQ_stride;
int gK_stride;
int sK_stride;
int cur_iter;
int num_stage;
};

template <typename GVTensor, typename SVTensor, typename TiledCopy>
class G2SCopyV {
public:
DEVICE G2SCopyV(GVTensor& gV, SVTensor& sV, TiledCopy tiled_copy,
int gV_stride, int sV_stride, int num_stage = 2)
: gV(gV),
sV(sV),
gV_stride(gV_stride),
sV_stride(sV_stride),
cur_iter(0),
num_stage(num_stage) {}

// For debugging purpose.
DEVICE void print_gV() {
if (thread0()) {
print(gV), print("\n");
}
}

DEVICE void prologue() {
// Pipeline the copy operation.
}

private:
GVTensor& gV;
SVTensor& sV;
TiledCopy tiled_copy;
int gV_stride;
int sV_stride;
int cur_iter;
int num_stage;
};
Expand Down Expand Up @@ -117,6 +150,27 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
return copy_qk;
}

template <typename Element, typename GlobalVLayout, typename SharedVLayout,
typename TiledCopy>
DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride,
int sV_stride) {
int tid = threadIdx.x;

auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{});
auto sV = make_tensor(make_smem_ptr(sV_ptr), SharedVLayout{});

TiledCopy tiled_copy;

auto loader = tiled_copy.get_thread_slice(tid);

auto gVs = loader.partition_S(gV);
auto sVs = loader.partition_D(sV);

detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride);

return copy_v;
}

// template <class G2STiledCopy, class GTensor, class STensor>
// class G2SCopyTile {
// public:
Expand Down
8 changes: 7 additions & 1 deletion benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,13 @@ __global__ void __launch_bounds__(Nthreads)
typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kK, kTK, kK,
kTK);

g2s_copy_qk.print_q();
auto g2s_copy_v =
make_g2s_v<Element, typename KeTraits::GmemLayoutV,
typename KeTraits::SmemLayoutV,
typename KeTraits::TiledCopyG2S>(V, sV_ptr, kN, kTN);

g2s_copy_qk.print_gQ();
g2s_copy_v.print_gV();
}

} // namespace cutlass_wrapper
Expand Down

0 comments on commit 7a10f6f

Please sign in to comment.