diff --git a/benchmarks/cpp/flashattention/convert.cuh b/benchmarks/cpp/flashattention/convert.cuh index 2480eee..703513d 100644 --- a/benchmarks/cpp/flashattention/convert.cuh +++ b/benchmarks/cpp/flashattention/convert.cuh @@ -48,6 +48,9 @@ DEVICE auto convert_layout_C_Aregs() { get<1>(l), get<1>(get<2>(l))); } +/** + * @brief Convert a 3d register tensor into a 2d register tensor. + */ template DEVICE auto convert_layout_scores(LayoutType layout_s) { using namespace cute; diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index ced7ae9..8909d0b 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -47,6 +47,17 @@ class G2SCopyQK { gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride; } + /** + * @brief Reset the pointer of the global K tensor. + * + * The current function is called when `load_q_once` is true, i.e., when + * kTK == kK. In this case, the pointer of Q needs to be restored to the + * starting position. + * + * @param stride The stride in K dimension. + */ + DEVICE void reset_tile_Q(int stride) { sQ.data() = sQ.data() + (-stride); } + /** * @brief Preload the K matrix. When `load_q_once` is true, the Q matrix * only needs to be loaded once and does not require repeated loading, while @@ -480,6 +491,10 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, TiledCopy tiled_copy; + // if (thread0()) { + // print_latex(tiled_copy); + // } + auto loader = tiled_copy.get_thread_slice(tid); auto gQs = loader.partition_S(gQ); @@ -490,10 +505,12 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, int sQ_stride = size(sQ); int sK_stride = size(sK); +#ifdef DEBUG if (thread0()) { printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n", gQ_stride, sQ_stride, gK_stride, sK_stride); } +#endif detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride, sQ_stride, gK_stride, sK_stride); @@ -518,9 +535,11 @@ DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) { int sV_stride = size(sV); +#ifdef DEBUG if (thread0()) { printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride); } +#endif detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride); @@ -545,6 +564,15 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid); auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid); +#ifdef DEBUG + if (thread0()) { + printf("sQ_Layout: "); + print(sQ_layout), print('\n'); + printf("s2r_copy_q: "); + print(s2r_copy_q), print('\n'); + } +#endif + auto sQ = s2r_thr_copy_q.partition_S(sQ_); auto sK = s2r_thr_copy_k.partition_S(sK_); @@ -556,6 +584,19 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma); auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma); +#ifdef DEBUG + if (thread0()) { + printf("sQ_: "); + print(sQ_), print('\n'); + printf("sQ: "); + print(sQ), print('\n'); + printf("rQ_copy: "); + print(rQ_copy), print('\n'); + printf("rQ_mma: "); + print(rQ_mma), print('\n'); + } +#endif + int sQ_stride = size(sQ_); int sK_stride = size(sK_); diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index 6cae18b..532d390 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -23,6 +23,8 @@ template , Int>, Stride, _1>>; using GmemLayoutK = Layout, Int>, Stride, _1>>; @@ -93,7 +95,9 @@ template (mma); auto acco = get_acc(mma); +#ifdef DEBUG + if (thread0()) { + printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", + (int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0)); + printf("acco size<0>: %d, size<1>: %d, size<2>: %d\n", + (int)size<0>(acco), (int)size<1>(acco), (int)size<2>(acco)); + } +#endif + + /** + * In TileFusion, we use + * ```cpp + * using RegVec = RegTile>; + * ``` + * We need to store the reduce results for both the top row and the bottom + * row simultaneously. + */ + auto m_new = make_tensor(Shape(acc0)>>{}); auto lse_new = make_fragment_like(m_new); @@ -165,6 +187,8 @@ __global__ void __launch_bounds__(Nthreads) int split_n = kN / kTN; for (int n = 0; n < split_n; ++n) { clear(acc0); + + // When `load_q_once` is true, the following code is not executed. int slice_k = kK / kTK - 1; for (int k = 0; k < slice_k; ++k) { // Barrier to ensure all data are loaded into shared memory. @@ -178,6 +202,8 @@ __global__ void __launch_bounds__(Nthreads) cp_async_wait_flash<0>(); __syncthreads(); g2s_copy_v.prologue(); + // When `load_q_once` is true, `g2s_copy_qk.prologue()` is executed only + // once, and `s2r_pipeline_qk.epilogue()` is executed once as well. s2r_pipeline_qk.epilogue(); // scores = dot(q, k) @@ -197,30 +223,46 @@ __global__ void __launch_bounds__(Nthreads) m_new(ax0) = max(m_new(ax0), scores_max(ax0)); } - auto acco_rowcol = + // Currently, `acco` stores the results from the previous iteration's + // computation. + auto previous_attn_block = make_tensor(acco.data(), convert_layout_scores(acco.layout())); - // Renormalizatio for the previous block. - for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) { +#ifdef DEBUG + if (thread0()) { + printf("scores size<0>: %d, size<1>: %d\n", (int)size<0>(scores), + (int)size<1>(scores)); + printf("previous_attn_block size<0>: %d, size<1>: %d\n", + (int)size<0>(previous_attn_block), + (int)size<1>(previous_attn_block)); + } +#endif + + // Renormalization for the previous block. + for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) { + // Compute `acc_o_scale = exp(m_i - m_ij)` float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); lse_new(ax0) = lse_new(ax0) * scale; - for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) { - acco_rowcol(ax0, ax1) *= scale; + // Compute `acc_o = acc_o_scale * acc_o` + for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) { + previous_attn_block(ax0, ax1) *= scale; } } for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { - float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); - lse_new(ax0) = lse_new(ax0) * m_scaled; + // Compute `p = exp(qk - m_ij)` + float m_scaled = m_new(ax0) * softmax_scale; for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { scores(ax0, ax1) = exp(scores(ax0, ax1) * softmax_scale - m_scaled); } } + // Compute `l_ij = sum(p)`. auto scores_sum = make_fragment_like(lse_new); reduce_sum<4>(scores, scores_sum); + // Compute `l_i_new = exp(lse_i - m_ij) + l_ij`. for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); } @@ -274,10 +316,39 @@ __global__ void __launch_bounds__(Nthreads) */ if (load_q_once) { g2s_copy_qk.prologue_K(); + } else { + /** + * In this case, we need to reset thr pointer of Q to the + * starting position and simultaneously preload the Q and K. + */ + g2s_copy_qk.reset_tile_Q(kK); + g2s_copy_qk.prologue(); } } + // Compute `acc_o = acc_o + dot(p, v)` s2r_pipeline_v.epilogue(rP_Aregs); + + // Compute `lse_i = m_ij + log(l_i_new)`. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); + } + } + + // Normalize the attention block. + auto attn_block = + make_tensor(acco.data(), convert_layout_scores(acco.layout())); + for (int ax0 = 0; ax0 < size<0>(attn_block); ++ax0) { + // TODO(KuangjuX): fix the following code? -> `o_scale = exp(m_i - + // lse_i)`. + + // float scale = 1 / lse_new(ax0); + float o_scale = exp(m_new(ax0) - lse_new(ax0)); + // TODO(KuangjuX): Move this code into loop? + // lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); + for (int ax1 = 0; ax1 < size<1>(attn_block); ++ax1) { + attn_block(ax0, ax1) *= o_scale; + } } // Store O from registers to shared memory and then to global memory. diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu index bf3fb11..8b1cf16 100644 --- a/benchmarks/cpp/flashattention/main.cu +++ b/benchmarks/cpp/flashattention/main.cu @@ -4,31 +4,20 @@ #include "cutlass_fa.cuh" #include "util.hpp" +template void run(bool check = true) { using InType = cutlass::half_t; using AccType = cutlass::half_t; using OutType = cutlass::half_t; - static constexpr int kM = 64; - static constexpr int kN = 64; - static constexpr int kK = 128; - static constexpr int kP = 128; - - static constexpr int kTM = 64; - static constexpr int kTN = 64; - static constexpr int kTK = 128; - static constexpr int kTP = 128; - + // Currently `kBatch` is fixed to 1. static constexpr int kBatch = 1; - - static constexpr int kWarpPerRow = 1; - static constexpr int kWarpPerCol = 1; static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32; - static constexpr int kStagesQK = 1; - static constexpr int kStagesV = 1; - static_assert(kK == kTK, - "The current implementation requires kTK == K for now."); + // static_assert(kK == kTK, + // "The current implementation requires kTK == K for now."); static_assert(kP == kTP, "The current implementation requires kTP == P for now."); @@ -100,7 +89,8 @@ void run(bool check = true) { dim3 grid(block_x, block_y, block_z); dim3 block(kThreads, 1, 1); - int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP); + int shm_input = + (kTM * kTK * kStagesQK + kTK * kTN * kStagesQK + kTN * kTP * kStagesV); int shm_output = kTM * kTP; int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) : shm_input * sizeof(InType); @@ -125,4 +115,9 @@ void run(bool check = true) { cudaDeviceSynchronize(); } -int main() { run(); } +int main() { + // + run<64, 64, 128, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); + // run<64, 64, 256, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); +}