diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index c3a66ba3e..81fbd9632 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -20,10 +20,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr bool kTilePermute = false; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; - constexpr int kBlockPerCu = 1; // This part comes from the Codegen @@ -39,11 +35,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - // Whether doing the CShuffle (transpose before the global memory), depending on the output - // layout. - constexpr bool CShuffleEpilogue = - std::is_same_v; - using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, @@ -51,26 +42,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using TilePartitioner = ck_tile::GemmTile2DPartitioner; - using GemmEpilogue = std::conditional_t< - CShuffleEpilogue, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; - using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::GemmKernel; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 5d2bd2df3..fb43e6f50 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -60,9 +60,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::sequence>; using TilePartitioner = ck_tile::GemmTile2DPartitioner; - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile:: TileGemmUniversalTraits; @@ -95,6 +92,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 720802236..2a1cd5825 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -19,12 +19,9 @@ template float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - constexpr bool kTilePermute = false; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; constexpr int kBlockPerCu = 1; @@ -41,11 +38,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - // Whether doing the CShuffle (transpose before the global memory), depending on the output - // layout. - constexpr bool CShuffleEpilogue = - std::is_same_v; - using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, @@ -53,26 +45,24 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using TilePartitioner = ck_tile::GemmTile2DPartitioner; - using GemmEpilogue = std::conditional_t< - CShuffleEpilogue, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; - using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::BatchedGemmKernel; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index bb4bdbf51..c32fac6c0 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -20,12 +20,9 @@ namespace { struct GroupedGemmKernelParam { - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - static const bool kTilePermute = false; - - static const ck_tile::index_t kOutputRank = 2; + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; static const int kBlockPerCu = 1; static const ck_tile::index_t M_Tile = 128; @@ -54,24 +51,6 @@ using CodegenGemmShape = using TilePartitioner = ck_tile::GemmTile1DPartitioner; -template -using GemmEpilogue = std::conditional_t< - std::is_same_v, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue>>; - template using CodegenGemmTraits = ck_tile::TileGemmTraits using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1>; +template +using GemmEpilogue = ck_tile::CShuffleEpilogue::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemmKernelParam::M_Warp, + GroupedGemmKernelParam::N_Warp, + GroupedGemmKernelParam::M_Warp_Tile, + GroupedGemmKernelParam::N_Warp_Tile, + GroupedGemmKernelParam::K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; + template using Kernel = ck_tile::GroupedGemmKernel, - GemmEpilogue>; + GemmEpilogue>; }; // namespace std::size_t get_workspace_size(const std::vector& gemm_descs) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 01105d2a8..4aba3d7ec 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -1,194 +1,189 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" - -#define CK_TILE_MAX_RANK 5 +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" namespace ck_tile { -// this epilogue aiming to store a matrix with different layout from the shared memory to the global -// memory. template + typename CLayout_, + index_t kBlockSize_, + index_t kM_, + index_t kN_, + index_t kMWave_, + index_t kNWave_, + index_t kMPerXdl_, + index_t kNPerXdl_, + index_t kKPerXdl_, + bool isCTransposed_> struct CShuffleEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool kTilePermute = kTilePermute_; - static constexpr index_t kRank = kRank_; - static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4}; - static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = { - TileSize0, TileSize1, TileSize2, TileSize3, TileSize4}; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t kMWave = kMWave_; + static constexpr index_t kNWave = kNWave_; + static constexpr index_t kMPerXdl = kMPerXdl_; + static constexpr index_t kNPerXdl = kNPerXdl_; + static constexpr index_t kKPerXdl = kKPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; }; template struct CShuffleEpilogue { - using Problem = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; - const index_t* kPerm = Problem::kPerm; - static constexpr bool kTilePermute = Problem::kTilePermute; - static constexpr index_t kRank = Problem::kRank; - const index_t* tile_sizes = Problem::tile_sizes; - - // No additional shared memory needed - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } - - CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t kMWave = Problem::kMWave; + static constexpr index_t kNWave = Problem::kNWave; + static constexpr index_t kMPerXdl = Problem::kMPerXdl; + static constexpr index_t kNPerXdl = Problem::kNPerXdl; + static constexpr index_t kKPerXdl = Problem::kKPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr index_t kMPerIteration = kMPerXdl * kMWave; + static constexpr index_t kNPerIteration = kNPerXdl * kNWave; + + using WG = WarpGemmMfmaDispatcher; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + /** + * @brief Get the vector store size for C tensor. + * + * @note The vector store size for output C tensor would depend on multiple factors + * like its data layout and warp gemm C transposition. In general it would + * be the number of consecutive elements in contiguous C dimension hold by + * single thread. + * + * @return The vector store size for C tensor. + */ + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { - // TODO: At now CShuffle doesn't allow to vector store after permute. - // It should be fixed and this function should return true. - return false; + constexpr index_t MaxVectorStoreSize = 16; + return MaxVectorStoreSize / sizeof(ODataType); } - template - CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { - using DataType = typename OAccTile::DataType; - - // Get thread buffer - auto& thread_buf = o_acc_tile.get_thread_buffer(); - - // Create a temporary buffer to hold the permuted data - thread_buffer permuted_thread_buf; - - // Get the lengths of each dimension - auto thread_tensor_lengths = o_acc_tile.get_lengths(); - - // Total number of elements - index_t total_elements = OAccTile::kThreadElementSpaceSize; - - // Iterate over all elements - for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) + // N is contiguous dimension + if constexpr(std::is_same_v) { - // Convert linear index to multi-dimensional indices - array indices; - index_t remaining = linear_idx; - static_for<0, kRank, 1>{}([&](auto i) { - constexpr auto rev_i = kRank - 1 - i; - indices(rev_i) = remaining % thread_tensor_lengths.get(number{}); - remaining /= thread_tensor_lengths.get(number{}); - }); - - // Apply the permutation - array permuted_indices; - static_for<0, kRank, 1>{}( - [&](auto i) { permuted_indices(i) = indices.get(number{}); }); - - // Compute offsets - index_t dst_offset = 0; - index_t stride = 1; - - static_for<0, kRank, 1>{}([&](auto i) { - constexpr auto rev_i = kRank - 1 - i; - dst_offset += permuted_indices[rev_i] * stride; - stride *= thread_tensor_lengths.get(number{}); - }); - - // Move the data - permuted_thread_buf(dst_offset) = thread_buf[linear_idx]; + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{})); } - - // Copy the permuted data back to the original thread buffer - for(index_t i = 0; i < total_elements; ++i) + // M is contiguous dimension + else if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number<1>{}, number{})); + } + else { - thread_buf.set_as(i, permuted_thread_buf.get(i)); + static_assert(false, "Unsupported CLayout!"); } } - template - CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) + CK_TILE_DEVICE auto + operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) { - const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); - - // Compute the tile coordinates by dividing the window origin by the tile sizes - index_t tile_coords[CK_TILE_MAX_RANK] = {0}; - for(index_t i = 0; i < kRank; ++i) - { - tile_coords[i] = current_window_origin[i] / tile_sizes[i]; - // printf("The tile_coord is: %d", tile_coords[i]); - } - - // Apply the permutation to the tile coordinates - index_t permuted_tile_coords[CK_TILE_MAX_RANK]; - for(index_t i = 0; i < kRank; ++i) - { - permuted_tile_coords[i] = tile_coords[kPerm[i]]; - // printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]); - } - // Compute the permuted window origin - index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0}; - for(index_t i = 0; i < kRank; ++i) - { - permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i]; - // printf("The new permuted_window_origin is: %d", permuted_window_origin[i]); - } - - typename ODramWindowTmp::BottomTensorIndex step = {}; - for(index_t i = 0; i < kRank; ++i) - { - step[i] = permuted_window_origin[i] - current_window_origin[i]; - } + const index_t iMWarp = get_warp_id() / kNWave; + const index_t iNWarp = get_warp_id() - iMWarp * kNWave; + + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + auto o_lds_block = make_tensor_view( + static_cast(p_smem), lds_block_desc); + auto in_lds_window = + make_tile_window(o_lds_block, + make_tuple(number{}, number{}), + {number{} * iMWarp, number{} * iNWarp}); + auto out_lds_window = + make_tile_window(o_lds_block, + make_tuple(number{}, number{}), + {0, 0}); + + using SFC = space_filling_curve, + sequence<0, 1>, + sequence>; + constexpr index_t num_access = SFC::get_num_of_access(); + + using TileEncodingPattern = + TileDistributionEncodingPattern2D; + constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + CWarpTensor c_warp_in_tensor; + static_for<0, num_access, 1>{}([&](auto iAccess) { + constexpr auto idx_y_start = SFC::get_index(iAccess); + + constexpr auto mIter = number{}) / (kMPerXdl * kMWave)>{}; + constexpr auto nIter = number{}) / (kNPerXdl * kNWave)>{}; + + c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + const auto c_warp_in_tensor_casted = cast_tile(c_warp_in_tensor); + + block_sync_lds(); + store_tile(in_lds_window, c_warp_in_tensor_casted); + block_sync_lds(); + + const auto c_out_tensor = + load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - // Move the window - move_tile_window(o_dram_window_tmp, step); - - // Permute the data within the tile if necessary - if constexpr(kTilePermute) - { - permute_tile_data(o_acc_tile); - } - - // Store the tile data to the permuted location - if constexpr(kPadM || kPadN) - { if constexpr(out_memory_data_op == memory_operation_enum::set) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + store_tile(out_dram_window, c_out_tensor); } else { - update_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + update_tile(out_dram_window, c_out_tensor); } - buffer_store_fence(); - } - else - { - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(iAccess != num_access - 1) { - store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + constexpr auto step = SFC::get_forward_step(iAccess); + move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); } - else - { - update_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); - } - } + }); } }; - } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 177573de3..6e290fe6d 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" namespace ck_tile { @@ -23,6 +25,26 @@ struct Default2DEpilogueProblem static constexpr bool UseRawStore = UseRawStore_; }; +template +struct DefaultGemm2DEpilogueProblem + : public Default2DEpilogueProblem +{ + using CLayout = remove_cvref_t; + static constexpr index_t kMPerXdl = kMPerXdl_; + static constexpr index_t kNPerXdl = kNPerXdl_; + static constexpr index_t kKPerXdl = kKPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; +}; + template struct Default2DEpilogue { @@ -35,14 +57,13 @@ struct Default2DEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } - CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; } - // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? template - CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) + CK_TILE_DEVICE auto + operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) { // TODO: this is ugly @@ -71,4 +92,76 @@ struct Default2DEpilogue } } }; + +template +struct DefaultGemm2DEpilogue : public Default2DEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kMPerXdl = Problem::kMPerXdl; + static constexpr index_t kNPerXdl = Problem::kNPerXdl; + static constexpr index_t kKPerXdl = Problem::kKPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + + using WG = WarpGemmMfmaDispatcher; + + using CWarpDstr = typename WG::CWarpDstr; + + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() + { + // N is contiguous dimension + if constexpr(std::is_same_v) + { + if constexpr(isCTransposed) + { + // In this case each thread has multiple consecutive elements in + // N dimension, however consecutive threads' elements have stride. + constexpr index_t NDimY = CWarpDstr::NDimY; + constexpr auto c_warp_y_lengths = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); + static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == + c_warp_y_lengths.get(number{})); + return c_warp_y_lengths.get(number{}); + } + else + { + // In this case each thread has just a single item in Ndim + return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; + } + } + // M is contiguous dimension + else if constexpr(std::is_same_v) + { + if constexpr(isCTransposed) + { + // In this case each thread has just a single item in Mdim + return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; + } + else + { + // In this case each thread has multiple consecutive elements in + // M dimension, however consecutive threads' elements have stride. + constexpr index_t NDimY = CWarpDstr::NDimY; + constexpr auto c_warp_y_lengths = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); + static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == + c_warp_y_lengths.get(number{})); + return c_warp_y_lengths.get(number{}); + } + } + else + { + static_assert(false, "Unsupported CLayout!"); + } + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 8d640831d..774736e1f 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -159,12 +159,8 @@ struct GemmKernel CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { - constexpr bool is_output_c_reg_transposed = - EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); - if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 && - std::is_same_v && - is_output_c_reg_transposed) || - !(std::is_same_v || std::is_same_v))) + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) { if(kargs.KBatch != 1) { @@ -182,7 +178,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.K % GemmPipeline::VectorSizeA != 0) + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) { std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; return false; @@ -197,7 +193,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.M % GemmPipeline::VectorSizeA != 0) + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) { std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; return false; @@ -213,7 +209,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.N % GemmPipeline::VectorSizeB != 0) + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) { std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; return false; @@ -228,7 +224,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.K % GemmPipeline::VectorSizeB != 0) + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) { std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; return false; @@ -244,7 +240,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.N % GemmPipeline::VectorSizeC != 0) + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) { std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; return false; @@ -259,7 +255,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.M % GemmPipeline::VectorSizeC != 0) + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) { std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; return false; @@ -275,14 +271,6 @@ struct GemmKernel const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { - // const auto idxs = TilePartitioner{}(); - // const auto i_m = idxs.at(number<0>{}); - // const auto i_n = idxs.at(number<1>{}); - // // options - // const ADataType* a_start = static_cast(kargs.a_ptr); - // const BDataType* b_start = static_cast(kargs.b_ptr); - // // Convert pointers to tensor views - // auto a_tensor_view = [&]() { const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -290,7 +278,7 @@ struct GemmKernel a_ptr, make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(kargs.stride_A, 1), - number{}, + number{}, number<1>{}); } else @@ -299,7 +287,7 @@ struct GemmKernel a_ptr, make_tuple(splitk_batch_offset.splitted_k, kargs.M), make_tuple(kargs.stride_A, 1), - number{}, + number{}, number<1>{}); } }(); @@ -311,7 +299,7 @@ struct GemmKernel b_ptr, make_tuple(splitk_batch_offset.splitted_k, kargs.N), make_tuple(kargs.stride_B, 1), - number{}, + number{}, number<1>{}); } else @@ -320,7 +308,7 @@ struct GemmKernel b_ptr, make_tuple(kargs.N, splitk_batch_offset.splitted_k), make_tuple(kargs.stride_B, 1), - number{}, + number{}, number<1>{}); } }(); @@ -333,7 +321,7 @@ struct GemmKernel c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), - number{}, + number{}, number<1>{}); } else @@ -501,16 +489,13 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - constexpr bool is_output_c_reg_transposed = - EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); - if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) || - (GemmPipeline::VectorSizeC % 2 == 0 && - std::is_same_v && - is_output_c_reg_transposed)) + if constexpr(DstInMemOp == memory_operation_enum::set || + !(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { EpiloguePipeline{} .template operator()( - c_block_window, c_block_tile); + c_block_window, c_block_tile, smem_ptr); } } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 6acc547db..c08fe4546 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrImplBase static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + template CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, SrcTileWindow& dram_tile_window, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 70de4014c..0bd780723 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrCompV3 static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; @@ -62,9 +64,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA(); - static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB(); - static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC(); + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; @@ -81,11 +83,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 return Policy::template GetSmemSize(); } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() - { - return Policy::template IsTransposeC(); - } - template struct PipelineImpl : public PipelineImplBase { @@ -110,9 +107,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t B_LDS_Read_Width = KPerXDL; constexpr index_t A_Buffer_Load_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * VectorSizeA); + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); constexpr index_t B_Buffer_Load_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * VectorSizeB); + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 1d6a9a0b8..38c663f4c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrMem using BDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -113,9 +115,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA(); - static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB(); - static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC(); + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; @@ -133,11 +135,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return Policy::template GetSmemSize(); } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() - { - return Policy::template IsTransposeC(); - } - template struct PipelineImpl : public PipelineImplBase { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index ccb2f81d4..d9f04a87c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -31,21 +31,21 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; - static constexpr index_t VectorSizeA = Problem::VectorSizeA; - static constexpr index_t VectorSizeB = Problem::VectorSizeB; - static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } + static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } - template {}; static constexpr auto I2 = number<2>{}; - static constexpr bool TransposeC = true; - // 3d + padding template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() @@ -383,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { @@ -397,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2), - TransposeC>; + Problem::TransposeC>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy().get_element_space_size(); } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } - template ; using CLayout = remove_cvref_t; + static constexpr bool TransposeC = Traits::TransposeC; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr bool kPadM = Traits::kPadM; @@ -111,7 +113,6 @@ struct GemmPipelineProblemBase return kPadK ? 1 : GetAlignmentB(); } }(); - static constexpr index_t VectorSizeC = []() { if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 31a837aa4..33f105a43 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -549,12 +549,6 @@ struct UniversalGemmPipelineAgBgCrPolicy return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } - template - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() - { - return Problem::TransposeC; - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index ab534ffcf..047e0a293 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -29,12 +29,9 @@ class TestCkTileBatchedGemm : public ::testing::Test const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - constexpr bool kTilePermute = false; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; constexpr int kBlockPerCu = 1; @@ -51,11 +48,6 @@ class TestCkTileBatchedGemm : public ::testing::Test constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - // Whether doing the CShuffle (transpose before the global memory), depending on the output - // layout. - constexpr bool CShuffleEpilogue = - std::is_same_v; - using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, @@ -63,21 +55,6 @@ class TestCkTileBatchedGemm : public ::testing::Test using TilePartitioner = ck_tile::GemmTile2DPartitioner; - using GemmEpilogue = std::conditional_t< - CShuffleEpilogue, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; - using CodegenGemmTraits = ck_tile::TileGemmTraits; @@ -88,6 +65,20 @@ class TestCkTileBatchedGemm : public ::testing::Test CodegenGemmTraits>; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::BatchedGemmKernel; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 147449872..647b54cb8 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include @@ -65,9 +65,6 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::sequence>; using TilePartitioner = ck_tile::GemmTile2DPartitioner; - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile:: TileGemmUniversalTraits; @@ -106,6 +103,20 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::GemmPipelineAgBgCrCompV3>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -244,7 +255,7 @@ class TestCkTileGemmPipeline : public ::testing::Test public: std::vector k_batches_; - void SetUp() override { k_batches_ = {1}; } + void SetUp() override { k_batches_ = {1, 2}; } template void Run(const int M, diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index a1b767d85..6b9bf0c6f 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include @@ -26,12 +26,9 @@ class TestCkTileGroupedGemm : public ::testing::Test struct GroupedGemKernelParam { - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - static const bool kTilePermute = false; - - static const ck_tile::index_t kOutputRank = 2; + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; static const int kBlockPerCu = 1; static const ck_tile::index_t M_Tile = 128; @@ -60,26 +57,6 @@ class TestCkTileGroupedGemm : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; - template - using GemmEpilogue = - std::conditional_t, - ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; - template using CodegenGemmTraits = ck_tile::TileGemmTraits>; + template + using GemmEpilogue = ck_tile::CShuffleEpilogue::BlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemKernelParam::M_Warp, + GroupedGemKernelParam::N_Warp, + GroupedGemKernelParam::M_Warp_Tile, + GroupedGemKernelParam::N_Warp_Tile, + GroupedGemKernelParam::K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; + template using Kernel = ck_tile::GroupedGemmKernel, - GemmEpilogue>; + GemmEpilogue>; using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; std::size_t GetWorkspaceSize(const std::vector& gemm_descs)