diff --git a/.vscode/settings.json b/.vscode/settings.json index 81386e8..8a6e4a8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,20 +1,19 @@ { - "gotoSymbolStack.currentStackPosition": 0, - "gotoSymbolStack.maxStackPosition": 0, - "gotoSymbolStack.filePositionInfo": [], - "files.associations": { - "*.tcc": "cpp", - "optional": "cpp", - "ratio": "cpp", - "system_error": "cpp", - "array": "cpp", - "functional": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "utility": "cpp", - "variant": "cpp", - "compare": "cpp", - "concepts": "cpp", - "random": "cpp" - } + "files.associations": { + "array": "cpp", + "string": "cpp", + "string_view": "cpp", + "span": "cpp", + "bitset": "cpp", + "initializer_list": "cpp", + "utility": "cpp", + "*.tcc": "cpp", + "chrono": "cpp", + "random": "cpp", + "limits": "cpp", + "semaphore": "cpp" + }, + "gotoSymbolStack.currentStackPosition": 0, + "gotoSymbolStack.maxStackPosition": 0, + "gotoSymbolStack.filePositionInfo": [] } diff --git a/benchmarks/cpp/flashattention/CMakeLists.txt b/benchmarks/cpp/flashattention/CMakeLists.txt new file mode 100644 index 0000000..2b73af9 --- /dev/null +++ b/benchmarks/cpp/flashattention/CMakeLists.txt @@ -0,0 +1,19 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the +# MIT License. +# -------------------------------------------------------------------------- + +cmake_minimum_required(VERSION 3.25 FATAL_ERROR) +project(flash_attention_bench LANGUAGES C CXX CUDA) + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} + "${PROJECT_SOURCE_DIR}/../../../cmake") +set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party") + +include(generic) + +include_directories("${PROJECT_SOURCE_DIR}/../../../include") +include_directories("${PROJECT_SOURCE_DIR}/../../utils/cpp") +include_directories("${THIRD_PARTY_DIR}/cutlass/include") + +add_executable(flash_attn main.cu) diff --git a/benchmarks/cpp/flashattention/Makefile b/benchmarks/cpp/flashattention/Makefile new file mode 100644 index 0000000..2b59221 --- /dev/null +++ b/benchmarks/cpp/flashattention/Makefile @@ -0,0 +1,16 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + + +BUILD_DIR := build + +.PHONY: build clean + +build: + @mkdir -p $(BUILD_DIR) + @cd $(BUILD_DIR) && cmake .. && make -j$(proc) + +clean: + @rm -rf $(BUILD_DIR) diff --git a/benchmarks/cpp/flashattention/convert.cuh b/benchmarks/cpp/flashattention/convert.cuh new file mode 100644 index 0000000..2480eee --- /dev/null +++ b/benchmarks/cpp/flashattention/convert.cuh @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_utils.cuh" + +#include +#include +#include + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +template +CUTE_DEVICE auto convert_type(cute::Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +template +DEVICE auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using namespace cute; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + auto l = logical_divide(rowcol_layout, + Shape>>{}); + + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), + get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), get<1>(get<1>(get<1>(l)))); +} + +DEVICE auto convert_layout_C_Aregs() { + using namespace cute; + auto layout_s = Layout, _2, _16>>{}; + auto l = logical_divide(layout_s, Shape{}); + + return make_layout( + make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))), + get<1>(l), get<1>(get<2>(l))); +} + +template +DEVICE auto convert_layout_scores(LayoutType layout_s) { + using namespace cute; + static_assert(decltype(size<0>(layout_s))::value == 4); + static_assert(decltype(rank(layout_s))::value == 3); + + auto l = logical_divide(layout_s, Shape<_2>{}); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), + make_layout(get<0>(get<0>(l)), get<2>(l))); +} + +template +DEVICE auto convert_layout_scores_copyview(LayoutType layout_s) { + using namespace cute; + + auto l = logical_divide(layout_s, Shape>{}); + return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l))); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh new file mode 100644 index 0000000..ced7ae9 --- /dev/null +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -0,0 +1,634 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "cuda_utils.cuh" + +#include +#include + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +namespace detail { + +template +class G2SCopyQK { + public: + DEVICE G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK, + TiledCopy tiled_copy, int gQ_stride, int sQ_stride, + int gK_stride, int sK_stride, int num_stage = 2) + : gQ(gQ), + sQ(sQ), + gK(gK), + sK(sK), + gQ_stride(gQ_stride), + sQ_stride(sQ_stride), + gK_stride(gK_stride), + sK_stride(sK_stride), + cur_iter(0), + cur_iter_sk(0), + num_stage(num_stage) {} + + /** + * @brief Update the pointer of the global K tensor. + * + * Since the K matrix is split along both the n and k dimensions, the + * pointer offset for the K matrix needs to be updated to the next kTN * kK + * position during the next n dimension iteration. + * + * @param gK_slice The stride in N dimension. + * @param gK_stride The stride in K dimension. + */ + DEVICE void update_tile_K(int gK_slice, int gK_stride) { + gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride; + } + + /** + * @brief Preload the K matrix. When `load_q_once` is true, the Q matrix + * only needs to be loaded once and does not require repeated loading, while + * the K matrix needs to be updated and loaded. + */ + DEVICE void prologue_K() { +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; + + if ((cur_iter_sk + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter_sk++; + } + + DEVICE void prologue() { +#pragma unroll + for (int m = 0; m < size<1>(gQ); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } + +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + + gQ.data() = gQ.data() + gQ_stride; + sQ.data() = sQ.data() + sQ_stride; + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; + + // Circlically read SMEM Buffer + if ((cur_iter + 1) % num_stage == 0) { + sQ.data() = sQ.data() + (-sQ_stride * num_stage); + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void body() { +#pragma unroll + for (int m = 0; m < size<1>(gQ); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } + +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + + gQ.data() = gQ.data() + gQ_stride; + sQ.data() = sQ.data() + sQ_stride; + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sQ.data() = sQ.data() + (-sQ_stride * num_stage); + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void epilogue() { +#pragma unroll + for (int m = 0; m < size<1>(gQ); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } + +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + } + + private: + GQTensor& gQ; + SQTensor& sQ; + GKTensor& gK; + SKTensor& sK; + TiledCopy tiled_copy; + int gQ_stride; + int sQ_stride; + int gK_stride; + int sK_stride; + int cur_iter; + int cur_iter_sk; + int num_stage; +}; + +template +class G2SCopyV { + public: + DEVICE G2SCopyV(GVTensor& gV, SVTensor& sV, TiledCopy tiled_copy, + int gV_stride, int sV_stride, int num_stage = 2) + : gV(gV), + sV(sV), + gV_stride(gV_stride), + sV_stride(sV_stride), + cur_iter(0), + num_stage(num_stage) {} + + DEVICE void prologue() { +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + + cute::cp_async_fence(); + gV.data() = gV.data() + gV_stride; + sV.data() = sV.data() + sV_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void body() { +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + + cute::cp_async_fence(); + + gV.data() = gV.data() + gV_stride; + sV.data() = sV.data() + sV_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void epilogue() { +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + cute::cp_async_fence(); + } + + private: + GVTensor& gV; + SVTensor& sV; + TiledCopy tiled_copy; + int gV_stride; + int sV_stride; + int cur_iter; + int num_stage; +}; + +template +class S2RPipelineQK { + public: + DEVICE S2RPipelineQK(SQTensor& sQ, RQMmaView& rQ_mma_view, + RQCopyView& rQ_copy_view, SKTensor& sK, + RKMmaView& rK_mma_view, RKCopyView& rK_copy_view, + RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k, + TiledMma tiled_mma, int sQ_stride, int sK_stride, + int num_stage = 2) + : sQ(sQ), + rQ_mma_view(rQ_mma_view), + rQ_copy_view(rQ_copy_view), + sK(sK), + rK_mma_view(rK_mma_view), + rK_copy_view(rK_copy_view), + acc(acc), + copy_q(copy_q), + copy_k(copy_k), + tiled_mma(tiled_mma), + sQ_stride(sQ_stride), + sK_stride(sK_stride), + num_stage(num_stage), + cur_iter(0), + cur_iter_sq(0) {} + + DEVICE void prologue() { + cur_iter = 0; + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + sQ.data() = sQ.data() + sQ_stride; + sK.data() = sK.data() + sK_stride; + + cur_iter++; + } + + DEVICE void body() { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); + cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + sQ.data() = sQ.data() + sQ_stride; + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + cur_iter_sq++; + } + + DEVICE void epilogue() { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); + cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + + sQ.data() = sQ.data() + (-sQ_stride * cur_iter_sq); + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + cur_iter_sq = 0; + } + + private: + SQTensor& sQ; + RQMmaView& rQ_mma_view; + RQCopyView& rQ_copy_view; + SKTensor& sK; + RKMmaView& rK_mma_view; + RKCopyView& rK_copy_view; + RAccTensor& acc; + TiledCopyQ copy_q; + TiledCopyK copy_k; + TiledMma tiled_mma; + int sQ_stride; + int sK_stride; + int num_stage; + int cur_iter; + int cur_iter_sq; +}; + +template +class S2RPipelineV { + public: + DEVICE S2RPipelineV(SVTensor& sV, RVMmaView& rV_mma_view, + RVCopyView& rV_copy_view, RegAcc& acc, + TiledCopy tiled_copy, TiledMma tiled_mma, int sV_stride, + int num_stage = 2) + : sV(sV), + rV_mma_view(rV_mma_view), + rV_copy_view(rV_copy_view), + acc(acc), + tiled_copy(tiled_copy), + sV_stride(sV_stride), + num_stage(num_stage), + cur_iter(0), + cur_iter_sv(0) {} + + template + DEVICE void prologue(RegValue& value) { + cur_iter = 0; + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + // TODO(KuangjuX): Understand this code. Why do we need to use + // `value(_, _, cur_iter * size<2>(rV_mma_view) + i)`? + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + sV_stride; + cur_iter++; + } + + template + DEVICE void body(RegValue& value) { + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + sV_stride; + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + cur_iter_sv++; + } + + template + DEVICE void epilogue(RegValue& value) { + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + (-sV_stride * cur_iter_sv); + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + cur_iter_sv = 0; + } + + private: + SVTensor& sV; + RVMmaView& rV_mma_view; + RVCopyView& rV_copy_view; + RegAcc& acc; + TiledCopy tiled_copy; + TiledMma tiled_mma; + int sV_stride; + int num_stage; + int cur_iter; + int cur_iter_sv; +}; + +} // namespace detail + +template +inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, + const Element* gK_ptr, Element* sK_ptr, + int gQ_stride, int gK_stride) { + int tid = threadIdx.x; + + auto gQ = make_tensor(make_gmem_ptr(gQ_ptr), GlobalQLayout{}); + auto sQ = make_tensor(make_smem_ptr(sQ_ptr), SharedQLayout{}); + + auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{}); + auto sK = make_tensor(make_smem_ptr(sK_ptr), SharedKLayout{}); + + TiledCopy tiled_copy; + + auto loader = tiled_copy.get_thread_slice(tid); + + auto gQs = loader.partition_S(gQ); + auto gKs = loader.partition_S(gK); + auto sQs = loader.partition_D(sQ); + auto sKs = loader.partition_D(sK); + + int sQ_stride = size(sQ); + int sK_stride = size(sK); + + if (thread0()) { + printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n", + gQ_stride, sQ_stride, gK_stride, sK_stride); + } + + detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride, + sQ_stride, gK_stride, sK_stride); + + return copy_qk; +} + +template +DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) { + int tid = threadIdx.x; + + auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{}); + auto sV = make_tensor(make_smem_ptr(sV_ptr), SharedVLayout{}); + + TiledCopy tiled_copy; + + auto loader = tiled_copy.get_thread_slice(tid); + + auto gVs = loader.partition_S(gV); + auto sVs = loader.partition_D(sV); + + int sV_stride = size(sV); + + if (thread0()) { + printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride); + } + + detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride); + + return copy_v; +} + +template +DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, + SQLayout sQ_layout, SKLayout sK_layout, RegAcc acc, + SmemCopyAtom copy_atom = SmemCopyAtom{}, + TiledMma tiled_mma = TiledMma{}) { + int tid = threadIdx.x; + + auto sQ_ = make_tensor(make_smem_ptr(sQ_ptr), sQ_layout); + auto sK_ = make_tensor(make_smem_ptr(sK_ptr), sK_layout); + + auto thr_mma = tiled_mma.get_thread_slice(tid); + + auto s2r_copy_q = make_tiled_copy_A(copy_atom, tiled_mma); + auto s2r_copy_k = make_tiled_copy_B(copy_atom, tiled_mma); + auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid); + auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid); + + auto sQ = s2r_thr_copy_q.partition_S(sQ_); + auto sK = s2r_thr_copy_k.partition_S(sK_); + + // Thread partition for mma. + auto rQ_mma = thr_mma.partition_fragment_A(sQ_); + auto rK_mma = thr_mma.partition_fragment_B(sK_); + + // Thread partition for shared to register copy. + auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma); + auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma); + + int sQ_stride = size(sQ_); + int sK_stride = size(sK_); + + detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma, + rK_copy, acc, s2r_copy_q, s2r_copy_k, + tiled_mma, sQ_stride, sK_stride); + + return s2r_pipeline_qk; +} + +template +DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc, + SmemCopyAtom copy_atom, TiledMma tiled_mma) { + int tid = threadIdx.x; + + auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout); + + auto thr_mma = tiled_mma.get_thread_slice(tid); + + auto s2r_copy_v = make_tiled_copy_B(copy_atom, tiled_mma); + auto s2r_thr_copy_v = s2r_copy_v.get_thread_slice(tid); + + auto sV = s2r_thr_copy_v.partition_S(sV_); + + auto rV_mma = thr_mma.partition_fragment_B(sV_); + auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma); + + int sV_stride = size(sV_); + + detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v, + tiled_mma, sV_stride); + + return s2r_pipeline_v; +} + +template +DEVICE auto store_r2s_o(Element* sO_ptr, SOLayout sO_layout, RegO& o, + SmemCopyAtom copy_atom, TiledMma tiled_mma) { + auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout); + + auto r2s_copy_o = make_tiled_copy_C(copy_atom, tiled_mma); + auto r2s_thr_copy_o = r2s_copy_o.get_thread_slice(threadIdx.x); + + auto src = r2s_thr_copy_o.retile_S(o); + auto dst = r2s_thr_copy_o.partition_D(sO); + + cute::copy(r2s_copy_o, src, dst); +} + +template +DEVICE auto store_s2g_o(Element* gO_ptr, const Element* sO_ptr, + GOLayout gO_layout, SOLayout sO_layout, + TiledCopy tiled_copy) { + auto gO = make_tensor(make_gmem_ptr(gO_ptr), gO_layout); + auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout); + + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + + auto gO_partition = thr_copy.partition_D(gO); + auto sO_partition = thr_copy.partition_S(sO); + +#pragma unroll + for (int m = 0; m < size<1>(gO_partition); ++m) { +#pragma unroll + for (int n = 0; n < size<2>(gO_partition); ++n) { + cute::copy(tiled_copy, sO_partition(_, m, n), + gO_partition(_, m, n)); + } + } +} + +} // namespace cutlass_wrapper +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh new file mode 100644 index 0000000..6cae18b --- /dev/null +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "convert.cuh" +#include "copy.cuh" +#include "cuda_utils.cuh" +#include "cutlass/copy.cuh" +#include "cutlass/traits_base.cuh" +#include "reduce.cuh" + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +template > +struct FATraits : public Base { + using Element = Element_; + + // Declare global to shared memory copy layout. + using GmemLayoutQ = Layout, Int>, Stride, _1>>; + using GmemLayoutK = Layout, Int>, Stride, _1>>; + using GmemLayoutV = Layout, Int>, Stride, _1>>; + using GmemLayoutO = Layout, Int>, Stride, _1>>; + + static constexpr int kThreads = kWarpPerRow * kWarpPerCol * 32; + + /** + * Define the atomic layout of shared memory, which is the smallest + * configuration unit of shared memory. Larger shapes are tiled based on the + * atomic layout. + */ + using SmemLayoutAtom = decltype(composition( + Swizzle{}, + Layout>, Stride, _1>>{})); + + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + + /** + * In the Ampere architecture, loading from shared memory to register memory + * requires the use of the `ldmatrix` instruction, while storing from + * register memory to shared memory does not have hardware support and uses + * a default copy instead.” + */ + using LoadS2RCopyAtom = Copy_Atom; + using StoreR2SCopyAtom = Copy_Atom; + + static constexpr int kWarps = kThreads / 32; + + using TiledMma = + TiledMMA, + Layout, Int, _1>>, + Tile, Int<16 * kWarpPerCol>, _16>>; + +#ifdef CP_ASYNC_SM80_ENABLED + // for Ampere + using CopyInstG2S = + Copy_Atom, Element>; +#else + using CopyInstG2S = Copy_Atom; +#endif + + // TODO(KuangjuX): Understand this configuration. + using GmemCopyLayoutAtom = + Layout, Int>, + Stride, _1>>; + + using TiledCopyG2S = decltype(make_tiled_copy( + CopyInstG2S{}, GmemCopyLayoutAtom{}, Layout>{})); + + using TiledCopyS2G = decltype(make_tiled_copy( + Copy_Atom{}, GmemCopyLayoutAtom{}, + Layout>{})); +}; + +template +__global__ void __launch_bounds__(Nthreads) + fa_kernel(const Element* dQ, const Element* dK, const Element* dV, + Element* dO) { + constexpr float softmax_scale = 1.250000e-01f; + const bool load_q_once = (kTK == kK); + + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); + + const Element* Q = dQ + blockIdx.z * kTM * kN + blockIdx.x * kTM * kK; + const Element* K = dK + blockIdx.z * kK * kN; + const Element* V = dV + blockIdx.z * kP * kN + blockIdx.y * kTP * kN; + Element* O = + dO + blockIdx.z * kM * kP + blockIdx.x * (kTM * kP) + blockIdx.y * kTP; + + Element* sQ_ptr = reinterpret_cast(buf); + Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK; + Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK; + Element* sO_ptr = sQ_ptr; + + typename KeTraits::TiledMma mma; + typename KeTraits::TiledCopyG2S tiled_copy_g2s; + + // Build the copy plan for QK from global memory to shared memory. + auto g2s_copy_qk = make_g2s_qk< + Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ, + typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK, + typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kTK, kTK); + + /** + * In FractalTensor, The size of the V matrix is [kN, kP], and the size + * processed in a single SM Block is [kN, kTP]. When split along the N + * dimension, the size is [kTN, kTP]. Therefore, the stride for global + * memory should be set to kTN * kP. + * + * In the current implementation, the shape of the V matrix is [kP, kN], and + * the block size processed by a single Block is [kTP, kN]. Therefore, the + * stride only needs to be set to kTN each time. + */ + auto g2s_copy_v = + make_g2s_v(V, sV_ptr, kTN); + + auto acc0 = get_acc(mma); + auto acco = get_acc(mma); + + auto m_new = make_tensor(Shape(acc0)>>{}); + auto lse_new = make_fragment_like(m_new); + + auto s2r_pipeline_qk = + make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{}, + typename KeTraits::SmemLayoutK{}, acc0, + typename KeTraits::LoadS2RCopyAtom{}, mma); + + auto s2r_pipeline_v = + make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, + typename KeTraits::LoadS2RCopyAtom{}, mma); + + // Issue global to shared memory copy before the main loop. + g2s_copy_qk.prologue(); + + fill(lse_new, 0.0f); + fill(m_new, -INFINITY); + clear(acco); + + /** + * Flash Attention performs two-level tiling for each SM Block, splitting + * along the N dimension and the K dimension. The Q matrix is split along + * the K dimension, the V matrix is split along the N dimension, and the K + * matrix is split along both dimensions simultaneously. + */ + int split_n = kN / kTN; + for (int n = 0; n < split_n; ++n) { + clear(acc0); + int slice_k = kK / kTK - 1; + for (int k = 0; k < slice_k; ++k) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_qk.body(); + // Load data from shared memory into register and issue MMA. + s2r_pipeline_qk.body(); + } + + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.prologue(); + s2r_pipeline_qk.epilogue(); + + // scores = dot(q, k) + auto scores = + make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); + + auto m_old = make_fragment_like(m_new); + copy(m_new, m_old); + + auto scores_max = make_fragment_like(m_new); + + // scores_max = reduce_max(scores, axis=1) + reduce_max<4, true>(scores, scores_max); + + // Compute new partial max value. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = max(m_new(ax0), scores_max(ax0)); + } + + auto acco_rowcol = + make_tensor(acco.data(), convert_layout_scores(acco.layout())); + + // Renormalizatio for the previous block. + for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) { + float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); + lse_new(ax0) = lse_new(ax0) * scale; + for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) { + acco_rowcol(ax0, ax1) *= scale; + } + } + + for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { + float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); + lse_new(ax0) = lse_new(ax0) * m_scaled; + for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { + scores(ax0, ax1) = + exp(scores(ax0, ax1) * softmax_scale - m_scaled); + } + } + + auto scores_sum = make_fragment_like(lse_new); + reduce_sum<4>(scores, scores_sum); + + for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { + lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); + } + + // TODO(KuangjuX): Understand the following code. + auto frag = convert_type(scores); + auto rP = make_tensor(make_rmem_ptr(&frag), scores.layout()); + auto rP_Aregs = + make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); + + /** + * In FractalTensor, the `kTN` dimension is split again. To simplify the + * current implementation of rhe pipeline flashattention, the `tile_n` + * is hardcoded to 0 at this point. + */ + const int tile_n = 0; + for (int tile_ = 0; tile_ < tile_n; ++tile_) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.body(); + s2r_pipeline_v.body(rP_Aregs); + } + + cp_async_wait_flash<0>(); + __syncthreads(); + + if (n < split_n - 1) { + /** + * Update K tile because the entire K Block will be processed in a + * single SM Block. + * + * For example, In `TileFusion`: + * ```cpp + * for (int n = 0; n < GIteratorV::sc0; ++n) { + * load_sv(gVs(n), sV); + * for (int k = 0; k < GIteratorQ::sc1; ++k) { + * load_sq(gQs(k), sQ); + * load_sk(gKs(k, n), sK); + * } + * } + * ``` + */ + g2s_copy_qk.update_tile_K(kTN, kK); + /** + * `load_q_once` means that at this point `kK == kTK`, and the Q is + * loaded into shared memory in blocks only once. In this case, we + * only need to update the pointer of K and do not need to update + * the pointer for Q, because the blocking along the k dimension + * will not be executed, thus the Q is always reloaded. + */ + if (load_q_once) { + g2s_copy_qk.prologue_K(); + } + } + + s2r_pipeline_v.epilogue(rP_Aregs); + } + + // Store O from registers to shared memory and then to global memory. + store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco, + typename KeTraits::StoreR2SCopyAtom{}, mma); + __syncthreads(); + + store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{}, + typename KeTraits::SmemLayoutO{}, + typename KeTraits::TiledCopyS2G{}); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu new file mode 100644 index 0000000..bf3fb11 --- /dev/null +++ b/benchmarks/cpp/flashattention/main.cu @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cutlass_fa.cuh" +#include "util.hpp" + +void run(bool check = true) { + using InType = cutlass::half_t; + using AccType = cutlass::half_t; + using OutType = cutlass::half_t; + + static constexpr int kM = 64; + static constexpr int kN = 64; + static constexpr int kK = 128; + static constexpr int kP = 128; + + static constexpr int kTM = 64; + static constexpr int kTN = 64; + static constexpr int kTK = 128; + static constexpr int kTP = 128; + + static constexpr int kBatch = 1; + + static constexpr int kWarpPerRow = 1; + static constexpr int kWarpPerCol = 1; + static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32; + static constexpr int kStagesQK = 1; + static constexpr int kStagesV = 1; + + static_assert(kK == kTK, + "The current implementation requires kTK == K for now."); + static_assert(kP == kTP, + "The current implementation requires kTP == P for now."); + + // initialize data + thrust::host_vector h_a(kM * kK * kBatch); + + for (int i = 0; i < h_a.size(); ++i) + h_a[i] = static_cast(rand_float()); + + thrust::host_vector h_b(kK * kN * kBatch); + for (int i = 0; i < h_b.size(); ++i) + h_b[i] = static_cast(rand_float()); + + thrust::host_vector h_c(kN * kP * kBatch); + for (int i = 0; i < h_c.size(); ++i) + h_c[i] = static_cast(rand_float()); + + thrust::host_vector h_d(kM * kP * kBatch); + thrust::fill(h_d.begin(), h_d.end(), 0.); + + // Host side memory initialization. + thrust::host_vector acc(kM * kN * kBatch); + thrust::fill(acc.begin(), acc.end(), 0.); + + thrust::host_vector exp_values(kM * kP * kBatch); + thrust::fill(exp_values.begin(), exp_values.end(), 0.); + + thrust::host_vector h_o(kM * kP * kBatch); + thrust::fill(h_o.begin(), h_o.end(), 0.); + + thrust::host_vector cur_row_max(kM * kBatch); + thrust::fill(cur_row_max.begin(), cur_row_max.end(), 0.); + + thrust::host_vector prev_row_max(kM * kBatch); + thrust::fill(prev_row_max.begin(), prev_row_max.end(), 0.); + + thrust::host_vector new_row_max(kM * kBatch); + thrust::fill(new_row_max.begin(), new_row_max.end(), 0.); + + thrust::host_vector prev_norm_vec(kM * kBatch); + thrust::fill(prev_norm_vec.begin(), prev_norm_vec.end(), 0.); + + thrust::host_vector new_norm_vec(kM * kBatch); + thrust::fill(new_norm_vec.begin(), new_norm_vec.end(), 0.); + + thrust::host_vector prev_sum_vec(kM * kBatch); + thrust::fill(prev_sum_vec.begin(), prev_sum_vec.end(), 0.); + + thrust::host_vector cur_sum_vec(kM * kBatch); + thrust::fill(cur_sum_vec.begin(), cur_sum_vec.end(), 0.); + + thrust::host_vector new_sum_vec(kM * kBatch); + thrust::fill(new_sum_vec.begin(), new_sum_vec.end(), 0.); + + thrust::device_vector d_a = h_a; + thrust::device_vector d_b = h_b; + thrust::device_vector d_c = h_c; + thrust::device_vector d_d = h_d; + + const InType* A = thrust::raw_pointer_cast(d_a.data()); + const InType* B = thrust::raw_pointer_cast(d_b.data()); + const InType* C = thrust::raw_pointer_cast(d_c.data()); + InType* D = thrust::raw_pointer_cast(d_d.data()); + + int block_x = (kM + kTM - 1) / kTM; + int block_y = (kP + kTP - 1) / kTP; + int block_z = kBatch; + + dim3 grid(block_x, block_y, block_z); + dim3 block(kThreads, 1, 1); + + int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP); + int shm_output = kTM * kTP; + int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) + : shm_input * sizeof(InType); + + using Traits = + benchmarks::cutlass_wrapper::FATraits; + + auto fa_kernel = + benchmarks::cutlass_wrapper::fa_kernel; + + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute( + fa_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + } + + fa_kernel<<>>(A, B, C, D); + + cudaDeviceSynchronize(); +} + +int main() { run(); } diff --git a/benchmarks/cpp/flashattention/reduce.cuh b/benchmarks/cpp/flashattention/reduce.cuh new file mode 100644 index 0000000..f85d34d --- /dev/null +++ b/benchmarks/cpp/flashattention/reduce.cuh @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_utils.cuh" + +#include + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +struct MaxOp_float { + DEVICE float operator()(float const& x, float const& y) { + return max(x, y); + } +}; + +template +struct SumOp { + DEVICE T operator()(T const& x, T const& y) { return x + y; } +}; + +template +struct SumAbsOp { + DEVICE T operator()(T const& x, T const& y) { return x + abs(y); } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || + THREADS == 4); + template + static DEVICE T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template <> +struct Allreduce<2> { + template + static DEVICE T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template +DEVICE void thread_reduce_(cute::Tensor const& tensor, + cute::Tensor& summary, + Operator& op) { + using namespace cute; + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = + zero_init ? op(0, tensor(mi, 0)) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +DEVICE void quad_allreduce_(cute::Tensor& dst, + cute::Tensor& src, Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +DEVICE void eight_allreduce_(cute::Tensor& dst, + cute::Tensor& src, + Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<8>::run(src(i), op); + } +} + +template +DEVICE void allreduce_(cute::Tensor& dst, + cute::Tensor& src, Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce::run(src(i), op); + } +} + +template +DEVICE void reduce_(cute::Tensor const& tensor, + cute::Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + allreduce_(summary, summary, op); +} + +template +DEVICE void reduce_max(cute::Tensor const& tensor, + cute::Tensor& max) { + MaxOp_float max_op; + reduce_(tensor, max, max_op); +} + +template +DEVICE void reduce_sum(cute::Tensor const& tensor, + cute::Tensor& sum) { + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +template +DEVICE void reduce_sumabs(cute::Tensor const& tensor, + cute::Tensor& sum) { + SumAbsOp sumabs_op; + reduce_(tensor, sum, sumabs_op); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/util.hpp b/benchmarks/cpp/flashattention/util.hpp new file mode 100644 index 0000000..1cc00eb --- /dev/null +++ b/benchmarks/cpp/flashattention/util.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "util/debug.hpp" + +#include +#include + +float rand_float(float a = 1e-1, float b = 5e-2) { + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; +} + +bool check_results(const __half* values1, const __half* values2, int numel) { + bool passed = true; + const float epsilon = 1e-1; + + for (int i = 0; i < numel; ++i) { + if (fabs(__half2float(values1[i]) - __half2float(values2[i])) > + epsilon) { + printf("%d-th value differs: %.3f vs. %.3f\n", i, + __half2float(values1[i]), __half2float(values2[i])); + passed = false; + break; + } + } + return passed; +} diff --git a/benchmarks/utils/cpp/cutlass/copy.cuh b/benchmarks/utils/cpp/cutlass/copy.cuh index b3b105b..08a4c1d 100644 --- a/benchmarks/utils/cpp/cutlass/copy.cuh +++ b/benchmarks/utils/cpp/cutlass/copy.cuh @@ -32,6 +32,13 @@ DEVICE void __copy_async() { wait_group<0>(); } +template +DEVICE void cp_async_wait_flash() { +#if defined(CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + // Copy a 2d data tile from global memory to shared memory template