diff --git a/.vscode/settings.json b/.vscode/settings.json
index 81386e81..8a6e4a8e 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -1,20 +1,19 @@
 {
-  "gotoSymbolStack.currentStackPosition": 0,
-  "gotoSymbolStack.maxStackPosition": 0,
-  "gotoSymbolStack.filePositionInfo": [],
-  "files.associations": {
-    "*.tcc": "cpp",
-    "optional": "cpp",
-    "ratio": "cpp",
-    "system_error": "cpp",
-    "array": "cpp",
-    "functional": "cpp",
-    "tuple": "cpp",
-    "type_traits": "cpp",
-    "utility": "cpp",
-    "variant": "cpp",
-    "compare": "cpp",
-    "concepts": "cpp",
-    "random": "cpp"
-  }
+    "files.associations": {
+        "array": "cpp",
+        "string": "cpp",
+        "string_view": "cpp",
+        "span": "cpp",
+        "bitset": "cpp",
+        "initializer_list": "cpp",
+        "utility": "cpp",
+        "*.tcc": "cpp",
+        "chrono": "cpp",
+        "random": "cpp",
+        "limits": "cpp",
+        "semaphore": "cpp"
+    },
+    "gotoSymbolStack.currentStackPosition": 0,
+    "gotoSymbolStack.maxStackPosition": 0,
+    "gotoSymbolStack.filePositionInfo": []
 }
