diff --git a/demo/serving_dinfer_openai.md b/demo/serving_dinfer_openai.md
new file mode 100644
index 0000000..4a9aa5c
--- /dev/null
+++ b/demo/serving_dinfer_openai.md
@@ -0,0 +1,30 @@
+# Demo of dinfer openai serving
+
+## Serving
+
+```bash
+export SPECIAL_MODEL_DIR=/models/LLaDA2.0-mini--572899f-C8 # Download from https://www.modelscope.cn/models/inclusionAI/LLaDA2.0-mini
+export TASK_DLLM_BATCH_SIZE=2
+python3 serving_dinfer_openai.py
+```
+
+## Client
+
+```bash
+date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer YOUR_API_KEY" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": false}' http://0.0.0.0:48000/v1/chat/completions && date
+
+date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer YOUR_API_KEY" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": true}' http://0.0.0.0:48000/v1/chat/completions && date
+```
+
+## Web demo
+
+```bash
+date && docker pull ghcr.io/open-webui/open-webui:main && date
+
+mkdir data-open-webui
+cd data-open-webui
+
+date && docker run -d -p 60111:8080 -v $PWD:/app/backend/data --name open-webui ghcr.io/open-webui/open-webui:main && date
+```
+
+Config your open-webui with [http://0.0.0.0:48000/v1](http://0.0.0.0:48000/v1) like [this](https://developer.volcengine.com/articles/7533551308616237092#heading10)
diff --git a/demo/serving_dinfer_openai.py b/demo/serving_dinfer_openai.py
new file mode 100644
index 0000000..248dc2e
--- /dev/null
+++ b/demo/serving_dinfer_openai.py
@@ -0,0 +1,410 @@
+'''
+This is a fastapi dinfer serving of llada
+'''
+
+# pylint: disable=import-error, no-name-in-module
+# pylint: disable=global-statement, global-variable-not-assigned
+# pylint: disable=too-few-public-methods
+# pylint: disable=broad-exception-caught
+
+import json
+import logging
+import os
+from queue import Queue
+import threading
+import time
+from typing import Dict, List, Optional
+import uuid
+
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse, Response
+from pydantic import BaseModel, Field
+import torch
+from transformers import AutoTokenizer
+import uvicorn
+
+
+def init_logger(log_path: str = None):
+ ''' init logger with exception handling '''
+ log_format = '[%(asctime)s](%(created).17g) - %(levelname)s - ' \
+ '|%(pathname)s|%(funcName)s|%(lineno)d| - %(message)s'
+ date_format = '%Y-%m-%d_%H:%M:%S'
+
+ class CustomFormatter(logging.Formatter):
+ ''' CustomFormatter '''
+
+ def format(self, record):
+ ''' format '''
+ try:
+ return super().format(record)
+ except TypeError:
+ if hasattr(record, 'args') and record.args:
+ record.msg = f'{record.msg} - Unformattable args: {record.args}'
+ record.args = ()
+ return super().format(record)
+
+ task_logging_level = os.environ.get('TASK_LOGGING_LEVEL', 'INFO')
+ log_level = logging.INFO if task_logging_level == 'INFO' else logging.DEBUG
+ logging.basicConfig(filename=log_path,
+ level=log_level,
+ format=log_format,
+ datefmt=date_format)
+
+ for handler in logging.root.handlers:
+ handler.setFormatter(CustomFormatter(log_format, date_format))
+
+ return logging
+
+
+logging = init_logger()
+
+from dinfer import DiffusionLLMServing, SamplingParams, ThresholdParallelDecoder # pylint: disable=wrong-import-position
+
+
+class CompletionsRequest(BaseModel):
+ ''' Completions Request
+ '''
+ messages: List[Dict] = Field(title='messages')
+ stream: bool = Field(title='stream')
+
+
+class StreamRequest(BaseModel):
+ '''Stream Request'''
+ chat_uuid: str = Field(title='chat_uuid')
+ curr_x: Optional[List] = Field(title='curr_x')
+
+
+app = FastAPI(
+ title='xyz dllm serving',
+ redoc_url=None,
+ docs=None,
+)
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=['*'],
+ allow_methods=['*'],
+ allow_headers=['*'],
+)
+STATUS_OK = 1
+STATUS_ERR = 0
+
+SPECIAL_MODEL_DIR = os.environ.get('SPECIAL_MODEL_DIR')
+TASK_DLLM_SERVE_MODEL_NAME = os.environ.get('TASK_DLLM_SERVE_MODEL_NAME',
+ 'xyz-dllm')
+TASK_DLLM_NUM_GPUS = int(os.environ.get('TASK_DLLM_NUM_GPUS', 1))
+
+TASK_DLLM_GEN_LENGTH = int(os.environ.get('TASK_DLLM_GEN_LENGTH', 512))
+TASK_DLLM_BLOCK_LENGTH = int(os.environ.get('TASK_DLLM_BLOCK_LENGTH', 32))
+
+TASK_DLLM_MAX_LENGTH = int(os.environ.get('TASK_DLLM_MAX_LENGTH', 4096))
+TASK_DLLM_BATCH_SIZE = int(os.environ.get('TASK_DLLM_BATCH_SIZE', 2))
+TASK_DLLM_TEMPERATURE = float(os.environ.get('TASK_DLLM_TEMPERATURE', 0.0))
+TASK_DLLM_THRESHOLD = float(os.environ.get('TASK_DLLM_THRESHOLD', 0.9))
+
+TASK_DLLM_MASK_ID = int(os.environ.get('TASK_DLLM_MASK_ID', 156895))
+TASK_DLLM_EOS_ID = int(os.environ.get('TASK_DLLM_EOS_ID', 156892))
+
+TASK_STREAM_SLEEP_SECONDS = float(
+ os.environ.get('TASK_STREAM_SLEEP_SECONDS', 0.005))
+
+
+def get_dllm():
+ ''' get dllm '''
+ dllm_tokenizer = AutoTokenizer.from_pretrained(SPECIAL_MODEL_DIR,
+ trust_remote_code=True)
+ decoder = ThresholdParallelDecoder(temperature=TASK_DLLM_TEMPERATURE,
+ threshold=TASK_DLLM_THRESHOLD,
+ mask_id=TASK_DLLM_MASK_ID,
+ eos_id=TASK_DLLM_EOS_ID)
+ sample_params = SamplingParams(threshold=TASK_DLLM_THRESHOLD,
+ cache='prefix',
+ temperature=0.,
+ early_stop=True,
+ cont_weight=0,
+ prefix_look=0,
+ after_look=0,
+ warmup_steps=0,
+ enable_torch_compile=True,
+ mask_id=TASK_DLLM_MASK_ID,
+ eos_id=TASK_DLLM_EOS_ID,
+ parallel_decoding='threshold',
+ use_credit=False,
+ use_bd=True,
+ max_length=TASK_DLLM_MAX_LENGTH,
+ ep_size=1,
+ batch_size=TASK_DLLM_BATCH_SIZE,
+ mini_batch_size=TASK_DLLM_BATCH_SIZE,
+ use_naive_batching=True)
+ dllm_server = DiffusionLLMServing(SPECIAL_MODEL_DIR,
+ model_type='llada2-mini',
+ sample_params=sample_params,
+ server_port=40570,
+ num_gpus=TASK_DLLM_NUM_GPUS,
+ dp_size=1,
+ tpep_size=TASK_DLLM_NUM_GPUS,
+ backend='sglang')
+ return dllm_tokenizer, dllm_server, decoder
+
+
+tokenizer, MODEL_DLLM = None, None
+global_stream_dict = {}
+stream_lock = threading.Lock()
+
+
+def stream_put_api(chat_uuid: str, curr_x: Optional[List]):
+ ''' stream_put_api '''
+ with stream_lock:
+ if chat_uuid not in global_stream_dict:
+ global_stream_dict[chat_uuid] = Queue()
+
+ if curr_x is None:
+ global_stream_dict[chat_uuid].put([])
+ else:
+ global_stream_dict[chat_uuid].put(curr_x)
+
+
+def stream_get_api(chat_uuid: str) -> Optional[List]:
+ ''' stream_get_api '''
+ try:
+ with stream_lock:
+ if chat_uuid not in global_stream_dict:
+ return None
+
+ queue = global_stream_dict[chat_uuid]
+ result = queue.get(timeout=0)
+
+ if result == []:
+ del global_stream_dict[chat_uuid]
+
+ return result
+
+ except Exception:
+ return None
+
+
+def generate_in_background(chat_uuid: str, data: Dict):
+ ''' generate_in_background '''
+
+ def _generate():
+ try:
+ global tokenizer, MODEL_DLLM
+ input_ids = tokenizer.apply_chat_template(
+ data['messages'],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_tensors='pt',
+ )
+ batch_input_ids = torch.zeros(
+ (input_ids.shape[0], TASK_DLLM_MAX_LENGTH),
+ dtype=torch.long).fill_(TASK_DLLM_MASK_ID)
+ for s_k in range(input_ids.shape[0]):
+ batch_input_ids[s_k, :input_ids.shape[-1]] = input_ids[s_k]
+
+ _ = MODEL_DLLM.generate(batch_input_ids,
+ chat_uuid=chat_uuid,
+ gen_length=TASK_DLLM_GEN_LENGTH,
+ block_length=TASK_DLLM_BLOCK_LENGTH)
+ stream_put_api(chat_uuid, None)
+
+ except Exception as err:
+ logging.error('Generate error: %s', err)
+ stream_put_api(chat_uuid, None)
+
+ thread = threading.Thread(target=_generate, daemon=True)
+ thread.start()
+
+
+def get_answer_openai(chat_uuid: str, data: Dict) -> str:
+ '''get answer openai'''
+
+ generate_in_background(chat_uuid, data)
+ resp = {}
+
+ pre_text = ''
+ while True:
+ curr_x = stream_get_api(chat_uuid)
+
+ if curr_x is None:
+ time.sleep(TASK_STREAM_SLEEP_SECONDS)
+ continue
+
+ if curr_x == []:
+ break
+
+ x_str = tokenizer.decode(curr_x[0])
+ text = x_str.split('ASSISTANT')[-1]
+ text = text.replace('<|endoftext|>', '').replace('<|role_end|>', '')
+ text = text.replace('<|mask|>', ' ').rstrip(' ')
+ resp = {
+ 'id':
+ chat_uuid,
+ 'object':
+ 'chat.completion.chunk',
+ 'created':
+ time.time(),
+ 'model':
+ TASK_DLLM_SERVE_MODEL_NAME,
+ 'choices': [{
+ 'index': 0,
+ 'delta': {
+ 'role': 'assistant',
+ 'content': text[len(pre_text):],
+ 'resoning_content': None
+ },
+ 'logprobs': None,
+ 'finish_reason': None,
+ }],
+ 'prompt_token_ids':
+ None
+ }
+ pre_text = text
+ yield 'data: ' + json.dumps(resp, ensure_ascii=False) + '\n\n'
+
+ yield 'data: [DONE]'
+
+ with stream_lock:
+ if chat_uuid in global_stream_dict:
+ del global_stream_dict[chat_uuid]
+
+ resp['choices'][0]['delta']['content'] = pre_text
+ logging.info('[%s] resp: %s', chat_uuid, resp)
+
+
+def get_answer_openai_no_stream(chat_uuid: str, data: Dict) -> str:
+ ''' get answer openai (no stream) '''
+ global tokenizer, MODEL_DLLM
+ input_ids = tokenizer.apply_chat_template(
+ data['messages'],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_tensors='pt',
+ )
+ batch_input_ids = torch.zeros((input_ids.shape[0], TASK_DLLM_MAX_LENGTH),
+ dtype=torch.long).fill_(TASK_DLLM_MASK_ID)
+ for s_k in range(input_ids.shape[0]):
+ batch_input_ids[s_k, :input_ids.shape[-1]] = input_ids[s_k]
+
+ x_tokens_yield = MODEL_DLLM.generate(batch_input_ids,
+ gen_length=TASK_DLLM_GEN_LENGTH,
+ block_length=TASK_DLLM_BLOCK_LENGTH)
+
+ resp = {}
+ x_tokens_final = x_tokens_yield
+
+ x_str = tokenizer.decode(x_tokens_final[0])
+ text = x_str.split('ASSISTANT')[-1]
+ text = text.replace('<|endoftext|>', '').replace('<|role_end|>', '')
+ text = text.replace('<|mask|>', ' ')
+ resp = {
+ 'id':
+ chat_uuid,
+ 'object':
+ 'chat.completion',
+ 'created':
+ time.time(),
+ 'model':
+ TASK_DLLM_SERVE_MODEL_NAME,
+ 'choices': [{
+ 'index': 0,
+ 'message': {
+ 'role': 'assistant',
+ 'content': text,
+ 'resoning_content': None
+ },
+ 'logprobs': None,
+ 'finish_reason': 'stop',
+ }],
+ 'prompt_token_ids':
+ None,
+ 'usage': {
+ 'prompt_tokens': input_ids.shape[-1],
+ 'total_tokens': len(x_tokens_final[0]),
+ 'completion_tokens': len(x_tokens_final[0]) - input_ids.shape[-1],
+ 'prompt_tokens_details': None
+ }
+ }
+ logging.info('[%s] resp: %s', chat_uuid, resp)
+ return json.dumps(resp, ensure_ascii=False) + '\n'
+
+
+@app.post('/v1/chat/completions')
+def chat_openai(request: CompletionsRequest):
+ ''' chat '''
+ chat_uuid = f'chat-{str(uuid.uuid4())}'
+ data = request.dict()
+ logging.info('[%s] req: %s', chat_uuid, data)
+ if request.stream:
+ return StreamingResponse(get_answer_openai(chat_uuid, data))
+ return Response(
+ content=get_answer_openai_no_stream(chat_uuid, data),
+ media_type='text/plain',
+ )
+
+
+@app.post('/v1/stream_put')
+def stream_put_endpoint(request: StreamRequest):
+ ''' stream_put_endpoint '''
+ stream_put_api(request.chat_uuid, request.curr_x)
+ return {'status': 'success', 'message': 'Data added to stream'}
+
+
+@app.post('/v1/stream_get')
+def stream_get_endpoint(request: StreamRequest):
+ ''' stream_get_endpoint '''
+ curr_x = stream_get_api(request.chat_uuid)
+ return curr_x
+
+
+@app.api_route('/v1/models', methods=['GET', 'POST', 'OPTIONS'])
+def get_models():
+ ''' get_models
+ '''
+ created = time.time()
+ resp = {
+ 'object':
+ 'list',
+ 'data': [{
+ 'id':
+ TASK_DLLM_SERVE_MODEL_NAME,
+ 'object':
+ 'model',
+ 'created':
+ created,
+ 'owned_by':
+ 'dinfer',
+ 'root':
+ os.environ.get('SPECIAL_MODEL_DIR'),
+ 'parent':
+ None,
+ 'max_model_len':
+ TASK_DLLM_MAX_LENGTH,
+ 'permission': [{
+ 'id': 'modelperm-xyz-dllm',
+ 'object': 'model_permission',
+ 'created': created,
+ 'allow_create_engine': False,
+ 'allow_sampling': True,
+ 'allow_logprobs': True,
+ 'allow_search_indices': False,
+ 'allow_view': True,
+ 'allow_fine_tuning': False,
+ 'organization': '*',
+ 'group': None,
+ 'is_blocking': False
+ }]
+ }]
+ }
+ return resp
+
+
+def mission():
+ ''' mission '''
+ global tokenizer, MODEL_DLLM
+ tokenizer, MODEL_DLLM, _ = get_dllm()
+ port = int(os.environ.get('TASK_SERVER_PORT', '40081'))
+ uvicorn.run(app, host='0.0.0.0', port=port)
+
+
+if __name__ == '__main__':
+ mission()
diff --git a/python/dinfer/decoding/generate_uniform.py b/python/dinfer/decoding/generate_uniform.py
index 9fb5eed..f06f18c 100644
--- a/python/dinfer/decoding/generate_uniform.py
+++ b/python/dinfer/decoding/generate_uniform.py
@@ -1,3 +1,6 @@
+import os
+import requests
+
import torch
import numpy as np
import logging
@@ -1041,7 +1044,7 @@ def cache_updates(self):
return self.diff_iteration.cache_updates
@ torch.no_grad()
- def naive_batching_generate(self, prompt, gen_length=128, block_length=128):
+ def naive_batching_generate(self, prompt, gen_length=128, block_length=128, chat_uuid=None):
''' Generate tokens with diffusion iterations block by block.
'''
# recalculate gen length and init iteratory
@@ -1075,7 +1078,13 @@ def naive_batching_generate(self, prompt, gen_length=128, block_length=128):
# We need to reset iter_no at the beginning of generating a sequence.
self.diff_iteration.iter_no = 0
+ ip_port_url = 'http://0.0.0.0:' + os.environ.get('TASK_SERVER_PORT', '40081')
+ logger.info(f'[{chat_uuid}] ip_port_url: {ip_port_url}')
for block_id, (block_loc, block) in enumerate(it):
+ if chat_uuid:
+ x_list = x.get_generated_tokens().tolist()
+ stream_dict = {'chat_uuid': chat_uuid, 'curr_x': x_list}
+ requests.post(f'{ip_port_url}/v1/stream_put', json=stream_dict)
self.decoder.block_init(block, block_id)
if self.backend == 'vllm':
cross_block_attn_mask = bd_attn_mask[:,block_loc.start-block_length:block_loc.end, :block_loc.end]
@@ -1086,10 +1095,16 @@ def naive_batching_generate(self, prompt, gen_length=128, block_length=128):
if torch.all(decode_compl) and self.early_stop:
break
logger.info(f'The number of diffusion iterations: {self.num_forwards}')
- return x.get_generated_tokens()
+
+ x_generated_tokens = x.get_generated_tokens()
+ if chat_uuid:
+ stream_dict = {'chat_uuid': chat_uuid, 'curr_x': x_generated_tokens.tolist()}
+ requests.post(f'{ip_port_url}/v1/stream_put', json=stream_dict)
+
+ return x_generated_tokens
@ torch.no_grad()
- def dynamic_batching_generate(self, prompt, gen_length=128, block_length=128):
+ def dynamic_batching_generate(self, prompt, gen_length=128, block_length=128, chat_uuid=None):
''' Generate tokens with dynamic batching
'''
assert self.cache_factory is not None
diff --git a/python/dinfer/decoding/serving.py b/python/dinfer/decoding/serving.py
index 056cfc3..c74114a 100644
--- a/python/dinfer/decoding/serving.py
+++ b/python/dinfer/decoding/serving.py
@@ -268,9 +268,9 @@ def generate(dllm, device, req_q, res_q):
assert data == 'stop'
break
else:
- input_ids, gen_len, block_len = data
+ input_ids, gen_len, block_len, chat_uuid = data
input_ids = input_ids.to(device)
- out = dllm.generate(input_ids, gen_length=gen_len, block_length=block_len)
+ out = dllm.generate(input_ids, gen_length=gen_len, block_length=block_len, chat_uuid=chat_uuid)
num_forwards = dllm.num_forwards
if res_q is not None:
res_q.put((out, num_forwards))
@@ -485,18 +485,18 @@ def __init__(self):
self.groups = []
self.need_response = 0
- def add_requests(self, reqs):
+ def add_requests(self, reqs, chat_uuid=None):
prompts, gen_length, block_length = reqs
self.need_response = 0
# assert len(self.groups) == prompts.shape[0], 'We cannot only use DP to support batch size > 1.'
if len(self.groups) == prompts.shape[0]:
for i, prompt in enumerate(prompts):
- self.groups[i].add_request((prompt.unsqueeze(0), gen_length, block_length))
+ self.groups[i].add_request((prompt.unsqueeze(0), gen_length, block_length, chat_uuid))
self.need_response = len(prompts)
else:
partial_data = torch.chunk(prompts, len(self.groups), dim=0)
for i in range(len(partial_data)):
- self.groups[i].add_request((partial_data[i], gen_length, block_length))
+ self.groups[i].add_request((partial_data[i], gen_length, block_length, chat_uuid))
self.need_response = len(partial_data)
@@ -580,7 +580,7 @@ def __init__(self, model, model_type='llada2', sample_params=None, server_port=N
self.num_forwards = 0
self.timeout = timeout
- def generate(self, prompts, gen_length=128, block_length=128):
+ def generate(self, prompts, gen_length=128, block_length=128, chat_uuid=None):
''' Generate tokens with diffusion iterations.
Parameters:
@@ -598,7 +598,7 @@ def generate(self, prompts, gen_length=128, block_length=128):
The generation results of different lengths are padded with EOS.
'''
prompts = prompts.cpu()
- handle.add_requests((prompts, gen_length, block_length))
+ handle.add_requests((prompts, gen_length, block_length), chat_uuid=chat_uuid)
rets = handle.get_responses(timeout=self.timeout)
max_len = max([tensor.shape[1] for (tensor, _) in rets])
total_batch_size = sum([tensor.shape[0] for (tensor, _) in rets])