From 9a69f5f882383e4c670fc5bc877fcbc857016366 Mon Sep 17 00:00:00 2001 From: qyh Date: Tue, 28 Oct 2025 21:34:11 +0800 Subject: [PATCH 1/4] introduce chunksize --- ucm/integration/vllm/uc_connector.py | 116 +++++++++++++++------------ 1 file changed, 65 insertions(+), 51 deletions(-) diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index ddba78d6..c8d8e9a8 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -66,15 +66,17 @@ class RequestBlockInfo: block_operations: list[BlockOperation] = field(default_factory=list) # Next block position to process start_position: int = 0 + # vllm_block_ids in HBM + vllm_block_ids: list[int] = field(default_factory=list) @dataclass class ReqMeta: request_id: str # list[(block_hash, vllm_block_id)] - load_blocks: list[tuple[str, int]] = field(default_factory=list) + load_blocks: list[tuple[str, torch.Tensor]] = field(default_factory=list) # list[(block_hash, vllm_block_id)] - dump_blocks: list[tuple[str, int]] = field(default_factory=list) + dump_blocks: list[tuple[str, torch.Tensor]] = field(default_factory=list) # Whether use load_async load_async: bool = False @@ -158,6 +160,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): "use_layerwise" ] ) + self.chunk_size = 256 + self.blocks_per_chunk = self.chunk_size // self.block_size def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): for layer_name in forward_context.no_compile_layers: @@ -204,7 +208,7 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v): ) def get_tensor_and_offset_layerwise( - self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str + self, vllm_block_ids_tensors: List[torch.Tensor], kv_layer: torch.Tensor, layer_name: str ) -> tuple[List[torch.Tensor], List[int]]: k_tensors = [] k_offsets = [] @@ -212,16 +216,16 @@ def get_tensor_and_offset_layerwise( v_offsets = [] layer_id = self._extract_layer_index(layer_name) - for blk_id in vllm_block_ids: + for vllm_block_ids_tensor in vllm_block_ids_tensors: k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False) if self.is_mla: - k_tensors.append(kv_layer[blk_id]) + k_tensors.append(kv_layer[vllm_block_ids_tensor]) else: - k_tensors.append(kv_layer[0][blk_id]) + k_tensors.append(kv_layer[0][vllm_block_ids_tensor]) k_offsets.append(k_data_offset) if not self.is_mla: v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True) - v_tensors.append(kv_layer[1][blk_id]) + v_tensors.append(kv_layer[1][vllm_block_ids_tensor]) v_offsets.append(v_data_offset) return k_tensors + v_tensors, k_offsets + v_offsets @@ -266,14 +270,15 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: continue storage_block_ids = [block[0] for block in request.load_blocks] - vllm_block_ids = [block[1] for block in request.load_blocks] + vllm_block_ids_tensors = [block[1] for block in request.load_blocks] blocks_len = len(storage_block_ids) - self._load_req_to_blocks.setdefault(request.request_id, set()).update( - vllm_block_ids - ) + for vllm_block_ids_tensor in vllm_block_ids_tensors: + self._load_req_to_blocks.setdefault(request.request_id, set()).update( + vllm_block_ids_tensor.tolist() + ) for layer_name, kv_layer in self.kv_caches.items(): tensors, offsets = self.get_tensor_and_offset_layerwise( - vllm_block_ids, kv_layer, layer_name + vllm_block_ids_tensors, kv_layer, layer_name ) k_task_id = self.connector.load( storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] @@ -397,10 +402,10 @@ def save_kv_layer( # Example: [("hash_123", 5), ("hash_456", 8), ("hash_789", 12)] # ["hash_123", "hash_456", "hash_789"] storage_block_ids = [block[0] for block in request.dump_blocks] - vllm_block_ids = [block[1] for block in request.dump_blocks] # [5, 8, 12] + vllm_block_ids_tensors = [block[1] for block in request.dump_blocks] # [5, 8, 12] blocks_len = len(storage_block_ids) tensors, offsets = self.get_tensor_and_offset_layerwise( - vllm_block_ids, kv_layer, layer_name + vllm_block_ids_tensors, kv_layer, layer_name ) if kv_layer[0].device.type == "npu": @@ -457,11 +462,11 @@ def wait_for_tasks(): continue storage_block_ids = [block[0] for block in request.dump_blocks] - vllm_block_ids = [block[1] for block in request.dump_blocks] + vllm_block_ids_tensors = [block[1] for block in request.dump_blocks] blocks_len = len(storage_block_ids) for layer_name, kv_layer in self.kv_caches.items(): tensors, offsets = self.get_tensor_and_offset_layerwise( - vllm_block_ids, kv_layer, layer_name + vllm_block_ids_tensors, kv_layer, layer_name ) for block_id, offset, tensor in zip( storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] @@ -580,13 +585,13 @@ def hash_request_tokens( return ret assert num_computed_tokens % self.block_size == 0 - block_hashes = hash_request_tokens(md5, self.block_size, request) + block_hashes = hash_request_tokens(md5, self.chunk_size, request) if not block_hashes: logger.debug("Maybe tokens too short to load.") return 0, False # Calculate start position (exclude blocks already in HBM) - start_position = num_computed_tokens // self.block_size + start_position = num_computed_tokens // self.chunk_size block_operations = [BlockOperation.NONE] * len(block_hashes) @@ -655,12 +660,14 @@ def update_state_after_alloc( """ if request.request_id in self._need_load_reqs: local_block_ids = ( - # since we use unhashed blocks, so we don't need to reset start_position - blocks.get_unhashed_block_ids() + blocks.get_block_ids() if num_external_tokens > 0 else [] ) - self._need_load_reqs[request.request_id] = local_block_ids + self._need_load_reqs[request.request_id] = local_block_ids[0] + request_block_info = self.request_block_infos.get(request.request_id, None) + if request_block_info: + request_block_info.start_position = 0 return request_block_info = self.request_block_infos.get(request.request_id, None) @@ -699,15 +706,16 @@ def build_connector_meta( for req_id, block_ids in self._need_load_reqs.items(): block_info = self.request_block_infos.get(req_id) if block_info: - load_blocks, dump_blocks = self._extract_blocks(block_ids, block_info) - meta.requests.append( - ReqMeta( - request_id=req_id, - load_blocks=load_blocks, - dump_blocks=dump_blocks, - load_async=True, + block_info.vllm_block_ids = block_ids + load_blocks, dump_blocks = self._extract_blocks(block_info) + meta.requests.append( + ReqMeta( + request_id=req_id, + load_blocks=load_blocks, + dump_blocks=dump_blocks, + load_async=True, + ) ) - ) self._need_load_reqs.clear() for new_req in scheduler_output.scheduled_new_reqs: @@ -716,9 +724,8 @@ def build_connector_meta( block_info = self.request_block_infos.get(req_id) if block_info: - load_blocks, dump_blocks = self._extract_blocks( - vllm_block_ids, block_info - ) + block_info.vllm_block_ids = vllm_block_ids + load_blocks, dump_blocks = self._extract_blocks(block_info) if load_blocks or dump_blocks: meta.requests.append( ReqMeta( @@ -756,9 +763,8 @@ def get_requests(): for req_id, new_block_ids in get_requests(): block_info = self.request_block_infos.get(req_id) if block_info: - load_blocks, dump_blocks = self._extract_blocks( - new_block_ids[0], block_info - ) + block_info.vllm_block_ids.extend(new_block_ids[0]) + load_blocks, dump_blocks = self._extract_blocks(block_info) if load_blocks or dump_blocks: meta.requests.append( ReqMeta( @@ -791,8 +797,8 @@ def request_finished( return False, None def _extract_blocks( - self, vllm_block_ids: list[int], block_info: RequestBlockInfo - ) -> tuple[list[tuple[str, int]], list[tuple[str, int]]]: + self, block_info: RequestBlockInfo + ) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]: """ Extract blocks that need load and dump, block_info.start_position is the next block position to process, only return blocks that need @@ -802,23 +808,31 @@ def _extract_blocks( if start_pos >= len(block_info.block_operations): return [], [] - - process_length = min( - len(block_info.block_operations) - start_pos, len(vllm_block_ids) - ) - ops = block_info.block_operations[start_pos : start_pos + process_length] - hashes = block_info.block_hashes[start_pos : start_pos + process_length] - vllm_ids = vllm_block_ids[:process_length] - + load_blocks = [] dump_blocks = [] - for op, hash, vllm_id in zip(ops, hashes, vllm_ids): - if op == BlockOperation.LOAD: - load_blocks.append((hash, vllm_id)) - elif op == BlockOperation.DUMP: - dump_blocks.append((hash, vllm_id)) - block_info.start_position += process_length + block_mapping: dict[str, torch.Tensor] = {} + vllm_block_ids = block_info.vllm_block_ids + for idx, vllm_block_id in enumerate(vllm_block_ids[start_pos * self.blocks_per_chunk :], start_pos * self.blocks_per_chunk): + chunk_idx = idx // self.blocks_per_chunk + if chunk_idx >= len(block_info.block_hashes): + break + if idx + self.blocks_per_chunk > len(vllm_block_ids): + break + chunk_blocks = vllm_block_ids[idx : idx + self.blocks_per_chunk] + block_mapping[block_info.block_hashes[chunk_idx]] = torch.tensor(chunk_blocks) + + for i in range(start_pos, start_pos + len(block_mapping)): + if block_info.block_operations[i] == BlockOperation.LOAD: + chunk_hash = block_info.block_hashes[i] + load_blocks.append((chunk_hash, block_mapping[chunk_hash])) + elif block_info.block_operations[i] == BlockOperation.DUMP: + chunk_hash = block_info.block_hashes[i] + dump_blocks.append((chunk_hash, block_mapping[chunk_hash])) + + block_info.start_position += len(block_mapping) + return load_blocks, dump_blocks def get_block_ids_with_load_errors(self) -> set[int]: From fff9999555fa8216160db03b498e91bf0dd1ad2c Mon Sep 17 00:00:00 2001 From: qyh Date: Wed, 29 Oct 2025 19:50:12 +0800 Subject: [PATCH 2/4] Add chunkSize --- ucm/integration/vllm/uc_connector.py | 92 +++++++++++++------ ucm/store/dramstore/cpy/dramstore.py.cc | 6 +- ucm/store/localstore/cpy/localstore.py.cc | 6 +- .../nfsstore/cc/domain/trans/posix_queue.cc | 24 +++-- ucm/store/nfsstore/cpy/nfsstore.py.cc | 8 +- ucm/store/nfsstore/nfsstore_connector.py | 1 + ucm/store/task/task_shard.h | 6 +- ucm/store/ucmstore.py | 4 +- 8 files changed, 101 insertions(+), 46 deletions(-) diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index c8d8e9a8..e9edb442 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -113,6 +113,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.num_head = vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config ) + self.chunk_size = 256 + self.blocks_per_chunk = self.chunk_size // self.block_size self.head_size = vllm_config.model_config.get_head_size() if ( self._vllm_config.kv_transfer_config is not None @@ -139,8 +141,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): config_base * self.num_layers * (1 if self.is_mla else self.num_head * self.total_tp_size * 2) - ) - config["io_size"] = config_base * (1 if self.is_mla else self.num_head) + ) * self.blocks_per_chunk + self.io_size = config_base * (1 if self.is_mla else self.num_head) * self.blocks_per_chunk + config["io_size"] = self.io_size logger.info( "kv_block_size = %d, io_size = %d,", config["kv_block_size"], @@ -160,8 +163,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): "use_layerwise" ] ) - self.chunk_size = 256 - self.blocks_per_chunk = self.chunk_size // self.block_size def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): for layer_name in forward_context.no_compile_layers: @@ -207,8 +208,37 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v): layer_size * layer_id + layer_size / self.total_tp_size * self.rank ) - def get_tensor_and_offset_layerwise( + def get_pointers_and_offset_layerwise( self, vllm_block_ids_tensors: List[torch.Tensor], kv_layer: torch.Tensor, layer_name: str + ) -> tuple[List[List[int]], List[int]]: + k_pointer_lists = [] + k_offsets = [] + v_pointer_lists = [] + v_offsets = [] + layer_id = self._extract_layer_index(layer_name) + + for vllm_block_ids_tensor in vllm_block_ids_tensors: + vllm_block_ids = vllm_block_ids_tensor.tolist() + k_pointer_list = [] + k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False) + for vllm_block_id in vllm_block_ids: + if self.is_mla: + k_pointer_list.append(kv_layer[vllm_block_id].data_ptr()) + else: + k_pointer_list.append(kv_layer[0][vllm_block_id].data_ptr()) + k_pointer_lists.append(k_pointer_list) + k_offsets.append(k_data_offset) + if not self.is_mla: + v_pointer_list = [] + v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True) + for vllm_block_id in vllm_block_ids: + v_pointer_list.append(kv_layer[1][vllm_block_id].data_ptr()) + v_offsets.append(v_data_offset) + v_pointer_lists.append(v_pointer_list) + return k_pointer_lists + v_pointer_lists, k_offsets + v_offsets + + def get_tensor_and_offset_layerwise( + self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str ) -> tuple[List[torch.Tensor], List[int]]: k_tensors = [] k_offsets = [] @@ -216,16 +246,16 @@ def get_tensor_and_offset_layerwise( v_offsets = [] layer_id = self._extract_layer_index(layer_name) - for vllm_block_ids_tensor in vllm_block_ids_tensors: + for blk_id in vllm_block_ids: k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False) if self.is_mla: - k_tensors.append(kv_layer[vllm_block_ids_tensor]) + k_tensors.append(kv_layer[blk_id]) else: - k_tensors.append(kv_layer[0][vllm_block_ids_tensor]) + k_tensors.append(kv_layer[0][blk_id]) k_offsets.append(k_data_offset) if not self.is_mla: v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True) - v_tensors.append(kv_layer[1][vllm_block_ids_tensor]) + v_tensors.append(kv_layer[1][blk_id]) v_offsets.append(v_data_offset) return k_tensors + v_tensors, k_offsets + v_offsets @@ -277,18 +307,20 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: vllm_block_ids_tensor.tolist() ) for layer_name, kv_layer in self.kv_caches.items(): - tensors, offsets = self.get_tensor_and_offset_layerwise( + pointers_list, offsets = self.get_pointers_and_offset_layerwise( vllm_block_ids_tensors, kv_layer, layer_name ) - k_task_id = self.connector.load( - storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] + size = [self.io_size] * blocks_len + k_task_id = self.connector.fetch_data( + storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len], size ) v_task_id = None if not self.is_mla: - v_task_id = self.connector.load( + v_task_id = self.connector.fetch_data( storage_block_ids, offsets[blocks_len:], - tensors[blocks_len:], + pointers_list[blocks_len:], + size ) if request.request_id not in self.layerwise_load_tasks: self.layerwise_load_tasks[request.request_id] = {} @@ -404,7 +436,7 @@ def save_kv_layer( storage_block_ids = [block[0] for block in request.dump_blocks] vllm_block_ids_tensors = [block[1] for block in request.dump_blocks] # [5, 8, 12] blocks_len = len(storage_block_ids) - tensors, offsets = self.get_tensor_and_offset_layerwise( + pointers_list, offsets = self.get_pointers_and_offset_layerwise( vllm_block_ids_tensors, kv_layer, layer_name ) @@ -413,18 +445,18 @@ def save_kv_layer( elif kv_layer[0].device.type == "cuda": torch.cuda.current_stream().synchronize() - for block_id, offset, tensor in zip( - storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] + for block_id, offset, pointers in zip( + storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len] ): - task = self.connector.dump([block_id], [offset], [tensor]) + task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size]) self.dump_tasks.setdefault(request.request_id, {}).setdefault( block_id, [] ).append(task) if not self.is_mla: - for block_id, offset, tensor in zip( - storage_block_ids, offsets[blocks_len:], tensors[blocks_len:] + for block_id, offset, pointer_lists in zip( + storage_block_ids, offsets[blocks_len:], pointers_list[blocks_len:] ): - task = self.connector.dump([block_id], [offset], [tensor]) + task = self.connector.dump_data([block_id], [offset], [pointer_lists], [self.io_size]) self.dump_tasks.setdefault(request.request_id, {}).setdefault( block_id, [] ).append(task) @@ -465,23 +497,23 @@ def wait_for_tasks(): vllm_block_ids_tensors = [block[1] for block in request.dump_blocks] blocks_len = len(storage_block_ids) for layer_name, kv_layer in self.kv_caches.items(): - tensors, offsets = self.get_tensor_and_offset_layerwise( + pointers_list, offsets = self.get_pointers_and_offset_layerwise( vllm_block_ids_tensors, kv_layer, layer_name ) - for block_id, offset, tensor in zip( - storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] + for block_id, offset, pointers in zip( + storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len] ): - task = self.connector.dump([block_id], [offset], [tensor]) + task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size]) self.dump_tasks.setdefault(request.request_id, {}).setdefault( block_id, [] ).append(task) if not self.is_mla: - for block_id, offset, tensor in zip( + for block_id, offset, pointers in zip( storage_block_ids, offsets[blocks_len:], - tensors[blocks_len:], + pointers_list[blocks_len:], ): - task = self.connector.dump([block_id], [offset], [tensor]) + task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size]) self.dump_tasks.setdefault(request.request_id, {}).setdefault( block_id, [] ).append(task) @@ -633,7 +665,7 @@ def hash_request_tokens( start_position=start_position, ) self._need_load_reqs[request.request_id] = [] - return num_lookup_hits * self.block_size, True + return num_lookup_hits * self.block_size - num_computed_tokens, True # When all the tokens are cached in ssd or hbm, # we need to recompute the last token. This if condition will be removed @@ -650,7 +682,7 @@ def hash_request_tokens( start_position=start_position, ) - return num_lookup_hits * self.block_size, False + return num_lookup_hits * self.block_size - num_computed_tokens, False def update_state_after_alloc( self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int diff --git a/ucm/store/dramstore/cpy/dramstore.py.cc b/ucm/store/dramstore/cpy/dramstore.py.cc index 5b50748f..e2aed243 100644 --- a/ucm/store/dramstore/cpy/dramstore.py.cc +++ b/ucm/store/dramstore/cpy/dramstore.py.cc @@ -78,8 +78,12 @@ class DRAMStorePy : public DRAMStore { auto length = lengths.begin(); while ((blockId != blockIds.end()) && (offset != offsets.end()) && (address != addresses.end()) && (length != lengths.end())) { + std::vector addr_vec; + for (auto addr_item : address->cast()) { + addr_vec.push_back(addr_item.cast()); + } task.Append(blockId->cast(), offset->cast(), - address->cast(), length->cast()); + std::move(addr_vec), length->cast()); blockId++; offset++; address++; diff --git a/ucm/store/localstore/cpy/localstore.py.cc b/ucm/store/localstore/cpy/localstore.py.cc index c067df23..a809559b 100644 --- a/ucm/store/localstore/cpy/localstore.py.cc +++ b/ucm/store/localstore/cpy/localstore.py.cc @@ -78,8 +78,12 @@ class LocalStorePy : public LocalStore { auto length = lengths.begin(); while ((blockId != blockIds.end()) && (offset != offsets.end()) && (address != addresses.end()) && (length != lengths.end())) { + std::vector addr_vec; + for (auto addr_item : address->cast()) { + addr_vec.push_back(addr_item.cast()); + } task.Append(blockId->cast(), offset->cast(), - address->cast(), length->cast()); + std::move(addr_vec), length->cast()); blockId++; offset++; address++; diff --git a/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc b/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc index 21e4b85d..768388d4 100644 --- a/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc +++ b/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc @@ -103,7 +103,13 @@ Status PosixQueue::D2S(Task::Shard& shard, const Device& device) return Status::OutOfMemory(); } auto hub = shard.buffer.get(); - auto status = device->D2HSync((std::byte*)hub, (std::byte*)shard.address, shard.length); + auto dAddr = new std::byte*[shard.address.size()]; + auto hAddr = new std::byte*[shard.address.size()]; + for (size_t i = 0; i < shard.address.size(); i++) { + hAddr[i] = (std::byte*)hub + i * shard.length / shard.address.size(); + dAddr[i] = (std::byte*)shard.address[i]; + } + auto status = device->D2HBatchSync(hAddr, const_cast(dAddr), shard.address.size(), shard.length / shard.address.size()); if (status.Failure()) { return status; } auto path = this->layout_->DataFilePath(shard.block, true); return File::Write(path, shard.offset, shard.length, (uintptr_t)hub); @@ -120,21 +126,27 @@ Status PosixQueue::S2D(Task::Shard& shard, const Device& device) auto path = this->layout_->DataFilePath(shard.block, false); auto status = File::Read(path, shard.offset, shard.length, (uintptr_t)hub); if (status.Failure()) { return status; } - return device->H2DAsync((std::byte*)shard.address, (std::byte*)hub, shard.length); + auto dAddr = new std::byte*[shard.address.size()]; + auto hAddr = new std::byte*[shard.address.size()]; + for (size_t i = 0; i < shard.address.size(); i++) { + hAddr[i] = (std::byte*)hub + i * shard.length / shard.address.size(); + dAddr[i] = (std::byte*)shard.address[i]; + } + return device->H2DBatchSync(dAddr, const_cast(hAddr), shard.address.size(), shard.length / shard.address.size()); } Status PosixQueue::H2S(Task::Shard& shard) { auto path = this->layout_->DataFilePath(shard.block, true); - auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address); - return File::Write(path, shard.offset, shard.length, shard.address, aligned); + auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address[0]); + return File::Write(path, shard.offset, shard.length, shard.address[0], aligned); } Status PosixQueue::S2H(Task::Shard& shard) { auto path = this->layout_->DataFilePath(shard.block, false); - auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address); - return File::Read(path, shard.offset, shard.length, shard.address, aligned); + auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address[0]); + return File::Read(path, shard.offset, shard.length, shard.address[0], aligned); } } // namespace UC diff --git a/ucm/store/nfsstore/cpy/nfsstore.py.cc b/ucm/store/nfsstore/cpy/nfsstore.py.cc index 7f148f1b..59002424 100644 --- a/ucm/store/nfsstore/cpy/nfsstore.py.cc +++ b/ucm/store/nfsstore/cpy/nfsstore.py.cc @@ -91,8 +91,12 @@ class NFSStorePy : public NFSStore { auto length = lengths.begin(); while ((blockId != blockIds.end()) && (offset != offsets.end()) && (address != addresses.end()) && (length != lengths.end())) { + std::vector addr_vec; + for (auto addr_item : address->cast()) { + addr_vec.push_back(addr_item.cast()); + } task.Append(blockId->cast(), offset->cast(), - address->cast(), length->cast()); + std::move(addr_vec), length->cast()); blockId++; offset++; address++; @@ -123,8 +127,6 @@ PYBIND11_MODULE(ucmnfsstore, module) config.def_readwrite("transferBufferNumber", &UC::NFSStorePy::Config::transferBufferNumber); config.def_readwrite("transferTimeoutMs", &UC::NFSStorePy::Config::transferTimeoutMs); config.def_readwrite("tempDumpDirEnable", &UC::NFSStorePy::Config::tempDumpDirEnable); - config.def_readwrite("hotnessEnable", &UC::NFSStorePy::Config::hotnessEnable); - config.def_readwrite("hotnessInterval", &UC::NFSStorePy::Config::hotnessInterval); store.def(py::init<>()); store.def("CCStoreImpl", &UC::NFSStorePy::CCStoreImpl); store.def("Setup", &UC::NFSStorePy::Setup); diff --git a/ucm/store/nfsstore/nfsstore_connector.py b/ucm/store/nfsstore/nfsstore_connector.py index 0f57b68a..eaab0ad6 100644 --- a/ucm/store/nfsstore/nfsstore_connector.py +++ b/ucm/store/nfsstore/nfsstore_connector.py @@ -51,6 +51,7 @@ def __init__(self, config: Dict): if transfer_enable: param.transferDeviceId = config["device"] param.transferIoSize = config["io_size"] + param.transferStreamNumber = config.get("transfer_stream_number", 128) ret = self.store.Setup(param) if ret != 0: msg = f"Failed to initialize ucmnfsstore, errcode: {ret}." diff --git a/ucm/store/task/task_shard.h b/ucm/store/task/task_shard.h index 2f71738f..fb6de989 100644 --- a/ucm/store/task/task_shard.h +++ b/ucm/store/task/task_shard.h @@ -47,13 +47,13 @@ class Task { Location location; std::string block; size_t offset; - uintptr_t address; + std::vector address; size_t length; size_t owner; std::shared_ptr buffer; std::function done; Shard(const Type type, const Location location, const std::string& block, - const size_t offset, const uintptr_t address, const size_t length, const size_t owner) + const size_t offset, const std::vector address, const size_t length, const size_t owner) : type{type}, location{location}, block{block}, offset{offset}, address{address}, length{length}, owner{owner}, buffer{nullptr}, done{nullptr} { @@ -86,7 +86,7 @@ class Task { auto Id() const noexcept { return id_; } auto StartTp() const noexcept { return startTp_; } auto Str() const noexcept { return fmt::format("{},{},{},{}", id_, brief_, number_, size_); } - void Append(const std::string& block, const size_t offset, const uintptr_t address, + void Append(const std::string& block, const size_t offset, const std::vector address, const size_t length) { shards_.emplace_back(type_, location_, block, offset, address, length, id_); diff --git a/ucm/store/ucmstore.py b/ucm/store/ucmstore.py index f473bab5..b6cde07f 100644 --- a/ucm/store/ucmstore.py +++ b/ucm/store/ucmstore.py @@ -129,7 +129,7 @@ def fetch_data( self, block_ids: List[str], offset: List[int], - dst_addr: List[int], + dst_addr: List[List[int]], size: List[int], ) -> Task: """ @@ -150,7 +150,7 @@ def dump_data( self, block_ids: List[str], offset: List[int], - src_addr: List[int], + src_addr: List[List[int]], size: List[int], ) -> Task: """ From 4a91a1bb1236277296f3028b61a81778d17cf5e7 Mon Sep 17 00:00:00 2001 From: qyh Date: Thu, 30 Oct 2025 10:50:19 +0800 Subject: [PATCH 3/4] fix bug --- ucm/integration/vllm/uc_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index e9edb442..71abbe7f 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -846,7 +846,7 @@ def _extract_blocks( block_mapping: dict[str, torch.Tensor] = {} vllm_block_ids = block_info.vllm_block_ids - for idx, vllm_block_id in enumerate(vllm_block_ids[start_pos * self.blocks_per_chunk :], start_pos * self.blocks_per_chunk): + for idx in range(start_pos * self.blocks_per_chunk, len(vllm_block_ids), self.blocks_per_chunk): chunk_idx = idx // self.blocks_per_chunk if chunk_idx >= len(block_info.block_hashes): break From 50a30740e9185972d62ddbe8055d7b9501e0f745 Mon Sep 17 00:00:00 2001 From: qyh Date: Thu, 30 Oct 2025 20:59:21 +0800 Subject: [PATCH 4/4] fix bug --- ucm/integration/vllm/uc_connector.py | 23 +++++++++++++++---- .../nfsstore/cc/domain/trans/posix_queue.cc | 12 +++++----- ucm/store/nfsstore/nfsstore_connector.py | 4 ++-- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 71abbe7f..fedcb950 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -27,6 +27,7 @@ import pickle from dataclasses import dataclass, field from enum import Enum +import time from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union import torch @@ -187,10 +188,8 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v): # One block size k_min_data_block_size = ( kv_layer[0][0].numel() if not self.is_mla else kv_layer[0].numel() - ) * elem_size - v_min_data_block_size = ( - kv_layer[1][0].numel() if not self.is_mla else 0 - ) * elem_size + ) * elem_size * self.blocks_per_chunk + v_min_data_block_size = k_min_data_block_size # When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size layer_size = (k_min_data_block_size + v_min_data_block_size) * ( self.total_tp_size if not self.is_mla else 1 @@ -295,6 +294,8 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self.layerwise_load_tasks.clear() self.current_layer = 0 + total_size = 0 + start_time = time.perf_counter() for request in metadata.requests: if not request.load_blocks: continue @@ -310,6 +311,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: pointers_list, offsets = self.get_pointers_and_offset_layerwise( vllm_block_ids_tensors, kv_layer, layer_name ) + total_size += len(offsets) * self.io_size size = [self.io_size] * blocks_len k_task_id = self.connector.fetch_data( storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len], size @@ -354,6 +356,11 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: if v_task and self.connector.wait(v_task) != 0: self._load_failed_reqs.add(request.request_id) break + end_time = time.perf_counter() + elapsed_time = end_time - start_time + throughput_gbps = (total_size / (1024**3)) / elapsed_time + if total_size > 0: + logger.info(f"LOAD: 数据量={(total_size / (1024**3)):.4f}GB, 耗时={elapsed_time:.4f}, KV加载传输完成: 速度={throughput_gbps:.4f} GB/s") def wait_for_layer_load(self, layer_name: str) -> None: """ @@ -489,6 +496,8 @@ def wait_for_tasks(): self.dump_tasks.clear() return success_dumped_blocks if success_dumped_blocks else None + start_time = time.perf_counter() + total_size = 0 for request in metadata.requests: if not request.dump_blocks: continue @@ -500,6 +509,7 @@ def wait_for_tasks(): pointers_list, offsets = self.get_pointers_and_offset_layerwise( vllm_block_ids_tensors, kv_layer, layer_name ) + total_size += len(offsets) * self.io_size for block_id, offset, pointers in zip( storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len] ): @@ -519,6 +529,11 @@ def wait_for_tasks(): ).append(task) wait_for_tasks() self.dump_tasks.clear() + end_time = time.perf_counter() + elapsed_time = end_time - start_time + throughput_gbps = (total_size / (1024**3)) / elapsed_time + if total_size > 0: + logger.info(f"DUMP: 数据量={(total_size / (1024**3)):.4f}GB, 耗时={elapsed_time:.4f}, KV保存传输完成: 速度={throughput_gbps:.4f} GB/s") return success_dumped_blocks if success_dumped_blocks else None def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: diff --git a/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc b/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc index 768388d4..728b1cec 100644 --- a/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc +++ b/ucm/store/nfsstore/cc/domain/trans/posix_queue.cc @@ -103,13 +103,13 @@ Status PosixQueue::D2S(Task::Shard& shard, const Device& device) return Status::OutOfMemory(); } auto hub = shard.buffer.get(); - auto dAddr = new std::byte*[shard.address.size()]; - auto hAddr = new std::byte*[shard.address.size()]; + std::vector dAddr(shard.address.size()); + std::vector hAddr(shard.address.size()); for (size_t i = 0; i < shard.address.size(); i++) { hAddr[i] = (std::byte*)hub + i * shard.length / shard.address.size(); dAddr[i] = (std::byte*)shard.address[i]; } - auto status = device->D2HBatchSync(hAddr, const_cast(dAddr), shard.address.size(), shard.length / shard.address.size()); + auto status = device->D2HBatchSync(hAddr.data(), const_cast(dAddr.data()), shard.address.size(), shard.length / shard.address.size()); if (status.Failure()) { return status; } auto path = this->layout_->DataFilePath(shard.block, true); return File::Write(path, shard.offset, shard.length, (uintptr_t)hub); @@ -126,13 +126,13 @@ Status PosixQueue::S2D(Task::Shard& shard, const Device& device) auto path = this->layout_->DataFilePath(shard.block, false); auto status = File::Read(path, shard.offset, shard.length, (uintptr_t)hub); if (status.Failure()) { return status; } - auto dAddr = new std::byte*[shard.address.size()]; - auto hAddr = new std::byte*[shard.address.size()]; + std::vector dAddr(shard.address.size()); + std::vector hAddr(shard.address.size()); for (size_t i = 0; i < shard.address.size(); i++) { hAddr[i] = (std::byte*)hub + i * shard.length / shard.address.size(); dAddr[i] = (std::byte*)shard.address[i]; } - return device->H2DBatchSync(dAddr, const_cast(hAddr), shard.address.size(), shard.length / shard.address.size()); + return device->H2DBatchSync(dAddr.data(), const_cast(hAddr.data()), shard.address.size(), shard.length / shard.address.size()); } Status PosixQueue::H2S(Task::Shard& shard) diff --git a/ucm/store/nfsstore/nfsstore_connector.py b/ucm/store/nfsstore/nfsstore_connector.py index eaab0ad6..d9ce9b98 100644 --- a/ucm/store/nfsstore/nfsstore_connector.py +++ b/ucm/store/nfsstore/nfsstore_connector.py @@ -93,7 +93,7 @@ def fetch_data( self, block_ids: List[str], offset: List[int], - dst_addr: List[int], + dst_addr: List[List[int]], size: List[int], ) -> Task: task_id = self.store.LoadToDevice(block_ids, offset, dst_addr, size) @@ -103,7 +103,7 @@ def dump_data( self, block_ids: List[str], offset: List[int], - src_addr: List[int], + src_addr: List[List[int]], size: List[int], ) -> Task: task_id = self.store.DumpFromDevice(block_ids, offset, src_addr, size)