Skip to content

Commit fff9999

Browse files
committed
Add chunkSize
1 parent 9a69f5f commit fff9999

File tree

8 files changed

+101
-46
lines changed

8 files changed

+101
-46
lines changed

ucm/integration/vllm/uc_connector.py

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
113113
self.num_head = vllm_config.model_config.get_num_kv_heads(
114114
vllm_config.parallel_config
115115
)
116+
self.chunk_size = 256
117+
self.blocks_per_chunk = self.chunk_size // self.block_size
116118
self.head_size = vllm_config.model_config.get_head_size()
117119
if (
118120
self._vllm_config.kv_transfer_config is not None
@@ -139,8 +141,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
139141
config_base
140142
* self.num_layers
141143
* (1 if self.is_mla else self.num_head * self.total_tp_size * 2)
142-
)
143-
config["io_size"] = config_base * (1 if self.is_mla else self.num_head)
144+
) * self.blocks_per_chunk
145+
self.io_size = config_base * (1 if self.is_mla else self.num_head) * self.blocks_per_chunk
146+
config["io_size"] = self.io_size
144147
logger.info(
145148
"kv_block_size = %d, io_size = %d,",
146149
config["kv_block_size"],
@@ -160,8 +163,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
160163
"use_layerwise"
161164
]
162165
)
163-
self.chunk_size = 256
164-
self.blocks_per_chunk = self.chunk_size // self.block_size
165166

166167
def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
167168
for layer_name in forward_context.no_compile_layers:
@@ -207,25 +208,54 @@ def DataOffset(self, kv_layer, rank, layer_id, is_v):
207208
layer_size * layer_id + layer_size / self.total_tp_size * self.rank
208209
)
209210

210-
def get_tensor_and_offset_layerwise(
211+
def get_pointers_and_offset_layerwise(
211212
self, vllm_block_ids_tensors: List[torch.Tensor], kv_layer: torch.Tensor, layer_name: str
213+
) -> tuple[List[List[int]], List[int]]:
214+
k_pointer_lists = []
215+
k_offsets = []
216+
v_pointer_lists = []
217+
v_offsets = []
218+
layer_id = self._extract_layer_index(layer_name)
219+
220+
for vllm_block_ids_tensor in vllm_block_ids_tensors:
221+
vllm_block_ids = vllm_block_ids_tensor.tolist()
222+
k_pointer_list = []
223+
k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False)
224+
for vllm_block_id in vllm_block_ids:
225+
if self.is_mla:
226+
k_pointer_list.append(kv_layer[vllm_block_id].data_ptr())
227+
else:
228+
k_pointer_list.append(kv_layer[0][vllm_block_id].data_ptr())
229+
k_pointer_lists.append(k_pointer_list)
230+
k_offsets.append(k_data_offset)
231+
if not self.is_mla:
232+
v_pointer_list = []
233+
v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True)
234+
for vllm_block_id in vllm_block_ids:
235+
v_pointer_list.append(kv_layer[1][vllm_block_id].data_ptr())
236+
v_offsets.append(v_data_offset)
237+
v_pointer_lists.append(v_pointer_list)
238+
return k_pointer_lists + v_pointer_lists, k_offsets + v_offsets
239+
240+
def get_tensor_and_offset_layerwise(
241+
self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str
212242
) -> tuple[List[torch.Tensor], List[int]]:
213243
k_tensors = []
214244
k_offsets = []
215245
v_tensors = []
216246
v_offsets = []
217247
layer_id = self._extract_layer_index(layer_name)
218248

219-
for vllm_block_ids_tensor in vllm_block_ids_tensors:
249+
for blk_id in vllm_block_ids:
220250
k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False)
221251
if self.is_mla:
222-
k_tensors.append(kv_layer[vllm_block_ids_tensor])
252+
k_tensors.append(kv_layer[blk_id])
223253
else:
224-
k_tensors.append(kv_layer[0][vllm_block_ids_tensor])
254+
k_tensors.append(kv_layer[0][blk_id])
225255
k_offsets.append(k_data_offset)
226256
if not self.is_mla:
227257
v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True)
228-
v_tensors.append(kv_layer[1][vllm_block_ids_tensor])
258+
v_tensors.append(kv_layer[1][blk_id])
229259
v_offsets.append(v_data_offset)
230260
return k_tensors + v_tensors, k_offsets + v_offsets
231261

