@@ -152,8 +152,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
152152
153153 // Select M/N block sizes
154154 // TODO: support `% 16 == 8` block size on SM90
155- const auto & block_ms = gemm_type == GemmType::MGroupedContiguous ?
156- std::vector{get_mk_alignment_for_contiguous_layout ()} : std::vector{64 , 128 , 256 };
155+ auto block_ms = std::vector{64 , 128 , 256 };
156+ if (gemm_type == GemmType::MGroupedContiguous)
157+ block_ms = std::vector{get_mk_alignment_for_contiguous_layout ()};
158+ if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
159+ block_ms = std::vector{64 , 128 };
157160 std::vector<int > block_ns;
158161 for (int i = 16 ; i <= 256 ; i += 16 )
159162 block_ns.push_back (i);
@@ -214,7 +217,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
214217 MulticastConfig best_multicast_config = {1 , true };
215218 const auto & [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality (
216219 gemm_type, m, n, best_block_m, best_block_n, num_sms);
217- const bool is_legal[2 ] = {is_legal_on_a, is_legal_on_b };
220+ const bool is_legal[2 ] = {is_legal_on_b, is_legal_on_a };
218221 bool order[2 ] = {false , true };
219222 if (best_block_m > best_block_n)
220223 std::swap (order[0 ], order[1 ]);
0 commit comments