Skip to content

Commit

Permalink
Add comments and fix some bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Jan 2, 2025
1 parent 62b5395 commit 3eed106
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 46 deletions.
34 changes: 22 additions & 12 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -512,18 +512,12 @@ template <typename Element, typename GlobalQLayout, typename SharedQLayout,
typename GlobalKLayout, typename SharedKLayout, typename TiledCopy>
inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
const Element* gK_ptr, Element* sK_ptr,
int gQ_stride, int sQ_stride, int gK_stride,
int sK_stride) {
int gQ_stride, int gK_stride) {
int tid = threadIdx.x;

auto gQ = make_tensor(make_gmem_ptr(gQ_ptr), GlobalQLayout{});
auto sQ = make_tensor(make_smem_ptr(sQ_ptr), SharedQLayout{});

if (thread0()) {
printf("sQ: \n");
print(sQ), print("\n");
}

auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{});
auto sK = make_tensor(make_smem_ptr(sK_ptr), SharedKLayout{});

Expand All @@ -536,6 +530,14 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
auto sQs = loader.partition_D(sQ);
auto sKs = loader.partition_D(sK);

int sQ_stride = size(sQ);
int sK_stride = size(sK);

if (thread0()) {
printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n",
gQ_stride, sQ_stride, gK_stride, sK_stride);
}

detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride,
sQ_stride, gK_stride, sK_stride);

Expand All @@ -544,8 +546,7 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,

template <typename Element, typename GlobalVLayout, typename SharedVLayout,
typename TiledCopy>
DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride,
int sV_stride) {
DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) {
int tid = threadIdx.x;

auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{});
Expand All @@ -558,6 +559,12 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride,
auto gVs = loader.partition_S(gV);
auto sVs = loader.partition_D(sV);

int sV_stride = size(sV);

if (thread0()) {
printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride);
}

detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride);

return copy_v;
Expand All @@ -567,7 +574,6 @@ template <typename Element, typename SQLayout, typename SKLayout,
typename RegAcc, typename SmemCopyAtom, typename TiledMma>
DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
SQLayout sQ_layout, SKLayout sK_layout, RegAcc acc,
int sQ_stride, int sK_stride,
SmemCopyAtom copy_atom = SmemCopyAtom{},
TiledMma tiled_mma = TiledMma{}) {
int tid = threadIdx.x;
Expand All @@ -593,6 +599,9 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma);
auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma);

int sQ_stride = size(sQ_);
int sK_stride = size(sK_);

detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma,
rK_copy, acc, s2r_copy_q, s2r_copy_k,
tiled_mma, sQ_stride, sK_stride);
Expand All @@ -603,8 +612,7 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
template <typename Element, typename SVLayout, typename RegAcc,
typename SmemCopyAtom, typename TiledMma>
DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc,
int sV_stride, SmemCopyAtom copy_atom,
TiledMma tiled_mma) {
SmemCopyAtom copy_atom, TiledMma tiled_mma) {
int tid = threadIdx.x;

auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout);
Expand All @@ -619,6 +627,8 @@ DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc,
auto rV_mma = thr_mma.partition_fragment_B(sV_);
auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma);

int sV_stride = size(sV_);

detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v,
tiled_mma, sV_stride);