diff --git a/benchmarks/cpp/flashattention/CMakeLists.txt b/benchmarks/cpp/flashattention/CMakeLists.txt
new file mode 100644
index 00000000..2b73af97
--- /dev/null
+++ b/benchmarks/cpp/flashattention/CMakeLists.txt
@@ -0,0 +1,19 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the
+# MIT License.
+# --------------------------------------------------------------------------
+
+cmake_minimum_required(VERSION 3.25 FATAL_ERROR)
+project(flash_attention_bench LANGUAGES C CXX CUDA)
+
+set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
+                      "${PROJECT_SOURCE_DIR}/../../../cmake")
+set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party")
+
+include(generic)
+
+include_directories("${PROJECT_SOURCE_DIR}/../../../include")
+include_directories("${PROJECT_SOURCE_DIR}/../../utils/cpp")
+include_directories("${THIRD_PARTY_DIR}/cutlass/include")
+
+add_executable(flash_attn main.cu)
diff --git a/benchmarks/cpp/flashattention/Makefile b/benchmarks/cpp/flashattention/Makefile
new file mode 100644
index 00000000..2b592215
--- /dev/null
+++ b/benchmarks/cpp/flashattention/Makefile
@@ -0,0 +1,16 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+
+BUILD_DIR := build
+
+.PHONY: build clean
+
+build:
+	@mkdir -p $(BUILD_DIR)
+	@cd $(BUILD_DIR) && cmake .. && make -j$(proc)
+
+clean:
+	@rm -rf $(BUILD_DIR)
diff --git a/benchmarks/cpp/flashattention/convert.cuh b/benchmarks/cpp/flashattention/convert.cuh
new file mode 100644
index 00000000..2480eee4
--- /dev/null
+++ b/benchmarks/cpp/flashattention/convert.cuh
@@ -0,0 +1,71 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "cuda_utils.cuh"
+
+#include <cute/layout.hpp>
+#include <cute/tensor.hpp>
+#include <cutlass/numeric_conversion.h>
+
+namespace benchmarks {
+namespace cutlass_wrapper {
+
+using namespace cute;
+
+template <typename To_type, typename Engine, typename Layout>
+CUTE_DEVICE auto convert_type(cute::Tensor<Engine, Layout> const& tensor) {
+    using From_type = typename Engine::value_type;
+    constexpr int numel = decltype(size(tensor))::value;
+    cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
+    auto frag =
+        convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(
+            tensor.data()));
+    return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
+}
+
+template <typename Layout>
+DEVICE auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
+    using namespace cute;
+    static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
+    static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
+    auto l = logical_divide(rowcol_layout,
+                            Shape<Underscore, Shape<Underscore, Int<2>>>{});
+
+    return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)),
+                                   get<0>(get<1>(get<1>(l)))),
+                       get<1>(get<0>(l)), get<1>(get<1>(get<1>(l))));
+}
+
+DEVICE auto convert_layout_C_Aregs() {
+    using namespace cute;
+    auto layout_s = Layout<Shape<Shape<_2, _2>, _2, _16>>{};
+    auto l = logical_divide(layout_s, Shape<Underscore, Underscore, _2>{});
+
+    return make_layout(
+        make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))),
+        get<1>(l), get<1>(get<2>(l)));
+}
+
+template <class LayoutType>
+DEVICE auto convert_layout_scores(LayoutType layout_s) {
+    using namespace cute;
+    static_assert(decltype(size<0>(layout_s))::value == 4);
+    static_assert(decltype(rank(layout_s))::value == 3);
+
+    auto l = logical_divide(layout_s, Shape<_2>{});
+    return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)),
+                       make_layout(get<0>(get<0>(l)), get<2>(l)));
+}
+
+template <int ATOMNUM, class LayoutType>
+DEVICE auto convert_layout_scores_copyview(LayoutType layout_s) {
+    using namespace cute;
+
+    auto l = logical_divide(layout_s, Shape<Underscore, Int<ATOMNUM>>{});
+    return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l)));
+}
+
+}  // namespace cutlass_wrapper
+}  // namespace benchmarks
diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh
new file mode 100644
index 00000000..ced7ae92
--- /dev/null
+++ b/benchmarks/cpp/flashattention/copy.cuh
@@ -0,0 +1,634 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "cuda_utils.cuh"
+
+#include <cute/tensor.hpp>
+#include <cutlass/numeric_conversion.h>
+
+namespace benchmarks {
+namespace cutlass_wrapper {
+
+using namespace cute;
+
+namespace detail {
+
+template <typename GQTensor, typename SQTensor, typename GKTensor,
+          typename SKTensor, typename TiledCopy>
+class G2SCopyQK {
+  public:
+    DEVICE G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK,
+                     TiledCopy tiled_copy, int gQ_stride, int sQ_stride,
+                     int gK_stride, int sK_stride, int num_stage = 2)
+        : gQ(gQ),
+          sQ(sQ),
+          gK(gK),
+          sK(sK),
+          gQ_stride(gQ_stride),
+          sQ_stride(sQ_stride),
+          gK_stride(gK_stride),
+          sK_stride(sK_stride),
+          cur_iter(0),
+          cur_iter_sk(0),
+          num_stage(num_stage) {}
+
+    /**
+     * @brief Update the pointer of the global K tensor.
+     *
+     * Since the K matrix is split along both the n and k dimensions, the
+     * pointer offset for the K matrix needs to be updated to the next kTN * kK
+     * position during the next n dimension iteration.
+     *
+     * @param gK_slice The stride in N dimension.
+     * @param gK_stride The stride in K dimension.
+     */
+    DEVICE void update_tile_K(int gK_slice, int gK_stride) {
+        gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride;
+    }
+
+    /**
+     * @brief Preload the K matrix. When `load_q_once` is true, the Q matrix
+     * only needs to be loaded once and does not require repeated loading, while
+     * the K matrix needs to be updated and loaded.
+     */
+    DEVICE void prologue_K() {
+#pragma unroll
+        for (int m = 0; m < size<1>(gK); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gK); ++k) {
+                cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k));
+            }
+        }
+
+        cute::cp_async_fence();
+
+        gK.data() = gK.data() + gK_stride;
+        sK.data() = sK.data() + sK_stride;
+
+        if ((cur_iter_sk + 1) % num_stage == 0) {
+            sK.data() = sK.data() + (-sK_stride * num_stage);
+        }
+
+        cur_iter_sk++;
+    }
+
+    DEVICE void prologue() {
+#pragma unroll
+        for (int m = 0; m < size<1>(gQ); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gQ); ++k) {
+                cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k));
+            }
+        }
+
+#pragma unroll
+        for (int m = 0; m < size<1>(gK); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gK); ++k) {
+                cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k));
+            }
+        }
+
+        cute::cp_async_fence();
+
+        gQ.data() = gQ.data() + gQ_stride;
+        sQ.data() = sQ.data() + sQ_stride;
+        gK.data() = gK.data() + gK_stride;
+        sK.data() = sK.data() + sK_stride;
+
+        // Circlically read SMEM Buffer
+        if ((cur_iter + 1) % num_stage == 0) {
+            sQ.data() = sQ.data() + (-sQ_stride * num_stage);
+            sK.data() = sK.data() + (-sK_stride * num_stage);
+        }
+
+        cur_iter++;
+    }
+
+    DEVICE void body() {
+#pragma unroll
+        for (int m = 0; m < size<1>(gQ); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gQ); ++k) {
+                cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k));
+            }
+        }
+
+#pragma unroll
+        for (int m = 0; m < size<1>(gK); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gK); ++k) {
+                cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k));
+            }
+        }
+
+        cute::cp_async_fence();
+
+        gQ.data() = gQ.data() + gQ_stride;
+        sQ.data() = sQ.data() + sQ_stride;
+        gK.data() = gK.data() + gK_stride;
+        sK.data() = sK.data() + sK_stride;
+
+        if ((cur_iter + 1) % num_stage == 0) {
+            sQ.data() = sQ.data() + (-sQ_stride * num_stage);
+            sK.data() = sK.data() + (-sK_stride * num_stage);
+        }
+
+        cur_iter++;
+    }
+
+    DEVICE void epilogue() {
+#pragma unroll
+        for (int m = 0; m < size<1>(gQ); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gQ); ++k) {
+                cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k));
+            }
+        }
+
+#pragma unroll
+        for (int m = 0; m < size<1>(gK); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gK); ++k) {
+                cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k));
+            }
+        }
+
+        cute::cp_async_fence();
+    }
+
+  private:
+    GQTensor& gQ;
+    SQTensor& sQ;
+    GKTensor& gK;
+    SKTensor& sK;
+    TiledCopy tiled_copy;
+    int gQ_stride;
+    int sQ_stride;
+    int gK_stride;
+    int sK_stride;
+    int cur_iter;
+    int cur_iter_sk;
+    int num_stage;
+};
+
+template <typename GVTensor, typename SVTensor, typename TiledCopy>
+class G2SCopyV {
+  public:
+    DEVICE G2SCopyV(GVTensor& gV, SVTensor& sV, TiledCopy tiled_copy,
+                    int gV_stride, int sV_stride, int num_stage = 2)
+        : gV(gV),
+          sV(sV),
+          gV_stride(gV_stride),
+          sV_stride(sV_stride),
+          cur_iter(0),
+          num_stage(num_stage) {}
+
+    DEVICE void prologue() {
+#pragma unroll
+        for (int m = 0; m < size<1>(gV); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gV); ++k) {
+                cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k));
+            }
+        }
+
+        cute::cp_async_fence();
+        gV.data() = gV.data() + gV_stride;
+        sV.data() = sV.data() + sV_stride;
+
+        if ((cur_iter + 1) % num_stage == 0) {
+            sV.data() = sV.data() + (-sV_stride * num_stage);
+        }
+
+        cur_iter++;
+    }
+
+    DEVICE void body() {
+#pragma unroll
+        for (int m = 0; m < size<1>(gV); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gV); ++k) {
+                cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k));
+            }
+        }
+
+        cute::cp_async_fence();
+
+        gV.data() = gV.data() + gV_stride;
+        sV.data() = sV.data() + sV_stride;
+
+        if ((cur_iter + 1) % num_stage == 0) {
+            sV.data() = sV.data() + (-sV_stride * num_stage);
+        }
+
+        cur_iter++;
+    }
+
+    DEVICE void epilogue() {
+#pragma unroll
+        for (int m = 0; m < size<1>(gV); ++m) {
+#pragma unroll
+            for (int k = 0; k < size<2>(gV); ++k) {
+                cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k));
+            }
+        }
+        cute::cp_async_fence();
+    }
+
+  private:
+    GVTensor& gV;
+    SVTensor& sV;
+    TiledCopy tiled_copy;
+    int gV_stride;
+    int sV_stride;
+    int cur_iter;
+    int num_stage;
+};
+
+template <typename SQTensor, typename RQMmaView, typename RQCopyView,
+          typename SKTensor, typename RKMmaView, typename RKCopyView,
+          typename RAccTensor, typename TiledCopyQ, typename TiledCopyK,
+          typename TiledMma>
+class S2RPipelineQK {
+  public:
+    DEVICE S2RPipelineQK(SQTensor& sQ, RQMmaView& rQ_mma_view,
+                         RQCopyView& rQ_copy_view, SKTensor& sK,
+                         RKMmaView& rK_mma_view, RKCopyView& rK_copy_view,
+                         RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k,
+                         TiledMma tiled_mma, int sQ_stride, int sK_stride,
+                         int num_stage = 2)
+        : sQ(sQ),
+          rQ_mma_view(rQ_mma_view),
+          rQ_copy_view(rQ_copy_view),
+          sK(sK),
+          rK_mma_view(rK_mma_view),
+          rK_copy_view(rK_copy_view),
+          acc(acc),
+          copy_q(copy_q),
+          copy_k(copy_k),
+          tiled_mma(tiled_mma),
+          sQ_stride(sQ_stride),
+          sK_stride(sK_stride),
+          num_stage(num_stage),
+          cur_iter(0),
+          cur_iter_sq(0) {}
+
+    DEVICE void prologue() {
+        cur_iter = 0;
+        cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{}));
+        cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{}));
+
+#pragma unroll
+        for (int i = 0; i < size<2>(rK_mma_view); ++i) {
+            if (i < size<2>(rK_mma_view) - 1) {
+                cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{}));
+                cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{}));
+            }
+            cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i),
+                       acc);
+        }
+        sQ.data() = sQ.data() + sQ_stride;
+        sK.data() = sK.data() + sK_stride;
+
+        cur_iter++;
+    }
+
+    DEVICE void body() {
+        cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{}));
+        cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{}));
+
+#pragma unroll
+        for (int i = 0; i < size<2>(rK_mma_view); ++i) {
+            if (i < size<2>(rK_mma_view) - 1) {
+                cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1));
+                cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1));
+            }
+            cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i),
+                       acc);
+        }
+        sQ.data() = sQ.data() + sQ_stride;
+        sK.data() = sK.data() + sK_stride;
+
+        if ((cur_iter + 1) % num_stage == 0) {
+            sK.data() = sK.data() + (-sK_stride * num_stage);
+        }
+
+        cur_iter++;
+        cur_iter_sq++;
+    }
+
+    DEVICE void epilogue() {
+        cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{}));
+        cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{}));
+
+#pragma unroll
+        for (int i = 0; i < size<2>(rK_mma_view); ++i) {
+            if (i < size<2>(rK_mma_view) - 1) {
+                cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1));
+                cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1));
+            }
+            cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i),
+                       acc);
+        }
+
+        sQ.data() = sQ.data() + (-sQ_stride * cur_iter_sq);
+        sK.data() = sK.data() + sK_stride;
+
+        if ((cur_iter + 1) % num_stage == 0) {
+            sK.data() = sK.data() + (-sK_stride * num_stage);
+        }
+
+        cur_iter++;
+        cur_iter_sq = 0;
+    }
+
+  private:
+    SQTensor& sQ;
+    RQMmaView& rQ_mma_view;
+    RQCopyView& rQ_copy_view;
+    SKTensor& sK;
+    RKMmaView& rK_mma_view;
+    RKCopyView& rK_copy_view;
+    RAccTensor& acc;
+    TiledCopyQ copy_q;
+    TiledCopyK copy_k;
+    TiledMma tiled_mma;
+    int sQ_stride;
+    int sK_stride;
+    int num_stage;
+    int cur_iter;
+    int cur_iter_sq;
+};
+
+template <typename SVTensor, typename RVMmaView, typename RVCopyView,
+          typename RegAcc, typename TiledCopy, typename TiledMma>
+class S2RPipelineV {
+  public:
+    DEVICE S2RPipelineV(SVTensor& sV, RVMmaView& rV_mma_view,
+                        RVCopyView& rV_copy_view, RegAcc& acc,
+                        TiledCopy tiled_copy, TiledMma tiled_mma, int sV_stride,
+                        int num_stage = 2)
+        : sV(sV),
+          rV_mma_view(rV_mma_view),
+          rV_copy_view(rV_copy_view),
+          acc(acc),
+          tiled_copy(tiled_copy),
+          sV_stride(sV_stride),
+          num_stage(num_stage),
+          cur_iter(0),
+          cur_iter_sv(0) {}
+
+    template <typename RegValue>
+    DEVICE void prologue(RegValue& value) {
+        cur_iter = 0;
+        cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{}));
+#pragma unroll
+        for (int i = 0; i < size<2>(rV_mma_view); ++i) {
+            if (i < size<2>(rV_mma_view) - 1) {
+                cute::copy(tiled_copy, sV(_, _, i + 1),
+                           rV_copy_view(_, _, i + 1));
+            }
+            // TODO(KuangjuX):  Understand this code. Why do we need to use
+            // `value(_, _, cur_iter * size<2>(rV_mma_view) + i)`?
+            cute::gemm(tiled_mma,
+                       value(_, _, cur_iter * size<2>(rV_mma_view) + i),
+                       rV_mma_view(_, _, i), acc);
+        }
+
+        sV.data() = sV.data() + sV_stride;
+        cur_iter++;
+    }
+
+    template <typename RegValue>
+    DEVICE void body(RegValue& value) {
+        cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{}));
+
+#pragma unroll
+        for (int i = 0; i < size<2>(rV_mma_view); ++i) {
+            if (i < size<2>(rV_mma_view) - 1) {
+                cute::copy(tiled_copy, sV(_, _, i + 1),
+                           rV_copy_view(_, _, i + 1));
+            }
+            cute::gemm(tiled_mma,
+                       value(_, _, cur_iter * size<2>(rV_mma_view) + i),
+                       rV_mma_view(_, _, i), acc);
+        }
+
+        sV.data() = sV.data() + sV_stride;
+        if ((cur_iter + 1) % num_stage == 0) {
+            sV.data() = sV.data() + (-sV_stride * num_stage);
+        }
+
+        cur_iter++;
+        cur_iter_sv++;
+    }
+
+    template <typename RegValue>
+    DEVICE void epilogue(RegValue& value) {
+        cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{}));
+
+#pragma unroll
+        for (int i = 0; i < size<2>(rV_mma_view); ++i) {
+            if (i < size<2>(rV_mma_view) - 1) {
+                cute::copy(tiled_copy, sV(_, _, i + 1),
+                           rV_copy_view(_, _, i + 1));
+            }
+            cute::gemm(tiled_mma,
+                       value(_, _, cur_iter * size<2>(rV_mma_view) + i),
+                       rV_mma_view(_, _, i), acc);
+        }
+
+        sV.data() = sV.data() + (-sV_stride * cur_iter_sv);
+
+        if ((cur_iter + 1) % num_stage == 0) {
+            sV.data() = sV.data() + (-sV_stride * num_stage);
+        }
+
+        cur_iter++;
+        cur_iter_sv = 0;
+    }
+
+  private:
+    SVTensor& sV;
+    RVMmaView& rV_mma_view;
+    RVCopyView& rV_copy_view;
+    RegAcc& acc;
+    TiledCopy tiled_copy;
+    TiledMma tiled_mma;
+    int sV_stride;
+    int num_stage;
+    int cur_iter;
+    int cur_iter_sv;
+};
+
+}  // namespace detail
+
+template <typename Element, typename GlobalQLayout, typename SharedQLayout,
+          typename GlobalKLayout, typename SharedKLayout, typename TiledCopy>
+inline __device__ auto make_g2s_qk(const Element* gQ_ptr, Element* sQ_ptr,
+                                   const Element* gK_ptr, Element* sK_ptr,
+                                   int gQ_stride, int gK_stride) {
+    int tid = threadIdx.x;
+
+    auto gQ = make_tensor(make_gmem_ptr(gQ_ptr), GlobalQLayout{});
+    auto sQ = make_tensor(make_smem_ptr(sQ_ptr), SharedQLayout{});
+
+    auto gK = make_tensor(make_gmem_ptr(gK_ptr), GlobalKLayout{});
+    auto sK = make_tensor(make_smem_ptr(sK_ptr), SharedKLayout{});
+
+    TiledCopy tiled_copy;
+
+    auto loader = tiled_copy.get_thread_slice(tid);
+
+    auto gQs = loader.partition_S(gQ);
+    auto gKs = loader.partition_S(gK);
+    auto sQs = loader.partition_D(sQ);
+    auto sKs = loader.partition_D(sK);
+
+    int sQ_stride = size(sQ);
+    int sK_stride = size(sK);
+
+    if (thread0()) {
+        printf("gQ_stride: %d, sQ_stride: %d, gK_stride: %d, sK_stride: %d\n",
+               gQ_stride, sQ_stride, gK_stride, sK_stride);
+    }
+
+    detail::G2SCopyQK copy_qk(gQs, sQs, gKs, sKs, tiled_copy, gQ_stride,
+                              sQ_stride, gK_stride, sK_stride);
+
+    return copy_qk;
+}
+
+template <typename Element, typename GlobalVLayout, typename SharedVLayout,
+          typename TiledCopy>
+DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) {
+    int tid = threadIdx.x;
+
+    auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{});
+    auto sV = make_tensor(make_smem_ptr(sV_ptr), SharedVLayout{});
+
+    TiledCopy tiled_copy;
+
+    auto loader = tiled_copy.get_thread_slice(tid);
+
+    auto gVs = loader.partition_S(gV);
+    auto sVs = loader.partition_D(sV);
+
+    int sV_stride = size(sV);
+
+    if (thread0()) {
+        printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride);
+    }
+
+    detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride);
+
+    return copy_v;
+}
+
+template <typename Element, typename SQLayout, typename SKLayout,
+          typename RegAcc, typename SmemCopyAtom, typename TiledMma>
+DEVICE auto make_s2r_qk(const Element* sQ_ptr, const Element* sK_ptr,
+                        SQLayout sQ_layout, SKLayout sK_layout, RegAcc acc,
+                        SmemCopyAtom copy_atom = SmemCopyAtom{},
+                        TiledMma tiled_mma = TiledMma{}) {
+    int tid = threadIdx.x;
+
+    auto sQ_ = make_tensor(make_smem_ptr(sQ_ptr), sQ_layout);
+    auto sK_ = make_tensor(make_smem_ptr(sK_ptr), sK_layout);
+
+    auto thr_mma = tiled_mma.get_thread_slice(tid);
+
+    auto s2r_copy_q = make_tiled_copy_A(copy_atom, tiled_mma);
+    auto s2r_copy_k = make_tiled_copy_B(copy_atom, tiled_mma);
+    auto s2r_thr_copy_q = s2r_copy_q.get_thread_slice(tid);
+    auto s2r_thr_copy_k = s2r_copy_k.get_thread_slice(tid);
+
+    auto sQ = s2r_thr_copy_q.partition_S(sQ_);
+    auto sK = s2r_thr_copy_k.partition_S(sK_);
+
+    // Thread partition for mma.
+    auto rQ_mma = thr_mma.partition_fragment_A(sQ_);
+    auto rK_mma = thr_mma.partition_fragment_B(sK_);
+
+    // Thread partition for shared to register copy.
+    auto rQ_copy = s2r_thr_copy_q.retile_D(rQ_mma);
+    auto rK_copy = s2r_thr_copy_k.retile_D(rK_mma);
+
+    int sQ_stride = size(sQ_);
+    int sK_stride = size(sK_);
+
+    detail::S2RPipelineQK s2r_pipeline_qk(sQ, rQ_mma, rQ_copy, sK, rK_mma,
+                                          rK_copy, acc, s2r_copy_q, s2r_copy_k,
+                                          tiled_mma, sQ_stride, sK_stride);
+
+    return s2r_pipeline_qk;
+}
+
+template <typename Element, typename SVLayout, typename RegAcc,
+          typename SmemCopyAtom, typename TiledMma>
+DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc,
+                       SmemCopyAtom copy_atom, TiledMma tiled_mma) {
+    int tid = threadIdx.x;
+
+    auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout);
+
+    auto thr_mma = tiled_mma.get_thread_slice(tid);
+
+    auto s2r_copy_v = make_tiled_copy_B(copy_atom, tiled_mma);
+    auto s2r_thr_copy_v = s2r_copy_v.get_thread_slice(tid);
+
+    auto sV = s2r_thr_copy_v.partition_S(sV_);
+
+    auto rV_mma = thr_mma.partition_fragment_B(sV_);
+    auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma);
+
+    int sV_stride = size(sV_);
+
+    detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v,
+                                        tiled_mma, sV_stride);
+
+    return s2r_pipeline_v;
+}
+
+template <typename Element, typename SOLayout, typename RegO,
+          typename SmemCopyAtom, typename TiledMma>
+DEVICE auto store_r2s_o(Element* sO_ptr, SOLayout sO_layout, RegO& o,
+                        SmemCopyAtom copy_atom, TiledMma tiled_mma) {
+    auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout);
+
+    auto r2s_copy_o = make_tiled_copy_C(copy_atom, tiled_mma);
+    auto r2s_thr_copy_o = r2s_copy_o.get_thread_slice(threadIdx.x);
+
+    auto src = r2s_thr_copy_o.retile_S(o);
+    auto dst = r2s_thr_copy_o.partition_D(sO);
+
+    cute::copy(r2s_copy_o, src, dst);
+}
+
+template <typename Element, typename GOLayout, typename SOLayout,
+          typename TiledCopy>
+DEVICE auto store_s2g_o(Element* gO_ptr, const Element* sO_ptr,
+                        GOLayout gO_layout, SOLayout sO_layout,
+                        TiledCopy tiled_copy) {
+    auto gO = make_tensor(make_gmem_ptr(gO_ptr), gO_layout);
+    auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout);
+
+    auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
+
+    auto gO_partition = thr_copy.partition_D(gO);
+    auto sO_partition = thr_copy.partition_S(sO);
+
+#pragma unroll
+    for (int m = 0; m < size<1>(gO_partition); ++m) {
+#pragma unroll
+        for (int n = 0; n < size<2>(gO_partition); ++n) {
+            cute::copy(tiled_copy, sO_partition(_, m, n),
+                       gO_partition(_, m, n));
+        }
+    }
+}
+
+}  // namespace cutlass_wrapper
+}  // namespace benchmarks
diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh
new file mode 100644
index 00000000..6cae18bc
--- /dev/null
+++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh
@@ -0,0 +1,294 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "convert.cuh"
+#include "copy.cuh"
+#include "cuda_utils.cuh"
+#include "cutlass/copy.cuh"
+#include "cutlass/traits_base.cuh"
+#include "reduce.cuh"
+
+namespace benchmarks {
+namespace cutlass_wrapper {
+
+using namespace cute;
+
+template <typename Element_, const int kM, const int kN, const int kK,
+          const int kP, const int kTM, const int kTN, const int kTK,
+          const int kTP, const int kWarpPerRow, const int kWarpPerCol,
+          const int SmemKAtom = 64, const int kSwizzle = 3,
+          typename Base = AccessBase<Element_>>
+struct FATraits : public Base {
+    using Element = Element_;
+
+    // Declare global to shared memory copy layout.
+    using GmemLayoutQ = Layout<Shape<Int<kTM>, Int<kTK>>, Stride<Int<kK>, _1>>;
+    using GmemLayoutK = Layout<Shape<Int<kTN>, Int<kTK>>, Stride<Int<kK>, _1>>;
+    using GmemLayoutV = Layout<Shape<Int<kTP>, Int<kTN>>, Stride<Int<kN>, _1>>;
+    using GmemLayoutO = Layout<Shape<Int<kTM>, Int<kTP>>, Stride<Int<kP>, _1>>;
+
+    static constexpr int kThreads = kWarpPerRow * kWarpPerCol * 32;
+
+    /**
+     * Define the atomic layout of shared memory, which is the smallest
+     * configuration unit of shared memory. Larger shapes are tiled based on the
+     * atomic layout.
+     */
+    using SmemLayoutAtom = decltype(composition(
+        Swizzle<kSwizzle, 3, 3>{},
+        Layout<Shape<_8, Int<SmemKAtom>>, Stride<Int<SmemKAtom>, _1>>{}));
+
+    using SmemLayoutQ =
+        decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTM>, Int<kTK>>{}));
+    using SmemLayoutK =
+        decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTN>, Int<kTK>>{}));
+    using SmemLayoutV =
+        decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTP>, Int<kTN>>{}));
+    using SmemLayoutO =
+        decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTM>, Int<kTP>>{}));
+
+    /**
+     * In the Ampere architecture, loading from shared memory to register memory
+     * requires the use of the `ldmatrix` instruction, while storing from
+     * register memory to shared memory does not have hardware support and uses
+     * a default copy instead.”
+     */
+    using LoadS2RCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>;
+    using StoreR2SCopyAtom = Copy_Atom<DefaultCopy, Element>;
+
+    static constexpr int kWarps = kThreads / 32;
+
+    using TiledMma =
+        TiledMMA<MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
+                 Layout<Shape<Int<kWarpPerRow>, Int<kWarpPerCol>, _1>>,
+                 Tile<Int<16 * kWarpPerRow>, Int<16 * kWarpPerCol>, _16>>;
+
+#ifdef CP_ASYNC_SM80_ENABLED
+    // for Ampere
+    using CopyInstG2S =
+        Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>;
+#else
+    using CopyInstG2S = Copy_Atom<DefaultCopy, Element>;
+#endif
+
+    // TODO(KuangjuX): Understand this configuration.
+    using GmemCopyLayoutAtom =
+        Layout<Shape<Int<kThreads / (SmemKAtom / 8)>, Int<SmemKAtom / 8>>,
+               Stride<Int<SmemKAtom / 8>, _1>>;
+
+    using TiledCopyG2S = decltype(make_tiled_copy(
+        CopyInstG2S{}, GmemCopyLayoutAtom{}, Layout<Shape<_1, _8>>{}));
+
+    using TiledCopyS2G = decltype(make_tiled_copy(
+        Copy_Atom<DefaultCopy, Element>{}, GmemCopyLayoutAtom{},
+        Layout<Shape<_1, _8>>{}));
+};
+
+template <typename Element, typename KeTraits, const int kM, const int kN,
+          const int kK, const int kP, const int kTM, const int kTN,
+          const int kTK, const int kTP, const int Nthreads, const int kStagesQK,
+          const int kStageV>
+__global__ void __launch_bounds__(Nthreads)
+    fa_kernel(const Element* dQ, const Element* dK, const Element* dV,
+              Element* dO) {
+    constexpr float softmax_scale = 1.250000e-01f;
+    const bool load_q_once = (kTK == kK);
+
+    extern __shared__ __align__(sizeof(double)) unsigned char buf_[];
+    auto* buf = reinterpret_cast<Element*>(buf_);
+
+    const Element* Q = dQ + blockIdx.z * kTM * kN + blockIdx.x * kTM * kK;
+    const Element* K = dK + blockIdx.z * kK * kN;
+    const Element* V = dV + blockIdx.z * kP * kN + blockIdx.y * kTP * kN;
+    Element* O =
+        dO + blockIdx.z * kM * kP + blockIdx.x * (kTM * kP) + blockIdx.y * kTP;
+
+    Element* sQ_ptr = reinterpret_cast<Element*>(buf);
+    Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK;
+    Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK;
+    Element* sO_ptr = sQ_ptr;
+
+    typename KeTraits::TiledMma mma;
+    typename KeTraits::TiledCopyG2S tiled_copy_g2s;
+
+    // Build the copy plan for QK from global memory to shared memory.
+    auto g2s_copy_qk = make_g2s_qk<
+        Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ,
+        typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK,
+        typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kTK, kTK);
+
+    /**
+     * In FractalTensor, The size of the V matrix is [kN, kP], and the size
+     * processed in a single SM Block is [kN, kTP]. When split along the N
+     * dimension, the size is [kTN, kTP]. Therefore, the stride for global
+     * memory should be set to kTN * kP.
+     *
+     * In the current implementation, the shape of the V matrix is [kP, kN], and
+     * the block size processed by a single Block is [kTP, kN]. Therefore, the
+     * stride only needs to be set to kTN each time.
+     */
+    auto g2s_copy_v =
+        make_g2s_v<Element, typename KeTraits::GmemLayoutV,
+                   typename KeTraits::SmemLayoutV,
+                   typename KeTraits::TiledCopyG2S>(V, sV_ptr, kTN);
+
+    auto acc0 = get_acc<kTM, kTN>(mma);
+    auto acco = get_acc<kTM, kTP>(mma);
+
+    auto m_new = make_tensor<float>(Shape<Int<2 * size<1>(acc0)>>{});
+    auto lse_new = make_fragment_like(m_new);
+
+    auto s2r_pipeline_qk =
+        make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{},
+                    typename KeTraits::SmemLayoutK{}, acc0,
+                    typename KeTraits::LoadS2RCopyAtom{}, mma);
+
+    auto s2r_pipeline_v =
+        make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco,
+                   typename KeTraits::LoadS2RCopyAtom{}, mma);
+
+    // Issue global to shared memory copy before the main loop.
+    g2s_copy_qk.prologue();
+
+    fill(lse_new, 0.0f);
+    fill(m_new, -INFINITY);
+    clear(acco);
+
+    /**
+     * Flash Attention performs two-level tiling for each SM Block, splitting
+     * along the N dimension and the K dimension. The Q matrix is split along
+     * the K dimension, the V matrix is split along the N dimension, and the K
+     * matrix is split along both dimensions simultaneously.
+     */
+    int split_n = kN / kTN;
+    for (int n = 0; n < split_n; ++n) {
+        clear(acc0);
+        int slice_k = kK / kTK - 1;
+        for (int k = 0; k < slice_k; ++k) {
+            // Barrier to ensure all data are loaded into shared memory.
+            cp_async_wait_flash<0>();
+            __syncthreads();
+            g2s_copy_qk.body();
+            // Load data from shared memory into register and issue MMA.
+            s2r_pipeline_qk.body();
+        }
+
+        cp_async_wait_flash<0>();
+        __syncthreads();
+        g2s_copy_v.prologue();
+        s2r_pipeline_qk.epilogue();
+
+        // scores = dot(q, k)
+        auto scores =
+            make_tensor(acc0.data(), convert_layout_scores(acc0.layout()));
+
+        auto m_old = make_fragment_like(m_new);
+        copy(m_new, m_old);
+
+        auto scores_max = make_fragment_like(m_new);
+
+        // scores_max = reduce_max(scores, axis=1)
+        reduce_max<4, true>(scores, scores_max);
+
+        // Compute new partial max value.
+        for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) {
+            m_new(ax0) = max(m_new(ax0), scores_max(ax0));
+        }
+
+        auto acco_rowcol =
+            make_tensor(acco.data(), convert_layout_scores(acco.layout()));
+
+        // Renormalizatio for the previous block.
+        for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) {
+            float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale);
+            lse_new(ax0) = lse_new(ax0) * scale;
+            for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) {
+                acco_rowcol(ax0, ax1) *= scale;
+            }
+        }
+
+        for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) {
+            float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale);
+            lse_new(ax0) = lse_new(ax0) * m_scaled;
+            for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) {
+                scores(ax0, ax1) =
+                    exp(scores(ax0, ax1) * softmax_scale - m_scaled);
+            }
+        }
+
+        auto scores_sum = make_fragment_like(lse_new);
+        reduce_sum<4>(scores, scores_sum);
+
+        for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) {
+            lse_new(ax0) = lse_new(ax0) + scores_sum(ax0);
+        }
+
+        // TODO(KuangjuX): Understand the following code.
+        auto frag = convert_type<Element>(scores);
+        auto rP = make_tensor(make_rmem_ptr<Element>(&frag), scores.layout());
+        auto rP_Aregs =
+            make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout()));
+
+        /**
+         * In FractalTensor, the `kTN` dimension is split again. To simplify the
+         * current implementation of rhe pipeline flashattention, the `tile_n`
+         * is hardcoded to 0 at this point.
+         */
+        const int tile_n = 0;
+        for (int tile_ = 0; tile_ < tile_n; ++tile_) {
+            // Barrier to ensure all data are loaded into shared memory.
+            cp_async_wait_flash<0>();
+            __syncthreads();
+            g2s_copy_v.body();
+            s2r_pipeline_v.body(rP_Aregs);
+        }
+
+        cp_async_wait_flash<0>();
+        __syncthreads();
+
+        if (n < split_n - 1) {
+            /**
+             * Update K tile because the entire K Block will be processed in a
+             * single SM Block.
+             *
+             * For example, In `TileFusion`:
+             * ```cpp
+             * for (int n = 0; n < GIteratorV::sc0; ++n) {
+             *      load_sv(gVs(n), sV);
+             *      for (int k = 0; k < GIteratorQ::sc1; ++k) {
+             *          load_sq(gQs(k), sQ);
+             *          load_sk(gKs(k, n), sK);
+             *      }
+             * }
+             * ```
+             */
+            g2s_copy_qk.update_tile_K(kTN, kK);
+            /**
+             * `load_q_once` means that at this point `kK == kTK`, and the Q is
+             * loaded into shared memory in blocks only once. In this case, we
+             * only need to update the pointer of K and do not need to update
+             * the pointer for Q, because the blocking along the k dimension
+             * will not be executed, thus the Q is always reloaded.
+             */
+            if (load_q_once) {
+                g2s_copy_qk.prologue_K();
+            }
+        }
+
+        s2r_pipeline_v.epilogue(rP_Aregs);
+    }
+
+    // Store O from registers to shared memory and then to global memory.
+    store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco,
+                typename KeTraits::StoreR2SCopyAtom{}, mma);
+    __syncthreads();
+
+    store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{},
+                typename KeTraits::SmemLayoutO{},
+                typename KeTraits::TiledCopyS2G{});
+}
+
+}  // namespace cutlass_wrapper
+}  // namespace benchmarks
diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu
new file mode 100644
index 00000000..bf3fb116
--- /dev/null
+++ b/benchmarks/cpp/flashattention/main.cu
@@ -0,0 +1,128 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "cutlass_fa.cuh"
+#include "util.hpp"
+
+void run(bool check = true) {
+    using InType = cutlass::half_t;
+    using AccType = cutlass::half_t;
+    using OutType = cutlass::half_t;
+
+    static constexpr int kM = 64;
+    static constexpr int kN = 64;
+    static constexpr int kK = 128;
+    static constexpr int kP = 128;
+
+    static constexpr int kTM = 64;
+    static constexpr int kTN = 64;
+    static constexpr int kTK = 128;
+    static constexpr int kTP = 128;
+
+    static constexpr int kBatch = 1;
+
+    static constexpr int kWarpPerRow = 1;
+    static constexpr int kWarpPerCol = 1;
+    static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32;
+    static constexpr int kStagesQK = 1;
+    static constexpr int kStagesV = 1;
+
+    static_assert(kK == kTK,
+                  "The current implementation requires kTK == K for now.");
+    static_assert(kP == kTP,
+                  "The current implementation requires kTP == P for now.");
+
+    // initialize data
+    thrust::host_vector<InType> h_a(kM * kK * kBatch);
+
+    for (int i = 0; i < h_a.size(); ++i)
+        h_a[i] = static_cast<InType>(rand_float());
+
+    thrust::host_vector<InType> h_b(kK * kN * kBatch);
+    for (int i = 0; i < h_b.size(); ++i)
+        h_b[i] = static_cast<InType>(rand_float());
+
+    thrust::host_vector<InType> h_c(kN * kP * kBatch);
+    for (int i = 0; i < h_c.size(); ++i)
+        h_c[i] = static_cast<InType>(rand_float());
+
+    thrust::host_vector<InType> h_d(kM * kP * kBatch);
+    thrust::fill(h_d.begin(), h_d.end(), 0.);
+
+    // Host side memory initialization.
+    thrust::host_vector<InType> acc(kM * kN * kBatch);
+    thrust::fill(acc.begin(), acc.end(), 0.);
+
+    thrust::host_vector<InType> exp_values(kM * kP * kBatch);
+    thrust::fill(exp_values.begin(), exp_values.end(), 0.);
+
+    thrust::host_vector<InType> h_o(kM * kP * kBatch);
+    thrust::fill(h_o.begin(), h_o.end(), 0.);
+
+    thrust::host_vector<InType> cur_row_max(kM * kBatch);
+    thrust::fill(cur_row_max.begin(), cur_row_max.end(), 0.);
+
+    thrust::host_vector<InType> prev_row_max(kM * kBatch);
+    thrust::fill(prev_row_max.begin(), prev_row_max.end(), 0.);
+
+    thrust::host_vector<InType> new_row_max(kM * kBatch);
+    thrust::fill(new_row_max.begin(), new_row_max.end(), 0.);
+
+    thrust::host_vector<InType> prev_norm_vec(kM * kBatch);
+    thrust::fill(prev_norm_vec.begin(), prev_norm_vec.end(), 0.);
+
+    thrust::host_vector<InType> new_norm_vec(kM * kBatch);
+    thrust::fill(new_norm_vec.begin(), new_norm_vec.end(), 0.);
+
+    thrust::host_vector<InType> prev_sum_vec(kM * kBatch);
+    thrust::fill(prev_sum_vec.begin(), prev_sum_vec.end(), 0.);
+
+    thrust::host_vector<InType> cur_sum_vec(kM * kBatch);
+    thrust::fill(cur_sum_vec.begin(), cur_sum_vec.end(), 0.);
+
+    thrust::host_vector<InType> new_sum_vec(kM * kBatch);
+    thrust::fill(new_sum_vec.begin(), new_sum_vec.end(), 0.);
+
+    thrust::device_vector<InType> d_a = h_a;
+    thrust::device_vector<InType> d_b = h_b;
+    thrust::device_vector<InType> d_c = h_c;
+    thrust::device_vector<InType> d_d = h_d;
+
+    const InType* A = thrust::raw_pointer_cast(d_a.data());
+    const InType* B = thrust::raw_pointer_cast(d_b.data());
+    const InType* C = thrust::raw_pointer_cast(d_c.data());
+    InType* D = thrust::raw_pointer_cast(d_d.data());
+
+    int block_x = (kM + kTM - 1) / kTM;
+    int block_y = (kP + kTP - 1) / kTP;
+    int block_z = kBatch;
+
+    dim3 grid(block_x, block_y, block_z);
+    dim3 block(kThreads, 1, 1);
+
+    int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP);
+    int shm_output = kTM * kTP;
+    int shm_size = shm_input < shm_output ? shm_output * sizeof(InType)
+                                          : shm_input * sizeof(InType);
+
+    using Traits =
+        benchmarks::cutlass_wrapper::FATraits<cutlass::half_t, kM, kN, kK, kP,
+                                              kTM, kTN, kTK, kTP, kWarpPerRow,
+                                              kWarpPerCol>;
+
+    auto fa_kernel =
+        benchmarks::cutlass_wrapper::fa_kernel<cutlass::half_t, Traits, kM, kN,
+                                               kK, kP, kTM, kTN, kTK, kTP,
+                                               kThreads, kStagesQK, kStagesV>;
+
+    if (shm_size > 48 * 1024) {
+        cudaFuncSetAttribute(
+            fa_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
+    }
+
+    fa_kernel<<<grid, block, shm_size, 0>>>(A, B, C, D);
+
+    cudaDeviceSynchronize();
+}
+
+int main() { run(); }
diff --git a/benchmarks/cpp/flashattention/reduce.cuh b/benchmarks/cpp/flashattention/reduce.cuh
new file mode 100644
index 00000000..f85d34d6
--- /dev/null
+++ b/benchmarks/cpp/flashattention/reduce.cuh
@@ -0,0 +1,143 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "cuda_utils.cuh"
+
+#include <cute/tensor.hpp>
+
+namespace benchmarks {
+namespace cutlass_wrapper {
+
+using namespace cute;
+
+struct MaxOp_float {
+    DEVICE float operator()(float const& x, float const& y) {
+        return max(x, y);
+    }
+};
+
+template <typename T>
+struct SumOp {
+    DEVICE T operator()(T const& x, T const& y) { return x + y; }
+};
+
+template <typename T>
+struct SumAbsOp {
+    DEVICE T operator()(T const& x, T const& y) { return x + abs(y); }
+};
+
+template <int THREADS>
+struct Allreduce {
+    static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 ||
+                  THREADS == 4);
+    template <typename T, typename Operator>
+    static DEVICE T run(T x, Operator& op) {
+        constexpr int OFFSET = THREADS / 2;
+        x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
+        return Allreduce<OFFSET>::run(x, op);
+    }
+};
+
+template <>
+struct Allreduce<2> {
+    template <typename T, typename Operator>
+    static DEVICE T run(T x, Operator& op) {
+        x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
+        return x;
+    }
+};
+
+template <bool zero_init, typename Engine0, typename Layout0, typename Engine1,
+          typename Layout1, typename Operator>
+DEVICE void thread_reduce_(cute::Tensor<Engine0, Layout0> const& tensor,
+                           cute::Tensor<Engine1, Layout1>& summary,
+                           Operator& op) {
+    using namespace cute;
+    static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+    static_assert(Layout1::rank == 1, "Only support 1D Tensor");
+    CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
+#pragma unroll
+    for (int mi = 0; mi < size<0>(tensor); mi++) {
+        summary(mi) =
+            zero_init ? op(0, tensor(mi, 0)) : op(summary(mi), tensor(mi, 0));
+#pragma unroll
+        for (int ni = 1; ni < size<1>(tensor); ni++) {
+            summary(mi) = op(summary(mi), tensor(mi, ni));
+        }
+    }
+}
+
+template <typename Engine0, typename Layout0, typename Engine1,
+          typename Layout1, typename Operator>
+DEVICE void quad_allreduce_(cute::Tensor<Engine0, Layout0>& dst,
+                            cute::Tensor<Engine1, Layout1>& src, Operator& op) {
+    using namespace cute;
+    CUTE_STATIC_ASSERT_V(size(dst) == size(src));
+#pragma unroll
+    for (int i = 0; i < size(dst); i++) {
+        dst(i) = Allreduce<4>::run(src(i), op);
+    }
+}
+
+template <typename Engine0, typename Layout0, typename Engine1,
+          typename Layout1, typename Operator>
+DEVICE void eight_allreduce_(cute::Tensor<Engine0, Layout0>& dst,
+                             cute::Tensor<Engine1, Layout1>& src,
+                             Operator& op) {
+    using namespace cute;
+    CUTE_STATIC_ASSERT_V(size(dst) == size(src));
+#pragma unroll
+    for (int i = 0; i < size(dst); i++) {
+        dst(i) = Allreduce<8>::run(src(i), op);
+    }
+}
+
+template <int Rthreads, typename Engine0, typename Layout0, typename Engine1,
+          typename Layout1, typename Operator>
+DEVICE void allreduce_(cute::Tensor<Engine0, Layout0>& dst,
+                       cute::Tensor<Engine1, Layout1>& src, Operator& op) {
+    using namespace cute;
+    CUTE_STATIC_ASSERT_V(size(dst) == size(src));
+#pragma unroll
+    for (int i = 0; i < size(dst); i++) {
+        dst(i) = Allreduce<Rthreads>::run(src(i), op);
+    }
+}
+
+template <int Rthreads, bool zero_init = true, typename Engine0,
+          typename Layout0, typename Engine1, typename Layout1,
+          typename Operator>
+DEVICE void reduce_(cute::Tensor<Engine0, Layout0> const& tensor,
+                    cute::Tensor<Engine1, Layout1>& summary, Operator& op) {
+    thread_reduce_<zero_init>(tensor, summary, op);
+    allreduce_<Rthreads>(summary, summary, op);
+}
+
+template <int Rthreads, bool zero_init = true, typename Engine0,
+          typename Layout0, typename Engine1, typename Layout1>
+DEVICE void reduce_max(cute::Tensor<Engine0, Layout0> const& tensor,
+                       cute::Tensor<Engine1, Layout1>& max) {
+    MaxOp_float max_op;
+    reduce_<Rthreads, zero_init>(tensor, max, max_op);
+}
+
+template <int Rthreads, typename Engine0, typename Layout0, typename Engine1,
+          typename Layout1>
+DEVICE void reduce_sum(cute::Tensor<Engine0, Layout0> const& tensor,
+                       cute::Tensor<Engine1, Layout1>& sum) {
+    SumOp<float> sum_op;
+    reduce_<Rthreads>(tensor, sum, sum_op);
+}
+
+template <int Rthreads, typename Engine0, typename Layout0, typename Engine1,
+          typename Layout1>
+DEVICE void reduce_sumabs(cute::Tensor<Engine0, Layout0> const& tensor,
+                          cute::Tensor<Engine1, Layout1>& sum) {
+    SumAbsOp<float> sumabs_op;
+    reduce_<Rthreads>(tensor, sum, sumabs_op);
+}
+
+}  // namespace cutlass_wrapper
+}  // namespace benchmarks
diff --git a/benchmarks/cpp/flashattention/util.hpp b/benchmarks/cpp/flashattention/util.hpp
new file mode 100644
index 00000000..1cc00eb4
--- /dev/null
+++ b/benchmarks/cpp/flashattention/util.hpp
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "util/debug.hpp"
+
+#include <thrust/device_vector.h>
+#include <thrust/host_vector.h>
+
+float rand_float(float a = 1e-1, float b = 5e-2) {
+    float random = ((float)rand()) / (float)RAND_MAX;
+    float diff = b - a;
+    float r = random * diff;
+    return a + r;
+}
+
+bool check_results(const __half* values1, const __half* values2, int numel) {
+    bool passed = true;
+    const float epsilon = 1e-1;
+
+    for (int i = 0; i < numel; ++i) {
+        if (fabs(__half2float(values1[i]) - __half2float(values2[i])) >
+            epsilon) {
+            printf("%d-th value differs: %.3f vs. %.3f\n", i,
+                   __half2float(values1[i]), __half2float(values2[i]));
+            passed = false;
+            break;
+        }
+    }
+    return passed;
+}
diff --git a/benchmarks/utils/cpp/cutlass/copy.cuh b/benchmarks/utils/cpp/cutlass/copy.cuh
index b3b105b8..08a4c1d0 100644
--- a/benchmarks/utils/cpp/cutlass/copy.cuh
+++ b/benchmarks/utils/cpp/cutlass/copy.cuh
@@ -32,6 +32,13 @@ DEVICE void __copy_async() {
     wait_group<0>();
 }
 
+template <int N>
+DEVICE void cp_async_wait_flash() {
+#if defined(CP_ASYNC_SM80_ENABLED)
+    asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
+#endif
+}
+
 // Copy a 2d data tile from global memory to shared memory
 template <typename Element, typename SrcLayout, typename DstLayout,
           typename TiledCopy>