diff --git a/.vscode/settings.json b/.vscode/settings.json index 81386e81..8a6e4a8e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,20 +1,19 @@ { - "gotoSymbolStack.currentStackPosition": 0, - "gotoSymbolStack.maxStackPosition": 0, - "gotoSymbolStack.filePositionInfo": [], - "files.associations": { - "*.tcc": "cpp", - "optional": "cpp", - "ratio": "cpp", - "system_error": "cpp", - "array": "cpp", - "functional": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "utility": "cpp", - "variant": "cpp", - "compare": "cpp", - "concepts": "cpp", - "random": "cpp" - } + "files.associations": { + "array": "cpp", + "string": "cpp", + "string_view": "cpp", + "span": "cpp", + "bitset": "cpp", + "initializer_list": "cpp", + "utility": "cpp", + "*.tcc": "cpp", + "chrono": "cpp", + "random": "cpp", + "limits": "cpp", + "semaphore": "cpp" + }, + "gotoSymbolStack.currentStackPosition": 0, + "gotoSymbolStack.maxStackPosition": 0, + "gotoSymbolStack.filePositionInfo": [] } diff --git a/benchmarks/cpp/flashattention/CMakeLists.txt b/benchmarks/cpp/flashattention/CMakeLists.txt new file mode 100644 index 00000000..2b73af97 --- /dev/null +++ b/benchmarks/cpp/flashattention/CMakeLists.txt @@ -0,0 +1,19 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the +# MIT License. +# -------------------------------------------------------------------------- + +cmake_minimum_required(VERSION 3.25 FATAL_ERROR) +project(flash_attention_bench LANGUAGES C CXX CUDA) + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} + "${PROJECT_SOURCE_DIR}/../../../cmake") +set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party") + +include(generic) + +include_directories("${PROJECT_SOURCE_DIR}/../../../include") +include_directories("${PROJECT_SOURCE_DIR}/../../utils/cpp") +include_directories("${THIRD_PARTY_DIR}/cutlass/include") + +add_executable(flash_attn main.cu) diff --git a/benchmarks/cpp/flashattention/Makefile b/benchmarks/cpp/flashattention/Makefile new file mode 100644 index 00000000..2b592215 --- /dev/null +++ b/benchmarks/cpp/flashattention/Makefile @@ -0,0 +1,16 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + + +BUILD_DIR := build + +.PHONY: build clean + +build: + @mkdir -p $(BUILD_DIR) + @cd $(BUILD_DIR) && cmake .. && make -j$(proc) + +clean: + @rm -rf $(BUILD_DIR) diff --git a/benchmarks/cpp/flashattention/convert.cuh b/benchmarks/cpp/flashattention/convert.cuh new file mode 100644 index 00000000..2480eee4 --- /dev/null +++ b/benchmarks/cpp/flashattention/convert.cuh @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_utils.cuh" + +#include <cute/layout.hpp> +#include <cute/tensor.hpp> +#include <cutlass/numeric_conversion.h> + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +template <typename To_type, typename Engine, typename Layout> +CUTE_DEVICE auto convert_type(cute::Tensor<Engine, Layout> const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; + auto frag = + convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>( + tensor.data())); + return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout()); +} + +template <typename Layout> +DEVICE auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using namespace cute; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + auto l = logical_divide(rowcol_layout, + Shape<Underscore, Shape<Underscore, Int<2>>>{}); + + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), + get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), get<1>(get<1>(get<1>(l)))); +} + +DEVICE auto convert_layout_C_Aregs() { + using namespace cute; + auto layout_s = Layout<Shape<Shape<_2, _2>, _2, _16>>{}; + auto l = logical_divide(layout_s, Shape<Underscore, Underscore, _2>{}); + + return make_layout( + make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))), + get<1>(l), get<1>(get<2>(l))); +} + +template <class LayoutType> +DEVICE auto convert_layout_scores(LayoutType layout_s) { + using namespace cute; + static_assert(decltype(size<0>(layout_s))::value == 4); + static_assert(decltype(rank(layout_s))::value == 3); + + auto l = logical_divide(layout_s, Shape<_2>{}); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), + make_layout(get<0>(get<0>(l)), get<2>(l))); +} + +template <int ATOMNUM, class LayoutType> +DEVICE auto convert_layout_scores_copyview(LayoutType layout_s) { + using namespace cute; + + auto l = logical_divide(layout_s, Shape<Underscore, Int<ATOMNUM>>{}); + return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l))); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh new file mode 100644 index 00000000..ced7ae92 --- /dev/null +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -0,0 +1,634 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "cuda_utils.cuh" + +#include <cute/tensor.hpp> +#include <cutlass/numeric_conversion.h> + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +namespace detail { + +template <typename GQTensor, typename SQTensor, typename GKTensor, + typename SKTensor, typename TiledCopy> +class G2SCopyQK { + public: + DEVICE G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK, + TiledCopy tiled_copy, int gQ_stride, int sQ_stride, + int gK_stride, int sK_stride, int num_stage = 2) + : gQ(gQ), + sQ(sQ), + gK(gK), + sK(sK), + gQ_stride(gQ_stride), + sQ_stride(sQ_stride), + gK_stride(gK_stride), + sK_stride(sK_stride), + cur_iter(0), + cur_iter_sk(0), + num_stage(num_stage) {} + + /** + * @brief Update the pointer of the global K tensor. + * + * Since the K matrix is split along both the n and k dimensions, the + * pointer offset for the K matrix needs to be updated to the next kTN * kK + * position during the next n dimension iteration. + * + * @param gK_slice The stride in N dimension. + * @param gK_stride The stride in K dimension. + */ + DEVICE void update_tile_K(int gK_slice, int gK_stride) { + gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride; + } + + /** + * @brief Preload the K matrix. When `load_q_once` is true, the Q matrix + * only needs to be loaded once and does not require repeated loading, while + * the K matrix needs to be updated and loaded. + */ + DEVICE void prologue_K() { +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; + + if ((cur_iter_sk + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter_sk++; + } + + DEVICE void prologue() { +#pragma unroll + for (int m = 0; m < size<1>(gQ); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } + +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + + gQ.data() = gQ.data() + gQ_stride; + sQ.data() = sQ.data() + sQ_stride; + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; + + // Circlically read SMEM Buffer + if ((cur_iter + 1) % num_stage == 0) { + sQ.data() = sQ.data() + (-sQ_stride * num_stage); + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void body() { +#pragma unroll + for (int m = 0; m < size<1>(gQ); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } + +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + + gQ.data() = gQ.data() + gQ_stride; + sQ.data() = sQ.data() + sQ_stride; + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sQ.data() = sQ.data() + (-sQ_stride * num_stage); + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void epilogue() { +#pragma unroll + for (int m = 0; m < size<1>(gQ); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } + +#pragma unroll + for (int m = 0; m < size<1>(gK); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } + + cute::cp_async_fence(); + } + + private: + GQTensor& gQ; + SQTensor& sQ; + GKTensor& gK; + SKTensor& sK; + TiledCopy tiled_copy; + int gQ_stride; + int sQ_stride; + int gK_stride; + int sK_stride; + int cur_iter; + int cur_iter_sk; + int num_stage; +}; + +template <typename GVTensor, typename SVTensor, typename TiledCopy> +class G2SCopyV { + public: + DEVICE G2SCopyV(GVTensor& gV, SVTensor& sV, TiledCopy tiled_copy, + int gV_stride, int sV_stride, int num_stage = 2) + : gV(gV), + sV(sV), + gV_stride(gV_stride), + sV_stride(sV_stride), + cur_iter(0), + num_stage(num_stage) {} + + DEVICE void prologue() { +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + + cute::cp_async_fence(); + gV.data() = gV.data() + gV_stride; + sV.data() = sV.data() + sV_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void body() { +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + + cute::cp_async_fence(); + + gV.data() = gV.data() + gV_stride; + sV.data() = sV.data() + sV_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + } + + DEVICE void epilogue() { +#pragma unroll + for (int m = 0; m < size<1>(gV); ++m) { +#pragma unroll + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } + cute::cp_async_fence(); + } + + private: + GVTensor& gV; + SVTensor& sV; + TiledCopy tiled_copy; + int gV_stride; + int sV_stride; + int cur_iter; + int num_stage; +}; + +template <typename SQTensor, typename RQMmaView, typename RQCopyView, + typename SKTensor, typename RKMmaView, typename RKCopyView, + typename RAccTensor, typename TiledCopyQ, typename TiledCopyK, + typename TiledMma> +class S2RPipelineQK { + public: + DEVICE S2RPipelineQK(SQTensor& sQ, RQMmaView& rQ_mma_view, + RQCopyView& rQ_copy_view, SKTensor& sK, + RKMmaView& rK_mma_view, RKCopyView& rK_copy_view, + RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k, + TiledMma tiled_mma, int sQ_stride, int sK_stride, + int num_stage = 2) + : sQ(sQ), + rQ_mma_view(rQ_mma_view), + rQ_copy_view(rQ_copy_view), + sK(sK), + rK_mma_view(rK_mma_view), + rK_copy_view(rK_copy_view), + acc(acc), + copy_q(copy_q), + copy_k(copy_k), + tiled_mma(tiled_mma), + sQ_stride(sQ_stride), + sK_stride(sK_stride), + num_stage(num_stage), + cur_iter(0), + cur_iter_sq(0) {} + + DEVICE void prologue() { + cur_iter = 0; + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + sQ.data() = sQ.data() + sQ_stride; + sK.data() = sK.data() + sK_stride; + + cur_iter++; + } + + DEVICE void body() { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); + cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + sQ.data() = sQ.data() + sQ_stride; + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + cur_iter_sq++; + } + + DEVICE void epilogue() { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); + cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), + acc); + } + + sQ.data() = sQ.data() + (-sQ_stride * cur_iter_sq); + sK.data() = sK.data() + sK_stride; + + if ((cur_iter + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + cur_iter_sq = 0; + } + + private: + SQTensor& sQ; + RQMmaView& rQ_mma_view; + RQCopyView& rQ_copy_view; + SKTensor& sK; + RKMmaView& rK_mma_view; + RKCopyView& rK_copy_view; + RAccTensor& acc; + TiledCopyQ copy_q; + TiledCopyK copy_k; + TiledMma tiled_mma; + int sQ_stride; + int sK_stride; + int num_stage; + int cur_iter; + int cur_iter_sq; +}; + +template <typename SVTensor, typename RVMmaView, typename RVCopyView, + typename RegAcc, typename TiledCopy, typename TiledMma> +class S2RPipelineV { + public: + DEVICE S2RPipelineV(SVTensor& sV, RVMmaView& rV_mma_view, + RVCopyView& rV_copy_view, RegAcc& acc, + TiledCopy tiled_copy, TiledMma tiled_mma, int sV_stride, + int num_stage = 2) + : sV(sV), + rV_mma_view(rV_mma_view), + rV_copy_view(rV_copy_view), + acc(acc), + tiled_copy(tiled_copy), + sV_stride(sV_stride), + num_stage(num_stage), + cur_iter(0), + cur_iter_sv(0) {} + + template <typename RegValue> + DEVICE void prologue(RegValue& value) { + cur_iter = 0; + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + // TODO(KuangjuX): Understand this code. Why do we need to use + // `value(_, _, cur_iter * size<2>(rV_mma_view) + i)`? + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + sV_stride; + cur_iter++; + } + + template <typename RegValue> + DEVICE void body(RegValue& value) { + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + sV_stride; + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + cur_iter_sv++; + } + + template <typename RegValue> + DEVICE void epilogue(RegValue& value) { + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), + rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, + value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + (-sV_stride * cur_iter_sv); + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + cur_iter_sv = 0; + } + + private: + SVTensor& sV; + RVMmaView& rV_mma_view; + RVCopyView& rV_copy_view; + RegAcc& acc; + TiledCopy tiled_copy; + TiledMma tiled_mma; + int sV_stride; + int num_stage; + int cur_iter; + int cur_iter_sv; +}; + +} // namespace detail + +template <typename Element, typename GlobalQLayout, typename SharedQLayout, + typename GlobalKLayout, typename SharedKLayout, typename TiledCopy> +inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr, + const Element* gK_ptr, Element* sK_ptr, + int gQ_stride, int gK_stride) { + int tid = threadIdx.x; + + auto gQ = make_tensor(make_gmem_ptr(gQ_ptr), GlobalQLayout{}); + auto sQ = make_tensor(make_smem_ptr(sQ_ptr), SharedQLayout{}); + + auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{}); + auto sK = make_tensor(make_smem_ptr(sK_ptr), SharedKLayout{}); + + TiledCopy tiled_copy; + + auto loader = tiled_copy.get_thread_slice(tid); + + auto gQs = loader.partition_S(gQ); + auto gKs = loader.partition_S(gK); + auto sQs = loader.partition_D(sQ); + auto sKs = loader.partition_D(sK); + + int sQ_stride = size(sQ); + int sK_stride = size(sK); + + if (thread0()) { + printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n", + gQ_stride, sQ_stride, gK_stride, sK_stride); + } + + detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride, + sQ_stride, gK_stride, sK_stride); + + return copy_qk; +} + +template <typename Element, typename GlobalVLayout, typename SharedVLayout, + typename TiledCopy> +DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) { + int tid = threadIdx.x; + + auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{}); + auto sV = make_tensor(make_smem_ptr(sV_ptr), SharedVLayout{}); + + TiledCopy tiled_copy; + + auto loader = tiled_copy.get_thread_slice(tid); + + auto gVs = loader.partition_S(gV); + auto sVs = loader.partition_D(sV); + + int sV_stride = size(sV); + + if (thread0()) { + printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride); + } + + detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride); + + return copy_v; +} + +template <typename Element, typename SQLayout, typename SKLayout, + typename RegAcc, typename SmemCopyAtom, typename TiledMma> +DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr, + SQLayout sQ_layout, SKLayout sK_layout, RegAcc acc, + SmemCopyAtom copy_atom = SmemCopyAtom{}, + TiledMma tiled_mma = TiledMma{}) { + int tid = threadIdx.x; + + auto sQ_ = make_tensor(make_smem_ptr(sQ_ptr), sQ_layout); + auto sK_ = make_tensor(make_smem_ptr(sK_ptr), sK_layout); + + auto thr_mma = tiled_mma.get_thread_slice(tid); + + auto s2r_copy_q = make_tiled_copy_A(copy_atom, tiled_mma); + auto s2r_copy_k = make_tiled_copy_B(copy_atom, tiled_mma); + auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid); + auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid); + + auto sQ = s2r_thr_copy_q.partition_S(sQ_); + auto sK = s2r_thr_copy_k.partition_S(sK_); + + // Thread partition for mma. + auto rQ_mma = thr_mma.partition_fragment_A(sQ_); + auto rK_mma = thr_mma.partition_fragment_B(sK_); + + // Thread partition for shared to register copy. + auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma); + auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma); + + int sQ_stride = size(sQ_); + int sK_stride = size(sK_); + + detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma, + rK_copy, acc, s2r_copy_q, s2r_copy_k, + tiled_mma, sQ_stride, sK_stride); + + return s2r_pipeline_qk; +} + +template <typename Element, typename SVLayout, typename RegAcc, + typename SmemCopyAtom, typename TiledMma> +DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc, + SmemCopyAtom copy_atom, TiledMma tiled_mma) { + int tid = threadIdx.x; + + auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout); + + auto thr_mma = tiled_mma.get_thread_slice(tid); + + auto s2r_copy_v = make_tiled_copy_B(copy_atom, tiled_mma); + auto s2r_thr_copy_v = s2r_copy_v.get_thread_slice(tid); + + auto sV = s2r_thr_copy_v.partition_S(sV_); + + auto rV_mma = thr_mma.partition_fragment_B(sV_); + auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma); + + int sV_stride = size(sV_); + + detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v, + tiled_mma, sV_stride); + + return s2r_pipeline_v; +} + +template <typename Element, typename SOLayout, typename RegO, + typename SmemCopyAtom, typename TiledMma> +DEVICE auto store_r2s_o(Element* sO_ptr, SOLayout sO_layout, RegO& o, + SmemCopyAtom copy_atom, TiledMma tiled_mma) { + auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout); + + auto r2s_copy_o = make_tiled_copy_C(copy_atom, tiled_mma); + auto r2s_thr_copy_o = r2s_copy_o.get_thread_slice(threadIdx.x); + + auto src = r2s_thr_copy_o.retile_S(o); + auto dst = r2s_thr_copy_o.partition_D(sO); + + cute::copy(r2s_copy_o, src, dst); +} + +template <typename Element, typename GOLayout, typename SOLayout, + typename TiledCopy> +DEVICE auto store_s2g_o(Element* gO_ptr, const Element* sO_ptr, + GOLayout gO_layout, SOLayout sO_layout, + TiledCopy tiled_copy) { + auto gO = make_tensor(make_gmem_ptr(gO_ptr), gO_layout); + auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout); + + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + + auto gO_partition = thr_copy.partition_D(gO); + auto sO_partition = thr_copy.partition_S(sO); + +#pragma unroll + for (int m = 0; m < size<1>(gO_partition); ++m) { +#pragma unroll + for (int n = 0; n < size<2>(gO_partition); ++n) { + cute::copy(tiled_copy, sO_partition(_, m, n), + gO_partition(_, m, n)); + } + } +} + +} // namespace cutlass_wrapper +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh new file mode 100644 index 00000000..6cae18bc --- /dev/null +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "convert.cuh" +#include "copy.cuh" +#include "cuda_utils.cuh" +#include "cutlass/copy.cuh" +#include "cutlass/traits_base.cuh" +#include "reduce.cuh" + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +template <typename Element_, const int kM, const int kN, const int kK, + const int kP, const int kTM, const int kTN, const int kTK, + const int kTP, const int kWarpPerRow, const int kWarpPerCol, + const int SmemKAtom = 64, const int kSwizzle = 3, + typename Base = AccessBase<Element_>> +struct FATraits : public Base { + using Element = Element_; + + // Declare global to shared memory copy layout. + using GmemLayoutQ = Layout<Shape<Int<kTM>, Int<kTK>>, Stride<Int<kK>, _1>>; + using GmemLayoutK = Layout<Shape<Int<kTN>, Int<kTK>>, Stride<Int<kK>, _1>>; + using GmemLayoutV = Layout<Shape<Int<kTP>, Int<kTN>>, Stride<Int<kN>, _1>>; + using GmemLayoutO = Layout<Shape<Int<kTM>, Int<kTP>>, Stride<Int<kP>, _1>>; + + static constexpr int kThreads = kWarpPerRow * kWarpPerCol * 32; + + /** + * Define the atomic layout of shared memory, which is the smallest + * configuration unit of shared memory. Larger shapes are tiled based on the + * atomic layout. + */ + using SmemLayoutAtom = decltype(composition( + Swizzle<kSwizzle, 3, 3>{}, + Layout<Shape<_8, Int<SmemKAtom>>, Stride<Int<SmemKAtom>, _1>>{})); + + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTM>, Int<kTK>>{})); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTN>, Int<kTK>>{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTP>, Int<kTN>>{})); + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTM>, Int<kTP>>{})); + + /** + * In the Ampere architecture, loading from shared memory to register memory + * requires the use of the `ldmatrix` instruction, while storing from + * register memory to shared memory does not have hardware support and uses + * a default copy instead.” + */ + using LoadS2RCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>; + using StoreR2SCopyAtom = Copy_Atom<DefaultCopy, Element>; + + static constexpr int kWarps = kThreads / 32; + + using TiledMma = + TiledMMA<MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>, + Layout<Shape<Int<kWarpPerRow>, Int<kWarpPerCol>, _1>>, + Tile<Int<16 * kWarpPerRow>, Int<16 * kWarpPerCol>, _16>>; + +#ifdef CP_ASYNC_SM80_ENABLED + // for Ampere + using CopyInstG2S = + Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>; +#else + using CopyInstG2S = Copy_Atom<DefaultCopy, Element>; +#endif + + // TODO(KuangjuX): Understand this configuration. + using GmemCopyLayoutAtom = + Layout<Shape<Int<kThreads / (SmemKAtom / 8)>, Int<SmemKAtom / 8>>, + Stride<Int<SmemKAtom / 8>, _1>>; + + using TiledCopyG2S = decltype(make_tiled_copy( + CopyInstG2S{}, GmemCopyLayoutAtom{}, Layout<Shape<_1, _8>>{})); + + using TiledCopyS2G = decltype(make_tiled_copy( + Copy_Atom<DefaultCopy, Element>{}, GmemCopyLayoutAtom{}, + Layout<Shape<_1, _8>>{})); +}; + +template <typename Element, typename KeTraits, const int kM, const int kN, + const int kK, const int kP, const int kTM, const int kTN, + const int kTK, const int kTP, const int Nthreads, const int kStagesQK, + const int kStageV> +__global__ void __launch_bounds__(Nthreads) + fa_kernel(const Element* dQ, const Element* dK, const Element* dV, + Element* dO) { + constexpr float softmax_scale = 1.250000e-01f; + const bool load_q_once = (kTK == kK); + + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast<Element*>(buf_); + + const Element* Q = dQ + blockIdx.z * kTM * kN + blockIdx.x * kTM * kK; + const Element* K = dK + blockIdx.z * kK * kN; + const Element* V = dV + blockIdx.z * kP * kN + blockIdx.y * kTP * kN; + Element* O = + dO + blockIdx.z * kM * kP + blockIdx.x * (kTM * kP) + blockIdx.y * kTP; + + Element* sQ_ptr = reinterpret_cast<Element*>(buf); + Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK; + Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK; + Element* sO_ptr = sQ_ptr; + + typename KeTraits::TiledMma mma; + typename KeTraits::TiledCopyG2S tiled_copy_g2s; + + // Build the copy plan for QK from global memory to shared memory. + auto g2s_copy_qk = make_g2s_qk< + Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ, + typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK, + typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kTK, kTK); + + /** + * In FractalTensor, The size of the V matrix is [kN, kP], and the size + * processed in a single SM Block is [kN, kTP]. When split along the N + * dimension, the size is [kTN, kTP]. Therefore, the stride for global + * memory should be set to kTN * kP. + * + * In the current implementation, the shape of the V matrix is [kP, kN], and + * the block size processed by a single Block is [kTP, kN]. Therefore, the + * stride only needs to be set to kTN each time. + */ + auto g2s_copy_v = + make_g2s_v<Element, typename KeTraits::GmemLayoutV, + typename KeTraits::SmemLayoutV, + typename KeTraits::TiledCopyG2S>(V, sV_ptr, kTN); + + auto acc0 = get_acc<kTM, kTN>(mma); + auto acco = get_acc<kTM, kTP>(mma); + + auto m_new = make_tensor<float>(Shape<Int<2 * size<1>(acc0)>>{}); + auto lse_new = make_fragment_like(m_new); + + auto s2r_pipeline_qk = + make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{}, + typename KeTraits::SmemLayoutK{}, acc0, + typename KeTraits::LoadS2RCopyAtom{}, mma); + + auto s2r_pipeline_v = + make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, + typename KeTraits::LoadS2RCopyAtom{}, mma); + + // Issue global to shared memory copy before the main loop. + g2s_copy_qk.prologue(); + + fill(lse_new, 0.0f); + fill(m_new, -INFINITY); + clear(acco); + + /** + * Flash Attention performs two-level tiling for each SM Block, splitting + * along the N dimension and the K dimension. The Q matrix is split along + * the K dimension, the V matrix is split along the N dimension, and the K + * matrix is split along both dimensions simultaneously. + */ + int split_n = kN / kTN; + for (int n = 0; n < split_n; ++n) { + clear(acc0); + int slice_k = kK / kTK - 1; + for (int k = 0; k < slice_k; ++k) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_qk.body(); + // Load data from shared memory into register and issue MMA. + s2r_pipeline_qk.body(); + } + + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.prologue(); + s2r_pipeline_qk.epilogue(); + + // scores = dot(q, k) + auto scores = + make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); + + auto m_old = make_fragment_like(m_new); + copy(m_new, m_old); + + auto scores_max = make_fragment_like(m_new); + + // scores_max = reduce_max(scores, axis=1) + reduce_max<4, true>(scores, scores_max); + + // Compute new partial max value. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = max(m_new(ax0), scores_max(ax0)); + } + + auto acco_rowcol = + make_tensor(acco.data(), convert_layout_scores(acco.layout())); + + // Renormalizatio for the previous block. + for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) { + float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); + lse_new(ax0) = lse_new(ax0) * scale; + for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) { + acco_rowcol(ax0, ax1) *= scale; + } + } + + for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { + float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); + lse_new(ax0) = lse_new(ax0) * m_scaled; + for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { + scores(ax0, ax1) = + exp(scores(ax0, ax1) * softmax_scale - m_scaled); + } + } + + auto scores_sum = make_fragment_like(lse_new); + reduce_sum<4>(scores, scores_sum); + + for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { + lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); + } + + // TODO(KuangjuX): Understand the following code. + auto frag = convert_type<Element>(scores); + auto rP = make_tensor(make_rmem_ptr<Element>(&frag), scores.layout()); + auto rP_Aregs = + make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); + + /** + * In FractalTensor, the `kTN` dimension is split again. To simplify the + * current implementation of rhe pipeline flashattention, the `tile_n` + * is hardcoded to 0 at this point. + */ + const int tile_n = 0; + for (int tile_ = 0; tile_ < tile_n; ++tile_) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.body(); + s2r_pipeline_v.body(rP_Aregs); + } + + cp_async_wait_flash<0>(); + __syncthreads(); + + if (n < split_n - 1) { + /** + * Update K tile because the entire K Block will be processed in a + * single SM Block. + * + * For example, In `TileFusion`: + * ```cpp + * for (int n = 0; n < GIteratorV::sc0; ++n) { + * load_sv(gVs(n), sV); + * for (int k = 0; k < GIteratorQ::sc1; ++k) { + * load_sq(gQs(k), sQ); + * load_sk(gKs(k, n), sK); + * } + * } + * ``` + */ + g2s_copy_qk.update_tile_K(kTN, kK); + /** + * `load_q_once` means that at this point `kK == kTK`, and the Q is + * loaded into shared memory in blocks only once. In this case, we + * only need to update the pointer of K and do not need to update + * the pointer for Q, because the blocking along the k dimension + * will not be executed, thus the Q is always reloaded. + */ + if (load_q_once) { + g2s_copy_qk.prologue_K(); + } + } + + s2r_pipeline_v.epilogue(rP_Aregs); + } + + // Store O from registers to shared memory and then to global memory. + store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco, + typename KeTraits::StoreR2SCopyAtom{}, mma); + __syncthreads(); + + store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{}, + typename KeTraits::SmemLayoutO{}, + typename KeTraits::TiledCopyS2G{}); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu new file mode 100644 index 00000000..bf3fb116 --- /dev/null +++ b/benchmarks/cpp/flashattention/main.cu @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cutlass_fa.cuh" +#include "util.hpp" + +void run(bool check = true) { + using InType = cutlass::half_t; + using AccType = cutlass::half_t; + using OutType = cutlass::half_t; + + static constexpr int kM = 64; + static constexpr int kN = 64; + static constexpr int kK = 128; + static constexpr int kP = 128; + + static constexpr int kTM = 64; + static constexpr int kTN = 64; + static constexpr int kTK = 128; + static constexpr int kTP = 128; + + static constexpr int kBatch = 1; + + static constexpr int kWarpPerRow = 1; + static constexpr int kWarpPerCol = 1; + static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32; + static constexpr int kStagesQK = 1; + static constexpr int kStagesV = 1; + + static_assert(kK == kTK, + "The current implementation requires kTK == K for now."); + static_assert(kP == kTP, + "The current implementation requires kTP == P for now."); + + // initialize data + thrust::host_vector<InType> h_a(kM * kK * kBatch); + + for (int i = 0; i < h_a.size(); ++i) + h_a[i] = static_cast<InType>(rand_float()); + + thrust::host_vector<InType> h_b(kK * kN * kBatch); + for (int i = 0; i < h_b.size(); ++i) + h_b[i] = static_cast<InType>(rand_float()); + + thrust::host_vector<InType> h_c(kN * kP * kBatch); + for (int i = 0; i < h_c.size(); ++i) + h_c[i] = static_cast<InType>(rand_float()); + + thrust::host_vector<InType> h_d(kM * kP * kBatch); + thrust::fill(h_d.begin(), h_d.end(), 0.); + + // Host side memory initialization. + thrust::host_vector<InType> acc(kM * kN * kBatch); + thrust::fill(acc.begin(), acc.end(), 0.); + + thrust::host_vector<InType> exp_values(kM * kP * kBatch); + thrust::fill(exp_values.begin(), exp_values.end(), 0.); + + thrust::host_vector<InType> h_o(kM * kP * kBatch); + thrust::fill(h_o.begin(), h_o.end(), 0.); + + thrust::host_vector<InType> cur_row_max(kM * kBatch); + thrust::fill(cur_row_max.begin(), cur_row_max.end(), 0.); + + thrust::host_vector<InType> prev_row_max(kM * kBatch); + thrust::fill(prev_row_max.begin(), prev_row_max.end(), 0.); + + thrust::host_vector<InType> new_row_max(kM * kBatch); + thrust::fill(new_row_max.begin(), new_row_max.end(), 0.); + + thrust::host_vector<InType> prev_norm_vec(kM * kBatch); + thrust::fill(prev_norm_vec.begin(), prev_norm_vec.end(), 0.); + + thrust::host_vector<InType> new_norm_vec(kM * kBatch); + thrust::fill(new_norm_vec.begin(), new_norm_vec.end(), 0.); + + thrust::host_vector<InType> prev_sum_vec(kM * kBatch); + thrust::fill(prev_sum_vec.begin(), prev_sum_vec.end(), 0.); + + thrust::host_vector<InType> cur_sum_vec(kM * kBatch); + thrust::fill(cur_sum_vec.begin(), cur_sum_vec.end(), 0.); + + thrust::host_vector<InType> new_sum_vec(kM * kBatch); + thrust::fill(new_sum_vec.begin(), new_sum_vec.end(), 0.); + + thrust::device_vector<InType> d_a = h_a; + thrust::device_vector<InType> d_b = h_b; + thrust::device_vector<InType> d_c = h_c; + thrust::device_vector<InType> d_d = h_d; + + const InType* A = thrust::raw_pointer_cast(d_a.data()); + const InType* B = thrust::raw_pointer_cast(d_b.data()); + const InType* C = thrust::raw_pointer_cast(d_c.data()); + InType* D = thrust::raw_pointer_cast(d_d.data()); + + int block_x = (kM + kTM - 1) / kTM; + int block_y = (kP + kTP - 1) / kTP; + int block_z = kBatch; + + dim3 grid(block_x, block_y, block_z); + dim3 block(kThreads, 1, 1); + + int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP); + int shm_output = kTM * kTP; + int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) + : shm_input * sizeof(InType); + + using Traits = + benchmarks::cutlass_wrapper::FATraits<cutlass::half_t, kM, kN, kK, kP, + kTM, kTN, kTK, kTP, kWarpPerRow, + kWarpPerCol>; + + auto fa_kernel = + benchmarks::cutlass_wrapper::fa_kernel<cutlass::half_t, Traits, kM, kN, + kK, kP, kTM, kTN, kTK, kTP, + kThreads, kStagesQK, kStagesV>; + + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute( + fa_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + } + + fa_kernel<<<grid, block, shm_size, 0>>>(A, B, C, D); + + cudaDeviceSynchronize(); +} + +int main() { run(); } diff --git a/benchmarks/cpp/flashattention/reduce.cuh b/benchmarks/cpp/flashattention/reduce.cuh new file mode 100644 index 00000000..f85d34d6 --- /dev/null +++ b/benchmarks/cpp/flashattention/reduce.cuh @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_utils.cuh" + +#include <cute/tensor.hpp> + +namespace benchmarks { +namespace cutlass_wrapper { + +using namespace cute; + +struct MaxOp_float { + DEVICE float operator()(float const& x, float const& y) { + return max(x, y); + } +}; + +template <typename T> +struct SumOp { + DEVICE T operator()(T const& x, T const& y) { return x + y; } +}; + +template <typename T> +struct SumAbsOp { + DEVICE T operator()(T const& x, T const& y) { return x + abs(y); } +}; + +template <int THREADS> +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || + THREADS == 4); + template <typename T, typename Operator> + static DEVICE T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce<OFFSET>::run(x, op); + } +}; + +template <> +struct Allreduce<2> { + template <typename T, typename Operator> + static DEVICE T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template <bool zero_init, typename Engine0, typename Layout0, typename Engine1, + typename Layout1, typename Operator> +DEVICE void thread_reduce_(cute::Tensor<Engine0, Layout0> const& tensor, + cute::Tensor<Engine1, Layout1>& summary, + Operator& op) { + using namespace cute; + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = + zero_init ? op(0, tensor(mi, 0)) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template <typename Engine0, typename Layout0, typename Engine1, + typename Layout1, typename Operator> +DEVICE void quad_allreduce_(cute::Tensor<Engine0, Layout0>& dst, + cute::Tensor<Engine1, Layout1>& src, Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template <typename Engine0, typename Layout0, typename Engine1, + typename Layout1, typename Operator> +DEVICE void eight_allreduce_(cute::Tensor<Engine0, Layout0>& dst, + cute::Tensor<Engine1, Layout1>& src, + Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<8>::run(src(i), op); + } +} + +template <int Rthreads, typename Engine0, typename Layout0, typename Engine1, + typename Layout1, typename Operator> +DEVICE void allreduce_(cute::Tensor<Engine0, Layout0>& dst, + cute::Tensor<Engine1, Layout1>& src, Operator& op) { + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<Rthreads>::run(src(i), op); + } +} + +template <int Rthreads, bool zero_init = true, typename Engine0, + typename Layout0, typename Engine1, typename Layout1, + typename Operator> +DEVICE void reduce_(cute::Tensor<Engine0, Layout0> const& tensor, + cute::Tensor<Engine1, Layout1>& summary, Operator& op) { + thread_reduce_<zero_init>(tensor, summary, op); + allreduce_<Rthreads>(summary, summary, op); +} + +template <int Rthreads, bool zero_init = true, typename Engine0, + typename Layout0, typename Engine1, typename Layout1> +DEVICE void reduce_max(cute::Tensor<Engine0, Layout0> const& tensor, + cute::Tensor<Engine1, Layout1>& max) { + MaxOp_float max_op; + reduce_<Rthreads, zero_init>(tensor, max, max_op); +} + +template <int Rthreads, typename Engine0, typename Layout0, typename Engine1, + typename Layout1> +DEVICE void reduce_sum(cute::Tensor<Engine0, Layout0> const& tensor, + cute::Tensor<Engine1, Layout1>& sum) { + SumOp<float> sum_op; + reduce_<Rthreads>(tensor, sum, sum_op); +} + +template <int Rthreads, typename Engine0, typename Layout0, typename Engine1, + typename Layout1> +DEVICE void reduce_sumabs(cute::Tensor<Engine0, Layout0> const& tensor, + cute::Tensor<Engine1, Layout1>& sum) { + SumAbsOp<float> sumabs_op; + reduce_<Rthreads>(tensor, sum, sumabs_op); +} + +} // namespace cutlass_wrapper +} // namespace benchmarks diff --git a/benchmarks/cpp/flashattention/util.hpp b/benchmarks/cpp/flashattention/util.hpp new file mode 100644 index 00000000..1cc00eb4 --- /dev/null +++ b/benchmarks/cpp/flashattention/util.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "util/debug.hpp" + +#include <thrust/device_vector.h> +#include <thrust/host_vector.h> + +float rand_float(float a = 1e-1, float b = 5e-2) { + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; +} + +bool check_results(const __half* values1, const __half* values2, int numel) { + bool passed = true; + const float epsilon = 1e-1; + + for (int i = 0; i < numel; ++i) { + if (fabs(__half2float(values1[i]) - __half2float(values2[i])) > + epsilon) { + printf("%d-th value differs: %.3f vs. %.3f\n", i, + __half2float(values1[i]), __half2float(values2[i])); + passed = false; + break; + } + } + return passed; +} diff --git a/benchmarks/utils/cpp/cutlass/copy.cuh b/benchmarks/utils/cpp/cutlass/copy.cuh index b3b105b8..08a4c1d0 100644 --- a/benchmarks/utils/cpp/cutlass/copy.cuh +++ b/benchmarks/utils/cpp/cutlass/copy.cuh @@ -32,6 +32,13 @@ DEVICE void __copy_async() { wait_group<0>(); } +template <int N> +DEVICE void cp_async_wait_flash() { +#if defined(CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + // Copy a 2d data tile from global memory to shared memory template <typename Element, typename SrcLayout, typename DstLayout, typename TiledCopy>