Skip to content

Commit

Permalink
fix R2S Storer.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Jan 2, 2025
1 parent 3eed106 commit 44fba6a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
6 changes: 3 additions & 3 deletions benchmarks/cpp/flashattention/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,10 @@ DEVICE auto store_r2s_o(Element* sO_ptr, SOLayout sO_layout, RegO& o,
auto r2s_copy_o = make_tiled_copy_C(copy_atom, tiled_mma);
auto r2s_thr_copy_o = r2s_copy_o.get_thread_slice(threadIdx.x);

auto sOs = r2s_thr_copy_o.partition_D(sO);
auto rO_copy_view = r2s_thr_copy_o.retile_S(o);
auto src = r2s_thr_copy_o.retile_S(o);
auto dst = r2s_thr_copy_o.partition_D(sO);

cute::copy(r2s_copy_o, rO_copy_view, sOs);
cute::copy(r2s_copy_o, src, dst);
}

template <typename Element, typename GOLayout, typename SOLayout,
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/cpp/flashattention/cutlass_fa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct FATraits : public Base {
decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<kTM>, Int<kTP>>{}));

using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>;
using StoreR2SCopyAtom = Copy_Atom<DefaultCopy, Element>;

static constexpr int kWarps = kThreads / 32;

Expand Down Expand Up @@ -121,7 +122,7 @@ __global__ void __launch_bounds__(Nthreads)
auto g2s_copy_v =
make_g2s_v<Element, typename KeTraits::GmemLayoutV,
typename KeTraits::SmemLayoutV,
typename KeTraits::TiledCopyG2S>(V, sV_ptr, kTN * kP);
typename KeTraits::TiledCopyG2S>(V, sV_ptr, kTN);

auto acc0 = get_acc<kTM, kTN>(mma);
auto acco = get_acc<kTM, kTP>(mma);
Expand Down Expand Up @@ -275,9 +276,9 @@ __global__ void __launch_bounds__(Nthreads)
// auto acco_f16 = make_tensor(make_rmem_ptr<Element>(&frag2),
// acco.layout());

// store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco_f16,
// typename KeTraits::SmemCopyAtom{}, mma);
// __syncthreads();
store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco,
typename KeTraits::StoreR2SCopyAtom{}, mma);
__syncthreads();

store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{},
typename KeTraits::SmemLayoutO{},
Expand Down

0 comments on commit 44fba6a

Please sign in to comment.