Skip to content

Commit 604cfd4

Browse files
committed
bmm_fp8
1 parent 9ce1af7 commit 604cfd4

File tree

1 file changed

+64
-7
lines changed

1 file changed

+64
-7
lines changed

flashinfer/gemm.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,6 @@ def fp8_gemm_sm100(
381381
if CUDNN_AVAILABLE and "cudnn" in runner_names:
382382
runners.append(_cudnn_gemm_fp8_runner())
383383

384-
if len(runners) == 0:
385-
major, minor = get_compute_capability(torch.device("cuda"))
386-
raise ValueError(f"No valid runner found for current device sm{major}{minor}")
387-
388384
tuner = AutoTuner.get()
389385
a_tensor_index = 0
390386
out_tensor_index = 4
@@ -2009,6 +2005,70 @@ def mm_fp4(
20092005
return out
20102006

20112007

2008+
def _check_bmm_fp8_problem_size(
2009+
A: torch.Tensor,
2010+
B: torch.Tensor,
2011+
A_scale: torch.Tensor,
2012+
B_scale: torch.Tensor,
2013+
dtype: torch.dtype,
2014+
out: Optional[torch.Tensor] = None,
2015+
backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas",
2016+
):
2017+
_validate_fp8_output_dtype(dtype)
2018+
return True
2019+
2020+
2021+
@supported_compute_capability([89, 90, 100, 103, 120])
2022+
def _cudnn_bmm_fp8_requirement(
2023+
A: torch.Tensor,
2024+
B: torch.Tensor,
2025+
A_scale: torch.Tensor,
2026+
B_scale: torch.Tensor,
2027+
dtype: torch.dtype,
2028+
out: Optional[torch.Tensor] = None,
2029+
backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas",
2030+
):
2031+
_check_cudnn_availability()
2032+
return True
2033+
2034+
2035+
@supported_compute_capability([89, 90, 100, 103, 120])
2036+
def _cublas_bmm_fp8_requirement(
2037+
A: torch.Tensor,
2038+
B: torch.Tensor,
2039+
A_scale: torch.Tensor,
2040+
B_scale: torch.Tensor,
2041+
dtype: torch.dtype,
2042+
out: Optional[torch.Tensor] = None,
2043+
backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas",
2044+
):
2045+
return True
2046+
2047+
2048+
@supported_compute_capability([100, 103, 110, 120, 121])
2049+
def _cutlass_bmm_fp8_requirement(
2050+
A: torch.Tensor,
2051+
B: torch.Tensor,
2052+
A_scale: torch.Tensor,
2053+
B_scale: torch.Tensor,
2054+
dtype: torch.dtype,
2055+
out: Optional[torch.Tensor] = None,
2056+
backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas",
2057+
):
2058+
if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2:
2059+
raise ValueError("e5m2 is not supported for bmm_fp8 with cutlass backend")
2060+
return True
2061+
2062+
2063+
@backend_requirement(
2064+
{
2065+
"cudnn": _cudnn_bmm_fp8_requirement,
2066+
"cublas": _cublas_bmm_fp8_requirement,
2067+
"cutlass": _cutlass_bmm_fp8_requirement,
2068+
"auto": _cublas_bmm_fp8_requirement, # cublas default
2069+
},
2070+
common_check=_check_bmm_fp8_problem_size,
2071+
)
20122072
def bmm_fp8(
20132073
A: torch.Tensor,
20142074
B: torch.Tensor,
@@ -2073,7 +2133,6 @@ def bmm_fp8(
20732133
>>> out.dtype
20742134
torch.bfloat16
20752135
"""
2076-
_validate_fp8_output_dtype(dtype)
20772136

20782137
if out is None:
20792138
out = torch.empty(
@@ -2091,8 +2150,6 @@ def bmm_fp8(
20912150
elif backend == "cublas":
20922151
backends = ["cublas"]
20932152
elif backend == "cutlass":
2094-
if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2:
2095-
raise ValueError("e5m2 is not supported for cutlass backend")
20962153
backends = ["cutlass"]
20972154
elif backend == "auto":
20982155
backends = ["cutlass", "cublas", "cudnn"]

0 commit comments

Comments
 (0)