diff --git a/.clang-format b/.clang-format index 54d6495..94b9090 100644 --- a/.clang-format +++ b/.clang-format @@ -5,6 +5,7 @@ ColumnLimit: 80 IndentWidth: 4 AccessModifierOffset: -2 DerivePointerAlignment: false +# If true, empty lines at the start of blocks are kept. KeepEmptyLinesAtTheStartOfBlocks: false SortIncludes: true IncludeBlocks: Regroup @@ -17,7 +18,15 @@ IncludeCategories: Priority: 2 - Regex: '"([A-Za-z0-9.\Q/-_\E])+"' Priority: 1 - + AllowShortLoopsOnASingleLine: true AllowShortIfStatementsOnASingleLine: true Cpp11BracedListStyle: true +# If true, always break after the template<...> of a template declaration. +AlwaysBreakTemplateDeclarations: true +# If false, a function declaration's or function definition's parameters will +# either all be on the same line or will have one line each. +BinPackArguments: true +BreakConstructorInitializersBeforeComma: true +# The maximum number of consecutive empty lines to keep. +MaxEmptyLinesToKeep: 1 diff --git a/.vscode/settings.json b/.vscode/settings.json index 8a6e4a8..e09dcd8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,19 +1,20 @@ { - "files.associations": { - "array": "cpp", - "string": "cpp", - "string_view": "cpp", - "span": "cpp", - "bitset": "cpp", - "initializer_list": "cpp", - "utility": "cpp", - "*.tcc": "cpp", - "chrono": "cpp", - "random": "cpp", - "limits": "cpp", - "semaphore": "cpp" - }, - "gotoSymbolStack.currentStackPosition": 0, - "gotoSymbolStack.maxStackPosition": 0, - "gotoSymbolStack.filePositionInfo": [] + "files.associations": { + "array": "cpp", + "string": "cpp", + "string_view": "cpp", + "span": "cpp", + "bitset": "cpp", + "initializer_list": "cpp", + "utility": "cpp", + "*.tcc": "cpp", + "chrono": "cpp", + "random": "cpp", + "limits": "cpp", + "semaphore": "cpp", + "regex": "cpp" + }, + "gotoSymbolStack.currentStackPosition": 0, + "gotoSymbolStack.maxStackPosition": 0, + "gotoSymbolStack.filePositionInfo": [] } diff --git a/include/cell/copy/copy_atom.hpp b/include/cell/copy/copy_atom.hpp index b778564..ccf0de4 100644 --- a/include/cell/copy/copy_atom.hpp +++ b/include/cell/copy/copy_atom.hpp @@ -17,6 +17,44 @@ namespace tilefusion::cell::copy::atom { namespace tl = tile_layout; using namespace cute; +namespace { +template +DEVICE void async_copy(void const* g_ptr /*source*/, + uint32_t s_ptr /*destination*/) { + static_assert(size == 4 || size == 8 || size == 16); + +#if (__CUDA_ARCH__ >= 900) + // SM90, hopper + assert(false && "Not implemented yet."); +#elif (__CUDA_ARCH__ >= 800) + // SM80, SM86, ampere + // TODO(ying): add a wrapper to allow choosing between different caching + // policies (e.g. "cache all levels"). + asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), + "l"(g_ptr), "n"(size)); +#else + // SM75, turing + unsigned tmp[size / 4]; + if constexpr (size == 16) { + asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3]) + : "l"(g_ptr)); + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), + "r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]), "r"(tmp[3])); + } else if constexpr (size == 8) { + asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" + : "=r"(tmp[0]), "=r"(tmp[1]) + : "l"(g_ptr)); + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), + "r"(tmp[0]), "r"(tmp[1])); + } else if constexpr (size == 4) { + asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(g_ptr)); + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "r"(tmp[0])); + } +#endif +} +} // namespace + template requires std::is_same_v || std::is_same_v @@ -28,27 +66,28 @@ struct LoadMatBase { static constexpr int kElmentBits = sizeof(DType) * 8; static constexpr int kNumPerAccess = kAccessInBits / kElmentBits; - /// @brief returns the lane row of the current thread within a warp. - // For an example, in ldmatrix, threads in a warp are arranged as - // follows (a 16 x 2 column-major): - // - // | | 0 | 1| + /// @brief Returns the lane row of the current thread within a warp. + /// For ldmatrix, threads in a warp are arranged in a 16x2 + /// column-major layout: + /// + /// | | 0 | 1| // |--|---|---| // |0 | 0 | 16| // |1 | 2 | 17| // |2 | 4 | 18| // | |...|...| // |15| 15| 31| - // - // if threadIdx.x is 43, then its lane row is 8, lane col is 0. + /// For example, if threadIdx.x is 43, its lane_row is 8 and lane_col is 0. + + /// @brief Returns the lane row of the current thread within a warp. DEVICE int lane_row_id() { - int lane_id = threadIdx.x % warpSize; + int lane_id = threadIdx.x % WARP_SIZE; return lane_id % tl::num_rows; } /// @brief returns the lane col of the current thread within a warp. DEVICE int lane_col_id() { - int lane_id = threadIdx.x % warpSize; + int lane_id = threadIdx.x % WARP_SIZE; return lane_id / tl::num_rows; } @@ -100,7 +139,6 @@ struct BaseTileStorer { private: // the thread layout for wmma's output tile. using ThreadLayout = tile_layout::RowMajor<8, 4>; - static constexpr int kWarpSize = 32; // in the output of a wmma tile, each thread stores four segments in 2x2 // layout, and each fragment contains 2 elements regardless of the data @@ -116,11 +154,11 @@ struct BaseTileStorer { typename tl::SharedLayoutWrapper::Layout in_tile_; DEVICE int lane_row_id() { - return (threadIdx.x % kWarpSize) / tl::num_cols; + return (threadIdx.x % WARP_SIZE) / tl::num_cols; } DEVICE int lane_col_id() { - return (threadIdx.x % kWarpSize) % tl::num_cols; + return (threadIdx.x % WARP_SIZE) % tl::num_cols; } }; @@ -155,7 +193,6 @@ struct BaseTileStorer { private: // the thread layout for wmma's output tile. using ThreadLayout = tile_layout::RowMajor<8, 4>; - static constexpr int kWarpSize = 32; // in the output of a wmma tile, each thread stores four segments in 2x2 // layout, and each fragment contains 2 elements regardless of the data @@ -171,11 +208,11 @@ struct BaseTileStorer { typename tl::SharedLayoutWrapper::Layout in_tile_; DEVICE int lane_row_id() { - return (threadIdx.x % kWarpSize) / tl::num_cols; + return (threadIdx.x % WARP_SIZE) / tl::num_cols; } DEVICE int lane_col_id() { - return (threadIdx.x % kWarpSize) % tl::num_cols; + return (threadIdx.x % WARP_SIZE) % tl::num_cols; } }; @@ -210,7 +247,6 @@ struct BaseTileStorer { private: // the thread layout for wmma's output tile. using ThreadLayout = tile_layout::ColMajor<4, 8>; - static constexpr int kWarpSize = 32; // in the output of a wmma tile, each thread stores four segments in 2x2 // layout, and each fragment contains 2 elements regardless of the data @@ -226,11 +262,11 @@ struct BaseTileStorer { typename tl::SharedLayoutWrapper::Layout in_tile_; DEVICE int lane_row_id() { - return (threadIdx.x % kWarpSize) % tl::num_rows; + return (threadIdx.x % WARP_SIZE) % tl::num_rows; } DEVICE int lane_col_id() { - return (threadIdx.x % kWarpSize) / tl::num_rows; + return (threadIdx.x % WARP_SIZE) / tl::num_rows; } }; @@ -246,10 +282,9 @@ struct BaseTileStorer { int lane_row = lane_row_id(); int lane_col = lane_col_id(); - // A base tile has a fixed shape of 16x16 (a 16x16 2D coordinate space - // with integer indices ranging from 0 to 255). `row` and `col` are used - // to calculate the index of an element within this 16x16 coordinate - // space. + // A base tile has a fixed shape of 16x16. Each thread accesses elements + // within this 16x16 coordinate space using `row` and `col` indices to + // calculate the appropriate memory offsets. int row = 0, col = 0; #pragma unroll for (int i = 0; i < kSegRows; ++i) { @@ -265,11 +300,9 @@ struct BaseTileStorer { private: // the thread layout for wmma's output tile. using ThreadLayout = tile_layout::ColMajor<4, 8>; - static constexpr int kWarpSize = 32; - // in the output of a wmma tile, each thread stores four segments in 2x2 - // layout, and each fragment contains 2 elements regardless of the data - // type + // Each thread stores four segments in a 2x2 layout in the WMMA output tile. + // Each segment contains 2 elements, regardless of the data type. static constexpr int kSegRows = 2; static constexpr int kSegCols = 2; @@ -281,102 +314,79 @@ struct BaseTileStorer { typename tl::SharedLayoutWrapper::Layout in_tile_; DEVICE int lane_row_id() { - return (threadIdx.x % kWarpSize) % tl::num_rows; + return (threadIdx.x % WARP_SIZE) % tl::num_rows; } DEVICE int lane_col_id() { - return (threadIdx.x % kWarpSize) / tl::num_rows; + return (threadIdx.x % WARP_SIZE) / tl::num_rows; } }; -template -struct GlobalToSharedBaseTileLoader; +template +struct GlobalToSharedLoaderBase; -/// @brief Implement loading a `16x16` BaseTile from global memory to shared -/// memory. -template -struct GlobalToSharedBaseTileLoader { +/// @brief Load a single BaseTile from global memory to shared memory. +template +struct GlobalToSharedLoaderBase { using DType = Shared::DType; - // NOTE: Please keep this thread layout strictly consistent with the thread - // layout for ldmatrix. - // The macro kernel breaks down the entire copy operation into iterations - // over 16x16 BaseTiles. To transfer a single BaseTile, threads in a warp - // are arranged in a 16x2 row-major layout. Each thread uses 128-bit data in - // a single access. - using ThreadLayout = tile_layout::ColMajor<16, 2>; - static constexpr int kThreadsPerRow = tl::num_rows; - static constexpr int kThreadsPerCol = tl::num_cols; - - static constexpr int kWarpSize = 32; - + // accessed by derived classes, must be public. static constexpr int kNumPerAccess = traits::AccessBase::kNumPerAccess; + using GlobalLayout = tl::MatrixLayout; - using BaseShape = traits::BaseTileShape; - - static constexpr int kColStride = kThreadsPerCol * kNumPerAccess; - static constexpr int kExecCount = BaseShape::kCols / kColStride; - - using BaseTileGlobalLayout = - cute::Layout, Int>, - Stride, _1>>; - - using BaseTileSharedLayout = tl::SharedLayoutWrapper< - Shared, traits::AccessBase::kAccessInBits>::Layout; - -#ifdef CP_ASYNC_SM80_ENABLED - using CopyInst = - Copy_Atom, DType>; -#else - using CopyInst = Copy_Atom; -#endif - using TiledCopy = decltype(make_tiled_copy( - CopyInst{}, - cute::Layout, Int>, - Stride, _1>>{}, - cute::Layout>>{})); - - using DataLayoutPerThread = cute::Layout, _1>, - Stride<_1, Int>>; - - DEVICE GlobalToSharedBaseTileLoader() : tiled_copy_(TiledCopy{}) {} + using NonSwizzled = tl::RowMajor; + using Swizzled = + tl::SwizzledRowMajor::kAccessInBits, + WarpShape>; + using SharedLayout = + std::conditional_t; DEVICE void copy(const DType* src, DType* dst) { int offset = 0; + uint32_t s_ptr; + #pragma unroll for (int i = 0; i < kExecCount; ++i) { - auto src_tensor = - make_tensor(make_gmem_ptr(src + offset), data_layout_); - auto dst_tensor = - make_tensor(make_smem_ptr(dst + offset), data_layout_); - - cute::copy(tiled_copy_, src_tensor, dst_tensor); + s_ptr = + static_cast(__cvta_generic_to_shared(dst + offset)); + // a single memory access access 16 bytes + async_copy<16>(src + offset, s_ptr); offset += kColStride; } } + /// @brief returns the lane row of the current thread within a warp. DEVICE int lane_row_id() { - int lane_id = threadIdx.x % kWarpSize; - return lane_id % tl::num_rows; + // NOTE: When copying a RowMajor data tile, the thread layout is + // interpreted as RowMajor. + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id / kThreadsPerRow; } /// @brief returns the lane col of the current thread within a warp. DEVICE int lane_col_id() { - int lane_id = threadIdx.x % kWarpSize; - return lane_id / tl::num_rows; + // NOTE: When copying a RowMajor data tile, the thread layout is + // interpreted as RowMajor. + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id % kThreadsPerRow; } private: - DataLayoutPerThread data_layout_; - TiledCopy tiled_copy_; + static constexpr int kThreadsPerRow = WarpShape::kColThreads; + + static constexpr int kColStride = kThreadsPerRow * kNumPerAccess; + static constexpr int kExecCount = WarpShape::kCols / kColStride; }; -/// @brief Implement loading a `16x16` BaseTile from global memory to shared -/// memory. -template -struct GlobalToSharedBaseTileLoader { +/// @brief Load a BaseTile from global memory to shared memory. +template +struct GlobalToSharedLoaderBase { using DType = Shared::DType; // The macro kernel breaks down the entire copy operation into iterations @@ -387,8 +397,6 @@ struct GlobalToSharedBaseTileLoader { static constexpr int kThreadsPerRow = tl::num_rows; static constexpr int kThreadsPerCol = tl::num_cols; - static constexpr int kWarpSize = 32; - static constexpr int kNumPerAccess = traits::AccessBase::kNumPerAccess; @@ -405,23 +413,6 @@ struct GlobalToSharedBaseTileLoader { using BaseTileSharedLayout = tl::SharedLayoutWrapper< Shared, traits::AccessBase::kAccessInBits>::Layout; -#ifdef CP_ASYNC_SM80_ENABLED - using CopyInst = - Copy_Atom, DType>; -#else - using CopyInst = Copy_Atom; -#endif - using TiledCopy = decltype(make_tiled_copy( - CopyInst{}, - cute::Layout, Int>, - Stride, _1>>{}, - cute::Layout, _1>>{})); - - using DataLayoutPerThread = cute::Layout, _1>, - Stride<_1, Int>>; - - DEVICE GlobalToSharedBaseTileLoader() : tiled_copy_(TiledCopy{}) {} - DEVICE void copy(const DType* src, DType* dst) { int offset = 0; #pragma unroll @@ -437,19 +428,36 @@ struct GlobalToSharedBaseTileLoader { } } - DEVICE int lane_col_id() { - int lane_id = threadIdx.x % kWarpSize; - return lane_id % tl::num_cols; + /// @brief returns the lane row of the current thread within a warp. + DEVICE int lane_row_id() { + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id / tl::num_cols; } /// @brief returns the lane col of the current thread within a warp. - DEVICE int lane_row_id() { - int lane_id = threadIdx.x % kWarpSize; - return lane_id / tl::num_cols; + DEVICE int lane_col_id() { + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id % tl::num_cols; } private: - DataLayoutPerThread data_layout_; + using DataPerThread = cute::Layout, _1>, + Stride<_1, Int>>; + + DataPerThread data_layout_; + +#ifdef CP_ASYNC_SM80_ENABLED + using CopyInst = + Copy_Atom, DType>; +#else + using CopyInst = Copy_Atom; +#endif + using TiledCopy = decltype(make_tiled_copy( + CopyInst{}, + cute::Layout, Int>, + Stride, _1>>{}, + data_layout_)); + TiledCopy tiled_copy_; }; @@ -464,8 +472,6 @@ struct SharedToGlobalBaseTileStorer { static constexpr int kThreadsPerRow = tl::num_rows; static constexpr int kThreadsPerCol = tl::num_cols; - static constexpr int kWarpSize = 32; - static constexpr int kNumPerAccess = traits::AccessBase::kNumPerAccess; @@ -516,13 +522,13 @@ struct SharedToGlobalBaseTileStorer { } DEVICE int lane_row_id() { - int lane_id = threadIdx.x % warpSize; + int lane_id = threadIdx.x % WARP_SIZE; return lane_id / tl::num_cols; } /// @brief returns the lane col of the current thread within a warp. DEVICE int lane_col_id() { - int lane_id = threadIdx.x % warpSize; + int lane_id = threadIdx.x % WARP_SIZE; return lane_id % tl::num_cols; } @@ -541,7 +547,6 @@ struct SharedToGlobalBaseTileStorer { using ThreadLayout = tile_layout::RowMajor<2, 16>; static constexpr int kThreadsPerRow = tl::num_rows; static constexpr int kThreadsPerCol = tl::num_cols; - static constexpr int kWarpSize = 32; static constexpr int kNumPerAccess = traits::AccessBase::kNumPerAccess; diff --git a/include/cell/copy/global_to_shared.hpp b/include/cell/copy/global_to_shared.hpp index 3b59ff1..00719c9 100644 --- a/include/cell/copy/global_to_shared.hpp +++ b/include/cell/copy/global_to_shared.hpp @@ -7,26 +7,43 @@ #include "traits/base.hpp" #include "types/mod.hpp" +#include + namespace tilefusion::cell::copy { using namespace atom; namespace tl = tile_layout; -template +/** + * @brief Load a warp tile from global memory to shared memory. + * + * This function loads a warp tile whose shape is specified by `WarpShape` + * from global memory to shared memory. + * + * @tparam Global_ The type of the global memory pointer. + * @tparam Shared_ The type of the shared memory pointer. + * @tparam WarpShape_ The shape of the warp tile. + * @tparam kRowExec_ The number of rows to execute. + * @tparam kColExec_ The number of columns to execute. + * @tparam kType The type of the elements to be loaded. + */ +template struct GlobalToSharedLoaderImpl; -template -struct GlobalToSharedLoaderImpl - : public GlobalToSharedBaseTileLoader { +template +struct GlobalToSharedLoaderImpl + : public GlobalToSharedLoaderBase { using Global = Global_; using Shared = Shared_; using DType = Global::DType; - using LoadBase = - GlobalToSharedBaseTileLoader; - using BaseShape = traits::BaseTileShape; + using LoadBase = GlobalToSharedLoaderBase; + + using WarpShape = WarpShape_; static_assert(Global::kRows == Shared::kRows && Global::kCols == Shared::kCols, @@ -44,11 +61,12 @@ struct GlobalToSharedLoaderImpllane_row_id(); - int lane_col = this->lane_col_id() * kNumPerAccess; + int row = this->lane_row_id(); + int col = this->lane_col_id() * LoadBase::kNumPerAccess; - int src_lane_offset = src_layout_(lane_row, lane_col); - int dst_lane_offset = dst_layout_(lane_row, lane_col); + /// the pointer offset inside a warp tile. + int src_lane_offset = src_layout_(row, col); + int dst_lane_offset = dst_layout_(row, col); int src_offset = 0, dst_offset = 0; #pragma unroll @@ -64,38 +82,38 @@ struct GlobalToSharedLoaderImpl; + WarpShape::kRows * Global::kRowStride, + WarpShape::kCols>; SrcBaseTilesLayout src_base_tiles_; // a BaseTile is contiguously stored in shared memory using DstBaseTilesLayout = tl::MatrixLayout; + WarpShape::kRows * Shared::kRowStride, + WarpShape::kNumel>; DstBaseTilesLayout dst_base_tiles_; - typename LoadBase::BaseTileGlobalLayout src_layout_; - // the layout for a single BaseTile - typename LoadBase::BaseTileSharedLayout dst_layout_; + // Given a thread index, the layouts below return the data offset from which + // the thread should load from the global memory tile and where to store it + // in the shared memory tile, respectively. + typename LoadBase::GlobalLayout src_layout_; + typename LoadBase::SharedLayout dst_layout_; }; -template -struct GlobalToSharedLoaderImpl - : public GlobalToSharedBaseTileLoader { +template +struct GlobalToSharedLoaderImpl + : public GlobalToSharedLoaderBase { using Global = Global_; using Shared = Shared_; using DType = Global::DType; - using LoadBase = - GlobalToSharedBaseTileLoader; + using LoadBase = GlobalToSharedLoaderBase; static_assert(Global::kRows == Shared::kRows && Global::kCols == Shared::kCols, @@ -259,12 +277,14 @@ struct GlobalToSharedLoader { using DType = Shared::DType; using WarpLayout = WarpLayout_; - // FIXME(ying): automatically infer the warp-level tile shape instead - // of using a fixed `BaseShape`. - // using WarpShape = + // FIXME(ying): uncomment the following lines to automatically infer the + // warp-level tile shape instead of using a fixed 16x16 `BaseShape`. using + // WarpShape = // warp::WarpTileShape; - using WarpShape = traits::BaseTileShape; + using WarpShape = + warp::WarpTileShape, Shared::kType>; + static_assert(Shared::kRows % WarpShape::kRows == 0, "Shared::kRows must be divisible by WarpShape::kRows."); static_assert(Shared::kCols % WarpShape::kCols == 0, @@ -295,8 +315,9 @@ struct GlobalToSharedLoader { int offset_src = global_offset_.template get_warp_offset(); int offset_dst = shared_offset_.get_warp_offset(); - using Loader = GlobalToSharedLoaderImpl; + // Load a single warp tile from global memory to shared memory + using Loader = GlobalToSharedLoaderImpl; Loader loader; loader(src_ptr + offset_src, dst_ptr + offset_dst); @@ -313,17 +334,23 @@ struct SharedToGlobalStorer { using DType = Shared::DType; using WarpLayout = WarpLayout_; + using WarpShape = traits::BaseTileShape; + // FIXME(ying): automatically infer the warp-level tile shape instead // of using a fixed `BaseShape`. - using WarpShape = traits::BaseTileShape; + // using WarpShape = + // warp::WarpTileShape; + static_assert(Shared::kRows % WarpShape::kRows == 0, "Shared::kRows must be divisible by WarpShape::kRows."); static_assert(Shared::kCols % WarpShape::kCols == 0, "Shared::kCols must be divisible by WarpShape::kCols."); static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode + using SharedOffset = warp::SharedOffsetHelper; + using GlobalOffset = warp::GlobalOffsetHelper; using ExecCounter = warp::ExecCounter; diff --git a/include/cell/copy/warp.hpp b/include/cell/copy/warp.hpp index d0f2144..2555623 100644 --- a/include/cell/copy/warp.hpp +++ b/include/cell/copy/warp.hpp @@ -206,19 +206,19 @@ struct WarpTileShape { : TileLayout::kCols; // number of columns in a warp - static constexpr int kThreadPerRow = kCols / AccessInfo::kNumPerAccess; - static_assert(WARP_SIZE % kThreadPerRow == 0, + static constexpr int kColThreads = kCols / AccessInfo::kNumPerAccess; + static_assert(WARP_SIZE % kColThreads == 0, "Fail to infer warp thread layout."); - static constexpr int kThreadPerCol = WARP_SIZE / kThreadPerRow; + static constexpr int kRowThreads = WARP_SIZE / kColThreads; - static constexpr int kRows = kThreadPerCol; - static_assert(TileLayout::kRows % kThreadPerCol == 0, + static constexpr int kRows = kRowThreads; + static_assert(TileLayout::kRows % kRowThreads == 0, "The number of rows of the tile isn't evenly divisible by " "the number of threads in a column."); static constexpr int kNumel = kRows * kCols; - using WarpThreadLayout = tl::RowMajor; + using WarpThreadLayout = tl::RowMajor; }; template @@ -246,19 +246,19 @@ struct WarpTileShape { : TileLayout::kRows; // number of rows in a warp - static constexpr int kThreadPerCol = kRows / AccessInfo::kNumPerAccess; - static_assert(WARP_SIZE % kThreadPerCol == 0, + static constexpr int kRowThreads = kRows / AccessInfo::kNumPerAccess; + static_assert(WARP_SIZE % kRowThreads == 0, "Fail to infer warp thread layout."); - static constexpr int kThreadPerRow = WARP_SIZE / kThreadPerCol; + static constexpr int kColThreads = WARP_SIZE / kRowThreads; - static constexpr int kCols = kThreadPerRow; - static_assert(TileLayout::kCols % kThreadPerRow == 0, + static constexpr int kCols = kColThreads; + static_assert(TileLayout::kCols % kColThreads == 0, "The number of columns of the tile isn't evenly divisible by " "the number of threads in a row."); static constexpr int kNumel = kRows * kCols; - using WarpThreadLayout = tl::ColMajor; + using WarpThreadLayout = tl::ColMajor; }; template diff --git a/include/types/layout.hpp b/include/types/layout.hpp index 874f962..0b11c61 100644 --- a/include/types/layout.hpp +++ b/include/types/layout.hpp @@ -19,64 +19,16 @@ namespace tilefusion::tile_layout { enum class Layout { kRowMajor = 0, kColMajor = 1 }; -HOST_DEVICE -const char* layout_type_to_str(Layout type) { - switch (type) { - case Layout::kRowMajor: - return "RowMajor"; - case Layout::kColMajor: - return "ColMajor"; - } - return "UnsupportedLayout"; -} - -namespace detail { using namespace cute; -template -struct SharedLayout { - using BaseShape = traits::BaseTileShape; - - static constexpr int kRows = kRows_; - static constexpr int kCols = kCols_; - - static constexpr int kRowStride = kRowStride_; - static constexpr int kColStride = kColStride_; - - static constexpr int kNumel = kRows * kCols; - - static constexpr Layout kType = kType_; - - DEVICE int operator()(int i, int j) const { - int tile_x = i / BaseShape::kRows; - int tile_y = j / BaseShape::kCols; - - int in_tile_x = i % BaseShape::kRows; - int in_tile_y = j % BaseShape::kCols; - - int tile_offset = tile_x * kRowStride + tile_y * kColStride; - int in_tile_offset = in_tile_(in_tile_x, in_tile_y); - - return tile_offset + in_tile_offset; - } - - private: - using BaseTileLayout = std::conditional_t< - kType == Layout::kRowMajor, - cute::Layout, Stride<_16, _1>>, /*RowMajor*/ - cute::Layout, Stride<_1, _16>>>; /*ColMajor*/ - BaseTileLayout in_tile_; -}; - -/// @brief Swizzled layout for 16x16 BaseTile. -template +/// @brief Swizzled layout for a single BaseTile. +template struct SwizzledRowMajor; /// @brief Swizzled row-major layout for storing half-typed 16x16 BaseTile. -template <> -struct SwizzledRowMajor<32> { - using BaseShape = traits::BaseTileShape<__half>; +template +struct SwizzledRowMajor<32, BaseShape> { + // using BaseShape = traits::BaseTileShape<__half>; static constexpr int kB = 2; static constexpr int kM = 3; @@ -92,17 +44,15 @@ struct SwizzledRowMajor<32> { cute::Layout, Int>, Stride, _1>>{})); - DEVICE SwizzledRowMajor() : swizzled_(SwizzledBaseTile{}) {}; - DEVICE int operator()(int i, int j) const { return swizzled_(i, j); } private: SwizzledBaseTile swizzled_; }; -template <> -struct SwizzledRowMajor<64> { - using BaseShape = traits::BaseTileShape; +template +struct SwizzledRowMajor<64, BaseShape> { + // using BaseShape = traits::BaseTileShape; static constexpr int kB = 2; static constexpr int kM = 2; @@ -116,8 +66,6 @@ struct SwizzledRowMajor<64> { cute::Layout, Int>, Stride, _1>>{})); - DEVICE SwizzledRowMajor() : swizzled_(SwizzledBaseTile{}) {}; - DEVICE int operator()(int i, int j) const { return swizzled_(i, j); } private: @@ -129,10 +77,8 @@ struct SwizzledRowMajor<64> { // 2^S columns, and each coordinate position has 2^M elements. // Therefore, to apply a swizzle function to a 2D data tile, the data // tile should have a shape that is a multiple of 2^B x 2^S x 2^M. -template <> -struct SwizzledRowMajor<128> { - using BaseShape = traits::BaseTileShape<__half>; - +template +struct SwizzledRowMajor<128, BaseShape> { static constexpr int kB = 2; static constexpr int kM = 3; static constexpr int kS = 3; @@ -148,21 +94,19 @@ struct SwizzledRowMajor<128> { using SwizzledBaseTile = decltype(composition(cute::Swizzle{}, LayoutAtom{})); - DEVICE SwizzledRowMajor() : swizzled_(SwizzledBaseTile{}) {}; - DEVICE int operator()(int i, int j) const { return swizzled_(i, j); } private: SwizzledBaseTile swizzled_; }; -/// @brief Swizzled column-major layout for 16x16 BaseTile. -template +/// @brief Swizzled column-major layout for a single BaseTile. +template struct SwizzledColMajor; -template <> -struct SwizzledColMajor<64> { - using BaseShape = traits::BaseTileShape<__half>; +template +struct SwizzledColMajor<64, BaseShape> { + // using BaseShape = traits::BaseTileShape<__half>; static constexpr int kB = 2; static constexpr int kM = 2; @@ -185,9 +129,9 @@ struct SwizzledColMajor<64> { SwizzledBaseTile swizzled_; }; -template <> -struct SwizzledColMajor<128> { - using BaseShape = traits::BaseTileShape<__half>; +template +struct SwizzledColMajor<128, BaseShape> { + // using BaseShape = traits::BaseTileShape<__half>; static constexpr int kB = 2; static constexpr int kM = 3; @@ -211,6 +155,54 @@ struct SwizzledColMajor<128> { SwizzledBaseTile swizzled_; }; +HOST_DEVICE +const char* layout_type_to_str(Layout type) { + switch (type) { + case Layout::kRowMajor: + return "RowMajor"; + case Layout::kColMajor: + return "ColMajor"; + } + return "UnsupportedLayout"; +} + +namespace detail { +template +struct SharedLayout { + using BaseShape = traits::BaseTileShape; + + static constexpr int kRows = kRows_; + static constexpr int kCols = kCols_; + + static constexpr int kRowStride = kRowStride_; + static constexpr int kColStride = kColStride_; + + static constexpr int kNumel = kRows * kCols; + + static constexpr Layout kType = kType_; + + DEVICE int operator()(int i, int j) const { + int tile_x = i / BaseShape::kRows; + int tile_y = j / BaseShape::kCols; + + int in_tile_x = i % BaseShape::kRows; + int in_tile_y = j % BaseShape::kCols; + + int tile_offset = tile_x * kRowStride + tile_y * kColStride; + int in_tile_offset = in_tile_(in_tile_x, in_tile_y); + + return tile_offset + in_tile_offset; + } + + private: + using BaseTileLayout = std::conditional_t< + kType == Layout::kRowMajor, + cute::Layout, Stride<_16, _1>>, /*RowMajor*/ + cute::Layout, Stride<_1, _16>>>; /*ColMajor*/ + BaseTileLayout in_tile_; +}; + template struct SharedLayoutWrapperImpl; @@ -239,7 +231,7 @@ template <> struct SharedLayoutWrapperImpl { using BaseShape = traits::BaseTileShape<__half>; - using Layout = SwizzledRowMajor<32>; + using Layout = SwizzledRowMajor<32, BaseShape>; }; /// @brief Shared memory layout for swizzled row-major layout with 16-bit data @@ -248,7 +240,7 @@ template <> struct SharedLayoutWrapperImpl { using BaseShape = traits::BaseTileShape<__half>; - using Layout = SwizzledRowMajor<64>; + using Layout = SwizzledRowMajor<64, BaseShape>; }; /// @brief Shared memory layout for swizzled col-major layout with 16-bit data @@ -257,7 +249,7 @@ template <> struct SharedLayoutWrapperImpl { using BaseShape = traits::BaseTileShape<__half>; - using Layout = SwizzledColMajor<64>; + using Layout = SwizzledColMajor<64, BaseShape>; }; /// @brief Shared memory layout for swizzled row-major layout with 16-bit data @@ -266,7 +258,7 @@ template <> struct SharedLayoutWrapperImpl { using BaseShape = traits::BaseTileShape<__half>; - using Layout = SwizzledRowMajor<128>; + using Layout = SwizzledRowMajor<128, BaseShape>; }; /// @brief Shared memory layout for swizzled col-major layout with 16-bit data @@ -275,7 +267,7 @@ template <> struct SharedLayoutWrapperImpl { using BaseShape = traits::BaseTileShape<__half>; - using Layout = SwizzledColMajor<128>; + using Layout = SwizzledColMajor<128, BaseShape>; }; /// @brief Helper for pretty printing a matrix layout's static shape-related @@ -330,7 +322,8 @@ using RowMajor = MatrixLayout; template using ColMajor = MatrixLayout; -/// @brief: Wrapper for creating non-swizzled or swizzled shared memory layout. +/// @brief Wrapper for creating a shared memory layout, which can be either +/// swizzled or non-swizzled based on the `Shared::kSwizzled` flag. template struct SharedLayoutWrapper { using Layout = diff --git a/tests/cpp/cell/test_g2s_load.cu b/tests/cpp/cell/test_g2s_load.cu index f75fb55..39163c5 100644 --- a/tests/cpp/cell/test_g2s_load.cu +++ b/tests/cpp/cell/test_g2s_load.cu @@ -28,6 +28,13 @@ __global__ void copy_g2s(const Element* src_ptr, Element* dst_ptr, __copy_async(); __syncthreads(); +#if defined(DEBUG) + if (thread(0)) { + printf("\nshared\n"); + inter.dump_value(); + } +#endif + storer(inter, dst); __syncthreads(); } diff --git a/tests/cpp/cell/test_layout.cu b/tests/cpp/cell/test_layout.cu index bb78f86..4f36282 100644 --- a/tests/cpp/cell/test_layout.cu +++ b/tests/cpp/cell/test_layout.cu @@ -35,7 +35,8 @@ void test_swizzled_function<__half>() { RowMajor layout1; // only siwizzle the first [16x16] half of the [kRows, kCols] matrix - using Swizzled = tl::detail::SwizzledRowMajor; + using BaseShape = traits::BaseTileShape<__half>; + using Swizzled = tl::SwizzledRowMajor; Swizzled layout2; Element* ptr = thrust::raw_pointer_cast(data.data()); @@ -74,7 +75,8 @@ void test_swizzled_function() { RowMajor layout1; // only siwizzle the first [16x16] half of the [kRows, kCols] matrix - using Swizzled = tl::detail::SwizzledRowMajor; + using BaseShape = traits::BaseTileShape<__half>; + using Swizzled = tl::SwizzledRowMajor; Swizzled layout2; for (int i = 0; i < RowMajor::kRows; ++i) {