diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index e9e522e0b..89bc6c647 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -457,12 +457,14 @@ struct CollectiveMma< #pragma unroll for (int i = 0; i < size(tApA_ScaleA); ++i) { - tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; + tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < + std::min(scales_m, (m_coord + 1) * ScaleMsPerTile); } #pragma unroll for (int i = 0; i < size(tApA_ScaleA); ++i) { - tBpB_ScaleB(i) = get<0>(tBcB_ScaleB(i)) < scales_n; + tBpB_ScaleB(i) = get<0>(tBcB_ScaleB(i)) < + std::min(scales_n, (n_coord + 1) * ScaleNsPerTile); } // Mainloop