diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 91b23a29717..cc27e958ed0 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -265,40 +265,41 @@ def launch_cache_manager( else: kvcache_storage_backend_str = "none" - for i in range(tensor_parallel_size): - launch_cmd = ( - "FLAGS_allocator_strategy=auto_growth " - + visible_devices - + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" - + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" - + f" {sys.executable} {py_path}" - + f" --device_id {int(device_ids[i])}" - + f" --rank {i}" - + f" --splitwise_role {self.splitwise_role}" - + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" - + f" --mp_num {tensor_parallel_size}" - + f" --cache_dtype {cache_config.cache_dtype}" - + f" --key_cache_shape {key_cache_shape}" - + val_cache_arg_str - + f" --cache_queue_port {cache_config.local_cache_queue_port}" - + f" --enable_splitwise {int(self.enable_splitwise)}" - + f" --pod_ip {pod_ip}" - + f" --engine_worker_queue_port {engine_worker_queue_port}" - + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" - + f" --ipc_suffix {ipc_suffix}" - + f" --protocol {cache_config.cache_transfer_protocol}" - + f" --local_data_parallel_id {self.local_data_parallel_id}" - + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" - + f" --speculative_config '{self.speculative_config.to_json_string()}'" - + f" --default_dtype '{self.config.model_config.dtype}'" - + (" --create_cache_tensor" if create_cache_tensor else "") - + f" --kvcache_storage_backend {kvcache_storage_backend_str}" - + f" --write_policy {cache_config.write_policy}" - + f" --max_model_len {self.config.model_config.max_model_len}" - + f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1" - ) - logger.info(f"Launch cache transfer manager, command:{launch_cmd}") - cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) + if self.cache_config.swap_space or self.cache_config.kvcache_storage_backend: + for i in range(tensor_parallel_size): + launch_cmd = ( + "FLAGS_allocator_strategy=auto_growth " + + visible_devices + + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" + + f" {sys.executable} {py_path}" + + f" --device_id {int(device_ids[i])}" + + f" --rank {i}" + + f" --splitwise_role {self.splitwise_role}" + + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" + + f" --mp_num {tensor_parallel_size}" + + f" --cache_dtype {cache_config.cache_dtype}" + + f" --key_cache_shape {key_cache_shape}" + + val_cache_arg_str + + f" --cache_queue_port {cache_config.local_cache_queue_port}" + + f" --enable_splitwise {int(self.enable_splitwise)}" + + f" --pod_ip {pod_ip}" + + f" --engine_worker_queue_port {engine_worker_queue_port}" + + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" + + f" --ipc_suffix {ipc_suffix}" + + f" --protocol {cache_config.cache_transfer_protocol}" + + f" --local_data_parallel_id {self.local_data_parallel_id}" + + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" + + f" --speculative_config '{self.speculative_config.to_json_string()}'" + + f" --default_dtype '{self.config.model_config.dtype}'" + + (" --create_cache_tensor" if create_cache_tensor else "") + + f" --kvcache_storage_backend {kvcache_storage_backend_str}" + + f" --write_policy {cache_config.write_policy}" + + f" --max_model_len {self.config.model_config.max_model_len}" + + f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1" + ) + logger.info(f"Launch cache transfer manager, command:{launch_cmd}") + cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) logger.info("PrefixCacheManager is waiting for kv cache to be initialized.") while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: @@ -308,13 +309,14 @@ def launch_cache_manager( while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size: time.sleep(1) - exit_code = cache_manager_processes[-1].poll() - if exit_code is None: - logger.info("Launch cache transfer manager successful") - else: - logger.info( - "Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information" - ) + if cache_manager_processes: + exit_code = cache_manager_processes[-1].poll() + if exit_code is None: + logger.info("Launch cache transfer manager successful") + else: + logger.info( + "Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information" + ) # Start additional threads if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0: diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 9babe8fec74..96af0dc6540 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -34,7 +34,6 @@ from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import ( IPCSignal, - KVCacheStatus, ModelWeightsStatus, PrefixTreeStatus, RearrangeExpertStatus, @@ -529,6 +528,19 @@ def update_model_weight(self, timeout=300): 2 : worker update finish and notify client """ with self.clear_update_lock: + if self.fd_config.cache_config.swap_space: + return False, "hierarchical cache updating is not supported" + + if self.enable_prefix_caching: + # prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL + if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED: + self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING + api_server_logger.info(f"Start to update prefix tree {self.prefix_tree_status_signal.value[0]}") + while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL: + api_server_logger.info(f"..updating prefix tree {self.prefix_tree_status_signal.value[0]}") + time.sleep(1) + + # model_weights_status_signal: CLEARED -> UPDATING -> NORMAL if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL: return True, "" if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING: @@ -537,34 +549,13 @@ def update_model_weight(self, timeout=300): return False, "worker is clearing model weight, cannot update now" self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING - if self.enable_prefix_caching or self.enable_splitwise: - self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING - if self.enable_prefix_caching: - self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING - api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}") - all_updated = False - while timeout >= 0 and not all_updated: - api_server_logger.info( - f"Updating model weights.. " - f"model_weights_status: {self.model_weights_status_signal.value[0]}, " - f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, " - f"kv_cache_status: {self.kv_cache_status_signal.value[0]} " - ) - weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL - cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL - prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL - if self.enable_prefix_caching or self.enable_splitwise: - if self.enable_prefix_caching: - all_updated = weight_updated and cache_updated and prefix_updated - else: - all_updated = weight_updated and cache_updated - else: - all_updated = weight_updated + api_server_logger.info(f"Start to update model weight {self.model_weights_status_signal.value[0]}") + while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.NORMAL: + api_server_logger.info(f"..updating model weights {self.model_weights_status_signal.value[0]}") time.sleep(1) timeout -= 1 if timeout < 0: return False, "Update model weight timeout" - time.sleep(1) return True, "" def clear_load_weight(self, timeout=300): @@ -575,6 +566,19 @@ def clear_load_weight(self, timeout=300): """ with self.clear_update_lock: + if self.fd_config.cache_config.swap_space: + return False, "hierarchical cache clearing is not supported" + + if self.enable_prefix_caching: + # prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED + if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL: + self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING + api_server_logger.info(f"Start to clear prefix tree {self.prefix_tree_status_signal.value[0]}") + while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED: + api_server_logger.info(f"..clearing prefix tree {self.prefix_tree_status_signal.value[0]}") + time.sleep(1) + + # model_weights_status_signal: NORMAL -> CLEARING -> CLEARED if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED: return True, "" if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING: @@ -583,36 +587,13 @@ def clear_load_weight(self, timeout=300): return False, "worker is updating model weight, cannot clear now" self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING - if self.enable_prefix_caching or self.enable_splitwise: - self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING - if self.enable_prefix_caching: - self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING - - api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}") - all_cleared = False - while timeout >= 0 and not all_cleared: - api_server_logger.info( - f"Clearing model weights.. " - f"model_weights_status: {self.model_weights_status_signal.value[0]}, " - f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, " - f"kv_cache_status: {self.kv_cache_status_signal.value[0]} " - ) - weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED - cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED - prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED - if self.enable_prefix_caching or self.enable_splitwise: - if self.enable_prefix_caching: - all_cleared = weight_cleared and cache_cleared and prefix_cleared - else: - all_cleared = weight_cleared and cache_cleared - else: - all_cleared = weight_cleared + api_server_logger.info(f"Start to clear model weight {self.model_weights_status_signal.value[0]}") + while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.CLEARED: + api_server_logger.info(f"..clearing model weights {self.model_weights_status_signal.value[0]}") time.sleep(1) timeout -= 1 - if timeout < 0: return False, "Clear model weight timeout" - time.sleep(1) return True, "" def check_model_weight_status(self): diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index a834ff75843..df2b52d8abb 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -182,6 +182,7 @@ def _create_manager( local_rdma_comm_ports=None, kvcache_storage_backend=None, write_policy="write_through", + swap_space=4, ) model_config = SimpleNamespace( num_attention_heads=1, @@ -663,7 +664,6 @@ def test_launch_cache_messager_returns_none_when_process_fails(self): def test_launch_cache_manager_formats_value_cache_shape(self): manager = _create_manager() - manager.cache_config.enable_hierarchical_cache = False captured = {}