@@ -381,10 +381,6 @@ def fp8_gemm_sm100(
381381 if CUDNN_AVAILABLE and "cudnn" in runner_names :
382382 runners .append (_cudnn_gemm_fp8_runner ())
383383
384- if len (runners ) == 0 :
385- major , minor = get_compute_capability (torch .device ("cuda" ))
386- raise ValueError (f"No valid runner found for current device sm{ major } { minor } " )
387-
388384 tuner = AutoTuner .get ()
389385 a_tensor_index = 0
390386 out_tensor_index = 4
@@ -2009,6 +2005,70 @@ def mm_fp4(
20092005 return out
20102006
20112007
2008+ def _check_bmm_fp8_problem_size (
2009+ A : torch .Tensor ,
2010+ B : torch .Tensor ,
2011+ A_scale : torch .Tensor ,
2012+ B_scale : torch .Tensor ,
2013+ dtype : torch .dtype ,
2014+ out : Optional [torch .Tensor ] = None ,
2015+ backend : Literal ["cudnn" , "cublas" , "cutlass" , "auto" ] = "cublas" ,
2016+ ):
2017+ _validate_fp8_output_dtype (dtype )
2018+ return True
2019+
2020+
2021+ @supported_compute_capability ([89 , 90 , 100 , 103 , 120 ])
2022+ def _cudnn_bmm_fp8_requirement (
2023+ A : torch .Tensor ,
2024+ B : torch .Tensor ,
2025+ A_scale : torch .Tensor ,
2026+ B_scale : torch .Tensor ,
2027+ dtype : torch .dtype ,
2028+ out : Optional [torch .Tensor ] = None ,
2029+ backend : Literal ["cudnn" , "cublas" , "cutlass" , "auto" ] = "cublas" ,
2030+ ):
2031+ _check_cudnn_availability ()
2032+ return True
2033+
2034+
2035+ @supported_compute_capability ([89 , 90 , 100 , 103 , 120 ])
2036+ def _cublas_bmm_fp8_requirement (
2037+ A : torch .Tensor ,
2038+ B : torch .Tensor ,
2039+ A_scale : torch .Tensor ,
2040+ B_scale : torch .Tensor ,
2041+ dtype : torch .dtype ,
2042+ out : Optional [torch .Tensor ] = None ,
2043+ backend : Literal ["cudnn" , "cublas" , "cutlass" , "auto" ] = "cublas" ,
2044+ ):
2045+ return True
2046+
2047+
2048+ @supported_compute_capability ([100 , 103 , 110 , 120 , 121 ])
2049+ def _cutlass_bmm_fp8_requirement (
2050+ A : torch .Tensor ,
2051+ B : torch .Tensor ,
2052+ A_scale : torch .Tensor ,
2053+ B_scale : torch .Tensor ,
2054+ dtype : torch .dtype ,
2055+ out : Optional [torch .Tensor ] = None ,
2056+ backend : Literal ["cudnn" , "cublas" , "cutlass" , "auto" ] = "cublas" ,
2057+ ):
2058+ if A .dtype == torch .float8_e5m2 or B .dtype == torch .float8_e5m2 :
2059+ raise ValueError ("e5m2 is not supported for bmm_fp8 with cutlass backend" )
2060+ return True
2061+
2062+
2063+ @backend_requirement (
2064+ {
2065+ "cudnn" : _cudnn_bmm_fp8_requirement ,
2066+ "cublas" : _cublas_bmm_fp8_requirement ,
2067+ "cutlass" : _cutlass_bmm_fp8_requirement ,
2068+ "auto" : _cublas_bmm_fp8_requirement , # cublas default
2069+ },
2070+ common_check = _check_bmm_fp8_problem_size ,
2071+ )
20122072def bmm_fp8 (
20132073 A : torch .Tensor ,
20142074 B : torch .Tensor ,
@@ -2073,7 +2133,6 @@ def bmm_fp8(
20732133 >>> out.dtype
20742134 torch.bfloat16
20752135 """
2076- _validate_fp8_output_dtype (dtype )
20772136
20782137 if out is None :
20792138 out = torch .empty (
@@ -2091,8 +2150,6 @@ def bmm_fp8(
20912150 elif backend == "cublas" :
20922151 backends = ["cublas" ]
20932152 elif backend == "cutlass" :
2094- if A .dtype == torch .float8_e5m2 or B .dtype == torch .float8_e5m2 :
2095- raise ValueError ("e5m2 is not supported for cutlass backend" )
20962153 backends = ["cutlass" ]
20972154 elif backend == "auto" :
20982155 backends = ["cutlass" , "cublas" , "cudnn" ]
0 commit comments