Skip to content

Commit 8b6207b

Browse files
committed
Final cleanup
1 parent 9fc4a57 commit 8b6207b

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

flashinfer/gemm/gemm_base.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,21 +2003,15 @@ def _heuristic_func_mm_fp4(
20032003
"""
20042004
cuda_major, _ = get_cuda_version()
20052005
# If cuda version is 13 or greater:
2006-
# cudnn is more performant if cudnn version is 9.14 or greater.
2007-
if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400:
2006+
# cudnn is more performant if cudnn version is 9.15 or greater.
2007+
if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500:
20082008
candidate_backends = ("cudnn", "cutlass")
20092009
# Otherwise, prioritize cutlass
20102010
else:
20112011
candidate_backends = ("cutlass", "cudnn")
20122012

2013-
# Filter to only supported backends for this compute capability
2014-
# Note: The requirement function already validated that at least one backend is supported
2015-
heuristic_backends = []
2016-
for candidate in candidate_backends:
2017-
# mypy requires explicit type casting for the backend literal
2018-
if candidate in suitable_backends:
2019-
heuristic_backends.append(candidate)
2020-
return heuristic_backends
2013+
# Filter and return only supported backends
2014+
return [c for c in candidate_backends if c in suitable_backends]
20212015

20222016

20232017
@backend_requirement(
@@ -2027,7 +2021,7 @@ def _heuristic_func_mm_fp4(
20272021
"cutlass": _cutlass_gemm_fp4_requirement,
20282022
},
20292023
common_check=_check_mm_fp4_problem_size,
2030-
heuristic_func=_heuristic_func_mm_fp4,
2024+
heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends
20312025
)
20322026
def mm_fp4(
20332027
a: torch.Tensor,

0 commit comments

Comments
 (0)