Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bench): Add features and fix some bugs for pipeline flashattention. #31

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions benchmarks/cpp/flashattention/convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ DEVICE auto convert_layout_C_Aregs() {
get<1>(l), get<1>(get<2>(l)));
}

/**
* @brief Convert a 3d register tensor into a 2d register tensor.
*/
template <class LayoutType>
DEVICE auto convert_layout_scores(LayoutType layout_s) {
using namespace cute;
Expand Down
41 changes: 41 additions & 0 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ class G2SCopyQK {
gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride;
}

/**
* @brief Reset the pointer of the global K tensor.
*
* The current function is called when `load_q_once` is true, i.e., when
* kTK == kK. In this case, the pointer of Q needs to be restored to the
* starting position.
*
* @param stride The stride in K dimension.
*/
DEVICE void reset_tile_Q(int stride) { sQ.data() = sQ.data() + (-stride); }

/**
* @brief Preload the K matrix. When `load_q_once` is true, the Q matrix
* only needs to be loaded once and does not require repeated loading, while
Expand Down Expand Up @@ -480,6 +491,10 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,

TiledCopy tiled_copy;

// if (thread0()) {
// print_latex(tiled_copy);
// }

auto loader = tiled_copy.get_thread_slice(tid);

auto gQs = loader.partition_S(gQ);
Expand All @@ -490,10 +505,12 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
int sQ_stride = size(sQ);
int sK_stride = size(sK);

#ifdef DEBUG
if (thread0()) {
printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n",
gQ_stride, sQ_stride, gK_stride, sK_stride);
}
#endif

detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride,
sQ_stride, gK_stride, sK_stride);
Expand All @@ -518,9 +535,11 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) {

int sV_stride = size(sV);

#ifdef DEBUG
if (thread0()) {
printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride);
}
#endif

detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride);

Expand All @@ -545,6 +564,15 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid);
auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid);

#ifdef DEBUG
if (thread0()) {
printf("sQ_Layout: ");
print(sQ_layout), print('\n');
printf("s2r_copy_q: ");
print(s2r_copy_q), print('\n');
}
#endif

auto sQ = s2r_thr_copy_q.partition_S(sQ_);
auto sK = s2r_thr_copy_k.partition_S(sK_);

Expand All @@ -556,6 +584,19 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma);
auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma);

#ifdef DEBUG
if (thread0()) {
printf("sQ_: ");
print(sQ_), print('\n');
printf("sQ: ");
print(sQ), print('\n');
printf("rQ_copy: ");
print(rQ_copy), print('\n');
printf("rQ_mma: ");
print(rQ_mma), print('\n');
}
#endif

int sQ_stride = size(sQ_);
int sK_stride = size(sK_);

Expand Down
87 changes: 79 additions & 8 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ template <typename Element_, const int kM, const int kN, const int kK,
struct FATraits : public Base {
using Element = Element_;

static_assert(kTP == kP, "The current implementation requires kTP == P.");

// Declare global to shared memory copy layout.
using GmemLayoutQ = Layout<Shape<Int<kTM>, Int<kTK>>, Stride<Int<kK>, _1>>;
using GmemLayoutK = Layout<Shape<Int<kTN>, Int<kTK>>, Stride<Int<kK>, _1>>;
Expand Down Expand Up @@ -93,7 +95,9 @@ template <typename Element, typename KeTraits, const int kM, const int kN,
__global__ void __launch_bounds__(Nthreads)
fa_kernel(const Element* dQ, const Element* dK, const Element* dV,
Element* dO) {
constexpr float softmax_scale = 1.250000e-01f;
// constexpr float softmax_scale = 1.250000e-01f;
// TODO(KuangjuX): Use a fixed value for easy comparison.
constexpr float softmax_scale = 1.0f;
const bool load_q_once = (kTK == kK);

extern __shared__ __align__(sizeof(double)) unsigned char buf_[];
Expand Down Expand Up @@ -137,6 +141,24 @@ __global__ void __launch_bounds__(Nthreads)
auto acc0 = get_acc<kTM, kTN>(mma);
auto acco = get_acc<kTM, kTP>(mma);

#ifdef DEBUG
if (thread0()) {
printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n",
(int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0));
printf("acco size<0>: %d, size<1>: %d, size<2>: %d\n",
(int)size<0>(acco), (int)size<1>(acco), (int)size<2>(acco));
}
#endif

/**
* In TileFusion, we use
* ```cpp
* using RegVec = RegTile<InType, tl::RowMajor<kAccMs, 2>>;
* ```
* We need to store the reduce results for both the top row and the bottom
* row simultaneously.
*/

auto m_new = make_tensor<float>(Shape<Int<2 * size<1>(acc0)>>{});
auto lse_new = make_fragment_like(m_new);

Expand Down Expand Up @@ -165,6 +187,8 @@ __global__ void __launch_bounds__(Nthreads)
int split_n = kN / kTN;
for (int n = 0; n < split_n; ++n) {
clear(acc0);

// When `load_q_once` is true, the following code is not executed.
int slice_k = kK / kTK - 1;
for (int k = 0; k < slice_k; ++k) {
// Barrier to ensure all data are loaded into shared memory.
Expand All @@ -178,6 +202,8 @@ __global__ void __launch_bounds__(Nthreads)
cp_async_wait_flash<0>();
__syncthreads();
g2s_copy_v.prologue();
// When `load_q_once` is true, `g2s_copy_qk.prologue()` is executed only
// once, and `s2r_pipeline_qk.epilogue()` is executed once as well.
s2r_pipeline_qk.epilogue();

// scores = dot(q, k)
Expand All @@ -197,30 +223,46 @@ __global__ void __launch_bounds__(Nthreads)
m_new(ax0) = max(m_new(ax0), scores_max(ax0));
}

auto acco_rowcol =
// Currently, `acco` stores the results from the previous iteration's
// computation.
auto previous_attn_block =
make_tensor(acco.data(), convert_layout_scores(acco.layout()));

// Renormalizatio for the previous block.
for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) {
#ifdef DEBUG
if (thread0()) {
printf("scores size<0>: %d, size<1>: %d\n", (int)size<0>(scores),
(int)size<1>(scores));
printf("previous_attn_block size<0>: %d, size<1>: %d\n",
(int)size<0>(previous_attn_block),
(int)size<1>(previous_attn_block));
}
#endif

// Renormalization for the previous block.
for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) {
// Compute `acc_o_scale = exp(m_i - m_ij)`
float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale);
lse_new(ax0) = lse_new(ax0) * scale;
for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) {
acco_rowcol(ax0, ax1) *= scale;
// Compute `acc_o = acc_o_scale * acc_o`
for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) {
previous_attn_block(ax0, ax1) *= scale;
}
}

