fix: prevent int32 overflow in k-grouped GEMM size calculations #226
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix Int32 Overflow: Prevent overflow in k-grouped GEMM size validation
Description
Fix integer overflow in k-grouped GEMM when validating tensor sizes with large dimensions.
While training the Deepseek-v3 model using DeepGemm on an H100 machine, we encountered an error in the group GEMM kernel when using long sequences ($>256\text{k}$ ). Upon investigation, we found the cause was an overflow during the calculation of the product of sum_k and hidden_size.
Root cause: When
mornare large and multiplied withsum_k, the result exceeds int32 max value (2,147,483,647), causing incorrect validation or undefined behavior.Solution: Cast
mandntouint64_tbefore multiplication to safely handle large matrix dimensions.Changes
csrc/apis/gemm.hpp:289-290: Cast touint64_tin size assertions