Skip to content

Commit

Permalink
🔨[DEV] improve OpenAIChatModel
Browse files Browse the repository at this point in the history
  • Loading branch information
fairyshine committed Sep 27, 2024
1 parent 058b5e9 commit aecb7f2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
50 changes: 35 additions & 15 deletions src/fastmindapi/model/openai/ChatModel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ...server.router.openai import ChatMessage
from ...utils.transform import convert_openai_logprobs
from ... import logger

class OpenAIChatModel:
def __init__(self, client, model_name: str, system_prompt: str = "You are a helpful assistant."):
Expand All @@ -25,21 +27,39 @@ def __call__(self, input_text: str, max_new_tokens: int = 256):
except Exception as e:
return "【Error】: " + str(e)

# def generate(self,
# input_text: str,
# max_new_tokens: int = 256,
# return_logits: bool = False,
# logits_top_k: int = 10,
# stop_strings: list[str] = None):
# generation_output = {"output_text": output_text,
# "input_id_list": input_id_list,
# "input_token_list": input_token_list,
# "input_text": input_text,
# "full_id_list": full_id_list,
# "full_token_list": full_token_list,
# "full_text": full_text,
# "logits": logits_list}
# return generation_output
def generate(self,
input_text: str,
max_new_tokens: int = 256,
return_logits: bool = False,
logits_top_k: int = 10,
stop_strings: list[str] = None):
while True:
try:
completion = self.client.chat.completions.create(
model= self.model_name,
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": input_text}
],
max_completion_tokens=max_new_tokens,
logprobs=return_logits,
top_logprobs=logits_top_k,
stop=stop_strings
)
break
except Exception as e:
logger.info(f"【Error】: {e}")
output_text = completion.choices[0].message.content
logits_list = convert_openai_logprobs(completion.choices[0].logprobs)
generation_output = {"output_text": output_text,
# "input_id_list": input_id_list,
# "input_token_list": input_token_list,
"input_text": input_text,
# "full_id_list": full_id_list,
# "full_token_list": full_token_list,
# "full_text": full_text,
"logits": logits_list}
return generation_output

def chat(self, messages: list[ChatMessage], max_completion_tokens: int = None, logprobs: bool = False, top_logprobs: int =10):
try:
Expand Down
6 changes: 6 additions & 0 deletions tests/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import tests_settings # noqa: F401
import fastmindapi as FM

server = FM.Server()

server.run()

0 comments on commit aecb7f2

Please sign in to comment.