for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) {
float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale);
lse_new(ax0) = lse_new(ax0) * m_scaled;
// Compute `p = exp(qk - m_ij)`
float m_scaled = m_new(ax0) * softmax_scale;
for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) {
scores(ax0, ax1) =
exp(scores(ax0, ax1) * softmax_scale - m_scaled);
}
}

// Compute `l_ij = sum(p)`.
auto scores_sum = make_fragment_like(lse_new);
reduce_sum<4>(scores, scores_sum);

// Compute `l_i_new = exp(lse_i - m_ij) + l_ij`.
for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) {
lse_new(ax0) = lse_new(ax0) + scores_sum(ax0);
}
Expand Down Expand Up @@ -274,10 +316,39 @@ __global__ void __launch_bounds__(Nthreads)
*/
if (load_q_once) {
g2s_copy_qk.prologue_K();
} else {
/**
* In this case, we need to reset thr pointer of Q to the
* starting position and simultaneously preload the Q and K.
*/
g2s_copy_qk.reset_tile_Q(kK);
g2s_copy_qk.prologue();
}
}

// Compute `acc_o = acc_o + dot(p, v)`
s2r_pipeline_v.epilogue(rP_Aregs);

// Compute `lse_i = m_ij + log(l_i_new)`.
for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) {
m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In FractalTensor, the update of LSE is performed outside the loop(https://github.com/microsoft/FractalTensor/blob/artifact/artifacts/FractalTensor/benchmarks/multi-head_attention/fractaltensor/kernel.h#L278), whereas in the pseudocode, the update of LSE is done inside the loop.

}
}

