Skip to content

Commit dedcc48

Browse files
authored
Merge pull request #49 from flash-algo/add-dot-test-script
[PERFORMANCE OPTIMIZATION] add dot test script
2 parents 59f9e3b + a7f9fb2 commit dedcc48

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
2121
| [swap](./docs/swap.md) | swap vectors | $x \leftrightarrow y$ | $0$ | $4n$ | [](./kernel_course/python_ops/swap.py) | [](./kernel_course/pytorch_ops/swap.py) | [](./kernel_course/triton_ops/swap.py) || [](./tests/test_swap.py) |
2222
| [scal](./docs/scal.md) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [](./kernel_course/python_ops/scal.py) | [](./kernel_course/pytorch_ops/scal.py) | [](./kernel_course/triton_ops/scal.py) || [](./tests/test_scal.py) |
2323
| [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [](./kernel_course/python_ops/axpby.py) | [](./kernel_course/pytorch_ops/axpby.py) | [](./kernel_course/triton_ops/axpby.py) || [](./tests/test_axpby.py) |
24-
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [](./kernel_course/python_ops/dot.py) | [](./kernel_course/pytorch_ops/dot.py) | [](./kernel_course/triton_ops/dot.py) || |
24+
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [](./kernel_course/python_ops/dot.py) | [](./kernel_course/pytorch_ops/dot.py) | [](./kernel_course/triton_ops/dot.py) || [](./tests/test_dot.py) |
2525
| gemv | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ ||||||
2626
| geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ ||||||
2727
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ ||||||

tests/test_dot.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pytest
2+
import torch
3+
4+
from kernel_course import testing
5+
from kernel_course.python_ops import dot as python_dot
6+
7+
try:
8+
from kernel_course.pytorch_ops import dot as pytorch_dot
9+
10+
HAS_PYTORCH = True
11+
except Exception:
12+
pytorch_dot = None
13+
HAS_PYTORCH = False
14+
15+
try:
16+
from kernel_course.triton_ops import dot as triton_dot
17+
18+
HAS_TRITON = True
19+
except Exception:
20+
triton_dot = None
21+
HAS_TRITON = False
22+
23+
try:
24+
from kernel_course.cute_ops import dot as cute_dot
25+
26+
HAS_CUTE = True
27+
except Exception:
28+
cute_dot = None
29+
HAS_CUTE = False
30+
31+
32+
def factory(
33+
numel: int,
34+
device: torch.device,
35+
dtype: torch.dtype = torch.float32,
36+
):
37+
x = torch.linspace(0.0, 1.0, steps=numel, device=device, dtype=dtype)
38+
y = torch.linspace(0.0, 1.0, steps=numel, device=device, dtype=dtype)
39+
return (x, y), {}
40+
41+
42+
@pytest.mark.parametrize(
43+
"device",
44+
[
45+
pytest.param(
46+
torch.device("cuda"),
47+
marks=pytest.mark.skipif(
48+
not torch.cuda.is_available(), reason="requires CUDA"
49+
),
50+
),
51+
pytest.param(
52+
torch.device("mps"),
53+
marks=pytest.mark.skipif(
54+
not torch.backends.mps.is_available(), reason="requires MPS"
55+
),
56+
),
57+
],
58+
)
59+
@pytest.mark.parametrize(
60+
"dtype",
61+
[torch.float32, torch.float16, torch.bfloat16],
62+
)
63+
@pytest.mark.parametrize(
64+
"numel",
65+
[1 << 4, 1 << 8, 1 << 16],
66+
)
67+
def test_dot_benchmark(device: torch.device, dtype: torch.dtype, numel: int) -> None:
68+
impls = testing.get_impls(
69+
python_impl=python_dot.dot,
70+
pytorch_impl=pytorch_dot.dot if HAS_PYTORCH else None,
71+
triton_impl=triton_dot.dot if HAS_TRITON else None,
72+
cute_impl=cute_dot.dot if HAS_CUTE else None,
73+
)
74+
75+
# Benchmark each implementation
76+
config = testing.BenchmarkConfig(warmup=3, repeat=1_000)
77+
results = testing.run_benchmarks(
78+
impls,
79+
lambda: factory(numel, device, dtype),
80+
flops=2 * numel,
81+
config=config,
82+
)
83+
84+
testing.show_benchmarks(results)

0 commit comments

Comments
 (0)