|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | + |
| 6 | +@triton.autotune( |
| 7 | + configs=[ |
| 8 | + triton.Config({"BLOCK_SIZE": 1024}, num_stages=4, num_warps=4), |
| 9 | + triton.Config({"BLOCK_SIZE": 2048}, num_stages=4, num_warps=8), |
| 10 | + ], |
| 11 | + key=["n_elements"], |
| 12 | +) |
| 13 | +@triton.jit |
| 14 | +def dot_kernel( |
| 15 | + x_ptr, |
| 16 | + y_ptr, |
| 17 | + z_ptr, |
| 18 | + n_elements, |
| 19 | + BLOCK_SIZE: tl.constexpr, |
| 20 | +): |
| 21 | + # There are multiple program processing different blocks of data |
| 22 | + # We identify which program we are in using program_id |
| 23 | + pid = tl.program_id(axis=0) |
| 24 | + # This program will process inputs that offset from the initial pointer |
| 25 | + # For example, if you had a vector of size 256 and block_size of 64, the programs would each access the elements [0:64], [64:128], [128:192], [192:256] |
| 26 | + # We need note that offsets is a list of pointers |
| 27 | + block_start = pid * BLOCK_SIZE |
| 28 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 29 | + # Create a mask to guard memory operations against out-of-bounds accesses |
| 30 | + mask = offsets < n_elements |
| 31 | + # Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size |
| 32 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 33 | + y = tl.load(y_ptr + offsets, mask=mask) |
| 34 | + # Compute z = x \cdot y |
| 35 | + z = tl.dot(tl.trans(x), y) |
| 36 | + # Write z back to DRAM |
| 37 | + tl.store(z_ptr + offsets, z, mask=mask) |
| 38 | + |
| 39 | + |
| 40 | +def dot( |
| 41 | + x: torch.Tensor, |
| 42 | + y: torch.Tensor, |
| 43 | +) -> torch.Tensor: |
| 44 | + """ |
| 45 | + Computes the dot product of two tensors `x` and `y` using a Triton kernel. |
| 46 | +
|
| 47 | + Args: |
| 48 | + x (torch.Tensor): First tensor. |
| 49 | + y (torch.Tensor): Second tensor. |
| 50 | +
|
| 51 | + Returns: |
| 52 | + torch.Tensor: The dot product of `x` and `y`. |
| 53 | + """ |
| 54 | + |
| 55 | + # Calculate the number of elements in the input tensors |
| 56 | + n_elements = x.numel() |
| 57 | + |
| 58 | + # Allocate output tensor |
| 59 | + z = torch.empty(1, device=x.device, dtype=x.dtype) |
| 60 | + |
| 61 | + # The SPMD launch grid denotes the number of kernel instances that run it parallelly |
| 62 | + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int] |
| 63 | + # In this case, we use a 1D grid where the size is the number of blocks needed to cover all elements |
| 64 | + |
| 65 | + def grid(meta): |
| 66 | + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 67 | + |
| 68 | + dot_kernel[grid](x, y, z, n_elements) |
| 69 | + |
| 70 | + return z |
0 commit comments