Skip to content

Commit 79f48ee

Browse files
yukuai26Kuai YuLyricZhao
authored
Fix multicast bug and optimize masked GEMM (#193)
* Fix multicast bug and profile masked GEMM * Updates and lint --------- Co-authored-by: Kuai Yu <yukuai@deepseek.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
1 parent ea9c5d9 commit 79f48ee

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

csrc/jit_kernels/heuristics/common.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
152152

153153
// Select M/N block sizes
154154
// TODO: support `% 16 == 8` block size on SM90
155-
const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ?
156-
std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256};
155+
auto block_ms = std::vector{64, 128, 256};
156+
if (gemm_type == GemmType::MGroupedContiguous)
157+
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
158+
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
159+
block_ms = std::vector{64, 128};
157160
std::vector<int> block_ns;
158161
for (int i = 16; i <= 256; i += 16)
159162
block_ns.push_back(i);
@@ -214,7 +217,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
214217
MulticastConfig best_multicast_config = {1, true};
215218
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
216219
gemm_type, m, n, best_block_m, best_block_n, num_sms);
217-
const bool is_legal[2] = {is_legal_on_a, is_legal_on_b};
220+
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
218221
bool order[2] = {false, true};
219222
if (best_block_m > best_block_n)
220223
std::swap(order[0], order[1]);

csrc/jit_kernels/heuristics/sm100.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ struct SM100ArchSpec {
9191
const int& num_sms) {
9292
// TODO: support other layouts
9393
return {
94-
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
9594
false,
95+
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
9696
};
9797
}
9898

csrc/jit_kernels/heuristics/sm90.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ struct SM90ArchSpec {
7171
const int& num_sms) {
7272
return {
7373
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
74-
is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked,
74+
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
75+
is_multicast_legal(m, block_m, 2, num_sms, false)
76+
and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true))
7577
};
7678
}
7779

0 commit comments

Comments
 (0)