diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 2673af27684..60620ce7671 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -14,15 +14,18 @@ # limitations under the License. """ -from typing import Dict, Optional +import queue +from typing import Dict, List, Optional +import numpy as np import paddle from fastdeploy import envs from fastdeploy.model_executor.forward_meta import XPUForwardMeta from fastdeploy.model_executor.layers.sample.sampler import Sampler +from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData from fastdeploy.platforms import current_platform -from fastdeploy.worker.output import ModelOutputData +from fastdeploy.worker.output import LogprobsTensors, ModelOutputData if current_platform.is_xpu(): from fastdeploy.model_executor.ops.xpu import ( @@ -49,6 +52,43 @@ ) +def _build_stream_transfer_data( + output_tokens: paddle.Tensor, + pooler_outputs: List = None, + logprobs: Optional[LogprobsTensors] = None, + prompt_logprobs_list: Optional[LogprobsTensors] = None, +): + """Split output_tokens and output""" + stream_transfer_datas = [] + if output_tokens is not None: + output_tokens = output_tokens.reshape([-1]).numpy() + output_tokens_lists = np.split(output_tokens, output_tokens.shape[0]) + + for bid, output_token_per_sample in enumerate(output_tokens_lists): + stream_transfer_data = StreamTransferData( + decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid + ) + if logprobs: + stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1) + if prompt_logprobs_list: + stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid] + stream_transfer_datas.append(stream_transfer_data) + elif pooler_outputs is not None: + for bid, pooler_output in enumerate(pooler_outputs): + if pooler_output is None: + continue + if pooler_output.dtype == paddle.bfloat16: + pooler_output = pooler_output.astype("float32") + + pooler_output = pooler_output.numpy() + + stream_transfer_data = StreamTransferData( + decoder_state=DecoderState.TEXT, pooler_output=pooler_output, batch_id=bid + ) + stream_transfer_datas.append(stream_transfer_data) + return stream_transfer_datas + + def xpu_pre_process( input_ids: paddle.Tensor, seq_lens_this_time: int, @@ -217,6 +257,8 @@ def xpu_post_process_normal( share_inputs: Dict[str, paddle.Tensor], block_size: int = 64, skip_save_output: bool = False, + save_each_rank: bool = False, + async_output_queue: queue.Queue = None, think_end_id: int = None, line_break_id: int = None, ) -> None: @@ -314,27 +356,37 @@ def xpu_post_process_normal( # 3. Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if not skip_save_output: - if sampler_output.logprobs_tensors is None: - save_output( - sampled_token_ids, - model_output.not_need_stop, - model_output.mp_rank, - False, # use_ep - ) + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + if save_each_rank or model_output.mp_rank == 0: + output = _build_stream_transfer_data( + sampled_token_ids, + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs_list=model_output.prompt_logprobs_list, + ) + if async_output_queue is not None: + async_output_queue.put(output) else: - if save_output_topk is None: - raise ImportError( - "save_output_topk operator is not available. " - "Please rebuild the XPU operators with the new get_output_msg_with_topk.cc and save_output_msg_with_topk.cc files." + if sampler_output.logprobs_tensors is None: + save_output( + sampled_token_ids, + model_output.not_need_stop, + model_output.mp_rank, + False, # use_ep + ) + else: + if save_output_topk is None: + raise ImportError( + "save_output_topk operator is not available. " + "Please rebuild the XPU operators with the new get_output_msg_with_topk.cc and save_output_msg_with_topk.cc files." + ) + save_output_topk( + sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + model_output.not_need_stop, + model_output.mp_rank, ) - save_output_topk( - sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - model_output.not_need_stop, - model_output.mp_rank, - ) def xpu_post_process_specualate( diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 07dd0a3c883..99688ba425e 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -15,19 +15,22 @@ """ import os +import queue import random import time +from threading import Thread from typing import List, Optional import numpy as np import paddle +import zmq from paddle import nn from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request, RequestType from fastdeploy.input.ernie4_5_vl_processor import DataProcessor -from fastdeploy.inter_communicator import IPCSignal +from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, @@ -59,7 +62,7 @@ from fastdeploy.spec_decode import MTPProposer from fastdeploy.utils import get_logger from fastdeploy.worker.model_runner_base import ModelRunnerBase -from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput +from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput logger = get_logger("xpu_model_runner", "xpu_model_runner.log") @@ -156,6 +159,106 @@ def __init__( self.pd_disaggregation_mode: str = self.fd_config.parallel_config.pd_disaggregation_mode + # Initialize ZMQ client for async output + self.zmq_client = None + self.async_output_queue = None + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + logger.info(f"zmq client get_save_output_rank{local_rank}") + self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH) + self.zmq_client.connect() + self.zmq_client.socket.SNDTIMEO = 3000 + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_thread = Thread( + target=self._async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy", + ) + self.async_output_copy_thread.start() + # prompt logprobs state + self.prompt_logprobs_reqs: dict[str, Request] = {} + self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} + + def _async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + while True: + try: + if self.async_output_queue is None or self.zmq_client is None: + break + output = self.async_output_queue.get() + if self.zmq_client is not None: + self.zmq_client.send_pyobj(output) + except Exception as e: + logger.exception("Exception in async output loop: %s", e) + + def _get_prompt_logprobs_list(self, hidden_states: paddle.Tensor) -> list[Optional[LogprobsTensors]]: + """ + Build prompt_logprobs for requests that asked for it. + """ + if len(self.prompt_logprobs_reqs) > 0: + assert ( + not self.fd_config.cache_config.enable_prefix_caching + ), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching." + + if len(self.prompt_logprobs_reqs) == 0: + return self.scheduler_config.max_num_seqs * [None] + + logprobs_mode = self.fd_config.model_config.logprobs_mode + prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None] + completed_prefill_reqs: list[Request] = [] + + for req_id, request in self.prompt_logprobs_reqs.items(): + if not hasattr(request, "sampling_params") or request.sampling_params is None: + continue + num_prompt_logprobs = request.sampling_params.prompt_logprobs + if request.prompt_token_ids is None or num_prompt_logprobs is None: + continue + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.ori_vocab_size + + num_tokens = request.prefill_end_index - request.prefill_start_index + num_prompt_tokens = len(request.prompt_token_ids) + + logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id) + if not logprobs_tensors: + logprobs_tensors = LogprobsTensors.empty_cpu(num_prompt_tokens - 1, num_prompt_logprobs + 1) + self.in_progress_prompt_logprobs[req_id] = logprobs_tensors + + start_idx = request.prefill_start_index + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + num_logits = num_tokens + else: + num_logits = num_remaining_tokens + completed_prefill_reqs.append(request) + prompt_logprobs_list[request.idx] = logprobs_tensors + if num_logits <= 0: + continue + + offset = self.share_inputs["cu_seqlens_q"][request.idx] + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits] + prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64") + if logprobs_mode == "raw_logprobs": + raw_logprobs = self.sampler.compute_logprobs(logits) + elif logprobs_mode == "raw_logits": + raw_logprobs = logits + else: + raw_logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor + ) + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False) + + for req in completed_prefill_reqs: + del self.prompt_logprobs_reqs[req.request_id] + del self.in_progress_prompt_logprobs[req.request_id] + return prompt_logprobs_list + def exist_prefill(self): """ check whether prefill stage exist @@ -405,6 +508,13 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 + if ( + hasattr(request, "sampling_params") + and request.sampling_params is not None + and request.sampling_params.prompt_logprobs is not None + ): + self.prompt_logprobs_reqs[request.request_id] = request + if len(request.output_token_ids) == 0: input_ids = request.prompt_token_ids else: @@ -1296,6 +1406,10 @@ class at the server level, which is too granular for ModelRunner. # 5. Speculative decode # 6. Post Process + prompt_logprobs_list = None + if not self.speculative_decoding: + prompt_logprobs_list = self._get_prompt_logprobs_list(model_output) + model_output_data = ModelOutputData( next_tokens=self.share_inputs["next_tokens"], stop_flags=self.share_inputs["stop_flags"], @@ -1323,6 +1437,7 @@ class at the server level, which is too granular for ModelRunner. accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_logprobs_list=prompt_logprobs_list, ) if self.speculative_decoding: # base model post process @@ -1334,6 +1449,7 @@ class at the server level, which is too granular for ModelRunner. share_inputs=self.share_inputs, block_size=self.cache_config.block_size, skip_save_output=is_dummy_run, + async_output_queue=self.async_output_queue, think_end_id=self.model_config.think_end_id, line_break_id=self.model_config.line_break_id, )