Skip to content

Commit

Permalink
Merge pull request #30 from moeakwak/dev
Browse files Browse the repository at this point in the history
6.4
  • Loading branch information
spammenotinoz authored Oct 4, 2023
2 parents cb5b793 + 17c7ede commit 6249716
Show file tree
Hide file tree
Showing 44 changed files with 1,538 additions and 382 deletions.
5 changes: 4 additions & 1 deletion backend/api/conf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DataSetting(BaseModel):
data_dir: str = './data'
database_url: str = 'sqlite+aiosqlite:///data/database.db'
mongodb_url: str = 'mongodb://cws:password@mongo:27017' # 'mongodb://cws:password@localhost:27017'
mongodb_db_name: str = 'cws'
run_migration: bool = False
max_file_upload_size: int = Field(100 * 1024 * 1024, ge=0)

Expand All @@ -66,7 +67,7 @@ class OpenaiWebChatGPTSetting(BaseModel):
is_plus_account: bool = True
chatgpt_base_url: Optional[str] = None
proxy: Optional[str] = None
common_timeout: int = Field(10, ge=1) # connect, read, write
common_timeout: int = Field(20, ge=1) # connect, read, write
ask_timeout: int = Field(600, ge=1)
sync_conversations_on_startup: bool = True
sync_conversations_schedule: bool = False
Expand All @@ -75,6 +76,8 @@ class OpenaiWebChatGPTSetting(BaseModel):
"gpt_4_browsing"]
model_code_mapping: dict[OpenaiWebChatModels, str] = default_openai_web_model_code_mapping
file_upload_strategy: OpenaiWebFileUploadStrategyOption = OpenaiWebFileUploadStrategyOption.browser_upload_only
enable_uploading_attachments: bool = True
enable_uploading_multimodal_images: bool = True

@validator("chatgpt_base_url")
def chatgpt_base_url_end_with_slash(cls, v):
Expand Down
8 changes: 4 additions & 4 deletions backend/api/database/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
logger = get_logger(__name__)
config = Config()

DATABASE_NAME = "cws"

client: AsyncIOMotorClient | None = None


