From a7197f723d01aba34f90d28a43eb4b2834204ec4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Tue, 4 Feb 2025 21:01:48 +0100 Subject: [PATCH] Add CUTLASS-based row-wise scaled sparse FP8 kernel --- ...benchmark_rowwise_scaled_linear_cutlass.py | 11 +- ...rk_rowwise_scaled_linear_sparse_cutlass.py | 61 +++ docs/source/api_ref_dtypes.rst | 1 + setup.py | 56 ++ test/test_rowwise_scaled_linear_cutlass.py | 9 +- ...st_rowwise_scaled_linear_sparse_cutlass.py | 137 +++++ torchao/_models/llama/generate.py | 62 ++- .../rowwise_scaled_linear_cutlass.cuh | 9 +- .../rowwise_scaled_linear_cutlass_s4s4.cu | 20 +- .../rowwise_scaled_linear_cutlass_s8s4.cu | 14 +- .../rowwise_scaled_linear_sparse_cutlass.cuh | 493 ++++++++++++++++++ ...e_scaled_linear_sparse_cutlass_e4m3e4m3.cu | 28 + ...se_scaled_linear_sparse_cutlass_e4m3e4m3.h | 14 + ...e_scaled_linear_sparse_cutlass_e4m3e5m2.cu | 28 + ...se_scaled_linear_sparse_cutlass_e4m3e5m2.h | 14 + ...e_scaled_linear_sparse_cutlass_e5m2e4m3.cu | 28 + ...se_scaled_linear_sparse_cutlass_e5m2e4m3.h | 14 + ...e_scaled_linear_sparse_cutlass_e5m2e5m2.cu | 28 + ...se_scaled_linear_sparse_cutlass_e5m2e5m2.h | 14 + ...wwise_scaled_linear_sparse_cutlass_f8f8.cu | 48 ++ ...to_sparse_semi_structured_cutlass_sm9x.cuh | 174 +++++++ ..._sparse_semi_structured_cutlass_sm9x_f8.cu | 35 ++ torchao/dtypes/__init__.py | 2 + torchao/dtypes/affine_quantized_tensor_ops.py | 8 + torchao/dtypes/floatx/__init__.py | 4 + .../floatx/cutlass_semi_sparse_layout.py | 178 +++++++ .../uintx/cutlass_int4_packed_layout.py | 23 +- torchao/ops.py | 92 +++- torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 107 ++-- 30 files changed, 1622 insertions(+), 92 deletions(-) create mode 100644 benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py create mode 100644 test/test_rowwise_scaled_linear_sparse_cutlass.py create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu create mode 100644 torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh create mode 100644 torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu create mode 100644 torchao/dtypes/floatx/cutlass_semi_sparse_layout.py diff --git a/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py index c4c9c099be..ba31bc288d 100644 --- a/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py +++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py @@ -23,9 +23,8 @@ def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): -128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev ) B_scale = torch.randn((n,), dtype=torch.half, device=dev) - C = None - return A, A_scale, B, B_scale, C + return A, A_scale, B, B_scale def benchmark(m: int, k: int, n: int): @@ -34,14 +33,14 @@ def benchmark(m: int, k: int, n: int): B_ref = torch.randn((n, k), dtype=torch.half, device=dev) fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) - A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4) + A, A_scale, B, B_scale = get_problem(m, n, k, 8, 4) rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( - rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C + rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale ) - A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4) + A, A_scale, B, B_scale = get_problem(m, n, k, 4, 4) rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds( - rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C + rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale ) return { diff --git a/benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py new file mode 100644 index 0000000000..5c544daedb --- /dev/null +++ b/benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py @@ -0,0 +1,61 @@ +import pandas as pd +import torch +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, + to_sparse_semi_structured_cutlass_sm9x_f8, +) + + +def benchmark_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 + + +def get_problem(m: int, n: int, k: int): + dev = torch.device("cuda") + + A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2) + A_scale = torch.randn((m,), dtype=torch.half, device=dev) + B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn) + B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B) + B_scale = torch.randn((n,), dtype=torch.half, device=dev) + + return A, A_scale, B_sp, B_meta, B_scale + + +def benchmark(m: int, k: int, n: int): + dev = torch.device("cuda") + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B_ref = torch.randn((n, k), dtype=torch.half, device=dev) + fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) + + A, A_scale, B_sp, B_meta, B_scale = get_problem(m, n, k) + rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds( + rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale + ) + + return { + "m": m, + "k": k, + "n": n, + "fp16_latency (ms)": fp16_time, + "rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time, + "f8f8 speedup (d/s)": fp16_time + / rowwise_scaled_linear_sparse_cutlass_f8f8_time, + } + + +if __name__ == "__main__": + k_vals = (8192, 8192, 8192, 28672) + n_vals = (8192, 10240, 57344, 8192) + + results = [] + for m in tqdm([1 << i for i in range(10)]): + for n, k in zip(n_vals, k_vals): + results.append(benchmark(m, k, n)) + + df = pd.DataFrame(results) + df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False) + print(df.to_markdown(index=False)) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 26e1266c09..6cbec7465e 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -28,6 +28,7 @@ Layouts and Tensor Subclasses MarlinQQQLayout Int4CPULayout CutlassInt4PackedLayout + CutlassSemiSparseLayout Quantization techniques ----------------------- diff --git a/setup.py b/setup.py index ee3ebbf453..9e0e8bed77 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import copy import glob import os import subprocess @@ -73,6 +74,7 @@ def use_debug_mode(): BuildExtension, CppExtension, CUDAExtension, + _get_cuda_arch_flags, ) # Constant known variables used throughout this file @@ -251,6 +253,7 @@ def get_extensions(): sources += cuda_sources use_cutlass = False + cutlass_90a_sources = None if use_cuda and not IS_WINDOWS: use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") @@ -266,8 +269,46 @@ def get_extensions(): "-I" + cutlass_include_dir, "-I" + cutlass_tools_include_dir, "-I" + cutlass_extensions_include_dir, + "-DNDEBUG" if not debug_mode else "", + "-DCUTE_USE_PACKED_TUPLE=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", + "--use_fast_math", + "--ftemplate-backtrace-limit=0", + # "--keep", + # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", + # "--resource-usage", + # "-lineinfo", + # "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md ] ) + + cuda_arch_flags = _get_cuda_arch_flags() + build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags + build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags + if build_for_sm90 and not build_for_sm90a: + cutlass_90a_sources = [ + os.path.join( + extensions_cuda_dir, + "rowwise_scaled_linear_sparse_cutlass", + "rowwise_scaled_linear_sparse_cutlass_f8f8.cu", + ), + os.path.join( + extensions_cuda_dir, + "to_sparse_semi_structured_cutlass_sm9x", + "to_sparse_semi_structured_cutlass_sm9x_f8.cu", + ), + ] + for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]: + cutlass_90a_sources.append( + os.path.join( + extensions_cuda_dir, + "rowwise_scaled_linear_sparse_cutlass", + "rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu", + ) + ) + sources = [s for s in sources if s not in cutlass_90a_sources] else: # Remove CUTLASS-based kernels from the cuda_sources list. An # assumption is that these files will have "cutlass" in its @@ -291,6 +332,21 @@ def get_extensions(): ) ) + if cutlass_90a_sources is not None and len(cutlass_90a_sources) > 0: + cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args) + cutlass_90a_extra_compile_args["nvcc"].extend( + cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a"] + ) + ext_modules.append( + extension( + "torchao._C", + cutlass_90a_sources, + py_limited_api=True, + extra_compile_args=cutlass_90a_extra_compile_args, + extra_link_args=extra_link_args, + ) + ) + if build_torchao_experimental: ext_modules.append( CMakeExtension( diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py index d6203ab9a4..aac40d72f4 100644 --- a/test/test_rowwise_scaled_linear_cutlass.py +++ b/test/test_rowwise_scaled_linear_cutlass.py @@ -8,6 +8,7 @@ rowwise_scaled_linear_cutlass_s8s4, ) from torchao.quantization.utils import group_quantize_tensor_symmetric +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] @@ -40,6 +41,12 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias) w = torch.rand((size_n, size_k), dtype=dtype, device="cuda") bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None + if not TORCH_VERSION_AT_LEAST_2_4: + pytest.xfail("torch.nn.RMSNorm not supported") + rms_norm = torch.nn.RMSNorm(size_k).to("cuda") + x = rms_norm(x).to(dtype) + w = rms_norm(w).to(dtype) + x_2d = x.view(-1, x.shape[-1]) xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric( x_2d, xq_bits, size_k, dtype @@ -57,7 +64,7 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias) ) assert torch.all(wq_zeros == 0) if wq_bits == 4: - wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF) + wq = (wq_s8[..., 1::2] << 4) | (wq_s8[..., 0::2] & 0xF) else: wq = wq_s8 diff --git a/test/test_rowwise_scaled_linear_sparse_cutlass.py b/test/test_rowwise_scaled_linear_sparse_cutlass.py new file mode 100644 index 0000000000..71d4771b2c --- /dev/null +++ b/test/test_rowwise_scaled_linear_sparse_cutlass.py @@ -0,0 +1,137 @@ +import itertools +import random + +import pytest +import torch +from torch.testing._internal.common_cuda import SM90OrLater +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 + +from torchao.dtypes import ( + Float8Layout, + to_affine_quantized_floatx, +) +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, + to_sparse_semi_structured_cutlass_sm9x_f8, +) +from torchao.quantization.utils import _get_per_token_block_size +from torchao.sparsity.utils import create_semi_structured_tensor + +X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)] +XQ_WQ_DTYPES = [ + (torch.float8_e4m3fn, torch.float8_e4m3fn), + (torch.float8_e4m3fn, torch.float8_e5m2), + (torch.float8_e5m2, torch.float8_e4m3fn), + (torch.float8_e5m2, torch.float8_e5m2), +] +BATCH_SIZE = [1, 4] +SIZE_MNK = [ + (2, 128, 256), + (3, 128, 256), + (13, 128, 256), + (27, 128, 128), + (33, 128, 64), + (65, 128, 32), +] +USE_BIAS = [False, True] +BIAS_DTYPE = [torch.float16] +TEST_PARAMS = list( + itertools.product( + X_W_DTYPES, + XQ_WQ_DTYPES, + BATCH_SIZE, + SIZE_MNK, + USE_BIAS, + BIAS_DTYPE, + ) +) + + +def run_test_for_op( + op, + x_dtype, + w_dtype, + xq_dtype, + wq_dtype, + batch_size, + size_mnk, + use_bias, + bias_dtype, +): + random.seed(0) ## for create_semi_structured_tensor() + + size_m, size_n, size_k = size_mnk + + x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda") + w = create_semi_structured_tensor(size_n, size_k, dtype=w_dtype) + bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None + + if not TORCH_VERSION_AT_LEAST_2_4: + pytest.xfail("torch.nn.RMSNorm not supported") + rms_norm = torch.nn.RMSNorm(size_k).to("cuda") + x = rms_norm(x).to(x_dtype) + w = rms_norm(w).to(w_dtype) + + x_aqt = to_affine_quantized_floatx( + input_float=x, + target_dtype=xq_dtype, + block_size=_get_per_token_block_size(x), + _layout=Float8Layout(mm_config=None), + ) + xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain() + assert zero_points is None + + w_aqt = to_affine_quantized_floatx( + input_float=w, + target_dtype=wq_dtype, + block_size=_get_per_token_block_size(w), + _layout=Float8Layout(mm_config=None), + ) + wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain() + assert zero_points is None + wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq) + wq_sp_scales = wq_scales + + xq_2d = xq.view(-1, xq.shape[-1]) + size_m_2d = xq_2d.shape[0] + output_ref = ( + (xq_2d.float() @ wq.float().T) + * xq_scales.view(size_m_2d, 1) + * wq_scales.view(1, size_n) + ) + if bias is not None: + output_ref += bias + output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,)) + + fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias) + try: + output = op(*fn_inputs) + except NotImplementedError: + pytest.xfail("operator not implemented") + + torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=5e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices") +@pytest.mark.parametrize( + "x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype", + TEST_PARAMS, +) +def test_rowwise_scaled_linear_sparse_cutlass_f8f8( + x_w_dtypes, + xq_wq_dtypes, + batch_size, + size_mnk, + use_bias, + bias_dtype, +): + run_test_for_op( + rowwise_scaled_linear_sparse_cutlass_f8f8, + *x_w_dtypes, + *xq_wq_dtypes, + batch_size, + size_mnk, + use_bias, + bias_dtype, + ) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 0958a5207c..e225f8068b 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -334,11 +334,13 @@ def ffn_or_attn_only(mod, fqn): if quantization: from torchao.quantization import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, autoquant, float8_dynamic_activation_float8_weight, float8_weight_only, fpx_weight_only, gemlite_uintx_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -434,18 +436,30 @@ def ffn_or_attn_only(mod, fqn): ] ), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) - elif "int8adq-int4w-symm" in quantization: + elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout - quantize_( - model, - int8_dynamic_activation_int4_weight( - group_size=None, - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, - layout=CutlassInt4PackedLayout(), - ), - ) + nbits = int(quantization.removeprefix("int4dq-")) + assert nbits == 4 or nbits == 8 + if nbits == 4: + quantize_( + model, + int4_dynamic_activation_int4_weight( + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=CutlassInt4PackedLayout(), + ), + ) + elif nbits == 8: + quantize_( + model, + int8_dynamic_activation_int4_weight( + group_size=None, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=CutlassInt4PackedLayout(), + ), + ) if "marlin" in quantization: if "qqq" in quantization: from torchao.dtypes import MarlinQQQLayout @@ -564,16 +578,24 @@ def ffn_or_attn_only(mod, fqn): elif "float8wo" in quantization: quantize_(model, float8_weight_only()) elif "float8dq" in quantization: - granularity = str(quantization.split("-")[-1]) - if granularity == "tensor": - granularity = PerTensor() - elif granularity == "row": - granularity = PerRow() + if sparsity and "semi" in sparsity: + quantize_( + model, + Float8DynamicActivationFloat8SemiSparseWeightConfig(), + filter_fn=ffn_only, + ) else: - granularity = PerTensor() - quantize_( - model, float8_dynamic_activation_float8_weight(granularity=granularity) - ) + granularity = str(quantization.split("-")[-1]) + if granularity == "tensor": + granularity = PerTensor() + elif granularity == "row": + granularity = PerRow() + else: + granularity = PerTensor() + quantize_( + model, + float8_dynamic_activation_float8_weight(granularity=granularity), + ) elif "autoquant_v2" in quantization: from torchao._models._eval import InputRecorder from torchao._models.llama.model import prepare_inputs_for_model @@ -1130,7 +1152,7 @@ def callback(x): help=( "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx--, uintx---hqq, sparse-marlin, spinquant, " - + "embed-int8wo, marlin_qqq, gemlite---, int8adq-int4w-symm" + + "embed-int8wo, marlin_qqq, gemlite---, float8dq, int4dq-" ), ) parser.add_argument( diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh index 0117f12e27..994b0cfdc5 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -11,6 +13,8 @@ #endif #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) +#include +#include #include #include #include @@ -548,8 +552,11 @@ template at::Tensor rowwise_scaled_linear_cutlass( const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { + const at::Tensor& w_scale, const std::optional& bias_opt) { #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) + // Create bias tensor. + const auto bias = bias_opt.has_value() ? *bias_opt : at::Tensor{}; + // Check inputs. rowwise_scaled_linear_cutlass_check_inputs( xq, x_scale, wq, w_scale, bias); diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu index cc1b5ca123..a4c996a57d 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu @@ -7,21 +7,23 @@ namespace torchao { at::Tensor rowwise_scaled_linear_cutlass_s4s4( const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { + const at::Tensor& w_scale, + const std::optional& bias_opt = std::nullopt) { // Validate input datatypes. TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, __func__, " : The input datatypes combination ", xq.dtype(), " for xq and ", wq.dtype(), " for wq is not supported"); +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) // Dispatch to appropriate kernel template. - #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) - // We get ElementA/ElementB types from the header - return rowwise_scaled_linear_cutlass( - xq, x_scale, wq, w_scale, bias); - #else - TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s4s4 not available"); - return at::Tensor{}; - #endif + using ElementA = cutlass::int4b_t; + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + xq, x_scale, wq, w_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu index 29f30d08fc..d823956f24 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu @@ -1,4 +1,5 @@ #include + #include "rowwise_scaled_linear_cutlass.cuh" namespace torchao { @@ -6,20 +7,21 @@ namespace torchao { at::Tensor rowwise_scaled_linear_cutlass_s8s4( const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { + const at::Tensor& w_scale, + const std::optional& bias_opt = std::nullopt) { // Validate input datatypes. TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, __func__, " : The input datatypes combination ", xq.dtype(), " for xq and ", wq.dtype(), " for wq is not supported"); #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) - // Define ElementA as int8_t since it's a standard type + // Dispatch to appropriate kernel template. using ElementA = int8_t; - // ElementB comes from cutlass header - return rowwise_scaled_linear_cutlass( - xq, x_scale, wq, w_scale, bias); + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + xq, x_scale, wq, w_scale, bias_opt); #else - TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s8s4 not available"); + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); return at::Tensor{}; #endif } diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh new file mode 100644 index 0000000000..89a39ecfb8 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh @@ -0,0 +1,493 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12020) +#define BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS +#endif + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass_extensions/common.h" +#endif + +#define OPERATOR_NAME "rowwise_scaled_linear_sparse_cutlass" + +namespace torchao { + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) +template< + typename DtypeXq, + typename DtypeWq, + typename DtypeY, + typename UseBias, + typename DtypeBias, + typename DtypeXScale, + typename DtypeWScale, + typename TileShape, + typename ClusterShape> +void rowwise_scaled_linear_sparse_kernel_cutlass_sm9x( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + // For CUTLASS, sparsified tensor must be the first operand, thus + // the result will be calculated as: + // ((Wq @ Xq.T) * W_scale * X_scale.T + bias.T).T + + using SmArch = cutlass::arch::Sm90; + + // Use CUTLASS naming conventions for naming datatypes. + using ElementA = DtypeWq; + using ElementB = DtypeXq; + using ElementD = DtypeY; + using ElementAScale = DtypeWScale; + using ElementBScale = DtypeXScale; + using ElementBias = DtypeBias; + + using LayoutTagA = cutlass::layout::RowMajor; + using LayoutTagB = cutlass::layout::ColumnMajor; + using LayoutTagD = cutlass::layout::ColumnMajor; + + constexpr auto AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr auto AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr auto AlignmentD = 128 / cutlass::sizeof_bits::value; + + // TODO: use different accumulator datatype if inputs are not float. + using ElementAccumulator = float; + using ElementCompute = float; + + using ProblemShape = cute::Shape; + + // If KernelTmaWarpSpecializedPingpong used for kernel schedule, the + // performance is really bad; on the other side, using + // KernelTmaWarpSpecializedPingpongFP8FastAccum doesn't seem to + // affect the precision much - thus, sticking with it. + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using AScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementAScale>; + using ApplyAScale = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, + Accum, + AScale>; + using BScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementBScale>; + using ApplyBScale = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, + ApplyAScale, + BScale>; + using BiasScalar = + cutlass::epilogue::fusion::Sm90ScalarBroadcast; + using BiasTensor = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementBias>; + using Bias = std::conditional_t; + using ApplyBias = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, ElementCompute, ElementCompute, RoundStyle>, + ApplyBScale, + Bias>; + using EVT = ApplyBias; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + SmArch, cutlass::arch::OpClassSparseTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementD, LayoutTagD, AlignmentD, + ElementD, LayoutTagD, AlignmentD, + EpilogueSchedule, + EVT>::CollectiveOp; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + SmArch, cutlass::arch::OpClassSparseTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + using GemmKernel = enable_3x_kernel_for_sm90_or_later< + cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue>>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideE = StrideA; + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = + typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + const int m = (int)Wq.size(0); + const int n = (int)Xq.size(0); + const int k = (int)Xq.size(1); + + ProblemShape problem_shape(m, n, k, 1); + const auto layout_A = SparseConfig::fill_layoutA(problem_shape); + const auto layout_E = SparseConfig::fill_layoutE(problem_shape); + const auto stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + const auto stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { + (ElementA*)Wq.data_ptr(), layout_A, (ElementB*)Xq.data_ptr(), stride_B, + (ElementE*)W_meta.data_ptr(), layout_E + }, + { + {}, + (ElementD*)Y.data_ptr(), stride_D, (ElementD*)Y.data_ptr(), stride_D + } + }; + + const typename AScale::Arguments A_scale_arguments{ + (ElementAScale*)W_scale.data_ptr(), + ElementAScale(1), + {cute::_1{}, cute::_0{}, cute::_0{}} + }; + const typename BScale::Arguments B_scale_arguments{ + (ElementBScale*)X_scale.data_ptr(), + ElementBScale(1), + {cute::_0{}, cute::_1{}, cute::_0{}} + }; + const auto bias_arguments{ + [&]() -> typename Bias::Arguments { + if constexpr (UseBias::value) { + return { + (ElementBias*)bias.data_ptr(), + ElementBias(0), + {cute::_1{}, cute::_0{}, cute::_0{}} + }; + } else { + return {ElementBias(0)}; + } + }() + }; + arguments.epilogue.thread = { + { + { + {}, // Accum + A_scale_arguments, // AScale + {} // ApplyAScale + }, + B_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, + bias_arguments, // Bias + {} // ApplyBiass + }; + + Gemm gemm_op; + + cutlass::Status status; + + // Verify that GEMM operation with given arguments can be performed + // by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = Xq.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS mixed datatypes GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Perform mixed datatypes GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +static void select_config( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm9x = dprops->major == 9; + + if (is_sm9x) { + if constexpr ((std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value)) { + // TODO: add proper tuning here. + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + rowwise_scaled_linear_sparse_kernel_cutlass_sm9x< + DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template +static void +dispatch_on_bias( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + if (bias.numel() == 0) { + using UseBias = std::false_type; + using DtypeBias = DtypeY; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + using UseBias = std::true_type; + if (bias.scalar_type() == at::ScalarType::Half) { + using DtypeBias = cutlass::half_t; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } else if (bias.scalar_type() == at::ScalarType::BFloat16) { + using DtypeBias = cutlass::bfloat16_t; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported for datatype ", bias.scalar_type(), + " for bias"); +} + +template + static void +dispatch_on_X_scale_and_W_scale( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + TORCH_CHECK(Y.scalar_type() == X_scale.scalar_type(), + OPERATOR_NAME, " : Operator not supported for Y datatype ", + Y.scalar_type(), " as it's different from the first ", + " operand scale datatype ", X_scale.scalar_type()); + + if (X_scale.scalar_type() == at::ScalarType::Half && + W_scale.scalar_type() == at::ScalarType::Half) { + using DtypeXScale = cutlass::half_t; + using DtypeWScale = cutlass::half_t; + using DtypeY = cutlass::half_t; + dispatch_on_bias(Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } else if (X_scale.scalar_type() == at::ScalarType::BFloat16 && + W_scale.scalar_type() == at::ScalarType::BFloat16) { + using DtypeXScale = cutlass::bfloat16_t; + using DtypeWScale = cutlass::bfloat16_t; + using DtypeY = cutlass::bfloat16_t; + dispatch_on_bias(Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported for combination of datatypes ", + X_scale.scalar_type(), " for first operand scale and ", + W_scale.scalar_type(), " for second operand scale"); +} + +template +void +rowwise_scaled_linear_sparse_cutlass_check_inputs( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const at::Tensor& bias) { + // Validate metadata datatype. + TORCH_CHECK(W_meta.dtype() == at::kByte, OPERATOR_NAME, + " : Expected Wq meta argument to be of torch.uint8 datatype got ", + Wq.dtype()); + + // Validate layouts of arguments. + TORCH_CHECK(Xq.dim() >= 2, OPERATOR_NAME, + " : Expected Xq argument to be 2D or higher-dimensional tensor, " + " got ", Xq.dim(), " dims"); + TORCH_CHECK(Xq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Xq argument to be strided, got layout ", + Xq.layout()); + TORCH_CHECK(X_scale.dim() == Xq.dim() - 1, OPERATOR_NAME, + " : Expected Xq scale argument to be ", Xq.dim() - 1, + "D tensor, got ", X_scale.dim(), " dims"); + TORCH_CHECK(X_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Xq scale argument to be strided, got layout ", + X_scale.layout()); + TORCH_CHECK(Wq.dim() == 2, OPERATOR_NAME, + " : Expected Wq argument to be 2D tensor, got ", Wq.dim(), + " dims"); + TORCH_CHECK(Wq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq argument to be strided, got layout ", + Wq.layout()); + TORCH_CHECK(W_meta.dim() == 2, OPERATOR_NAME, + " : Expected Wq meta argument to be 2D tensor, got ", + W_meta.dim(), " dims"); + TORCH_CHECK(W_meta.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq meta argument to be strided, got layout ", + W_meta.layout()); + TORCH_CHECK(W_scale.dim() == 1 || W_scale.dim() == 2, OPERATOR_NAME, + " : Expected Wq scale argument to be 1D or 2D tensor, ", + "got ", W_scale.dim(), " dims"); + TORCH_CHECK(W_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq scale argument to be strided, got layout ", + W_scale.layout()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dim() == 1, OPERATOR_NAME, + " : Expected bias argument to be 1D tensor, got ", bias.dim(), + " dims"); + TORCH_CHECK(bias.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected bias argument to be strided, got layout ", + bias.layout()); + } + + // Validate sizes of arguments. + const auto Xq_sizes = Xq.sizes().vec(); + const auto Wq_sizes = Wq.sizes().vec(); + TORCH_CHECK(Xq_sizes.back() % 32 == 0, OPERATOR_NAME, + " : For alignment purpose, Xq argument must have number of " + "columns divisible by ", 32, ", got ", Xq_sizes.back(), + " columns"); + TORCH_CHECK(Wq_sizes[0] % 8 == 0, OPERATOR_NAME, + " : For alignment purpose, Wq argument to have number of rows " + "divisible by ", 8, ", but got ", Wq_sizes[0], " rows"); + TORCH_CHECK(Xq_sizes.back() == 2 * Wq_sizes[1], OPERATOR_NAME, + " : Expected Xq argument to have ", 2 * Wq_sizes[1], + " columns, but got ", Xq_sizes.back()); + const auto X_scale_sizes = X_scale.sizes().vec(); + for (auto i = 0; i < X_scale_sizes.size(); ++i) + TORCH_CHECK(X_scale_sizes[i] == Xq_sizes[i], OPERATOR_NAME, + " : Expected Xq scale argument size at position ", i, " to be ", + Xq_sizes[i], ", but got ", X_scale_sizes[i]); + TORCH_CHECK(Wq_sizes[1] % 8 == 0, OPERATOR_NAME, + " : Expected Wq argument to have number of columns divisible by ", + " 8, got ", Wq_sizes[1]); + // W_meta may be padded, thus expected shape calculations for this + // tensor are as follows. + const auto W_meta_size_0_expected = std::max((int)Wq_sizes[0], 64); + const auto W_meta_size_1_expected = std::max((int)Wq_sizes[1] / 4, 16); + TORCH_CHECK(W_meta.size(0) == W_meta_size_0_expected, OPERATOR_NAME, + " : Expected Wq meta argument to have ", W_meta_size_0_expected, + " rows, got ", W_meta.size(0), " rows"); + TORCH_CHECK(W_meta.size(1) == W_meta_size_1_expected, OPERATOR_NAME, + " : Expected Wq meta argument to hold ", W_meta_size_1_expected, + " bytes per row to encode sparsity of Wq argument, got ", + W_meta.size(1), " bytes"); + TORCH_CHECK(W_scale.numel() == Wq_sizes[0], OPERATOR_NAME, + " : Expected Wq scale argument to have ", Wq_sizes[0], + " elements, got ", W_scale.numel(), " elements"); + if (bias.numel() > 0) { + TORCH_CHECK(bias.numel() == Wq_sizes[0], OPERATOR_NAME, + " : Expected bias argument to have ", Wq_sizes[0], + " elements, got ", bias.numel(), " elements"); + } + + // Validate strides of arguments. + const auto Xq_strides = Xq.strides(); + TORCH_CHECK(Xq_strides[Xq_strides.size() - 1] == 1, OPERATOR_NAME, + " : Expected Xq argument in row-major layout"); + auto Xq_stride_expected = Xq_strides[Xq_strides.size() - 2]; + for (int i = Xq_strides.size() - 3; i >= 0; --i) { + Xq_stride_expected *= Xq_sizes[i + 1]; + TORCH_CHECK(Xq_strides[i] == Xq_stride_expected, OPERATOR_NAME, + " : Expected Xq argument in row-major layout"); + } + TORCH_CHECK(X_scale.is_contiguous(), OPERATOR_NAME, + " : Expected Xq scale argument to be contiguous"); + const auto Wq_strides = Wq.strides(); + TORCH_CHECK(Wq_strides[0] >= 1 && Wq_strides[1] == 1, OPERATOR_NAME, + " : Expected Wq argument in row-major layout"); + const auto W_meta_strides = W_meta.strides(); + TORCH_CHECK(W_meta_strides[0] >= 1 && W_meta_strides[1] == 1, OPERATOR_NAME, + " : Expected Wq meta argument in row-major layout"); + TORCH_CHECK(W_scale.is_contiguous(), OPERATOR_NAME, + " : Expected Wq scale argument to be contiguous"); + if (bias.numel() > 0) { + const auto bias_strides = bias.strides(); + TORCH_CHECK(bias_strides[0] == 1, OPERATOR_NAME, + " : Expected bias argument to be contiguous"); + } +} +#endif + +template +at::Tensor +rowwise_scaled_linear_sparse_cutlass( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + // Create bias tensor. + const auto bias = bias_opt.has_value() ? *bias_opt : at::Tensor{}; + + // Check inputs. + rowwise_scaled_linear_sparse_cutlass_check_inputs( + Xq, X_scale, Wq, W_meta, W_scale, bias); + + // Squash the input tensors as appropriate. + const auto Xq_sizes = Xq.sizes().vec(); + const auto Xq_2d = Xq.reshape({-1, Xq_sizes.back()}); + const auto X_scale_1d = X_scale.reshape({-1}); + const auto W_scale_1d = W_scale.reshape({-1}); + + // Create result tensor. + at::Tensor Y = X_scale.new_empty({Xq_2d.size(0), Wq.size(0)}); + + // Dispatch to appropriate kernel template. + dispatch_on_X_scale_and_W_scale( + Xq_2d, X_scale_1d, Wq, W_meta, W_scale_1d, bias, Y); + + // Reshape and return Y tensor. + auto Y_sizes = Xq_sizes; + Y_sizes.back() = Wq.size(0); + return Y.reshape(Y_sizes); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu new file mode 100644 index 0000000000..c2da7b6dc3 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu @@ -0,0 +1,28 @@ +#include "rowwise_scaled_linear_sparse_cutlass.cuh" +#include "rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e4m3e4m3( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { + // Validate input datatypes. + TORCH_CHECK( + Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e4m3fn, + __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", + Wq.dtype(), " for Wq is not supported"); + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + using DtypeXq = cutlass::float_e4m3_t; + using DtypeWq = cutlass::float_e4m3_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h new file mode 100644 index 0000000000..1ccf780487 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e4m3e4m3( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt); + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu new file mode 100644 index 0000000000..dcc92afa71 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu @@ -0,0 +1,28 @@ +#include "rowwise_scaled_linear_sparse_cutlass.cuh" +#include "rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e4m3e5m2( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { + // Validate input datatypes. + TORCH_CHECK( + Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e5m2, + __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", + Wq.dtype(), " for Wq is not supported"); + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + using DtypeXq = cutlass::float_e4m3_t; + using DtypeWq = cutlass::float_e5m2_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h new file mode 100644 index 0000000000..8ca6b4bd37 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e4m3e5m2( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt); + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu new file mode 100644 index 0000000000..185ff8586c --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu @@ -0,0 +1,28 @@ +#include "rowwise_scaled_linear_sparse_cutlass.cuh" +#include "rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e5m2e4m3( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { + // Validate input datatypes. + TORCH_CHECK( + Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e4m3fn, + __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", + Wq.dtype(), " for Wq is not supported"); + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + using DtypeXq = cutlass::float_e5m2_t; + using DtypeWq = cutlass::float_e4m3_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h new file mode 100644 index 0000000000..61e04ab051 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e5m2e4m3( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt); + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu new file mode 100644 index 0000000000..bda18ea86d --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu @@ -0,0 +1,28 @@ +#include "rowwise_scaled_linear_sparse_cutlass.cuh" +#include "rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e5m2e5m2( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt) { + // Validate input datatypes. + TORCH_CHECK( + Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e5m2, + __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", + Wq.dtype(), " for Wq is not supported"); + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + using DtypeXq = cutlass::float_e5m2_t; + using DtypeWq = cutlass::float_e5m2_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h new file mode 100644 index 0000000000..4fee072f16 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_e5m2e5m2( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt); + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu new file mode 100644 index 0000000000..7a651cb3dd --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu @@ -0,0 +1,48 @@ +#include + +#include "rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h" +#include "rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h" +#include "rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h" +#include "rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_f8f8( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const std::optional& bias_opt = std::nullopt) { + // Validate input datatypes. + TORCH_CHECK( + (Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e4m3fn) || + (Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e5m2) || + (Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e4m3fn) || + (Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e5m2), + __func__, " : The input datatypes combination ", Xq.dtype(), + " for Xq and ", Wq.dtype(), " for Wq is not supported"); + + // Dispatch to appropriate kernel template. + if (Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e4m3fn) { + return rowwise_scaled_linear_sparse_cutlass_e4m3e4m3( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); + } else if (Xq.dtype() == at::kFloat8_e4m3fn && + Wq.dtype() == at::kFloat8_e5m2) { + return rowwise_scaled_linear_sparse_cutlass_e4m3e5m2( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); + } else if (Xq.dtype() == at::kFloat8_e5m2 && + Wq.dtype() == at::kFloat8_e4m3fn) { + return rowwise_scaled_linear_sparse_cutlass_e5m2e4m3( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); + } else if (Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e5m2) { + return rowwise_scaled_linear_sparse_cutlass_e5m2e5m2( + Xq, X_scale, Wq, W_meta, W_scale, bias_opt); + } + return at::Tensor{}; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_sparse_cutlass_f8f8", + &rowwise_scaled_linear_sparse_cutlass_f8f8); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh new file mode 100644 index 0000000000..24a9371593 --- /dev/null +++ b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh @@ -0,0 +1,174 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12020) +#define BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X +#endif + +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass_extensions/common.h" +#endif + +#define OPERATOR_NAME "to_sparse_semi_structured_cutlass_sm9x" + +namespace torchao { + +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) +template +std::tuple +to_sparse_semi_structured_kernel_cutlass_sm9x(const at::Tensor& W) { + // The kernel doesn't check, but assumes instead, that the input + // tensor is a structured sparse tensor. + + static_assert(std::is_same_v || + std::is_same_v); + + using SmArch = cutlass::arch::Sm90; + + using ProblemShape = cute::Shape; + + using LayoutTagW = cutlass::layout::RowMajor; + using StrideW = cutlass::gemm::TagToStrideA_t; + + using DtypeMeta = unsigned char; + + // CUTLASS derives the sparse config from the mainloop. In order + // not to instantiate the whole mainloop here, the sparse config for + // FP8 case is hard-coded below. The config template arguments are + // found by changing input data types for CUTLASS example 62 to FP8, + // and then printing out the sparse config data type. + using SparseConfig = cutlass::Sm90GemmSparseConfig< + cute::sparse_elem<2, DtypeW>, + cute::GMMA::Major::K, + cute::sparse_elem<8, unsigned char>, + cute::_128>; + + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, DtypeW, LayoutTagW, SparseConfig>; + using CompressorKernel = enable_3x_kernel_for_sm90_or_later< + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, DtypeW, LayoutTagW, SparseConfig, SmArch>>; + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; + + const int m = W.size(0); + const int k = W.size(1); + + ProblemShape problem_shape(m, 1, k, 1); + + StrideW stride_W = + cutlass::make_cute_packed_stride(StrideW{}, cute::make_shape(m, k, 1)); + CompressorUtility compressor_utility(problem_shape, stride_W); + int k_compressed = compressor_utility.get_tensorA_k_physical(); + int m_meta = compressor_utility.get_metadata_m_physical(); + int k_meta = compressor_utility.get_metadata_k_physical(); + + // Create result tensors. + at::Tensor W_compressed = W.new_empty({m, k_compressed}); + at::Tensor W_meta = + W.new_empty({m_meta, k_meta}, at::TensorOptions().dtype(at::kByte)); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + typename Compressor::Arguments arguments{ + problem_shape, + { + (DtypeW*)W.data_ptr(), stride_W, (DtypeW*)W_compressed.data_ptr(), + (DtypeMeta*)W_meta.data_ptr() + }, + {hw_info}}; + + Compressor compressor_op; + + cutlass::Status status; + + // Verify that compression operation with given arguments can be + // performed by CUTLASS. + status = compressor_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Allocate workspace for the compressor. + const auto workspace_size = Compressor::get_workspace_size(arguments); + auto workspace = W.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize compressor. + status = compressor_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Perform compression. + status = compressor_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(W_compressed, W_meta); +} + +template +void +to_sparse_semi_structured_cutlass_sm9x_check_inputs(const at::Tensor& W) { + // Validate the input tensor layout. + TORCH_CHECK(W.dim() == 2, OPERATOR_NAME, + " : Expected W argument to be 2D tensor, got ", W.dim(), + " dims"); + TORCH_CHECK(W.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected W argument to be strided, got layout ",W.layout()); + + // Validate the input tensor shape. + const auto W_sizes = W.sizes().vec(); + TORCH_CHECK(W_sizes[1] % 8 == 0, OPERATOR_NAME, + " : Expected number of columns of the W argument to be ", + "divisible by 8, got ", W_sizes[1], " columns"); + + // Validate the input tensor strides. + const auto W_strides = W.strides(); + TORCH_CHECK(W_strides[1] == 1, OPERATOR_NAME, + " : Expected W argument in row-major layout"); +} +#endif + +template +std::tuple +to_sparse_semi_structured_cutlass_sm9x(const at::Tensor& W) { +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm9x = dprops->major == 9; + if (!is_sm9x) { + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); + } + + // Check inputs. + to_sparse_semi_structured_cutlass_sm9x_check_inputs(W); + + // Call the kernel. + return to_sparse_semi_structured_kernel_cutlass_sm9x(W); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return std::make_tuple(at::Tensor{}, at::Tensor{}); +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu new file mode 100644 index 0000000000..1a4ab285b6 --- /dev/null +++ b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu @@ -0,0 +1,35 @@ +#include + +#include "to_sparse_semi_structured_cutlass_sm9x.cuh" + +namespace torchao { + +std::tuple +to_sparse_semi_structured_cutlass_sm9x_f8(const at::Tensor& W) { + // Validate input datatypes. + TORCH_CHECK(W.dtype() == at::kFloat8_e5m2 || W.dtype() == at::kFloat8_e4m3fn, + __func__, " : The input datatype ", W.dtype(), + " is not supported"); + +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) + // Dispatch to appropriate kernel template. + if (W.dtype() == at::kFloat8_e5m2) { + using DtypeW = cutlass::float_e5m2_t; + return to_sparse_semi_structured_cutlass_sm9x(W); + } else if (W.dtype() == at::kFloat8_e4m3fn) { + using DtypeW = cutlass::float_e4m3_t; + return to_sparse_semi_structured_cutlass_sm9x(W); + } + return std::tuple(at::Tensor{}, at::Tensor{}); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return std::tuple(at::Tensor{}, at::Tensor{}); +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::to_sparse_semi_structured_cutlass_sm9x_f8", + &to_sparse_semi_structured_cutlass_sm9x_f8); +} + +} // namespace torchao diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 9cbd4cd2a0..9224f8efa3 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -9,6 +9,7 @@ to_affine_quantized_intx_static, ) from .floatx import ( + CutlassSemiSparseLayout, Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 @@ -52,4 +53,5 @@ "MarlinQQQLayout", "Int4CPULayout", "CutlassInt4PackedLayout", + "CutlassSemiSparseLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 54f4a72811..3719bccb89 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -6,6 +6,10 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, ) +from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( + _linear_fp8_act_fp8_weight_sparse_cutlass_check, + _linear_fp8_act_fp8_weight_sparse_cutlass_impl, +) from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl, @@ -161,6 +165,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, ), + ( + _linear_fp8_act_fp8_weight_sparse_cutlass_check, + _linear_fp8_act_fp8_weight_sparse_cutlass_impl, + ), ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 3f0a1ccd5c..7e634a5211 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,3 +1,6 @@ +from .cutlass_semi_sparse_layout import ( + CutlassSemiSparseLayout, +) from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( FloatxTensorCoreLayout, @@ -10,4 +13,5 @@ "to_scaled_tc_floatx", "from_scaled_tc_floatx", "Float8Layout", + "CutlassSemiSparseLayout", ] diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py new file mode 100644 index 0000000000..3b81d9a021 --- /dev/null +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -0,0 +1,178 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, + to_sparse_semi_structured_cutlass_sm9x_f8, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class CutlassSemiSparseLayout(Layout): + """Layout class for float8 2:4 sparsity layout for affine quantized tensor, for cutlass kernel.""" + + def pre_process(self, dense: torch.Tensor) -> torch.Tensor: + # prune to 2:4 if not already + from torchao.sparsity.utils import mask_creator + + return dense * mask_creator(dense).bool() + + +@register_layout(CutlassSemiSparseLayout) +class CutlassSemiSparseTensorImpl(AQTTensorImpl): + @staticmethod + def __new__( + cls, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = sparse.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else sparse.layout + ) + kwargs["dtype"] = sparse.dtype + kwargs["requires_grad"] = False + shape = (sparse.shape[0], 2 * sparse.shape[-1]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + self.sparse = sparse + self.meta = meta + self.scale = scale + self._layout = _layout + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["sparse", "meta", "scale"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + sparse = tensor_data_dict["sparse"] + meta = tensor_data_dict["meta"] + scale = tensor_data_dict["scale"] + (_layout,) = tensor_attributes + return cls(sparse, meta, scale, _layout) + + def get_plain(self): + # No support in CUTLASS to convert back to dense from sparse + # semi-structured format, so multiplying with identity matrix + # for the conversion. + cols = 2 * self.sparse.shape[1] + input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) + input_scale = torch.ones( + (cols,), dtype=self.scale.dtype, device=self.sparse.device + ) + dense = ( + rowwise_scaled_linear_sparse_cutlass_f8f8( + input, input_scale, self.sparse, self.meta, self.scale + ) + .t() + .contiguous() + ) + + return dense, self.scale, None + + @classmethod + def from_plain( + cls, + dense: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert zero_point is None or torch.all(zero_point == 0) + + # FIXME: remove this when CUTLASS PR #2110 merged. + dtype = dense.dtype + dense = dense.view(torch.uint8) + dense[dense == 128] = 0 + dense = dense.view(dtype) + + sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(dense) + return cls( + sparse, + meta, + scale, + _layout, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.sparse = fn(self.sparse) + self.meta = fn(self.meta) + self.scale = fn(self.scale) + return self + + +def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import Float8Layout + + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, Float8Layout) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassSemiSparseLayout) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and len(weight_tensor.tensor_impl.scale.shape) == 1 + and (bias is None or bias.dtype == input_tensor.dtype) + and (bias is None or len(bias.shape) == 1) + ) + + +def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 + + input = input_tensor.tensor_impl.float8_data + input_scale = input_tensor.tensor_impl.scale + weight = weight_tensor.tensor_impl.sparse + weight_meta = weight_tensor.tensor_impl.meta + weight_scale = weight_tensor.tensor_impl.scale + + out = rowwise_scaled_linear_sparse_cutlass_f8f8( + input, input_scale, weight, weight_meta, weight_scale, bias + ) + + return out diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index ae8ea78ceb..9544fdfe71 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -11,9 +11,9 @@ register_layout, ) from torchao.dtypes.uintx.plain_layout import ( - _aqt_is_int8_reduced_range, + _aqt_is_int8, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout aten = torch.ops.aten @@ -82,9 +82,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["int_data", "scale"], [ - self._layout, - ] + return ["int_data", "scale"], [self._layout] @classmethod def __tensor_unflatten__( @@ -92,13 +90,13 @@ def __tensor_unflatten__( ): int_data = tensor_data_dict["int_data"] scale = tensor_data_dict["scale"] - _layout = tensor_attributes + (_layout,) = tensor_attributes return cls(int_data, scale, _layout) def get_plain(self): int_data = torch.stack( - ((self.int_data << 4) >> 4, self.int_data >> 4), dim=2 - ).view((self.int_data.shape[0], 2 * self.int_data.shape[1])) + ((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1 + ).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],)) return int_data, self.scale, None @classmethod @@ -110,8 +108,7 @@ def from_plain( _layout: Layout, ): assert zero_point is None or torch.all(zero_point == 0) - - int_data_s4 = ((int_data[:, 1::2] & 0xF) << 4) | (int_data[:, 0::2] & 0xF) + int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) return cls( int_data_s4, scale, @@ -130,12 +127,14 @@ def _apply_fn_to_data(self, fn): def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(input_tensor._layout, PlainLayout) + and _aqt_is_int8(input_tensor) and input_tensor.dtype in (torch.float16, torch.bfloat16) and len(input_tensor.shape) >= 2 and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(weight_tensor) and weight_tensor.dtype == input_tensor.dtype and len(weight_tensor.shape) == 2 @@ -164,12 +163,14 @@ def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias) def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(input_tensor) and input_tensor.dtype in (torch.float16, torch.bfloat16) and len(input_tensor.shape) >= 2 and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(weight_tensor) and weight_tensor.dtype == input_tensor.dtype and len(weight_tensor.shape) == 2 diff --git a/torchao/ops.py b/torchao/ops.py index a3aee761b9..e7586d6a59 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,4 +1,5 @@ import functools +from typing import Optional import torch from torch import Tensor @@ -22,10 +23,16 @@ "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" ) lib.define( - "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" + "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor? bias=None) -> Tensor" ) lib.define( - "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" + "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor? bias=None) -> Tensor" +) +lib.define( + "rowwise_scaled_linear_sparse_cutlass_f8f8(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_meta, Tensor weight_scale, Tensor? bias=None) -> Tensor" +) +lib.define( + "to_sparse_semi_structured_cutlass_sm9x_f8(Tensor weight) -> (Tensor, Tensor)" ) lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") @@ -536,7 +543,7 @@ def rowwise_scaled_linear_cutlass_s8s4( input_scale: Tensor, weight: Tensor, weight_scale: Tensor, - bias: Tensor, + bias: Optional[Tensor] = None, ) -> Tensor: """ CUTLASS-based row-wise scaled W4A8 linear operator. @@ -545,7 +552,7 @@ def rowwise_scaled_linear_cutlass_s8s4( input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. weight: quantized weight matrix, in row-major layout. weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). - bias: a vector of size equal to number of rows of weight tensor, or None. + bias: an optional vector of size equal to number of rows of weight tensor, or None. Returns: output: result tensor, in row-major layout. """ @@ -561,7 +568,7 @@ def _( input_scale: Tensor, weight: Tensor, weight_scale: Tensor, - bias: Tensor, + bias: Optional[Tensor] = None, ) -> Tensor: # No checks here, as detailed checks are performed by the # operator itself. @@ -578,7 +585,7 @@ def rowwise_scaled_linear_cutlass_s4s4( input_scale: Tensor, weight: Tensor, weight_scale: Tensor, - bias: Tensor, + bias: Optional[Tensor] = None, ) -> Tensor: """ CUTLASS-based row-wise scaled W4A4 linear operator. @@ -587,7 +594,7 @@ def rowwise_scaled_linear_cutlass_s4s4( input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. weight: quantized weight matrix, in row-major layout. weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). - bias: a vector of size equal to number of rows of weight tensor, or None. + bias: an optional vector of size equal to number of rows of weight tensor, or None. Returns: output: result tensor, in row-major layout. """ @@ -603,11 +610,80 @@ def _( input_scale: Tensor, weight: Tensor, weight_scale: Tensor, - bias: Tensor, + bias: Optional[Tensor] = None, ) -> Tensor: + # No checks here, as detailed checks are performed by the + # operator itself. + return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) +def rowwise_scaled_linear_sparse_cutlass_f8f8( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_meta: Tensor, + weight_scale: Tensor, + bias: Optional[Tensor] = None, +) -> Tensor: + """ + CUTLASS-based row-wise scaled F8F8 linear operator, for sparsified weight case. + Args: + input: quantized input tensor, in row-major layout. + input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. + weight: sparsified quantized weight matrix, in row-major layout. + weight_meta: sparsify metadata for weight tensor. + weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). + bias: an optional vector of size equal to number of rows of weight tensor, or None. + Returns: + output: result tensor, in row-major layout. + """ + + return torch.ops.torchao.rowwise_scaled_linear_sparse_cutlass_f8f8.default( + input, input_scale, weight, weight_meta, weight_scale, bias + ) + + +@register_custom_op("torchao::rowwise_scaled_linear_sparse_cutlass_f8f8") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_meta: Tensor, + weight_scale: Tensor, + bias: Optional[Tensor] = None, +) -> Tensor: + return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) + + +def to_sparse_semi_structured_cutlass_sm9x_f8( + weight: Tensor, +) -> (Tensor, Tensor): + """ + CUTLASS-based conversion from sparsified input tensor to corresponding compressed tensor, along with corresponding metadata tensor. + Args: + weight: input tensor, in row-major layout. + Returns: + weight_compressed: compressed weight tensor, with sparsity eliminated, in row-major layout. + weight_meta: metadata tensor, describing the sparsity structure of the input tensor, also in row-major layout. + """ + + return torch.ops.torchao.to_sparse_semi_structured_cutlass_sm9x_f8.default(weight) + + +@register_custom_op("torchao::to_sparse_semi_structured_cutlass_sm9x_f8") +def _( + weight: Tensor, +) -> (Tensor, Tensor): + # No checks here, as detailed checks are performed by the + # operator itself. + + return ( + weight.new_empty(weight[0], weight[1] // 2), + weight.new_empty(weight[0], max(weight[1] // 8, 16), dtype=torch.char), + ) + + @functools.lru_cache() def _get_dtypes(): """TODO: when e8m0 is hardened and major release lets remove uint8 support""" diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 5f15a6bbbe..b816eb585e 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -46,6 +46,7 @@ AffineQuantizedObserverBase, ) from .quant_api import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, @@ -138,6 +139,7 @@ "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", "Float8StaticActivationFloat8WeightConfig", + "Float8DynamicActivationFloat8SemiSparseWeightConfig", "UIntXWeightOnlyConfig", "FPXWeightOnlyConfig", "GemliteUIntXWeightOnlyConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 0e7cda16f0..796c8c7930 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -30,6 +30,7 @@ from torchao.dtypes import ( AffineQuantizedTensor, CutlassInt4PackedLayout, + CutlassSemiSparseLayout, Float8Layout, Int4CPULayout, MarlinQQQLayout, @@ -122,6 +123,7 @@ "float8_static_activation_float8_weight", "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer", + "Float8DynamicActivationFloat8SemiSparseWeightConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -654,7 +656,7 @@ def _int8_dynamic_activation_int4_weight_transform( if isinstance(layout, MarlinQQQLayout): input_quant_func = _int8_symm_per_token_quant elif isinstance(layout, CutlassInt4PackedLayout): - input_quant_func = _int8_symm_per_token_reduced_range_quant_cutlass + input_quant_func = _int8_symm_per_token_quant_cutlass else: input_quant_func = _int8_symm_per_token_quant else: @@ -664,6 +666,8 @@ def _int8_dynamic_activation_int4_weight_transform( weight = to_marlinqqq_quantized_intx( weight, block_size, quant_min, quant_max, _layout=layout ) + elif isinstance(layout, CutlassInt4PackedLayout): + weight = _int4_symm_per_token_quant_cutlass(weight) else: weight = to_affine_quantized_intx( weight, @@ -718,17 +722,7 @@ def _int4_dynamic_activation_int4_weight_transform( if act_mapping_type != MappingType.SYMMETRIC: raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") - weight = to_affine_quantized_intx( - weight, - mapping_type=mapping_type, - block_size=(1, weight.shape[1]), - target_dtype=torch.int8, - quant_min=-8, - quant_max=7, - eps=torch.finfo(torch.float32).eps, - zero_point_domain=ZeroPointDomain.NONE, - _layout=layout, - ) + weight = _int4_symm_per_token_quant_cutlass(weight) weight = to_linear_activation_quantized( weight, _int4_symm_per_token_quant_cutlass, @@ -972,24 +966,14 @@ def _int8_symm_per_token_reduced_range_quant_noop_decode( ) -def _int8_symm_per_token_reduced_range_quant_cutlass( - x: torch.Tensor, -) -> torch.Tensor: - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = 1e-5 - quant_min = -127 - quant_max = 127 +def _int8_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: return to_affine_quantized_intx( x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - eps=eps, + mapping_type=MappingType.SYMMETRIC, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int8, + eps=torch.finfo(torch.float32).eps, zero_point_domain=ZeroPointDomain.NONE, - quant_min=quant_min, - quant_max=quant_max, - scale_dtype=torch.float16 if x.dtype == torch.float16 else None, ) @@ -1001,7 +985,7 @@ def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: target_dtype=torch.int8, quant_min=-8, quant_max=7, - eps=1e-5, + eps=torch.finfo(torch.float32).eps, zero_point_domain=ZeroPointDomain.NONE, _layout=CutlassInt4PackedLayout(), ) @@ -1327,6 +1311,69 @@ def _float8_dynamic_activation_float8_weight_transform( return module +@dataclass +class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig): + """ + Applies float8 dynamic quantization to activations and float8 quantization followed by compression to sparse semi-structured tensor to weights of linear layers. + + Args: + `layout`: layout type for quantized weight tensor, only supports `CutlassSemiSparseLayout` at the moment. + `activation_dtype`: data type for quantized activation tensor. + `weight_dtype`: data type for quantized weight tensor. + """ + + layout: Layout = CutlassSemiSparseLayout() + activation_dtype: torch.dtype = torch.float8_e5m2 + weight_dtype: torch.dtype = torch.float8_e4m3fn + + +def _float8_dynamic_activation_quant_func( + input: torch.Tensor, + activation_dtype: torch.dtype, +): + return to_affine_quantized_floatx( + input_float=input, + target_dtype=activation_dtype, + block_size=_get_per_token_block_size(input), + _layout=Float8Layout(mm_config=None), + ) + + +@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig) +def _float8_dynamic_activation_float8_semi_sparse_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationFloat8SemiSparseWeightConfig +): + assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0" + + layout = config.layout + if not isinstance(layout, CutlassSemiSparseLayout): + raise NotImplementedError( + f"Only CutlassSemiSparseLayout layout is supported. Received {layout}." + ) + + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + weight = module.weight + + weight_sparse = to_affine_quantized_floatx( + input_float=weight, + target_dtype=weight_dtype, + block_size=_get_per_token_block_size(weight), + _layout=CutlassSemiSparseLayout(), + ) + + input_quant_func = _float8_dynamic_activation_quant_func + input_quant_kwargs = {"activation_dtype": activation_dtype} + + weight = to_linear_activation_quantized( + weight_sparse, input_quant_func, quant_kwargs=input_quant_kwargs + ) + + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + @dataclass class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): """ @@ -1552,8 +1599,8 @@ def _fpx_weight_only_transform( [ _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, - _int8_symm_per_token_reduced_range_quant_cutlass, - _int4_symm_per_token_quant_cutlass, _input_activation_quant_func_fp8, + _int4_symm_per_token_quant_cutlass, + _int8_symm_per_token_quant_cutlass, ] )