Skip to content

Commit 3ac79a2

Browse files
authored
Merge pull request #47 from flash-algo/add-dot-pytorch-kernel
[PERFORMANCE OPTIMIZATION] add dot pytorch kernel
2 parents 6c31b2b + 5d0bd04 commit 3ac79a2

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
2121
| [swap](./docs/swap.md) | swap vectors | $x \leftrightarrow y$ | $0$ | $4n$ | [](./kernel_course/python_ops/swap.py) | [](./kernel_course/pytorch_ops/swap.py) | [](./kernel_course/triton_ops/swap.py) || [](./tests/test_swap.py) |
2222
| [scal](./docs/scal.md) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [](./kernel_course/python_ops/scal.py) | [](./kernel_course/pytorch_ops/scal.py) | [](./kernel_course/triton_ops/scal.py) || [](./tests/test_scal.py) |
2323
| [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [](./kernel_course/python_ops/axpby.py) | [](./kernel_course/pytorch_ops/axpby.py) | [](./kernel_course/triton_ops/axpby.py) || [](./tests/test_axpby.py) |
24-
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [](./kernel_course/python_ops/dot.py) | ||||
24+
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [](./kernel_course/python_ops/dot.py) | [](./kernel_course/pytorch_ops/dot.py) ||||
2525
| gemv | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ ||||||
2626
| geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ ||||||
2727
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ ||||||

kernel_course/pytorch_ops/dot.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
3+
4+
def dot(
5+
x: torch.Tensor,
6+
y: torch.Tensor,
7+
) -> torch.Tensor:
8+
"""
9+
Computes the dot product of two tensors using PyTorch operations.
10+
11+
Args:
12+
x (torch.Tensor): First tensor.
13+
y (torch.Tensor): Second tensor.
14+
15+
Returns:
16+
torch.Tensor: The dot product of `x` and `y`.
17+
"""
18+
19+
z = torch.sum(torch.mul(x, y))
20+
21+
return z

0 commit comments

Comments
 (0)