Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions custom_ops/xpu_ops/src/ops/get_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
#include "msg_utils.h"
#include "paddle/extension.h"

void GetOutputKVSignal(const paddle::Tensor &x,
void GetOutputKVSignal(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
int msg_queue_id = 1024;
if (const char *msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) {
if (const char* msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string msg_que_str(msg_que_str_tmp);
msg_queue_id = std::stoi(msg_que_str);
}
Expand All @@ -33,7 +33,7 @@ void GetOutputKVSignal(const paddle::Tensor &x,
static key_t key = ftok("/opt/", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);

int *out_data = const_cast<int *>(x.data<int>());
int* out_data = const_cast<int*>(x.data<int>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
Expand All @@ -53,15 +53,12 @@ void GetOutputKVSignal(const paddle::Tensor &x,
return;
}

void GetOutput(const paddle::Tensor &x,
void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;
if (const char *inference_msg_queue_id_env_p =
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
Expand All @@ -82,7 +79,7 @@ void GetOutput(const paddle::Tensor &x,
std::cout << "get_output wait_flag: " << wait_flag << std::endl;
#endif

int64_t *out_data = const_cast<int64_t *>(x.data<int64_t>());
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
Expand Down Expand Up @@ -110,11 +107,20 @@ void GetOutput(const paddle::Tensor &x,
return;
}

void GetOutputStatic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag) {
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
if (rank_id > 0) {
return;
}
GetOutput(x, rank_id, wait_flag, 1);
}

void GetOutputEPStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
GetOutput(x, rank_id, wait_flag, 1);
}

void GetOutputDynamic(const paddle::Tensor &x,
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
Expand All @@ -140,7 +146,7 @@ PD_BUILD_OP(get_output_ep)
.Attrs({"rank_id: int64_t", "wait_flag: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutputStatic));
.SetKernelFn(PD_KERNEL(GetOutputEPStatic));

PD_BUILD_OP(get_output_ep_dynamic)
.Inputs({"x"})
Expand Down
6 changes: 5 additions & 1 deletion custom_ops/xpu_ops/src/ops/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ std::vector<paddle::Tensor> GetInferParam(

void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);

void GetOutputEPStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag);

void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
Expand Down Expand Up @@ -839,7 +843,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
"get_output function");

m.def("get_output_ep",
&GetOutputStatic,
&GetOutputEPStatic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
Expand Down
6 changes: 5 additions & 1 deletion fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def __init__(
# Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None

# Postprocess Env params
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.local_engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.local_engine_worker_queue_port)}")

self.pd_disaggregation_mode: str = self.fd_config.parallel_config.pd_disaggregation_mode

# Initialize ZMQ client for async output
Expand Down Expand Up @@ -1439,7 +1443,7 @@ class at the server level, which is too granular for ModelRunner.
# 投机解码
full_hidden_states=model_output if self.speculative_decoding else None,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank,
mp_rank=self.parallel_config.tensor_parallel_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
Expand Down
Loading
Loading