From b0ff17783e8cd90e1b1facfef85cae6b4b541f9f Mon Sep 17 00:00:00 2001 From: Lishiling Date: Sat, 10 Jan 2026 18:45:37 +0800 Subject: [PATCH 1/3] fix: ensure atomic creation of knowledge base with proper cleanup on failure - Added pre-validation for embedding_provider_id parameter - Added check for existing knowledge base with same name - Implemented proper rollback mechanism when KBHelper initialization fails - Uses same session for cleanup to ensure data consistency - Fixes #4403 --- astrbot/core/knowledge_base/kb_mgr.py | 41 ++++++++++++++++++++------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 2219cc00b..9aae339b1 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -92,6 +92,15 @@ async def create_kb( top_m_final: int | None = None, ) -> KBHelper: """创建新的知识库实例""" + # 检测embedding_provider_id + if embedding_provider_id is None: + raise ValueError("创建知识库时必须提供embedding_provider_id") + + # 检查是否已存在同名知识库 + existing_kb = await self.kb_db.get_kb_by_name(kb_name) + if existing_kb is not None: + raise ValueError(f"知识库名称 '{kb_name}' 已存在") + kb = KnowledgeBase( kb_name=kb_name, description=description, @@ -104,21 +113,31 @@ async def create_kb( top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, top_m_final=top_m_final if top_m_final is not None else 5, ) + + kb_helper = None async with self.kb_db.get_db() as session: session.add(kb) await session.commit() await session.refresh(kb) - - kb_helper = KBHelper( - kb_db=self.kb_db, - kb=kb, - provider_manager=self.provider_manager, - kb_root_dir=FILES_PATH, - chunker=CHUNKER, - ) - await kb_helper.initialize() - self.kb_insts[kb.kb_id] = kb_helper - return kb_helper + try: + kb_helper = KBHelper( + kb_db=self.kb_db, + kb=kb, + provider_manager=self.provider_manager, + kb_root_dir=FILES_PATH, + chunker=CHUNKER, + ) + await kb_helper.initialize() + except Exception: + await session.refresh(kb) + await session.delete(kb) + await session.commit() + raise + # 判断是否成功创建 + if kb_helper: + self.kb_insts[kb.kb_id] = kb_helper + return kb_helper + raise RuntimeError("知识库创建失败:未知错误") async def get_kb(self, kb_id: str) -> KBHelper | None: """获取知识库实例""" From b3b63a2501806a62a658e449a7b88b30490110a9 Mon Sep 17 00:00:00 2001 From: Lishiling Date: Sat, 10 Jan 2026 19:30:35 +0800 Subject: [PATCH 2/3] fix: ensure atomic KB creation with session.flush() to remove race condition risks --- astrbot/core/knowledge_base/kb_mgr.py | 34 +++++++++------------------ 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 9aae339b1..1d4705c83 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -91,16 +91,9 @@ async def create_kb( top_k_sparse: int | None = None, top_m_final: int | None = None, ) -> KBHelper: - """创建新的知识库实例""" - # 检测embedding_provider_id + """创建知识库""" if embedding_provider_id is None: raise ValueError("创建知识库时必须提供embedding_provider_id") - - # 检查是否已存在同名知识库 - existing_kb = await self.kb_db.get_kb_by_name(kb_name) - if existing_kb is not None: - raise ValueError(f"知识库名称 '{kb_name}' 已存在") - kb = KnowledgeBase( kb_name=kb_name, description=description, @@ -113,13 +106,11 @@ async def create_kb( top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, top_m_final=top_m_final if top_m_final is not None else 5, ) + try: + async with self.kb_db.get_db() as session: + session.add(kb) + await session.flush() - kb_helper = None - async with self.kb_db.get_db() as session: - session.add(kb) - await session.commit() - await session.refresh(kb) - try: kb_helper = KBHelper( kb_db=self.kb_db, kb=kb, @@ -128,16 +119,13 @@ async def create_kb( chunker=CHUNKER, ) await kb_helper.initialize() - except Exception: - await session.refresh(kb) - await session.delete(kb) await session.commit() - raise - # 判断是否成功创建 - if kb_helper: - self.kb_insts[kb.kb_id] = kb_helper - return kb_helper - raise RuntimeError("知识库创建失败:未知错误") + self.kb_insts[kb.kb_id] = kb_helper + return kb_helper + except Exception as e: + if "kb_name" in str(e): + raise ValueError(f"知识库名称 '{kb_name}' 已存在") + raise async def get_kb(self, kb_id: str) -> KBHelper | None: """获取知识库实例""" From 40cd2086aad131d2cc29a2e6b15cfb88f19379ee Mon Sep 17 00:00:00 2001 From: Lishiling Date: Sat, 10 Jan 2026 19:40:32 +0800 Subject: [PATCH 3/3] fix: ensure change the annotation back --- astrbot/core/knowledge_base/kb_mgr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 1d4705c83..b085924ca 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -91,7 +91,7 @@ async def create_kb( top_k_sparse: int | None = None, top_m_final: int | None = None, ) -> KBHelper: - """创建知识库""" + """创建新的知识库实例""" if embedding_provider_id is None: raise ValueError("创建知识库时必须提供embedding_provider_id") kb = KnowledgeBase(