Skip to content

Commit 44b8b58

Browse files
committed
Adds Triton dot-product kernel
Introduces an autotuned Triton kernel to compute vector dot products on GPU, enabling efficient parallel execution through configurable block sizes and SPMD grid
1 parent 3ac79a2 commit 44b8b58

File tree

1 file changed

+70
-0
lines changed
  • kernel_course/triton_ops

1 file changed

+70
-0
lines changed

kernel_course/triton_ops/dot.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.autotune(
7+
configs=[
8+
triton.Config({"BLOCK_SIZE": 1024}, num_stages=4, num_warps=4),
9+
triton.Config({"BLOCK_SIZE": 2048}, num_stages=4, num_warps=8),
10+
],
11+
key=["n_elements"],
12+
)
13+
@triton.jit
14+
def dot_kernel(
15+
x_ptr,
16+
y_ptr,
17+
z_ptr,
18+
n_elements,
19+
BLOCK_SIZE: tl.constexpr,
20+
):
21+
# There are multiple program processing different blocks of data
22+
# We identify which program we are in using program_id
23+
pid = tl.program_id(axis=0)
24+
# This program will process inputs that offset from the initial pointer
25+
# For example, if you had a vector of size 256 and block_size of 64, the programs would each access the elements [0:64], [64:128], [128:192], [192:256]
26+
# We need note that offsets is a list of pointers
27+
block_start = pid * BLOCK_SIZE
28+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
29+
# Create a mask to guard memory operations against out-of-bounds accesses
30+
mask = offsets < n_elements
31+
# Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size
32+
x = tl.load(x_ptr + offsets, mask=mask)
33+
y = tl.load(y_ptr + offsets, mask=mask)
34+
# Compute z = x \cdot y
35+
z = tl.dot(tl.trans(x), y)
36+
# Write z back to DRAM
37+
tl.store(z_ptr + offsets, z, mask=mask)
38+
39+
40+
def dot(
41+
x: torch.Tensor,
42+
y: torch.Tensor,
43+
) -> torch.Tensor:
44+
"""
45+
Computes the dot product of two tensors `x` and `y` using a Triton kernel.
46+
47+
Args:
48+
x (torch.Tensor): First tensor.
49+
y (torch.Tensor): Second tensor.
50+
51+
Returns:
52+
torch.Tensor: The dot product of `x` and `y`.
53+
"""
54+
55+
# Calculate the number of elements in the input tensors
56+
n_elements = x.numel()
57+
58+
# Allocate output tensor
59+
z = torch.empty(1, device=x.device, dtype=x.dtype)
60+
61+
# The SPMD launch grid denotes the number of kernel instances that run it parallelly
62+
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
63+
# In this case, we use a 1D grid where the size is the number of blocks needed to cover all elements
64+
65+
def grid(meta):
66+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
67+
68+
dot_kernel[grid](x, y, z, n_elements)
69+
70+
return z

0 commit comments

Comments
 (0)