diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 64bcf4ce3..93f952fd2 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -293,6 +293,41 @@ class CommandConflict(SQLModel, table=True): ) +class ApiKey(SQLModel, table=True): + """API Key table for external API access.""" + + __tablename__ = "api_keys" # type: ignore + + id: int | None = Field( + default=None, primary_key=True, sa_column_kwargs={"autoincrement": True} + ) + key_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + index=True, + ) + api_key: str = Field(nullable=False, max_length=255, index=True) + """The actual API key (hashed)""" + username: str = Field(nullable=False, max_length=255) + """WebUI username who created this API key""" + name: str | None = Field(default=None, max_length=255) + """Optional name/description for the API key""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + expires_at: datetime | None = Field(default=None) + """Expiration time (not implemented yet)""" + last_used_at: datetime | None = Field(default=None) + """Last time this API key was used""" + + __table_args__ = ( + UniqueConstraint( + "key_id", + name="uix_api_key_id", + ), + ) + + @dataclass class Conversation: """LLM 对话类 diff --git a/astrbot/dashboard/entities.py b/astrbot/dashboard/entities.py new file mode 100644 index 000000000..4c426e032 --- /dev/null +++ b/astrbot/dashboard/entities.py @@ -0,0 +1,21 @@ +from pydantic.dataclasses import dataclass + + +@dataclass +class Response: + status: str | None = None + message: str | None = None + data: dict | list | None = None + + def error(self, message: str): + self.status = "error" + self.message = message + return self + + def ok(self, data: dict | list | None = None, message: str | None = None): + self.status = "ok" + if data is None: + data = {} + self.data = data + self.message = message + return self diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 951db956c..7db4631f8 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -1,3 +1,4 @@ +from .api_key import ApiKeyRoute from .auth import AuthRoute from .chat import ChatRoute from .command import CommandRoute @@ -16,6 +17,7 @@ from .update import UpdateRoute __all__ = [ + "ApiKeyRoute", "AuthRoute", "ChatRoute", "CommandRoute", diff --git a/astrbot/dashboard/routes/api_key.py b/astrbot/dashboard/routes/api_key.py new file mode 100644 index 000000000..e9351a3e3 --- /dev/null +++ b/astrbot/dashboard/routes/api_key.py @@ -0,0 +1,24 @@ +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase + +from ..services.api_key import ApiKeyService +from .route import Route, RouteContext + + +class ApiKeyRoute(Route): + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + db: BaseDatabase, + ): + super().__init__(context) + self.api_key_service = ApiKeyService(core_lifecycle, db) + self.routes = { + "/api-key": [ + ("POST", self.api_key_service.create_api_key), + ("GET", self.api_key_service.list_api_keys), + ], + "/api-key/": [("DELETE", self.api_key_service.delete_api_key)], + } + self.register_routes() diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 71c3fecd3..25c058364 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -1,30 +1,8 @@ -import asyncio -import json -import mimetypes -import os -import uuid -from contextlib import asynccontextmanager -from typing import cast - -from quart import Response as QuartResponse -from quart import g, make_response, request, send_file - -from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr -from astrbot.core.utils.astrbot_path import get_astrbot_data_path - -from .route import Response, Route, RouteContext - -@asynccontextmanager -async def track_conversation(convs: dict, conv_id: str): - convs[conv_id] = True - try: - yield - finally: - convs.pop(conv_id, None) +from ..services.chat import ChatService +from .route import Route, RouteContext class ChatRoute(Route): @@ -35,658 +13,19 @@ def __init__( core_lifecycle: AstrBotCoreLifecycle, ) -> None: super().__init__(context) + self.chat_service = ChatService(core_lifecycle, db) self.routes = { - "/chat/send": ("POST", self.chat), - "/chat/new_session": ("GET", self.new_session), - "/chat/sessions": ("GET", self.get_sessions), - "/chat/get_session": ("GET", self.get_session), - "/chat/delete_session": ("GET", self.delete_webchat_session), + "/chat/send": ("POST", self.chat_service.chat), + "/chat/new_session": ("GET", self.chat_service.new_session), + "/chat/sessions": ("GET", self.chat_service.get_sessions), + "/chat/get_session": ("GET", self.chat_service.get_session), + "/chat/delete_session": ("GET", self.chat_service.delete_session), "/chat/update_session_display_name": ( "POST", - self.update_session_display_name, + self.chat_service.update_session_display_name, ), - "/chat/get_file": ("GET", self.get_file), - "/chat/get_attachment": ("GET", self.get_attachment), - "/chat/post_file": ("POST", self.post_file), + "/chat/get_file": ("GET", self.chat_service.get_file), + "/chat/get_attachment": ("GET", self.chat_service.get_attachment), + "/chat/post_file": ("POST", self.chat_service.post_file), } - self.core_lifecycle = core_lifecycle self.register_routes() - self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") - os.makedirs(self.imgs_dir, exist_ok=True) - - self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] - self.conv_mgr = core_lifecycle.conversation_manager - self.platform_history_mgr = core_lifecycle.platform_message_history_manager - self.db = db - self.umop_config_router = core_lifecycle.umop_config_router - - self.running_convs: dict[str, bool] = {} - - async def get_file(self): - filename = request.args.get("filename") - if not filename: - return Response().error("Missing key: filename").__dict__ - - try: - file_path = os.path.join(self.imgs_dir, os.path.basename(filename)) - real_file_path = os.path.realpath(file_path) - real_imgs_dir = os.path.realpath(self.imgs_dir) - - if not real_file_path.startswith(real_imgs_dir): - return Response().error("Invalid file path").__dict__ - - filename_ext = os.path.splitext(filename)[1].lower() - if filename_ext == ".wav": - return await send_file(real_file_path, mimetype="audio/wav") - if filename_ext[1:] in self.supported_imgs: - return await send_file(real_file_path, mimetype="image/jpeg") - return await send_file(real_file_path) - - except (FileNotFoundError, OSError): - return Response().error("File access error").__dict__ - - async def get_attachment(self): - """Get attachment file by attachment_id.""" - attachment_id = request.args.get("attachment_id") - if not attachment_id: - return Response().error("Missing key: attachment_id").__dict__ - - try: - attachment = await self.db.get_attachment_by_id(attachment_id) - if not attachment: - return Response().error("Attachment not found").__dict__ - - file_path = attachment.path - real_file_path = os.path.realpath(file_path) - - return await send_file(real_file_path, mimetype=attachment.mime_type) - - except (FileNotFoundError, OSError): - return Response().error("File access error").__dict__ - - async def post_file(self): - """Upload a file and create an attachment record, return attachment_id.""" - post_data = await request.files - if "file" not in post_data: - return Response().error("Missing key: file").__dict__ - - file = post_data["file"] - filename = file.filename or f"{uuid.uuid4()!s}" - content_type = file.content_type or "application/octet-stream" - - # 根据 content_type 判断文件类型并添加扩展名 - if content_type.startswith("image"): - attach_type = "image" - elif content_type.startswith("audio"): - attach_type = "record" - elif content_type.startswith("video"): - attach_type = "video" - else: - attach_type = "file" - - path = os.path.join(self.imgs_dir, filename) - await file.save(path) - - # 创建 attachment 记录 - attachment = await self.db.insert_attachment( - path=path, - type=attach_type, - mime_type=content_type, - ) - - if not attachment: - return Response().error("Failed to create attachment").__dict__ - - filename = os.path.basename(attachment.path) - - return ( - Response() - .ok( - data={ - "attachment_id": attachment.attachment_id, - "filename": filename, - "type": attach_type, - } - ) - .__dict__ - ) - - async def _build_user_message_parts(self, message: str | list) -> list[dict]: - """构建用户消息的部分列表 - - Args: - message: 文本消息 (str) 或消息段列表 (list) - """ - parts = [] - - if isinstance(message, list): - for part in message: - part_type = part.get("type") - if part_type == "plain": - parts.append({"type": "plain", "text": part.get("text", "")}) - elif part_type == "reply": - parts.append( - {"type": "reply", "message_id": part.get("message_id")} - ) - elif attachment_id := part.get("attachment_id"): - attachment = await self.db.get_attachment_by_id(attachment_id) - if attachment: - parts.append( - { - "type": attachment.type, - "attachment_id": attachment.attachment_id, - "filename": os.path.basename(attachment.path), - "path": attachment.path, # will be deleted - } - ) - return parts - - if message: - parts.append({"type": "plain", "text": message}) - - return parts - - async def _create_attachment_from_file( - self, filename: str, attach_type: str - ) -> dict | None: - """从本地文件创建 attachment 并返回消息部分 - - 用于处理 bot 回复中的媒体文件 - - Args: - filename: 存储的文件名 - attach_type: 附件类型 (image, record, file, video) - """ - file_path = os.path.join(self.imgs_dir, os.path.basename(filename)) - if not os.path.exists(file_path): - return None - - # guess mime type - mime_type, _ = mimetypes.guess_type(filename) - if not mime_type: - mime_type = "application/octet-stream" - - # insert attachment - attachment = await self.db.insert_attachment( - path=file_path, - type=attach_type, - mime_type=mime_type, - ) - if not attachment: - return None - - return { - "type": attach_type, - "attachment_id": attachment.attachment_id, - "filename": os.path.basename(file_path), - } - - async def _save_bot_message( - self, - webchat_conv_id: str, - text: str, - media_parts: list, - reasoning: str, - agent_stats: dict, - ): - """保存 bot 消息到历史记录,返回保存的记录""" - bot_message_parts = [] - bot_message_parts.extend(media_parts) - if text: - bot_message_parts.append({"type": "plain", "text": text}) - - new_his = {"type": "bot", "message": bot_message_parts} - if reasoning: - new_his["reasoning"] = reasoning - if agent_stats: - new_his["agent_stats"] = agent_stats - - record = await self.platform_history_mgr.insert( - platform_id="webchat", - user_id=webchat_conv_id, - content=new_his, - sender_id="bot", - sender_name="bot", - ) - return record - - async def chat(self): - username = g.get("username", "guest") - - post_data = await request.json - if "message" not in post_data and "files" not in post_data: - return Response().error("Missing key: message or files").__dict__ - - if "session_id" not in post_data and "conversation_id" not in post_data: - return ( - Response().error("Missing key: session_id or conversation_id").__dict__ - ) - - message = post_data["message"] - session_id = post_data.get("session_id", post_data.get("conversation_id")) - selected_provider = post_data.get("selected_provider") - selected_model = post_data.get("selected_model") - enable_streaming = post_data.get("enable_streaming", True) - - # 检查消息是否为空 - if isinstance(message, list): - has_content = any( - part.get("type") in ("plain", "image", "record", "file", "video") - for part in message - ) - if not has_content: - return ( - Response() - .error("Message content is empty (reply only is not allowed)") - .__dict__ - ) - elif not message: - return Response().error("Message are both empty").__dict__ - - if not session_id: - return Response().error("session_id is empty").__dict__ - - webchat_conv_id = session_id - back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) - - # 构建用户消息段(包含 path 用于传递给 adapter) - message_parts = await self._build_user_message_parts(message) - - async def stream(): - client_disconnected = False - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" - tool_calls = {} - agent_stats = {} - try: - async with track_conversation(self.running_convs, webchat_conv_id): - while True: - try: - result = await asyncio.wait_for(back_queue.get(), timeout=1) - except asyncio.TimeoutError: - continue - except asyncio.CancelledError: - logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") - client_disconnected = True - except Exception as e: - logger.error(f"WebChat stream error: {e}") - - if not result: - continue - - result_text = result["data"] - msg_type = result.get("type") - streaming = result.get("streaming", False) - chain_type = result.get("chain_type") - - if chain_type == "agent_stats": - stats_info = { - "type": "agent_stats", - "data": json.loads(result_text), - } - yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n" - agent_stats = stats_info["data"] - continue - - # 发送 SSE 数据 - try: - if not client_disconnected: - yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" - except Exception as e: - if not client_disconnected: - logger.debug( - f"[WebChat] 用户 {username} 断开聊天长连接。 {e}" - ) - client_disconnected = True - - try: - if not client_disconnected: - await asyncio.sleep(0.05) - except asyncio.CancelledError: - logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") - client_disconnected = True - - # 累积消息部分 - if msg_type == "plain": - chain_type = result.get("chain_type") - if chain_type == "tool_call": - tool_call = json.loads(result_text) - tool_calls[tool_call.get("id")] = tool_call - if accumulated_text: - # 如果累积了文本,则先保存文本 - accumulated_parts.append( - {"type": "plain", "text": accumulated_text} - ) - accumulated_text = "" - elif chain_type == "tool_call_result": - tcr = json.loads(result_text) - tc_id = tcr.get("id") - if tc_id in tool_calls: - tool_calls[tc_id]["result"] = tcr.get("result") - tool_calls[tc_id]["finished_ts"] = tcr.get("ts") - accumulated_parts.append( - { - "type": "tool_call", - "tool_calls": [tool_calls[tc_id]], - } - ) - tool_calls.pop(tc_id, None) - elif chain_type == "reasoning": - accumulated_reasoning += result_text - elif streaming: - accumulated_text += result_text - else: - accumulated_text = result_text - elif msg_type == "image": - filename = result_text.replace("[IMAGE]", "") - part = await self._create_attachment_from_file( - filename, "image" - ) - if part: - accumulated_parts.append(part) - elif msg_type == "record": - filename = result_text.replace("[RECORD]", "") - part = await self._create_attachment_from_file( - filename, "record" - ) - if part: - accumulated_parts.append(part) - elif msg_type == "file": - # 格式: [FILE]filename - filename = result_text.replace("[FILE]", "") - part = await self._create_attachment_from_file( - filename, "file" - ) - if part: - accumulated_parts.append(part) - - # 消息结束处理 - if msg_type == "end": - break - elif ( - (streaming and msg_type == "complete") or not streaming - # or msg_type == "break" - ): - if ( - chain_type == "tool_call" - or chain_type == "tool_call_result" - ): - continue - saved_record = await self._save_bot_message( - webchat_conv_id, - accumulated_text, - accumulated_parts, - accumulated_reasoning, - agent_stats, - ) - # 发送保存的消息信息给前端 - if saved_record and not client_disconnected: - saved_info = { - "type": "message_saved", - "data": { - "id": saved_record.id, - "created_at": saved_record.created_at.astimezone().isoformat(), - }, - } - try: - yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n" - except Exception: - pass - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" - # tool_calls = {} - agent_stats = {} - except BaseException as e: - logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) - - # 将消息放入会话特定的队列 - chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) - await chat_queue.put( - ( - username, - webchat_conv_id, - { - "message": message_parts, - "selected_provider": selected_provider, - "selected_model": selected_model, - "enable_streaming": enable_streaming, - }, - ), - ) - - message_parts_for_storage = [] - for part in message_parts: - part_copy = {k: v for k, v in part.items() if k != "path"} - message_parts_for_storage.append(part_copy) - - await self.platform_history_mgr.insert( - platform_id="webchat", - user_id=webchat_conv_id, - content={"type": "user", "message": message_parts_for_storage}, - sender_id=username, - sender_name=username, - ) - - response = cast( - QuartResponse, - await make_response( - stream(), - { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Transfer-Encoding": "chunked", - "Connection": "keep-alive", - }, - ), - ) - response.timeout = None # fix SSE auto disconnect issue - return response - - async def delete_webchat_session(self): - """Delete a Platform session and all its related data.""" - session_id = request.args.get("session_id") - if not session_id: - return Response().error("Missing key: session_id").__dict__ - username = g.get("username", "guest") - - # 验证会话是否存在且属于当前用户 - session = await self.db.get_platform_session_by_id(session_id) - if not session: - return Response().error(f"Session {session_id} not found").__dict__ - if session.creator != username: - return Response().error("Permission denied").__dict__ - - # 删除该会话下的所有对话 - message_type = "GroupMessage" if session.is_group else "FriendMessage" - unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}" - await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) - - # 获取消息历史中的所有附件 ID 并删除附件 - history_list = await self.platform_history_mgr.get( - platform_id=session.platform_id, - user_id=session_id, - page=1, - page_size=100000, # 获取足够多的记录 - ) - attachment_ids = self._extract_attachment_ids(history_list) - if attachment_ids: - await self._delete_attachments(attachment_ids) - - # 删除消息历史 - await self.platform_history_mgr.delete( - platform_id=session.platform_id, - user_id=session_id, - offset_sec=99999999, - ) - - # 删除与会话关联的配置路由 - try: - await self.umop_config_router.delete_route(unified_msg_origin) - except ValueError as exc: - logger.warning( - "Failed to delete UMO route %s during session cleanup: %s", - unified_msg_origin, - exc, - ) - - # 清理队列(仅对 webchat) - if session.platform_id == "webchat": - webchat_queue_mgr.remove_queues(session_id) - - # 删除会话 - await self.db.delete_platform_session(session_id) - - return Response().ok().__dict__ - - def _extract_attachment_ids(self, history_list) -> list[str]: - """从消息历史中提取所有 attachment_id""" - attachment_ids = [] - for history in history_list: - content = history.content - if not content or "message" not in content: - continue - message_parts = content.get("message", []) - for part in message_parts: - if isinstance(part, dict) and "attachment_id" in part: - attachment_ids.append(part["attachment_id"]) - return attachment_ids - - async def _delete_attachments(self, attachment_ids: list[str]): - """删除附件(包括数据库记录和磁盘文件)""" - try: - attachments = await self.db.get_attachments(attachment_ids) - for attachment in attachments: - if not os.path.exists(attachment.path): - continue - try: - os.remove(attachment.path) - except OSError as e: - logger.warning( - f"Failed to delete attachment file {attachment.path}: {e}" - ) - except Exception as e: - logger.warning(f"Failed to get attachments: {e}") - - # 批量删除数据库记录 - try: - await self.db.delete_attachments(attachment_ids) - except Exception as e: - logger.warning(f"Failed to delete attachments: {e}") - - async def new_session(self): - """Create a new Platform session (default: webchat).""" - username = g.get("username", "guest") - - # 获取可选的 platform_id 参数,默认为 webchat - platform_id = request.args.get("platform_id", "webchat") - - # 创建新会话 - session = await self.db.create_platform_session( - creator=username, - platform_id=platform_id, - is_group=0, - ) - - return ( - Response() - .ok( - data={ - "session_id": session.session_id, - "platform_id": session.platform_id, - } - ) - .__dict__ - ) - - async def get_sessions(self): - """Get all Platform sessions for the current user.""" - username = g.get("username", "guest") - - # 获取可选的 platform_id 参数 - platform_id = request.args.get("platform_id") - - sessions = await self.db.get_platform_sessions_by_creator( - creator=username, - platform_id=platform_id, - page=1, - page_size=100, # 暂时返回前100个 - ) - - # 转换为字典格式,并添加额外信息 - sessions_data = [] - for session in sessions: - sessions_data.append( - { - "session_id": session.session_id, - "platform_id": session.platform_id, - "creator": session.creator, - "display_name": session.display_name, - "is_group": session.is_group, - "created_at": session.created_at.astimezone().isoformat(), - "updated_at": session.updated_at.astimezone().isoformat(), - } - ) - - return Response().ok(data=sessions_data).__dict__ - - async def get_session(self): - """Get session information and message history by session_id.""" - session_id = request.args.get("session_id") - if not session_id: - return Response().error("Missing key: session_id").__dict__ - - # 获取会话信息以确定 platform_id - session = await self.db.get_platform_session_by_id(session_id) - platform_id = session.platform_id if session else "webchat" - - # Get platform message history using session_id - history_ls = await self.platform_history_mgr.get( - platform_id=platform_id, - user_id=session_id, - page=1, - page_size=1000, - ) - - history_res = [history.model_dump() for history in history_ls] - - return ( - Response() - .ok( - data={ - "history": history_res, - "is_running": self.running_convs.get(session_id, False), - }, - ) - .__dict__ - ) - - async def update_session_display_name(self): - """Update a Platform session's display name.""" - post_data = await request.json - - session_id = post_data.get("session_id") - display_name = post_data.get("display_name") - - if not session_id: - return Response().error("Missing key: session_id").__dict__ - if display_name is None: - return Response().error("Missing key: display_name").__dict__ - - username = g.get("username", "guest") - - # 验证会话是否存在且属于当前用户 - session = await self.db.get_platform_session_by_id(session_id) - if not session: - return Response().error(f"Session {session_id} not found").__dict__ - if session.creator != username: - return Response().error("Permission denied").__dict__ - - # 更新 display_name - await self.db.update_platform_session( - session_id=session_id, - display_name=display_name, - ) - - return Response().ok().__dict__ diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index f39cccfe6..4caa60749 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,152 +1,28 @@ import asyncio -import inspect -import os import traceback -from typing import Any from quart import request -from astrbot.core import astrbot_config, file_token_service, logger +from astrbot.core import logger from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.default import ( CONFIG_METADATA_2, CONFIG_METADATA_3, CONFIG_METADATA_3_SYSTEM, DEFAULT_CONFIG, - DEFAULT_VALUE_MAP, ) from astrbot.core.config.i18n_utils import ConfigMetadataI18n from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.platform.register import platform_cls_map, platform_registry -from astrbot.core.provider import Provider +from astrbot.core.platform.register import platform_registry from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry -from astrbot.core.utils.llm_metadata import LLM_METADATAS -from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config +from ..services.platform import PlatformService +from ..services.provider import ProviderService +from ..services.utils import save_config from .route import Response, Route, RouteContext -def try_cast(value: Any, type_: str): - if type_ == "int": - try: - return int(value) - except (ValueError, TypeError): - return None - elif ( - type_ == "float" - and isinstance(value, str) - and value.replace(".", "", 1).isdigit() - ) or (type_ == "float" and isinstance(value, int)): - return float(value) - elif type_ == "float": - try: - return float(value) - except (ValueError, TypeError): - return None - - -def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]: - errors = [] - - def validate(data: dict, metadata: dict = schema, path=""): - for key, value in data.items(): - if key not in metadata: - continue - meta = metadata[key] - if "type" not in meta: - logger.debug(f"配置项 {path}{key} 没有类型定义, 跳过校验") - continue - # null 转换 - if value is None: - data[key] = DEFAULT_VALUE_MAP[meta["type"]] - continue - if meta["type"] == "list" and not isinstance(value, list): - errors.append( - f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}", - ) - elif ( - meta["type"] == "list" - and isinstance(value, list) - and value - and "items" in meta - and isinstance(value[0], dict) - ): - # 当前仅针对 list[dict] 的情况进行类型校验,以适配 AstrBot 中 platform、provider 的配置 - for item in value: - validate(item, meta["items"], path=f"{path}{key}.") - elif meta["type"] == "object" and isinstance(value, dict): - validate(value, meta["items"], path=f"{path}{key}.") - - if meta["type"] == "int" and not isinstance(value, int): - casted = try_cast(value, "int") - if casted is None: - errors.append( - f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}", - ) - data[key] = casted - elif meta["type"] == "float" and not isinstance(value, float): - casted = try_cast(value, "float") - if casted is None: - errors.append( - f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}", - ) - data[key] = casted - elif meta["type"] == "bool" and not isinstance(value, bool): - errors.append( - f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}", - ) - elif meta["type"] in ["string", "text"] and not isinstance(value, str): - errors.append( - f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}", - ) - elif meta["type"] == "list" and not isinstance(value, list): - errors.append( - f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}", - ) - elif meta["type"] == "object" and not isinstance(value, dict): - errors.append( - f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}", - ) - - if is_core: - meta_all = { - **schema["platform_group"]["metadata"], - **schema["provider_group"]["metadata"], - **schema["misc_config_group"]["metadata"], - } - validate(data, meta_all) - else: - validate(data, schema) - - return errors, data - - -def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False): - """验证并保存配置""" - errors = None - logger.info(f"Saving config, is_core={is_core}") - try: - if is_core: - errors, post_config = validate_config( - post_config, - CONFIG_METADATA_2, - is_core, - ) - else: - errors, post_config = validate_config( - post_config, getattr(config, "schema", {}), is_core - ) - except BaseException as e: - logger.error(traceback.format_exc()) - logger.warning(f"验证配置时出现异常: {e}") - raise ValueError(f"验证配置时出现异常: {e}") - if errors: - raise ValueError(f"格式校验未通过: {errors}") - - config.save_config(post_config) - - class ConfigRoute(Route): def __init__( self, @@ -156,9 +32,12 @@ def __init__( super().__init__(context) self.core_lifecycle = core_lifecycle self.config: AstrBotConfig = core_lifecycle.astrbot_config - self._logo_token_cache = {} # 缓存logo token,避免重复注册 self.acm = core_lifecycle.astrbot_config_mgr self.ucr = core_lifecycle.umop_config_router + + self.provider_service = ProviderService(core_lifecycle) + self.platform_service = PlatformService(core_lifecycle) + self.routes = { "/config/abconf/new": ("POST", self.create_abconf), "/config/abconf": ("GET", self.get_abconf), @@ -173,164 +52,66 @@ def __init__( "/config/default": ("GET", self.get_default_config), "/config/astrbot/update": ("POST", self.post_astrbot_configs), "/config/plugin/update": ("POST", self.post_plugin_configs), - "/config/platform/new": ("POST", self.post_new_platform), - "/config/platform/update": ("POST", self.post_update_platform), - "/config/platform/delete": ("POST", self.post_delete_platform), - "/config/platform/list": ("GET", self.get_platform_list), - "/config/provider/new": ("POST", self.post_new_provider), - "/config/provider/update": ("POST", self.post_update_provider), - "/config/provider/delete": ("POST", self.post_delete_provider), - "/config/provider/template": ("GET", self.get_provider_template), - "/config/provider/check_one": ("GET", self.check_one_provider_status), - "/config/provider/list": ("GET", self.get_provider_config_list), - "/config/provider/model_list": ("GET", self.get_provider_model_list), - "/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim), + "/config/platform/new": ( + "POST", + self.platform_service.post_new_platform, + ), + "/config/platform/update": ( + "POST", + self.platform_service.post_update_platform, + ), + "/config/platform/delete": ( + "POST", + self.platform_service.post_delete_platform, + ), + "/config/platform/list": ( + "GET", + self.platform_service.get_platform_list, + ), + # provider related + "/config/provider/new": ( + "POST", + self.provider_service.post_new_provider, + ), + "/config/provider/update": ( + "POST", + self.provider_service.post_update_provider, + ), + "/config/provider/delete": ( + "POST", + self.provider_service.post_delete_provider, + ), + "/config/provider/template": ( + "GET", + self.provider_service.get_provider_template, + ), + "/config/provider/check_one": ( + "GET", + self.provider_service.check_one_provider_status, + ), + "/config/provider/list": ( + "GET", + self.provider_service.get_provider_config_list, + ), + "/config/provider/get_embedding_dim": ( + "POST", + self.provider_service.get_embedding_dim, + ), "/config/provider_sources/models": ( "GET", - self.get_provider_source_models, + self.provider_service.get_provider_source_models, ), "/config/provider_sources/update": ( "POST", - self.update_provider_source, + self.provider_service.update_provider_source, ), "/config/provider_sources/delete": ( "POST", - self.delete_provider_source, + self.provider_service.delete_provider_source, ), } self.register_routes() - async def delete_provider_source(self): - """删除 provider_source,并更新关联的 providers""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - provider_source_id = post_data.get("id") - if not provider_source_id: - return Response().error("缺少 provider_source_id").__dict__ - - provider_sources = self.config.get("provider_sources", []) - target_idx = next( - ( - i - for i, ps in enumerate(provider_sources) - if ps.get("id") == provider_source_id - ), - -1, - ) - - if target_idx == -1: - return Response().error("未找到对应的 provider source").__dict__ - - # 删除 provider_source - del provider_sources[target_idx] - - # 写回配置 - self.config["provider_sources"] = provider_sources - - # 删除引用了该 provider_source 的 providers - await self.core_lifecycle.provider_manager.delete_provider( - provider_source_id=provider_source_id - ) - - try: - save_config(self.config, self.config, is_core=True) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - return Response().ok(message="删除 provider source 成功").__dict__ - - async def update_provider_source(self): - """更新或新增 provider_source,并重载关联的 providers""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - new_source_config = post_data.get("config") or post_data - original_id = post_data.get("original_id") - if not original_id: - return Response().error("缺少 original_id").__dict__ - - if not isinstance(new_source_config, dict): - return Response().error("缺少或错误的配置数据").__dict__ - - # 确保配置中有 id 字段 - if not new_source_config.get("id"): - new_source_config["id"] = original_id - - provider_sources = self.config.get("provider_sources", []) - - for ps in provider_sources: - if ps.get("id") == new_source_config["id"] and ps.get("id") != original_id: - return ( - Response() - .error( - f"Provider source ID '{new_source_config['id']}' exists already, please try another ID.", - ) - .__dict__ - ) - - # 查找旧的 provider_source,若不存在则追加为新配置 - target_idx = next( - (i for i, ps in enumerate(provider_sources) if ps.get("id") == original_id), - -1, - ) - - old_id = original_id - if target_idx == -1: - provider_sources.append(new_source_config) - else: - old_id = provider_sources[target_idx].get("id") - provider_sources[target_idx] = new_source_config - - # 更新引用了该 provider_source 的 providers - affected_providers = [] - for provider in self.config.get("provider", []): - if provider.get("provider_source_id") == old_id: - provider["provider_source_id"] = new_source_config["id"] - affected_providers.append(provider) - - # 写回配置 - self.config["provider_sources"] = provider_sources - - try: - save_config(self.config, self.config, is_core=True) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - # 重载受影响的 providers,使新的 source 配置生效 - reload_errors = [] - prov_mgr = self.core_lifecycle.provider_manager - for provider in affected_providers: - try: - await prov_mgr.reload(provider) - except Exception as e: - logger.error(traceback.format_exc()) - reload_errors.append(f"{provider.get('id')}: {e}") - - if reload_errors: - return ( - Response() - .error("更新成功,但部分提供商重载失败: " + ", ".join(reload_errors)) - .__dict__ - ) - - return Response().ok(message="更新 provider source 成功").__dict__ - - async def get_provider_template(self): - config_schema = { - "provider": CONFIG_METADATA_2["provider_group"]["metadata"]["provider"] - } - data = { - "config_schema": config_schema, - "providers": astrbot_config["provider"], - "provider_sources": astrbot_config["provider_sources"], - } - return Response().ok(data=data).__dict__ - async def get_uc_table(self): """获取 UMOP 配置路由表""" return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__ @@ -482,88 +263,6 @@ async def update_abconf(self): logger.error(traceback.format_exc()) return Response().error(f"更新配置文件失败: {e!s}").__dict__ - async def _test_single_provider(self, provider): - """辅助函数:测试单个 provider 的可用性""" - meta = provider.meta() - provider_name = provider.provider_config.get("id", "Unknown Provider") - provider_capability_type = meta.provider_type - - status_info = { - "id": getattr(meta, "id", "Unknown ID"), - "model": getattr(meta, "model", "Unknown Model"), - "type": provider_capability_type.value, - "name": provider_name, - "status": "unavailable", # 默认为不可用 - "error": None, - } - logger.debug( - f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", - ) - - try: - await provider.test() - status_info["status"] = "available" - logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", - ) - except Exception as e: - error_message = str(e) - status_info["error"] = error_message - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}", - ) - logger.debug( - f"Traceback for {status_info['name']}:\n{traceback.format_exc()}", - ) - - return status_info - - def _error_response( - self, - message: str, - status_code: int = 500, - log_fn=logger.error, - ): - log_fn(message) - # 记录更详细的traceback信息,但只在是严重错误时 - if status_code == 500: - log_fn(traceback.format_exc()) - return Response().error(message).__dict__ - - async def check_one_provider_status(self): - """API: check a single LLM Provider's status by id""" - provider_id = request.args.get("id") - if not provider_id: - return self._error_response( - "Missing provider_id parameter", - 400, - logger.warning, - ) - - logger.info(f"API call: /config/provider/check_one id={provider_id}") - try: - prov_mgr = self.core_lifecycle.provider_manager - target = prov_mgr.inst_map.get(provider_id) - - if not target: - logger.warning( - f"Provider with id '{provider_id}' not found in provider_manager.", - ) - return ( - Response() - .error(f"Provider with id '{provider_id}' not found") - .__dict__ - ) - - result = await self._test_single_provider(target) - return Response().ok(result).__dict__ - - except Exception as e: - return self._error_response( - f"Critical error checking provider {provider_id}: {e}", - 500, - ) - async def get_configs(self): # plugin_name 为空时返回 AstrBot 配置 # 否则返回指定 plugin_name 的插件配置 @@ -572,231 +271,6 @@ async def get_configs(self): return Response().ok(await self._get_astrbot_config()).__dict__ return Response().ok(await self._get_plugin_config(plugin_name)).__dict__ - async def get_provider_config_list(self): - provider_type = request.args.get("provider_type", None) - if not provider_type: - return Response().error("缺少参数 provider_type").__dict__ - provider_type_ls = provider_type.split(",") - provider_list = [] - ps = self.core_lifecycle.provider_manager.providers_config - p_source_pt = { - psrc["id"]: psrc["provider_type"] - for psrc in self.core_lifecycle.provider_manager.provider_sources_config - } - for provider in ps: - ps_id = provider.get("provider_source_id", None) - if ( - ps_id - and ps_id in p_source_pt - and p_source_pt[ps_id] in provider_type_ls - ): - # chat - prov = self.core_lifecycle.provider_manager.get_merged_provider_config( - provider - ) - provider_list.append(prov) - elif not ps_id and provider.get("provider_type", None) in provider_type_ls: - # agent runner, embedding, etc - provider_list.append(provider) - return Response().ok(provider_list).__dict__ - - async def get_provider_model_list(self): - """获取指定提供商的模型列表""" - provider_id = request.args.get("provider_id", None) - if not provider_id: - return Response().error("缺少参数 provider_id").__dict__ - - prov_mgr = self.core_lifecycle.provider_manager - provider = prov_mgr.inst_map.get(provider_id, None) - if not provider: - return Response().error(f"未找到 ID 为 {provider_id} 的提供商").__dict__ - if not isinstance(provider, Provider): - return ( - Response() - .error(f"提供商 {provider_id} 类型不支持获取模型列表") - .__dict__ - ) - - try: - models = await provider.get_models() - models = models or [] - - metadata_map = {} - for model_id in models: - meta = LLM_METADATAS.get(model_id) - if meta: - metadata_map[model_id] = meta - - ret = { - "models": models, - "provider_id": provider_id, - "model_metadata": metadata_map, - } - return Response().ok(ret).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def get_embedding_dim(self): - """获取嵌入模型的维度""" - post_data = await request.json - provider_config = post_data.get("provider_config", None) - if not provider_config: - return Response().error("缺少参数 provider_config").__dict__ - - try: - # 动态导入 EmbeddingProvider - from astrbot.core.provider.provider import EmbeddingProvider - from astrbot.core.provider.register import provider_cls_map - - # 获取 provider 类型 - provider_type = provider_config.get("type", None) - if not provider_type: - return Response().error("provider_config 缺少 type 字段").__dict__ - - # 获取对应的 provider 类 - if provider_type not in provider_cls_map: - return ( - Response() - .error(f"未找到适用于 {provider_type} 的提供商适配器") - .__dict__ - ) - - provider_metadata = provider_cls_map[provider_type] - cls_type = provider_metadata.cls_type - - if not cls_type: - return Response().error(f"无法找到 {provider_type} 的类").__dict__ - - # 实例化 provider - inst = cls_type(provider_config, {}) - - # 检查是否是 EmbeddingProvider - if not isinstance(inst, EmbeddingProvider): - return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ - - init_fn = getattr(inst, "initialize", None) - if inspect.iscoroutinefunction(init_fn): - await init_fn() - - # 获取嵌入向量维度 - vec = await inst.get_embedding("echo") - dim = len(vec) - - logger.info( - f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}", - ) - - return Response().ok({"embedding_dimensions": dim}).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ - - async def get_provider_source_models(self): - """获取指定 provider_source 支持的模型列表 - - 本质上会临时初始化一个 Provider 实例,调用 get_models() 获取模型列表,然后销毁实例 - """ - provider_source_id = request.args.get("source_id") - if not provider_source_id: - return Response().error("缺少参数 source_id").__dict__ - - try: - from astrbot.core.provider.register import provider_cls_map - - # 从配置中查找对应的 provider_source - provider_sources = self.config.get("provider_sources", []) - provider_source = None - for ps in provider_sources: - if ps.get("id") == provider_source_id: - provider_source = ps - break - - if not provider_source: - return ( - Response() - .error(f"未找到 ID 为 {provider_source_id} 的 provider_source") - .__dict__ - ) - - # 获取 provider 类型 - provider_type = provider_source.get("type", None) - if not provider_type: - return Response().error("provider_source 缺少 type 字段").__dict__ - - try: - self.core_lifecycle.provider_manager.dynamic_import_provider( - provider_type - ) - except ImportError as e: - logger.error(traceback.format_exc()) - return Response().error(f"动态导入提供商适配器失败: {e!s}").__dict__ - - # 获取对应的 provider 类 - if provider_type not in provider_cls_map: - return ( - Response() - .error(f"未找到适用于 {provider_type} 的提供商适配器") - .__dict__ - ) - - provider_metadata = provider_cls_map[provider_type] - cls_type = provider_metadata.cls_type - - if not cls_type: - return Response().error(f"无法找到 {provider_type} 的类").__dict__ - - # 检查是否是 Provider 类型 - if not issubclass(cls_type, Provider): - return ( - Response() - .error(f"提供商 {provider_type} 不支持获取模型列表") - .__dict__ - ) - - # 临时实例化 provider - inst = cls_type(provider_source, {}) - - # 如果有 initialize 方法,调用它 - init_fn = getattr(inst, "initialize", None) - if inspect.iscoroutinefunction(init_fn): - await init_fn() - - # 获取模型列表 - models = await inst.get_models() - models = models or [] - - metadata_map = {} - for model_id in models: - meta = LLM_METADATAS.get(model_id) - if meta: - metadata_map[model_id] = meta - - # 销毁实例(如果有 terminate 方法) - terminate_fn = getattr(inst, "terminate", None) - if inspect.iscoroutinefunction(terminate_fn): - await terminate_fn() - - logger.info( - f"获取到 provider_source {provider_source_id} 的模型列表: {models}", - ) - - return ( - Response() - .ok({"models": models, "model_metadata": metadata_map}) - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"获取模型列表失败: {e!s}").__dict__ - - async def get_platform_list(self): - """获取所有平台的列表""" - platform_list = [] - for platform in self.config["platform"]: - platform_list.append(platform) - return Response().ok({"platforms": platform_list}).__dict__ - async def post_astrbot_configs(self): data = await request.json config = data.get("config", None) @@ -831,178 +305,12 @@ async def post_plugin_configs(self): except Exception as e: return Response().error(str(e)).__dict__ - async def post_new_platform(self): - new_platform_config = await request.json - - # 如果是支持统一 webhook 模式的平台,生成 webhook_uuid - ensure_platform_webhook_config(new_platform_config) - - self.config["platform"].append(new_platform_config) - try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.platform_manager.load_platform( - new_platform_config, - ) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "新增平台配置成功~").__dict__ - - async def post_new_provider(self): - new_provider_config = await request.json - - try: - await self.core_lifecycle.provider_manager.create_provider( - new_provider_config - ) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "新增服务提供商配置成功").__dict__ - - async def post_update_platform(self): - update_platform_config = await request.json - origin_platform_id = update_platform_config.get("id", None) - new_config = update_platform_config.get("config", None) - if not origin_platform_id or not new_config: - return Response().error("参数错误").__dict__ - - if origin_platform_id != new_config.get("id", None): - return Response().error("机器人名称不允许修改").__dict__ - - # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid - ensure_platform_webhook_config(new_config) - - for i, platform in enumerate(self.config["platform"]): - if platform["id"] == origin_platform_id: - self.config["platform"][i] = new_config - break - else: - return Response().error("未找到对应平台").__dict__ - - try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.platform_manager.reload(new_config) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "更新平台配置成功~").__dict__ - - async def post_update_provider(self): - update_provider_config = await request.json - origin_provider_id = update_provider_config.get("id", None) - new_config = update_provider_config.get("config", None) - if not origin_provider_id or not new_config: - return Response().error("参数错误").__dict__ - - try: - await self.core_lifecycle.provider_manager.update_provider( - origin_provider_id, new_config - ) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "更新成功,已经实时生效~").__dict__ - - async def post_delete_platform(self): - platform_id = await request.json - platform_id = platform_id.get("id") - for i, platform in enumerate(self.config["platform"]): - if platform["id"] == platform_id: - del self.config["platform"][i] - break - else: - return Response().error("未找到对应平台").__dict__ - try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.platform_manager.terminate_platform(platform_id) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "删除平台配置成功~").__dict__ - - async def post_delete_provider(self): - provider_id = await request.json - provider_id = provider_id.get("id", "") - if not provider_id: - return Response().error("缺少参数 id").__dict__ - - try: - await self.core_lifecycle.provider_manager.delete_provider( - provider_id=provider_id - ) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "删除成功,已经实时生效。").__dict__ - async def get_llm_tools(self): """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" tool_mgr = self.core_lifecycle.provider_manager.llm_tools tools = tool_mgr.get_func_desc_openai_style() return Response().ok(tools).__dict__ - async def _register_platform_logo(self, platform, platform_default_tmpl): - """注册平台logo文件并生成访问令牌""" - if not platform.logo_path: - return - - try: - # 检查缓存 - cache_key = f"{platform.name}:{platform.logo_path}" - if cache_key in self._logo_token_cache: - cached_token = self._logo_token_cache[cache_key] - # 确保platform_default_tmpl[platform.name]存在且为字典 - if platform.name not in platform_default_tmpl or not isinstance( - platform_default_tmpl[platform.name], dict - ): - platform_default_tmpl[platform.name] = {} - platform_default_tmpl[platform.name]["logo_token"] = cached_token - logger.debug(f"Using cached logo token for platform {platform.name}") - return - - # 获取平台适配器类 - platform_cls = platform_cls_map.get(platform.name) - if not platform_cls: - logger.warning(f"Platform class not found for {platform.name}") - return - - # 获取插件目录路径 - module_file = inspect.getfile(platform_cls) - plugin_dir = os.path.dirname(module_file) - - # 解析logo文件路径 - logo_file_path = os.path.join(plugin_dir, platform.logo_path) - - # 检查文件是否存在并注册令牌 - if os.path.exists(logo_file_path): - logo_token = await file_token_service.register_file( - logo_file_path, - timeout=3600, - ) - - # 确保platform_default_tmpl[platform.name]存在且为字典 - if platform.name not in platform_default_tmpl or not isinstance( - platform_default_tmpl[platform.name], dict - ): - platform_default_tmpl[platform.name] = {} - - platform_default_tmpl[platform.name]["logo_token"] = logo_token - - # 缓存token - self._logo_token_cache[cache_key] = logo_token - - logger.debug(f"Logo token registered for platform {platform.name}") - else: - logger.warning( - f"Platform {platform.name} logo file not found: {logo_file_path}", - ) - - except (ImportError, AttributeError) as e: - logger.warning( - f"Failed to import required modules for platform {platform.name}: {e}", - ) - except OSError as e: - logger.warning(f"File system error for platform {platform.name} logo: {e}") - except Exception as e: - logger.warning( - f"Unexpected error registering logo for platform {platform.name}: {e}", - ) - async def _get_astrbot_config(self): config = self.config @@ -1019,7 +327,9 @@ async def _get_astrbot_config(self): # 收集logo注册任务 if platform.logo_path: logo_registration_tasks.append( - self._register_platform_logo(platform, platform_default_tmpl), + self.platform_service.register_platform_logo( + platform, platform_default_tmpl + ), ) # 并行执行logo注册 diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 537a81f0b..4e8994295 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -1,19 +1,9 @@ """知识库管理 API 路由""" -import asyncio -import os -import traceback -import uuid - -import aiofiles -from quart import request - -from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider -from ..utils import generate_tsne_visualization -from .route import Response, Route, RouteContext +from ..services.knowledge_base import KnowledgeBaseService +from .route import Route, RouteContext class KnowledgeBaseRoute(Route): @@ -29,1235 +19,35 @@ def __init__( ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle - self.kb_manager = None # 延迟初始化 - self.kb_db = None - self.session_config_db = None # 会话配置数据库 - self.retrieval_manager = None - self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}} - self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}} + self.kb_service = KnowledgeBaseService(core_lifecycle) # 注册路由 self.routes = { # 知识库管理 - "/kb/list": ("GET", self.list_kbs), - "/kb/create": ("POST", self.create_kb), - "/kb/get": ("GET", self.get_kb), - "/kb/update": ("POST", self.update_kb), - "/kb/delete": ("POST", self.delete_kb), - "/kb/stats": ("GET", self.get_kb_stats), + "/kb/list": ("GET", self.kb_service.list_kbs), + "/kb/create": ("POST", self.kb_service.create_kb), + "/kb/get": ("GET", self.kb_service.get_kb), + "/kb/update": ("POST", self.kb_service.update_kb), + "/kb/delete": ("POST", self.kb_service.delete_kb), + "/kb/stats": ("GET", self.kb_service.get_kb_stats), # 文档管理 - "/kb/document/list": ("GET", self.list_documents), - "/kb/document/upload": ("POST", self.upload_document), - "/kb/document/import": ("POST", self.import_documents), - "/kb/document/upload/url": ("POST", self.upload_document_from_url), - "/kb/document/upload/progress": ("GET", self.get_upload_progress), - "/kb/document/get": ("GET", self.get_document), - "/kb/document/delete": ("POST", self.delete_document), + "/kb/document/list": ("GET", self.kb_service.list_documents), + "/kb/document/upload": ("POST", self.kb_service.upload_document), + "/kb/document/import": ("POST", self.kb_service.import_documents), + "/kb/document/upload/url": ( + "POST", + self.kb_service.upload_document_from_url, + ), + "/kb/document/upload/progress": ( + "GET", + self.kb_service.get_upload_progress, + ), + "/kb/document/get": ("GET", self.kb_service.get_document), + "/kb/document/delete": ("POST", self.kb_service.delete_document), # # 块管理 - "/kb/chunk/list": ("GET", self.list_chunks), - "/kb/chunk/delete": ("POST", self.delete_chunk), - # # 多媒体管理 - # "/kb/media/list": ("GET", self.list_media), - # "/kb/media/delete": ("POST", self.delete_media), + "/kb/chunk/list": ("GET", self.kb_service.list_chunks), + "/kb/chunk/delete": ("POST", self.kb_service.delete_chunk), # 检索 - "/kb/retrieve": ("POST", self.retrieve), + "/kb/retrieve": ("POST", self.kb_service.retrieve), } self.register_routes() - - def _get_kb_manager(self): - return self.core_lifecycle.kb_manager - - def _init_task(self, task_id: str, status: str = "pending") -> None: - self.upload_tasks[task_id] = { - "status": status, - "result": None, - "error": None, - } - - def _set_task_result( - self, task_id: str, status: str, result: any = None, error: str | None = None - ) -> None: - self.upload_tasks[task_id] = { - "status": status, - "result": result, - "error": error, - } - if task_id in self.upload_progress: - self.upload_progress[task_id]["status"] = status - - def _update_progress( - self, - task_id: str, - *, - status: str | None = None, - file_index: int | None = None, - file_name: str | None = None, - stage: str | None = None, - current: int | None = None, - total: int | None = None, - ) -> None: - if task_id not in self.upload_progress: - return - p = self.upload_progress[task_id] - if status is not None: - p["status"] = status - if file_index is not None: - p["file_index"] = file_index - if file_name is not None: - p["file_name"] = file_name - if stage is not None: - p["stage"] = stage - if current is not None: - p["current"] = current - if total is not None: - p["total"] = total - - def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): - async def _callback(stage: str, current: int, total: int): - self._update_progress( - task_id, - status="processing", - file_index=file_idx, - file_name=file_name, - stage=stage, - current=current, - total=total, - ) - - return _callback - - async def _background_upload_task( - self, - task_id: str, - kb_helper, - files_to_upload: list, - chunk_size: int, - chunk_overlap: int, - batch_size: int, - tasks_limit: int, - max_retries: int, - ): - """后台上传任务""" - try: - # 初始化任务状态 - self._init_task(task_id, status="processing") - self.upload_progress[task_id] = { - "status": "processing", - "file_index": 0, - "file_total": len(files_to_upload), - "stage": "waiting", - "current": 0, - "total": 100, - } - - uploaded_docs = [] - failed_docs = [] - - for file_idx, file_info in enumerate(files_to_upload): - try: - # 更新整体进度 - self._update_progress( - task_id, - status="processing", - file_index=file_idx, - file_name=file_info["file_name"], - stage="parsing", - current=0, - total=100, - ) - - # 创建进度回调函数 - progress_callback = self._make_progress_callback( - task_id, file_idx, file_info["file_name"] - ) - - doc = await kb_helper.upload_document( - file_name=file_info["file_name"], - file_content=file_info["file_content"], - file_type=file_info["file_type"], - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - progress_callback=progress_callback, - ) - - uploaded_docs.append(doc.model_dump()) - except Exception as e: - logger.error(f"上传文档 {file_info['file_name']} 失败: {e}") - failed_docs.append( - {"file_name": file_info["file_name"], "error": str(e)}, - ) - - # 更新任务完成状态 - result = { - "task_id": task_id, - "uploaded": uploaded_docs, - "failed": failed_docs, - "total": len(files_to_upload), - "success_count": len(uploaded_docs), - "failed_count": len(failed_docs), - } - - self._set_task_result(task_id, "completed", result=result) - - except Exception as e: - logger.error(f"后台上传任务 {task_id} 失败: {e}") - logger.error(traceback.format_exc()) - self._set_task_result(task_id, "failed", error=str(e)) - - async def _background_import_task( - self, - task_id: str, - kb_helper, - documents: list, - batch_size: int, - tasks_limit: int, - max_retries: int, - ): - """后台导入预切片文档任务""" - try: - # 初始化任务状态 - self._init_task(task_id, status="processing") - self.upload_progress[task_id] = { - "status": "processing", - "file_index": 0, - "file_total": len(documents), - "stage": "waiting", - "current": 0, - "total": 100, - } - - uploaded_docs = [] - failed_docs = [] - - for file_idx, doc_info in enumerate(documents): - file_name = doc_info.get("file_name", f"imported_doc_{file_idx}") - chunks = doc_info.get("chunks", []) - - try: - # 更新整体进度 - self._update_progress( - task_id, - status="processing", - file_index=file_idx, - file_name=file_name, - stage="importing", - current=0, - total=100, - ) - - # 创建进度回调函数 - progress_callback = self._make_progress_callback( - task_id, file_idx, file_name - ) - - # 调用 upload_document,传入 pre_chunked_text - doc = await kb_helper.upload_document( - file_name=file_name, - file_content=None, # 预切片模式下不需要原始内容 - file_type=doc_info.get("file_type") - or ( - file_name.rsplit(".", 1)[-1].lower() - if "." in file_name - else "txt" - ), - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - progress_callback=progress_callback, - pre_chunked_text=chunks, - ) - - uploaded_docs.append(doc.model_dump()) - except Exception as e: - logger.error(f"导入文档 {file_name} 失败: {e}") - failed_docs.append( - {"file_name": file_name, "error": str(e)}, - ) - - # 更新任务完成状态 - result = { - "task_id": task_id, - "uploaded": uploaded_docs, - "failed": failed_docs, - "total": len(documents), - "success_count": len(uploaded_docs), - "failed_count": len(failed_docs), - } - - self._set_task_result(task_id, "completed", result=result) - - except Exception as e: - logger.error(f"后台导入任务 {task_id} 失败: {e}") - logger.error(traceback.format_exc()) - self._set_task_result(task_id, "failed", error=str(e)) - - async def list_kbs(self): - """获取知识库列表 - - Query 参数: - - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) - - refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true) - """ - try: - kb_manager = self._get_kb_manager() - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) - - kbs = await kb_manager.list_kbs() - - # 转换为字典列表 - kb_list = [] - for kb in kbs: - kb_list.append(kb.model_dump()) - - return ( - Response() - .ok({"items": kb_list, "page": page, "page_size": page_size}) - .__dict__ - ) - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取知识库列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取知识库列表失败: {e!s}").__dict__ - - async def create_kb(self): - """创建知识库 - - Body: - - kb_name: 知识库名称 (必填) - - description: 描述 (可选) - - emoji: 图标 (可选) - - embedding_provider_id: 嵌入模型提供商ID (可选) - - rerank_provider_id: 重排序模型提供商ID (可选) - - chunk_size: 分块大小 (可选, 默认512) - - chunk_overlap: 块重叠大小 (可选, 默认50) - - top_k_dense: 密集检索数量 (可选, 默认50) - - top_k_sparse: 稀疏检索数量 (可选, 默认50) - - top_m_final: 最终返回数量 (可选, 默认5) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - kb_name = data.get("kb_name") - if not kb_name: - return Response().error("知识库名称不能为空").__dict__ - - description = data.get("description") - emoji = data.get("emoji") - embedding_provider_id = data.get("embedding_provider_id") - rerank_provider_id = data.get("rerank_provider_id") - chunk_size = data.get("chunk_size") - chunk_overlap = data.get("chunk_overlap") - top_k_dense = data.get("top_k_dense") - top_k_sparse = data.get("top_k_sparse") - top_m_final = data.get("top_m_final") - - # pre-check embedding dim - if not embedding_provider_id: - return Response().error("缺少参数 embedding_provider_id").__dict__ - prv = await kb_manager.provider_manager.get_provider_by_id( - embedding_provider_id, - ) # type: ignore - if not prv or not isinstance(prv, EmbeddingProvider): - return ( - Response().error(f"嵌入模型不存在或类型错误({type(prv)})").__dict__ - ) - try: - vec = await prv.get_embedding("astrbot") - if len(vec) != prv.get_dim(): - raise ValueError( - f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}", - ) - except Exception as e: - return Response().error(f"测试嵌入模型失败: {e!s}").__dict__ - # pre-check rerank - if rerank_provider_id: - rerank_prv: RerankProvider = ( - await kb_manager.provider_manager.get_provider_by_id( - rerank_provider_id, - ) - ) # type: ignore - if not rerank_prv: - return Response().error("重排序模型不存在").__dict__ - # 检查重排序模型可用性 - try: - res = await rerank_prv.rerank( - query="astrbot", - documents=["astrbot knowledge base"], - ) - if not res: - raise ValueError("重排序模型返回结果异常") - except Exception as e: - return ( - Response() - .error(f"测试重排序模型失败: {e!s},请检查平台日志输出。") - .__dict__ - ) - - kb_helper = await kb_manager.create_kb( - kb_name=kb_name, - description=description, - emoji=emoji, - embedding_provider_id=embedding_provider_id, - rerank_provider_id=rerank_provider_id, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - top_k_dense=top_k_dense, - top_k_sparse=top_k_sparse, - top_m_final=top_m_final, - ) - kb = kb_helper.kb - - return Response().ok(kb.model_dump(), "创建知识库成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"创建知识库失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"创建知识库失败: {e!s}").__dict__ - - async def get_kb(self): - """获取知识库详情 - - Query 参数: - - kb_id: 知识库 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - kb = kb_helper.kb - - return Response().ok(kb.model_dump()).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取知识库详情失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取知识库详情失败: {e!s}").__dict__ - - async def update_kb(self): - """更新知识库 - - Body: - - kb_id: 知识库 ID (必填) - - kb_name: 新的知识库名称 (可选) - - description: 新的描述 (可选) - - emoji: 新的图标 (可选) - - embedding_provider_id: 新的嵌入模型提供商ID (可选) - - rerank_provider_id: 新的重排序模型提供商ID (可选) - - chunk_size: 分块大小 (可选) - - chunk_overlap: 块重叠大小 (可选) - - top_k_dense: 密集检索数量 (可选) - - top_k_sparse: 稀疏检索数量 (可选) - - top_m_final: 最终返回数量 (可选) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - kb_name = data.get("kb_name") - description = data.get("description") - emoji = data.get("emoji") - embedding_provider_id = data.get("embedding_provider_id") - rerank_provider_id = data.get("rerank_provider_id") - chunk_size = data.get("chunk_size") - chunk_overlap = data.get("chunk_overlap") - top_k_dense = data.get("top_k_dense") - top_k_sparse = data.get("top_k_sparse") - top_m_final = data.get("top_m_final") - - # 检查是否至少提供了一个更新字段 - if all( - v is None - for v in [ - kb_name, - description, - emoji, - embedding_provider_id, - rerank_provider_id, - chunk_size, - chunk_overlap, - top_k_dense, - top_k_sparse, - top_m_final, - ] - ): - return Response().error("至少需要提供一个更新字段").__dict__ - - kb_helper = await kb_manager.update_kb( - kb_id=kb_id, - kb_name=kb_name, - description=description, - emoji=emoji, - embedding_provider_id=embedding_provider_id, - rerank_provider_id=rerank_provider_id, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - top_k_dense=top_k_dense, - top_k_sparse=top_k_sparse, - top_m_final=top_m_final, - ) - - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - kb = kb_helper.kb - return Response().ok(kb.model_dump(), "更新知识库成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"更新知识库失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"更新知识库失败: {e!s}").__dict__ - - async def delete_kb(self): - """删除知识库 - - Body: - - kb_id: 知识库 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - success = await kb_manager.delete_kb(kb_id) - if not success: - return Response().error("知识库不存在").__dict__ - - return Response().ok(message="删除知识库成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除知识库失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除知识库失败: {e!s}").__dict__ - - async def get_kb_stats(self): - """获取知识库统计信息 - - Query 参数: - - kb_id: 知识库 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - kb = kb_helper.kb - - stats = { - "kb_id": kb.kb_id, - "kb_name": kb.kb_name, - "doc_count": kb.doc_count, - "chunk_count": kb.chunk_count, - "created_at": kb.created_at.isoformat(), - "updated_at": kb.updated_at.isoformat(), - } - - return Response().ok(stats).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取知识库统计失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取知识库统计失败: {e!s}").__dict__ - - # ===== 文档管理 API ===== - - async def list_documents(self): - """获取文档列表 - - Query 参数: - - kb_id: 知识库 ID (必填) - - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 100, type=int) - - offset = (page - 1) * page_size - limit = page_size - - doc_list = await kb_helper.list_documents(offset=offset, limit=limit) - - doc_list = [doc.model_dump() for doc in doc_list] - - return ( - Response() - .ok({"items": doc_list, "page": page, "page_size": page_size}) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取文档列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取文档列表失败: {e!s}").__dict__ - - async def upload_document(self): - """上传文档 - - 支持两种方式: - 1. multipart/form-data 文件上传(支持多文件,最多10个) - 2. JSON 格式 base64 编码上传(支持多文件,最多10个) - - Form Data (multipart/form-data): - - kb_id: 知识库 ID (必填) - - file: 文件对象 (必填,可多个,字段名为 file, file1, file2, ... 或 files[]) - - JSON Body (application/json): - - kb_id: 知识库 ID (必填) - - files: 文件数组 (必填) - - file_name: 文件名 (必填) - - file_content: base64 编码的文件内容 (必填) - - 返回: - - task_id: 任务ID,用于查询上传进度和结果 - """ - try: - kb_manager = self._get_kb_manager() - - # 检查 Content-Type - content_type = request.content_type - kb_id = None - chunk_size = None - chunk_overlap = None - batch_size = 32 - tasks_limit = 3 - max_retries = 3 - files_to_upload = [] # 存储待上传的文件信息列表 - - if content_type and "multipart/form-data" not in content_type: - return ( - Response().error("Content-Type 须为 multipart/form-data").__dict__ - ) - form_data = await request.form - files = await request.files - - kb_id = form_data.get("kb_id") - chunk_size = int(form_data.get("chunk_size", 512)) - chunk_overlap = int(form_data.get("chunk_overlap", 50)) - batch_size = int(form_data.get("batch_size", 32)) - tasks_limit = int(form_data.get("tasks_limit", 3)) - max_retries = int(form_data.get("max_retries", 3)) - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - # 收集所有文件 - file_list = [] - # 支持 file, file1, file2, ... 或 files[] 格式 - for key in files.keys(): - if key == "file" or key.startswith("file") or key == "files[]": - file_items = files.getlist(key) - file_list.extend(file_items) - - if not file_list: - return Response().error("缺少文件").__dict__ - - # 限制文件数量 - if len(file_list) > 10: - return Response().error("最多只能上传10个文件").__dict__ - - # 处理每个文件 - for file in file_list: - file_name = file.filename - - # 保存到临时文件 - temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}" - await file.save(temp_file_path) - - try: - # 异步读取文件内容 - async with aiofiles.open(temp_file_path, "rb") as f: - file_content = await f.read() - - # 提取文件类型 - file_type = ( - file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "" - ) - - files_to_upload.append( - { - "file_name": file_name, - "file_content": file_content, - "file_type": file_type, - }, - ) - finally: - # 清理临时文件 - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - - # 获取知识库 - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - # 生成任务ID - task_id = str(uuid.uuid4()) - - # 初始化任务状态 - self._init_task(task_id, status="pending") - - # 启动后台任务 - asyncio.create_task( - self._background_upload_task( - task_id=task_id, - kb_helper=kb_helper, - files_to_upload=files_to_upload, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - ), - ) - - return ( - Response() - .ok( - { - "task_id": task_id, - "file_count": len(files_to_upload), - "message": "task created, processing in background", - }, - ) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"上传文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"上传文档失败: {e!s}").__dict__ - - def _validate_import_request(self, data: dict): - kb_id = data.get("kb_id") - if not kb_id: - raise ValueError("缺少参数 kb_id") - - documents = data.get("documents") - if not documents or not isinstance(documents, list): - raise ValueError("缺少参数 documents 或格式错误") - - for doc in documents: - if "file_name" not in doc or "chunks" not in doc: - raise ValueError("文档格式错误,必须包含 file_name 和 chunks") - if not isinstance(doc["chunks"], list): - raise ValueError("chunks 必须是列表") - if not all( - isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"] - ): - raise ValueError("chunks 必须是非空字符串列表") - - batch_size = data.get("batch_size", 32) - tasks_limit = data.get("tasks_limit", 3) - max_retries = data.get("max_retries", 3) - return kb_id, documents, batch_size, tasks_limit, max_retries - - async def import_documents(self): - """导入预切片文档 - - Body: - - kb_id: 知识库 ID (必填) - - documents: 文档列表 (必填) - - file_name: 文件名 (必填) - - chunks: 切片列表 (必填, list[str]) - - file_type: 文件类型 (可选, 默认从文件名推断或为 txt) - - batch_size: 批处理大小 (可选, 默认32) - - tasks_limit: 并发任务限制 (可选, 默认3) - - max_retries: 最大重试次数 (可选, 默认3) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id, documents, batch_size, tasks_limit, max_retries = ( - self._validate_import_request(data) - ) - - # 获取知识库 - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - # 生成任务ID - task_id = str(uuid.uuid4()) - - # 初始化任务状态 - self._init_task(task_id, status="pending") - - # 启动后台任务 - asyncio.create_task( - self._background_import_task( - task_id=task_id, - kb_helper=kb_helper, - documents=documents, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - ), - ) - - return ( - Response() - .ok( - { - "task_id": task_id, - "doc_count": len(documents), - "message": "import task created, processing in background", - }, - ) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"导入文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"导入文档失败: {e!s}").__dict__ - - async def get_upload_progress(self): - """获取上传进度和结果 - - Query 参数: - - task_id: 任务 ID (必填) - - 返回状态: - - pending: 任务待处理 - - processing: 任务处理中 - - completed: 任务完成 - - failed: 任务失败 - """ - try: - task_id = request.args.get("task_id") - if not task_id: - return Response().error("缺少参数 task_id").__dict__ - - # 检查任务是否存在 - if task_id not in self.upload_tasks: - return Response().error("找不到该任务").__dict__ - - task_info = self.upload_tasks[task_id] - status = task_info["status"] - - # 构建返回数据 - response_data = { - "task_id": task_id, - "status": status, - } - - # 如果任务正在处理,返回进度信息 - if status == "processing" and task_id in self.upload_progress: - response_data["progress"] = self.upload_progress[task_id] - - # 如果任务完成,返回结果 - if status == "completed": - response_data["result"] = task_info["result"] - # 清理已完成的任务 - # del self.upload_tasks[task_id] - # if task_id in self.upload_progress: - # del self.upload_progress[task_id] - - # 如果任务失败,返回错误信息 - if status == "failed": - response_data["error"] = task_info["error"] - - return Response().ok(response_data).__dict__ - - except Exception as e: - logger.error(f"获取上传进度失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取上传进度失败: {e!s}").__dict__ - - async def get_document(self): - """获取文档详情 - - Query 参数: - - doc_id: 文档 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - doc_id = request.args.get("doc_id") - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - doc = await kb_helper.get_document(doc_id) - if not doc: - return Response().error("文档不存在").__dict__ - - return Response().ok(doc.model_dump()).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取文档详情失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取文档详情失败: {e!s}").__dict__ - - async def delete_document(self): - """删除文档 - - Body: - - kb_id: 知识库 ID (必填) - - doc_id: 文档 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - doc_id = data.get("doc_id") - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - await kb_helper.delete_document(doc_id) - return Response().ok(message="删除文档成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除文档失败: {e!s}").__dict__ - - async def delete_chunk(self): - """删除文本块 - - Body: - - kb_id: 知识库 ID (必填) - - chunk_id: 块 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - chunk_id = data.get("chunk_id") - if not chunk_id: - return Response().error("缺少参数 chunk_id").__dict__ - doc_id = data.get("doc_id") - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - await kb_helper.delete_chunk(chunk_id, doc_id) - return Response().ok(message="删除文本块成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除文本块失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除文本块失败: {e!s}").__dict__ - - async def list_chunks(self): - """获取块列表 - - Query 参数: - - kb_id: 知识库 ID (必填) - - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - doc_id = request.args.get("doc_id") - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 100, type=int) - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - kb_helper = await kb_manager.get_kb(kb_id) - offset = (page - 1) * page_size - limit = page_size - if not kb_helper: - return Response().error("知识库不存在").__dict__ - chunk_list = await kb_helper.get_chunks_by_doc_id( - doc_id=doc_id, - offset=offset, - limit=limit, - ) - return ( - Response() - .ok( - data={ - "items": chunk_list, - "page": page, - "page_size": page_size, - "total": await kb_helper.get_chunk_count_by_doc_id(doc_id), - }, - ) - .__dict__ - ) - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取块列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取块列表失败: {e!s}").__dict__ - - # ===== 检索 API ===== - - async def retrieve(self): - """检索知识库 - - Body: - - query: 查询文本 (必填) - - kb_ids: 知识库 ID 列表 (必填) - - top_k: 返回结果数量 (可选, 默认 5) - - debug: 是否启用调试模式,返回 t-SNE 可视化图片 (可选, 默认 False) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - query = data.get("query") - kb_names = data.get("kb_names") - debug = data.get("debug", False) - - if not query: - return Response().error("缺少参数 query").__dict__ - if not kb_names or not isinstance(kb_names, list): - return Response().error("缺少参数 kb_names 或格式错误").__dict__ - - top_k = data.get("top_k", 5) - - results = await kb_manager.retrieve( - query=query, - kb_names=kb_names, - top_m_final=top_k, - ) - result_list = [] - if results: - result_list = results["results"] - - response_data = { - "results": result_list, - "total": len(result_list), - "query": query, - } - - # Debug 模式:生成 t-SNE 可视化 - if debug: - try: - img_base64 = await generate_tsne_visualization( - query, - kb_names, - kb_manager, - ) - if img_base64: - response_data["visualization"] = img_base64 - except Exception as e: - logger.error(f"生成 t-SNE 可视化失败: {e}") - logger.error(traceback.format_exc()) - response_data["visualization_error"] = str(e) - - return Response().ok(response_data).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"检索失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"检索失败: {e!s}").__dict__ - - async def upload_document_from_url(self): - """从 URL 上传文档 - - Body: - - kb_id: 知识库 ID (必填) - - url: 要提取内容的网页 URL (必填) - - chunk_size: 分块大小 (可选, 默认512) - - chunk_overlap: 块重叠大小 (可选, 默认50) - - batch_size: 批处理大小 (可选, 默认32) - - tasks_limit: 并发任务限制 (可选, 默认3) - - max_retries: 最大重试次数 (可选, 默认3) - - 返回: - - task_id: 任务ID,用于查询上传进度和结果 - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - url = data.get("url") - if not url: - return Response().error("缺少参数 url").__dict__ - - chunk_size = data.get("chunk_size", 512) - chunk_overlap = data.get("chunk_overlap", 50) - batch_size = data.get("batch_size", 32) - tasks_limit = data.get("tasks_limit", 3) - max_retries = data.get("max_retries", 3) - enable_cleaning = data.get("enable_cleaning", False) - cleaning_provider_id = data.get("cleaning_provider_id") - - # 获取知识库 - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - # 生成任务ID - task_id = str(uuid.uuid4()) - - # 初始化任务状态 - self._init_task(task_id, status="pending") - - # 启动后台任务 - asyncio.create_task( - self._background_upload_from_url_task( - task_id=task_id, - kb_helper=kb_helper, - url=url, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - enable_cleaning=enable_cleaning, - cleaning_provider_id=cleaning_provider_id, - ), - ) - - return ( - Response() - .ok( - { - "task_id": task_id, - "url": url, - "message": "URL upload task created, processing in background", - }, - ) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"从URL上传文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"从URL上传文档失败: {e!s}").__dict__ - - async def _background_upload_from_url_task( - self, - task_id: str, - kb_helper, - url: str, - chunk_size: int, - chunk_overlap: int, - batch_size: int, - tasks_limit: int, - max_retries: int, - enable_cleaning: bool, - cleaning_provider_id: str | None, - ): - """后台上传URL任务""" - try: - # 初始化任务状态 - self._init_task(task_id, status="processing") - self.upload_progress[task_id] = { - "status": "processing", - "file_index": 0, - "file_total": 1, - "file_name": f"URL: {url}", - "stage": "extracting", - "current": 0, - "total": 100, - } - - # 创建进度回调函数 - progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") - - # 上传文档 - doc = await kb_helper.upload_from_url( - url=url, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - progress_callback=progress_callback, - enable_cleaning=enable_cleaning, - cleaning_provider_id=cleaning_provider_id, - ) - - # 更新任务完成状态 - result = { - "task_id": task_id, - "uploaded": [doc.model_dump()], - "failed": [], - "total": 1, - "success_count": 1, - "failed_count": 0, - } - - self._set_task_result(task_id, "completed", result=result) - - except Exception as e: - logger.error(f"后台上传URL任务 {task_id} 失败: {e}") - logger.error(traceback.format_exc()) - self._set_task_result(task_id, "failed", error=str(e)) diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 01ab292d4..9cf17e41a 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from dataclasses import dataclass from quart import Quart @@ -12,31 +13,42 @@ class RouteContext: class Route: - routes: list | dict + routes: dict[ + str, + tuple[str, Callable] + | list[tuple[str, Callable]] + | tuple[str, Callable, str] + | list[tuple[str, Callable, str]], + ] def __init__(self, context: RouteContext): self.app = context.app self.config = context.config def register_routes(self): - def _add_rule(path, method, func): + def _add_rule(path, method, func, endpoint: str | None = None): # 统一添加 /api 前缀 full_path = f"/api{path}" - self.app.add_url_rule(full_path, view_func=func, methods=[method]) + self.app.add_url_rule( + full_path, view_func=func, methods=[method], endpoint=endpoint + ) - # 兼容字典和列表两种格式 - routes_to_register = ( - self.routes.items() if isinstance(self.routes, dict) else self.routes - ) - - for route, definition in routes_to_register: - # 兼容一个路由多个方法 - if isinstance(definition, list): - for method, func in definition: - _add_rule(route, method, func) + for route, defi in self.routes.items(): + if isinstance(defi, list): + for item in defi: + if len(item) == 2: + method, func = item + _add_rule(route, method, func) + elif len(item) == 3: + method, func, endpoint = item + _add_rule(route, method, func, endpoint) else: - method, func = definition - _add_rule(route, method, func) + if len(defi) == 2: + method, func = defi + _add_rule(route, method, func) + elif len(defi) == 3: + method, func, endpoint = defi + _add_rule(route, method, func, endpoint) @dataclass diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index db70a8820..b7b0d36f4 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -18,22 +18,18 @@ def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle): self.config = core_lifecycle.astrbot_config self.manager = TemplateManager() # 使用列表保证路由注册顺序,避免 / 路由优先匹配 /reset_default - self.routes = [ - ("/t2i/templates", ("GET", self.list_templates)), - ("/t2i/templates/active", ("GET", self.get_active_template)), - ("/t2i/templates/create", ("POST", self.create_template)), - ("/t2i/templates/reset_default", ("POST", self.reset_default_template)), - ("/t2i/templates/set_active", ("POST", self.set_active_template)), - # 动态路由应该在静态路由之后注册 - ( - "/t2i/templates/", - [ - ("GET", self.get_template), - ("PUT", self.update_template), - ("DELETE", self.delete_template), - ], - ), - ] + self.routes = { + "/t2i/templates": ("GET", self.list_templates), + "/t2i/templates/active": ("GET", self.get_active_template), + "/t2i/templates/create": ("POST", self.create_template), + "/t2i/templates/reset_default": ("POST", self.reset_default_template), + "/t2i/templates/set_active": ("POST", self.set_active_template), + "/t2i/templates/": [ + ("GET", self.get_template), + ("PUT", self.update_template), + ("DELETE", self.delete_template), + ], + } self.register_routes() async def list_templates(self): diff --git a/astrbot/dashboard/routes/v1/README.md b/astrbot/dashboard/routes/v1/README.md new file mode 100644 index 000000000..127234d7b --- /dev/null +++ b/astrbot/dashboard/routes/v1/README.md @@ -0,0 +1,5 @@ +# AstrBot API + +为了更好地让外部系统与 AstrBot 进行交互,我们提供了可对外暴露的 API。 + +并且 AstrBot WebUI 也将逐渐切换到此 API 规范来进行前后端交互。 diff --git a/astrbot/dashboard/routes/v1/__init__.py b/astrbot/dashboard/routes/v1/__init__.py new file mode 100644 index 000000000..2f7446703 --- /dev/null +++ b/astrbot/dashboard/routes/v1/__init__.py @@ -0,0 +1,9 @@ +from .chat import V1ChatRoute +from .knowledge_base import V1KnowledgeBaseRoute +from .provider import V1ProviderRoute + +__all__ = [ + "V1ChatRoute", + "V1KnowledgeBaseRoute", + "V1ProviderRoute", +] diff --git a/astrbot/dashboard/routes/v1/chat.py b/astrbot/dashboard/routes/v1/chat.py new file mode 100644 index 000000000..ea7bedc31 --- /dev/null +++ b/astrbot/dashboard/routes/v1/chat.py @@ -0,0 +1,41 @@ +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase + +from ...services.chat import ChatService +from ..route import Route, RouteContext + + +class V1ChatRoute(Route): + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + db: BaseDatabase, + ): + super().__init__(context) + self.chat_service = ChatService(core_lifecycle, db) + self.routes = { + "/v1/chat": [("POST", self.chat_service.chat, "v1_chat_send")], + "/v1/chat/sessions": [ + ("GET", self.chat_service.get_sessions, "v1_chat_sessions") + ], + "/v1/chat/session": [ + ("POST", self.chat_service.new_session, "v1_chat_new_session") + ], + "/v1/chat/session/": [ + ("GET", self.chat_service.get_session, "v1_chat_get_session"), + ("DELETE", self.chat_service.delete_session, "v1_chat_delete_session"), + ( + "PUT", + self.chat_service.update_session_display_name, + "v1_chat_update_session_display_name", + ), + ], + "/v1/chat/attachment": [ + ("POST", self.chat_service.post_file, "v1_chat_post_file") + ], + "/v1/chat/attachment/": [ + ("GET", self.chat_service.get_attachment, "v1_chat_get_attachment") + ], + } + self.register_routes() diff --git a/astrbot/dashboard/routes/v1/knowledge_base.py b/astrbot/dashboard/routes/v1/knowledge_base.py new file mode 100644 index 000000000..91195d566 --- /dev/null +++ b/astrbot/dashboard/routes/v1/knowledge_base.py @@ -0,0 +1,57 @@ +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + +from ...services.knowledge_base import KnowledgeBaseService +from ..route import Route, RouteContext + + +class V1KnowledgeBaseRoute(Route): + def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle): + super().__init__(context) + self.kb_service = KnowledgeBaseService(core_lifecycle) + self.routes = { + "/v1/kbs": [("GET", self.kb_service.list_kbs, "v1_kb_list_kbs")], + "/v1/kb": [ + ("POST", self.kb_service.create_kb, "v1_kb_create_kb"), + ], + "/v1/kb/": [ + ("GET", self.kb_service.get_kb, "v1_kb_get_kb"), + ("PUT", self.kb_service.update_kb, "v1_kb_update_kb"), + ("DELETE", self.kb_service.delete_kb, "v1_kb_delete_kb"), + ], + "/v1/kb//documents": ( + "GET", + self.kb_service.list_documents, + "v1_kb_list_documents", + ), + "/v1/kb//document-file": [ + ("POST", self.kb_service.upload_document, "v1_kb_upload_document"), + ], + "/v1/kb//document-url": [ + ( + "POST", + self.kb_service.upload_document_from_url, + "v1_kb_upload_document_from_url", + ), + ], + "/v1/kb//document-preload": [ + ("POST", self.kb_service.import_documents, "v1_kb_import_documents"), + ], + "/v1/kb//document-progress/": [ + ( + "GET", + self.kb_service.get_upload_progress, + "v1_kb_get_upload_progress", + ), + ], + "/v1/kb//document/": [ + ("GET", self.kb_service.get_document, "v1_kb_get_document"), + ("DELETE", self.kb_service.delete_document, "v1_kb_delete_document"), + ], + "/v1/kb//document//chunks": [ + ("GET", self.kb_service.list_chunks, "v1_kb_list_chunks"), + ], + "/v1/kb//document//chunk/": [ + ("DELETE", self.kb_service.delete_chunk, "v1_kb_delete_chunk"), + ], + } + self.register_routes() diff --git a/astrbot/dashboard/routes/v1/provider.py b/astrbot/dashboard/routes/v1/provider.py new file mode 100644 index 000000000..4d7acbe16 --- /dev/null +++ b/astrbot/dashboard/routes/v1/provider.py @@ -0,0 +1,59 @@ +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + +from ...services.provider import ProviderService +from ..route import Route, RouteContext + + +class V1ProviderRoute(Route): + def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle): + super().__init__(context) + self.provider_service = ProviderService(core_lifecycle) + self.routes = { + # provider source + "/v1/provider-source": [ + ( + "PUT", + self.provider_service.update_provider_source, + "v1_provider_source_update", + ), + ( + "DELETE", + self.provider_service.delete_provider_source, + "v1_provider_source_delete", + ), + ( + "GET", + self.provider_service.list_provider_sources, + "v1_provider_source_list", + ), + ( + "GET", + self.provider_service.get_provider_source_models, + "v1_provider_source_models", + ), + ], + # provider + "/v1/provider": [ + ( + "POST", + self.provider_service.post_new_provider, + "v1_provider_post_new_provider", + ), + ( + "PUT", + self.provider_service.post_update_provider, + "v1_provider_post_update_provider", + ), + ( + "DELETE", + self.provider_service.post_delete_provider, + "v1_provider_post_delete_provider", + ), + ( + "GET", + self.provider_service.get_provider_config_list, + "v1_provider_get_provider_config_list", + ), + ], + } + self.register_routes() diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 6d6530c90..52dbdfcfd 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -23,8 +23,7 @@ from .routes.route import Response, RouteContext from .routes.session_management import SessionManagementRoute from .routes.t2i import T2iRoute - -APP: Quart +from .routes.v1 import * class AstrBotDashboard: @@ -36,9 +35,9 @@ def __init__( webui_dir: str | None = None, ) -> None: self.core_lifecycle = core_lifecycle + self.db = db self.config = core_lifecycle.astrbot_config - # 参数指定webui目录 if webui_dir and os.path.exists(webui_dir): self.data_path = os.path.abspath(webui_dir) else: @@ -47,15 +46,14 @@ def __init__( ) self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/") - APP = self.app # noqa - self.app.config["MAX_CONTENT_LENGTH"] = ( - 128 * 1024 * 1024 - ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB + # 将 Flask 允许的最大上传文件体大小设置为 128 MB + self.app.config["MAX_CONTENT_LENGTH"] = 128 * 1024 * 1024 cast(DefaultJSONProvider, self.app.json).sort_keys = False self.app.before_request(self.auth_middleware) - # token 用于验证请求 logging.getLogger(self.app.name).removeHandler(default_handler) self.context = RouteContext(self.config, self.app) + + # Internal API self.ur = UpdateRoute( self.context, core_lifecycle.astrbot_updator, @@ -85,6 +83,7 @@ def __init__( self.t2i_route = T2iRoute(self.context, core_lifecycle) self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle) self.platform_route = PlatformRoute(self.context, core_lifecycle) + self.api_key_route = ApiKeyRoute(self.context, core_lifecycle, db) self.app.add_url_rule( "/api/plug/", @@ -92,6 +91,13 @@ def __init__( methods=["GET", "POST"], ) + # External API + self.v1_chat_route = V1ChatRoute(self.context, core_lifecycle, db) + self.v1_knowledge_base_route = V1KnowledgeBaseRoute( + self.context, core_lifecycle + ) + self.v1_provider_route = V1ProviderRoute(self.context, core_lifecycle) + self.shutdown_event = shutdown_event self._init_jwt_secret() @@ -111,21 +117,43 @@ async def auth_middleware(self): allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"] if any(request.path.startswith(prefix) for prefix in allowed_endpoints): return None - # 声明 JWT - token = request.headers.get("Authorization") - if not token: + + # 对于 /v1 路由,支持 API Key 认证 + is_v1_route = request.path.startswith("/api/v1") + auth_header = request.headers.get("Authorization") + + if not auth_header: r = jsonify(Response().error("未授权").__dict__) r.status_code = 401 return r - token = token.removeprefix("Bearer ") + + auth_value = auth_header.removeprefix("Bearer ") + + # 尝试 API Key 认证(仅对 /v1 路由) + if is_v1_route: + from .services.api_key import ApiKeyService + + api_key_service = ApiKeyService(self.core_lifecycle, self.db) + api_key_obj = await api_key_service.verify_api_key(auth_value) + if api_key_obj: + g.username = api_key_obj.username + g.api_key_id = api_key_obj.key_id + return None + + # JWT 认证(用于 WebUI 和 /v1 路由的备选方案) try: - payload = jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) + payload = jwt.decode(auth_value, self._jwt_secret, algorithms=["HS256"]) g.username = payload["username"] except jwt.ExpiredSignatureError: r = jsonify(Response().error("Token 过期").__dict__) r.status_code = 401 return r except jwt.InvalidTokenError: + if is_v1_route: + # 对于 /v1 路由,如果 JWT 也失败,返回错误 + r = jsonify(Response().error("API Key 或 Token 无效").__dict__) + r.status_code = 401 + return r r = jsonify(Response().error("Token 无效").__dict__) r.status_code = 401 return r diff --git a/astrbot/dashboard/services/__init__.py b/astrbot/dashboard/services/__init__.py new file mode 100644 index 000000000..2b128f7d0 --- /dev/null +++ b/astrbot/dashboard/services/__init__.py @@ -0,0 +1,7 @@ +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + + +class BaseService: + def __init__(self, core_lifecycle: AstrBotCoreLifecycle): + self.cl = core_lifecycle + self.clpm = core_lifecycle.provider_manager diff --git a/astrbot/dashboard/services/api_key.py b/astrbot/dashboard/services/api_key.py new file mode 100644 index 000000000..2694225ce --- /dev/null +++ b/astrbot/dashboard/services/api_key.py @@ -0,0 +1,152 @@ +"""API Key 服务 + +提供 API Key 的创建、查询、删除等业务逻辑 +""" + +import hashlib +import secrets +from datetime import datetime, timezone + +from quart import g, request +from sqlalchemy import select +from sqlmodel import col, desc + +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import ApiKey + +from ..entities import Response +from . import BaseService + + +class ApiKeyService(BaseService): + """API Key 服务""" + + def __init__(self, core_lifecycle, db: BaseDatabase): + super().__init__(core_lifecycle) + self.db = db + + def _generate_api_key(self) -> str: + """生成一个安全的 API Key""" + # 生成 32 字节的随机 token,然后编码为 base64 URL-safe 字符串 + token = secrets.token_urlsafe(32) + # 添加前缀以便识别 + return f"astrbot_{token}" + + def _hash_api_key(self, api_key: str) -> str: + """对 API Key 进行哈希""" + return hashlib.sha256(api_key.encode()).hexdigest() + + async def create_api_key(self): + """创建新的 API Key""" + post_data = await request.json + name = post_data.get("name", "") + + # 获取当前用户名(从 JWT token 中) + username = getattr(g, "username", None) + if not username: + return Response().error("未授权").__dict__ + + # 生成新的 API Key + raw_api_key = self._generate_api_key() + hashed_key = self._hash_api_key(raw_api_key) + + # 创建数据库记录 + async with self.db.get_db() as session: + api_key_obj = ApiKey( + api_key=hashed_key, + username=username, + name=name if name else None, + ) + session.add(api_key_obj) + await session.commit() + await session.refresh(api_key_obj) + + # 返回完整的 API Key(只在创建时返回一次) + return ( + Response() + .ok( + { + "key_id": api_key_obj.key_id, + "api_key": raw_api_key, # 只返回一次,前端需要保存 + "name": api_key_obj.name, + "username": api_key_obj.username, + "created_at": api_key_obj.created_at.isoformat(), + }, + "API Key 创建成功,请妥善保管", + ) + .__dict__ + ) + + async def list_api_keys(self): + """列出当前用户的所有 API Keys""" + username = getattr(g, "username", None) + if not username: + return Response().error("未授权").__dict__ + + async with self.db.get_db() as session: + stmt = ( + select(ApiKey) + .where(col(ApiKey.username) == username) + .order_by(desc(ApiKey.created_at)) + ) + result = await session.execute(stmt) + api_keys = result.scalars().all() + + # 不返回实际的 API Key,只返回元数据 + keys_data = [ + { + "key_id": key.key_id, + "name": key.name, + "username": key.username, + "created_at": key.created_at.isoformat(), + "expires_at": key.expires_at.isoformat() + if key.expires_at + else None, + "last_used_at": key.last_used_at.isoformat() + if key.last_used_at + else None, + } + for key in api_keys + ] + + return Response().ok(keys_data).__dict__ + + async def delete_api_key(self, key_id: str): + """删除指定的 API Key""" + username = getattr(g, "username", None) + if not username: + return Response().error("未授权").__dict__ + + async with self.db.get_db() as session: + stmt = select(ApiKey).where( + col(ApiKey.key_id) == key_id, col(ApiKey.username) == username + ) + result = await session.execute(stmt) + api_key = result.scalar_one_or_none() + + if not api_key: + return Response().error("API Key 不存在或无权限").__dict__ + + await session.delete(api_key) + await session.commit() + + return Response().ok(None, "API Key 删除成功").__dict__ + + async def verify_api_key(self, api_key: str) -> ApiKey | None: + """验证 API Key 是否有效 + + 返回对应的 ApiKey 对象,如果无效则返回 None + """ + hashed_key = self._hash_api_key(api_key) + + async with self.db.get_db() as session: + stmt = select(ApiKey).where(col(ApiKey.api_key) == hashed_key) + result = await session.execute(stmt) + api_key_obj = result.scalar_one_or_none() + + if api_key_obj: + # 更新最后使用时间 + api_key_obj.last_used_at = datetime.now(timezone.utc) + await session.commit() + + return api_key_obj diff --git a/astrbot/dashboard/services/chat.py b/astrbot/dashboard/services/chat.py new file mode 100644 index 000000000..304c46ab1 --- /dev/null +++ b/astrbot/dashboard/services/chat.py @@ -0,0 +1,749 @@ +"""聊天服务 + +提供聊天会话、消息处理、附件管理等业务逻辑 +""" + +import asyncio +import json +import mimetypes +import os +import uuid +from contextlib import asynccontextmanager +from typing import cast + +from quart import Response as QuartResponse +from quart import g, make_response, request, send_file + +from astrbot.core import logger +from astrbot.core.db import BaseDatabase +from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +from ..entities import Response +from . import BaseService + + +@asynccontextmanager +async def track_conversation(convs: dict, conv_id: str): + convs[conv_id] = True + try: + yield + finally: + convs.pop(conv_id, None) + + +class ChatService(BaseService): + """聊天服务 + + 提供聊天会话、消息处理、附件管理等业务逻辑 + """ + + def __init__(self, core_lifecycle, db: BaseDatabase): + super().__init__(core_lifecycle) + self.db = db + self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + os.makedirs(self.imgs_dir, exist_ok=True) + self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] + self.conv_mgr = core_lifecycle.conversation_manager + self.platform_history_mgr = core_lifecycle.platform_message_history_manager + self.umop_config_router = core_lifecycle.umop_config_router + self.running_convs: dict[str, bool] = {} + + async def get_file(self): + """获取文件""" + filename = request.args.get("filename") + if not filename: + return Response().error("Missing key: filename").__dict__ + + try: + file_path = os.path.join(self.imgs_dir, os.path.basename(filename)) + real_file_path = os.path.realpath(file_path) + real_imgs_dir = os.path.realpath(self.imgs_dir) + + if not real_file_path.startswith(real_imgs_dir): + return Response().error("Invalid file path").__dict__ + + filename_ext = os.path.splitext(filename)[1].lower() + if filename_ext == ".wav": + return await send_file(real_file_path, mimetype="audio/wav") + if filename_ext[1:] in self.supported_imgs: + return await send_file(real_file_path, mimetype="image/jpeg") + return await send_file(real_file_path) + + except (FileNotFoundError, OSError): + return Response().error("File access error").__dict__ + + async def get_attachment(self, attachment_id: str | None = None): + """获取附件文件 + + 路径参数或查询参数: + - attachment_id: 附件ID (必填) + """ + if attachment_id is None: + # 从路径参数获取 + attachment_id = ( + request.view_args.get("attachment_id") if request.view_args else None + ) + if not attachment_id: + # 从查询参数获取 + attachment_id = request.args.get("attachment_id") + if not attachment_id: + return Response().error("Missing key: attachment_id").__dict__ + + try: + attachment = await self.db.get_attachment_by_id(attachment_id) + if not attachment: + return Response().error("Attachment not found").__dict__ + + file_path = attachment.path + real_file_path = os.path.realpath(file_path) + + return await send_file(real_file_path, mimetype=attachment.mime_type) + + except (FileNotFoundError, OSError): + return Response().error("File access error").__dict__ + + async def post_file(self): + """上传文件并创建附件记录""" + post_data = await request.files + if "file" not in post_data: + return Response().error("Missing key: file").__dict__ + + file = post_data["file"] + + filename = file.filename or f"{uuid.uuid4()!s}" + content_type = file.content_type or "application/octet-stream" + + # 根据 content_type 判断文件类型并添加扩展名 + if content_type.startswith("image"): + attach_type = "image" + elif content_type.startswith("audio"): + attach_type = "record" + elif content_type.startswith("video"): + attach_type = "video" + else: + attach_type = "file" + + path = os.path.join(self.imgs_dir, filename) + await file.save(path) + + # 创建 attachment 记录 + attachment = await self.db.insert_attachment( + path=path, + type=attach_type, + mime_type=content_type, + ) + + if not attachment: + return Response().error("Failed to create attachment").__dict__ + + filename = os.path.basename(attachment.path) + + return ( + Response() + .ok( + data={ + "attachment_id": attachment.attachment_id, + "filename": filename, + "type": attach_type, + } + ) + .__dict__ + ) + + async def _build_user_message_parts(self, message: str | list) -> list[dict]: + """构建用户消息的部分列表 + + Args: + message: 文本消息 (str) 或消息段列表 (list) + """ + parts = [] + + if isinstance(message, list): + for part in message: + part_type = part.get("type") + if part_type == "plain": + parts.append({"type": "plain", "text": part.get("text", "")}) + elif part_type == "reply": + parts.append( + {"type": "reply", "message_id": part.get("message_id")} + ) + elif attachment_id := part.get("attachment_id"): + attachment = await self.db.get_attachment_by_id(attachment_id) + if attachment: + parts.append( + { + "type": attachment.type, + "attachment_id": attachment.attachment_id, + "filename": os.path.basename(attachment.path), + "path": attachment.path, # will be deleted + } + ) + return parts + + if message: + parts.append({"type": "plain", "text": message}) + + return parts + + async def _create_attachment_from_file( + self, filename: str, attach_type: str + ) -> dict | None: + """从本地文件创建 attachment 并返回消息部分 + + 用于处理 bot 回复中的媒体文件 + + Args: + filename: 存储的文件名 + attach_type: 附件类型 (image, record, file, video) + """ + file_path = os.path.join(self.imgs_dir, os.path.basename(filename)) + if not os.path.exists(file_path): + return None + + # guess mime type + mime_type, _ = mimetypes.guess_type(filename) + if not mime_type: + mime_type = "application/octet-stream" + + # insert attachment + attachment = await self.db.insert_attachment( + path=file_path, + type=attach_type, + mime_type=mime_type, + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": os.path.basename(file_path), + } + + async def _save_bot_message( + self, + webchat_conv_id: str, + text: str, + media_parts: list, + reasoning: str, + agent_stats: dict, + ): + """保存 bot 消息到历史记录,返回保存的记录""" + bot_message_parts = [] + bot_message_parts.extend(media_parts) + if text: + bot_message_parts.append({"type": "plain", "text": text}) + + new_his = {"type": "bot", "message": bot_message_parts} + if reasoning: + new_his["reasoning"] = reasoning + if agent_stats: + new_his["agent_stats"] = agent_stats + + record = await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id="bot", + sender_name="bot", + ) + return record + + async def chat(self): + """处理聊天消息并返回流式响应 + + Body: + - message: 消息内容(文本或消息段列表)(必填) + - session_id 或 conversation_id: 会话ID (必填) + - selected_provider: 选中的提供商 (可选) + - selected_model: 选中的模型 (可选) + - enable_streaming: 是否启用流式响应 (可选, 默认 True) + + Returns: + QuartResponse: SSE 流式响应 + """ + username = g.get("username", "guest") + post_data = await request.json + + if "message" not in post_data and "files" not in post_data: + return Response().error("Missing key: message or files").__dict__ + + if "session_id" not in post_data and "conversation_id" not in post_data: + return ( + Response().error("Missing key: session_id or conversation_id").__dict__ + ) + + message = post_data["message"] + session_id = post_data.get("session_id", post_data.get("conversation_id")) + selected_provider = post_data.get("selected_provider") + selected_model = post_data.get("selected_model") + enable_streaming = post_data.get("enable_streaming", True) + + # 检查消息是否为空 + if isinstance(message, list): + has_content = any( + part.get("type") in ("plain", "image", "record", "file", "video") + for part in message + ) + if not has_content: + return ( + Response() + .error("Message content is empty (reply only is not allowed)") + .__dict__ + ) + elif not message: + return Response().error("Message are both empty").__dict__ + + if not session_id: + return Response().error("session_id is empty").__dict__ + + webchat_conv_id = session_id + back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) + + # 构建用户消息段(包含 path 用于传递给 adapter) + message_parts = await self._build_user_message_parts(message) + + async def stream(): + client_disconnected = False + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + tool_calls = {} + agent_stats = {} + try: + async with track_conversation(self.running_convs, webchat_conv_id): + while True: + try: + result = await asyncio.wait_for(back_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") + client_disconnected = True + except Exception as e: + logger.error(f"WebChat stream error: {e}") + + if not result: + continue + + result_text = result["data"] + msg_type = result.get("type") + streaming = result.get("streaming", False) + chain_type = result.get("chain_type") + + if chain_type == "agent_stats": + stats_info = { + "type": "agent_stats", + "data": json.loads(result_text), + } + yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n" + agent_stats = stats_info["data"] + continue + + # 发送 SSE 数据 + try: + if not client_disconnected: + yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" + except Exception as e: + if not client_disconnected: + logger.debug( + f"[WebChat] 用户 {username} 断开聊天长连接。 {e}" + ) + client_disconnected = True + + try: + if not client_disconnected: + await asyncio.sleep(0.05) + except asyncio.CancelledError: + logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") + client_disconnected = True + + # 累积消息部分 + if msg_type == "plain": + chain_type = result.get("chain_type") + if chain_type == "tool_call": + tool_call = json.loads(result_text) + tool_calls[tool_call.get("id")] = tool_call + if accumulated_text: + # 如果累积了文本,则先保存文本 + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + elif chain_type == "tool_call_result": + tcr = json.loads(result_text) + tc_id = tcr.get("id") + if tc_id in tool_calls: + tool_calls[tc_id]["result"] = tcr.get("result") + tool_calls[tc_id]["finished_ts"] = tcr.get("ts") + accumulated_parts.append( + { + "type": "tool_call", + "tool_calls": [tool_calls[tc_id]], + } + ) + tool_calls.pop(tc_id, None) + elif chain_type == "reasoning": + accumulated_reasoning += result_text + elif streaming: + accumulated_text += result_text + else: + accumulated_text = result_text + elif msg_type == "image": + filename = result_text.replace("[IMAGE]", "") + part = await self._create_attachment_from_file( + filename, "image" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = result_text.replace("[RECORD]", "") + part = await self._create_attachment_from_file( + filename, "record" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "file": + # 格式: [FILE]filename + filename = result_text.replace("[FILE]", "") + part = await self._create_attachment_from_file( + filename, "file" + ) + if part: + accumulated_parts.append(part) + + # 消息结束处理 + if msg_type == "end": + break + elif ( + (streaming and msg_type == "complete") or not streaming + # or msg_type == "break" + ): + if ( + chain_type == "tool_call" + or chain_type == "tool_call_result" + ): + continue + saved_record = await self._save_bot_message( + webchat_conv_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, + agent_stats, + ) + # 发送保存的消息信息给前端 + if saved_record and not client_disconnected: + saved_info = { + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + } + try: + yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n" + except Exception: + pass + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + # tool_calls = {} + agent_stats = {} + except BaseException as e: + logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) + + # 将消息放入会话特定的队列 + chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) + await chat_queue.put( + ( + username, + webchat_conv_id, + { + "message": message_parts, + "selected_provider": selected_provider, + "selected_model": selected_model, + "enable_streaming": enable_streaming, + }, + ), + ) + + message_parts_for_storage = [] + for part in message_parts: + part_copy = {k: v for k, v in part.items() if k != "path"} + message_parts_for_storage.append(part_copy) + + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=username, + sender_name=username, + ) + + response = cast( + QuartResponse, + await make_response( + stream(), + { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Transfer-Encoding": "chunked", + "Connection": "keep-alive", + }, + ), + ) + response.timeout = None # fix SSE auto disconnect issue + return response + + def _extract_attachment_ids(self, history_list) -> list[str]: + """从消息历史中提取所有 attachment_id""" + attachment_ids = [] + for history in history_list: + content = history.content + if not content or "message" not in content: + continue + message_parts = content.get("message", []) + for part in message_parts: + if isinstance(part, dict) and "attachment_id" in part: + attachment_ids.append(part["attachment_id"]) + return attachment_ids + + async def _delete_attachments(self, attachment_ids: list[str]): + """删除附件(包括数据库记录和磁盘文件)""" + try: + attachments = await self.db.get_attachments(attachment_ids) + for attachment in attachments: + if not os.path.exists(attachment.path): + continue + try: + os.remove(attachment.path) + except OSError as e: + logger.warning( + f"Failed to delete attachment file {attachment.path}: {e}" + ) + except Exception as e: + logger.warning(f"Failed to get attachments: {e}") + + # 批量删除数据库记录 + try: + await self.db.delete_attachments(attachment_ids) + except Exception as e: + logger.warning(f"Failed to delete attachments: {e}") + + async def delete_session(self, session_id: str | None = None): + """删除会话及其相关数据 + + 路径参数或查询参数: + - session_id: 会话ID (必填) + """ + username = g.get("username", "guest") + if session_id is None: + # 从路径参数获取 + session_id = ( + request.view_args.get("session_id") if request.view_args else None + ) + if not session_id: + # 从查询参数获取 + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ + + # 验证会话是否存在且属于当前用户 + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 删除该会话下的所有对话 + message_type = "GroupMessage" if session.is_group else "FriendMessage" + unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}" + await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) + + # 获取消息历史中的所有附件 ID 并删除附件 + history_list = await self.platform_history_mgr.get( + platform_id=session.platform_id, + user_id=session_id, + page=1, + page_size=100000, # 获取足够多的记录 + ) + attachment_ids = self._extract_attachment_ids(history_list) + if attachment_ids: + await self._delete_attachments(attachment_ids) + + # 删除消息历史 + await self.platform_history_mgr.delete( + platform_id=session.platform_id, + user_id=session_id, + offset_sec=99999999, + ) + + # 删除与会话关联的配置路由 + try: + await self.umop_config_router.delete_route(unified_msg_origin) + except ValueError as exc: + logger.warning( + "Failed to delete UMO route %s during session cleanup: %s", + unified_msg_origin, + exc, + ) + + # 清理队列(仅对 webchat) + if session.platform_id == "webchat": + webchat_queue_mgr.remove_queues(session_id) + + # 删除会话 + await self.db.delete_platform_session(session_id) + + return Response().ok().__dict__ + + async def new_session(self): + """创建新会话 + + Query 参数或 Body: + - platform_id: 平台ID (可选, 默认 webchat) + """ + username = g.get("username", "guest") + # 优先从 query 参数获取,如果没有则从 body 获取 + platform_id = request.args.get("platform_id") + if not platform_id: + try: + post_data = await request.json + platform_id = post_data.get("platform_id", "webchat") + except Exception: + platform_id = "webchat" + # 创建新会话 + session = await self.db.create_platform_session( + creator=username, + platform_id=platform_id, + is_group=0, + ) + + return ( + Response() + .ok( + data={ + "session_id": session.session_id, + "platform_id": session.platform_id, + } + ) + .__dict__ + ) + + async def get_sessions(self): + """获取所有会话 + + Query 参数: + - platform_id: 平台ID (可选) + """ + username = g.get("username", "guest") + platform_id = request.args.get("platform_id") + sessions = await self.db.get_platform_sessions_by_creator( + creator=username, + platform_id=platform_id, + page=1, + page_size=100, # 暂时返回前100个 + ) + + # 转换为字典格式,并添加额外信息 + sessions_data = [] + for session in sessions: + sessions_data.append( + { + "session_id": session.session_id, + "platform_id": session.platform_id, + "creator": session.creator, + "display_name": session.display_name, + "is_group": session.is_group, + "created_at": session.created_at.astimezone().isoformat(), + "updated_at": session.updated_at.astimezone().isoformat(), + } + ) + + return Response().ok(data=sessions_data).__dict__ + + async def get_session(self, session_id: str | None = None): + """获取会话信息和消息历史 + + 路径参数或查询参数: + - session_id: 会话ID (必填) + """ + if session_id is None: + # 从路径参数获取 + session_id = ( + request.view_args.get("session_id") if request.view_args else None + ) + if not session_id: + # 从查询参数获取 + session_id = request.args.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ + + # 获取会话信息以确定 platform_id + session = await self.db.get_platform_session_by_id(session_id) + platform_id = session.platform_id if session else "webchat" + + # Get platform message history using session_id + history_ls = await self.platform_history_mgr.get( + platform_id=platform_id, + user_id=session_id, + page=1, + page_size=1000, + ) + + history_res = [history.model_dump() for history in history_ls] + + return ( + Response() + .ok( + data={ + "history": history_res, + "is_running": self.running_convs.get(session_id, False), + }, + ) + .__dict__ + ) + + async def update_session_display_name(self, session_id: str | None = None): + """更新会话显示名称 + + 路径参数或 Body: + - session_id: 会话ID (必填) + - display_name: 显示名称 (必填) + """ + username = g.get("username", "guest") + if session_id is None: + # 从路径参数获取 + session_id = ( + request.view_args.get("session_id") if request.view_args else None + ) + if not session_id: + # 从 body 获取 + post_data = await request.json + session_id = post_data.get("session_id") + display_name = post_data.get("display_name") + else: + # 如果从路径参数获取了 session_id,再从 body 获取 display_name + post_data = await request.json + display_name = post_data.get("display_name") + + if not session_id: + return Response().error("Missing key: session_id").__dict__ + if display_name is None: + return Response().error("Missing key: display_name").__dict__ + + # 验证会话是否存在且属于当前用户 + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + # 更新 display_name + await self.db.update_platform_session( + session_id=session_id, + display_name=display_name, + ) + + return Response().ok().__dict__ diff --git a/astrbot/dashboard/services/knowledge_base.py b/astrbot/dashboard/services/knowledge_base.py new file mode 100644 index 000000000..a113e3dfd --- /dev/null +++ b/astrbot/dashboard/services/knowledge_base.py @@ -0,0 +1,1345 @@ +"""知识库管理服务""" + +import asyncio +import os +import traceback +import uuid + +import aiofiles +from quart import request + +from astrbot.core import logger +from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider + +from ..entities import Response +from . import BaseService + + +class KnowledgeBaseService(BaseService): + """知识库管理服务 + + 提供知识库、文档、检索等业务逻辑 + """ + + def __init__(self, core_lifecycle): + super().__init__(core_lifecycle) + self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}} + self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}} + + def _get_kb_manager(self): + return self.cl.kb_manager + + def _init_task(self, task_id: str, status: str = "pending") -> None: + self.upload_tasks[task_id] = { + "status": status, + "result": None, + "error": None, + } + + def _set_task_result( + self, + task_id: str, + status: str, + result: any = None, + error: str | None = None, # type: ignore + ) -> None: + self.upload_tasks[task_id] = { + "status": status, + "result": result, + "error": error, + } + if task_id in self.upload_progress: + self.upload_progress[task_id]["status"] = status + + def _update_progress( + self, + task_id: str, + *, + status: str | None = None, + file_index: int | None = None, + file_name: str | None = None, + stage: str | None = None, + current: int | None = None, + total: int | None = None, + ) -> None: + if task_id not in self.upload_progress: + return + p = self.upload_progress[task_id] + if status is not None: + p["status"] = status + if file_index is not None: + p["file_index"] = file_index + if file_name is not None: + p["file_name"] = file_name + if stage is not None: + p["stage"] = stage + if current is not None: + p["current"] = current + if total is not None: + p["total"] = total + + def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): + async def _callback(stage: str, current: int, total: int): + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_name, + stage=stage, + current=current, + total=total, + ) + + return _callback + + async def _background_upload_task( + self, + task_id: str, + kb_helper, + files_to_upload: list, + chunk_size: int, + chunk_overlap: int, + batch_size: int, + tasks_limit: int, + max_retries: int, + ): + """后台上传任务""" + try: + # 初始化任务状态 + self._init_task(task_id, status="processing") + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": len(files_to_upload), + "stage": "waiting", + "current": 0, + "total": 100, + } + + uploaded_docs = [] + failed_docs = [] + + for file_idx, file_info in enumerate(files_to_upload): + try: + # 更新整体进度 + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_info["file_name"], + stage="parsing", + current=0, + total=100, + ) + + # 创建进度回调函数 + progress_callback = self._make_progress_callback( + task_id, file_idx, file_info["file_name"] + ) + + doc = await kb_helper.upload_document( + file_name=file_info["file_name"], + file_content=file_info["file_content"], + file_type=file_info["file_type"], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + + uploaded_docs.append(doc.model_dump()) + except Exception as e: + logger.error(f"上传文档 {file_info['file_name']} 失败: {e}") + failed_docs.append( + {"file_name": file_info["file_name"], "error": str(e)}, + ) + + # 更新任务完成状态 + result = { + "task_id": task_id, + "uploaded": uploaded_docs, + "failed": failed_docs, + "total": len(files_to_upload), + "success_count": len(uploaded_docs), + "failed_count": len(failed_docs), + } + + self._set_task_result(task_id, "completed", result=result) + + except Exception as e: + logger.error(f"后台上传任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + + async def _background_import_task( + self, + task_id: str, + kb_helper, + documents: list, + batch_size: int, + tasks_limit: int, + max_retries: int, + ): + """后台导入预切片文档任务""" + try: + # 初始化任务状态 + self._init_task(task_id, status="processing") + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": len(documents), + "stage": "waiting", + "current": 0, + "total": 100, + } + + uploaded_docs = [] + failed_docs = [] + + for file_idx, doc_info in enumerate(documents): + file_name = doc_info.get("file_name", f"imported_doc_{file_idx}") + chunks = doc_info.get("chunks", []) + + try: + # 更新整体进度 + self._update_progress( + task_id, + status="processing", + file_index=file_idx, + file_name=file_name, + stage="importing", + current=0, + total=100, + ) + + # 创建进度回调函数 + progress_callback = self._make_progress_callback( + task_id, file_idx, file_name + ) + + # 调用 upload_document,传入 pre_chunked_text + doc = await kb_helper.upload_document( + file_name=file_name, + file_content=None, # 预切片模式下不需要原始内容 + file_type=doc_info.get("file_type") + or ( + file_name.rsplit(".", 1)[-1].lower() + if "." in file_name + else "txt" + ), + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + pre_chunked_text=chunks, + ) + + uploaded_docs.append(doc.model_dump()) + except Exception as e: + logger.error(f"导入文档 {file_name} 失败: {e}") + failed_docs.append( + {"file_name": file_name, "error": str(e)}, + ) + + # 更新任务完成状态 + result = { + "task_id": task_id, + "uploaded": uploaded_docs, + "failed": failed_docs, + "total": len(documents), + "success_count": len(uploaded_docs), + "failed_count": len(failed_docs), + } + + self._set_task_result(task_id, "completed", result=result) + + except Exception as e: + logger.error(f"后台导入任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + + async def _background_upload_from_url_task( + self, + task_id: str, + kb_helper, + url: str, + chunk_size: int, + chunk_overlap: int, + batch_size: int, + tasks_limit: int, + max_retries: int, + enable_cleaning: bool, + cleaning_provider_id: str | None, + ): + """后台上传URL任务""" + try: + # 初始化任务状态 + self._init_task(task_id, status="processing") + self.upload_progress[task_id] = { + "status": "processing", + "file_index": 0, + "file_total": 1, + "file_name": f"URL: {url}", + "stage": "extracting", + "current": 0, + "total": 100, + } + + # 创建进度回调函数 + progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") + + # 上传文档 + doc = await kb_helper.upload_from_url( + url=url, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + enable_cleaning=enable_cleaning, + cleaning_provider_id=cleaning_provider_id, + ) + + # 更新任务完成状态 + result = { + "task_id": task_id, + "uploaded": [doc.model_dump()], + "failed": [], + "total": 1, + "success_count": 1, + "failed_count": 0, + } + + self._set_task_result(task_id, "completed", result=result) + + except Exception as e: + logger.error(f"后台上传URL任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + + def _validate_import_request(self, data: dict): + kb_id = data.get("kb_id") + if not kb_id: + raise ValueError("缺少参数 kb_id") + + documents = data.get("documents") + if not documents or not isinstance(documents, list): + raise ValueError("缺少参数 documents 或格式错误") + + for doc in documents: + if "file_name" not in doc or "chunks" not in doc: + raise ValueError("文档格式错误,必须包含 file_name 和 chunks") + if not isinstance(doc["chunks"], list): + raise ValueError("chunks 必须是列表") + if not all( + isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"] + ): + raise ValueError("chunks 必须是非空字符串列表") + + batch_size = data.get("batch_size", 32) + tasks_limit = data.get("tasks_limit", 3) + max_retries = data.get("max_retries", 3) + return kb_id, documents, batch_size, tasks_limit, max_retries + + async def list_kbs(self): + """获取知识库列表 + + Query 参数: + - page: 页码 (默认 1) + - page_size: 每页数量 (默认 20) + - refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true) + """ + try: + kb_manager = self._get_kb_manager() + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 20, type=int) + + kbs = await kb_manager.list_kbs() + + # 转换为字典列表 + kb_list = [] + for kb in kbs: + kb_list.append(kb.model_dump()) + + return ( + Response() + .ok({"items": kb_list, "page": page, "page_size": page_size}) + .__dict__ + ) + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"获取知识库列表失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取知识库列表失败: {e!s}").__dict__ + + async def create_kb(self): + """创建知识库 + + Body: + - kb_name: 知识库名称 (必填) + - description: 描述 (可选) + - emoji: 图标 (可选) + - embedding_provider_id: 嵌入模型提供商ID (可选) + - rerank_provider_id: 重排序模型提供商ID (可选) + - chunk_size: 分块大小 (可选, 默认512) + - chunk_overlap: 块重叠大小 (可选, 默认50) + - top_k_dense: 密集检索数量 (可选, 默认50) + - top_k_sparse: 稀疏检索数量 (可选, 默认50) + - top_m_final: 最终返回数量 (可选, 默认5) + """ + try: + kb_manager = self._get_kb_manager() + data = await request.json + kb_name = data.get("kb_name") + if not kb_name: + return Response().error("知识库名称不能为空").__dict__ + + description = data.get("description") + emoji = data.get("emoji") + embedding_provider_id = data.get("embedding_provider_id") + rerank_provider_id = data.get("rerank_provider_id") + chunk_size = data.get("chunk_size") + chunk_overlap = data.get("chunk_overlap") + top_k_dense = data.get("top_k_dense") + top_k_sparse = data.get("top_k_sparse") + top_m_final = data.get("top_m_final") + + # pre-check embedding dim + if not embedding_provider_id: + return Response().error("缺少参数 embedding_provider_id").__dict__ + prv = await kb_manager.provider_manager.get_provider_by_id( + embedding_provider_id, + ) # type: ignore + if not prv or not isinstance(prv, EmbeddingProvider): + return ( + Response().error(f"嵌入模型不存在或类型错误({type(prv)})").__dict__ + ) + try: + vec = await prv.get_embedding("astrbot") + if len(vec) != prv.get_dim(): + raise ValueError( + f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}", + ) + except Exception as e: + return Response().error(f"测试嵌入模型失败: {e!s}").__dict__ + # pre-check rerank + if rerank_provider_id: + rerank_prv: RerankProvider = ( + await kb_manager.provider_manager.get_provider_by_id( + rerank_provider_id, + ) + ) # type: ignore + if not rerank_prv: + return Response().error("重排序模型不存在").__dict__ + # 检查重排序模型可用性 + try: + res = await rerank_prv.rerank( + query="astrbot", + documents=["astrbot knowledge base"], + ) + if not res: + raise ValueError("重排序模型返回结果异常") + except Exception as e: + return ( + Response() + .error(f"测试重排序模型失败: {e!s},请检查平台日志输出。") + .__dict__ + ) + + kb_helper = await kb_manager.create_kb( + kb_name=kb_name, + description=description, + emoji=emoji, + embedding_provider_id=embedding_provider_id, + rerank_provider_id=rerank_provider_id, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + top_k_dense=top_k_dense, + top_k_sparse=top_k_sparse, + top_m_final=top_m_final, + ) + kb = kb_helper.kb + + return Response().ok(kb.model_dump(), "创建知识库成功").__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"创建知识库失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"创建知识库失败: {e!s}").__dict__ + + async def get_kb(self, kb_id: str | None = None): + """获取知识库详情 + + 路径参数或 Query 参数: + - kb_id: 知识库 ID (必填) + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从查询参数获取 + kb_id = request.args.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + kb = kb_helper.kb + + return Response().ok(kb.model_dump()).__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"获取知识库详情失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取知识库详情失败: {e!s}").__dict__ + + async def update_kb(self, kb_id: str | None = None): + """更新知识库 + + 路径参数或 Body: + - kb_id: 知识库 ID (必填) + - kb_name: 新的知识库名称 (可选) + - description: 新的描述 (可选) + - emoji: 新的图标 (可选) + - embedding_provider_id: 新的嵌入模型提供商ID (可选) + - rerank_provider_id: 新的重排序模型提供商ID (可选) + - chunk_size: 分块大小 (可选) + - chunk_overlap: 块重叠大小 (可选) + - top_k_dense: 密集检索数量 (可选) + - top_k_sparse: 稀疏检索数量 (可选) + - top_m_final: 最终返回数量 (可选) + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从 body 获取 + data = await request.json + kb_id = data.get("kb_id") + else: + # 如果从路径参数获取了 kb_id,再从 body 获取其他参数 + data = await request.json + + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + kb_name = data.get("kb_name") + description = data.get("description") + emoji = data.get("emoji") + embedding_provider_id = data.get("embedding_provider_id") + rerank_provider_id = data.get("rerank_provider_id") + chunk_size = data.get("chunk_size") + chunk_overlap = data.get("chunk_overlap") + top_k_dense = data.get("top_k_dense") + top_k_sparse = data.get("top_k_sparse") + top_m_final = data.get("top_m_final") + + # 检查是否至少提供了一个更新字段 + if all( + v is None + for v in [ + kb_name, + description, + emoji, + embedding_provider_id, + rerank_provider_id, + chunk_size, + chunk_overlap, + top_k_dense, + top_k_sparse, + top_m_final, + ] + ): + return Response().error("至少需要提供一个更新字段").__dict__ + + kb_helper = await kb_manager.update_kb( + kb_id=kb_id, + kb_name=kb_name, + description=description, + emoji=emoji, + embedding_provider_id=embedding_provider_id, + rerank_provider_id=rerank_provider_id, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + top_k_dense=top_k_dense, + top_k_sparse=top_k_sparse, + top_m_final=top_m_final, + ) + + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + kb = kb_helper.kb + return Response().ok(kb.model_dump(), "更新知识库成功").__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"更新知识库失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"更新知识库失败: {e!s}").__dict__ + + async def delete_kb(self, kb_id: str | None = None): + """删除知识库 + + 路径参数或 Body: + - kb_id: 知识库 ID (必填) + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从 body 获取 + data = await request.json + kb_id = data.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + success = await kb_manager.delete_kb(kb_id) + if not success: + return Response().error("知识库不存在").__dict__ + + return Response().ok(message="删除知识库成功").__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"删除知识库失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"删除知识库失败: {e!s}").__dict__ + + async def get_kb_stats(self): + """获取知识库统计信息 + + Query 参数: + - kb_id: 知识库 ID (必填) + """ + try: + kb_manager = self._get_kb_manager() + kb_id = request.args.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + kb = kb_helper.kb + + stats = { + "kb_id": kb.kb_id, + "kb_name": kb.kb_name, + "doc_count": kb.doc_count, + "chunk_count": kb.chunk_count, + "created_at": kb.created_at.isoformat(), + "updated_at": kb.updated_at.isoformat(), + } + + return Response().ok(stats).__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"获取知识库统计失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取知识库统计失败: {e!s}").__dict__ + + async def list_documents(self, kb_id: str | None = None): + """获取文档列表 + + 路径参数或 Query 参数: + - kb_id: 知识库 ID (必填) + - page: 页码 (默认 1) + - page_size: 每页数量 (默认 20) + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从查询参数获取 + kb_id = request.args.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 100, type=int) + + offset = (page - 1) * page_size + limit = page_size + + doc_list = await kb_helper.list_documents(offset=offset, limit=limit) + + doc_list = [doc.model_dump() for doc in doc_list] + + return ( + Response() + .ok({"items": doc_list, "page": page, "page_size": page_size}) + .__dict__ + ) + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"获取文档列表失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取文档列表失败: {e!s}").__dict__ + + async def upload_document(self, kb_id: str | None = None): + """上传文档 + + 支持两种方式: + 1. multipart/form-data 文件上传(支持多文件,最多10个) + 2. JSON 格式 base64 编码上传(支持多文件,最多10个) + + 路径参数或 Form Data (multipart/form-data): + - kb_id: 知识库 ID (必填) + - file: 文件对象 (必填,可多个,字段名为 file, file1, file2, ... 或 files[]) + + 路径参数或 JSON Body (application/json): + - kb_id: 知识库 ID (必填) + - files: 文件数组 (必填) + - file_name: 文件名 (必填) + - file_content: base64 编码的文件内容 (必填) + + 返回: + - task_id: 任务ID,用于查询上传进度和结果 + """ + try: + kb_manager = self._get_kb_manager() + + # 检查 Content-Type + content_type = request.content_type + chunk_size = None + chunk_overlap = None + batch_size = 32 + tasks_limit = 3 + max_retries = 3 + files_to_upload = [] # 存储待上传的文件信息列表 + + if content_type and "multipart/form-data" not in content_type: + return ( + Response().error("Content-Type 须为 multipart/form-data").__dict__ + ) + form_data = await request.form + files = await request.files + + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从 form data 获取 + kb_id = form_data.get("kb_id") + chunk_size = int(form_data.get("chunk_size", 512)) + chunk_overlap = int(form_data.get("chunk_overlap", 50)) + batch_size = int(form_data.get("batch_size", 32)) + tasks_limit = int(form_data.get("tasks_limit", 3)) + max_retries = int(form_data.get("max_retries", 3)) + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + # 收集所有文件 + file_list = [] + # 支持 file, file1, file2, ... 或 files[] 格式 + for key in files.keys(): + if key == "file" or key.startswith("file") or key == "files[]": + file_items = files.getlist(key) + file_list.extend(file_items) + + if not file_list: + return Response().error("缺少文件").__dict__ + + # 限制文件数量 + if len(file_list) > 10: + return Response().error("最多只能上传10个文件").__dict__ + + # 处理每个文件 + for file in file_list: + file_name = file.filename + + # 保存到临时文件 + temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}" + await file.save(temp_file_path) + + try: + # 异步读取文件内容 + async with aiofiles.open(temp_file_path, "rb") as f: + file_content = await f.read() + + # 提取文件类型 + file_type = ( + file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "" + ) + + files_to_upload.append( + { + "file_name": file_name, + "file_content": file_content, + "file_type": file_type, + }, + ) + finally: + # 清理临时文件 + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + # 获取知识库 + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, status="pending") + + # 启动后台任务 + asyncio.create_task( + self._background_upload_task( + task_id=task_id, + kb_helper=kb_helper, + files_to_upload=files_to_upload, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + ), + ) + + return ( + Response() + .ok( + { + "task_id": task_id, + "file_count": len(files_to_upload), + "message": "task created, processing in background", + }, + ) + .__dict__ + ) + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"上传文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"上传文档失败: {e!s}").__dict__ + + async def import_documents(self, kb_id: str | None = None): + """导入预切片文档 + + 路径参数或 Body: + - kb_id: 知识库 ID (必填) + - documents: 文档列表 (必填) + - file_name: 文件名 (必填) + - chunks: 切片列表 (必填, list[str]) + - file_type: 文件类型 (可选, 默认从文件名推断或为 txt) + - batch_size: 批处理大小 (可选, 默认32) + - tasks_limit: 并发任务限制 (可选, 默认3) + - max_retries: 最大重试次数 (可选, 默认3) + """ + try: + kb_manager = self._get_kb_manager() + data = await request.json + + # 如果从路径参数获取了 kb_id,更新到 data 中 + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if kb_id: + data["kb_id"] = kb_id + + kb_id, documents, batch_size, tasks_limit, max_retries = ( + self._validate_import_request(data) + ) + + # 获取知识库 (kb_id 已经通过 _validate_import_request 验证,不会是 None) + assert kb_id is not None, "kb_id should not be None after validation" + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, status="pending") + + # 启动后台任务 + asyncio.create_task( + self._background_import_task( + task_id=task_id, + kb_helper=kb_helper, + documents=documents, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + ), + ) + + return ( + Response() + .ok( + { + "task_id": task_id, + "doc_count": len(documents), + "message": "import task created, processing in background", + }, + ) + .__dict__ + ) + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"导入文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"导入文档失败: {e!s}").__dict__ + + async def get_upload_progress( + self, kb_id: str | None = None, task_id: str | None = None + ): + """获取上传进度和结果 + + 路径参数或 Query 参数: + - kb_id: 知识库 ID (必填) + - task_id: 任务 ID (必填) + + 返回状态: + - pending: 任务待处理 + - processing: 任务处理中 + - completed: 任务完成 + - failed: 任务失败 + """ + try: + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从查询参数获取 + kb_id = request.args.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + if task_id is None: + # 从路径参数获取 + task_id = ( + request.view_args.get("task_id") if request.view_args else None + ) + if not task_id: + # 从查询参数获取 + task_id = request.args.get("task_id") + if not task_id: + return Response().error("缺少参数 task_id").__dict__ + + # 检查任务是否存在 + if task_id not in self.upload_tasks: + return Response().error("找不到该任务").__dict__ + + task_info = self.upload_tasks[task_id] + status = task_info["status"] + + # 构建返回数据 + response_data = { + "task_id": task_id, + "status": status, + } + + # 如果任务正在处理,返回进度信息 + if status == "processing" and task_id in self.upload_progress: + response_data["progress"] = self.upload_progress[task_id] + + # 如果任务完成,返回结果 + if status == "completed": + response_data["result"] = task_info["result"] + + # 如果任务失败,返回错误信息 + if status == "failed": + response_data["error"] = task_info["error"] + + return Response().ok(response_data).__dict__ + + except Exception as e: + logger.error(f"获取上传进度失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取上传进度失败: {e!s}").__dict__ + + async def get_document(self, kb_id: str | None = None, doc_id: str | None = None): + """获取文档详情 + + 路径参数或 Query 参数: + - kb_id: 知识库 ID (必填) + - doc_id: 文档 ID (必填) + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从查询参数获取 + kb_id = request.args.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + if doc_id is None: + # 从路径参数获取 + doc_id = request.view_args.get("doc_id") if request.view_args else None + if not doc_id: + # 从查询参数获取 + doc_id = request.args.get("doc_id") + if not doc_id: + return Response().error("缺少参数 doc_id").__dict__ + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + doc = await kb_helper.get_document(doc_id) + if not doc: + return Response().error("文档不存在").__dict__ + + return Response().ok(doc.model_dump()).__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"获取文档详情失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取文档详情失败: {e!s}").__dict__ + + async def delete_document( + self, kb_id: str | None = None, doc_id: str | None = None + ): + """删除文档 + + 路径参数或 Body: + - kb_id: 知识库 ID (必填) + - doc_id: 文档 ID (必填) + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if doc_id is None: + # 从路径参数获取 + doc_id = request.view_args.get("doc_id") if request.view_args else None + + if not kb_id or not doc_id: + # 从 body 获取 + data = await request.json + if not kb_id: + kb_id = data.get("kb_id") + if not doc_id: + doc_id = data.get("doc_id") + + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + if not doc_id: + return Response().error("缺少参数 doc_id").__dict__ + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + await kb_helper.delete_document(doc_id) + return Response().ok(message="删除文档成功").__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"删除文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"删除文档失败: {e!s}").__dict__ + + async def delete_chunk( + self, + kb_id: str | None = None, + doc_id: str | None = None, + chunk_id: str | None = None, + ): + """删除文本块 + + 路径参数或 Body: + - kb_id: 知识库 ID (必填) + - doc_id: 文档 ID (必填) + - chunk_id: 块 ID (必填) + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if doc_id is None: + # 从路径参数获取 + doc_id = request.view_args.get("doc_id") if request.view_args else None + if chunk_id is None: + # 从路径参数获取 + chunk_id = ( + request.view_args.get("chunk_id") if request.view_args else None + ) + + if not kb_id or not doc_id or not chunk_id: + # 从 body 获取 + data = await request.json + if not kb_id: + kb_id = data.get("kb_id") + if not doc_id: + doc_id = data.get("doc_id") + if not chunk_id: + chunk_id = data.get("chunk_id") + + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + if not doc_id: + return Response().error("缺少参数 doc_id").__dict__ + if not chunk_id: + return Response().error("缺少参数 chunk_id").__dict__ + + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + await kb_helper.delete_chunk(chunk_id, doc_id) + return Response().ok(message="删除文本块成功").__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"删除文本块失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"删除文本块失败: {e!s}").__dict__ + + async def list_chunks(self, kb_id: str | None = None, doc_id: str | None = None): + """获取块列表 + + 路径参数或 Query 参数: + - kb_id: 知识库 ID (必填) + - doc_id: 文档 ID (必填) + - page: 页码 (默认 1) + - page_size: 每页数量 (默认 20) + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从查询参数获取 + kb_id = request.args.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + if doc_id is None: + # 从路径参数获取 + doc_id = request.view_args.get("doc_id") if request.view_args else None + if not doc_id: + # 从查询参数获取 + doc_id = request.args.get("doc_id") + if not doc_id: + return Response().error("缺少参数 doc_id").__dict__ + + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 100, type=int) + kb_helper = await kb_manager.get_kb(kb_id) + offset = (page - 1) * page_size + limit = page_size + if not kb_helper: + return Response().error("知识库不存在").__dict__ + chunk_list = await kb_helper.get_chunks_by_doc_id( + doc_id=doc_id, + offset=offset, + limit=limit, + ) + return ( + Response() + .ok( + data={ + "items": chunk_list, + "page": page, + "page_size": page_size, + "total": await kb_helper.get_chunk_count_by_doc_id(doc_id), + }, + ) + .__dict__ + ) + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"获取块列表失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取块列表失败: {e!s}").__dict__ + + async def retrieve(self): + """检索知识库 + + Body: + - query: 查询文本 (必填) + - kb_ids: 知识库 ID 列表 (必填) + - top_k: 返回结果数量 (可选, 默认 5) + - debug: 是否启用调试模式,返回 t-SNE 可视化图片 (可选, 默认 False) + """ + try: + kb_manager = self._get_kb_manager() + data = await request.json + + query = data.get("query") + kb_names = data.get("kb_names") + debug = data.get("debug", False) + + if not query: + return Response().error("缺少参数 query").__dict__ + if not kb_names or not isinstance(kb_names, list): + return Response().error("缺少参数 kb_names 或格式错误").__dict__ + + top_k = data.get("top_k", 5) + + results = await kb_manager.retrieve( + query=query, + kb_names=kb_names, + top_m_final=top_k, + ) + result_list = [] + if results: + result_list = results["results"] + + response_data = { + "results": result_list, + "total": len(result_list), + "query": query, + } + + # Debug 模式:生成 t-SNE 可视化 + if debug: + try: + from ..utils import generate_tsne_visualization + + img_base64 = await generate_tsne_visualization( + query, + kb_names, + kb_manager, + ) + if img_base64: + response_data["visualization"] = img_base64 + except Exception as e: + logger.error(f"生成 t-SNE 可视化失败: {e}") + logger.error(traceback.format_exc()) + response_data["visualization_error"] = str(e) + + return Response().ok(response_data).__dict__ + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"检索失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"检索失败: {e!s}").__dict__ + + async def upload_document_from_url(self, kb_id: str | None = None): + """从 URL 上传文档 + + 路径参数或 Body: + - kb_id: 知识库 ID (必填) + - url: 要提取内容的网页 URL (必填) + - chunk_size: 分块大小 (可选, 默认512) + - chunk_overlap: 块重叠大小 (可选, 默认50) + - batch_size: 批处理大小 (可选, 默认32) + - tasks_limit: 并发任务限制 (可选, 默认3) + - max_retries: 最大重试次数 (可选, 默认3) + + 返回: + - task_id: 任务ID,用于查询上传进度和结果 + """ + try: + kb_manager = self._get_kb_manager() + if kb_id is None: + # 从路径参数获取 + kb_id = request.view_args.get("kb_id") if request.view_args else None + if not kb_id: + # 从 body 获取 + data = await request.json + kb_id = data.get("kb_id") + else: + # 如果从路径参数获取了 kb_id,再从 body 获取其他参数 + data = await request.json + + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ + + url = data.get("url") + if not url: + return Response().error("缺少参数 url").__dict__ + + chunk_size = data.get("chunk_size", 512) + chunk_overlap = data.get("chunk_overlap", 50) + batch_size = data.get("batch_size", 32) + tasks_limit = data.get("tasks_limit", 3) + max_retries = data.get("max_retries", 3) + enable_cleaning = data.get("enable_cleaning", False) + cleaning_provider_id = data.get("cleaning_provider_id") + + # 获取知识库 + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, status="pending") + + # 启动后台任务 + asyncio.create_task( + self._background_upload_from_url_task( + task_id=task_id, + kb_helper=kb_helper, + url=url, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + enable_cleaning=enable_cleaning, + cleaning_provider_id=cleaning_provider_id, + ), + ) + + return ( + Response() + .ok( + { + "task_id": task_id, + "url": url, + "message": "URL upload task created, processing in background", + }, + ) + .__dict__ + ) + + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"从URL上传文档失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"从URL上传文档失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/services/platform.py b/astrbot/dashboard/services/platform.py new file mode 100644 index 000000000..19b16aa32 --- /dev/null +++ b/astrbot/dashboard/services/platform.py @@ -0,0 +1,159 @@ +import inspect +import os + +from quart import request + +from astrbot.core import file_token_service, logger +from astrbot.core.platform.register import platform_cls_map + +from ..entities import Response +from . import BaseService +from .utils import save_config + + +class PlatformService(BaseService): + def __init__(self, core_lifecycle): + super().__init__(core_lifecycle) + self._logo_token_cache = {} # 缓存logo token,避免重复注册 + + async def get_platform_list(self): + """获取所有平台的列表""" + platform_list = [] + config = self.cl.astrbot_config + for platform in config["platform"]: + platform_list.append(platform) + return Response().ok({"platforms": platform_list}).__dict__ + + async def post_new_platform(self): + """创建新的平台配置""" + new_platform_config = await request.json + + # 如果是支持统一 webhook 模式的平台,生成 webhook_uuid + from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config + + ensure_platform_webhook_config(new_platform_config) + + config = self.cl.astrbot_config + config["platform"].append(new_platform_config) + try: + save_config(config, config, is_core=True) + await self.cl.platform_manager.load_platform(new_platform_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "新增平台配置成功~").__dict__ + + async def post_update_platform(self): + """更新平台配置""" + update_platform_config = await request.json + origin_platform_id = update_platform_config.get("id", None) + new_config = update_platform_config.get("config", None) + if not origin_platform_id or not new_config: + return Response().error("参数错误").__dict__ + + if origin_platform_id != new_config.get("id", None): + return Response().error("机器人名称不允许修改").__dict__ + + # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid + from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config + + ensure_platform_webhook_config(new_config) + + config = self.cl.astrbot_config + for i, platform in enumerate(config["platform"]): + if platform["id"] == origin_platform_id: + config["platform"][i] = new_config + break + else: + return Response().error("未找到对应平台").__dict__ + + try: + save_config(config, config, is_core=True) + await self.cl.platform_manager.reload(new_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "更新平台配置成功~").__dict__ + + async def post_delete_platform(self): + """删除平台配置""" + platform_id = await request.json + platform_id = platform_id.get("id") + config = self.cl.astrbot_config + for i, platform in enumerate(config["platform"]): + if platform["id"] == platform_id: + del config["platform"][i] + break + else: + return Response().error("未找到对应平台").__dict__ + try: + save_config(config, config, is_core=True) + await self.cl.platform_manager.terminate_platform(platform_id) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "删除平台配置成功~").__dict__ + + async def register_platform_logo(self, platform, platform_default_tmpl): + """注册平台logo文件并生成访问令牌""" + if not platform.logo_path: + return + + try: + # 检查缓存 + cache_key = f"{platform.name}:{platform.logo_path}" + if cache_key in self._logo_token_cache: + cached_token = self._logo_token_cache[cache_key] + # 确保platform_default_tmpl[platform.name]存在且为字典 + if platform.name not in platform_default_tmpl or not isinstance( + platform_default_tmpl[platform.name], dict + ): + platform_default_tmpl[platform.name] = {} + platform_default_tmpl[platform.name]["logo_token"] = cached_token + logger.debug(f"Using cached logo token for platform {platform.name}") + return + + # 获取平台适配器类 + platform_cls = platform_cls_map.get(platform.name) + if not platform_cls: + logger.warning(f"Platform class not found for {platform.name}") + return + + # 获取插件目录路径 + module_file = inspect.getfile(platform_cls) + plugin_dir = os.path.dirname(module_file) + + # 解析logo文件路径 + logo_file_path = os.path.join(plugin_dir, platform.logo_path) + + # 检查文件是否存在并注册令牌 + if os.path.exists(logo_file_path): + logo_token = await file_token_service.register_file( + logo_file_path, + timeout=3600, + ) + + # 确保platform_default_tmpl[platform.name]存在且为字典 + if platform.name not in platform_default_tmpl or not isinstance( + platform_default_tmpl[platform.name], dict + ): + platform_default_tmpl[platform.name] = {} + + platform_default_tmpl[platform.name]["logo_token"] = logo_token + + # 缓存token + self._logo_token_cache[cache_key] = logo_token + + logger.debug(f"Logo token registered for platform {platform.name}") + else: + logger.warning( + f"Platform {platform.name} logo file not found: {logo_file_path}", + ) + + except (ImportError, AttributeError) as e: + logger.warning( + f"Failed to import required modules for platform {platform.name}: {e}", + ) + except OSError as e: + logger.warning(f"File system error for platform {platform.name} logo: {e}") + except Exception as e: + logger.warning( + f"Unexpected error registering logo for platform {platform.name}: {e}", + ) diff --git a/astrbot/dashboard/services/provider.py b/astrbot/dashboard/services/provider.py new file mode 100644 index 000000000..e069a4268 --- /dev/null +++ b/astrbot/dashboard/services/provider.py @@ -0,0 +1,444 @@ +import inspect +import traceback + +from quart import request + +from astrbot.core import astrbot_config, logger +from astrbot.core.config.default import CONFIG_METADATA_2 +from astrbot.core.provider.provider import EmbeddingProvider, Provider +from astrbot.core.utils.llm_metadata import LLM_METADATAS + +from ..entities import Response +from . import BaseService +from .utils import save_config + + +class ProviderService(BaseService): + async def update_provider_source(self): + """更新或新增 provider_source,并重载关联的 providers""" + post_data = await request.json + if not post_data: + return Response().error("缺少配置数据").__dict__ + + new_source_config = post_data.get("config") or post_data + original_id = post_data.get("original_id") + if not original_id: + return Response().error("缺少 original_id").__dict__ + + if not isinstance(new_source_config, dict): + return Response().error("缺少或错误的配置数据").__dict__ + + # 确保配置中有 id 字段 + if not new_source_config.get("id"): + new_source_config["id"] = original_id + + provider_sources = astrbot_config.get("provider_sources", []) + + for ps in provider_sources: + if ps.get("id") == new_source_config["id"] and ps.get("id") != original_id: + return ( + Response() + .error( + f"Provider source ID '{new_source_config['id']}' exists already, please try another ID.", + ) + .__dict__ + ) + + # 查找旧的 provider_source,若不存在则追加为新配置 + target_idx = next( + (i for i, ps in enumerate(provider_sources) if ps.get("id") == original_id), + -1, + ) + + old_id = original_id + if target_idx == -1: + provider_sources.append(new_source_config) + else: + old_id = provider_sources[target_idx].get("id") + provider_sources[target_idx] = new_source_config + + # 更新引用了该 provider_source 的 providers + affected_providers = [] + for provider in astrbot_config.get("provider", []): + if provider.get("provider_source_id") == old_id: + provider["provider_source_id"] = new_source_config["id"] + affected_providers.append(provider) + + # 写回配置 + astrbot_config["provider_sources"] = provider_sources + + try: + save_config(astrbot_config, astrbot_config, is_core=True) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(str(e)).__dict__ + + # 重载受影响的 providers,使新的 source 配置生效 + reload_errors = [] + for provider in affected_providers: + try: + await self.clpm.reload(provider) + except Exception as e: + logger.error(traceback.format_exc()) + reload_errors.append(f"{provider.get('id')}: {e}") + + if reload_errors: + return ( + Response() + .error("更新成功,但部分提供商重载失败: " + ", ".join(reload_errors)) + .__dict__ + ) + + return Response().ok(message="更新 provider source 成功").__dict__ + + async def delete_provider_source(self): + """删除 provider_source,并更新关联的 providers""" + post_data = await request.json + if not post_data: + return Response().error("缺少配置数据").__dict__ + + provider_source_id = post_data.get("id") + if not provider_source_id: + return Response().error("缺少 provider_source_id").__dict__ + + provider_sources = astrbot_config.get("provider_sources", []) + target_idx = next( + ( + i + for i, ps in enumerate(provider_sources) + if ps.get("id") == provider_source_id + ), + -1, + ) + + if target_idx == -1: + return Response().error("未找到对应的 provider source").__dict__ + + # 删除 provider_source + del provider_sources[target_idx] + + # 写回配置 + astrbot_config["provider_sources"] = provider_sources + + # 删除引用了该 provider_source 的 providers + await self.clpm.delete_provider(provider_source_id=provider_source_id) + + try: + save_config(astrbot_config, astrbot_config, is_core=True) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(str(e)).__dict__ + + return Response().ok(message="删除 provider source 成功").__dict__ + + async def get_provider_source_models(self): + """获取指定 provider_source 支持的模型列表 + + 本质上会临时初始化一个 Provider 实例,调用 get_models() 获取模型列表,然后销毁实例 + """ + provider_source_id = request.args.get("source_id") + if not provider_source_id: + return Response().error("缺少参数 source_id").__dict__ + + try: + from astrbot.core.provider.register import provider_cls_map + + # 从配置中查找对应的 provider_source + provider_sources = astrbot_config.get("provider_sources", []) + provider_source = None + for ps in provider_sources: + if ps.get("id") == provider_source_id: + provider_source = ps + break + + if not provider_source: + return ( + Response() + .error(f"未找到 ID 为 {provider_source_id} 的 provider_source") + .__dict__ + ) + + # 获取 provider 类型 + provider_type = provider_source.get("type", None) + if not provider_type: + return Response().error("provider_source 缺少 type 字段").__dict__ + + try: + self.clpm.dynamic_import_provider(provider_type) + except ImportError as e: + logger.error(traceback.format_exc()) + return Response().error(f"动态导入提供商适配器失败: {e!s}").__dict__ + + # 获取对应的 provider 类 + if provider_type not in provider_cls_map: + return ( + Response() + .error(f"未找到适用于 {provider_type} 的提供商适配器") + .__dict__ + ) + + provider_metadata = provider_cls_map[provider_type] + cls_type = provider_metadata.cls_type + + if not cls_type: + return Response().error(f"无法找到 {provider_type} 的类").__dict__ + + # 检查是否是 Provider 类型 + if not issubclass(cls_type, Provider): + return ( + Response() + .error(f"提供商 {provider_type} 不支持获取模型列表") + .__dict__ + ) + + # 临时实例化 provider + inst = cls_type(provider_source, {}) + + # 如果有 initialize 方法,调用它 + init_fn = getattr(inst, "initialize", None) + if inspect.iscoroutinefunction(init_fn): + await init_fn() + + # 获取模型列表 + models = await inst.get_models() + models = models or [] + + metadata_map = {} + for model_id in models: + meta = LLM_METADATAS.get(model_id) + if meta: + metadata_map[model_id] = meta + + # 销毁实例(如果有 terminate 方法) + terminate_fn = getattr(inst, "terminate", None) + if inspect.iscoroutinefunction(terminate_fn): + await terminate_fn() + + logger.info( + f"获取到 provider_source {provider_source_id} 的模型列表: {models}", + ) + + return ( + Response() + .ok({"models": models, "model_metadata": metadata_map}) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取模型列表失败: {e!s}").__dict__ + + async def get_provider_template(self): + """获取 provider 配置模板""" + config_schema = { + "provider": CONFIG_METADATA_2["provider_group"]["metadata"]["provider"] + } + data = { + "config_schema": config_schema, + "providers": astrbot_config["provider"], + "provider_sources": astrbot_config["provider_sources"], + } + return Response().ok(data=data).__dict__ + + async def list_provider_sources(self): + """获取 provider source 列表""" + return Response().ok(data=astrbot_config["provider_sources"]).__dict__ + + async def _test_single_provider(self, provider): + """辅助函数:测试单个 provider 的可用性""" + meta = provider.meta() + provider_name = provider.provider_config.get("id", "Unknown Provider") + provider_capability_type = meta.provider_type + + status_info = { + "id": getattr(meta, "id", "Unknown ID"), + "model": getattr(meta, "model", "Unknown Model"), + "type": provider_capability_type.value, + "name": provider_name, + "status": "unavailable", # 默认为不可用 + "error": None, + } + logger.debug( + f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", + ) + + try: + await provider.test() + status_info["status"] = "available" + logger.info( + f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", + ) + except Exception as e: + error_message = str(e) + status_info["error"] = error_message + logger.warning( + f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}", + ) + logger.debug( + f"Traceback for {status_info['name']}:\n{traceback.format_exc()}", + ) + + return status_info + + def _error_response( + self, + message: str, + status_code: int = 500, + log_fn=logger.error, + ): + log_fn(message) + # 记录更详细的traceback信息,但只在是严重错误时 + if status_code == 500: + log_fn(traceback.format_exc()) + return Response().error(message).__dict__ + + async def check_one_provider_status(self): + """API: check a single LLM Provider's status by id""" + provider_id = request.args.get("id") + if not provider_id: + return self._error_response( + "Missing provider_id parameter", + 400, + logger.warning, + ) + + logger.info(f"API call: /config/provider/check_one id={provider_id}") + try: + prov_mgr = self.clpm + target = prov_mgr.inst_map.get(provider_id) + + if not target: + logger.warning( + f"Provider with id '{provider_id}' not found in provider_manager.", + ) + return ( + Response() + .error(f"Provider with id '{provider_id}' not found") + .__dict__ + ) + + result = await self._test_single_provider(target) + return Response().ok(result).__dict__ + + except Exception as e: + return self._error_response( + f"Critical error checking provider {provider_id}: {e}", + 500, + ) + + async def get_provider_config_list(self): + """获取指定类型的 provider 配置列表""" + provider_type = request.args.get("provider_type", None) + if not provider_type: + return Response().error("缺少参数 provider_type").__dict__ + provider_type_ls = provider_type.split(",") + provider_list = [] + ps = self.clpm.providers_config + p_source_pt = { + psrc["id"]: psrc["provider_type"] + for psrc in self.clpm.provider_sources_config + } + for provider in ps: + ps_id = provider.get("provider_source_id", None) + if ( + ps_id + and ps_id in p_source_pt + and p_source_pt[ps_id] in provider_type_ls + ): + # chat + prov = self.clpm.get_merged_provider_config(provider) + provider_list.append(prov) + elif not ps_id and provider.get("provider_type", None) in provider_type_ls: + # agent runner, embedding, etc + provider_list.append(provider) + return Response().ok(provider_list).__dict__ + + async def get_embedding_dim(self): + """获取嵌入模型的维度""" + post_data = await request.json + provider_config = post_data.get("provider_config", None) + if not provider_config: + return Response().error("缺少参数 provider_config").__dict__ + + try: + # 动态导入 EmbeddingProvider + from astrbot.core.provider.register import provider_cls_map + + # 获取 provider 类型 + provider_type = provider_config.get("type", None) + if not provider_type: + return Response().error("provider_config 缺少 type 字段").__dict__ + + # 获取对应的 provider 类 + if provider_type not in provider_cls_map: + return ( + Response() + .error(f"未找到适用于 {provider_type} 的提供商适配器") + .__dict__ + ) + + provider_metadata = provider_cls_map[provider_type] + cls_type = provider_metadata.cls_type + + if not cls_type: + return Response().error(f"无法找到 {provider_type} 的类").__dict__ + + # 实例化 provider + inst = cls_type(provider_config, {}) + + # 检查是否是 EmbeddingProvider + if not isinstance(inst, EmbeddingProvider): + return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ + + init_fn = getattr(inst, "initialize", None) + if inspect.iscoroutinefunction(init_fn): + await init_fn() + + # 获取嵌入向量维度 + vec = await inst.get_embedding("echo") + dim = len(vec) + + logger.info( + f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}", + ) + + return Response().ok({"embedding_dimensions": dim}).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ + + async def post_new_provider(self): + """创建新的 provider""" + new_provider_config = await request.json + + try: + await self.clpm.create_provider(new_provider_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "新增服务提供商配置成功").__dict__ + + async def post_update_provider(self): + """更新 provider 配置""" + update_provider_config = await request.json + origin_provider_id = update_provider_config.get("id", None) + new_config = update_provider_config.get("config", None) + if not origin_provider_id or not new_config: + return Response().error("参数错误").__dict__ + + try: + await self.clpm.update_provider(origin_provider_id, new_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "更新成功,已经实时生效~").__dict__ + + async def post_delete_provider(self): + """删除 provider""" + provider_id = await request.json + provider_id = provider_id.get("id", "") + if not provider_id: + return Response().error("缺少参数 id").__dict__ + + try: + await self.clpm.delete_provider(provider_id=provider_id) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "删除成功,已经实时生效。").__dict__ diff --git a/astrbot/dashboard/services/utils.py b/astrbot/dashboard/services/utils.py new file mode 100644 index 000000000..f61451876 --- /dev/null +++ b/astrbot/dashboard/services/utils.py @@ -0,0 +1,126 @@ +import traceback +from typing import Any + +from astrbot.core import logger +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP + + +def try_cast(value: Any, type_: str): + if type_ == "int": + try: + return int(value) + except (ValueError, TypeError): + return None + elif ( + type_ == "float" + and isinstance(value, str) + and value.replace(".", "", 1).isdigit() + ) or (type_ == "float" and isinstance(value, int)): + return float(value) + elif type_ == "float": + try: + return float(value) + except (ValueError, TypeError): + return None + + +def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]: + errors = [] + + def validate(data: dict, metadata: dict = schema, path=""): + for key, value in data.items(): + if key not in metadata: + continue + meta = metadata[key] + if "type" not in meta: + logger.debug(f"配置项 {path}{key} 没有类型定义, 跳过校验") + continue + # null 转换 + if value is None: + data[key] = DEFAULT_VALUE_MAP[meta["type"]] + continue + if meta["type"] == "list" and not isinstance(value, list): + errors.append( + f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}", + ) + elif ( + meta["type"] == "list" + and isinstance(value, list) + and value + and "items" in meta + and isinstance(value[0], dict) + ): + # 当前仅针对 list[dict] 的情况进行类型校验,以适配 AstrBot 中 platform、provider 的配置 + for item in value: + validate(item, meta["items"], path=f"{path}{key}.") + elif meta["type"] == "object" and isinstance(value, dict): + validate(value, meta["items"], path=f"{path}{key}.") + + if meta["type"] == "int" and not isinstance(value, int): + casted = try_cast(value, "int") + if casted is None: + errors.append( + f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}", + ) + data[key] = casted + elif meta["type"] == "float" and not isinstance(value, float): + casted = try_cast(value, "float") + if casted is None: + errors.append( + f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}", + ) + data[key] = casted + elif meta["type"] == "bool" and not isinstance(value, bool): + errors.append( + f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}", + ) + elif meta["type"] in ["string", "text"] and not isinstance(value, str): + errors.append( + f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}", + ) + elif meta["type"] == "list" and not isinstance(value, list): + errors.append( + f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}", + ) + elif meta["type"] == "object" and not isinstance(value, dict): + errors.append( + f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}", + ) + + if is_core: + meta_all = { + **schema["platform_group"]["metadata"], + **schema["provider_group"]["metadata"], + **schema["misc_config_group"]["metadata"], + } + validate(data, meta_all) + else: + validate(data, schema) + + return errors, data + + +def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False): + """验证并保存配置""" + errors = None + logger.info(f"Saving config, is_core={is_core}") + try: + if is_core: + errors, post_config = validate_config( + post_config, + CONFIG_METADATA_2, + is_core, + ) + else: + errors, post_config = validate_config( + post_config, getattr(config, "schema", {}), is_core + ) + except BaseException as e: + logger.error(traceback.format_exc()) + logger.warning(f"验证配置时出现异常: {e}") + raise ValueError(f"验证配置时出现异常: {e}") + if errors: + raise ValueError(f"格式校验未通过: {errors}") + + config.save_config(post_config) diff --git a/dashboard/src/components/shared/ApiKeyDialog.vue b/dashboard/src/components/shared/ApiKeyDialog.vue new file mode 100644 index 000000000..8a71caab5 --- /dev/null +++ b/dashboard/src/components/shared/ApiKeyDialog.vue @@ -0,0 +1,205 @@ + + + + + diff --git a/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue b/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue index 4dc79356a..c38946337 100644 --- a/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue +++ b/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue @@ -18,6 +18,7 @@ import StyledMenu from '@/components/shared/StyledMenu.vue'; import { useLanguageSwitcher } from '@/i18n/composables'; import type { Locale } from '@/i18n/types'; import AboutPage from '@/views/AboutPage.vue'; +import ApiKeyDialog from '@/components/shared/ApiKeyDialog.vue'; enableKatex(); enableMermaid(); @@ -30,6 +31,7 @@ let dialog = ref(false); let accountWarning = ref(false) let updateStatusDialog = ref(false); let aboutDialog = ref(false); +let apiKeyDialog = ref(false); const username = localStorage.getItem('user'); let password = ref(''); let newPassword = ref(''); @@ -439,6 +441,18 @@ const changeLanguage = async (langCode: string) => { + + + + API Keys + + { + + +