@@ -2003,21 +2003,15 @@ def _heuristic_func_mm_fp4(
20032003 """
20042004 cuda_major , _ = get_cuda_version ()
20052005 # If cuda version is 13 or greater:
2006- # cudnn is more performant if cudnn version is 9.14 or greater.
2007- if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn .backend_version () >= 91400 :
2006+ # cudnn is more performant if cudnn version is 9.15 or greater.
2007+ if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn .backend_version () >= 91500 :
20082008 candidate_backends = ("cudnn" , "cutlass" )
20092009 # Otherwise, prioritize cutlass
20102010 else :
20112011 candidate_backends = ("cutlass" , "cudnn" )
20122012
2013- # Filter to only supported backends for this compute capability
2014- # Note: The requirement function already validated that at least one backend is supported
2015- heuristic_backends = []
2016- for candidate in candidate_backends :
2017- # mypy requires explicit type casting for the backend literal
2018- if candidate in suitable_backends :
2019- heuristic_backends .append (candidate )
2020- return heuristic_backends
2013+ # Filter and return only supported backends
2014+ return [c for c in candidate_backends if c in suitable_backends ]
20212015
20222016
20232017@backend_requirement (
@@ -2027,7 +2021,7 @@ def _heuristic_func_mm_fp4(
20272021 "cutlass" : _cutlass_gemm_fp4_requirement ,
20282022 },
20292023 common_check = _check_mm_fp4_problem_size ,
2030- heuristic_func = _heuristic_func_mm_fp4 ,
2024+ heuristic_func = _heuristic_func_mm_fp4 , # result stored in mm_fp4.suitable_auto_backends
20312025)
20322026def mm_fp4 (
20332027 a : torch .Tensor ,
0 commit comments