From 3410c9ff41faba85e1127332b84b164ae6144902 Mon Sep 17 00:00:00 2001 From: AIxyz Date: Mon, 12 Jan 2026 10:17:45 +0800 Subject: [PATCH 1/7] Add serving_dinfer_openai.md --- demo/serving_dinfer_openai.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 demo/serving_dinfer_openai.md diff --git a/demo/serving_dinfer_openai.md b/demo/serving_dinfer_openai.md new file mode 100644 index 0000000..6d8013d --- /dev/null +++ b/demo/serving_dinfer_openai.md @@ -0,0 +1,13 @@ +# Demo of dinfer openai serving + +## Serving + +```bash +python3 serving_dinfer_openai.py +``` + +## Client + +```bash +date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer 12345678" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": false}' http://0.0.0.0:48000/v1/chat/completions && date +``` From 92ad3f5acd69159602445ddf7ed79981778b784f Mon Sep 17 00:00:00 2001 From: AIxyz Date: Mon, 12 Jan 2026 11:12:11 +0800 Subject: [PATCH 2/7] Add open-webui in serving_dinfer_openai.md --- demo/serving_dinfer_openai.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/demo/serving_dinfer_openai.md b/demo/serving_dinfer_openai.md index 6d8013d..c2931d0 100644 --- a/demo/serving_dinfer_openai.md +++ b/demo/serving_dinfer_openai.md @@ -11,3 +11,16 @@ python3 serving_dinfer_openai.py ```bash date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer 12345678" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": false}' http://0.0.0.0:48000/v1/chat/completions && date ``` + +## Web demo + +```bash +date && docker pull ghcr.io/open-webui/open-webui:main && date + +mkdir data-open-webui +cd data-open-webui + +date && docker run -d -p 60111:8080 -v $PWD:/app/backend/data --name open-webui ghcr.io/open-webui/open-webui:main && date +``` + +Config your open-webui with [http://0.0.0.0:48000/v1](http://0.0.0.0:48000/v1) like [this](https://developer.volcengine.com/articles/7533551308616237092#heading10) From 1298cb313c8e0414b57a6a4114bc4fd69c8a20a2 Mon Sep 17 00:00:00 2001 From: AIxyz Date: Mon, 12 Jan 2026 11:59:26 +0800 Subject: [PATCH 3/7] Add serving_dinfer_openai.py --- demo/serving_dinfer_openai.py | 219 ++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 demo/serving_dinfer_openai.py diff --git a/demo/serving_dinfer_openai.py b/demo/serving_dinfer_openai.py new file mode 100644 index 0000000..49bc56d --- /dev/null +++ b/demo/serving_dinfer_openai.py @@ -0,0 +1,219 @@ +''' +This is a fastapi dinfer serving of llada +''' + +# pylint: disable=import-error, no-name-in-module +# pylint: disable=global-statement, global-variable-not-assigned +# pylint: disable=too-few-public-methods + +import json +import logging +import os +import time +from typing import List, Dict +import uuid + +from fastapi import FastAPI +from fastapi.responses import StreamingResponse, Response +from pydantic import BaseModel, Field +import torch +from transformers import AutoTokenizer +import uvicorn + + +def init_logger(log_path: str = None): + ''' init logger with exception handling ''' + log_format = '[%(asctime)s](%(created).17g) - %(levelname)s - ' \ + '|%(pathname)s|%(funcName)s|%(lineno)d| - %(message)s' + date_format = '%Y-%m-%d_%H:%M:%S' + + class CustomFormatter(logging.Formatter): + ''' CustomFormatter ''' + + def format(self, record): + ''' format ''' + try: + return super().format(record) + except TypeError: + if hasattr(record, 'args') and record.args: + record.msg = f"{record.msg} - Unformattable args: {record.args}" + record.args = () + return super().format(record) + + task_logging_level = os.environ.get("TASK_LOGGING_LEVEL", "INFO") + log_level = logging.INFO if task_logging_level == "INFO" else logging.DEBUG + logging.basicConfig(filename=log_path, + level=log_level, + format=log_format, + datefmt=date_format) + + for handler in logging.root.handlers: + handler.setFormatter(CustomFormatter(log_format, date_format)) + + return logging + + +logging = init_logger() + +from dinfer import DiffusionLLMServing, SamplingParams, ThresholdParallelDecoder # pylint: disable=wrong-import-position + + +class CompletionsRequest(BaseModel): + ''' Completions Request + ''' + messages: List[Dict] = Field(title='messages') + stream: bool = Field(title='stream') + + +app = FastAPI( + title='xyz dllm serving', + redoc_url=None, + docs=None, +) +STATUS_OK = 1 +STATUS_ERR = 0 + +SPECIAL_MODEL_DIR = os.environ.get('SPECIAL_MODEL_DIR') +TASK_DLLM_NUM_GPUS = int(os.environ.get('TASK_DLLM_NUM_GPUS', 1)) + +TASK_DLLM_GEN_LENGTH = int(os.environ.get('TASK_DLLM_GEN_LENGTH', 512)) +TASK_DLLM_BLOCK_LENGTH = int(os.environ.get('TASK_DLLM_BLOCK_LENGTH', 32)) + +TASK_DLLM_MAX_LENGTH = int(os.environ.get('TASK_DLLM_MAX_LENGTH', 4096)) +TASK_DLLM_BATCH_SIZE = int(os.environ.get('TASK_DLLM_BATCH_SIZE', 2)) +TASK_DLLM_TEMPERATURE = float(os.environ.get('TASK_DLLM_TEMPERATURE', 0.0)) +TASK_DLLM_THRESHOLD = float(os.environ.get('TASK_DLLM_THRESHOLD', 0.9)) + +TASK_DLLM_MASK_ID = int(os.environ.get('TASK_DLLM_MASK_ID', 156895)) +TASK_DLLM_EOS_ID = int(os.environ.get('TASK_DLLM_EOS_ID', 156892)) + + +def get_dllm(): + ''' get dllm ''' + dllm_tokenizer = AutoTokenizer.from_pretrained(SPECIAL_MODEL_DIR, + trust_remote_code=True) + decoder = ThresholdParallelDecoder(temperature=TASK_DLLM_TEMPERATURE, + threshold=TASK_DLLM_THRESHOLD, + mask_id=TASK_DLLM_MASK_ID, + eos_id=TASK_DLLM_EOS_ID) + sample_params = SamplingParams(threshold=TASK_DLLM_THRESHOLD, + cache='prefix', + temperature=0., + early_stop=True, + cont_weight=0, + prefix_look=0, + after_look=0, + warmup_steps=0, + enable_torch_compile=True, + mask_id=TASK_DLLM_MASK_ID, + eos_id=TASK_DLLM_EOS_ID, + parallel_decoding='threshold', + use_credit=False, + use_bd=True, + max_length=TASK_DLLM_MAX_LENGTH, + ep_size=1, + batch_size=TASK_DLLM_BATCH_SIZE, + mini_batch_size=TASK_DLLM_BATCH_SIZE, + use_naive_batching=True) + dllm_server = DiffusionLLMServing(SPECIAL_MODEL_DIR, + model_type='llada2-mini', + sample_params=sample_params, + server_port=40570, + num_gpus=TASK_DLLM_NUM_GPUS, + dp_size=1, + tpep_size=TASK_DLLM_NUM_GPUS, + backend='sglang') + return dllm_tokenizer, dllm_server, decoder + + +tokenizer, MODEL_DLLM = None, None + + +def get_answer_openai(chat_uuid: str, data: Dict) -> str: + ''' get answer openai ''' + logging.info('[%s] resp: %s', chat_uuid, data) + + +def get_answer_openai_no_stream(chat_uuid: str, data: Dict) -> str: + ''' get answer openai (no stream) ''' + global tokenizer, MODEL_DLLM + input_ids = tokenizer.apply_chat_template( + data['messages'], + add_generation_prompt=True, + tokenize=True, + return_tensors='pt', + ) + batch_input_ids = torch.zeros((input_ids.shape[0], TASK_DLLM_MAX_LENGTH), + dtype=torch.long).fill_(TASK_DLLM_MASK_ID) + for s_k in range(input_ids.shape[0]): + batch_input_ids[s_k, :input_ids.shape[-1]] = input_ids[s_k] + + x_tokens_yield = MODEL_DLLM.generate(batch_input_ids, + gen_length=TASK_DLLM_GEN_LENGTH, + block_length=TASK_DLLM_BLOCK_LENGTH) + + resp = {} + x_tokens_final = x_tokens_yield + + x_str = tokenizer.decode(x_tokens_final[0]) + text = x_str.split('ASSISTANT')[-1] + text = text.replace('<|endoftext|>', '').replace('<|role_end|>', '') + text = text.replace('<|mask|>', ' ') + resp = { + 'id': + chat_uuid, + 'object': + 'chat.completion', + 'created': + time.time(), + 'model': + 'xyz-dllm', + 'choices': [{ + 'index': 0, + 'message': { + 'role': 'assistant', + 'content': text, + 'resoning_content': None + }, + 'logprobs': None, + 'finish_reason': 'stop', + }], + 'prompt_token_ids': + None, + "usage": { + "prompt_tokens": input_ids.shape[-1], + "total_tokens": len(x_tokens_final[0]), + "completion_tokens": len(x_tokens_final[0]) - input_ids.shape[-1], + "prompt_tokens_details": None + } + } + logging.info('[%s] resp: %s', chat_uuid, resp) + return json.dumps(resp, ensure_ascii=False) + '\n' + + +@app.post('/v1/chat/completions') +def chat_openai(request: CompletionsRequest): + ''' chat + ''' + chat_uuid = f'chat-{str(uuid.uuid4())}' + data = request.dict() + logging.info('[%s] req: %s', chat_uuid, data) + if request.stream: + return StreamingResponse(get_answer_openai(chat_uuid, data)) + return Response( + content=get_answer_openai_no_stream(chat_uuid, data), + media_type='text/plain', + ) + + +def mission(): + ''' api demo + ''' + global tokenizer, MODEL_DLLM + tokenizer, MODEL_DLLM, _ = get_dllm() + port = int(os.environ.get('TASK_SERVER_PORT', '40081')) + uvicorn.run(app, host='0.0.0.0', port=port) + + +if __name__ == '__main__': + mission() From 600be117203acd77c3784e4bcf33bcd42b5fb42a Mon Sep 17 00:00:00 2001 From: AIxyz Date: Mon, 12 Jan 2026 12:04:51 +0800 Subject: [PATCH 4/7] Update serving_dinfer_openai.md --- demo/serving_dinfer_openai.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/demo/serving_dinfer_openai.md b/demo/serving_dinfer_openai.md index c2931d0..4785a52 100644 --- a/demo/serving_dinfer_openai.md +++ b/demo/serving_dinfer_openai.md @@ -3,13 +3,14 @@ ## Serving ```bash +export TASK_DLLM_BATCH_SIZE=2 python3 serving_dinfer_openai.py ``` ## Client ```bash -date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer 12345678" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": false}' http://0.0.0.0:48000/v1/chat/completions && date +date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer YOUR_API_KEY" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": false}' http://0.0.0.0:48000/v1/chat/completions && date ``` ## Web demo From 1a7d4301613b384e53e94753f0a3b96578540b35 Mon Sep 17 00:00:00 2001 From: AIxyz Date: Mon, 12 Jan 2026 15:13:56 +0800 Subject: [PATCH 5/7] Update SPECIAL_MODEL_DIR in serving_dinfer_openai.md --- demo/serving_dinfer_openai.md | 1 + 1 file changed, 1 insertion(+) diff --git a/demo/serving_dinfer_openai.md b/demo/serving_dinfer_openai.md index 4785a52..1f2bdc4 100644 --- a/demo/serving_dinfer_openai.md +++ b/demo/serving_dinfer_openai.md @@ -3,6 +3,7 @@ ## Serving ```bash +export SPECIAL_MODEL_DIR=/models/LLaDA2.0-mini--572899f-C8 # Download from https://www.modelscope.cn/models/inclusionAI/LLaDA2.0-mini export TASK_DLLM_BATCH_SIZE=2 python3 serving_dinfer_openai.py ``` From b813f3298bcf0664669bf08d6300b4905da5924a Mon Sep 17 00:00:00 2001 From: AIxyz Date: Tue, 13 Jan 2026 20:29:51 +0800 Subject: [PATCH 6/7] Support stream in serving_dinfer_openai --- demo/serving_dinfer_openai.md | 2 + demo/serving_dinfer_openai.py | 149 ++++++++++++++++++++- python/dinfer/decoding/generate_uniform.py | 21 ++- python/dinfer/decoding/serving.py | 14 +- 4 files changed, 169 insertions(+), 17 deletions(-) diff --git a/demo/serving_dinfer_openai.md b/demo/serving_dinfer_openai.md index 1f2bdc4..4a9aa5c 100644 --- a/demo/serving_dinfer_openai.md +++ b/demo/serving_dinfer_openai.md @@ -12,6 +12,8 @@ python3 serving_dinfer_openai.py ```bash date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer YOUR_API_KEY" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": false}' http://0.0.0.0:48000/v1/chat/completions && date + +date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer YOUR_API_KEY" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": true}' http://0.0.0.0:48000/v1/chat/completions && date ``` ## Web demo diff --git a/demo/serving_dinfer_openai.py b/demo/serving_dinfer_openai.py index 49bc56d..0241675 100644 --- a/demo/serving_dinfer_openai.py +++ b/demo/serving_dinfer_openai.py @@ -5,12 +5,15 @@ # pylint: disable=import-error, no-name-in-module # pylint: disable=global-statement, global-variable-not-assigned # pylint: disable=too-few-public-methods +# pylint: disable=broad-exception-caught import json import logging import os +from queue import Queue +import threading import time -from typing import List, Dict +from typing import Dict, List, Optional import uuid from fastapi import FastAPI @@ -65,6 +68,12 @@ class CompletionsRequest(BaseModel): stream: bool = Field(title='stream') +class StreamRequest(BaseModel): + '''Stream Request''' + chat_uuid: str = Field(title='chat_uuid') + curr_x: Optional[List] = Field(title='curr_x') + + app = FastAPI( title='xyz dllm serving', redoc_url=None, @@ -87,6 +96,9 @@ class CompletionsRequest(BaseModel): TASK_DLLM_MASK_ID = int(os.environ.get('TASK_DLLM_MASK_ID', 156895)) TASK_DLLM_EOS_ID = int(os.environ.get('TASK_DLLM_EOS_ID', 156892)) +TASK_STREAM_SLEEP_SECONDS = float( + os.environ.get('TASK_STREAM_SLEEP_SECONDS', 0.005)) + def get_dllm(): ''' get dllm ''' @@ -127,11 +139,122 @@ def get_dllm(): tokenizer, MODEL_DLLM = None, None +global_stream_dict = {} +stream_lock = threading.Lock() + + +def stream_put_api(chat_uuid: str, curr_x: Optional[List]): + ''' stream_put_api ''' + with stream_lock: + if chat_uuid not in global_stream_dict: + global_stream_dict[chat_uuid] = Queue() + + if curr_x is None: + global_stream_dict[chat_uuid].put([]) + else: + global_stream_dict[chat_uuid].put(curr_x) + + +def stream_get_api(chat_uuid: str) -> Optional[List]: + ''' stream_get_api ''' + try: + with stream_lock: + if chat_uuid not in global_stream_dict: + return None + + queue = global_stream_dict[chat_uuid] + result = queue.get(timeout=0) + + if result == []: + del global_stream_dict[chat_uuid] + + return result + + except Exception: + return None + + +def generate_in_background(chat_uuid: str, data: Dict): + ''' generate_in_background ''' + + def _generate(): + try: + global tokenizer, MODEL_DLLM + input_ids = tokenizer.apply_chat_template( + data['messages'], + add_generation_prompt=True, + tokenize=True, + return_tensors='pt', + ) + batch_input_ids = torch.zeros( + (input_ids.shape[0], TASK_DLLM_MAX_LENGTH), + dtype=torch.long).fill_(TASK_DLLM_MASK_ID) + for s_k in range(input_ids.shape[0]): + batch_input_ids[s_k, :input_ids.shape[-1]] = input_ids[s_k] + + _ = MODEL_DLLM.generate(batch_input_ids, + chat_uuid=chat_uuid, + gen_length=TASK_DLLM_GEN_LENGTH, + block_length=TASK_DLLM_BLOCK_LENGTH) + stream_put_api(chat_uuid, None) + + except Exception as err: + logging.error("Generate error: %s", err) + stream_put_api(chat_uuid, None) + + thread = threading.Thread(target=_generate, daemon=True) + thread.start() def get_answer_openai(chat_uuid: str, data: Dict) -> str: - ''' get answer openai ''' - logging.info('[%s] resp: %s', chat_uuid, data) + '''get answer openai''' + + generate_in_background(chat_uuid, data) + resp = {} + + while True: + curr_x = stream_get_api(chat_uuid) + + if curr_x is None: + time.sleep(TASK_STREAM_SLEEP_SECONDS) + continue + + if curr_x == []: + break + + x_str = tokenizer.decode(curr_x[0]) + text = x_str.split('ASSISTANT')[-1] + text = text.replace('<|endoftext|>', '').replace('<|role_end|>', '') + text = text.replace('<|mask|>', ' ').rstrip(' ') + resp = { + 'id': + chat_uuid, + 'object': + 'chat.completion.chunk', + 'created': + time.time(), + 'model': + 'xyz-dllm', + 'choices': [{ + 'index': 0, + 'delta': { + 'role': 'assistant', + 'content': text, + 'resoning_content': None + }, + 'logprobs': None, + 'finish_reason': None, + }], + 'prompt_token_ids': + None + } + yield 'data: ' + json.dumps(resp, ensure_ascii=False) + '\n' + + logging.info('[%s] resp: %s', chat_uuid, resp) + + with stream_lock: + if chat_uuid in global_stream_dict: + del global_stream_dict[chat_uuid] def get_answer_openai_no_stream(chat_uuid: str, data: Dict) -> str: @@ -193,8 +316,7 @@ def get_answer_openai_no_stream(chat_uuid: str, data: Dict) -> str: @app.post('/v1/chat/completions') def chat_openai(request: CompletionsRequest): - ''' chat - ''' + ''' chat ''' chat_uuid = f'chat-{str(uuid.uuid4())}' data = request.dict() logging.info('[%s] req: %s', chat_uuid, data) @@ -206,9 +328,22 @@ def chat_openai(request: CompletionsRequest): ) +@app.post('/v1/stream_put') +def stream_put_endpoint(request: StreamRequest): + ''' stream_put_endpoint ''' + stream_put_api(request.chat_uuid, request.curr_x) + return {"status": "success", "message": "Data added to stream"} + + +@app.post('/v1/stream_get') +def stream_get_endpoint(request: StreamRequest): + ''' stream_get_endpoint ''' + curr_x = stream_get_api(request.chat_uuid) + return curr_x + + def mission(): - ''' api demo - ''' + ''' mission ''' global tokenizer, MODEL_DLLM tokenizer, MODEL_DLLM, _ = get_dllm() port = int(os.environ.get('TASK_SERVER_PORT', '40081')) diff --git a/python/dinfer/decoding/generate_uniform.py b/python/dinfer/decoding/generate_uniform.py index 9fb5eed..f06f18c 100644 --- a/python/dinfer/decoding/generate_uniform.py +++ b/python/dinfer/decoding/generate_uniform.py @@ -1,3 +1,6 @@ +import os +import requests + import torch import numpy as np import logging @@ -1041,7 +1044,7 @@ def cache_updates(self): return self.diff_iteration.cache_updates @ torch.no_grad() - def naive_batching_generate(self, prompt, gen_length=128, block_length=128): + def naive_batching_generate(self, prompt, gen_length=128, block_length=128, chat_uuid=None): ''' Generate tokens with diffusion iterations block by block. ''' # recalculate gen length and init iteratory @@ -1075,7 +1078,13 @@ def naive_batching_generate(self, prompt, gen_length=128, block_length=128): # We need to reset iter_no at the beginning of generating a sequence. self.diff_iteration.iter_no = 0 + ip_port_url = 'http://0.0.0.0:' + os.environ.get('TASK_SERVER_PORT', '40081') + logger.info(f'[{chat_uuid}] ip_port_url: {ip_port_url}') for block_id, (block_loc, block) in enumerate(it): + if chat_uuid: + x_list = x.get_generated_tokens().tolist() + stream_dict = {'chat_uuid': chat_uuid, 'curr_x': x_list} + requests.post(f'{ip_port_url}/v1/stream_put', json=stream_dict) self.decoder.block_init(block, block_id) if self.backend == 'vllm': cross_block_attn_mask = bd_attn_mask[:,block_loc.start-block_length:block_loc.end, :block_loc.end] @@ -1086,10 +1095,16 @@ def naive_batching_generate(self, prompt, gen_length=128, block_length=128): if torch.all(decode_compl) and self.early_stop: break logger.info(f'The number of diffusion iterations: {self.num_forwards}') - return x.get_generated_tokens() + + x_generated_tokens = x.get_generated_tokens() + if chat_uuid: + stream_dict = {'chat_uuid': chat_uuid, 'curr_x': x_generated_tokens.tolist()} + requests.post(f'{ip_port_url}/v1/stream_put', json=stream_dict) + + return x_generated_tokens @ torch.no_grad() - def dynamic_batching_generate(self, prompt, gen_length=128, block_length=128): + def dynamic_batching_generate(self, prompt, gen_length=128, block_length=128, chat_uuid=None): ''' Generate tokens with dynamic batching ''' assert self.cache_factory is not None diff --git a/python/dinfer/decoding/serving.py b/python/dinfer/decoding/serving.py index 056cfc3..c74114a 100644 --- a/python/dinfer/decoding/serving.py +++ b/python/dinfer/decoding/serving.py @@ -268,9 +268,9 @@ def generate(dllm, device, req_q, res_q): assert data == 'stop' break else: - input_ids, gen_len, block_len = data + input_ids, gen_len, block_len, chat_uuid = data input_ids = input_ids.to(device) - out = dllm.generate(input_ids, gen_length=gen_len, block_length=block_len) + out = dllm.generate(input_ids, gen_length=gen_len, block_length=block_len, chat_uuid=chat_uuid) num_forwards = dllm.num_forwards if res_q is not None: res_q.put((out, num_forwards)) @@ -485,18 +485,18 @@ def __init__(self): self.groups = [] self.need_response = 0 - def add_requests(self, reqs): + def add_requests(self, reqs, chat_uuid=None): prompts, gen_length, block_length = reqs self.need_response = 0 # assert len(self.groups) == prompts.shape[0], 'We cannot only use DP to support batch size > 1.' if len(self.groups) == prompts.shape[0]: for i, prompt in enumerate(prompts): - self.groups[i].add_request((prompt.unsqueeze(0), gen_length, block_length)) + self.groups[i].add_request((prompt.unsqueeze(0), gen_length, block_length, chat_uuid)) self.need_response = len(prompts) else: partial_data = torch.chunk(prompts, len(self.groups), dim=0) for i in range(len(partial_data)): - self.groups[i].add_request((partial_data[i], gen_length, block_length)) + self.groups[i].add_request((partial_data[i], gen_length, block_length, chat_uuid)) self.need_response = len(partial_data) @@ -580,7 +580,7 @@ def __init__(self, model, model_type='llada2', sample_params=None, server_port=N self.num_forwards = 0 self.timeout = timeout - def generate(self, prompts, gen_length=128, block_length=128): + def generate(self, prompts, gen_length=128, block_length=128, chat_uuid=None): ''' Generate tokens with diffusion iterations. Parameters: @@ -598,7 +598,7 @@ def generate(self, prompts, gen_length=128, block_length=128): The generation results of different lengths are padded with EOS. ''' prompts = prompts.cpu() - handle.add_requests((prompts, gen_length, block_length)) + handle.add_requests((prompts, gen_length, block_length), chat_uuid=chat_uuid) rets = handle.get_responses(timeout=self.timeout) max_len = max([tensor.shape[1] for (tensor, _) in rets]) total_batch_size = sum([tensor.shape[0] for (tensor, _) in rets]) From c4d917932cbe49874b20344d209f86128e022d16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=99=93=E9=98=B310335485?= Date: Tue, 13 Jan 2026 22:40:35 +0800 Subject: [PATCH 7/7] Support /v1/models (CORS) and delta stream in serving_dinfer_openai.py --- demo/serving_dinfer_openai.py | 86 +++++++++++++++++++++++++++++------ 1 file changed, 71 insertions(+), 15 deletions(-) diff --git a/demo/serving_dinfer_openai.py b/demo/serving_dinfer_openai.py index 0241675..248dc2e 100644 --- a/demo/serving_dinfer_openai.py +++ b/demo/serving_dinfer_openai.py @@ -17,6 +17,7 @@ import uuid from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, Response from pydantic import BaseModel, Field import torch @@ -39,12 +40,12 @@ def format(self, record): return super().format(record) except TypeError: if hasattr(record, 'args') and record.args: - record.msg = f"{record.msg} - Unformattable args: {record.args}" + record.msg = f'{record.msg} - Unformattable args: {record.args}' record.args = () return super().format(record) - task_logging_level = os.environ.get("TASK_LOGGING_LEVEL", "INFO") - log_level = logging.INFO if task_logging_level == "INFO" else logging.DEBUG + task_logging_level = os.environ.get('TASK_LOGGING_LEVEL', 'INFO') + log_level = logging.INFO if task_logging_level == 'INFO' else logging.DEBUG logging.basicConfig(filename=log_path, level=log_level, format=log_format, @@ -79,10 +80,18 @@ class StreamRequest(BaseModel): redoc_url=None, docs=None, ) +app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_methods=['*'], + allow_headers=['*'], +) STATUS_OK = 1 STATUS_ERR = 0 SPECIAL_MODEL_DIR = os.environ.get('SPECIAL_MODEL_DIR') +TASK_DLLM_SERVE_MODEL_NAME = os.environ.get('TASK_DLLM_SERVE_MODEL_NAME', + 'xyz-dllm') TASK_DLLM_NUM_GPUS = int(os.environ.get('TASK_DLLM_NUM_GPUS', 1)) TASK_DLLM_GEN_LENGTH = int(os.environ.get('TASK_DLLM_GEN_LENGTH', 512)) @@ -199,7 +208,7 @@ def _generate(): stream_put_api(chat_uuid, None) except Exception as err: - logging.error("Generate error: %s", err) + logging.error('Generate error: %s', err) stream_put_api(chat_uuid, None) thread = threading.Thread(target=_generate, daemon=True) @@ -212,6 +221,7 @@ def get_answer_openai(chat_uuid: str, data: Dict) -> str: generate_in_background(chat_uuid, data) resp = {} + pre_text = '' while True: curr_x = stream_get_api(chat_uuid) @@ -234,12 +244,12 @@ def get_answer_openai(chat_uuid: str, data: Dict) -> str: 'created': time.time(), 'model': - 'xyz-dllm', + TASK_DLLM_SERVE_MODEL_NAME, 'choices': [{ 'index': 0, 'delta': { 'role': 'assistant', - 'content': text, + 'content': text[len(pre_text):], 'resoning_content': None }, 'logprobs': None, @@ -248,14 +258,18 @@ def get_answer_openai(chat_uuid: str, data: Dict) -> str: 'prompt_token_ids': None } - yield 'data: ' + json.dumps(resp, ensure_ascii=False) + '\n' + pre_text = text + yield 'data: ' + json.dumps(resp, ensure_ascii=False) + '\n\n' - logging.info('[%s] resp: %s', chat_uuid, resp) + yield 'data: [DONE]' with stream_lock: if chat_uuid in global_stream_dict: del global_stream_dict[chat_uuid] + resp['choices'][0]['delta']['content'] = pre_text + logging.info('[%s] resp: %s', chat_uuid, resp) + def get_answer_openai_no_stream(chat_uuid: str, data: Dict) -> str: ''' get answer openai (no stream) ''' @@ -290,7 +304,7 @@ def get_answer_openai_no_stream(chat_uuid: str, data: Dict) -> str: 'created': time.time(), 'model': - 'xyz-dllm', + TASK_DLLM_SERVE_MODEL_NAME, 'choices': [{ 'index': 0, 'message': { @@ -303,11 +317,11 @@ def get_answer_openai_no_stream(chat_uuid: str, data: Dict) -> str: }], 'prompt_token_ids': None, - "usage": { - "prompt_tokens": input_ids.shape[-1], - "total_tokens": len(x_tokens_final[0]), - "completion_tokens": len(x_tokens_final[0]) - input_ids.shape[-1], - "prompt_tokens_details": None + 'usage': { + 'prompt_tokens': input_ids.shape[-1], + 'total_tokens': len(x_tokens_final[0]), + 'completion_tokens': len(x_tokens_final[0]) - input_ids.shape[-1], + 'prompt_tokens_details': None } } logging.info('[%s] resp: %s', chat_uuid, resp) @@ -332,7 +346,7 @@ def chat_openai(request: CompletionsRequest): def stream_put_endpoint(request: StreamRequest): ''' stream_put_endpoint ''' stream_put_api(request.chat_uuid, request.curr_x) - return {"status": "success", "message": "Data added to stream"} + return {'status': 'success', 'message': 'Data added to stream'} @app.post('/v1/stream_get') @@ -342,6 +356,48 @@ def stream_get_endpoint(request: StreamRequest): return curr_x +@app.api_route('/v1/models', methods=['GET', 'POST', 'OPTIONS']) +def get_models(): + ''' get_models + ''' + created = time.time() + resp = { + 'object': + 'list', + 'data': [{ + 'id': + TASK_DLLM_SERVE_MODEL_NAME, + 'object': + 'model', + 'created': + created, + 'owned_by': + 'dinfer', + 'root': + os.environ.get('SPECIAL_MODEL_DIR'), + 'parent': + None, + 'max_model_len': + TASK_DLLM_MAX_LENGTH, + 'permission': [{ + 'id': 'modelperm-xyz-dllm', + 'object': 'model_permission', + 'created': created, + 'allow_create_engine': False, + 'allow_sampling': True, + 'allow_logprobs': True, + 'allow_search_indices': False, + 'allow_view': True, + 'allow_fine_tuning': False, + 'organization': '*', + 'group': None, + 'is_blocking': False + }] + }] + } + return resp + + def mission(): ''' mission ''' global tokenizer, MODEL_DLLM