diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 07c26f9..c7ecad7 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -19,6 +19,7 @@ MultiHeadAttentionDecodeBenchmark, ) from .gemm import GemmBenchmark, MatMulBenchmark +from .gemv import GemvBenchmark from .grouped_gemm import ( GroupedGemmBenchmark, GroupedGemmNNBenchmark, @@ -36,6 +37,7 @@ "GroupQueryAttentionFwdBenchmark", "GroupQueryAttentionBwdBenchmark", "GemmBenchmark", + "GemvBenchmark", "MultiHeadAttentionDecodeBenchmark", "GroupQueryAttentionDecodeBenchmark", "MultiHeadLatentAttentionDecodeBenchmark", diff --git a/benchmarks/gemv/__init__.py b/benchmarks/gemv/__init__.py new file mode 100644 index 0000000..c0a5753 --- /dev/null +++ b/benchmarks/gemv/__init__.py @@ -0,0 +1,3 @@ +from .gemv import GemvBenchmark + +__all__ = ["GemvBenchmark"] diff --git a/benchmarks/gemv/gemv.py b/benchmarks/gemv/gemv.py new file mode 100644 index 0000000..6537bb0 --- /dev/null +++ b/benchmarks/gemv/gemv.py @@ -0,0 +1,39 @@ +from typing import Tuple + +import torch + +from benchmarks.benchmark import Benchmark +from top.ops import GemvOp + + +class GemvBenchmark(Benchmark): + + op_type = GemvOp + + def __init__(self, n: int, k: int, dtype: torch.dtype): + self.n = n + self.k = k + self.dtype = dtype + + @property + def total_flops(self) -> float: + return 2.0 * self.n * self.k + + @property + def total_memory(self) -> int: + return (self.k + self.k * self.n + self.n) * self.dtype.itemsize + + def gen_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: + shape_a = (self.k,) + a = torch.randn(*shape_a, device='cuda', dtype=self.dtype) + shape_b = (self.n, self.k) + b = torch.randn(*shape_b, device='cuda', dtype=self.dtype) + return a, b + + def ref_program(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + # return torch.mv(b, a) + return b @ a + + def baseline_profile(self, *inputs, warmup=100, rep=10, device="cuda:0") -> None: + return super().baseline_profile( + self.ref_program, *inputs, backend="torch", warmup=warmup, rep=rep, device=device) diff --git a/tests/ops/test_gemv.py b/tests/ops/test_gemv.py new file mode 100644 index 0000000..7ff7b7c --- /dev/null +++ b/tests/ops/test_gemv.py @@ -0,0 +1,34 @@ +import torch +import pytest + +from benchmarks import GemvBenchmark +from top.ops import GemvOp + + +@pytest.mark.parametrize( + "n, k, dtype, tune", + [ + (1024, 1024, torch.float16, False), + (7168, 16384, torch.float16, True), + (18432, 7168, torch.float16, True), + (1024, 1024, torch.bfloat16, False), + (7168, 16384, torch.bfloat16, True), + (18432, 7168, torch.bfloat16, True), + ], +) +def test_gemv(n: int, k: int, dtype: torch.dtype, tune: bool) -> None: + op = GemvOp(n, k, dtype=dtype, tune=tune) + benchmark = GemvBenchmark(n, k, dtype) + + inputs = benchmark.gen_inputs() + + if dtype == torch.float16: + benchmark.check(op, *inputs, atol=1e-3, rtol=1e-3) + else: + benchmark.check(op, *inputs, atol=1.6e-2, rtol=1.6e-2) + benchmark.profile(op, *inputs) + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-vvs"]) diff --git a/top/kernels/gemv/__init__.py b/top/kernels/gemv/__init__.py new file mode 100644 index 0000000..658317b --- /dev/null +++ b/top/kernels/gemv/__init__.py @@ -0,0 +1,3 @@ +from .gemv import GemvKernel + +__all__ = ["GemvKernel"] diff --git a/top/kernels/gemv/gemv.py b/top/kernels/gemv/gemv.py new file mode 100644 index 0000000..6eaa66c --- /dev/null +++ b/top/kernels/gemv/gemv.py @@ -0,0 +1,142 @@ +import itertools +from typing import Callable, Optional + +import tilelang +import tilelang.language as T +import torch + +from top.kernels.kernel import Kernel +from top.utils import get_sm_version, str2dtype + +__all__ = [ + 'GemvKernel', +] + + +def _gemv_kernel(n: int, k: int, dtype: str = "float16") -> Callable: + accum_dtype = "float" + + @tilelang.jit(out_idx=[-1], compile_flags=["-O3", "-DENABLE_BF16"]) + def _gemv_func( + block_n: int, + reduce_threads: int, + ) -> Callable: + + max_transaction_size_in_bits = 128 + tile_k = max_transaction_size_in_bits // (str2dtype[dtype].itemsize * 8) + block_k = reduce_threads * tile_k + + @T.prim_func + def _gemv_main( + a: T.Buffer((k,), dtype), + b: T.Buffer((n, k), dtype), + c: T.Buffer((n,), dtype), + ): + with T.Kernel(T.ceildiv(n, block_n), threads=(block_n, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + a_local = T.alloc_local((tile_k,), dtype) + b_local = T.alloc_local((tile_k,), dtype) + c_accum = T.alloc_local((1,), accum_dtype) + + T.clear(c_accum) + for bk in T.serial(T.ceildiv(k, block_k)): + for _k in T.vectorized(tile_k): + a_local[_k] = a[bk * block_k + tk * tile_k + _k] + b_local[_k] = b[bn * block_n + tn, bk * block_k + tk * tile_k + _k] + for _k in T.serial(tile_k): + c_accum[0] += a_local[_k].astype(accum_dtype) * b_local[_k].astype( + accum_dtype) + c_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + c_accum[0], + True, + c_reduced[0], + tk, + dtype="handle", + )) + + c[bn * block_n + tn] = c_reduced[0] + + return _gemv_main + + return _gemv_func + + +@torch.library.custom_op("top::gemv_wrapped_kernel", mutates_args=()) +def _gemv_wrapped_kernel( + n: int, + k: int, + dtype: str, + block_n: int, + reduce_threads: int, + a: torch.Tensor, + b: torch.Tensor, +) -> torch.Tensor: + return _gemv_kernel(n, k, dtype)(block_n, reduce_threads)(a, b) + + +@_gemv_wrapped_kernel.register_fake +def _(n: int, k: int, # noqa: U100 + dtype: str, block_n: int, reduce_threads: int, # noqa: U100 + *inputs: tuple[torch.Tensor, ...]) -> torch.Tensor: # noqa: U100 + return torch.empty((n,), dtype=inputs[0].dtype, device=inputs[0].device) + + +class GemvKernel(Kernel): + supported_archs: list[int] = [90] + + def __init__(self, + n: int, + k: int, + dtype: torch.dtype, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.n = n + self.k = k + self.dtype = dtype + + self.kernel = _gemv_kernel(n, k, self.dtype_str) + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + # From tilelang/examples/gemm/example_gemm_autotune.py + sm_version = get_sm_version() + + if sm_version in {90}: + return { + "block_n": 32, + "reduce_threads": 8, + } + + return { + "block_n": 128, + "reduce_threads": 32, + } + + @property + def autotune_configs(self) -> list[dict]: + # From tilelang/examples/gemm/example_gemm_autotune.py + block_n = [64, 128, 256] + reduce_threads = [16, 32] + _configs = list(itertools.product(block_n, reduce_threads)) + + return [{ + 'block_n': c[0], + 'reduce_threads': c[1], + } for c in _configs] + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + a = a.flatten().contiguous() + return _gemv_wrapped_kernel(self.n, self.k, self.dtype_str, self.config["block_n"], + self.config["reduce_threads"], a, b) diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 6260df2..d963a62 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -5,6 +5,7 @@ from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp from .gemm import GemmOp +from .gemv import GemvOp from .gqa import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp from .gqa_decode_paged import GroupQueryAttentionDecodePagedWithKVCacheOp @@ -23,6 +24,7 @@ "GroupQueryAttentionFwdOp", "GroupQueryAttentionBwdOp", "GemmOp", + "GemvOp", "MultiHeadAttentionDecodeWithKVCacheOp", "MultiHeadAttentionDecodePagedWithKVCacheOp", "GroupQueryAttentionDecodeWithKVCacheOp", diff --git a/top/ops/gemv.py b/top/ops/gemv.py new file mode 100644 index 0000000..15277af --- /dev/null +++ b/top/ops/gemv.py @@ -0,0 +1,34 @@ +from typing import Dict, Optional + +import torch + +from top.kernels.gemv import GemvKernel +from top.kernels.kernel import Kernel + +from .op import Op + +__all__ = ['GemvOp'] + + +class GemvOp(Op): + + def __init__(self, + n: int, + k: int, + dtype: torch.dtype = torch.float16, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune: bool = False) -> None: + self.N = n + self.K = k + + self.dtype = dtype + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["gemv_kernel"](n, k, self.dtype, tune=tune) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"gemv_kernel": GemvKernel} + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return self.kernel(a, b)