|
26 | 26 | #include <sstream> |
27 | 27 | #include <unordered_map> |
28 | 28 |
|
| 29 | +#include "tvm/ffi/error.h" |
29 | 30 | #include "tvm_ffi_utils.h" |
30 | 31 |
|
31 | 32 | using tvm::ffi::Optional; |
@@ -157,6 +158,9 @@ void trtllm_paged_attention_launcher( |
157 | 158 | use_multi_block ? TileScheduler::Static : TileScheduler::Persistent; |
158 | 159 | runner_params.mMultiCtasKvMode = use_multi_block; |
159 | 160 |
|
| 161 | + runner_params.cumSeqLensQPtr = cum_seq_lens_q; |
| 162 | + runner_params.cumSeqLensKvPtr = cum_seq_lens_kv; |
| 163 | + |
160 | 164 | size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw |
161 | 165 | size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB |
162 | 166 | size_t num_semaphores = |
@@ -209,22 +213,49 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal |
209 | 213 | Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, |
210 | 214 | int64_t o_sf_vec_size, int64_t o_sf_start_index, |
211 | 215 | int64_t window_left, int64_t sm_count, bool enable_pdl, |
212 | | - int64_t workspace_size, Optional<TensorView> attention_sinks) { |
| 216 | + int64_t workspace_size, Optional<TensorView> attention_sinks, |
| 217 | + Optional<int64_t> optional_max_q_len, |
| 218 | + Optional<TensorView> cum_seq_lens_q, |
| 219 | + Optional<TensorView> cum_seq_lens_kv |
| 220 | + ) { |
213 | 221 | auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); |
214 | 222 | auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); |
215 | 223 | TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); |
216 | 224 | for (int i = 0; i < key_cache.ndim(); i++) { |
217 | 225 | TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i)); |
218 | 226 | } |
219 | 227 | auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); |
220 | | - // NOTE(Zihao): query is [B, Q, H, D] |
221 | | - // where Q is the number of query tokens per request, used in MTP |
222 | | - // based on profiled results, always use decode mode for MTP (q_len is small) |
223 | | - // example: when kv_len = 10000, q < 200, decode mode is faster |
224 | | - int batch_size = query.size(0); |
225 | | - int q_len_per_request = query.size(1); |
226 | | - int sum_seq_q = batch_size * q_len_per_request; |
227 | | - int num_qo_heads = query.size(2); |
| 228 | + int batch_size; |
| 229 | + int max_q_len; |
| 230 | + int sum_seq_q; |
| 231 | + int num_qo_heads; |
| 232 | + int* cum_seq_lens_q_ptr = nullptr; |
| 233 | + int* cum_seq_lens_kv_ptr = nullptr; |
| 234 | + if (!optional_max_q_len.has_value()) { |
| 235 | + // each request has the same length |
| 236 | + |
| 237 | + // NOTE(Zihao): query is [B, Q, H, D] |
| 238 | + // where Q is the number of query tokens per request, used in MTP |
| 239 | + // based on profiled results, always use decode mode for MTP (q_len is small) |
| 240 | + // example: when kv_len = 10000, q < 200, decode mode is faster |
| 241 | + int q_len_per_request = query.size(1); |
| 242 | + batch_size = query.size(0); |
| 243 | + sum_seq_q = batch_size * q_len_per_request; |
| 244 | + num_qo_heads = query.size(2); |
| 245 | + max_q_len = q_len_per_request; |
| 246 | + } else { |
| 247 | + // each request has different length |
| 248 | + TVM_FFI_CHECK(cum_seq_lens_q.has_value(), "cum_seq_lens_q must be provided when max_q_len is provided"); |
| 249 | + TVM_FFI_CHECK(cum_seq_lens_kv.has_value(), "cum_seq_lens_kv must be provided when max_q_len is provided"); |
| 250 | + // the shape of query: [sum_seq_q, num_qo_heads, head_dim_q] |
| 251 | + // the shape of cum_seq_lens_q: [batch_size + 1] |
| 252 | + batch_size = cum_seq_lens_q.value().size(0) - 1; |
| 253 | + sum_seq_q = query.size(0); |
| 254 | + num_qo_heads = query.size(1); |
| 255 | + max_q_len = optional_max_q_len.value(); |
| 256 | + cum_seq_lens_q_ptr = static_cast<int*>(cum_seq_lens_q.value().data_ptr()); |
| 257 | + cum_seq_lens_kv_ptr = static_cast<int*>(cum_seq_lens_kv.value().data_ptr()); |
| 258 | + } |
228 | 259 | // Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even. |
229 | 260 | int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1); |
230 | 261 | int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1); |
@@ -281,9 +312,9 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal |
281 | 312 | out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), |
282 | 313 | workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), |
283 | 314 | static_cast<int*>(seq_lens.data_ptr()), |
284 | | - /*cum_seq_lens_q=*/nullptr, |
285 | | - /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, |
286 | | - TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, |
| 315 | + cum_seq_lens_q_ptr, |
| 316 | + cum_seq_lens_kv_ptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, |
| 317 | + TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, |
287 | 318 | num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, |
288 | 319 | kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, |
289 | 320 | bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, |
|
0 commit comments