diff --git a/demo/serving_dinfer_openai.md b/demo/serving_dinfer_openai.md new file mode 100644 index 0000000..4a9aa5c --- /dev/null +++ b/demo/serving_dinfer_openai.md @@ -0,0 +1,30 @@ +# Demo of dinfer openai serving + +## Serving + +```bash +export SPECIAL_MODEL_DIR=/models/LLaDA2.0-mini--572899f-C8 # Download from https://www.modelscope.cn/models/inclusionAI/LLaDA2.0-mini +export TASK_DLLM_BATCH_SIZE=2 +python3 serving_dinfer_openai.py +``` + +## Client + +```bash +date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer YOUR_API_KEY" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": false}' http://0.0.0.0:48000/v1/chat/completions && date + +date && curl -X POST -H "Content-Type: application/json" -H "Authorization: Bearer YOUR_API_KEY" -N -d '{"messages": [{"role": "user", "content": "你好, 我是小明"}], "stream": true}' http://0.0.0.0:48000/v1/chat/completions && date +``` + +## Web demo + +```bash +date && docker pull ghcr.io/open-webui/open-webui:main && date + +mkdir data-open-webui +cd data-open-webui + +date && docker run -d -p 60111:8080 -v $PWD:/app/backend/data --name open-webui ghcr.io/open-webui/open-webui:main && date +``` + +Config your open-webui with [http://0.0.0.0:48000/v1](http://0.0.0.0:48000/v1) like [this](https://developer.volcengine.com/articles/7533551308616237092#heading10) diff --git a/demo/serving_dinfer_openai.py b/demo/serving_dinfer_openai.py new file mode 100644 index 0000000..248dc2e --- /dev/null +++ b/demo/serving_dinfer_openai.py @@ -0,0 +1,410 @@ +''' +This is a fastapi dinfer serving of llada +''' + +# pylint: disable=import-error, no-name-in-module +# pylint: disable=global-statement, global-variable-not-assigned +# pylint: disable=too-few-public-methods +# pylint: disable=broad-exception-caught + +import json +import logging +import os +from queue import Queue +import threading +import time +from typing import Dict, List, Optional +import uuid + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, Response +from pydantic import BaseModel, Field +import torch +from transformers import AutoTokenizer +import uvicorn + + +def init_logger(log_path: str = None): + ''' init logger with exception handling ''' + log_format = '[%(asctime)s](%(created).17g) - %(levelname)s - ' \ + '|%(pathname)s|%(funcName)s|%(lineno)d| - %(message)s' + date_format = '%Y-%m-%d_%H:%M:%S' + + class CustomFormatter(logging.Formatter): + ''' CustomFormatter ''' + + def format(self, record): + ''' format ''' + try: + return super().format(record) + except TypeError: + if hasattr(record, 'args') and record.args: + record.msg = f'{record.msg} - Unformattable args: {record.args}' + record.args = () + return super().format(record) + + task_logging_level = os.environ.get('TASK_LOGGING_LEVEL', 'INFO') + log_level = logging.INFO if task_logging_level == 'INFO' else logging.DEBUG + logging.basicConfig(filename=log_path, + level=log_level, + format=log_format, + datefmt=date_format) + + for handler in logging.root.handlers: + handler.setFormatter(CustomFormatter(log_format, date_format)) + + return logging + + +logging = init_logger() + +from dinfer import DiffusionLLMServing, SamplingParams, ThresholdParallelDecoder # pylint: disable=wrong-import-position + + +class CompletionsRequest(BaseModel): + ''' Completions Request + ''' + messages: List[Dict] = Field(title='messages') + stream: bool = Field(title='stream') + + +class StreamRequest(BaseModel): + '''Stream Request''' + chat_uuid: str = Field(title='chat_uuid') + curr_x: Optional[List] = Field(title='curr_x') + + +app = FastAPI( + title='xyz dllm serving', + redoc_url=None, + docs=None, +) +app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_methods=['*'], + allow_headers=['*'], +) +STATUS_OK = 1 +STATUS_ERR = 0 + +SPECIAL_MODEL_DIR = os.environ.get('SPECIAL_MODEL_DIR') +TASK_DLLM_SERVE_MODEL_NAME = os.environ.get('TASK_DLLM_SERVE_MODEL_NAME', + 'xyz-dllm') +TASK_DLLM_NUM_GPUS = int(os.environ.get('TASK_DLLM_NUM_GPUS', 1)) + +TASK_DLLM_GEN_LENGTH = int(os.environ.get('TASK_DLLM_GEN_LENGTH', 512)) +TASK_DLLM_BLOCK_LENGTH = int(os.environ.get('TASK_DLLM_BLOCK_LENGTH', 32)) + +TASK_DLLM_MAX_LENGTH = int(os.environ.get('TASK_DLLM_MAX_LENGTH', 4096)) +TASK_DLLM_BATCH_SIZE = int(os.environ.get('TASK_DLLM_BATCH_SIZE', 2)) +TASK_DLLM_TEMPERATURE = float(os.environ.get('TASK_DLLM_TEMPERATURE', 0.0)) +TASK_DLLM_THRESHOLD = float(os.environ.get('TASK_DLLM_THRESHOLD', 0.9)) + +TASK_DLLM_MASK_ID = int(os.environ.get('TASK_DLLM_MASK_ID', 156895)) +TASK_DLLM_EOS_ID = int(os.environ.get('TASK_DLLM_EOS_ID', 156892)) + +TASK_STREAM_SLEEP_SECONDS = float( + os.environ.get('TASK_STREAM_SLEEP_SECONDS', 0.005)) + + +def get_dllm(): + ''' get dllm ''' + dllm_tokenizer = AutoTokenizer.from_pretrained(SPECIAL_MODEL_DIR, + trust_remote_code=True) + decoder = ThresholdParallelDecoder(temperature=TASK_DLLM_TEMPERATURE, + threshold=TASK_DLLM_THRESHOLD, + mask_id=TASK_DLLM_MASK_ID, + eos_id=TASK_DLLM_EOS_ID) + sample_params = SamplingParams(threshold=TASK_DLLM_THRESHOLD, + cache='prefix', + temperature=0., + early_stop=True, + cont_weight=0, + prefix_look=0, + after_look=0, + warmup_steps=0, + enable_torch_compile=True, + mask_id=TASK_DLLM_MASK_ID, + eos_id=TASK_DLLM_EOS_ID, + parallel_decoding='threshold', + use_credit=False, + use_bd=True, + max_length=TASK_DLLM_MAX_LENGTH, + ep_size=1, + batch_size=TASK_DLLM_BATCH_SIZE, + mini_batch_size=TASK_DLLM_BATCH_SIZE, + use_naive_batching=True) + dllm_server = DiffusionLLMServing(SPECIAL_MODEL_DIR, + model_type='llada2-mini', + sample_params=sample_params, + server_port=40570, + num_gpus=TASK_DLLM_NUM_GPUS, + dp_size=1, + tpep_size=TASK_DLLM_NUM_GPUS, + backend='sglang') + return dllm_tokenizer, dllm_server, decoder + + +tokenizer, MODEL_DLLM = None, None +global_stream_dict = {} +stream_lock = threading.Lock() + + +def stream_put_api(chat_uuid: str, curr_x: Optional[List]): + ''' stream_put_api ''' + with stream_lock: + if chat_uuid not in global_stream_dict: + global_stream_dict[chat_uuid] = Queue() + + if curr_x is None: + global_stream_dict[chat_uuid].put([]) + else: + global_stream_dict[chat_uuid].put(curr_x) + + +def stream_get_api(chat_uuid: str) -> Optional[List]: + ''' stream_get_api ''' + try: + with stream_lock: + if chat_uuid not in global_stream_dict: + return None + + queue = global_stream_dict[chat_uuid] + result = queue.get(timeout=0) + + if result == []: + del global_stream_dict[chat_uuid] + + return result + + except Exception: + return None + + +def generate_in_background(chat_uuid: str, data: Dict): + ''' generate_in_background ''' + + def _generate(): + try: + global tokenizer, MODEL_DLLM + input_ids = tokenizer.apply_chat_template( + data['messages'], + add_generation_prompt=True, + tokenize=True, + return_tensors='pt', + ) + batch_input_ids = torch.zeros( + (input_ids.shape[0], TASK_DLLM_MAX_LENGTH), + dtype=torch.long).fill_(TASK_DLLM_MASK_ID) + for s_k in range(input_ids.shape[0]): + batch_input_ids[s_k, :input_ids.shape[-1]] = input_ids[s_k] + + _ = MODEL_DLLM.generate(batch_input_ids, + chat_uuid=chat_uuid, + gen_length=TASK_DLLM_GEN_LENGTH, + block_length=TASK_DLLM_BLOCK_LENGTH) + stream_put_api(chat_uuid, None) + + except Exception as err: + logging.error('Generate error: %s', err) + stream_put_api(chat_uuid, None) + + thread = threading.Thread(target=_generate, daemon=True) + thread.start() + + +def get_answer_openai(chat_uuid: str, data: Dict) -> str: + '''get answer openai''' + + generate_in_background(chat_uuid, data) + resp = {} + + pre_text = '' + while True: + curr_x = stream_get_api(chat_uuid) + + if curr_x is None: + time.sleep(TASK_STREAM_SLEEP_SECONDS) + continue + + if curr_x == []: + break + + x_str = tokenizer.decode(curr_x[0]) + text = x_str.split('ASSISTANT')[-1] + text = text.replace('<|endoftext|>', '').replace('<|role_end|>', '') + text = text.replace('<|mask|>', ' ').rstrip(' ') + resp = { + 'id': + chat_uuid, + 'object': + 'chat.completion.chunk', + 'created': + time.time(), + 'model': + TASK_DLLM_SERVE_MODEL_NAME, + 'choices': [{ + 'index': 0, + 'delta': { + 'role': 'assistant', + 'content': text[len(pre_text):], + 'resoning_content': None + }, + 'logprobs': None, + 'finish_reason': None, + }], + 'prompt_token_ids': + None + } + pre_text = text + yield 'data: ' + json.dumps(resp, ensure_ascii=False) + '\n\n' + + yield 'data: [DONE]' + + with stream_lock: + if chat_uuid in global_stream_dict: + del global_stream_dict[chat_uuid] + + resp['choices'][0]['delta']['content'] = pre_text + logging.info('[%s] resp: %s', chat_uuid, resp) + + +def get_answer_openai_no_stream(chat_uuid: str, data: Dict) -> str: + ''' get answer openai (no stream) ''' + global tokenizer, MODEL_DLLM + input_ids = tokenizer.apply_chat_template( + data['messages'], + add_generation_prompt=True, + tokenize=True, + return_tensors='pt', + ) + batch_input_ids = torch.zeros((input_ids.shape[0], TASK_DLLM_MAX_LENGTH), + dtype=torch.long).fill_(TASK_DLLM_MASK_ID) + for s_k in range(input_ids.shape[0]): + batch_input_ids[s_k, :input_ids.shape[-1]] = input_ids[s_k] + + x_tokens_yield = MODEL_DLLM.generate(batch_input_ids, + gen_length=TASK_DLLM_GEN_LENGTH, + block_length=TASK_DLLM_BLOCK_LENGTH) + + resp = {} + x_tokens_final = x_tokens_yield + + x_str = tokenizer.decode(x_tokens_final[0]) + text = x_str.split('ASSISTANT')[-1] + text = text.replace('<|endoftext|>', '').replace('<|role_end|>', '') + text = text.replace('<|mask|>', ' ') + resp = { + 'id': + chat_uuid, + 'object': + 'chat.completion', + 'created': + time.time(), + 'model': + TASK_DLLM_SERVE_MODEL_NAME, + 'choices': [{ + 'index': 0, + 'message': { + 'role': 'assistant', + 'content': text, + 'resoning_content': None + }, + 'logprobs': None, + 'finish_reason': 'stop', + }], + 'prompt_token_ids': + None, + 'usage': { + 'prompt_tokens': input_ids.shape[-1], + 'total_tokens': len(x_tokens_final[0]), + 'completion_tokens': len(x_tokens_final[0]) - input_ids.shape[-1], + 'prompt_tokens_details': None + } + } + logging.info('[%s] resp: %s', chat_uuid, resp) + return json.dumps(resp, ensure_ascii=False) + '\n' + + +@app.post('/v1/chat/completions') +def chat_openai(request: CompletionsRequest): + ''' chat ''' + chat_uuid = f'chat-{str(uuid.uuid4())}' + data = request.dict() + logging.info('[%s] req: %s', chat_uuid, data) + if request.stream: + return StreamingResponse(get_answer_openai(chat_uuid, data)) + return Response( + content=get_answer_openai_no_stream(chat_uuid, data), + media_type='text/plain', + ) + + +@app.post('/v1/stream_put') +def stream_put_endpoint(request: StreamRequest): + ''' stream_put_endpoint ''' + stream_put_api(request.chat_uuid, request.curr_x) + return {'status': 'success', 'message': 'Data added to stream'} + + +@app.post('/v1/stream_get') +def stream_get_endpoint(request: StreamRequest): + ''' stream_get_endpoint ''' + curr_x = stream_get_api(request.chat_uuid) + return curr_x + + +@app.api_route('/v1/models', methods=['GET', 'POST', 'OPTIONS']) +def get_models(): + ''' get_models + ''' + created = time.time() + resp = { + 'object': + 'list', + 'data': [{ + 'id': + TASK_DLLM_SERVE_MODEL_NAME, + 'object': + 'model', + 'created': + created, + 'owned_by': + 'dinfer', + 'root': + os.environ.get('SPECIAL_MODEL_DIR'), + 'parent': + None, + 'max_model_len': + TASK_DLLM_MAX_LENGTH, + 'permission': [{ + 'id': 'modelperm-xyz-dllm', + 'object': 'model_permission', + 'created': created, + 'allow_create_engine': False, + 'allow_sampling': True, + 'allow_logprobs': True, + 'allow_search_indices': False, + 'allow_view': True, + 'allow_fine_tuning': False, + 'organization': '*', + 'group': None, + 'is_blocking': False + }] + }] + } + return resp + + +def mission(): + ''' mission ''' + global tokenizer, MODEL_DLLM + tokenizer, MODEL_DLLM, _ = get_dllm() + port = int(os.environ.get('TASK_SERVER_PORT', '40081')) + uvicorn.run(app, host='0.0.0.0', port=port) + + +if __name__ == '__main__': + mission() diff --git a/python/dinfer/decoding/generate_uniform.py b/python/dinfer/decoding/generate_uniform.py index 9fb5eed..f06f18c 100644 --- a/python/dinfer/decoding/generate_uniform.py +++ b/python/dinfer/decoding/generate_uniform.py @@ -1,3 +1,6 @@ +import os +import requests + import torch import numpy as np import logging @@ -1041,7 +1044,7 @@ def cache_updates(self): return self.diff_iteration.cache_updates @ torch.no_grad() - def naive_batching_generate(self, prompt, gen_length=128, block_length=128): + def naive_batching_generate(self, prompt, gen_length=128, block_length=128, chat_uuid=None): ''' Generate tokens with diffusion iterations block by block. ''' # recalculate gen length and init iteratory @@ -1075,7 +1078,13 @@ def naive_batching_generate(self, prompt, gen_length=128, block_length=128): # We need to reset iter_no at the beginning of generating a sequence. self.diff_iteration.iter_no = 0 + ip_port_url = 'http://0.0.0.0:' + os.environ.get('TASK_SERVER_PORT', '40081') + logger.info(f'[{chat_uuid}] ip_port_url: {ip_port_url}') for block_id, (block_loc, block) in enumerate(it): + if chat_uuid: + x_list = x.get_generated_tokens().tolist() + stream_dict = {'chat_uuid': chat_uuid, 'curr_x': x_list} + requests.post(f'{ip_port_url}/v1/stream_put', json=stream_dict) self.decoder.block_init(block, block_id) if self.backend == 'vllm': cross_block_attn_mask = bd_attn_mask[:,block_loc.start-block_length:block_loc.end, :block_loc.end] @@ -1086,10 +1095,16 @@ def naive_batching_generate(self, prompt, gen_length=128, block_length=128): if torch.all(decode_compl) and self.early_stop: break logger.info(f'The number of diffusion iterations: {self.num_forwards}') - return x.get_generated_tokens() + + x_generated_tokens = x.get_generated_tokens() + if chat_uuid: + stream_dict = {'chat_uuid': chat_uuid, 'curr_x': x_generated_tokens.tolist()} + requests.post(f'{ip_port_url}/v1/stream_put', json=stream_dict) + + return x_generated_tokens @ torch.no_grad() - def dynamic_batching_generate(self, prompt, gen_length=128, block_length=128): + def dynamic_batching_generate(self, prompt, gen_length=128, block_length=128, chat_uuid=None): ''' Generate tokens with dynamic batching ''' assert self.cache_factory is not None diff --git a/python/dinfer/decoding/serving.py b/python/dinfer/decoding/serving.py index 056cfc3..c74114a 100644 --- a/python/dinfer/decoding/serving.py +++ b/python/dinfer/decoding/serving.py @@ -268,9 +268,9 @@ def generate(dllm, device, req_q, res_q): assert data == 'stop' break else: - input_ids, gen_len, block_len = data + input_ids, gen_len, block_len, chat_uuid = data input_ids = input_ids.to(device) - out = dllm.generate(input_ids, gen_length=gen_len, block_length=block_len) + out = dllm.generate(input_ids, gen_length=gen_len, block_length=block_len, chat_uuid=chat_uuid) num_forwards = dllm.num_forwards if res_q is not None: res_q.put((out, num_forwards)) @@ -485,18 +485,18 @@ def __init__(self): self.groups = [] self.need_response = 0 - def add_requests(self, reqs): + def add_requests(self, reqs, chat_uuid=None): prompts, gen_length, block_length = reqs self.need_response = 0 # assert len(self.groups) == prompts.shape[0], 'We cannot only use DP to support batch size > 1.' if len(self.groups) == prompts.shape[0]: for i, prompt in enumerate(prompts): - self.groups[i].add_request((prompt.unsqueeze(0), gen_length, block_length)) + self.groups[i].add_request((prompt.unsqueeze(0), gen_length, block_length, chat_uuid)) self.need_response = len(prompts) else: partial_data = torch.chunk(prompts, len(self.groups), dim=0) for i in range(len(partial_data)): - self.groups[i].add_request((partial_data[i], gen_length, block_length)) + self.groups[i].add_request((partial_data[i], gen_length, block_length, chat_uuid)) self.need_response = len(partial_data) @@ -580,7 +580,7 @@ def __init__(self, model, model_type='llada2', sample_params=None, server_port=N self.num_forwards = 0 self.timeout = timeout - def generate(self, prompts, gen_length=128, block_length=128): + def generate(self, prompts, gen_length=128, block_length=128, chat_uuid=None): ''' Generate tokens with diffusion iterations. Parameters: @@ -598,7 +598,7 @@ def generate(self, prompts, gen_length=128, block_length=128): The generation results of different lengths are padded with EOS. ''' prompts = prompts.cpu() - handle.add_requests((prompts, gen_length, block_length)) + handle.add_requests((prompts, gen_length, block_length), chat_uuid=chat_uuid) rets = handle.get_responses(timeout=self.timeout) max_len = max([tensor.shape[1] for (tensor, _) in rets]) total_batch_size = sum([tensor.shape[0] for (tensor, _) in rets])