From 0bf2375dddbf127a076d8356c192cb964915a590 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Thu, 22 Jan 2026 16:23:55 +0800 Subject: [PATCH 1/6] add gemv --- top/kernels/gemm/__init__.py | 4 +- top/kernels/gemm/gemm.py | 135 +++++++++++++++++++++++++++++++++++ top/ops/gemm.py | 11 +-- 3 files changed, 144 insertions(+), 6 deletions(-) diff --git a/top/kernels/gemm/__init__.py b/top/kernels/gemm/__init__.py index 16eec2d..ee84fd5 100644 --- a/top/kernels/gemm/__init__.py +++ b/top/kernels/gemm/__init__.py @@ -1,3 +1,3 @@ -from .gemm import GemmKernel +from .gemm import GemmKernel, GemvKernel -__all__ = ["GemmKernel"] +__all__ = ["GemmKernel", "GemvKernel"] diff --git a/top/kernels/gemm/gemm.py b/top/kernels/gemm/gemm.py index 1b237d9..d40a098 100644 --- a/top/kernels/gemm/gemm.py +++ b/top/kernels/gemm/gemm.py @@ -189,4 +189,139 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: self.config["threads"], self.config["enable_rasteration"], a, b) +def _gemv_kernel(n: int, k: int, dtype: str = "float16") -> Callable: + accum_dtype = "float" + + @tilelang.jit(out_idx=[-1], compile_flags=["-O3", "-DENABLE_BF16"]) + def _gemv_func( + block_n: int, + reduce_threads: int, + tile_k: int = 8, + ) -> Callable: + + # MAX_TRANSACTION_SIZE_IN_BITS = 128 + # tile_k = MAX_TRANSACTION_SIZE_IN_BITS // 16 + block_k = reduce_threads * tile_k + + @T.prim_func + def _gemv_main( + a: T.Buffer((k,), dtype), + b: T.Buffer((n, k), dtype), + c: T.Buffer((n,), dtype), + ): + with T.Kernel(T.ceildiv(n, block_n), threads=(block_n, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + a_local = T.alloc_local((tile_k,), dtype) + b_local = T.alloc_local((tile_k,), dtype) + c_accum = T.alloc_local((1,), accum_dtype) + + T.clear(c_accum) + for bk in T.serial(T.ceildiv(k, block_k)): + for _k in T.vectorized(tile_k): + a_local[_k] = a[bk * block_k + tk * tile_k + _k] + b_local[_k] = b[bn * block_n + tn, bk * block_k + tk * tile_k + _k] + for _k in T.serial(tile_k): + c_accum[0] += a_local[_k].astype(accum_dtype) * b_local[_k].astype( + accum_dtype) + c_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + c_accum[0], + True, + c_reduced[0], + tk, + dtype="handle", + )) + + c[bn * block_n + tn] = c_reduced[0] + + return _gemv_main + + return _gemv_func + + +@torch.library.custom_op("top::gemv_wrapped_kernel", mutates_args=()) +def _gemv_wrapped_kernel( + n: int, + k: int, + dtype: str, + block_n: int, + reduce_threads: int, + tile_k: int, + a: torch.Tensor, + b: torch.Tensor, +) -> torch.Tensor: + return _gemv_kernel(n, k, dtype)(block_n, reduce_threads, tile_k)(a, b) + + +@_gemv_wrapped_kernel.register_fake +def _(n: int, k: int, # noqa: U100 + dtype: str, block_n: int, reduce_threads: int, tile_k: int, # noqa: U100 + *inputs: tuple[torch.Tensor, ...]) -> torch.Tensor: # noqa: U100 + return torch.empty((n,), dtype=inputs[0].dtype, device=inputs[0].device) + + +class GemvKernel(Kernel): + supported_archs: list[int] = [90] + + def __init__(self, + n: int, + k: int, + dtype: torch.dtype, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.n = n + self.k = k + self.dtype = dtype + + self.kernel = _gemv_kernel(n, k, self.dtype_str) + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + # From tilelang/examples/gemm/example_gemm_autotune.py + sm_version = get_sm_version() + + if sm_version in {90}: + return { + "block_n": 32, + "reduce_threads": 32, + "tile_k": 8, + } + + return { + "block_n": 128, + "reduce_threads": 32, + "tile_k": 8, + } + + @property + def autotune_configs(self) -> list[dict]: + # From tilelang/examples/gemm/example_gemm_autotune.py + block_n = [64, 128, 256] + reduce_threads = [16, 32] + tile_k = [8, 16] + _configs = list(itertools.product(block_n, reduce_threads, tile_k)) + + return [{ + 'block_n': c[0], + 'reduce_threads': c[1], + 'tile_k': c[2], + } for c in _configs] + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + a = a.flatten().contiguous() + return _gemv_wrapped_kernel(self.n, self.k, self.dtype_str, self.config["block_n"], + self.config["reduce_threads"], self.config["tile_k"], a, b) + + # TODO: add persistent, split-k, steam-k... diff --git a/top/ops/gemm.py b/top/ops/gemm.py index d654be6..72727ed 100644 --- a/top/ops/gemm.py +++ b/top/ops/gemm.py @@ -2,7 +2,7 @@ import torch -from top.kernels.gemm import GemmKernel +from top.kernels.gemm import GemmKernel, GemvKernel from top.kernels.kernel import Kernel from .op import Op @@ -28,12 +28,15 @@ def __init__(self, self.dtype = dtype self.dispatch_kernel(kernel_map) - self.kernel = self.kernel_map["gemm_kernel"]( - m, n, k, self.dtype, tune=tune, trans_a=trans_a, trans_b=trans_b) + if m == 1: + self.kernel = self.kernel_map["gemv_kernel"](n, k, self.dtype, tune=tune) + else: + self.kernel = self.kernel_map["gemm_kernel"]( + m, n, k, self.dtype, tune=tune, trans_a=trans_a, trans_b=trans_b) @property def default_kernel_map(self) -> Dict[str, Kernel]: - return {"gemm_kernel": GemmKernel} + return {"gemm_kernel": GemmKernel, "gemv_kernel": GemvKernel} def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return self.kernel(a, b) From a4eb656f345808656e0d22e02d96b7882ed32a6e Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Tue, 10 Feb 2026 20:45:19 +0800 Subject: [PATCH 2/6] fix bug in ops --- benchmarks/__init__.py | 2 + top/kernels/gemm/__init__.py | 4 +- top/kernels/gemm/gemm.py | 135 ----------------------------------- top/ops/__init__.py | 2 + top/ops/gemm.py | 11 ++- 5 files changed, 10 insertions(+), 144 deletions(-) diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 07c26f9..c7ecad7 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -19,6 +19,7 @@ MultiHeadAttentionDecodeBenchmark, ) from .gemm import GemmBenchmark, MatMulBenchmark +from .gemv import GemvBenchmark from .grouped_gemm import ( GroupedGemmBenchmark, GroupedGemmNNBenchmark, @@ -36,6 +37,7 @@ "GroupQueryAttentionFwdBenchmark", "GroupQueryAttentionBwdBenchmark", "GemmBenchmark", + "GemvBenchmark", "MultiHeadAttentionDecodeBenchmark", "GroupQueryAttentionDecodeBenchmark", "MultiHeadLatentAttentionDecodeBenchmark", diff --git a/top/kernels/gemm/__init__.py b/top/kernels/gemm/__init__.py index ee84fd5..16eec2d 100644 --- a/top/kernels/gemm/__init__.py +++ b/top/kernels/gemm/__init__.py @@ -1,3 +1,3 @@ -from .gemm import GemmKernel, GemvKernel +from .gemm import GemmKernel -__all__ = ["GemmKernel", "GemvKernel"] +__all__ = ["GemmKernel"] diff --git a/top/kernels/gemm/gemm.py b/top/kernels/gemm/gemm.py index d40a098..1b237d9 100644 --- a/top/kernels/gemm/gemm.py +++ b/top/kernels/gemm/gemm.py @@ -189,139 +189,4 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: self.config["threads"], self.config["enable_rasteration"], a, b) -def _gemv_kernel(n: int, k: int, dtype: str = "float16") -> Callable: - accum_dtype = "float" - - @tilelang.jit(out_idx=[-1], compile_flags=["-O3", "-DENABLE_BF16"]) - def _gemv_func( - block_n: int, - reduce_threads: int, - tile_k: int = 8, - ) -> Callable: - - # MAX_TRANSACTION_SIZE_IN_BITS = 128 - # tile_k = MAX_TRANSACTION_SIZE_IN_BITS // 16 - block_k = reduce_threads * tile_k - - @T.prim_func - def _gemv_main( - a: T.Buffer((k,), dtype), - b: T.Buffer((n, k), dtype), - c: T.Buffer((n,), dtype), - ): - with T.Kernel(T.ceildiv(n, block_n), threads=(block_n, reduce_threads)) as bn: - tn = T.get_thread_binding(0) - tk = T.get_thread_binding(1) - a_local = T.alloc_local((tile_k,), dtype) - b_local = T.alloc_local((tile_k,), dtype) - c_accum = T.alloc_local((1,), accum_dtype) - - T.clear(c_accum) - for bk in T.serial(T.ceildiv(k, block_k)): - for _k in T.vectorized(tile_k): - a_local[_k] = a[bk * block_k + tk * tile_k + _k] - b_local[_k] = b[bn * block_n + tn, bk * block_k + tk * tile_k + _k] - for _k in T.serial(tile_k): - c_accum[0] += a_local[_k].astype(accum_dtype) * b_local[_k].astype( - accum_dtype) - c_reduced = T.alloc_local((1,), accum_dtype) - with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), - "reduce_scope", - T.reinterpret(T.uint64(0), dtype="handle"), - ): - T.evaluate( - T.tvm_thread_allreduce( - T.uint32(1), - c_accum[0], - True, - c_reduced[0], - tk, - dtype="handle", - )) - - c[bn * block_n + tn] = c_reduced[0] - - return _gemv_main - - return _gemv_func - - -@torch.library.custom_op("top::gemv_wrapped_kernel", mutates_args=()) -def _gemv_wrapped_kernel( - n: int, - k: int, - dtype: str, - block_n: int, - reduce_threads: int, - tile_k: int, - a: torch.Tensor, - b: torch.Tensor, -) -> torch.Tensor: - return _gemv_kernel(n, k, dtype)(block_n, reduce_threads, tile_k)(a, b) - - -@_gemv_wrapped_kernel.register_fake -def _(n: int, k: int, # noqa: U100 - dtype: str, block_n: int, reduce_threads: int, tile_k: int, # noqa: U100 - *inputs: tuple[torch.Tensor, ...]) -> torch.Tensor: # noqa: U100 - return torch.empty((n,), dtype=inputs[0].dtype, device=inputs[0].device) - - -class GemvKernel(Kernel): - supported_archs: list[int] = [90] - - def __init__(self, - n: int, - k: int, - dtype: torch.dtype, - config: Optional[dict] = None, - tune: bool = False) -> None: - super().__init__() - self.n = n - self.k = k - self.dtype = dtype - - self.kernel = _gemv_kernel(n, k, self.dtype_str) - - self.init_config(config, tune) - - @property - def default_config(self) -> dict: - # From tilelang/examples/gemm/example_gemm_autotune.py - sm_version = get_sm_version() - - if sm_version in {90}: - return { - "block_n": 32, - "reduce_threads": 32, - "tile_k": 8, - } - - return { - "block_n": 128, - "reduce_threads": 32, - "tile_k": 8, - } - - @property - def autotune_configs(self) -> list[dict]: - # From tilelang/examples/gemm/example_gemm_autotune.py - block_n = [64, 128, 256] - reduce_threads = [16, 32] - tile_k = [8, 16] - _configs = list(itertools.product(block_n, reduce_threads, tile_k)) - - return [{ - 'block_n': c[0], - 'reduce_threads': c[1], - 'tile_k': c[2], - } for c in _configs] - - def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - a = a.flatten().contiguous() - return _gemv_wrapped_kernel(self.n, self.k, self.dtype_str, self.config["block_n"], - self.config["reduce_threads"], self.config["tile_k"], a, b) - - # TODO: add persistent, split-k, steam-k... diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 6260df2..d963a62 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -5,6 +5,7 @@ from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp from .gemm import GemmOp +from .gemv import GemvOp from .gqa import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp from .gqa_decode_paged import GroupQueryAttentionDecodePagedWithKVCacheOp @@ -23,6 +24,7 @@ "GroupQueryAttentionFwdOp", "GroupQueryAttentionBwdOp", "GemmOp", + "GemvOp", "MultiHeadAttentionDecodeWithKVCacheOp", "MultiHeadAttentionDecodePagedWithKVCacheOp", "GroupQueryAttentionDecodeWithKVCacheOp", diff --git a/top/ops/gemm.py b/top/ops/gemm.py index 72727ed..d654be6 100644 --- a/top/ops/gemm.py +++ b/top/ops/gemm.py @@ -2,7 +2,7 @@ import torch -from top.kernels.gemm import GemmKernel, GemvKernel +from top.kernels.gemm import GemmKernel from top.kernels.kernel import Kernel from .op import Op @@ -28,15 +28,12 @@ def __init__(self, self.dtype = dtype self.dispatch_kernel(kernel_map) - if m == 1: - self.kernel = self.kernel_map["gemv_kernel"](n, k, self.dtype, tune=tune) - else: - self.kernel = self.kernel_map["gemm_kernel"]( - m, n, k, self.dtype, tune=tune, trans_a=trans_a, trans_b=trans_b) + self.kernel = self.kernel_map["gemm_kernel"]( + m, n, k, self.dtype, tune=tune, trans_a=trans_a, trans_b=trans_b) @property def default_kernel_map(self) -> Dict[str, Kernel]: - return {"gemm_kernel": GemmKernel, "gemv_kernel": GemvKernel} + return {"gemm_kernel": GemmKernel} def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return self.kernel(a, b) From 43bc32144a08b1bfdce5543c31216641205370bd Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 11 Feb 2026 11:16:10 +0800 Subject: [PATCH 3/6] add gemv --- benchmarks/gemv/__init__.py | 3 + benchmarks/gemv/gemv.py | 39 +++++++++ tests/ops/test_gemv.py | 28 +++++++ top/kernels/gemv/__init__.py | 3 + top/kernels/gemv/gemv.py | 148 +++++++++++++++++++++++++++++++++++ top/ops/gemv.py | 34 ++++++++ 6 files changed, 255 insertions(+) create mode 100644 benchmarks/gemv/__init__.py create mode 100644 benchmarks/gemv/gemv.py create mode 100644 tests/ops/test_gemv.py create mode 100644 top/kernels/gemv/__init__.py create mode 100644 top/kernels/gemv/gemv.py create mode 100644 top/ops/gemv.py diff --git a/benchmarks/gemv/__init__.py b/benchmarks/gemv/__init__.py new file mode 100644 index 0000000..c0a5753 --- /dev/null +++ b/benchmarks/gemv/__init__.py @@ -0,0 +1,3 @@ +from .gemv import GemvBenchmark + +__all__ = ["GemvBenchmark"] diff --git a/benchmarks/gemv/gemv.py b/benchmarks/gemv/gemv.py new file mode 100644 index 0000000..6537bb0 --- /dev/null +++ b/benchmarks/gemv/gemv.py @@ -0,0 +1,39 @@ +from typing import Tuple + +import torch + +from benchmarks.benchmark import Benchmark +from top.ops import GemvOp + + +class GemvBenchmark(Benchmark): + + op_type = GemvOp + + def __init__(self, n: int, k: int, dtype: torch.dtype): + self.n = n + self.k = k + self.dtype = dtype + + @property + def total_flops(self) -> float: + return 2.0 * self.n * self.k + + @property + def total_memory(self) -> int: + return (self.k + self.k * self.n + self.n) * self.dtype.itemsize + + def gen_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: + shape_a = (self.k,) + a = torch.randn(*shape_a, device='cuda', dtype=self.dtype) + shape_b = (self.n, self.k) + b = torch.randn(*shape_b, device='cuda', dtype=self.dtype) + return a, b + + def ref_program(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + # return torch.mv(b, a) + return b @ a + + def baseline_profile(self, *inputs, warmup=100, rep=10, device="cuda:0") -> None: + return super().baseline_profile( + self.ref_program, *inputs, backend="torch", warmup=warmup, rep=rep, device=device) diff --git a/tests/ops/test_gemv.py b/tests/ops/test_gemv.py new file mode 100644 index 0000000..7faf8d3 --- /dev/null +++ b/tests/ops/test_gemv.py @@ -0,0 +1,28 @@ +import torch +import pytest + +from benchmarks import GemvBenchmark +from top.ops import GemvOp + + +@pytest.mark.parametrize( + "n, k, dtype, tune", + [ + (1024, 1024, torch.float16, False), + (64, 512, torch.float16, False), + (64, 64, torch.float16, False), + ], +) +def test_gemv(n: int, k: int, dtype: torch.dtype, tune: bool) -> None: + op = GemvOp(n, k, dtype=dtype, tune=tune) + benchmark = GemvBenchmark(n, k, dtype) + + inputs = benchmark.gen_inputs() + + benchmark.check(op, *inputs, atol=1e-5, rtol=1e-3) + benchmark.profile(op, *inputs) + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-vvs"]) diff --git a/top/kernels/gemv/__init__.py b/top/kernels/gemv/__init__.py new file mode 100644 index 0000000..658317b --- /dev/null +++ b/top/kernels/gemv/__init__.py @@ -0,0 +1,3 @@ +from .gemv import GemvKernel + +__all__ = ["GemvKernel"] diff --git a/top/kernels/gemv/gemv.py b/top/kernels/gemv/gemv.py new file mode 100644 index 0000000..4e1156a --- /dev/null +++ b/top/kernels/gemv/gemv.py @@ -0,0 +1,148 @@ +import itertools +from typing import Callable, Optional + +import tilelang +import tilelang.language as T +import torch + +from top.kernels.kernel import Kernel +from top.utils import get_sm_version + +__all__ = [ + 'GemvKernel', +] + + +def _gemv_kernel(n: int, k: int, dtype: str = "float16") -> Callable: + accum_dtype = "float" + + @tilelang.jit(out_idx=[-1], compile_flags=["-O3", "-DENABLE_BF16"]) + def _gemv_func( + block_n: int, + reduce_threads: int, + tile_k: int = 8, + ) -> Callable: + + # MAX_TRANSACTION_SIZE_IN_BITS = 128 + # tile_k = MAX_TRANSACTION_SIZE_IN_BITS // 16 + block_k = reduce_threads * tile_k + + @T.prim_func + def _gemv_main( + a: T.Buffer((k,), dtype), + b: T.Buffer((n, k), dtype), + c: T.Buffer((n,), dtype), + ): + with T.Kernel(T.ceildiv(n, block_n), threads=(block_n, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + a_local = T.alloc_local((tile_k,), dtype) + b_local = T.alloc_local((tile_k,), dtype) + c_accum = T.alloc_local((1,), accum_dtype) + + T.clear(c_accum) + for bk in T.serial(T.ceildiv(k, block_k)): + for _k in T.vectorized(tile_k): + a_local[_k] = a[bk * block_k + tk * tile_k + _k] + b_local[_k] = b[bn * block_n + tn, bk * block_k + tk * tile_k + _k] + for _k in T.serial(tile_k): + c_accum[0] += a_local[_k].astype(accum_dtype) * b_local[_k].astype( + accum_dtype) + c_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + c_accum[0], + True, + c_reduced[0], + tk, + dtype="handle", + )) + + c[bn * block_n + tn] = c_reduced[0] + + return _gemv_main + + return _gemv_func + + +@torch.library.custom_op("top::gemv_wrapped_kernel", mutates_args=()) +def _gemv_wrapped_kernel( + n: int, + k: int, + dtype: str, + block_n: int, + reduce_threads: int, + tile_k: int, + a: torch.Tensor, + b: torch.Tensor, +) -> torch.Tensor: + return _gemv_kernel(n, k, dtype)(block_n, reduce_threads, tile_k)(a, b) + + +@_gemv_wrapped_kernel.register_fake +def _(n: int, k: int, # noqa: U100 + dtype: str, block_n: int, reduce_threads: int, tile_k: int, # noqa: U100 + *inputs: tuple[torch.Tensor, ...]) -> torch.Tensor: # noqa: U100 + return torch.empty((n,), dtype=inputs[0].dtype, device=inputs[0].device) + + +class GemvKernel(Kernel): + supported_archs: list[int] = [90] + + def __init__(self, + n: int, + k: int, + dtype: torch.dtype, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.n = n + self.k = k + self.dtype = dtype + + self.kernel = _gemv_kernel(n, k, self.dtype_str) + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + # From tilelang/examples/gemm/example_gemm_autotune.py + sm_version = get_sm_version() + + if sm_version in {90}: + return { + "block_n": 32, + "reduce_threads": 1, + "tile_k": 8, + } + + return { + "block_n": 128, + "reduce_threads": 32, + "tile_k": 8, + } + + @property + def autotune_configs(self) -> list[dict]: + # From tilelang/examples/gemm/example_gemm_autotune.py + block_n = [64, 128, 256] + reduce_threads = [16, 32] + tile_k = [8, 16] + _configs = list(itertools.product(block_n, reduce_threads, tile_k)) + + return [{ + 'block_n': c[0], + 'reduce_threads': c[1], + 'tile_k': c[2], + } for c in _configs] + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + a = a.flatten().contiguous() + return _gemv_wrapped_kernel(self.n, self.k, self.dtype_str, self.config["block_n"], + self.config["reduce_threads"], self.config["tile_k"], a, b) diff --git a/top/ops/gemv.py b/top/ops/gemv.py new file mode 100644 index 0000000..15277af --- /dev/null +++ b/top/ops/gemv.py @@ -0,0 +1,34 @@ +from typing import Dict, Optional + +import torch + +from top.kernels.gemv import GemvKernel +from top.kernels.kernel import Kernel + +from .op import Op + +__all__ = ['GemvOp'] + + +class GemvOp(Op): + + def __init__(self, + n: int, + k: int, + dtype: torch.dtype = torch.float16, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune: bool = False) -> None: + self.N = n + self.K = k + + self.dtype = dtype + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["gemv_kernel"](n, k, self.dtype, tune=tune) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"gemv_kernel": GemvKernel} + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return self.kernel(a, b) From 2278a47317c42ba2590b1e73d0bbfec1e3545290 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 11 Feb 2026 11:42:11 +0800 Subject: [PATCH 4/6] fix gemm test --- tests/ops/test_gemv.py | 6 +++--- top/kernels/gemv/gemv.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/ops/test_gemv.py b/tests/ops/test_gemv.py index 7faf8d3..86dfa31 100644 --- a/tests/ops/test_gemv.py +++ b/tests/ops/test_gemv.py @@ -9,8 +9,8 @@ "n, k, dtype, tune", [ (1024, 1024, torch.float16, False), - (64, 512, torch.float16, False), - (64, 64, torch.float16, False), + (7168, 16384, torch.float16, True), + (18432, 7168, torch.float16, True), ], ) def test_gemv(n: int, k: int, dtype: torch.dtype, tune: bool) -> None: @@ -19,7 +19,7 @@ def test_gemv(n: int, k: int, dtype: torch.dtype, tune: bool) -> None: inputs = benchmark.gen_inputs() - benchmark.check(op, *inputs, atol=1e-5, rtol=1e-3) + benchmark.check(op, *inputs, atol=1e-3, rtol=1e-3) benchmark.profile(op, *inputs) diff --git a/top/kernels/gemv/gemv.py b/top/kernels/gemv/gemv.py index 4e1156a..21fefbe 100644 --- a/top/kernels/gemv/gemv.py +++ b/top/kernels/gemv/gemv.py @@ -6,7 +6,7 @@ import torch from top.kernels.kernel import Kernel -from top.utils import get_sm_version +from top.utils import get_sm_version, str2dtype __all__ = [ 'GemvKernel', @@ -23,8 +23,8 @@ def _gemv_func( tile_k: int = 8, ) -> Callable: - # MAX_TRANSACTION_SIZE_IN_BITS = 128 - # tile_k = MAX_TRANSACTION_SIZE_IN_BITS // 16 + max_transaction_size_in_bits = 128 + tile_k = max_transaction_size_in_bits // (str2dtype[dtype].itemsize * 8) block_k = reduce_threads * tile_k @T.prim_func From f921abc3a1d02ad7c70b3439c9fd4b93b640af9a Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 11 Feb 2026 11:47:52 +0800 Subject: [PATCH 5/6] fix gemv tile_k --- top/kernels/gemv/gemv.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/top/kernels/gemv/gemv.py b/top/kernels/gemv/gemv.py index 21fefbe..6eaa66c 100644 --- a/top/kernels/gemv/gemv.py +++ b/top/kernels/gemv/gemv.py @@ -20,7 +20,6 @@ def _gemv_kernel(n: int, k: int, dtype: str = "float16") -> Callable: def _gemv_func( block_n: int, reduce_threads: int, - tile_k: int = 8, ) -> Callable: max_transaction_size_in_bits = 128 @@ -78,16 +77,15 @@ def _gemv_wrapped_kernel( dtype: str, block_n: int, reduce_threads: int, - tile_k: int, a: torch.Tensor, b: torch.Tensor, ) -> torch.Tensor: - return _gemv_kernel(n, k, dtype)(block_n, reduce_threads, tile_k)(a, b) + return _gemv_kernel(n, k, dtype)(block_n, reduce_threads)(a, b) @_gemv_wrapped_kernel.register_fake def _(n: int, k: int, # noqa: U100 - dtype: str, block_n: int, reduce_threads: int, tile_k: int, # noqa: U100 + dtype: str, block_n: int, reduce_threads: int, # noqa: U100 *inputs: tuple[torch.Tensor, ...]) -> torch.Tensor: # noqa: U100 return torch.empty((n,), dtype=inputs[0].dtype, device=inputs[0].device) @@ -118,14 +116,12 @@ def default_config(self) -> dict: if sm_version in {90}: return { "block_n": 32, - "reduce_threads": 1, - "tile_k": 8, + "reduce_threads": 8, } return { "block_n": 128, "reduce_threads": 32, - "tile_k": 8, } @property @@ -133,16 +129,14 @@ def autotune_configs(self) -> list[dict]: # From tilelang/examples/gemm/example_gemm_autotune.py block_n = [64, 128, 256] reduce_threads = [16, 32] - tile_k = [8, 16] - _configs = list(itertools.product(block_n, reduce_threads, tile_k)) + _configs = list(itertools.product(block_n, reduce_threads)) return [{ 'block_n': c[0], 'reduce_threads': c[1], - 'tile_k': c[2], } for c in _configs] def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: a = a.flatten().contiguous() return _gemv_wrapped_kernel(self.n, self.k, self.dtype_str, self.config["block_n"], - self.config["reduce_threads"], self.config["tile_k"], a, b) + self.config["reduce_threads"], a, b) From 1eca97dac4945b00ad9e5608177b5f5b67a72f69 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Wed, 11 Feb 2026 17:19:38 +0800 Subject: [PATCH 6/6] add bfloat16 test case --- tests/ops/test_gemv.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/ops/test_gemv.py b/tests/ops/test_gemv.py index 86dfa31..7ff7b7c 100644 --- a/tests/ops/test_gemv.py +++ b/tests/ops/test_gemv.py @@ -11,6 +11,9 @@ (1024, 1024, torch.float16, False), (7168, 16384, torch.float16, True), (18432, 7168, torch.float16, True), + (1024, 1024, torch.bfloat16, False), + (7168, 16384, torch.bfloat16, True), + (18432, 7168, torch.bfloat16, True), ], ) def test_gemv(n: int, k: int, dtype: torch.dtype, tune: bool) -> None: @@ -19,7 +22,10 @@ def test_gemv(n: int, k: int, dtype: torch.dtype, tune: bool) -> None: inputs = benchmark.gen_inputs() - benchmark.check(op, *inputs, atol=1e-3, rtol=1e-3) + if dtype == torch.float16: + benchmark.check(op, *inputs, atol=1e-3, rtol=1e-3) + else: + benchmark.check(op, *inputs, atol=1.6e-2, rtol=1.6e-2) benchmark.profile(op, *inputs)