Skip to content

Commit e78bcbb

Browse files
committed
feat: remove aquant config from grouped gemm quant example, update to add persistency as runtime parameter
1 parent 76897c5 commit e78bcbb

File tree

2 files changed

+56
-54
lines changed

2 files changed

+56
-54
lines changed

example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ struct GemmTypeConfig<ck_tile::bf8_t>
6464
using CDataType = ck_tile::half_t;
6565
};
6666

67+
template <bool Persistent_>
6768
struct GemmConfigBase
6869
{
6970
static constexpr bool kPadM = false;
@@ -83,11 +84,11 @@ struct GemmConfigBase
8384
static constexpr ck_tile::index_t NumWaveGroups = 1;
8485
static constexpr bool DoubleSmemBuffer = false;
8586
static constexpr bool PreshuffleB = false;
86-
static constexpr bool Persistent = false;
87+
static constexpr bool Persistent = Persistent_;
8788
};
8889

89-
template <typename PrecType>
90-
struct GemmConfigComputeV3_2 : public GemmConfigBase
90+
template <typename PrecType, bool Persistent>
91+
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
9192
{
9293
static constexpr ck_tile::index_t M_Tile = 128;
9394
static constexpr ck_tile::index_t N_Tile = 128;
@@ -102,26 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
102103
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
103104
};
104105

