diff --git a/flashinfer/deep_gemm.py b/flashinfer/deep_gemm.py index 4da91750fd..0178a4d174 100644 --- a/flashinfer/deep_gemm.py +++ b/flashinfer/deep_gemm.py @@ -45,7 +45,12 @@ from .cuda_utils import checkCudaErrors from .jit.cubin_loader import get_cubin from .jit.env import FLASHINFER_CUBIN_DIR -from .utils import ceil_div, round_up +from .utils import ( + ceil_div, + round_up, + supported_compute_capability, + backend_requirement, +) class GemmType(enum.Enum): @@ -1358,24 +1363,27 @@ def m_grouped_fp8_gemm_nt_masked_sm10x( runtime(**all_kwargs) -def m_grouped_fp8_gemm_nt_contiguous( +@supported_compute_capability([100, 103]) +def _check_group_deepgemm_fp8_nt_contiguous_problem_size( a_fp8: Tuple[torch.Tensor, torch.Tensor], b_fp8: Tuple[torch.Tensor, torch.Tensor], d: torch.Tensor, m_indices: torch.Tensor, recipe: Optional[Tuple[int, int, int]] = None, compiled_dims: str = "nk", -) -> None: - # Compiled dims can be upper cases - compiled_dims = compiled_dims.lower() - +) -> bool: # NOTES: shape must be `[M, K] @ [G, N, K].mT` major_a = get_major_type_ab(a_fp8[0]) major_b = get_major_type_ab(b_fp8[0]) - assert major_a == MajorTypeAB.KMajor - if must_be_k_major(): - assert major_b == MajorTypeAB.KMajor - assert m_indices.is_contiguous() + if major_a != MajorTypeAB.KMajor: + raise ValueError(f"major_a must be KMajor, but got {major_a}") + if must_be_k_major() and (major_b != MajorTypeAB.KMajor): + raise ValueError(f"major_b must be KMajor, but got {major_b}") + + if not m_indices.is_contiguous(): + raise ValueError( + f"m_indices must be contiguous, but got {m_indices.is_contiguous()}" + ) a, sfa = a_fp8 b, sfb = b_fp8 @@ -1385,15 +1393,48 @@ def m_grouped_fp8_gemm_nt_contiguous( m__ = m_indices.numel() # Type and shape checks - assert m == m_ == m__ and n == n_ and k == k_ - assert n > 0 and k > 0 and num_groups > 0 - assert a.dtype == torch.float8_e4m3fn - assert b.dtype == torch.float8_e4m3fn - assert d.dtype == torch.bfloat16 - assert m_indices.dtype == torch.int32 + if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: + raise ValueError( + f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}" + ) + if a.dtype != torch.float8_e4m3fn: + raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}") + if b.dtype != torch.float8_e4m3fn: + raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}") + if d.dtype != torch.bfloat16: + raise ValueError(f"d must be bfloat16, but got {d.dtype}") + if m_indices.dtype != torch.int32: + raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}") # D must be N-major - assert get_major_type_cd(d) == MajorTypeCD.NMajor + if get_major_type_cd(d) != MajorTypeCD.NMajor: + raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}") + + return True + + +@backend_requirement( + {}, + common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, +) +def m_grouped_fp8_gemm_nt_contiguous( + a_fp8: Tuple[torch.Tensor, torch.Tensor], + b_fp8: Tuple[torch.Tensor, torch.Tensor], + d: torch.Tensor, + m_indices: torch.Tensor, + recipe: Optional[Tuple[int, int, int]] = None, + compiled_dims: str = "nk", +) -> None: + # Compiled dims can be upper cases + compiled_dims = compiled_dims.lower() + + major_a = get_major_type_ab(a_fp8[0]) + major_b = get_major_type_ab(b_fp8[0]) + + a, sfa = a_fp8 + b, sfb = b_fp8 + m, k = a.shape + num_groups, n, k_ = b.shape # Do nothing if the problem is empty if m == 0: @@ -1423,6 +1464,72 @@ def m_grouped_fp8_gemm_nt_contiguous( impl(a, sfa, b, sfb, d, m_indices) +@supported_compute_capability([100, 103]) +def _check_m_grouped_fp8_gemm_nt_masked_problem_size( + a_fp8: Tuple[torch.Tensor, torch.Tensor], + b_fp8: Tuple[torch.Tensor, torch.Tensor], + d: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + recipe: Optional[Tuple[int, int, int]] = None, + compiled_dims: str = "nk", +) -> bool: + major_a = get_major_type_ab(a_fp8[0]) + major_b = get_major_type_ab(b_fp8[0]) + if major_a != MajorTypeAB.KMajor: + raise ValueError(f"major_a must be KMajor, but got {major_a}") + if major_b != MajorTypeAB.KMajor: + raise ValueError(f"major_b must be KMajor, but got {major_b}") + + if not masked_m.is_contiguous(): + raise ValueError( + f"masked_m must be contiguous, but got {masked_m.is_contiguous()}" + ) + + a, sfa = a_fp8 + b, sfb = b_fp8 + num_groups, m, k = a.shape + num_groups_, n, k_ = b.shape + num_groups__, m_, n_ = d.shape + num_groups___ = masked_m.numel() + + # Type and shape checks + if ( + num_groups != num_groups_ + or num_groups != num_groups__ + or num_groups != num_groups___ + ): + raise ValueError( + f"num_groups mismatch. num_groups = {num_groups}, num_groups_ = {num_groups_}, num_groups__ = {num_groups__}, num_groups___ = {num_groups___}" + ) + if m != m_ or n != n_ or k != k_: + raise ValueError( + f"m, n, k mismatch. m = {m}, m_ = {m_}, n = {n}, n_ = {n_}, k = {k}, k_ = {k_}" + ) + if expected_m <= 0 or m <= 0 or n <= 0 or k <= 0 or num_groups <= 0: + raise ValueError( + f"expected_m, m, n, k, num_groups must be greater than 0, but got expected_m = {expected_m}, m = {m}, n = {n}, k = {k}, num_groups = {num_groups}" + ) + if a.dtype != torch.float8_e4m3fn: + raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}") + if b.dtype != torch.float8_e4m3fn: + raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}") + if d.dtype != torch.bfloat16: + raise ValueError(f"d must be bfloat16, but got {d.dtype}") + if masked_m.dtype != torch.int32: + raise ValueError(f"masked_m must be int32, but got {masked_m.dtype}") + + # D must be N-major + if get_major_type_cd(d) != MajorTypeCD.NMajor: + raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}") + + return True + + +@backend_requirement( + {}, + common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size, +) def m_grouped_fp8_gemm_nt_masked( a_fp8: Tuple[torch.Tensor, torch.Tensor], b_fp8: Tuple[torch.Tensor, torch.Tensor], @@ -1445,20 +1552,6 @@ def m_grouped_fp8_gemm_nt_masked( b, sfb = b_fp8 num_groups, m, k = a.shape num_groups_, n, k_ = b.shape - num_groups__, m_, n_ = d.shape - num_groups___ = masked_m.numel() - - # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ - assert m == m_ and n == n_ and k == k_ - assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert a.dtype == torch.float8_e4m3fn - assert b.dtype == torch.float8_e4m3fn - assert d.dtype == torch.bfloat16 - assert masked_m.dtype == torch.int32 - - # D must be N-major - assert get_major_type_cd(d) == MajorTypeCD.NMajor # Transform SFA and SFB into compute-required layout recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 63a2f7e211..c9f61b6d92 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -350,6 +350,7 @@ def forward( ) +# This is just helper for bmm_fp8 def fp8_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -381,10 +382,6 @@ def fp8_gemm_sm100( if CUDNN_AVAILABLE and "cudnn" in runner_names: runners.append(_cudnn_gemm_fp8_runner()) - if len(runners) == 0: - major, minor = get_compute_capability(torch.device("cuda")) - raise ValueError(f"No valid runner found for current device sm{major}{minor}") - tuner = AutoTuner.get() a_tensor_index = 0 out_tensor_index = 4 @@ -2009,6 +2006,70 @@ def mm_fp4( return out +@supported_compute_capability([89, 90, 100, 103, 120]) +def _cudnn_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + _check_cudnn_availability() + return True + + +@supported_compute_capability([89, 90, 100, 103, 120]) +def _cublas_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + return True + + +@supported_compute_capability([100, 103, 110, 120, 121]) +def _cutlass_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: + raise ValueError("e5m2 is not supported for bmm_fp8 with cutlass backend") + return True + + +def _check_bmm_fp8_problem_size( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + _validate_fp8_output_dtype(dtype) + return True + + +@backend_requirement( + { + "cudnn": _cudnn_bmm_fp8_requirement, + "cublas": _cublas_bmm_fp8_requirement, + "cutlass": _cutlass_bmm_fp8_requirement, + "auto": _cublas_bmm_fp8_requirement, # cublas default + }, + common_check=_check_bmm_fp8_problem_size, +) def bmm_fp8( A: torch.Tensor, B: torch.Tensor, @@ -2073,7 +2134,6 @@ def bmm_fp8( >>> out.dtype torch.bfloat16 """ - _validate_fp8_output_dtype(dtype) if out is None: out = torch.empty( @@ -2091,8 +2151,6 @@ def bmm_fp8( elif backend == "cublas": backends = ["cublas"] elif backend == "cutlass": - if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: - raise ValueError("e5m2 is not supported for cutlass backend") backends = ["cutlass"] elif backend == "auto": backends = ["cutlass", "cublas", "cudnn"] @@ -2103,6 +2161,78 @@ def bmm_fp8( return out +@supported_compute_capability([100, 103, 120, 121]) +def _cutlass_gemm_fp8_nt_groupwise_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + scale_major_mode: Optional[Literal["MN", "K"]] = None, + mma_sm: int = 1, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + backend: Literal["cutlass", "trtllm"] = "cutlass", +): + if scale_major_mode is None: + raise ValueError("scale_major_mode is required in CUTLASS") + + return True + + +@supported_compute_capability([100, 103]) +def _trtllm_gemm_fp8_nt_groupwise_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + scale_major_mode: Optional[Literal["MN", "K"]] = None, + mma_sm: int = 1, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + backend: Literal["cutlass", "trtllm"] = "cutlass", +): + if scale_granularity_mnk != (1, 128, 128): + raise ValueError("scale_granularity_mnk must be (1, 128, 128) in TRTLLM") + if a.shape[1] < 256: + raise ValueError("a.shape[1] must be >= 256 in TRTLLM") + + return True + + +def _check_gemm_fp8_nt_groupwise_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + scale_major_mode: Optional[Literal["MN", "K"]] = None, + mma_sm: int = 1, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + backend: Literal["cutlass", "trtllm"] = "cutlass", +): + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"Shape mismatch. a.shape = {a.shape}, b.shape = {b.shape}") + + if a.shape[1] != b.shape[1]: + raise ValueError( + f"Shape mismatch. a.shape[1] = {a.shape[1]}, b.shape[1] = {b.shape[1]}" + ) + + _validate_fp8_output_dtype(out_dtype) + + return True + + +@backend_requirement( + { + "cutlass": _cutlass_gemm_fp8_nt_groupwise_requirement, + "trtllm": _trtllm_gemm_fp8_nt_groupwise_requirement, + }, + common_check=_check_gemm_fp8_nt_groupwise_problem_size, +) def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -2176,27 +2306,16 @@ def gemm_fp8_nt_groupwise( ----- The ``m`` should be padded to a multiple of 4 before calling this function, to accommodate the kernel's requirement. """ - if backend == "trtllm" and _match_sm_version(a.device, ["110"]): - raise ValueError("TRTLLM FP8 GEMM is not supported on SM110.") workspace_buffer = _get_cache_buf( "gemm_fp8_nt_groupwise_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) - if a.ndim != 2 or b.ndim != 2: - raise ValueError(f"Shape mismatch. a.shape = {a.shape}, b.shape = {b.shape}") - - if a.shape[1] != b.shape[1]: - raise ValueError( - f"Shape mismatch. a.shape[1] = {a.shape[1]}, b.shape[1] = {b.shape[1]}" - ) if out is None: out_dtype = out_dtype or torch.bfloat16 else: out_dtype = out.dtype - _validate_fp8_output_dtype(out_dtype) - # NOTE(Zihao): (out_specified, need_padding) # (False, False) -> create out_padded tensor explicitly # (False, True) -> create out_padded tensor explicitly @@ -2212,18 +2331,6 @@ def gemm_fp8_nt_groupwise( ) if backend == "cutlass": - if not _match_sm_version(a.device, ["100", "103", "110", "120", "121"]): - raise ValueError( - "gemm_fp8_nt_groupwise is only supported on SM100, SM103, SM110, SM120, or SM121 in cutlass backend." - ) - elif backend == "trtllm": - if not _match_sm_version(a.device, ["100", "103"]): - raise ValueError( - "gemm_fp8_nt_groupwise is only supported on SM100, SM103 in trtllm backend." - ) - - if backend == "cutlass": - assert scale_major_mode is not None if is_sm120a_supported(a.device) or is_sm121a_supported(a.device): # SM120/121 doesn't use mma_sm parameter get_gemm_sm120_module().gemm_fp8_nt_groupwise( @@ -2251,8 +2358,6 @@ def gemm_fp8_nt_groupwise( else: raise ValueError(f"Unsupported device for FP8 GEMM: {a.device}") elif backend == "trtllm": - assert scale_granularity_mnk == (1, 128, 128) - assert a.shape[1] >= 256 # mma_sm is ignored get_trtllm_gemm_module().trtllm_gemm( workspace_buffer, @@ -2411,6 +2516,48 @@ def pad_up(x, y): ) +@supported_compute_capability([100, 103, 120, 121]) +def _check_gemm_fp8_nt_blockscaled_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + scale_major_mode: Optional[Literal["MN", "K"]] = "MN", + mma_sm: int = 1, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +): + _check_gemm_fp8_nt_groupwise_problem_size( + a, + b, + a_scale, + b_scale, + scale_major_mode, + mma_sm, + out, + out_dtype, + backend="cutlass", + ) + + _cutlass_gemm_fp8_nt_groupwise_requirement( + a, + b, + a_scale, + b_scale, + scale_major_mode, + mma_sm, + out, + out_dtype, + backend="cutlass", + ) + + return True + + +@backend_requirement( + {}, + common_check=_check_gemm_fp8_nt_blockscaled_problem_size, +) def gemm_fp8_nt_blockscaled( a: torch.Tensor, b: torch.Tensor, @@ -2439,6 +2586,79 @@ def gemm_fp8_nt_blockscaled( ) +@supported_compute_capability([100, 120, 121]) +def _check_group_gemm_fp8_nt_groupwise_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + m_indptr: torch.Tensor, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + scale_major_mode: Literal["MN", "K"] = "MN", + mma_sm: int = 1, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +): + if a.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError(f"a must be a float8 tensor, but got {a.dtype}") + if b.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError(f"b must be a float8 tensor, but got {b.dtype}") + if a_scale.dtype not in [torch.float32]: + raise ValueError(f"a_scale must be a float32 tensor, but got {a_scale.dtype}") + if b_scale.dtype not in [torch.float32]: + raise ValueError(f"b_scale must be a float32 tensor, but got {b_scale.dtype}") + if m_indptr.dtype not in [torch.int32]: + raise ValueError(f"m_indptr must be a int32 tensor, but got {m_indptr.dtype}") + if scale_major_mode not in ["MN", "K"]: + raise ValueError( + f"scale_major_mode must be either 'MN' or 'K', but got {scale_major_mode}" + ) + if mma_sm not in [1, 2]: + raise ValueError(f"mma_sm must be either 1 or 2, but got {mma_sm}") + + # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance + n = b.shape[1] + k = b.shape[2] + + if out is None: + if out_dtype is None: + out_dtype = torch.bfloat16 + else: + if out_dtype is None: + out_dtype = out.dtype + if out.shape != (a.shape[0], n): + raise ValueError( + f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}" + ) + if out.dtype != out_dtype: + raise ValueError( + f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}" + ) + + _validate_fp8_output_dtype(out_dtype) + + if a.shape[1] != k: + raise ValueError(f"Shape mismatch. a.shape[1] = {a.shape[1]}, k = {k}") + if n % 8 != 0: + raise ValueError(f"n must be a multiple of 8, but got {n}") + if k % 16 != 0: + raise ValueError(f"k must be a multiple of 16, but got {k}") + + num_groups = m_indptr.shape[0] - 1 + + if is_sm120a_supported(a.device) or is_sm121a_supported(a.device): + if num_groups > 1: + raise RuntimeError( + "group_gemm_fp8_nt_groupwise has correctness issues for num_groups > 1 on SM120/121" + ) + + return True + + +@backend_requirement( + {}, + common_check=_check_group_gemm_fp8_nt_groupwise_problem_size, +) def group_gemm_fp8_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2503,19 +2723,6 @@ def group_gemm_fp8_nt_groupwise( Each value in ``m_indptr`` should be padded to a multiple of 4 before calling this function, to accommodate the kernel's requirement. """ - if ( - not is_sm100a_supported(a.device) - and not is_sm120a_supported(a.device) - and not is_sm121a_supported(a.device) - ): - raise ValueError( - "gemm_fp8_nt_groupwise is only supported on SM100, SM120, and SM121." - ) - if not (_match_sm_version(a.device, ["100", "103", "110", "120", "121"])): - raise ValueError( - "gemm_fp8_nt_groupwise is only supported on SM100, SM103, SM110, SM120, or SM121." - ) - int_workspace_buffer = _get_cache_buf( "group_gemm_fp8_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) @@ -2523,46 +2730,21 @@ def group_gemm_fp8_nt_groupwise( "group_gemm_fp8_nt_groupwise_float_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) - assert a.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert b.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert a_scale.dtype == torch.float32 - assert b_scale.dtype == torch.float32 - assert m_indptr.dtype == torch.int32 - assert scale_major_mode in ["MN", "K"] - assert mma_sm in [1, 2] if out is None: if out_dtype is None: out_dtype = torch.bfloat16 else: if out_dtype is None: out_dtype = out.dtype - _validate_fp8_output_dtype(out_dtype) - num_groups = m_indptr.shape[0] - 1 - assert b.shape[0] == num_groups n = b.shape[1] k = b.shape[2] - # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance - assert a.shape[1] == k - align_n = 8 - align_k = 16 - assert n % align_n == 0 - assert k % align_k == 0 - out_shape = (a.shape[0], n) if out is None: out = torch.empty(out_shape, dtype=out_dtype, device=a.device) - else: - assert out.shape == out_shape - assert out.dtype == out_dtype if is_sm120a_supported(a.device) or is_sm121a_supported(a.device): - # it has correctness issues for num_groups > 1 - if num_groups > 1: - raise RuntimeError( - "group_gemm_fp8_nt_groupwise has correctness issues for num_groups > 1 on SM120/121" - ) # SM120/121 doesn't use mma_sm parameter get_gemm_sm120_module().group_gemm_fp8_nt_groupwise( int_workspace_buffer, @@ -2594,13 +2776,96 @@ def group_gemm_fp8_nt_groupwise( scale_major_mode, mma_sm, ) + return out + + +@supported_compute_capability([100, 103, 110, 120, 121]) +def _check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + m_indptr: torch.Tensor, + mma_sm: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + swap_ab: bool = True, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +): + if a.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError( + f"a must be a float8_e4m3fn or float8_e5m2 tensor, but got {a.dtype}" + ) + if b.dtype != torch.uint8: + raise ValueError(f"b must be a uint8 tensor, but got {b.dtype}") + if a_scale.dtype != torch.uint8: + raise ValueError(f"a_scale must be a uint8 tensor, but got {a_scale.dtype}") + if b_scale.dtype != torch.uint8: + raise ValueError(f"b_scale must be a uint8 tensor, but got {b_scale.dtype}") + if m_indptr.dtype != torch.int32: + raise ValueError(f"m_indptr must be a int32 tensor, but got {m_indptr.dtype}") + if mma_sm not in [1, 2]: + raise ValueError(f"mma_sm must be either 1 or 2, but got {mma_sm}") + if tile_m not in [128]: + raise ValueError(f"tile_m must be 128, but got {tile_m}") + if tile_n not in [64, 128, 192, 256]: + raise ValueError(f"tile_n must be one of [64, 128, 192, 256], but got {tile_n}") + if tile_k not in [128, 256]: + raise ValueError(f"tile_k must be either 128 or 256, but got {tile_k}") + if swap_ab not in [True, False]: + raise ValueError(f"swap_ab must be a boolean value, but got {swap_ab}") + + # Determine out_dtype if not specified + if out is None: + if out_dtype is None: + out_dtype = torch.bfloat16 else: + if out_dtype is None: + out_dtype = out.dtype + + if out_dtype not in [torch.bfloat16, torch.float16]: raise ValueError( - f"group_gemm_fp8_nt_groupwise requires SM100, SM120, or SM121, but got {a.device}" + f"out_dtype must be either torch.bfloat16 or torch.float16, but got {out_dtype}" + ) + + num_groups = m_indptr.shape[0] - 1 + if b.shape[0] != num_groups: + raise ValueError( + f"b.shape[0] must equal num_groups (m_indptr.shape[0] - 1), but got b.shape[0]={b.shape[0]}, num_groups={num_groups}" ) - return out + n = b.shape[1] + k = b.shape[2] * 2 # Multiply by 2 because b is e2m1 packed as uint8 + # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance + if a.shape[1] != k: + raise ValueError( + f"a.shape[1] must equal k, but got a.shape[1]={a.shape[1]}, k={k}" + ) + + align_n = 8 + align_k = 128 + if n % align_n != 0: + raise ValueError(f"n must be a multiple of {align_n}, but got n={n}") + if k % align_k != 0: + raise ValueError(f"k must be a multiple of {align_k}, but got k={k}") + + out_shape = (a.shape[0], n) + if out is not None: + if out.shape != out_shape: + raise ValueError(f"out.shape must be {out_shape}, but got {out.shape}") + if out.dtype != out_dtype: + raise ValueError(f"out.dtype must be {out_dtype}, but got {out.dtype}") + + return True + + +@backend_requirement( + {}, + common_check=_check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size, +) def group_gemm_mxfp8_mxfp4_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k // 2) @@ -2677,43 +2942,20 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( DEFAULT_WORKSPACE_SIZE, a.device, ) - - assert a.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert b.dtype == torch.uint8 - assert a_scale.dtype == torch.uint8 - assert b_scale.dtype == torch.uint8 - assert m_indptr.dtype == torch.int32 - assert mma_sm in [1, 2] - assert tile_m in [128] - assert tile_n in [64, 128, 192, 256] - assert tile_k in [128, 256] - assert swap_ab in [True, False] + # Determine out_dtype if not specified if out is None: if out_dtype is None: out_dtype = torch.bfloat16 else: if out_dtype is None: out_dtype = out.dtype - assert out_dtype in [torch.bfloat16, torch.float16] - num_groups = m_indptr.shape[0] - 1 - assert b.shape[0] == num_groups n = b.shape[1] k = b.shape[2] * 2 # Multiply by 2 because b is e2m1 packed as uint8 - # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance - assert a.shape[1] == k - align_n = 8 - align_k = 128 - assert n % align_n == 0 - assert k % align_k == 0 - out_shape = (a.shape[0], n) if out is None: out = torch.empty(out_shape, dtype=out_dtype, device=a.device) - else: - assert out.shape == out_shape - assert out.dtype == out_dtype get_gemm_sm100_module().group_gemm_mxfp4_nt_groupwise( int_workspace_buffer, @@ -2768,6 +3010,30 @@ def get_deepgemm_sm100_module(): return module +@supported_compute_capability([100, 103]) +def _check_group_deepgemm_fp8_nt_groupwise_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + m_indices: torch.Tensor, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> bool: + from flashinfer.deep_gemm import ( + _check_group_deepgemm_fp8_nt_contiguous_problem_size, + ) + + return _check_group_deepgemm_fp8_nt_contiguous_problem_size( + (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk + ) + + +@backend_requirement( + {}, + common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size, +) def group_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2882,11 +3148,6 @@ def group_deepgemm_fp8_nt_groupwise( """ from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_contiguous - if not _match_sm_version(a.device, ["100", "103"]): - raise ValueError( - "m_grouped_fp8_gemm_nt_contiguous is only supported on SM100, SM100, SM103." - ) - if out is None: out_dtype = out_dtype or torch.bfloat16 out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device) @@ -2898,6 +3159,29 @@ def group_deepgemm_fp8_nt_groupwise( return out +@supported_compute_capability([100, 103]) +def _check_batch_deepgemm_fp8_nt_groupwise( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> bool: + from flashinfer.deep_gemm import _check_m_grouped_fp8_gemm_nt_masked_problem_size + + return _check_m_grouped_fp8_gemm_nt_masked_problem_size( + (a, a_scale), (b, b_scale), out, masked_m, expected_m, scale_granularity_mnk + ) + + +@backend_requirement( + {}, + common_check=_check_batch_deepgemm_fp8_nt_groupwise, +) def batch_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (batch_size, m, k) b: torch.Tensor, # (batch_size, n, k) @@ -3015,11 +3299,6 @@ def batch_deepgemm_fp8_nt_groupwise( """ from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_masked - if not _match_sm_version(a.device, ["100", "103"]): - raise ValueError( - "m_grouped_fp8_gemm_nt_masked is only supported on SM100, SM103." - ) - if out is None: out_dtype = out_dtype or torch.bfloat16 out = torch.empty(