@@ -1114,7 +1114,6 @@ def _check_cudnn_fp4_availability():
11141114
11151115def _is_cublas_fp4_available_in_cudnn ():
11161116 """Check if cuBLAS backend for FP4 GEMM is available in cuDNN."""
1117- _check_cudnn_availability ()
11181117
11191118 # Check cuDNN backend version for FP4 support (requires cudnn_version == 9.11.1 or cudnn_version >= 9.13)
11201119 backend_version = cudnn .backend_version ()
@@ -1166,7 +1165,6 @@ def create_cudnn_execution_plans_fp4_gemm(
11661165 alpha_is_not_none ,
11671166 use_nvfp4 ,
11681167):
1169- _check_cudnn_availability ()
11701168 stream = torch .cuda .current_stream (device )
11711169 with cudnn .graph (_get_cudnn_handle (stream )) as (graph , _ ):
11721170 scale_type = cudnn .data_type .FP8_E4M3 if use_nvfp4 else cudnn .data_type .FP8_E8M0
@@ -1269,6 +1267,7 @@ def build_plans_cudnn_fp4_gemm_graph(
12691267 use_nvfp4 ,
12701268 tactic : int = - 1 ,
12711269):
1270+ # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement
12721271 graph = create_cudnn_execution_plans_fp4_gemm (
12731272 a_shape ,
12741273 a_stride ,
@@ -1674,7 +1673,6 @@ def _get_cudnn_fp4_gemm_graph(
16741673 use_nvfp4 : bool = True ,
16751674 tactic : int = - 1 ,
16761675):
1677- _check_cudnn_availability ()
16781676 # the fp4 cudnn graph will be shared for both mm and bmm, so
16791677 # here we need to get the 3d shape and stride including the
16801678 # batch dimension for both input and block scale tensors.
@@ -1689,6 +1687,7 @@ def _get_cudnn_fp4_gemm_graph(
16891687 )
16901688
16911689 # build the fp4 cudnn graph
1690+ # Constructed graph is cached, via @functools.cache decorator.
16921691 graph = build_plans_cudnn_fp4_gemm_graph (
16931692 real_a_shape ,
16941693 real_a_stride ,
@@ -1722,6 +1721,7 @@ def _cudnn_gemm_fp4(
17221721 workspace_buffer : torch .Tensor = None ,
17231722 tactic : int = - 1 ,
17241723):
1724+ # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement
17251725 graph = _get_cudnn_fp4_gemm_graph (
17261726 a = a ,
17271727 b = b ,
@@ -1748,7 +1748,6 @@ def get_valid_tactics(
17481748 profile : OptimizationProfile ,
17491749 ) -> List [int ]:
17501750 # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic
1751- _check_cudnn_availability ()
17521751 (
17531752 a ,
17541753 b ,
@@ -1762,6 +1761,7 @@ def get_valid_tactics(
17621761 workspace_buffer ,
17631762 ) = inputs
17641763
1764+ # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement
17651765 graph = _get_cudnn_fp4_gemm_graph (
17661766 a = a ,
17671767 b = b ,
@@ -1821,10 +1821,10 @@ def _check_mm_fp4_problem_size(
18211821 b_descale : torch .Tensor ,
18221822 alpha : Optional [torch .Tensor ] = None ,
18231823 out_dtype : torch .dtype = torch .bfloat16 ,
1824- out : Optional [torch .Tensor ] = None ,
1824+ out : Optional [torch .Tensor ] = None , # unused
18251825 block_size : int = 16 ,
1826- use_8x4_sf_layout : bool = False ,
1827- backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" ,
1826+ use_8x4_sf_layout : bool = False , # unused
1827+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" , # unused
18281828 use_nvfp4 : bool = True ,
18291829):
18301830 # Generic checks
@@ -1878,10 +1878,10 @@ def _cudnn_gemm_fp4_requirement(
18781878 b_descale : torch .Tensor ,
18791879 alpha : Optional [torch .Tensor ] = None ,
18801880 out_dtype : torch .dtype = torch .bfloat16 ,
1881- out : Optional [torch .Tensor ] = None ,
1881+ out : Optional [torch .Tensor ] = None , # unused
18821882 block_size : int = 16 ,
18831883 use_8x4_sf_layout : bool = False ,
1884- backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" ,
1884+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" , # unused
18851885 use_nvfp4 : bool = True ,
18861886):
18871887 if use_8x4_sf_layout :
@@ -1908,7 +1908,8 @@ def _cudnn_gemm_fp4_requirement(
19081908 _expand_block_scale_tensor_shape (b_descale , batch )
19091909 )
19101910
1911- # build the fp4 cudnn graph
1911+ # build the fp4 cudnn graph. This graph will be cached & reused in mm_fp4()
1912+ # because the graph is constructed with @functools.cache decorator
19121913 graph = create_cudnn_execution_plans_fp4_gemm (
19131914 real_a_shape ,
19141915 real_a_stride ,
@@ -1932,16 +1933,16 @@ def _cudnn_gemm_fp4_requirement(
19321933
19331934@supported_compute_capability ([100 , 103 ])
19341935def _trtllm_gemm_fp4_requirement (
1935- a : torch .Tensor ,
1936- b : torch .Tensor ,
1937- a_descale : torch .Tensor ,
1938- b_descale : torch .Tensor ,
1939- alpha : Optional [torch .Tensor ] = None ,
1936+ a : torch .Tensor , # unused
1937+ b : torch .Tensor , # unused
1938+ a_descale : torch .Tensor , # unused
1939+ b_descale : torch .Tensor , # unused
1940+ alpha : Optional [torch .Tensor ] = None , # unused
19401941 out_dtype : torch .dtype = torch .bfloat16 ,
1941- out : Optional [torch .Tensor ] = None ,
1942- block_size : int = 16 ,
1943- use_8x4_sf_layout : bool = False ,
1944- backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" ,
1942+ out : Optional [torch .Tensor ] = None , # unused
1943+ block_size : int = 16 , # unused
1944+ use_8x4_sf_layout : bool = False , # unused
1945+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" , # unused
19451946 use_nvfp4 : bool = True ,
19461947):
19471948 if not use_nvfp4 :
@@ -1956,16 +1957,16 @@ def _trtllm_gemm_fp4_requirement(
19561957
19571958@supported_compute_capability ([100 , 103 , 110 , 120 , 121 ])
19581959def _cutlass_gemm_fp4_requirement (
1959- a : torch .Tensor ,
1960- b : torch .Tensor ,
1961- a_descale : torch .Tensor ,
1962- b_descale : torch .Tensor ,
1963- alpha : Optional [torch .Tensor ] = None ,
1964- out_dtype : torch .dtype = torch .bfloat16 ,
1965- out : Optional [torch .Tensor ] = None ,
1966- block_size : int = 16 ,
1960+ a : torch .Tensor , # unused
1961+ b : torch .Tensor , # unused
1962+ a_descale : torch .Tensor , # unused
1963+ b_descale : torch .Tensor , # unused
1964+ alpha : Optional [torch .Tensor ] = None , # unused
1965+ out_dtype : torch .dtype = torch .bfloat16 , # unused
1966+ out : Optional [torch .Tensor ] = None , # unused
1967+ block_size : int = 16 , # unused
19671968 use_8x4_sf_layout : bool = False ,
1968- backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" ,
1969+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" , # unused
19691970 use_nvfp4 : bool = True ,
19701971):
19711972 if use_8x4_sf_layout :
0 commit comments