Skip to content

Commit

Permalink
Add CUTLASS-based row-wise scaled sparse FP8 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Feb 25, 2025
1 parent 38e36de commit a7197f7
Show file tree
Hide file tree
Showing 30 changed files with 1,622 additions and 92 deletions.
11 changes: 5 additions & 6 deletions benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A, A_scale, B, B_scale, C
return A, A_scale, B, B_scale


def benchmark(m: int, k: int, n: int):
Expand All @@ -34,14 +33,14 @@ def benchmark(m: int, k: int, n: int):
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
A, A_scale, B, B_scale = get_problem(m, n, k, 8, 4)
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale
)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
A, A_scale, B, B_scale = get_problem(m, n, k, 4, 4)
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale
)

return {
Expand Down
61 changes: 61 additions & 0 deletions benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.ops import (
rowwise_scaled_linear_sparse_cutlass_f8f8,
to_sparse_semi_structured_cutlass_sm9x_f8,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def get_problem(m: int, n: int, k: int):
dev = torch.device("cuda")

A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn)
B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)

return A, A_scale, B_sp, B_meta, B_scale


def benchmark(m: int, k: int, n: int):
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B_sp, B_meta, B_scale = get_problem(m, n, k)
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
"f8f8 speedup (d/s)": fp16_time
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
1 change: 1 addition & 0 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Layouts and Tensor Subclasses
MarlinQQQLayout
Int4CPULayout
CutlassInt4PackedLayout
CutlassSemiSparseLayout

Quantization techniques
-----------------------
Expand Down
56 changes: 56 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import glob
import os
import subprocess
Expand Down Expand Up @@ -73,6 +74,7 @@ def use_debug_mode():
BuildExtension,
CppExtension,
CUDAExtension,
_get_cuda_arch_flags,
)

# Constant known variables used throughout this file
Expand Down Expand Up @@ -251,6 +253,7 @@ def get_extensions():
sources += cuda_sources

use_cutlass = False
cutlass_90a_sources = None
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
Expand All @@ -266,8 +269,46 @@ def get_extensions():
"-I" + cutlass_include_dir,
"-I" + cutlass_tools_include_dir,
"-I" + cutlass_extensions_include_dir,
"-DNDEBUG" if not debug_mode else "",
"-DCUTE_USE_PACKED_TUPLE=1",
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"--use_fast_math",
"--ftemplate-backtrace-limit=0",
# "--keep",
# "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage",
# "--resource-usage",
# "-lineinfo",
# "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md
]
)

cuda_arch_flags = _get_cuda_arch_flags()
build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
if build_for_sm90 and not build_for_sm90a:
cutlass_90a_sources = [
os.path.join(
extensions_cuda_dir,
"rowwise_scaled_linear_sparse_cutlass",
"rowwise_scaled_linear_sparse_cutlass_f8f8.cu",
),
os.path.join(
extensions_cuda_dir,
"to_sparse_semi_structured_cutlass_sm9x",
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
),
]
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
cutlass_90a_sources.append(
os.path.join(
extensions_cuda_dir,
"rowwise_scaled_linear_sparse_cutlass",
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
)
)
sources = [s for s in sources if s not in cutlass_90a_sources]
else:
# Remove CUTLASS-based kernels from the cuda_sources list. An
# assumption is that these files will have "cutlass" in its
Expand All @@ -291,6 +332,21 @@ def get_extensions():
)
)

if cutlass_90a_sources is not None and len(cutlass_90a_sources) > 0:
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
cutlass_90a_extra_compile_args["nvcc"].extend(
cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a"]
)
ext_modules.append(
extension(
"torchao._C",
cutlass_90a_sources,
py_limited_api=True,
extra_compile_args=cutlass_90a_extra_compile_args,
extra_link_args=extra_link_args,
)
)

