-
Notifications
You must be signed in to change notification settings - Fork 223
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CUTLASS-based row-wise scaled sparse FP8 kernel
- Loading branch information
1 parent
38e36de
commit a7197f7
Showing
30 changed files
with
1,622 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import pandas as pd | ||
import torch | ||
from tqdm import tqdm | ||
from triton.testing import do_bench | ||
|
||
from torchao.ops import ( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
to_sparse_semi_structured_cutlass_sm9x_f8, | ||
) | ||
|
||
|
||
def benchmark_microseconds(f, *args): | ||
return do_bench(lambda: f(*args), return_mode="median") * 1e3 | ||
|
||
|
||
def get_problem(m: int, n: int, k: int): | ||
dev = torch.device("cuda") | ||
|
||
A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2) | ||
A_scale = torch.randn((m,), dtype=torch.half, device=dev) | ||
B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn) | ||
B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B) | ||
B_scale = torch.randn((n,), dtype=torch.half, device=dev) | ||
|
||
return A, A_scale, B_sp, B_meta, B_scale | ||
|
||
|
||
def benchmark(m: int, k: int, n: int): | ||
dev = torch.device("cuda") | ||
A_ref = torch.randn((m, k), dtype=torch.half, device=dev) | ||
B_ref = torch.randn((n, k), dtype=torch.half, device=dev) | ||
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) | ||
|
||
A, A_scale, B_sp, B_meta, B_scale = get_problem(m, n, k) | ||
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale | ||
) | ||
|
||
return { | ||
"m": m, | ||
"k": k, | ||
"n": n, | ||
"fp16_latency (ms)": fp16_time, | ||
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time, | ||
"f8f8 speedup (d/s)": fp16_time | ||
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
k_vals = (8192, 8192, 8192, 28672) | ||
n_vals = (8192, 10240, 57344, 8192) | ||
|
||
results = [] | ||
for m in tqdm([1 << i for i in range(10)]): | ||
for n, k in zip(n_vals, k_vals): | ||
results.append(benchmark(m, k, n)) | ||
|
||
df = pd.DataFrame(results) | ||
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False) | ||
print(df.to_markdown(index=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import itertools | ||
import random | ||
|
||
import pytest | ||
import torch | ||
from torch.testing._internal.common_cuda import SM90OrLater | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 | ||
|
||
from torchao.dtypes import ( | ||
Float8Layout, | ||
to_affine_quantized_floatx, | ||
) | ||
from torchao.ops import ( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
to_sparse_semi_structured_cutlass_sm9x_f8, | ||
) | ||
from torchao.quantization.utils import _get_per_token_block_size | ||
from torchao.sparsity.utils import create_semi_structured_tensor | ||
|
||
X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)] | ||
XQ_WQ_DTYPES = [ | ||
(torch.float8_e4m3fn, torch.float8_e4m3fn), | ||
(torch.float8_e4m3fn, torch.float8_e5m2), | ||
(torch.float8_e5m2, torch.float8_e4m3fn), | ||
(torch.float8_e5m2, torch.float8_e5m2), | ||
] | ||
BATCH_SIZE = [1, 4] | ||
SIZE_MNK = [ | ||
(2, 128, 256), | ||
(3, 128, 256), | ||
(13, 128, 256), | ||
(27, 128, 128), | ||
(33, 128, 64), | ||
(65, 128, 32), | ||
] | ||
USE_BIAS = [False, True] | ||
BIAS_DTYPE = [torch.float16] | ||
TEST_PARAMS = list( | ||
itertools.product( | ||
X_W_DTYPES, | ||
XQ_WQ_DTYPES, | ||
BATCH_SIZE, | ||
SIZE_MNK, | ||
USE_BIAS, | ||
BIAS_DTYPE, | ||
) | ||
) | ||
|
||
|
||
def run_test_for_op( | ||
op, | ||
x_dtype, | ||
w_dtype, | ||
xq_dtype, | ||
wq_dtype, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
): | ||
random.seed(0) ## for create_semi_structured_tensor() | ||
|
||
size_m, size_n, size_k = size_mnk | ||
|
||
x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda") | ||
w = create_semi_structured_tensor(size_n, size_k, dtype=w_dtype) | ||
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None | ||
|
||
if not TORCH_VERSION_AT_LEAST_2_4: | ||
pytest.xfail("torch.nn.RMSNorm not supported") | ||
rms_norm = torch.nn.RMSNorm(size_k).to("cuda") | ||
x = rms_norm(x).to(x_dtype) | ||
w = rms_norm(w).to(w_dtype) | ||
|
||
x_aqt = to_affine_quantized_floatx( | ||
input_float=x, | ||
target_dtype=xq_dtype, | ||
block_size=_get_per_token_block_size(x), | ||
_layout=Float8Layout(mm_config=None), | ||
) | ||
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain() | ||
assert zero_points is None | ||
|
||
w_aqt = to_affine_quantized_floatx( | ||
input_float=w, | ||
target_dtype=wq_dtype, | ||
block_size=_get_per_token_block_size(w), | ||
_layout=Float8Layout(mm_config=None), | ||
) | ||
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain() | ||
assert zero_points is None | ||
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq) | ||
wq_sp_scales = wq_scales | ||
|
||
xq_2d = xq.view(-1, xq.shape[-1]) | ||
size_m_2d = xq_2d.shape[0] | ||
output_ref = ( | ||
(xq_2d.float() @ wq.float().T) | ||
* xq_scales.view(size_m_2d, 1) | ||
* wq_scales.view(1, size_n) | ||
) | ||
if bias is not None: | ||
output_ref += bias | ||
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,)) | ||
|
||
fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias) | ||
try: | ||
output = op(*fn_inputs) | ||
except NotImplementedError: | ||
pytest.xfail("operator not implemented") | ||
|
||
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=5e-3) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices") | ||
@pytest.mark.parametrize( | ||
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype", | ||
TEST_PARAMS, | ||
) | ||
def test_rowwise_scaled_linear_sparse_cutlass_f8f8( | ||
x_w_dtypes, | ||
xq_wq_dtypes, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
): | ||
run_test_for_op( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
*x_w_dtypes, | ||
*xq_wq_dtypes, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
) |
Oops, something went wrong.