From d1d96f76bf85852a7c1387da93c3f3a1de9c07f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Tue, 4 Feb 2025 21:01:48 +0100 Subject: [PATCH] Add CUTLASS-based row-wise scaled sparse FP8 kernel --- ...benchmark_rowwise_scaled_linear_cutlass.py | 17 +- ...rk_rowwise_scaled_linear_sparse_cutlass.py | 61 ++ docs/source/api_ref_dtypes.rst | 1 + setup.py | 56 ++ test/test_rowwise_scaled_linear_cutlass.py | 90 ++- ...st_rowwise_scaled_linear_sparse_cutlass.py | 130 +++++ torchao/_models/llama/generate.py | 62 ++- .../rowwise_scaled_linear_cutlass.cuh | 523 +++++++++--------- .../rowwise_scaled_linear_cutlass_s4s4.cu | 29 +- .../rowwise_scaled_linear_cutlass_s8s4.cu | 23 +- .../rowwise_scaled_linear_sparse_cutlass.cuh | 496 +++++++++++++++++ ...e_scaled_linear_sparse_cutlass_e4m3e4m3.cu | 28 + ...se_scaled_linear_sparse_cutlass_e4m3e4m3.h | 14 + ...e_scaled_linear_sparse_cutlass_e4m3e5m2.cu | 28 + ...se_scaled_linear_sparse_cutlass_e4m3e5m2.h | 14 + ...e_scaled_linear_sparse_cutlass_e5m2e4m3.cu | 28 + ...se_scaled_linear_sparse_cutlass_e5m2e4m3.h | 14 + ...e_scaled_linear_sparse_cutlass_e5m2e5m2.cu | 28 + ...se_scaled_linear_sparse_cutlass_e5m2e5m2.h | 14 + ...wwise_scaled_linear_sparse_cutlass_f8f8.cu | 48 ++ ...to_sparse_semi_structured_cutlass_sm9x.cuh | 174 ++++++ ..._sparse_semi_structured_cutlass_sm9x_f8.cu | 35 ++ torchao/dtypes/__init__.py | 2 + torchao/dtypes/affine_quantized_tensor_ops.py | 8 + torchao/dtypes/floatx/__init__.py | 4 + .../floatx/cutlass_semi_sparse_layout.py | 178 ++++++ .../uintx/cutlass_int4_packed_layout.py | 37 +- torchao/ops.py | 124 ++++- torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 109 +++- 30 files changed, 1957 insertions(+), 420 deletions(-) create mode 100644 benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py create mode 100644 test/test_rowwise_scaled_linear_sparse_cutlass.py create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu create mode 100644 torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh create mode 100644 torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu create mode 100644 torchao/dtypes/floatx/cutlass_semi_sparse_layout.py diff --git a/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py index c4c9c099be..31b94c08db 100644 --- a/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py +++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py @@ -18,14 +18,15 @@ def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): dev = torch.device("cuda") A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) - A_scale = torch.randn((m,), dtype=torch.half, device=dev) + A_scale = torch.randn((m,), dtype=torch.float32, device=dev) B = torch.randint( -128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev ) - B_scale = torch.randn((n,), dtype=torch.half, device=dev) - C = None + B_scale = torch.randn((n,), dtype=torch.float32, device=dev) + bias = None + out_dtype = torch.bfloat16 - return A, A_scale, B, B_scale, C + return A, A_scale, B, B_scale, bias, out_dtype def benchmark(m: int, k: int, n: int): @@ -34,14 +35,14 @@ def benchmark(m: int, k: int, n: int): B_ref = torch.randn((n, k), dtype=torch.half, device=dev) fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) - A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4) + A, A_scale, B, B_scale, bias, out_dtype = get_problem(m, n, k, 8, 4) rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( - rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C + rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, bias, out_dtype ) - A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4) + A, A_scale, B, B_scale, bias, out_dtype = get_problem(m, n, k, 4, 4) rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds( - rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C + rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, bias, out_dtype ) return { diff --git a/benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py new file mode 100644 index 0000000000..5c544daedb --- /dev/null +++ b/benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py @@ -0,0 +1,61 @@ +import pandas as pd +import torch +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, + to_sparse_semi_structured_cutlass_sm9x_f8, +) + + +def benchmark_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 + + +def get_problem(m: int, n: int, k: int): + dev = torch.device("cuda") + + A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2) + A_scale = torch.randn((m,), dtype=torch.half, device=dev) + B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn) + B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B) + B_scale = torch.randn((n,), dtype=torch.half, device=dev) + + return A, A_scale, B_sp, B_meta, B_scale + + +def benchmark(m: int, k: int, n: int): + dev = torch.device("cuda") + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B_ref = torch.randn((n, k), dtype=torch.half, device=dev) + fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) + + A, A_scale, B_sp, B_meta, B_scale = get_problem(m, n, k) + rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds( + rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale + ) + + return { + "m": m, + "k": k, + "n": n, + "fp16_latency (ms)": fp16_time, + "rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time, + "f8f8 speedup (d/s)": fp16_time + / rowwise_scaled_linear_sparse_cutlass_f8f8_time, + } + + +if __name__ == "__main__": + k_vals = (8192, 8192, 8192, 28672) + n_vals = (8192, 10240, 57344, 8192) + + results = [] + for m in tqdm([1 << i for i in range(10)]): + for n, k in zip(n_vals, k_vals): + results.append(benchmark(m, k, n)) + + df = pd.DataFrame(results) + df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False) + print(df.to_markdown(index=False)) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 26e1266c09..6cbec7465e 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -28,6 +28,7 @@ Layouts and Tensor Subclasses MarlinQQQLayout Int4CPULayout CutlassInt4PackedLayout + CutlassSemiSparseLayout Quantization techniques ----------------------- diff --git a/setup.py b/setup.py index ee3ebbf453..9e0e8bed77 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import copy import glob import os import subprocess @@ -73,6 +74,7 @@ def use_debug_mode(): BuildExtension, CppExtension, CUDAExtension, + _get_cuda_arch_flags, ) # Constant known variables used throughout this file @@ -251,6 +253,7 @@ def get_extensions(): sources += cuda_sources use_cutlass = False + cutlass_90a_sources = None if use_cuda and not IS_WINDOWS: use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") @@ -266,8 +269,46 @@ def get_extensions(): "-I" + cutlass_include_dir, "-I" + cutlass_tools_include_dir, "-I" + cutlass_extensions_include_dir, + "-DNDEBUG" if not debug_mode else "", + "-DCUTE_USE_PACKED_TUPLE=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", + "--use_fast_math", + "--ftemplate-backtrace-limit=0", + # "--keep", + # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", + # "--resource-usage", + # "-lineinfo", + # "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md ] ) + + cuda_arch_flags = _get_cuda_arch_flags() + build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags + build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags + if build_for_sm90 and not build_for_sm90a: + cutlass_90a_sources = [ + os.path.join( + extensions_cuda_dir, + "rowwise_scaled_linear_sparse_cutlass", + "rowwise_scaled_linear_sparse_cutlass_f8f8.cu", + ), + os.path.join( + extensions_cuda_dir, + "to_sparse_semi_structured_cutlass_sm9x", + "to_sparse_semi_structured_cutlass_sm9x_f8.cu", + ), + ] + for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]: + cutlass_90a_sources.append( + os.path.join( + extensions_cuda_dir, + "rowwise_scaled_linear_sparse_cutlass", + "rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu", + ) + ) + sources = [s for s in sources if s not in cutlass_90a_sources] else: # Remove CUTLASS-based kernels from the cuda_sources list. An # assumption is that these files will have "cutlass" in its @@ -291,6 +332,21 @@ def get_extensions(): ) ) + if cutlass_90a_sources is not None and len(cutlass_90a_sources) > 0: + cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args) + cutlass_90a_extra_compile_args["nvcc"].extend( + cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a"] + ) + ext_modules.append( + extension( + "torchao._C", + cutlass_90a_sources, + py_limited_api=True, + extra_compile_args=cutlass_90a_extra_compile_args, + extra_link_args=extra_link_args, + ) + ) + if build_torchao_experimental: ext_modules.append( CMakeExtension( diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py index d6203ab9a4..fd0060e495 100644 --- a/test/test_rowwise_scaled_linear_cutlass.py +++ b/test/test_rowwise_scaled_linear_cutlass.py @@ -7,11 +7,19 @@ rowwise_scaled_linear_cutlass_s4s4, rowwise_scaled_linear_cutlass_s8s4, ) -from torchao.quantization.utils import group_quantize_tensor_symmetric +from torchao.quantization.quant_api import ( + _int4_symm_per_token_quant_cutlass, + _int8_symm_per_token_quant_cutlass, +) +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) +from torchao.quantization.utils import _get_per_token_block_size -ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] -ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] -ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [ +DTYPES = [torch.float16, torch.bfloat16] +BATCH_SIZE = [1, 4, 8, 16, 32, 64] +SIZE_MNK = [ (2, 512, 128), (3, 2048, 2048), (4, 3584, 640), @@ -19,63 +27,53 @@ (26, 18944, 1664), (67, 6656, 1408), ] -ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True] -ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list( +USE_BIAS = [False, True] +TEST_PARAMS = list( itertools.product( - ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE, - ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE, - ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK, - ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS, + DTYPES, + BATCH_SIZE, + SIZE_MNK, + USE_BIAS, ) ) -def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias): - assert xq_bits in [4, 8] - assert wq_bits in [4, 8] - +def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias): size_m, size_n, size_k = size_mnk x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") w = torch.rand((size_n, size_k), dtype=dtype, device="cuda") bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None - x_2d = x.view(-1, x.shape[-1]) - xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric( - x_2d, xq_bits, size_k, dtype - ) - assert torch.all(xq_2d_zeros == 0) - xq_s8 = xq_2d_s8.reshape(x.shape) - if xq_bits == 4: - xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF) - else: - xq = xq_s8 - xq_scales = xq_2d_scales.reshape(x.shape[:-1]) - - wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric( - w, wq_bits, size_n, dtype + xq_bits = 4 if op == rowwise_scaled_linear_cutlass_s4s4 else 8 + pack_s4 = lambda x: (x[..., 1::2] << 4) | (x[..., 0::2] & 0xF) + + x_quant_func = ( + _int4_symm_per_token_quant_cutlass + if xq_bits == 4 + else _int8_symm_per_token_quant_cutlass ) - assert torch.all(wq_zeros == 0) - if wq_bits == 4: - wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF) - else: - wq = wq_s8 + x_aqt = x_quant_func(x) + xq_s8, xq_scales, zero_points = x_aqt.tensor_impl.get_plain() + assert zero_points is None + xq = pack_s4(xq_s8) if xq_bits == 4 else xq_s8 + + w_quant_func = _int4_symm_per_token_quant_cutlass + w_aqt = w_quant_func(w) + wq_s8, wq_scales, zero_points = w_aqt.tensor_impl.get_plain() + assert zero_points is None + wq = pack_s4(wq_s8) # If torch.nn.functional.linear(x, w, bias) used as reference, the # error would be too big. The calculation below is approximately # what rowwise_scaled_linear_cutlass kernel is doing (except that # matrix multiplication is over integers there). - size_m_2d = x_2d.shape[0] - output_ref = ( - (xq_2d_s8.float() @ wq_s8.float().T) - * xq_2d_scales.view(size_m_2d, 1) - * wq_scales.view(1, size_n) - ) + output_ref = (xq_s8.float() @ wq_s8.float().T) * xq_scales[..., None] * wq_scales if bias is not None: output_ref += bias output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,)) - fn_inputs = (xq, xq_scales, wq, wq_scales, bias) + fn_inputs = (xq, xq_scales, wq, wq_scales, bias, dtype) try: output = op(*fn_inputs) except NotImplementedError: @@ -85,20 +83,16 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS -) +@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS) def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): run_test_for_op( - rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias + rowwise_scaled_linear_cutlass_s4s4, dtype, batch_size, size_mnk, use_bias ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS -) +@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS) def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): run_test_for_op( - rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias + rowwise_scaled_linear_cutlass_s8s4, dtype, batch_size, size_mnk, use_bias ) diff --git a/test/test_rowwise_scaled_linear_sparse_cutlass.py b/test/test_rowwise_scaled_linear_sparse_cutlass.py new file mode 100644 index 0000000000..c983edc7a7 --- /dev/null +++ b/test/test_rowwise_scaled_linear_sparse_cutlass.py @@ -0,0 +1,130 @@ +import itertools +import random + +import pytest +import torch +from torch.testing._internal.common_cuda import SM90OrLater + +from torchao.dtypes import ( + Float8Layout, + to_affine_quantized_floatx, +) +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, + to_sparse_semi_structured_cutlass_sm9x_f8, +) +from torchao.quantization.utils import _get_per_token_block_size +from torchao.sparsity.utils import create_semi_structured_tensor + +X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)] +XQ_WQ_DTYPES = [ + (torch.float8_e4m3fn, torch.float8_e4m3fn), + (torch.float8_e4m3fn, torch.float8_e5m2), + (torch.float8_e5m2, torch.float8_e4m3fn), + (torch.float8_e5m2, torch.float8_e5m2), +] +BATCH_SIZE = [1, 4] +SIZE_MNK = [ + (2, 128, 256), + (3, 128, 256), + (13, 128, 256), + (27, 128, 128), + (33, 128, 64), + (65, 128, 32), +] +USE_BIAS = [False, True] +BIAS_DTYPE = [torch.float16] +TEST_PARAMS = list( + itertools.product( + X_W_DTYPES, + XQ_WQ_DTYPES, + BATCH_SIZE, + SIZE_MNK, + USE_BIAS, + BIAS_DTYPE, + ) +) + + +def run_test_for_op( + op, + x_dtype, + w_dtype, + xq_dtype, + wq_dtype, + batch_size, + size_mnk, + use_bias, + bias_dtype, +): + random.seed(0) ## for create_semi_structured_tensor() + + size_m, size_n, size_k = size_mnk + + x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda") + w = create_semi_structured_tensor(size_n, size_k, dtype=w_dtype) + bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None + + x_aqt = to_affine_quantized_floatx( + input_float=x, + target_dtype=xq_dtype, + block_size=_get_per_token_block_size(x), + _layout=Float8Layout(mm_config=None), + ) + xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain() + assert zero_points is None + + w_aqt = to_affine_quantized_floatx( + input_float=w, + target_dtype=wq_dtype, + block_size=_get_per_token_block_size(w), + _layout=Float8Layout(mm_config=None), + ) + wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain() + assert zero_points is None + wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq) + wq_sp_scales = wq_scales + + xq_2d = xq.view(-1, xq.shape[-1]) + size_m_2d = xq_2d.shape[0] + output_ref = ( + (xq_2d.float() @ wq.float().T) + * xq_scales.view(size_m_2d, 1) + * wq_scales.view(1, size_n) + ) + if bias is not None: + output_ref += bias + output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,)) + + fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias) + try: + output = op(*fn_inputs) + except NotImplementedError: + pytest.xfail("operator not implemented") + + torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=5e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices") +@pytest.mark.parametrize( + "x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype", + TEST_PARAMS, +) +def test_rowwise_scaled_linear_sparse_cutlass_f8f8( + x_w_dtypes, + xq_wq_dtypes, + batch_size, + size_mnk, + use_bias, + bias_dtype, +): + run_test_for_op( + rowwise_scaled_linear_sparse_cutlass_f8f8, + *x_w_dtypes, + *xq_wq_dtypes, + batch_size, + size_mnk, + use_bias, + bias_dtype, + ) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 0958a5207c..e225f8068b 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -334,11 +334,13 @@ def ffn_or_attn_only(mod, fqn): if quantization: from torchao.quantization import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, autoquant, float8_dynamic_activation_float8_weight, float8_weight_only, fpx_weight_only, gemlite_uintx_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -434,18 +436,30 @@ def ffn_or_attn_only(mod, fqn): ] ), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) - elif "int8adq-int4w-symm" in quantization: + elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout - quantize_( - model, - int8_dynamic_activation_int4_weight( - group_size=None, - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, - layout=CutlassInt4PackedLayout(), - ), - ) + nbits = int(quantization.removeprefix("int4dq-")) + assert nbits == 4 or nbits == 8 + if nbits == 4: + quantize_( + model, + int4_dynamic_activation_int4_weight( + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=CutlassInt4PackedLayout(), + ), + ) + elif nbits == 8: + quantize_( + model, + int8_dynamic_activation_int4_weight( + group_size=None, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=CutlassInt4PackedLayout(), + ), + ) if "marlin" in quantization: if "qqq" in quantization: from torchao.dtypes import MarlinQQQLayout @@ -564,16 +578,24 @@ def ffn_or_attn_only(mod, fqn): elif "float8wo" in quantization: quantize_(model, float8_weight_only()) elif "float8dq" in quantization: - granularity = str(quantization.split("-")[-1]) - if granularity == "tensor": - granularity = PerTensor() - elif granularity == "row": - granularity = PerRow() + if sparsity and "semi" in sparsity: + quantize_( + model, + Float8DynamicActivationFloat8SemiSparseWeightConfig(), + filter_fn=ffn_only, + ) else: - granularity = PerTensor() - quantize_( - model, float8_dynamic_activation_float8_weight(granularity=granularity) - ) + granularity = str(quantization.split("-")[-1]) + if granularity == "tensor": + granularity = PerTensor() + elif granularity == "row": + granularity = PerRow() + else: + granularity = PerTensor() + quantize_( + model, + float8_dynamic_activation_float8_weight(granularity=granularity), + ) elif "autoquant_v2" in quantization: from torchao._models._eval import InputRecorder from torchao._models.llama.model import prepare_inputs_for_model @@ -1130,7 +1152,7 @@ def callback(x): help=( "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx--, uintx---hqq, sparse-marlin, spinquant, " - + "embed-int8wo, marlin_qqq, gemlite---, int8adq-int4w-symm" + + "embed-int8wo, marlin_qqq, gemlite---, float8dq, int4dq-" ), ) parser.add_argument( diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh index 0117f12e27..bfc6e69cc1 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -11,6 +13,8 @@ #endif #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) +#include +#include #include #include #include @@ -25,60 +29,61 @@ namespace torchao { #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) template< + typename DtypeXq, + typename DtypeWq, + typename DtypeY, + typename UseBias, + typename DtypeBias, + typename DtypeXScale, + typename DtypeWScale, typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadblockSwizzle, - int NumStages, - typename ElementA, - typename ElementB, - typename ElementOutput, - typename ElementC, - typename UseTensorC, - typename ElementAScale, - typename ElementBScale> + int NumStages> void rowwise_scaled_linear_kernel_cutlass_sm8x( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale, const at::Tensor& bias, at::Tensor& Y) { + using SmArch = cutlass::arch::Sm80; + + // Use CUTLASS naming conventions for naming datatypes. + using ElementA = DtypeXq; + using ElementB = DtypeWq; + using ElementD = DtypeY; + using ElementAScale = DtypeXScale; + using ElementBScale = DtypeWScale; + using ElementBias = DtypeBias; + static_assert((cutlass::sizeof_bits::value >= 8 || 8 % cutlass::sizeof_bits::value == 0) && (cutlass::sizeof_bits::value >= 8 || 8 % cutlass::sizeof_bits::value == 0)); - using SmArch = cutlass::arch::Sm80; + using LayoutTagA = cutlass::layout::RowMajor; + using LayoutTagB = cutlass::layout::ColumnMajor; + using LayoutTagD = cutlass::layout::RowMajor; - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - // TODO: use FP32 if either ElementA/B is FP + // TODO: use FP32 if either ElementA/ElementB is FP using ElementAccumulator = int32_t; using Operator = std::conditional_t::value, cutlass::arch::OpMultiplyAddSaturate, cutlass::arch::OpMultiplyAddMixedInputUpcast>; - using ElementEpilogue = float; + using ElementCompute = float; constexpr auto NumEVTEpilogueStages = 1; - const int m = tensor_a.size(0); - const int n = tensor_b.size(0); - int k = tensor_a.size(1); + const int m = Xq.size(0); + const int n = Wq.size(0); + int k = Xq.size(1); if constexpr (cutlass::sizeof_bits::value < 8) { k *= 8 / cutlass::sizeof_bits::value; } constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentAScale = - 128 / cutlass::sizeof_bits::value; constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentBScale = - 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentOutput = - 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Check for current CUTLASS limitations w.r.t. alignments. TORCH_CHECK(k % AlignmentA == 0, OPERATOR_NAME, @@ -87,16 +92,16 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( TORCH_CHECK(k % AlignmentB == 0, OPERATOR_NAME, " : Number of columns of tensor B must be divisible by ", AlignmentB); - TORCH_CHECK(n % AlignmentC == 0, OPERATOR_NAME, - " : Number of columns of tensor C must be divisible by ", - AlignmentC); + TORCH_CHECK(n % AlignmentD == 0, OPERATOR_NAME, + " : Number of columns of output tensor must be divisible by ", + AlignmentD); using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, WarpShape, - ElementOutput, - AlignmentOutput, + ElementD, + AlignmentD, NumEVTEpilogueStages>; using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; @@ -116,18 +121,18 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( using TensorBScaleArguments = typename TensorBScale::Arguments; using TensorCScalar = - cutlass::epilogue::threadblock::VisitorScalarBroadcast; + cutlass::epilogue::threadblock::VisitorScalarBroadcast; using TensorCTensor = cutlass::epilogue::threadblock::VisitorRowBroadcast< OutputTileThreadMap, - ElementC, + ElementBias, cute::Stride>; using TensorC = - std::conditional_t; + std::conditional_t; using TensorCArguments = typename TensorC::Arguments; using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest >; using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT< @@ -136,7 +141,7 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( TensorAScale>; using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest >; using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT< @@ -145,7 +150,7 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( TensorBScale>; using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::plus, ElementEpilogue, ElementEpilogue, + cutlass::plus, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest >; using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< @@ -154,7 +159,7 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( TensorC>; using Output = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementOutput, + OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, cute::Stride // StrideMNL >; @@ -165,11 +170,11 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( using EVTKernel = torchao::enable_2x_kernel_for_sm80_or_later< typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, - ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, - ElementOutput, LayoutOutput, AlignmentOutput, + ElementA, LayoutTagA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutTagB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementD, LayoutTagD, AlignmentD, ElementAccumulator, - ElementEpilogue, + ElementCompute, cutlass::arch::OpClassTensorOp, SmArch, ThreadblockShape, @@ -187,55 +192,55 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( cutlass::gemm::GemmCoord problem_size(m, n, k); constexpr auto SplitKFactor = 1; - TensorAScaleArguments tensor_a_scale_arguments{ - (ElementAScale*)tensor_a_scale.data_ptr(), + TensorAScaleArguments X_scale_arguments{ + (ElementAScale*)X_scale.data_ptr(), ElementAScale(1), {cute::_1{}, cute::_0{}, problem_size.m()} }; - TensorBScaleArguments tensor_b_scale_arguments{ - (ElementBScale*)tensor_b_scale.data_ptr(), + TensorBScaleArguments W_scale_arguments{ + (ElementBScale*)W_scale.data_ptr(), ElementBScale(1), {cute::_0{}, cute::_1{}, problem_size.n()} }; - TensorCArguments tensor_c_arguments{ + TensorCArguments bias_arguments{ [&]() -> TensorCArguments { - if constexpr (UseTensorC::value) { - return {(ElementC*)tensor_c.data_ptr(), - ElementC(0), + if constexpr (UseBias::value) { + return {(ElementBias*)bias.data_ptr(), + ElementBias(0), {cute::_0{}, cute::_1{}, problem_size.n()}}; } else { - return {ElementC(0)}; + return {ElementBias(0)}; } }() }; typename Output::Arguments output_arguments{ - (ElementOutput*)tensor_d.data_ptr(), + (ElementD*)Y.data_ptr(), {problem_size.n(), cute::_1{}, problem_size.mn().product()} }; typename EVTOutput::Arguments callback_arguments{ { { { - {}, // Accum - tensor_a_scale_arguments, // TensorAScale - {} // ApplyAScale - }, // EVTApplyAScale - tensor_b_scale_arguments, // TensorBScale - {}, // ApplyBScale - }, // EVTApplyBScale - tensor_c_arguments, // TensorC - {} // ApplySum - }, // EVTApplySum - output_arguments // Output - }; // EVTOutput + {}, // Accum + X_scale_arguments, // TensorAScale + {} // ApplyAScale + }, // EVTApplyAScale + W_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, // EVTApplyBScale + bias_arguments, // TensorC + {} // ApplySum + }, // EVTApplySum + output_arguments // Output + }; // EVTOutput typename Gemm::Arguments arguments( cutlass::gemm::GemmUniversalMode::kGemm, problem_size, SplitKFactor, callback_arguments, // arguments of EVT callbacks - (ElementA*)tensor_a.data_ptr(), - (ElementB*)tensor_b.data_ptr(), + (ElementA*)Xq.data_ptr(), + (ElementB*)Wq.data_ptr(), nullptr, // ptr C (unused) nullptr, // ptr D (unused) problem_size.mk().product(), // batch stride A @@ -259,8 +264,8 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. const auto workspace_size = Gemm::get_workspace_size(arguments); - auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, - at::TensorOptions().dtype(at::kByte)); + auto workspace = Xq.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); // Initialize CUTLASS mixed datatypes GEMM object. status = gemm_op.initialize(arguments, workspace.data_ptr(), @@ -274,95 +279,87 @@ void rowwise_scaled_linear_kernel_cutlass_sm8x( C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template +template static void select_config( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale, const at::Tensor& bias, at::Tensor& Y) { const auto dprops = at::cuda::getCurrentDeviceProperties(); const auto is_sm8x = dprops->major == 8; if (is_sm8x) { - if constexpr (std::is_same::value && - std::is_same::value) { + if constexpr (std::is_same::value && + std::is_same::value) { using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; // some basic tuning - if (tensor_a.size(0) <= 16) { + if (Xq.size(0) <= 16) { using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 256>; using WarpShape = cutlass::gemm::GemmShape<16, 32, 256>; constexpr auto NumStages = 5; rowwise_scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, - NumStages, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 32) { + DtypeXq, DtypeWq, Types..., ThreadblockShape, WarpShape, + InstructionShape, ThreadblockSwizzle, NumStages>( + Xq, X_scale, Wq, W_scale, bias, Y); + } else if (Xq.size(0) <= 32) { using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 256>; using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; constexpr auto NumStages = 4; rowwise_scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, - NumStages, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 128) { + DtypeXq, DtypeWq, Types..., ThreadblockShape, WarpShape, + InstructionShape, ThreadblockSwizzle, NumStages>( + Xq, X_scale, Wq, W_scale, bias, Y); + } else if (Xq.size(0) <= 128) { using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; constexpr auto NumStages = 4; rowwise_scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, - NumStages, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); + DtypeXq, DtypeWq, Types..., ThreadblockShape, WarpShape, + InstructionShape, ThreadblockSwizzle, NumStages>( + Xq, X_scale, Wq, W_scale, bias, Y); } else { using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; constexpr auto NumStages = 4; rowwise_scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, - NumStages, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); + DtypeXq, DtypeWq, Types..., ThreadblockShape, WarpShape, + InstructionShape, ThreadblockSwizzle, NumStages>( + Xq, X_scale, Wq, W_scale, bias, Y); } return; - } else if constexpr (std::is_same::value && - std::is_same::value) { + } else if constexpr (std::is_same::value && + std::is_same::value) { using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using ThreadblockSwizzle = cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // A minimal heuristic to improve performance for small number // of inputs cases. - if (tensor_a.size(0) <= 16) { + if (Xq.size(0) <= 16) { using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; constexpr auto NumStages = 6; rowwise_scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, - NumStages, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 32) { + DtypeXq, DtypeWq, Types..., ThreadblockShape, WarpShape, + InstructionShape, ThreadblockSwizzle, NumStages>( + Xq, X_scale, Wq, W_scale, bias, Y); + } else if (Xq.size(0) <= 32) { using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; constexpr auto NumStages = 5; rowwise_scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, - NumStages, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); + DtypeXq, DtypeWq, Types..., ThreadblockShape, WarpShape, + InstructionShape, ThreadblockSwizzle, NumStages>( + Xq, X_scale, Wq, W_scale, bias, Y); } else { using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; constexpr auto NumStages = 4; rowwise_scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, - NumStages, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); + DtypeXq, DtypeWq, Types..., ThreadblockShape, WarpShape, + InstructionShape, ThreadblockSwizzle, NumStages>( + Xq, X_scale, Wq, W_scale, bias, Y); } return; } @@ -372,115 +369,124 @@ static void select_config( dprops->major, ".", dprops->minor, " for given operands"); } -template< - typename ElementA, - typename ElementB, - typename ElementOutput, - typename... Types> +template static void -dispatch_on_tensor_c( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - if (tensor_c.numel() == 0) { - using ElementC = ElementOutput; - using UseTensorC = std::false_type; - select_config< - ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); +dispatch_on_X_scale_and_W_scale( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale, const at::Tensor& bias, at::Tensor& Y) { + if (X_scale.scalar_type() == at::ScalarType::Half && + W_scale.scalar_type() == at::ScalarType::Half) { + using DtypeXScale = cutlass::half_t; + using DtypeWScale = cutlass::half_t; + select_config( + Xq, X_scale, Wq, W_scale, bias, Y); return; - } - - using UseTensorC = std::true_type; - if (tensor_c.scalar_type() == at::ScalarType::Half) { - using ElementC = cutlass::half_t; - select_config< - ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); + } else if (X_scale.scalar_type() == at::ScalarType::BFloat16 && + W_scale.scalar_type() == at::ScalarType::BFloat16) { + using DtypeXScale = cutlass::bfloat16_t; + using DtypeWScale = cutlass::bfloat16_t; + select_config( + Xq, X_scale, Wq, W_scale, bias, Y); return; - } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { - using ElementC = cutlass::bfloat16_t; - select_config< - ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); + } else if (X_scale.scalar_type() == at::ScalarType::Float && + W_scale.scalar_type() == at::ScalarType::Float) { + using DtypeXScale = float; + using DtypeWScale = float; + select_config( + Xq, X_scale, Wq, W_scale, bias, Y); return; } - TORCH_CHECK(false, OPERATOR_NAME, " : Operator not supported for datatype ", - tensor_c.scalar_type(), " for addend"); + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported for combination of data types ", + X_scale.scalar_type(), " for first operand scale and ", + W_scale.scalar_type(), " for second operand scale"); } -template +template static void -dispatch_on_tensor_a_scale_and_tensor_b_scale( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), - OPERATOR_NAME, " : Operator not supported for output datatype ", - tensor_d.scalar_type(), " as it's different from the first ", - " operand scale datatype ", tensor_a_scale.scalar_type()); - - if (tensor_a_scale.scalar_type() == at::ScalarType::Half && - tensor_b_scale.scalar_type() == at::ScalarType::Half) { - using ElementAScale = cutlass::half_t; - using ElementBScale = cutlass::half_t; - using ElementOutput = cutlass::half_t; - dispatch_on_tensor_c( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); +dispatch_on_bias( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale, const std::optional& bias_opt, + at::Tensor& Y) { + if (bias_opt.has_value()) { + const auto bias = *bias_opt; + TORCH_CHECK(bias.scalar_type() == Y.scalar_type(), + OPERATOR_NAME, " : Operator not supported for bias datatype ", + bias.scalar_type(), " as it's different from the output ", + " datatype ", Y.scalar_type()); + } + + using DtypeBias = DtypeY; + if (bias_opt.has_value()) { + using UseBias = std::true_type; + const auto bias = *bias_opt; + dispatch_on_X_scale_and_W_scale< + DtypeXq, DtypeWq, DtypeY, UseBias, DtypeBias>( + Xq, X_scale, Wq, W_scale, bias, Y); + } else { + using UseBias = std::false_type; + dispatch_on_X_scale_and_W_scale< + DtypeXq, DtypeWq, DtypeY, UseBias, DtypeBias>( + Xq, X_scale, Wq, W_scale, Y, Y); + } +} + +template +static void +dispatch_on_Y( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale,const std::optional& bias_opt, + at::Tensor& Y) { + if (Y.scalar_type() == at::ScalarType::Half) { + using DtypeY = cutlass::half_t; + dispatch_on_bias( + Xq, X_scale, Wq, W_scale, bias_opt, Y); return; - } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && - tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { - using ElementAScale = cutlass::bfloat16_t; - using ElementBScale = cutlass::bfloat16_t; - using ElementOutput = cutlass::bfloat16_t; - dispatch_on_tensor_c( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + } else if (Y.scalar_type() == at::ScalarType::BFloat16) { + using DtypeY = cutlass::bfloat16_t; + dispatch_on_bias( + Xq, X_scale, Wq, W_scale, bias_opt, Y); return; } TORCH_CHECK(false, OPERATOR_NAME, - " : Operator not supported for combination of data types ", - tensor_a_scale.scalar_type(), " for first operand scale and ", - tensor_b_scale.scalar_type(), " for second operand scale"); + " : Operator not supported for output data type ", + Y.scalar_type()); } -template +template void -rowwise_scaled_linear_cutlass_check_inputs( - const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias){ +check_inputs( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale, const std::optional& bias_opt) { // Validate layouts of arguments. - TORCH_CHECK(xq.dim() >= 2, OPERATOR_NAME, - " : Expected xq argument to be 2D or higher-dimensional tensor, " - "got ", xq.dim(), " dims"); - TORCH_CHECK(xq.layout() == at::Layout::Strided, OPERATOR_NAME, - " : Expected xq argument to be strided, got layout ", - xq.layout()); - TORCH_CHECK(x_scale.dim() == xq.dim() - 1, OPERATOR_NAME, - " : Expected xq scale argument to be ", xq.dim() - 1, - "D tensor, got ", x_scale.dim(), " dims"); - TORCH_CHECK(x_scale.layout() == at::Layout::Strided, OPERATOR_NAME, - " : Expected xq scale argument to be strided, got layout ", - x_scale.layout()); - TORCH_CHECK(wq.dim() == 2, OPERATOR_NAME, - " : Expected wq argument to be 2D tensor, got ", wq.dim(), + TORCH_CHECK(Xq.dim() >= 2, OPERATOR_NAME, + " : Expected Xq argument to be 2D or higher-dimensional tensor, " + "got ", Xq.dim(), " dims"); + TORCH_CHECK(Xq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Xq argument to be strided, got layout ", + Xq.layout()); + TORCH_CHECK(X_scale.dim() == Xq.dim() - 1, OPERATOR_NAME, + " : Expected Xq scale argument to be ", Xq.dim() - 1, + "D tensor, got ", X_scale.dim(), " dims"); + TORCH_CHECK(X_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Xq scale argument to be strided, got layout ", + X_scale.layout()); + TORCH_CHECK(Wq.dim() == 2, OPERATOR_NAME, + " : Expected Wq argument to be 2D tensor, got ", Wq.dim(), " dims"); - TORCH_CHECK(wq.layout() == at::Layout::Strided, OPERATOR_NAME, - " : Expected wq argument to be strided, got layout ", - wq.layout()); - TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, OPERATOR_NAME, - " : Expected wq scale argument to be 1D or 2D tensor, ", "got ", - w_scale.dim(), " dims"); - TORCH_CHECK(w_scale.layout() == at::Layout::Strided, OPERATOR_NAME, - " : Expected wq scale argument to be strided, got layout ", - w_scale.layout()); - if (bias.numel() > 0) { + TORCH_CHECK(Wq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq argument to be strided, got layout ", + Wq.layout()); + TORCH_CHECK(W_scale.dim() == 1 || W_scale.dim() == 2, OPERATOR_NAME, + " : Expected Wq scale argument to be 1D or 2D tensor, ", "got ", + W_scale.dim(), " dims"); + TORCH_CHECK(W_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq scale argument to be strided, got layout ", + W_scale.layout()); + if (bias_opt.has_value()) { + const auto bias = *bias_opt; TORCH_CHECK(bias.dim() == 1, OPERATOR_NAME, " : Expected bias argument to be 1D tensor, got ", bias.dim(), " dims"); @@ -490,43 +496,45 @@ rowwise_scaled_linear_cutlass_check_inputs( } // Validate sizes of arguments. - const auto xq_sizes = xq.sizes().vec(); - TORCH_CHECK(xq_sizes.back() == wq.size(1) || - xq_sizes.back() == 2 * wq.size(1), - OPERATOR_NAME, " : Expected xq argument to have ", wq.size(1), - " or ", 2 * wq.size(1), " columns, but got ", xq_sizes.back()); - const auto x_scale_sizes = x_scale.sizes().vec(); - for (auto i = 0; i < x_scale_sizes.size(); ++i) - TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], OPERATOR_NAME, - " : Expected xq scale argument size at position ", i, " to be ", - xq_sizes[i], ", but got ", x_scale_sizes[i]); - TORCH_CHECK(w_scale.numel() == wq.size(0), OPERATOR_NAME, - " : Expected wq scale argument to have ", wq.size(0), - " elements, got ", w_scale.numel(), " elements"); - if (bias.numel() > 0) { - TORCH_CHECK(bias.numel() == wq.size(0), OPERATOR_NAME, - " : Expected bias argument to have ", wq.size(0), + const auto Xq_sizes = Xq.sizes().vec(); + TORCH_CHECK(Xq_sizes.back() == Wq.size(1) || + Xq_sizes.back() == 2 * Wq.size(1), + OPERATOR_NAME, " : Expected Xq argument to have ", Wq.size(1), + " or ", 2 * Wq.size(1), " columns, but got ", Xq_sizes.back()); + const auto X_scale_sizes = X_scale.sizes().vec(); + for (auto i = 0; i < X_scale_sizes.size(); ++i) + TORCH_CHECK(X_scale_sizes[i] == Xq_sizes[i], OPERATOR_NAME, + " : Expected Xq scale argument size at position ", i, " to be ", + Xq_sizes[i], ", but got ", X_scale_sizes[i]); + TORCH_CHECK(W_scale.numel() == Wq.size(0), OPERATOR_NAME, + " : Expected Wq scale argument to have ", Wq.size(0), + " elements, got ", W_scale.numel(), " elements"); + if (bias_opt.has_value()) { + const auto bias = *bias_opt; + TORCH_CHECK(bias.numel() == Wq.size(0), OPERATOR_NAME, + " : Expected bias argument to have ", Wq.size(0), " elements, got ", bias.numel(), " elements"); } // Validate strides of arguments. - const auto xq_strides = xq.strides(); - TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, OPERATOR_NAME, - " : Expected xq argument in row-major layout"); - auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; - for (int i = xq_strides.size() - 3; i >= 0; --i) { - xq_stride_expected *= xq_sizes[i + 1]; - TORCH_CHECK(xq_strides[i] == xq_stride_expected, OPERATOR_NAME, - " : Expected xq argument in row-major layout"); + const auto Xq_strides = Xq.strides(); + TORCH_CHECK(Xq_strides[Xq_strides.size() - 1] == 1, OPERATOR_NAME, + " : Expected Xq argument in row-major layout"); + auto Xq_stride_expected = Xq_strides[Xq_strides.size() - 2]; + for (int i = Xq_strides.size() - 3; i >= 0; --i) { + Xq_stride_expected *= Xq_sizes[i + 1]; + TORCH_CHECK(Xq_strides[i] == Xq_stride_expected, OPERATOR_NAME, + " : Expected Xq argument in row-major layout"); } - TORCH_CHECK(x_scale.is_contiguous(), OPERATOR_NAME, - " : Expected xq scale argument to be contiguous"); - const auto wq_strides = wq.strides(); - TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, OPERATOR_NAME, - " : Expected wq argument in row-major layout"); - TORCH_CHECK(w_scale.is_contiguous(), OPERATOR_NAME, - " : Expected wq scale argument to be contiguous"); - if (bias.numel() > 0) { + TORCH_CHECK(X_scale.is_contiguous(), OPERATOR_NAME, + " : Expected Xq scale argument to be contiguous"); + const auto Wq_strides = Wq.strides(); + TORCH_CHECK(Wq_strides[0] >= 1 && Wq_strides[1] == 1, OPERATOR_NAME, + " : Expected Wq argument in row-major layout"); + TORCH_CHECK(W_scale.is_contiguous(), OPERATOR_NAME, + " : Expected Wq scale argument to be contiguous"); + if (bias_opt.has_value()) { + const auto bias = *bias_opt; const auto bias_strides = bias.strides(); TORCH_CHECK(bias_strides[0] == 1, OPERATOR_NAME, " : Expected bias argument to be contiguous"); @@ -536,42 +544,41 @@ rowwise_scaled_linear_cutlass_check_inputs( // Perform linear operation, using corresponding CUTLASS datatypes // GEMM kernel, to given arguments - result produced is: -// (tensor_a * tensor_a_scale) @ (tensor_b * tensor_b_scale).T + tensor_c -// -// Notes: The "tensor_a" and "tensor_b" are expected to be 2D tensors. -// The "tensor_a_scale" tensor is expected to be a vector, of size -// equal to number of rows of "tensor_a" tensor. The "tensor_b_scale" -// tensor is expected to be a vector, of size equal to number of rows -// of "tensor_b" tensor. The "tensor_c" tensor is expected to be a -// vector, of size equal to number of rows of "tensor_b" tensor. -template +// (Xq * X_scale) @ (Wq * W_scale).T + bias +template at::Tensor rowwise_scaled_linear_cutlass( - const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale, const std::optional& bias_opt, + const std::optional out_dtype_opt) { #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) - // Check inputs. - rowwise_scaled_linear_cutlass_check_inputs( - xq, x_scale, wq, w_scale, bias); + // Check inputs. Note that data types are checked in the + // corresponding dispatch methods. The limitations on data types + // there are mostly to control the number of templates to + // instantiate - the number of data type combinations that could be + // supported is actually much larger. + check_inputs(Xq, X_scale, Wq, W_scale, bias_opt); // Squash the input tensors as appropriate. - const auto xq_sizes = xq.sizes().vec(); - const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); - const auto x_scale_1d = x_scale.reshape({-1}); - const auto w_scale_1d = w_scale.reshape({-1}); + const auto Xq_sizes = Xq.sizes().vec(); + const auto Xq_2d = Xq.reshape({-1, Xq_sizes.back()}); + const auto X_scale_1d = X_scale.reshape({-1}); + const auto W_scale_1d = W_scale.reshape({-1}); // Create result tensor. - at::Tensor result = - x_scale.new_empty({xq_2d.size(0), wq.size(0)}); + const auto options = out_dtype_opt.has_value() + ? X_scale.options().dtype(*out_dtype_opt) + : X_scale.options(); + at::Tensor Y = at::empty({Xq_2d.size(0), Wq.size(0)}, options); // Dispatch to appropriate kernel template. - dispatch_on_tensor_a_scale_and_tensor_b_scale( - xq_2d, x_scale_1d, wq, w_scale_1d, bias, result); + dispatch_on_Y( + Xq_2d, X_scale_1d, Wq, W_scale_1d, bias_opt, Y); // Reshape and return result tensor. - auto result_sizes = xq_sizes; - result_sizes.back() = wq.size(0); - return result.reshape(result_sizes); + auto Y_sizes = Xq_sizes; + Y_sizes.back() = Wq.size(0); + return Y.reshape(Y_sizes); #else TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); return at::Tensor{}; diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu index cc1b5ca123..59f7951a63 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu @@ -6,22 +6,25 @@ namespace torchao { at::Tensor rowwise_scaled_linear_cutlass_s4s4( - const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale, + const std::optional& bias_opt = std::nullopt, + const std::optional out_dtype_opt = std::nullopt) { // Validate input datatypes. - TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, - __func__, " : The input datatypes combination ", xq.dtype(), - " for xq and ", wq.dtype(), " for wq is not supported"); + TORCH_CHECK(Xq.dtype() == at::kChar && Wq.dtype() == at::kChar, + __func__, " : The input datatypes combination ", Xq.dtype(), + " for Xq and ", Wq.dtype(), " for Wq is not supported"); +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) // Dispatch to appropriate kernel template. - #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) - // We get ElementA/ElementB types from the header - return rowwise_scaled_linear_cutlass( - xq, x_scale, wq, w_scale, bias); - #else - TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s4s4 not available"); - return at::Tensor{}; - #endif + using ElementA = cutlass::int4b_t; + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + Xq, X_scale, Wq, W_scale, bias_opt, out_dtype_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu index 29f30d08fc..419eed1073 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu @@ -1,25 +1,28 @@ #include + #include "rowwise_scaled_linear_cutlass.cuh" namespace torchao { at::Tensor rowwise_scaled_linear_cutlass_s8s4( - const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_scale, + const std::optional& bias_opt = std::nullopt, + const std::optional out_dtype_opt = std::nullopt) { // Validate input datatypes. - TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, - __func__, " : The input datatypes combination ", xq.dtype(), - " for xq and ", wq.dtype(), " for wq is not supported"); + TORCH_CHECK(Xq.dtype() == at::kChar && Wq.dtype() == at::kChar, + __func__, " : The input datatypes combination ", Xq.dtype(), + " for Xq and ", Wq.dtype(), " for Wq is not supported"); #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) - // Define ElementA as int8_t since it's a standard type + // Dispatch to appropriate kernel template. using ElementA = int8_t; - // ElementB comes from cutlass header - return rowwise_scaled_linear_cutlass( - xq, x_scale, wq, w_scale, bias); + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + Xq, X_scale, Wq, W_scale, bias_opt, out_dtype_opt); #else - TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s8s4 not available"); + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); return at::Tensor{}; #endif } diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh new file mode 100644 index 0000000000..2b0a6f052b --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh @@ -0,0 +1,496 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12020) +#define BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS +#endif + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass_extensions/common.h" +#endif + +#define OPERATOR_NAME "rowwise_scaled_linear_sparse_cutlass" + +namespace torchao { + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) +template< + typename DtypeXq, + typename DtypeWq, + typename DtypeY, + typename UseBias, + typename DtypeBias, + typename DtypeXScale, + typename DtypeWScale, + typename TileShape, + typename ClusterShape> +void rowwise_scaled_linear_sparse_kernel_cutlass_sm9x( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + // For CUTLASS, sparsified tensor must be the first operand, thus + // the result will be calculated as: + // ((Wq @ Xq.T) * W_scale * X_scale.T + bias.T).T + + using SmArch = cutlass::arch::Sm90; + + // Use CUTLASS naming conventions for naming datatypes. + using ElementA = DtypeWq; + using ElementB = DtypeXq; + using ElementD = DtypeY; + using ElementAScale = DtypeWScale; + using ElementBScale = DtypeXScale; + using ElementBias = DtypeBias; + + using LayoutTagA = cutlass::layout::RowMajor; + using LayoutTagB = cutlass::layout::ColumnMajor; + using LayoutTagD = cutlass::layout::ColumnMajor; + + constexpr auto AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr auto AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr auto AlignmentD = 128 / cutlass::sizeof_bits::value; + + // TODO: use different accumulator datatype if inputs are not float. + using ElementAccumulator = float; + using ElementCompute = float; + + using ProblemShape = cute::Shape; + + // If KernelTmaWarpSpecializedPingpong used for kernel schedule, the + // performance is really bad; on the other side, using + // KernelTmaWarpSpecializedPingpongFP8FastAccum doesn't seem to + // affect the precision much - thus, sticking with it. + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using AScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementAScale>; + using ApplyAScale = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, + Accum, + AScale>; + using BScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementBScale>; + using ApplyBScale = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, + ApplyAScale, + BScale>; + using BiasScalar = + cutlass::epilogue::fusion::Sm90ScalarBroadcast; + using BiasTensor = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementBias>; + using Bias = std::conditional_t; + using ApplyBias = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, ElementCompute, ElementCompute, RoundStyle>, + ApplyBScale, + Bias>; + using EVT = ApplyBias; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + SmArch, cutlass::arch::OpClassSparseTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementD, LayoutTagD, AlignmentD, + ElementD, LayoutTagD, AlignmentD, + EpilogueSchedule, + EVT>::CollectiveOp; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + SmArch, cutlass::arch::OpClassSparseTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + using GemmKernel = enable_3x_kernel_for_sm90_or_later< + cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue>>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideE = StrideA; + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = + typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + const int m = (int)Wq.size(0); + const int n = (int)Xq.size(0); + const int k = (int)Xq.size(1); + + ProblemShape problem_shape(m, n, k, 1); + const auto layout_A = SparseConfig::fill_layoutA(problem_shape); + const auto layout_E = SparseConfig::fill_layoutE(problem_shape); + const auto stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + const auto stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { + (ElementA*)Wq.data_ptr(), layout_A, (ElementB*)Xq.data_ptr(), stride_B, + (ElementE*)W_meta.data_ptr(), layout_E + }, + { + {}, + (ElementD*)Y.data_ptr(), stride_D, (ElementD*)Y.data_ptr(), stride_D + } + }; + + const typename AScale::Arguments A_scale_arguments{ + (ElementAScale*)W_scale.data_ptr(), + ElementAScale(1), + {cute::_1{}, cute::_0{}, cute::_0{}} + }; + const typename BScale::Arguments B_scale_arguments{ + (ElementBScale*)X_scale.data_ptr(), + ElementBScale(1), + {cute::_0{}, cute::_1{}, cute::_0{}} + }; + const auto bias_arguments{ + [&]() -> typename Bias::Arguments { + if constexpr (UseBias::value) { + return { + (ElementBias*)bias.data_ptr(), + ElementBias(0), + {cute::_1{}, cute::_0{}, cute::_0{}} + }; + } else { + return {ElementBias(0)}; + } + }() + }; + arguments.epilogue.thread = { + { + { + {}, // Accum + A_scale_arguments, // AScale + {} // ApplyAScale + }, + B_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, + bias_arguments, // Bias + {} // ApplyBiass + }; + + Gemm gemm_op; + + cutlass::Status status; + + // Verify that GEMM operation with given arguments can be performed + // by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = Xq.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS mixed datatypes GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Perform mixed datatypes GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +static void select_config( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm9x = dprops->major == 9; + + if (is_sm9x) { + if constexpr ((std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value)) { + // TODO: add proper tuning here. + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + rowwise_scaled_linear_sparse_kernel_cutlass_sm9x< + DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template +static void +dispatch_on_bias( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + if (bias.numel() == 0) { + using UseBias = std::false_type; + using DtypeBias = DtypeY; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + using UseBias = std::true_type; + if (bias.scalar_type() == at::ScalarType::Half) { + using DtypeBias = cutlass::half_t; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } else if (bias.scalar_type() == at::ScalarType::BFloat16) { + using DtypeBias = cutlass::bfloat16_t; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported for datatype ", bias.scalar_type(), + " for bias"); +} + +template + static void +dispatch_on_X_scale_and_W_scale( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + TORCH_CHECK(Y.scalar_type() == X_scale.scalar_type(), + OPERATOR_NAME, " : Operator not supported for Y datatype ", + Y.scalar_type(), " as it's different from the first ", + " operand scale datatype ", X_scale.scalar_type()); + + if (X_scale.scalar_type() == at::ScalarType::Half && + W_scale.scalar_type() == at::ScalarType::Half) { + using DtypeXScale = cutlass::half_t; + using DtypeWScale = cutlass::half_t; + using DtypeY = cutlass::half_t; + dispatch_on_bias(Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } else if (X_scale.scalar_type() == at::ScalarType::BFloat16 && + W_scale.scalar_type() == at::ScalarType::BFloat16) { + using DtypeXScale = cutlass::bfloat16_t; + using DtypeWScale = cutlass::bfloat16_t; + using DtypeY = cutlass::bfloat16_t; + dispatch_on_bias(Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported for combination of datatypes ", + X_scale.scalar_type(), " for first operand scale and ", + W_scale.scalar_type(), " for second operand scale"); +} + +template +void +check_inputs( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const at::Tensor& bias) { + // Validate metadata datatype. + TORCH_CHECK(W_meta.dtype() == at::kByte, OPERATOR_NAME, + " : Expected Wq meta argument to be of torch.uint8 datatype got ", + Wq.dtype()); + + // Validate layouts of arguments. + TORCH_CHECK(Xq.dim() >= 2, OPERATOR_NAME, + " : Expected Xq argument to be 2D or higher-dimensional tensor, " + " got ", Xq.dim(), " dims"); + TORCH_CHECK(Xq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Xq argument to be strided, got layout ", + Xq.layout()); + TORCH_CHECK(X_scale.dim() == Xq.dim() - 1, OPERATOR_NAME, + " : Expected Xq scale argument to be ", Xq.dim() - 1, + "D tensor, got ", X_scale.dim(), " dims"); + TORCH_CHECK(X_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Xq scale argument to be strided, got layout ", + X_scale.layout()); + TORCH_CHECK(Wq.dim() == 2, OPERATOR_NAME, + " : Expected Wq argument to be 2D tensor, got ", Wq.dim(), + " dims"); + TORCH_CHECK(Wq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq argument to be strided, got layout ", + Wq.layout()); + TORCH_CHECK(W_meta.dim() == 2, OPERATOR_NAME, + " : Expected Wq meta argument to be 2D tensor, got ", + W_meta.dim(), " dims"); + TORCH_CHECK(W_meta.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq meta argument to be strided, got layout ", + W_meta.layout()); + TORCH_CHECK(W_scale.dim() == 1 || W_scale.dim() == 2, OPERATOR_NAME, + " : Expected Wq scale argument to be 1D or 2D tensor, ", + "got ", W_scale.dim(), " dims"); + TORCH_CHECK(W_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq scale argument to be strided, got layout ", + W_scale.layout()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dim() == 1, OPERATOR_NAME, + " : Expected bias argument to be 1D tensor, got ", bias.dim(), + " dims"); + TORCH_CHECK(bias.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected bias argument to be strided, got layout ", + bias.layout()); + } + + // Validate sizes of arguments. + const auto Xq_sizes = Xq.sizes().vec(); + const auto Wq_sizes = Wq.sizes().vec(); + TORCH_CHECK(Xq_sizes.back() % 32 == 0, OPERATOR_NAME, + " : For alignment purpose, Xq argument must have number of " + "columns divisible by ", 32, ", got ", Xq_sizes.back(), + " columns"); + TORCH_CHECK(Wq_sizes[0] % 8 == 0, OPERATOR_NAME, + " : For alignment purpose, Wq argument to have number of rows " + "divisible by ", 8, ", but got ", Wq_sizes[0], " rows"); + TORCH_CHECK(Xq_sizes.back() == 2 * Wq_sizes[1], OPERATOR_NAME, + " : Expected Xq argument to have ", 2 * Wq_sizes[1], + " columns, but got ", Xq_sizes.back()); + const auto X_scale_sizes = X_scale.sizes().vec(); + for (auto i = 0; i < X_scale_sizes.size(); ++i) + TORCH_CHECK(X_scale_sizes[i] == Xq_sizes[i], OPERATOR_NAME, + " : Expected Xq scale argument size at position ", i, " to be ", + Xq_sizes[i], ", but got ", X_scale_sizes[i]); + TORCH_CHECK(Wq_sizes[1] % 8 == 0, OPERATOR_NAME, + " : Expected Wq argument to have number of columns divisible by ", + " 8, got ", Wq_sizes[1]); + // W_meta may be padded, thus expected shape calculations for this + // tensor are as follows. + const auto W_meta_size_0_expected = std::max((int)Wq_sizes[0], 64); + const auto W_meta_size_1_expected = std::max((int)Wq_sizes[1] / 4, 16); + TORCH_CHECK(W_meta.size(0) == W_meta_size_0_expected, OPERATOR_NAME, + " : Expected Wq meta argument to have ", W_meta_size_0_expected, + " rows, got ", W_meta.size(0), " rows"); + TORCH_CHECK(W_meta.size(1) == W_meta_size_1_expected, OPERATOR_NAME, + " : Expected Wq meta argument to hold ", W_meta_size_1_expected, + " bytes per row to encode sparsity of Wq argument, got ", + W_meta.size(1), " bytes"); + TORCH_CHECK(W_scale.numel() == Wq_sizes[0], OPERATOR_NAME, + " : Expected Wq scale argument to have ", Wq_sizes[0], + " elements, got ", W_scale.numel(), " elements"); + if (bias.numel() > 0) { + TORCH_CHECK(bias.numel() == Wq_sizes[0], OPERATOR_NAME, + " : Expected bias argument to have ", Wq_sizes[0], + " elements, got ", bias.numel(), " elements"); + } + + // Validate strides of arguments. + const auto Xq_strides = Xq.strides(); + TORCH_CHECK(Xq_strides[Xq_strides.size() - 1] == 1, OPERATOR_NAME, + " : Expected Xq argument in row-major layout"); + auto Xq_stride_expected = Xq_strides[Xq_strides.size() - 2]; + for (int i = Xq_strides.size() - 3; i >= 0; --i) { + Xq_stride_expected *= Xq_sizes[i + 1]; + TORCH_CHECK(Xq_strides[i] == Xq_stride_expected, OPERATOR_NAME, + " : Expected Xq argument in row-major layout"); + } + TORCH_CHECK(X_scale.is_contiguous(), OPERATOR_NAME, + " : Expected Xq scale argument to be contiguous"); + const auto Wq_strides = Wq.strides(); + TORCH_CHECK(Wq_strides[0] >= 1 && Wq_strides[1] == 1, OPERATOR_NAME, + " : Expected Wq argument in row-major layout"); + const auto W_meta_strides = W_meta.strides(); + TORCH_CHECK(W_meta_strides[0] >= 1 && W_meta_strides[1] == 1, OPERATOR_NAME, + " : Expected Wq meta argument in row-major layout"); + TORCH_CHECK(W_scale.is_contiguous(), OPERATOR_NAME, + " : Expected Wq scale argument to be contiguous"); + if (bias.numel() > 0) { + const auto bias_strides = bias.strides(); + TORCH_CHECK(bias_strides[0] == 1, OPERATOR_NAME, + " : Expected bias argument to be contiguous"); + } +} +#endif + +template +at::Tensor +rowwise_scaled_linear_sparse_cutlass( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + // Create bias tensor. + const auto bias = bias_opt.has_value() ? *bias_opt : at::Tensor{}; + + // Check inputs. Note that data types are checked in the + // corresponding dispatch methods. The limitations on data types + // there are mostly to control the number of templates to + // instantiate - the number of data type combinations that could be + // supported is actually much larger. + check_inputs(Xq, X_scale, Wq, W_meta, W_scale, bias); + + // Squash the input tensors as appropriate. + const auto Xq_sizes = Xq.sizes().vec(); + const auto Xq_2d = Xq.reshape({-1, Xq_sizes.back()}); + const auto X_scale_1d = X_scale.reshape({-1}); + const auto W_scale_1d = W_scale.reshape({-1}); + + // Create result tensor. + at::Tensor Y = X_scale.new_empty({Xq_2d.size(0), Wq.size(0)}); + + // Dispatch to appropriate kernel template. + dispatch_on_X_scale_and_W_scale( + Xq_2d, X_scale_1d, Wq, W_meta, W_scale_1d, bias, Y); + + // Reshape and return Y tensor. + auto Y_sizes = Xq_sizes; + Y_sizes.back() = Wq.size(0); + return Y.reshape(Y_sizes); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu new file mode 100644 index 0000000000..c2da7b6dc3 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu @@ -0,0 +1,28 @@ +#include "rowwise_scaled_linear_sparse_cutlass.cuh" +#include "rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e4m3e4m3( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { + // Validate input datatypes. + TORCH_CHECK( + Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e4m3fn, + __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", + Wq.dtype(), " for Wq is not supported"); + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + using DtypeXq = cutlass::float_e4m3_t; + using DtypeWq = cutlass::float_e4m3_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h new file mode 100644 index 0000000000..1ccf780487 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e4m3e4m3( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt); + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu new file mode 100644 index 0000000000..dcc92afa71 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu @@ -0,0 +1,28 @@ +#include "rowwise_scaled_linear_sparse_cutlass.cuh" +#include "rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e4m3e5m2( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { + // Validate input datatypes. + TORCH_CHECK( + Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e5m2, + __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", + Wq.dtype(), " for Wq is not supported"); + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + using DtypeXq = cutlass::float_e4m3_t; + using DtypeWq = cutlass::float_e5m2_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h new file mode 100644 index 0000000000..8ca6b4bd37 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e4m3e5m2( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt); + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu new file mode 100644 index 0000000000..185ff8586c --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu @@ -0,0 +1,28 @@ +#include "rowwise_scaled_linear_sparse_cutlass.cuh" +#include "rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e5m2e4m3( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { + // Validate input datatypes. + TORCH_CHECK( + Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e4m3fn, + __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", + Wq.dtype(), " for Wq is not supported"); + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + using DtypeXq = cutlass::float_e5m2_t; + using DtypeWq = cutlass::float_e4m3_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h new file mode 100644 index 0000000000..61e04ab051 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e5m2e4m3( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt); + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu new file mode 100644 index 0000000000..bda18ea86d --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu @@ -0,0 +1,28 @@ +#include "rowwise_scaled_linear_sparse_cutlass.cuh" +#include "rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e5m2e5m2( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { + // Validate input datatypes. + TORCH_CHECK( + Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e5m2, + __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", + Wq.dtype(), " for Wq is not supported"); + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + using DtypeXq = cutlass::float_e5m2_t; + using DtypeWq = cutlass::float_e5m2_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h new file mode 100644 index 0000000000..4fee072f16 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e5m2e5m2( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt); + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu new file mode 100644 index 0000000000..7a651cb3dd --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu @@ -0,0 +1,48 @@ +#include + +#include "rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h" +#include "rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h" +#include "rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h" +#include "rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_f8f8( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt = std::nullopt) { + // Validate input datatypes. + TORCH_CHECK( + (Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e4m3fn) || + (Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e5m2) || + (Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e4m3fn) || + (Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e5m2), + __func__, " : The input datatypes combination ", Xq.dtype(), + " for Xq and ", Wq.dtype(), " for Wq is not supported"); + + // Dispatch to appropriate kernel template. + if (Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e4m3fn) { + return rowwise_scaled_linear_sparse_cutlass_e4m3e4m3( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); + } else if (Xq.dtype() == at::kFloat8_e4m3fn && + Wq.dtype() == at::kFloat8_e5m2) { + return rowwise_scaled_linear_sparse_cutlass_e4m3e5m2( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); + } else if (Xq.dtype() == at::kFloat8_e5m2 && + Wq.dtype() == at::kFloat8_e4m3fn) { + return rowwise_scaled_linear_sparse_cutlass_e5m2e4m3( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); + } else if (Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e5m2) { + return rowwise_scaled_linear_sparse_cutlass_e5m2e5m2( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); + } + return at::Tensor{}; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_sparse_cutlass_f8f8", + &rowwise_scaled_linear_sparse_cutlass_f8f8); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh new file mode 100644 index 0000000000..24a9371593 --- /dev/null +++ b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh @@ -0,0 +1,174 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12020) +#define BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X +#endif + +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass_extensions/common.h" +#endif + +#define OPERATOR_NAME "to_sparse_semi_structured_cutlass_sm9x" + +namespace torchao { + +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) +template +std::tuple +to_sparse_semi_structured_kernel_cutlass_sm9x(const at::Tensor& W) { + // The kernel doesn't check, but assumes instead, that the input + // tensor is a structured sparse tensor. + + static_assert(std::is_same_v || + std::is_same_v); + + using SmArch = cutlass::arch::Sm90; + + using ProblemShape = cute::Shape; + + using LayoutTagW = cutlass::layout::RowMajor; + using StrideW = cutlass::gemm::TagToStrideA_t; + + using DtypeMeta = unsigned char; + + // CUTLASS derives the sparse config from the mainloop. In order + // not to instantiate the whole mainloop here, the sparse config for + // FP8 case is hard-coded below. The config template arguments are + // found by changing input data types for CUTLASS example 62 to FP8, + // and then printing out the sparse config data type. + using SparseConfig = cutlass::Sm90GemmSparseConfig< + cute::sparse_elem<2, DtypeW>, + cute::GMMA::Major::K, + cute::sparse_elem<8, unsigned char>, + cute::_128>; + + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, DtypeW, LayoutTagW, SparseConfig>; + using CompressorKernel = enable_3x_kernel_for_sm90_or_later< + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, DtypeW, LayoutTagW, SparseConfig, SmArch>>; + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; + + const int m = W.size(0); + const int k = W.size(1); + + ProblemShape problem_shape(m, 1, k, 1); + + StrideW stride_W = + cutlass::make_cute_packed_stride(StrideW{}, cute::make_shape(m, k, 1)); + CompressorUtility compressor_utility(problem_shape, stride_W); + int k_compressed = compressor_utility.get_tensorA_k_physical(); + int m_meta = compressor_utility.get_metadata_m_physical(); + int k_meta = compressor_utility.get_metadata_k_physical(); + + // Create result tensors. + at::Tensor W_compressed = W.new_empty({m, k_compressed}); + at::Tensor W_meta = + W.new_empty({m_meta, k_meta}, at::TensorOptions().dtype(at::kByte)); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + typename Compressor::Arguments arguments{ + problem_shape, + { + (DtypeW*)W.data_ptr(), stride_W, (DtypeW*)W_compressed.data_ptr(), + (DtypeMeta*)W_meta.data_ptr() + }, + {hw_info}}; + + Compressor compressor_op; + + cutlass::Status status; + + // Verify that compression operation with given arguments can be + // performed by CUTLASS. + status = compressor_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Allocate workspace for the compressor. + const auto workspace_size = Compressor::get_workspace_size(arguments); + auto workspace = W.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize compressor. + status = compressor_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Perform compression. + status = compressor_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(W_compressed, W_meta); +} + +template +void +to_sparse_semi_structured_cutlass_sm9x_check_inputs(const at::Tensor& W) { + // Validate the input tensor layout. + TORCH_CHECK(W.dim() == 2, OPERATOR_NAME, + " : Expected W argument to be 2D tensor, got ", W.dim(), + " dims"); + TORCH_CHECK(W.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected W argument to be strided, got layout ",W.layout()); + + // Validate the input tensor shape. + const auto W_sizes = W.sizes().vec(); + TORCH_CHECK(W_sizes[1] % 8 == 0, OPERATOR_NAME, + " : Expected number of columns of the W argument to be ", + "divisible by 8, got ", W_sizes[1], " columns"); + + // Validate the input tensor strides. + const auto W_strides = W.strides(); + TORCH_CHECK(W_strides[1] == 1, OPERATOR_NAME, + " : Expected W argument in row-major layout"); +} +#endif + +template +std::tuple +to_sparse_semi_structured_cutlass_sm9x(const at::Tensor& W) { +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm9x = dprops->major == 9; + if (!is_sm9x) { + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); + } + + // Check inputs. + to_sparse_semi_structured_cutlass_sm9x_check_inputs(W); + + // Call the kernel. + return to_sparse_semi_structured_kernel_cutlass_sm9x(W); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return std::make_tuple(at::Tensor{}, at::Tensor{}); +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu new file mode 100644 index 0000000000..1a4ab285b6 --- /dev/null +++ b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu @@ -0,0 +1,35 @@ +#include + +#include "to_sparse_semi_structured_cutlass_sm9x.cuh" + +namespace torchao { + +std::tuple +to_sparse_semi_structured_cutlass_sm9x_f8(const at::Tensor& W) { + // Validate input datatypes. + TORCH_CHECK(W.dtype() == at::kFloat8_e5m2 || W.dtype() == at::kFloat8_e4m3fn, + __func__, " : The input datatype ", W.dtype(), + " is not supported"); + +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) + // Dispatch to appropriate kernel template. + if (W.dtype() == at::kFloat8_e5m2) { + using DtypeW = cutlass::float_e5m2_t; + return to_sparse_semi_structured_cutlass_sm9x(W); + } else if (W.dtype() == at::kFloat8_e4m3fn) { + using DtypeW = cutlass::float_e4m3_t; + return to_sparse_semi_structured_cutlass_sm9x(W); + } + return std::tuple(at::Tensor{}, at::Tensor{}); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return std::tuple(at::Tensor{}, at::Tensor{}); +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::to_sparse_semi_structured_cutlass_sm9x_f8", + &to_sparse_semi_structured_cutlass_sm9x_f8); +} + +} // namespace torchao diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 9cbd4cd2a0..9224f8efa3 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -9,6 +9,7 @@ to_affine_quantized_intx_static, ) from .floatx import ( + CutlassSemiSparseLayout, Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 @@ -52,4 +53,5 @@ "MarlinQQQLayout", "Int4CPULayout", "CutlassInt4PackedLayout", + "CutlassSemiSparseLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 54f4a72811..3719bccb89 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -6,6 +6,10 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, ) +from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( + _linear_fp8_act_fp8_weight_sparse_cutlass_check, + _linear_fp8_act_fp8_weight_sparse_cutlass_impl, +) from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl, @@ -161,6 +165,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, ), + ( + _linear_fp8_act_fp8_weight_sparse_cutlass_check, + _linear_fp8_act_fp8_weight_sparse_cutlass_impl, + ), ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 3f0a1ccd5c..7e634a5211 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,3 +1,6 @@ +from .cutlass_semi_sparse_layout import ( + CutlassSemiSparseLayout, +) from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( FloatxTensorCoreLayout, @@ -10,4 +13,5 @@ "to_scaled_tc_floatx", "from_scaled_tc_floatx", "Float8Layout", + "CutlassSemiSparseLayout", ] diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py new file mode 100644 index 0000000000..3b81d9a021 --- /dev/null +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -0,0 +1,178 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, + to_sparse_semi_structured_cutlass_sm9x_f8, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class CutlassSemiSparseLayout(Layout): + """Layout class for float8 2:4 sparsity layout for affine quantized tensor, for cutlass kernel.""" + + def pre_process(self, dense: torch.Tensor) -> torch.Tensor: + # prune to 2:4 if not already + from torchao.sparsity.utils import mask_creator + + return dense * mask_creator(dense).bool() + + +@register_layout(CutlassSemiSparseLayout) +class CutlassSemiSparseTensorImpl(AQTTensorImpl): + @staticmethod + def __new__( + cls, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = sparse.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else sparse.layout + ) + kwargs["dtype"] = sparse.dtype + kwargs["requires_grad"] = False + shape = (sparse.shape[0], 2 * sparse.shape[-1]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + self.sparse = sparse + self.meta = meta + self.scale = scale + self._layout = _layout + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["sparse", "meta", "scale"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + sparse = tensor_data_dict["sparse"] + meta = tensor_data_dict["meta"] + scale = tensor_data_dict["scale"] + (_layout,) = tensor_attributes + return cls(sparse, meta, scale, _layout) + + def get_plain(self): + # No support in CUTLASS to convert back to dense from sparse + # semi-structured format, so multiplying with identity matrix + # for the conversion. + cols = 2 * self.sparse.shape[1] + input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) + input_scale = torch.ones( + (cols,), dtype=self.scale.dtype, device=self.sparse.device + ) + dense = ( + rowwise_scaled_linear_sparse_cutlass_f8f8( + input, input_scale, self.sparse, self.meta, self.scale + ) + .t() + .contiguous() + ) + + return dense, self.scale, None + + @classmethod + def from_plain( + cls, + dense: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert zero_point is None or torch.all(zero_point == 0) + + # FIXME: remove this when CUTLASS PR #2110 merged. + dtype = dense.dtype + dense = dense.view(torch.uint8) + dense[dense == 128] = 0 + dense = dense.view(dtype) + + sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(dense) + return cls( + sparse, + meta, + scale, + _layout, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.sparse = fn(self.sparse) + self.meta = fn(self.meta) + self.scale = fn(self.scale) + return self + + +def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import Float8Layout + + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, Float8Layout) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassSemiSparseLayout) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and len(weight_tensor.tensor_impl.scale.shape) == 1 + and (bias is None or bias.dtype == input_tensor.dtype) + and (bias is None or len(bias.shape) == 1) + ) + + +def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 + + input = input_tensor.tensor_impl.float8_data + input_scale = input_tensor.tensor_impl.scale + weight = weight_tensor.tensor_impl.sparse + weight_meta = weight_tensor.tensor_impl.meta + weight_scale = weight_tensor.tensor_impl.scale + + out = rowwise_scaled_linear_sparse_cutlass_f8f8( + input, input_scale, weight, weight_meta, weight_scale, bias + ) + + return out diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index ae8ea78ceb..e5b337dd23 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -11,9 +11,9 @@ register_layout, ) from torchao.dtypes.uintx.plain_layout import ( - _aqt_is_int8_reduced_range, + _aqt_is_int8, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout aten = torch.ops.aten @@ -82,9 +82,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["int_data", "scale"], [ - self._layout, - ] + return ["int_data", "scale"], [self._layout] @classmethod def __tensor_unflatten__( @@ -92,13 +90,13 @@ def __tensor_unflatten__( ): int_data = tensor_data_dict["int_data"] scale = tensor_data_dict["scale"] - _layout = tensor_attributes + (_layout,) = tensor_attributes return cls(int_data, scale, _layout) def get_plain(self): int_data = torch.stack( - ((self.int_data << 4) >> 4, self.int_data >> 4), dim=2 - ).view((self.int_data.shape[0], 2 * self.int_data.shape[1])) + ((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1 + ).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],)) return int_data, self.scale, None @classmethod @@ -110,8 +108,7 @@ def from_plain( _layout: Layout, ): assert zero_point is None or torch.all(zero_point == 0) - - int_data_s4 = ((int_data[:, 1::2] & 0xF) << 4) | (int_data[:, 0::2] & 0xF) + int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) return cls( int_data_s4, scale, @@ -130,16 +127,18 @@ def _apply_fn_to_data(self, fn): def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(input_tensor._layout, PlainLayout) + and _aqt_is_int8(input_tensor) and input_tensor.dtype in (torch.float16, torch.bfloat16) and len(input_tensor.shape) >= 2 - and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and input_tensor.tensor_impl.scale.dtype == torch.float32 and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(weight_tensor) and weight_tensor.dtype == input_tensor.dtype and len(weight_tensor.shape) == 2 - and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and weight_tensor.tensor_impl.scale.dtype == torch.float32 and len(weight_tensor.tensor_impl.scale.shape) == 1 and (bias is None or bias.dtype == input_tensor.dtype) and (bias is None or len(bias.shape) == 1) @@ -153,9 +152,10 @@ def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias) weight_scale = weight_tensor.tensor_impl.scale input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale + out_dtype = input_tensor.dtype out = rowwise_scaled_linear_cutlass_s8s4( - input, input_scale, weight, weight_scale, bias + input, input_scale, weight, weight_scale, bias, out_dtype ) return out @@ -164,16 +164,18 @@ def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias) def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(input_tensor) and input_tensor.dtype in (torch.float16, torch.bfloat16) and len(input_tensor.shape) >= 2 - and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and input_tensor.tensor_impl.scale.dtype == torch.float32 and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(weight_tensor) and weight_tensor.dtype == input_tensor.dtype and len(weight_tensor.shape) == 2 - and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and weight_tensor.tensor_impl.scale.dtype == torch.float32 and len(weight_tensor.tensor_impl.scale.shape) == 1 ) @@ -185,9 +187,10 @@ def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias) weight_scale = weight_tensor.tensor_impl.scale input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale + out_dtype = input_tensor.dtype out = rowwise_scaled_linear_cutlass_s4s4( - input, input_scale, weight, weight_scale, bias + input, input_scale, weight, weight_scale, bias, out_dtype ) return out diff --git a/torchao/ops.py b/torchao/ops.py index a3aee761b9..b8e5fd47e0 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,4 +1,5 @@ import functools +from typing import Optional import torch from torch import Tensor @@ -22,10 +23,16 @@ "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" ) lib.define( - "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" + "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor" ) lib.define( - "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" + "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor" +) +lib.define( + "rowwise_scaled_linear_sparse_cutlass_f8f8(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_meta, Tensor weight_scale, Tensor? bias=None) -> Tensor" +) +lib.define( + "to_sparse_semi_structured_cutlass_sm9x_f8(Tensor weight) -> (Tensor, Tensor)" ) lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") @@ -536,7 +543,8 @@ def rowwise_scaled_linear_cutlass_s8s4( input_scale: Tensor, weight: Tensor, weight_scale: Tensor, - bias: Tensor, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> Tensor: """ CUTLASS-based row-wise scaled W4A8 linear operator. @@ -545,13 +553,19 @@ def rowwise_scaled_linear_cutlass_s8s4( input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. weight: quantized weight matrix, in row-major layout. weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). - bias: a vector of size equal to number of rows of weight tensor, or None. + bias: an optional vector of size equal to number of rows of weight tensor, or None. + out_dtype: optional data type for output tensor. Returns: output: result tensor, in row-major layout. """ return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default( - input, input_scale, weight, weight_scale, bias + input, + input_scale, + weight, + weight_scale, + bias, + out_dtype, ) @@ -561,16 +575,15 @@ def _( input_scale: Tensor, weight: Tensor, weight_scale: Tensor, - bias: Tensor, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> Tensor: # No checks here, as detailed checks are performed by the # operator itself. - return torch.empty( - (*input.shape[:-1], weight.shape[0]), - dtype=input_scale.dtype, - device=input.device, - ) + dtype = out_dtype if out_dtype is not None else input_scale.dtype + device = input.device + return torch.empty((*input.shape[:-1], weight.shape[0]), dtype=dtype, device=device) def rowwise_scaled_linear_cutlass_s4s4( @@ -578,7 +591,8 @@ def rowwise_scaled_linear_cutlass_s4s4( input_scale: Tensor, weight: Tensor, weight_scale: Tensor, - bias: Tensor, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> Tensor: """ CUTLASS-based row-wise scaled W4A4 linear operator. @@ -587,13 +601,14 @@ def rowwise_scaled_linear_cutlass_s4s4( input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. weight: quantized weight matrix, in row-major layout. weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). - bias: a vector of size equal to number of rows of weight tensor, or None. + bias: an optional vector of size equal to number of rows of weight tensor, or None. + out_dtype: optional data type for output tensor. Returns: output: result tensor, in row-major layout. """ return torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4.default( - input, input_scale, weight, weight_scale, bias + input, input_scale, weight, weight_scale, bias, out_dtype ) @@ -603,9 +618,86 @@ def _( input_scale: Tensor, weight: Tensor, weight_scale: Tensor, - bias: Tensor, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> Tensor: + # No checks here, as detailed checks are performed by the + # operator itself. + + dtype = out_dtype if out_dtype is not None else input_scale.dtype + device = input.device + return torch.empty((*input.shape[:-1], weight.shape[0]), dtype=dtype, device=device) + + +def rowwise_scaled_linear_sparse_cutlass_f8f8( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_meta: Tensor, + weight_scale: Tensor, + bias: Optional[Tensor] = None, +) -> Tensor: + """ + CUTLASS-based row-wise scaled F8F8 linear operator, for sparsified weight case. + Args: + input: quantized input tensor, in row-major layout. + input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. + weight: sparsified quantized weight matrix, in row-major layout. + weight_meta: sparsify metadata for weight tensor. + weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). + bias: an optional vector of size equal to number of rows of weight tensor, or None. + Returns: + output: result tensor, in row-major layout. + """ + + return torch.ops.torchao.rowwise_scaled_linear_sparse_cutlass_f8f8.default( + input, input_scale, weight, weight_meta, weight_scale, bias + ) + + +@register_custom_op("torchao::rowwise_scaled_linear_sparse_cutlass_f8f8") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_meta: Tensor, + weight_scale: Tensor, + bias: Optional[Tensor] = None, ) -> Tensor: - return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) + # No checks here, as detailed checks are performed by the + # operator itself. + + dtype = input_scale.dtype + device = input.device + return torch.empty((*input.shape[:-1], weight.shape[0]), dtype=dtype, device=device) + + +def to_sparse_semi_structured_cutlass_sm9x_f8( + weight: Tensor, +) -> (Tensor, Tensor): + """ + CUTLASS-based conversion from sparsified input tensor to corresponding compressed tensor, along with corresponding metadata tensor. + Args: + weight: input tensor, in row-major layout. + Returns: + weight_compressed: compressed weight tensor, with sparsity eliminated, in row-major layout. + weight_meta: metadata tensor, describing the sparsity structure of the input tensor, also in row-major layout. + """ + + return torch.ops.torchao.to_sparse_semi_structured_cutlass_sm9x_f8.default(weight) + + +@register_custom_op("torchao::to_sparse_semi_structured_cutlass_sm9x_f8") +def _( + weight: Tensor, +) -> (Tensor, Tensor): + # No checks here, as detailed checks are performed by the + # operator itself. + + return ( + weight.new_empty(weight[0], weight[1] // 2), + weight.new_empty(weight[0], max(weight[1] // 8, 16), dtype=torch.char), + ) @functools.lru_cache() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 5f15a6bbbe..b816eb585e 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -46,6 +46,7 @@ AffineQuantizedObserverBase, ) from .quant_api import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, @@ -138,6 +139,7 @@ "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", "Float8StaticActivationFloat8WeightConfig", + "Float8DynamicActivationFloat8SemiSparseWeightConfig", "UIntXWeightOnlyConfig", "FPXWeightOnlyConfig", "GemliteUIntXWeightOnlyConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 0e7cda16f0..74a83de58a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -30,6 +30,7 @@ from torchao.dtypes import ( AffineQuantizedTensor, CutlassInt4PackedLayout, + CutlassSemiSparseLayout, Float8Layout, Int4CPULayout, MarlinQQQLayout, @@ -122,6 +123,7 @@ "float8_static_activation_float8_weight", "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer", + "Float8DynamicActivationFloat8SemiSparseWeightConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -654,7 +656,7 @@ def _int8_dynamic_activation_int4_weight_transform( if isinstance(layout, MarlinQQQLayout): input_quant_func = _int8_symm_per_token_quant elif isinstance(layout, CutlassInt4PackedLayout): - input_quant_func = _int8_symm_per_token_reduced_range_quant_cutlass + input_quant_func = _int8_symm_per_token_quant_cutlass else: input_quant_func = _int8_symm_per_token_quant else: @@ -664,6 +666,8 @@ def _int8_dynamic_activation_int4_weight_transform( weight = to_marlinqqq_quantized_intx( weight, block_size, quant_min, quant_max, _layout=layout ) + elif isinstance(layout, CutlassInt4PackedLayout): + weight = _int4_symm_per_token_quant_cutlass(weight) else: weight = to_affine_quantized_intx( weight, @@ -718,17 +722,7 @@ def _int4_dynamic_activation_int4_weight_transform( if act_mapping_type != MappingType.SYMMETRIC: raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") - weight = to_affine_quantized_intx( - weight, - mapping_type=mapping_type, - block_size=(1, weight.shape[1]), - target_dtype=torch.int8, - quant_min=-8, - quant_max=7, - eps=torch.finfo(torch.float32).eps, - zero_point_domain=ZeroPointDomain.NONE, - _layout=layout, - ) + weight = _int4_symm_per_token_quant_cutlass(weight) weight = to_linear_activation_quantized( weight, _int4_symm_per_token_quant_cutlass, @@ -972,24 +966,15 @@ def _int8_symm_per_token_reduced_range_quant_noop_decode( ) -def _int8_symm_per_token_reduced_range_quant_cutlass( - x: torch.Tensor, -) -> torch.Tensor: - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = 1e-5 - quant_min = -127 - quant_max = 127 +def _int8_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: return to_affine_quantized_intx( x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - eps=eps, + mapping_type=MappingType.SYMMETRIC, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int8, + scale_dtype=torch.float32, + eps=torch.finfo(torch.float32).eps, zero_point_domain=ZeroPointDomain.NONE, - quant_min=quant_min, - quant_max=quant_max, - scale_dtype=torch.float16 if x.dtype == torch.float16 else None, ) @@ -1001,7 +986,8 @@ def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: target_dtype=torch.int8, quant_min=-8, quant_max=7, - eps=1e-5, + scale_dtype=torch.float32, + eps=torch.finfo(torch.float32).eps, zero_point_domain=ZeroPointDomain.NONE, _layout=CutlassInt4PackedLayout(), ) @@ -1327,6 +1313,69 @@ def _float8_dynamic_activation_float8_weight_transform( return module +@dataclass +class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig): + """ + Applies float8 dynamic quantization to activations and float8 quantization followed by compression to sparse semi-structured tensor to weights of linear layers. + + Args: + `layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment. + `activation_dtype`: data type for quantized activation tensor. + `weight_dtype`: data type for quantized weight tensor. + """ + + layout: Layout = CutlassSemiSparseLayout() + activation_dtype: torch.dtype = torch.float8_e5m2 + weight_dtype: torch.dtype = torch.float8_e4m3fn + + +def _float8_dynamic_activation_quant_func( + input: torch.Tensor, + activation_dtype: torch.dtype, +): + return to_affine_quantized_floatx( + input_float=input, + target_dtype=activation_dtype, + block_size=_get_per_token_block_size(input), + _layout=Float8Layout(mm_config=None), + ) + + +@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig) +def _float8_dynamic_activation_float8_semi_sparse_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationFloat8SemiSparseWeightConfig +): + assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0" + + layout = config.layout + if not isinstance(layout, CutlassSemiSparseLayout): + raise NotImplementedError( + f"Only CutlassSemiSparseLayout layout is supported. Received {layout}." + ) + + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + weight = module.weight + + weight_sparse = to_affine_quantized_floatx( + input_float=weight, + target_dtype=weight_dtype, + block_size=_get_per_token_block_size(weight), + _layout=CutlassSemiSparseLayout(), + ) + + input_quant_func = _float8_dynamic_activation_quant_func + input_quant_kwargs = {"activation_dtype": activation_dtype} + + weight = to_linear_activation_quantized( + weight_sparse, input_quant_func, quant_kwargs=input_quant_kwargs + ) + + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + @dataclass class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): """ @@ -1552,8 +1601,8 @@ def _fpx_weight_only_transform( [ _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, - _int8_symm_per_token_reduced_range_quant_cutlass, - _int4_symm_per_token_quant_cutlass, _input_activation_quant_func_fp8, + _int4_symm_per_token_quant_cutlass, + _int8_symm_per_token_quant_cutlass, ] )