Skip to content

Commit 64d7e6d

Browse files
committed
update and pass existing tests
1 parent 3a23405 commit 64d7e6d

File tree

3 files changed

+65
-17
lines changed

3 files changed

+65
-17
lines changed

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <sstream>
2727
#include <unordered_map>
2828

29+
#include "tvm/ffi/error.h"
2930
#include "tvm_ffi_utils.h"
3031

3132
using tvm::ffi::Optional;
@@ -157,6 +158,9 @@ void trtllm_paged_attention_launcher(
157158
use_multi_block ? TileScheduler::Static : TileScheduler::Persistent;
158159
runner_params.mMultiCtasKvMode = use_multi_block;
159160

161+
runner_params.cumSeqLensQPtr = cum_seq_lens_q;
162+
runner_params.cumSeqLensKvPtr = cum_seq_lens_kv;
163+
160164
size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw
161165
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
162166
size_t num_semaphores =
@@ -209,22 +213,49 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
209213
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
210214
int64_t o_sf_vec_size, int64_t o_sf_start_index,
211215
int64_t window_left, int64_t sm_count, bool enable_pdl,
212-
int64_t workspace_size, Optional<TensorView> attention_sinks) {
216+
int64_t workspace_size, Optional<TensorView> attention_sinks,
217+
Optional<int64_t> optional_max_q_len,
218+
Optional<TensorView> cum_seq_lens_q,
219+
Optional<TensorView> cum_seq_lens_kv
220+
) {
213221
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
214222
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
215223
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
216224
for (int i = 0; i < key_cache.ndim(); i++) {
217225
TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i));
218226
}
219227
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
220-
// NOTE(Zihao): query is [B, Q, H, D]
221-
// where Q is the number of query tokens per request, used in MTP
222-
// based on profiled results, always use decode mode for MTP (q_len is small)
223-
// example: when kv_len = 10000, q < 200, decode mode is faster
224-
int batch_size = query.size(0);
225-
int q_len_per_request = query.size(1);
226-
int sum_seq_q = batch_size * q_len_per_request;
227-
int num_qo_heads = query.size(2);
228+
int batch_size;
229+
int max_q_len;
230+
int sum_seq_q;
231+
int num_qo_heads;
232+
int* cum_seq_lens_q_ptr = nullptr;
233+
int* cum_seq_lens_kv_ptr = nullptr;
234+
if (!optional_max_q_len.has_value()) {
235+
// each request has the same length
236+
237+
// NOTE(Zihao): query is [B, Q, H, D]
238+
// where Q is the number of query tokens per request, used in MTP
239+
// based on profiled results, always use decode mode for MTP (q_len is small)
240+
// example: when kv_len = 10000, q < 200, decode mode is faster
241+
int q_len_per_request = query.size(1);
242+
batch_size = query.size(0);
243+
sum_seq_q = batch_size * q_len_per_request;
244+
num_qo_heads = query.size(2);
245+
max_q_len = q_len_per_request;
246+
} else {
247+
// each request has different length
248+
TVM_FFI_CHECK(cum_seq_lens_q.has_value(), "cum_seq_lens_q must be provided when max_q_len is provided");
249+
TVM_FFI_CHECK(cum_seq_lens_kv.has_value(), "cum_seq_lens_kv must be provided when max_q_len is provided");
250+
// the shape of query: [sum_seq_q, num_qo_heads, head_dim_q]
251+
// the shape of cum_seq_lens_q: [batch_size + 1]
252+
batch_size = cum_seq_lens_q.value().size(0) - 1;
253+
sum_seq_q = query.size(0);
254+
num_qo_heads = query.size(1);
255+
max_q_len = optional_max_q_len.value();
256+
cum_seq_lens_q_ptr = static_cast<int*>(cum_seq_lens_q.value().data_ptr());
257+
cum_seq_lens_kv_ptr = static_cast<int*>(cum_seq_lens_kv.value().data_ptr());
258+
}
228259
// Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even.
229260
int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1);
230261
int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1);
@@ -281,9 +312,9 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
281312
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
282313
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
283314
static_cast<int*>(seq_lens.data_ptr()),
284-
/*cum_seq_lens_q=*/nullptr,
285-
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
286-
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
315+
cum_seq_lens_q_ptr,
316+
cum_seq_lens_kv_ptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
317+
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len,
287318
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
288319
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
289320
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,

