Skip to content

Commit

Permalink
improve implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Jan 8, 2025
1 parent ca5d6cb commit 69cc57f
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 40 deletions.
38 changes: 14 additions & 24 deletions include/cell/copy/copy_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ struct LoadMatBase {
static constexpr int kNumPerAccess = kAccessInBits / kElmentBits;

/// @brief Returns the lane row of the current thread within a warp.
/// For ldmatrix, threads in a warp are arranged in a 16x2
/// column-major layout:
///
/// | | 0 | 1|
// For ldmatrix, threads in a warp are arranged in a 16x2
// column-major layout:
//
// | | 0 | 1|
// |--|---|---|
// |0 | 0 | 16|
// |1 | 2 | 17|
Expand Down Expand Up @@ -326,7 +326,12 @@ template <typename Global, typename Shared, typename WarpShape,
const tl::Layout kType = Shared::kType>
struct GlobalToSharedLoaderBase;

/// TODO(ying): Determine if a separate implementation is needed for copying an
/// atomic warp tile. The current implementation is quite simple, so it might be
/// possible to simplify it into `GlobalToSharedLoaderImpl`.
///
/// @brief Load a single BaseTile from global memory to shared memory.
///
template <typename Global, typename Shared, typename WarpShape>
struct GlobalToSharedLoaderBase<Global, Shared, WarpShape,
tl::Layout::kRowMajor> {
Expand All @@ -346,41 +351,26 @@ struct GlobalToSharedLoaderBase<Global, Shared, WarpShape,
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;

DEVICE void copy(const DType* src, DType* dst) {
int offset = 0;
uint32_t s_ptr;

#pragma unroll
for (int i = 0; i < kExecCount; ++i) {
s_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(dst + offset));

// a single memory access access 16 bytes
async_copy<16>(src + offset, s_ptr);
offset += kColStride;
}
// a single memory access access 16 bytes
async_copy<16>(src,
static_cast<uint32_t>(__cvta_generic_to_shared(dst)));
}

/// @brief returns the lane row of the current thread within a warp.
DEVICE int lane_row_id() {
// NOTE: When copying a RowMajor data tile, the thread layout is
// interpreted as RowMajor.
int lane_id = threadIdx.x % WARP_SIZE;
return lane_id / kThreadsPerRow;
return lane_id / WarpShape::kColThreads;
}

/// @brief returns the lane col of the current thread within a warp.
DEVICE int lane_col_id() {
// NOTE: When copying a RowMajor data tile, the thread layout is
// interpreted as RowMajor.
int lane_id = threadIdx.x % WARP_SIZE;
return lane_id % kThreadsPerRow;
return lane_id % WarpShape::kColThreads;
}

private:
static constexpr int kThreadsPerRow = WarpShape::kColThreads;

static constexpr int kColStride = kThreadsPerRow * kNumPerAccess;
static constexpr int kExecCount = WarpShape::kCols / kColStride;
};

/// @brief Load a BaseTile from global memory to shared memory.
Expand Down
10 changes: 5 additions & 5 deletions include/cell/copy/global_to_shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include "traits/base.hpp"
#include "types/mod.hpp"

#include <cuda_runtime.h>