Expand Down
58 changes: 26 additions & 32 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,12 @@ namespace cutlass_wrapper {

using namespace cute;

//
/// @brief
/// @tparam Element_
template <typename Element_, const int kM, const int kN, const int kK,
const int kP, const int kTM, const int kTN, const int kTK,
const int kTP, const int kWarpPerRow, const int kWarpPerCol,
const int kThreads, const int SmemKAtom = 64, const int kSwizzle = 3,
typename Base = AccessBase<Element_>>
struct FATraits : public Base {
// Q: [kM, kK] --> [length, hidden_qk]
// K: [kN, kK] --> [length, hidden_qk]
// V: [kP, kN] --> [length, hidden_v]
// O: [kM, kP] --> [length, hidden_v]
// assert(kM == kN)
using Element = Element_;

// Declare global to shared memory copy layout.
Expand Down Expand Up @@ -59,13 +51,6 @@ struct FATraits : public Base {

static constexpr int kWarps = kThreads / 32;

// Declare MMA Operation: [16, 8, 16] * [1, 2, 1] -> [16, 16, 16]
// Legacy code
// using TiledMma =
// TiledMMA<MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
// Layout<Shape<Int<kWarps>, _1, _1>>, Layout<Shape<_1, _2,
// _1>>>;

using TiledMma =
TiledMMA<MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
Layout<Shape<Int<kWarpPerRow>, Int<kWarpPerCol>, _1>>,
Expand Down Expand Up @@ -124,34 +109,33 @@ __global__ void __launch_bounds__(Nthreads)
auto g2s_copy_qk = make_g2s_qk<
Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ,
typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK,
typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kK, kTK, kK,
kTK);
typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kTK, kTK);

// Build the copy plan for V from global memory to shared memory.
/**
* The size of the V matrix is [kN, kP], and the size processed in a single
* SM Block is [kN, kTP]. When split along the N dimension, the size is
* [kTN, kTP]. Therefore, the stride for global memory should be set to kTN
* * kP.
*/
auto g2s_copy_v =
make_g2s_v<Element, typename KeTraits::GmemLayoutV,
typename KeTraits::SmemLayoutV,
typename KeTraits::TiledCopyG2S>(V, sV_ptr, kN, kTN);
typename KeTraits::TiledCopyG2S>(V, sV_ptr, kTN * kP);

auto acc0 = get_acc<kTM, kTN>(mma);
auto acco = get_acc<kTM, kTP>(mma);

auto m_new = make_tensor<float>(Shape<Int<2 * size<1>(acc0)>>{});
auto lse_new = make_fragment_like(m_new);

if (thread0()) {
printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n",
(int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0));
}

auto s2r_pipeline_qk =
make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{},
typename KeTraits::SmemLayoutK{}, acc0, kTK, kTK,
typename KeTraits::SmemLayoutK{}, acc0,
typename KeTraits::SmemCopyAtom{}, mma);
s2r_pipeline_qk.print_rQ();

auto s2r_pipeline_v =
make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, kTN,
make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco,
typename KeTraits::SmemCopyAtom{}, mma);

// Issue global to shared memory copy before the main loop.
Expand All @@ -161,10 +145,16 @@ __global__ void __launch_bounds__(Nthreads)
fill(m_new, -INFINITY);
clear(acco);

/**
* Flash Attention performs two-level tiling for each SM Block, splitting
* along the N dimension and the K dimension. The Q matrix is split along
* the K dimension, the V matrix is split along the N dimension, and the K
* matrix is split along both dimensions simultaneously.
*/
int split_n = kN / kTN;
for (int n = 0; n < split_n; ++n) {
clear(acc0);
int slice_k = kK / kTK - 1;
// Pipeline
for (int k = 0; k < slice_k; ++k) {
// Barrier to ensure all data are loaded into shared memory.
cp_async_wait_flash<0>();
Expand Down Expand Up @@ -280,14 +270,18 @@ __global__ void __launch_bounds__(Nthreads)
s2r_pipeline_v.epilogue(rP_Aregs);
}

auto acco_f16 = convert_type<Element>(acco);
// Store O from registers to shared memory and then to global memory.
// auto frag2 = convert_type<Element>(acco);
// auto acco_f16 = make_tensor(make_rmem_ptr<Element>(&frag2),
// acco.layout());

store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco_f16,
typename KeTraits::SmemCopyAtom{}, mma);
__syncthreads();
// store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco_f16,
// typename KeTraits::SmemCopyAtom{}, mma);
// __syncthreads();

store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{},
typename KeTraits::SmemLayoutO{}, tiled_copy_g2s);
typename KeTraits::SmemLayoutO{},
typename KeTraits::TiledCopyS2G{});
}

} // namespace cutlass_wrapper
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/flashattention/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ void run(bool check = true) {

static constexpr int kBatch = 1;

static constexpr int kWarpPerRow = 4;
static constexpr int kWarpPerRow = 1;
static constexpr int kWarpPerCol = 1;
static constexpr int kThreads = 128;
static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32;
static constexpr int kStagesQK = 1;
static constexpr int kStagesV = 1;

Expand Down

0 comments on commit 3eed106

Please sign in to comment.