@@ -277,18 +307,20 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
277307
vllm_block_ids_tensor.tolist()
278308
)
279309
for layer_name, kv_layer in self.kv_caches.items():
280-
tensors, offsets = self.get_tensor_and_offset_layerwise(
310+
pointers_list, offsets = self.get_pointers_and_offset_layerwise(
281311
vllm_block_ids_tensors, kv_layer, layer_name
282312
)
283-
k_task_id = self.connector.load(
284-
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
313+
size = [self.io_size] * blocks_len
314+
k_task_id = self.connector.fetch_data(
315+
storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len], size
285316
)
286317
v_task_id = None
287318
if not self.is_mla:
288-
v_task_id = self.connector.load(
319+
v_task_id = self.connector.fetch_data(
289320
storage_block_ids,
290321
offsets[blocks_len:],
291-
tensors[blocks_len:],
322+
pointers_list[blocks_len:],
323+
size
292324
)
293325
if request.request_id not in self.layerwise_load_tasks:
294326
self.layerwise_load_tasks[request.request_id] = {}
@@ -404,7 +436,7 @@ def save_kv_layer(
404436
storage_block_ids = [block[0] for block in request.dump_blocks]
405437
vllm_block_ids_tensors = [block[1] for block in request.dump_blocks] # [5, 8, 12]
406438
blocks_len = len(storage_block_ids)
407-
tensors, offsets = self.get_tensor_and_offset_layerwise(
439+
pointers_list, offsets = self.get_pointers_and_offset_layerwise(
408440
vllm_block_ids_tensors, kv_layer, layer_name
409441
)
410442

@@ -413,18 +445,18 @@ def save_kv_layer(
413445
elif kv_layer[0].device.type == "cuda":
414446
torch.cuda.current_stream().synchronize()
415447

416-
for block_id, offset, tensor in zip(
417-
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
448+
for block_id, offset, pointers in zip(
449+
storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len]
418450
):
419-
task = self.connector.dump([block_id], [offset], [tensor])
451+
task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size])
420452
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
421453
block_id, []
422454
).append(task)
423455
if not self.is_mla:
424-
for block_id, offset, tensor in zip(
425-
storage_block_ids, offsets[blocks_len:], tensors[blocks_len:]
456+
for block_id, offset, pointer_lists in zip(
457+
storage_block_ids, offsets[blocks_len:], pointers_list[blocks_len:]
426458
):
427-
task = self.connector.dump([block_id], [offset], [tensor])
459+
task = self.connector.dump_data([block_id], [offset], [pointer_lists], [self.io_size])
428460
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
429461
block_id, []
430462
).append(task)
@@ -465,23 +497,23 @@ def wait_for_tasks():
465497
vllm_block_ids_tensors = [block[1] for block in request.dump_blocks]
466498
blocks_len = len(storage_block_ids)
467499
for layer_name, kv_layer in self.kv_caches.items():
468-
tensors, offsets = self.get_tensor_and_offset_layerwise(
500+
pointers_list, offsets = self.get_pointers_and_offset_layerwise(
469501
vllm_block_ids_tensors, kv_layer, layer_name
470502
)
471-
for block_id, offset, tensor in zip(
472-
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
503+
for block_id, offset, pointers in zip(
504+
storage_block_ids, offsets[:blocks_len], pointers_list[:blocks_len]
473505
):
474-
task = self.connector.dump([block_id], [offset], [tensor])
506+
task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size])
475507
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
476508
block_id, []
477509
).append(task)
478510
if not self.is_mla:
479-
for block_id, offset, tensor in zip(
511+
for block_id, offset, pointers in zip(
480512
storage_block_ids,
481513
offsets[blocks_len:],
482-
tensors[blocks_len:],
514+
pointers_list[blocks_len:],
483515
):
484-
task = self.connector.dump([block_id], [offset], [tensor])
516+
task = self.connector.dump_data([block_id], [offset], [pointers], [self.io_size])
485517
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
486518
block_id, []
487519
).append(task)
@@ -633,7 +665,7 @@ def hash_request_tokens(
633665
start_position=start_position,
634666
)
635667
self._need_load_reqs[request.request_id] = []
636-
return num_lookup_hits * self.block_size, True
668+
return num_lookup_hits * self.block_size - num_computed_tokens, True
637669

638670
# When all the tokens are cached in ssd or hbm,
639671
# we need to recompute the last token. This if condition will be removed
@@ -650,7 +682,7 @@ def hash_request_tokens(
650682
start_position=start_position,
651683
)
652684

653-
return num_lookup_hits * self.block_size, False
685+
return num_lookup_hits * self.block_size - num_computed_tokens, False
654686

655687
def update_state_after_alloc(
656688
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int

ucm/store/dramstore/cpy/dramstore.py.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,12 @@ class DRAMStorePy : public DRAMStore {
7878
auto length = lengths.begin();
7979
while ((blockId != blockIds.end()) && (offset != offsets.end()) &&
8080
(address != addresses.end()) && (length != lengths.end())) {
81+
std::vector<uintptr_t> addr_vec;
82+
for (auto addr_item : address->cast<py::list>()) {
83+
addr_vec.push_back(addr_item.cast<uintptr_t>());
84+
}
8185
task.Append(blockId->cast<std::string>(), offset->cast<size_t>(),
82-
address->cast<uintptr_t>(), length->cast<size_t>());
86+
std::move(addr_vec), length->cast<size_t>());
8387
blockId++;
8488
offset++;
8589
address++;

ucm/store/localstore/cpy/localstore.py.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,12 @@ class LocalStorePy : public LocalStore {
7878
auto length = lengths.begin();
7979
while ((blockId != blockIds.end()) && (offset != offsets.end()) &&
8080
(address != addresses.end()) && (length != lengths.end())) {
81+
std::vector<uintptr_t> addr_vec;
82+
for (auto addr_item : address->cast<py::list>()) {
83+
addr_vec.push_back(addr_item.cast<uintptr_t>());
84+
}
8185
task.Append(blockId->cast<std::string>(), offset->cast<size_t>(),
82-
address->cast<uintptr_t>(), length->cast<size_t>());
86+
std::move(addr_vec), length->cast<size_t>());
8387
blockId++;
8488
offset++;
8589
address++;

ucm/store/nfsstore/cc/domain/trans/posix_queue.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,13 @@ Status PosixQueue::D2S(Task::Shard& shard, const Device& device)
103103
return Status::OutOfMemory();
104104
}
105105
auto hub = shard.buffer.get();
106-
auto status = device->D2HSync((std::byte*)hub, (std::byte*)shard.address, shard.length);
106+
auto dAddr = new std::byte*[shard.address.size()];
107+
auto hAddr = new std::byte*[shard.address.size()];
108+
for (size_t i = 0; i < shard.address.size(); i++) {
109+
hAddr[i] = (std::byte*)hub + i * shard.length / shard.address.size();
110+
dAddr[i] = (std::byte*)shard.address[i];
111+
}
112+
auto status = device->D2HBatchSync(hAddr, const_cast<const std::byte**>(dAddr), shard.address.size(), shard.length / shard.address.size());
107113
if (status.Failure()) { return status; }
108114
auto path = this->layout_->DataFilePath(shard.block, true);
109115
return File::Write(path, shard.offset, shard.length, (uintptr_t)hub);
@@ -120,21 +126,27 @@ Status PosixQueue::S2D(Task::Shard& shard, const Device& device)
120126
auto path = this->layout_->DataFilePath(shard.block, false);
121127
auto status = File::Read(path, shard.offset, shard.length, (uintptr_t)hub);
122128
if (status.Failure()) { return status; }
123-
return device->H2DAsync((std::byte*)shard.address, (std::byte*)hub, shard.length);
129+
auto dAddr = new std::byte*[shard.address.size()];
130+
auto hAddr = new std::byte*[shard.address.size()];
131+
for (size_t i = 0; i < shard.address.size(); i++) {
132+
hAddr[i] = (std::byte*)hub + i * shard.length / shard.address.size();
133+
dAddr[i] = (std::byte*)shard.address[i];
134+
}
135+
return device->H2DBatchSync(dAddr, const_cast<const std::byte**>(hAddr), shard.address.size(), shard.length / shard.address.size());
124136
}
125137

126138
Status PosixQueue::H2S(Task::Shard& shard)
127139
{
128140
auto path = this->layout_->DataFilePath(shard.block, true);
129-
auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address);
130-
return File::Write(path, shard.offset, shard.length, shard.address, aligned);
141+
auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address[0]);
142+
return File::Write(path, shard.offset, shard.length, shard.address[0], aligned);
131143
}
132144

