diff --git a/CHANGELOG.md b/CHANGELOG.md index 58b7ef4..64eddcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,30 @@ 所有重要更改都将记录在此文件中。 +## [Next-2.0.1] - 2026-02-23 + +### 🔧 Bug 修复 + +#### 插件卸载/重载卡死 (100% CPU) +- 修复 5 个后台 `while True` 任务(`_daily_mood_updater`、`_periodic_memory_sync`、`_periodic_context_cleanup`、`_periodic_knowledge_update`、`_periodic_recommendation_refresh`)未被跟踪和取消的问题 +- `plugin_lifecycle.py` 中 3 个 `asyncio.create_task()` 调用现在全部注册到 `background_tasks` 集合,确保关停时被取消 +- 所有关停步骤添加 `asyncio.wait_for` 超时保护(每步 8s),避免单个服务阻塞整个关停流程 +- `ServiceRegistry.stop_all_services()` 每个服务添加 5s 超时 +- `GroupLearningOrchestrator.cancel_all()` 添加 per-task 超时 +- `Server.stop()` 将 `thread.join()` 移至线程池执行器,避免阻塞事件循环 +- `WebUIManager.stop()` 添加锁获取超时,防止死锁 +- 关停时清理 `SingletonABCMeta._instances`,防止重载后单例残留 + +#### MySQL 兼容性修复 +- 修复 `persona_content` 列 INSERT 时传入 `None` 导致 `IntegrityError (1048)` 的问题 +- 修复 `TEXT` 列不能有 `DEFAULT` 值的 MySQL 严格模式错误 +- 启用启动时自动列迁移,跳过 TEXT/BLOB/JSON 列的 DEFAULT 生成 +- Facade 文件中 65 处方法内延迟导入移至模块级别,修复热重载后 `ModuleNotFoundError` + +#### 人格审批修复 +- 传统审批路径(纯数字 ID)改为通过 `PersonaWebManager` 路由,解决跨线程调用导致的卡死 +- 修复 `save_or_update_jargon` 参数顺序和类型错误 + ## [Next-2.0.0] - 2026-02-22 ### 🎯 新功能 diff --git a/VIDEO_SCRIPT.md b/VIDEO_SCRIPT.md deleted file mode 100644 index b4f30b3..0000000 --- a/VIDEO_SCRIPT.md +++ /dev/null @@ -1,131 +0,0 @@ -# Self-Learning 插件 Bilibili 视频脚本 - -> 预计时长:约 5 分钟 | 定位:功能讲解 + 使用教程 | 语气:口语化、轻松 - ---- - -## 一、开场 Hook(约 30 秒) - -各位好,先问大家一个问题—— - -你有没有觉得,现在的 AI 聊天机器人,虽然什么都能答,但说话总是……一股"AI味儿"? - -你跟它说"6",它回你"您的意思是数字六吗?"。你在群里发个梗,它给你来一段百科解释。 - -说白了,它不懂你们群的"语言",也不知道你们平时怎么聊天。 - -那有没有办法,让 AI 自己去"偷师",学会像真人一样说话? - -今天给大家介绍的这个插件,就是专门干这件事的。 - ---- - -## 二、一句话介绍(约 20 秒) - -这个插件叫 **Self-Learning**,是给 AstrBot 用的一个自主学习插件。 - -它能做的事情用一句话概括就是——**让你的 Bot 潜伏在群里,自动学习大家的说话方式,然后越聊越像真人。** - -不需要你手动喂数据,不需要你写什么 prompt 模板,它全自动完成。 - ---- - -## 三、核心功能讲解(约 3 分钟) - -接下来我挨个说说它到底能干什么。 - -### 1. 自动学群友说话(约 40 秒) - -首先,最核心的能力——**表达模式学习**。 - -插件开启之后,它会在后台默默收集群里的聊天消息。然后定时触发一次学习,用大模型分析这些对话:**在什么场景下,大家会用什么样的表达方式。** - -比如它可能会学到:表示赞同的时候,群友喜欢说"确实"而不是"我同意你的看法"。 - -这些学到的表达模式会被自动注入到 Bot 的回复里,这样 Bot 说话就不会那么正式、那么像机器了。 - -而且它有一个时间衰减机制,过时的表达会自动降权,新学到的会优先使用。所以 Bot 的说话风格会跟着群聊氛围一起"进化"。 - -### 2. 听得懂黑话(约 30 秒) - -第二个功能——**黑话挖掘**。 - -每个群都有自己的"黑话"对吧?比如某个群里"发财了"可能是表示"太好了"的意思,"下次一定"其实是在拒绝。 - -这些东西你不教,AI 是真不懂。 - -这个插件会自动检测群里的高频特殊用语,然后调用大模型根据上下文推断它的真实含义,保存下来。之后 Bot 在回复消息的时候,就能正确理解这些黑话了,不会再闹笑话。 - -### 3. 社交关系网络(约 30 秒) - -第三个——**社交关系分析**。 - -插件不光学说话,它还会"看人"。它会自动记录群里谁跟谁聊得多、谁 at 了谁、谁经常回复谁,把这些互动关系整理成一张社交网络图。 - -这有什么用呢?Bot 回复的时候,它知道这个群里谁跟谁关系好、谁是活跃分子、谁是边缘人。这样它聊天的时候就能更"懂事",不会在两个关系很好的人面前说不合时宜的话。 - -在管理后台里你还能看到一张可视化的关系图谱,节点越大说明这个人越活跃,连线越粗说明两个人互动越频繁,挺有意思的。 - -### 4. 好感度系统(约 30 秒) - -第四个——**好感度和情绪系统**。 - -插件会记录每个人跟 Bot 的互动。经常夸它、跟它友好聊天,好感度就会涨;反过来骂它,好感度就掉。 - -好感度会影响 Bot 的回复态度。对喜欢的人说话更热情、更主动,对不喜欢的人就冷淡一些。 - -Bot 自己还有一套情绪系统,每天会自动切换心情。开心的时候活泼一点,低落的时候话少一点。这样聊起来就更有"人味儿"了。 - -### 5. 人格审查——你说了算(约 30 秒) - -可能有人会担心:Bot 自己学习,万一学歪了怎么办? - -放心,插件有一个**人格审查机制**。Bot 学完之后生成的人格更新建议,不会直接应用,而是先提交到审查队列里。 - -你可以在管理后台看到它打算怎么改、改了哪些内容,觉得没问题就批准,觉得不对就驳回。**最终决定权始终在你手里。** - -### 6. 可视化管理后台(约 30 秒) - -可能有人会担心:Bot 自己学习,万一学歪了怎么办? - -放心,插件有一个**人格审查机制**。Bot 学完之后生成的人格更新建议,不会直接应用,而是先提交到审查队列里。 - -你可以在管理后台看到它打算怎么改、改了哪些内容,觉得没问题就批准,觉得不对就驳回。**最终决定权始终在你手里。** - -### 5. 可视化管理后台(约 30 秒) - -最后说一下**WebUI 管理界面**。 - -插件带了一个完整的网页后台,默认端口 7833,浏览器直接就能访问。 - -里面能看到消息收集的数据统计、学习进度、社交关系网络图、好感度排行榜,还有刚才说的人格审查和风格学习详情。 - -基本上 Bot 在做什么、学到了什么、效果怎么样,一目了然。不需要去翻日志,也不用敲命令,全部可视化搞定。 - ---- - -## 四、安装使用(约 30 秒) - -说了这么多,怎么用呢?非常简单。 - -第一步,确保你已经在跑 AstrBot,版本不低于 4.11.4。 - -第二步,在 AstrBot 的插件商店里搜索 **self-learning**,一键安装。 - -第三步,到 AstrBot 后台的插件配置里,把"启用消息抓取"和"启用自动学习"打开。 - -然后你就不用管了。它会自动开始收集消息、自动学习。过几个小时你再去看,Bot 说话就已经开始变了。 - -如果你想看详细数据,浏览器打开 `你的服务器地址:7833`,登录管理后台就行。默认密码在配置里,记得第一次登录之后改掉。 - ---- - -## 收尾(约 20 秒) - -总结一下,这个插件做的事情就是:**让你的 AI Bot 从一个"什么都会但不会说人话"的机器,变成一个能融入群聊、听得懂黑话、记得住谁对它好的"拟人化 Bot"。** - -感兴趣的话,GitHub 搜 **astrbot_plugin_self_learning** 就能找到,觉得有用欢迎给个 Star。 - -遇到问题或者想交流,可以加 QQ 群 **1021544792**。 - -好,这期就到这里,我们下期见。 diff --git a/core/database/engine.py b/core/database/engine.py index 354a029..df1d50e 100644 --- a/core/database/engine.py +++ b/core/database/engine.py @@ -287,10 +287,13 @@ def _get_existing_columns(sync_conn): col_type = col.type.compile(self.engine.dialect) nullable = "NULL" if col.nullable else "NOT NULL" default = "" - if col.server_default is not None: - default = f" DEFAULT {col.server_default.arg!r}" - elif col.default is not None and col.default.is_scalar: - default = f" DEFAULT {col.default.arg!r}" + # MySQL 不允许 TEXT/BLOB 列有 DEFAULT 值 + is_text_type = col_type.upper() in ("TEXT", "BLOB", "MEDIUMTEXT", "LONGTEXT", "JSON") + if not is_text_type: + if col.server_default is not None: + default = f" DEFAULT {col.server_default.arg!r}" + elif col.default is not None and col.default.is_scalar: + default = f" DEFAULT {col.default.arg!r}" alter_statements.append( f"ALTER TABLE `{table.name}` ADD COLUMN " f"`{col.name}` {col_type} {nullable}{default}" diff --git a/core/patterns.py b/core/patterns.py index fc1756f..69ec263 100644 --- a/core/patterns.py +++ b/core/patterns.py @@ -348,16 +348,22 @@ async def start_all_services(self) -> bool: return all(results) + _SERVICE_STOP_TIMEOUT = 5 # 每个服务停止的超时秒数 + async def stop_all_services(self) -> bool: - """停止所有服务""" + """停止所有服务(每个服务带超时,避免卡死)""" + import asyncio + self._logger.info("停止所有服务") results = [] - + for name, service in self._services.items(): try: - # 检查服务是否有stop方法 if hasattr(service, 'stop') and callable(getattr(service, 'stop')): - result = await service.stop() + result = await asyncio.wait_for( + service.stop(), + timeout=self._SERVICE_STOP_TIMEOUT, + ) results.append(result) if not result: self._logger.error(f"服务 {name} 停止失败") @@ -365,14 +371,16 @@ async def stop_all_services(self) -> bool: self._logger.info(f"服务 {name} 已停止") else: self._logger.warning(f"服务 {name} 没有stop方法,跳过停止") - results.append(True) # 没有stop方法就认为成功 - except AttributeError as e: - self._logger.error(f"停止服务 {name} 异常:{e}") + results.append(True) + except asyncio.TimeoutError: + self._logger.warning( + f"服务 {name} 停止超时 ({self._SERVICE_STOP_TIMEOUT}s),跳过" + ) results.append(False) except Exception as e: self._logger.error(f"停止服务 {name} 异常: {e}") results.append(False) - + return all(results) def get_service_status(self) -> Dict[str, str]: diff --git a/core/plugin_lifecycle.py b/core/plugin_lifecycle.py index e50cb7f..d2d1d07 100644 --- a/core/plugin_lifecycle.py +++ b/core/plugin_lifecycle.py @@ -243,11 +243,15 @@ def bootstrap( ) need_immediate_start = self._webui_manager.create_server() if need_immediate_start: - asyncio.create_task(self._webui_manager.immediate_start(p.db_manager)) + _t = asyncio.create_task(self._webui_manager.immediate_start(p.db_manager)) + p.background_tasks.add(_t) + _t.add_done_callback(p.background_tasks.discard) # ------ 自动学习启动(必须在 _group_orchestrator 创建之后)------ if plugin_config.enable_auto_learning: - asyncio.create_task(p._group_orchestrator.delayed_auto_start_learning()) + _t = asyncio.create_task(p._group_orchestrator.delayed_auto_start_learning()) + p.background_tasks.add(_t) + _t.add_done_callback(p.background_tasks.discard) logger.info(StatusMessages.FACTORY_SERVICES_INIT_COMPLETE) @@ -296,7 +300,9 @@ def _setup_internal_components( p.learning_scheduler = component_factory.create_learning_scheduler(p) p.background_tasks = set() - asyncio.create_task(self._delayed_provider_reinitialization()) + _t = asyncio.create_task(self._delayed_provider_reinitialization()) + p.background_tasks.add(_t) + _t.add_done_callback(p.background_tasks.discard) # Phase 2: 异步启动(on_load 阶段调用) @@ -375,8 +381,23 @@ async def on_load(self) -> None: # Phase 3: 有序关停(terminate 阶段调用) + _STEP_TIMEOUT = 8 # 每个关停步骤的超时秒数 + _TASK_CANCEL_TIMEOUT = 3 # 每个后台任务取消等待的超时秒数 + + async def _safe_step(self, label: str, coro, timeout: float = None) -> None: + """执行一个关停步骤,超时或异常均不阻塞后续步骤""" + if timeout is None: + timeout = self._STEP_TIMEOUT + try: + await asyncio.wait_for(coro, timeout=timeout) + logger.info(f"{label} 完成") + except asyncio.TimeoutError: + logger.warning(f"{label} 超时 ({timeout}s),跳过") + except Exception as e: + logger.error(f"{label} 失败: {e}") + async def shutdown(self) -> None: - """有序关停所有服务""" + """有序关停所有服务(每步带超时,避免卡死)""" p = self._plugin try: logger.info("开始插件清理工作...") @@ -384,25 +405,30 @@ async def shutdown(self) -> None: # 1. 停止学习任务 logger.info("停止所有学习任务...") if getattr(p, "_group_orchestrator", None): - await p._group_orchestrator.cancel_all() + await self._safe_step( + "停止学习任务", + p._group_orchestrator.cancel_all(), + ) # 2. 停止学习调度器 if hasattr(p, "learning_scheduler"): - try: - await p.learning_scheduler.stop() - logger.info("学习调度器已停止") - except Exception as e: - logger.error(f"停止学习调度器失败: {e}") + await self._safe_step( + "停止学习调度器", + p.learning_scheduler.stop(), + ) - # 3. 取消后台任务 + # 3. 取消后台任务(每个任务单独超时) logger.info("取消所有后台任务...") for task in list(p.background_tasks): try: if not task.done(): task.cancel() try: - await task - except asyncio.CancelledError: + await asyncio.wait_for( + asyncio.shield(task), + timeout=self._TASK_CANCEL_TIMEOUT, + ) + except (asyncio.CancelledError, asyncio.TimeoutError): pass except Exception as e: logger.error( @@ -411,21 +437,18 @@ async def shutdown(self) -> None: p.background_tasks.clear() # 4. 停止服务工厂 - logger.info("停止所有服务...") if hasattr(p, "factory_manager"): - try: - await p.factory_manager.cleanup() - logger.info("服务工厂已清理") - except Exception as e: - logger.error(f"清理服务工厂失败: {e}") + await self._safe_step( + "清理服务工厂", + p.factory_manager.cleanup(), + ) # 4.5 停止 V2 if getattr(p, "v2_integration", None): - try: - await p.v2_integration.stop() - logger.info("V2LearningIntegration stopped") - except Exception as e: - logger.error(f"V2LearningIntegration stop failed: {e}") + await self._safe_step( + "停止 V2LearningIntegration", + p.v2_integration.stop(), + ) # 4.6 重置单例 try: @@ -437,25 +460,34 @@ async def shutdown(self) -> None: except Exception: pass + try: + from .patterns import SingletonABCMeta + + SingletonABCMeta._instances.clear() + logger.info("SingletonABCMeta 实例缓存已清理") + except Exception: + pass + # 5. 清理临时人格 if hasattr(p, "temporary_persona_updater"): - try: - await p.temporary_persona_updater.cleanup_temp_personas() - logger.info("临时人格已清理") - except Exception as e: - logger.error(f"清理临时人格失败: {e}") + await self._safe_step( + "清理临时人格", + p.temporary_persona_updater.cleanup_temp_personas(), + ) # 6. 保存状态 if hasattr(p, "message_collector"): - try: - await p.message_collector.save_state() - logger.info("消息收集器状态已保存") - except Exception as e: - logger.error(f"保存消息收集器状态失败: {e}") + await self._safe_step( + "保存消息收集器状态", + p.message_collector.save_state(), + ) # 7. 停止 WebUI if self._webui_manager: - await self._webui_manager.stop() + await self._safe_step( + "停止 WebUI", + self._webui_manager.stop(), + ) # 8. 保存配置 try: diff --git a/metadata.yaml b/metadata.yaml index 79bb374..5cead40 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -2,7 +2,7 @@ name: "astrbot_plugin_self_learning" author: "NickMo" display_name: "self-learning" description: "SELF LEARNING 自主学习插件 — 让 AI 聊天机器人自主学习对话风格、理解群组黑话、管理社交关系与好感度、自适应人格演化,像真人一样自然对话。(使用前必须手动备份人格数据)" -version: "Next-2.0.0" +version: "Next-2.0.1" repo: "https://github.com/NickCharlie/astrbot_plugin_self_learning" tags: - "自学习" diff --git a/models/orm/psychological.py b/models/orm/psychological.py index 736a44b..9fa543c 100644 --- a/models/orm/psychological.py +++ b/models/orm/psychological.py @@ -231,7 +231,7 @@ class PersonaBackup(Base): imitation_dialogues = Column(Text) # JSON backup_reason = Column(Text) backup_time = Column(Float, nullable=True) # legacy column in production DB - persona_content = Column(Text, nullable=True, default='', server_default='') # legacy column in production DB + persona_content = Column(Text, nullable=True) # legacy column in production DB created_at = Column(DateTime, default=func.now()) __table_args__ = ( diff --git a/services/analysis/intelligence_enhancement.py b/services/analysis/intelligence_enhancement.py index e4fda47..a25f7be 100644 --- a/services/analysis/intelligence_enhancement.py +++ b/services/analysis/intelligence_enhancement.py @@ -98,18 +98,27 @@ async def _do_start(self) -> bool: await self._load_knowledge_graph() await self._load_user_preferences() - # 启动定期任务 - asyncio.create_task(self._periodic_knowledge_update()) - asyncio.create_task(self._periodic_recommendation_refresh()) - + # 启动定期任务(保留引用以便 stop 时取消) + self._knowledge_task = asyncio.create_task(self._periodic_knowledge_update()) + self._recommend_task = asyncio.create_task(self._periodic_recommendation_refresh()) + self._logger.info("智能化提升服务启动成功") return True except Exception as e: self._logger.error(f"智能化提升服务启动失败: {e}") return False - + async def _do_stop(self) -> bool: """停止智能化服务""" + # 取消后台任务 + for task in (getattr(self, '_knowledge_task', None), + getattr(self, '_recommend_task', None)): + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass await self._save_emotion_profiles() await self._save_knowledge_graph() await self._save_user_preferences() @@ -879,34 +888,33 @@ async def update_adaptive_learning_rate(self, group_id: str, user_id: str, async def _periodic_knowledge_update(self): """定期更新知识图谱""" - while True: - try: + try: + while True: await asyncio.sleep(3600) # 每小时更新一次 - + # 清理过期实体 current_time = time.time() expired_entities = [] - + for entity_id, entity in self.knowledge_entities.items(): if current_time - entity.last_mentioned > 86400 * 7: # 7天未提及 expired_entities.append(entity_id) - + for entity_id in expired_entities: del self.knowledge_entities[entity_id] if self.knowledge_graph.has_node(entity_id): self.knowledge_graph.remove_node(entity_id) - + self._logger.info(f"清理过期知识实体: {len(expired_entities)}") - - except Exception as e: - self._logger.error(f"知识图谱更新失败: {e}") + except asyncio.CancelledError: + self._logger.debug("知识图谱更新任务已取消") async def _periodic_recommendation_refresh(self): """定期刷新推荐缓存""" - while True: - try: + try: + while True: await asyncio.sleep(1800) # 30分钟刷新一次 - + # 清理过期推荐 current_time = time.time() for user_key in list(self.recommendation_cache.keys()): @@ -915,14 +923,13 @@ async def _periodic_recommendation_refresh(self): rec for rec in recommendations if current_time - rec.timestamp < 3600 # 1小时内的推荐 ] - + if fresh_recommendations: self.recommendation_cache[user_key] = fresh_recommendations else: del self.recommendation_cache[user_key] - - except Exception as e: - self._logger.error(f"推荐缓存刷新失败: {e}") + except asyncio.CancelledError: + self._logger.debug("推荐缓存刷新任务已取消") async def _load_emotion_profiles(self): """加载情感档案""" diff --git a/services/database/facades/admin_facade.py b/services/database/facades/admin_facade.py index 276912f..1297b2f 100644 --- a/services/database/facades/admin_facade.py +++ b/services/database/facades/admin_facade.py @@ -6,6 +6,11 @@ from astrbot.api import logger from ._base import BaseFacade +from sqlalchemy import delete as sa_delete, select +from ....models.orm.learning import LearningBatch +from ....models.orm.message import FilteredMessage, RawMessage +from ....models.orm.performance import LearningPerformanceHistory +from ....models.orm.reinforcement import PersonaFusionHistory, ReinforcementLearningResult, StrategyOptimizationResult class AdminFacade(BaseFacade): @@ -15,14 +20,6 @@ async def clear_all_messages_data(self) -> bool: """清除所有消息与学习数据(批量删除多个表)""" try: async with self.get_session() as session: - from sqlalchemy import delete as sa_delete - from ....models.orm.message import RawMessage, FilteredMessage - from ....models.orm.learning import LearningBatch - from ....models.orm.reinforcement import ( - ReinforcementLearningResult, PersonaFusionHistory, - StrategyOptimizationResult - ) - from ....models.orm.performance import LearningPerformanceHistory tables = [ FilteredMessage, RawMessage, LearningBatch, @@ -50,8 +47,6 @@ async def export_messages_learning_data( """导出原始消息和筛选消息""" try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.message import RawMessage, FilteredMessage raw_stmt = select(RawMessage) filtered_stmt = select(FilteredMessage) diff --git a/services/database/facades/expression_facade.py b/services/database/facades/expression_facade.py index 015765d..e07fe74 100644 --- a/services/database/facades/expression_facade.py +++ b/services/database/facades/expression_facade.py @@ -9,6 +9,14 @@ from ._base import BaseFacade from ....repositories.style_profile_repository import StyleProfileRepository +from sqlalchemy import desc, func, select +from ....models.orm.expression import ( + ExpressionPattern, + LanguageStylePattern, + StyleLearningRecord, + StyleProfile, +) +from ....repositories.expression_repository import ExpressionPatternRepository class ExpressionFacade(BaseFacade): @@ -18,7 +26,6 @@ async def get_all_expression_patterns(self) -> Dict[str, List[Dict[str, Any]]]: """获取所有群组的表达模式""" try: async with self.get_session() as session: - from ....repositories.expression_repository import ExpressionPatternRepository repo = ExpressionPatternRepository(session) all_patterns = await repo.get_all(limit=1000) @@ -38,8 +45,6 @@ async def get_expression_patterns_statistics(self) -> Dict[str, Any]: """获取表达模式统计""" try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.expression import ExpressionPattern total_stmt = select(func.count()).select_from(ExpressionPattern) total_result = await session.execute(total_stmt) @@ -60,7 +65,6 @@ async def get_group_expression_patterns( """获取指定群组的表达模式""" try: async with self.get_session() as session: - from ....repositories.expression_repository import ExpressionPatternRepository repo = ExpressionPatternRepository(session) patterns = await repo.find_many( @@ -77,8 +81,6 @@ async def get_recent_week_expression_patterns( """获取最近指定时间范围的表达模式""" try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.expression import ExpressionPattern cutoff = time.time() - (hours * 3600) stmt = select(ExpressionPattern).where( @@ -124,8 +126,6 @@ async def save_style_profile( """保存风格画像(upsert)""" try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.expression import StyleProfile stmt = select(StyleProfile).where(StyleProfile.profile_name == profile_name) result = await session.execute(stmt) @@ -154,7 +154,6 @@ async def save_style_learning_record(self, record_data: Dict[str, Any]) -> bool: """保存风格学习记录""" try: async with self.get_session() as session: - from ....models.orm.expression import StyleLearningRecord rec = StyleLearningRecord( style_type=record_data.get('style_type', 'unknown'), @@ -176,8 +175,6 @@ async def save_language_style_pattern( """保存语言风格模式(upsert)""" try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.expression import LanguageStylePattern stmt = select(LanguageStylePattern).where( LanguageStylePattern.language_style == language_style diff --git a/services/database/facades/jargon_facade.py b/services/database/facades/jargon_facade.py index 6e95027..fa9bbe9 100644 --- a/services/database/facades/jargon_facade.py +++ b/services/database/facades/jargon_facade.py @@ -7,9 +7,11 @@ import json from typing import Dict, List, Optional, Any +from sqlalchemy import select, and_, func, desc, or_, case from astrbot.api import logger from ._base import BaseFacade +from ....models.orm.jargon import Jargon class JargonFacade(BaseFacade): @@ -28,8 +30,6 @@ async def get_jargon(self, chat_id: str, content: str) -> Optional[Dict[str, Any """ try: async with self.get_session() as session: - from sqlalchemy import select, and_ - from ....models.orm.jargon import Jargon stmt = select(Jargon).where(and_( Jargon.chat_id == chat_id, @@ -59,7 +59,6 @@ async def insert_jargon(self, jargon_data: Dict[str, Any]) -> Optional[int]: """ try: async with self.get_session() as session: - from ....models.orm.jargon import Jargon now_ts = int(time.time()) @@ -124,8 +123,6 @@ async def update_jargon(self, jargon_data: Dict[str, Any]) -> bool: try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.jargon import Jargon stmt = select(Jargon).where(Jargon.id == jargon_id) result = await session.execute(stmt) @@ -191,8 +188,6 @@ async def get_jargon_statistics(self, group_id: str = None) -> Dict[str, Any]: } try: async with self.get_session() as session: - from sqlalchemy import select, func, case - from ....models.orm.jargon import Jargon columns = [ func.count().label('total'), @@ -263,8 +258,6 @@ async def get_recent_jargon_list( try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.jargon import Jargon # 构建查询 stmt = select(Jargon) @@ -338,8 +331,6 @@ async def get_jargon_count( """ try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.jargon import Jargon stmt = select(func.count(Jargon.id)) @@ -380,8 +371,6 @@ async def search_jargon( """ try: async with self.get_session() as session: - from sqlalchemy import select, and_ - from ....models.orm.jargon import Jargon conditions = [ Jargon.content.ilike(f'%{keyword}%'), @@ -434,8 +423,6 @@ async def get_jargon_by_id(self, jargon_id: int) -> Optional[Dict]: """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.jargon import Jargon stmt = select(Jargon).where(Jargon.id == jargon_id) result = await session.execute(stmt) @@ -464,8 +451,6 @@ async def delete_jargon_by_id(self, jargon_id: int) -> bool: """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.jargon import Jargon stmt = select(Jargon).where(Jargon.id == jargon_id) result = await session.execute(stmt) @@ -497,8 +482,6 @@ async def set_jargon_global(self, jargon_id: int, is_global: bool) -> bool: """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.jargon import Jargon stmt = select(Jargon).where(Jargon.id == jargon_id) result = await session.execute(stmt) @@ -534,8 +517,6 @@ async def sync_global_jargon_to_group(self, target_chat_id: str) -> int: """ try: async with self.get_session() as session: - from sqlalchemy import select, and_ - from ....models.orm.jargon import Jargon # 获取非目标群组的全局黑话 stmt = select(Jargon).where(and_( @@ -607,8 +588,6 @@ async def save_or_update_jargon( """ try: async with self.get_session() as session: - from sqlalchemy import select, and_ - from ....models.orm.jargon import Jargon stmt = select(Jargon).where(and_( Jargon.chat_id == chat_id, @@ -686,8 +665,6 @@ async def get_global_jargon_list(self, limit: int = 50) -> List[Dict]: """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.jargon import Jargon stmt = select(Jargon).where( Jargon.is_jargon == True, @@ -734,8 +711,6 @@ async def get_jargon_groups(self) -> List[Dict]: """ try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.jargon import Jargon stmt = select( Jargon.chat_id, diff --git a/services/database/facades/learning_facade.py b/services/database/facades/learning_facade.py index 1956974..e563b99 100644 --- a/services/database/facades/learning_facade.py +++ b/services/database/facades/learning_facade.py @@ -8,6 +8,16 @@ from astrbot.api import logger from ._base import BaseFacade +from sqlalchemy import delete as sa_delete, desc, func, select +from ....models.orm.learning import ( + LearningBatch, + LearningSession, + PersonaLearningReview, + StyleLearningPattern, + StyleLearningReview, +) +from ....models.orm.message import FilteredMessage +from ....models.orm.performance import LearningPerformanceHistory class LearningFacade(BaseFacade): @@ -26,7 +36,6 @@ async def add_persona_learning_review(self, review_data: Dict[str, Any]) -> int: """ try: async with self.get_session() as session: - from ....models.orm.learning import PersonaLearningReview metadata = review_data.get('metadata', {}) record = PersonaLearningReview( @@ -59,8 +68,6 @@ async def get_pending_persona_update_records(self) -> List[Dict[str, Any]]: """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import PersonaLearningReview stmt = ( select(PersonaLearningReview) @@ -117,8 +124,6 @@ async def update_persona_update_record_status( """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.learning import PersonaLearningReview stmt = select(PersonaLearningReview).where( PersonaLearningReview.id == record_id @@ -148,8 +153,6 @@ async def delete_persona_update_record(self, record_id: int) -> bool: """ try: async with self.get_session() as session: - from sqlalchemy import select, delete as sa_delete - from ....models.orm.learning import PersonaLearningReview stmt = select(PersonaLearningReview).where( PersonaLearningReview.id == record_id @@ -183,8 +186,6 @@ async def get_persona_update_record_by_id( """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.learning import PersonaLearningReview stmt = select(PersonaLearningReview).where( PersonaLearningReview.id == record_id @@ -227,8 +228,6 @@ async def get_reviewed_persona_update_records( """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import PersonaLearningReview if status_filter: stmt = ( @@ -285,8 +284,6 @@ async def get_pending_persona_learning_reviews( """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import PersonaLearningReview stmt = ( select(PersonaLearningReview) @@ -337,8 +334,6 @@ async def get_reviewed_persona_learning_updates( """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import PersonaLearningReview if status_filter: stmt = ( @@ -391,8 +386,6 @@ async def delete_persona_learning_review_by_id(self, review_id: int) -> bool: """ try: async with self.get_session() as session: - from sqlalchemy import select, delete as sa_delete - from ....models.orm.learning import PersonaLearningReview stmt = select(PersonaLearningReview).where( PersonaLearningReview.id == review_id @@ -443,8 +436,6 @@ async def update_persona_learning_review_status( """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.learning import PersonaLearningReview stmt = select(PersonaLearningReview).where( PersonaLearningReview.id == review_id @@ -483,7 +474,6 @@ async def create_style_learning_review( """ try: async with self.get_session() as session: - from ....models.orm.learning import StyleLearningReview learned_patterns = review_data.get('learned_patterns', []) record = StyleLearningReview( @@ -519,8 +509,6 @@ async def get_pending_style_reviews(self, limit=None, offset=0) -> List[Dict]: """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import StyleLearningReview stmt = ( select(StyleLearningReview) @@ -570,8 +558,6 @@ async def get_approved_few_shots( """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import StyleLearningReview stmt = ( select(StyleLearningReview.few_shots_content) @@ -605,8 +591,6 @@ async def get_reviewed_style_learning_updates( """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import StyleLearningReview if status_filter: stmt = ( @@ -663,8 +647,6 @@ async def update_style_review_status( """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.learning import StyleLearningReview stmt = select(StyleLearningReview).where( StyleLearningReview.id == review_id @@ -694,8 +676,6 @@ async def delete_style_review_by_id(self, review_id: int) -> bool: """ try: async with self.get_session() as session: - from sqlalchemy import select, delete as sa_delete - from ....models.orm.learning import StyleLearningReview stmt = select(StyleLearningReview).where( StyleLearningReview.id == review_id @@ -732,8 +712,6 @@ async def get_learning_batch_history( """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import LearningBatch stmt = ( select(LearningBatch) @@ -761,8 +739,6 @@ async def get_recent_learning_batches(self, limit=5) -> List[Dict]: """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import LearningBatch stmt = ( select(LearningBatch) @@ -788,8 +764,6 @@ async def get_learning_sessions(self, group_id, limit=5) -> List[Dict]: """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import LearningSession stmt = ( select(LearningSession) @@ -815,8 +789,6 @@ async def get_recent_learning_sessions(self, days=7) -> List[Dict]: """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import LearningSession cutoff = time.time() - (days * 24 * 3600) stmt = ( @@ -845,8 +817,6 @@ async def save_learning_session_record( """ try: async with self.get_session() as session: - from ....models.orm.learning import LearningSession - from sqlalchemy import select sid = session_data.get('session_id', '') @@ -895,7 +865,6 @@ async def save_learning_performance_record( """ try: async with self.get_session() as session: - from ....models.orm.performance import LearningPerformanceHistory metadata = performance_data.get('metadata', {}) record = LearningPerformanceHistory( @@ -941,8 +910,6 @@ async def count_pending_persona_updates(self) -> int: """ try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.learning import PersonaLearningReview stmt = ( select(func.count()) @@ -963,8 +930,6 @@ async def count_style_learning_patterns(self) -> int: """ try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.learning import StyleLearningPattern stmt = select(func.count()).select_from(StyleLearningPattern) result = await session.execute(stmt) @@ -981,8 +946,6 @@ async def count_refined_messages(self) -> int: """ try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.message import FilteredMessage stmt = select(func.count()).select_from(FilteredMessage) result = await session.execute(stmt) @@ -999,8 +962,6 @@ async def get_style_learning_statistics(self) -> Dict[str, Any]: """ try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.learning import StyleLearningReview total_stmt = select(func.count()).select_from(StyleLearningReview) total_result = await session.execute(total_stmt) @@ -1048,8 +1009,6 @@ async def get_style_progress_data( """ try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import LearningBatch stmt = ( select(LearningBatch) @@ -1097,8 +1056,6 @@ async def get_learning_patterns_data( """ try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.learning import StyleLearningPattern stmt = select( StyleLearningPattern.pattern_type, diff --git a/services/database/facades/message_facade.py b/services/database/facades/message_facade.py index 878f6e8..9c914bf 100644 --- a/services/database/facades/message_facade.py +++ b/services/database/facades/message_facade.py @@ -10,6 +10,9 @@ from ....repositories.raw_message_repository import RawMessageRepository from ....repositories.filtered_message_repository import FilteredMessageRepository from ....repositories.bot_message_repository import BotMessageRepository +from sqlalchemy import and_, desc, distinct, func, select +from ....models.orm.message import BotMessage, FilteredMessage, RawMessage +from ....models.orm.social_relation import SocialRelation class MessageFacade(BaseFacade): @@ -28,7 +31,6 @@ async def save_raw_message(self, message_data) -> int: """ try: async with self.get_session() as session: - from ....models.orm.message import RawMessage if hasattr(message_data, '__dict__'): data = message_data.__dict__ @@ -146,8 +148,6 @@ async def get_messages_for_replay( """获取用于记忆重放的消息""" try: async with self.get_session() as session: - from sqlalchemy import select, desc, and_ - from ....models.orm.message import RawMessage cutoff_time = time.time() - (days * 24 * 3600) stmt = ( @@ -272,8 +272,6 @@ async def get_message_statistics( try: async with self.get_session() as session: - from sqlalchemy import select, func, and_ - from ....models.orm.message import RawMessage, FilteredMessage total_stmt = select(func.count()).select_from(RawMessage).where( RawMessage.group_id == group_id @@ -308,8 +306,6 @@ async def get_messages_statistics(self) -> Dict[str, Any]: """获取全局消息统计""" try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.message import RawMessage, FilteredMessage, BotMessage raw_count = (await session.execute( select(func.count()).select_from(RawMessage) @@ -367,8 +363,6 @@ async def get_groups_for_social_analysis(self) -> List[Dict[str, Any]]: """ try: async with self.get_session() as session: - from sqlalchemy import select, func, distinct - from ....models.orm.message import RawMessage # 查询 1: 从 RawMessage 获取群组列表、消息数、成员数 msg_stmt = ( @@ -393,7 +387,6 @@ async def get_groups_for_social_analysis(self) -> List[Dict[str, Any]]: # 查询 2: 从 SocialRelation 获取每个群组的关系数(可选) if groups: try: - from ....models.orm.social_relation import SocialRelation rel_stmt = ( select( SocialRelation.group_id, diff --git a/services/database/facades/metrics_facade.py b/services/database/facades/metrics_facade.py index 497262f..1615d60 100644 --- a/services/database/facades/metrics_facade.py +++ b/services/database/facades/metrics_facade.py @@ -7,6 +7,14 @@ from astrbot.api import logger from ._base import BaseFacade +from sqlalchemy import and_, func, select +from ....models.orm.learning import ( + LearningBatch, + PersonaLearningReview, + StyleLearningPattern, + StyleLearningReview, +) +from ....models.orm.message import BotMessage, FilteredMessage, RawMessage class MetricsFacade(BaseFacade): @@ -16,9 +24,6 @@ async def get_group_statistics(self, group_id: str = None) -> Dict[str, Any]: """获取群组综合统计数据""" try: async with self.get_session() as session: - from sqlalchemy import select, func, and_ - from ....models.orm.message import RawMessage, FilteredMessage - from ....models.orm.learning import PersonaLearningReview, StyleLearningReview # 原始消息数 raw_stmt = select(func.count()).select_from(RawMessage) @@ -63,12 +68,6 @@ async def get_detailed_metrics(self, group_id: str = None) -> Dict[str, Any]: """获取详细指标""" try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.message import RawMessage, FilteredMessage, BotMessage - from ....models.orm.learning import ( - PersonaLearningReview, StyleLearningReview, - LearningBatch, StyleLearningPattern - ) async def _count(model, group_col=None): stmt = select(func.count()).select_from(model) @@ -111,9 +110,6 @@ async def get_trends_data(self) -> Dict[str, Any]: """获取趋势数据""" try: async with self.get_session() as session: - from sqlalchemy import select, func - from ....models.orm.message import RawMessage - from ....models.orm.learning import LearningBatch # 过去7天每天的消息数 cutoff = int(time.time()) - (7 * 24 * 3600) diff --git a/services/database/facades/persona_facade.py b/services/database/facades/persona_facade.py index a46152d..d0b3072 100644 --- a/services/database/facades/persona_facade.py +++ b/services/database/facades/persona_facade.py @@ -9,6 +9,9 @@ from ._base import BaseFacade from ....repositories.persona_backup_repository import PersonaBackupRepository +from sqlalchemy import desc, select +from ....models.orm.learning import PersonaLearningReview +from ....models.orm.psychological import PersonaBackup class PersonaFacade(BaseFacade): @@ -18,7 +21,6 @@ async def backup_persona(self, backup_data: Dict[str, Any]) -> bool: """创建人格备份""" try: async with self.get_session() as session: - from ....models.orm.psychological import PersonaBackup now = time.time() backup = PersonaBackup( @@ -31,6 +33,7 @@ async def backup_persona(self, backup_data: Dict[str, Any]) -> bool: imitation_dialogues=json.dumps(backup_data.get('imitation_dialogues', []), ensure_ascii=False), backup_reason=backup_data.get('backup_reason', ''), backup_time=now, + persona_content=backup_data.get('persona_content', ''), ) session.add(backup) await session.commit() @@ -88,8 +91,6 @@ async def get_persona_update_history( """获取人格更新历史""" try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.learning import PersonaLearningReview stmt = select(PersonaLearningReview).order_by( desc(PersonaLearningReview.timestamp) diff --git a/services/database/facades/psychological_facade.py b/services/database/facades/psychological_facade.py index d2e4464..450048b 100644 --- a/services/database/facades/psychological_facade.py +++ b/services/database/facades/psychological_facade.py @@ -9,6 +9,8 @@ from ._base import BaseFacade from ....repositories.emotion_profile_repository import EmotionProfileRepository +from sqlalchemy import and_, select +from ....models.orm.psychological import EmotionProfile class PsychologicalFacade(BaseFacade): @@ -43,8 +45,6 @@ async def save_emotion_profile( """保存情绪画像(upsert)""" try: async with self.get_session() as session: - from sqlalchemy import select, and_ - from ....models.orm.psychological import EmotionProfile stmt = select(EmotionProfile).where( and_(EmotionProfile.user_id == user_id, EmotionProfile.group_id == group_id) diff --git a/services/database/facades/reinforcement_facade.py b/services/database/facades/reinforcement_facade.py index 75a0753..c6f499f 100644 --- a/services/database/facades/reinforcement_facade.py +++ b/services/database/facades/reinforcement_facade.py @@ -12,6 +12,8 @@ PersonaFusionRepository, StrategyOptimizationRepository, ) +from sqlalchemy import desc, select +from ....models.orm.performance import LearningPerformanceHistory class ReinforcementFacade(BaseFacade): @@ -23,8 +25,6 @@ async def get_learning_history_for_reinforcement( """获取用于强化学习的历史数据""" try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.performance import LearningPerformanceHistory stmt = ( select(LearningPerformanceHistory) @@ -90,8 +90,6 @@ async def get_learning_performance_history( """获取学习性能历史""" try: async with self.get_session() as session: - from sqlalchemy import select, desc - from ....models.orm.performance import LearningPerformanceHistory stmt = ( select(LearningPerformanceHistory) diff --git a/services/database/facades/social_facade.py b/services/database/facades/social_facade.py index 6eabb13..80720c2 100644 --- a/services/database/facades/social_facade.py +++ b/services/database/facades/social_facade.py @@ -10,6 +10,14 @@ from ._base import BaseFacade from ....repositories.user_profile_repository import UserProfileRepository from ....repositories.user_preferences_repository import UserPreferencesRepository +from sqlalchemy import and_, or_, select +from ....models.orm.social_relation import ( + UserPreferences, + UserProfile, + UserSocialProfile, + UserSocialRelationComponent, +) +from ....repositories.social_repository import SocialRelationComponentRepository class SocialFacade(BaseFacade): @@ -21,7 +29,6 @@ async def load_user_profile(self, qq_id: str) -> Optional[Dict[str, Any]]: """加载用户画像""" try: async with self.get_session() as session: - from ....models.orm.social_relation import UserProfile profile = await session.get(UserProfile, qq_id) if not profile: return None @@ -43,7 +50,6 @@ async def save_user_profile(self, qq_id: str, profile_data: Dict[str, Any]) -> b """保存用户画像(upsert)""" try: async with self.get_session() as session: - from ....models.orm.social_relation import UserProfile profile = await session.get(UserProfile, qq_id) if profile: profile.qq_name = profile_data.get('qq_name', profile.qq_name) @@ -79,8 +85,6 @@ async def load_user_preferences( """加载用户偏好""" try: async with self.get_session() as session: - from sqlalchemy import select, and_ - from ....models.orm.social_relation import UserPreferences stmt = select(UserPreferences).where( and_(UserPreferences.user_id == user_id, UserPreferences.group_id == group_id) ) @@ -106,8 +110,6 @@ async def save_user_preferences( """保存用户偏好(upsert)""" try: async with self.get_session() as session: - from sqlalchemy import select, and_ - from ....models.orm.social_relation import UserPreferences stmt = select(UserPreferences).where( and_(UserPreferences.user_id == user_id, UserPreferences.group_id == group_id) ) @@ -146,8 +148,6 @@ async def get_social_relations_by_group(self, group_id: str) -> List[Dict[str, A """ try: async with self.get_session() as session: - from sqlalchemy import select - from ....models.orm.social_relation import UserSocialRelationComponent stmt = select(UserSocialRelationComponent).where( UserSocialRelationComponent.group_id == group_id @@ -188,14 +188,7 @@ async def save_social_relation( """ try: async with self.get_session() as session: - from ....models.orm.social_relation import ( - UserSocialRelationComponent, - UserSocialProfile, - ) - from sqlalchemy import select - import time as _time - - now = int(_time.time()) + now = int(time.time()) from_user = relation_data.get('from_user', relation_data.get('from_user_id', '')) # 获取或创建 from_user 的 profile 以满足外键约束 @@ -245,10 +238,7 @@ async def get_user_social_relations( """获取用户的社交关系""" try: async with self.get_session() as session: - from ....repositories.social_repository import SocialRelationComponentRepository repo = SocialRelationComponentRepository(session) - from sqlalchemy import select, or_ - from ....models.orm.social_relation import UserSocialRelationComponent stmt = select(UserSocialRelationComponent).where( UserSocialRelationComponent.group_id == group_id, diff --git a/services/database/sqlalchemy_database_manager.py b/services/database/sqlalchemy_database_manager.py index 3068b87..99ec3a6 100644 --- a/services/database/sqlalchemy_database_manager.py +++ b/services/database/sqlalchemy_database_manager.py @@ -73,7 +73,7 @@ async def start(self) -> bool: self.engine = DatabaseEngine(db_url, echo=False) logger.info("[DomainRouter] 数据库引擎已创建") - await self.engine.create_tables() + await self.engine.create_tables(enable_auto_migration=True) if await self.engine.health_check(): self._init_facades() diff --git a/services/learning/group_orchestrator.py b/services/learning/group_orchestrator.py index 7dcac1c..0bcc084 100644 --- a/services/learning/group_orchestrator.py +++ b/services/learning/group_orchestrator.py @@ -216,14 +216,25 @@ def _apply_filter(stmt): async def cancel_all(self) -> None: """Cancel all running learning tasks (called during shutdown).""" + # Signal all groups to stop first (non-blocking) + try: + await asyncio.wait_for( + self._progressive_learning.stop_learning(), + timeout=3, + ) + except (asyncio.TimeoutError, Exception) as e: + logger.warning(f"stop_learning 超时或失败: {e}") + + # Cancel and wait for each task with individual timeouts for group_id, task in list(self.learning_tasks.items()): try: - await self._progressive_learning.stop_learning() if not task.done(): task.cancel() try: - await task - except asyncio.CancelledError: + await asyncio.wait_for( + asyncio.shield(task), timeout=2, + ) + except (asyncio.CancelledError, asyncio.TimeoutError): pass logger.info(f"群组 {group_id} 学习任务已停止") except Exception as e: diff --git a/services/state/affection_manager.py b/services/state/affection_manager.py index 69b8c78..902e8e3 100644 --- a/services/state/affection_manager.py +++ b/services/state/affection_manager.py @@ -131,19 +131,27 @@ async def _do_start(self) -> bool: # 为所有活跃群组设置初始随机情绪(如果启用) if self.config.enable_startup_random_mood: await self._initialize_random_moods_for_active_groups() - + # 启动每日情绪更新任务 if self.config.enable_daily_mood: - asyncio.create_task(self._daily_mood_updater()) - + self._mood_task = asyncio.create_task(self._daily_mood_updater()) + self._logger.info("好感度管理服务启动成功") return True except Exception as e: self._logger.error(f"好感度管理服务启动失败: {e}") return False - + async def _do_stop(self) -> bool: """停止好感度管理服务""" + # 取消后台任务 + task = getattr(self, '_mood_task', None) + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # 保存当前状态 await self._save_current_state() return True @@ -993,21 +1001,15 @@ async def get_mood_influenced_system_prompt(self, group_id: str, base_prompt: st async def _daily_mood_updater(self): """每日情绪更新任务""" - while True: - try: + try: + while True: current_hour = datetime.now().hour if current_hour == self.config.mood_change_hour: - # 为所有活跃群组更新情绪 - # 这里需要获取活跃群组列表,简化实现暂时跳过 - # await self._update_all_group_moods() pass - - # 每小时检查一次 - await asyncio.sleep(3600) - - except Exception as e: - self._logger.error(f"每日情绪更新失败: {e}") + await asyncio.sleep(3600) + except asyncio.CancelledError: + self._logger.debug("每日情绪更新任务已取消") async def _save_current_state(self): """保存当前状态到数据库""" diff --git a/services/state/enhanced_interaction.py b/services/state/enhanced_interaction.py index 3d95d34..746ba1c 100644 --- a/services/state/enhanced_interaction.py +++ b/services/state/enhanced_interaction.py @@ -72,19 +72,28 @@ async def _do_start(self) -> bool: try: await self._load_cross_group_memories() await self._load_group_interests() - - # 启动定期任务 - asyncio.create_task(self._periodic_memory_sync()) - asyncio.create_task(self._periodic_context_cleanup()) - + + # 启动定期任务(保留引用以便 stop 时取消) + self._sync_task = asyncio.create_task(self._periodic_memory_sync()) + self._cleanup_task = asyncio.create_task(self._periodic_context_cleanup()) + self._logger.info("增强交互服务启动成功") return True except Exception as e: self._logger.error(f"增强交互服务启动失败: {e}") return False - + async def _do_stop(self) -> bool: """停止增强交互服务""" + # 取消后台任务 + for task in (getattr(self, '_sync_task', None), + getattr(self, '_cleanup_task', None)): + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass await self._save_cross_group_memories() await self._save_group_interests() return True @@ -448,31 +457,30 @@ async def _select_engaging_topic(self, group_id: str, interests: Dict) -> Option async def _periodic_memory_sync(self): """定期同步跨群记忆""" - while True: - try: + try: + while True: await asyncio.sleep(self.memory_sync_interval) await self._save_cross_group_memories() - except Exception as e: - self._logger.error(f"记忆同步失败: {e}") - + except asyncio.CancelledError: + self._logger.debug("记忆同步任务已取消") + async def _periodic_context_cleanup(self): """定期清理过期上下文""" - while True: - try: + try: + while True: await asyncio.sleep(600) # 10分钟清理一次 current_time = time.time() - + expired_contexts = [ group_id for group_id, context in self.conversation_contexts.items() if current_time - context.last_activity > self.context_retention_time ] - + for group_id in expired_contexts: del self.conversation_contexts[group_id] self._logger.debug(f"清理过期对话上下文: {group_id}") - - except Exception as e: - self._logger.error(f"上下文清理失败: {e}") + except asyncio.CancelledError: + self._logger.debug("上下文清理任务已取消") async def _load_cross_group_memories(self): """加载跨群记忆数据""" diff --git a/webui/manager.py b/webui/manager.py index 22c1fc8..ae67ed0 100644 --- a/webui/manager.py +++ b/webui/manager.py @@ -156,7 +156,22 @@ async def stop(self) -> None: """有序停止 WebUI 服务器""" global _server_instance, _server_cleanup_lock - async with _server_cleanup_lock: + try: + await asyncio.wait_for( + _server_cleanup_lock.acquire(), timeout=3.0, + ) + except asyncio.TimeoutError: + logger.warning("[WebUI] 获取清理锁超时,强制继续清理") + # 拿不到锁也要继续清理 + if _server_instance: + try: + await _server_instance.stop() + except Exception: + pass + _server_instance = None + return + + try: if not _server_instance: return try: @@ -173,6 +188,8 @@ async def stop(self) -> None: except Exception as e: logger.error(f"停止 Web 服务器失败: {e}", exc_info=True) _server_instance = None + finally: + _server_cleanup_lock.release() # 内部方法 diff --git a/webui/server.py b/webui/server.py index 4f277e8..cd0b107 100644 --- a/webui/server.py +++ b/webui/server.py @@ -171,9 +171,18 @@ async def stop(self): except Exception: pass - # 等待线程退出 + # 在线程池中等待线程退出,避免阻塞事件循环 if self.server_thread: - self.server_thread.join(timeout=5.0) + loop = asyncio.get_event_loop() + try: + await asyncio.wait_for( + loop.run_in_executor( + None, self.server_thread.join, 5.0, + ), + timeout=6.0, + ) + except asyncio.TimeoutError: + logger.warning("[WebUI] 服务器线程退出超时,强制继续") self.server_thread = None self._thread_loop = None