Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 43 additions & 41 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,40 +265,41 @@ def launch_cache_manager(
else:
kvcache_storage_backend_str = "none"

for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}"
+ val_cache_arg_str
+ f" --cache_queue_port {cache_config.local_cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --ipc_suffix {ipc_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" --kvcache_storage_backend {kvcache_storage_backend_str}"
+ f" --write_policy {cache_config.write_policy}"
+ f" --max_model_len {self.config.model_config.max_model_len}"
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
if self.cache_config.swap_space or self.cache_config.kvcache_storage_backend:
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}"
+ val_cache_arg_str
+ f" --cache_queue_port {cache_config.local_cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --ipc_suffix {ipc_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" --kvcache_storage_backend {kvcache_storage_backend_str}"
+ f" --write_policy {cache_config.write_policy}"
+ f" --max_model_len {self.config.model_config.max_model_len}"
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))

logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
Expand All @@ -308,13 +309,14 @@ def launch_cache_manager(
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
time.sleep(1)

exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info(
"Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information"
)
if cache_manager_processes:
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info(
"Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information"
)

# Start additional threads
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
Expand Down
83 changes: 32 additions & 51 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
IPCSignal,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
RearrangeExpertStatus,
Expand Down Expand Up @@ -529,6 +528,19 @@ def update_model_weight(self, timeout=300):
2 : worker update finish and notify client
"""
with self.clear_update_lock:
if self.fd_config.cache_config.swap_space:
return False, "hierarchical cache updating is not supported"

if self.enable_prefix_caching:
# prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"Start to update prefix tree {self.prefix_tree_status_signal.value[0]}")
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
api_server_logger.info(f"..updating prefix tree {self.prefix_tree_status_signal.value[0]}")
time.sleep(1)

# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
Expand All @@ -537,34 +549,13 @@ def update_model_weight(self, timeout=300):
return False, "worker is clearing model weight, cannot update now"

self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
all_updated = False
while timeout >= 0 and not all_updated:
api_server_logger.info(
f"Updating model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_updated = weight_updated and cache_updated and prefix_updated
else:
all_updated = weight_updated and cache_updated
else:
all_updated = weight_updated
api_server_logger.info(f"Start to update model weight {self.model_weights_status_signal.value[0]}")
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.NORMAL:
api_server_logger.info(f"..updating model weights {self.model_weights_status_signal.value[0]}")
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Update model weight timeout"
time.sleep(1)
return True, ""

def clear_load_weight(self, timeout=300):
Expand All @@ -575,6 +566,19 @@ def clear_load_weight(self, timeout=300):
"""

with self.clear_update_lock:
if self.fd_config.cache_config.swap_space:
return False, "hierarchical cache clearing is not supported"

if self.enable_prefix_caching:
# prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(f"Start to clear prefix tree {self.prefix_tree_status_signal.value[0]}")
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
api_server_logger.info(f"..clearing prefix tree {self.prefix_tree_status_signal.value[0]}")
time.sleep(1)

# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
Expand All @@ -583,36 +587,13 @@ def clear_load_weight(self, timeout=300):
return False, "worker is updating model weight, cannot clear now"

self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING

api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
all_cleared = False
while timeout >= 0 and not all_cleared:
api_server_logger.info(
f"Clearing model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_cleared = weight_cleared and cache_cleared and prefix_cleared
else:
all_cleared = weight_cleared and cache_cleared
else:
all_cleared = weight_cleared
api_server_logger.info(f"Start to clear model weight {self.model_weights_status_signal.value[0]}")
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.CLEARED:
api_server_logger.info(f"..clearing model weights {self.model_weights_status_signal.value[0]}")
time.sleep(1)
timeout -= 1

if timeout < 0:
return False, "Clear model weight timeout"
time.sleep(1)
return True, ""

def check_model_weight_status(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/cache_manager/test_prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _create_manager(
local_rdma_comm_ports=None,
kvcache_storage_backend=None,
write_policy="write_through",
swap_space=4,
)
model_config = SimpleNamespace(
num_attention_heads=1,
Expand Down Expand Up @@ -663,7 +664,6 @@ def test_launch_cache_messager_returns_none_when_process_fails(self):

def test_launch_cache_manager_formats_value_cache_shape(self):
manager = _create_manager()
manager.cache_config.enable_hierarchical_cache = False

captured = {}

Expand Down
Loading