async def init_mongodb():
global client
client = AsyncIOMotorClient(config.data.mongodb_url)
await init_beanie(database=client[DATABASE_NAME],
await init_beanie(database=client[config.data.mongodb_db_name],
document_models=[OpenaiApiConversationHistoryDocument, OpenaiWebConversationHistoryDocument, AskLogDocument,
RequestLogDocument])
# 展示当前mongodb数据库用量
db = client[DATABASE_NAME]
db = client[config.data.mongodb_db_name]
stats = await db.command({"dbStats": 1})
logger.info(
f"MongoDB initialized. dataSize: {stats['dataSize'] / 1024 / 1024:.2f} MB, objects: {stats['objects']}")
Expand All @@ -34,7 +34,7 @@ async def handle_timeseries():
"""
global client
assert client is not None, "MongoDB not initialized"
db = client[DATABASE_NAME]
db = client[config.data.mongodb_db_name]
time_series_docs = [AskLogDocument, RequestLogDocument]
config_ttls = [config.stats.ask_stats_ttl, config.stats.request_stats_ttl]
for doc, config_ttl in zip(time_series_docs, config_ttls):
Expand Down
1 change: 0 additions & 1 deletion backend/api/enums/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class OpenaiWebFileUploadStrategyOption(StrEnum):
disable_upload = auto()
server_upload_only = auto()
browser_upload_only = auto()
browser_upload_when_file_size_exceed = auto()
1 change: 1 addition & 0 deletions backend/api/file_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def save_file(self, file: UploadFile, user_id: int, session: AsyncSession)
if not file_dir_path.exists():
file_dir_path.mkdir(parents=True)


async with aiofiles.open(file_path, "wb") as buffer:
while True:
chunk = await file.read(1024 * 1024) # read by 1MB chunk
Expand Down
3 changes: 2 additions & 1 deletion backend/api/models/json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Optional, Generic, TypeVar, get_args
from typing import Optional, Generic, TypeVar, get_args, Literal

from pydantic import BaseModel, Field, create_model, root_validator
from pydantic.generics import GenericModel
Expand Down Expand Up @@ -57,5 +57,6 @@ class CustomOpenaiApiSettings(BaseModel):

class UploadedFileOpenaiWebInfo(BaseModel):
file_id: Optional[str]
use_case: Optional[Literal['ace_upload', 'multimodal'] | str]
upload_url: Optional[str] = Field(description="上传文件的url, 上传后应清空该字段")
download_url: Optional[str]
26 changes: 17 additions & 9 deletions backend/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,15 @@ async def check_limits(user: UserReadAdmin, ask_request: AskRequest):
# 判断是否允许使用附件
if ask_request.openai_web_attachments and len(ask_request.openai_web_attachments) > 0:
if ask_request.model != OpenaiWebChatModels.gpt_4_code_interpreter or \
config.openai_web.file_upload_strategy == OpenaiWebFileUploadStrategyOption.disable_upload:
config.openai_web.enable_uploading_attachments is False:
raise WebsocketInvalidAskException("errors.attachmentsNotAllowed")

# 判断是否允许使用多模态图片
if ask_request.openai_web_multimodal_image_parts and len(ask_request.openai_web_multimodal_image_parts) > 0:
if ask_request.model != OpenaiWebChatModels.gpt_4 or \
config.openai_web.enable_uploading_multimodal_images is False:
raise WebsocketInvalidAskException("errors.multimodalImagesNotAllowed")


def check_message(msg: str):
# 检查消息中的敏感信息
Expand Down Expand Up @@ -282,7 +288,7 @@ async def reply(response: AskResponse):
queueing_start_time = None
queueing_end_time = None

# rev: 排队
# 排队
if ask_request.source == ChatSourceTypes.openai_web:
if openai_web_manager.is_busy():
await reply(AskResponse(
Expand Down Expand Up @@ -324,12 +330,14 @@ async def reply(response: AskResponse):
model = OpenaiApiChatModels(ask_request.model)

# stream 传输
async for data in manager.ask(content=ask_request.content,
async for data in manager.ask(text_content=ask_request.text_content,
conversation_id=ask_request.conversation_id,
parent_id=ask_request.parent,
model=model,
plugin_ids=ask_request.openai_web_plugin_ids,
attachments=ask_request.openai_web_attachments):
attachments=ask_request.openai_web_attachments,
multimodal_image_parts=ask_request.openai_web_multimodal_image_parts,
):
has_got_reply = True

try:
Expand Down Expand Up @@ -449,7 +457,7 @@ async def reply(response: AskResponse):
if ask_request.source == ChatSourceTypes.openai_api:
assert message.parent is not None, "message.parent is None"

content = ask_request.content
content = ask_request.text_content
if isinstance(content, str):
content = OpenaiApiChatMessageTextContent(content_type="text", text=content)

Expand Down Expand Up @@ -508,11 +516,11 @@ async def reply(response: AskResponse):
if ask_request.new_conversation:
assert conversation_id is not None, "has_got_reply but conversation_id is None"

# rev设置默认标题
if ask_request.source == ChatSourceTypes.openai_web:
# 设置默认标题
if ask_request.source == ChatSourceTypes.openai_web and ask_request.new_title is not None and \
ask_request.new_title.strip() != "":
try:
if ask_request.new_title is not None:
await openai_web_manager.set_conversation_title(str(conversation_id), ask_request.new_title)
await openai_web_manager.set_conversation_title(str(conversation_id), ask_request.new_title)
except Exception as e:
logger.warning(f"set_conversation_title error {e.__class__.__name__}: {str(e)}")

Expand Down
21 changes: 6 additions & 15 deletions backend/api/routers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,19 @@ async def delete_all_conversation(_user: User = Depends(current_super_user)):
return response(200)


@router.patch("/conv/{conversation_id}/gen_title", tags=["conversation"], response_model=OpenaiWebConversationSchema)
@router.patch("/conv/{conversation_id}/gen_title", tags=["conversation"], response_model=str)
async def generate_conversation_title(message_id: str,
conversation: OpenaiWebConversation = Depends(_get_conversation_by_id)):
if conversation.title is not None:
raise InvalidParamsException("errors.conversationTitleAlreadyGenerated")
async with get_async_session_context() as session:
result = await openai_web_manager.generate_conversation_title(conversation.id, message_id)
if result["title"]:
conversation.title = result["title"]
title = await openai_web_manager.generate_conversation_title(conversation.conversation_id, message_id)
if title:
conversation.title = title
session.add(conversation)
await session.commit()
await session.refresh(conversation)
else:
raise InvalidParamsException(f"{result['message']}")
result = jsonable_encoder(conversation)
return result
raise InternalException("errors.generateTitleFailed")
return title


@router.get("/conv/{conversation_id}/interpreter", tags=["conversation"], response_model=OpenaiChatInterpreterInfo)
Expand All @@ -224,12 +221,6 @@ async def get_conversation_interpreter_info(conversation_id: str):
return response(200, result=url)


@router.get("/conv/files/{file_id}/download-url", tags=["conversation"])
async def get_file_download_url(file_id: str):
url = await openai_web_manager.get_file_download_url(file_id)
return response(200, result=url)


@router.get("/conv/{conversation_id}/interpreter/download-url", tags=["conversation"])
async def get_conversation_interpreter_download_url(conversation_id: str, message_id: str, sandbox_path: str):
if message_id is None or sandbox_path is None:
Expand Down
30 changes: 23 additions & 7 deletions backend/api/routers/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mimetypes import guess_type

from fastapi import APIRouter, Depends, HTTPException, File, UploadFile
from fastapi_cache.decorator import cache
from starlette.responses import FileResponse

from api.conf import Config
Expand All @@ -12,6 +13,7 @@
from api.file_provider import FileProvider
from api.models.db import User, UploadedFileInfo
from api.models.json import UploadedFileOpenaiWebInfo
from api.schemas import UserRead
from api.schemas.file_schemas import UploadedFileInfoSchema, StartUploadResponseSchema
from api.schemas.openai_schemas import OpenaiChatFileUploadInfo
from api.sources import OpenaiWebChatManager
Expand All @@ -23,14 +25,23 @@
openai_web_manager = OpenaiWebChatManager()


@router.get("/files/{file_id}/download-url", tags=["conversation"], response_model=str)
@cache(expire=10 * 60)
async def get_file_download_url(file_id: str):
"""
file_id: OpenAI 分配的 id,以 file- 开头
"""
url = await openai_web_manager.get_file_download_url(file_id)
return url


@router.post("/files/local/upload", tags=["files"], response_model=UploadedFileInfoSchema)
async def upload_file_to_local(file: UploadFile = File(...), user: User = Depends(current_active_user)):
"""
上传文件到服务器。文件将被保存在服务器上,返回文件信息。
仅当需要在服务器留存上传的文件时才使用.
"""
if config.openai_web.file_upload_strategy in [OpenaiWebFileUploadStrategyOption.browser_upload_only,
OpenaiWebFileUploadStrategyOption.disable_upload]:
if config.openai_web.file_upload_strategy == OpenaiWebFileUploadStrategyOption.browser_upload_only:
raise InvalidRequestException(f"File upload disabled")
if file.size > config.data.max_file_upload_size:
raise InvalidRequestException(f"File too large! Max size: {config.data.max_file_upload_size}")
Expand Down Expand Up @@ -68,12 +79,17 @@ async def start_upload_to_openai(upload_info: OpenaiChatFileUploadInfo, user: Us
b. 再调用 upload_local_file_to_openai_web 接口,通知服务器上传文件到 OpenAI Web
"""
file_size_exceed = upload_info.file_size > config.data.max_file_upload_size
if upload_info.use_case != "ace_upload":
raise InvalidRequestException(f"Invalid use case: {upload_info.use_case}")
if config.openai_web.file_upload_strategy == OpenaiWebFileUploadStrategyOption.server_upload_only and file_size_exceed:
raise InvalidRequestException(f"File too large! Max size: {config.data.max_file_upload_size}")
if config.openai_web.file_upload_strategy == OpenaiWebFileUploadStrategyOption.disable_upload:
raise InvalidRequestException(f"File upload disabled")
raise InvalidRequestException(f"File is too large! Max size: {config.data.max_file_upload_size}")
user_info = UserRead.from_orm(user)
if upload_info.use_case == "ace_upload" and \
(user_info.setting.openai_web.allow_uploading_attachments is False or
config.openai_web.enable_uploading_attachments is False):
raise InvalidRequestException(f"Uploading attachments disabled")
if upload_info.use_case == "multimodal" and \
(user_info.setting.openai_web.allow_uploading_multimodal_images is False or
config.openai_web.enable_uploading_multimodal_images is False):
raise InvalidRequestException(f"Uploading multimodal images disabled")

file_info = None

Expand Down
7 changes: 6 additions & 1 deletion backend/api/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_user_manager, UserManager

router = APIRouter()
config = Config()


# router.include_router(
Expand Down Expand Up @@ -93,10 +94,14 @@ async def get_me(user: User = Depends(current_active_user)):
user_read = UserRead.from_orm(user)
for source in ["openai_api", "openai_web"]:
source_setting = getattr(user_read.setting, source)
global_enabled_models = getattr(Config(), source).enabled_models
global_enabled_models = getattr(config, source).enabled_models
source_setting.available_models = list(
set(source_setting.available_models).intersection(set(global_enabled_models)))
setattr(user_read.setting, source, source_setting)
if not config.openai_web.enable_uploading_attachments:
user_read.setting.openai_web.allow_uploading_attachments = False
if not config.openai_web.enable_uploading_multimodal_images:
user_read.setting.openai_web.allow_uploading_multimodal_images = False
return user_read


Expand Down
8 changes: 5 additions & 3 deletions backend/api/schemas/conversation_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from strenum import StrEnum

from api.enums import ChatSourceTypes, OpenaiWebChatModels, OpenaiApiChatModels
from api.models.doc import OpenaiWebChatMessage, OpenaiApiChatMessage
from api.models.doc import OpenaiWebChatMessage, OpenaiApiChatMessage, \
OpenaiWebChatMessageMultimodalTextContentImagePart
from api.schemas.openai_schemas import OpenaiWebAskAttachment
from utils.logger import get_logger

Expand All @@ -29,13 +30,14 @@ class AskRequest(BaseModel):
source: ChatSourceTypes
model: str
new_conversation: bool
new_title: Optional[str] = None
new_title: Optional[str] = None # 为空则生成标题
conversation_id: Optional[uuid.UUID] = None
parent: Optional[uuid.UUID] = None
api_context_message_count: int = Field(-1, ge=-1)
content: str
text_content: str
openai_web_plugin_ids: Optional[list[str]] = None
openai_web_attachments: Optional[list[OpenaiWebAskAttachment]] = None
openai_web_multimodal_image_parts: Optional[list[OpenaiWebChatMessageMultimodalTextContentImagePart]] = None

@root_validator
def check(cls, values):
Expand Down
2 changes: 1 addition & 1 deletion backend/api/schemas/openai_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class OpenaiChatInterpreterInfo(BaseModel):
class OpenaiChatFileUploadInfo(BaseModel):
file_name: str
file_size: int
use_case: str | Literal['ace_upload']
use_case: Literal['ace_upload', 'multimodal']


class OpenaiChatFileUploadUrlResponse(BaseModel):
Expand Down
22 changes: 20 additions & 2 deletions backend/api/schemas/user_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
from typing import Optional

from fastapi_users import schemas
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, EmailStr, validator, root_validator

from api.conf import Config
from api.enums import OpenaiWebChatStatus, OpenaiWebChatModels, OpenaiApiChatModels
from api.models.json import CustomOpenaiApiSettings, TimeWindowRateLimit, DailyTimeSlot, \
OpenaiWebPerModelAskCount, OpenaiApiPerModelAskCount

config = Config()


class BaseSourceSettingSchema(BaseModel):
allow_to_use: bool
Expand Down Expand Up @@ -45,12 +48,17 @@ def unlimited():
class OpenaiWebSourceSettingSchema(BaseSourceSettingSchema):
available_models: list[OpenaiWebChatModels]
per_model_ask_count: OpenaiWebPerModelAskCount
allow_uploading_attachments: bool
allow_uploading_multimodal_images: bool

@staticmethod
def default():
return OpenaiWebSourceSettingSchema(
available_models=[OpenaiWebChatModels(m) for m in OpenaiWebChatModels],
available_models=[OpenaiWebChatModels(m) for m in
["gpt_3_5", "gpt_4", "gpt_4_code_interpreter", "gpt_4_plugins", "gpt_4_browsing"]],
per_model_ask_count=OpenaiWebPerModelAskCount(),
allow_uploading_attachments=config.openai_web.enable_uploading_attachments,
allow_uploading_multimodal_images=config.openai_web.enable_uploading_multimodal_images,
**BaseSourceSettingSchema.default().dict()
)

Expand All @@ -59,9 +67,19 @@ def unlimited():
return OpenaiWebSourceSettingSchema(
available_models=[OpenaiWebChatModels(m) for m in OpenaiWebChatModels],
per_model_ask_count=OpenaiWebPerModelAskCount.unlimited(),
allow_uploading_attachments=True,
allow_uploading_multimodal_images=True,
**BaseSourceSettingSchema.unlimited().dict()
)

@root_validator(pre=True)
def check(cls, values):
if "allow_uploading_attachments" not in values:
values["allow_uploading_attachments"] = config.openai_web.enable_uploading_attachments
if "allow_uploading_multimodal_images" not in values:
values["allow_uploading_multimodal_images"] = config.openai_web.enable_uploading_multimodal_images
return values

class Config:
orm_mode = True

Expand Down
4 changes: 2 additions & 2 deletions backend/api/sources/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self):
def reset_session(self):
self.session = make_session()

async def ask(self, content: str, conversation_id: uuid.UUID = None,
async def ask(self, text_content: str, conversation_id: uuid.UUID = None,
parent_id: uuid.UUID = None, model: OpenaiApiChatModels = None,
context_message_count: int = -1, extra_args: Optional[dict] = None, **_kwargs):

Expand All @@ -75,7 +75,7 @@ async def ask(self, content: str, conversation_id: uuid.UUID = None,
create_time=now_time,
parent=parent_id,
children=[],
content=OpenaiApiChatMessageTextContent(content_type="text", text=content),
content=OpenaiApiChatMessageTextContent(content_type="text", text=text_content),
metadata=OpenaiApiChatMessageMetadata(
source="openai_api",
)
Expand Down
Loading

0 comments on commit 6249716

Please sign in to comment.