Skip to content

Commit

Permalink
🔨[DEV] Support OpenAI API Format for transformers model.
Browse files Browse the repository at this point in the history
  • Loading branch information
fairyshine committed Sep 25, 2024
1 parent 021c638 commit be7d1b5
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/Quick_Start/user_guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 1 addition & 1 deletion src/fastmindapi/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .transformers.CasualLM import TransformersCausalLM
from .transformers.CausalLM import TransformersCausalLM
from .transformers.PeftModel import PeftCausalLM
from .llama_cpp.LLM import LlamacppLLM

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

from ...server.router.openai import ChatMessage

class TransformersCausalLM:
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
Expand Down Expand Up @@ -103,5 +105,51 @@ def generate(self,

return generation_output

def chat(self):
pass
def chat(self, messages: list[ChatMessage], max_completion_tokens: int = None):
import torch
import time

# 将消息列表转换为单个输入文本
input_text = ""
for message in messages:
role = message.role
content = message.content
input_text += f"{role}: {content}\n"

inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)

generate_kwargs = {
"max_new_tokens": max_completion_tokens,
}

with torch.no_grad():
outputs = self.model.generate(**inputs, **generate_kwargs)

full_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
re_inputs = self.tokenizer.batch_decode(inputs.input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
output_texts = [full_text[len(re_inputs):] for full_text in full_texts]

choices = []
for i, output_text in enumerate(output_texts):
choices.append({
"index": i,
"message": {
"role": "assistant",
"content": output_text
},
"finish_reason": "stop"
})

response = {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": self.model.config.name_or_path,
"choices": choices,
"usage": {
"prompt_tokens": inputs.input_ids.shape[1],
"completion_tokens": sum(len(self.tokenizer.encode(text)) for text in output_texts),
"total_tokens": inputs.input_ids.shape[1] + sum(len(self.tokenizer.encode(text)) for text in output_texts)
}
}
return response
2 changes: 1 addition & 1 deletion src/fastmindapi/model/transformers/PeftModel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .CasualLM import TransformersCausalLM
from .CausalLM import TransformersCausalLM

class PeftCausalLM(TransformersCausalLM):
def __init__(self, base_model: TransformersCausalLM, peft_model):
Expand Down
11 changes: 5 additions & 6 deletions src/fastmindapi/server/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ...model import ModelModule
from ..router.basic import get_basic_router
from ..router.model import get_model_router

from ..router.openai import get_openai_router
from ... import logger

class Server:
Expand All @@ -24,10 +24,10 @@ def __init__(self):
# 加载路由
self.app.include_router(get_basic_router())
self.app.include_router(get_model_router())
self.app.include_router(get_openai_router())


# def load_model(self, model_name: str, model):
# self.module["model"].load_model(model_name, model)
# def load_model(self, model_name: str, model):
# self.module["model"].load_model(model_name, model)

def run(self):
match self.deploy_mode:
Expand All @@ -39,5 +39,4 @@ def run(self):
port=self.port,
log_config=None)
self.logger.info("Client stops running.")



1 change: 1 addition & 0 deletions src/fastmindapi/server/router/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .basic import get_basic_router
from .model import get_model_router
from .openai import get_openai_router



19 changes: 19 additions & 0 deletions src/fastmindapi/server/router/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ class GenerationOutput(BaseModel):
output_text: str
logits: list

class ChatMessage(BaseModel):
role: str
content: str

class ChatRequest(BaseModel):
messages: list[ChatMessage]
max_new_tokens: int = 256

model_config=ConfigDict(protected_namespaces=())

def add_model_info(request: Request, item: BasicModel):
server = request.app.state.server
if item.model_name in server.module["model"].available_models:
Expand Down Expand Up @@ -73,6 +83,14 @@ def generate(request: Request, model_name: str, item: GenerationRequest):
outputs = server.module["model"].loaded_models[model_name].generate(**item.model_dump())
return outputs

def chat(request: Request, model_name: str, item: ChatRequest):
server = request.app.state.server
try:
assert model_name in server.module["model"].loaded_models
except AssertionError:
return f"【Error】: {model_name} is not loaded."
outputs = server.module["model"].loaded_models[model_name].chat(messages=item.messages, max_new_tokens=item.max_new_tokens)
return outputs

def get_model_router():
router = APIRouter(prefix=PREFIX)
Expand All @@ -82,4 +100,5 @@ def get_model_router():
router.add_api_route("/unload/{model_name}", unload_model, methods=["GET"])
router.add_api_route("/call/{model_name}", simple_generate, methods=["POST"])
router.add_api_route("/generate/{model_name}", generate, methods=["POST"])
router.add_api_route("/chat/{model_name}", chat, methods=["POST"])
return router
35 changes: 35 additions & 0 deletions src/fastmindapi/server/router/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pydantic import BaseModel, ConfigDict
from fastapi import APIRouter, Request

PREFIX = "/openai"

class ChatMessage(BaseModel):
role: str
content: str

class ChatRequest(BaseModel):
model: str
messages: list[ChatMessage]
max_completion_tokens: int = None

model_config=ConfigDict(protected_namespaces=())


def chat_completions(request: Request, item: ChatRequest):
server = request.app.state.server
try:
assert item.model in server.module["model"].loaded_models
except AssertionError:
return f"【Error】: {item.model} is not loaded."

outputs = server.module["model"].loaded_models[item.model].chat(
messages=item.messages,
max_completion_tokens=item.max_completion_tokens
)
return outputs

def get_openai_router():
router = APIRouter(prefix=PREFIX)

router.add_api_route("/chat/completions", chat_completions, methods=["POST"])
return router
19 changes: 18 additions & 1 deletion tests/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,40 @@ curl http://127.0.0.1:8000/model/add_info \
"model_path": "/Users/wumengsong/Resource/gemma-2-2b"
}'

curl http://127.0.0.1:8000/model/load/gemma2
curl http://127.0.0.1:8000/model/load/gemma2 -H "Authorization: Bearer sk-anything"

curl http://127.0.0.1:8000/model/call/gemma2 \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-anything" \
-d '{
"input_text": "Do you know something about Dota2?",
"max_new_tokens": 2
}'

curl http://127.0.0.1:8000/model/generate/gemma2 \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-anything" \
-d '{
"input_text": "Do you know something about Dota2?",
"max_new_tokens": 2,
"return_logits": true,
"stop_strings": ["\n"]
}'

curl http://127.0.0.1:8000/openai/chat/completions -H "Content-Type: application/json" -H "Authorization: Bearer sk-anything" -d '{
"model": "gemma2",
"messages": [
{
"role": "system",
"content": "You are a test assistant."
},
{
"role": "user",
"content": "Do you know something about Dota2?"
}
],
"max_completion_tokens": 2
}'
```

```shell
Expand Down

0 comments on commit be7d1b5

Please sign in to comment.