Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 40 additions & 38 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,37 +254,38 @@ def launch_cache_manager(
val_shape_str = str(val_cache_shape)
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"

for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}"
+ val_cache_arg_str
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --engine_pid {pid_suffix}"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
if self.cache_config.enable_hierarchical_cache:
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}"
+ val_cache_arg_str
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --engine_pid {pid_suffix}"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))

logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
Expand All @@ -294,13 +295,14 @@ def launch_cache_manager(
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
time.sleep(1)

exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info(
"Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information"
)
if cache_manager_processes:
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info(
"Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information"
)

# Start additional threads
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
Expand Down
100 changes: 49 additions & 51 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
IPCSignal,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
RearrangeExpertStatus,
Expand Down Expand Up @@ -548,6 +547,28 @@ def update_model_weight(self, timeout=300):
2 : worker update finish and notify client
"""
with self.clear_update_lock:
if self.fd_config.cache_config.enable_hierarchical_cache:
return False, "hierarchical cache updating is not supported"

# if self.enable_prefix_caching or self.enable_splitwise:
# # kv_cache_status_signal: CLEARED -> UPDATING -> NORMAL
# if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED:
# self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
# api_server_logger.info(f"Start to update kv cache {self.kv_cache_status_signal.value[0]}")
# while self.kv_cache_status_signal.value[0] != KVCacheStatus.NORMAL:
# api_server_logger.info(f"..updating kv cache {self.kv_cache_status_signal.value[0]}")
# time.sleep(1)

if self.enable_prefix_caching:
# prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"Start to update prefix tree {self.prefix_tree_status_signal.value[0]}")
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
api_server_logger.info(f"..updating prefix tree {self.prefix_tree_status_signal.value[0]}")
time.sleep(1)

# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
Expand All @@ -556,34 +577,13 @@ def update_model_weight(self, timeout=300):
return False, "worker is clearing model weight, cannot update now"

self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
all_updated = False
while timeout >= 0 and not all_updated:
api_server_logger.info(
f"Updating model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_updated = weight_updated and cache_updated and prefix_updated
else:
all_updated = weight_updated and cache_updated
else:
all_updated = weight_updated
api_server_logger.info(f"Start to update model weight {self.model_weights_status_signal.value[0]}")
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.NORMAL:
api_server_logger.info(f"..updating model weights {self.model_weights_status_signal.value[0]}")
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Update model weight timeout"
time.sleep(1)
return True, ""

def clear_load_weight(self, timeout=300):
Expand All @@ -594,6 +594,27 @@ def clear_load_weight(self, timeout=300):
"""

with self.clear_update_lock:
if self.fd_config.cache_config.enable_hierarchical_cache:
return False, "hierarchical cache clearing is not supported"
# if self.enable_prefix_caching or self.enable_splitwise:
# # kv_cache_status_signal: NORMAL -> CLEARING -> CLEARED
# if self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL:
# self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
# api_server_logger.info(f"Start to clear kv cache {self.kv_cache_status_signal.value[0]}")
# while self.kv_cache_status_signal.value[0] != KVCacheStatus.CLEARED:
# api_server_logger.info(f"..clearing kv cache {self.kv_cache_status_signal.value[0]}")
# time.sleep(1)

if self.enable_prefix_caching:
# prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(f"Start to clear prefix tree {self.prefix_tree_status_signal.value[0]}")
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
api_server_logger.info(f"..clearing prefix tree {self.prefix_tree_status_signal.value[0]}")
time.sleep(1)

# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
Expand All @@ -602,36 +623,13 @@ def clear_load_weight(self, timeout=300):
return False, "worker is updating model weight, cannot clear now"

self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING

api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
all_cleared = False
while timeout >= 0 and not all_cleared:
api_server_logger.info(
f"Clearing model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_cleared = weight_cleared and cache_cleared and prefix_cleared
else:
all_cleared = weight_cleared and cache_cleared
else:
all_cleared = weight_cleared
api_server_logger.info(f"Start to clear model weight {self.model_weights_status_signal.value[0]}")
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.CLEARED:
api_server_logger.info(f"..clearing model weights {self.model_weights_status_signal.value[0]}")
time.sleep(1)
timeout -= 1

if timeout < 0:
return False, "Clear model weight timeout"
time.sleep(1)
return True, ""

def check_model_weight_status(self):
Expand Down
Loading