From 0f3806661a0eef2f953d5edbefff742b3938c982 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Thu, 27 Nov 2025 15:58:40 +0000 Subject: [PATCH 01/14] wip: add aquant to grouped gemm quant example --- .../17_grouped_gemm/quant_grouped_gemm.cpp | 52 ++++-- .../17_grouped_gemm/quant_grouped_gemm.hpp | 3 +- .../quant_run_grouped_gemm_example.inc | 170 ++++++++++++------ .../kernel/grouped_gemm_quant_kernel.hpp | 19 +- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 45 +++++ 5 files changed, 219 insertions(+), 70 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index d8b905fe3d..bc7bcce34c 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" @@ -59,7 +60,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, BQLayout, GemmConfig::TransposeC, GemmConfig::DoubleSmemBuffer, - true>; // Persistence + GemmConfig::Persistent>; float ave_time{0}; @@ -68,15 +69,27 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, constexpr auto memory_operation = memory_operation_.value; constexpr bool transpose_c = false; - using QuantGemmProblem = typename std::conditional< - QuantMode == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + UseGroupedQuant, + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(argc, argv); + int result1 = run_grouped_gemm_example(argc, argv); /* || + run_grouped_gemm_example(argc, argv);*/ return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index ede683abe6..554584c2bd 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -83,6 +83,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = true; }; template @@ -148,7 +149,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol") + .insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol") .insert("init", "0", "0. Random, 2. One(s) (Constant)"); bool result = arg_parser.parse(argc, argv); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 37fab44f77..f2d7f628a5 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -57,56 +57,79 @@ float invoke_gemm(int n_warmup, float ave_time = 0; - // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have - // the gemm problems known on the host. Instead, we can just pass the pointer - // to the kernel and let the workgroups figure out which tiles to work on. - // This is useful when the gemm problems are generated dynamically. - // In this example however, we generate the `kargs` using the known gemm_descs, - // and copy the gemm descriptions to the device memory. - // The contents of the memory pointed to by `kargs_ptr` pointer could be - // written by e.g. another kernel from earlier stage. - std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - assert(args[0].k_batch == 1); - for(const auto& arg : args) + if constexpr(!GemmConfig::Persistent) { - kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.aq_ptr, - arg.bq_ptr, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.QK_A, - arg.QK_B, - arg.stride_A, - arg.stride_B, - arg.stride_E, - arg.stride_AQ, - arg.stride_BQ, - arg.k_batch}); + ave_time = grouped_gemm_tileloop( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + assert(args[0].k_batch == 1); + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); } - const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), - hipMemcpyHostToDevice, - stream.stream_id_)); - ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; @@ -259,6 +282,15 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error("K must be divisible by 128 for AQuantGrouped mode"); + } + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { AQK = 0; // No A quantization @@ -284,6 +316,12 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -311,10 +349,17 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout)))); + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); + ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout)))); bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); } @@ -444,7 +489,7 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -541,6 +597,13 @@ int run_grouped_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } + else if(quant_mode == "aquant") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::QuantType::AQuantGrouped>( + a_layout, b_layout, argc, argv); + } else if(quant_mode == "bquant") { return run_gemm_example_prec_type, @@ -569,6 +632,13 @@ int run_grouped_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } + else if(quant_mode == "aquant") + { + return run_gemm_example_prec_type, + ck_tile::bf8_t, + ck_tile::QuantType::AQuantGrouped>( + a_layout, b_layout, argc, argv); + } else if(quant_mode == "bquant") { return run_gemm_example_prec_type, diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index caa6aad363..032ae70f1a 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -451,7 +451,24 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::BQuantGrouped) + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + const auto& aq_block_window = gemm_tile_windows.at(Base::I1); + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + auto& c_block_window = gemm_tile_windows.at(Base::I4); + + // Run Epilogue Pipeline + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(Base::I3); // Run GEMM pipeline diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 30b9d70eb8..3671107fa1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -439,6 +439,51 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem, + index_t m = 0) const + { + const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) { + constexpr bool hot_loop = has_hot_loop_.value; + constexpr auto tail_num = tail_number_.value; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + m, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } }; } // namespace ck_tile From 5d4a91a09ba4208215ac8cff8ed44bfbae5643a6 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Fri, 28 Nov 2025 11:55:25 +0000 Subject: [PATCH 02/14] fix: properly handle hot loop count in aquant pipeline --- .../gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 3671107fa1..e7bd4a2626 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Date: Fri, 28 Nov 2025 15:06:10 +0000 Subject: [PATCH 03/14] fix: add separate GemmConfig structs for AQuant, automatically select the correct one --- .../17_grouped_gemm/quant_grouped_gemm.cpp | 8 ++- .../17_grouped_gemm/quant_grouped_gemm.hpp | 50 +++++++++++++++++++ .../quant_run_grouped_gemm_example.inc | 42 +++++----------- 3 files changed, 66 insertions(+), 34 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index bc7bcce34c..ed7cadc41c 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -67,7 +67,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::BQuantGrouped; @@ -82,7 +81,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmShape, GemmUniversalTraits, QuantGroupSize, - transpose_c>, + GemmConfig::TransposeC>, ck_tile::GemmBQuantPipelineProblem>; @@ -161,7 +160,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, int main(int argc, char* argv[]) { - int result1 = run_grouped_gemm_example(argc, argv); /* || - run_grouped_gemm_example(argc, argv);*/ + int result1 = run_grouped_gemm_example(argc, argv); return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index 554584c2bd..eb8ccd86b9 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -102,6 +102,24 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +template +struct GemmConfig_Aquant : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool TransposeC = true; +}; + template struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase { @@ -118,10 +136,42 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_from_preshuffled_warp_tile(); + static constexpr bool TransposeC = false; static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmQuantConfig; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfig_Aquant; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; +}; + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index f2d7f628a5..3628247f44 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -533,12 +533,13 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - using Types = GemmTypeConfig; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using GemmConfig = GemmQuantConfig::template GemmConfig; + using Types = GemmTypeConfig; // Specific type aliases for easy access using ADataType = typename Types::ADataType; using BDataType = typename Types::BDataType; @@ -567,7 +568,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } } -template