Skip to content

Commit 92096da

Browse files
committed
gemm.py refactor
1 parent 499dcc5 commit 92096da

File tree

2 files changed

+418
-143
lines changed

2 files changed

+418
-143
lines changed

flashinfer/deep_gemm.py

Lines changed: 107 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@
4545
from .cuda_utils import checkCudaErrors
4646
from .jit.cubin_loader import get_cubin
4747
from .jit.env import FLASHINFER_CUBIN_DIR
48-
from .utils import ceil_div, round_up
48+
from .utils import (
49+
ceil_div,
50+
round_up,
51+
supported_compute_capability,
52+
backend_requirement,
53+
)
4954

5055

5156
class GemmType(enum.Enum):
@@ -1357,26 +1362,27 @@ def m_grouped_fp8_gemm_nt_masked_sm10x(
13571362
runtime = load("fp8_m_grouped_gemm", code)
13581363
runtime(**all_kwargs)
13591364

1360-
1361-
def m_grouped_fp8_gemm_nt_contiguous(
1365+
@supported_compute_capability([100, 103])
1366+
def _check_group_deepgemm_fp8_nt_contiguous_problem_size(
13621367
a_fp8: Tuple[torch.Tensor, torch.Tensor],
13631368
b_fp8: Tuple[torch.Tensor, torch.Tensor],
13641369
d: torch.Tensor,
13651370
m_indices: torch.Tensor,
13661371
recipe: Optional[Tuple[int, int, int]] = None,
13671372
compiled_dims: str = "nk",
1368-
) -> None:
1369-
# Compiled dims can be upper cases
1370-
compiled_dims = compiled_dims.lower()
1373+
) -> bool:
13711374

13721375
# NOTES: shape must be `[M, K] @ [G, N, K].mT`
13731376
major_a = get_major_type_ab(a_fp8[0])
13741377
major_b = get_major_type_ab(b_fp8[0])
1375-
assert major_a == MajorTypeAB.KMajor
1376-
if must_be_k_major():
1377-
assert major_b == MajorTypeAB.KMajor
1378-
assert m_indices.is_contiguous()
1379-
1378+
if major_a != MajorTypeAB.KMajor:
1379+
raise ValueError(f"major_a must be KMajor, but got {major_a}")
1380+
if must_be_k_major() and (major_b != MajorTypeAB.KMajor):
1381+
raise ValueError(f"major_b must be KMajor, but got {major_b}")
1382+
1383+
if not m_indices.is_contiguous():
1384+
raise ValueError(f"m_indices must be contiguous, but got {m_indices.is_contiguous()}")
1385+
13801386
a, sfa = a_fp8
13811387
b, sfb = b_fp8
13821388
m, k = a.shape
@@ -1385,15 +1391,44 @@ def m_grouped_fp8_gemm_nt_contiguous(
13851391
m__ = m_indices.numel()
13861392

13871393
# Type and shape checks
1388-
assert m == m_ == m__ and n == n_ and k == k_
1389-
assert n > 0 and k > 0 and num_groups > 0
1390-
assert a.dtype == torch.float8_e4m3fn
1391-
assert b.dtype == torch.float8_e4m3fn
1392-
assert d.dtype == torch.bfloat16
1393-
assert m_indices.dtype == torch.int32
1394-
1394+
if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__:
1395+
raise ValueError(f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}")
1396+
if a.dtype != torch.float8_e4m3fn:
1397+
raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}")
1398+
if b.dtype != torch.float8_e4m3fn:
1399+
raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}")
1400+
if d.dtype != torch.bfloat16:
1401+
raise ValueError(f"d must be bfloat16, but got {d.dtype}")
1402+
if m_indices.dtype != torch.int32:
1403+
raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}")
1404+
13951405
# D must be N-major
1396-
assert get_major_type_cd(d) == MajorTypeCD.NMajor
1406+
if get_major_type_cd(d) != MajorTypeCD.NMajor:
1407+
raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}")
1408+
1409+
return True
1410+
1411+
@backend_requirement(
1412+
common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
1413+
)
1414+
def m_grouped_fp8_gemm_nt_contiguous(
1415+
a_fp8: Tuple[torch.Tensor, torch.Tensor],
1416+
b_fp8: Tuple[torch.Tensor, torch.Tensor],
1417+
d: torch.Tensor,
1418+
m_indices: torch.Tensor,
1419+
recipe: Optional[Tuple[int, int, int]] = None,
1420+
compiled_dims: str = "nk",
1421+
) -> None:
1422+
# Compiled dims can be upper cases
1423+
compiled_dims = compiled_dims.lower()
1424+
1425+
major_a = get_major_type_ab(a_fp8[0])
1426+
major_b = get_major_type_ab(b_fp8[0])
1427+
1428+
a, sfa = a_fp8
1429+
b, sfb = b_fp8
1430+
m, k = a.shape
1431+
num_groups, n, k_ = b.shape
13971432

13981433
# Do nothing if the problem is empty
13991434
if m == 0:
@@ -1423,6 +1458,59 @@ def m_grouped_fp8_gemm_nt_contiguous(
14231458
impl(a, sfa, b, sfb, d, m_indices)
14241459

14251460

1461+
@supported_compute_capability([100, 103])
1462+
def _check_m_grouped_fp8_gemm_nt_masked_problem_size(
1463+
a_fp8: Tuple[torch.Tensor, torch.Tensor],
1464+
b_fp8: Tuple[torch.Tensor, torch.Tensor],
1465+
d: torch.Tensor,
1466+
masked_m: torch.Tensor,
1467+
expected_m: int,
1468+
recipe: Optional[Tuple[int, int, int]] = None,
1469+
compiled_dims: str = "nk",
1470+
) -> bool:
1471+
1472+
major_a = get_major_type_ab(a_fp8[0])
1473+
major_b = get_major_type_ab(b_fp8[0])
1474+
if major_a != MajorTypeAB.KMajor:
1475+
raise ValueError(f"major_a must be KMajor, but got {major_a}")
1476+
if major_b != MajorTypeAB.KMajor:
1477+
raise ValueError(f"major_b must be KMajor, but got {major_b}")
1478+
1479+
if not masked_m.is_contiguous():
1480+
raise ValueError(f"masked_m must be contiguous, but got {masked_m.is_contiguous()}")
1481+
1482+
a, sfa = a_fp8
1483+
b, sfb = b_fp8
1484+
num_groups, m, k = a.shape
1485+
num_groups_, n, k_ = b.shape
1486+
num_groups__, m_, n_ = d.shape
1487+
num_groups___ = masked_m.numel()
1488+
1489+
# Type and shape checks
1490+
if num_groups != num_groups_ or num_groups != num_groups__ or num_groups != num_groups___:
1491+
raise ValueError(f"num_groups mismatch. num_groups = {num_groups}, num_groups_ = {num_groups_}, num_groups__ = {num_groups__}, num_groups___ = {num_groups___}")
1492+
if m != m_ or n != n_ or k != k_:
1493+
raise ValueError(f"m, n, k mismatch. m = {m}, m_ = {m_}, n = {n}, n_ = {n_}, k = {k}, k_ = {k_}")
1494+
if expected_m <= 0 or m <= 0 or n <= 0 or k <= 0 or num_groups <= 0:
1495+
raise ValueError(f"expected_m, m, n, k, num_groups must be greater than 0, but got expected_m = {expected_m}, m = {m}, n = {n}, k = {k}, num_groups = {num_groups}")
1496+
if a.dtype != torch.float8_e4m3fn:
1497+
raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}")
1498+
if b.dtype != torch.float8_e4m3fn:
1499+
raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}")
1500+
if d.dtype != torch.bfloat16:
1501+
raise ValueError(f"d must be bfloat16, but got {d.dtype}")
1502+
if masked_m.dtype != torch.int32:
1503+
raise ValueError(f"masked_m must be int32, but got {masked_m.dtype}")
1504+
1505+
# D must be N-major
1506+
if get_major_type_cd(d) != MajorTypeCD.NMajor:
1507+
raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}")
1508+
1509+
return True
1510+
1511+
@backend_requirement(
1512+
common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size,
1513+
)
14261514
def m_grouped_fp8_gemm_nt_masked(
14271515
a_fp8: Tuple[torch.Tensor, torch.Tensor],
14281516
b_fp8: Tuple[torch.Tensor, torch.Tensor],
@@ -1445,20 +1533,6 @@ def m_grouped_fp8_gemm_nt_masked(
14451533
b, sfb = b_fp8
14461534
num_groups, m, k = a.shape
14471535
num_groups_, n, k_ = b.shape
1448-
num_groups__, m_, n_ = d.shape
1449-
num_groups___ = masked_m.numel()
1450-
1451-
# Type and shape checks
1452-
assert num_groups == num_groups_ == num_groups__ == num_groups___
1453-
assert m == m_ and n == n_ and k == k_
1454-
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
1455-
assert a.dtype == torch.float8_e4m3fn
1456-
assert b.dtype == torch.float8_e4m3fn
1457-
assert d.dtype == torch.bfloat16
1458-
assert masked_m.dtype == torch.int32
1459-
1460-
# D must be N-major
1461-
assert get_major_type_cd(d) == MajorTypeCD.NMajor
14621536

14631537
# Transform SFA and SFB into compute-required layout
14641538
recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe

0 commit comments

Comments
 (0)