flashinfer/decode.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,9 @@ def _paged_run(
19201920
enable_pdl,
19211921
workspace_size,
19221922
sinks,
1923+
None, # max_q_len
1924+
None, # cum_seq_lens_q
1925+
None # cum_seq_lens_kv
19231926
)
19241927
return out
19251928

@@ -2065,7 +2068,7 @@ def trtllm_batch_decode_with_kv_cache(
20652068
workspace_buffer: torch.Tensor,
20662069
block_tables: torch.Tensor,
20672070
seq_lens: torch.Tensor,
2068-
max_seq_len: int,
2071+
max_kv_len: int,
20692072
bmm1_scale: Union[float, torch.Tensor] = 1.0,
20702073
bmm2_scale: Union[float, torch.Tensor] = 1.0,
20712074
window_left: int = -1,
@@ -2079,12 +2082,15 @@ def trtllm_batch_decode_with_kv_cache(
20792082
backend: str = "auto",
20802083
q_len_per_req: Optional[int] = 1,
20812084
o_scale: Optional[float] = 1.0,
2085+
max_q_len: Optional[int] = None,
2086+
cum_seq_lens_q: Optional[torch.Tensor] = None,
2087+
cum_seq_lens_kv: Optional[torch.Tensor] = None,
20822088
) -> Union[torch.Tensor, FP4Tensor]:
20832089
"""
20842090
Parameters
20852091
----------
20862092
query : torch.Tensor
2087-
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request
2093+
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch.
20882094
20892095
kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
20902096
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``,
@@ -2185,6 +2191,10 @@ def trtllm_batch_decode_with_kv_cache(
21852191
raise ValueError("xqa backend does not support nvfp4 output")
21862192
if o_sf_scale is not None or o_sf_vec_size is not None:
21872193
raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size")
2194+
if max_q_len is not None or cum_seq_lens_q is not None or cum_seq_lens_kv is not None:
2195+
raise ValueError(
2196+
"xqa backend does not support cum_seq_lens_q or cum_seq_lens_kv"
2197+
)
21882198

21892199
# Handle out and out_dtype
21902200
if out_dtype is None:
@@ -2199,7 +2209,7 @@ def trtllm_batch_decode_with_kv_cache(
21992209
workspace_buffer=workspace_buffer,
22002210
block_tables=block_tables,
22012211
seq_lens=seq_lens,
2202-
max_seq_len=max_seq_len,
2212+
max_seq_len=max_kv_len,
22032213
bmm1_scale=bmm1_scale,
22042214
bmm2_scale=bmm2_scale,
22052215
window_left=window_left,
@@ -2308,13 +2318,13 @@ def trtllm_batch_decode_with_kv_cache(
23082318
q_len_per_req,
23092319
query.size(1),
23102320
query.size(2),
2311-
),
2321+
) if q_len_per_req is not None else query,
23122322
k_cache,
23132323
v_cache,
23142324
workspace_buffer,
23152325
block_tables,
23162326
seq_lens,
2317-
max_seq_len,
2327+
max_kv_len,
23182328
bmm1_scale,
23192329
bmm2_scale,
23202330
o_sf_scale or -1.0,
@@ -2325,6 +2335,9 @@ def trtllm_batch_decode_with_kv_cache(
23252335
enable_pdl,
23262336
workspace_buffer.numel() * workspace_buffer.element_size(),
23272337
sinks,
2338+
max_q_len,
2339+
cum_seq_lens_q,
2340+
cum_seq_lens_kv,
23282341
)
23292342

23302343
return (

tests/attention/test_trtllm_gen_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,3 +1369,7 @@ def test_trtllm_gen_prefill_deepseek_bs1(
13691369
test_trtllm_gen_prefill_deepseek(
13701370
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
13711371
)
1372+
1373+
1374+
if __name__ == "__main__":
1375+
pytest.main([__file__])

0 commit comments

Comments
 (0)