Skip to content

Commit

Permalink
Add some comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Jan 2, 2025
1 parent 489f333 commit 62b5395
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
2 changes: 1 addition & 1 deletion benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class G2SCopyQK {
}
}

DEVICE void next_K_slice(int gK_slice, int gK_stride) {
DEVICE void update_tile_K(int gK_slice, int gK_stride) {
gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride;
}

Expand Down
56 changes: 36 additions & 20 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ __global__ void __launch_bounds__(Nthreads)
fa_kernel(const Element* dQ, const Element* dK, const Element* dV,
Element* dO) {
constexpr float softmax_scale = 1.250000e-01f;
const bool load_q_once = (kTK == kK);

extern __shared__ __align__(sizeof(double)) unsigned char buf_[];
auto* buf = reinterpret_cast<Element*>(buf_);
Expand Down Expand Up @@ -132,12 +133,6 @@ __global__ void __launch_bounds__(Nthreads)
typename KeTraits::SmemLayoutV,
typename KeTraits::TiledCopyG2S>(V, sV_ptr, kN, kTN);

#ifdef DEBUG
g2s_copy_qk.print_gQ();
g2s_copy_v.print_gV();
g2s_copy_qk.print_gQ_data(0);
#endif

auto acc0 = get_acc<kTM, kTN>(mma);
auto acco = get_acc<kTM, kTP>(mma);

Expand Down Expand Up @@ -168,9 +163,9 @@ __global__ void __launch_bounds__(Nthreads)

int split_n = kN / kTN;
for (int n = 0; n < split_n; ++n) {
int split_k = kK / kTK - 1;
int slice_k = kK / kTK - 1;
// Pipeline
for (int k = 0; k < split_k; ++k) {
for (int k = 0; k < slice_k; ++k) {
// Barrier to ensure all data are loaded into shared memory.
cp_async_wait_flash<0>();
__syncthreads();
Expand All @@ -184,11 +179,6 @@ __global__ void __launch_bounds__(Nthreads)
g2s_copy_v.prologue();
s2r_pipeline_qk.epilogue();

// Print acc0 data.
if (thread0()) {
printf("acc0: \n");
print(acc0), print("\n");
}
// scores = dot(q, k)
auto scores =
make_tensor(acc0.data(), convert_layout_scores(acc0.layout()));
Expand Down Expand Up @@ -241,9 +231,13 @@ __global__ void __launch_bounds__(Nthreads)
auto rP_Aregs =
make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout()));

// Load V into register and issue MMA.
int split_n = kN / kTN - 1;
for (int n = 0; n < split_n; ++n) {
/**
* In FractalTensor, the `kTN` dimension is split again. To simplify the
* current implementation of rhe pipeline flashattention, the `tile_n`
* is hardcoded to 0 at this point.
*/
const int tile_n = 0;
for (int tile_ = 0; tile_ < tile_n; ++tile_) {
// Barrier to ensure all data are loaded into shared memory.
cp_async_wait_flash<0>();
__syncthreads();
Expand All @@ -255,10 +249,32 @@ __global__ void __launch_bounds__(Nthreads)
__syncthreads();

if (n < split_n - 1) {
// Update the pointer of K.
g2s_copy_qk.next_K_slice(kTN, kK);
// TODO(KuangjuX): Assume load q once.
g2s_copy_qk.prologue_K();
/**
* Update K tile because the entire K Block will be processed in a
* single SM Block.
*
* For example, In `TileFusion`:
* ```cpp
* for (int n = 0; n < GIteratorV::sc0; ++n) {
* load_sv(gVs(n), sV);
* for (int k = 0; k < GIteratorQ::sc1; ++k) {
* load_sq(gQs(k), sQ);
* load_sk(gKs(k, n), sK);
* }
* }
* ```
*/
g2s_copy_qk.update_tile_K(kTN, kK);
/**
* `load_q_once` means that at this point `kK == kTK`, and the Q is
* loaded into shared memory in blocks only once. In this case, we
* only need to update the pointer of K and do not need to update
* the pointer for Q, because the blocking along the k dimension
* will not be executed, thus the Q is always reloaded.
*/
if (load_q_once) {
g2s_copy_qk.prologue_K();
}
}

s2r_pipeline_v.epilogue(rP_Aregs);
Expand Down

0 comments on commit 62b5395

Please sign in to comment.