133145
Status PosixQueue::S2H(Task::Shard& shard)
134146
{
135147
auto path = this->layout_->DataFilePath(shard.block, false);
136-
auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address);
137-
return File::Read(path, shard.offset, shard.length, shard.address, aligned);
148+
auto aligned = IsAligned(shard.offset) && IsAligned(shard.length) && IsAligned(shard.address[0]);
149+
return File::Read(path, shard.offset, shard.length, shard.address[0], aligned);
138150
}
139151

140152
} // namespace UC

ucm/store/nfsstore/cpy/nfsstore.py.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,12 @@ class NFSStorePy : public NFSStore {
9191
auto length = lengths.begin();
9292
while ((blockId != blockIds.end()) && (offset != offsets.end()) &&
9393
(address != addresses.end()) && (length != lengths.end())) {
94+
std::vector<uintptr_t> addr_vec;
95+
for (auto addr_item : address->cast<py::list>()) {
96+
addr_vec.push_back(addr_item.cast<uintptr_t>());
97+
}
9498
task.Append(blockId->cast<std::string>(), offset->cast<size_t>(),
95-
address->cast<uintptr_t>(), length->cast<size_t>());
99+
std::move(addr_vec), length->cast<size_t>());
96100
blockId++;
97101
offset++;
98102
address++;
@@ -123,8 +127,6 @@ PYBIND11_MODULE(ucmnfsstore, module)
123127
config.def_readwrite("transferBufferNumber", &UC::NFSStorePy::Config::transferBufferNumber);
124128
config.def_readwrite("transferTimeoutMs", &UC::NFSStorePy::Config::transferTimeoutMs);
125129
config.def_readwrite("tempDumpDirEnable", &UC::NFSStorePy::Config::tempDumpDirEnable);
126-
config.def_readwrite("hotnessEnable", &UC::NFSStorePy::Config::hotnessEnable);
127-
config.def_readwrite("hotnessInterval", &UC::NFSStorePy::Config::hotnessInterval);
128130
store.def(py::init<>());
129131
store.def("CCStoreImpl", &UC::NFSStorePy::CCStoreImpl);
130132
store.def("Setup", &UC::NFSStorePy::Setup);

ucm/store/nfsstore/nfsstore_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, config: Dict):
5151
if transfer_enable:
5252
param.transferDeviceId = config["device"]
5353
param.transferIoSize = config["io_size"]
54+
param.transferStreamNumber = config.get("transfer_stream_number", 128)
5455
ret = self.store.Setup(param)
5556
if ret != 0:
5657
msg = f"Failed to initialize ucmnfsstore, errcode: {ret}."

