Skip to content

Commit b321b27

Browse files
committed
wip
1 parent accdc20 commit b321b27

File tree

1 file changed

+105
-6
lines changed

1 file changed

+105
-6
lines changed

tests/attention/test_trtllm_gen_attention.py

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import math
22

3+
import sys
4+
sys.path.append("./")
5+
36
import pytest
47
import torch
58
from tests.test_helpers.utils_fp4 import (
@@ -54,8 +57,13 @@ def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len):
5457
return q_lens, in_kv_lens, seq_lens
5558

5659

57-
def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len):
58-
q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32)
60+
def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len, max_q_len):
61+
if q_len_per_req is not None:
62+
assert max_q_len is None, "Can not specify both q_len_per_req and max_q_len."
63+
q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32)
64+
else:
65+
assert max_q_len is not None, "Must specify either q_len_per_req or max_q_len."
66+
q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32)
5967
in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int)
6068
in_kv_lens[-1] = max_in_kv_len
6169
seq_lens = q_lens + in_kv_lens
@@ -746,6 +754,7 @@ def _test_trtllm_batch_decode(
746754
max_in_kv_len,
747755
head_dim,
748756
device_scale=False,
757+
max_q_len=None,
749758
):
750759
"""
751760
Common function for testing trtllm-gen decode.
@@ -780,7 +789,7 @@ def _test_trtllm_batch_decode(
780789
# Generate random sequence lengths
781790
num_qo_heads = num_kv_heads * head_grp_size
782791
q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode(
783-
batch_size, q_len_per_req, max_in_kv_len
792+
batch_size, q_len_per_req, max_in_kv_len, max_q_len
784793
)
785794

786795
# Create query tensor and related data
@@ -835,7 +844,7 @@ def _test_trtllm_batch_decode(
835844
"window_left": window_left,
836845
}
837846
if not enable_sink:
838-
if q_len_per_req == 1:
847+
if q_len_per_req is not None and q_len_per_req == 1:
839848
wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
840849
workspace_buffer_ref, kv_layout, use_tensor_cores=True
841850
)
@@ -923,6 +932,9 @@ def _test_trtllm_batch_decode(
923932
q_len_per_req=q_len_per_req,
924933
o_scale=o_scale,
925934
mask=mask,
935+
max_q_len=max_q_len if max_q_len is not None else None,
936+
cum_seq_lens_q=q_indptr if max_q_len is not None else None,
937+
cum_seq_lens_kv=kv_indptr if max_q_len is not None else None,
926938
)
927939
if backend == "trtllm-gen":
928940
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
@@ -948,7 +960,7 @@ def _test_trtllm_batch_decode(
948960

949961
# convert to float32 for fp8 is not supported by assert_close
950962
# relax rtol and atol for speculative decoding test
951-
if q_len_per_req > 1:
963+
if (q_len_per_req and q_len_per_req > 1) or (max_q_len and max_q_len > 1):
952964
rtol, atol = rtol * 2, atol * 2
953965

954966
# Arbitary small mismatch rate
@@ -1436,5 +1448,92 @@ def test_trtllm_gen_prefill_deepseek_bs1(
14361448
)
14371449

14381450

1451+
def test_trtllm_batch_decode_spec(
1452+
kv_layout,
1453+
batch_size,
1454+
max_q_len,
1455+
page_size,
1456+
num_kv_heads,
1457+
head_grp_size,
1458+
window_left,
1459+
q_dtype,
1460+
o_dtype,
1461+
kv_dtype,
1462+
enable_pdl,
1463+
enable_sink,
1464+
max_in_kv_len,
1465+
head_dim,
1466+
):
1467+
_test_trtllm_batch_decode(
1468+
"trtllm-gen",
1469+
kv_layout,
1470+
batch_size,
1471+
None, # q_len_per_req
1472+
page_size,
1473+
num_kv_heads,
1474+
head_grp_size,
1475+
window_left,
1476+
q_dtype,
1477+
o_dtype,
1478+
kv_dtype,
1479+
enable_pdl,
1480+
enable_sink,
1481+
max_in_kv_len,
1482+
head_dim,
1483+
max_q_len=max_q_len,
1484+
)
1485+
1486+
14391487
if __name__ == "__main__":
1440-
pytest.main([__file__])
1488+
# pytest.main([__file__])
1489+
test_trtllm_batch_decode_spec(
1490+
kv_layout="HND",
1491+
batch_size=4,
1492+
max_q_len=12,
1493+
page_size=64,
1494+
num_kv_heads=4,
1495+
head_grp_size=1,
1496+
window_left=-1,
1497+
q_dtype="bf16",
1498+
kv_dtype="bf16",
1499+
o_dtype="bf16",
1500+
enable_pdl=None,
1501+
enable_sink=False,
1502+
max_in_kv_len=110,
1503+
head_dim=128,
1504+
)
1505+
# _test_trtllm_batch_decode(
1506+
# backend='trtllm-gen',
1507+
# kv_layout="HND",
1508+
# batch_size=4,
1509+
# q_len_per_req=3,
1510+
# page_size=64,
1511+
# num_kv_heads=4,
1512+
# head_grp_size=1,
1513+
# window_left=-1,
1514+
# q_dtype="bf16",
1515+
# kv_dtype="bf16",
1516+
# o_dtype="bf16",
1517+
# enable_pdl=None,
1518+
# enable_sink=False,
1519+
# max_in_kv_len=110,
1520+
# head_dim=128,
1521+
# )
1522+
1523+
# _test_trtllm_batch_decode(
1524+
# backend='trtllm-gen',
1525+
# kv_layout="HND",
1526+
# batch_size=4,
1527+
# q_len_per_req=1,
1528+
# page_size=64,
1529+
# num_kv_heads=4,
1530+
# head_grp_size=1,
1531+
# window_left=-1,
1532+
# q_dtype="fp8",
1533+
# kv_dtype="fp8",
1534+
# o_dtype="nvfp4",
1535+
# enable_pdl=None,
1536+
# enable_sink=False,
1537+
# max_in_kv_len=110,
1538+
# head_dim=128,
1539+
# )

0 commit comments

Comments
 (0)