@@ -533,13 +533,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
533533 return pass;
534534}
535535
536- template <typename PrecType, ck_tile::QuantType QuantMode>
536+ template <typename PrecType, ck_tile::QuantType QuantMode, typename GemmConfig >
537537int run_gemm_example_prec_type (std::string a_layout, std::string b_layout, int argc, char * argv[])
538538{
539- using Row = ck_tile::tensor_layout::gemm::RowMajor;
540- using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
541- using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType>;
542- using Types = GemmTypeConfig<PrecType>;
539+ using Row = ck_tile::tensor_layout::gemm::RowMajor;
540+ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
541+ using Types = GemmTypeConfig<PrecType>;
543542 // Specific type aliases for easy access
544543 using ADataType = typename Types::ADataType;
545544 using BDataType = typename Types::BDataType;
@@ -551,6 +550,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
551550
552551 if (a_layout == " R" && b_layout == " C" )
553552 {
553+
554554 return run_grouped_gemm_example_with_layouts<GemmConfig,
555555 ADataType,
556556 AQDataType,
@@ -568,6 +568,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
568568 }
569569}
570570
571+ template <typename PrecType, ck_tile::QuantType QuantMode>
572+ int run_gemm_example_persistency (
573+ std::string a_layout, std::string b_layout, bool persistent, int argc, char * argv[])
574+ {
575+ if (persistent)
576+ {
577+ using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, true >;
578+ return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
579+ a_layout, b_layout, argc, argv);
580+ }
581+ else
582+ {
583+ using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, false >;
584+ return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
585+ a_layout, b_layout, argc, argv);
586+ }
587+ }
588+
571589int run_grouped_gemm_example (int argc, char * argv[])
572590{
573591 auto [result, arg_parser] = create_args (argc, argv);
@@ -580,28 +598,29 @@ int run_grouped_gemm_example(int argc, char* argv[])
580598 const std::string b_layout = arg_parser.get_str (" b_layout" );
581599 const std::string data_type = arg_parser.get_str (" prec" );
582600 std::string quant_mode = arg_parser.get_str (" quant_mode" );
601+ bool persistent = arg_parser.get_bool (" persistent" );
583602
584603 if (data_type == " fp8" )
585604 {
586605 if (quant_mode == " tensor" )
587606 {
588- return run_gemm_example_prec_type <ck_tile::fp8_t , ck_tile::QuantType::TensorQuant>(
589- a_layout, b_layout, argc, argv);
607+ return run_gemm_example_persistency <ck_tile::fp8_t , ck_tile::QuantType::TensorQuant>(
608+ a_layout, b_layout, persistent, argc, argv);
590609 }
591610 else if (quant_mode == " rowcol" )
592611 {
593- return run_gemm_example_prec_type <ck_tile::fp8_t , ck_tile::QuantType::RowColQuant>(
594- a_layout, b_layout, argc, argv);
612+ return run_gemm_example_persistency <ck_tile::fp8_t , ck_tile::QuantType::RowColQuant>(
613+ a_layout, b_layout, persistent, argc, argv);
595614 }
596615 else if (quant_mode == " aquant" )
597616 {
598- return run_gemm_example_prec_type <ck_tile::fp8_t , ck_tile::QuantType::AQuantGrouped>(
599- a_layout, b_layout, argc, argv);
617+ return run_gemm_example_persistency <ck_tile::fp8_t , ck_tile::QuantType::AQuantGrouped>(
618+ a_layout, b_layout, persistent, argc, argv);
600619 }
601620 else if (quant_mode == " bquant" )
602621 {
603- return run_gemm_example_prec_type <ck_tile::fp8_t , ck_tile::QuantType::BQuantGrouped>(
604- a_layout, b_layout, argc, argv);
622+ return run_gemm_example_persistency <ck_tile::fp8_t , ck_tile::QuantType::BQuantGrouped>(
623+ a_layout, b_layout, persistent, argc, argv);
605624 }
606625 else
607626 {
@@ -612,23 +631,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
612631 {
613632 if (quant_mode == " tensor" )
614633 {
615- return run_gemm_example_prec_type <ck_tile::bf8_t , ck_tile::QuantType::TensorQuant>(
616- a_layout, b_layout, argc, argv);
634+ return run_gemm_example_persistency <ck_tile::bf8_t , ck_tile::QuantType::TensorQuant>(
635+ a_layout, b_layout, persistent, argc, argv);
617636 }
618637 else if (quant_mode == " rowcol" )
619638 {
620- return run_gemm_example_prec_type <ck_tile::bf8_t , ck_tile::QuantType::RowColQuant>(
621- a_layout, b_layout, argc, argv);
639+ return run_gemm_example_persistency <ck_tile::bf8_t , ck_tile::QuantType::RowColQuant>(
640+ a_layout, b_layout, persistent, argc, argv);
622641 }
623642 else if (quant_mode == " aquant" )
624643 {
625- return run_gemm_example_prec_type <ck_tile::bf8_t , ck_tile::QuantType::AQuantGrouped>(
626- a_layout, b_layout, argc, argv);
644+ return run_gemm_example_persistency <ck_tile::bf8_t , ck_tile::QuantType::AQuantGrouped>(
645+ a_layout, b_layout, persistent, argc, argv);
627646 }
628647 else if (quant_mode == " bquant" )
629648 {
630- return run_gemm_example_prec_type <ck_tile::bf8_t , ck_tile::QuantType::BQuantGrouped>(
631- a_layout, b_layout, argc, argv);
649+ return run_gemm_example_persistency <ck_tile::bf8_t , ck_tile::QuantType::BQuantGrouped>(
650+ a_layout, b_layout, persistent, argc, argv);
632651 }
633652 else
634653 {
0 commit comments