diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index b8b81b5..d4d3b0c 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -340,12 +340,12 @@ class S2RPipelineQK { template -class S2RPipelineQK { +class S2RPipelineV { public: - DEVICE S2RPipelineQK(SVTensor& sV, RVMmaView& rV_mma_view, - RVCopyView& rV_copy_view, RegAcc& acc, - TiledCopy tiled_copy, TiledMma, tiled_mma, - int sV_stride, int num_stage = 2) + DEVICE S2RPipelineV(SVTensor& sV, RVMmaView& rV_mma_view, + RVCopyView& rV_copy_view, RegAcc& acc, + TiledCopy tiled_copy, TiledMma tiled_mma, int sV_stride, + int num_stage = 2) : sV(sV), rV_mma_view(rV_mma_view), rV_copy_view(rV_copy_view), @@ -366,6 +366,8 @@ class S2RPipelineQK { cute::copy(tiled_copy, sV(_, _, i + 1), rV_copy_view(_, _, i + 1)); } + // TODO: Why do we need to use value(_, _, cur_iter * + // size<2>(rV_mma_view) + i)? cute::gemm(tiled_mma, value(_, _, cur_iter * size<2>(rV_mma_view) + i), rV_mma_view(_, _, i), acc); @@ -375,6 +377,55 @@ class S2RPipelineQK { cur_iter++; } + template + DEVICE void body(RegValue& value) { + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + sV_stride; + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + cur_iter_sv++; + } + + template + DEVICE void epilogue(RegValue& value) { + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + (-sV_stride * cur_iter_sv); + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + cur_iter_sv = 0; + } + private: SVTensor& sV; RVMmaView& rV_mma_view; @@ -386,7 +437,7 @@ class S2RPipelineQK { int num_stage; int cur_iter; int cur_iter_sv; -} +}; } // namespace detail @@ -402,10 +453,8 @@ inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, auto sQ = make_tensor(make_smem_ptr(sQ_ptr), SharedQLayout{}); if (thread0()) { - printf("gQ: \n"); - print(gQ), print("\n"); - printf("size<0>(gQ): %d, size<1>(gQ): %d\n", (int)size<0>(gQ), - (int)size<1>(gQ)); + printf("sQ: \n"); + print(sQ), print("\n"); } auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{}); @@ -484,10 +533,11 @@ DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, return s2r_pipeline_qk; } -template -DEVICE void make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, int sV_stride, - SmemCopyAtom copy_atom, TiledMma tiled_mma) { +template +DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc, + int sV_stride, SmemCopyAtom copy_atom, + TiledMma tiled_mma) { int tid = threadIdx.x; auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout); @@ -501,10 +551,50 @@ DEVICE void make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, int sV_stride, auto rV_mma = thr_mma.partition_fragment_B(sV_); auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma); + + detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v, + tiled_mma, sV_stride); + + return s2r_pipeline_v; } -} // namespace cutlass_wrapper -} // namespace benchmarks +template +DEVICE auto store_r2s_o(Element* sO_ptr, SOLayout sO_layout, RegO& o, + SmemCopyAtom copy_atom, TiledMma tiled_mma) { + auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout); + + auto r2s_copy_o = make_tiled_copy_C(copy_atom, tiled_mma); + auto r2s_thr_copy_o = r2s_copy_o.get_thread_slice(threadIdx.x); + + auto sOs = r2s_thr_copy_o.partition_D(sO); + auto rO_copy_view = r2s_thr_copy_o.retile_S(o); + + cute::copy(r2s_copy_o, rO_copy_view, sOs); +} + +template +DEVICE auto store_s2g_o(Element* gO_ptr, const Element* sO_ptr, + GOLayout gO_layout, SOLayout sO_layout, + TiledCopy tiled_copy) { + auto gO = make_tensor(make_gmem_ptr(gO_ptr), gO_layout); + auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout); + + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + + auto gO_partition = thr_copy.partition_D(gO); + auto sO_partition = thr_copy.partition_S(sO); + +#pragma unroll + for (int m = 0; m < size<1>(gO_partition); ++m) { +#pragma unroll + for (int n = 0; n < size<2>(gO_partition); ++n) { + cute::copy(tiled_copy, sO_partition(_, m, n), + gO_partition(_, m, n)); + } + } +} } // namespace cutlass_wrapper -} // namespace benchmarks \ No newline at end of file +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index ae9661a..b6bd8c3 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -12,6 +12,7 @@ namespace benchmarks { namespace cutlass_wrapper { + using namespace cute; // @@ -113,7 +114,7 @@ __global__ void __launch_bounds__(Nthreads) Element* sQ_ptr = reinterpret_cast(buf); Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK; Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK; - // Element* sO_ptr = sQ_ptr; + Element* sO_ptr = sQ_ptr; typename KeTraits::TiledMma mma; typename KeTraits::TiledCopyG2S tiled_copy_g2s; @@ -154,6 +155,10 @@ __global__ void __launch_bounds__(Nthreads) typename KeTraits::SmemCopyAtom{}, mma); s2r_pipeline_qk.print_rQ(); + auto s2r_pipeline_v = + make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, kTN, + typename KeTraits::SmemCopyAtom{}, mma); + // Issue global to shared memory copy before the main loop. g2s_copy_qk.prologue(); @@ -187,10 +192,10 @@ __global__ void __launch_bounds__(Nthreads) auto scores = make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); - Tensor m_old = make_fragment_like(m_new); + auto m_old = make_fragment_like(m_new); copy(m_new, m_old); - Tensor scores_max = make_fragment_like(m_new); + auto scores_max = make_fragment_like(m_new); // Compute row max. reduce_max<4, true>(scores, scores_max); @@ -212,6 +217,28 @@ __global__ void __launch_bounds__(Nthreads) } } + for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { + float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); + lse_new(ax0) = lse_new(ax0) * m_scaled; + for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { + scores(ax0, ax1) = + exp(scores(ax0, ax1) * softmax_scale - m_scaled); + } + } + + auto scores_sum = make_fragment_like(lse_new); + reduce_sum<4>(scores, scores_sum); + + for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { + lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); + } + + // TODO: Understand the following code. + auto frag = convert_type(scores); + auto rP = make_tensor(make_rmem_ptr(&frag), scores.layout()); + auto rP_Aregs = + make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); + // Load V into register and issue MMA. int split_n = kN / kTN - 1; for (int n = 0; n < split_n; ++n) { @@ -219,8 +246,23 @@ __global__ void __launch_bounds__(Nthreads) cp_async_wait_flash<0>(); __syncthreads(); g2s_copy_v.body(); + s2r_pipeline_v.body(rP_Aregs); } + + cp_async_wait_flash<0>(); + __syncthreads(); + + s2r_pipeline_v.epilogue(rP_Aregs); } + + auto acco_f16 = convert_type(acco); + + store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco_f16, + typename KeTraits::SmemCopyAtom{}, mma); + __syncthreads(); + + store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{}, + typename KeTraits::SmemLayoutO{}, tiled_copy_g2s); } } // namespace cutlass_wrapper diff --git a/benchmarks/utils/cpp/cutlass/copy.cuh b/benchmarks/utils/cpp/cutlass/copy.cuh index 08a4c1d..f783b98 100644 --- a/benchmarks/utils/cpp/cutlass/copy.cuh +++ b/benchmarks/utils/cpp/cutlass/copy.cuh @@ -54,12 +54,6 @@ DEVICE void copy_tile_g2s(const Element* src_data, Element* dst_data, auto src = loader.partition_S(gtile); auto dst = loader.partition_D(stile); - -#pragma unroll - for (int i = 0; i < int(size<1>(src)); ++i) -#pragma unroll - for (int j = 0; j < int(size<2>(src)); ++j) - cute::copy(tiled_copy, src(_, i, j), dst(_, i, j)); } // Copy a tensor from shared memory to global memory