diff --git a/include/cell/copy/copy_atom.hpp b/include/cell/copy/copy_atom.hpp index ccf0de4..9a30e30 100644 --- a/include/cell/copy/copy_atom.hpp +++ b/include/cell/copy/copy_atom.hpp @@ -67,10 +67,10 @@ struct LoadMatBase { static constexpr int kNumPerAccess = kAccessInBits / kElmentBits; /// @brief Returns the lane row of the current thread within a warp. - /// For ldmatrix, threads in a warp are arranged in a 16x2 - /// column-major layout: - /// - /// | | 0 | 1| + // For ldmatrix, threads in a warp are arranged in a 16x2 + // column-major layout: + // + // | | 0 | 1| // |--|---|---| // |0 | 0 | 16| // |1 | 2 | 17| @@ -326,7 +326,12 @@ template struct GlobalToSharedLoaderBase; +/// TODO(ying): Determine if a separate implementation is needed for copying an +/// atomic warp tile. The current implementation is quite simple, so it might be +/// possible to simplify it into `GlobalToSharedLoaderImpl`. +/// /// @brief Load a single BaseTile from global memory to shared memory. +/// template struct GlobalToSharedLoaderBase { @@ -346,18 +351,9 @@ struct GlobalToSharedLoaderBase; DEVICE void copy(const DType* src, DType* dst) { - int offset = 0; - uint32_t s_ptr; - -#pragma unroll - for (int i = 0; i < kExecCount; ++i) { - s_ptr = - static_cast(__cvta_generic_to_shared(dst + offset)); - - // a single memory access access 16 bytes - async_copy<16>(src + offset, s_ptr); - offset += kColStride; - } + // a single memory access access 16 bytes + async_copy<16>(src, + static_cast(__cvta_generic_to_shared(dst))); } /// @brief returns the lane row of the current thread within a warp. @@ -365,7 +361,7 @@ struct GlobalToSharedLoaderBase - namespace tilefusion::cell::copy { using namespace atom; namespace tl = tile_layout; @@ -277,11 +275,13 @@ struct GlobalToSharedLoader { using DType = Shared::DType; using WarpLayout = WarpLayout_; + // This implementation uses a fixed 16x16 `BaseShape` as the atomic data + // tile accessed by threads in a single warp that issues a single load/store + // instruction. // FIXME(ying): uncomment the following lines to automatically infer the // warp-level tile shape instead of using a fixed 16x16 `BaseShape`. using // WarpShape = // warp::WarpTileShape; - using WarpShape = warp::WarpTileShape, Shared::kType>; @@ -300,8 +300,8 @@ struct GlobalToSharedLoader { static constexpr int kColExec = ExecCounter::kColExec; static_assert(kRowExec && kColExec, - "Ensure that the execution count for all " - "rows and columns is greater than 0."); + "Ensure that the execution count for all rows and columns is " + "greater than 0."); template DEVICE void operator()(const Global& src, Shared& dst) { diff --git a/include/cell/copy/warp.hpp b/include/cell/copy/warp.hpp index 2555623..efd5c44 100644 --- a/include/cell/copy/warp.hpp +++ b/include/cell/copy/warp.hpp @@ -292,6 +292,33 @@ struct GlobalOffsetHelper { } }; +/** + * FIXME(ying): `kIsSharedLayout` is a temporary fix for an issue in the current + * implementation where `RowMajor` and `ColMajor` layouts are not explicitly + * distinguished between shared memory and global memory. This should be + * addressed in the future with a more robust design. The issue arises as + * follows: suppose we have a shared memory tile with a row-major layout + * declared as: + * using Shared = SharedTile<__half, RowMajor>; + * + * In physical memory, shared memory is organized in units of a base tile, + * which is contiguous in shared memory banks and can be accessed without + * bank conflicts. This differs from global memory, where data is laid out + * contiguously with specific strides defined by the given layout. + * + * These differences are transparent to front-end users. The conflicts in the + * current implementation arise from the fact that such a shared memory layout + * can be declared by the user as above, or created internally by constructs + * like `SharedTileIterator`. When calculating the offset of a warp tile in + * shared memory or copying data, the caller should be aware of the layout of + * the shared memory tile. + * + * `kIsSharedLayout` is a temporary fix to address this issue. When set to + * `false`, the layout is created by the front-end user, since user is not aware + * of how data is physically stored, layout parameters (e.g., `strides`) does + * not correctly reveal the physical layout of data in memory. This requires + * further special treatment. + */ template > diff --git a/include/types/layout.hpp b/include/types/layout.hpp index 0b11c61..2963c7f 100644 --- a/include/types/layout.hpp +++ b/include/types/layout.hpp @@ -28,8 +28,6 @@ struct SwizzledRowMajor; /// @brief Swizzled row-major layout for storing half-typed 16x16 BaseTile. template struct SwizzledRowMajor<32, BaseShape> { - // using BaseShape = traits::BaseTileShape<__half>; - static constexpr int kB = 2; static constexpr int kM = 3; static constexpr int kS = 3; @@ -52,8 +50,6 @@ struct SwizzledRowMajor<32, BaseShape> { template struct SwizzledRowMajor<64, BaseShape> { - // using BaseShape = traits::BaseTileShape; - static constexpr int kB = 2; static constexpr int kM = 2; static constexpr int kS = 3; @@ -106,8 +102,6 @@ struct SwizzledColMajor; template struct SwizzledColMajor<64, BaseShape> { - // using BaseShape = traits::BaseTileShape<__half>; - static constexpr int kB = 2; static constexpr int kM = 2; static constexpr int kS = 3; @@ -131,8 +125,6 @@ struct SwizzledColMajor<64, BaseShape> { template struct SwizzledColMajor<128, BaseShape> { - // using BaseShape = traits::BaseTileShape<__half>; - static constexpr int kB = 2; static constexpr int kM = 3; static constexpr int kS = 3; diff --git a/tests/cpp/cell/test_g2s_load.cu b/tests/cpp/cell/test_g2s_load.cu index 39163c5..58e6e72 100644 --- a/tests/cpp/cell/test_g2s_load.cu +++ b/tests/cpp/cell/test_g2s_load.cu @@ -28,15 +28,19 @@ __global__ void copy_g2s(const Element* src_ptr, Element* dst_ptr, __copy_async(); __syncthreads(); + storer(inter, dst); + __syncthreads(); + #if defined(DEBUG) if (thread(0)) { printf("\nshared\n"); inter.dump_value(); + + printf("\nglobal\n"); + dst.dump_value(); + printf("\n"); } #endif - - storer(inter, dst); - __syncthreads(); } template , 96, 128>(); run_test_row_major, 16, 16>(); + run_test_row_major, 16, 32>(); run_test_row_major, 32, 128>(); run_test_row_major, 192, 32>(); run_test_row_major, 64, 128>();