Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 73 additions & 21 deletions fastdeploy/model_executor/xpu_pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
# limitations under the License.
"""

from typing import Dict, Optional
import queue
from typing import Dict, List, Optional

import numpy as np
import paddle

from fastdeploy import envs
from fastdeploy.model_executor.forward_meta import XPUForwardMeta
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import ModelOutputData
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData

if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
Expand All @@ -49,6 +52,43 @@
)


def _build_stream_transfer_data(
output_tokens: paddle.Tensor,
pooler_outputs: List = None,
logprobs: Optional[LogprobsTensors] = None,
prompt_logprobs_list: Optional[LogprobsTensors] = None,
):
"""Split output_tokens and output"""
stream_transfer_datas = []
if output_tokens is not None:
output_tokens = output_tokens.reshape([-1]).numpy()
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])

for bid, output_token_per_sample in enumerate(output_tokens_lists):
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
if logprobs:
stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1)
if prompt_logprobs_list:
stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid]
stream_transfer_datas.append(stream_transfer_data)
elif pooler_outputs is not None:
for bid, pooler_output in enumerate(pooler_outputs):
if pooler_output is None:
continue
if pooler_output.dtype == paddle.bfloat16:
pooler_output = pooler_output.astype("float32")

pooler_output = pooler_output.numpy()

stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, pooler_output=pooler_output, batch_id=bid
)
stream_transfer_datas.append(stream_transfer_data)
return stream_transfer_datas


def xpu_pre_process(
input_ids: paddle.Tensor,
seq_lens_this_time: int,
Expand Down Expand Up @@ -217,6 +257,8 @@ def xpu_post_process_normal(
share_inputs: Dict[str, paddle.Tensor],
block_size: int = 64,
skip_save_output: bool = False,
save_each_rank: bool = False,
async_output_queue: queue.Queue = None,
think_end_id: int = None,
line_break_id: int = None,
) -> None:
Expand Down Expand Up @@ -314,27 +356,37 @@ def xpu_post_process_normal(
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
save_output(
sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
False, # use_ep
)
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(
sampled_token_ids,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_list=model_output.prompt_logprobs_list,
)
if async_output_queue is not None:
async_output_queue.put(output)
else:
if save_output_topk is None:
raise ImportError(
"save_output_topk operator is not available. "
"Please rebuild the XPU operators with the new get_output_msg_with_topk.cc and save_output_msg_with_topk.cc files."
if sampler_output.logprobs_tensors is None:
save_output(
sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
False, # use_ep
)
else:
if save_output_topk is None:
raise ImportError(
"save_output_topk operator is not available. "
"Please rebuild the XPU operators with the new get_output_msg_with_topk.cc and save_output_msg_with_topk.cc files."
)
save_output_topk(
sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
model_output.mp_rank,
)
save_output_topk(
sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
model_output.mp_rank,
)


def xpu_post_process_specualate(
Expand Down
120 changes: 118 additions & 2 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@
"""

import os
import queue
import random
import time
from threading import Thread
from typing import List, Optional

import numpy as np
import paddle
import zmq
from paddle import nn

from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
Expand Down Expand Up @@ -59,7 +62,7 @@
from fastdeploy.spec_decode import MTPProposer
from fastdeploy.utils import get_logger
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput

logger = get_logger("xpu_model_runner", "xpu_model_runner.log")

Expand Down Expand Up @@ -156,6 +159,106 @@ def __init__(

self.pd_disaggregation_mode: str = self.fd_config.parallel_config.pd_disaggregation_mode

# Initialize ZMQ client for async output
self.zmq_client = None
self.async_output_queue = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
logger.info(f"zmq client get_save_output_rank{local_rank}")
self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
self.zmq_client.connect()
self.zmq_client.socket.SNDTIMEO = 3000
self.async_output_queue: queue.Queue = queue.Queue()
self.async_output_copy_thread = Thread(
target=self._async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy",
)
self.async_output_copy_thread.start()
# prompt logprobs state
self.prompt_logprobs_reqs: dict[str, Request] = {}
self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {}

def _async_output_busy_loop(self):
"""Entrypoint for the thread which handles outputs asynchronously."""
while True:
try:
if self.async_output_queue is None or self.zmq_client is None:
break
output = self.async_output_queue.get()
if self.zmq_client is not None:
self.zmq_client.send_pyobj(output)
except Exception as e:
logger.exception("Exception in async output loop: %s", e)

def _get_prompt_logprobs_list(self, hidden_states: paddle.Tensor) -> list[Optional[LogprobsTensors]]:
"""
Build prompt_logprobs for requests that asked for it.
"""
if len(self.prompt_logprobs_reqs) > 0:
assert (
not self.fd_config.cache_config.enable_prefix_caching
), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching."

if len(self.prompt_logprobs_reqs) == 0:
return self.scheduler_config.max_num_seqs * [None]

logprobs_mode = self.fd_config.model_config.logprobs_mode
prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None]
completed_prefill_reqs: list[Request] = []

for req_id, request in self.prompt_logprobs_reqs.items():
if not hasattr(request, "sampling_params") or request.sampling_params is None:
continue
num_prompt_logprobs = request.sampling_params.prompt_logprobs
if request.prompt_token_ids is None or num_prompt_logprobs is None:
continue
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.ori_vocab_size

num_tokens = request.prefill_end_index - request.prefill_start_index
num_prompt_tokens = len(request.prompt_token_ids)

logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id)
if not logprobs_tensors:
logprobs_tensors = LogprobsTensors.empty_cpu(num_prompt_tokens - 1, num_prompt_logprobs + 1)
self.in_progress_prompt_logprobs[req_id] = logprobs_tensors

start_idx = request.prefill_start_index
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens <= num_remaining_tokens:
num_logits = num_tokens
else:
num_logits = num_remaining_tokens
completed_prefill_reqs.append(request)
prompt_logprobs_list[request.idx] = logprobs_tensors
if num_logits <= 0:
continue

offset = self.share_inputs["cu_seqlens_q"][request.idx]
prompt_hidden_states = hidden_states[offset : offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states)
prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits]
prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64")
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.sampler.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
raw_logprobs = logits
else:
raw_logprobs = self.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor
)
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False)

for req in completed_prefill_reqs:
del self.prompt_logprobs_reqs[req.request_id]
del self.in_progress_prompt_logprobs[req.request_id]
return prompt_logprobs_list

def exist_prefill(self):
"""
check whether prefill stage exist
Expand Down Expand Up @@ -405,6 +508,13 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0

if (
hasattr(request, "sampling_params")
and request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None
):
self.prompt_logprobs_reqs[request.request_id] = request

if len(request.output_token_ids) == 0:
input_ids = request.prompt_token_ids
else:
Expand Down Expand Up @@ -1296,6 +1406,10 @@ class at the server level, which is too granular for ModelRunner.
# 5. Speculative decode

# 6. Post Process
prompt_logprobs_list = None
if not self.speculative_decoding:
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)

model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
Expand Down Expand Up @@ -1323,6 +1437,7 @@ class at the server level, which is too granular for ModelRunner.
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_logprobs_list=prompt_logprobs_list,
)
if self.speculative_decoding:
# base model post process
Expand All @@ -1334,6 +1449,7 @@ class at the server level, which is too granular for ModelRunner.
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
skip_save_output=is_dummy_run,
async_output_queue=self.async_output_queue,
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
)
Expand Down
Loading