@@ -113,6 +113,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
113113 self .num_head = vllm_config .model_config .get_num_kv_heads (
114114 vllm_config .parallel_config
115115 )
116+ self .chunk_size = 256
117+ self .blocks_per_chunk = self .chunk_size // self .block_size
116118 self .head_size = vllm_config .model_config .get_head_size ()
117119 if (
118120 self ._vllm_config .kv_transfer_config is not None
@@ -139,8 +141,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
139141 config_base
140142 * self .num_layers
141143 * (1 if self .is_mla else self .num_head * self .total_tp_size * 2 )
142- )
143- config ["io_size" ] = config_base * (1 if self .is_mla else self .num_head )
144+ ) * self .blocks_per_chunk
145+ self .io_size = config_base * (1 if self .is_mla else self .num_head ) * self .blocks_per_chunk
146+ config ["io_size" ] = self .io_size
144147 logger .info (
145148 "kv_block_size = %d, io_size = %d," ,
146149 config ["kv_block_size" ],
@@ -160,8 +163,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
160163 "use_layerwise"
161164 ]
162165 )
163- self .chunk_size = 256
164- self .blocks_per_chunk = self .chunk_size // self .block_size
165166
166167 def _init_kv_caches_from_forward_context (self , forward_context : "ForwardContext" ):
167168 for layer_name in forward_context .no_compile_layers :
@@ -207,25 +208,54 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v):
207208 layer_size * layer_id + layer_size / self .total_tp_size * self .rank
208209 )
209210
210- def get_tensor_and_offset_layerwise (
211+ def get_pointers_and_offset_layerwise (
211212 self , vllm_block_ids_tensors : List [torch .Tensor ], kv_layer : torch .Tensor , layer_name : str
213+ ) -> tuple [List [List [int ]], List [int ]]:
214+ k_pointer_lists = []
215+ k_offsets = []
216+ v_pointer_lists = []
217+ v_offsets = []
218+ layer_id = self ._extract_layer_index (layer_name )
219+
220+ for vllm_block_ids_tensor in vllm_block_ids_tensors :
221+ vllm_block_ids = vllm_block_ids_tensor .tolist ()
222+ k_pointer_list = []
223+ k_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , False )
224+ for vllm_block_id in vllm_block_ids :
225+ if self .is_mla :
226+ k_pointer_list .append (kv_layer [vllm_block_id ].data_ptr ())
227+ else :
228+ k_pointer_list .append (kv_layer [0 ][vllm_block_id ].data_ptr ())
229+ k_pointer_lists .append (k_pointer_list )
230+ k_offsets .append (k_data_offset )
231+ if not self .is_mla :
232+ v_pointer_list = []
233+ v_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , True )
234+ for vllm_block_id in vllm_block_ids :
235+ v_pointer_list .append (kv_layer [1 ][vllm_block_id ].data_ptr ())
236+ v_offsets .append (v_data_offset )
237+ v_pointer_lists .append (v_pointer_list )
238+ return k_pointer_lists + v_pointer_lists , k_offsets + v_offsets
239+
240+ def get_tensor_and_offset_layerwise (
241+ self , vllm_block_ids : List [int ], kv_layer : torch .Tensor , layer_name : str
212242 ) -> tuple [List [torch .Tensor ], List [int ]]:
213243 k_tensors = []
214244 k_offsets = []
215245 v_tensors = []
216246 v_offsets = []
217247 layer_id = self ._extract_layer_index (layer_name )
218248
219- for vllm_block_ids_tensor in vllm_block_ids_tensors :
249+ for blk_id in vllm_block_ids :
220250 k_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , False )
221251 if self .is_mla :
222- k_tensors .append (kv_layer [vllm_block_ids_tensor ])
252+ k_tensors .append (kv_layer [blk_id ])
223253 else :
224- k_tensors .append (kv_layer [0 ][vllm_block_ids_tensor ])
254+ k_tensors .append (kv_layer [0 ][blk_id ])
225255 k_offsets .append (k_data_offset )
226256 if not self .is_mla :
227257 v_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , True )
228- v_tensors .append (kv_layer [1 ][vllm_block_ids_tensor ])
258+ v_tensors .append (kv_layer [1 ][blk_id ])
229259 v_offsets .append (v_data_offset )
230260 return k_tensors + v_tensors , k_offsets + v_offsets
231261
@@ -277,18 +307,20 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
277307 vllm_block_ids_tensor .tolist ()
278308 )
279309 for layer_name , kv_layer in self .kv_caches .items ():
280- tensors , offsets = self .get_tensor_and_offset_layerwise (
310+ pointers_list , offsets = self .get_pointers_and_offset_layerwise (
281311 vllm_block_ids_tensors , kv_layer , layer_name
282312 )
283- k_task_id = self .connector .load (
284- storage_block_ids , offsets [:blocks_len ], tensors [:blocks_len ]
313+ size = [self .io_size ] * blocks_len
314+ k_task_id = self .connector .fetch_data (
315+ storage_block_ids , offsets [:blocks_len ], pointers_list [:blocks_len ], size
285316 )
286317 v_task_id = None
287318 if not self .is_mla :
288- v_task_id = self .connector .load (
319+ v_task_id = self .connector .fetch_data (
289320 storage_block_ids ,
290321 offsets [blocks_len :],
291- tensors [blocks_len :],
322+ pointers_list [blocks_len :],
323+ size
292324 )
293325 if request .request_id not in self .layerwise_load_tasks :
294326 self .layerwise_load_tasks [request .request_id ] = {}
@@ -404,7 +436,7 @@ def save_kv_layer(
404436 storage_block_ids = [block [0 ] for block in request .dump_blocks ]
405437 vllm_block_ids_tensors = [block [1 ] for block in request .dump_blocks ] # [5, 8, 12]
406438 blocks_len = len (storage_block_ids )
407- tensors , offsets = self .get_tensor_and_offset_layerwise (
439+ pointers_list , offsets = self .get_pointers_and_offset_layerwise (
408440 vllm_block_ids_tensors , kv_layer , layer_name
409441 )
410442
@@ -413,18 +445,18 @@ def save_kv_layer(
413445 elif kv_layer [0 ].device .type == "cuda" :
414446 torch .cuda .current_stream ().synchronize ()
415447
416- for block_id , offset , tensor in zip (
417- storage_block_ids , offsets [:blocks_len ], tensors [:blocks_len ]
448+ for block_id , offset , pointers in zip (
449+ storage_block_ids , offsets [:blocks_len ], pointers_list [:blocks_len ]
418450 ):
419- task = self .connector .dump ([block_id ], [offset ], [tensor ])
451+ task = self .connector .dump_data ([block_id ], [offset ], [pointers ], [ self . io_size ])
420452 self .dump_tasks .setdefault (request .request_id , {}).setdefault (
421453 block_id , []
422454 ).append (task )
423455 if not self .is_mla :
424- for block_id , offset , tensor in zip (
425- storage_block_ids , offsets [blocks_len :], tensors [blocks_len :]
456+ for block_id , offset , pointer_lists in zip (
457+ storage_block_ids , offsets [blocks_len :], pointers_list [blocks_len :]
426458 ):
427- task = self .connector .dump ([block_id ], [offset ], [tensor ])
459+ task = self .connector .dump_data ([block_id ], [offset ], [pointer_lists ], [ self . io_size ])
428460 self .dump_tasks .setdefault (request .request_id , {}).setdefault (
429461 block_id , []
430462 ).append (task )
@@ -465,23 +497,23 @@ def wait_for_tasks():
465497 vllm_block_ids_tensors = [block [1 ] for block in request .dump_blocks ]
466498 blocks_len = len (storage_block_ids )
467499 for layer_name , kv_layer in self .kv_caches .items ():
468- tensors , offsets = self .get_tensor_and_offset_layerwise (
500+ pointers_list , offsets = self .get_pointers_and_offset_layerwise (
469501 vllm_block_ids_tensors , kv_layer , layer_name
470502 )
471- for block_id , offset , tensor in zip (
472- storage_block_ids , offsets [:blocks_len ], tensors [:blocks_len ]
503+ for block_id , offset , pointers in zip (
504+ storage_block_ids , offsets [:blocks_len ], pointers_list [:blocks_len ]
473505 ):
474- task = self .connector .dump ([block_id ], [offset ], [tensor ])
506+ task = self .connector .dump_data ([block_id ], [offset ], [pointers ], [ self . io_size ])
475507 self .dump_tasks .setdefault (request .request_id , {}).setdefault (
476508 block_id , []
477509 ).append (task )
478510 if not self .is_mla :
479- for block_id , offset , tensor in zip (
511+ for block_id , offset , pointers in zip (
480512 storage_block_ids ,
481513 offsets [blocks_len :],
482- tensors [blocks_len :],
514+ pointers_list [blocks_len :],
483515 ):
484- task = self .connector .dump ([block_id ], [offset ], [tensor ])
516+ task = self .connector .dump_data ([block_id ], [offset ], [pointers ], [ self . io_size ])
485517 self .dump_tasks .setdefault (request .request_id , {}).setdefault (
486518 block_id , []
487519 ).append (task )
@@ -633,7 +665,7 @@ def hash_request_tokens(
633665 start_position = start_position ,
634666 )
635667 self ._need_load_reqs [request .request_id ] = []
636- return num_lookup_hits * self .block_size , True
668+ return num_lookup_hits * self .block_size - num_computed_tokens , True
637669
638670 # When all the tokens are cached in ssd or hbm,
639671 # we need to recompute the last token. This if condition will be removed
@@ -650,7 +682,7 @@ def hash_request_tokens(
650682 start_position = start_position ,
651683 )
652684
653- return num_lookup_hits * self .block_size , False
685+ return num_lookup_hits * self .block_size - num_computed_tokens , False
654686
655687 def update_state_after_alloc (
656688 self , request : "Request" , blocks : "KVCacheBlocks" , num_external_tokens : int
0 commit comments