diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 66315e63..35d0b4be 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -1,4 +1,5 @@ from .benchmark import Benchmark # noqa: F401 +from .deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark from .flash_attn import MultiHeadAttentionBenchmark, MultiHeadAttentionBwdBenchmark, MultiHeadAttentionFwdBenchmark, GroupQueryAttentionBenchmark, GroupQueryAttentionFwdBenchmark, GroupQueryAttentionBwdBenchmark from .gemm import GemmBenchmark, MatMulBenchmark from .flash_decode import MultiHeadAttentionDecodeBenchmark, GroupQueryAttentionDecodeBenchmark @@ -6,6 +7,7 @@ __all__ = [ 'Benchmark', + 'NativeSparseAttentionForwardBenchmark', 'MultiHeadAttentionBenchmark', 'MultiHeadAttentionBwdBenchmark', 'MultiHeadAttentionFwdBenchmark', diff --git a/benchmarks/deepseek_nsa/__init__.py b/benchmarks/deepseek_nsa/__init__.py new file mode 100644 index 00000000..911147ad --- /dev/null +++ b/benchmarks/deepseek_nsa/__init__.py @@ -0,0 +1,5 @@ +from .deepseek_nsa import NativeSparseAttentionForwardBenchmark + +__all__ = [ + "NativeSparseAttentionForwardBenchmark", +] diff --git a/benchmarks/deepseek_nsa/deepseek_nsa.py b/benchmarks/deepseek_nsa/deepseek_nsa.py new file mode 100644 index 00000000..96f58b3b --- /dev/null +++ b/benchmarks/deepseek_nsa/deepseek_nsa.py @@ -0,0 +1,203 @@ +from benchmarks.benchmark import Benchmark +from top.ops import NativeSparseAttentionForwardOp +from top.ops import MeanPoolingForwardOp + +import torch + +from typing import Any +from native_sparse_attention.ops.naive import naive_nsa +from native_sparse_attention.ops.parallel import parallel_nsa_fwd +from fla.ops.utils import mean_pooling + +from fla.ops.common.utils import prepare_chunk_indices + + +class NativeSparseAttentionForwardBenchmark(Benchmark): + op_type = NativeSparseAttentionForwardOp + + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + tune=False): + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + + self.head_kv = self.heads // self.groups + self.dtype = torch.float16 + self.tune = tune + + @property + def total_flops(self): + B = self.batch + T = self.seq_len + HQ = self.heads + D = self.dim + S = self.selected_blocks + BS = self.block_size + + window_size = 0 + total_keys = S * BS + window_size + flops = 4 * B * T * HQ * D * total_keys + return flops + + @property + def total_memory(self): + return (self.batch * self.heads * (2 * self.seq_len) * self.dim * self.dtype.itemsize) + + def gen_inputs(self): + Q = torch.randn( + self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype) + K = torch.randn( + self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype) + V = torch.randn( + self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype) + + self.o_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim), + dtype=self.dtype, + device="cuda") + self.lse_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim), + dtype=torch.float, + device="cuda") + + self.g_slc = torch.ones((self.batch, self.seq_len, self.heads), + dtype=self.dtype, + device="cuda").requires_grad_(True) + self.g_swa = torch.ones((self.batch, self.seq_len, self.heads), + dtype=self.dtype, + device="cuda").requires_grad_(True) + + block_indices = torch.full((self.batch, self.seq_len, self.head_kv, self.selected_blocks), + self.seq_len, + dtype=torch.long, + device="cuda") + self.block_counts = torch.zeros((self.batch, self.seq_len, self.head_kv), + dtype=torch.long, + device="cuda") + for b in range(self.batch): + for t in range(self.seq_len): + for h in range(self.head_kv): + i_i = torch.randperm(max(1, (t // self.block_size)))[:self.selected_blocks] + block_indices[b, t, h, :len(i_i)] = i_i + self.block_counts[b, t, h] = (block_indices[b, t, h] + != self.seq_len).sum().item() + block_indices = block_indices.sort(-1)[0].to(torch.int32) + return Q, K, V, block_indices + + def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor) -> torch.Tensor: + return naive_nsa( + q=Q, + k=K, + v=V, + g_slc=self.g_slc, + g_swa=self.g_swa, + block_indices=BlockIndices.to(torch.long), + block_counts=self.block_counts, + block_size=self.block_size, + scale=self.scale, + ) + + def baseline_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor) -> torch.Tensor: + o, lse = parallel_nsa_fwd( + q=Q, + k=K, + v=V, + block_indices=BlockIndices, + block_counts=self.block_counts, + block_size=self.block_size, + scale=self.scale, + ) + return o + + def baseline_profile(self, + *inputs: Any, + warmup: int = 100, + rep: int = 100, + device: str = "cuda:0") -> Any: + print("===== Profiling FLA NSA_Fwd backend =====") + return super().baseline_profile( + self.baseline_program, *inputs, backend="FLA", warmup=warmup, rep=rep, device=device) + + +class MeanPoolingForwardBenchmark(Benchmark): + op_type = MeanPoolingForwardOp + + def __init__(self, batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=True): + self.batch_size = batch_size + self.total_seqlen = total_seqlen + self.total_chunks = total_chunks + self.heads = heads + self.dim = dim + self.chunk_size = chunk_size + self.tune = tune + self.dtype = torch.float16 + + @property + def total_flops(self): + flops = self.heads * self.dim * (self.total_seqlen + self.total_chunks) + return flops + + @property + def total_memory(self): + return self.heads * self.dim * ( + self.total_seqlen + self.total_chunks) * self.dtype.itemsize + 16 * self.total_chunks + + def gen_inputs(self): + x_unpad = torch.randn( + self.total_seqlen, self.heads, self.dim, device='cuda', dtype=self.dtype) + # fixed length + b = self.batch_size + t = self.total_seqlen // b + + cu_seqlens = torch.arange(0, (b + 1) * t, t, dtype=torch.int32, device='cuda') + chunk_indices = prepare_chunk_indices(cu_seqlens, self.chunk_size) + + return x_unpad, cu_seqlens, chunk_indices + + def ref_program(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, + chunk_indices: torch.Tensor) -> torch.Tensor: + b = self.batch_size + t = self.total_seqlen // b + x = x_unpad.view(b, t, self.heads, self.dim) + + return mean_pooling( + x, chunk_size=self.chunk_size, cu_seqlens=None, + head_first=False).view(-1, self.heads, self.dim) + + def baseline_program(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, + chunk_indices: torch.Tensor) -> torch.Tensor: + b = self.batch_size + t = self.total_seqlen // b + x = x_unpad.view(b, t, self.heads, self.dim) + return mean_pooling( + x, chunk_size=self.chunk_size, cu_seqlens=None, + head_first=False).view(-1, self.heads, self.dim) + + def baseline_profile(self, + *inputs: Any, + warmup: int = 100, + rep: int = 100, + device: str = "cuda:0") -> Any: + print("===== Profiling Mean Pooling_Fwd backend =====") + return super().baseline_profile( + self.baseline_program, + *inputs, + backend="Mean Pooling", + warmup=warmup, + rep=rep, + device=device) diff --git a/benchmarks/input_params/deepseek_nsa.csv b/benchmarks/input_params/deepseek_nsa.csv new file mode 100644 index 00000000..0c338f08 --- /dev/null +++ b/benchmarks/input_params/deepseek_nsa.csv @@ -0,0 +1,5 @@ +batch,heads,seq_len,dim,is_causal,scale,block_size,groups,selected_blocks,tune +1,64,8192,128,True,0.1,32,16,16,True +1,64,16384,128,True,0.1,32,16,16,True +1,64,32768,128,True,0.1,32,16,16,True +1,64,65536,128,True,0.1,32,16,16,True \ No newline at end of file diff --git a/benchmarks/profile/profile_run.py b/benchmarks/profile/profile_run.py index b0f00c17..ac3172b7 100644 --- a/benchmarks/profile/profile_run.py +++ b/benchmarks/profile/profile_run.py @@ -90,6 +90,36 @@ def build_gqa_decode_cmd(args_dict): return cmd_args +def build_nsa_cmd(args_dict): + """ + Build command arguments for Native Sparse Attention test script + """ + cmd_args = [ + '--batch', + str(args_dict['batch']), + '--heads', + str(args_dict['heads']), + '--seq_len', + str(args_dict['seq_len']), + '--dim', + str(args_dict['dim']), + '--scale', + str(args_dict.get('scale', 0.1)), + '--block_size', + str(args_dict['block_size']), + '--groups', + str(args_dict['groups']), + '--selected_blocks', + str(args_dict['selected_blocks']), + ] + + if args_dict.get('is_causal', 'True').lower() == 'true': + cmd_args.append('--is_causal') + if args_dict.get('tune', 'False').lower() == 'true': + cmd_args.append('--tune') + return cmd_args + + def build_mla_decode_cmd(args_dict): """ Build command arguments for MLA decode test script @@ -196,6 +226,8 @@ def run_test_script(script_path, args_dict): cmd_args = build_mla_decode_cmd(args_dict) elif 'sparse_mla' in script_name: cmd_args = build_sparse_mla_cmd(args_dict) + elif 'deepseek_nsa' in script_name or 'nsa' in script_name: + cmd_args = build_nsa_cmd(args_dict) elif 'mha' in script_name: cmd_args = build_mha_cmd(args_dict) elif 'gqa' in script_name: diff --git a/test_tileops.py b/test_tileops.py new file mode 100644 index 00000000..2b695875 --- /dev/null +++ b/test_tileops.py @@ -0,0 +1,33 @@ +import torch +from top import MLAKernel + +device = "cuda" +dtype = torch.float16 + +batch = 128 +heads = 64 +kv_heads = 1 +kv_ctx = 8192 +dim = 512 +pe_dim = 64 + +# Query input: [batch, heads, dim] +q = torch.randn(batch, heads, dim, device=device, dtype=dtype) + +# Query positional encoding: [batch, heads, pe_dim] +q_pe = torch.randn(batch, heads, pe_dim, device=device, dtype=dtype) + +# KV cache input: [batch, kv_ctx, kv_heads, dim] +kv = torch.randn(batch, kv_ctx, kv_heads, dim, device=device, dtype=dtype) + +# KV positional encoding: [batch, kv_ctx, kv_heads, pe_dim] +k_pe = torch.randn(batch, kv_ctx, kv_heads, pe_dim, device=device, dtype=dtype) + +# Use MLA kernel +block_N = 64 +block_H = 64 +num_split = 1 + +mla = MLAKernel(batch, heads, kv_heads, kv_ctx, dim, pe_dim, block_N, block_H, num_split) + +out = mla(q, q_pe, kv, k_pe) diff --git a/tests/functions/test_deepseek_nsa_func.py b/tests/functions/test_deepseek_nsa_func.py new file mode 100644 index 00000000..4628ff3a --- /dev/null +++ b/tests/functions/test_deepseek_nsa_func.py @@ -0,0 +1,82 @@ +import argparse +import pytest +import torch + +from top.functions import NativeSparseAttentionFunc +from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark + + +@pytest.fixture(autouse=True) +def setup() -> None: + """Set up the test environment.""" + torch.manual_seed(1234) + + +@pytest.mark.parametrize( + "batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune", + [ + # default configuration + (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 8, 128, True, 0.1, 32, 16, 16, True), + (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + ], +) +def test_nsa_func( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune, +): + func = NativeSparseAttentionFunc( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, + block_size, groups, selected_blocks) + + inputs = benchmark.gen_inputs() + benchmark.check(func, *inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=2, help='batch size') + parser.add_argument('--heads', type=int, default=16 * 4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192 * 3, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='head dim') + parser.add_argument( + '--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument('--scale', type=float, default=0.1, help='scale') + parser.add_argument('--block_size', type=int, default=32, help='block size') + parser.add_argument('--groups', type=int, default=16, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=16, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') + args = parser.parse_args() + + test_nsa_func( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.scale, + args.block_size, + args.groups, + args.selected_blocks, + args.tune, + ) diff --git a/tests/layers/test_deepseek_nsa_layer.py b/tests/layers/test_deepseek_nsa_layer.py new file mode 100644 index 00000000..cc12b4cf --- /dev/null +++ b/tests/layers/test_deepseek_nsa_layer.py @@ -0,0 +1,81 @@ +import argparse +from top.layers import NativeSparseAttentionLayer +from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark +import pytest +import torch + + +@pytest.fixture(autouse=True) +def setup() -> None: + """Set up the test environment.""" + torch.manual_seed(1234) + + +@pytest.mark.parametrize( + "batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune", + [ + # default configuration + (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 8, 128, True, 0.1, 32, 16, 16, True), + (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + ], +) +def test_nsa_layer( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune, +): + layer = NativeSparseAttentionLayer( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, + block_size, groups, selected_blocks) + + inputs = benchmark.gen_inputs() + benchmark.check(layer, *inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=1, help='batch size') + parser.add_argument('--heads', type=int, default=16 * 4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192 * 4, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='head dim') + parser.add_argument( + '--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument('--scale', type=float, default=0.1, help='scale') + parser.add_argument('--block_size', type=int, default=32, help='block size') + parser.add_argument('--groups', type=int, default=16, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=16, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') + args = parser.parse_args() + + test_nsa_layer( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.scale, + args.block_size, + args.groups, + args.selected_blocks, + args.tune, + ) diff --git a/tests/ops/test_deepseek_nsa_ops.py b/tests/ops/test_deepseek_nsa_ops.py new file mode 100644 index 00000000..7b9a03f1 --- /dev/null +++ b/tests/ops/test_deepseek_nsa_ops.py @@ -0,0 +1,106 @@ +"""Test NativeSparseAttention operation.""" +import argparse +import pytest +import torch + +from top.ops import NativeSparseAttentionForwardOp +from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark + + +@pytest.fixture(autouse=True) +def setup() -> None: + """Set up the test environment.""" + torch.manual_seed(1234) + + +@pytest.mark.parametrize( + "batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune", + [ + # default configuration + (1, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 2, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 4, 128, True, 0.1, 32, 16, 16, True), + (1, 64, 8192 * 8, 128, True, 0.1, 32, 16, 16, True), + (16, 64, 8192, 128, True, 0.1, 32, 16, 16, True), + # (16, 64, 8192*2, 128, True, 0.1, 32, 16, 16, True), + # (16, 64, 8192*4, 128, True, 0.1, 32, 16, 16, True), + # (16, 64, 8192*8, 128, True, 0.1, 32, 16, 16, True), + # small batch size configuration + (1, 16, 1024, 128, True, 0.1, 32, 16, 16, True), + (4, 32, 2048, 128, True, 0.1, 32, 16, 16, True), + # different sequence length + (8, 32, 4096, 128, True, 0.1, 32, 16, 16, True), + (8, 32, 16384, 128, True, 0.1, 32, 16, 16, True), + # different block_size + (8, 32, 4096, 128, True, 0.1, 64, 16, 16, True), + (8, 32, 4096, 128, True, 0.1, 16, 16, 16, True), + # different groups + (8, 32, 4096, 128, True, 0.1, 32, 32, 16, True), + # different selected_blocks + (8, 32, 4096, 128, True, 0.1, 32, 16, 8, True), + (8, 32, 4096, 128, True, 0.1, 32, 16, 32, True), + # different scale + (8, 32, 4096, 128, True, 0.05, 32, 16, 16, True), + (8, 32, 4096, 128, True, 0.2, 32, 16, 16, True), + ], +) +def test_nsa_op( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune, +): + """Test NativeSparseAttention forward operation with various configurations.""" + op = NativeSparseAttentionForwardOp( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale, + block_size, groups, selected_blocks) + + inputs = benchmark.gen_inputs() + benchmark.check(op, *inputs) + # benchmark.profile(op, *inputs) + # benchmark.baseline_profile(*inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=16, help='batch size') + parser.add_argument('--heads', type=int, default=16 * 4, help='number of heads') + parser.add_argument('--seq_len', type=int, default=8192 * 1, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='head dim') + parser.add_argument( + '--is_causal', action='store_true', default=True, help='enable causal attention') + parser.add_argument('--scale', type=float, default=0.1, help='scale') + parser.add_argument('--block_size', type=int, default=32, help='block size') + parser.add_argument('--groups', type=int, default=16, help='number of groups') + parser.add_argument('--selected_blocks', type=int, default=16, help='number of selected blocks') + parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') + args = parser.parse_args() + + test_nsa_op( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.scale, + args.block_size, + args.groups, + args.selected_blocks, + args.tune, + ) diff --git a/tests/ops/test_mean_pooling_ops.py b/tests/ops/test_mean_pooling_ops.py new file mode 100644 index 00000000..aab7509b --- /dev/null +++ b/tests/ops/test_mean_pooling_ops.py @@ -0,0 +1,37 @@ +import argparse +from top.ops import MeanPoolingForwardOp +from benchmarks.deepseek_nsa.deepseek_nsa import MeanPoolingForwardBenchmark + + +def test_mean_pooling_op(batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=True): + op = MeanPoolingForwardOp( + batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=tune) + + benchmark = MeanPoolingForwardBenchmark( + batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=tune) + + inputs = benchmark.gen_inputs() + benchmark.check(op, *inputs) + benchmark.profile(op, *inputs) + benchmark.baseline_profile(*inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=1, help='logical batch size') + parser.add_argument('--total_seqlen', type=int, default=1 * 8192 * 1, help='number of heads') + parser.add_argument('--total_chunks', type=int, default=1 * 256 * 1, help='sequence length') + parser.add_argument('--heads', type=int, default=128, help='head dim') + parser.add_argument('--dim', type=int, default=128, help='scale') + parser.add_argument('--chunk_size', type=int, default=32, help='scale') + parser.add_argument('--tune', action='store_true', default=True, help='enable autotune') + args = parser.parse_args() + test_mean_pooling_op( + args.batch_size, + args.total_seqlen, + args.total_chunks, + args.heads, + args.dim, + args.chunk_size, + args.tune, + ) diff --git a/top/functions/__init__.py b/top/functions/__init__.py index 68370615..e79e0892 100644 --- a/top/functions/__init__.py +++ b/top/functions/__init__.py @@ -6,6 +6,7 @@ from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheFunc from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheFunc from .matmul import MatMulFunc +from .deepseek_nsa import NativeSparseAttentionFunc __all__ = [ "Function", @@ -16,4 +17,5 @@ "MultiHeadLatentAttentionDecodeWithKVCacheFunc", "DeepSeekSparseAttentionDecodeWithKVCacheFunc", "MatMulFunc", + "NativeSparseAttentionFunc", ] diff --git a/top/functions/deepseek_nsa.py b/top/functions/deepseek_nsa.py new file mode 100644 index 00000000..5edfe64a --- /dev/null +++ b/top/functions/deepseek_nsa.py @@ -0,0 +1,63 @@ +import torch +from top.functions.function import Function +from top.ops.deepseek_nsa import NativeSparseAttentionForwardOp + +__all__ = ['NativeSparseAttentionFunc'] + + +class nsa_decode_ctx(torch.autograd.Function): + + @staticmethod + def forward(ctx, Q, K, V, BlockIndices, fwd_op): + O = fwd_op(Q, K, V, BlockIndices) + return O + + @staticmethod + def backward(ctx, dO): + raise NotImplementedError("Backward pass is not implemented for nsa.") + + @staticmethod + def decode(ctx, dO): + raise NotImplementedError("Decode pass is not implemented for nsa.") + + +class NativeSparseAttentionFunc(Function): + + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + tune=False): + + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + self.tune = tune + + self.fwd_op = NativeSparseAttentionForwardOp( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor) -> torch.Tensor: + return nsa_decode_ctx.apply(Q, K, V, BlockIndices, self.fwd_op) diff --git a/top/kernels/deepseek_nsa/__init__.py b/top/kernels/deepseek_nsa/__init__.py new file mode 100644 index 00000000..d28c4833 --- /dev/null +++ b/top/kernels/deepseek_nsa/__init__.py @@ -0,0 +1,2 @@ +from .nsa_fwd import * +from .mean_pooling_fwd import * \ No newline at end of file diff --git a/top/kernels/deepseek_nsa/mean_pooling_fwd.py b/top/kernels/deepseek_nsa/mean_pooling_fwd.py new file mode 100644 index 00000000..17d70fe2 --- /dev/null +++ b/top/kernels/deepseek_nsa/mean_pooling_fwd.py @@ -0,0 +1,170 @@ +import torch + +from typing import Optional +from top.kernels.kernel import Kernel +import itertools + +import tilelang +import tilelang.language as T + +__all__ = ["mean_pooling_fwd_kernel"] + + +def _mean_pooling_kernel( + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, +): + dtype = T.float16 + accum_dtype = T.float32 + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + def _mean_pooling_func(block_D, threads): + + ND = T.ceildiv(dim, block_D) + + x_shape = [total_seqlen, heads, dim] + cu_seqlens_shape = [batch_size + 1] + chunk_indices_shape = [total_chunks, 2] + output_shape = [total_chunks, heads, dim] + + @T.prim_func + def _mean_pooling_main( + X_unpad: T.Tensor(x_shape, dtype), + cu_seqlens: T.Tensor(cu_seqlens_shape, T.int32), + chunk_indices: T.Tensor(chunk_indices_shape, T.int32), + Output: T.Tensor(output_shape, dtype), + ): + with T.Kernel(ND, total_chunks, heads, threads=threads) as (i_d, i_t, i_h): + accum = T.alloc_fragment([block_D], accum_dtype) + d_start = i_d * block_D + + seq_id = chunk_indices[i_t, 0] + local_chunk_id = chunk_indices[i_t, 1] + start = cu_seqlens[seq_id] + end = cu_seqlens[seq_id + 1] + seqlen = end - start + + chunk_start = local_chunk_id * chunk_size + chunk_end = T.min(chunk_start + chunk_size, seqlen) + actual_bt = chunk_end - chunk_start + + for d in T.Parallel(block_D): + accum[d] = T.cast(0, accum_dtype) + for t_rel in T.serial(actual_bt): + t_abs = start + chunk_start + t_rel + for d in T.Parallel(block_D): + if d_start + d < dim: + accum[d] += T.cast(X_unpad[t_abs, i_h, d_start + d], accum_dtype) + for d in T.Parallel(block_D): + if d_start + d < dim: + Output[i_t, i_h, + d_start + d] = T.cast(accum[d] / T.cast(actual_bt, accum_dtype), + dtype) + + return _mean_pooling_main + + return _mean_pooling_func + + +@torch.library.custom_op("top::mean_pooling_fwd_wrapped_kernel", mutates_args=()) +def _mean_pooling_wrapped_kernel( + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + block_D: int, + threads: int, + x_unpad: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_indices: torch.Tensor, +) -> torch.Tensor: + return _mean_pooling_kernel( + batch_size, + total_seqlen, + total_chunks, + heads, + dim, + chunk_size, + )(block_D, threads)(x_unpad, cu_seqlens, chunk_indices) + + +@_mean_pooling_wrapped_kernel.register_fake +def _( + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + block_D: int, + threads: int, + *inputs +) -> torch.Tensor: + fake_o = torch.empty_like(inputs[0]) + return fake_o + + +class mean_pooling_fwd_kernel(Kernel): + supported_archs: list[int] = [80, 89, 90, 100] + + def __init__(self, + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + config: Optional[dict] = None, + tune=False): + super().__init__() + self.batch_size = batch_size + self.total_seqlen = total_seqlen + self.total_chunks = total_chunks + self.heads = heads + self.dim = dim + self.chunk_size = chunk_size + + self.kernel = _mean_pooling_kernel( + self.batch_size, + self.total_seqlen, + self.total_chunks, + self.heads, + self.dim, + self.chunk_size, + ) + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_D": min(64, self.dim), + "threads": 128, + } + + @property + def autotune_configs(self) -> list[dict]: + block_D = [32, 64, 128] + threads = [32, 64, 128] + _configs = list(itertools.product(block_D, threads)) + configs = [{"block_D": c[0], "threads": c[1]} for c in _configs] + return configs + + def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): + return _mean_pooling_wrapped_kernel(self.batch_size, self.total_seqlen, self.total_chunks, + self.heads, self.dim, self.chunk_size, + self.config["block_D"], self.config["threads"], x_unpad, + cu_seqlens, chunk_indices) diff --git a/top/kernels/deepseek_nsa/nsa_fwd.py b/top/kernels/deepseek_nsa/nsa_fwd.py new file mode 100644 index 00000000..258251dd --- /dev/null +++ b/top/kernels/deepseek_nsa/nsa_fwd.py @@ -0,0 +1,240 @@ +import tilelang +import tilelang.language as T +from typing import Optional +from top.kernels.kernel import Kernel +import itertools +import torch + +__all__ = ["nsa_fwd_kernel"] + + +# dtype default float16, accum_dtype default float32 +def _nsa_fwd_kernel(batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16): + + scale = (1.0 / dim)**0.5 * 1.44269504 if scale is None else scale * 1.44269504 + + head_kv = heads // groups + + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + + block_S = block_size + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + def _nsa_fwd_func(block_T, num_stages, threads): + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, f"The head dimension (dim={dim}) cannot be larger than block_T ({block_T}). " \ + f"This kernel processes Q and K in a single block, so dim must be <= block_T." + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + + @T.prim_func + def _nsa_fwd_main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_fragment([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_t, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) + + if is_causal: + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - + scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(G, BV): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV]) + + return _nsa_fwd_main + + return _nsa_fwd_func + + +@torch.library.custom_op("top::nsa_fwd_wrapped_kernel", mutates_args=()) +def _nsa_fwd_wrapped_kernel( + batch: int, + heads: int, + seq_len: int, + dim: int, + is_causal: bool, + scale: float, + block_size: int, + groups: int, + selected_blocks: int, + block_T: int, + num_stages: int, + threads: int, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + BlockIndices: torch.Tensor, +) -> torch.Tensor: + return _nsa_fwd_kernel(batch, heads, seq_len, dim, is_causal, scale, block_size, groups, + selected_blocks)(block_T, num_stages, threads)(Q, K, V, BlockIndices) + + +@_nsa_fwd_wrapped_kernel.register_fake +def _( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + block_T, + num_stages, + threads, + *inputs +) -> torch.Tensor: + fake_o = torch.empty_like(inputs[0]) + return fake_o + + +class nsa_fwd_kernel(Kernel): + supported_archs: list[int] = [80, 89, 90, 100] + + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + config: Optional[dict] = None, + tune=False): + + super().__init__() + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + + self.kernel = _nsa_fwd_kernel(self.batch, self.heads, self.seq_len, self.dim, + self.is_causal, self.scale, self.block_size, self.groups, + self.selected_blocks) + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_T": min(128, tilelang.math.next_power_of_2(self.dim)), + "num_stages": 2, + "threads": 32, + } + + @property + def autotune_configs(self) -> list[dict]: + block_T = [32, 64, 128] + num_stages = [2, 3] + threads = [32, 64, 128] + _configs = list(itertools.product(block_T, num_stages, threads)) + configs = [{"block_T": c[0], "num_stages": c[1], "threads": c[2]} for c in _configs] + return configs + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor): + return _nsa_fwd_wrapped_kernel(self.batch, self.heads, self.seq_len, self.dim, + self.is_causal, self.scale, self.block_size, self.groups, + self.selected_blocks, self.config["block_T"], + self.config["num_stages"], self.config["threads"], Q, K, V, + BlockIndices) diff --git a/top/layers/__init__.py b/top/layers/__init__.py index f3fc167b..4bbb265a 100644 --- a/top/layers/__init__.py +++ b/top/layers/__init__.py @@ -2,9 +2,10 @@ from .flash_decode import MultiHeadAttentionDecodeLayer, GroupQueryAttentionDecodeLayer from .deepseek_mla import MultiHeadLatentAttentionDecodeLayer, DeepSeekSparseAttentionDecodeLayer from .linear import LinearLayer +from .deepseek_nsa import NativeSparseAttentionLayer __all__ = [ "MultiHeadAttentionLayer", "GroupQueryAttentionLayer", "MultiHeadAttentionDecodeLayer", "GroupQueryAttentionDecodeLayer", "MultiHeadLatentAttentionDecodeLayer", - "DeepSeekSparseAttentionDecodeLayer", "LinearLayer" + "DeepSeekSparseAttentionDecodeLayer", "LinearLayer", "NativeSparseAttentionLayer" ] diff --git a/top/layers/deepseek_nsa.py b/top/layers/deepseek_nsa.py new file mode 100644 index 00000000..26cfa295 --- /dev/null +++ b/top/layers/deepseek_nsa.py @@ -0,0 +1,46 @@ +import torch +from torch import nn +from top.functions import NativeSparseAttentionFunc + + +class NativeSparseAttentionLayer(nn.Module): + + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + tune=False): + super().__init__() + + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + self.tune = tune + + self.fn = NativeSparseAttentionFunc( + batch, + heads, + seq_len, + dim, + is_causal, + scale, + block_size, + groups, + selected_blocks, + tune=tune) + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor) -> torch.Tensor: + return self.fn(Q, K, V, BlockIndices) diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 61d96403..e44d82c3 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -6,16 +6,12 @@ from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheOp +from .deepseek_nsa import NativeSparseAttentionForwardOp, MeanPoolingForwardOp __all__ = [ - "Op", - "MultiHeadAttentionFwdOp", - "MultiHeadAttentionBwdOp", - "GroupQueryAttentionFwdOp", - "GroupQueryAttentionBwdOp", - "GemmOp", - "MultiHeadAttentionDecodeWithKVCacheOp", - "GroupQueryAttentionDecodeWithKVCacheOp", - "MultiHeadLatentAttentionDecodeWithKVCacheOp", - "DeepSeekSparseAttentionDecodeWithKVCacheOp", + "Op", "MultiHeadAttentionFwdOp", "MultiHeadAttentionBwdOp", "GroupQueryAttentionFwdOp", + "GroupQueryAttentionBwdOp", "GemmOp", "MultiHeadAttentionDecodeWithKVCacheOp", + "GroupQueryAttentionDecodeWithKVCacheOp", "MultiHeadLatentAttentionDecodeWithKVCacheOp", + "DeepSeekSparseAttentionDecodeWithKVCacheOp", "NativeSparseAttentionForwardOp", + "MeanPoolingForwardOp" ] diff --git a/top/ops/deepseek_nsa.py b/top/ops/deepseek_nsa.py new file mode 100644 index 00000000..b5c19a73 --- /dev/null +++ b/top/ops/deepseek_nsa.py @@ -0,0 +1,93 @@ +import torch +from top.ops.op import Op +from top.kernels.kernel import Kernel +from top.kernels.deepseek_nsa.nsa_fwd import nsa_fwd_kernel +from top.kernels.deepseek_nsa.mean_pooling_fwd import mean_pooling_fwd_kernel +from typing import Optional, Dict + +__all__ = ["NativeSparseAttentionForwardOp", "MeanPoolingForwardOp"] + + +class MeanPoolingForwardOp(Op): + + def __init__(self, + batch_size: int, + total_seqlen: int, + total_chunks: int, + heads: int, + dim: int, + chunk_size: int, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune=False) -> torch.Tensor: + self.batch_size = batch_size + self.total_seqlen = total_seqlen + self.total_chunks = total_chunks + self.heads = heads + self.dim = dim + self.chunk_size = chunk_size + self.tune = tune + + self.dispatch_kernel(kernel_map) + + self.kernel = self.kernel_map["mean_pooling_fwd_kernel"]( + batch_size=self.batch_size, + total_seqlen=self.total_seqlen, + total_chunks=self.total_chunks, + heads=self.heads, + dim=self.dim, + chunk_size=self.chunk_size, + tune=self.tune) + + @property + def default_kernel_map(self): + return {"mean_pooling_fwd_kernel": mean_pooling_fwd_kernel} + + def forward(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor, chunk_indices: torch.Tensor): + return self.kernel(x_unpad, cu_seqlens, chunk_indices) + + +class NativeSparseAttentionForwardOp(Op): + + def __init__(self, + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune=False): + self.batch = batch + self.heads = heads + self.seq_len = seq_len + self.dim = dim + self.is_causal = is_causal + self.scale = scale + self.block_size = block_size + self.groups = groups + self.selected_blocks = selected_blocks + self.tune = tune + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["nsa_fwd_kernel"]( + self.batch, + self.heads, + self.seq_len, + self.dim, + self.is_causal, + self.scale, + self.block_size, + self.groups, + self.selected_blocks, + tune=self.tune) + + @property + def default_kernel_map(self): + return {"nsa_fwd_kernel": nsa_fwd_kernel} + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + BlockIndices: torch.Tensor): + return self.kernel(Q, K, V, BlockIndices)