Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
1
  • Loading branch information
spammenotinoz committed Mar 17, 2024
1 parent 5347391 commit 576913f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 40 deletions.
2 changes: 1 addition & 1 deletion backend/api/schemas/conversation_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _validate_model(_source: ChatSourceTypes, model: str | None):
logger.warning(f"model {model} not in openai_api models: {'|'.join(list(OpenaiApiChatModels))}")


MAX_CONTEXT_MESSAGE_COUNT = 100
MAX_CONTEXT_MESSAGE_COUNT = 1000


class AskRequest(BaseModel):
Expand Down
44 changes: 5 additions & 39 deletions backend/api/sources/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import uuid
from datetime import datetime, timezone
from typing import Optional

import httpx
from pydantic import ValidationError

from api.conf import Config, Credentials
from api.enums import OpenaiApiChatModels, ChatSourceTypes
from api.exceptions import OpenaiApiException
Expand All @@ -16,12 +14,9 @@
from utils.logger import get_logger

logger = get_logger(__name__)

config = Config()
credentials = Credentials()

MAX_CONTEXT_MESSAGE_COUNT = 100

MAX_CONTEXT_MESSAGE_COUNT = 1000

async def _check_response(response: httpx.Response) -> None:
# 改成自带的错误处理
Expand All @@ -35,7 +30,6 @@ async def _check_response(response: httpx.Response) -> None:
)
raise error from ex


def make_session() -> httpx.AsyncClient:
if config.openai_api.proxy is not None:
proxies = {
Expand All @@ -47,12 +41,10 @@ def make_session() -> httpx.AsyncClient:
session = httpx.AsyncClient(timeout=None)
return session


class OpenaiApiChatManager(metaclass=SingletonMeta):
"""
OpenAI API Manager
"""

def __init__(self):
self.session = make_session()

Expand All @@ -62,9 +54,7 @@ def reset_session(self):
async def complete(self, model: OpenaiApiChatModels, text_content: str, conversation_id: uuid.UUID = None,
parent_message_id: uuid.UUID = None,
context_message_count: int = -1, extra_args: Optional[dict] = None, **_kwargs):

assert config.openai_api.enabled, "openai_api is not enabled"

now_time = datetime.now().astimezone(tz=timezone.utc)
message_id = uuid.uuid4()
new_message = OpenaiApiChatMessage(
Expand All @@ -79,9 +69,7 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa
source="openai_api",
)
)

messages = []

if not conversation_id:
assert parent_message_id is None, "parent_id must be None when conversation_id is None"
messages = [new_message]
Expand All @@ -93,45 +81,29 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa
raise ValueError(f"{conversation_id} is not api conversation")
if not conv_history.mapping.get(str(parent_message_id)):
raise ValueError(f"{parent_message_id} is not a valid parent of {conversation_id}")

# 从 current_node 开始往前找 context_message_count 个 message
# 从 current_node 开始往前找最多 5 个 message
if not conv_history.current_node:
raise ValueError(f"{conversation_id} current_node is None")

msg = conv_history.mapping.get(str(conv_history.current_node))
assert msg, f"{conv_history.id} current_node({conv_history.current_node}) not found in mapping"

count = 0
iter_count = 0

while msg:
count += 1
while msg and count < 5:
messages.append(msg)
if context_message_count != -1 and count >= context_message_count:
break
iter_count += 1
if iter_count > MAX_CONTEXT_MESSAGE_COUNT:
raise ValueError(f"too many messages to iterate, conversation_id={conversation_id}")
msg = conv_history.mapping.get(str(msg.parent))

count += 1
messages.reverse()
messages.append(new_message)

# TODO: credits 判断

base_url = config.openai_api.openai_base_url
data = {
"model": model.code(),
"messages": [{"role": msg.role, "content": msg.content.text} for msg in messages],
"stream": True,
**(extra_args or {})
}

reply_message = None
text_content = ""

timeout = httpx.Timeout(config.openai_api.read_timeout, connect=config.openai_api.connect_timeout)

async with self.session.stream(method="POST",
url=f"{base_url}chat/completions",
json=data,
Expand All @@ -146,14 +118,11 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa
line = line[6:]
if "[DONE]" in line:
break

try:
line = json.loads(line)
resp = OpenaiChatResponse.model_validate(line)

if not resp.choices or len(resp.choices) == 0:
continue

if resp.choices[0].message is not None:
text_content = resp.choices[0].message.get("content")
if resp.choices[0].delta is not None:
Expand All @@ -175,13 +144,10 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa
)
else:
reply_message.content = OpenaiApiChatMessageTextContent(content_type="text", text=text_content)

if resp.usage:
reply_message.metadata.usage = resp.usage

yield reply_message

except json.decoder.JSONDecodeError:
logger.warning(f"OpenAIChatResponse parse json error")
except ValidationError as e:
logger.warning(f"OpenAIChatResponse validate error: {e}")
logger.warning(f"OpenAIChatResponse validate error: {e}")

0 comments on commit 576913f

Please sign in to comment.