From be7d1b5b122bc6e9564bf4fa13179371a4e1bc41 Mon Sep 17 00:00:00 2001 From: Radicat <2285225334@qq.com> Date: Wed, 25 Sep 2024 20:04:06 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8[DEV]=20Support=20OpenAI=20API=20Fo?= =?UTF-8?q?rmat=20for=20transformers=20model.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/Quick_Start/user_guide.md | 1 + src/fastmindapi/model/__init__.py | 2 +- .../transformers/{CasualLM.py => CausalLM.py} | 52 ++++++++++++++++++- .../model/transformers/PeftModel.py | 2 +- src/fastmindapi/server/core/main.py | 11 ++-- src/fastmindapi/server/router/__init__.py | 1 + src/fastmindapi/server/router/model.py | 19 +++++++ src/fastmindapi/server/router/openai.py | 35 +++++++++++++ tests/test.md | 19 ++++++- 9 files changed, 131 insertions(+), 11 deletions(-) rename src/fastmindapi/model/transformers/{CasualLM.py => CausalLM.py} (73%) create mode 100644 src/fastmindapi/server/router/openai.py diff --git a/docs/Quick_Start/user_guide.md b/docs/Quick_Start/user_guide.md index e69de29..8b13789 100644 --- a/docs/Quick_Start/user_guide.md +++ b/docs/Quick_Start/user_guide.md @@ -0,0 +1 @@ + diff --git a/src/fastmindapi/model/__init__.py b/src/fastmindapi/model/__init__.py index 5301033..96acffd 100644 --- a/src/fastmindapi/model/__init__.py +++ b/src/fastmindapi/model/__init__.py @@ -1,4 +1,4 @@ -from .transformers.CasualLM import TransformersCausalLM +from .transformers.CausalLM import TransformersCausalLM from .transformers.PeftModel import PeftCausalLM from .llama_cpp.LLM import LlamacppLLM diff --git a/src/fastmindapi/model/transformers/CasualLM.py b/src/fastmindapi/model/transformers/CausalLM.py similarity index 73% rename from src/fastmindapi/model/transformers/CasualLM.py rename to src/fastmindapi/model/transformers/CausalLM.py index 7eba4d2..8362d22 100644 --- a/src/fastmindapi/model/transformers/CasualLM.py +++ b/src/fastmindapi/model/transformers/CausalLM.py @@ -1,4 +1,6 @@ +from ...server.router.openai import ChatMessage + class TransformersCausalLM: def __init__(self, tokenizer, model): self.tokenizer = tokenizer @@ -103,5 +105,51 @@ def generate(self, return generation_output - def chat(self): - pass \ No newline at end of file + def chat(self, messages: list[ChatMessage], max_completion_tokens: int = None): + import torch + import time + + # 将消息列表转换为单个输入文本 + input_text = "" + for message in messages: + role = message.role + content = message.content + input_text += f"{role}: {content}\n" + + inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) + + generate_kwargs = { + "max_new_tokens": max_completion_tokens, + } + + with torch.no_grad(): + outputs = self.model.generate(**inputs, **generate_kwargs) + + full_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False) + re_inputs = self.tokenizer.batch_decode(inputs.input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + output_texts = [full_text[len(re_inputs):] for full_text in full_texts] + + choices = [] + for i, output_text in enumerate(output_texts): + choices.append({ + "index": i, + "message": { + "role": "assistant", + "content": output_text + }, + "finish_reason": "stop" + }) + + response = { + "id": f"chatcmpl-{int(time.time())}", + "object": "chat.completion", + "created": int(time.time()), + "model": self.model.config.name_or_path, + "choices": choices, + "usage": { + "prompt_tokens": inputs.input_ids.shape[1], + "completion_tokens": sum(len(self.tokenizer.encode(text)) for text in output_texts), + "total_tokens": inputs.input_ids.shape[1] + sum(len(self.tokenizer.encode(text)) for text in output_texts) + } + } + return response \ No newline at end of file diff --git a/src/fastmindapi/model/transformers/PeftModel.py b/src/fastmindapi/model/transformers/PeftModel.py index a789779..43992d5 100644 --- a/src/fastmindapi/model/transformers/PeftModel.py +++ b/src/fastmindapi/model/transformers/PeftModel.py @@ -1,4 +1,4 @@ -from .CasualLM import TransformersCausalLM +from .CausalLM import TransformersCausalLM class PeftCausalLM(TransformersCausalLM): def __init__(self, base_model: TransformersCausalLM, peft_model): diff --git a/src/fastmindapi/server/core/main.py b/src/fastmindapi/server/core/main.py index 8b5a198..64f0531 100644 --- a/src/fastmindapi/server/core/main.py +++ b/src/fastmindapi/server/core/main.py @@ -2,7 +2,7 @@ from ...model import ModelModule from ..router.basic import get_basic_router from ..router.model import get_model_router - +from ..router.openai import get_openai_router from ... import logger class Server: @@ -24,10 +24,10 @@ def __init__(self): # 加载路由 self.app.include_router(get_basic_router()) self.app.include_router(get_model_router()) + self.app.include_router(get_openai_router()) - - # def load_model(self, model_name: str, model): - # self.module["model"].load_model(model_name, model) + # def load_model(self, model_name: str, model): + # self.module["model"].load_model(model_name, model) def run(self): match self.deploy_mode: @@ -39,5 +39,4 @@ def run(self): port=self.port, log_config=None) self.logger.info("Client stops running.") - - \ No newline at end of file + diff --git a/src/fastmindapi/server/router/__init__.py b/src/fastmindapi/server/router/__init__.py index b0aa016..68323f1 100644 --- a/src/fastmindapi/server/router/__init__.py +++ b/src/fastmindapi/server/router/__init__.py @@ -1,5 +1,6 @@ from .basic import get_basic_router from .model import get_model_router +from .openai import get_openai_router diff --git a/src/fastmindapi/server/router/model.py b/src/fastmindapi/server/router/model.py index 7133a24..d2a8380 100644 --- a/src/fastmindapi/server/router/model.py +++ b/src/fastmindapi/server/router/model.py @@ -23,6 +23,16 @@ class GenerationOutput(BaseModel): output_text: str logits: list +class ChatMessage(BaseModel): + role: str + content: str + +class ChatRequest(BaseModel): + messages: list[ChatMessage] + max_new_tokens: int = 256 + + model_config=ConfigDict(protected_namespaces=()) + def add_model_info(request: Request, item: BasicModel): server = request.app.state.server if item.model_name in server.module["model"].available_models: @@ -73,6 +83,14 @@ def generate(request: Request, model_name: str, item: GenerationRequest): outputs = server.module["model"].loaded_models[model_name].generate(**item.model_dump()) return outputs +def chat(request: Request, model_name: str, item: ChatRequest): + server = request.app.state.server + try: + assert model_name in server.module["model"].loaded_models + except AssertionError: + return f"【Error】: {model_name} is not loaded." + outputs = server.module["model"].loaded_models[model_name].chat(messages=item.messages, max_new_tokens=item.max_new_tokens) + return outputs def get_model_router(): router = APIRouter(prefix=PREFIX) @@ -82,4 +100,5 @@ def get_model_router(): router.add_api_route("/unload/{model_name}", unload_model, methods=["GET"]) router.add_api_route("/call/{model_name}", simple_generate, methods=["POST"]) router.add_api_route("/generate/{model_name}", generate, methods=["POST"]) + router.add_api_route("/chat/{model_name}", chat, methods=["POST"]) return router diff --git a/src/fastmindapi/server/router/openai.py b/src/fastmindapi/server/router/openai.py new file mode 100644 index 0000000..a3ba5bc --- /dev/null +++ b/src/fastmindapi/server/router/openai.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, ConfigDict +from fastapi import APIRouter, Request + +PREFIX = "/openai" + +class ChatMessage(BaseModel): + role: str + content: str + +class ChatRequest(BaseModel): + model: str + messages: list[ChatMessage] + max_completion_tokens: int = None + + model_config=ConfigDict(protected_namespaces=()) + + +def chat_completions(request: Request, item: ChatRequest): + server = request.app.state.server + try: + assert item.model in server.module["model"].loaded_models + except AssertionError: + return f"【Error】: {item.model} is not loaded." + + outputs = server.module["model"].loaded_models[item.model].chat( + messages=item.messages, + max_completion_tokens=item.max_completion_tokens + ) + return outputs + +def get_openai_router(): + router = APIRouter(prefix=PREFIX) + + router.add_api_route("/chat/completions", chat_completions, methods=["POST"]) + return router \ No newline at end of file diff --git a/tests/test.md b/tests/test.md index 5a15284..03cd71d 100644 --- a/tests/test.md +++ b/tests/test.md @@ -28,10 +28,11 @@ curl http://127.0.0.1:8000/model/add_info \ "model_path": "/Users/wumengsong/Resource/gemma-2-2b" }' -curl http://127.0.0.1:8000/model/load/gemma2 +curl http://127.0.0.1:8000/model/load/gemma2 -H "Authorization: Bearer sk-anything" curl http://127.0.0.1:8000/model/call/gemma2 \ -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-anything" \ -d '{ "input_text": "Do you know something about Dota2?", "max_new_tokens": 2 @@ -39,12 +40,28 @@ curl http://127.0.0.1:8000/model/call/gemma2 \ curl http://127.0.0.1:8000/model/generate/gemma2 \ -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-anything" \ -d '{ "input_text": "Do you know something about Dota2?", "max_new_tokens": 2, "return_logits": true, "stop_strings": ["\n"] }' + +curl http://127.0.0.1:8000/openai/chat/completions -H "Content-Type: application/json" -H "Authorization: Bearer sk-anything" -d '{ + "model": "gemma2", + "messages": [ + { + "role": "system", + "content": "You are a test assistant." + }, + { + "role": "user", + "content": "Do you know something about Dota2?" + } + ], + "max_completion_tokens": 2 +}' ``` ```shell