Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based row-wise scaled sparse FP8 kernel #1671

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):

dev = torch.device("cuda")
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
A_scale = torch.randn((m,), dtype=torch.float32, device=dev)
B = torch.randint(
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None
B_scale = torch.randn((n,), dtype=torch.float32, device=dev)
bias = None
out_dtype = torch.bfloat16

return A, A_scale, B, B_scale, C
return A, A_scale, B, B_scale, bias, out_dtype


def benchmark(m: int, k: int, n: int):
Expand All @@ -34,14 +35,14 @@ def benchmark(m: int, k: int, n: int):
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
A, A_scale, B, B_scale, bias, out_dtype = get_problem(m, n, k, 8, 4)
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, bias, out_dtype
)

A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
A, A_scale, B, B_scale, bias, out_dtype = get_problem(m, n, k, 4, 4)
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, bias, out_dtype
)

return {
Expand Down
61 changes: 61 additions & 0 deletions benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.ops import (
rowwise_scaled_linear_sparse_cutlass_f8f8,
to_sparse_semi_structured_cutlass_sm9x_f8,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def get_problem(m: int, n: int, k: int):
dev = torch.device("cuda")

A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn)
B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)

return A, A_scale, B_sp, B_meta, B_scale


def benchmark(m: int, k: int, n: int):
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B_sp, B_meta, B_scale = get_problem(m, n, k)
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
"f8f8 speedup (d/s)": fp16_time
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
1 change: 1 addition & 0 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Layouts and Tensor Subclasses
MarlinQQQLayout
Int4CPULayout
CutlassInt4PackedLayout
CutlassSemiSparseLayout

Quantization techniques
-----------------------
Expand Down
56 changes: 56 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import glob
import os
import subprocess
Expand Down Expand Up @@ -73,6 +74,7 @@ def use_debug_mode():
BuildExtension,
CppExtension,
CUDAExtension,
_get_cuda_arch_flags,
)

# Constant known variables used throughout this file
Expand Down Expand Up @@ -251,6 +253,7 @@ def get_extensions():
sources += cuda_sources

use_cutlass = False
cutlass_90a_sources = None
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
Expand All @@ -266,8 +269,46 @@ def get_extensions():
"-I" + cutlass_include_dir,
"-I" + cutlass_tools_include_dir,
"-I" + cutlass_extensions_include_dir,
"-DNDEBUG" if not debug_mode else "",
"-DCUTE_USE_PACKED_TUPLE=1",
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"--use_fast_math",
"--ftemplate-backtrace-limit=0",
# "--keep",
# "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage",
# "--resource-usage",
# "-lineinfo",
# "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md
]
)

cuda_arch_flags = _get_cuda_arch_flags()
build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
if build_for_sm90 and not build_for_sm90a:
cutlass_90a_sources = [
os.path.join(
extensions_cuda_dir,
"rowwise_scaled_linear_sparse_cutlass",
"rowwise_scaled_linear_sparse_cutlass_f8f8.cu",
),
os.path.join(
extensions_cuda_dir,
"to_sparse_semi_structured_cutlass_sm9x",
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
),
]
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
cutlass_90a_sources.append(
os.path.join(
extensions_cuda_dir,
"rowwise_scaled_linear_sparse_cutlass",
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
)
)
sources = [s for s in sources if s not in cutlass_90a_sources]
else:
# Remove CUTLASS-based kernels from the cuda_sources list. An
# assumption is that these files will have "cutlass" in its
Expand All @@ -291,6 +332,21 @@ def get_extensions():
)
)

if cutlass_90a_sources is not None and len(cutlass_90a_sources) > 0:
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
cutlass_90a_extra_compile_args["nvcc"].extend(
cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a"]
)
ext_modules.append(
extension(
"torchao._C",
cutlass_90a_sources,
py_limited_api=True,
extra_compile_args=cutlass_90a_extra_compile_args,
extra_link_args=extra_link_args,
)
)

if build_torchao_experimental:
ext_modules.append(
CMakeExtension(
Expand Down
90 changes: 42 additions & 48 deletions test/test_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,75 +7,73 @@
rowwise_scaled_linear_cutlass_s4s4,
rowwise_scaled_linear_cutlass_s8s4,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.quantization.quant_api import (
_int4_symm_per_token_quant_cutlass,
_int8_symm_per_token_quant_cutlass,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.utils import _get_per_token_block_size

ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
DTYPES = [torch.float16, torch.bfloat16]
BATCH_SIZE = [1, 4, 8, 16, 32, 64]
SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
USE_BIAS = [False, True]
TEST_PARAMS = list(
itertools.product(
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
DTYPES,
BATCH_SIZE,
SIZE_MNK,
USE_BIAS,
)
)


def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias):
assert xq_bits in [4, 8]
assert wq_bits in [4, 8]

def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias):
size_m, size_n, size_k = size_mnk

x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

x_2d = x.view(-1, x.shape[-1])
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric(
x_2d, xq_bits, size_k, dtype
)
assert torch.all(xq_2d_zeros == 0)
xq_s8 = xq_2d_s8.reshape(x.shape)
if xq_bits == 4:
xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF)
else:
xq = xq_s8
xq_scales = xq_2d_scales.reshape(x.shape[:-1])

wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric(
w, wq_bits, size_n, dtype
xq_bits = 4 if op == rowwise_scaled_linear_cutlass_s4s4 else 8
pack_s4 = lambda x: (x[..., 1::2] << 4) | (x[..., 0::2] & 0xF)

x_quant_func = (
_int4_symm_per_token_quant_cutlass
if xq_bits == 4
else _int8_symm_per_token_quant_cutlass
)
assert torch.all(wq_zeros == 0)
if wq_bits == 4:
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF)
else:
wq = wq_s8
x_aqt = x_quant_func(x)
xq_s8, xq_scales, zero_points = x_aqt.tensor_impl.get_plain()
assert zero_points is None
xq = pack_s4(xq_s8) if xq_bits == 4 else xq_s8

w_quant_func = _int4_symm_per_token_quant_cutlass
w_aqt = w_quant_func(w)
wq_s8, wq_scales, zero_points = w_aqt.tensor_impl.get_plain()
assert zero_points is None
wq = pack_s4(wq_s8)

# If torch.nn.functional.linear(x, w, bias) used as reference, the
# error would be too big. The calculation below is approximately
# what rowwise_scaled_linear_cutlass kernel is doing (except that
# matrix multiplication is over integers there).
size_m_2d = x_2d.shape[0]
output_ref = (
(xq_2d_s8.float() @ wq_s8.float().T)
* xq_2d_scales.view(size_m_2d, 1)
* wq_scales.view(1, size_n)
)
output_ref = (xq_s8.float() @ wq_s8.float().T) * xq_scales[..., None] * wq_scales
if bias is not None:
output_ref += bias
output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,))

fn_inputs = (xq, xq_scales, wq, wq_scales, bias)
fn_inputs = (xq, xq_scales, wq, wq_scales, bias, dtype)
try:
output = op(*fn_inputs)
except NotImplementedError:
Expand All @@ -85,20 +83,16 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
)
@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS)
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
run_test_for_op(
rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias
rowwise_scaled_linear_cutlass_s4s4, dtype, batch_size, size_mnk, use_bias
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
)
@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS)
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
run_test_for_op(
rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias
rowwise_scaled_linear_cutlass_s8s4, dtype, batch_size, size_mnk, use_bias
)
Loading
Loading