From 7353b748e3b81c16df52c6bf09a28abe742f7884 Mon Sep 17 00:00:00 2001 From: Samaresh Kumar Singh Date: Wed, 12 Nov 2025 11:14:04 -0600 Subject: [PATCH] Fix FP8 linear layer dimension check to prevent runtime error Fixes #6390 The issue occurs when use_fp8=True is enabled and the model has output layers with dimensions not divisible by 16 (e.g., binary classification with 2 outputs). torch._scaled_mm requires BOTH dimensions of mat2 (weight matrix) to be divisible by 16. The previous check only validated input dimensions but not the weight output dimension (weight.shape[0]). When using GPT2ForSequenceClassification with num_labels=2, the score layer has weight shape [768, 2], causing the error: 'Expected both dimensions of mat2 to be divisible by 16 but got torch.Size([768, 2])' This fix adds a check for weight.shape[0] % 16 != 0 to fallback to regular F.linear when the output dimension is not compatible with FP8. --- colossalai/quantization/fp8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e23da5cccd4d..e02e56aa791f 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -840,7 +840,9 @@ def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch. def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0: + # torch._scaled_mm requires both dimensions of matrices to be divisible by 16 + # Check input dimensions and weight output dimension + if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0 or weight.shape[0] % 16 != 0: return F.linear(input, weight, bias) out = _linear_fp8(input, weight, bias) return out