diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index d8b905fe3d..d3b75ac72f 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -9,14 +9,190 @@ #include #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm_quant.hpp" #include "ck_tile/host.hpp" #include "quant_grouped_gemm.hpp" +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + template ; // Persistence + GemmConfig::Persistent>; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; - - using QuantGemmProblem = typename std::conditional< - QuantMode == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(argc, argv); + int result1 = run_grouped_gemm_example(argc, argv); return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index ede683abe6..0317685770 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -64,6 +64,7 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template struct GemmConfigBase { static constexpr bool kPadM = false; @@ -83,10 +84,11 @@ struct GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; }; -template -struct GemmConfigComputeV3_2 : public GemmConfigBase +template +struct GemmConfigComputeV3_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; -template -struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase +template +struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmQuantConfig; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + + template + using GemmPipeline = std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>; + + template + using BaseGemmPipeline = + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; +}; + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol") - .insert("init", "0", "0. Random, 2. One(s) (Constant)"); + .insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 37fab44f77..37832b54ba 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup, float ave_time = 0; - // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have - // the gemm problems known on the host. Instead, we can just pass the pointer - // to the kernel and let the workgroups figure out which tiles to work on. - // This is useful when the gemm problems are generated dynamically. - // In this example however, we generate the `kargs` using the known gemm_descs, - // and copy the gemm descriptions to the device memory. - // The contents of the memory pointed to by `kargs_ptr` pointer could be - // written by e.g. another kernel from earlier stage. - std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - assert(args[0].k_batch == 1); - for(const auto& arg : args) + if constexpr(!GemmConfig::Persistent) { - kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.aq_ptr, - arg.bq_ptr, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.QK_A, - arg.QK_B, - arg.stride_A, - arg.stride_B, - arg.stride_E, - arg.stride_AQ, - arg.stride_BQ, - arg.k_batch}); + ave_time = + grouped_gemm(args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); } - const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), - hipMemcpyHostToDevice, - stream.stream_id_)); - ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; @@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { AQK = 0; // No A quantization BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout)))); + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); + ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout)))); bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); } @@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { + return run_grouped_gemm_example_with_layouts typename GemmConfig> +template +int run_gemm_example_persistency( + std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[]) +{ + if(persistent) + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } +} + int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); std::string quant_mode = arg_parser.get_str("quant_mode"); + bool persistent = arg_parser.get_bool("persistent"); if(data_type == "fp8") { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { @@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[]) { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index caa6aad363..726f678d37 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -163,7 +163,6 @@ struct QuantGroupedGemmKernel static constexpr index_t kBlockSize = GemmPipeline::BlockSize; static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; - static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true"); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -262,10 +261,9 @@ struct QuantGroupedGemmKernel auto karg = QuantGroupedGemmKernelArgs{type_convert(gemm_descs[i].a_ptr), type_convert(gemm_descs[i].b_ptr), - type_convert(gemm_descs[i].e_ptr), type_convert(gemm_descs[i].aq_ptr), type_convert(gemm_descs[i].bq_ptr), - gemm_descs[i].k_batch, + type_convert(gemm_descs[i].e_ptr), M, N, K, @@ -275,7 +273,8 @@ struct QuantGroupedGemmKernel stride_b, stride_e, gemm_descs[i].stride_AQ, - gemm_descs[i].stride_BQ}; + gemm_descs[i].stride_BQ, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -342,16 +341,32 @@ struct QuantGroupedGemmKernel else { - RunGemmWithPipelineSelection(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else // Non-persistent kernel + { + Base::RunGemm({a_ptr}, + {b_ptr}, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } } @@ -451,7 +466,24 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::BQuantGrouped) + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + const auto& aq_block_window = gemm_tile_windows.at(Base::I1); + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + auto& c_block_window = gemm_tile_windows.at(Base::I4); + + // Run Epilogue Pipeline + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(Base::I3); // Run GEMM pipeline @@ -496,6 +528,53 @@ struct QuantGroupedGemmKernel } } + CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg* gemm_desc_ptr, + index_t block_id, + index_t group_count) const + { + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= gemm_desc_ptr[group_id].block_start && + block_id < gemm_desc_ptr[group_id].block_end)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + // For non-persistent kernels + template > + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + index_t group_count) const + { + const index_t block_id = ck_tile::get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); + const auto& kargs = gemm_desc_ptr[group_id]; + + const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, + kargs.group_karg.M, + kargs.group_karg.N, + (block_id - kargs.block_start) % grid_size_2d); + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + } + // For persistent kernels template , diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 30b9d70eb8..e7bd4a2626 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem, + index_t m = 0) const + { + const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) { + constexpr bool hot_loop = has_hot_loop_.value; + constexpr auto tail_num = tail_number_.value; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + m, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } }; } // namespace ck_tile diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 2bd2571993..7a7ae77730 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -14,6 +14,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp index 551989421f..6a1a28884a 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -18,32 +18,41 @@ using True = ck_tile::bool_constant; using False = ck_tile::bool_constant; using RowColQuant = std::integral_constant; using TensorQuant = std::integral_constant; +using AQuant = std::integral_constant; using BQuant = std::integral_constant; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp new file mode 100644 index 0000000000..8dcd6d017d --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using AQuant = std::integral_constant; + +// clang-format off +using KernelTypes_AQuant = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_AQuant, KernelTypes_AQuant); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_AQuant +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp index 4f44acf4c4..6c0ad545b7 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp @@ -20,9 +20,14 @@ using BQuant = std::integral_constant, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp index 48720aeebf..cc1b32fb20 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp @@ -20,11 +20,14 @@ using RowColQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp index f59fa29ec2..e446f7b168 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp @@ -20,11 +20,14 @@ using TensorQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 68b6735655..9941066c3e 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" @@ -32,24 +33,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using AQLayout = Row; using BQLayout = Col; - static constexpr bool Persistent = true; static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value; - - template - static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() - { -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif - } + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value; struct GroupedGemKernelParam_Mfma { @@ -66,11 +52,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test static const ck_tile::index_t N_Warp = 2; static const ck_tile::index_t K_Warp = 1; - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static const ck_tile::index_t K_Warp_Tile = - TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile(); + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = 32; }; struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma @@ -90,16 +74,201 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); } + template + float invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr bool DoubleSmemBuffer = + PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B + + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + + using QuantGroupSize = ck_tile::QuantGroupShape>; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t K_split = + (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemKernelParam::M_Warp, + GroupedGemKernelParam::N_Warp, + GroupedGemKernelParam::M_Warp_Tile, + GroupedGemKernelParam::N_Warp_Tile, + GroupedGemKernelParam::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) { - constexpr bool TransposeC = false; constexpr bool DoubleSmemBuffer = PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B - constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -131,40 +300,53 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test BQLayout, TransposeC, DoubleSmemBuffer, - true>; + Persistent>; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; // We create the GEMM pipeline without specifying hotloop or tailnumber. // These are automatically run inside the kernel based on the given input data. - using QuantGemmProblem = typename std::conditional< - QuantType == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; using GemmPipeline = std::conditional_t< - QuantType == ck_tile::QuantType::RowColQuant || - QuantType == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -292,13 +474,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / 128; // Group quantization: BQK = K / GroupSize - if(K % 128 != 0) + AQK = 0; // No A quantization + BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -317,6 +510,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout())); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -348,11 +547,20 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::HostTensor(ck_tile::host_tensor_descriptor( 1, 1, stride_BQs[i], is_row_major(BQLayout())))); } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + 0, 0, stride_BQs[i], is_row_major(BQLayout())))); + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( - 0, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + 0, 0, stride_AQs[i], is_row_major(AQLayout{})))); bq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( BQK, N, stride_BQs[i], is_row_major(BQLayout())))); @@ -429,11 +637,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if constexpr(Persistent) { // Generate kernel arguments std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); assert(gemm_descs[0].k_batch == 1); for(const auto& arg : gemm_descs) { @@ -471,7 +680,14 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test } else { - GTEST_FAIL() << "Non-persistent kernel not implemented yet"; + const auto stream = ck_tile::stream_config{nullptr, false, 1}; +#if CK_TILE_USE_WMMA + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#else + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#endif } // Copy results back to host for validation @@ -512,7 +728,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -550,5 +777,8 @@ using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant; template using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant; +template +using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant; + template using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant;