Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .benchmark import Benchmark # noqa: F401
from .deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark
from .flash_attn import MultiHeadAttentionBenchmark, MultiHeadAttentionBwdBenchmark, MultiHeadAttentionFwdBenchmark, GroupQueryAttentionBenchmark, GroupQueryAttentionFwdBenchmark, GroupQueryAttentionBwdBenchmark
from .gemm import GemmBenchmark, MatMulBenchmark
from .flash_decode import MultiHeadAttentionDecodeBenchmark, GroupQueryAttentionDecodeBenchmark
from .deepseek_mla import MultiHeadLatentAttentionDecodeBenchmark, DeepSeekSparseAttentionDecodeBenchmark

__all__ = [
'Benchmark',
'NativeSparseAttentionForwardBenchmark',
'MultiHeadAttentionBenchmark',
'MultiHeadAttentionBwdBenchmark',
'MultiHeadAttentionFwdBenchmark',
Expand Down
5 changes: 5 additions & 0 deletions benchmarks/deepseek_nsa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .deepseek_nsa import NativeSparseAttentionForwardBenchmark

__all__ = [
"NativeSparseAttentionForwardBenchmark",
]
203 changes: 203 additions & 0 deletions benchmarks/deepseek_nsa/deepseek_nsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from benchmarks.benchmark import Benchmark
from top.ops import NativeSparseAttentionForwardOp
from top.ops import MeanPoolingForwardOp

import torch

from typing import Any
from native_sparse_attention.ops.naive import naive_nsa
from native_sparse_attention.ops.parallel import parallel_nsa_fwd
from fla.ops.utils import mean_pooling

from fla.ops.common.utils import prepare_chunk_indices


class NativeSparseAttentionForwardBenchmark(Benchmark):
op_type = NativeSparseAttentionForwardOp

def __init__(self,
batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16,
tune=False):
self.batch = batch
self.heads = heads
self.seq_len = seq_len
self.dim = dim
self.is_causal = is_causal
self.scale = scale
self.block_size = block_size
self.groups = groups
self.selected_blocks = selected_blocks

self.head_kv = self.heads // self.groups
self.dtype = torch.float16
self.tune = tune

@property
def total_flops(self):
B = self.batch
T = self.seq_len
HQ = self.heads
D = self.dim
S = self.selected_blocks
BS = self.block_size

window_size = 0
total_keys = S * BS + window_size
flops = 4 * B * T * HQ * D * total_keys
return flops

@property
def total_memory(self):
return (self.batch * self.heads * (2 * self.seq_len) * self.dim * self.dtype.itemsize)

def gen_inputs(self):
Q = torch.randn(
self.batch, self.seq_len, self.heads, self.dim, device='cuda', dtype=self.dtype)
K = torch.randn(
self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype)
V = torch.randn(
self.batch, self.seq_len, self.head_kv, self.dim, device='cuda', dtype=self.dtype)

self.o_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim),
dtype=self.dtype,
device="cuda")
self.lse_slc = torch.empty((self.batch, self.seq_len, self.heads, self.dim),
dtype=torch.float,
device="cuda")

self.g_slc = torch.ones((self.batch, self.seq_len, self.heads),
dtype=self.dtype,
device="cuda").requires_grad_(True)
self.g_swa = torch.ones((self.batch, self.seq_len, self.heads),
dtype=self.dtype,
device="cuda").requires_grad_(True)

block_indices = torch.full((self.batch, self.seq_len, self.head_kv, self.selected_blocks),
self.seq_len,
dtype=torch.long,
device="cuda")
self.block_counts = torch.zeros((self.batch, self.seq_len, self.head_kv),
dtype=torch.long,
device="cuda")
for b in range(self.batch):
for t in range(self.seq_len):
for h in range(self.head_kv):
i_i = torch.randperm(max(1, (t // self.block_size)))[:self.selected_blocks]
block_indices[b, t, h, :len(i_i)] = i_i
self.block_counts[b, t, h] = (block_indices[b, t, h]
!= self.seq_len).sum().item()
block_indices = block_indices.sort(-1)[0].to(torch.int32)
return Q, K, V, block_indices

def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
BlockIndices: torch.Tensor) -> torch.Tensor:
return naive_nsa(
q=Q,
k=K,
v=V,
g_slc=self.g_slc,
g_swa=self.g_swa,
block_indices=BlockIndices.to(torch.long),
block_counts=self.block_counts,
block_size=self.block_size,
scale=self.scale,
)

def baseline_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
BlockIndices: torch.Tensor) -> torch.Tensor:
o, lse = parallel_nsa_fwd(
q=Q,
k=K,
v=V,
block_indices=BlockIndices,
block_counts=self.block_counts,
block_size=self.block_size,
scale=self.scale,
)
return o

def baseline_profile(self,
*inputs: Any,
warmup: int = 100,
rep: int = 100,
device: str = "cuda:0") -> Any:
print("===== Profiling FLA NSA_Fwd backend =====")
return super().baseline_profile(
self.baseline_program, *inputs, backend="FLA", warmup=warmup, rep=rep, device=device)


class MeanPoolingForwardBenchmark(Benchmark):
op_type = MeanPoolingForwardOp

def __init__(self, batch_size, total_seqlen, total_chunks, heads, dim, chunk_size, tune=True):
self.batch_size = batch_size
self.total_seqlen = total_seqlen
self.total_chunks = total_chunks
self.heads = heads
self.dim = dim
self.chunk_size = chunk_size
self.tune = tune
self.dtype = torch.float16

@property
def total_flops(self):
flops = self.heads * self.dim * (self.total_seqlen + self.total_chunks)
return flops

@property
def total_memory(self):
return self.heads * self.dim * (
self.total_seqlen + self.total_chunks) * self.dtype.itemsize + 16 * self.total_chunks

def gen_inputs(self):
x_unpad = torch.randn(
self.total_seqlen, self.heads, self.dim, device='cuda', dtype=self.dtype)
# fixed length
b = self.batch_size
t = self.total_seqlen // b

cu_seqlens = torch.arange(0, (b + 1) * t, t, dtype=torch.int32, device='cuda')
chunk_indices = prepare_chunk_indices(cu_seqlens, self.chunk_size)

return x_unpad, cu_seqlens, chunk_indices

def ref_program(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor,
chunk_indices: torch.Tensor) -> torch.Tensor:
b = self.batch_size
t = self.total_seqlen // b
x = x_unpad.view(b, t, self.heads, self.dim)

return mean_pooling(
x, chunk_size=self.chunk_size, cu_seqlens=None,
head_first=False).view(-1, self.heads, self.dim)

def baseline_program(self, x_unpad: torch.Tensor, cu_seqlens: torch.Tensor,
chunk_indices: torch.Tensor) -> torch.Tensor:
b = self.batch_size
t = self.total_seqlen // b
x = x_unpad.view(b, t, self.heads, self.dim)
return mean_pooling(
x, chunk_size=self.chunk_size, cu_seqlens=None,
head_first=False).view(-1, self.heads, self.dim)

def baseline_profile(self,
*inputs: Any,
warmup: int = 100,
rep: int = 100,
device: str = "cuda:0") -> Any:
print("===== Profiling Mean Pooling_Fwd backend =====")
return super().baseline_profile(
self.baseline_program,
*inputs,
backend="Mean Pooling",
warmup=warmup,
rep=rep,
device=device)
5 changes: 5 additions & 0 deletions benchmarks/input_params/deepseek_nsa.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
batch,heads,seq_len,dim,is_causal,scale,block_size,groups,selected_blocks,tune
1,64,8192,128,True,0.1,32,16,16,True
1,64,16384,128,True,0.1,32,16,16,True
1,64,32768,128,True,0.1,32,16,16,True
1,64,65536,128,True,0.1,32,16,16,True
32 changes: 32 additions & 0 deletions benchmarks/profile/profile_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,36 @@ def build_gqa_decode_cmd(args_dict):
return cmd_args


def build_nsa_cmd(args_dict):
"""
Build command arguments for Native Sparse Attention test script
"""
cmd_args = [
'--batch',
str(args_dict['batch']),
'--heads',
str(args_dict['heads']),
'--seq_len',
str(args_dict['seq_len']),
'--dim',
str(args_dict['dim']),
'--scale',
str(args_dict.get('scale', 0.1)),
'--block_size',
str(args_dict['block_size']),
'--groups',
str(args_dict['groups']),
'--selected_blocks',
str(args_dict['selected_blocks']),
]

if args_dict.get('is_causal', 'True').lower() == 'true':
cmd_args.append('--is_causal')
if args_dict.get('tune', 'False').lower() == 'true':
cmd_args.append('--tune')
return cmd_args


def build_mla_decode_cmd(args_dict):
"""
Build command arguments for MLA decode test script
Expand Down Expand Up @@ -196,6 +226,8 @@ def run_test_script(script_path, args_dict):
cmd_args = build_mla_decode_cmd(args_dict)
elif 'sparse_mla' in script_name:
cmd_args = build_sparse_mla_cmd(args_dict)
elif 'deepseek_nsa' in script_name or 'nsa' in script_name:
cmd_args = build_nsa_cmd(args_dict)
elif 'mha' in script_name:
cmd_args = build_mha_cmd(args_dict)
elif 'gqa' in script_name:
Expand Down
33 changes: 33 additions & 0 deletions test_tileops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from top import MLAKernel

device = "cuda"
dtype = torch.float16

batch = 128
heads = 64
kv_heads = 1
kv_ctx = 8192
dim = 512
pe_dim = 64

# Query input: [batch, heads, dim]
q = torch.randn(batch, heads, dim, device=device, dtype=dtype)

# Query positional encoding: [batch, heads, pe_dim]
q_pe = torch.randn(batch, heads, pe_dim, device=device, dtype=dtype)

# KV cache input: [batch, kv_ctx, kv_heads, dim]
kv = torch.randn(batch, kv_ctx, kv_heads, dim, device=device, dtype=dtype)

# KV positional encoding: [batch, kv_ctx, kv_heads, pe_dim]
k_pe = torch.randn(batch, kv_ctx, kv_heads, pe_dim, device=device, dtype=dtype)

# Use MLA kernel
block_N = 64
block_H = 64
num_split = 1

mla = MLAKernel(batch, heads, kv_heads, kv_ctx, dim, pe_dim, block_N, block_H, num_split)

out = mla(q, q_pe, kv, k_pe)
Comment on lines 1 to 33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This entire file seems to be for testing MLAKernel, which is unrelated to the Native Sparse Attention (NSA) changes in this pull request. It appears to be a temporary test file that was accidentally included. Please remove it to keep the PR focused.

82 changes: 82 additions & 0 deletions tests/functions/test_deepseek_nsa_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse
import pytest
import torch

from top.functions import NativeSparseAttentionFunc
from benchmarks.deepseek_nsa.deepseek_nsa import NativeSparseAttentionForwardBenchmark


@pytest.fixture(autouse=True)
def setup() -> None:
"""Set up the test environment."""
torch.manual_seed(1234)


@pytest.mark.parametrize(
"batch, heads, seq_len, dim, is_causal, scale, block_size, groups, selected_blocks, tune",
[
# default configuration
(1, 64, 8192, 128, True, 0.1, 32, 16, 16, True),
(1, 64, 8192 * 2, 128, True, 0.1, 32, 16, 16, True),
(1, 64, 8192 * 4, 128, True, 0.1, 32, 16, 16, True),
(1, 64, 8192 * 8, 128, True, 0.1, 32, 16, 16, True),
(16, 64, 8192, 128, True, 0.1, 32, 16, 16, True),
],
)
def test_nsa_func(
batch,
heads,
seq_len,
dim,
is_causal,
scale,
block_size,
groups,
selected_blocks,
tune,
):
func = NativeSparseAttentionFunc(
batch,
heads,
seq_len,
dim,
is_causal,
scale,
block_size,
groups,
selected_blocks,
tune=tune)
benchmark = NativeSparseAttentionForwardBenchmark(batch, heads, seq_len, dim, is_causal, scale,
block_size, groups, selected_blocks)

inputs = benchmark.gen_inputs()
benchmark.check(func, *inputs)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=2, help='batch size')
parser.add_argument('--heads', type=int, default=16 * 4, help='number of heads')
parser.add_argument('--seq_len', type=int, default=8192 * 3, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='head dim')
parser.add_argument(
'--is_causal', action='store_true', default=True, help='enable causal attention')
parser.add_argument('--scale', type=float, default=0.1, help='scale')
parser.add_argument('--block_size', type=int, default=32, help='block size')
parser.add_argument('--groups', type=int, default=16, help='number of groups')
parser.add_argument('--selected_blocks', type=int, default=16, help='number of selected blocks')
parser.add_argument('--tune', action='store_true', default=True, help='enable autotune')
args = parser.parse_args()

test_nsa_func(
args.batch,
args.heads,
args.seq_len,
args.dim,
args.is_causal,
args.scale,
args.block_size,
args.groups,
args.selected_blocks,
args.tune,
)
Loading
Loading