From 62a312233b28e7cce4ddc762d7bdd9f8475d3b46 Mon Sep 17 00:00:00 2001 From: wangguoteng <877825076@qq.com> Date: Tue, 4 Nov 2025 06:10:31 +0000 Subject: [PATCH] fix: prevent int32 overflow in k-grouped GEMM size calculations --- csrc/apis/gemm.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 8d062921..076179c4 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -280,13 +280,13 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair(d); - const auto& sum_mk = a.first.numel(); - const auto& sum_nk = b.first.numel(); - int sum_k = 0; + const auto& sum_mk = static_cast(a.first.numel()); + const auto& sum_nk = static_cast(b.first.numel()); + uint64_t sum_k = 0; for (const auto& k: ks) - sum_k += k; - DG_HOST_ASSERT(sum_mk == m * sum_k); - DG_HOST_ASSERT(sum_nk == n * sum_k); + sum_k += static_cast(k); + DG_HOST_ASSERT(sum_mk == static_cast(m) * sum_k); + DG_HOST_ASSERT(sum_nk == static_cast(n) * sum_k); // Contiguity checks DG_HOST_ASSERT(a.first.is_contiguous());