Skip to content

Commit e562c1c

Browse files
committed
Add chunk_size
1 parent 9a69f5f commit e562c1c

File tree

8 files changed

+649
-589
lines changed

8 files changed

+649
-589
lines changed

ucm/integration/vllm/uc_connector.py

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
113113
self.num_head = vllm_config.model_config.get_num_kv_heads(
114114
vllm_config.parallel_config
115115
)
116+
self.chunk_size = 256
117+
self.blocks_per_chunk = self.chunk_size // self.block_size
116118
self.head_size = vllm_config.model_config.get_head_size()
117119
if (
118120
self._vllm_config.kv_transfer_config is not None
@@ -139,8 +141,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
139141
config_base
140142
* self.num_layers
141143
* (1 if self.is_mla else self.num_head * self.total_tp_size * 2)
142-
)
143-
config["io_size"] = config_base * (1 if self.is_mla else self.num_head)
144+
) * self.blocks_per_chunk
145+
self.io_size = config_base * (1 if self.is_mla else self.num_head) * self.blocks_per_chunk
146+
config["io_size"] = self.io_size
144147
logger.info(
145148
"kv_block_size = %d, io_size = %d,",
146149
config["kv_block_size"],
@@ -160,8 +163,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
160163
"use_layerwise"
161164
]
162165
)
163-
self.chunk_size = 256
164-
self.blocks_per_chunk = self.chunk_size // self.block_size
165166

166167
def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
167168
for layer_name in forward_context.no_compile_layers:
@@ -207,25 +208,54 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v):
207208
layer_size * layer_id + layer_size / self.total_tp_size * self.rank
208209
)
209210

210-
def get_tensor_and_offset_layerwise(
211+
def get_pointers_and_offset_layerwise(
211212
self, vllm_block_ids_tensors: List[torch.Tensor], kv_layer: torch.Tensor, layer_name: str
213+
) -> tuple[List[List[int]], List[int]]:
214+
k_pointer_lists = []
215+
k_offsets = []
216+
v_pointer_lists = []
217+
v_offsets = []
218+
layer_id = self._extract_layer_index(layer_name)
219+
220+
for vllm_block_ids_tensor in vllm_block_ids_tensors:
221+
vllm_block_ids = vllm_block_ids_tensor.tolist()
222+
k_pointer_list = []
223+
k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False)
224+
for vllm_block_id in vllm_block_ids:
225+
if self.is_mla:
226+
k_pointer_list.append(kv_layer[vllm_block_id].data_ptr())
227+
else:
228+
k_pointer_list.append(kv_layer[0][vllm_block_id].data_ptr())
229+
k_pointer_lists.append(k_pointer_list)
230+
k_offsets.append(k_data_offset)
231+
if not self.is_mla:
232+
v_pointer_list = []
233+
v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True)
234+
for vllm_block_id in vllm_block_ids:
235+
v_pointer_list.append(kv_layer[1][vllm_block_id].data_ptr())
236+
v_offsets.append(v_data_offset)
237+
v_pointer_lists.append(v_pointer_list)
238+
return k_pointer_lists + v_pointer_lists, k_offsets + v_offsets
239+
240+
def get_tensor_and_offset_layerwise(
241+
self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str
212242
) -> tuple[List[torch.Tensor], List[int]]:
213243
k_tensors = []
214244
k_offsets = []
215245
v_tensors = []
216246
v_offsets = []
217247
layer_id = self._extract_layer_index(layer_name)
218248

219-
for vllm_block_ids_tensor in vllm_block_ids_tensors:
249+
for blk_id in vllm_block_ids:
220250
k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False)
221251
if self.is_mla:
222-
k_tensors.append(kv_layer[vllm_block_ids_tensor])
252+
k_tensors.append(kv_layer[blk_id])
223253
else:
224-
k_tensors.append(kv_layer[0][vllm_block_ids_tensor])
254+
k_tensors.append(kv_layer[0][blk_id])
225255
k_offsets.append(k_data_offset)
226256
if not self.is_mla:
227257
v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True)
228-
v_tensors.append(kv_layer[1][vllm_block_ids_tensor])
258+
v_tensors.append(kv_layer[1][blk_id])
229259
v_offsets.append(v_data_offset)
230260
return k_tensors + v_tensors, k_offsets + v_offsets
231261

