From 66ad2df28a1da5a31d1d3b86092381f0e6222ce6 Mon Sep 17 00:00:00 2001 From: ddchenhao66 Date: Sun, 4 Jan 2026 11:54:05 +0000 Subject: [PATCH] [XPU]xpu support ep4tp1 in pd disaggregation --- custom_ops/xpu_ops/src/ops/get_output.cc | 30 +- custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 6 +- fastdeploy/worker/xpu_model_runner.py | 6 +- tests/xpu_ci/test_pd_21b_tp1ep4.py | 327 ++++++++++++++++++++ 4 files changed, 355 insertions(+), 14 deletions(-) create mode 100644 tests/xpu_ci/test_pd_21b_tp1ep4.py diff --git a/custom_ops/xpu_ops/src/ops/get_output.cc b/custom_ops/xpu_ops/src/ops/get_output.cc index e2cf48aab42..825814ff48f 100644 --- a/custom_ops/xpu_ops/src/ops/get_output.cc +++ b/custom_ops/xpu_ops/src/ops/get_output.cc @@ -20,11 +20,11 @@ #include "msg_utils.h" #include "paddle/extension.h" -void GetOutputKVSignal(const paddle::Tensor &x, +void GetOutputKVSignal(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) { int msg_queue_id = 1024; - if (const char *msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) { + if (const char* msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) { std::string msg_que_str(msg_que_str_tmp); msg_queue_id = std::stoi(msg_que_str); } @@ -33,7 +33,7 @@ void GetOutputKVSignal(const paddle::Tensor &x, static key_t key = ftok("/opt/", msg_queue_id); static int msgid = msgget(key, IPC_CREAT | 0666); - int *out_data = const_cast(x.data()); + int* out_data = const_cast(x.data()); int ret = -1; if (!wait_flag) { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT); @@ -53,15 +53,12 @@ void GetOutputKVSignal(const paddle::Tensor &x, return; } -void GetOutput(const paddle::Tensor &x, +void GetOutput(const paddle::Tensor& x, int64_t rank_id, bool wait_flag, int msg_queue_id) { - if (rank_id > 0) { - return; - } static struct msgdata msg_rcv; - if (const char *inference_msg_queue_id_env_p = + if (const char* inference_msg_queue_id_env_p = std::getenv("INFERENCE_MSG_QUEUE_ID")) { std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); int inference_msg_queue_id_from_env = @@ -82,7 +79,7 @@ void GetOutput(const paddle::Tensor &x, std::cout << "get_output wait_flag: " << wait_flag << std::endl; #endif - int64_t *out_data = const_cast(x.data()); + int64_t* out_data = const_cast(x.data()); int ret = -1; if (!wait_flag) { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); @@ -110,11 +107,20 @@ void GetOutput(const paddle::Tensor &x, return; } -void GetOutputStatic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag) { +void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) { + if (rank_id > 0) { + return; + } + GetOutput(x, rank_id, wait_flag, 1); +} + +void GetOutputEPStatic(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag) { GetOutput(x, rank_id, wait_flag, 1); } -void GetOutputDynamic(const paddle::Tensor &x, +void GetOutputDynamic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag, int msg_queue_id) { @@ -140,7 +146,7 @@ PD_BUILD_OP(get_output_ep) .Attrs({"rank_id: int64_t", "wait_flag: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) - .SetKernelFn(PD_KERNEL(GetOutputStatic)); + .SetKernelFn(PD_KERNEL(GetOutputEPStatic)); PD_BUILD_OP(get_output_ep_dynamic) .Inputs({"x"}) diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 23817415772..a2631bab058 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -362,6 +362,10 @@ std::vector GetInferParam( void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag); +void GetOutputEPStatic(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag); + void GetOutputDynamic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag, @@ -839,7 +843,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { "get_output function"); m.def("get_output_ep", - &GetOutputStatic, + &GetOutputEPStatic, py::arg("x"), py::arg("rank_id"), py::arg("wait_flag"), diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index e02f7b6dcfe..65511f98776 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -173,6 +173,10 @@ def __init__( # Forward meta store the global meta information of the forward self.forward_meta: ForwardMeta = None + # Postprocess Env params + os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.local_engine_worker_queue_port) + logger.info(f"queue id is {str(self.parallel_config.local_engine_worker_queue_port)}") + self.pd_disaggregation_mode: str = self.fd_config.parallel_config.pd_disaggregation_mode # Initialize ZMQ client for async output @@ -1439,7 +1443,7 @@ class at the server level, which is too granular for ModelRunner. # 投机解码 full_hidden_states=model_output if self.speculative_decoding else None, msg_queue_id=self.parallel_config.msg_queue_id, - mp_rank=self.local_rank, + mp_rank=self.parallel_config.tensor_parallel_rank, use_ep=self.parallel_config.use_ep, draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), actual_draft_token_num=( diff --git a/tests/xpu_ci/test_pd_21b_tp1ep4.py b/tests/xpu_ci/test_pd_21b_tp1ep4.py new file mode 100644 index 00000000000..b4aad965cf8 --- /dev/null +++ b/tests/xpu_ci/test_pd_21b_tp1ep4.py @@ -0,0 +1,327 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PD分离测试 - Prefill/Decode分离部署模式 + +测试配置: +- 模型: ERNIE-4.5-21B-A3B-Paddle +- 量化: wint4 +- Tensor Parallel: 4 +- 特性: splitwise PD分离, RDMA cache传输 +- 节点: Router + Prefill节点 + Decode节点 +""" + +import os +import shutil +import subprocess +import time + +import openai +import pytest +from conftest import ( + cleanup_resources, + get_model_path, + get_port_num, + restore_pd_ep_env, + setup_pd_ep_env, + stop_processes, +) + + +def wait_for_pd_health_check(port_p, port_d, timeout=600, interval=10): + """ + 等待PD分离服务健康检查通过(检查P节点和D节点) + + Args: + port_p: Prefill节点端口 + port_d: Decode节点端口 + timeout: 超时时间(秒), 默认10分钟 + interval: 检查间隔(秒), 默认10秒 + + Returns: + bool: 服务是否启动成功 + """ + endpoint_p = f"http://0.0.0.0:{port_p}/health" + endpoint_d = f"http://0.0.0.0:{port_d}/health" + start_time = time.time() + + print(f"开始PD分离+EP4TP1服务健康检查,最长等待时间:{timeout}秒") + + while True: + elapsed = int(time.time() - start_time) + + # 超时判断 + if elapsed >= timeout: + print(f"\nPD分离服务启动超时:经过 {timeout//60} 分钟服务仍未启动!") + return False + + # 检查P节点 + try: + result_p = subprocess.run( + f'curl -s -o /dev/null -w "%{{http_code}}" -m 2 {endpoint_p}', + shell=True, + capture_output=True, + text=True, + ) + http_code_p = result_p.stdout.strip() + except Exception: + http_code_p = "000" + + # 检查D节点 + try: + result_d = subprocess.run( + f'curl -s -o /dev/null -w "%{{http_code}}" -m 2 {endpoint_d}', + shell=True, + capture_output=True, + text=True, + ) + http_code_d = result_d.stdout.strip() + except Exception: + http_code_d = "000" + + print( + f"\r服务健康检查中... 已等待 {elapsed} 秒,P节点状态码:{http_code_p},D节点状态码:{http_code_d}", + end="", + flush=True, + ) + + if http_code_p == "200" and http_code_d == "200": + print(f"\nPD分离服务启动成功!耗时 {elapsed} 秒") + return True + + time.sleep(interval) + + +def print_pd_logs_on_failure(): + """失败时打印PD分离相关日志""" + log_dirs = ["log_router", "log_prefill", "log_decode"] + + for log_dir in log_dirs: + nohup_path = os.path.join(log_dir, "nohup") + if os.path.exists(nohup_path): + print(f"\n========== {nohup_path} ==========") + with open(nohup_path, "r") as f: + print(f.read()) + + +def start_pd_server(model_path, port_num, wait_before_check=60): + """ + 启动PD分离服务(Router + Prefill节点 + Decode节点) + + Args: + model_path: 模型路径 + port_num: 基础端口号 + wait_before_check: 启动后等待多少秒再进行健康检查,默认60秒 + + Returns: + bool: 服务是否启动成功 + """ + + # 停止旧进程 + stop_processes() + + # 清理资源 + cleanup_resources() + + # 清理并创建日志目录 + for log_dir in ["log_router", "log_prefill", "log_decode"]: + if os.path.exists(log_dir): + shutil.rmtree(log_dir) + os.makedirs(log_dir, exist_ok=True) + + # 1. 启动Router + print("启动Router...") + router_env = os.environ.copy() + router_env["FD_LOG_DIR"] = "log_router" + router_cmd = [ + "python", + "-m", + "fastdeploy.router.launch", + "--port", + str(port_num), + "--splitwise", + ] + + with open("log_router/nohup", "w") as log_file: + subprocess.Popen(router_cmd, stdout=log_file, stderr=subprocess.STDOUT, start_new_session=True, env=router_env) + print(f"Router启动命令: {' '.join(router_cmd)}") + time.sleep(1) + + # 2. 启动Prefill节点 + print("启动Prefill节点...") + prefill_env = os.environ.copy() + prefill_env["FD_LOG_DIR"] = "log_prefill" + prefill_env["XPU_VISIBLE_DEVICES"] = "0,1,2,3" + + prefill_cmd = [ + "python", + "-m", + "fastdeploy.entrypoints.openai.multi_api_server", + "--port", + f"{port_num + 11},{port_num + 12},{port_num + 13},{port_num + 14}", + "--num-servers", + "4", + "--args", + "--model", + f"{model_path}/ERNIE-4.5-21B-A3B-Paddle", + "--tensor-parallel-size", + "1", + "--data-parallel-size", + "4", + "--max-model-len", + "32768", + "--max-num-seqs", + "64", + "--quantization", + "wint4", + "--splitwise-role", + "prefill", + "--cache-transfer-protocol", + "rdma", + "--enable-expert-parallel", + "--disable-sequence-parallel-moe", + "--router", + f"0.0.0.0:{port_num}", + ] + + with open("log_prefill/nohup", "w") as log_file: + subprocess.Popen( + prefill_cmd, stdout=log_file, stderr=subprocess.STDOUT, start_new_session=True, env=prefill_env + ) + print(f"Prefill节点启动命令: {' '.join(prefill_cmd)}") + + # 3. 启动Decode节点 + print("启动Decode节点...") + decode_env = os.environ.copy() + decode_env["FD_LOG_DIR"] = "log_decode" + decode_env["XPU_VISIBLE_DEVICES"] = "4,5,6,7" + + decode_cmd = [ + "python", + "-m", + "fastdeploy.entrypoints.openai.multi_api_server", + "--port", + f"{port_num + 21},{port_num + 22},{port_num + 23},{port_num + 24}", + "--num-servers", + "4", + "--args", + "--model", + f"{model_path}/ERNIE-4.5-21B-A3B-Paddle", + "--tensor-parallel-size", + "1", + "--data-parallel-size", + "4", + "--max-model-len", + "32768", + "--max-num-seqs", + "64", + "--quantization", + "wint4", + "--splitwise-role", + "decode", + "--cache-transfer-protocol", + "rdma", + "--enable-expert-parallel", + "--disable-sequence-parallel-moe", + "--router", + f"0.0.0.0:{port_num}", + ] + + with open("log_decode/nohup", "w") as log_file: + subprocess.Popen(decode_cmd, stdout=log_file, stderr=subprocess.STDOUT, start_new_session=True, env=decode_env) + print(f"Decode节点启动命令: {' '.join(decode_cmd)}") + + # 等待服务启动 + print(f"等待 {wait_before_check} 秒让服务初始化...") + time.sleep(wait_before_check) + + # 健康检查(检查P节点和D节点) + port_p = port_num + 11 + port_d = port_num + 21 + + if not wait_for_pd_health_check(port_p, port_d): + print_pd_logs_on_failure() + stop_processes() + return False + # ensure pd service is ready + time.sleep(5) + + return True + + +def test_pd_separation(): + """PD分离部署模式测试""" + + print("\n============================开始PD分离+EP4TP1测试!============================") + + # 设置PD分离环境变量 + original_env = setup_pd_ep_env() + + # 检查RDMA网卡是否配置成功 + rdma_nics = os.environ.get("KVCACHE_RDMA_NICS", "") + if not rdma_nics: + pytest.fail("KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh") + print(f"KVCACHE_RDMA_NICS: {rdma_nics}") + + try: + # 获取配置 + port_num = get_port_num() + model_path = get_model_path() + + # 启动PD分离服务 + if not start_pd_server(model_path, port_num): + pytest.fail("PD分离服务启动失败") + + # 执行测试 - 通过Router端口访问 + ip = "0.0.0.0" + client = openai.Client(base_url=f"http://{ip}:{port_num}/v1", api_key="EMPTY_API_KEY") + + # 非流式对话 + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "user", "content": "你好,你是谁?"}, + ], + temperature=1, + top_p=0, + max_tokens=64, + stream=False, + ) + + print(f"\n模型回复: {response.choices[0].message.content}") + + # 验证响应 + assert any( + keyword in response.choices[0].message.content for keyword in ["人工智能", "文心一言", "百度", "智能助手"] + ), f"响应内容不符合预期: {response.choices[0].message.content}" + + print("\nPD分离测试通过!") + + except Exception as e: + print(f"\nPD分离测试失败: {str(e)}") + print_pd_logs_on_failure() + pytest.fail(f"PD分离测试失败: {str(e)}") + + finally: + # 停止服务 + print("\n停止PD分离服务...") + stop_processes() + + # 恢复环境变量 + restore_pd_ep_env(original_env) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])