4545from .cuda_utils import checkCudaErrors
4646from .jit .cubin_loader import get_cubin
4747from .jit .env import FLASHINFER_CUBIN_DIR
48- from .utils import ceil_div , round_up
48+ from .utils import (
49+ ceil_div ,
50+ round_up ,
51+ supported_compute_capability ,
52+ backend_requirement ,
53+ )
4954
5055
5156class GemmType (enum .Enum ):
@@ -1357,26 +1362,27 @@ def m_grouped_fp8_gemm_nt_masked_sm10x(
13571362 runtime = load ("fp8_m_grouped_gemm" , code )
13581363 runtime (** all_kwargs )
13591364
1360-
1361- def m_grouped_fp8_gemm_nt_contiguous (
1365+ @ supported_compute_capability ([ 100 , 103 ])
1366+ def _check_group_deepgemm_fp8_nt_contiguous_problem_size (
13621367 a_fp8 : Tuple [torch .Tensor , torch .Tensor ],
13631368 b_fp8 : Tuple [torch .Tensor , torch .Tensor ],
13641369 d : torch .Tensor ,
13651370 m_indices : torch .Tensor ,
13661371 recipe : Optional [Tuple [int , int , int ]] = None ,
13671372 compiled_dims : str = "nk" ,
1368- ) -> None :
1369- # Compiled dims can be upper cases
1370- compiled_dims = compiled_dims .lower ()
1373+ ) -> bool :
13711374
13721375 # NOTES: shape must be `[M, K] @ [G, N, K].mT`
13731376 major_a = get_major_type_ab (a_fp8 [0 ])
13741377 major_b = get_major_type_ab (b_fp8 [0 ])
1375- assert major_a == MajorTypeAB .KMajor
1376- if must_be_k_major ():
1377- assert major_b == MajorTypeAB .KMajor
1378- assert m_indices .is_contiguous ()
1379-
1378+ if major_a != MajorTypeAB .KMajor :
1379+ raise ValueError (f"major_a must be KMajor, but got { major_a } " )
1380+ if must_be_k_major () and (major_b != MajorTypeAB .KMajor ):
1381+ raise ValueError (f"major_b must be KMajor, but got { major_b } " )
1382+
1383+ if not m_indices .is_contiguous ():
1384+ raise ValueError (f"m_indices must be contiguous, but got { m_indices .is_contiguous ()} " )
1385+
13801386 a , sfa = a_fp8
13811387 b , sfb = b_fp8
13821388 m , k = a .shape
@@ -1385,15 +1391,44 @@ def m_grouped_fp8_gemm_nt_contiguous(
13851391 m__ = m_indices .numel ()
13861392
13871393 # Type and shape checks
1388- assert m == m_ == m__ and n == n_ and k == k_
1389- assert n > 0 and k > 0 and num_groups > 0
1390- assert a .dtype == torch .float8_e4m3fn
1391- assert b .dtype == torch .float8_e4m3fn
1392- assert d .dtype == torch .bfloat16
1393- assert m_indices .dtype == torch .int32
1394-
1394+ if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__ :
1395+ raise ValueError (f"Shape mismatch. m = { m } , m_ = { m_ } , k = { k } , k_ = { k_ } , n = { n } , n_ = { n_ } , m__ = { m__ } " )
1396+ if a .dtype != torch .float8_e4m3fn :
1397+ raise ValueError (f"a must be float8_e4m3fn, but got { a .dtype } " )
1398+ if b .dtype != torch .float8_e4m3fn :
1399+ raise ValueError (f"b must be float8_e4m3fn, but got { b .dtype } " )
1400+ if d .dtype != torch .bfloat16 :
1401+ raise ValueError (f"d must be bfloat16, but got { d .dtype } " )
1402+ if m_indices .dtype != torch .int32 :
1403+ raise ValueError (f"m_indices must be int32, but got { m_indices .dtype } " )
1404+
13951405 # D must be N-major
1396- assert get_major_type_cd (d ) == MajorTypeCD .NMajor
1406+ if get_major_type_cd (d ) != MajorTypeCD .NMajor :
1407+ raise ValueError (f"d must be N-major, but got { get_major_type_cd (d )} " )
1408+
1409+ return True
1410+
1411+ @backend_requirement (
1412+ common_check = _check_group_deepgemm_fp8_nt_contiguous_problem_size ,
1413+ )
1414+ def m_grouped_fp8_gemm_nt_contiguous (
1415+ a_fp8 : Tuple [torch .Tensor , torch .Tensor ],
1416+ b_fp8 : Tuple [torch .Tensor , torch .Tensor ],
1417+ d : torch .Tensor ,
1418+ m_indices : torch .Tensor ,
1419+ recipe : Optional [Tuple [int , int , int ]] = None ,
1420+ compiled_dims : str = "nk" ,
1421+ ) -> None :
1422+ # Compiled dims can be upper cases
1423+ compiled_dims = compiled_dims .lower ()
1424+
1425+ major_a = get_major_type_ab (a_fp8 [0 ])
1426+ major_b = get_major_type_ab (b_fp8 [0 ])
1427+
1428+ a , sfa = a_fp8
1429+ b , sfb = b_fp8
1430+ m , k = a .shape
1431+ num_groups , n , k_ = b .shape
13971432
13981433 # Do nothing if the problem is empty
13991434 if m == 0 :
@@ -1423,6 +1458,59 @@ def m_grouped_fp8_gemm_nt_contiguous(
14231458 impl (a , sfa , b , sfb , d , m_indices )
14241459
14251460
1461+ @supported_compute_capability ([100 , 103 ])
1462+ def _check_m_grouped_fp8_gemm_nt_masked_problem_size (
1463+ a_fp8 : Tuple [torch .Tensor , torch .Tensor ],
1464+ b_fp8 : Tuple [torch .Tensor , torch .Tensor ],
1465+ d : torch .Tensor ,
1466+ masked_m : torch .Tensor ,
1467+ expected_m : int ,
1468+ recipe : Optional [Tuple [int , int , int ]] = None ,
1469+ compiled_dims : str = "nk" ,
1470+ ) -> bool :
1471+
1472+ major_a = get_major_type_ab (a_fp8 [0 ])
1473+ major_b = get_major_type_ab (b_fp8 [0 ])
1474+ if major_a != MajorTypeAB .KMajor :
1475+ raise ValueError (f"major_a must be KMajor, but got { major_a } " )
1476+ if major_b != MajorTypeAB .KMajor :
1477+ raise ValueError (f"major_b must be KMajor, but got { major_b } " )
1478+
1479+ if not masked_m .is_contiguous ():
1480+ raise ValueError (f"masked_m must be contiguous, but got { masked_m .is_contiguous ()} " )
1481+
1482+ a , sfa = a_fp8
1483+ b , sfb = b_fp8
1484+ num_groups , m , k = a .shape
1485+ num_groups_ , n , k_ = b .shape
1486+ num_groups__ , m_ , n_ = d .shape
1487+ num_groups___ = masked_m .numel ()
1488+
1489+ # Type and shape checks
1490+ if num_groups != num_groups_ or num_groups != num_groups__ or num_groups != num_groups___ :
1491+ raise ValueError (f"num_groups mismatch. num_groups = { num_groups } , num_groups_ = { num_groups_ } , num_groups__ = { num_groups__ } , num_groups___ = { num_groups___ } " )
1492+ if m != m_ or n != n_ or k != k_ :
1493+ raise ValueError (f"m, n, k mismatch. m = { m } , m_ = { m_ } , n = { n } , n_ = { n_ } , k = { k } , k_ = { k_ } " )
1494+ if expected_m <= 0 or m <= 0 or n <= 0 or k <= 0 or num_groups <= 0 :
1495+ raise ValueError (f"expected_m, m, n, k, num_groups must be greater than 0, but got expected_m = { expected_m } , m = { m } , n = { n } , k = { k } , num_groups = { num_groups } " )
1496+ if a .dtype != torch .float8_e4m3fn :
1497+ raise ValueError (f"a must be float8_e4m3fn, but got { a .dtype } " )
1498+ if b .dtype != torch .float8_e4m3fn :
1499+ raise ValueError (f"b must be float8_e4m3fn, but got { b .dtype } " )
1500+ if d .dtype != torch .bfloat16 :
1501+ raise ValueError (f"d must be bfloat16, but got { d .dtype } " )
1502+ if masked_m .dtype != torch .int32 :
1503+ raise ValueError (f"masked_m must be int32, but got { masked_m .dtype } " )
1504+
1505+ # D must be N-major
1506+ if get_major_type_cd (d ) != MajorTypeCD .NMajor :
1507+ raise ValueError (f"d must be N-major, but got { get_major_type_cd (d )} " )
1508+
1509+ return True
1510+
1511+ @backend_requirement (
1512+ common_check = _check_m_grouped_fp8_gemm_nt_masked_problem_size ,
1513+ )
14261514def m_grouped_fp8_gemm_nt_masked (
14271515 a_fp8 : Tuple [torch .Tensor , torch .Tensor ],
14281516 b_fp8 : Tuple [torch .Tensor , torch .Tensor ],
@@ -1445,20 +1533,6 @@ def m_grouped_fp8_gemm_nt_masked(
14451533 b , sfb = b_fp8
14461534 num_groups , m , k = a .shape
14471535 num_groups_ , n , k_ = b .shape
1448- num_groups__ , m_ , n_ = d .shape
1449- num_groups___ = masked_m .numel ()
1450-
1451- # Type and shape checks
1452- assert num_groups == num_groups_ == num_groups__ == num_groups___
1453- assert m == m_ and n == n_ and k == k_
1454- assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
1455- assert a .dtype == torch .float8_e4m3fn
1456- assert b .dtype == torch .float8_e4m3fn
1457- assert d .dtype == torch .bfloat16
1458- assert masked_m .dtype == torch .int32
1459-
1460- # D must be N-major
1461- assert get_major_type_cd (d ) == MajorTypeCD .NMajor
14621536
14631537 # Transform SFA and SFB into compute-required layout
14641538 recipe = get_default_recipe (sfa .dtype , sfb .dtype ) if recipe is None else recipe
0 commit comments