if build_torchao_experimental:
ext_modules.append(
CMakeExtension(
Expand Down
9 changes: 8 additions & 1 deletion test/test_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
rowwise_scaled_linear_cutlass_s8s4,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
Expand Down Expand Up @@ -40,6 +41,12 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias)
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.xfail("torch.nn.RMSNorm not supported")
rms_norm = torch.nn.RMSNorm(size_k).to("cuda")
x = rms_norm(x).to(dtype)
w = rms_norm(w).to(dtype)

x_2d = x.view(-1, x.shape[-1])
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric(
x_2d, xq_bits, size_k, dtype
Expand All @@ -57,7 +64,7 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias)
)
assert torch.all(wq_zeros == 0)
if wq_bits == 4:
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF)
wq = (wq_s8[..., 1::2] << 4) | (wq_s8[..., 0::2] & 0xF)
else:
wq = wq_s8

Expand Down
137 changes: 137 additions & 0 deletions test/test_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import itertools
import random

import pytest
import torch
from torch.testing._internal.common_cuda import SM90OrLater
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

from torchao.dtypes import (
Float8Layout,
to_affine_quantized_floatx,
)
from torchao.ops import (
rowwise_scaled_linear_sparse_cutlass_f8f8,
to_sparse_semi_structured_cutlass_sm9x_f8,
)
from torchao.quantization.utils import _get_per_token_block_size
from torchao.sparsity.utils import create_semi_structured_tensor

X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)]
XQ_WQ_DTYPES = [
(torch.float8_e4m3fn, torch.float8_e4m3fn),
(torch.float8_e4m3fn, torch.float8_e5m2),
(torch.float8_e5m2, torch.float8_e4m3fn),
(torch.float8_e5m2, torch.float8_e5m2),
]
BATCH_SIZE = [1, 4]
SIZE_MNK = [
(2, 128, 256),
(3, 128, 256),
(13, 128, 256),
(27, 128, 128),
(33, 128, 64),
(65, 128, 32),
]
USE_BIAS = [False, True]
BIAS_DTYPE = [torch.float16]
TEST_PARAMS = list(
itertools.product(
X_W_DTYPES,
XQ_WQ_DTYPES,
BATCH_SIZE,
SIZE_MNK,
USE_BIAS,
BIAS_DTYPE,
)
)


def run_test_for_op(
op,
x_dtype,
w_dtype,
xq_dtype,
wq_dtype,
batch_size,
size_mnk,
use_bias,
bias_dtype,
):
random.seed(0) ## for create_semi_structured_tensor()

size_m, size_n, size_k = size_mnk

x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda")
w = create_semi_structured_tensor(size_n, size_k, dtype=w_dtype)
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.xfail("torch.nn.RMSNorm not supported")
rms_norm = torch.nn.RMSNorm(size_k).to("cuda")
x = rms_norm(x).to(x_dtype)
w = rms_norm(w).to(w_dtype)

x_aqt = to_affine_quantized_floatx(
input_float=x,
target_dtype=xq_dtype,
block_size=_get_per_token_block_size(x),
_layout=Float8Layout(mm_config=None),
)
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain()
assert zero_points is None

w_aqt = to_affine_quantized_floatx(
input_float=w,
target_dtype=wq_dtype,
block_size=_get_per_token_block_size(w),
_layout=Float8Layout(mm_config=None),
)
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain()
assert zero_points is None
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq)
wq_sp_scales = wq_scales

xq_2d = xq.view(-1, xq.shape[-1])
size_m_2d = xq_2d.shape[0]
output_ref = (
(xq_2d.float() @ wq.float().T)
* xq_scales.view(size_m_2d, 1)
* wq_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,))

fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias)
try:
output = op(*fn_inputs)
except NotImplementedError:
pytest.xfail("operator not implemented")

torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=5e-3)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices")
@pytest.mark.parametrize(
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype",
TEST_PARAMS,
)
def test_rowwise_scaled_linear_sparse_cutlass_f8f8(
x_w_dtypes,
xq_wq_dtypes,
batch_size,
size_mnk,
use_bias,
bias_dtype,
):
run_test_for_op(
rowwise_scaled_linear_sparse_cutlass_f8f8,
*x_w_dtypes,
*xq_wq_dtypes,
batch_size,
size_mnk,
use_bias,
bias_dtype,
)
Loading

0 comments on commit a7197f7

Please sign in to comment.