// Normalize the attention block.
auto attn_block =
make_tensor(acco.data(), convert_layout_scores(acco.layout()));
for (int ax0 = 0; ax0 < size<0>(attn_block); ++ax0) {
// TODO(KuangjuX): fix the following code? -> `o_scale = exp(m_i -
// lse_i)`.

// float scale = 1 / lse_new(ax0);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the pseudocode, o_scale = exp(mi − lsei) is used to rescale the final result, while in FractalTensor, the operation is performed using 1 / lse_new(ax0)(https://github.com/microsoft/FractalTensor/blob/artifact/artifacts/FractalTensor/benchmarks/multi-head_attention/fractaltensor/kernel.h#L277).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to give quick answers to your two questions and can provide detailed explanations for the reasons. First of all, I have carefully derived it, and the pseudocode is correct. I believe you can follow it. The normalization factor is outside the loop intentionally to reduce computational complexity.

float o_scale = exp(m_new(ax0) - lse_new(ax0));
// TODO(KuangjuX): Move this code into loop?
// lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the update of LSE inside the loop.

for (int ax1 = 0; ax1 < size<1>(attn_block); ++ax1) {
attn_block(ax0, ax1) *= o_scale;
}
}

// Store O from registers to shared memory and then to global memory.
Expand Down
33 changes: 14 additions & 19 deletions benchmarks/cpp/flashattention/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,20 @@
#include "cutlass_fa.cuh"
#include "util.hpp"

template <const int kM, const int kN, const int kK, const int kP, const int kTM,
const int kTN, const int kTK, const int kTP, const int kWarpPerRow,
const int kWarpPerCol, const int kStagesQK, const int kStagesV>
void run(bool check = true) {
using InType = cutlass::half_t;
using AccType = cutlass::half_t;
using OutType = cutlass::half_t;

static constexpr int kM = 64;
static constexpr int kN = 64;
static constexpr int kK = 128;
static constexpr int kP = 128;

static constexpr int kTM = 64;
static constexpr int kTN = 64;
static constexpr int kTK = 128;
static constexpr int kTP = 128;

// Currently `kBatch` is fixed to 1.
static constexpr int kBatch = 1;

static constexpr int kWarpPerRow = 1;
static constexpr int kWarpPerCol = 1;
static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32;
static constexpr int kStagesQK = 1;
static constexpr int kStagesV = 1;

static_assert(kK == kTK,
"The current implementation requires kTK == K for now.");
// static_assert(kK == kTK,
// "The current implementation requires kTK == K for now.");
static_assert(kP == kTP,
"The current implementation requires kTP == P for now.");

Expand Down Expand Up @@ -100,7 +89,8 @@ void run(bool check = true) {
dim3 grid(block_x, block_y, block_z);
dim3 block(kThreads, 1, 1);

int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP);
int shm_input =
(kTM * kTK * kStagesQK + kTK * kTN * kStagesQK + kTN * kTP * kStagesV);
int shm_output = kTM * kTP;
int shm_size = shm_input < shm_output ? shm_output * sizeof(InType)
: shm_input * sizeof(InType);
Expand All @@ -125,4 +115,9 @@ void run(bool check = true) {
cudaDeviceSynchronize();
}

int main() { run(); }
int main() {
// <kM, kN, kK, kP, kTM, kTN, kTK, kTP, kWarpPerRow, kWarpPerCol, kStagesQK,
// kStagesV>
run<64, 64, 128, 128, 64, 64, 128, 128, 1, 1, 1, 1>();
// run<64, 64, 256, 128, 64, 64, 128, 128, 1, 1, 1, 1>();
}
Loading