diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 2219cc00b..b085924ca 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -92,6 +92,8 @@ async def create_kb( top_m_final: int | None = None, ) -> KBHelper: """创建新的知识库实例""" + if embedding_provider_id is None: + raise ValueError("创建知识库时必须提供embedding_provider_id") kb = KnowledgeBase( kb_name=kb_name, description=description, @@ -104,21 +106,26 @@ async def create_kb( top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, top_m_final=top_m_final if top_m_final is not None else 5, ) - async with self.kb_db.get_db() as session: - session.add(kb) - await session.commit() - await session.refresh(kb) - - kb_helper = KBHelper( - kb_db=self.kb_db, - kb=kb, - provider_manager=self.provider_manager, - kb_root_dir=FILES_PATH, - chunker=CHUNKER, - ) - await kb_helper.initialize() - self.kb_insts[kb.kb_id] = kb_helper - return kb_helper + try: + async with self.kb_db.get_db() as session: + session.add(kb) + await session.flush() + + kb_helper = KBHelper( + kb_db=self.kb_db, + kb=kb, + provider_manager=self.provider_manager, + kb_root_dir=FILES_PATH, + chunker=CHUNKER, + ) + await kb_helper.initialize() + await session.commit() + self.kb_insts[kb.kb_id] = kb_helper + return kb_helper + except Exception as e: + if "kb_name" in str(e): + raise ValueError(f"知识库名称 '{kb_name}' 已存在") + raise async def get_kb(self, kb_id: str) -> KBHelper | None: """获取知识库实例"""