From cd2fbef6f412c721d72f00fc8ce863c9bb4bbb46 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 3 Feb 2026 10:52:16 +0800 Subject: [PATCH 01/16] [Feat][NSA] Implement a NSA compression forward kernel. --- benchmarks/deepseek_nsa/deepseek_nsa.py | 276 +++++++++++++++++++++++- tests/ops/test_deepseek_nsa_cmp_fwd.py | 80 +++++++ top/kernels/deepseek_nsa/__init__.py | 3 +- top/kernels/deepseek_nsa/nsa_cmp_fwd.py | 258 ++++++++++++++++++++++ top/ops/__init__.py | 3 +- top/ops/deepseek_nsa.py | 59 ++++- 6 files changed, 673 insertions(+), 6 deletions(-) create mode 100644 tests/ops/test_deepseek_nsa_cmp_fwd.py create mode 100644 top/kernels/deepseek_nsa/nsa_cmp_fwd.py diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py index fe004a7..f8b8a1f 100644 --- a/benchmarks/deepseek_nsa/deepseek_nsa.py +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -1,10 +1,10 @@ -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Tuple import torch from einops import rearrange, repeat from benchmarks.benchmark import Benchmark -from top.ops import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp +from top.ops import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp from .utils import prepare_token_indices, prepare_chunk_offsets @@ -686,3 +686,275 @@ def baseline_profile( print("===== Profiling FLA NSA_Topk backend =====") return super().baseline_profile( self.baseline_program, *inputs, backend="FLA", warmup=warmup, rep=rep, device=device) + + +class NSACmpFwdVarlenBenchmark(Benchmark): + op_type = NSACmpFwdVarlenOp + + def __init__( + self, + seq_num: int, + c_seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + group: int, + scale: float, + bc: int, + bs: int, + bk: int, + bv: int, + dtype: torch.dtype, + accum_dtype: torch.dtype, + tune: bool = False, + ) -> None: + self.seq_num = seq_num + self.c_seq_len = c_seq_len + self.heads = heads + self.dim_k = dim_k + self.dim_v = dim_v + self.group = group + self.scale = scale + self.bc = bc + self.bs = bs + self.bk = bk + self.bv = bv + self.tune = tune + + self.head_kv = self.heads // self.group + self.dtype = dtype + self.accum_dtype = accum_dtype + + @property + def total_flops(self) -> int: + # Step 1 (LSE) + Step 2 (Scores) + return (2 * self.heads * self.dim_k * self.c_seq_len**2) // self.bs + + @property + def total_memory(self) -> int: + # q: read once, k_cmp: read twice per preceding block per token, block_indices: write once + q_read = self.heads * self.c_seq_len * self.dim_k * self.dtype.itemsize + k_read = (self.head_kv * self.dim_k * self.c_seq_len**2 * self.dtype.itemsize) // self.bs + v_read = (self.head_kv * self.dim_v * self.c_seq_len**2 * self.dtype.itemsize) // self.bs + return q_read + k_read + v_read + + + def gen_inputs(self) -> tuple[torch.Tensor, ...]: + valid_range = self.c_seq_len - self.bs + rand_indices = torch.randperm(valid_range)[:self.seq_num - 1] + offsets = torch.cat([ + torch.tensor([0]), + torch.arange(self.bs, self.c_seq_len)[rand_indices], + torch.tensor([self.c_seq_len]) + ], 0).cuda().sort()[0].to(torch.int32) + + chunk_offsets = prepare_chunk_offsets(offsets, self.bs).to(torch.int32) + token_indices = prepare_token_indices(offsets).to(torch.int32) + chunk_num = chunk_offsets[-1].item() + + # float16, data Tie-breaking + q = torch.randn( + (self.c_seq_len, self.heads, self.dim_k), dtype=self.dtype, device="cuda") + k = torch.randn((chunk_num, self.head_kv, self.dim_k), dtype=self.dtype, device="cuda") + v = torch.randn((chunk_num, self.head_kv, self.dim_v), dtype=self.dtype, device="cuda") + + self.chunk_num = chunk_offsets[-1].item() + return ( + q, + k, + v, + offsets.to(torch.int32), + chunk_offsets.to(torch.int32), + token_indices.to(torch.int32), + ) + + + + def parallel_nsa_compression_fwd_pytorch( + self, + q: torch.Tensor, + k_cmp: torch.Tensor, + v_cmp: torch.Tensor, + block_size: int, + scale: float, + offsets: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch reference implementation on GPU""" + # Clone inputs to allocate fresh memory (use _ref suffix) + q_ref = q.clone().contiguous() + k_cmp_ref = k_cmp.clone().contiguous() + v_cmp_ref = v_cmp.clone().contiguous() + offsets_ref = offsets.clone().contiguous() + + seq_len, heads, dim_k = q_ref.shape + num_chunks, head_kv, _ = k_cmp_ref.shape + dim_v = v_cmp_ref.shape[-1] + group = heads // head_kv + device = q_ref.device + num_seq = len(offsets_ref) - 1 + + o = torch.zeros((seq_len, heads, dim_v), dtype=torch.float32, device=device) + lse = torch.full((seq_len, heads), float('-inf'), dtype=torch.float32, device=device) + + chunk_offsets = prepare_chunk_offsets(offsets_ref, block_size) + + for i_n in range(num_seq): + bos, eos = offsets_ref[i_n].item(), offsets_ref[i_n + 1].item() + boc = chunk_offsets[i_n].item() + + for i_t in range(eos - bos): + nc = (i_t + 1) // block_size + if nc == 0: + lse[bos + i_t] = 0.0 + continue + + # [HQ, dim_k] + q_curr = q_ref[bos + i_t].float() + # [nc, H, dim_k] -> [H, nc, dim_k] + k_curr = k_cmp_ref[boc : boc + nc].transpose(0, 1).float() + # [nc, H, dim_v] -> [H, nc, dim_v] + v_curr = v_cmp_ref[boc : boc + nc].transpose(0, 1).float() + + # Expand K/V for GQA + k_curr = k_curr.unsqueeze(1).expand(-1, group, -1, -1).reshape(heads, nc, dim_k) + v_curr = v_curr.unsqueeze(1).expand(-1, group, -1, -1).reshape(heads, nc, dim_v) + + # scores: [HQ, nc] + scores = torch.matmul(q_curr.unsqueeze(1), k_curr.transpose(-1, -2)).squeeze(1) * scale + + # LSE and Softmax + m = torch.max(scores, dim=-1, keepdim=True)[0] + exp_scores = torch.exp(scores - m) + sum_exp = torch.sum(exp_scores, dim=-1, keepdim=True) + + # probs: [HQ, nc] + probs = exp_scores / sum_exp + + # output: [HQ, dim_v] + out = torch.matmul(probs.unsqueeze(1), v_curr).squeeze(1) + + o[bos + i_t] = out + lse[bos + i_t] = (m + torch.log(sum_exp)).squeeze(-1) + + # Compare original inputs with cloned versions after computation + if not torch.equal(q, q_ref): + diff = (q - q_ref).abs().max().item() + print(f"⚠️ [REF DEBUG] q was modified! max diff: {diff}") + if not torch.equal(k_cmp, k_cmp_ref): + diff = (k_cmp - k_cmp_ref).abs().max().item() + print(f"⚠️ [REF DEBUG] k_cmp was modified! max diff: {diff}") + if not torch.equal(v_cmp, v_cmp_ref): + diff = (v_cmp - v_cmp_ref).abs().max().item() + print(f"⚠️ [REF DEBUG] v_cmp was modified! max diff: {diff}") + if not torch.equal(offsets, offsets_ref): + diff = (offsets.float() - offsets_ref.float()).abs().max().item() + print(f"⚠️ [REF DEBUG] offsets was modified! max diff: {diff}") + + return o.to(self.dtype), lse.to(self.dtype) + + def parallel_nsa_compression_fwd_pytorch_cpu( + self, + q: torch.Tensor, + k_cmp: torch.Tensor, + v_cmp: torch.Tensor, + block_size: int, + scale: float, + offsets: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch reference implementation - runs entirely on CPU to avoid GPU pollution""" + device = q.device + + # Move all inputs to CPU + q_cpu = q.cpu().float() + k_cmp_cpu = k_cmp.cpu().float() + v_cmp_cpu = v_cmp.cpu().float() + offsets_cpu = offsets.cpu() + + seq_len, heads, dim_k = q_cpu.shape + _, head_kv, _ = k_cmp_cpu.shape + dim_v = v_cmp_cpu.shape[-1] + group = heads // head_kv + num_seq = len(offsets_cpu) - 1 + + o = torch.zeros((seq_len, heads, dim_v), dtype=torch.float32) + lse = torch.full((seq_len, heads), float('-inf'), dtype=torch.float32) + + chunk_offsets = prepare_chunk_offsets(offsets_cpu, block_size) + + for i_n in range(num_seq): + bos, eos = offsets_cpu[i_n].item(), offsets_cpu[i_n + 1].item() + boc = chunk_offsets[i_n].item() + + for i_t in range(eos - bos): + nc = (i_t + 1) // block_size + if nc == 0: + lse[bos + i_t] = 0.0 + continue + + # q_curr: [heads, dim_k] + q_curr = q_cpu[bos + i_t] + # k_curr: [head_kv, nc, dim_k] + k_curr = k_cmp_cpu[boc : boc + nc].transpose(0, 1) + # v_curr: [head_kv, nc, dim_v] + v_curr = v_cmp_cpu[boc : boc + nc].transpose(0, 1) + + # Expand K/V for GQA: [head_kv, nc, dim] -> [heads, nc, dim] + k_curr = k_curr.unsqueeze(1).expand(-1, group, -1, -1).reshape(heads, nc, dim_k) + v_curr = v_curr.unsqueeze(1).expand(-1, group, -1, -1).reshape(heads, nc, dim_v) + + # scores: [heads, nc] + scores = torch.einsum('hd,hnd->hn', q_curr, k_curr) * scale + + # Softmax with numerical stability + m = scores.max(dim=-1, keepdim=True)[0] + exp_scores = torch.exp(scores - m) + sum_exp = exp_scores.sum(dim=-1, keepdim=True) + probs = exp_scores / sum_exp + + # output: [heads, dim_v] + out = torch.einsum('hn,hnd->hd', probs, v_curr) + + o[bos + i_t] = out + lse[bos + i_t] = (m + torch.log(sum_exp)).squeeze(-1) + + # Move results back to original device + return o.to(device=device, dtype=self.dtype), lse.to(device=device, dtype=self.dtype) + + + def ref_program( + self, + q: torch.Tensor, + k_cmp: torch.Tensor, + v_cmp: torch.Tensor, + offsets: torch.LongTensor, + chunk_offsets: torch.LongTensor, + token_indices: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + _ = chunk_offsets, token_indices + return self.parallel_nsa_compression_fwd_pytorch(q, k_cmp, v_cmp, self.bs, self.scale, offsets) + + + def baseline_program( + self, + q: torch.Tensor, + k_cmp: torch.Tensor, + v_cmp: torch.Tensor, + offsets: torch.LongTensor, + chunk_offsets: torch.LongTensor, + token_indices: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from native_sparse_attention.ops.parallel import parallel_nsa_compression_fwd + out, lse = parallel_nsa_compression_fwd(q.unsqueeze(0), k_cmp.unsqueeze(0), v_cmp.unsqueeze(0), self.bs, self.scale, offsets,token_indices) + return out.squeeze(0).to(self.dtype), lse.squeeze(0).to(self.dtype) + + + def baseline_profile( + self, + *inputs: tuple[torch.Tensor, ...], + warmup: int = 100, + rep: int = 100, + device: str = "cuda", + ) -> torch.Tensor: + print("===== Profiling FLA NSA_Compression backend =====") + return super().baseline_profile( + self.baseline_program, *inputs, backend="FLA NSA_Compression", warmup=warmup, rep=rep, device=device) \ No newline at end of file diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py new file mode 100644 index 0000000..cd51031 --- /dev/null +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -0,0 +1,80 @@ +import pytest +import torch + +from benchmarks.deepseek_nsa.deepseek_nsa import NSACmpFwdVarlenBenchmark +from top.ops import NSACmpFwdVarlenOp + + +@pytest.fixture(autouse=True) +def setup() -> None: + torch.manual_seed(1234) + + +@pytest.mark.parametrize( + ("seq_num, c_seq_len, heads, dim_k, dim_v, chunk_num, group, scale, bc, bs, bk, bv, " + "dtype, accum_dtype, tune"), + [ + (5, 1024, 32, 128, 128, 1024 // 32, 16, 128**-0.5, 32, 32, 128, 128, torch.float16, torch.float32, False), + (3, 512, 32, 128, 128, 512 // 32, 16, 128**-0.5, 32, 32, 128, 128, torch.float16, torch.float32, False), + ], +) +def test_nsa_cmp_fwd_varlen_op( + seq_num: int, + c_seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + group: int, + scale: float, + bc: int, + bs: int, + bk: int, + bv: int, + dtype: torch.dtype, + accum_dtype: torch.dtype, + tune: bool, +) -> None: + + assert group % 16 == 0, "Group size must be a multiple of 16 in NSA" + + params = { + "seq_num": seq_num, + "c_seq_len": c_seq_len, + "heads": heads, + "dim_k": dim_k, + "dim_v": dim_v, + "group": group, + "scale": scale, + "bc": bc, + "bs": bs, + "bk": bk, + "bv": bv, + "dtype": dtype, + "accum_dtype": accum_dtype, + "tune": tune, + } + benchmark = NSACmpFwdVarlenBenchmark(**params) + inputs = benchmark.gen_inputs() + # Update chunk_num based on generated inputs + params["chunk_num"] = benchmark.chunk_num + op = NSACmpFwdVarlenOp(**params) + benchmark.check(op, *inputs) + + +if __name__ == "__main__": + test_nsa_cmp_fwd_varlen_op( + seq_num=9, + c_seq_len=8192, + heads=32, + dim_k=128, + dim_v=128, + group=16, + scale=128**-0.5, + bc=32, + bs=32, + bk=128, + bv=128, + dtype=torch.float16, + accum_dtype=torch.float32, + tune=False + ) diff --git a/top/kernels/deepseek_nsa/__init__.py b/top/kernels/deepseek_nsa/__init__.py index 73077b1..81a4ce5 100644 --- a/top/kernels/deepseek_nsa/__init__.py +++ b/top/kernels/deepseek_nsa/__init__.py @@ -1,5 +1,6 @@ from .mean_pooling_fwd import MeanPoolingFwdKernel from .nsa_fwd import NSAFwdVarlenKernel from .nsa_topk import NSATopkVarlenKernel +from .nsa_cmp_fwd import NSACmpFwdVarlenKernel -__all__ = ["MeanPoolingFwdKernel", "NSAFwdVarlenKernel", "NSATopkVarlenKernel"] +__all__ = ["MeanPoolingFwdKernel", "NSAFwdVarlenKernel", "NSATopkVarlenKernel", "NSACmpFwdVarlenKernel"] diff --git a/top/kernels/deepseek_nsa/nsa_cmp_fwd.py b/top/kernels/deepseek_nsa/nsa_cmp_fwd.py new file mode 100644 index 0000000..a372a53 --- /dev/null +++ b/top/kernels/deepseek_nsa/nsa_cmp_fwd.py @@ -0,0 +1,258 @@ +import torch +from typing import Optional, Any, Callable, Tuple + +import tilelang +from tilelang import language as T + +from top.kernels.kernel import Kernel + + +def _nsa_cmp_fwd_varlen_kernel( + seq_num: int, + c_seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_num: int, + group: int, + scale: float, + bc: int, + bs: int, + bk: int, + bv: int, + dtype: str, + accum_dtype: str, +) -> Callable: + LOG2_E = 1.44269504 + scale_log2 = scale * LOG2_E + head_kv = heads // group + + q_shape = [c_seq_len, heads, dim_k] + k_cmp_shape = [chunk_num, head_kv, dim_k] + v_cmp_shape = [chunk_num, head_kv, dim_v] + lse_shape = [c_seq_len, heads] + offsets_shape = [seq_num + 1] + token_indices_shape = [c_seq_len, 2] + chunk_offsets_shape = [seq_num + 1] + o_shape = [c_seq_len, heads, dim_v] + + @tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + def _nsa_cmp_fwd_varlen_func(threads: int): + @T.prim_func + def _parallel_nsa_cmp_fwd_varlen_main( + q: T.Tensor(q_shape, dtype), + k_cmp: T.Tensor(k_cmp_shape, dtype), + v_cmp: T.Tensor(v_cmp_shape, dtype), + offsets: T.Tensor(offsets_shape, T.int32), + chunk_offsets: T.Tensor(chunk_offsets_shape, T.int32), + token_indices: T.Tensor(token_indices_shape, T.int32), + output: T.Tensor(o_shape, dtype), + temp_lse: T.Tensor(lse_shape, dtype), + ): + with T.Kernel(c_seq_len, head_kv, threads=threads) as (bx, by): + q_shared = T.alloc_shared([group, bk], dtype) + k_shared = T.alloc_shared([bc, bk], dtype) + v_shared = T.alloc_shared([bc, bv], dtype) + + + i_c, i_h = bx, by + i_n, i_t = token_indices[i_c, 0], token_indices[i_c, 1] + + bos, eos = offsets[i_n], offsets[i_n + 1] + boc = chunk_offsets[i_n] + nc = (i_t + 1) // bs + + T.copy(q[bos + i_t, i_h * group : (i_h + 1) * group, :bk], q_shared) + + b_o = T.alloc_fragment([group, bv], dtype) + b_lse = T.alloc_fragment([group], dtype) + acc_s = T.alloc_fragment([group, bc], accum_dtype) + acc_s_cast = T.alloc_fragment([group, bc], dtype) + scores_max = T.alloc_fragment([group], accum_dtype) + scores_max_prev = T.alloc_fragment([group], accum_dtype) + scores_scale = T.alloc_fragment([group], accum_dtype) + scores_sum = T.alloc_fragment([group], accum_dtype) + logsum = T.alloc_fragment([group], accum_dtype) + + T.fill(b_o, 0.0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.fill(logsum, 0.0) + + for i_loop in T.Pipelined(T.ceildiv(nc, bc), num_stages=3): + curr_bc = T.min(bc, nc - i_loop * bc) + T.copy(k_cmp[boc + i_loop * bc : boc + i_loop * bc + curr_bc, i_h, :bk], k_shared[:curr_bc, :bk]) + T.copy(v_cmp[boc + i_loop * bc : boc + i_loop * bc + curr_bc, i_h, :bv], v_shared[:curr_bc, :bv]) + + for g_m, c_m in T.Parallel(group, bc): + acc_s[g_m, c_m] = T.if_then_else(c_m < curr_bc, 0.0, -T.infinity(accum_dtype)) + + T.gemm(q_shared, k_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + + for i in T.Parallel(group): + scores_scale[i] = T.if_then_else(scores_max[i] > -T.infinity(accum_dtype), + T.exp2(scores_max_prev[i] * scale_log2 - scores_max[i] * scale_log2), 0.0) + + for i, j in T.Parallel(group, bc): + acc_s[i, j] = T.if_then_else(acc_s[i, j] > -T.infinity(accum_dtype), + T.exp2(acc_s[i, j] * scale_log2 - scores_max[i] * scale_log2), 0.0) + + for i, k_idx in T.Parallel(group, bv): + b_o[i, k_idx] *= scores_scale[i] + + T.copy(acc_s, acc_s_cast) + + T.gemm(acc_s_cast, v_shared, b_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(group): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + for i, k_idx in T.Parallel(group, bv): + if nc > 0 and logsum[i] > 0: + b_o[i, k_idx] /= logsum[i] + + for i in T.Parallel(group): + if nc == 0 or logsum[i] <= 0: + b_lse[i] = 0.0 + else: + b_lse[i] = (scores_max[i] * scale_log2 + T.log2(logsum[i])) / LOG2_E + + T.copy(b_o, output[bos + i_t, i_h * group : (i_h + 1) * group, :dim_v]) + T.copy(b_lse, temp_lse[bos + i_t, i_h * group : (i_h + 1) * group]) + + return _parallel_nsa_cmp_fwd_varlen_main + + return _nsa_cmp_fwd_varlen_func + + +@torch.library.custom_op("top::nsa_cmp_fwd_varlen_wrapped_kernel", mutates_args=()) +def _nsa_cmp_fwd_varlen_wrapped_kernel( + seq_num: int, + c_seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_num: int, + group: int, + scale: float, + bc: int, + bs: int, + bk: int, + bv: int, + dtype: str, + accum_dtype: str, + threads: int, + q: torch.Tensor, + k_cmp: torch.Tensor, + v_cmp: torch.Tensor, + offsets: torch.Tensor, + chunk_offsets: torch.Tensor, + token_indices: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _nsa_cmp_fwd_varlen_kernel(seq_num, c_seq_len, heads, dim_k, dim_v, chunk_num, group, scale, + bc, bs, bk, bv, dtype, accum_dtype)(threads)(q, k_cmp, v_cmp, offsets, chunk_offsets, + token_indices) + + +@_nsa_cmp_fwd_varlen_wrapped_kernel.register_fake +def _( + seq_num: int, + c_seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_num: int, + group: int, + scale: float, + bc: int, + bs: int, + bk: int, + bv: int, + dtype: str, + accum_dtype: str, + threads: int, + *inputs: Any, +) -> Tuple[torch.Tensor, torch.Tensor]: + _ = (seq_num, dim_k, dim_v, chunk_num, group, scale, bc, bs, bk, bv, dtype, accum_dtype, threads) + return (torch.empty([c_seq_len, heads, dim_v], + dtype=inputs[0].dtype, + device=inputs[0].device), + torch.empty([c_seq_len, heads], + dtype=inputs[0].dtype, + device=inputs[0].device)) + + +class NSACmpFwdVarlenKernel(Kernel): + supported_archs: list[int] = [90] + + def __init__(self, + seq_num: int, + c_seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_num: int, + group: int, + scale: float, + bc: int, + bs: int, + bk: int, + bv: int, + dtype: torch.dtype, + accum_dtype: torch.dtype, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.seq_num = seq_num + self.c_seq_len = c_seq_len + self.heads = heads + self.dim_k = dim_k + self.dim_v = dim_v + self.chunk_num = chunk_num + self.group = group + self.scale = scale + self.bc = bc + self.bs = bs + self.bk = bk + self.bv = bv + self.dtype_name = str(dtype).split('.')[-1] + self.accum_dtype_name = str(accum_dtype).split('.')[-1] + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "threads": 32, + } + + @property + def autotune_configs(self) -> list[dict]: + threads = [32] + return [{"threads": t} for t in threads] + + def forward(self, q: torch.Tensor, k_cmp: torch.Tensor, v_cmp: torch.Tensor, + offsets: torch.Tensor, chunk_offsets: torch.Tensor, + token_indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return _nsa_cmp_fwd_varlen_wrapped_kernel(self.seq_num, self.c_seq_len, self.heads, self.dim_k, self.dim_v, + self.chunk_num, self.group, self.scale, + self.bc, self.bs, self.bk, self.bv, + self.dtype_name, self.accum_dtype_name, + self.config["threads"], + q.to(getattr(torch, self.dtype_name)), + k_cmp.to(getattr(torch, self.dtype_name)), + v_cmp.to(getattr(torch, self.dtype_name)), + offsets.to(torch.int32), + chunk_offsets.to(torch.int32), + token_indices.to(torch.int32)) + + diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 5a5f84e..c8848f5 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -1,6 +1,6 @@ from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheOp from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp -from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp +from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp from .fp8_lighting_indexer import Fp8LightingIndexerOp from .topk_selector import TopkSelectorOp from .gemm import GemmOp @@ -33,6 +33,7 @@ "MeanPoolingForwardOp", "NSATopkVarlenOp", "NSAFwdVarlenOp", + "NSACmpFwdVarlenOp", "ManifoldConstrainedHyperConnectionPreOp", "ManifoldConstrainedHyperConnectionPostOp", ] diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index 884636c..d7152ae 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -1,14 +1,15 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch from top.kernels.deepseek_nsa.mean_pooling_fwd import MeanPoolingFwdKernel from top.kernels.deepseek_nsa.nsa_fwd import NSAFwdVarlenKernel from top.kernels.deepseek_nsa.nsa_topk import NSATopkVarlenKernel +from top.kernels.deepseek_nsa.nsa_cmp_fwd import NSACmpFwdVarlenKernel from top.kernels.kernel import Kernel from top.ops.op import Op -__all__ = ["MeanPoolingForwardOp", "NSAFwdVarlenOp", "NSATopkVarlenOp"] +__all__ = ["MeanPoolingForwardOp", "NSAFwdVarlenOp", "NSATopkVarlenOp", "NSACmpFwdVarlenOp"] class MeanPoolingForwardOp(Op): @@ -158,3 +159,57 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_indices: torch.Tensor, block_counts: torch.Tensor, offsets: torch.Tensor, token_indices: torch.Tensor) -> torch.Tensor: return self.kernel(q, k, v, block_indices, block_counts, offsets, token_indices) + + +class NSACmpFwdVarlenOp(Op): + + def __init__( + self, + seq_num: int, + c_seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_num: int, + group: int, + scale: float, + bc: int, + bs: int, + bk: int, + bv: int, + dtype: torch.dtype, + accum_dtype: torch.dtype, + tune: bool = False, + kernel_map: Optional[Dict[str, Kernel]] = None, + ) -> None: + params = { + "seq_num": seq_num, + "c_seq_len": c_seq_len, + "heads": heads, + "dim_k": dim_k, + "dim_v": dim_v, + "chunk_num": chunk_num, + "group": group, + "scale": scale, + "bc": bc, + "bs": bs, + "bk": bk, + "bv": bv, + "dtype": dtype, + "accum_dtype": accum_dtype, + "tune": tune, + } + for key, value in params.items(): + setattr(self, key, value) + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["nsa_cmp_fwd_varlen_kernel"](**params) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"nsa_cmp_fwd_varlen_kernel": NSACmpFwdVarlenKernel} + + def forward(self, q: torch.Tensor, k_cmp: torch.Tensor, v_cmp: torch.Tensor, + offsets: torch.Tensor, chunk_offsets: torch.Tensor, + token_indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.kernel(q, k_cmp, v_cmp, offsets, chunk_offsets, token_indices) From 15d6ab41fcefd8f31410ab3cceedd694e0e243f2 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 3 Feb 2026 17:17:57 +0800 Subject: [PATCH 02/16] [Feat][NSA] Implement a NSA compression forward kernel. --- benchmarks/deepseek_nsa/deepseek_nsa.py | 171 ++++++------------------ tests/ops/test_deepseek_nsa_cmp_fwd.py | 10 +- top/kernels/deepseek_nsa/__init__.py | 4 +- top/kernels/deepseek_nsa/nsa_cmp_fwd.py | 120 +++++++++-------- 4 files changed, 111 insertions(+), 194 deletions(-) diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py index f8b8a1f..a08becc 100644 --- a/benchmarks/deepseek_nsa/deepseek_nsa.py +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -727,7 +727,6 @@ def __init__( @property def total_flops(self) -> int: - # Step 1 (LSE) + Step 2 (Scores) return (2 * self.heads * self.dim_k * self.c_seq_len**2) // self.bs @property @@ -738,13 +737,12 @@ def total_memory(self) -> int: v_read = (self.head_kv * self.dim_v * self.c_seq_len**2 * self.dtype.itemsize) // self.bs return q_read + k_read + v_read - def gen_inputs(self) -> tuple[torch.Tensor, ...]: valid_range = self.c_seq_len - self.bs rand_indices = torch.randperm(valid_range)[:self.seq_num - 1] offsets = torch.cat([ - torch.tensor([0]), - torch.arange(self.bs, self.c_seq_len)[rand_indices], + torch.tensor([0]), + torch.arange(self.bs, self.c_seq_len)[rand_indices], torch.tensor([self.c_seq_len]) ], 0).cuda().sort()[0].to(torch.int32) @@ -753,8 +751,7 @@ def gen_inputs(self) -> tuple[torch.Tensor, ...]: chunk_num = chunk_offsets[-1].item() # float16, data Tie-breaking - q = torch.randn( - (self.c_seq_len, self.heads, self.dim_k), dtype=self.dtype, device="cuda") + q = torch.randn((self.c_seq_len, self.heads, self.dim_k), dtype=self.dtype, device="cuda") k = torch.randn((chunk_num, self.head_kv, self.dim_k), dtype=self.dtype, device="cuda") v = torch.randn((chunk_num, self.head_kv, self.dim_v), dtype=self.dtype, device="cuda") @@ -768,8 +765,6 @@ def gen_inputs(self) -> tuple[torch.Tensor, ...]: token_indices.to(torch.int32), ) - - def parallel_nsa_compression_fwd_pytorch( self, q: torch.Tensor, @@ -780,147 +775,51 @@ def parallel_nsa_compression_fwd_pytorch( offsets: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """PyTorch reference implementation on GPU""" - # Clone inputs to allocate fresh memory (use _ref suffix) - q_ref = q.clone().contiguous() - k_cmp_ref = k_cmp.clone().contiguous() - v_cmp_ref = v_cmp.clone().contiguous() - offsets_ref = offsets.clone().contiguous() - - seq_len, heads, dim_k = q_ref.shape - num_chunks, head_kv, _ = k_cmp_ref.shape - dim_v = v_cmp_ref.shape[-1] + seq_len, heads, dim_k = q.shape + _, head_kv, _ = k_cmp.shape + dim_v = v_cmp.shape[-1] group = heads // head_kv - device = q_ref.device - num_seq = len(offsets_ref) - 1 - + device = q.device + num_seq = len(offsets) - 1 + o = torch.zeros((seq_len, heads, dim_v), dtype=torch.float32, device=device) lse = torch.full((seq_len, heads), float('-inf'), dtype=torch.float32, device=device) - - chunk_offsets = prepare_chunk_offsets(offsets_ref, block_size) - + + chunk_offsets = prepare_chunk_offsets(offsets, block_size) + for i_n in range(num_seq): - bos, eos = offsets_ref[i_n].item(), offsets_ref[i_n + 1].item() + bos, eos = offsets[i_n].item(), offsets[i_n + 1].item() boc = chunk_offsets[i_n].item() - + for i_t in range(eos - bos): nc = (i_t + 1) // block_size if nc == 0: lse[bos + i_t] = 0.0 continue - - # [HQ, dim_k] - q_curr = q_ref[bos + i_t].float() - # [nc, H, dim_k] -> [H, nc, dim_k] - k_curr = k_cmp_ref[boc : boc + nc].transpose(0, 1).float() - # [nc, H, dim_v] -> [H, nc, dim_v] - v_curr = v_cmp_ref[boc : boc + nc].transpose(0, 1).float() - - # Expand K/V for GQA + + q_curr = q[bos + i_t].float() + k_curr = k_cmp[boc:boc + nc].transpose(0, 1).float() + v_curr = v_cmp[boc:boc + nc].transpose(0, 1).float() + k_curr = k_curr.unsqueeze(1).expand(-1, group, -1, -1).reshape(heads, nc, dim_k) v_curr = v_curr.unsqueeze(1).expand(-1, group, -1, -1).reshape(heads, nc, dim_v) - - # scores: [HQ, nc] - scores = torch.matmul(q_curr.unsqueeze(1), k_curr.transpose(-1, -2)).squeeze(1) * scale - - # LSE and Softmax + + scores = torch.matmul(q_curr.unsqueeze(1), k_curr.transpose(-1, + -2)).squeeze(1) * scale + m = torch.max(scores, dim=-1, keepdim=True)[0] exp_scores = torch.exp(scores - m) sum_exp = torch.sum(exp_scores, dim=-1, keepdim=True) - - # probs: [HQ, nc] + probs = exp_scores / sum_exp - - # output: [HQ, dim_v] + out = torch.matmul(probs.unsqueeze(1), v_curr).squeeze(1) - - o[bos + i_t] = out - lse[bos + i_t] = (m + torch.log(sum_exp)).squeeze(-1) - - # Compare original inputs with cloned versions after computation - if not torch.equal(q, q_ref): - diff = (q - q_ref).abs().max().item() - print(f"⚠️ [REF DEBUG] q was modified! max diff: {diff}") - if not torch.equal(k_cmp, k_cmp_ref): - diff = (k_cmp - k_cmp_ref).abs().max().item() - print(f"⚠️ [REF DEBUG] k_cmp was modified! max diff: {diff}") - if not torch.equal(v_cmp, v_cmp_ref): - diff = (v_cmp - v_cmp_ref).abs().max().item() - print(f"⚠️ [REF DEBUG] v_cmp was modified! max diff: {diff}") - if not torch.equal(offsets, offsets_ref): - diff = (offsets.float() - offsets_ref.float()).abs().max().item() - print(f"⚠️ [REF DEBUG] offsets was modified! max diff: {diff}") - - return o.to(self.dtype), lse.to(self.dtype) - def parallel_nsa_compression_fwd_pytorch_cpu( - self, - q: torch.Tensor, - k_cmp: torch.Tensor, - v_cmp: torch.Tensor, - block_size: int, - scale: float, - offsets: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """PyTorch reference implementation - runs entirely on CPU to avoid GPU pollution""" - device = q.device - - # Move all inputs to CPU - q_cpu = q.cpu().float() - k_cmp_cpu = k_cmp.cpu().float() - v_cmp_cpu = v_cmp.cpu().float() - offsets_cpu = offsets.cpu() - - seq_len, heads, dim_k = q_cpu.shape - _, head_kv, _ = k_cmp_cpu.shape - dim_v = v_cmp_cpu.shape[-1] - group = heads // head_kv - num_seq = len(offsets_cpu) - 1 - - o = torch.zeros((seq_len, heads, dim_v), dtype=torch.float32) - lse = torch.full((seq_len, heads), float('-inf'), dtype=torch.float32) - - chunk_offsets = prepare_chunk_offsets(offsets_cpu, block_size) - - for i_n in range(num_seq): - bos, eos = offsets_cpu[i_n].item(), offsets_cpu[i_n + 1].item() - boc = chunk_offsets[i_n].item() - - for i_t in range(eos - bos): - nc = (i_t + 1) // block_size - if nc == 0: - lse[bos + i_t] = 0.0 - continue - - # q_curr: [heads, dim_k] - q_curr = q_cpu[bos + i_t] - # k_curr: [head_kv, nc, dim_k] - k_curr = k_cmp_cpu[boc : boc + nc].transpose(0, 1) - # v_curr: [head_kv, nc, dim_v] - v_curr = v_cmp_cpu[boc : boc + nc].transpose(0, 1) - - # Expand K/V for GQA: [head_kv, nc, dim] -> [heads, nc, dim] - k_curr = k_curr.unsqueeze(1).expand(-1, group, -1, -1).reshape(heads, nc, dim_k) - v_curr = v_curr.unsqueeze(1).expand(-1, group, -1, -1).reshape(heads, nc, dim_v) - - # scores: [heads, nc] - scores = torch.einsum('hd,hnd->hn', q_curr, k_curr) * scale - - # Softmax with numerical stability - m = scores.max(dim=-1, keepdim=True)[0] - exp_scores = torch.exp(scores - m) - sum_exp = exp_scores.sum(dim=-1, keepdim=True) - probs = exp_scores / sum_exp - - # output: [heads, dim_v] - out = torch.einsum('hn,hnd->hd', probs, v_curr) - o[bos + i_t] = out lse[bos + i_t] = (m + torch.log(sum_exp)).squeeze(-1) - - # Move results back to original device - return o.to(device=device, dtype=self.dtype), lse.to(device=device, dtype=self.dtype) - + return o.to(self.dtype), lse.to(self.dtype) + def ref_program( self, q: torch.Tensor, @@ -931,8 +830,8 @@ def ref_program( token_indices: torch.LongTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: _ = chunk_offsets, token_indices - return self.parallel_nsa_compression_fwd_pytorch(q, k_cmp, v_cmp, self.bs, self.scale, offsets) - + return self.parallel_nsa_compression_fwd_pytorch(q, k_cmp, v_cmp, self.bs, self.scale, + offsets) def baseline_program( self, @@ -944,10 +843,11 @@ def baseline_program( token_indices: torch.LongTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: from native_sparse_attention.ops.parallel import parallel_nsa_compression_fwd - out, lse = parallel_nsa_compression_fwd(q.unsqueeze(0), k_cmp.unsqueeze(0), v_cmp.unsqueeze(0), self.bs, self.scale, offsets,token_indices) + out, lse = parallel_nsa_compression_fwd( + q.unsqueeze(0), k_cmp.unsqueeze(0), v_cmp.unsqueeze(0), self.bs, self.scale, offsets, + token_indices) return out.squeeze(0).to(self.dtype), lse.squeeze(0).to(self.dtype) - def baseline_profile( self, *inputs: tuple[torch.Tensor, ...], @@ -957,4 +857,9 @@ def baseline_profile( ) -> torch.Tensor: print("===== Profiling FLA NSA_Compression backend =====") return super().baseline_profile( - self.baseline_program, *inputs, backend="FLA NSA_Compression", warmup=warmup, rep=rep, device=device) \ No newline at end of file + self.baseline_program, + *inputs, + backend="FLA NSA_Compression", + warmup=warmup, + rep=rep, + device=device) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index cd51031..a407fb7 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -11,11 +11,10 @@ def setup() -> None: @pytest.mark.parametrize( - ("seq_num, c_seq_len, heads, dim_k, dim_v, chunk_num, group, scale, bc, bs, bk, bv, " + ("seq_num, c_seq_len, heads, dim_k, dim_v, group, scale, bc, bs, bk, bv, " "dtype, accum_dtype, tune"), [ - (5, 1024, 32, 128, 128, 1024 // 32, 16, 128**-0.5, 32, 32, 128, 128, torch.float16, torch.float32, False), - (3, 512, 32, 128, 128, 512 // 32, 16, 128**-0.5, 32, 32, 128, 128, torch.float16, torch.float32, False), + (9, 8192, 32, 128, 128, 16, 128**-0.5, 32, 32, 128, 128, torch.float16, torch.float32, False), ], ) def test_nsa_cmp_fwd_varlen_op( @@ -55,7 +54,7 @@ def test_nsa_cmp_fwd_varlen_op( } benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs() - # Update chunk_num based on generated inputs + params["chunk_num"] = benchmark.chunk_num op = NSACmpFwdVarlenOp(**params) benchmark.check(op, *inputs) @@ -63,7 +62,7 @@ def test_nsa_cmp_fwd_varlen_op( if __name__ == "__main__": test_nsa_cmp_fwd_varlen_op( - seq_num=9, + seq_num=12, c_seq_len=8192, heads=32, dim_k=128, @@ -78,3 +77,4 @@ def test_nsa_cmp_fwd_varlen_op( accum_dtype=torch.float32, tune=False ) + diff --git a/top/kernels/deepseek_nsa/__init__.py b/top/kernels/deepseek_nsa/__init__.py index 81a4ce5..820ed3b 100644 --- a/top/kernels/deepseek_nsa/__init__.py +++ b/top/kernels/deepseek_nsa/__init__.py @@ -3,4 +3,6 @@ from .nsa_topk import NSATopkVarlenKernel from .nsa_cmp_fwd import NSACmpFwdVarlenKernel -__all__ = ["MeanPoolingFwdKernel", "NSAFwdVarlenKernel", "NSATopkVarlenKernel", "NSACmpFwdVarlenKernel"] +__all__ = [ + "MeanPoolingFwdKernel", "NSAFwdVarlenKernel", "NSATopkVarlenKernel", "NSACmpFwdVarlenKernel" +] diff --git a/top/kernels/deepseek_nsa/nsa_cmp_fwd.py b/top/kernels/deepseek_nsa/nsa_cmp_fwd.py index a372a53..8114bf2 100644 --- a/top/kernels/deepseek_nsa/nsa_cmp_fwd.py +++ b/top/kernels/deepseek_nsa/nsa_cmp_fwd.py @@ -26,7 +26,7 @@ def _nsa_cmp_fwd_varlen_kernel( LOG2_E = 1.44269504 scale_log2 = scale * LOG2_E head_kv = heads // group - + q_shape = [c_seq_len, heads, dim_k] k_cmp_shape = [chunk_num, head_kv, dim_k] v_cmp_shape = [chunk_num, head_kv, dim_v] @@ -44,31 +44,31 @@ def _nsa_cmp_fwd_varlen_kernel( tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) def _nsa_cmp_fwd_varlen_func(threads: int): + @T.prim_func def _parallel_nsa_cmp_fwd_varlen_main( - q: T.Tensor(q_shape, dtype), - k_cmp: T.Tensor(k_cmp_shape, dtype), - v_cmp: T.Tensor(v_cmp_shape, dtype), - offsets: T.Tensor(offsets_shape, T.int32), - chunk_offsets: T.Tensor(chunk_offsets_shape, T.int32), - token_indices: T.Tensor(token_indices_shape, T.int32), - output: T.Tensor(o_shape, dtype), - temp_lse: T.Tensor(lse_shape, dtype), + q: T.Tensor(q_shape, dtype), + k_cmp: T.Tensor(k_cmp_shape, dtype), + v_cmp: T.Tensor(v_cmp_shape, dtype), + offsets: T.Tensor(offsets_shape, T.int32), + chunk_offsets: T.Tensor(chunk_offsets_shape, T.int32), + token_indices: T.Tensor(token_indices_shape, T.int32), + output: T.Tensor(o_shape, dtype), + temp_lse: T.Tensor(lse_shape, dtype), ): with T.Kernel(c_seq_len, head_kv, threads=threads) as (bx, by): q_shared = T.alloc_shared([group, bk], dtype) k_shared = T.alloc_shared([bc, bk], dtype) v_shared = T.alloc_shared([bc, bv], dtype) - i_c, i_h = bx, by i_n, i_t = token_indices[i_c, 0], token_indices[i_c, 1] - bos, eos = offsets[i_n], offsets[i_n + 1] + bos = offsets[i_n] boc = chunk_offsets[i_n] nc = (i_t + 1) // bs - T.copy(q[bos + i_t, i_h * group : (i_h + 1) * group, :bk], q_shared) + T.copy(q[bos + i_t, i_h * group:(i_h + 1) * group, :bk], q_shared) b_o = T.alloc_fragment([group, bv], dtype) b_lse = T.alloc_fragment([group], dtype) @@ -79,44 +79,58 @@ def _parallel_nsa_cmp_fwd_varlen_main( scores_scale = T.alloc_fragment([group], accum_dtype) scores_sum = T.alloc_fragment([group], accum_dtype) logsum = T.alloc_fragment([group], accum_dtype) - + T.fill(b_o, 0.0) T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(logsum, 0.0) - - for i_loop in T.Pipelined(T.ceildiv(nc, bc), num_stages=3): + + for i_loop in T.serial(T.ceildiv(nc, bc)): curr_bc = T.min(bc, nc - i_loop * bc) - T.copy(k_cmp[boc + i_loop * bc : boc + i_loop * bc + curr_bc, i_h, :bk], k_shared[:curr_bc, :bk]) - T.copy(v_cmp[boc + i_loop * bc : boc + i_loop * bc + curr_bc, i_h, :bv], v_shared[:curr_bc, :bv]) - + # Initialize shared memory in each iteration to avoid stale data + T.fill(k_shared, 0.0) + T.fill(v_shared, 0.0) + T.copy(k_cmp[boc + i_loop * bc:boc + i_loop * bc + curr_bc, i_h, :bk], + k_shared[:curr_bc, :bk]) + T.copy(v_cmp[boc + i_loop * bc:boc + i_loop * bc + curr_bc, i_h, :bv], + v_shared[:curr_bc, :bv]) + for g_m, c_m in T.Parallel(group, bc): - acc_s[g_m, c_m] = T.if_then_else(c_m < curr_bc, 0.0, -T.infinity(accum_dtype)) + acc_s[g_m, c_m] = T.if_then_else(c_m < curr_bc, 0.0, + -T.infinity(accum_dtype)) - T.gemm(q_shared, k_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm( + q_shared, + k_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=True) for i in T.Parallel(group): - scores_scale[i] = T.if_then_else(scores_max[i] > -T.infinity(accum_dtype), - T.exp2(scores_max_prev[i] * scale_log2 - scores_max[i] * scale_log2), 0.0) - + scores_scale[i] = T.if_then_else( + scores_max[i] > -T.infinity(accum_dtype), + T.exp2(scores_max_prev[i] * scale_log2 - scores_max[i] * scale_log2), + 0.0) + for i, j in T.Parallel(group, bc): - acc_s[i, j] = T.if_then_else(acc_s[i, j] > -T.infinity(accum_dtype), - T.exp2(acc_s[i, j] * scale_log2 - scores_max[i] * scale_log2), 0.0) - + acc_s[i, j] = T.if_then_else( + acc_s[i, j] > -T.infinity(accum_dtype), + T.exp2(acc_s[i, j] * scale_log2 - scores_max[i] * scale_log2), 0.0) + for i, k_idx in T.Parallel(group, bv): b_o[i, k_idx] *= scores_scale[i] - + T.copy(acc_s, acc_s_cast) - + T.gemm(acc_s_cast, v_shared, b_o, policy=T.GemmWarpPolicy.FullRow) T.reduce_sum(acc_s, scores_sum, dim=1) for i in T.Parallel(group): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - + for i, k_idx in T.Parallel(group, bv): if nc > 0 and logsum[i] > 0: b_o[i, k_idx] /= logsum[i] @@ -126,9 +140,9 @@ def _parallel_nsa_cmp_fwd_varlen_main( b_lse[i] = 0.0 else: b_lse[i] = (scores_max[i] * scale_log2 + T.log2(logsum[i])) / LOG2_E - - T.copy(b_o, output[bos + i_t, i_h * group : (i_h + 1) * group, :dim_v]) - T.copy(b_lse, temp_lse[bos + i_t, i_h * group : (i_h + 1) * group]) + + T.copy(b_o, output[bos + i_t, i_h * group:(i_h + 1) * group, :dim_v]) + T.copy(b_lse, temp_lse[bos + i_t, i_h * group:(i_h + 1) * group]) return _parallel_nsa_cmp_fwd_varlen_main @@ -159,9 +173,10 @@ def _nsa_cmp_fwd_varlen_wrapped_kernel( chunk_offsets: torch.Tensor, token_indices: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - return _nsa_cmp_fwd_varlen_kernel(seq_num, c_seq_len, heads, dim_k, dim_v, chunk_num, group, scale, - bc, bs, bk, bv, dtype, accum_dtype)(threads)(q, k_cmp, v_cmp, offsets, chunk_offsets, - token_indices) + return _nsa_cmp_fwd_varlen_kernel(seq_num, c_seq_len, heads, dim_k, dim_v, chunk_num, group, + scale, bc, bs, bk, bv, dtype, + accum_dtype)(threads)(q, k_cmp, v_cmp, offsets, chunk_offsets, + token_indices) @_nsa_cmp_fwd_varlen_wrapped_kernel.register_fake @@ -183,13 +198,10 @@ def _( threads: int, *inputs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: - _ = (seq_num, dim_k, dim_v, chunk_num, group, scale, bc, bs, bk, bv, dtype, accum_dtype, threads) - return (torch.empty([c_seq_len, heads, dim_v], - dtype=inputs[0].dtype, - device=inputs[0].device), - torch.empty([c_seq_len, heads], - dtype=inputs[0].dtype, - device=inputs[0].device)) + _ = (seq_num, dim_k, dim_v, chunk_num, group, scale, bc, bs, bk, bv, dtype, accum_dtype, + threads) + return (torch.empty([c_seq_len, heads, dim_v], dtype=inputs[0].dtype, device=inputs[0].device), + torch.empty([c_seq_len, heads], dtype=inputs[0].dtype, device=inputs[0].device)) class NSACmpFwdVarlenKernel(Kernel): @@ -243,16 +255,14 @@ def autotune_configs(self) -> list[dict]: def forward(self, q: torch.Tensor, k_cmp: torch.Tensor, v_cmp: torch.Tensor, offsets: torch.Tensor, chunk_offsets: torch.Tensor, token_indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - return _nsa_cmp_fwd_varlen_wrapped_kernel(self.seq_num, self.c_seq_len, self.heads, self.dim_k, self.dim_v, - self.chunk_num, self.group, self.scale, - self.bc, self.bs, self.bk, self.bv, - self.dtype_name, self.accum_dtype_name, - self.config["threads"], - q.to(getattr(torch, self.dtype_name)), - k_cmp.to(getattr(torch, self.dtype_name)), - v_cmp.to(getattr(torch, self.dtype_name)), - offsets.to(torch.int32), - chunk_offsets.to(torch.int32), - token_indices.to(torch.int32)) - - + return _nsa_cmp_fwd_varlen_wrapped_kernel(self.seq_num, self.c_seq_len, self.heads, + self.dim_k, self.dim_v, self.chunk_num, + self.group, self.scale, self.bc, self.bs, self.bk, + self.bv, self.dtype_name, self.accum_dtype_name, + self.config["threads"], + q.to(getattr(torch, self.dtype_name)), + k_cmp.to(getattr(torch, self.dtype_name)), + v_cmp.to(getattr(torch, self.dtype_name)), + offsets.to(torch.int32), + chunk_offsets.to(torch.int32), + token_indices.to(torch.int32)) From 2dd78a2643db04cbb8ea322ba9083adeb03895a1 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 3 Feb 2026 17:27:59 +0800 Subject: [PATCH 03/16] [Feat][NSA] Implement a NSA compression forward kernel. --- tests/ops/test_deepseek_nsa_cmp_fwd.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index a407fb7..24cef99 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -14,7 +14,8 @@ def setup() -> None: ("seq_num, c_seq_len, heads, dim_k, dim_v, group, scale, bc, bs, bk, bv, " "dtype, accum_dtype, tune"), [ - (9, 8192, 32, 128, 128, 16, 128**-0.5, 32, 32, 128, 128, torch.float16, torch.float32, False), + (9, 8192, 32, 128, 128, 16, 128** + -0.5, 32, 32, 128, 128, torch.float16, torch.float32, False), ], ) def test_nsa_cmp_fwd_varlen_op( @@ -54,7 +55,7 @@ def test_nsa_cmp_fwd_varlen_op( } benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs() - + params["chunk_num"] = benchmark.chunk_num op = NSACmpFwdVarlenOp(**params) benchmark.check(op, *inputs) @@ -75,6 +76,4 @@ def test_nsa_cmp_fwd_varlen_op( bv=128, dtype=torch.float16, accum_dtype=torch.float32, - tune=False - ) - + tune=False) From 42f604de833c8ab2404a29d6b3a2d38022a3d92c Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 3 Feb 2026 17:35:17 +0800 Subject: [PATCH 04/16] [Feat][NSA] Implement a NSA compression forward kernel. --- tests/ops/test_deepseek_nsa_cmp_fwd.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index 24cef99..44fff4d 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -37,22 +37,8 @@ def test_nsa_cmp_fwd_varlen_op( assert group % 16 == 0, "Group size must be a multiple of 16 in NSA" - params = { - "seq_num": seq_num, - "c_seq_len": c_seq_len, - "heads": heads, - "dim_k": dim_k, - "dim_v": dim_v, - "group": group, - "scale": scale, - "bc": bc, - "bs": bs, - "bk": bk, - "bv": bv, - "dtype": dtype, - "accum_dtype": accum_dtype, - "tune": tune, - } + # Use locals() to create params dictionary from function arguments + params = locals().copy() benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs() From 35e228a4f6dc4441c01ef45e51e66bdd8a1ea27a Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 3 Feb 2026 17:42:59 +0800 Subject: [PATCH 05/16] [Feat][NSA] Implement a NSA compression forward kernel. --- top/ops/deepseek_nsa.py | 64 +++-------------------------------------- 1 file changed, 4 insertions(+), 60 deletions(-) diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index d7152ae..584b9c3 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -29,19 +29,7 @@ def __init__( tune: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, ) -> None: - params = { - "batch_size": batch_size, - "seq_len": seq_len, - "heads": heads, - "dim": dim, - "chunk_size": chunk_size, - "chunks_per_bacth": chunks_per_bacth, - "seq_num": seq_num, - "use_offsets": use_offsets, - "dtype": dtype, - "accum_dtype": accum_dtype, - "tune": tune, - } + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map')} for key, value in params.items(): setattr(self, key, value) @@ -81,22 +69,7 @@ def __init__( tune: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, ) -> None: - params = { - "seq_num": seq_num, - "c_seq_len": c_seq_len, - "heads": heads, - "dim": dim, - "chunk_num": chunk_num, - "group": group, - "scale": scale, - "selected_block_num": selected_block_num, - "bc": bc, - "bs": bs, - "bk": bk, - "dtype": dtype, - "accum_dtype": accum_dtype, - "tune": tune, - } + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map')} for key, value in params.items(): setattr(self, key, value) @@ -131,20 +104,7 @@ def __init__( tune: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, ) -> None: - params = { - "batch": batch, - "heads": heads, - "c_seq_len": c_seq_len, - "dim": dim, - "is_causal": is_causal, - "scale": scale, - "block_size": block_size, - "groups": groups, - "selected_blocks": selected_blocks, - "dtype": dtype, - "accum_dtype": accum_dtype, - "tune": tune, - } + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map')} for key, value in params.items(): setattr(self, key, value) @@ -182,23 +142,7 @@ def __init__( tune: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, ) -> None: - params = { - "seq_num": seq_num, - "c_seq_len": c_seq_len, - "heads": heads, - "dim_k": dim_k, - "dim_v": dim_v, - "chunk_num": chunk_num, - "group": group, - "scale": scale, - "bc": bc, - "bs": bs, - "bk": bk, - "bv": bv, - "dtype": dtype, - "accum_dtype": accum_dtype, - "tune": tune, - } + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map')} for key, value in params.items(): setattr(self, key, value) From a2fb5f6d5bb9977675b64505f6e187870016da2e Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Thu, 5 Feb 2026 16:41:30 +0800 Subject: [PATCH 06/16] [Feat][NSA] Implement a GQA forward kernel with sliding window. --- benchmarks/deepseek_nsa/deepseek_nsa.py | 194 ++++++++++- .../test_deepseek_nsa_gqa_window_sliding.py | 105 ++++++ top/kernels/deepseek_nsa/__init__.py | 7 +- .../deepseek_nsa/gqa_window_sliding.py | 310 ++++++++++++++++++ top/ops/__init__.py | 3 +- top/ops/deepseek_nsa.py | 43 ++- 6 files changed, 658 insertions(+), 4 deletions(-) create mode 100644 tests/ops/test_deepseek_nsa_gqa_window_sliding.py create mode 100644 top/kernels/deepseek_nsa/gqa_window_sliding.py diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py index a08becc..a71abce 100644 --- a/benchmarks/deepseek_nsa/deepseek_nsa.py +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -4,7 +4,7 @@ from einops import rearrange, repeat from benchmarks.benchmark import Benchmark -from top.ops import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp +from top.ops import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp from .utils import prepare_token_indices, prepare_chunk_offsets @@ -863,3 +863,195 @@ def baseline_profile( warmup=warmup, rep=rep, device=device) + + +class GQAWindowSlidingBenchmark(Benchmark): + op_type = GQAWindowSlidingOp + + def __init__( + self, + batch_size: int, + groups: int, + uq: int, + ukv: int, + heads: int, + dim: int, + is_causal: bool, + window_size_left: int, + window_size_right: int, + dtype: torch.dtype, + accum_dtype: torch.dtype, + tune: bool = False, + ) -> None: + self.batch_size = batch_size + self.groups = groups + self.uq = uq + self.ukv = ukv + self.heads = heads + self.dim = dim + self.is_causal = is_causal + self.window_size_left = window_size_left + self.window_size_right = window_size_right + self.dtype = dtype + self.accum_dtype = accum_dtype + self.tune = tune + + @property + def total_flops(self) -> int: + total_flops = 2.0 * self.heads * self.uq * self.ukv * self.dim * 2 + if self.is_causal: + total_flops *= 0.5 + return int(total_flops) + + @property + def total_memory(self) -> int: + head_kv = self.heads // self.groups + q_memory = self.uq * self.heads * self.dim * self.dtype.itemsize + k_memory = self.ukv * head_kv * self.dim * self.dtype.itemsize + v_memory = self.ukv * head_kv * self.dim * self.dtype.itemsize + output_memory = self.uq * self.heads * self.dim * self.dtype.itemsize + return q_memory + k_memory + v_memory + output_memory + + def gen_inputs(self) -> tuple[torch.Tensor, ...]: + rand_indices_q = torch.randperm(self.uq)[:self.batch_size - 1] + cu_seqlens_q = torch.cat( + [torch.tensor([0]), + torch.arange(1, self.uq)[rand_indices_q], + torch.tensor([self.uq])], 0).cuda().sort()[0].to(torch.int32) + rand_indices_k = torch.randperm(self.ukv)[:self.batch_size - 1] + cu_seqlens_k = torch.cat([ + torch.tensor([0]), + torch.arange(1, self.ukv)[rand_indices_k], + torch.tensor([self.ukv]) + ], 0).cuda().sort()[0].to(torch.int32) + + q = torch.randn((self.uq, self.heads, self.dim), dtype=self.dtype, device="cuda") + k = torch.randn((self.ukv, self.heads // self.groups, self.dim), + dtype=self.dtype, + device="cuda") + v = torch.randn((self.ukv, self.heads // self.groups, self.dim), + dtype=self.dtype, + device="cuda") + max_seqlen_q = int((cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()) + return q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q + + def ref_program(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + cu_seqlens_q: torch.LongTensor, cu_seqlens_k: torch.LongTensor, + max_seqlen_q: int) -> torch.Tensor: + """PyTorch reference implementation for GQA window sliding attention (vectorized)""" + device = q.device + head_kv = self.heads // self.groups + scale = (1.0 / self.dim)**0.5 + has_window = self.window_size_left >= 0 or self.window_size_right >= 0 + + output = torch.zeros((self.uq, self.heads, self.dim), dtype=q.dtype, device=device) + + for batch_idx in range(self.batch_size): + q_start = cu_seqlens_q[batch_idx].item() + q_end = cu_seqlens_q[batch_idx + 1].item() + kv_start = cu_seqlens_k[batch_idx].item() + kv_end = cu_seqlens_k[batch_idx + 1].item() + + q_seqlen = q_end - q_start + kv_seqlen = kv_end - kv_start + + if q_seqlen == 0: + continue + + q_batch = q[q_start:q_end] + k_batch = k[kv_start:kv_end] + v_batch = v[kv_start:kv_end] + + offset = kv_seqlen - q_seqlen + + output_batch = torch.zeros((q_seqlen, self.heads, self.dim), + dtype=q.dtype, + device=device) + + for kv_head_idx in range(head_kv): + head_start = kv_head_idx * self.groups + head_end = head_start + self.groups + + q_group = q_batch[:, head_start:head_end, :] + k_head = k_batch[:, kv_head_idx, :] + v_head = v_batch[:, kv_head_idx, :] + + scores = torch.einsum('qgd,kd->qgk', q_group, k_head) * scale + + q_positions = torch.arange(q_seqlen, device=device, dtype=torch.float32) + kv_positions = torch.arange(kv_seqlen, device=device, dtype=torch.float32) + q_abs_positions = q_positions.unsqueeze(-1) + offset + kv_abs_positions = kv_positions.unsqueeze(0) + + mask = torch.zeros((q_seqlen, kv_seqlen), dtype=torch.bool, device=device) + + if self.is_causal: + causal_mask = (q_positions.unsqueeze(-1) + offset < kv_positions.unsqueeze(0)) + mask = mask | causal_mask + + if has_window: + if self.window_size_left >= 0: + window_left_mask = kv_abs_positions < ( + q_abs_positions - self.window_size_left) + mask = mask | window_left_mask + + if self.window_size_right >= 0: + window_right_mask = kv_abs_positions > ( + q_abs_positions + self.window_size_right) + mask = mask | window_right_mask + + scores = scores.masked_fill(mask.unsqueeze(1), float('-inf')) + + if self.is_causal and offset < 0: + invalid_mask = (q_positions + offset < 0) + scores = scores.masked_fill( + invalid_mask.unsqueeze(-1).unsqueeze(-1), float('-inf')) + + probs = torch.softmax(scores, dim=-1) + + out_group = torch.einsum('qgk,kd->qgd', probs, v_head) + + if self.is_causal and offset < 0: + invalid_positions = (q_positions + offset < 0) + out_group[invalid_positions] = 0 + + output_batch[:, head_start:head_end, :] = out_group + + output[q_start:q_end] = output_batch + + return output + + def baseline_program(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + cu_seqlens_q: torch.LongTensor, cu_seqlens_k: torch.LongTensor, + max_seqlen_q: int) -> torch.Tensor: + import flash_attn + max_seqlen_k = int((cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()) + return flash_attn.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + causal=self.is_causal, + window_size=(self.window_size_left, self.window_size_right) + if self.window_size_left >= 0 or self.window_size_right >= 0 else (-1, -1), + ) + + def baseline_profile( + self, + *inputs: tuple[torch.Tensor, ...], + warmup: int = 100, + rep: int = 100, + device: str = "cuda", + ) -> torch.Tensor: + print("===== Profiling FLA GQA Window Sliding backend =====") + return super().baseline_profile( + self.baseline_program, + *inputs, + backend="FLA GQA Window Sliding", + warmup=warmup, + rep=rep, + device=device) diff --git a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py new file mode 100644 index 0000000..f12ee3b --- /dev/null +++ b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py @@ -0,0 +1,105 @@ +"""Test DeepSeek NSA GQA Window Sliding operation.""" + +import pytest +import torch + +from benchmarks.deepseek_nsa.deepseek_nsa import GQAWindowSlidingBenchmark +from top.ops import GQAWindowSlidingOp + + +@pytest.fixture(autouse=True) +def setup() -> None: + """Set up the test environment.""" + torch.manual_seed(1234) + + +@pytest.mark.parametrize( + ("batch_size", "groups", "uq", "ukv", "heads", "dim", "is_causal", "window_size_left", + "window_size_right", "dtype", "accum_dtype", "tune"), + [ + (1, 16, 1024, 1024, 64, 128, True, 32, -1, torch.float16, torch.float32, False), + (3, 16, 8192, 8192, 64, 128, True, 2048, 0, torch.float16, torch.float32, False), + (3, 16, 8192, 8192, 64, 128, False, -1, -1, torch.float16, torch.float32, False), + ], +) +def test_nsa_gqa_window_sliding_op( + batch_size: int, + groups: int, + uq: int, + ukv: int, + heads: int, + dim: int, + is_causal: bool, + window_size_left: int, + window_size_right: int, + dtype: torch.dtype, + accum_dtype: torch.dtype, + tune: bool, +) -> None: + + assert groups % 16 == 0, "Group size must be a multiple of 16 in NSA" + + params = { + "batch_size": batch_size, + "groups": groups, + "uq": uq, + "ukv": ukv, + "heads": heads, + "dim": dim, + "is_causal": is_causal, + "window_size_left": window_size_left, + "window_size_right": window_size_right, + "dtype": dtype, + "accum_dtype": accum_dtype, + "tune": tune, + } + benchmark = GQAWindowSlidingBenchmark(**params) + op = GQAWindowSlidingOp(**params) + + inputs = benchmark.gen_inputs() + benchmark.check(op, *inputs) + benchmark.baseline_profile(*inputs) + benchmark.profile(op, *inputs) + + +if __name__ == "__main__": + + test_nsa_gqa_window_sliding_op( + batch_size=1, + groups=16, + uq=1024, + ukv=1024, + heads=64, + dim=128, + is_causal=True, + window_size_left=32, + window_size_right=-1, + dtype=torch.float16, + accum_dtype=torch.float32, + tune=False) + test_nsa_gqa_window_sliding_op( + batch_size=3, + groups=16, + uq=8192, + ukv=8192, + heads=64, + dim=128, + is_causal=True, + window_size_left=2048, + window_size_right=0, + dtype=torch.float16, + accum_dtype=torch.float32, + tune=False) + test_nsa_gqa_window_sliding_op( + batch_size=3, + groups=16, + uq=8192, + ukv=8192, + heads=64, + dim=128, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + dtype=torch.float16, + accum_dtype=torch.float32, + tune=False) diff --git a/top/kernels/deepseek_nsa/__init__.py b/top/kernels/deepseek_nsa/__init__.py index 820ed3b..5a16357 100644 --- a/top/kernels/deepseek_nsa/__init__.py +++ b/top/kernels/deepseek_nsa/__init__.py @@ -2,7 +2,12 @@ from .nsa_fwd import NSAFwdVarlenKernel from .nsa_topk import NSATopkVarlenKernel from .nsa_cmp_fwd import NSACmpFwdVarlenKernel +from .gqa_window_sliding import GQAWindowSlidingKernel __all__ = [ - "MeanPoolingFwdKernel", "NSAFwdVarlenKernel", "NSATopkVarlenKernel", "NSACmpFwdVarlenKernel" + "MeanPoolingFwdKernel", + "NSAFwdVarlenKernel", + "NSATopkVarlenKernel", + "NSACmpFwdVarlenKernel", + "GQAWindowSlidingKernel", ] diff --git a/top/kernels/deepseek_nsa/gqa_window_sliding.py b/top/kernels/deepseek_nsa/gqa_window_sliding.py new file mode 100644 index 0000000..3f02bfb --- /dev/null +++ b/top/kernels/deepseek_nsa/gqa_window_sliding.py @@ -0,0 +1,310 @@ +import torch +from typing import Optional, Any, Callable +import itertools +import tilelang +from tilelang import language as T + +from top.kernels.kernel import Kernel + + +def _gqa_window_sliding_kernel( + batch_size: int, + groups: int, + uq: int, + ukv: int, + heads: int, + dim: int, + is_causal: bool, + window_size_left: int, + window_size_right: int, + dtype: str, + accum_dtype: str, +) -> Callable: + scale = (1.0 / dim)**0.5 * 1.44269504 + head_kv = heads // groups + has_window = window_size_left >= 0 or window_size_right >= 0 + + @tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + def _gqa_window_sliding_func(block_m: int, block_n: int, num_stages: int, threads: int): + q_shape = [uq, heads, dim] + kv_shape = [ukv, head_kv, dim] + o_shape = [uq, heads, dim] + + @T.prim_func + def _parallel_gqa_window_sliding_main( + q_unpad: T.Tensor(q_shape, dtype), + k_unpad: T.Tensor(kv_shape, dtype), + v_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel( + T.ceildiv(max_seqlen_q, block_m), heads, batch_size, + threads=threads) as (bx, by, bz): + q_shared = T.alloc_shared([block_m, dim], dtype) + k_shared = T.alloc_shared([block_n, dim], dtype) + v_shared = T.alloc_shared([block_n, dim], dtype) + o_shared = T.alloc_shared([block_m, dim], dtype) + acc_s = T.alloc_fragment([block_m, block_n], accum_dtype) + acc_s_cast = T.alloc_fragment([block_m, block_n], dtype) + acc_o = T.alloc_fragment([block_m, dim], accum_dtype) + scores_max = T.alloc_fragment([block_m], accum_dtype) + scores_max_prev = T.alloc_fragment([block_m], accum_dtype) + scores_scale = T.alloc_fragment([block_m], accum_dtype) + scores_sum = T.alloc_fragment([block_m], accum_dtype) + logsum = T.alloc_fragment([block_m], accum_dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.copy( + q_unpad[q_start_idx + bx * block_m:q_start_idx + (bx + 1) * block_m, + head_idx, :], q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + offset = kv_current_seqlen - q_current_seqlen + + if is_causal: + max_visible_k_idx = offset + (bx + 1) * block_m + if has_window and window_size_left >= 0: + loop_range = T.min( + T.ceildiv(max_visible_k_idx, block_n), + T.ceildiv(kv_current_seqlen, block_n)) + else: + loop_range = T.min( + T.ceildiv(max_visible_k_idx, block_n), + T.ceildiv(kv_current_seqlen, block_n)) + else: + loop_range = T.ceildiv(kv_current_seqlen, block_n) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + k_unpad[kv_start_idx + k * block_n:kv_start_idx + (k + 1) * block_n, + kv_head_idx, :], k_shared) + + if is_causal: + for i, j in T.Parallel(block_m, block_n): + causal_mask = (bx * block_m + i + offset < k * block_n + j) + + window_mask_left = T.if_then_else( + has_window and window_size_left >= 0, (k * block_n + j) + < (bx * block_m + i + offset - window_size_left), (k * block_n + j) + < -1) + + boundary_mask = ( + bx * block_m + i >= q_current_seqlen or + k * block_n + j >= kv_current_seqlen) + + acc_s[i, j] = T.if_then_else( + causal_mask or window_mask_left or boundary_mask, + -1e9, + 0, + ) + else: + for i, j in T.Parallel(block_m, block_n): + window_mask_left = T.if_then_else( + has_window and window_size_left >= 0, (k * block_n + j) + < (bx * block_m + i + offset - window_size_left), (k * block_n + j) + < -1) + window_mask_right = T.if_then_else( + has_window and window_size_right >= 0, (k * block_n + j) + > (bx * block_m + i + offset + window_size_right), (k * block_n + j) + < -1) + + boundary_mask = ( + bx * block_m + i >= q_current_seqlen or + k * block_n + j >= kv_current_seqlen) + + acc_s[i, j] = T.if_then_else( + window_mask_left or window_mask_right or boundary_mask, + -1e9, + 0, + ) + + T.gemm( + q_shared, + k_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_m): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_m): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_m, block_n): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_m): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_m, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy( + v_unpad[kv_start_idx + k * block_n:kv_start_idx + (k + 1) * block_n, + kv_head_idx, :], v_shared) + + T.gemm(acc_s_cast, v_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_m, dim): + acc_o[i, j] = 0 if is_causal and bx * block_m + i + offset < 0 else acc_o[ + i, j] / logsum[i] + + T.copy(acc_o, o_shared) + for i, d in T.Parallel(block_m, dim): + if bx * block_m + i < q_current_seqlen: + output_unpad[q_start_idx + bx * block_m + i, head_idx, d] = o_shared[i, d] + + return _parallel_gqa_window_sliding_main + + return _gqa_window_sliding_func + + +@torch.library.custom_op("top::gqa_window_sliding_wrapped_kernel", mutates_args=()) +def _gqa_window_sliding_wrapped_kernel( + batch_size: int, + groups: int, + uq: int, + ukv: int, + heads: int, + dim: int, + is_causal: bool, + window_size_left: int, + window_size_right: int, + dtype: str, + accum_dtype: str, + block_m: int, + block_n: int, + num_stages: int, + threads: int, + q_unpad: torch.Tensor, + k_unpad: torch.Tensor, + v_unpad: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, +) -> torch.Tensor: + return _gqa_window_sliding_kernel(batch_size, groups, uq, ukv, heads, dim, is_causal, + window_size_left, window_size_right, dtype, + accum_dtype)(block_m, block_n, num_stages, + threads)(q_unpad, k_unpad, v_unpad, cu_seqlens_q, + cu_seqlens_k, max_seqlen_q) + + +@_gqa_window_sliding_wrapped_kernel.register_fake +def _( + batch_size: int, + groups: int, + uq: int, + ukv: int, + heads: int, + dim: int, + is_causal: bool, + window_size_left: int, + window_size_right: int, + dtype: str, + accum_dtype: str, + block_m: int, + block_n: int, + num_stages: int, + threads: int, + *inputs: tuple[Any], +) -> torch.Tensor: + _ = (batch_size, groups, uq, ukv, heads, dim, is_causal, window_size_left, window_size_right, + dtype, accum_dtype, block_m, block_n, num_stages, threads) + return torch.empty([uq, heads, dim], dtype=inputs[0].dtype, device=inputs[0].device) + + +class GQAWindowSlidingKernel(Kernel): + supported_archs: list[int] = [90] + + def __init__(self, + batch_size: int, + groups: int, + uq: int, + ukv: int, + heads: int, + dim: int, + is_causal: bool, + window_size_left: int, + window_size_right: int, + dtype: torch.dtype, + accum_dtype: torch.dtype, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.batch_size = batch_size + self.groups = groups + self.uq = uq + self.ukv = ukv + self.heads = heads + self.dim = dim + self.is_causal = is_causal + self.window_size_left = window_size_left + self.window_size_right = window_size_right + self.dtype = dtype + self.accum_dtype = accum_dtype + self.dtype_name = str(dtype).split('.')[-1] + self.accum_dtype_name = str(accum_dtype).split('.')[-1] + + self.kernel = _gqa_window_sliding_kernel(self.batch_size, self.groups, self.uq, self.ukv, + self.heads, self.dim, self.is_causal, + self.window_size_left, self.window_size_right, + self.dtype_name, self.accum_dtype_name) + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_m": 128, + "block_n": 128, + "num_stages": 2, + "threads": 256, + } + + @property + def autotune_configs(self) -> list[dict]: + block_m = [64, 128] + block_n = [64, 128] + num_stages = [1] + threads = [128] + _configs = list(itertools.product(block_m, block_n, num_stages, threads)) + return [{ + "block_m": c[0], + "block_n": c[1], + "num_stages": c[2], + "threads": c[3] + } for c in _configs] + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, max_seqlen_q: int) -> torch.Tensor: + return _gqa_window_sliding_wrapped_kernel( + self.batch_size, self.groups, self.uq, self.ukv, self.heads, self.dim, self.is_causal, + self.window_size_left, self.window_size_right, self.dtype_name, self.accum_dtype_name, + self.config["block_m"], self.config["block_n"], self.config["num_stages"], + self.config["threads"], q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 83171b5..9791bc9 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -3,7 +3,7 @@ from .topk_selector import TopkSelectorOp from .fp8_quant import Fp8QuantOp from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp -from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp +from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp from .gemm import GemmOp from .gqa import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp @@ -36,6 +36,7 @@ "NSATopkVarlenOp", "NSAFwdVarlenOp", "NSACmpFwdVarlenOp", + "GQAWindowSlidingOp", "ManifoldConstrainedHyperConnectionPreOp", "ManifoldConstrainedHyperConnectionPostOp", ] diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index 584b9c3..447450a 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -6,10 +6,17 @@ from top.kernels.deepseek_nsa.nsa_fwd import NSAFwdVarlenKernel from top.kernels.deepseek_nsa.nsa_topk import NSATopkVarlenKernel from top.kernels.deepseek_nsa.nsa_cmp_fwd import NSACmpFwdVarlenKernel +from top.kernels.deepseek_nsa.gqa_window_sliding import GQAWindowSlidingKernel from top.kernels.kernel import Kernel from top.ops.op import Op -__all__ = ["MeanPoolingForwardOp", "NSAFwdVarlenOp", "NSATopkVarlenOp", "NSACmpFwdVarlenOp"] +__all__ = [ + "MeanPoolingForwardOp", + "NSAFwdVarlenOp", + "NSATopkVarlenOp", + "NSACmpFwdVarlenOp", + "GQAWindowSlidingOp", +] class MeanPoolingForwardOp(Op): @@ -157,3 +164,37 @@ def forward(self, q: torch.Tensor, k_cmp: torch.Tensor, v_cmp: torch.Tensor, offsets: torch.Tensor, chunk_offsets: torch.Tensor, token_indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return self.kernel(q, k_cmp, v_cmp, offsets, chunk_offsets, token_indices) + + +class GQAWindowSlidingOp(Op): + + def __init__( + self, + batch_size: int, + groups: int, + uq: int, + ukv: int, + heads: int, + dim: int, + is_causal: bool, + window_size_left: int, + window_size_right: int, + dtype: torch.dtype, + accum_dtype: torch.dtype, + tune: bool = False, + kernel_map: Optional[Dict[str, Kernel]] = None, + ) -> None: + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map')} + for key, value in params.items(): + setattr(self, key, value) + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["gqa_window_sliding_kernel"](**params) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"gqa_window_sliding_kernel": GQAWindowSlidingKernel} + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, max_seqlen_q: int) -> torch.Tensor: + return self.kernel(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) From 3e159a4b25efbee252a6e8782011c6fe6cd5f275 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Thu, 5 Feb 2026 16:49:38 +0800 Subject: [PATCH 07/16] [Feat][NSA] Implement a GQA forward kernel with sliding window. --- benchmarks/deepseek_nsa/deepseek_nsa.py | 4 ---- top/ops/__init__.py | 3 --- 2 files changed, 7 deletions(-) diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py index 80f9561..a71abce 100644 --- a/benchmarks/deepseek_nsa/deepseek_nsa.py +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -4,11 +4,7 @@ from einops import rearrange, repeat from benchmarks.benchmark import Benchmark -<<<<<<< HEAD from top.ops import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp -======= -from top.ops import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp ->>>>>>> main from .utils import prepare_token_indices, prepare_chunk_offsets diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 08760ae..9791bc9 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -36,10 +36,7 @@ "NSATopkVarlenOp", "NSAFwdVarlenOp", "NSACmpFwdVarlenOp", -<<<<<<< HEAD "GQAWindowSlidingOp", -======= ->>>>>>> main "ManifoldConstrainedHyperConnectionPreOp", "ManifoldConstrainedHyperConnectionPostOp", ] From 77f08bd3bb5bdb198d9dd88935222851b0dd61bf Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Thu, 5 Feb 2026 17:26:45 +0800 Subject: [PATCH 08/16] [Feat][NSA] Implement a GQA forward kernel with sliding window. --- .../test_deepseek_nsa_gqa_window_sliding.py | 2 -- top/ops/deepseek_nsa.py | 34 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py index f12ee3b..fd5828a 100644 --- a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py +++ b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py @@ -58,8 +58,6 @@ def test_nsa_gqa_window_sliding_op( inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) - benchmark.baseline_profile(*inputs) - benchmark.profile(op, *inputs) if __name__ == "__main__": diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index 684730c..ef16c20 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -213,4 +213,38 @@ def default_kernel_map(self) -> Dict[str, Kernel]: def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int) -> torch.Tensor: + # Security validation: prevent OOB writes by validating input tensors + # 1. Check tensor shapes + assert cu_seqlens_q.shape[0] == self.batch_size + 1, \ + f"cu_seqlens_q.shape[0] ({cu_seqlens_q.shape[0]}) must equal batch_size + 1 ({self.batch_size + 1})" + assert cu_seqlens_k.shape[0] == self.batch_size + 1, \ + f"cu_seqlens_k.shape[0] ({cu_seqlens_k.shape[0]}) must equal batch_size + 1 ({self.batch_size + 1})" + + # 2. Check that values are non-decreasing + cu_seqlens_q_diff = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + assert torch.all(cu_seqlens_q_diff >= 0), \ + "cu_seqlens_q must be non-decreasing" + cu_seqlens_k_diff = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + assert torch.all(cu_seqlens_k_diff >= 0), \ + "cu_seqlens_k must be non-decreasing" + + # 3. Check that maximum values don't exceed tensor dimensions + max_q_idx = cu_seqlens_q[-1].item() + assert max_q_idx <= self.uq, \ + f"cu_seqlens_q[-1] ({max_q_idx}) must not exceed uq ({self.uq})" + max_kv_idx = cu_seqlens_k[-1].item() + assert max_kv_idx <= self.ukv, \ + f"cu_seqlens_k[-1] ({max_kv_idx}) must not exceed ukv ({self.ukv})" + + # 4. Check that max_seqlen_q is consistent with actual maximum sequence length + actual_max_seqlen_q = cu_seqlens_q_diff.max().item() + assert max_seqlen_q >= actual_max_seqlen_q, \ + f"max_seqlen_q ({max_seqlen_q}) must be >= actual max sequence length ({actual_max_seqlen_q})" + + # 5. Additional safety: ensure cu_seqlens_q starts at 0 + assert cu_seqlens_q[0].item() == 0, \ + f"cu_seqlens_q[0] must be 0, got {cu_seqlens_q[0].item()}" + assert cu_seqlens_k[0].item() == 0, \ + f"cu_seqlens_k[0] must be 0, got {cu_seqlens_k[0].item()}" + return self.kernel(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) From 5b788fe569317a822444e694e98cad3d2bbf59c1 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Thu, 5 Feb 2026 17:45:15 +0800 Subject: [PATCH 09/16] [Feat][NSA] Implement a GQA forward kernel with sliding window. --- .../ops/test_deepseek_nsa_gqa_window_sliding.py | 2 ++ top/kernels/deepseek_nsa/gqa_window_sliding.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py index fd5828a..f12ee3b 100644 --- a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py +++ b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py @@ -58,6 +58,8 @@ def test_nsa_gqa_window_sliding_op( inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) + benchmark.baseline_profile(*inputs) + benchmark.profile(op, *inputs) if __name__ == "__main__": diff --git a/top/kernels/deepseek_nsa/gqa_window_sliding.py b/top/kernels/deepseek_nsa/gqa_window_sliding.py index 3f02bfb..6236406 100644 --- a/top/kernels/deepseek_nsa/gqa_window_sliding.py +++ b/top/kernels/deepseek_nsa/gqa_window_sliding.py @@ -85,14 +85,9 @@ def _parallel_gqa_window_sliding_main( if is_causal: max_visible_k_idx = offset + (bx + 1) * block_m - if has_window and window_size_left >= 0: - loop_range = T.min( - T.ceildiv(max_visible_k_idx, block_n), - T.ceildiv(kv_current_seqlen, block_n)) - else: - loop_range = T.min( - T.ceildiv(max_visible_k_idx, block_n), - T.ceildiv(kv_current_seqlen, block_n)) + loop_range = T.min( + T.ceildiv(max_visible_k_idx, block_n), + T.ceildiv(kv_current_seqlen, block_n)) else: loop_range = T.ceildiv(kv_current_seqlen, block_n) @@ -177,8 +172,10 @@ def _parallel_gqa_window_sliding_main( T.copy(acc_o, o_shared) for i, d in T.Parallel(block_m, dim): - if bx * block_m + i < q_current_seqlen: - output_unpad[q_start_idx + bx * block_m + i, head_idx, d] = o_shared[i, d] + q_pos = bx * block_m + i + output_idx = q_start_idx + q_pos + if q_pos < q_current_seqlen and output_idx < q_end_idx: + output_unpad[output_idx, head_idx, d] = o_shared[i, d] return _parallel_gqa_window_sliding_main From 7b9f264090255d79d83564e09827a357246d3f45 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Thu, 5 Feb 2026 17:53:26 +0800 Subject: [PATCH 10/16] [Feat][NSA] Implement a GQA forward kernel with sliding window. --- top/ops/deepseek_nsa.py | 45 ++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index ef16c20..1b69bb4 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -214,37 +214,46 @@ def default_kernel_map(self) -> Dict[str, Kernel]: def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int) -> torch.Tensor: # Security validation: prevent OOB writes by validating input tensors + # Using explicit if statements instead of assert to ensure validation + # is always performed, even with Python -O optimization flag + # 1. Check tensor shapes - assert cu_seqlens_q.shape[0] == self.batch_size + 1, \ - f"cu_seqlens_q.shape[0] ({cu_seqlens_q.shape[0]}) must equal batch_size + 1 ({self.batch_size + 1})" - assert cu_seqlens_k.shape[0] == self.batch_size + 1, \ - f"cu_seqlens_k.shape[0] ({cu_seqlens_k.shape[0]}) must equal batch_size + 1 ({self.batch_size + 1})" + if cu_seqlens_q.shape[0] != self.batch_size + 1: + raise ValueError( + f"cu_seqlens_q.shape[0] ({cu_seqlens_q.shape[0]}) must equal batch_size + 1 ({self.batch_size + 1})" + ) + if cu_seqlens_k.shape[0] != self.batch_size + 1: + raise ValueError( + f"cu_seqlens_k.shape[0] ({cu_seqlens_k.shape[0]}) must equal batch_size + 1 ({self.batch_size + 1})" + ) # 2. Check that values are non-decreasing cu_seqlens_q_diff = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - assert torch.all(cu_seqlens_q_diff >= 0), \ - "cu_seqlens_q must be non-decreasing" + if not torch.all(cu_seqlens_q_diff >= 0): + raise ValueError("cu_seqlens_q must be non-decreasing") cu_seqlens_k_diff = cu_seqlens_k[1:] - cu_seqlens_k[:-1] - assert torch.all(cu_seqlens_k_diff >= 0), \ - "cu_seqlens_k must be non-decreasing" + if not torch.all(cu_seqlens_k_diff >= 0): + raise ValueError("cu_seqlens_k must be non-decreasing") # 3. Check that maximum values don't exceed tensor dimensions max_q_idx = cu_seqlens_q[-1].item() - assert max_q_idx <= self.uq, \ - f"cu_seqlens_q[-1] ({max_q_idx}) must not exceed uq ({self.uq})" + if max_q_idx > self.uq: + raise ValueError(f"cu_seqlens_q[-1] ({max_q_idx}) must not exceed uq ({self.uq})") max_kv_idx = cu_seqlens_k[-1].item() - assert max_kv_idx <= self.ukv, \ - f"cu_seqlens_k[-1] ({max_kv_idx}) must not exceed ukv ({self.ukv})" + if max_kv_idx > self.ukv: + raise ValueError(f"cu_seqlens_k[-1] ({max_kv_idx}) must not exceed ukv ({self.ukv})") # 4. Check that max_seqlen_q is consistent with actual maximum sequence length actual_max_seqlen_q = cu_seqlens_q_diff.max().item() - assert max_seqlen_q >= actual_max_seqlen_q, \ - f"max_seqlen_q ({max_seqlen_q}) must be >= actual max sequence length ({actual_max_seqlen_q})" + if max_seqlen_q < actual_max_seqlen_q: + raise ValueError( + f"max_seqlen_q ({max_seqlen_q}) must be >= actual max sequence length ({actual_max_seqlen_q})" + ) # 5. Additional safety: ensure cu_seqlens_q starts at 0 - assert cu_seqlens_q[0].item() == 0, \ - f"cu_seqlens_q[0] must be 0, got {cu_seqlens_q[0].item()}" - assert cu_seqlens_k[0].item() == 0, \ - f"cu_seqlens_k[0] must be 0, got {cu_seqlens_k[0].item()}" + if cu_seqlens_q[0].item() != 0: + raise ValueError(f"cu_seqlens_q[0] must be 0, got {cu_seqlens_q[0].item()}") + if cu_seqlens_k[0].item() != 0: + raise ValueError(f"cu_seqlens_k[0] must be 0, got {cu_seqlens_k[0].item()}") return self.kernel(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) From d12ba27b02bd888e7ba1dddcc96e4afcbdd8dca7 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Thu, 5 Feb 2026 17:57:29 +0800 Subject: [PATCH 11/16] [Feat][NSA] Implement a GQA forward kernel with sliding window. --- tests/ops/test_deepseek_nsa_gqa_window_sliding.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py index f12ee3b..fd5828a 100644 --- a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py +++ b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py @@ -58,8 +58,6 @@ def test_nsa_gqa_window_sliding_op( inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) - benchmark.baseline_profile(*inputs) - benchmark.profile(op, *inputs) if __name__ == "__main__": From dcf2f409d030ea39e896f943569beda952355b4d Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Mon, 9 Feb 2026 11:41:48 +0800 Subject: [PATCH 12/16] [Fix][NSA] Fixed a bug in pytest parameter passing in . --- tests/ops/test_deepseek_nsa_cmp_fwd.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index 44fff4d..3faa2bb 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -1,3 +1,4 @@ +import inspect import pytest import torch @@ -37,8 +38,14 @@ def test_nsa_cmp_fwd_varlen_op( assert group % 16 == 0, "Group size must be a multiple of 16 in NSA" - # Use locals() to create params dictionary from function arguments - params = locals().copy() + # Create params dictionary from function arguments, excluding pytest internals + # Filter out any keys that start with '@' (pytest internal variables) + sig = inspect.signature(test_nsa_cmp_fwd_varlen_op) + params = { + name: locals()[name] + for name in sig.parameters.keys() + if not name.startswith('@') + } benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs() From 81ade8abed1257bb4ebd0c764223081ac0e0cd18 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Mon, 9 Feb 2026 11:49:46 +0800 Subject: [PATCH 13/16] [Fix][NSA] Fixed a bug in pytest parameter passing in . --- tests/ops/test_deepseek_nsa_cmp_fwd.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index 9ff501d..ac4a955 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -41,11 +41,7 @@ def test_nsa_cmp_fwd_varlen_op( # Create params dictionary from function arguments, excluding pytest internals # Filter out any keys that start with '@' (pytest internal variables) sig = inspect.signature(test_nsa_cmp_fwd_varlen_op) - params = { - name: locals()[name] - for name in sig.parameters.keys() - if not name.startswith('@') - } + params = {name: locals()[name] for name in sig.parameters.keys() if not name.startswith('@')} benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs() From 4ce06c0658d5db613658713d4e1b26701bfb427f Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 10 Feb 2026 11:00:26 +0800 Subject: [PATCH 14/16] [Fix][NSA] Fixed a bug in pytest parameter passing in . --- tests/ops/test_deepseek_nsa_cmp_fwd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index 553a716..815c068 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -35,8 +35,10 @@ def test_nsa_cmp_fwd_varlen_op( # Create params dictionary from function arguments, excluding pytest internals # Filter out any keys that start with '@' (pytest internal variables) + # Note: Need to capture locals() before list comprehension due to scope issues + local_vars = locals() sig = inspect.signature(test_nsa_cmp_fwd_varlen_op) - params = {name: locals()[name] for name in sig.parameters.keys() if not name.startswith('@')} + params = {name: local_vars[name] for name in sig.parameters.keys() if not name.startswith('@')} benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs() From 9cda9e19cb42d065d552eeb8e6f5013c048ee159 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 10 Feb 2026 11:19:51 +0800 Subject: [PATCH 15/16] [Fix][NSA] Fixed a bug in pytest parameter passing in . --- tests/ops/test_deepseek_nsa_cmp_fwd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index 815c068..b3d96f4 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -33,12 +33,12 @@ def test_nsa_cmp_fwd_varlen_op( assert group % 16 == 0, "Group size must be a multiple of 16 in NSA" - # Create params dictionary from function arguments, excluding pytest internals - # Filter out any keys that start with '@' (pytest internal variables) + # Create params dictionary from function arguments using the function signature + # to avoid including pytest-injected local variables. # Note: Need to capture locals() before list comprehension due to scope issues local_vars = locals() sig = inspect.signature(test_nsa_cmp_fwd_varlen_op) - params = {name: local_vars[name] for name in sig.parameters.keys() if not name.startswith('@')} + params = {name: local_vars[name] for name in sig.parameters} benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs() From 056b03d6785a128591ba94fca7c331dc79176c2a Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 10 Feb 2026 11:29:49 +0800 Subject: [PATCH 16/16] [Fix][NSA] Fixed a bug in pytest parameter passing in . --- tests/ops/test_deepseek_nsa_cmp_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index b3d96f4..7fe96b3 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -37,7 +37,7 @@ def test_nsa_cmp_fwd_varlen_op( # to avoid including pytest-injected local variables. # Note: Need to capture locals() before list comprehension due to scope issues local_vars = locals() - sig = inspect.signature(test_nsa_cmp_fwd_varlen_op) + sig = inspect.signature(globals()[inspect.stack()[0].function]) params = {name: local_vars[name] for name in sig.parameters} benchmark = NSACmpFwdVarlenBenchmark(**params) inputs = benchmark.gen_inputs()