Skip to content

Commit

Permalink
fix copy v matrix from global to shared.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Dec 30, 2024
1 parent 49e9cf7 commit 29f47eb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
46 changes: 43 additions & 3 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,55 @@ class G2SCopyV {
}

DEVICE void prologue() {
// Pipeline the copy operation.
#pragma unroll
for (int m = 0; m < size<1>(gV); ++m) {
#pragma unroll
for (int k = 0; k < size<2>(gV); ++k) {
cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k));
}
}

cute::cp_async_fence();
gV.data() = gV.data() + gV_stride;
sV.data() = sV.data() + sV_stride;

if ((cur_iter + 1) % num_stage == 0) {
sV.data() = sV.data() + (-sV_stride * num_stage);
}

cur_iter++;
}

DEVICE void body() {
// Pipeline the copy operation.
#pragma unroll
for (int m = 0; m < size<1>(gV); ++m) {
#pragma unroll
for (int k = 0; k < size<2>(gV); ++k) {
cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k));
}
}

cute::cp_async_fence();

gV.data() = gV.data() + gV_stride;
sV.data() = sV.data() + sV_stride;

if ((cur_iter + 1) % num_stage == 0) {
sV.data() = sV.data() + (-sV_stride * num_stage);
}

cur_iter++;
}

DEVICE void epilogue() {
// Pipeline the copy operation.
#pragma unroll
for (int m = 0; m < size<1>(gV); ++m) {
#pragma unroll
for (int k = 0; k < size<2>(gV); ++k) {
cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k));
}
}
cute::cp_async_fence();
}

private:
Expand Down
10 changes: 8 additions & 2 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ __global__ void __launch_bounds__(Nthreads)
printf("acc0: \n");
print(acc0), print("\n");
}
// scores = dot(q, k)
auto scores =
make_tensor(acc0.data(), convert_layout_scores(acc0.layout()));

Expand All @@ -197,10 +198,10 @@ __global__ void __launch_bounds__(Nthreads)

auto scores_max = make_fragment_like(m_new);

// Compute row max.
// scores_max = reduce_max(scores, axis=1)
reduce_max<4, true>(scores, scores_max);

// Compute new max vector.
// Compute new partial max value.
for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) {
m_new(ax0) = max(m_new(ax0), scores_max(ax0));
}
Expand Down Expand Up @@ -236,6 +237,7 @@ __global__ void __launch_bounds__(Nthreads)
// TODO: Understand the following code.
auto frag = convert_type<Element>(scores);
auto rP = make_tensor(make_rmem_ptr<Element>(&frag), scores.layout());
// Why convert the layout?
auto rP_Aregs =
make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout()));

Expand All @@ -252,6 +254,10 @@ __global__ void __launch_bounds__(Nthreads)
cp_async_wait_flash<0>();
__syncthreads();

if (n < split_n - 1) {
// Update the pointer of K.
}

s2r_pipeline_v.epilogue(rP_Aregs);
}

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/flashattention/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ void run(bool check = true) {
static constexpr int kWarpPerRow = 4;
static constexpr int kWarpPerCol = 1;
static constexpr int kThreads = 128;
static constexpr int kStagesQK = 2;
static constexpr int kStagesV = 2;
static constexpr int kStagesQK = 1;
static constexpr int kStagesV = 1;

static_assert(kK == kTK,
"The current implementation requires kTK == K for now.");
Expand Down

0 comments on commit 29f47eb

Please sign in to comment.