ucm/store/task/task_shard.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ class Task {
4747
Location location;
4848
std::string block;
4949
size_t offset;
50-
uintptr_t address;
50+
std::vector<uintptr_t> address;
5151
size_t length;
5252
size_t owner;
5353
std::shared_ptr<void> buffer;
5454
std::function<void(void)> done;
5555
Shard(const Type type, const Location location, const std::string& block,
56-
const size_t offset, const uintptr_t address, const size_t length, const size_t owner)
56+
const size_t offset, const std::vector<uintptr_t> address, const size_t length, const size_t owner)
5757
: type{type}, location{location}, block{block}, offset{offset}, address{address},
5858
length{length}, owner{owner}, buffer{nullptr}, done{nullptr}
5959
{
@@ -86,7 +86,7 @@ class Task {
8686
auto Id() const noexcept { return id_; }
8787
auto StartTp() const noexcept { return startTp_; }
8888
auto Str() const noexcept { return fmt::format("{},{},{},{}", id_, brief_, number_, size_); }
89-
void Append(const std::string& block, const size_t offset, const uintptr_t address,
89+
void Append(const std::string& block, const size_t offset, const std::vector<uintptr_t> address,
9090
const size_t length)
9191
{
9292
shards_.emplace_back(type_, location_, block, offset, address, length, id_);

ucm/store/ucmstore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def fetch_data(
129129
self,
130130
block_ids: List[str],
131131
offset: List[int],
132-
dst_addr: List[int],
132+
dst_addr: List[List[int]],
133133
size: List[int],
134134
) -> Task:
135135
"""
@@ -150,7 +150,7 @@ def dump_data(
150150
self,
151151
block_ids: List[str],
152152
offset: List[int],
153-
src_addr: List[int],
153+
src_addr: List[List[int]],
154154
size: List[int],
155155
) -> Task:
156156
"""

0 commit comments

Comments
 (0)