From 3476b96872d627ba95b04bbb953742d10f1ed2a6 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Fri, 26 Dec 2025 17:21:26 +0800 Subject: [PATCH 01/14] nsa kernel --- benchmarks/__init__.py | 3 ++- benchmarks/profile/profile_run.py | 25 +++++++++++++++++++++++++ top/functions/__init__.py | 3 ++- top/ops/__init__.py | 2 ++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 66315e63..7a0aed7d 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -1,11 +1,12 @@ from .benchmark import Benchmark # noqa: F401 +from .deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark from .flash_attn import MultiHeadAttentionBenchmark, MultiHeadAttentionBwdBenchmark, MultiHeadAttentionFwdBenchmark, GroupQueryAttentionBenchmark, GroupQueryAttentionFwdBenchmark, GroupQueryAttentionBwdBenchmark from .gemm import GemmBenchmark, MatMulBenchmark from .flash_decode import MultiHeadAttentionDecodeBenchmark, GroupQueryAttentionDecodeBenchmark from .deepseek_mla import MultiHeadLatentAttentionDecodeBenchmark, DeepSeekSparseAttentionDecodeBenchmark - __all__ = [ 'Benchmark', + 'NativeSparseAttentionForwardBenchmark', 'MultiHeadAttentionBenchmark', 'MultiHeadAttentionBwdBenchmark', 'MultiHeadAttentionFwdBenchmark', diff --git a/benchmarks/profile/profile_run.py b/benchmarks/profile/profile_run.py index b0f00c17..521f8dbc 100644 --- a/benchmarks/profile/profile_run.py +++ b/benchmarks/profile/profile_run.py @@ -90,6 +90,29 @@ def build_gqa_decode_cmd(args_dict): return cmd_args +def build_nsa_cmd(args_dict): + """ + Build command arguments for Native Sparse Attention test script + """ + cmd_args = [ + '--batch', + str(args_dict['batch']), '--heads', + str(args_dict['heads']), '--seq_len', + str(args_dict['seq_len']), '--dim', + str(args_dict['dim']), '--scale', + str(args_dict.get('scale', 0.1)), '--block_size', + str(args_dict['block_size']), '--groups', + str(args_dict['groups']), '--selected_blocks', + str(args_dict['selected_blocks']), + ] + + if args_dict.get('is_causal', 'True').lower() == 'true': + cmd_args.append('--is_causal') + if args_dict.get('tune', 'False').lower() == 'true': + cmd_args.append('--tune') + return cmd_args + + def build_mla_decode_cmd(args_dict): """ Build command arguments for MLA decode test script @@ -196,6 +219,8 @@ def run_test_script(script_path, args_dict): cmd_args = build_mla_decode_cmd(args_dict) elif 'sparse_mla' in script_name: cmd_args = build_sparse_mla_cmd(args_dict) + elif 'deepseek_nsa' in script_name or 'nsa' in script_name: + cmd_args = build_nsa_cmd(args_dict) elif 'mha' in script_name: cmd_args = build_mha_cmd(args_dict) elif 'gqa' in script_name: diff --git a/top/functions/__init__.py b/top/functions/__init__.py index 68370615..728b5339 100644 --- a/top/functions/__init__.py +++ b/top/functions/__init__.py @@ -6,7 +6,7 @@ from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheFunc from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheFunc from .matmul import MatMulFunc - +from .deepseek_nsa import NativeSparseAttentionFunc __all__ = [ "Function", "MultiHeadAttentionFunc", @@ -16,4 +16,5 @@ "MultiHeadLatentAttentionDecodeWithKVCacheFunc", "DeepSeekSparseAttentionDecodeWithKVCacheFunc", "MatMulFunc", + "NativeSparseAttentionFunc", ] diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 61d96403..cc49f188 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -6,6 +6,7 @@ from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheOp +from .deepseek_nsa import NativeSparseAttentionForwardOp __all__ = [ "Op", @@ -18,4 +19,5 @@ "GroupQueryAttentionDecodeWithKVCacheOp", "MultiHeadLatentAttentionDecodeWithKVCacheOp", "DeepSeekSparseAttentionDecodeWithKVCacheOp", + "NativeSparseAttentionForwardOp", ] From de259cdeabf5b486af9197023e3b5b6f3785496b Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Fri, 26 Dec 2025 23:22:53 +0800 Subject: [PATCH 02/14] add deepseek nsa --- benchmarks/deepseek_nsa/__init__.py | 5 + benchmarks/deepseek_nsa/deepseek_nsa.py | 81 +++++ benchmarks/input_params/deepseek_nsa.csv | 2 + benchmarks/profile_run2.sh | 78 +++++ test_tileops.py | 34 ++ tests/functions/test_deepseek_nsa_func.py | 42 +++ tests/layers/test_deepseek_nsa_layer.py | 42 +++ tests/ops/test_deepseek_nsa_ops.py | 53 +++ top/functions/deepseek_nsa.py | 116 +++++++ top/kernels/deepseek_nsa/__init__.py | 2 + top/kernels/deepseek_nsa/nsa_fwd.py | 319 ++++++++++++++++++ top/kernels/deepseek_nsa/nsa_torch.py | 380 ++++++++++++++++++++++ top/layers/deepseek_nsa.py | 100 ++++++ top/ops/deepseek_nsa.py | 109 +++++++ 14 files changed, 1363 insertions(+) create mode 100644 benchmarks/deepseek_nsa/__init__.py create mode 100644 benchmarks/deepseek_nsa/deepseek_nsa.py create mode 100644 benchmarks/input_params/deepseek_nsa.csv create mode 100644 benchmarks/profile_run2.sh create mode 100644 test_tileops.py create mode 100644 tests/functions/test_deepseek_nsa_func.py create mode 100644 tests/layers/test_deepseek_nsa_layer.py create mode 100644 tests/ops/test_deepseek_nsa_ops.py create mode 100644 top/functions/deepseek_nsa.py create mode 100644 top/kernels/deepseek_nsa/__init__.py create mode 100644 top/kernels/deepseek_nsa/nsa_fwd.py create mode 100644 top/kernels/deepseek_nsa/nsa_torch.py create mode 100644 top/layers/deepseek_nsa.py create mode 100644 top/ops/deepseek_nsa.py diff --git a/benchmarks/deepseek_nsa/__init__.py b/benchmarks/deepseek_nsa/__init__.py new file mode 100644 index 00000000..911147ad --- /dev/null +++ b/benchmarks/deepseek_nsa/__init__.py @@ -0,0 +1,5 @@ +from .deepseek_nsa import NativeSparseAttentionForwardBenchmark + +__all__ = [ + "NativeSparseAttentionForwardBenchmark", +] diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py new file mode 100644 index 00000000..abbe54c7 --- /dev/null +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -0,0 +1,81 @@ +from benchmarks.benchmark import Benchmark +from top.ops import NativeSparseAttentionForwardOp +import torch +from torch.nn import functional as f +from top.kernels.deepseek_nsa.nsa_torch import naive_nsa + +class NativeSparseAttentionForwardBenchmark(Benchmark): + op_type = NativeSparseAttentionForwardOp + + def __init__( + self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + # tune=False + ): + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + + self.head_kv = self.heads // self.groups + self.dtype = torch.float16 + + @property + def total_flops(self): + flops_per_matmul = 2.0 * self.batch * self.heads * self.seq_len * self.dim + flops = flops_per_matmul * 2 + return flops + + @property + def total_memory(self): + return (self.batch * self.heads * (2 * self.seq_len) * self.dim * self.dtype.itemsize) + # q_shape = [batch, seq_len, heads, dim] + # kv_shape = [batch, seq_len, head_kv, dim] + # block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + def gen_inputs(self): + Q = torch.randn( + self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + K = torch.randn( + self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype) + V = torch.randn( + self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype) + + self.g_slc = torch.ones((self.batch, self.seq_len, self.heads), dtype=self.dtype, device="cuda").requires_grad_(True) + self.g_swa = torch.ones((self.batch, self.seq_len, self.heads), dtype=self.dtype, device="cuda").requires_grad_(True) + + block_indices = torch.full((self.batch, self.seq_len, self.head_kv, self.selected_blocks), self.seq_len, dtype=torch.long, device="cuda") + self.block_counts = torch.zeros((self.batch, self.seq_len, self.head_kv), dtype=torch.long, device="cuda") + for b in range(self.batch): + for t in range(self.seq_len): + for h in range(self.head_kv): + i_i = torch.randperm(max(1, (t // self.block_size)))[:self.selected_blocks] + block_indices[b, t, h, : len(i_i)] = i_i + self.block_counts[b, t, h] = (block_indices[b, t, h] != self.seq_len).sum().item() + block_indices = block_indices.sort(-1)[0] + return Q, K, V, block_indices + + def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): + return naive_nsa( + q=Q, + k=K, + v=V, + g_slc=self.g_slc, + g_swa=self.g_swa, + block_indices=BlockIndices, + block_counts=slblock_counts, + block_size=block_size, + scale=scale, + ) \ No newline at end of file diff --git a/benchmarks/input_params/deepseek_nsa.csv b/benchmarks/input_params/deepseek_nsa.csv new file mode 100644 index 00000000..9ff09f86 --- /dev/null +++ b/benchmarks/input_params/deepseek_nsa.csv @@ -0,0 +1,2 @@ +batch,heads,seq_len,dim,is_causal,scale,block_size,groups,selected_blocks,tune +2,16,64,32,True,0.1,32,2,32,False \ No newline at end of file diff --git a/benchmarks/profile_run2.sh b/benchmarks/profile_run2.sh new file mode 100644 index 00000000..2165231c --- /dev/null +++ b/benchmarks/profile_run2.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# Default parameters +PROFILE_OUT="./profile_out" +LOG_FILE="./profile_run.log" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --profile_out) + PROFILE_OUT="$2" + shift 2 + ;; + --log) + LOG_FILE="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Check and handle existing PROFILE_OUT directory +if [ -d "$PROFILE_OUT" ]; then + echo "Warning: PROFILE_OUT directory '$PROFILE_OUT' already exists." +fi + +# Check and handle existing LOG_FILE +if [ -f "$LOG_FILE" ]; then + echo "Warning: LOG_FILE '$LOG_FILE' already exists. Overwriting..." +fi + +# Create output directory +mkdir -p "$PROFILE_OUT" + +# Separator function +print_separator() { + echo "========================================" >> "$LOG_FILE" + echo "========================================" +} + +# Function to run tests +run_test() { + local test_name=$1 + local script_path=$2 + local csv_path=$3 + + echo "Running $test_name test..." | tee -a "$LOG_FILE" + print_separator + + local output_csv="$PROFILE_OUT/${test_name}_results.csv" + + python3 ./benchmarks/profile/profile_run.py \ + --script "$script_path" \ + --input_csv "$csv_path" \ + --output_csv "$output_csv" \ + 2>&1 | tee -a "$LOG_FILE" + + echo "Results saved to: $output_csv" | tee -a "$LOG_FILE" + echo "" | tee -a "$LOG_FILE" +} + +# Main execution flow +{ + + +echo "Starting profile run at $(date)" +print_separator + +# Run GEMM test +run_test "deepseek_nsa" "./tests/ops/test_deepseek_nsa_ops.py" "./benchmarks/input_params/deepseek_nsa.csv" + +print_separator +echo "All tests completed at $(date)" + +} | tee -a "$LOG_FILE" \ No newline at end of file diff --git a/test_tileops.py b/test_tileops.py new file mode 100644 index 00000000..25d24aeb --- /dev/null +++ b/test_tileops.py @@ -0,0 +1,34 @@ +import torch +import top +from top import MLAKernel + +device = "cuda" +dtype = torch.float16 + +batch = 128 +heads = 64 +kv_heads = 1 +kv_ctx = 8192 +dim = 512 +pe_dim = 64 + +# Query input: [batch, heads, dim] +q = torch.randn(batch, heads, dim, device=device, dtype=dtype) + +# Query positional encoding: [batch, heads, pe_dim] +q_pe = torch.randn(batch, heads, pe_dim, device=device, dtype=dtype) + +# KV cache input: [batch, kv_ctx, kv_heads, dim] +kv = torch.randn(batch, kv_ctx, kv_heads, dim, device=device, dtype=dtype) + +# KV positional encoding: [batch, kv_ctx, kv_heads, pe_dim] +k_pe = torch.randn(batch, kv_ctx, kv_heads, pe_dim, device=device, dtype=dtype) + +# Use MLA kernel +block_N = 64 +block_H = 64 +num_split = 1 + +mla = MLAKernel(batch, heads, kv_heads, kv_ctx, dim, pe_dim, block_N, block_H, num_split) + +out = mla(q, q_pe, kv, k_pe) diff --git a/tests/functions/test_deepseek_nsa_func.py b/tests/functions/test_deepseek_nsa_func.py new file mode 100644 index 00000000..d7383b86 --- /dev/null +++ b/tests/functions/test_deepseek_nsa_func.py @@ -0,0 +1,42 @@ +import argparse +from top.functions import NativeSparseAttentionForwardFunc +from top.utils import str2dtype +from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark + + +def test_nsa_op( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + # dtype='float16', + tune=False, + ): + func = NativeSparseAttentionForwardFunc(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks) + + inputs = benchmark.gen_inputs() + benchmark.check(func, *inputs) + benchmark.profile(func, *inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=2, help='batch size') + parser.add_argument('--heads', type=int, default=16, help='number of heads') + parser.add_argument('--seq_len', type=int, default=64, help='sequence length') + parser.add_argument('--dim', type=int, default=32, help='head dim') + parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument('--scale', type=float, default=0.1, help='scale') + parser.add_argument('--block_size', type=int, default=32, help='block size') + parser.add_argument('--groups', type=int, default=2, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=32, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + args = parser.parse_args() + + test_nsa_op(args.batch, args.heads, args.seq_len, args.dim, str2dtype[args.dtype], args.tune) \ No newline at end of file diff --git a/tests/layers/test_deepseek_nsa_layer.py b/tests/layers/test_deepseek_nsa_layer.py new file mode 100644 index 00000000..3acd8a61 --- /dev/null +++ b/tests/layers/test_deepseek_nsa_layer.py @@ -0,0 +1,42 @@ +import argparse +from top.layers import NativeSparseAttentionForwardLayer +from top.utils import str2dtype +from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark + + +def test_nsa_op( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + # dtype='float16', + tune=False, + ): + layer = NativeSparseAttentionForwardLayer(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks) + + inputs = benchmark.gen_inputs() + benchmark.check(layer, *inputs) + benchmark.profile(layer, *inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=2, help='batch size') + parser.add_argument('--heads', type=int, default=16, help='number of heads') + parser.add_argument('--seq_len', type=int, default=64, help='sequence length') + parser.add_argument('--dim', type=int, default=32, help='head dim') + parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument('--scale', type=float, default=0.1, help='scale') + parser.add_argument('--block_size', type=int, default=32, help='block size') + parser.add_argument('--groups', type=int, default=2, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=32, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + args = parser.parse_args() + + test_nsa_op(args.batch, args.heads, args.seq_len, args.dim, str2dtype[args.dtype], args.tune) \ No newline at end of file diff --git a/tests/ops/test_deepseek_nsa_ops.py b/tests/ops/test_deepseek_nsa_ops.py new file mode 100644 index 00000000..7c08ff0e --- /dev/null +++ b/tests/ops/test_deepseek_nsa_ops.py @@ -0,0 +1,53 @@ +import argparse +from top.ops import NativeSparseAttentionForwardOp +from top.utils import str2dtype +from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark + + +def test_nsa_op( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + # dtype='float16', + tune=False, +): + op = NativeSparseAttentionForwardOp(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks) + + inputs = benchmark.gen_inputs() + benchmark.check(op, *inputs) + benchmark.profile(op, *inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=2, help='batch size') + parser.add_argument('--heads', type=int, default=16, help='number of heads') + parser.add_argument('--seq_len', type=int, default=64, help='sequence length') + parser.add_argument('--dim', type=int, default=32, help='head dim') + parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument('--scale', type=float, default=0.1, help='scale') + parser.add_argument('--block_size', type=int, default=32, help='block size') + parser.add_argument('--groups', type=int, default=2, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=32, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + args = parser.parse_args() + + test_nsa_op( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.scale, + args.block_size, + args.groups, + args.selected_blocks, + args.tune, + ) \ No newline at end of file diff --git a/top/functions/deepseek_nsa.py b/top/functions/deepseek_nsa.py new file mode 100644 index 00000000..deecfcc0 --- /dev/null +++ b/top/functions/deepseek_nsa.py @@ -0,0 +1,116 @@ +import torch +from top.functions.function import Function +from top.ops.deepseek_nsa import NativeSparseAttentionForwardOp + +from top.kernels.deepseek_nsa.nsa_torch import naive_nsa + +__all__ = ['NativeSparseAttentionFunc'] + + +class nsa_decode_ctx(torch.autograd.Function): + + @staticmethod + def forward(ctx, Q, K, V, BlockIndices, fwd_op): + O = fwd_op(Q, K, V, BlockIndices) + return O + + @staticmethod + def backward(ctx, dO): + raise NotImplementedError("Backward pass is not implemented for nsa.") + + @staticmethod + def decode(ctx, dO): + raise NotImplementedError("Decode pass is not implemented for nsa.") + + +class NativeSparseAttentionFunc(Function): + + def __init__( + self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + tune=False): + + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + self.tune = tune + + self.fwd_op = NativeSparseAttentionForwardOp( + batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor) -> torch.Tensor: + return nsa_decode_ctx.apply(Q, K, V, BlockIndices, self.fwd_op) + + +# def main(): +# B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + +# block_T = min(128, 16) + +# kernel = NativeSparseAttentionFunc( +# batch=B, +# heads=HQ, +# seq_len=SEQ_LEN, +# dim=D, +# is_causal=True, +# block_size=block_size, +# groups=HQ // H, +# selected_blocks=S, +# scale=scale, +# tune=True, +# ) + + +# torch.random.manual_seed(0) +# Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) +# K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) +# V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) +# g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) +# g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) +# DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + +# block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") +# block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") +# for b in range(B): +# for t in range(SEQ_LEN): +# for h in range(H): +# i_i = torch.randperm(max(1, (t // block_size)))[:S] +# block_indices[b, t, h, : len(i_i)] = i_i +# block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() +# block_indices = block_indices.sort(-1)[0] + +# out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) + +# ref = naive_nsa( +# q=Q, +# k=K, +# v=V, +# g_slc=g_slc, +# g_swa=g_swa, +# block_indices=block_indices, +# block_counts=block_counts, +# block_size=block_size, +# scale=scale, +# ) + +# print("out", out) +# print("ref", ref) +# torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + + +# if __name__ == "__main__": +# main() \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/__init__.py b/top/kernels/deepseek_nsa/__init__.py new file mode 100644 index 00000000..d399ca27 --- /dev/null +++ b/top/kernels/deepseek_nsa/__init__.py @@ -0,0 +1,2 @@ +from .nsa_fwd import * +from .nsa_torch import * \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/nsa_fwd.py b/top/kernels/deepseek_nsa/nsa_fwd.py new file mode 100644 index 00000000..11902666 --- /dev/null +++ b/top/kernels/deepseek_nsa/nsa_fwd.py @@ -0,0 +1,319 @@ +import tilelang +import tilelang.language as T +from typing import Optional, Tuple +from top.kernels.kernel import Kernel +import itertools +import torch + +from top.kernels.deepseek_nsa.nsa_torch import naive_nsa + +__all__ = ["nsa_fwd_kernel"] + + +# dtype default float16, accum_dtype default float32 +def _nsa_fwd_kernel( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16 + ): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + else: + scale = scale * 1.44269504 # log2(e) + + head_kv = heads // groups + + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + + block_S = block_size + # block_T = min(128, tilelang.math.next_power_of_2(dim)) + # num_stages = 2 + # threads = 32 + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + + def _nsa_fwd_func(block_T, num_stages, threads): + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + + @T.prim_func + def _nsa_fwd_main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_fragment([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_t, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + if is_causal: + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(G, BV): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) + + return _nsa_fwd_main + return _nsa_fwd_func + + +@torch.library.custom_op("top::nsa_fwd_wrapped_kernel", mutates_args=()) +def _nsa_fwd_wrapped_kernel( + batch: int, + heads: int, + seq_len: int, + dim: int, + is_causal: bool, + scale: float, + block_size: int, + groups: int, + selected_blocks: int, + block_T: int, + num_stages: int, + threads: int, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + BlockIndices: torch.Tensor, +) ->torch.Tensor: + return _nsa_fwd_kernel(batch, heads, seq_len, dim, is_causal, scale, block_size, + groups, selected_blocks)(block_T, num_stages, + threads)(Q, K, V, BlockIndices) + + +@_nsa_fwd_wrapped_kernel.register_fake +def _( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + block_T, + num_stages, + threads, + *inputs +) -> torch.Tensor: + fake_o = torch.empty_like(inputs[0]) + return fake_o + + +class nsa_fwd_kernel(Kernel): + supported_archs: list[int] = [80, 89, 90, 100] + + def __init__( + self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + config: Optional[dict] = None, + tune=False): + + super().__init__() + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + + self.kernel = _nsa_fwd_kernel(self.batch, self.heads, self.seq_len, self.dim, self.is_causal, self.scale, self.block_size, self.groups, self.selected_blocks) + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_T": min(128, tilelang.math.next_power_of_2(self.dim)), + "num_stages": 2, + "threads": 32, + } + + @property + def autotune_configs(self) -> list[dict]: + block_T = [32, 64, 128] + num_stages = [2,] + threads = [32, 64] + _configs = list(itertools.product(block_T, num_stages, threads)) + configs = [{ + "block_T": c[0], + "num_stages": c[1], + "threads": c[2] + } for c in _configs] + return configs + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): + return _nsa_fwd_wrapped_kernel(self.batch, self.heads, self.seq_len, self.dim, self.is_causal, self.scale, self.block_size, self.groups, self.selected_blocks, self.config["block_T"], self.config["num_stages"], self.config["threads"], Q, K, V, BlockIndices) + + +def main(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + + block_T = min(128, tilelang.math.next_power_of_2(D)) + kernel = _nsa_fwd_kernel( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + scale=scale, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + )(block_T=block_T, num_stages=2, threads=32) + + kernel2 = nsa_fwd_kernel( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + tune=True, + ) + + + src_kernel = kernel.get_kernel_source() + print(src_kernel) + # with open("nsa_fwd_kernel.cu", "w") as f: + # f.write(src_kernel) + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + out = kernel(Q, K, V, block_indices.to(torch.int32)) + + out2 = kernel2.forward(Q, K, V, block_indices.to(torch.int32)) + + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + print("out", out) + print("out2", out2) + print("ref", ref) + torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref, out2, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/top/kernels/deepseek_nsa/nsa_torch.py b/top/kernels/deepseek_nsa/nsa_torch.py new file mode 100644 index 00000000..f155a765 --- /dev/null +++ b/top/kernels/deepseek_nsa/nsa_torch.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import math +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + + +@torch.compile +def compression( + k: torch.Tensor, + v: torch.Tensor, + block_size: int +) -> torch.Tensor: + # Currently, we set mean pooling as our basic compression function. + B, T, H = k.shape[:3] + num_block = math.ceil(T / block_size) + if k.shape[1] % block_size != 0: + k = F.pad(k, (0, 0, 0, 0, 0, num_block * block_size - T)) + v = F.pad(v, (0, 0, 0, 0, 0, num_block * block_size - T)) + k_cmp = k.view(B, num_block, block_size, H, -1).mean(dim=2) + v_cmp = v.view(B, num_block, block_size, H, -1).mean(dim=2) + return k_cmp, v_cmp + + +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None`. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, 'b h t -> b t h') + + dtype = q.dtype + G = q.shape[2] // k.shape[2] + BS = block_size + S = block_indices.shape[-1] + k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + if isinstance(block_counts, torch.Tensor): + block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + + o_slc = torch.zeros_like(v) + o_swa = torch.zeros_like(v) if window_size > 0 else None + varlen = True + if cu_seqlens is None: + varlen = False + B, T = q.shape[:2] + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B*T, T)), block_indices.new_tensor([B*T])]) + + for i in range(len(cu_seqlens) - 1): + if not varlen: + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[i] + else: + s_b = block_counts + else: + T = cu_seqlens[i+1] - cu_seqlens[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( + lambda x: x[0][cu_seqlens[i]:cu_seqlens[i+1]], + (q, k, v, g_slc, g_swa, block_indices) + ) + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i+1]] + else: + s_b = block_counts + + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [HQ] + g_slc_i = g_slc_b[i_q] + # [HQ] + g_swa_i = g_swa_b[i_q] + # [S*BS, HQ] + i_i = i_b[i_q] + # [HQ] + if isinstance(block_counts, torch.Tensor): + s_i = s_b[i_q] + else: + s_i = s_b + # [S*BS, HQ, -1] + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp( + 0, T-1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + # [S*BS, HQ] + attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( + torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), + float('-inf') + ).softmax(0) + if not varlen: + o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + else: + o_slc[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + if window_size > 0: + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) + if not varlen: + o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + else: + o_swa[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + + if head_first: + o_slc = rearrange(o_slc, 'b t h d -> b h t d') + o_swa = rearrange(o_swa, 'b t h d -> b h t d') + + return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) + + +def naive_nsa_compression( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_cmp: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + scale: float, + head_first: bool = False +) -> torch.LongTensor: + dtype = q.dtype + B, T = q.shape[0], q.shape[1] + H, HQ = k.shape[2], q.shape[2] + G = HQ//H + BS = block_size + if isinstance(block_counts, int): + block_counts = torch.full((B, T, H), block_counts, dtype=torch.long, device=q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + k_cmp, v_cmp = compression(k, v, BS) + C = k_cmp.shape[1] + S = min(block_counts.max().item(), C) + k_cmp, v_cmp = map(lambda x: repeat(x, 'b c h d -> b c (h g) d', g=G), (k_cmp, v_cmp)) + + casual_mask = ((torch.arange(T) - BS + 1)[:, None] // BS < torch.arange(C)[None, :]).to(q.device) + empty_mask = casual_mask.all(-1, True) + local_mask = (torch.arange(T)[:, None] // BS == torch.arange(C)[None, :]).to(q.device) + + attn_cmp = torch.einsum('bqhd,bkhd->bhqk', q*scale, k_cmp) + attn_cmp = attn_cmp.masked_fill(casual_mask & empty_mask.logical_not(), float('-inf')) + attn_cmp = attn_cmp.softmax(-1).masked_fill(empty_mask, 0.0) + o_cmp = torch.einsum('bhqk, bkhd -> bqhd', attn_cmp, v_cmp) * g_cmp.unsqueeze(-1) + attn_select = attn_cmp.masked_fill(local_mask, float(1.0)) + attn_select = attn_select.view(B, H, G, T, C).sum(2) + block_indices = attn_select.topk(S, -1)[1] + + block_indices = block_indices.masked_fill(block_indices > (block_indices.new_tensor(range(T))[:, None] // BS), -1) + block_indices = block_indices.transpose(1, 2) + + if head_first: + o_cmp = rearrange(o_cmp, 'b t h d -> b h t d') + return block_indices, o_cmp.to(dtype) + + +def naive_nsa_compression_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_cmp: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + scale: float, + cu_seqlens: torch.LongTensor, + head_first: bool = False +) -> torch.LongTensor: + dtype = q.dtype + B, T = q.shape[0], q.shape[1] + H, HQ = k.shape[2], q.shape[2] + D = v.shape[-1] + G = HQ//H + BS = block_size + S = block_counts if isinstance(block_counts, int) else block_counts.max().item() + C = math.ceil(T / block_size) + S = min(S, C) + block_indices = torch.zeros(B, T, H, S, dtype=torch.long, device=q.device) + o_cmp = torch.zeros(B, T, HQ, D, dtype=dtype, device=q.device) + for i in range(len(cu_seqlens) - 1): + T_b = cu_seqlens[i+1] - cu_seqlens[i] + C_b = math.ceil(T_b / block_size) + q_b, k_b, v_b, g_cmp_b = map( + lambda x: x[0][cu_seqlens[i]:cu_seqlens[i+1]], + (q, k, v, g_cmp) + ) + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i+1]] + else: + s_b = block_counts + + k_cmp, v_cmp = compression(k_b.unsqueeze(0), v_b.unsqueeze(0), BS) + S_b = s_b if isinstance(s_b, int) else s_b.max().item() + C_b = k_cmp.shape[1] + S_b = min(S_b, C_b) + k_cmp, v_cmp = map(lambda x: repeat(x.squeeze(0), 'c h d -> c (h g) d', g=G), (k_cmp, v_cmp)) + q_b, k_cmp, v_cmp = map(lambda x: x.float(), (q_b, k_cmp, v_cmp)) + + casual_mask = ((torch.arange(T_b) - BS + 1)[:, None] // BS < torch.arange(C_b)[None, :]).to(q_b.device) + local_mask = (torch.arange(T_b)[:, None] // BS == torch.arange(C_b)[None, :]).to(q.device) + + attn_cmp = torch.einsum('qhd,khd->hqk', q_b*scale, k_cmp) + attn_cmp = attn_cmp.masked_fill(casual_mask, float('-inf')) + attn_cmp = attn_cmp.softmax(-1) + o_cmp[0][cu_seqlens[i]:cu_seqlens[i+1]] = torch.einsum('hqk,khd->qhd', attn_cmp, v_cmp).nan_to_num() *\ + g_cmp_b.unsqueeze(-1) + attn_select = attn_cmp.masked_fill(local_mask, float(1.0)) + attn_select = attn_select.view(H, G, T_b, C_b).sum(1) + block_indices_b = attn_select.topk(S_b, -1)[1] + block_indices_b = block_indices_b.masked_fill( + block_indices_b > (block_indices_b.new_tensor(range(T_b))[:, None]//BS), + 0 + ) + block_indices[0][cu_seqlens[i]:cu_seqlens[i+1], :, :S_b] = block_indices_b.transpose(0, 1) + + if head_first: + o_cmp = rearrange(o_cmp, 'b t h d -> b h t d') + return block_indices, o_cmp.to(dtype) + + +def naive_nsa_with_compression( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_cmp: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_cmp (torch.Tensor): + Gate score for compressed attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + g_cmp, g_slc = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_cmp, g_slc)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, 'b h t -> b t h') + if cu_seqlens is not None: + block_indices, o_cmp = naive_nsa_compression_varlen( + q=q, + k=k, + v=v, + g_cmp=g_cmp, + block_counts=block_counts, + block_size=block_size, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=False) + else: + block_indices, o_cmp = naive_nsa_compression( + q=q, + k=k, + v=v, + g_cmp=g_cmp, + block_counts=block_counts, + block_size=block_size, + scale=scale, + head_first=False) + o = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=False + ) + o_cmp + + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + + return o, block_indices diff --git a/top/layers/deepseek_nsa.py b/top/layers/deepseek_nsa.py new file mode 100644 index 00000000..30fef125 --- /dev/null +++ b/top/layers/deepseek_nsa.py @@ -0,0 +1,100 @@ +import torch +from torch import nn +from top.functions import NativeSparseAttentionFunc + +from top.kernels.deepseek_nsa.nsa_torch import naive_nsa + + +class NativeSparseAttentionLayer(nn.Module): + + def __init__( + self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + tune=False + ): + super().__init__() + + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + self.tune = tune + + self.fn = NativeSparseAttentionFunc( + batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor) -> torch.Tensor: + return self.fn(Q, K, V, BlockIndices) + + +def main(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + + block_T = min(128, 16) + + kernel = NativeSparseAttentionLayer( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + tune=True, + ) + + + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) + + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + print("out", out) + print("ref", ref) + torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py new file mode 100644 index 00000000..ee0b810b --- /dev/null +++ b/top/ops/deepseek_nsa.py @@ -0,0 +1,109 @@ +import torch +from top.ops.op import Op +from top.kernels.kernel import Kernel +from top.kernels.deepseek_nsa.nsa_fwd import nsa_fwd_kernel +from typing import Optional, Dict + +from top.kernels.deepseek_nsa.nsa_torch import naive_nsa + +__all__ = ["NativeSparseAttentionForwardOp"] + + +class NativeSparseAttentionForwardOp(Op): + def __init__( + self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune=False + ): + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + self.tune = tune + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["nsa_fwd_kernel"]( + self.batch, self.heads, self.seq_len, + self.dim, self.is_causal, self.scale, + self.block_size, self.groups, self.selected_blocks, tune=self.tune) + + @property + def default_kernel_map(self): + return {"nsa_fwd_kernel": nsa_fwd_kernel} + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): + return self.kernel(Q, K, V, BlockIndices) + + +# def main(): +# B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + +# block_T = min(128, 16) + +# kernel = NativeSparseAttentionForwardOp( +# batch=B, +# heads=HQ, +# seq_len=SEQ_LEN, +# dim=D, +# is_causal=True, +# block_size=block_size, +# groups=HQ // H, +# selected_blocks=S, +# scale=scale, +# tune=True, +# ) + + +# torch.random.manual_seed(0) +# Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) +# K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) +# V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) +# g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) +# g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) +# DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + +# block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") +# block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") +# for b in range(B): +# for t in range(SEQ_LEN): +# for h in range(H): +# i_i = torch.randperm(max(1, (t // block_size)))[:S] +# block_indices[b, t, h, : len(i_i)] = i_i +# block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() +# block_indices = block_indices.sort(-1)[0] + +# out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) + +# ref = naive_nsa( +# q=Q, +# k=K, +# v=V, +# g_slc=g_slc, +# g_swa=g_swa, +# block_indices=block_indices, +# block_counts=block_counts, +# block_size=block_size, +# scale=scale, +# ) + +# print("out", out) +# print("ref", ref) +# torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + + +# if __name__ == "__main__": +# main() \ No newline at end of file From 884935368532ce8719e41c46dfd799bb06e8524c Mon Sep 17 00:00:00 2001 From: jienengyu Date: Tue, 30 Dec 2025 14:47:33 +0800 Subject: [PATCH 03/14] nsa fwd benchmark --- benchmarks/deepseek_nsa/deepseek_nsa.py | 32 +++--- tests/ops/test_deepseek_nsa_ops.py | 12 +-- top/ops/deepseek_nsa.py | 127 +++++++++++++----------- 3 files changed, 95 insertions(+), 76 deletions(-) diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py index abbe54c7..923315db 100644 --- a/benchmarks/deepseek_nsa/deepseek_nsa.py +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -18,7 +18,7 @@ def __init__( block_size=64, groups=1, selected_blocks=16, - # tune=False + tune=False ): self.batch = batch self.heads = heads @@ -32,19 +32,27 @@ def __init__( self.head_kv = self.heads // self.groups self.dtype = torch.float16 + self.tune = tune @property def total_flops(self): - flops_per_matmul = 2.0 * self.batch * self.heads * self.seq_len * self.dim - flops = flops_per_matmul * 2 - return flops + B = self.batch + T = self.seq_len + HQ = self.heads + D = self.dim + S = self.selected_blocks + BS = self.block_size + window_size = 0 + total_keys = S * BS + window_size + flops = 4 * B * T * HQ * D * total_keys + return flops + @property def total_memory(self): return (self.batch * self.heads * (2 * self.seq_len) * self.dim * self.dtype.itemsize) - # q_shape = [batch, seq_len, heads, dim] - # kv_shape = [batch, seq_len, head_kv, dim] - # block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + + def gen_inputs(self): Q = torch.randn( self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) @@ -64,7 +72,7 @@ def gen_inputs(self): i_i = torch.randperm(max(1, (t // self.block_size)))[:self.selected_blocks] block_indices[b, t, h, : len(i_i)] = i_i self.block_counts[b, t, h] = (block_indices[b, t, h] != self.seq_len).sum().item() - block_indices = block_indices.sort(-1)[0] + block_indices = block_indices.sort(-1)[0].to(torch.int32) return Q, K, V, block_indices def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): @@ -74,8 +82,8 @@ def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIn v=V, g_slc=self.g_slc, g_swa=self.g_swa, - block_indices=BlockIndices, - block_counts=slblock_counts, - block_size=block_size, - scale=scale, + block_indices=BlockIndices.to(torch.long), + block_counts=self.block_counts, + block_size=self.block_size, + scale=self.scale, ) \ No newline at end of file diff --git a/tests/ops/test_deepseek_nsa_ops.py b/tests/ops/test_deepseek_nsa_ops.py index 7c08ff0e..bdc24b37 100644 --- a/tests/ops/test_deepseek_nsa_ops.py +++ b/tests/ops/test_deepseek_nsa_ops.py @@ -27,16 +27,16 @@ def test_nsa_op( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=2, help='batch size') + parser.add_argument('--batch', type=int, default=16, help='batch size') parser.add_argument('--heads', type=int, default=16, help='number of heads') - parser.add_argument('--seq_len', type=int, default=64, help='sequence length') - parser.add_argument('--dim', type=int, default=32, help='head dim') + parser.add_argument('--seq_len', type=int, default=2048, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='head dim') parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') parser.add_argument('--scale', type=float, default=0.1, help='scale') parser.add_argument('--block_size', type=int, default=32, help='block size') - parser.add_argument('--groups', type=int, default=2, help='number of groups') - parser.add_argument('--selected_blocks', type=int, default=32, help='number of selected blocks') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + parser.add_argument('--groups', type=int, default=16, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=1, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') args = parser.parse_args() test_nsa_op( diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index ee0b810b..2c00ac3f 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -35,6 +35,17 @@ def __init__( self.selected_blocks = selected_blocks self.tune = tune + # print("batch ", self.batch) + # print("heads ", self.heads) + # print("seq_len ", self.seq_len) + # print("dim ", self.dim) + # print("is_causal ", self.is_causal) + # print("scale ", self.scale) + # print("block_size ", self.block_size) + # print("groups ", self.groups) + # print("selected_blocks ", self.selected_blocks) + # print("tune ", self.tune) + self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["nsa_fwd_kernel"]( self.batch, self.heads, self.seq_len, @@ -49,61 +60,61 @@ def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndice return self.kernel(Q, K, V, BlockIndices) -# def main(): -# B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 - -# block_T = min(128, 16) - -# kernel = NativeSparseAttentionForwardOp( -# batch=B, -# heads=HQ, -# seq_len=SEQ_LEN, -# dim=D, -# is_causal=True, -# block_size=block_size, -# groups=HQ // H, -# selected_blocks=S, -# scale=scale, -# tune=True, -# ) - - -# torch.random.manual_seed(0) -# Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) -# K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) -# V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) -# g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) -# g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) -# DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") - -# block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") -# block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") -# for b in range(B): -# for t in range(SEQ_LEN): -# for h in range(H): -# i_i = torch.randperm(max(1, (t // block_size)))[:S] -# block_indices[b, t, h, : len(i_i)] = i_i -# block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() -# block_indices = block_indices.sort(-1)[0] - -# out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) - -# ref = naive_nsa( -# q=Q, -# k=K, -# v=V, -# g_slc=g_slc, -# g_swa=g_swa, -# block_indices=block_indices, -# block_counts=block_counts, -# block_size=block_size, -# scale=scale, -# ) - -# print("out", out) -# print("ref", ref) -# torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) - - -# if __name__ == "__main__": -# main() \ No newline at end of file +def main(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + + block_T = min(128, 16) + + kernel = NativeSparseAttentionForwardOp( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + tune=True, + ) + + + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) + + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + print("out", out) + print("ref", ref) + torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() \ No newline at end of file From edc6a194f58fa85d039e13b761033f26c65d732f Mon Sep 17 00:00:00 2001 From: jienengyu Date: Sat, 3 Jan 2026 15:47:00 +0800 Subject: [PATCH 04/14] feat(python): run benchmarks and mean_pooling_tilelang kernel --- benchmarks/deepseek_nsa/__init__.py | 1 + benchmarks/deepseek_nsa/deepseek_nsa.py | 1 + benchmarks/input_params/deepseek_nsa.csv | 5 ++- tests/functions/test_deepseek_nsa_func.py | 29 +++++++++---- tests/layers/test_deepseek_nsa_layer.py | 29 +++++++++---- tests/ops/test_deepseek_nsa_ops.py | 8 ++-- top/kernels/deepseek_nsa/__init__.py | 4 +- top/kernels/deepseek_nsa/nsa_fwd.py | 35 ++++++++-------- top/layers/__init__.py | 3 +- top/ops/deepseek_nsa.py | 51 ++++++++++++----------- 10 files changed, 100 insertions(+), 66 deletions(-) diff --git a/benchmarks/deepseek_nsa/__init__.py b/benchmarks/deepseek_nsa/__init__.py index 911147ad..58c34136 100644 --- a/benchmarks/deepseek_nsa/__init__.py +++ b/benchmarks/deepseek_nsa/__init__.py @@ -1,5 +1,6 @@ from .deepseek_nsa import NativeSparseAttentionForwardBenchmark + __all__ = [ "NativeSparseAttentionForwardBenchmark", ] diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py index 923315db..770b05e5 100644 --- a/benchmarks/deepseek_nsa/deepseek_nsa.py +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -76,6 +76,7 @@ def gen_inputs(self): return Q, K, V, block_indices def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): + print("running naive nsa ref_program") return naive_nsa( q=Q, k=K, diff --git a/benchmarks/input_params/deepseek_nsa.csv b/benchmarks/input_params/deepseek_nsa.csv index 9ff09f86..0c338f08 100644 --- a/benchmarks/input_params/deepseek_nsa.csv +++ b/benchmarks/input_params/deepseek_nsa.csv @@ -1,2 +1,5 @@ batch,heads,seq_len,dim,is_causal,scale,block_size,groups,selected_blocks,tune -2,16,64,32,True,0.1,32,2,32,False \ No newline at end of file +1,64,8192,128,True,0.1,32,16,16,True +1,64,16384,128,True,0.1,32,16,16,True +1,64,32768,128,True,0.1,32,16,16,True +1,64,65536,128,True,0.1,32,16,16,True \ No newline at end of file diff --git a/tests/functions/test_deepseek_nsa_func.py b/tests/functions/test_deepseek_nsa_func.py index d7383b86..b9136793 100644 --- a/tests/functions/test_deepseek_nsa_func.py +++ b/tests/functions/test_deepseek_nsa_func.py @@ -1,5 +1,5 @@ import argparse -from top.functions import NativeSparseAttentionForwardFunc +from top.functions import NativeSparseAttentionFunc from top.utils import str2dtype from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark @@ -17,7 +17,7 @@ def test_nsa_op( # dtype='float16', tune=False, ): - func = NativeSparseAttentionForwardFunc(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) + func = NativeSparseAttentionFunc(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks) inputs = benchmark.gen_inputs() @@ -28,15 +28,26 @@ def test_nsa_op( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=2, help='batch size') - parser.add_argument('--heads', type=int, default=16, help='number of heads') - parser.add_argument('--seq_len', type=int, default=64, help='sequence length') - parser.add_argument('--dim', type=int, default=32, help='head dim') + parser.add_argument('--heads', type=int, default=16*4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192*3, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='head dim') parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') parser.add_argument('--scale', type=float, default=0.1, help='scale') parser.add_argument('--block_size', type=int, default=32, help='block size') - parser.add_argument('--groups', type=int, default=2, help='number of groups') - parser.add_argument('--selected_blocks', type=int, default=32, help='number of selected blocks') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + parser.add_argument('--groups', type=int, default=16, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=16, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') args = parser.parse_args() - test_nsa_op(args.batch, args.heads, args.seq_len, args.dim, str2dtype[args.dtype], args.tune) \ No newline at end of file + test_nsa_op( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.scale, + args.block_size, + args.groups, + args.selected_blocks, + args.tune, + ) \ No newline at end of file diff --git a/tests/layers/test_deepseek_nsa_layer.py b/tests/layers/test_deepseek_nsa_layer.py index 3acd8a61..7177ad57 100644 --- a/tests/layers/test_deepseek_nsa_layer.py +++ b/tests/layers/test_deepseek_nsa_layer.py @@ -1,5 +1,5 @@ import argparse -from top.layers import NativeSparseAttentionForwardLayer +from top.layers import NativeSparseAttentionLayer from top.utils import str2dtype from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark @@ -17,7 +17,7 @@ def test_nsa_op( # dtype='float16', tune=False, ): - layer = NativeSparseAttentionForwardLayer(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) + layer = NativeSparseAttentionLayer(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks) inputs = benchmark.gen_inputs() @@ -28,15 +28,26 @@ def test_nsa_op( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=2, help='batch size') - parser.add_argument('--heads', type=int, default=16, help='number of heads') - parser.add_argument('--seq_len', type=int, default=64, help='sequence length') - parser.add_argument('--dim', type=int, default=32, help='head dim') + parser.add_argument('--heads', type=int, default=16*4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192*3, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='head dim') parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') parser.add_argument('--scale', type=float, default=0.1, help='scale') parser.add_argument('--block_size', type=int, default=32, help='block size') - parser.add_argument('--groups', type=int, default=2, help='number of groups') - parser.add_argument('--selected_blocks', type=int, default=32, help='number of selected blocks') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + parser.add_argument('--groups', type=int, default=16, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=16, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') args = parser.parse_args() - test_nsa_op(args.batch, args.heads, args.seq_len, args.dim, str2dtype[args.dtype], args.tune) \ No newline at end of file + test_nsa_op( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.scale, + args.block_size, + args.groups, + args.selected_blocks, + args.tune, + ) \ No newline at end of file diff --git a/tests/ops/test_deepseek_nsa_ops.py b/tests/ops/test_deepseek_nsa_ops.py index bdc24b37..53868af5 100644 --- a/tests/ops/test_deepseek_nsa_ops.py +++ b/tests/ops/test_deepseek_nsa_ops.py @@ -27,15 +27,15 @@ def test_nsa_op( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=16, help='batch size') - parser.add_argument('--heads', type=int, default=16, help='number of heads') - parser.add_argument('--seq_len', type=int, default=2048, help='sequence length') + parser.add_argument('--batch', type=int, default=2, help='batch size') + parser.add_argument('--heads', type=int, default=16*4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192*3, help='sequence length') parser.add_argument('--dim', type=int, default=128, help='head dim') parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') parser.add_argument('--scale', type=float, default=0.1, help='scale') parser.add_argument('--block_size', type=int, default=32, help='block size') parser.add_argument('--groups', type=int, default=16, help='number of groups') - parser.add_argument('--selected_blocks', type=int, default=1, help='number of selected blocks') + parser.add_argument('--selected_blocks', type=int, default=16, help='number of selected blocks') parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') args = parser.parse_args() diff --git a/top/kernels/deepseek_nsa/__init__.py b/top/kernels/deepseek_nsa/__init__.py index d399ca27..f362f32b 100644 --- a/top/kernels/deepseek_nsa/__init__.py +++ b/top/kernels/deepseek_nsa/__init__.py @@ -1,2 +1,4 @@ from .nsa_fwd import * -from .nsa_torch import * \ No newline at end of file +from .nsa_torch import * +from .mean_pooling_triton import * +from .utils import * \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/nsa_fwd.py b/top/kernels/deepseek_nsa/nsa_fwd.py index 11902666..445353fd 100644 --- a/top/kernels/deepseek_nsa/nsa_fwd.py +++ b/top/kernels/deepseek_nsa/nsa_fwd.py @@ -226,8 +226,8 @@ def default_config(self) -> dict: @property def autotune_configs(self) -> list[dict]: block_T = [32, 64, 128] - num_stages = [2,] - threads = [32, 64] + num_stages = [2, 3] + threads = [32, 64, 128] _configs = list(itertools.product(block_T, num_stages, threads)) configs = [{ "block_T": c[0], @@ -241,7 +241,8 @@ def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndice def main(): - B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + # B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 8192, 4, 16*4, 128, 16, 32, torch.float16, 0.1 block_T = min(128, tilelang.math.next_power_of_2(D)) kernel = _nsa_fwd_kernel( @@ -296,23 +297,23 @@ def main(): out2 = kernel2.forward(Q, K, V, block_indices.to(torch.int32)) - ref = naive_nsa( - q=Q, - k=K, - v=V, - g_slc=g_slc, - g_swa=g_swa, - block_indices=block_indices, - block_counts=block_counts, - block_size=block_size, - scale=scale, - ) + # ref = naive_nsa( + # q=Q, + # k=K, + # v=V, + # g_slc=g_slc, + # g_swa=g_swa, + # block_indices=block_indices, + # block_counts=block_counts, + # block_size=block_size, + # scale=scale, + # ) print("out", out) print("out2", out2) - print("ref", ref) - torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(ref, out2, atol=1e-2, rtol=1e-2) + # print("ref", ref) + # torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + # torch.testing.assert_close(ref, out2, atol=1e-2, rtol=1e-2) if __name__ == "__main__": diff --git a/top/layers/__init__.py b/top/layers/__init__.py index f3fc167b..c4c371b4 100644 --- a/top/layers/__init__.py +++ b/top/layers/__init__.py @@ -2,9 +2,10 @@ from .flash_decode import MultiHeadAttentionDecodeLayer, GroupQueryAttentionDecodeLayer from .deepseek_mla import MultiHeadLatentAttentionDecodeLayer, DeepSeekSparseAttentionDecodeLayer from .linear import LinearLayer +from .deepseek_nsa import NativeSparseAttentionLayer __all__ = [ "MultiHeadAttentionLayer", "GroupQueryAttentionLayer", "MultiHeadAttentionDecodeLayer", "GroupQueryAttentionDecodeLayer", "MultiHeadLatentAttentionDecodeLayer", - "DeepSeekSparseAttentionDecodeLayer", "LinearLayer" + "DeepSeekSparseAttentionDecodeLayer", "LinearLayer", "NativeSparseAttentionLayer" ] diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index 2c00ac3f..27b87002 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -35,22 +35,23 @@ def __init__( self.selected_blocks = selected_blocks self.tune = tune - # print("batch ", self.batch) - # print("heads ", self.heads) - # print("seq_len ", self.seq_len) - # print("dim ", self.dim) - # print("is_causal ", self.is_causal) - # print("scale ", self.scale) - # print("block_size ", self.block_size) - # print("groups ", self.groups) - # print("selected_blocks ", self.selected_blocks) - # print("tune ", self.tune) + print("batch ", self.batch) + print("heads ", self.heads) + print("seq_len ", self.seq_len) + print("dim ", self.dim) + print("is_causal ", self.is_causal) + print("scale ", self.scale) + print("block_size ", self.block_size) + print("groups ", self.groups) + print("selected_blocks ", self.selected_blocks) + print("tune ", self.tune) self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["nsa_fwd_kernel"]( self.batch, self.heads, self.seq_len, self.dim, self.is_causal, self.scale, self.block_size, self.groups, self.selected_blocks, tune=self.tune) + @property def default_kernel_map(self): @@ -61,7 +62,9 @@ def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndice def main(): - B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + # B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 8192, 4, 16*4, 128, 16, 32, torch.float16, 0.1 block_T = min(128, 16) @@ -99,21 +102,21 @@ def main(): out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) - ref = naive_nsa( - q=Q, - k=K, - v=V, - g_slc=g_slc, - g_swa=g_swa, - block_indices=block_indices, - block_counts=block_counts, - block_size=block_size, - scale=scale, - ) + # ref = naive_nsa( + # q=Q, + # k=K, + # v=V, + # g_slc=g_slc, + # g_swa=g_swa, + # block_indices=block_indices, + # block_counts=block_counts, + # block_size=block_size, + # scale=scale, + # ) print("out", out) - print("ref", ref) - torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + # print("ref", ref) + # torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) if __name__ == "__main__": From 86447196ceea32b900a702d197e327a0a8388570 Mon Sep 17 00:00:00 2001 From: jienengyu Date: Sat, 3 Jan 2026 15:47:25 +0800 Subject: [PATCH 05/14] feat(python): run benchmarks and mean_pooling_tilelang kernel --- benchmarks/profile_run2.sh | 78 -------------------------------------- 1 file changed, 78 deletions(-) delete mode 100644 benchmarks/profile_run2.sh diff --git a/benchmarks/profile_run2.sh b/benchmarks/profile_run2.sh deleted file mode 100644 index 2165231c..00000000 --- a/benchmarks/profile_run2.sh +++ /dev/null @@ -1,78 +0,0 @@ -#!/bin/bash - -# Default parameters -PROFILE_OUT="./profile_out" -LOG_FILE="./profile_run.log" - -# Parse command line arguments -while [[ $# -gt 0 ]]; do - case $1 in - --profile_out) - PROFILE_OUT="$2" - shift 2 - ;; - --log) - LOG_FILE="$2" - shift 2 - ;; - *) - echo "Unknown option: $1" - exit 1 - ;; - esac -done - -# Check and handle existing PROFILE_OUT directory -if [ -d "$PROFILE_OUT" ]; then - echo "Warning: PROFILE_OUT directory '$PROFILE_OUT' already exists." -fi - -# Check and handle existing LOG_FILE -if [ -f "$LOG_FILE" ]; then - echo "Warning: LOG_FILE '$LOG_FILE' already exists. Overwriting..." -fi - -# Create output directory -mkdir -p "$PROFILE_OUT" - -# Separator function -print_separator() { - echo "========================================" >> "$LOG_FILE" - echo "========================================" -} - -# Function to run tests -run_test() { - local test_name=$1 - local script_path=$2 - local csv_path=$3 - - echo "Running $test_name test..." | tee -a "$LOG_FILE" - print_separator - - local output_csv="$PROFILE_OUT/${test_name}_results.csv" - - python3 ./benchmarks/profile/profile_run.py \ - --script "$script_path" \ - --input_csv "$csv_path" \ - --output_csv "$output_csv" \ - 2>&1 | tee -a "$LOG_FILE" - - echo "Results saved to: $output_csv" | tee -a "$LOG_FILE" - echo "" | tee -a "$LOG_FILE" -} - -# Main execution flow -{ - - -echo "Starting profile run at $(date)" -print_separator - -# Run GEMM test -run_test "deepseek_nsa" "./tests/ops/test_deepseek_nsa_ops.py" "./benchmarks/input_params/deepseek_nsa.csv" - -print_separator -echo "All tests completed at $(date)" - -} | tee -a "$LOG_FILE" \ No newline at end of file From b687402b07c905ea3017d84fe86f7475e7d1cefe Mon Sep 17 00:00:00 2001 From: jienengyu Date: Sat, 3 Jan 2026 15:47:38 +0800 Subject: [PATCH 06/14] feat(python): run benchmarks and mean_pooling_tilelang kernel --- top/kernels/deepseek_nsa/mean_pooling.py | 138 +++++++++++ .../deepseek_nsa/mean_pooling_triton.py | 150 +++++++++++ top/kernels/deepseek_nsa/utils.py | 233 ++++++++++++++++++ 3 files changed, 521 insertions(+) create mode 100644 top/kernels/deepseek_nsa/mean_pooling.py create mode 100644 top/kernels/deepseek_nsa/mean_pooling_triton.py create mode 100644 top/kernels/deepseek_nsa/utils.py diff --git a/top/kernels/deepseek_nsa/mean_pooling.py b/top/kernels/deepseek_nsa/mean_pooling.py new file mode 100644 index 00000000..b894fc7b --- /dev/null +++ b/top/kernels/deepseek_nsa/mean_pooling.py @@ -0,0 +1,138 @@ +import torch + +from top.kernels.deepseek_nsa.mean_pooling_triton import mean_pooling +from top.kernels.deepseek_nsa.mean_pooling_triton import prepare_chunk_indices + +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[3]) +def mean_pooling_tilelang_kernel( + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + block_D: int = 64, + threads: int = 128, +): + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + X_unpad: T.Tensor([total_seqlen, heads, dim], dtype), + cu_seqlens: T.Tensor([batch_size + 1], T.int32), + chunk_indices: T.Tensor([total_chunks, 2], T.int32), + Output: T.Tensor([total_chunks, heads, dim], dtype), + ): + with T.Kernel( + T.ceildiv(dim, block_D), + total_chunks, + heads, + threads=threads + ) as (i_d, i_t, i_h): + accum = T.alloc_fragment([block_D], accum_dtype) + d_start = i_d * block_D + + seq_id = chunk_indices[i_t, 0] + local_chunk_id = chunk_indices[i_t, 1] + start = cu_seqlens[seq_id] + end = cu_seqlens[seq_id + 1] + seqlen = end - start + + chunk_start = local_chunk_id * chunk_size + chunk_end = T.min(chunk_start + chunk_size, seqlen) + actual_bt = chunk_end - chunk_start + + for d in T.Parallel(block_D): + accum[d] = T.cast(0, accum_dtype) + for t_rel in T.serial(actual_bt): + t_abs = start + chunk_start + t_rel + for d in T.Parallel(block_D): + if d_start + d < dim: + accum[d] += T.cast(X_unpad[t_abs, i_h, d_start + d], accum_dtype) + for d in T.Parallel(block_D): + if d_start + d < dim: + Output[i_t, i_h, d_start + d] = T.cast(accum[d] / T.cast(actual_bt, accum_dtype), dtype) + + return main + + +def mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size, block_D=64): + total_T, H, D = x_unpad.shape + B = cu_seqlens.shape[0] - 1 + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + total_chunks = chunk_indices.shape[0] + + kernel = mean_pooling_tilelang_kernel( + batch_size=B, + total_seqlen=total_T, + total_chunks=total_chunks, + heads=H, + dim=D, + chunk_size=chunk_size, + block_D=block_D, + threads=128, + ) + return kernel(x_unpad, cu_seqlens, chunk_indices) + + +def test_varlen(): + print("=== 🌊 Testing Variable-Length Mode ===") + device = "cuda" + torch.manual_seed(42) + + seqlens = torch.tensor([100, 150], device=device, dtype=torch.int32) + cu_seqlens = torch.zeros(seqlens.shape[0] + 1, device=device, dtype=torch.int32) + cu_seqlens[1:] = seqlens.cumsum(0) + total_T = cu_seqlens[-1].item() + H, D, chunk_size = 4, 64, 32 + + x_unpad = torch.randn(total_T, H, D, dtype=torch.float16, device=device) + # x_triton = x_unpad.unsqueeze(0) # (1, total_T, H, D) + + # Triton + out_triton = mean_pooling(x_unpad.unsqueeze(0), chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False) + out_triton = out_triton.squeeze(0) + + # TileLang + out_tilelang = mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size) + + print(f"Triton: {out_triton.shape}") + print(f"TileLang: {out_tilelang.shape}") + print(f"Max diff: {(out_triton - out_tilelang).abs().max().item():.6f}") + torch.testing.assert_close(out_triton.float(), out_tilelang.float(), atol=1e-2, rtol=1e-2) + print("✅ Varlen test passed!\n") + + +# Test 2: Fixed-Length +def test_fixed(): + print("=== 📏 Testing Fixed-Length Mode ===") + device = "cuda" + torch.manual_seed(42) + + B, T, H, D = 3, 1024, 128, 128 + chunk_size = 32 + + x = torch.randn(B, T, H, D, dtype=torch.float16, device=device) + out_triton = mean_pooling(x, chunk_size=chunk_size, cu_seqlens=None, head_first=False) # (B, NT, H, D) + out_triton_reshaped = out_triton.view(-1, H, D) # (B*NT, H, D) + + x_unpad = x.view(-1, H, D) # (B*T, H, D) + cu_seqlens = torch.arange(0, (B + 1) * T, T, dtype=torch.int32, device=device) # [0, T, 2T] + out_tilelang = mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size) # (total_chunks, H, D) + + print(f"Triton: {out_triton_reshaped.shape}") + print(f"TileLang: {out_tilelang.shape}") + print(f"Max diff: {(out_triton_reshaped - out_tilelang).abs().max().item():.6f}") + torch.testing.assert_close(out_triton_reshaped.float(), out_tilelang.float(), atol=1e-2, rtol=1e-2) + print("✅ Fixed-length test passed!\n") + + +if __name__ == "__main__": + test_varlen() + test_fixed() + print("🎉 All tests passed! TileLang and Triton outputs match perfectly.") \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/mean_pooling_triton.py b/top/kernels/deepseek_nsa/mean_pooling_triton.py new file mode 100644 index 00000000..2755e918 --- /dev/null +++ b/top/kernels/deepseek_nsa/mean_pooling_triton.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from .utils import input_guard, prepare_chunk_indices + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [16, 32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def mean_pooling_fwd_kernel( + x, + o, + offsets, + indices, + T: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) + # [BT, BD] + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + # [BD] + b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +def mean_pooling_fwd( + x: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None +) -> torch.Tensor: + B, T, H, D = x.shape + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = x.new_empty(B, NT, H, D) + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) + mean_pooling_fwd_kernel[grid]( + x, + o, + offsets, + indices, + T=T, + H=H, + D=D, + BT=BT, + NT=NT, + ) + return o + + +class MeanPoolingFunction(torch.autograd.Function): + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None + ) -> torch.Tensor: + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + o = mean_pooling_fwd(x, chunk_size, offsets, indices) + ctx.batch_size = x.shape[0] + ctx.seq_len = x.shape[1] + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + return o + + +def mean_pooling( + x: torch.Tensor, + chunk_size: int, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + if head_first: + x = x.transpose(1, 2) + if cu_seqlens is not None: + if x.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens) + if head_first: + o = o.transpose(1, 2) + return o + + +def test_mean_pooling(): + torch.manual_seed(42) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + batch_size = 2 + seq_len = 1024 + num_heads = 4 + head_dim = 64 + chunk_size = 32 + + x = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, requires_grad=True) + output = mean_pooling(x, chunk_size=chunk_size, head_first=False) + + x_hf = x.permute(0, 2, 1, 3).contiguous().requires_grad_(True) # (B, H, T, D) + output_hf = mean_pooling(x_hf, chunk_size=chunk_size, head_first=True) + out1 = output_hf.permute(0, 2, 1, 3).contiguous() + out2 = output.contiguous() + + print("max abs diff:", (out1 - out2).abs().max().item()) + print("mean abs diff:", (out1 - out2).abs().mean().item()) + assert torch.allclose(output_hf.permute(0, 2, 1, 3).contiguous().clone(), output.contiguous().clone(), atol=1e-4) + + +if __name__ == "__main__": + test_mean_pooling() \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/utils.py b/top/kernels/deepseek_nsa/utils.py new file mode 100644 index 00000000..d2be54e4 --- /dev/null +++ b/top/kernels/deepseek_nsa/utils.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- + +import contextlib +import functools +import os +from functools import lru_cache +from typing import Any, Callable, Dict, Literal, Optional, Tuple + +import torch +import triton +from packaging import version + + +def tensor_cache( + fn: Callable[..., torch.Tensor] +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: Optional[Tuple] = None + last_kwargs: Optional[Dict] = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if last_args is not None and last_kwargs is not None: + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): + if all(a is b for a, b in zip(args, last_args)) and \ + all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +def input_guard( + fn: Callable[..., torch.Tensor] +) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = custom_device_ctx(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +contiguous = input_guard + + +def require_version(version, hint): + """ + Perform a runtime check of the dependency versions, using the exact same syntax used by pip. + """ + def decorator(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + from transformers.utils.versions import require_version + require_version(version, hint) + return fn(ctx, + *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), + **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) + return wrapper + return decorator + + +def checkpoint(fn): + def wrapper(*args, **kwargs): + return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) + return wrapper + + +@lru_cache(maxsize=None) +def check_pytorch_version(version_s: str = '2.4') -> bool: + return version.parse(torch.__version__) >= version.parse(version_s) + + +@lru_cache(maxsize=None) +def get_multiprocessor_count(tensor_idx: int = 0) -> int: + return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count'] + + +@lru_cache(maxsize=None) +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + import warnings + warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1) + return 'cpu' + + +@lru_cache(maxsize=None) +def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: + device = get_available_device() + if device == 'cuda': + return 'nvidia' + elif device == 'hip': + return 'amd' + elif device == 'xpu': + return 'intel' + else: + return device + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != 'hip' else 'cuda' +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_intel = (device_platform == 'intel') +is_nvidia = (device_platform == 'nvidia') +is_amd = (device_platform == 'amd') +is_intel_a770 = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) +use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') + +# Nvidia Ampere or newer, haven't check AMD and intel yet. +is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8) + + +def get_all_max_shared_memory(): + return [ + triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem'] + for i in range(device_torch_lib.device_count()) + ] + + +@lru_cache(maxsize=None) +def is_triton_shared_mem_enough(max_shared_mem: int = 102400, tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_memory() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= max_shared_mem + except Exception: + return False + + +device_capacity = is_triton_shared_mem_enough() + + +if check_pytorch_version('2.4'): + device = 'cuda' if device == 'cpu' else device + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) + + def custom_device_ctx(index: int): + return device_torch_lib.device(index) +else: + assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.' + autocast_custom_fwd = device_torch_lib.amp.custom_fwd + autocast_custom_bwd = device_torch_lib.amp.custom_bwd + + def custom_device_ctx(index: int): + return torch.cuda.device(index) + + +@tensor_cache +def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor: + return offsets[1:] - offsets[:-1] + + +@tensor_cache +def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor: + return torch.cat([torch.arange(n) for n in prepare_lens(offsets).tolist()]).to(offsets.device) + + +@tensor_cache +def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor: + return position_ids.eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(offsets) + return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets) + + +@tensor_cache +def prepare_chunk_offsets( + offsets: torch.Tensor, + chunk_size: int +) -> torch.LongTensor: + return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1) + + +@tensor_cache +def prepare_chunk_indices( + offsets: torch.LongTensor, + chunk_size: int +) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()]) + return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets) \ No newline at end of file From 6b57bcdf5dc07ef7096e24f41783f9d9549334be Mon Sep 17 00:00:00 2001 From: jienengyu Date: Sun, 4 Jan 2026 15:31:41 +0800 Subject: [PATCH 07/14] feat(python): run benchmarks and mean_pooling_tilelang kernel --- top/kernels/deepseek_nsa/nsa_fwd.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/top/kernels/deepseek_nsa/nsa_fwd.py b/top/kernels/deepseek_nsa/nsa_fwd.py index 445353fd..34bde028 100644 --- a/top/kernels/deepseek_nsa/nsa_fwd.py +++ b/top/kernels/deepseek_nsa/nsa_fwd.py @@ -297,24 +297,6 @@ def main(): out2 = kernel2.forward(Q, K, V, block_indices.to(torch.int32)) - # ref = naive_nsa( - # q=Q, - # k=K, - # v=V, - # g_slc=g_slc, - # g_swa=g_swa, - # block_indices=block_indices, - # block_counts=block_counts, - # block_size=block_size, - # scale=scale, - # ) - - print("out", out) - print("out2", out2) - # print("ref", ref) - # torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) - # torch.testing.assert_close(ref, out2, atol=1e-2, rtol=1e-2) - if __name__ == "__main__": main() From e4bb142a0de07f087a412dab69833b0b50a2776e Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 6 Jan 2026 14:33:09 +0800 Subject: [PATCH 08/14] [Feat]add nsa_fwd kernel/op & mean_pool kernel/op --- benchmarks/deepseek_nsa/deepseek_nsa.py | 105 ++++- tests/ops/test_deepseek_nsa_ops.py | 6 +- top/functions/deepseek_nsa.py | 61 --- top/kernels/deepseek_nsa/__init__.py | 4 +- top/kernels/deepseek_nsa/mean_pooling.py | 138 ------- .../deepseek_nsa/mean_pooling_triton.py | 150 ------- top/kernels/deepseek_nsa/nsa_fwd.py | 69 +--- top/kernels/deepseek_nsa/nsa_torch.py | 380 ------------------ top/kernels/deepseek_nsa/utils.py | 233 ----------- top/layers/deepseek_nsa.py | 63 --- top/ops/__init__.py | 3 +- top/ops/deepseek_nsa.py | 138 +++---- 12 files changed, 176 insertions(+), 1174 deletions(-) delete mode 100644 top/kernels/deepseek_nsa/mean_pooling.py delete mode 100644 top/kernels/deepseek_nsa/mean_pooling_triton.py delete mode 100644 top/kernels/deepseek_nsa/nsa_torch.py delete mode 100644 top/kernels/deepseek_nsa/utils.py diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py index 770b05e5..d4257fa7 100644 --- a/benchmarks/deepseek_nsa/deepseek_nsa.py +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -1,8 +1,17 @@ from benchmarks.benchmark import Benchmark from top.ops import NativeSparseAttentionForwardOp +from top.ops import MeanPoolingForwardOp + import torch from torch.nn import functional as f -from top.kernels.deepseek_nsa.nsa_torch import naive_nsa + +from typing import Tuple, Any, Optional +from native_sparse_attention.ops.naive import naive_nsa +from native_sparse_attention.ops.parallel import parallel_nsa_fwd +from fla.ops.utils import mean_pooling + +from fla.ops.common.utils import prepare_chunk_indices + class NativeSparseAttentionForwardBenchmark(Benchmark): op_type = NativeSparseAttentionForwardOp @@ -61,6 +70,9 @@ def gen_inputs(self): V = torch.randn( self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype) + self.o_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim), dtype=self.dtype, device="cuda") + self.lse_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim), dtype=torch.float, device="cuda") + self.g_slc = torch.ones((self.batch, self.seq_len, self.heads), dtype=self.dtype, device="cuda").requires_grad_(True) self.g_swa = torch.ones((self.batch, self.seq_len, self.heads), dtype=self.dtype, device="cuda").requires_grad_(True) @@ -75,8 +87,8 @@ def gen_inputs(self): block_indices = block_indices.sort(-1)[0].to(torch.int32) return Q, K, V, block_indices - def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): - print("running naive nsa ref_program") + + def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor) -> torch.Tensor: return naive_nsa( q=Q, k=K, @@ -87,4 +99,89 @@ def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIn block_counts=self.block_counts, block_size=self.block_size, scale=self.scale, - ) \ No newline at end of file + ) + + + def baseline_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor)-> torch.Tensor: + o, lse = parallel_nsa_fwd( + q=Q, + k=K, + v=V, + block_indices=BlockIndices, + block_counts=self.block_counts, + block_size=self.block_size, + scale=self.scale, + ) + return o + + + def baseline_profile(self, *inputs: Any, warmup: int = 100, rep: int = 100, device: str = "cuda:0") -> Any: + print("===== Profiling FLA NSA_Fwd backend =====") + return super().baseline_profile( + self.baseline_program, *inputs, backend="FLA", warmup=warmup, rep=rep, device=device) + + +class MeanPoolingForwardBenchmark(Benchmark): + op_type = MeanPoolingForwardOp + + def __init__( + self, + batch_size, + total_seqlen, + total_chunks, + heads, + dim, + chunk_size, + tune= True + ): + self.batch_size = batch_size + self.total_seqlen = total_seqlen + self.total_chunks = total_chunks + self.heads = heads + self.dim = dim + self.chunk_size = chunk_size + self.tune= tune + self.dtype = torch.float16 + + @property + def total_flops(self): + flops = self.heads * self.dim * (self.total_seqlen + self.total_chunks) + return flops + + @property + def total_memory(self): + return self.heads*self.dim*(self.total_seqlen+self.total_chunks)*self.dtype.itemsize + 16*self.total_chunks + + def gen_inputs(self): + x_unpad = torch.randn(self.total_seqlen, self.heads, self.dim, device='cuda', dtype=self.dtype) + # fixed length + b = self.batch_size + t = self.total_seqlen//b + + cu_seqlens = torch.arange(0, (b + 1) * t, t, dtype=torch.int32, device='cuda') + chunk_indices = prepare_chunk_indices(cu_seqlens, self.chunk_size) + + return x_unpad, cu_seqlens, chunk_indices + + + def ref_program(self, x_unpad:torch.Tensor, cu_seqlens:torch.Tensor, chunk_indices:torch.Tensor) -> torch.Tensor: + b = self.batch_size + t = self.total_seqlen//b + x = x_unpad.view(b, t, self.heads, self.dim) + + return mean_pooling(x, chunk_size=self.chunk_size, cu_seqlens=None, head_first=False).view(-1,self.heads, self.dim) + + + def baseline_program(self, x_unpad:torch.Tensor, cu_seqlens:torch.Tensor, chunk_indices:torch.Tensor) -> torch.Tensor: + b = self.batch_size + t = self.total_seqlen//b + x = x_unpad.view(b, t, self.heads, self.dim) + return mean_pooling(x, chunk_size=self.chunk_size, cu_seqlens=None, head_first=False).view(-1,self.heads, self.dim) + + + + def baseline_profile(self, *inputs: Any, warmup: int = 100, rep: int = 100, device: str = "cuda:0") -> Any: + print("===== Profiling Mean Pooling_Fwd backend =====") + return super().baseline_profile( + self.baseline_program, *inputs, backend="Mean Pooling", warmup=warmup, rep=rep, device=device) + diff --git a/tests/ops/test_deepseek_nsa_ops.py b/tests/ops/test_deepseek_nsa_ops.py index 53868af5..f89ba0a7 100644 --- a/tests/ops/test_deepseek_nsa_ops.py +++ b/tests/ops/test_deepseek_nsa_ops.py @@ -23,13 +23,15 @@ def test_nsa_op( inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) benchmark.profile(op, *inputs) + benchmark.baseline_profile(*inputs) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=2, help='batch size') + parser.add_argument('--batch', type=int, default=16, help='batch size') parser.add_argument('--heads', type=int, default=16*4, help='number of heads') - parser.add_argument('--seq_len', type=int, default=8192*3, help='sequence length') + parser.add_argument('--seq_len', type=int, default=8192*1, help='sequence length') parser.add_argument('--dim', type=int, default=128, help='head dim') parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') parser.add_argument('--scale', type=float, default=0.1, help='scale') diff --git a/top/functions/deepseek_nsa.py b/top/functions/deepseek_nsa.py index deecfcc0..65b990fe 100644 --- a/top/functions/deepseek_nsa.py +++ b/top/functions/deepseek_nsa.py @@ -2,7 +2,6 @@ from top.functions.function import Function from top.ops.deepseek_nsa import NativeSparseAttentionForwardOp -from top.kernels.deepseek_nsa.nsa_torch import naive_nsa __all__ = ['NativeSparseAttentionFunc'] @@ -54,63 +53,3 @@ def __init__( def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor) -> torch.Tensor: return nsa_decode_ctx.apply(Q, K, V, BlockIndices, self.fwd_op) - - -# def main(): -# B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 - -# block_T = min(128, 16) - -# kernel = NativeSparseAttentionFunc( -# batch=B, -# heads=HQ, -# seq_len=SEQ_LEN, -# dim=D, -# is_causal=True, -# block_size=block_size, -# groups=HQ // H, -# selected_blocks=S, -# scale=scale, -# tune=True, -# ) - - -# torch.random.manual_seed(0) -# Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) -# K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) -# V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) -# g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) -# g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) -# DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") - -# block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") -# block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") -# for b in range(B): -# for t in range(SEQ_LEN): -# for h in range(H): -# i_i = torch.randperm(max(1, (t // block_size)))[:S] -# block_indices[b, t, h, : len(i_i)] = i_i -# block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() -# block_indices = block_indices.sort(-1)[0] - -# out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) - -# ref = naive_nsa( -# q=Q, -# k=K, -# v=V, -# g_slc=g_slc, -# g_swa=g_swa, -# block_indices=block_indices, -# block_counts=block_counts, -# block_size=block_size, -# scale=scale, -# ) - -# print("out", out) -# print("ref", ref) -# torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) - - -# if __name__ == "__main__": -# main() \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/__init__.py b/top/kernels/deepseek_nsa/__init__.py index f362f32b..d28c4833 100644 --- a/top/kernels/deepseek_nsa/__init__.py +++ b/top/kernels/deepseek_nsa/__init__.py @@ -1,4 +1,2 @@ from .nsa_fwd import * -from .nsa_torch import * -from .mean_pooling_triton import * -from .utils import * \ No newline at end of file +from .mean_pooling_fwd import * \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/mean_pooling.py b/top/kernels/deepseek_nsa/mean_pooling.py deleted file mode 100644 index b894fc7b..00000000 --- a/top/kernels/deepseek_nsa/mean_pooling.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch - -from top.kernels.deepseek_nsa.mean_pooling_triton import mean_pooling -from top.kernels.deepseek_nsa.mean_pooling_triton import prepare_chunk_indices - -import tilelang -import tilelang.language as T - - -@tilelang.jit(out_idx=[3]) -def mean_pooling_tilelang_kernel( - batch_size: int, - total_seqlen: int, - total_chunks: int, - heads: int, - dim: int, - chunk_size: int, - block_D: int = 64, - threads: int = 128, -): - dtype = T.float16 - accum_dtype = T.float32 - - @T.prim_func - def main( - X_unpad: T.Tensor([total_seqlen, heads, dim], dtype), - cu_seqlens: T.Tensor([batch_size + 1], T.int32), - chunk_indices: T.Tensor([total_chunks, 2], T.int32), - Output: T.Tensor([total_chunks, heads, dim], dtype), - ): - with T.Kernel( - T.ceildiv(dim, block_D), - total_chunks, - heads, - threads=threads - ) as (i_d, i_t, i_h): - accum = T.alloc_fragment([block_D], accum_dtype) - d_start = i_d * block_D - - seq_id = chunk_indices[i_t, 0] - local_chunk_id = chunk_indices[i_t, 1] - start = cu_seqlens[seq_id] - end = cu_seqlens[seq_id + 1] - seqlen = end - start - - chunk_start = local_chunk_id * chunk_size - chunk_end = T.min(chunk_start + chunk_size, seqlen) - actual_bt = chunk_end - chunk_start - - for d in T.Parallel(block_D): - accum[d] = T.cast(0, accum_dtype) - for t_rel in T.serial(actual_bt): - t_abs = start + chunk_start + t_rel - for d in T.Parallel(block_D): - if d_start + d < dim: - accum[d] += T.cast(X_unpad[t_abs, i_h, d_start + d], accum_dtype) - for d in T.Parallel(block_D): - if d_start + d < dim: - Output[i_t, i_h, d_start + d] = T.cast(accum[d] / T.cast(actual_bt, accum_dtype), dtype) - - return main - - -def mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size, block_D=64): - total_T, H, D = x_unpad.shape - B = cu_seqlens.shape[0] - 1 - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) - total_chunks = chunk_indices.shape[0] - - kernel = mean_pooling_tilelang_kernel( - batch_size=B, - total_seqlen=total_T, - total_chunks=total_chunks, - heads=H, - dim=D, - chunk_size=chunk_size, - block_D=block_D, - threads=128, - ) - return kernel(x_unpad, cu_seqlens, chunk_indices) - - -def test_varlen(): - print("=== 🌊 Testing Variable-Length Mode ===") - device = "cuda" - torch.manual_seed(42) - - seqlens = torch.tensor([100, 150], device=device, dtype=torch.int32) - cu_seqlens = torch.zeros(seqlens.shape[0] + 1, device=device, dtype=torch.int32) - cu_seqlens[1:] = seqlens.cumsum(0) - total_T = cu_seqlens[-1].item() - H, D, chunk_size = 4, 64, 32 - - x_unpad = torch.randn(total_T, H, D, dtype=torch.float16, device=device) - # x_triton = x_unpad.unsqueeze(0) # (1, total_T, H, D) - - # Triton - out_triton = mean_pooling(x_unpad.unsqueeze(0), chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False) - out_triton = out_triton.squeeze(0) - - # TileLang - out_tilelang = mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size) - - print(f"Triton: {out_triton.shape}") - print(f"TileLang: {out_tilelang.shape}") - print(f"Max diff: {(out_triton - out_tilelang).abs().max().item():.6f}") - torch.testing.assert_close(out_triton.float(), out_tilelang.float(), atol=1e-2, rtol=1e-2) - print("✅ Varlen test passed!\n") - - -# Test 2: Fixed-Length -def test_fixed(): - print("=== 📏 Testing Fixed-Length Mode ===") - device = "cuda" - torch.manual_seed(42) - - B, T, H, D = 3, 1024, 128, 128 - chunk_size = 32 - - x = torch.randn(B, T, H, D, dtype=torch.float16, device=device) - out_triton = mean_pooling(x, chunk_size=chunk_size, cu_seqlens=None, head_first=False) # (B, NT, H, D) - out_triton_reshaped = out_triton.view(-1, H, D) # (B*NT, H, D) - - x_unpad = x.view(-1, H, D) # (B*T, H, D) - cu_seqlens = torch.arange(0, (B + 1) * T, T, dtype=torch.int32, device=device) # [0, T, 2T] - out_tilelang = mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size) # (total_chunks, H, D) - - print(f"Triton: {out_triton_reshaped.shape}") - print(f"TileLang: {out_tilelang.shape}") - print(f"Max diff: {(out_triton_reshaped - out_tilelang).abs().max().item():.6f}") - torch.testing.assert_close(out_triton_reshaped.float(), out_tilelang.float(), atol=1e-2, rtol=1e-2) - print("✅ Fixed-length test passed!\n") - - -if __name__ == "__main__": - test_varlen() - test_fixed() - print("🎉 All tests passed! TileLang and Triton outputs match perfectly.") \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/mean_pooling_triton.py b/top/kernels/deepseek_nsa/mean_pooling_triton.py deleted file mode 100644 index 2755e918..00000000 --- a/top/kernels/deepseek_nsa/mean_pooling_triton.py +++ /dev/null @@ -1,150 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -from typing import Optional, Tuple - -import torch -import triton -import triton.language as tl - -from .utils import input_guard, prepare_chunk_indices - - -@triton.heuristics({ - 'USE_OFFSETS': lambda args: args['offsets'] is not None -}) -@triton.autotune( - configs=[ - triton.Config({'BD': BD}, num_warps=num_warps) - for BD in [16, 32, 64, 128] - for num_warps in [1, 2, 4, 8] - ], - key=['BT'] -) -@triton.jit(do_not_specialize=['T']) -def mean_pooling_fwd_kernel( - x, - o, - offsets, - indices, - T: tl.constexpr, - H: tl.constexpr, - D: tl.constexpr, - BT: tl.constexpr, - BD: tl.constexpr, - NT: tl.constexpr, - USE_OFFSETS: tl.constexpr -): - i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_b, i_h = i_bh // H, i_bh % H - if USE_OFFSETS: - i_tg = i_t - i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) - T = eos - bos - NT = tl.cdiv(T, BT) - else: - NT = tl.cdiv(T, BT) - i_tg = i_b * NT + i_t - bos, eos = i_b * T, i_b * T + T - - p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) - p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) - # [BT, BD] - b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) - # [BD] - b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) - - -def mean_pooling_fwd( - x: torch.Tensor, - chunk_size: int, - offsets: Optional[torch.LongTensor] = None, - indices: Optional[torch.LongTensor] = None -) -> torch.Tensor: - B, T, H, D = x.shape - BT = chunk_size - NT = triton.cdiv(T, BT) if offsets is None else len(indices) - - o = x.new_empty(B, NT, H, D) - def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) - mean_pooling_fwd_kernel[grid]( - x, - o, - offsets, - indices, - T=T, - H=H, - D=D, - BT=BT, - NT=NT, - ) - return o - - -class MeanPoolingFunction(torch.autograd.Function): - @staticmethod - @input_guard - def forward( - ctx, - x: torch.Tensor, - chunk_size: int, - offsets: Optional[torch.LongTensor] = None - ) -> torch.Tensor: - # 2-d indices denoting the offsets of chunks in each sequence - # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, - # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be - # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] - indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None - o = mean_pooling_fwd(x, chunk_size, offsets, indices) - ctx.batch_size = x.shape[0] - ctx.seq_len = x.shape[1] - ctx.chunk_size = chunk_size - ctx.offsets = offsets - ctx.indices = indices - return o - - -def mean_pooling( - x: torch.Tensor, - chunk_size: int, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False -) -> torch.Tensor: - if head_first: - x = x.transpose(1, 2) - if cu_seqlens is not None: - if x.shape[0] != 1: - raise ValueError(f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") - o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens) - if head_first: - o = o.transpose(1, 2) - return o - - -def test_mean_pooling(): - torch.manual_seed(42) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - batch_size = 2 - seq_len = 1024 - num_heads = 4 - head_dim = 64 - chunk_size = 32 - - x = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, requires_grad=True) - output = mean_pooling(x, chunk_size=chunk_size, head_first=False) - - x_hf = x.permute(0, 2, 1, 3).contiguous().requires_grad_(True) # (B, H, T, D) - output_hf = mean_pooling(x_hf, chunk_size=chunk_size, head_first=True) - out1 = output_hf.permute(0, 2, 1, 3).contiguous() - out2 = output.contiguous() - - print("max abs diff:", (out1 - out2).abs().max().item()) - print("mean abs diff:", (out1 - out2).abs().mean().item()) - assert torch.allclose(output_hf.permute(0, 2, 1, 3).contiguous().clone(), output.contiguous().clone(), atol=1e-4) - - -if __name__ == "__main__": - test_mean_pooling() \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/nsa_fwd.py b/top/kernels/deepseek_nsa/nsa_fwd.py index 34bde028..a8023735 100644 --- a/top/kernels/deepseek_nsa/nsa_fwd.py +++ b/top/kernels/deepseek_nsa/nsa_fwd.py @@ -5,7 +5,6 @@ import itertools import torch -from top.kernels.deepseek_nsa.nsa_torch import naive_nsa __all__ = ["nsa_fwd_kernel"] @@ -34,9 +33,7 @@ def _nsa_fwd_kernel( accum_dtype = T.float32 block_S = block_size - # block_T = min(128, tilelang.math.next_power_of_2(dim)) - # num_stages = 2 - # threads = 32 + @tilelang.jit( out_idx=[-1], pass_configs={ @@ -45,8 +42,6 @@ def _nsa_fwd_kernel( tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }, ) - - def _nsa_fwd_func(block_T, num_stages, threads): NK = tilelang.cdiv(dim, block_T) @@ -238,65 +233,3 @@ def autotune_configs(self) -> list[dict]: def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): return _nsa_fwd_wrapped_kernel(self.batch, self.heads, self.seq_len, self.dim, self.is_causal, self.scale, self.block_size, self.groups, self.selected_blocks, self.config["block_T"], self.config["num_stages"], self.config["threads"], Q, K, V, BlockIndices) - - -def main(): - # B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 - B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 8192, 4, 16*4, 128, 16, 32, torch.float16, 0.1 - - block_T = min(128, tilelang.math.next_power_of_2(D)) - kernel = _nsa_fwd_kernel( - batch=B, - heads=HQ, - seq_len=SEQ_LEN, - dim=D, - is_causal=True, - scale=scale, - block_size=block_size, - groups=HQ // H, - selected_blocks=S, - )(block_T=block_T, num_stages=2, threads=32) - - kernel2 = nsa_fwd_kernel( - batch=B, - heads=HQ, - seq_len=SEQ_LEN, - dim=D, - is_causal=True, - block_size=block_size, - groups=HQ // H, - selected_blocks=S, - scale=scale, - tune=True, - ) - - - src_kernel = kernel.get_kernel_source() - print(src_kernel) - # with open("nsa_fwd_kernel.cu", "w") as f: - # f.write(src_kernel) - torch.random.manual_seed(0) - Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) - g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) - DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") - - block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") - block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") - for b in range(B): - for t in range(SEQ_LEN): - for h in range(H): - i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, : len(i_i)] = i_i - block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() - block_indices = block_indices.sort(-1)[0] - - out = kernel(Q, K, V, block_indices.to(torch.int32)) - - out2 = kernel2.forward(Q, K, V, block_indices.to(torch.int32)) - - -if __name__ == "__main__": - main() diff --git a/top/kernels/deepseek_nsa/nsa_torch.py b/top/kernels/deepseek_nsa/nsa_torch.py deleted file mode 100644 index f155a765..00000000 --- a/top/kernels/deepseek_nsa/nsa_torch.py +++ /dev/null @@ -1,380 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang - -import math -from typing import Optional, Union - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat - - - -@torch.compile -def compression( - k: torch.Tensor, - v: torch.Tensor, - block_size: int -) -> torch.Tensor: - # Currently, we set mean pooling as our basic compression function. - B, T, H = k.shape[:3] - num_block = math.ceil(T / block_size) - if k.shape[1] % block_size != 0: - k = F.pad(k, (0, 0, 0, 0, 0, num_block * block_size - T)) - v = F.pad(v, (0, 0, 0, 0, 0, num_block * block_size - T)) - k_cmp = k.view(B, num_block, block_size, H, -1).mean(dim=2) - v_cmp = v.view(B, num_block, block_size, H, -1).mean(dim=2) - return k_cmp, v_cmp - - -def naive_nsa( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_indices: torch.LongTensor, - block_counts: Optional[Union[torch.LongTensor, int]] = None, - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False -) -> torch.Tensor: - r""" - Args: - q (torch.Tensor): - Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. - k (torch.Tensor): - Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. - GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. - v (torch.Tensor): - Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. - g_slc (torch.Tensor): - Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. - g_swa (torch.Tensor): - Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. - block_indices (torch.LongTensor): - Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. - `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. - block_counts (Union[torch.LongTensor, int]): - Number of selected blocks for each token. - If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, - each token can select the same number of blocks. - If not provided, it will default to `S`, Default: `None`. - block_size (int): - Selected block size. Default: 64. - window_size (int): - Sliding window size. Default: 0. - scale (Optional[int]): - Scale factor for attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. - cu_seqlens (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. - head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `False`. - - Returns: - o (torch.Tensor): - Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. - """ - if scale is None: - scale = k.shape[-1] ** -0.5 - if cu_seqlens is not None: - assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" - if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") - if head_first: - q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, block_indices)) - g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa)) - if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') - - dtype = q.dtype - G = q.shape[2] // k.shape[2] - BS = block_size - S = block_indices.shape[-1] - k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) - if isinstance(block_counts, torch.Tensor): - block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G) - c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) - q, k, v = map(lambda x: x.float(), (q, k, v)) - - o_slc = torch.zeros_like(v) - o_swa = torch.zeros_like(v) if window_size > 0 else None - varlen = True - if cu_seqlens is None: - varlen = False - B, T = q.shape[:2] - cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B*T, T)), block_indices.new_tensor([B*T])]) - - for i in range(len(cu_seqlens) - 1): - if not varlen: - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] - if isinstance(block_counts, torch.Tensor): - s_b = block_counts[i] - else: - s_b = block_counts - else: - T = cu_seqlens[i+1] - cu_seqlens[i] - q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i+1]], - (q, k, v, g_slc, g_swa, block_indices) - ) - if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i+1]] - else: - s_b = block_counts - - i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) - # [T, S*BS, HQ] - i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) - for i_q in range(T): - # [HQ, D] - q_i = q_b[i_q] * scale - # [HQ] - g_slc_i = g_slc_b[i_q] - # [HQ] - g_swa_i = g_swa_b[i_q] - # [S*BS, HQ] - i_i = i_b[i_q] - # [HQ] - if isinstance(block_counts, torch.Tensor): - s_i = s_b[i_q] - else: - s_i = s_b - # [S*BS, HQ, -1] - k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp( - 0, T-1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) - # [S*BS, HQ] - attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill( - torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), - float('-inf') - ).softmax(0) - if not varlen: - o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) - else: - o_slc[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) - if window_size > 0: - k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1], (k_b, v_b)) - attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0) - if not varlen: - o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) - else: - o_swa[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) - - if head_first: - o_slc = rearrange(o_slc, 'b t h d -> b h t d') - o_swa = rearrange(o_swa, 'b t h d -> b h t d') - - return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) - - -def naive_nsa_compression( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_cmp: torch.Tensor, - block_counts: Union[torch.LongTensor, int], - block_size: int, - scale: float, - head_first: bool = False -) -> torch.LongTensor: - dtype = q.dtype - B, T = q.shape[0], q.shape[1] - H, HQ = k.shape[2], q.shape[2] - G = HQ//H - BS = block_size - if isinstance(block_counts, int): - block_counts = torch.full((B, T, H), block_counts, dtype=torch.long, device=q.device) - q, k, v = map(lambda x: x.float(), (q, k, v)) - k_cmp, v_cmp = compression(k, v, BS) - C = k_cmp.shape[1] - S = min(block_counts.max().item(), C) - k_cmp, v_cmp = map(lambda x: repeat(x, 'b c h d -> b c (h g) d', g=G), (k_cmp, v_cmp)) - - casual_mask = ((torch.arange(T) - BS + 1)[:, None] // BS < torch.arange(C)[None, :]).to(q.device) - empty_mask = casual_mask.all(-1, True) - local_mask = (torch.arange(T)[:, None] // BS == torch.arange(C)[None, :]).to(q.device) - - attn_cmp = torch.einsum('bqhd,bkhd->bhqk', q*scale, k_cmp) - attn_cmp = attn_cmp.masked_fill(casual_mask & empty_mask.logical_not(), float('-inf')) - attn_cmp = attn_cmp.softmax(-1).masked_fill(empty_mask, 0.0) - o_cmp = torch.einsum('bhqk, bkhd -> bqhd', attn_cmp, v_cmp) * g_cmp.unsqueeze(-1) - attn_select = attn_cmp.masked_fill(local_mask, float(1.0)) - attn_select = attn_select.view(B, H, G, T, C).sum(2) - block_indices = attn_select.topk(S, -1)[1] - - block_indices = block_indices.masked_fill(block_indices > (block_indices.new_tensor(range(T))[:, None] // BS), -1) - block_indices = block_indices.transpose(1, 2) - - if head_first: - o_cmp = rearrange(o_cmp, 'b t h d -> b h t d') - return block_indices, o_cmp.to(dtype) - - -def naive_nsa_compression_varlen( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_cmp: torch.Tensor, - block_counts: Union[torch.LongTensor, int], - block_size: int, - scale: float, - cu_seqlens: torch.LongTensor, - head_first: bool = False -) -> torch.LongTensor: - dtype = q.dtype - B, T = q.shape[0], q.shape[1] - H, HQ = k.shape[2], q.shape[2] - D = v.shape[-1] - G = HQ//H - BS = block_size - S = block_counts if isinstance(block_counts, int) else block_counts.max().item() - C = math.ceil(T / block_size) - S = min(S, C) - block_indices = torch.zeros(B, T, H, S, dtype=torch.long, device=q.device) - o_cmp = torch.zeros(B, T, HQ, D, dtype=dtype, device=q.device) - for i in range(len(cu_seqlens) - 1): - T_b = cu_seqlens[i+1] - cu_seqlens[i] - C_b = math.ceil(T_b / block_size) - q_b, k_b, v_b, g_cmp_b = map( - lambda x: x[0][cu_seqlens[i]:cu_seqlens[i+1]], - (q, k, v, g_cmp) - ) - if isinstance(block_counts, torch.Tensor): - s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i+1]] - else: - s_b = block_counts - - k_cmp, v_cmp = compression(k_b.unsqueeze(0), v_b.unsqueeze(0), BS) - S_b = s_b if isinstance(s_b, int) else s_b.max().item() - C_b = k_cmp.shape[1] - S_b = min(S_b, C_b) - k_cmp, v_cmp = map(lambda x: repeat(x.squeeze(0), 'c h d -> c (h g) d', g=G), (k_cmp, v_cmp)) - q_b, k_cmp, v_cmp = map(lambda x: x.float(), (q_b, k_cmp, v_cmp)) - - casual_mask = ((torch.arange(T_b) - BS + 1)[:, None] // BS < torch.arange(C_b)[None, :]).to(q_b.device) - local_mask = (torch.arange(T_b)[:, None] // BS == torch.arange(C_b)[None, :]).to(q.device) - - attn_cmp = torch.einsum('qhd,khd->hqk', q_b*scale, k_cmp) - attn_cmp = attn_cmp.masked_fill(casual_mask, float('-inf')) - attn_cmp = attn_cmp.softmax(-1) - o_cmp[0][cu_seqlens[i]:cu_seqlens[i+1]] = torch.einsum('hqk,khd->qhd', attn_cmp, v_cmp).nan_to_num() *\ - g_cmp_b.unsqueeze(-1) - attn_select = attn_cmp.masked_fill(local_mask, float(1.0)) - attn_select = attn_select.view(H, G, T_b, C_b).sum(1) - block_indices_b = attn_select.topk(S_b, -1)[1] - block_indices_b = block_indices_b.masked_fill( - block_indices_b > (block_indices_b.new_tensor(range(T_b))[:, None]//BS), - 0 - ) - block_indices[0][cu_seqlens[i]:cu_seqlens[i+1], :, :S_b] = block_indices_b.transpose(0, 1) - - if head_first: - o_cmp = rearrange(o_cmp, 'b t h d -> b h t d') - return block_indices, o_cmp.to(dtype) - - -def naive_nsa_with_compression( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g_cmp: torch.Tensor, - g_slc: torch.Tensor, - g_swa: torch.Tensor, - block_counts: Union[torch.LongTensor, int], - block_size: int = 64, - window_size: int = 0, - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False -) -> torch.Tensor: - r""" - Args: - q (torch.Tensor): - Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. - k (torch.Tensor): - Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. - GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. - v (torch.Tensor): - Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. - g_cmp (torch.Tensor): - Gate score for compressed attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. - g_slc (torch.Tensor): - Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. - g_swa (torch.Tensor): - Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. - block_counts (Union[torch.LongTensor, int]): - Number of selected blocks for each token. - If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, - each token can select the same number of blocks. - block_size (int): - Selected block size. Default: 64. - window_size (int): - Sliding window size. Default: 0. - scale (Optional[int]): - Scale factor for attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. - head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `False`. - cu_seqlens (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. - - Returns: - o (torch.Tensor): - Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. - """ - if scale is None: - scale = k.shape[-1] ** -0.5 - if cu_seqlens is not None: - assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" - if head_first: - raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") - if head_first: - q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) - g_cmp, g_slc = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_cmp, g_slc)) - if isinstance(block_counts, torch.Tensor): - block_counts = rearrange(block_counts, 'b h t -> b t h') - if cu_seqlens is not None: - block_indices, o_cmp = naive_nsa_compression_varlen( - q=q, - k=k, - v=v, - g_cmp=g_cmp, - block_counts=block_counts, - block_size=block_size, - scale=scale, - cu_seqlens=cu_seqlens, - head_first=False) - else: - block_indices, o_cmp = naive_nsa_compression( - q=q, - k=k, - v=v, - g_cmp=g_cmp, - block_counts=block_counts, - block_size=block_size, - scale=scale, - head_first=False) - o = naive_nsa( - q=q, - k=k, - v=v, - g_slc=g_slc, - g_swa=g_swa, - block_indices=block_indices, - block_counts=block_counts, - block_size=block_size, - window_size=window_size, - scale=scale, - cu_seqlens=cu_seqlens, - head_first=False - ) + o_cmp - - if head_first: - o = rearrange(o, 'b t h d -> b h t d') - - return o, block_indices diff --git a/top/kernels/deepseek_nsa/utils.py b/top/kernels/deepseek_nsa/utils.py deleted file mode 100644 index d2be54e4..00000000 --- a/top/kernels/deepseek_nsa/utils.py +++ /dev/null @@ -1,233 +0,0 @@ -# -*- coding: utf-8 -*- - -import contextlib -import functools -import os -from functools import lru_cache -from typing import Any, Callable, Dict, Literal, Optional, Tuple - -import torch -import triton -from packaging import version - - -def tensor_cache( - fn: Callable[..., torch.Tensor] -) -> Callable[..., torch.Tensor]: - """ - A decorator that caches the most recent result of a function with tensor inputs. - - This decorator will store the output of the decorated function for the most recent set of input tensors. - If the function is called again with the same input tensors, it will return the cached result. - - - Args: - fn (Callable[..., torch.Tensor]): - The function to be decorated. It should take tensor inputs and return tensor outputs. - - Returns: - Callable[..., torch.Tensor]: - A wrapped version of the input function with single-entry caching. - """ - last_args: Optional[Tuple] = None - last_kwargs: Optional[Dict] = None - last_result: Any = None - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: - nonlocal last_args, last_kwargs, last_result - - if last_args is not None and last_kwargs is not None: - if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): - if all(a is b for a, b in zip(args, last_args)) and \ - all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): - return last_result - - result = fn(*args, **kwargs) - last_args, last_kwargs, last_result = args, kwargs, result - return result - - return wrapper - - -def input_guard( - fn: Callable[..., torch.Tensor] -) -> Callable[..., torch.Tensor]: - """ - A decorator to make sure all input tensors are contiguous and set the device based on input tensors. - """ - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) - contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} - - tensor = None - for arg in args: - if isinstance(arg, torch.Tensor): - tensor = arg - break - if tensor is None: - for value in kwargs.values(): - if isinstance(value, torch.Tensor): - tensor = value - break - - if tensor is not None: - ctx = custom_device_ctx(tensor.device.index) - else: - ctx = contextlib.nullcontext() - - with ctx: - return fn(*contiguous_args, **contiguous_kwargs) - - return wrapper - - -contiguous = input_guard - - -def require_version(version, hint): - """ - Perform a runtime check of the dependency versions, using the exact same syntax used by pip. - """ - def decorator(fn): - @functools.wraps(fn) - def wrapper(ctx, *args, **kwargs): - from transformers.utils.versions import require_version - require_version(version, hint) - return fn(ctx, - *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), - **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) - return wrapper - return decorator - - -def checkpoint(fn): - def wrapper(*args, **kwargs): - return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) - return wrapper - - -@lru_cache(maxsize=None) -def check_pytorch_version(version_s: str = '2.4') -> bool: - return version.parse(torch.__version__) >= version.parse(version_s) - - -@lru_cache(maxsize=None) -def get_multiprocessor_count(tensor_idx: int = 0) -> int: - return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count'] - - -@lru_cache(maxsize=None) -def get_available_device() -> str: - try: - return triton.runtime.driver.active.get_current_target().backend - except BaseException: - import warnings - warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1) - return 'cpu' - - -@lru_cache(maxsize=None) -def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: - device = get_available_device() - if device == 'cuda': - return 'nvidia' - elif device == 'hip': - return 'amd' - elif device == 'xpu': - return 'intel' - else: - return device - - -# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. -# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. -# Therefore, we need to check the triton backend to determine the actual GPU vendor. -device = get_available_device() if get_available_device() != 'hip' else 'cuda' -device_torch_lib = getattr(torch, device) -device_platform = _check_platform() - -is_intel = (device_platform == 'intel') -is_nvidia = (device_platform == 'nvidia') -is_amd = (device_platform == 'amd') -is_intel_a770 = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) -use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') - -# Nvidia Ampere or newer, haven't check AMD and intel yet. -is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8) - - -def get_all_max_shared_memory(): - return [ - triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem'] - for i in range(device_torch_lib.device_count()) - ] - - -@lru_cache(maxsize=None) -def is_triton_shared_mem_enough(max_shared_mem: int = 102400, tensor_idx: int = 0) -> bool: - try: - device_shared_mem_list = get_all_max_shared_memory() - max_shared_memory = device_shared_mem_list[tensor_idx] - return max_shared_memory >= max_shared_mem - except Exception: - return False - - -device_capacity = is_triton_shared_mem_enough() - - -if check_pytorch_version('2.4'): - device = 'cuda' if device == 'cpu' else device - autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) - autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) - - def custom_device_ctx(index: int): - return device_torch_lib.device(index) -else: - assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.' - autocast_custom_fwd = device_torch_lib.amp.custom_fwd - autocast_custom_bwd = device_torch_lib.amp.custom_bwd - - def custom_device_ctx(index: int): - return torch.cuda.device(index) - - -@tensor_cache -def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor: - return offsets[1:] - offsets[:-1] - - -@tensor_cache -def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor: - return torch.cat([torch.arange(n) for n in prepare_lens(offsets).tolist()]).to(offsets.device) - - -@tensor_cache -def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor: - return position_ids.eq(0).cumsum(0) - 1 - - -@tensor_cache -def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor: - position_ids = prepare_position_ids(offsets) - return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets) - - -@tensor_cache -def prepare_chunk_offsets( - offsets: torch.Tensor, - chunk_size: int -) -> torch.LongTensor: - return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1) - - -@tensor_cache -def prepare_chunk_indices( - offsets: torch.LongTensor, - chunk_size: int -) -> torch.LongTensor: - indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()]) - return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets) \ No newline at end of file diff --git a/top/layers/deepseek_nsa.py b/top/layers/deepseek_nsa.py index 30fef125..2ca7336a 100644 --- a/top/layers/deepseek_nsa.py +++ b/top/layers/deepseek_nsa.py @@ -2,11 +2,8 @@ from torch import nn from top.functions import NativeSparseAttentionFunc -from top.kernels.deepseek_nsa.nsa_torch import naive_nsa - class NativeSparseAttentionLayer(nn.Module): - def __init__( self, batch, @@ -38,63 +35,3 @@ def __init__( def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor) -> torch.Tensor: return self.fn(Q, K, V, BlockIndices) - - -def main(): - B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 - - block_T = min(128, 16) - - kernel = NativeSparseAttentionLayer( - batch=B, - heads=HQ, - seq_len=SEQ_LEN, - dim=D, - is_causal=True, - block_size=block_size, - groups=HQ // H, - selected_blocks=S, - scale=scale, - tune=True, - ) - - - torch.random.manual_seed(0) - Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) - g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) - DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") - - block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") - block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") - for b in range(B): - for t in range(SEQ_LEN): - for h in range(H): - i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, : len(i_i)] = i_i - block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() - block_indices = block_indices.sort(-1)[0] - - out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) - - ref = naive_nsa( - q=Q, - k=K, - v=V, - g_slc=g_slc, - g_swa=g_swa, - block_indices=block_indices, - block_counts=block_counts, - block_size=block_size, - scale=scale, - ) - - print("out", out) - print("ref", ref) - torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/top/ops/__init__.py b/top/ops/__init__.py index cc49f188..81419a76 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -6,7 +6,7 @@ from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheOp -from .deepseek_nsa import NativeSparseAttentionForwardOp +from .deepseek_nsa import NativeSparseAttentionForwardOp, MeanPoolingForwardOp __all__ = [ "Op", @@ -20,4 +20,5 @@ "MultiHeadLatentAttentionDecodeWithKVCacheOp", "DeepSeekSparseAttentionDecodeWithKVCacheOp", "NativeSparseAttentionForwardOp", + "MeanPoolingForwardOp" ] diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index 27b87002..31c7977a 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -2,13 +2,64 @@ from top.ops.op import Op from top.kernels.kernel import Kernel from top.kernels.deepseek_nsa.nsa_fwd import nsa_fwd_kernel -from typing import Optional, Dict +from top.kernels.deepseek_nsa.mean_pooling_fwd import mean_pooling_fwd_kernel +from typing import Optional, Dict, Callable, Tuple, Any +from fla.ops.utils import mean_pooling +from fla.ops.common.utils import prepare_chunk_indices -from top.kernels.deepseek_nsa.nsa_torch import naive_nsa -__all__ = ["NativeSparseAttentionForwardOp"] +import tilelang +import tilelang.language as T +import functools +__all__ = ["NativeSparseAttentionForwardOp", "MeanPoolingForwardOp"] + + +class MeanPoolingForwardOp(Op): + def __init__( + self, + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune=False + )-> torch.Tensor: + self.batch_size = batch_size + self.total_seqlen = total_seqlen + self.total_chunks = total_chunks + self.heads = heads + self.dim = dim + self.chunk_size = chunk_size + self.tune = tune + + self.dispatch_kernel(kernel_map) + + self.kernel = self.kernel_map["mean_pooling_fwd_kernel"]( + batch_size=self.batch_size, + total_seqlen=self.total_seqlen, + total_chunks=self.total_chunks, + heads=self.heads, + dim=self.dim, + chunk_size=self.chunk_size, + tune= self.tune + ) + @property + def default_kernel_map(self): + return {"mean_pooling_fwd_kernel": mean_pooling_fwd_kernel} + + def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): + return self.kernel(x_unpad, cu_seqlens, chunk_indices) + # def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): + # out = self.kernel(x, cu_seqlens, chunk_indices) + # print(self.batch_size) + # return out.view(self.batch_size,-1, self.heads, self.dim) + + + class NativeSparseAttentionForwardOp(Op): def __init__( self, @@ -35,17 +86,6 @@ def __init__( self.selected_blocks = selected_blocks self.tune = tune - print("batch ", self.batch) - print("heads ", self.heads) - print("seq_len ", self.seq_len) - print("dim ", self.dim) - print("is_causal ", self.is_causal) - print("scale ", self.scale) - print("block_size ", self.block_size) - print("groups ", self.groups) - print("selected_blocks ", self.selected_blocks) - print("tune ", self.tune) - self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["nsa_fwd_kernel"]( self.batch, self.heads, self.seq_len, @@ -61,63 +101,19 @@ def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndice return self.kernel(Q, K, V, BlockIndices) -def main(): - # B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 - - B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 8192, 4, 16*4, 128, 16, 32, torch.float16, 0.1 +def mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size, block_D=64): + total_T, H, D = x_unpad.shape + B = cu_seqlens.shape[0] - 1 + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + total_chunks = chunk_indices.shape[0] - block_T = min(128, 16) - - kernel = NativeSparseAttentionForwardOp( - batch=B, - heads=HQ, - seq_len=SEQ_LEN, + op = MeanPoolingForwardOp( + batch_size=B, + total_seqlen=total_T, + total_chunks=total_chunks, + heads=H, dim=D, - is_causal=True, - block_size=block_size, - groups=HQ // H, - selected_blocks=S, - scale=scale, - tune=True, + chunk_size=chunk_size, + tune= True ) - - - torch.random.manual_seed(0) - Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) - g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) - DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") - - block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") - block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") - for b in range(B): - for t in range(SEQ_LEN): - for h in range(H): - i_i = torch.randperm(max(1, (t // block_size)))[:S] - block_indices[b, t, h, : len(i_i)] = i_i - block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() - block_indices = block_indices.sort(-1)[0] - - out = kernel.forward(Q, K, V, block_indices.to(torch.int32)) - - # ref = naive_nsa( - # q=Q, - # k=K, - # v=V, - # g_slc=g_slc, - # g_swa=g_swa, - # block_indices=block_indices, - # block_counts=block_counts, - # block_size=block_size, - # scale=scale, - # ) - - print("out", out) - # print("ref", ref) - # torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) - - -if __name__ == "__main__": - main() \ No newline at end of file + return op.forward(x_unpad, cu_seqlens, chunk_indices) From 92da705f69ca4678812d6fbf0787b758e88eac2e Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 6 Jan 2026 14:39:52 +0800 Subject: [PATCH 09/14] [Feat]add nsa_fwd kernel/op & mean_pool kernel/op --- tests/ops/test_mean_pooling_ops.py | 44 +++++ top/kernels/deepseek_nsa/mean_pooling_fwd.py | 187 +++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 tests/ops/test_mean_pooling_ops.py create mode 100644 top/kernels/deepseek_nsa/mean_pooling_fwd.py diff --git a/tests/ops/test_mean_pooling_ops.py b/tests/ops/test_mean_pooling_ops.py new file mode 100644 index 00000000..08ad77b3 --- /dev/null +++ b/tests/ops/test_mean_pooling_ops.py @@ -0,0 +1,44 @@ +import argparse +from top.ops import MeanPoolingForwardOp +from top.utils import str2dtype +from benchmarks.deepseek_nsa.deepseek_nsa import MeanPoolingForwardBenchmark + + +def test_mean_pooling_op( + batch_size, + total_seqlen, + total_chunks, + heads, + dim, + chunk_size, + tune= True +): + op = MeanPoolingForwardOp(batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=tune) + + benchmark = MeanPoolingForwardBenchmark(batch_size, total_seqlen, total_chunks, heads, dim, chunk_size,tune=tune) + + inputs = benchmark.gen_inputs() + benchmark.check(op, *inputs) + benchmark.profile(op, *inputs) + benchmark.baseline_profile(*inputs) + +if __name__ == "__main__": + import sys + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=1, help='logical batch size') + parser.add_argument('--total_seqlen', type=int, default=1*8192*1, help='number of heads') + parser.add_argument('--total_chunks', type=int, default=1*256*1, help='sequence length') + parser.add_argument('--heads', type=int, default=128, help='head dim') + parser.add_argument('--dim', type=int, default=128, help='scale') + parser.add_argument('--chunk_size', type=int, default=32, help='scale') + parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') + args = parser.parse_args() + test_mean_pooling_op( + args.batch_size, + args.total_seqlen, + args.total_chunks, + args.heads, + args.dim, + args.chunk_size, + args.tune, + ) \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/mean_pooling_fwd.py b/top/kernels/deepseek_nsa/mean_pooling_fwd.py new file mode 100644 index 00000000..dde7add1 --- /dev/null +++ b/top/kernels/deepseek_nsa/mean_pooling_fwd.py @@ -0,0 +1,187 @@ +import torch +from fla.ops.utils import mean_pooling +from fla.ops.common.utils import prepare_chunk_indices + +from typing import Optional, Tuple +from top.kernels.kernel import Kernel +import itertools + +import tilelang +import tilelang.language as T + + +__all__ = ["mean_pooling_fwd_kernel"] + +def _mean_pooling_kernel( + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, +): + dtype = T.float16 + accum_dtype = T.float32 + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + def _mean_pooling_func(block_D, threads): + + ND = T.ceildiv(dim, block_D) + + x_shape = [total_seqlen, heads, dim] + cu_seqlens_shape = [batch_size + 1] + chunk_indices_shape = [total_chunks, 2] + output_shape = [total_chunks, heads, dim] + @T.prim_func + def _mean_pooling_main( + X_unpad: T.Tensor(x_shape, dtype), + cu_seqlens: T.Tensor(cu_seqlens_shape, T.int32), + chunk_indices: T.Tensor(chunk_indices_shape, T.int32), + Output: T.Tensor(output_shape, dtype), + ): + with T.Kernel( + ND, + total_chunks, + heads, + threads=threads + ) as (i_d, i_t, i_h): + accum = T.alloc_fragment([block_D], accum_dtype) + d_start = i_d * block_D + + seq_id = chunk_indices[i_t, 0] + local_chunk_id = chunk_indices[i_t, 1] + start = cu_seqlens[seq_id] + end = cu_seqlens[seq_id + 1] + seqlen = end - start + + chunk_start = local_chunk_id * chunk_size + chunk_end = T.min(chunk_start + chunk_size, seqlen) + actual_bt = chunk_end - chunk_start + + for d in T.Parallel(block_D): + accum[d] = T.cast(0, accum_dtype) + for t_rel in T.serial(actual_bt): + t_abs = start + chunk_start + t_rel + for d in T.Parallel(block_D): + if d_start + d < dim: + accum[d] += T.cast(X_unpad[t_abs, i_h, d_start + d], accum_dtype) + for d in T.Parallel(block_D): + if d_start + d < dim: + Output[i_t, i_h, d_start + d] = T.cast(accum[d] / T.cast(actual_bt, accum_dtype), dtype) + + return _mean_pooling_main + return _mean_pooling_func + + +@torch.library.custom_op("top::mean_pooling_fwd_wrapped_kernel", mutates_args=()) +def _mean_pooling_wrapped_kernel( + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + block_D: int, + threads: int, + x_unpad: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_indices: torch.Tensor, +)->torch.Tensor: + return _mean_pooling_kernel( + batch_size, + total_seqlen, + total_chunks, + heads, + dim, + chunk_size, + )(block_D, threads)(x_unpad, cu_seqlens, chunk_indices) + + +@_mean_pooling_wrapped_kernel.register_fake +def _( + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + block_D: int, + threads: int, + *inputs +)->torch.Tensor: + fake_o = torch.empty_like(inputs[0]) + return fake_o + + +class mean_pooling_fwd_kernel(Kernel): + supported_archs: list[int] = [80, 89, 90, 100] + + def __init__( + self, + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + config: Optional[dict] = None, + tune=False + ): + super().__init__() + self.batch_size = batch_size + self.total_seqlen = total_seqlen + self.total_chunks = total_chunks + self.heads = heads + self.dim = dim + self.chunk_size = chunk_size + + self.kernel = _mean_pooling_kernel( + self.batch_size, + self.total_seqlen, + self.total_chunks, + self.heads, + self.dim, + self.chunk_size, + ) + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_D": min(64, self.dim), + "threads": 128, + } + + @property + def autotune_configs(self) -> list[dict]: + block_D = [32, 64, 128] + threads = [32, 64, 128] + _configs = list(itertools.product(block_D, threads)) + configs = [{ + "block_D": c[0], + "threads": c[1] + } for c in _configs] + return configs + + def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): + return _mean_pooling_wrapped_kernel( + self.batch_size, + self.total_seqlen, + self.total_chunks, + self.heads, + self.dim, + self.chunk_size, + self.config["block_D"], + self.config["threads"], + x_unpad, + cu_seqlens, + chunk_indices + ) \ No newline at end of file From 406dc5b65c3820a36c7ab22fd4f6c61332516821 Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Tue, 6 Jan 2026 17:53:52 +0800 Subject: [PATCH 10/14] [Feat]add nsa_fwd kernel/op & mean_pool kernel/op --- benchmarks/__init__.py | 1 + benchmarks/deepseek_nsa/__init__.py | 1 - benchmarks/deepseek_nsa/deepseek_nsa.py | 158 ++++++++++--------- benchmarks/profile/profile_run.py | 21 ++- test_tileops.py | 1 - tests/functions/test_deepseek_nsa_func.py | 29 ++-- tests/layers/test_deepseek_nsa_layer.py | 29 ++-- tests/ops/test_deepseek_nsa_ops.py | 28 ++-- tests/ops/test_mean_pooling_ops.py | 27 ++-- top/functions/__init__.py | 1 + top/functions/deepseek_nsa.py | 42 +++-- top/kernels/deepseek_nsa/mean_pooling_fwd.py | 97 +++++------- top/kernels/deepseek_nsa/nsa_fwd.py | 144 +++++++++-------- top/layers/__init__.py | 2 +- top/layers/deepseek_nsa.py | 39 +++-- top/ops/__init__.py | 15 +- top/ops/deepseek_nsa.py | 87 +++++----- 17 files changed, 382 insertions(+), 340 deletions(-) diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 7a0aed7d..35d0b4be 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -4,6 +4,7 @@ from .gemm import GemmBenchmark, MatMulBenchmark from .flash_decode import MultiHeadAttentionDecodeBenchmark, GroupQueryAttentionDecodeBenchmark from .deepseek_mla import MultiHeadLatentAttentionDecodeBenchmark, DeepSeekSparseAttentionDecodeBenchmark + __all__ = [ 'Benchmark', 'NativeSparseAttentionForwardBenchmark', diff --git a/benchmarks/deepseek_nsa/__init__.py b/benchmarks/deepseek_nsa/__init__.py index 58c34136..911147ad 100644 --- a/benchmarks/deepseek_nsa/__init__.py +++ b/benchmarks/deepseek_nsa/__init__.py @@ -1,6 +1,5 @@ from .deepseek_nsa import NativeSparseAttentionForwardBenchmark - __all__ = [ "NativeSparseAttentionForwardBenchmark", ] diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py index d4257fa7..96f58b3b 100644 --- a/benchmarks/deepseek_nsa/deepseek_nsa.py +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -3,9 +3,8 @@ from top.ops import MeanPoolingForwardOp import torch -from torch.nn import functional as f -from typing import Tuple, Any, Optional +from typing import Any from native_sparse_attention.ops.naive import naive_nsa from native_sparse_attention.ops.parallel import parallel_nsa_fwd from fla.ops.utils import mean_pooling @@ -16,19 +15,17 @@ class NativeSparseAttentionForwardBenchmark(Benchmark): op_type = NativeSparseAttentionForwardOp - def __init__( - self, - batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - tune=False - ): + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + tune=False): self.batch = batch self.heads = heads self.seq_len = seq_len @@ -56,12 +53,11 @@ def total_flops(self): total_keys = S * BS + window_size flops = 4 * B * T * HQ * D * total_keys return flops - + @property def total_memory(self): return (self.batch * self.heads * (2 * self.seq_len) * self.dim * self.dtype.itemsize) - def gen_inputs(self): Q = torch.randn( self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) @@ -69,26 +65,40 @@ def gen_inputs(self): self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype) V = torch.randn( self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype) - - self.o_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim), dtype=self.dtype, device="cuda") - self.lse_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim), dtype=torch.float, device="cuda") - - self.g_slc = torch.ones((self.batch, self.seq_len, self.heads), dtype=self.dtype, device="cuda").requires_grad_(True) - self.g_swa = torch.ones((self.batch, self.seq_len, self.heads), dtype=self.dtype, device="cuda").requires_grad_(True) - - block_indices = torch.full((self.batch, self.seq_len, self.head_kv, self.selected_blocks), self.seq_len, dtype=torch.long, device="cuda") - self.block_counts = torch.zeros((self.batch, self.seq_len, self.head_kv), dtype=torch.long, device="cuda") + + self.o_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim), + dtype=self.dtype, + device="cuda") + self.lse_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim), + dtype=torch.float, + device="cuda") + + self.g_slc = torch.ones((self.batch, self.seq_len, self.heads), + dtype=self.dtype, + device="cuda").requires_grad_(True) + self.g_swa = torch.ones((self.batch, self.seq_len, self.heads), + dtype=self.dtype, + device="cuda").requires_grad_(True) + + block_indices = torch.full((self.batch, self.seq_len, self.head_kv, self.selected_blocks), + self.seq_len, + dtype=torch.long, + device="cuda") + self.block_counts = torch.zeros((self.batch, self.seq_len, self.head_kv), + dtype=torch.long, + device="cuda") for b in range(self.batch): for t in range(self.seq_len): - for h in range(self.head_kv): + for h in range(self.head_kv): i_i = torch.randperm(max(1, (t // self.block_size)))[:self.selected_blocks] - block_indices[b, t, h, : len(i_i)] = i_i - self.block_counts[b, t, h] = (block_indices[b, t, h] != self.seq_len).sum().item() + block_indices[b, t, h, :len(i_i)] = i_i + self.block_counts[b, t, h] = (block_indices[b, t, h] + != self.seq_len).sum().item() block_indices = block_indices.sort(-1)[0].to(torch.int32) return Q, K, V, block_indices - - def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor) -> torch.Tensor: + def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor) -> torch.Tensor: return naive_nsa( q=Q, k=K, @@ -97,12 +107,12 @@ def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIn g_swa=self.g_swa, block_indices=BlockIndices.to(torch.long), block_counts=self.block_counts, - block_size=self.block_size, + block_size=self.block_size, scale=self.scale, - ) - + ) - def baseline_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor)-> torch.Tensor: + def baseline_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor) -> torch.Tensor: o, lse = parallel_nsa_fwd( q=Q, k=K, @@ -114,74 +124,80 @@ def baseline_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, Bl ) return o - - def baseline_profile(self, *inputs: Any, warmup: int = 100, rep: int = 100, device: str = "cuda:0") -> Any: + def baseline_profile(self, + *inputs: Any, + warmup: int = 100, + rep: int = 100, + device: str = "cuda:0") -> Any: print("===== Profiling FLA NSA_Fwd backend =====") return super().baseline_profile( self.baseline_program, *inputs, backend="FLA", warmup=warmup, rep=rep, device=device) - + class MeanPoolingForwardBenchmark(Benchmark): op_type = MeanPoolingForwardOp - def __init__( - self, - batch_size, - total_seqlen, - total_chunks, - heads, - dim, - chunk_size, - tune= True - ): + def __init__(self, batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=True): self.batch_size = batch_size self.total_seqlen = total_seqlen self.total_chunks = total_chunks self.heads = heads self.dim = dim self.chunk_size = chunk_size - self.tune= tune + self.tune = tune self.dtype = torch.float16 @property def total_flops(self): flops = self.heads * self.dim * (self.total_seqlen + self.total_chunks) return flops - + @property def total_memory(self): - return self.heads*self.dim*(self.total_seqlen+self.total_chunks)*self.dtype.itemsize + 16*self.total_chunks - + return self.heads * self.dim * ( + self.total_seqlen + self.total_chunks) * self.dtype.itemsize + 16 * self.total_chunks + def gen_inputs(self): - x_unpad = torch.randn(self.total_seqlen, self.heads, self.dim, device='cuda', dtype=self.dtype) + x_unpad = torch.randn( + self.total_seqlen, self.heads, self.dim, device='cuda', dtype=self.dtype) # fixed length b = self.batch_size - t = self.total_seqlen//b + t = self.total_seqlen // b - cu_seqlens = torch.arange(0, (b + 1) * t, t, dtype=torch.int32, device='cuda') + cu_seqlens = torch.arange(0, (b + 1) * t, t, dtype=torch.int32, device='cuda') chunk_indices = prepare_chunk_indices(cu_seqlens, self.chunk_size) - - return x_unpad, cu_seqlens, chunk_indices + return x_unpad, cu_seqlens, chunk_indices - def ref_program(self, x_unpad:torch.Tensor, cu_seqlens:torch.Tensor, chunk_indices:torch.Tensor) -> torch.Tensor: + def ref_program(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, + chunk_indices: torch.Tensor) -> torch.Tensor: b = self.batch_size - t = self.total_seqlen//b + t = self.total_seqlen // b x = x_unpad.view(b, t, self.heads, self.dim) - - return mean_pooling(x, chunk_size=self.chunk_size, cu_seqlens=None, head_first=False).view(-1,self.heads, self.dim) - - def baseline_program(self, x_unpad:torch.Tensor, cu_seqlens:torch.Tensor, chunk_indices:torch.Tensor) -> torch.Tensor: + return mean_pooling( + x, chunk_size=self.chunk_size, cu_seqlens=None, + head_first=False).view(-1, self.heads, self.dim) + + def baseline_program(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, + chunk_indices: torch.Tensor) -> torch.Tensor: b = self.batch_size - t = self.total_seqlen//b + t = self.total_seqlen // b x = x_unpad.view(b, t, self.heads, self.dim) - return mean_pooling(x, chunk_size=self.chunk_size, cu_seqlens=None, head_first=False).view(-1,self.heads, self.dim) - - - - def baseline_profile(self, *inputs: Any, warmup: int = 100, rep: int = 100, device: str = "cuda:0") -> Any: + return mean_pooling( + x, chunk_size=self.chunk_size, cu_seqlens=None, + head_first=False).view(-1, self.heads, self.dim) + + def baseline_profile(self, + *inputs: Any, + warmup: int = 100, + rep: int = 100, + device: str = "cuda:0") -> Any: print("===== Profiling Mean Pooling_Fwd backend =====") return super().baseline_profile( - self.baseline_program, *inputs, backend="Mean Pooling", warmup=warmup, rep=rep, device=device) - + self.baseline_program, + *inputs, + backend="Mean Pooling", + warmup=warmup, + rep=rep, + device=device) diff --git a/benchmarks/profile/profile_run.py b/benchmarks/profile/profile_run.py index 521f8dbc..ac3172b7 100644 --- a/benchmarks/profile/profile_run.py +++ b/benchmarks/profile/profile_run.py @@ -96,13 +96,20 @@ def build_nsa_cmd(args_dict): """ cmd_args = [ '--batch', - str(args_dict['batch']), '--heads', - str(args_dict['heads']), '--seq_len', - str(args_dict['seq_len']), '--dim', - str(args_dict['dim']), '--scale', - str(args_dict.get('scale', 0.1)), '--block_size', - str(args_dict['block_size']), '--groups', - str(args_dict['groups']), '--selected_blocks', + str(args_dict['batch']), + '--heads', + str(args_dict['heads']), + '--seq_len', + str(args_dict['seq_len']), + '--dim', + str(args_dict['dim']), + '--scale', + str(args_dict.get('scale', 0.1)), + '--block_size', + str(args_dict['block_size']), + '--groups', + str(args_dict['groups']), + '--selected_blocks', str(args_dict['selected_blocks']), ] diff --git a/test_tileops.py b/test_tileops.py index 25d24aeb..2b695875 100644 --- a/test_tileops.py +++ b/test_tileops.py @@ -1,5 +1,4 @@ import torch -import top from top import MLAKernel device = "cuda" diff --git a/tests/functions/test_deepseek_nsa_func.py b/tests/functions/test_deepseek_nsa_func.py index b9136793..21a439fe 100644 --- a/tests/functions/test_deepseek_nsa_func.py +++ b/tests/functions/test_deepseek_nsa_func.py @@ -1,10 +1,9 @@ import argparse from top.functions import NativeSparseAttentionFunc -from top.utils import str2dtype from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark -def test_nsa_op( +def test_nsa_op( batch, heads, seq_len, @@ -16,9 +15,20 @@ def test_nsa_op( selected_blocks=16, # dtype='float16', tune=False, - ): - func = NativeSparseAttentionFunc(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) - benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks) +): + func = NativeSparseAttentionFunc( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, + block_size, groups, selected_blocks) inputs = benchmark.gen_inputs() benchmark.check(func, *inputs) @@ -28,10 +38,11 @@ def test_nsa_op( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=2, help='batch size') - parser.add_argument('--heads', type=int, default=16*4, help='number of heads') - parser.add_argument('--seq_len', type=int, default=8192*3, help='sequence length') + parser.add_argument('--heads', type=int, default=16 * 4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192 * 3, help='sequence length') parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument( + '--is_causal', action='store_true', default=True, help='enable causal attention') parser.add_argument('--scale', type=float, default=0.1, help='scale') parser.add_argument('--block_size', type=int, default=32, help='block size') parser.add_argument('--groups', type=int, default=16, help='number of groups') @@ -50,4 +61,4 @@ def test_nsa_op( args.groups, args.selected_blocks, args.tune, - ) \ No newline at end of file + ) diff --git a/tests/layers/test_deepseek_nsa_layer.py b/tests/layers/test_deepseek_nsa_layer.py index 7177ad57..b3732673 100644 --- a/tests/layers/test_deepseek_nsa_layer.py +++ b/tests/layers/test_deepseek_nsa_layer.py @@ -1,10 +1,9 @@ import argparse from top.layers import NativeSparseAttentionLayer -from top.utils import str2dtype from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark -def test_nsa_op( +def test_nsa_op( batch, heads, seq_len, @@ -16,9 +15,20 @@ def test_nsa_op( selected_blocks=16, # dtype='float16', tune=False, - ): - layer = NativeSparseAttentionLayer(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) - benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks) +): + layer = NativeSparseAttentionLayer( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, + block_size, groups, selected_blocks) inputs = benchmark.gen_inputs() benchmark.check(layer, *inputs) @@ -28,10 +38,11 @@ def test_nsa_op( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=2, help='batch size') - parser.add_argument('--heads', type=int, default=16*4, help='number of heads') - parser.add_argument('--seq_len', type=int, default=8192*3, help='sequence length') + parser.add_argument('--heads', type=int, default=16 * 4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192 * 3, help='sequence length') parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument( + '--is_causal', action='store_true', default=True, help='enable causal attention') parser.add_argument('--scale', type=float, default=0.1, help='scale') parser.add_argument('--block_size', type=int, default=32, help='block size') parser.add_argument('--groups', type=int, default=16, help='number of groups') @@ -50,4 +61,4 @@ def test_nsa_op( args.groups, args.selected_blocks, args.tune, - ) \ No newline at end of file + ) diff --git a/tests/ops/test_deepseek_nsa_ops.py b/tests/ops/test_deepseek_nsa_ops.py index f89ba0a7..2d4b6b11 100644 --- a/tests/ops/test_deepseek_nsa_ops.py +++ b/tests/ops/test_deepseek_nsa_ops.py @@ -1,10 +1,9 @@ import argparse from top.ops import NativeSparseAttentionForwardOp -from top.utils import str2dtype from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark -def test_nsa_op( +def test_nsa_op( batch, heads, seq_len, @@ -17,8 +16,19 @@ def test_nsa_op( # dtype='float16', tune=False, ): - op = NativeSparseAttentionForwardOp(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) - benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks) + op = NativeSparseAttentionForwardOp( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, + block_size, groups, selected_blocks) inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) @@ -26,14 +36,14 @@ def test_nsa_op( benchmark.baseline_profile(*inputs) - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=16, help='batch size') - parser.add_argument('--heads', type=int, default=16*4, help='number of heads') - parser.add_argument('--seq_len', type=int, default=8192*1, help='sequence length') + parser.add_argument('--heads', type=int, default=16 * 4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192 * 1, help='sequence length') parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument( + '--is_causal', action='store_true', default=True, help='enable causal attention') parser.add_argument('--scale', type=float, default=0.1, help='scale') parser.add_argument('--block_size', type=int, default=32, help='block size') parser.add_argument('--groups', type=int, default=16, help='number of groups') @@ -52,4 +62,4 @@ def test_nsa_op( args.groups, args.selected_blocks, args.tune, - ) \ No newline at end of file + ) diff --git a/tests/ops/test_mean_pooling_ops.py b/tests/ops/test_mean_pooling_ops.py index 08ad77b3..aab7509b 100644 --- a/tests/ops/test_mean_pooling_ops.py +++ b/tests/ops/test_mean_pooling_ops.py @@ -1,33 +1,26 @@ import argparse from top.ops import MeanPoolingForwardOp -from top.utils import str2dtype from benchmarks.deepseek_nsa.deepseek_nsa import MeanPoolingForwardBenchmark -def test_mean_pooling_op( - batch_size, - total_seqlen, - total_chunks, - heads, - dim, - chunk_size, - tune= True -): - op = MeanPoolingForwardOp(batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=tune) - - benchmark = MeanPoolingForwardBenchmark(batch_size, total_seqlen, total_chunks, heads, dim, chunk_size,tune=tune) +def test_mean_pooling_op(batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=True): + op = MeanPoolingForwardOp( + batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=tune) + + benchmark = MeanPoolingForwardBenchmark( + batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=tune) inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) benchmark.profile(op, *inputs) benchmark.baseline_profile(*inputs) + if __name__ == "__main__": - import sys parser = argparse.ArgumentParser() parser.add_argument('--batch_size', type=int, default=1, help='logical batch size') - parser.add_argument('--total_seqlen', type=int, default=1*8192*1, help='number of heads') - parser.add_argument('--total_chunks', type=int, default=1*256*1, help='sequence length') + parser.add_argument('--total_seqlen', type=int, default=1 * 8192 * 1, help='number of heads') + parser.add_argument('--total_chunks', type=int, default=1 * 256 * 1, help='sequence length') parser.add_argument('--heads', type=int, default=128, help='head dim') parser.add_argument('--dim', type=int, default=128, help='scale') parser.add_argument('--chunk_size', type=int, default=32, help='scale') @@ -41,4 +34,4 @@ def test_mean_pooling_op( args.dim, args.chunk_size, args.tune, - ) \ No newline at end of file + ) diff --git a/top/functions/__init__.py b/top/functions/__init__.py index 728b5339..e79e0892 100644 --- a/top/functions/__init__.py +++ b/top/functions/__init__.py @@ -7,6 +7,7 @@ from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheFunc from .matmul import MatMulFunc from .deepseek_nsa import NativeSparseAttentionFunc + __all__ = [ "Function", "MultiHeadAttentionFunc", diff --git a/top/functions/deepseek_nsa.py b/top/functions/deepseek_nsa.py index 65b990fe..5edfe64a 100644 --- a/top/functions/deepseek_nsa.py +++ b/top/functions/deepseek_nsa.py @@ -2,7 +2,6 @@ from top.functions.function import Function from top.ops.deepseek_nsa import NativeSparseAttentionForwardOp - __all__ = ['NativeSparseAttentionFunc'] @@ -16,7 +15,7 @@ def forward(ctx, Q, K, V, BlockIndices, fwd_op): @staticmethod def backward(ctx, dO): raise NotImplementedError("Backward pass is not implemented for nsa.") - + @staticmethod def decode(ctx, dO): raise NotImplementedError("Decode pass is not implemented for nsa.") @@ -24,18 +23,17 @@ def decode(ctx, dO): class NativeSparseAttentionFunc(Function): - def __init__( - self, - batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - tune=False): + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + tune=False): self.batch = batch self.heads = heads @@ -49,7 +47,17 @@ def __init__( self.tune = tune self.fwd_op = NativeSparseAttentionForwardOp( - batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) - - def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor) -> torch.Tensor: + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor) -> torch.Tensor: return nsa_decode_ctx.apply(Q, K, V, BlockIndices, self.fwd_op) diff --git a/top/kernels/deepseek_nsa/mean_pooling_fwd.py b/top/kernels/deepseek_nsa/mean_pooling_fwd.py index dde7add1..17d70fe2 100644 --- a/top/kernels/deepseek_nsa/mean_pooling_fwd.py +++ b/top/kernels/deepseek_nsa/mean_pooling_fwd.py @@ -1,17 +1,15 @@ import torch -from fla.ops.utils import mean_pooling -from fla.ops.common.utils import prepare_chunk_indices -from typing import Optional, Tuple +from typing import Optional from top.kernels.kernel import Kernel import itertools import tilelang import tilelang.language as T - __all__ = ["mean_pooling_fwd_kernel"] + def _mean_pooling_kernel( batch_size: int, total_seqlen: int, @@ -39,19 +37,15 @@ def _mean_pooling_func(block_D, threads): cu_seqlens_shape = [batch_size + 1] chunk_indices_shape = [total_chunks, 2] output_shape = [total_chunks, heads, dim] + @T.prim_func def _mean_pooling_main( - X_unpad: T.Tensor(x_shape, dtype), - cu_seqlens: T.Tensor(cu_seqlens_shape, T.int32), - chunk_indices: T.Tensor(chunk_indices_shape, T.int32), - Output: T.Tensor(output_shape, dtype), + X_unpad: T.Tensor(x_shape, dtype), + cu_seqlens: T.Tensor(cu_seqlens_shape, T.int32), + chunk_indices: T.Tensor(chunk_indices_shape, T.int32), + Output: T.Tensor(output_shape, dtype), ): - with T.Kernel( - ND, - total_chunks, - heads, - threads=threads - ) as (i_d, i_t, i_h): + with T.Kernel(ND, total_chunks, heads, threads=threads) as (i_d, i_t, i_h): accum = T.alloc_fragment([block_D], accum_dtype) d_start = i_d * block_D @@ -74,9 +68,12 @@ def _mean_pooling_main( accum[d] += T.cast(X_unpad[t_abs, i_h, d_start + d], accum_dtype) for d in T.Parallel(block_D): if d_start + d < dim: - Output[i_t, i_h, d_start + d] = T.cast(accum[d] / T.cast(actual_bt, accum_dtype), dtype) + Output[i_t, i_h, + d_start + d] = T.cast(accum[d] / T.cast(actual_bt, accum_dtype), + dtype) return _mean_pooling_main + return _mean_pooling_func @@ -93,7 +90,7 @@ def _mean_pooling_wrapped_kernel( x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor, -)->torch.Tensor: +) -> torch.Tensor: return _mean_pooling_kernel( batch_size, total_seqlen, @@ -101,21 +98,21 @@ def _mean_pooling_wrapped_kernel( heads, dim, chunk_size, - )(block_D, threads)(x_unpad, cu_seqlens, chunk_indices) + )(block_D, threads)(x_unpad, cu_seqlens, chunk_indices) @_mean_pooling_wrapped_kernel.register_fake def _( - batch_size: int, - total_seqlen: int, - total_chunks: int, - heads: int, - dim: int, - chunk_size: int, - block_D: int, - threads: int, - *inputs -)->torch.Tensor: + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + block_D: int, + threads: int, + *inputs +) -> torch.Tensor: fake_o = torch.empty_like(inputs[0]) return fake_o @@ -123,17 +120,15 @@ def _( class mean_pooling_fwd_kernel(Kernel): supported_archs: list[int] = [80, 89, 90, 100] - def __init__( - self, - batch_size: int, - total_seqlen: int, - total_chunks: int, - heads: int, - dim: int, - chunk_size: int, - config: Optional[dict] = None, - tune=False - ): + def __init__(self, + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + config: Optional[dict] = None, + tune=False): super().__init__() self.batch_size = batch_size self.total_seqlen = total_seqlen @@ -152,7 +147,7 @@ def __init__( ) self.init_config(config, tune) - + @property def default_config(self) -> dict: return { @@ -165,23 +160,11 @@ def autotune_configs(self) -> list[dict]: block_D = [32, 64, 128] threads = [32, 64, 128] _configs = list(itertools.product(block_D, threads)) - configs = [{ - "block_D": c[0], - "threads": c[1] - } for c in _configs] + configs = [{"block_D": c[0], "threads": c[1]} for c in _configs] return configs - def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): - return _mean_pooling_wrapped_kernel( - self.batch_size, - self.total_seqlen, - self.total_chunks, - self.heads, - self.dim, - self.chunk_size, - self.config["block_D"], - self.config["threads"], - x_unpad, - cu_seqlens, - chunk_indices - ) \ No newline at end of file + def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): + return _mean_pooling_wrapped_kernel(self.batch_size, self.total_seqlen, self.total_chunks, + self.heads, self.dim, self.chunk_size, + self.config["block_D"], self.config["threads"], x_unpad, + cu_seqlens, chunk_indices) diff --git a/top/kernels/deepseek_nsa/nsa_fwd.py b/top/kernels/deepseek_nsa/nsa_fwd.py index a8023735..9ac9783f 100644 --- a/top/kernels/deepseek_nsa/nsa_fwd.py +++ b/top/kernels/deepseek_nsa/nsa_fwd.py @@ -1,39 +1,34 @@ import tilelang import tilelang.language as T -from typing import Optional, Tuple +from typing import Optional from top.kernels.kernel import Kernel import itertools import torch - __all__ = ["nsa_fwd_kernel"] # dtype default float16, accum_dtype default float32 -def _nsa_fwd_kernel( - batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16 - ): - if scale is None: - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - else: - scale = scale * 1.44269504 # log2(e) +def _nsa_fwd_kernel(batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16): + + scale = (1.0 / dim)**0.5 * 1.44269504 if scale is None else scale * 1.44269504 head_kv = heads // groups - + block_indices_dtype = T.int32 dtype = T.float16 accum_dtype = T.float32 block_S = block_size - + @tilelang.jit( out_idx=[-1], pass_configs={ @@ -56,14 +51,14 @@ def _nsa_fwd_func(block_T, num_stages, threads): q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - + @T.prim_func def _nsa_fwd_main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), - Output: T.Tensor(q_shape, dtype), + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), ): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([G, BK], dtype) @@ -84,7 +79,7 @@ def _nsa_fwd_main( i_b, i_h = i_bh // head_kv, i_bh % head_kv NS = S - T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) + T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -94,22 +89,29 @@ def _nsa_fwd_main( i_s = BlockIndices[i_b, i_t, i_h, i] * BS if i_s <= i_t and i_s >= 0: # [BS, BK] - T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) if is_causal: for i, j in T.Parallel(G, BS): - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, + -T.infinity(acc_s.dtype)) else: T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) # Softmax T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=True) for i in T.Parallel(G): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - + scores_max[i] * scale) for i, j in T.Parallel(G, BS): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) @@ -122,15 +124,16 @@ def _nsa_fwd_main( acc_o[i, j] *= scores_scale[i] # V * softmax(Q * K) - T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(G, BV): acc_o[i, j] /= logsum[i] T.copy(acc_o, O_shared) - T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) + T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) return _nsa_fwd_main + return _nsa_fwd_func @@ -152,27 +155,26 @@ def _nsa_fwd_wrapped_kernel( K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor, -) ->torch.Tensor: - return _nsa_fwd_kernel(batch, heads, seq_len, dim, is_causal, scale, block_size, - groups, selected_blocks)(block_T, num_stages, - threads)(Q, K, V, BlockIndices) +) -> torch.Tensor: + return _nsa_fwd_kernel(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, + selected_blocks)(block_T, num_stages, threads)(Q, K, V, BlockIndices) @_nsa_fwd_wrapped_kernel.register_fake def _( - batch, - heads, - seq_len, - dim, - is_causal, - scale, - block_size, - groups, - selected_blocks, - block_T, - num_stages, - threads, - *inputs + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + block_T, + num_stages, + threads, + *inputs ) -> torch.Tensor: fake_o = torch.empty_like(inputs[0]) return fake_o @@ -181,19 +183,18 @@ def _( class nsa_fwd_kernel(Kernel): supported_archs: list[int] = [80, 89, 90, 100] - def __init__( - self, - batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - config: Optional[dict] = None, - tune=False): + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + config: Optional[dict] = None, + tune=False): super().__init__() self.batch = batch @@ -206,7 +207,9 @@ def __init__( self.groups = groups self.selected_blocks = selected_blocks - self.kernel = _nsa_fwd_kernel(self.batch, self.heads, self.seq_len, self.dim, self.is_causal, self.scale, self.block_size, self.groups, self.selected_blocks) + self.kernel = _nsa_fwd_kernel(self.batch, self.heads, self.seq_len, self.dim, + self.is_causal, self.scale, self.block_size, self.groups, + self.selected_blocks) self.init_config(config, tune) @@ -224,12 +227,13 @@ def autotune_configs(self) -> list[dict]: num_stages = [2, 3] threads = [32, 64, 128] _configs = list(itertools.product(block_T, num_stages, threads)) - configs = [{ - "block_T": c[0], - "num_stages": c[1], - "threads": c[2] - } for c in _configs] + configs = [{"block_T": c[0], "num_stages": c[1], "threads": c[2]} for c in _configs] return configs - def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): - return _nsa_fwd_wrapped_kernel(self.batch, self.heads, self.seq_len, self.dim, self.is_causal, self.scale, self.block_size, self.groups, self.selected_blocks, self.config["block_T"], self.config["num_stages"], self.config["threads"], Q, K, V, BlockIndices) + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor): + return _nsa_fwd_wrapped_kernel(self.batch, self.heads, self.seq_len, self.dim, + self.is_causal, self.scale, self.block_size, self.groups, + self.selected_blocks, self.config["block_T"], + self.config["num_stages"], self.config["threads"], Q, K, V, + BlockIndices) diff --git a/top/layers/__init__.py b/top/layers/__init__.py index c4c371b4..4bbb265a 100644 --- a/top/layers/__init__.py +++ b/top/layers/__init__.py @@ -2,7 +2,7 @@ from .flash_decode import MultiHeadAttentionDecodeLayer, GroupQueryAttentionDecodeLayer from .deepseek_mla import MultiHeadLatentAttentionDecodeLayer, DeepSeekSparseAttentionDecodeLayer from .linear import LinearLayer -from .deepseek_nsa import NativeSparseAttentionLayer +from .deepseek_nsa import NativeSparseAttentionLayer __all__ = [ "MultiHeadAttentionLayer", "GroupQueryAttentionLayer", "MultiHeadAttentionDecodeLayer", diff --git a/top/layers/deepseek_nsa.py b/top/layers/deepseek_nsa.py index 2ca7336a..26cfa295 100644 --- a/top/layers/deepseek_nsa.py +++ b/top/layers/deepseek_nsa.py @@ -4,19 +4,18 @@ class NativeSparseAttentionLayer(nn.Module): - def __init__( - self, - batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - tune=False - ): + + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + tune=False): super().__init__() self.batch = batch @@ -31,7 +30,17 @@ def __init__( self.tune = tune self.fn = NativeSparseAttentionFunc( - batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune=tune) + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) - def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor) -> torch.Tensor: + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor) -> torch.Tensor: return self.fn(Q, K, V, BlockIndices) diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 81419a76..e44d82c3 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -9,16 +9,9 @@ from .deepseek_nsa import NativeSparseAttentionForwardOp, MeanPoolingForwardOp __all__ = [ - "Op", - "MultiHeadAttentionFwdOp", - "MultiHeadAttentionBwdOp", - "GroupQueryAttentionFwdOp", - "GroupQueryAttentionBwdOp", - "GemmOp", - "MultiHeadAttentionDecodeWithKVCacheOp", - "GroupQueryAttentionDecodeWithKVCacheOp", - "MultiHeadLatentAttentionDecodeWithKVCacheOp", - "DeepSeekSparseAttentionDecodeWithKVCacheOp", - "NativeSparseAttentionForwardOp", + "Op", "MultiHeadAttentionFwdOp", "MultiHeadAttentionBwdOp", "GroupQueryAttentionFwdOp", + "GroupQueryAttentionBwdOp", "GemmOp", "MultiHeadAttentionDecodeWithKVCacheOp", + "GroupQueryAttentionDecodeWithKVCacheOp", "MultiHeadLatentAttentionDecodeWithKVCacheOp", + "DeepSeekSparseAttentionDecodeWithKVCacheOp", "NativeSparseAttentionForwardOp", "MeanPoolingForwardOp" ] diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index 31c7977a..ec5e35ca 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -3,31 +3,23 @@ from top.kernels.kernel import Kernel from top.kernels.deepseek_nsa.nsa_fwd import nsa_fwd_kernel from top.kernels.deepseek_nsa.mean_pooling_fwd import mean_pooling_fwd_kernel -from typing import Optional, Dict, Callable, Tuple, Any -from fla.ops.utils import mean_pooling +from typing import Optional, Dict from fla.ops.common.utils import prepare_chunk_indices - -import tilelang -import tilelang.language as T -import functools - - __all__ = ["NativeSparseAttentionForwardOp", "MeanPoolingForwardOp"] class MeanPoolingForwardOp(Op): - def __init__( - self, - batch_size: int, - total_seqlen: int, - total_chunks: int, - heads: int, - dim: int, - chunk_size: int, - kernel_map: Optional[Dict[str, Kernel]] = None, - tune=False - )-> torch.Tensor: + + def __init__(self, + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune=False) -> torch.Tensor: self.batch_size = batch_size self.total_seqlen = total_seqlen self.total_chunks = total_chunks @@ -35,7 +27,7 @@ def __init__( self.dim = dim self.chunk_size = chunk_size self.tune = tune - + self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["mean_pooling_fwd_kernel"]( @@ -45,36 +37,35 @@ def __init__( heads=self.heads, dim=self.dim, chunk_size=self.chunk_size, - tune= self.tune - ) + tune=self.tune) + @property def default_kernel_map(self): return {"mean_pooling_fwd_kernel": mean_pooling_fwd_kernel} def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): return self.kernel(x_unpad, cu_seqlens, chunk_indices) + # def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): # out = self.kernel(x, cu_seqlens, chunk_indices) # print(self.batch_size) # return out.view(self.batch_size,-1, self.heads, self.dim) - - + class NativeSparseAttentionForwardOp(Op): - def __init__( - self, - batch, - heads, - seq_len, - dim, - is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - kernel_map: Optional[Dict[str, Kernel]] = None, - tune=False - ): + + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune=False): self.batch = batch self.heads = heads self.seq_len = seq_len @@ -88,16 +79,23 @@ def __init__( self.dispatch_kernel(kernel_map) self.kernel = self.kernel_map["nsa_fwd_kernel"]( - self.batch, self.heads, self.seq_len, - self.dim, self.is_causal, self.scale, - self.block_size, self.groups, self.selected_blocks, tune=self.tune) - + self.batch, + self.heads, + self.seq_len, + self.dim, + self.is_causal, + self.scale, + self.block_size, + self.groups, + self.selected_blocks, + tune=self.tune) @property def default_kernel_map(self): return {"nsa_fwd_kernel": nsa_fwd_kernel} - def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor): return self.kernel(Q, K, V, BlockIndices) @@ -114,6 +112,5 @@ def mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size, block_D=64): heads=H, dim=D, chunk_size=chunk_size, - tune= True - ) + tune=True) return op.forward(x_unpad, cu_seqlens, chunk_indices) From abe19828113c4dc1f30039437cbe3423a056ee3b Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Wed, 7 Jan 2026 12:10:08 +0800 Subject: [PATCH 11/14] test: using pytest for better extensibility. --- tests/functions/test_deepseek_nsa_func.py | 35 +++++++++---- tests/layers/test_deepseek_nsa_layer.py | 38 ++++++++++---- tests/ops/test_deepseek_nsa_ops.py | 64 +++++++++++++++++++---- top/kernels/deepseek_nsa/nsa_fwd.py | 3 +- 4 files changed, 108 insertions(+), 32 deletions(-) diff --git a/tests/functions/test_deepseek_nsa_func.py b/tests/functions/test_deepseek_nsa_func.py index 21a439fe..0368e092 100644 --- a/tests/functions/test_deepseek_nsa_func.py +++ b/tests/functions/test_deepseek_nsa_func.py @@ -1,20 +1,38 @@ import argparse +import pytest +import torch + from top.functions import NativeSparseAttentionFunc from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark +@pytest.fixture(autouse=True) +def setup() -> None: + """Set up the test environment.""" + torch.manual_seed(1234) + -def test_nsa_op( +@pytest.mark.parametrize( + "batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune", + [ + # default configuration + (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*8, 128, True, 0.1, 32, 16, 16, True), + (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + ], +) +def test_nsa_func( batch, heads, seq_len, dim, is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - # dtype='float16', - tune=False, + scale, + block_size, + groups, + selected_blocks, + tune, ): func = NativeSparseAttentionFunc( batch, @@ -32,7 +50,6 @@ def test_nsa_op( inputs = benchmark.gen_inputs() benchmark.check(func, *inputs) - benchmark.profile(func, *inputs) if __name__ == "__main__": @@ -50,7 +67,7 @@ def test_nsa_op( parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') args = parser.parse_args() - test_nsa_op( + test_nsa_func( args.batch, args.heads, args.seq_len, diff --git a/tests/layers/test_deepseek_nsa_layer.py b/tests/layers/test_deepseek_nsa_layer.py index b3732673..4994e755 100644 --- a/tests/layers/test_deepseek_nsa_layer.py +++ b/tests/layers/test_deepseek_nsa_layer.py @@ -1,20 +1,37 @@ import argparse from top.layers import NativeSparseAttentionLayer from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark +import pytest +import torch +@pytest.fixture(autouse=True) +def setup() -> None: + """Set up the test environment.""" + torch.manual_seed(1234) -def test_nsa_op( +@pytest.mark.parametrize( + "batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune", + [ + # default configuration + (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*8, 128, True, 0.1, 32, 16, 16, True), + (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + + ], +) +def test_nsa_layer( batch, heads, seq_len, dim, is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - # dtype='float16', - tune=False, + scale, + block_size, + groups, + selected_blocks, + tune, ): layer = NativeSparseAttentionLayer( batch, @@ -32,14 +49,13 @@ def test_nsa_op( inputs = benchmark.gen_inputs() benchmark.check(layer, *inputs) - benchmark.profile(layer, *inputs) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=2, help='batch size') + parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument('--heads', type=int, default=16 * 4, help='number of heads') - parser.add_argument('--seq_len', type=int, default=8192 * 3, help='sequence length') + parser.add_argument('--seq_len', type=int, default=8192 * 4, help='sequence length') parser.add_argument('--dim', type=int, default=128, help='head dim') parser.add_argument( '--is_causal', action='store_true', default=True, help='enable causal attention') @@ -50,7 +66,7 @@ def test_nsa_op( parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') args = parser.parse_args() - test_nsa_op( + test_nsa_layer( args.batch, args.heads, args.seq_len, diff --git a/tests/ops/test_deepseek_nsa_ops.py b/tests/ops/test_deepseek_nsa_ops.py index 2d4b6b11..fb3913d1 100644 --- a/tests/ops/test_deepseek_nsa_ops.py +++ b/tests/ops/test_deepseek_nsa_ops.py @@ -1,21 +1,62 @@ +"""Test NativeSparseAttention operation.""" import argparse +import pytest +import torch + from top.ops import NativeSparseAttentionForwardOp from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark +@pytest.fixture(autouse=True) +def setup() -> None: + """Set up the test environment.""" + torch.manual_seed(1234) + + +@pytest.mark.parametrize( + "batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune", + [ + # default configuration + (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192*8, 128, True, 0.1, 32, 16, 16, True), + (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + # (16, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), + # (16, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), + # (16, 64, 8192*8, 128, True, 0.1, 32, 16, 16, True), + # small batch size configuration + (1, 16, 1024, 128, True, 0.1, 32, 16, 16, True), + (4, 32, 2048, 128, True, 0.1, 32, 16, 16, True), + # different sequence length + (8, 32, 4096, 128, True, 0.1, 32, 16, 16, True), + (8, 32, 16384, 128, True, 0.1, 32, 16, 16, True), + # different block_size + (8, 32, 4096, 128, True, 0.1, 64, 16, 16, True), + (8, 32, 4096, 128, True, 0.1, 16, 16, 16, True), + # different groups + (8, 32, 4096, 128, True, 0.1, 32, 32, 16, True), + # different selected_blocks + (8, 32, 4096, 128, True, 0.1, 32, 16, 8, True), + (8, 32, 4096, 128, True, 0.1, 32, 16, 32, True), + # different scale + (8, 32, 4096, 128, True, 0.05, 32, 16, 16, True), + (8, 32, 4096, 128, True, 0.2, 32, 16, 16, True), + ], +) def test_nsa_op( batch, heads, seq_len, dim, is_causal, - scale=None, - block_size=64, - groups=1, - selected_blocks=16, - # dtype='float16', - tune=False, + scale, + block_size, + groups, + selected_blocks, + tune, ): + """Test NativeSparseAttention forward operation with various configurations.""" op = NativeSparseAttentionForwardOp( batch, heads, @@ -27,13 +68,14 @@ def test_nsa_op( groups, selected_blocks, tune=tune) - benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, - block_size, groups, selected_blocks) + benchmark = NativeSparseAttentionForwardBenchmark( + batch, heads, seq_len, dim, is_causal, scale, + block_size, groups, selected_blocks) inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) - benchmark.profile(op, *inputs) - benchmark.baseline_profile(*inputs) + # benchmark.profile(op, *inputs) + # benchmark.baseline_profile(*inputs) if __name__ == "__main__": @@ -62,4 +104,4 @@ def test_nsa_op( args.groups, args.selected_blocks, args.tune, - ) + ) \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/nsa_fwd.py b/top/kernels/deepseek_nsa/nsa_fwd.py index 9ac9783f..258251dd 100644 --- a/top/kernels/deepseek_nsa/nsa_fwd.py +++ b/top/kernels/deepseek_nsa/nsa_fwd.py @@ -41,7 +41,8 @@ def _nsa_fwd_func(block_T, num_stages, threads): NK = tilelang.cdiv(dim, block_T) NV = tilelang.cdiv(dim, block_T) - assert NK == 1, "The key dimension can not be larger than 256" + assert NK == 1, f"The head dimension (dim={dim}) cannot be larger than block_T ({block_T}). " \ + f"This kernel processes Q and K in a single block, so dim must be <= block_T." S = selected_blocks G = groups From dca22b84580fc34946fc9637231f74888f9ea4ae Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Wed, 7 Jan 2026 12:13:04 +0800 Subject: [PATCH 12/14] test: using pytest for better extensibility. --- tests/functions/test_deepseek_nsa_func.py | 7 ++++--- tests/layers/test_deepseek_nsa_layer.py | 9 +++++---- tests/ops/test_deepseek_nsa_ops.py | 13 ++++++------- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/functions/test_deepseek_nsa_func.py b/tests/functions/test_deepseek_nsa_func.py index 0368e092..4628ff3a 100644 --- a/tests/functions/test_deepseek_nsa_func.py +++ b/tests/functions/test_deepseek_nsa_func.py @@ -5,6 +5,7 @@ from top.functions import NativeSparseAttentionFunc from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark + @pytest.fixture(autouse=True) def setup() -> None: """Set up the test environment.""" @@ -16,9 +17,9 @@ def setup() -> None: [ # default configuration (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*8, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 8, 128, True, 0.1, 32, 16, 16, True), (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), ], ) diff --git a/tests/layers/test_deepseek_nsa_layer.py b/tests/layers/test_deepseek_nsa_layer.py index 4994e755..cc12b4cf 100644 --- a/tests/layers/test_deepseek_nsa_layer.py +++ b/tests/layers/test_deepseek_nsa_layer.py @@ -4,21 +4,22 @@ import pytest import torch + @pytest.fixture(autouse=True) def setup() -> None: """Set up the test environment.""" torch.manual_seed(1234) + @pytest.mark.parametrize( "batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune", [ # default configuration (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*8, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 8, 128, True, 0.1, 32, 16, 16, True), (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), - ], ) def test_nsa_layer( diff --git a/tests/ops/test_deepseek_nsa_ops.py b/tests/ops/test_deepseek_nsa_ops.py index fb3913d1..7b9a03f1 100644 --- a/tests/ops/test_deepseek_nsa_ops.py +++ b/tests/ops/test_deepseek_nsa_ops.py @@ -18,9 +18,9 @@ def setup() -> None: [ # default configuration (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), - (1, 64, 8192*8, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 8, 128, True, 0.1, 32, 16, 16, True), (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), # (16, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), # (16, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), @@ -68,9 +68,8 @@ def test_nsa_op( groups, selected_blocks, tune=tune) - benchmark = NativeSparseAttentionForwardBenchmark( - batch, heads, seq_len, dim, is_causal, scale, - block_size, groups, selected_blocks) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, + block_size, groups, selected_blocks) inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) @@ -104,4 +103,4 @@ def test_nsa_op( args.groups, args.selected_blocks, args.tune, - ) \ No newline at end of file + ) From afc38bfcd0cba476b7eeda5eeced8771ee93010b Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Wed, 7 Jan 2026 13:05:12 +0800 Subject: [PATCH 13/14] test: using pytest for better extensibility. --- top/ops/deepseek_nsa.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index ec5e35ca..882e2631 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -4,7 +4,6 @@ from top.kernels.deepseek_nsa.nsa_fwd import nsa_fwd_kernel from top.kernels.deepseek_nsa.mean_pooling_fwd import mean_pooling_fwd_kernel from typing import Optional, Dict -from fla.ops.common.utils import prepare_chunk_indices __all__ = ["NativeSparseAttentionForwardOp", "MeanPoolingForwardOp"] @@ -46,10 +45,6 @@ def default_kernel_map(self): def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): return self.kernel(x_unpad, cu_seqlens, chunk_indices) - # def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): - # out = self.kernel(x, cu_seqlens, chunk_indices) - # print(self.batch_size) - # return out.view(self.batch_size,-1, self.heads, self.dim) class NativeSparseAttentionForwardOp(Op): @@ -98,19 +93,3 @@ def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): return self.kernel(Q, K, V, BlockIndices) - -def mean_pooling_tilelang(x_unpad, cu_seqlens, chunk_size, block_D=64): - total_T, H, D = x_unpad.shape - B = cu_seqlens.shape[0] - 1 - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) - total_chunks = chunk_indices.shape[0] - - op = MeanPoolingForwardOp( - batch_size=B, - total_seqlen=total_T, - total_chunks=total_chunks, - heads=H, - dim=D, - chunk_size=chunk_size, - tune=True) - return op.forward(x_unpad, cu_seqlens, chunk_indices) From f35d549bca14aaf2fecba0ef2726ccb6bdb78bec Mon Sep 17 00:00:00 2001 From: "jieneng.yu" <1033160740@qq.com> Date: Wed, 7 Jan 2026 13:15:26 +0800 Subject: [PATCH 14/14] test: using pytest for better extensibility. --- top/ops/deepseek_nsa.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py index 882e2631..b5c19a73 100644 --- a/top/ops/deepseek_nsa.py +++ b/top/ops/deepseek_nsa.py @@ -46,7 +46,6 @@ def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices return self.kernel(x_unpad, cu_seqlens, chunk_indices) - class NativeSparseAttentionForwardOp(Op): def __init__(self, @@ -92,4 +91,3 @@ def default_kernel_map(self): def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, BlockIndices: torch.Tensor): return self.kernel(Q, K, V, BlockIndices) -