diff --git a/backend/api/schemas/conversation_schemas.py b/backend/api/schemas/conversation_schemas.py index 13abaafe..bc0e98f0 100644 --- a/backend/api/schemas/conversation_schemas.py +++ b/backend/api/schemas/conversation_schemas.py @@ -26,7 +26,7 @@ def _validate_model(_source: ChatSourceTypes, model: str | None): logger.warning(f"model {model} not in openai_api models: {'|'.join(list(OpenaiApiChatModels))}") -MAX_CONTEXT_MESSAGE_COUNT = 100 +MAX_CONTEXT_MESSAGE_COUNT = 1000 class AskRequest(BaseModel): diff --git a/backend/api/sources/openai_api.py b/backend/api/sources/openai_api.py index 41f8fb9f..1da7d8e0 100644 --- a/backend/api/sources/openai_api.py +++ b/backend/api/sources/openai_api.py @@ -2,10 +2,8 @@ import uuid from datetime import datetime, timezone from typing import Optional - import httpx from pydantic import ValidationError - from api.conf import Config, Credentials from api.enums import OpenaiApiChatModels, ChatSourceTypes from api.exceptions import OpenaiApiException @@ -16,12 +14,9 @@ from utils.logger import get_logger logger = get_logger(__name__) - config = Config() credentials = Credentials() - -MAX_CONTEXT_MESSAGE_COUNT = 100 - +MAX_CONTEXT_MESSAGE_COUNT = 1000 async def _check_response(response: httpx.Response) -> None: # 改成自带的错误处理 @@ -35,7 +30,6 @@ async def _check_response(response: httpx.Response) -> None: ) raise error from ex - def make_session() -> httpx.AsyncClient: if config.openai_api.proxy is not None: proxies = { @@ -47,12 +41,10 @@ def make_session() -> httpx.AsyncClient: session = httpx.AsyncClient(timeout=None) return session - class OpenaiApiChatManager(metaclass=SingletonMeta): """ OpenAI API Manager """ - def __init__(self): self.session = make_session() @@ -62,9 +54,7 @@ def reset_session(self): async def complete(self, model: OpenaiApiChatModels, text_content: str, conversation_id: uuid.UUID = None, parent_message_id: uuid.UUID = None, context_message_count: int = -1, extra_args: Optional[dict] = None, **_kwargs): - assert config.openai_api.enabled, "openai_api is not enabled" - now_time = datetime.now().astimezone(tz=timezone.utc) message_id = uuid.uuid4() new_message = OpenaiApiChatMessage( @@ -79,9 +69,7 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa source="openai_api", ) ) - messages = [] - if not conversation_id: assert parent_message_id is None, "parent_id must be None when conversation_id is None" messages = [new_message] @@ -93,32 +81,19 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa raise ValueError(f"{conversation_id} is not api conversation") if not conv_history.mapping.get(str(parent_message_id)): raise ValueError(f"{parent_message_id} is not a valid parent of {conversation_id}") - - # 从 current_node 开始往前找 context_message_count 个 message + # 从 current_node 开始往前找最多 5 个 message if not conv_history.current_node: raise ValueError(f"{conversation_id} current_node is None") - msg = conv_history.mapping.get(str(conv_history.current_node)) assert msg, f"{conv_history.id} current_node({conv_history.current_node}) not found in mapping" - count = 0 - iter_count = 0 - - while msg: - count += 1 + while msg and count < 5: messages.append(msg) - if context_message_count != -1 and count >= context_message_count: - break - iter_count += 1 - if iter_count > MAX_CONTEXT_MESSAGE_COUNT: - raise ValueError(f"too many messages to iterate, conversation_id={conversation_id}") msg = conv_history.mapping.get(str(msg.parent)) - + count += 1 messages.reverse() messages.append(new_message) - # TODO: credits 判断 - base_url = config.openai_api.openai_base_url data = { "model": model.code(), @@ -126,12 +101,9 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa "stream": True, **(extra_args or {}) } - reply_message = None text_content = "" - timeout = httpx.Timeout(config.openai_api.read_timeout, connect=config.openai_api.connect_timeout) - async with self.session.stream(method="POST", url=f"{base_url}chat/completions", json=data, @@ -146,14 +118,11 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa line = line[6:] if "[DONE]" in line: break - try: line = json.loads(line) resp = OpenaiChatResponse.model_validate(line) - if not resp.choices or len(resp.choices) == 0: continue - if resp.choices[0].message is not None: text_content = resp.choices[0].message.get("content") if resp.choices[0].delta is not None: @@ -175,13 +144,10 @@ async def complete(self, model: OpenaiApiChatModels, text_content: str, conversa ) else: reply_message.content = OpenaiApiChatMessageTextContent(content_type="text", text=text_content) - if resp.usage: reply_message.metadata.usage = resp.usage - yield reply_message - except json.decoder.JSONDecodeError: logger.warning(f"OpenAIChatResponse parse json error") except ValidationError as e: - logger.warning(f"OpenAIChatResponse validate error: {e}") + logger.warning(f"OpenAIChatResponse validate error: {e}") \ No newline at end of file