diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu index 5a3d98da6ddc..337a5ad69ac6 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu @@ -129,9 +129,8 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor template void sm90_fp8_blockwise_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { - // TODO: switch different tileshape and clustershape with different m and n using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _2, _1>; + using ClusterShape = Shape<_1, _1, _1>; launch_sm90_fp8_blockwise_scaled_mm(out, a, b, scales_a, scales_b); }