Skip to content

Commit f2bf2ec

Browse files
committed
Address comments
1 parent 8e335eb commit f2bf2ec

File tree

2 files changed

+33
-31
lines changed

2 files changed

+33
-31
lines changed

flashinfer/gemm/gemm_base.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,6 @@ def _check_cudnn_fp4_availability():
11141114

11151115
def _is_cublas_fp4_available_in_cudnn():
11161116
"""Check if cuBLAS backend for FP4 GEMM is available in cuDNN."""
1117-
_check_cudnn_availability()
11181117

11191118
# Check cuDNN backend version for FP4 support (requires cudnn_version == 9.11.1 or cudnn_version >= 9.13)
11201119
backend_version = cudnn.backend_version()
@@ -1166,7 +1165,6 @@ def create_cudnn_execution_plans_fp4_gemm(
11661165
alpha_is_not_none,
11671166
use_nvfp4,
11681167
):
1169-
_check_cudnn_availability()
11701168
stream = torch.cuda.current_stream(device)
11711169
with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _):
11721170
scale_type = cudnn.data_type.FP8_E4M3 if use_nvfp4 else cudnn.data_type.FP8_E8M0
@@ -1269,6 +1267,7 @@ def build_plans_cudnn_fp4_gemm_graph(
12691267
use_nvfp4,
12701268
tactic: int = -1,
12711269
):
1270+
# Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement
12721271
graph = create_cudnn_execution_plans_fp4_gemm(
12731272
a_shape,
12741273
a_stride,
@@ -1674,7 +1673,6 @@ def _get_cudnn_fp4_gemm_graph(
16741673
use_nvfp4: bool = True,
16751674
tactic: int = -1,
16761675
):
1677-
_check_cudnn_availability()
16781676
# the fp4 cudnn graph will be shared for both mm and bmm, so
16791677
# here we need to get the 3d shape and stride including the
16801678
# batch dimension for both input and block scale tensors.
@@ -1689,6 +1687,7 @@ def _get_cudnn_fp4_gemm_graph(
16891687
)
16901688

16911689
# build the fp4 cudnn graph
1690+
# Constructed graph is cached, via @functools.cache decorator.
16921691
graph = build_plans_cudnn_fp4_gemm_graph(
16931692
real_a_shape,
16941693
real_a_stride,
@@ -1722,6 +1721,7 @@ def _cudnn_gemm_fp4(
17221721
workspace_buffer: torch.Tensor = None,
17231722
tactic: int = -1,
17241723
):
1724+
# Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement
17251725
graph = _get_cudnn_fp4_gemm_graph(
17261726
a=a,
17271727
b=b,
@@ -1748,7 +1748,6 @@ def get_valid_tactics(
17481748
profile: OptimizationProfile,
17491749
) -> List[int]:
17501750
# cudnn has heuristic for fp4 gemm, so we only need to use the default tactic
1751-
_check_cudnn_availability()
17521751
(
17531752
a,
17541753
b,
@@ -1762,6 +1761,7 @@ def get_valid_tactics(
17621761
workspace_buffer,
17631762
) = inputs
17641763

1764+
# Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement
17651765
graph = _get_cudnn_fp4_gemm_graph(
17661766
a=a,
17671767
b=b,
@@ -1821,10 +1821,10 @@ def _check_mm_fp4_problem_size(
18211821
b_descale: torch.Tensor,
18221822
alpha: Optional[torch.Tensor] = None,
18231823
out_dtype: torch.dtype = torch.bfloat16,
1824-
out: Optional[torch.Tensor] = None,
1824+
out: Optional[torch.Tensor] = None, # unused
18251825
block_size: int = 16,
1826-
use_8x4_sf_layout: bool = False,
1827-
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
1826+
use_8x4_sf_layout: bool = False, # unused
1827+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused
18281828
use_nvfp4: bool = True,
18291829
):
18301830
# Generic checks
@@ -1878,10 +1878,10 @@ def _cudnn_gemm_fp4_requirement(
18781878
b_descale: torch.Tensor,
18791879
alpha: Optional[torch.Tensor] = None,
18801880
out_dtype: torch.dtype = torch.bfloat16,
1881-
out: Optional[torch.Tensor] = None,
1881+
out: Optional[torch.Tensor] = None, # unused
18821882
block_size: int = 16,
18831883
use_8x4_sf_layout: bool = False,
1884-
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
1884+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused
18851885
use_nvfp4: bool = True,
18861886
):
18871887
if use_8x4_sf_layout:
@@ -1908,7 +1908,8 @@ def _cudnn_gemm_fp4_requirement(
19081908
_expand_block_scale_tensor_shape(b_descale, batch)
19091909
)
19101910

1911-
# build the fp4 cudnn graph
1911+
# build the fp4 cudnn graph. This graph will be cached & reused in mm_fp4()
1912+
# because the graph is constructed with @functools.cache decorator
19121913
graph = create_cudnn_execution_plans_fp4_gemm(
19131914
real_a_shape,
19141915
real_a_stride,
@@ -1932,16 +1933,16 @@ def _cudnn_gemm_fp4_requirement(
19321933

19331934
@supported_compute_capability([100, 103])
19341935
def _trtllm_gemm_fp4_requirement(
1935-
a: torch.Tensor,
1936-
b: torch.Tensor,
1937-
a_descale: torch.Tensor,
1938-
b_descale: torch.Tensor,
1939-
alpha: Optional[torch.Tensor] = None,
1936+
a: torch.Tensor, # unused
1937+
b: torch.Tensor, # unused
1938+
a_descale: torch.Tensor, # unused
1939+
b_descale: torch.Tensor, # unused
1940+
alpha: Optional[torch.Tensor] = None, # unused
19401941
out_dtype: torch.dtype = torch.bfloat16,
1941-
out: Optional[torch.Tensor] = None,
1942-
block_size: int = 16,
1943-
use_8x4_sf_layout: bool = False,
1944-
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
1942+
out: Optional[torch.Tensor] = None, # unused
1943+
block_size: int = 16, # unused
1944+
use_8x4_sf_layout: bool = False, # unused
1945+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused
19451946
use_nvfp4: bool = True,
19461947
):
19471948
if not use_nvfp4:
@@ -1956,16 +1957,16 @@ def _trtllm_gemm_fp4_requirement(
19561957

19571958
@supported_compute_capability([100, 103, 110, 120, 121])
19581959
def _cutlass_gemm_fp4_requirement(
1959-
a: torch.Tensor,
1960-
b: torch.Tensor,
1961-
a_descale: torch.Tensor,
1962-
b_descale: torch.Tensor,
1963-
alpha: Optional[torch.Tensor] = None,
1964-
out_dtype: torch.dtype = torch.bfloat16,
1965-
out: Optional[torch.Tensor] = None,
1966-
block_size: int = 16,
1960+
a: torch.Tensor, # unused
1961+
b: torch.Tensor, # unused
1962+
a_descale: torch.Tensor, # unused
1963+
b_descale: torch.Tensor, # unused
1964+
alpha: Optional[torch.Tensor] = None, # unused
1965+
out_dtype: torch.dtype = torch.bfloat16, # unused
1966+
out: Optional[torch.Tensor] = None, # unused
1967+
block_size: int = 16, # unused
19671968
use_8x4_sf_layout: bool = False,
1968-
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
1969+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused
19691970
use_nvfp4: bool = True,
19701971
):
19711972
if use_8x4_sf_layout:

flashinfer/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,8 @@ def backend_requirement(
922922
True if requirements are met, False otherwise.
923923
In the case where the kernel function does not have any specific backends, this can be decorated with @supported_compute_capability to specify the function's supported compute capabilities.
924924
heuristic_func : callable, optional
925-
An optional function that performs heuristic backend selection when backend is "auto". Does not do anything if backend is not "auto".
925+
A function that performs heuristic backend selection when backend is "auto".
926+
Must be provided if backend is "auto". Does not do anything if backend is not "auto".
926927
Should accept the same arguments as the decorated function.
927928
Should return an ordered list of runnable backends with the most preferred backend first.
928929
When decorated function is not autotuned, the first backend in the heuristic list will be run.
@@ -1082,8 +1083,8 @@ def suitable_auto_backends(cc, *args, **kwargs):
10821083
except ValueError:
10831084
continue
10841085
# If a heuristic function is provided, filter the suitable backends based on the heuristic function
1085-
if heuristic_func is not None:
1086-
suitable_backends = heuristic_func(suitable_backends, *args, **kwargs)
1086+
assert heuristic_func is not None, "Heuristic function must be provided"
1087+
suitable_backends = heuristic_func(suitable_backends, *args, **kwargs)
10871088
if not suitable_backends:
10881089
return False
10891090
wrapper.suitable_auto_backends = suitable_backends

0 commit comments

Comments
 (0)