diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index d97145cbc3..628e9194ae 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -154,6 +155,10 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + using Base = BlockGemmBQuantBase; using WarpGemm = remove_cvref_t; @@ -271,12 +276,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, b_block_window); + load_int4_tile( + a_warp_tile_, a_block_window); + // If B datatype were pkint4 it would be converted prior to storing in LDS + load_int4_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -397,11 +410,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase MakeCBlockTile(); } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index dd85705cf2..203b79aec6 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -426,7 +426,6 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::BQuantGrouped) { - static_assert(std::is_same_v); if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -781,7 +780,9 @@ struct QuantGemmKernel { if constexpr(PreshuffleQuant) { - static_assert(std::is_same_v); + static_assert(std::is_same_v, + "PreshuffleQuant with BQuantGrouped currently only supports " + "ColumnMajor BQ layout"); return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, @@ -791,14 +792,35 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); + + if constexpr(std::is_same_v) + { + // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] + // Dimensions: [K/QuantGroupK, N/QuantGroupN] + // Strides: [N/QuantGroupN, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + number{}, + number<1>{}); + } + else + { + static_assert(std::is_same_v); + // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] + // Dimensions: [N/QuantGroupN, K/QuantGroupK] + // Strides: [K/QuantGroupK, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + integer_divide_ceil(kargs.K, QuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + number{}, + number<1>{}); + } } } else @@ -1023,10 +1045,10 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { + using QuantGroupSize = remove_cvref_t; if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; @@ -1042,13 +1064,23 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); + if constexpr(std::is_same_v) + { + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } + else + { + static_assert(std::is_same_v); + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } } } else diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp index 4cd343e640..c570d4a131 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp @@ -42,14 +42,18 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - - using YPerTile = number; - using XPerTile = number; + using YPerTile = + std::conditional_t, + number, + number>; + using XPerTile = + std::conditional_t, + number, + number>; auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile(), XPerTile()), + make_tuple(YPerTile{}, XPerTile{}), bq_dram_block_window_tmp.get_window_origin(), Policy::template MakeBQDramTileDistribution()); return bq_copy_dram_window; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 870326cb9d..154d068f0a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -25,8 +25,16 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; - static_assert(std::is_same_v); - return GetABQGlobalVectorLoadSize(); + // Support both RowMajor and ColumnMajor layouts for BQ + if constexpr(std::is_same_v) + { + return GetABQGlobalVectorLoadSize(); + } + else + { + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } } template @@ -52,7 +60,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I2), Problem::TransposeC>; - static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_bq< @@ -62,18 +69,21 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), VecLoadSize, + BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } else { + // KPerTile and NPerTile are LOGICAL dimensions (K quant groups and N quant groups) using TileEncodingPattern = tile_distribution_encoding_pattern_bq; + KPerBlockBQ, // Logical K dimension + NPerBlockBQ, // Logical N dimension + Problem::QuantGroupSize::kN, + BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 4883a30f57..2c191cc2b4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -33,6 +33,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using QuantGroupSize = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; using I1 = number<1>; @@ -83,6 +87,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -125,7 +132,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, + const BDramWindow& b_dram_window) + { + using DestDataType = typename BBlockTile_::DataType; + using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(b_block_tile, b_dram_window); + } + template ; - constexpr bool is_bq_col_major = - std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + constexpr bool is_bq_row_major = + std::is_same_v; static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -212,12 +227,22 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -237,7 +262,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using BQBlockTile = decltype(make_static_distributed_tensor(BQBlockTileDistr{})); @@ -258,18 +283,20 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), 0) - : is_bq_col_major ? make_array(0, KPerBlockBQ) - : make_array(KPerBlockBQ, 0); + : is_bq_row_major ? make_array(KPerBlockBQ, 0) + : make_array(0, KPerBlockBQ); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // B tile gets converted to A datatype during loading + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); Base::GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -281,9 +308,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // B datatype is converted to A datatype during loading + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -294,11 +322,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -322,9 +352,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType PkInt4 gets converted during loading earlier + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -335,7 +366,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType gets converted during loading from PkInt4 + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -393,7 +427,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -210,36 +211,41 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) /// /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (XPerQ) relative + /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative /// to warp dimensions. /// /// Three distinct distribution patterns are handled: /// - /// 1. Fine-grained quantization (XPerQ < WarpGemm::kN): + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast - /// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp /// - /// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps): + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN - /// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 /// - /// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps): + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): /// - Quantization group spans multiple warps /// - All warps share the same scale value - /// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale /// /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { + // Preshuffle only supported for ColumnMajor currently + static_assert(!(PreshuffleQuant && std::is_same_v), + "PreshuffleQuant only supported for ColumnMajor BQLayout"); + if constexpr(PreshuffleQuant) { + // ColumnMajor only for preshuffle constexpr index_t X1 = warp_size; - constexpr index_t X0 = XPerTile / warp_size; + constexpr index_t X0 = NPerTile / warp_size; constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = YPerTile / Y1; + constexpr index_t Y0 = KPerTile / Y1; return make_static_tile_distribution( tile_distribution_encoding, @@ -251,52 +257,97 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - if constexpr(YPerQ < WarpGemm::kN) + if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t X = XPerTile; // Full X dimension of tile - constexpr index_t XR = 1; // No Y replication needed - constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t Y1 = NWarps; // Number of warps in N-dim - constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp - constexpr index_t YR = YPerQ; // Elements per quantization group - - static_assert(Y0 * Y1 * Y2 == YPerTile, - "Y0, Y1, Y2 must cover the blocktile along Y."); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 1, 0>>, - tuple, sequence<1, 2, 2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + // N dimension needs to be partitioned the same way regardless of layout + constexpr index_t NR = 1; // No N replication needed + constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t N1 = NWarps; // Number of warps in N-dim + constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp + + static_assert(N0 * N1 * N2 == NPerTile, + "N0, N1, N2 must cover the blocktile along N dimension."); + + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else if constexpr(YPerQ <= WarpGemm::kN * NWarps) + else if constexpr(NPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto Y1 = NWarps / YR; // Warps per unique scale - constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto N1 = NWarps / NR; // Warps per unique scale + constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension + + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1), K] - N on Y-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1)] - N on X-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else // XPerQ > WarpGemm::kN * NWarps + else // NPerQ > WarpGemm::kN * NWarps { // Case 3: Coarse-grained - quantization group spans all warps // All warps in N-dimension share the same quantization scale - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<2, 1>, - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [N, K] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, N] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } } } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 38bd59b882..39a7c66f38 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -86,8 +86,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; - // BQLayout is always ColumnMajor for BQuant - using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + // Re-use the AQLayout for BQLayout + using BQLayout = AQLayout; using CodegenGemmTraits = ck_tile::TileGemmQuantTraits>; using GroupSize2D128N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests (without PreshuffleB) -// Tuple format: // clang-format off using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + // 1d cases with grouping only on k axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // some cases with transpose layouts + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple, + + // pkint4 + transpose cases + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 6cde4bded5..3a62fc091a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -26,60 +26,60 @@ using GroupSize2D32N = ck_tile::QuantGroupShape>; using GroupSize2D64N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests with PreshuffleB -// Tuple format: // clang-format off using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // //2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 7b16529aa8..bf9c7a138d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -389,6 +389,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBaseis_row_major(BQLayout{}) ? BQN : BQK; // Generate test data ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); - // BQ is always ColumnMajor ck_tile::HostTensor bq_bqk_bqn( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant{})); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k);