diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 4142aeccaa2..5c0cc7cc5ad 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -254,37 +254,38 @@ def launch_cache_manager( val_shape_str = str(val_cache_shape) val_cache_arg_str = f" --value_cache_shape {val_shape_str}" - for i in range(tensor_parallel_size): - launch_cmd = ( - "FLAGS_allocator_strategy=auto_growth " - + visible_devices - + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" - + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" - + f" {sys.executable} {py_path}" - + f" --device_id {int(device_ids[i])}" - + f" --rank {i}" - + f" --splitwise_role {self.splitwise_role}" - + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" - + f" --mp_num {tensor_parallel_size}" - + f" --cache_dtype {cache_config.cache_dtype}" - + f" --key_cache_shape {key_cache_shape}" - + val_cache_arg_str - + f" --cache_queue_port {cache_config.cache_queue_port}" - + f" --enable_splitwise {int(self.enable_splitwise)}" - + f" --pod_ip {pod_ip}" - + f" --engine_worker_queue_port {engine_worker_queue_port}" - + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" - + f" --engine_pid {pid_suffix}" - + f" --default_dtype '{self.config.model_config.dtype}'" - + f" --protocol {cache_config.cache_transfer_protocol}" - + f" --local_data_parallel_id {self.local_data_parallel_id}" - + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" - + f" --speculative_config '{self.speculative_config.to_json_string()}'" - + (" --create_cache_tensor" if create_cache_tensor else "") - + f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1" - ) - logger.info(f"Launch cache transfer manager, command:{launch_cmd}") - cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) + if self.cache_config.enable_hierarchical_cache: + for i in range(tensor_parallel_size): + launch_cmd = ( + "FLAGS_allocator_strategy=auto_growth " + + visible_devices + + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" + + f" {sys.executable} {py_path}" + + f" --device_id {int(device_ids[i])}" + + f" --rank {i}" + + f" --splitwise_role {self.splitwise_role}" + + f" --num_layers {cache_config.model_cfg.num_hidden_layers}" + + f" --mp_num {tensor_parallel_size}" + + f" --cache_dtype {cache_config.cache_dtype}" + + f" --key_cache_shape {key_cache_shape}" + + val_cache_arg_str + + f" --cache_queue_port {cache_config.cache_queue_port}" + + f" --enable_splitwise {int(self.enable_splitwise)}" + + f" --pod_ip {pod_ip}" + + f" --engine_worker_queue_port {engine_worker_queue_port}" + + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" + + f" --engine_pid {pid_suffix}" + + f" --default_dtype '{self.config.model_config.dtype}'" + + f" --protocol {cache_config.cache_transfer_protocol}" + + f" --local_data_parallel_id {self.local_data_parallel_id}" + + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + + f" --speculative_config '{self.speculative_config.to_json_string()}'" + + (" --create_cache_tensor" if create_cache_tensor else "") + + f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1" + ) + logger.info(f"Launch cache transfer manager, command:{launch_cmd}") + cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) logger.info("PrefixCacheManager is waiting for kv cache to be initialized.") while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: @@ -294,13 +295,14 @@ def launch_cache_manager( while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size: time.sleep(1) - exit_code = cache_manager_processes[-1].poll() - if exit_code is None: - logger.info("Launch cache transfer manager successful") - else: - logger.info( - "Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information" - ) + if cache_manager_processes: + exit_code = cache_manager_processes[-1].poll() + if exit_code is None: + logger.info("Launch cache transfer manager successful") + else: + logger.info( + "Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information" + ) # Start additional threads if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 7d387acc609..0998ce4a8b4 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -33,7 +33,6 @@ from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import ( IPCSignal, - KVCacheStatus, ModelWeightsStatus, PrefixTreeStatus, RearrangeExpertStatus, @@ -548,6 +547,28 @@ def update_model_weight(self, timeout=300): 2 : worker update finish and notify client """ with self.clear_update_lock: + if self.fd_config.cache_config.enable_hierarchical_cache: + return False, "hierarchical cache updating is not supported" + + # if self.enable_prefix_caching or self.enable_splitwise: + # # kv_cache_status_signal: CLEARED -> UPDATING -> NORMAL + # if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED: + # self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING + # api_server_logger.info(f"Start to update kv cache {self.kv_cache_status_signal.value[0]}") + # while self.kv_cache_status_signal.value[0] != KVCacheStatus.NORMAL: + # api_server_logger.info(f"..updating kv cache {self.kv_cache_status_signal.value[0]}") + # time.sleep(1) + + if self.enable_prefix_caching: + # prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL + if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED: + self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING + api_server_logger.info(f"Start to update prefix tree {self.prefix_tree_status_signal.value[0]}") + while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL: + api_server_logger.info(f"..updating prefix tree {self.prefix_tree_status_signal.value[0]}") + time.sleep(1) + + # model_weights_status_signal: CLEARED -> UPDATING -> NORMAL if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL: return True, "" if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING: @@ -556,34 +577,13 @@ def update_model_weight(self, timeout=300): return False, "worker is clearing model weight, cannot update now" self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING - if self.enable_prefix_caching or self.enable_splitwise: - self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING - if self.enable_prefix_caching: - self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING - api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}") - all_updated = False - while timeout >= 0 and not all_updated: - api_server_logger.info( - f"Updating model weights.. " - f"model_weights_status: {self.model_weights_status_signal.value[0]}, " - f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, " - f"kv_cache_status: {self.kv_cache_status_signal.value[0]} " - ) - weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL - cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL - prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL - if self.enable_prefix_caching or self.enable_splitwise: - if self.enable_prefix_caching: - all_updated = weight_updated and cache_updated and prefix_updated - else: - all_updated = weight_updated and cache_updated - else: - all_updated = weight_updated + api_server_logger.info(f"Start to update model weight {self.model_weights_status_signal.value[0]}") + while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.NORMAL: + api_server_logger.info(f"..updating model weights {self.model_weights_status_signal.value[0]}") time.sleep(1) timeout -= 1 if timeout < 0: return False, "Update model weight timeout" - time.sleep(1) return True, "" def clear_load_weight(self, timeout=300): @@ -594,6 +594,27 @@ def clear_load_weight(self, timeout=300): """ with self.clear_update_lock: + if self.fd_config.cache_config.enable_hierarchical_cache: + return False, "hierarchical cache clearing is not supported" + # if self.enable_prefix_caching or self.enable_splitwise: + # # kv_cache_status_signal: NORMAL -> CLEARING -> CLEARED + # if self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL: + # self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING + # api_server_logger.info(f"Start to clear kv cache {self.kv_cache_status_signal.value[0]}") + # while self.kv_cache_status_signal.value[0] != KVCacheStatus.CLEARED: + # api_server_logger.info(f"..clearing kv cache {self.kv_cache_status_signal.value[0]}") + # time.sleep(1) + + if self.enable_prefix_caching: + # prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED + if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL: + self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING + api_server_logger.info(f"Start to clear prefix tree {self.prefix_tree_status_signal.value[0]}") + while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED: + api_server_logger.info(f"..clearing prefix tree {self.prefix_tree_status_signal.value[0]}") + time.sleep(1) + + # model_weights_status_signal: NORMAL -> CLEARING -> CLEARED if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED: return True, "" if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING: @@ -602,36 +623,13 @@ def clear_load_weight(self, timeout=300): return False, "worker is updating model weight, cannot clear now" self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING - if self.enable_prefix_caching or self.enable_splitwise: - self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING - if self.enable_prefix_caching: - self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING - - api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}") - all_cleared = False - while timeout >= 0 and not all_cleared: - api_server_logger.info( - f"Clearing model weights.. " - f"model_weights_status: {self.model_weights_status_signal.value[0]}, " - f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, " - f"kv_cache_status: {self.kv_cache_status_signal.value[0]} " - ) - weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED - cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED - prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED - if self.enable_prefix_caching or self.enable_splitwise: - if self.enable_prefix_caching: - all_cleared = weight_cleared and cache_cleared and prefix_cleared - else: - all_cleared = weight_cleared and cache_cleared - else: - all_cleared = weight_cleared + api_server_logger.info(f"Start to clear model weight {self.model_weights_status_signal.value[0]}") + while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.CLEARED: + api_server_logger.info(f"..clearing model weights {self.model_weights_status_signal.value[0]}") time.sleep(1) timeout -= 1 - if timeout < 0: return False, "Clear model weight timeout" - time.sleep(1) return True, "" def check_model_weight_status(self):