Skip to content

Commit 93aeb57

Browse files
authored
Merge pull request #48 from flash-algo/add-dot-triton-kernel
[PERFORMANCE OPTIMIZATION] add dot triton kernel
2 parents 3ac79a2 + e606b30 commit 93aeb57

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
2121
| [swap](./docs/swap.md) | swap vectors | $x \leftrightarrow y$ | $0$ | $4n$ | [](./kernel_course/python_ops/swap.py) | [](./kernel_course/pytorch_ops/swap.py) | [](./kernel_course/triton_ops/swap.py) || [](./tests/test_swap.py) |
2222
| [scal](./docs/scal.md) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [](./kernel_course/python_ops/scal.py) | [](./kernel_course/pytorch_ops/scal.py) | [](./kernel_course/triton_ops/scal.py) || [](./tests/test_scal.py) |
2323
| [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [](./kernel_course/python_ops/axpby.py) | [](./kernel_course/pytorch_ops/axpby.py) | [](./kernel_course/triton_ops/axpby.py) || [](./tests/test_axpby.py) |
24-
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [](./kernel_course/python_ops/dot.py) | [](./kernel_course/pytorch_ops/dot.py) | |||
24+
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [](./kernel_course/python_ops/dot.py) | [](./kernel_course/pytorch_ops/dot.py) | [](./kernel_course/triton_ops/dot.py) |||
2525
| gemv | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ ||||||
2626
| geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ ||||||
2727
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ ||||||

kernel_course/triton_ops/dot.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.autotune(
7+
configs=[
8+
triton.Config({"BLOCK_SIZE": 1024}, num_stages=4, num_warps=4),
9+
triton.Config({"BLOCK_SIZE": 2048}, num_stages=4, num_warps=8),
10+
],
11+
key=["n_elements"],
12+
)
13+
@triton.jit
14+
def dot_kernel(
15+
x_ptr,
16+
y_ptr,
17+
z_ptr,
18+
n_elements,
19+
BLOCK_SIZE: tl.constexpr,
20+
):
21+
# There are multiple program processing different blocks of data
22+
# We identify which program we are in using program_id
23+
pid = tl.program_id(axis=0)
24+
# This program will process inputs that offset from the initial pointer
25+
# For example, if you had a vector of size 256 and block_size of 64, the programs would each access the elements [0:64], [64:128], [128:192], [192:256]
26+
# We need note that offsets is a list of pointers
27+
block_start = pid * BLOCK_SIZE
28+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
29+
# Create a mask to guard memory operations against out-of-bounds accesses
30+
mask = offsets < n_elements
31+
# Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size
32+
x = tl.load(x_ptr + offsets, mask=mask)
33+
y = tl.load(y_ptr + offsets, mask=mask)
34+
# Compute z = x \cdot y
35+
z = tl.dot(tl.trans(x), y)
36+
# Write z back to DRAM
37+
tl.store(z_ptr + offsets, z, mask=mask)
38+
39+
40+
def dot(
41+
x: torch.Tensor,
42+
y: torch.Tensor,
43+
) -> torch.Tensor:
44+
"""
45+
Computes the dot product of two tensors `x` and `y` using a Triton kernel.
46+
47+
Args:
48+
x (torch.Tensor): First tensor.
49+
y (torch.Tensor): Second tensor.
50+
51+
Returns:
52+
torch.Tensor: The dot product of `x` and `y`.
53+
"""
54+
55+
# Calculate the number of elements in the input tensors
56+
n_elements = x.numel()
57+
58+
# Allocate output tensor
59+
z = torch.empty(1, device=x.device, dtype=x.dtype)
60+
61+
# The SPMD launch grid denotes the number of kernel instances that run it parallelly
62+
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
63+
# In this case, we use a 1D grid where the size is the number of blocks needed to cover all elements
64+
65+
def grid(meta):
66+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
67+
68+
dot_kernel[grid](x, y, z, n_elements)
69+
70+
return z

0 commit comments

Comments
 (0)