11import math
22
3+ import sys
4+ sys .path .append ("./" )
5+
36import pytest
47import torch
58from tests .test_helpers .utils_fp4 import (
@@ -54,8 +57,13 @@ def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len):
5457 return q_lens , in_kv_lens , seq_lens
5558
5659
57- def generate_seq_lens_decode (batch_size , q_len_per_req , max_in_kv_len ):
58- q_lens = torch .full ((batch_size ,), q_len_per_req , dtype = torch .int32 )
60+ def generate_seq_lens_decode (batch_size , q_len_per_req , max_in_kv_len , max_q_len ):
61+ if q_len_per_req is not None :
62+ assert max_q_len is None , "Can not specify both q_len_per_req and max_q_len."
63+ q_lens = torch .full ((batch_size ,), q_len_per_req , dtype = torch .int32 )
64+ else :
65+ assert max_q_len is not None , "Must specify either q_len_per_req or max_q_len."
66+ q_lens = torch .randint (1 , max_q_len + 1 , (batch_size ,), dtype = torch .int32 )
5967 in_kv_lens = torch .randint (0 , max_in_kv_len + 1 , (batch_size ,), dtype = torch .int )
6068 in_kv_lens [- 1 ] = max_in_kv_len
6169 seq_lens = q_lens + in_kv_lens
@@ -746,6 +754,7 @@ def _test_trtllm_batch_decode(
746754 max_in_kv_len ,
747755 head_dim ,
748756 device_scale = False ,
757+ max_q_len = None ,
749758):
750759 """
751760 Common function for testing trtllm-gen decode.
@@ -780,7 +789,7 @@ def _test_trtllm_batch_decode(
780789 # Generate random sequence lengths
781790 num_qo_heads = num_kv_heads * head_grp_size
782791 q_lens , in_kv_lens , seq_lens = generate_seq_lens_decode (
783- batch_size , q_len_per_req , max_in_kv_len
792+ batch_size , q_len_per_req , max_in_kv_len , max_q_len
784793 )
785794
786795 # Create query tensor and related data
@@ -835,7 +844,7 @@ def _test_trtllm_batch_decode(
835844 "window_left" : window_left ,
836845 }
837846 if not enable_sink :
838- if q_len_per_req == 1 :
847+ if q_len_per_req is not None and q_len_per_req == 1 :
839848 wrapper_ref = flashinfer .decode .BatchDecodeWithPagedKVCacheWrapper (
840849 workspace_buffer_ref , kv_layout , use_tensor_cores = True
841850 )
@@ -923,6 +932,9 @@ def _test_trtllm_batch_decode(
923932 q_len_per_req = q_len_per_req ,
924933 o_scale = o_scale ,
925934 mask = mask ,
935+ max_q_len = max_q_len if max_q_len is not None else None ,
936+ cum_seq_lens_q = q_indptr if max_q_len is not None else None ,
937+ cum_seq_lens_kv = kv_indptr if max_q_len is not None else None ,
926938 )
927939 if backend == "trtllm-gen" :
928940 # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
@@ -948,7 +960,7 @@ def _test_trtllm_batch_decode(
948960
949961 # convert to float32 for fp8 is not supported by assert_close
950962 # relax rtol and atol for speculative decoding test
951- if q_len_per_req > 1 :
963+ if ( q_len_per_req and q_len_per_req > 1 ) or ( max_q_len and max_q_len > 1 ) :
952964 rtol , atol = rtol * 2 , atol * 2
953965
954966 # Arbitary small mismatch rate
@@ -1436,5 +1448,92 @@ def test_trtllm_gen_prefill_deepseek_bs1(
14361448 )
14371449
14381450
1451+ def test_trtllm_batch_decode_spec (
1452+ kv_layout ,
1453+ batch_size ,
1454+ max_q_len ,
1455+ page_size ,
1456+ num_kv_heads ,
1457+ head_grp_size ,
1458+ window_left ,
1459+ q_dtype ,
1460+ o_dtype ,
1461+ kv_dtype ,
1462+ enable_pdl ,
1463+ enable_sink ,
1464+ max_in_kv_len ,
1465+ head_dim ,
1466+ ):
1467+ _test_trtllm_batch_decode (
1468+ "trtllm-gen" ,
1469+ kv_layout ,
1470+ batch_size ,
1471+ None , # q_len_per_req
1472+ page_size ,
1473+ num_kv_heads ,
1474+ head_grp_size ,
1475+ window_left ,
1476+ q_dtype ,
1477+ o_dtype ,
1478+ kv_dtype ,
1479+ enable_pdl ,
1480+ enable_sink ,
1481+ max_in_kv_len ,
1482+ head_dim ,
1483+ max_q_len = max_q_len ,
1484+ )
1485+
1486+
14391487if __name__ == "__main__" :
1440- pytest .main ([__file__ ])
1488+ # pytest.main([__file__])
1489+ test_trtllm_batch_decode_spec (
1490+ kv_layout = "HND" ,
1491+ batch_size = 4 ,
1492+ max_q_len = 12 ,
1493+ page_size = 64 ,
1494+ num_kv_heads = 4 ,
1495+ head_grp_size = 1 ,
1496+ window_left = - 1 ,
1497+ q_dtype = "bf16" ,
1498+ kv_dtype = "bf16" ,
1499+ o_dtype = "bf16" ,
1500+ enable_pdl = None ,
1501+ enable_sink = False ,
1502+ max_in_kv_len = 110 ,
1503+ head_dim = 128 ,
1504+ )
1505+ # _test_trtllm_batch_decode(
1506+ # backend='trtllm-gen',
1507+ # kv_layout="HND",
1508+ # batch_size=4,
1509+ # q_len_per_req=3,
1510+ # page_size=64,
1511+ # num_kv_heads=4,
1512+ # head_grp_size=1,
1513+ # window_left=-1,
1514+ # q_dtype="bf16",
1515+ # kv_dtype="bf16",
1516+ # o_dtype="bf16",
1517+ # enable_pdl=None,
1518+ # enable_sink=False,
1519+ # max_in_kv_len=110,
1520+ # head_dim=128,
1521+ # )
1522+
1523+ # _test_trtllm_batch_decode(
1524+ # backend='trtllm-gen',
1525+ # kv_layout="HND",
1526+ # batch_size=4,
1527+ # q_len_per_req=1,
1528+ # page_size=64,
1529+ # num_kv_heads=4,
1530+ # head_grp_size=1,
1531+ # window_left=-1,
1532+ # q_dtype="fp8",
1533+ # kv_dtype="fp8",
1534+ # o_dtype="nvfp4",
1535+ # enable_pdl=None,
1536+ # enable_sink=False,
1537+ # max_in_kv_len=110,
1538+ # head_dim=128,
1539+ # )
0 commit comments