@@ -277,18 +307,20 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
277307
vllm_block_ids_tensor.tolist()
278308
)
279309
for layer_name, kv_layer in self.kv_caches.items():
280-
tensors, offsets = self.get_tensor_and_offset_layerwise(
310+
pointers_list, offsets = self.get_pointers_and_offset_layerwise(
281311
vllm_block_ids_tensors, kv_layer, layer_name
282312
)
283-
k_task_id = self.connector.load(
284-
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
313+
size = [self.io_size] * blocks_len
314+
k_task_id = self.connector.fetch_data(
315+
storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len], size
285316
)
286317
v_task_id = None
287318
if not self.is_mla:
288-
v_task_id = self.connector.load(
319+
v_task_id = self.connector.fetch_data(
289320
storage_block_ids,
290321
offsets[blocks_len:],
291-
tensors[blocks_len:],
322+
pointers_list[blocks_len:],
323+
size
292324
)
293325
if request.request_id not in self.layerwise_load_tasks:
294326
self.layerwise_load_tasks[request.request_id] = {}
@@ -404,7 +436,7 @@ def save_kv_layer(
404436
storage_block_ids = [block[0] for block in request.dump_blocks]
405437
vllm_block_ids_tensors = [block[1] for block in request.dump_blocks] # [5, 8, 12]
406438
blocks_len = len(storage_block_ids)
407-
tensors, offsets = self.get_tensor_and_offset_layerwise(
439+
pointers_list, offsets = self.get_pointers_and_offset_layerwise(
408440
vllm_block_ids_tensors, kv_layer, layer_name
409441
)
410442

@@ -413,18 +445,18 @@ def save_kv_layer(
413445
elif kv_layer[0].device.type == "cuda":
414446
torch.cuda.current_stream().synchronize()
415447

416-
for block_id, offset, tensor in zip(
417-
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
448+
for block_id, offset, pointers in zip(
449+
storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len]
418450
):
419-
task = self.connector.dump([block_id], [offset], [tensor])
451+
task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size])
420452
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
421453
block_id, []
422454
).append(task)
423455
if not self.is_mla:
424-
for block_id, offset, tensor in zip(
425-
storage_block_ids, offsets[blocks_len:], tensors[blocks_len:]
456+
for block_id, offset, pointer_lists in zip(
457+
storage_block_ids, offsets[blocks_len:], pointers_list[blocks_len:]
426458
):
427-
task = self.connector.dump([block_id], [offset], [tensor])
459+
task = self.connector.dump_data([block_id], [offset], [pointer_lists], [self.io_size])
428460
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
429461
block_id, []
430462
).append(task)
@@ -465,23 +497,23 @@ def wait_for_tasks():
465497
vllm_block_ids_tensors = [block[1] for block in request.dump_blocks]
466498
blocks_len = len(storage_block_ids)
467499
for layer_name, kv_layer in self.kv_caches.items():
468-
tensors, offsets = self.get_tensor_and_offset_layerwise(
500+
pointers_list, offsets = self.get_pointers_and_offset_layerwise(
469501
vllm_block_ids_tensors, kv_layer, layer_name
470502
)
471-
for block_id, offset, tensor in zip(
472-
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
503+
for block_id, offset, pointers in zip(
504+
storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len]
473505
):
474-
task = self.connector.dump([block_id], [offset], [tensor])
506+
task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size])
475507
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
476508
block_id, []
477509
).append(task)
478510
if not self.is_mla:
479-
for block_id, offset, tensor in zip(
511+
for block_id, offset, pointers in zip(
480512
storage_block_ids,
481513
offsets[blocks_len:],
482-
tensors[blocks_len:],
514+
pointers_list[blocks_len:],
483515
):
484-
task = self.connector.dump([block_id], [offset], [tensor])
516+
task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size])
485517
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
486518
block_id, []
487519
).append(task)
@@ -633,7 +665,7 @@ def hash_request_tokens(
633665
start_position=start_position,
634666
)
635667
self._need_load_reqs[request.request_id] = []
636-
return num_lookup_hits * self.block_size, True
668+
return num_lookup_hits * self.block_size - num_computed_tokens, True
637669

638670
# When all the tokens are cached in ssd or hbm,
639671
# we need to recompute the last token. This if condition will be removed
@@ -650,7 +682,7 @@ def hash_request_tokens(
650682
start_position=start_position,
651683
)
652684

653-
return num_lookup_hits * self.block_size, False
685+
return num_lookup_hits * self.block_size - num_computed_tokens, False
654686

655687
def update_state_after_alloc(
656688
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int

0 commit comments

Comments
 (0)