namespace tilefusion::cell::copy {
using namespace atom;
namespace tl = tile_layout;
Expand Down Expand Up @@ -277,11 +275,13 @@ struct GlobalToSharedLoader {
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

// This implementation uses a fixed 16x16 `BaseShape` as the atomic data
// tile accessed by threads in a single warp that issues a single load/store
// instruction.
// FIXME(ying): uncomment the following lines to automatically infer the
// warp-level tile shape instead of using a fixed 16x16 `BaseShape`. using
// WarpShape =
// warp::WarpTileShape<DType, typename Shared::Layout, Shared::kType>;

using WarpShape =
warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;

Expand All @@ -300,8 +300,8 @@ struct GlobalToSharedLoader {
static constexpr int kColExec = ExecCounter::kColExec;

static_assert(kRowExec && kColExec,
"Ensure that the execution count for all "
"rows and columns is greater than 0.");
"Ensure that the execution count for all rows and columns is "
"greater than 0.");

template <typename Global>
DEVICE void operator()(const Global& src, Shared& dst) {
Expand Down
27 changes: 27 additions & 0 deletions include/cell/copy/warp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,33 @@ struct GlobalOffsetHelper {
}
};

/**
* FIXME(ying): `kIsSharedLayout` is a temporary fix for an issue in the current
* implementation where `RowMajor` and `ColMajor` layouts are not explicitly
* distinguished between shared memory and global memory. This should be
* addressed in the future with a more robust design. The issue arises as
* follows: suppose we have a shared memory tile with a row-major layout
* declared as:
* using Shared = SharedTile<__half, RowMajor<kRows, kCols>>;
*
* In physical memory, shared memory is organized in units of a base tile,
* which is contiguous in shared memory banks and can be accessed without
* bank conflicts. This differs from global memory, where data is laid out
* contiguously with specific strides defined by the given layout.
*
* These differences are transparent to front-end users. The conflicts in the
* current implementation arise from the fact that such a shared memory layout
* can be declared by the user as above, or created internally by constructs
* like `SharedTileIterator`. When calculating the offset of a warp tile in
* shared memory or copying data, the caller should be aware of the layout of
* the shared memory tile.
*
* `kIsSharedLayout` is a temporary fix to address this issue. When set to
* `false`, the layout is created by the front-end user, since user is not aware
* of how data is physically stored, layout parameters (e.g., `strides`) does
* not correctly reveal the physical layout of data in memory. This requires
* further special treatment.
*/
template <typename WarpLayout, typename WarpShape, typename Shared,
const WarpReuse kMode, const tl::Layout kType = Shared::kType,
const bool kIsSharedLayout = IsSharedLayout<Shared>>
Expand Down
8 changes: 0 additions & 8 deletions include/types/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ struct SwizzledRowMajor;
/// @brief Swizzled row-major layout for storing half-typed 16x16 BaseTile.
template <typename BaseShape>
struct SwizzledRowMajor<32, BaseShape> {
// using BaseShape = traits::BaseTileShape<__half>;

static constexpr int kB = 2;
static constexpr int kM = 3;
static constexpr int kS = 3;
Expand All @@ -52,8 +50,6 @@ struct SwizzledRowMajor<32, BaseShape> {

template <typename BaseShape>
struct SwizzledRowMajor<64, BaseShape> {
// using BaseShape = traits::BaseTileShape<float>;

static constexpr int kB = 2;
static constexpr int kM = 2;
static constexpr int kS = 3;
Expand Down Expand Up @@ -106,8 +102,6 @@ struct SwizzledColMajor;

template <typename BaseShape>
struct SwizzledColMajor<64, BaseShape> {
// using BaseShape = traits::BaseTileShape<__half>;

static constexpr int kB = 2;
static constexpr int kM = 2;
static constexpr int kS = 3;
Expand All @@ -131,8 +125,6 @@ struct SwizzledColMajor<64, BaseShape> {

template <typename BaseShape>
struct SwizzledColMajor<128, BaseShape> {
// using BaseShape = traits::BaseTileShape<__half>;

static constexpr int kB = 2;
static constexpr int kM = 3;
static constexpr int kS = 3;
Expand Down
11 changes: 8 additions & 3 deletions tests/cpp/cell/test_g2s_load.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,19 @@ __global__ void copy_g2s(const Element* src_ptr, Element* dst_ptr,
__copy_async();
__syncthreads();

storer(inter, dst);
__syncthreads();

#if defined(DEBUG)
if (thread(0)) {
printf("\nshared\n");
inter.dump_value();

printf("\nglobal\n");
dst.dump_value();
printf("\n");
}
#endif

storer(inter, dst);
__syncthreads();
}

template <typename Element, typename WarpLayout, const int kRows,
Expand Down Expand Up @@ -133,6 +137,7 @@ TEST(GlobalToSharedLoad, test_row_major_load) {
run_test_row_major<__half, tl::RowMajor<2, 4>, 96, 128>();

run_test_row_major<float, tl::RowMajor<1, 1>, 16, 16>();
run_test_row_major<float, tl::RowMajor<1, 2>, 16, 32>();
run_test_row_major<float, tl::RowMajor<1, 4>, 32, 128>();
run_test_row_major<float, tl::RowMajor<4, 1>, 192, 32>();
run_test_row_major<float, tl::RowMajor<2, 2>, 64, 128>();
Expand Down

0 comments on commit 69cc57f

Please sign in to comment.