File tree Expand file tree Collapse file tree 2 files changed +22
-1
lines changed
kernel_course/pytorch_ops Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Original file line number Diff line number Diff line change @@ -21,7 +21,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
2121| [ swap] ( ./docs/swap.md ) | swap vectors | $x \leftrightarrow y$ | $0$ | $4n$ | [ ✅] ( ./kernel_course/python_ops/swap.py ) | [ ✅] ( ./kernel_course/pytorch_ops/swap.py ) | [ ✅] ( ./kernel_course/triton_ops/swap.py ) | ❌ | [ ✅] ( ./tests/test_swap.py ) |
2222| [ scal] ( ./docs/scal.md ) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [ ✅] ( ./kernel_course/python_ops/scal.py ) | [ ✅] ( ./kernel_course/pytorch_ops/scal.py ) | [ ✅] ( ./kernel_course/triton_ops/scal.py ) | ❌ | [ ✅] ( ./tests/test_scal.py ) |
2323| [ axpby] ( ./docs/axpby.md ) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [ ✅] ( ./kernel_course/python_ops/axpby.py ) | [ ✅] ( ./kernel_course/pytorch_ops/axpby.py ) | [ ✅] ( ./kernel_course/triton_ops/axpby.py ) | ❌ | [ ✅] ( ./tests/test_axpby.py ) |
24- | [ dot] ( ./docs/dot.md ) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [ ✅] ( ./kernel_course/python_ops/dot.py ) | ❌ | ❌ | ❌ | ❌ |
24+ | [ dot] ( ./docs/dot.md ) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [ ✅] ( ./kernel_course/python_ops/dot.py ) | [ ✅ ] ( ./kernel_course/pytorch_ops/dot.py ) | ❌ | ❌ | ❌ |
2525| gemv | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | ❌ | ❌ | ❌ | ❌ | ❌ |
2626| geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | ❌ | ❌ | ❌ | ❌ | ❌ |
2727| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ |
Original file line number Diff line number Diff line change 1+ import torch
2+
3+
4+ def dot (
5+ x : torch .Tensor ,
6+ y : torch .Tensor ,
7+ ) -> torch .Tensor :
8+ """
9+ Computes the dot product of two tensors using PyTorch operations.
10+
11+ Args:
12+ x (torch.Tensor): First tensor.
13+ y (torch.Tensor): Second tensor.
14+
15+ Returns:
16+ torch.Tensor: The dot product of `x` and `y`.
17+ """
18+
19+ z = torch .sum (torch .mul (x , y ))
20+
21+ return z
You can’t perform that action at this time.
0 commit comments