Skip to content

Commit b869797

Browse files
authored
Merge pull request #52 from flash-algo/add-gemv-pytorch-kernel
[FEATURE SUPPORT] add gemv pytorch kernel
2 parents 52d0dd9 + e9824d0 commit b869797

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
2222
| [scal](./docs/scal.md) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [](./kernel_course/python_ops/scal.py) | [](./kernel_course/pytorch_ops/scal.py) | [](./kernel_course/triton_ops/scal.py) || [](./tests/test_scal.py) |
2323
| [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [](./kernel_course/python_ops/axpby.py) | [](./kernel_course/pytorch_ops/axpby.py) | [](./kernel_course/triton_ops/axpby.py) || [](./tests/test_axpby.py) |
2424
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [](./kernel_course/python_ops/dot.py) | [](./kernel_course/pytorch_ops/dot.py) | [](./kernel_course/triton_ops/dot.py) || [](./tests/test_dot.py) |
25-
| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [](./kernel_course/python_ops/gemv.py) | ||||
25+
| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [](./kernel_course/python_ops/gemv.py) | [](./kernel_course/pytorch_ops/gemv.py) ||||
2626
| geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ ||||||
2727
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ ||||||
2828

kernel_course/python_ops/gemv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def gemv(
1818
y (torch.Tensor): Vector tensor to be updated.
1919
alpha (float): Scaling factor for the product of `A` and `x`.
2020
beta (float): Scaling factor for `y`.
21-
21+
2222
Returns:
2323
torch.Tensor: The updated tensor `y`.
2424
"""

kernel_course/pytorch_ops/gemv.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
3+
4+
def gemv(
5+
A: torch.Tensor,
6+
x: torch.Tensor,
7+
y: torch.Tensor,
8+
alpha: float,
9+
beta: float,
10+
) -> torch.Tensor:
11+
"""
12+
Updates tensor `y` by adding the product of matrix `A` and vector `x`
13+
scaled by `alpha`, and `y` scaled by `beta` using PyTorch operations.
14+
15+
Args:
16+
A (torch.Tensor): Matrix tensor.
17+
x (torch.Tensor): Vector tensor to be multiplied with `A`.
18+
y (torch.Tensor): Vector tensor to be updated.
19+
alpha (float): Scaling factor for the product of `A` and `x`.
20+
beta (float): Scaling factor for `y`.
21+
22+
Returns:
23+
torch.Tensor: The updated tensor `y`.
24+
"""
25+
26+
y = torch.add(torch.mul(torch.matmul(A, x), alpha), torch.mul(y, beta))
27+
28+
return y

0 commit comments

Comments
 (0)