105-
template <typename PrecType>
106-
struct GemmConfig_Aquant : public GemmConfigBase
107-
{
108-
static constexpr ck_tile::index_t M_Tile = 128;
109-
static constexpr ck_tile::index_t N_Tile = 128;
110-
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
111-
112-
static constexpr ck_tile::index_t M_Warp = 4;
113-
static constexpr ck_tile::index_t N_Warp = 1;
114-
static constexpr ck_tile::index_t K_Warp = 1;
115-
116-
static constexpr ck_tile::index_t M_Warp_Tile = 32;
117-
static constexpr ck_tile::index_t N_Warp_Tile = 32;
118-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
119-
120-
static constexpr bool TransposeC = true;
121-
};
122-
123-
template <typename PrecType>
124-
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
106+
template <typename PrecType, bool Persistent>
107+
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
125108
{
126109
static constexpr ck_tile::index_t M_Tile = 128;
127110
static constexpr ck_tile::index_t N_Tile = 128;
@@ -136,7 +119,6 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
136119
static constexpr ck_tile::index_t K_Warp_Tile =
137120
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
138121

139-
static constexpr bool TransposeC = false;
140122
static constexpr bool PreshuffleB = true;
141123
static constexpr bool DoubleSmemBuffer = true;
142124
};
@@ -147,8 +129,8 @@ struct GemmQuantConfig;
147129
template <>
148130
struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
149131
{
150-
template <typename PrecType>
151-
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
132+
template <typename PrecType, bool Persistent>
133+
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
152134

153135
template <typename GemmProblem, bool PreshuffleB = false>
154136
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
@@ -160,8 +142,8 @@ struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
160142
template <>
161143
struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
162144
{
163-
template <typename PrecType>
164-
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
145+
template <typename PrecType, bool Persistent>
146+
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
165147

166148
template <typename GemmProblem, bool PreshuffleB = false>
167149
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
@@ -173,8 +155,8 @@ struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
173155
template <>
174156
struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
175157
{
176-
template <typename PrecType>
177-
using GemmConfig = GemmConfig_Aquant<PrecType>;
158+
template <typename PrecType, bool Persistent>
159+
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
178160

179161
template <typename GemmProblem, bool PreshuffleB = false>
180162
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
@@ -186,8 +168,8 @@ struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
186168
template <>
187169
struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
188170
{
189-
template <typename PrecType>
190-
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType>;
171+
template <typename PrecType, bool Persistent>
172+
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType, Persistent>;
191173

192174
template <typename GemmProblem, bool PreshuffleB = false>
193175
using GemmPipeline = std::conditional_t<PreshuffleB == true,
@@ -229,7 +211,8 @@ auto create_args(int argc, char* argv[])
229211
.insert("group_count", "8", "group count.")
230212
.insert("kbatch", "1", "kbatch for SplitK")
231213
.insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol")
232-
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
214+
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
215+
.insert("persistent", "0", "Kernel kersistency. 0: non-persistent. 1: persistent.");
233216

234217
bool result = arg_parser.parse(argc, argv);
235218
return std::make_tuple(result, arg_parser);

example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -533,13 +533,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
533533
return pass;
534534
}
535535

536-
template <typename PrecType, ck_tile::QuantType QuantMode>
536+
template <typename PrecType, ck_tile::QuantType QuantMode, typename GemmConfig>
537537
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
538538
{
539-
using Row = ck_tile::tensor_layout::gemm::RowMajor;
540-
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
541-
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType>;
542-
using Types = GemmTypeConfig<PrecType>;
539+
using Row = ck_tile::tensor_layout::gemm::RowMajor;
540+
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
541+
using Types = GemmTypeConfig<PrecType>;
543542
// Specific type aliases for easy access
544543
using ADataType = typename Types::ADataType;
545544
using BDataType = typename Types::BDataType;
@@ -551,6 +550,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
551550

552551
if(a_layout == "R" && b_layout == "C")
553552
{
553+
554554
return run_grouped_gemm_example_with_layouts<GemmConfig,
555555
ADataType,
556556
AQDataType,
@@ -568,6 +568,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
568568
}
569569
}
570570

571+
template <typename PrecType, ck_tile::QuantType QuantMode>
572+
int run_gemm_example_persistency(
573+
std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[])
574+
{
575+
if(persistent)
576+
{
577+
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, true>;
578+
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
579+
a_layout, b_layout, argc, argv);
580+
}
581+
else
582+
{
583+
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, false>;
584+
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
585+
a_layout, b_layout, argc, argv);
586+
}
587+
}
588+
571589
int run_grouped_gemm_example(int argc, char* argv[])
572590
{
573591
auto [result, arg_parser] = create_args(argc, argv);
@@ -580,28 +598,29 @@ int run_grouped_gemm_example(int argc, char* argv[])
580598
const std::string b_layout = arg_parser.get_str("b_layout");
581599
const std::string data_type = arg_parser.get_str("prec");
582600
std::string quant_mode = arg_parser.get_str("quant_mode");
601+
bool persistent = arg_parser.get_bool("persistent");
583602

584603
if(data_type == "fp8")
585604
{
586605
if(quant_mode == "tensor")
587606
{
588-
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
589-
a_layout, b_layout, argc, argv);
607+
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
608+
a_layout, b_layout, persistent, argc, argv);
590609
}
591610
else if(quant_mode == "rowcol")
592611
{
593-
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
594-
a_layout, b_layout, argc, argv);
612+
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
613+
a_layout, b_layout, persistent, argc, argv);
595614
}
596615
else if(quant_mode == "aquant")
597616
{
598-
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
599-
a_layout, b_layout, argc, argv);
617+
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
618+
a_layout, b_layout, persistent, argc, argv);
600619
}
601620
else if(quant_mode == "bquant")
602621
{
603-
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
604-
a_layout, b_layout, argc, argv);
622+
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
623+
a_layout, b_layout, persistent, argc, argv);
605624
}
606625
else
607626
{
@@ -612,23 +631,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
612631
{
613632
if(quant_mode == "tensor")
614633
{
615-
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
616-
a_layout, b_layout, argc, argv);
634+
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
635+
a_layout, b_layout, persistent, argc, argv);
617636
}
618637
else if(quant_mode == "rowcol")
619638
{
620-
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
621-
a_layout, b_layout, argc, argv);
639+
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
640+
a_layout, b_layout, persistent, argc, argv);
622641
}
623642
else if(quant_mode == "aquant")
624643
{
625-
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
626-
a_layout, b_layout, argc, argv);
644+
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
645+
a_layout, b_layout, persistent, argc, argv);
627646
}
628647
else if(quant_mode == "bquant")
629648
{
630-
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
631-
a_layout, b_layout, argc, argv);
649+
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
650+
a_layout, b_layout, persistent, argc, argv);
632651
}
633652
else
634653
{

0 commit comments

Comments
 (0)