diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c52cb5..b052273 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,50 @@ 所有重要更改都将记录在此文件中。 +## [Next-2.0.6] - 2026-03-02 + +### 新功能 + +#### WebUI 监听地址可配置 +- 新增 `web_interface_host` 配置项,允许用户自定义 WebUI 监听地址(默认 `0.0.0.0`) + +### Bug 修复 + +#### SQLite 并发访问错误 +- 修复 SQLite 连接池使用 `StaticPool`(共享单连接)导致 WebUI 并发请求事务状态污染的问题 +- 改为 `NullPool`,每个会话获取独立连接,消除 "Cannot operate on a closed database" 错误 + +#### 插件卸载 CPU 100% +- 移除 WebUI 关停流程中 `server.py` 和 `manager.py` 的两次 `gc.collect()` 调用 +- 每次 `gc.collect()` 遍历 ~200 个模块的对象图耗时 80+ 秒,导致卸载期间 CPU 满载 + +#### 命令处理器空指针 +- 为全部 6 个管理命令(`learning_status`、`start_learning`、`stop_learning`、`force_learning`、`affection_status`、`set_mood`)添加空值守卫 +- 当 `bootstrap()` 失败导致 `_command_handlers` 为 `None` 时,返回友好提示而非抛出 `'NoneType' object has no attribute` 异常 + +#### 人格审查系统 +- 修复撤回操作崩溃和已审查列表数据缺失问题 +- 修复风格学习审查记录在已审查历史中显示空内容、类型"未知"、置信度 0.0% 的问题,补全 `StyleLearningReview` 到前端统一格式的字段映射 +- WebUI 风格统计查询改用 Facade 而非直接 Repository 调用 + +#### MySQL 8 连接 +- 禁用 MySQL 8 默认 SSL 要求,解决 `ssl.SSLError` 连接失败 +- 强化会话生命周期管理 + +#### ORM 字段映射 +- 修正心理状态和情绪持久化的 ORM 字段映射 +- 使用防御性 `getattr` 处理 ORM-to-dataclass 组件映射中的缺失属性 + +#### 其他修复 +- WebUI 使用全局默认人格代替随机 UMO +- WebUI 响应速度指标无 LLM 数据时使用中性回退值 +- 黑话 meaning 字段 dict/list 类型序列化为 JSON 字符串后写入数据库 +- 批量学习路径中正确保存筛选后的消息到数据库 +- 防护 `background_tasks` 在关停序列中的访问安全 + +### 测试 +- 新增核心模块单元测试,扩展覆盖率配置 + ## [Next-2.0.5] - 2026-02-24 ### Bug 修复 diff --git a/README.md b/README.md index 5464e80..6af5a80 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@
-[![Version](https://img.shields.io/badge/version-Next--2.0.0-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-GPLv3-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) +[![Version](https://img.shields.io/badge/version-Next--2.0.6-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-GPLv3-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) [核心功能](#-我们能做什么) · [快速开始](#-快速开始) · [管理界面](#-可视化管理界面) · [社区交流](#-社区交流) · [贡献指南](CONTRIBUTING.md) @@ -229,10 +229,18 @@ http://localhost:7833 --- +--- +
**如果觉得有帮助,欢迎 Star 支持!** +### 赞助支持 + +如果这个项目对你有帮助,欢迎通过爱发电赞助支持开发者持续维护: + +爱发电赞助二维码 + [回到顶部](#astrbot-自主学习插件)
diff --git a/README_EN.md b/README_EN.md index 250354c..1a993ec 100644 --- a/README_EN.md +++ b/README_EN.md @@ -14,7 +14,7 @@
-[![Version](https://img.shields.io/badge/version-Next--2.0.0-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-GPLv3-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) +[![Version](https://img.shields.io/badge/version-Next--2.0.6-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-GPLv3-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) [Features](#what-we-can-do) · [Quick Start](#quick-start) · [Web UI](#visual-management-interface) · [Community](#community) · [Contributing](CONTRIBUTING.md) diff --git a/core/database/engine.py b/core/database/engine.py index 0208aec..a2a99f2 100644 --- a/core/database/engine.py +++ b/core/database/engine.py @@ -8,7 +8,7 @@ 避免 "Task got Future attached to a different loop" 错误。 """ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker -from sqlalchemy.pool import StaticPool +from sqlalchemy.pool import NullPool from sqlalchemy import types as sa_types from astrbot.api import logger from typing import Optional @@ -90,11 +90,13 @@ def _create_sqlite_engine(self): logger.info(f"[DatabaseEngine] 创建数据库目录: {db_dir}") # SQLite 配置 - # StaticPool reuses a single connection, avoiding per-query overhead + # NullPool: 每个 session 独立创建/关闭连接,避免 StaticPool + # 单连接共享导致的并发事务状态污染和 "closed database" 错误。 + # SQLite 建连成本极低(打开文件句柄),配合 WAL 模式可安全并发读。 engine = create_async_engine( db_url, echo=self.echo, - poolclass=StaticPool, + poolclass=NullPool, connect_args={ 'check_same_thread': False, 'timeout': 30, @@ -138,6 +140,7 @@ def _create_mysql_engine(self): connect_args={ 'connect_timeout': 10, 'charset': 'utf8mb4', + 'ssl': False, } ) diff --git a/core/framework_llm_adapter.py b/core/framework_llm_adapter.py index d5ffe2a..c5f8df7 100644 --- a/core/framework_llm_adapter.py +++ b/core/framework_llm_adapter.py @@ -243,6 +243,10 @@ async def filter_chat_completion( # 尝试延迟初始化 self._try_lazy_init() + # 确保 contexts 不为 None,避免 Provider 内部调用 len(None) + if contexts is None: + contexts = [] + if not self.filter_provider: logger.warning("筛选Provider未配置,尝试使用备选Provider或降级处理") # 尝试使用其他可用的Provider作为备选 @@ -301,6 +305,10 @@ async def refine_chat_completion( # 尝试延迟初始化 self._try_lazy_init() + # 确保 contexts 不为 None,避免 Provider 内部调用 len(None) + if contexts is None: + contexts = [] + if not self.refine_provider: logger.warning("提炼Provider未配置,尝试使用备选Provider或降级处理") # 尝试使用其他可用的Provider作为备选 @@ -359,6 +367,10 @@ async def reinforce_chat_completion( # 尝试延迟初始化 self._try_lazy_init() + # 确保 contexts 不为 None,避免 Provider 内部调用 len(None) + if contexts is None: + contexts = [] + if not self.reinforce_provider: logger.warning("强化Provider未配置,尝试使用备选Provider或降级处理") # 尝试使用其他可用的Provider作为备选 diff --git a/image/afdian-NickMo.jpeg b/image/afdian-NickMo.jpeg new file mode 100644 index 0000000..3152387 Binary files /dev/null and b/image/afdian-NickMo.jpeg differ diff --git a/main.py b/main.py index de11aa2..3fe3071 100644 --- a/main.py +++ b/main.py @@ -231,6 +231,9 @@ async def inject_diversity_to_llm_request(self, event: AstrMessageEvent, req=Non @filter.permission_type(PermissionType.ADMIN) async def learning_status_command(self, event: AstrMessageEvent): """查看学习状态""" + if not self._command_handlers: + yield event.plain_result("插件服务未就绪,请检查启动日志") + return async for result in self._command_handlers.learning_status(event): yield result @@ -238,6 +241,9 @@ async def learning_status_command(self, event: AstrMessageEvent): @filter.permission_type(PermissionType.ADMIN) async def start_learning_command(self, event: AstrMessageEvent): """手动启动学习""" + if not self._command_handlers: + yield event.plain_result("插件服务未就绪,请检查启动日志") + return async for result in self._command_handlers.start_learning(event): yield result @@ -245,6 +251,9 @@ async def start_learning_command(self, event: AstrMessageEvent): @filter.permission_type(PermissionType.ADMIN) async def stop_learning_command(self, event: AstrMessageEvent): """停止学习""" + if not self._command_handlers: + yield event.plain_result("插件服务未就绪,请检查启动日志") + return async for result in self._command_handlers.stop_learning(event): yield result @@ -252,6 +261,9 @@ async def stop_learning_command(self, event: AstrMessageEvent): @filter.permission_type(PermissionType.ADMIN) async def force_learning_command(self, event: AstrMessageEvent): """强制执行一次学习周期""" + if not self._command_handlers: + yield event.plain_result("插件服务未就绪,请检查启动日志") + return async for result in self._command_handlers.force_learning(event): yield result @@ -259,6 +271,9 @@ async def force_learning_command(self, event: AstrMessageEvent): @filter.permission_type(PermissionType.ADMIN) async def affection_status_command(self, event: AstrMessageEvent): """查看好感度状态""" + if not self._command_handlers: + yield event.plain_result("插件服务未就绪,请检查启动日志") + return async for result in self._command_handlers.affection_status(event): yield result @@ -266,5 +281,8 @@ async def affection_status_command(self, event: AstrMessageEvent): @filter.permission_type(PermissionType.ADMIN) async def set_mood_command(self, event: AstrMessageEvent): """手动设置bot情绪""" + if not self._command_handlers: + yield event.plain_result("插件服务未就绪,请检查启动日志") + return async for result in self._command_handlers.set_mood(event): yield result diff --git a/metadata.yaml b/metadata.yaml index 5634dc3..9070d2d 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -2,7 +2,7 @@ name: "astrbot_plugin_self_learning" author: "NickMo" display_name: "self-learning" description: "SELF LEARNING 自主学习插件 — 让 AI 聊天机器人自主学习对话风格、理解群组黑话、管理社交关系与好感度、自适应人格演化,像真人一样自然对话。(使用前必须手动备份人格数据)" -version: "Next-2.0.5" +version: "Next-2.0.6" repo: "https://github.com/NickCharlie/astrbot_plugin_self_learning" tags: - "自学习" diff --git a/persona_web_manager.py b/persona_web_manager.py index 010ccd3..3930853 100644 --- a/persona_web_manager.py +++ b/persona_web_manager.py @@ -143,8 +143,9 @@ async def get_all_personas_for_web(self) -> List[Dict[str, Any]]: async def get_default_persona_for_web(self) -> Dict[str, Any]: """获取默认人格,格式化为Web界面需要的格式 - 使用 group_id_to_unified_origin 映射中的 UMO 来获取当前活跃配置的人格, - 而非始终返回 default 配置的人格。 + 始终传入 None 调用 get_default_persona_v3 以获取 AstrBot 全局默认人格, + 避免在多配置文件场景下因随机选取 UMO 而导致每次返回不同人格。 + 如需查看特定配置的人格,应通过 get_persona_for_group 并明确指定 group_id。 """ fallback = { "persona_id": "default", @@ -157,14 +158,9 @@ async def get_default_persona_for_web(self) -> Dict[str, Any]: return fallback try: - # 尝试从映射中获取一个 UMO,以加载当前活跃配置的人格 - umo = None - if self.group_id_to_unified_origin: - # 取任意一个 UMO(通常同一配置文件下的群组共享同一配置) - umo = next(iter(self.group_id_to_unified_origin.values()), None) - + # 获取全局默认人格,不依赖 group_id_to_unified_origin 映射 default_persona = await self._run_on_main_loop( - self.persona_manager.get_default_persona_v3(umo) + self.persona_manager.get_default_persona_v3(None) ) if default_persona: diff --git a/pytest.ini b/pytest.ini index f51256a..d4008a1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -23,6 +23,12 @@ addopts = # Strict markers (only registered markers allowed) --strict-markers # Coverage options + --cov=config + --cov=constants + --cov=exceptions + --cov=core + --cov=utils + --cov=services --cov=webui --cov-report=html --cov-report=term-missing @@ -40,7 +46,10 @@ markers = auth: Authentication related tests service: Service layer tests blueprint: Blueprint/route tests - security: Security-related tests + core: Core module tests + config: Configuration tests + utils: Utility module tests + quality: Quality monitoring tests # Log output log_cli = true @@ -53,7 +62,14 @@ asyncio_mode = auto # Coverage options [coverage:run] -source = webui +source = + config + constants + exceptions + core + utils + services + webui omit = */tests/* */test_*.py diff --git a/repositories/bot_mood_repository.py b/repositories/bot_mood_repository.py index b8f9d99..7048b1e 100644 --- a/repositories/bot_mood_repository.py +++ b/repositories/bot_mood_repository.py @@ -55,7 +55,14 @@ async def save(self, mood_data: Dict[str, Any]) -> Optional[BotMood]: mood = BotMood(**mood_data) self.session.add(mood) await self.session.commit() - await self.session.refresh(mood) + try: + await self.session.refresh(mood) + except Exception as refresh_err: + # 部分 async 驱动在 commit 后 refresh 可能失败, + # 此时 mood 对象已持久化且 ID 已赋值,可安全返回 + logger.debug( + f"[BotMoodRepository] refresh after commit skipped: {refresh_err}" + ) return mood except Exception as e: await self.session.rollback() diff --git a/repositories/psychological_repository.py b/repositories/psychological_repository.py index ccc74f3..96b07ec 100644 --- a/repositories/psychological_repository.py +++ b/repositories/psychological_repository.py @@ -98,51 +98,57 @@ async def get_components( 获取状态的所有组件 Args: - state_id: 状态 ID + state_id: 复合心理状态记录的主键 ID Returns: List[PsychologicalStateComponent]: 组件列表 """ - return await self.find_many(state_id=state_id) + return await self.find_many(composite_state_id=state_id) async def update_component( self, state_id: int, component_name: str, value: float, - threshold: float = None + threshold: float = None, + group_id: str = "", + state_id_str: str = "" ) -> Optional[PsychologicalStateComponent]: """ 更新组件值 Args: - state_id: 状态 ID - component_name: 组件名称 + state_id: 复合心理状态记录的主键 ID + component_name: 组件类别名称(category) value: 组件值 threshold: 阈值 + group_id: 群组 ID (创建新组件时使用) + state_id_str: 状态标识符 (创建新组件时使用) Returns: Optional[PsychologicalStateComponent]: 组件对象 """ component = await self.find_one( - state_id=state_id, - component_name=component_name + composite_state_id=state_id, + category=component_name ) if component: component.value = value if threshold is not None: component.threshold = threshold - component.updated_at = int(time.time()) return await self.update(component) else: + now = int(time.time()) return await self.create( - state_id=state_id, - component_name=component_name, + composite_state_id=state_id, + group_id=group_id, + state_id=state_id_str or f"{group_id}:{component_name}", + category=component_name, + state_type=component_name, value=value, - threshold=threshold or 0.5, - created_at=int(time.time()), - updated_at=int(time.time()) + threshold=threshold or 0.3, + start_time=now, ) @@ -158,7 +164,9 @@ async def add_history( from_state: str, to_state: str, trigger_event: str = None, - intensity_change: float = 0.0 + intensity_change: float = 0.0, + group_id: str = "", + category: str = "" ) -> Optional[PsychologicalStateHistory]: """ 添加历史记录 @@ -169,16 +177,21 @@ async def add_history( to_state: 结束状态 trigger_event: 触发事件 intensity_change: 强度变化 + group_id: 群组 ID + category: 状态类别 Returns: Optional[PsychologicalStateHistory]: 历史记录 """ return await self.create( - state_id=state_id, - from_state=from_state, - to_state=to_state, - trigger_event=trigger_event, - intensity_change=intensity_change, + group_id=group_id, + state_id=str(state_id), + category=category or "unknown", + old_state_type=from_state, + new_state_type=to_state or "", + old_value=0.0, + new_value=intensity_change, + change_reason=trigger_event, timestamp=int(time.time()) ) diff --git a/services/commands/handlers.py b/services/commands/handlers.py index 102d52a..19b3a54 100644 --- a/services/commands/handlers.py +++ b/services/commands/handlers.py @@ -257,6 +257,10 @@ async def affection_status(self, event: Any) -> AsyncGenerator: yield event.plain_result(CommandMessages.AFFECTION_DISABLED) return + if not self._affection_manager: + yield event.plain_result("好感度管理器未初始化,请检查启动日志") + return + affection_status = await self._affection_manager.get_affection_status(group_id) current_mood = None @@ -324,6 +328,14 @@ async def set_mood(self, event: Any) -> AsyncGenerator: yield event.plain_result(CommandMessages.AFFECTION_DISABLED) return + if not self._temporary_persona_updater: + yield event.plain_result("临时人格更新器未初始化,无法设置情绪") + return + + if not self._affection_manager: + yield event.plain_result("好感度管理器未初始化,无法设置情绪") + return + args = event.get_message_str().split()[1:] if len(args) < 1: yield event.plain_result( diff --git a/services/core_learning/progressive_learning.py b/services/core_learning/progressive_learning.py index 4851c82..951ac4c 100644 --- a/services/core_learning/progressive_learning.py +++ b/services/core_learning/progressive_learning.py @@ -250,7 +250,26 @@ async def _execute_learning_batch(self, group_id: str, relearn_mode: bool = Fals logger.debug("没有通过筛选的消息") await self._mark_messages_processed(unprocessed_messages) return - + + # 2.5 将筛选后的消息写入 FilteredMessage 表(供 WebUI 统计) + saved_count = 0 + for msg in filtered_messages: + try: + await self.message_collector.add_filtered_message({ + "raw_message_id": msg.get("id"), + "message": msg.get("message", ""), + "sender_id": msg.get("sender_id", ""), + "group_id": msg.get("group_id", group_id), + "timestamp": msg.get("timestamp", int(time.time())), + "confidence": msg.get("relevance_score", 1.0), + "filter_reason": msg.get("filter_reason", "batch_learning"), + }) + saved_count += 1 + except Exception: + pass # best-effort, don't block learning + if saved_count: + logger.debug(f"已保存 {saved_count}/{len(filtered_messages)} 条筛选消息到 FilteredMessage 表") + # 3. 获取当前人格设置 (针对特定群组) current_persona = await self._get_current_persona(group_id) diff --git a/services/database/facades/_base.py b/services/database/facades/_base.py index c877aa2..19aea50 100644 --- a/services/database/facades/_base.py +++ b/services/database/facades/_base.py @@ -26,13 +26,20 @@ def __init__(self, engine: DatabaseEngine, config: PluginConfig): @asynccontextmanager async def get_session(self): - """获取异步数据库会话(上下文管理器)""" + """获取异步数据库会话(上下文管理器) + + 自动处理会话的创建、提交和回滚。 + 使用 ``async with session`` 确保事务完整性, + 不再在 finally 中重复调用 close 以避免连接状态异常。 + """ + if self.engine is None or self.engine.engine is None: + raise RuntimeError("数据库引擎未初始化或已关闭") session = self.engine.get_session() try: async with session: yield session - finally: - await session.close() + except Exception: + raise @staticmethod def _row_to_dict(obj: Any, fields: Optional[List[str]] = None) -> Dict[str, Any]: diff --git a/services/database/facades/jargon_facade.py b/services/database/facades/jargon_facade.py index 7ca2282..a7f3961 100644 --- a/services/database/facades/jargon_facade.py +++ b/services/database/facades/jargon_facade.py @@ -152,7 +152,13 @@ async def update_jargon(self, jargon_data: Dict[str, Any]) -> bool: if 'raw_content' in jargon_data: record.raw_content = jargon_data['raw_content'] if 'meaning' in jargon_data: - record.meaning = jargon_data['meaning'] + meaning_val = jargon_data['meaning'] + if isinstance(meaning_val, dict): + record.meaning = json.dumps(meaning_val, ensure_ascii=False) + elif isinstance(meaning_val, list): + record.meaning = json.dumps(meaning_val, ensure_ascii=False) + else: + record.meaning = str(meaning_val) if meaning_val is not None else None if 'is_jargon' in jargon_data: record.is_jargon = jargon_data['is_jargon'] if 'count' in jargon_data: @@ -743,8 +749,12 @@ async def get_jargon_groups(self) -> List[Dict]: groups = [] for row in rows: try: + chat_id = row.chat_id or '' groups.append({ - 'chat_id': row.chat_id, + 'group_id': chat_id, + 'group_name': chat_id, + 'id': chat_id, + 'chat_id': chat_id, 'count': row.count or 0 }) except Exception as row_error: diff --git a/services/database/sqlalchemy_database_manager.py b/services/database/sqlalchemy_database_manager.py index 99ec3a6..586d30f 100644 --- a/services/database/sqlalchemy_database_manager.py +++ b/services/database/sqlalchemy_database_manager.py @@ -150,37 +150,53 @@ def _get_database_url(self) -> str: return f"sqlite:///{db_path}" async def _ensure_mysql_database_exists(self): - """确保 MySQL 数据库存在""" - try: - import aiomysql - host = getattr(self.config, 'mysql_host', 'localhost') - port = getattr(self.config, 'mysql_port', 3306) - user = getattr(self.config, 'mysql_user', 'root') - password = getattr(self.config, 'mysql_password', '') - database = getattr(self.config, 'mysql_database', 'astrbot_self_learning') + """确保 MySQL 数据库存在 + + 使用 aiomysql 直连 MySQL 服务器以检查/创建目标数据库。 + 显式禁用 SSL 以避免 MySQL 8 默认 TLS 握手导致的 + struct.unpack 解包异常。 + """ + import aiomysql + host = getattr(self.config, 'mysql_host', 'localhost') + port = getattr(self.config, 'mysql_port', 3306) + user = getattr(self.config, 'mysql_user', 'root') + password = getattr(self.config, 'mysql_password', '') + database = getattr(self.config, 'mysql_database', 'astrbot_self_learning') - conn = await aiomysql.connect( - host=host, port=port, user=user, - password=password, charset='utf8mb4', + try: + conn = await asyncio.wait_for( + aiomysql.connect( + host=host, port=port, user=user, + password=password, charset='utf8mb4', + ssl=False, connect_timeout=10, + ), + timeout=15, ) - try: - async with conn.cursor() as cursor: + except asyncio.TimeoutError: + logger.error("[DomainRouter] 连接 MySQL 超时 (15s)") + raise + except Exception as e: + logger.error(f"[DomainRouter] 连接 MySQL 失败: {e}") + raise + + try: + async with conn.cursor() as cursor: + await cursor.execute( + "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = %s", + (database,), + ) + if not await cursor.fetchone(): + logger.info(f"[DomainRouter] 数据库 {database} 不存在,正在创建...") await cursor.execute( - "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = %s", - (database,), + f"CREATE DATABASE `{database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" ) - if not await cursor.fetchone(): - logger.info(f"[DomainRouter] 数据库 {database} 不存在,正在创建…") - await cursor.execute( - f"CREATE DATABASE `{database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" - ) - await conn.commit() - logger.info(f"[DomainRouter] 数据库 {database} 创建成功") - finally: - conn.close() + await conn.commit() + logger.info(f"[DomainRouter] 数据库 {database} 创建成功") except Exception as e: logger.error(f"[DomainRouter] 确保 MySQL 数据库存在失败: {e}") raise + finally: + conn.close() # Infrastructure: session @@ -206,8 +222,8 @@ async def get_session(self): try: async with session: yield session - finally: - await session.close() + except Exception: + raise # Domain delegates: AffectionFacade diff --git a/services/jargon/jargon_miner.py b/services/jargon/jargon_miner.py index 331a4b9..9f85733 100644 --- a/services/jargon/jargon_miner.py +++ b/services/jargon/jargon_miner.py @@ -116,7 +116,13 @@ async def infer_meaning( logger.info(f"黑话 {content} 信息不足,等待下次推断") return {'no_info': True} - meaning1 = inference1.get('meaning', '').strip() + meaning1_raw = inference1.get('meaning', '') + if isinstance(meaning1_raw, dict): + meaning1 = json.dumps(meaning1_raw, ensure_ascii=False) + elif isinstance(meaning1_raw, list): + meaning1 = json.dumps(meaning1_raw, ensure_ascii=False) + else: + meaning1 = str(meaning1_raw).strip() if meaning1_raw else '' if not meaning1: return {'no_info': True} @@ -153,9 +159,18 @@ async def infer_meaning( is_similar = comparison.get('is_similar', False) is_jargon = not is_similar + if is_jargon: + final_meaning = meaning1 + else: + meaning2_raw = inference2.get('meaning', '') + if isinstance(meaning2_raw, (dict, list)): + final_meaning = json.dumps(meaning2_raw, ensure_ascii=False) + else: + final_meaning = str(meaning2_raw).strip() if meaning2_raw else '' + return { 'is_jargon': is_jargon, - 'meaning': meaning1 if is_jargon else inference2.get('meaning', ''), + 'meaning': final_meaning, 'no_info': False } diff --git a/services/state/enhanced_psychological_state_manager.py b/services/state/enhanced_psychological_state_manager.py index af4ffc2..71f4842 100644 --- a/services/state/enhanced_psychological_state_manager.py +++ b/services/state/enhanced_psychological_state_manager.py @@ -196,15 +196,26 @@ async def get_current_state( # 获取组件 components = await component_repo.get_components(state.id) - # 转换为 CompositePsychologicalState + # 转换 ORM 组件对象为 dataclass 实例 + # 使用 getattr 进行防御性读取,防止 ORM 属性未加载时抛出 + # AttributeError(如 lazy-load 失败或 schema 不一致) state_components = [] for comp in components: - state_components.append(PsychologicalStateComponent( - dimension=comp.component_name, - state_type=comp.component_name, # TODO: 需要解析类型 - value=comp.value, - threshold=comp.threshold - )) + try: + state_components.append(PsychologicalStateComponent( + category=getattr(comp, 'category', 'unknown'), + state_type=getattr(comp, 'state_type', 'unknown'), + value=float(getattr(comp, 'value', 0.5)), + threshold=float(getattr(comp, 'threshold', 0.3)), + description=getattr(comp, 'description', '') or "", + start_time=float(getattr(comp, 'start_time', 0)) + if getattr(comp, 'start_time', None) else time.time() + )) + except Exception as comp_err: + self._logger.warning( + f"[增强型心理状态] 转换组件失败: {comp_err}, " + f"comp_type={type(comp).__name__}" + ) composite_state = CompositePsychologicalState( group_id=group_id, @@ -267,7 +278,9 @@ async def update_state( await component_repo.update_component( state.id, dimension, - new_value + new_value, + group_id=group_id, + state_id_str=f"{group_id}:{user_id}" ) # 记录历史 @@ -276,7 +289,9 @@ async def update_state( from_state=state.overall_state, to_state=str(new_state_type), trigger_event=trigger_event, - intensity_change=0.0 + intensity_change=0.0, + group_id=group_id, + category=dimension ) # 清除缓存 diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_cache_manager.py b/tests/unit/test_cache_manager.py new file mode 100644 index 0000000..01b93f0 --- /dev/null +++ b/tests/unit/test_cache_manager.py @@ -0,0 +1,251 @@ +""" +Unit tests for CacheManager + +Tests the unified cache management system: +- Cache get/set/delete/clear operations +- Named cache isolation (affection, memory, state, etc.) +- Hit rate statistics tracking +- Cache stats reporting +- Global singleton management +- Unknown cache name handling +""" +import pytest +from unittest.mock import patch + +from utils.cache_manager import CacheManager, get_cache_manager, cached, async_cached + + +@pytest.mark.unit +@pytest.mark.utils +class TestCacheManagerOperations: + """Test basic CacheManager CRUD operations.""" + + def test_set_and_get(self): + """Test setting and getting a cache value.""" + mgr = CacheManager() + mgr.set("general", "key1", "value1") + + result = mgr.get("general", "key1") + assert result == "value1" + + def test_get_nonexistent_key(self): + """Test getting a nonexistent key returns None.""" + mgr = CacheManager() + + result = mgr.get("general", "nonexistent") + assert result is None + + def test_delete_existing_key(self): + """Test deleting an existing cache entry.""" + mgr = CacheManager() + mgr.set("general", "key1", "value1") + + mgr.delete("general", "key1") + + assert mgr.get("general", "key1") is None + + def test_delete_nonexistent_key(self): + """Test deleting a nonexistent key does not raise.""" + mgr = CacheManager() + mgr.delete("general", "nonexistent") # Should not raise + + def test_clear_specific_cache(self): + """Test clearing a specific named cache.""" + mgr = CacheManager() + mgr.set("affection", "k1", "v1") + mgr.set("affection", "k2", "v2") + + mgr.clear("affection") + + assert mgr.get("affection", "k1") is None + assert mgr.get("affection", "k2") is None + + def test_clear_all_caches(self): + """Test clearing all caches at once.""" + mgr = CacheManager() + mgr.set("affection", "k1", "v1") + mgr.set("memory", "k2", "v2") + mgr.set("general", "k3", "v3") + + mgr.clear_all() + + assert mgr.get("affection", "k1") is None + assert mgr.get("memory", "k2") is None + assert mgr.get("general", "k3") is None + + +@pytest.mark.unit +@pytest.mark.utils +class TestCacheManagerIsolation: + """Test cache name isolation between different caches.""" + + def test_different_caches_are_isolated(self): + """Test same key in different caches are independent.""" + mgr = CacheManager() + mgr.set("affection", "shared_key", "affection_value") + mgr.set("memory", "shared_key", "memory_value") + + assert mgr.get("affection", "shared_key") == "affection_value" + assert mgr.get("memory", "shared_key") == "memory_value" + + @pytest.mark.parametrize("cache_name", [ + "affection", "memory", "state", "relation", + "context", "embedding_query", + "conversation", "summary", "general", + ]) + def test_all_named_caches_accessible(self, cache_name): + """Test all named caches are accessible.""" + mgr = CacheManager() + mgr.set(cache_name, "test_key", "test_value") + + result = mgr.get(cache_name, "test_key") + assert result == "test_value" + + def test_unknown_cache_name_returns_none(self): + """Test accessing an unknown cache name returns None.""" + mgr = CacheManager() + + result = mgr.get("unknown_cache", "key1") + assert result is None + + def test_set_to_unknown_cache_does_not_raise(self): + """Test setting to an unknown cache does not raise.""" + mgr = CacheManager() + mgr.set("unknown_cache", "key1", "value1") # Should not raise + + +@pytest.mark.unit +@pytest.mark.utils +class TestCacheManagerStats: + """Test cache statistics and hit rate tracking.""" + + def test_hit_rate_empty(self): + """Test hit rates with no operations.""" + mgr = CacheManager() + + stats = mgr.get_hit_rates() + assert stats == {} + + def test_hit_rate_tracking(self): + """Test hit/miss tracking across operations.""" + mgr = CacheManager() + mgr.set("general", "key1", "value1") + + # Hit + mgr.get("general", "key1") + # Miss + mgr.get("general", "nonexistent") + + stats = mgr.get_hit_rates() + assert "general" in stats + assert stats["general"]["hits"] == 1 + assert stats["general"]["misses"] == 1 + assert stats["general"]["hit_rate"] == 0.5 + + def test_get_stats_for_cache(self): + """Test getting stats for a specific cache.""" + mgr = CacheManager() + mgr.set("affection", "k1", "v1") + + stats = mgr.get_stats("affection") + + assert "size" in stats + assert "maxsize" in stats + assert stats["size"] == 1 + + def test_get_stats_unknown_cache(self): + """Test getting stats for unknown cache returns empty dict.""" + mgr = CacheManager() + + stats = mgr.get_stats("unknown") + assert stats == {} + + +@pytest.mark.unit +@pytest.mark.utils +class TestCachedDecorator: + """Test the synchronous cached decorator.""" + + def test_cached_decorator_caches_result(self): + """Test cached decorator returns cached result on second call.""" + mgr = CacheManager() + call_count = 0 + + @cached(cache_name="general", key_func=lambda x: f"key_{x}", manager=mgr) + def expensive_func(x): + nonlocal call_count + call_count += 1 + return x * 2 + + result1 = expensive_func(5) + result2 = expensive_func(5) + + assert result1 == 10 + assert result2 == 10 + assert call_count == 1 # Only called once + + def test_cached_decorator_different_keys(self): + """Test cached decorator uses correct keys for different inputs.""" + mgr = CacheManager() + + @cached(cache_name="general", key_func=lambda x: f"key_{x}", manager=mgr) + def add_one(x): + return x + 1 + + assert add_one(1) == 2 + assert add_one(2) == 3 + + +@pytest.mark.unit +@pytest.mark.utils +class TestAsyncCachedDecorator: + """Test the asynchronous cached decorator.""" + + @pytest.mark.asyncio + async def test_async_cached_decorator(self): + """Test async cached decorator caches result.""" + mgr = CacheManager() + call_count = 0 + + @async_cached( + cache_name="general", + key_func=lambda x: f"async_key_{x}", + manager=mgr, + ) + async def async_expensive_func(x): + nonlocal call_count + call_count += 1 + return x * 3 + + result1 = await async_expensive_func(7) + result2 = await async_expensive_func(7) + + assert result1 == 21 + assert result2 == 21 + assert call_count == 1 + + +@pytest.mark.unit +@pytest.mark.utils +class TestGlobalCacheManager: + """Test global singleton cache manager.""" + + def test_get_cache_manager_returns_instance(self): + """Test get_cache_manager returns a CacheManager instance.""" + # Reset global to ensure clean state + import utils.cache_manager as module + module._global_cache_manager = None + + mgr = get_cache_manager() + + assert isinstance(mgr, CacheManager) + + def test_get_cache_manager_returns_same_instance(self): + """Test get_cache_manager always returns the same singleton.""" + import utils.cache_manager as module + module._global_cache_manager = None + + mgr1 = get_cache_manager() + mgr2 = get_cache_manager() + + assert mgr1 is mgr2 diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..cdc9d4f --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,379 @@ +""" +Unit tests for PluginConfig + +Tests the plugin configuration management including: +- Default value initialization +- Configuration creation from dict +- Configuration validation +- File persistence (save/load) +- Boundary value verification +""" +import os +import json +import tempfile +import pytest +from unittest.mock import patch, MagicMock + +from config import PluginConfig + + +@pytest.mark.unit +@pytest.mark.config +class TestPluginConfigDefaults: + """Test PluginConfig default value initialization.""" + + def test_create_default_instance(self): + """Test creating a default PluginConfig instance.""" + config = PluginConfig() + + assert config.enable_message_capture is True + assert config.enable_auto_learning is True + assert config.enable_realtime_learning is False + assert config.enable_web_interface is True + assert config.web_interface_port == 7833 + assert config.web_interface_host == "0.0.0.0" + + def test_create_default_classmethod(self): + """Test the create_default classmethod.""" + config = PluginConfig.create_default() + + assert isinstance(config, PluginConfig) + assert config.learning_interval_hours == 6 + assert config.min_messages_for_learning == 50 + assert config.max_messages_per_batch == 200 + + def test_default_learning_parameters(self): + """Test default learning parameter values.""" + config = PluginConfig() + + assert config.message_min_length == 5 + assert config.message_max_length == 500 + assert config.confidence_threshold == 0.7 + assert config.relevance_threshold == 0.6 + assert config.style_analysis_batch_size == 100 + assert config.style_update_threshold == 0.6 + + def test_default_database_settings(self): + """Test default database configuration values.""" + config = PluginConfig() + + assert config.db_type == "sqlite" + assert config.mysql_host == "localhost" + assert config.mysql_port == 3306 + assert config.postgresql_host == "localhost" + assert config.postgresql_port == 5432 + assert config.max_connections == 10 + + def test_default_affection_settings(self): + """Test default affection system configuration.""" + config = PluginConfig() + + assert config.enable_affection_system is True + assert config.max_total_affection == 250 + assert config.max_user_affection == 100 + assert config.affection_decay_rate == 0.95 + + def test_default_provider_ids_none(self): + """Test provider IDs default to None.""" + config = PluginConfig() + + assert config.filter_provider_id is None + assert config.refine_provider_id is None + assert config.reinforce_provider_id is None + assert config.embedding_provider_id is None + assert config.rerank_provider_id is None + + def test_sqlalchemy_always_true(self): + """Test that use_sqlalchemy is always True (hardcoded).""" + config = PluginConfig() + assert config.use_sqlalchemy is True + + +@pytest.mark.unit +@pytest.mark.config +class TestPluginConfigFromDict: + """Test PluginConfig creation from configuration dict.""" + + def test_create_from_basic_config(self): + """Test creating config from a basic configuration dict.""" + raw_config = { + 'Self_Learning_Basic': { + 'enable_message_capture': False, + 'enable_auto_learning': False, + 'web_interface_port': 8080, + } + } + + config = PluginConfig.create_from_config(raw_config, data_dir="/tmp/test") + + assert config.enable_message_capture is False + assert config.enable_auto_learning is False + assert config.web_interface_port == 8080 + assert config.data_dir == "/tmp/test" + + def test_create_from_config_with_model_settings(self): + """Test config creation with model configuration.""" + raw_config = { + 'Model_Configuration': { + 'filter_provider_id': 'provider_1', + 'refine_provider_id': 'provider_2', + 'reinforce_provider_id': 'provider_3', + } + } + + config = PluginConfig.create_from_config(raw_config, data_dir="/tmp/test") + + assert config.filter_provider_id == 'provider_1' + assert config.refine_provider_id == 'provider_2' + assert config.reinforce_provider_id == 'provider_3' + + def test_create_from_config_missing_data_dir(self): + """Test config creation with empty data_dir uses fallback.""" + config = PluginConfig.create_from_config({}, data_dir="") + + assert config.data_dir == "./data/self_learning_data" + + def test_create_from_config_with_database_settings(self): + """Test config creation with database settings.""" + raw_config = { + 'Database_Settings': { + 'db_type': 'mysql', + 'mysql_host': '192.168.1.100', + 'mysql_port': 3307, + 'mysql_user': 'admin', + 'mysql_password': 'secret', + 'mysql_database': 'test_db', + } + } + + config = PluginConfig.create_from_config(raw_config, data_dir="/tmp/test") + + assert config.db_type == 'mysql' + assert config.mysql_host == '192.168.1.100' + assert config.mysql_port == 3307 + assert config.mysql_user == 'admin' + assert config.mysql_database == 'test_db' + + def test_create_from_config_with_v2_settings(self): + """Test config creation with v2 architecture settings.""" + raw_config = { + 'V2_Architecture_Settings': { + 'embedding_provider_id': 'embed_provider', + 'rerank_provider_id': 'rerank_provider', + 'knowledge_engine': 'lightrag', + 'memory_engine': 'mem0', + } + } + + config = PluginConfig.create_from_config(raw_config, data_dir="/tmp/test") + + assert config.embedding_provider_id == 'embed_provider' + assert config.rerank_provider_id == 'rerank_provider' + assert config.knowledge_engine == 'lightrag' + assert config.memory_engine == 'mem0' + + def test_create_from_empty_config(self): + """Test config creation from empty dict uses all defaults.""" + config = PluginConfig.create_from_config({}, data_dir="/tmp/test") + + assert config.enable_message_capture is True + assert config.learning_interval_hours == 6 + assert config.db_type == 'sqlite' + + def test_extra_fields_ignored(self): + """Test that extra/unknown fields are ignored.""" + config = PluginConfig( + unknown_field_1="value1", + unknown_field_2=42, + ) + assert not hasattr(config, 'unknown_field_1') + + +@pytest.mark.unit +@pytest.mark.config +class TestPluginConfigValidation: + """Test PluginConfig validation logic.""" + + def test_valid_config_no_errors(self): + """Test validation of a valid default config.""" + config = PluginConfig( + filter_provider_id="provider_1", + refine_provider_id="provider_2", + ) + errors = config.validate_config() + + # Should have no blocking errors (may have warnings for reinforce) + blocking_errors = [e for e in errors if not e.startswith(" ")] + assert len(blocking_errors) == 0 + + def test_invalid_learning_interval(self): + """Test validation catches invalid learning interval.""" + config = PluginConfig(learning_interval_hours=0) + errors = config.validate_config() + + assert any("学习间隔必须大于0" in e for e in errors) + + def test_invalid_min_messages(self): + """Test validation catches invalid min messages for learning.""" + config = PluginConfig(min_messages_for_learning=0) + errors = config.validate_config() + + assert any("最少学习消息数量必须大于0" in e for e in errors) + + def test_invalid_max_batch_size(self): + """Test validation catches invalid max batch size.""" + config = PluginConfig(max_messages_per_batch=-1) + errors = config.validate_config() + + assert any("每批最大消息数量必须大于0" in e for e in errors) + + def test_invalid_message_length_range(self): + """Test validation catches min_length >= max_length.""" + config = PluginConfig(message_min_length=500, message_max_length=100) + errors = config.validate_config() + + assert any("最小长度必须小于最大长度" in e for e in errors) + + def test_invalid_confidence_threshold(self): + """Test validation catches confidence threshold out of range.""" + config = PluginConfig(confidence_threshold=1.5) + errors = config.validate_config() + + assert any("置信度阈值必须在0-1之间" in e for e in errors) + + def test_invalid_style_threshold(self): + """Test validation catches style update threshold out of range.""" + config = PluginConfig(style_update_threshold=-0.1) + errors = config.validate_config() + + assert any("风格更新阈值必须在0-1之间" in e for e in errors) + + def test_no_providers_configured(self): + """Test validation warns when no providers are configured.""" + config = PluginConfig( + filter_provider_id=None, + refine_provider_id=None, + reinforce_provider_id=None, + ) + errors = config.validate_config() + + assert any("至少需要配置一个模型提供商ID" in e for e in errors) + + def test_partial_providers_configured(self): + """Test validation with only some providers configured.""" + config = PluginConfig( + filter_provider_id="provider_1", + refine_provider_id=None, + reinforce_provider_id=None, + ) + errors = config.validate_config() + + # Should have warnings but no blocking errors + blocking_errors = [e for e in errors if not e.startswith(" ")] + assert len(blocking_errors) == 0 + + +@pytest.mark.unit +@pytest.mark.config +class TestPluginConfigSerialization: + """Test PluginConfig serialization and deserialization.""" + + def test_to_dict(self): + """Test converting config to dict.""" + config = PluginConfig( + enable_message_capture=False, + web_interface_port=9090, + ) + + config_dict = config.to_dict() + + assert isinstance(config_dict, dict) + assert config_dict['enable_message_capture'] is False + assert config_dict['web_interface_port'] == 9090 + assert 'learning_interval_hours' in config_dict + + def test_save_to_file_success(self): + """Test saving config to file.""" + config = PluginConfig() + + with tempfile.NamedTemporaryFile( + mode='w', suffix='.json', delete=False + ) as f: + filepath = f.name + + try: + result = config.save_to_file(filepath) + + assert result is True + assert os.path.exists(filepath) + + with open(filepath, 'r', encoding='utf-8') as f: + saved_data = json.load(f) + assert saved_data['enable_message_capture'] is True + finally: + os.unlink(filepath) + + def test_load_from_file_success(self): + """Test loading config from existing file.""" + config_data = { + 'enable_message_capture': False, + 'web_interface_port': 9999, + 'learning_interval_hours': 12, + } + + with tempfile.NamedTemporaryFile( + mode='w', suffix='.json', delete=False + ) as f: + json.dump(config_data, f) + filepath = f.name + + try: + loaded_config = PluginConfig.load_from_file(filepath) + + assert loaded_config.enable_message_capture is False + assert loaded_config.web_interface_port == 9999 + assert loaded_config.learning_interval_hours == 12 + finally: + os.unlink(filepath) + + def test_load_from_nonexistent_file(self): + """Test loading config from nonexistent file returns defaults.""" + loaded_config = PluginConfig.load_from_file("/nonexistent/path.json") + + assert loaded_config.enable_message_capture is True + assert loaded_config.learning_interval_hours == 6 + + def test_load_from_file_with_data_dir(self): + """Test loading config with explicit data_dir override.""" + config_data = {'enable_message_capture': True} + + with tempfile.NamedTemporaryFile( + mode='w', suffix='.json', delete=False + ) as f: + json.dump(config_data, f) + filepath = f.name + + try: + loaded_config = PluginConfig.load_from_file( + filepath, data_dir="/custom/data/dir" + ) + + assert loaded_config.data_dir == "/custom/data/dir" + finally: + os.unlink(filepath) + + def test_load_from_corrupt_file(self): + """Test loading config from corrupt file returns defaults.""" + with tempfile.NamedTemporaryFile( + mode='w', suffix='.json', delete=False + ) as f: + f.write("this is not valid json {{{") + filepath = f.name + + try: + loaded_config = PluginConfig.load_from_file(filepath) + + # Should return default config + assert loaded_config.enable_message_capture is True + finally: + os.unlink(filepath) diff --git a/tests/unit/test_constants.py b/tests/unit/test_constants.py new file mode 100644 index 0000000..badadb2 --- /dev/null +++ b/tests/unit/test_constants.py @@ -0,0 +1,147 @@ +""" +Unit tests for constants module + +Tests the update type normalization and review source resolution: +- normalize_update_type exact and fuzzy matching +- get_review_source_from_update_type classification +- Legacy format backward compatibility +""" +import pytest + +from constants import ( + UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING, + UPDATE_TYPE_STYLE_LEARNING, + UPDATE_TYPE_EXPRESSION_LEARNING, + UPDATE_TYPE_TRADITIONAL, + LEGACY_UPDATE_TYPE_MAPPING, + normalize_update_type, + get_review_source_from_update_type, +) + + +@pytest.mark.unit +class TestNormalizeUpdateType: + """Test normalize_update_type function.""" + + def test_empty_input_returns_traditional(self): + """Test empty or None input returns traditional type.""" + assert normalize_update_type("") == UPDATE_TYPE_TRADITIONAL + assert normalize_update_type(None) == UPDATE_TYPE_TRADITIONAL + + def test_exact_match_progressive_learning(self): + """Test exact match for progressive_learning legacy key.""" + result = normalize_update_type("progressive_learning") + assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING + + def test_exact_match_style_learning(self): + """Test exact match for style_learning.""" + result = normalize_update_type("style_learning") + assert result == UPDATE_TYPE_STYLE_LEARNING + + def test_exact_match_expression_learning(self): + """Test exact match for expression_learning.""" + result = normalize_update_type("expression_learning") + assert result == UPDATE_TYPE_EXPRESSION_LEARNING + + def test_legacy_chinese_progressive_style(self): + """Test legacy Chinese format for progressive style analysis.""" + result = normalize_update_type("渐进式学习-风格分析") + assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING + + def test_legacy_chinese_progressive_persona(self): + """Test legacy Chinese format for progressive persona update.""" + result = normalize_update_type("渐进式学习-人格更新") + assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING + + def test_fuzzy_match_chinese_progressive(self): + """Test fuzzy match with Chinese progressive learning keyword.""" + result = normalize_update_type("渐进式学习-新类型") + assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING + + def test_fuzzy_match_english_progressive(self): + """Test fuzzy match with English progressive keyword.""" + result = normalize_update_type("PROGRESSIVE_update") + assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING + + def test_unknown_type_returns_traditional(self): + """Test unknown type returns traditional.""" + result = normalize_update_type("some_unknown_type") + assert result == UPDATE_TYPE_TRADITIONAL + + def test_already_normalized_value(self): + """Test passing an already normalized constant.""" + result = normalize_update_type(UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING) + assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING + + +@pytest.mark.unit +class TestGetReviewSourceFromUpdateType: + """Test get_review_source_from_update_type function.""" + + def test_progressive_persona_learning_source(self): + """Test progressive persona learning maps to persona_learning.""" + result = get_review_source_from_update_type( + UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING + ) + assert result == 'persona_learning' + + def test_style_learning_source(self): + """Test style learning maps to style_learning.""" + result = get_review_source_from_update_type(UPDATE_TYPE_STYLE_LEARNING) + assert result == 'style_learning' + + def test_expression_learning_source(self): + """Test expression learning maps to persona_learning.""" + result = get_review_source_from_update_type( + UPDATE_TYPE_EXPRESSION_LEARNING + ) + assert result == 'persona_learning' + + def test_traditional_source(self): + """Test traditional update maps to traditional.""" + result = get_review_source_from_update_type(UPDATE_TYPE_TRADITIONAL) + assert result == 'traditional' + + def test_unknown_type_defaults_to_traditional(self): + """Test unknown update type defaults to traditional source.""" + result = get_review_source_from_update_type("random_unknown_type") + assert result == 'traditional' + + def test_legacy_format_normalization(self): + """Test legacy Chinese format is normalized before classification.""" + result = get_review_source_from_update_type("渐进式学习-风格分析") + assert result == 'persona_learning' + + def test_empty_string(self): + """Test empty string defaults to traditional.""" + result = get_review_source_from_update_type("") + assert result == 'traditional' + + +@pytest.mark.unit +class TestLegacyMapping: + """Test legacy update type mapping completeness.""" + + def test_all_legacy_keys_mapped(self): + """Test all legacy keys exist in the mapping.""" + expected_keys = { + "渐进式学习-风格分析", + "渐进式学习-人格更新", + "progressive_learning", + "style_learning", + "expression_learning", + } + + assert set(LEGACY_UPDATE_TYPE_MAPPING.keys()) == expected_keys + + def test_all_legacy_values_are_valid_constants(self): + """Test all legacy values map to valid update type constants.""" + valid_types = { + UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING, + UPDATE_TYPE_STYLE_LEARNING, + UPDATE_TYPE_EXPRESSION_LEARNING, + UPDATE_TYPE_TRADITIONAL, + } + + for value in LEGACY_UPDATE_TYPE_MAPPING.values(): + assert value in valid_types, f"Invalid mapping value: {value}" diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py new file mode 100644 index 0000000..e0667a9 --- /dev/null +++ b/tests/unit/test_exceptions.py @@ -0,0 +1,110 @@ +""" +Unit tests for custom exception hierarchy + +Tests the exception class inheritance chain and instantiation: +- Base exception class +- All derived exception types +- Exception message propagation +- isinstance checks for polymorphic handling +""" +import pytest + +from exceptions import ( + SelfLearningError, + ConfigurationError, + MessageCollectionError, + StyleAnalysisError, + PersonaUpdateError, + ModelAccessError, + DataStorageError, + LearningSchedulerError, + LearningError, + ServiceError, + ResponseError, + BackupError, + ExpressionLearningError, + MemoryGraphError, + TimeDecayError, + MessageAnalysisError, + KnowledgeGraphError, +) + + +@pytest.mark.unit +class TestExceptionHierarchy: + """Test exception class hierarchy and inheritance.""" + + def test_base_exception_is_exception(self): + """Test SelfLearningError is a proper Exception subclass.""" + assert issubclass(SelfLearningError, Exception) + + @pytest.mark.parametrize("exc_class", [ + ConfigurationError, + MessageCollectionError, + StyleAnalysisError, + PersonaUpdateError, + ModelAccessError, + DataStorageError, + LearningSchedulerError, + LearningError, + ServiceError, + ResponseError, + BackupError, + ExpressionLearningError, + MemoryGraphError, + TimeDecayError, + MessageAnalysisError, + KnowledgeGraphError, + ]) + def test_subclass_inherits_from_base(self, exc_class): + """Test all custom exceptions inherit from SelfLearningError.""" + assert issubclass(exc_class, SelfLearningError) + + @pytest.mark.parametrize("exc_class", [ + ConfigurationError, + MessageCollectionError, + StyleAnalysisError, + PersonaUpdateError, + ModelAccessError, + DataStorageError, + LearningSchedulerError, + LearningError, + ServiceError, + ResponseError, + BackupError, + ExpressionLearningError, + MemoryGraphError, + TimeDecayError, + MessageAnalysisError, + KnowledgeGraphError, + ]) + def test_exception_instantiation(self, exc_class): + """Test all exception classes can be instantiated with a message.""" + msg = f"Test error message for {exc_class.__name__}" + exc = exc_class(msg) + + assert str(exc) == msg + + def test_catch_specific_exception(self): + """Test catching a specific exception type.""" + with pytest.raises(ConfigurationError): + raise ConfigurationError("invalid config") + + def test_catch_base_exception_for_derived(self): + """Test catching base exception catches derived types.""" + with pytest.raises(SelfLearningError): + raise DataStorageError("storage failure") + + def test_exception_with_no_message(self): + """Test exception can be raised without a message.""" + exc = SelfLearningError() + assert str(exc) == "" + + def test_exception_chain(self): + """Test exception chaining with __cause__.""" + original = ValueError("original cause") + try: + raise ConfigurationError("config error") from original + except ConfigurationError as e: + assert e.__cause__ is original + assert str(e.__cause__) == "original cause" diff --git a/tests/unit/test_guardrails_models.py b/tests/unit/test_guardrails_models.py new file mode 100644 index 0000000..c72b012 --- /dev/null +++ b/tests/unit/test_guardrails_models.py @@ -0,0 +1,264 @@ +""" +Unit tests for Guardrails Pydantic validation models + +Tests the Pydantic model definitions used for structured LLM output: +- PsychologicalStateTransition validation +- GoalAnalysisResult validation +- ConversationIntentAnalysis defaults and validation +- RelationChange and SocialRelationAnalysis validation +- Field range constraints (ge, le, min_length, max_length) +""" +import pytest +from pydantic import ValidationError + +from utils.guardrails_manager import ( + PsychologicalStateTransition, + GoalAnalysisResult, + ConversationIntentAnalysis, + RelationChange, + SocialRelationAnalysis, +) + + +@pytest.mark.unit +@pytest.mark.utils +class TestPsychologicalStateTransition: + """Test PsychologicalStateTransition Pydantic model.""" + + def test_valid_creation(self): + """Test creating a valid state transition.""" + state = PsychologicalStateTransition( + new_state="愉悦", + confidence=0.85, + reason="Positive conversation detected", + ) + + assert state.new_state == "愉悦" + assert state.confidence == 0.85 + assert state.reason == "Positive conversation detected" + + def test_default_values(self): + """Test default confidence and reason values.""" + state = PsychologicalStateTransition(new_state="平静") + + assert state.confidence == 0.8 + assert state.reason == "" + + def test_state_name_too_long(self): + """Test state name longer than 20 chars is rejected.""" + with pytest.raises(ValidationError): + PsychologicalStateTransition(new_state="a" * 21) + + def test_empty_state_name(self): + """Test empty state name is rejected.""" + with pytest.raises(ValidationError): + PsychologicalStateTransition(new_state="") + + def test_confidence_below_zero(self): + """Test confidence below 0.0 is rejected.""" + with pytest.raises(ValidationError): + PsychologicalStateTransition(new_state="测试", confidence=-0.1) + + def test_confidence_above_one(self): + """Test confidence above 1.0 is rejected.""" + with pytest.raises(ValidationError): + PsychologicalStateTransition(new_state="测试", confidence=1.1) + + def test_confidence_boundary_values(self): + """Test confidence at exact boundaries (0.0 and 1.0).""" + s_low = PsychologicalStateTransition(new_state="低", confidence=0.0) + s_high = PsychologicalStateTransition(new_state="高", confidence=1.0) + + assert s_low.confidence == 0.0 + assert s_high.confidence == 1.0 + + def test_state_name_whitespace_stripped(self): + """Test state name with whitespace is stripped.""" + state = PsychologicalStateTransition(new_state=" 愉悦 ") + assert state.new_state == "愉悦" + + +@pytest.mark.unit +@pytest.mark.utils +class TestGoalAnalysisResult: + """Test GoalAnalysisResult Pydantic model.""" + + def test_valid_creation(self): + """Test creating a valid goal analysis result.""" + result = GoalAnalysisResult( + goal_type="emotional_support", + topic="工作压力", + confidence=0.85, + reasoning="User seems stressed", + ) + + assert result.goal_type == "emotional_support" + assert result.topic == "工作压力" + assert result.confidence == 0.85 + + def test_default_values(self): + """Test default values for optional fields.""" + result = GoalAnalysisResult( + goal_type="casual_chat", + topic="日常", + ) + + assert result.confidence == 0.7 + assert result.reasoning == "" + + def test_goal_type_too_long(self): + """Test goal_type exceeding 50 chars is rejected.""" + with pytest.raises(ValidationError): + GoalAnalysisResult( + goal_type="a" * 51, + topic="test", + ) + + def test_topic_too_long(self): + """Test topic exceeding 100 chars is rejected.""" + with pytest.raises(ValidationError): + GoalAnalysisResult( + goal_type="test", + topic="a" * 101, + ) + + def test_empty_goal_type(self): + """Test empty goal_type is rejected.""" + with pytest.raises(ValidationError): + GoalAnalysisResult(goal_type="", topic="test") + + def test_empty_topic(self): + """Test empty topic is rejected.""" + with pytest.raises(ValidationError): + GoalAnalysisResult(goal_type="test", topic="") + + +@pytest.mark.unit +@pytest.mark.utils +class TestConversationIntentAnalysis: + """Test ConversationIntentAnalysis Pydantic model.""" + + def test_default_values(self): + """Test all default values are correctly set.""" + intent = ConversationIntentAnalysis() + + assert intent.goal_switch_needed is False + assert intent.new_goal_type is None + assert intent.new_topic is None + assert intent.topic_completed is False + assert intent.stage_completed is False + assert intent.stage_adjustment_needed is False + assert intent.suggested_stage is None + assert intent.completion_signals == 0 + assert intent.user_engagement == 0.5 + assert intent.reasoning == "" + + def test_custom_values(self): + """Test setting custom values.""" + intent = ConversationIntentAnalysis( + goal_switch_needed=True, + new_goal_type="knowledge_sharing", + user_engagement=0.9, + completion_signals=3, + ) + + assert intent.goal_switch_needed is True + assert intent.new_goal_type == "knowledge_sharing" + assert intent.user_engagement == 0.9 + assert intent.completion_signals == 3 + + def test_engagement_below_zero(self): + """Test user_engagement below 0.0 is rejected.""" + with pytest.raises(ValidationError): + ConversationIntentAnalysis(user_engagement=-0.1) + + def test_engagement_above_one(self): + """Test user_engagement above 1.0 is rejected.""" + with pytest.raises(ValidationError): + ConversationIntentAnalysis(user_engagement=1.1) + + def test_negative_completion_signals(self): + """Test negative completion_signals is rejected.""" + with pytest.raises(ValidationError): + ConversationIntentAnalysis(completion_signals=-1) + + +@pytest.mark.unit +@pytest.mark.utils +class TestRelationChange: + """Test RelationChange Pydantic model.""" + + def test_valid_creation(self): + """Test creating a valid relation change.""" + change = RelationChange( + relation_type="挚友", + value_delta=0.1, + reason="Shared positive experience", + ) + + assert change.relation_type == "挚友" + assert change.value_delta == 0.1 + + def test_relation_type_too_long(self): + """Test relation_type exceeding 30 chars is rejected.""" + with pytest.raises(ValidationError): + RelationChange( + relation_type="a" * 31, + value_delta=0.1, + ) + + def test_value_delta_below_negative_one(self): + """Test value_delta below -1.0 is rejected.""" + with pytest.raises(ValidationError): + RelationChange(relation_type="test", value_delta=-1.1) + + def test_value_delta_above_one(self): + """Test value_delta above 1.0 is rejected.""" + with pytest.raises(ValidationError): + RelationChange(relation_type="test", value_delta=1.1) + + def test_boundary_values(self): + """Test boundary values for value_delta.""" + low = RelationChange(relation_type="低", value_delta=-1.0) + high = RelationChange(relation_type="高", value_delta=1.0) + + assert low.value_delta == -1.0 + assert high.value_delta == 1.0 + + +@pytest.mark.unit +@pytest.mark.utils +class TestSocialRelationAnalysis: + """Test SocialRelationAnalysis Pydantic model.""" + + def test_valid_creation(self): + """Test creating a valid social relation analysis.""" + analysis = SocialRelationAnalysis( + relations=[ + RelationChange(relation_type="友情", value_delta=0.05), + RelationChange(relation_type="信任", value_delta=0.02), + ], + overall_sentiment="positive", + ) + + assert len(analysis.relations) == 2 + assert analysis.overall_sentiment == "positive" + + def test_empty_relations(self): + """Test empty relations list is valid.""" + analysis = SocialRelationAnalysis(relations=[]) + assert len(analysis.relations) == 0 + + def test_default_sentiment(self): + """Test default overall_sentiment is neutral.""" + analysis = SocialRelationAnalysis(relations=[]) + assert analysis.overall_sentiment == "neutral" + + def test_max_five_relations(self): + """Test relations are capped at 5.""" + relations = [ + RelationChange(relation_type=f"type_{i}", value_delta=0.01) + for i in range(7) + ] + analysis = SocialRelationAnalysis(relations=relations) + assert len(analysis.relations) == 5 diff --git a/tests/unit/test_interfaces.py b/tests/unit/test_interfaces.py new file mode 100644 index 0000000..0643489 --- /dev/null +++ b/tests/unit/test_interfaces.py @@ -0,0 +1,244 @@ +""" +Unit tests for core interfaces module + +Tests the core data classes, enums, and interface definitions: +- MessageData dataclass construction and defaults +- AnalysisResult dataclass construction and defaults +- PersonaUpdateRecord dataclass construction and defaults +- ServiceLifecycle enum values +- LearningStrategyType enum values +- AnalysisType enum values +""" +import pytest +from unittest.mock import MagicMock + +from core.interfaces import ( + ServiceLifecycle, + MessageData, + AnalysisResult, + PersonaUpdateRecord, + LearningStrategyType, + AnalysisType, +) + + +@pytest.mark.unit +@pytest.mark.core +class TestMessageData: + """Test MessageData dataclass.""" + + def test_required_fields(self): + """Test creating MessageData with all required fields.""" + msg = MessageData( + sender_id="user_001", + sender_name="Alice", + message="Hello world", + group_id="group_001", + timestamp=1700000000.0, + platform="qq", + ) + + assert msg.sender_id == "user_001" + assert msg.sender_name == "Alice" + assert msg.message == "Hello world" + assert msg.group_id == "group_001" + assert msg.timestamp == 1700000000.0 + assert msg.platform == "qq" + + def test_optional_fields_default_none(self): + """Test optional fields default to None.""" + msg = MessageData( + sender_id="user_001", + sender_name="Alice", + message="Hello", + group_id="group_001", + timestamp=1700000000.0, + platform="qq", + ) + + assert msg.message_id is None + assert msg.reply_to is None + + def test_optional_fields_set_explicitly(self): + """Test optional fields can be set explicitly.""" + msg = MessageData( + sender_id="user_001", + sender_name="Alice", + message="Hello", + group_id="group_001", + timestamp=1700000000.0, + platform="qq", + message_id="msg_123", + reply_to="msg_100", + ) + + assert msg.message_id == "msg_123" + assert msg.reply_to == "msg_100" + + +@pytest.mark.unit +@pytest.mark.core +class TestAnalysisResult: + """Test AnalysisResult dataclass.""" + + def test_required_fields(self): + """Test creating AnalysisResult with required fields.""" + result = AnalysisResult( + success=True, + confidence=0.85, + data={"key": "value"}, + ) + + assert result.success is True + assert result.confidence == 0.85 + assert result.data == {"key": "value"} + + def test_default_values(self): + """Test AnalysisResult default values.""" + result = AnalysisResult( + success=True, + confidence=0.9, + data={}, + ) + + assert result.timestamp == 0.0 + assert result.error is None + assert result.consistency_score is None + + def test_with_error(self): + """Test AnalysisResult with error information.""" + result = AnalysisResult( + success=False, + confidence=0.0, + data={}, + error="Analysis failed due to insufficient data", + ) + + assert result.success is False + assert result.error == "Analysis failed due to insufficient data" + + def test_with_consistency_score(self): + """Test AnalysisResult with consistency score.""" + result = AnalysisResult( + success=True, + confidence=0.8, + data={"metrics": [1, 2, 3]}, + consistency_score=0.75, + ) + + assert result.consistency_score == 0.75 + + +@pytest.mark.unit +@pytest.mark.core +class TestPersonaUpdateRecord: + """Test PersonaUpdateRecord dataclass.""" + + def test_required_fields(self): + """Test creating PersonaUpdateRecord with required fields.""" + record = PersonaUpdateRecord( + timestamp=1700000000.0, + group_id="group_001", + update_type="prompt_update", + original_content="Original prompt", + new_content="New prompt", + reason="Style analysis update", + ) + + assert record.timestamp == 1700000000.0 + assert record.group_id == "group_001" + assert record.update_type == "prompt_update" + assert record.original_content == "Original prompt" + assert record.new_content == "New prompt" + assert record.reason == "Style analysis update" + + def test_default_values(self): + """Test PersonaUpdateRecord default values.""" + record = PersonaUpdateRecord( + timestamp=0.0, + group_id="g1", + update_type="test", + original_content="", + new_content="", + reason="", + ) + + assert record.confidence_score == 0.5 + assert record.id is None + assert record.status == "pending" + assert record.reviewer_comment is None + assert record.review_time is None + + def test_approved_record(self): + """Test PersonaUpdateRecord with approved status.""" + record = PersonaUpdateRecord( + timestamp=1700000000.0, + group_id="g1", + update_type="prompt_update", + original_content="old", + new_content="new", + reason="update", + id=42, + status="approved", + reviewer_comment="Looks good", + review_time=1700001000.0, + ) + + assert record.id == 42 + assert record.status == "approved" + assert record.reviewer_comment == "Looks good" + assert record.review_time == 1700001000.0 + + +@pytest.mark.unit +@pytest.mark.core +class TestServiceLifecycleEnum: + """Test ServiceLifecycle enum.""" + + def test_all_states_exist(self): + """Test all expected lifecycle states exist.""" + assert ServiceLifecycle.CREATED.value == "created" + assert ServiceLifecycle.INITIALIZING.value == "initializing" + assert ServiceLifecycle.RUNNING.value == "running" + assert ServiceLifecycle.STOPPING.value == "stopping" + assert ServiceLifecycle.STOPPED.value == "stopped" + assert ServiceLifecycle.ERROR.value == "error" + + def test_enum_count(self): + """Test the total number of lifecycle states.""" + assert len(ServiceLifecycle) == 6 + + +@pytest.mark.unit +@pytest.mark.core +class TestLearningStrategyTypeEnum: + """Test LearningStrategyType enum.""" + + def test_all_strategies_exist(self): + """Test all expected strategy types exist.""" + assert LearningStrategyType.PROGRESSIVE.value == "progressive" + assert LearningStrategyType.BATCH.value == "batch" + assert LearningStrategyType.REALTIME.value == "realtime" + assert LearningStrategyType.HYBRID.value == "hybrid" + + def test_enum_count(self): + """Test the total number of strategy types.""" + assert len(LearningStrategyType) == 4 + + +@pytest.mark.unit +@pytest.mark.core +class TestAnalysisTypeEnum: + """Test AnalysisType enum.""" + + def test_all_types_exist(self): + """Test all expected analysis types exist.""" + assert AnalysisType.STYLE.value == "style" + assert AnalysisType.SENTIMENT.value == "sentiment" + assert AnalysisType.TOPIC.value == "topic" + assert AnalysisType.BEHAVIOR.value == "behavior" + assert AnalysisType.QUALITY.value == "quality" + + def test_enum_count(self): + """Test the total number of analysis types.""" + assert len(AnalysisType) == 5 diff --git a/tests/unit/test_json_utils.py b/tests/unit/test_json_utils.py new file mode 100644 index 0000000..35b1dac --- /dev/null +++ b/tests/unit/test_json_utils.py @@ -0,0 +1,391 @@ +""" +Unit tests for JSON utilities module + +Tests LLM response parsing, markdown cleanup, and JSON validation: +- remove_thinking_content for various LLM thinking tags +- clean_markdown_blocks for code block removal +- clean_control_characters for sanitization +- extract_json_content for boundary detection +- fix_common_json_errors for auto-repair +- safe_parse_llm_json for end-to-end parsing +- validate_json_structure for schema validation +- detect_llm_provider for model name detection +""" +import pytest + +from utils.json_utils import ( + remove_thinking_content, + extract_thinking_content, + clean_markdown_blocks, + clean_control_characters, + extract_json_content, + fix_common_json_errors, + clean_llm_json_response, + safe_parse_llm_json, + validate_json_structure, + detect_llm_provider, + LLMProvider, + _convert_single_quotes, +) + + +@pytest.mark.unit +@pytest.mark.utils +class TestRemoveThinkingContent: + """Test removal of LLM thinking tags.""" + + def test_empty_input(self): + """Test empty input returns as-is.""" + assert remove_thinking_content("") == "" + assert remove_thinking_content(None) is None + + def test_remove_thinking_tags(self): + """Test removal of tags.""" + text = "Internal reasoningFinal answer" + result = remove_thinking_content(text) + assert "Internal reasoning" not in result + assert "Final answer" in result + + def test_remove_thought_tags(self): + """Test removal of tags.""" + text = "Analysis hereResult" + result = remove_thinking_content(text) + assert "Analysis here" not in result + assert "Result" in result + + def test_remove_reasoning_tags(self): + """Test removal of tags.""" + text = "Step 1, Step 2Output" + result = remove_thinking_content(text) + assert "Step 1" not in result + assert "Output" in result + + def test_remove_think_tags(self): + """Test removal of tags.""" + text = "Hmm let me thinkAnswer is 42" + result = remove_thinking_content(text) + assert "Hmm let me think" not in result + assert "Answer is 42" in result + + def test_remove_chinese_thinking_tags(self): + """Test removal of Chinese thinking tags.""" + text = "<思考>这是思考过程最终结果" + result = remove_thinking_content(text) + assert "这是思考过程" not in result + assert "最终结果" in result + + def test_multiline_thinking_content(self): + """Test removal of multiline thinking content.""" + text = "\nLine 1\nLine 2\nLine 3\nFinal" + result = remove_thinking_content(text) + assert "Line 1" not in result + assert "Final" in result + + def test_text_without_thinking_tags(self): + """Test text without thinking tags is unchanged.""" + text = "Just a regular response without any tags" + result = remove_thinking_content(text) + assert result == text + + +@pytest.mark.unit +@pytest.mark.utils +class TestExtractThinkingContent: + """Test extraction and separation of thinking content.""" + + def test_extract_thinking(self): + """Test extracting thinking content.""" + text = "My thoughtsAnswer" + cleaned, thoughts = extract_thinking_content(text) + + assert "Answer" in cleaned + assert len(thoughts) >= 1 + + def test_no_thinking_content(self): + """Test text without thinking content.""" + text = "Plain text response" + cleaned, thoughts = extract_thinking_content(text) + + assert cleaned == "Plain text response" + assert thoughts == [] + + def test_empty_input(self): + """Test empty input.""" + cleaned, thoughts = extract_thinking_content("") + assert cleaned == "" + assert thoughts == [] + + +@pytest.mark.unit +@pytest.mark.utils +class TestCleanMarkdownBlocks: + """Test markdown code block cleaning.""" + + def test_clean_json_code_block(self): + """Test cleaning ```json code block.""" + text = '```json\n{"key": "value"}\n```' + result = clean_markdown_blocks(text) + assert result == '{"key": "value"}' + + def test_clean_plain_code_block(self): + """Test cleaning plain ``` code block.""" + text = '```\n{"key": "value"}\n```' + result = clean_markdown_blocks(text) + assert result == '{"key": "value"}' + + def test_no_code_blocks(self): + """Test text without code blocks is unchanged.""" + text = '{"key": "value"}' + result = clean_markdown_blocks(text) + assert result == text + + def test_empty_input(self): + """Test empty input returns as-is.""" + assert clean_markdown_blocks("") == "" + assert clean_markdown_blocks(None) is None + + +@pytest.mark.unit +@pytest.mark.utils +class TestCleanControlCharacters: + """Test control character cleaning.""" + + def test_remove_null_bytes(self): + """Test removal of null bytes.""" + text = "hello\x00world" + result = clean_control_characters(text) + assert result == "helloworld" + + def test_preserve_tabs_and_newlines(self): + """Test preservation of tabs and newlines.""" + text = "hello\tworld\nfoo" + result = clean_control_characters(text) + assert result == text + + def test_empty_input(self): + """Test empty input.""" + assert clean_control_characters("") == "" + assert clean_control_characters(None) is None + + +@pytest.mark.unit +@pytest.mark.utils +class TestExtractJsonContent: + """Test JSON content extraction from mixed text.""" + + def test_extract_json_object(self): + """Test extracting JSON object from text.""" + text = 'Some text {"key": "value"} more text' + result = extract_json_content(text) + assert result == '{"key": "value"}' + + def test_extract_json_array(self): + """Test extracting JSON array from text.""" + text = 'Prefix [1, 2, 3] suffix' + result = extract_json_content(text) + assert result == '[1, 2, 3]' + + def test_no_json_content(self): + """Test text without JSON returns original.""" + text = "no json here" + result = extract_json_content(text) + assert result == text + + def test_nested_json(self): + """Test extracting nested JSON object.""" + text = '{"outer": {"inner": "value"}}' + result = extract_json_content(text) + assert result == text + + def test_empty_input(self): + """Test empty input.""" + assert extract_json_content("") == "" + + +@pytest.mark.unit +@pytest.mark.utils +class TestFixCommonJsonErrors: + """Test JSON error auto-repair.""" + + def test_fix_trailing_comma_object(self): + """Test fixing trailing comma in object.""" + text = '{"key": "value",}' + result = fix_common_json_errors(text) + assert result == '{"key": "value"}' + + def test_fix_trailing_comma_array(self): + """Test fixing trailing comma in array.""" + text = '[1, 2, 3,]' + result = fix_common_json_errors(text) + assert result == '[1, 2, 3]' + + def test_fix_python_true_false(self): + """Test fixing Python True/False/None to JSON equivalents.""" + text = '{"flag": True, "empty": None, "off": False}' + result = fix_common_json_errors(text) + assert ": true" in result + assert ": null" in result + assert ": false" in result + + def test_fix_nan_value(self): + """Test fixing NaN to null.""" + text = '{"score": nan}' + result = fix_common_json_errors(text) + assert ": null" in result + + def test_empty_input(self): + """Test empty input.""" + assert fix_common_json_errors("") == "" + + +@pytest.mark.unit +@pytest.mark.utils +class TestSafeParseLlmJson: + """Test end-to-end safe JSON parsing.""" + + def test_parse_clean_json(self): + """Test parsing clean JSON.""" + result = safe_parse_llm_json('{"key": "value"}') + assert result == {"key": "value"} + + def test_parse_json_in_markdown(self): + """Test parsing JSON wrapped in markdown code block.""" + text = '```json\n{"key": "value"}\n```' + result = safe_parse_llm_json(text) + assert result == {"key": "value"} + + def test_parse_json_with_thinking_tags(self): + """Test parsing JSON with thinking tags.""" + text = 'Analysis{"result": 42}' + result = safe_parse_llm_json(text) + assert result == {"result": 42} + + def test_parse_json_with_trailing_comma(self): + """Test parsing JSON with trailing comma.""" + text = '{"key": "value",}' + result = safe_parse_llm_json(text) + assert result == {"key": "value"} + + def test_parse_invalid_json_returns_fallback(self): + """Test invalid JSON returns fallback result.""" + result = safe_parse_llm_json("not json at all", fallback_result={"default": True}) + assert result == {"default": True} + + def test_parse_empty_input(self): + """Test empty input returns fallback.""" + result = safe_parse_llm_json("", fallback_result=None) + assert result is None + + def test_parse_json_array(self): + """Test parsing JSON array.""" + result = safe_parse_llm_json('[1, 2, 3]') + assert result == [1, 2, 3] + + def test_parse_with_single_quotes(self): + """Test parsing JSON with single quotes.""" + text = "{'key': 'value'}" + result = safe_parse_llm_json(text) + assert result == {"key": "value"} + + def test_parse_nested_json(self): + """Test parsing nested JSON structure.""" + text = '{"outer": {"inner": [1, 2, 3]}, "flag": true}' + result = safe_parse_llm_json(text) + assert result["outer"]["inner"] == [1, 2, 3] + assert result["flag"] is True + + +@pytest.mark.unit +@pytest.mark.utils +class TestValidateJsonStructure: + """Test JSON structure validation.""" + + def test_valid_with_required_fields(self): + """Test validation with all required fields present.""" + data = {"name": "Alice", "age": 30} + valid, msg = validate_json_structure( + data, required_fields=["name", "age"] + ) + assert valid is True + assert msg == "" + + def test_missing_required_fields(self): + """Test validation with missing required fields.""" + data = {"name": "Alice"} + valid, msg = validate_json_structure( + data, required_fields=["name", "age"] + ) + assert valid is False + assert "age" in msg + + def test_none_data(self): + """Test validation with None data.""" + valid, msg = validate_json_structure(None) + assert valid is False + + def test_type_check_success(self): + """Test validation with correct expected type.""" + valid, msg = validate_json_structure( + {"key": "value"}, expected_type=dict + ) + assert valid is True + + def test_type_check_failure(self): + """Test validation with incorrect expected type.""" + valid, msg = validate_json_structure( + [1, 2, 3], expected_type=dict + ) + assert valid is False + assert "dict" in msg + + +@pytest.mark.unit +@pytest.mark.utils +class TestDetectLlmProvider: + """Test LLM provider detection from model names.""" + + def test_detect_deepseek(self): + """Test detecting DeepSeek provider.""" + assert detect_llm_provider("deepseek-chat") == LLMProvider.DEEPSEEK + assert detect_llm_provider("deepseek-reasoner") == LLMProvider.DEEPSEEK + + def test_detect_anthropic(self): + """Test detecting Anthropic provider.""" + assert detect_llm_provider("claude-3-opus") == LLMProvider.ANTHROPIC + assert detect_llm_provider("claude-3.5-sonnet") == LLMProvider.ANTHROPIC + + def test_detect_openai(self): + """Test detecting OpenAI provider.""" + assert detect_llm_provider("gpt-4") == LLMProvider.OPENAI + assert detect_llm_provider("gpt-4o-mini") == LLMProvider.OPENAI + + def test_detect_unknown(self): + """Test detecting unknown provider.""" + assert detect_llm_provider("some-custom-model") == LLMProvider.GENERIC + + def test_detect_empty_input(self): + """Test detecting from empty input.""" + assert detect_llm_provider("") == LLMProvider.GENERIC + assert detect_llm_provider(None) == LLMProvider.GENERIC + + +@pytest.mark.unit +@pytest.mark.utils +class TestConvertSingleQuotes: + """Test single-to-double quote conversion.""" + + def test_basic_conversion(self): + """Test basic single quote to double quote conversion.""" + result = _convert_single_quotes("{'key': 'value'}") + assert result == '{"key": "value"}' + + def test_already_double_quotes(self): + """Test text with double quotes is unchanged.""" + text = '{"key": "value"}' + result = _convert_single_quotes(text) + assert result == text + + def test_empty_input(self): + """Test empty input.""" + assert _convert_single_quotes("") == "" + assert _convert_single_quotes(None) is None diff --git a/tests/unit/test_learning_quality_monitor.py b/tests/unit/test_learning_quality_monitor.py new file mode 100644 index 0000000..2d55baa --- /dev/null +++ b/tests/unit/test_learning_quality_monitor.py @@ -0,0 +1,586 @@ +""" +Unit tests for LearningQualityMonitor + +Tests the learning quality monitoring service: +- PersonaMetrics and LearningAlert dataclasses +- Consistency calculation (text similarity fallback) +- Style stability calculation +- Vocabulary diversity calculation +- Emotional balance calculation (simple fallback) +- Coherence calculation +- Quality alert generation +- Style drift detection +- Threshold dynamic adjustment +- Pause learning decision +- Quality report generation +""" +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from datetime import datetime, timedelta + +from services.quality.learning_quality_monitor import ( + LearningQualityMonitor, + PersonaMetrics, + LearningAlert, +) + + +def _create_monitor( + consistency_threshold=0.5, + stability_threshold=0.4, + drift_threshold=0.4, +) -> LearningQualityMonitor: + """Create a LearningQualityMonitor with mocked dependencies.""" + config = MagicMock() + context = MagicMock() + + monitor = LearningQualityMonitor( + config=config, + context=context, + llm_adapter=None, + prompts=None, + ) + monitor.consistency_threshold = consistency_threshold + monitor.stability_threshold = stability_threshold + monitor.drift_threshold = drift_threshold + + return monitor + + +def _make_messages(texts): + """Helper to create message dicts from text list.""" + return [{"message": text} for text in texts] + + +@pytest.mark.unit +@pytest.mark.quality +class TestPersonaMetrics: + """Test PersonaMetrics dataclass.""" + + def test_default_values(self): + """Test default metric values.""" + metrics = PersonaMetrics() + + assert metrics.consistency_score == 0.0 + assert metrics.style_stability == 0.0 + assert metrics.vocabulary_diversity == 0.0 + assert metrics.emotional_balance == 0.0 + assert metrics.coherence_score == 0.0 + + def test_custom_values(self): + """Test custom metric values.""" + metrics = PersonaMetrics( + consistency_score=0.85, + style_stability=0.9, + vocabulary_diversity=0.7, + emotional_balance=0.65, + coherence_score=0.8, + ) + + assert metrics.consistency_score == 0.85 + assert metrics.style_stability == 0.9 + + +@pytest.mark.unit +@pytest.mark.quality +class TestLearningAlert: + """Test LearningAlert dataclass.""" + + def test_alert_creation(self): + """Test creating a learning alert.""" + alert = LearningAlert( + alert_type="consistency", + severity="high", + message="Consistency dropped below threshold", + timestamp=datetime.now().isoformat(), + metrics={"consistency_score": 0.3}, + suggestions=["Review persona changes"], + ) + + assert alert.alert_type == "consistency" + assert alert.severity == "high" + assert len(alert.suggestions) == 1 + + +@pytest.mark.unit +@pytest.mark.quality +class TestConsistencyCalculation: + """Test persona consistency score calculation.""" + + @pytest.mark.asyncio + async def test_both_empty_personas(self): + """Test consistency when both personas are empty.""" + monitor = _create_monitor() + + score = await monitor._calculate_consistency( + {"prompt": ""}, {"prompt": ""} + ) + assert score == 0.7 + + @pytest.mark.asyncio + async def test_one_empty_persona(self): + """Test consistency when one persona is empty.""" + monitor = _create_monitor() + + score = await monitor._calculate_consistency( + {"prompt": "I am a helpful bot"}, {"prompt": ""} + ) + assert score == 0.6 + + @pytest.mark.asyncio + async def test_identical_personas(self): + """Test consistency when personas are identical.""" + monitor = _create_monitor() + prompt = "I am a friendly chatbot." + + score = await monitor._calculate_consistency( + {"prompt": prompt}, {"prompt": prompt} + ) + assert score == 0.95 + + @pytest.mark.asyncio + async def test_similar_personas_fallback(self): + """Test consistency using text similarity fallback (no LLM).""" + monitor = _create_monitor() + + score = await monitor._calculate_consistency( + {"prompt": "I am a helpful assistant."}, + {"prompt": "I am a helpful assistant. I like chatting."}, + ) + assert 0.4 <= score <= 1.0 + + +@pytest.mark.unit +@pytest.mark.quality +class TestTextSimilarity: + """Test text similarity fallback method.""" + + def test_identical_texts(self): + """Test identical texts return high similarity.""" + monitor = _create_monitor() + + score = monitor._calculate_text_similarity("hello world", "hello world") + assert score == 0.95 + + def test_empty_texts(self): + """Test empty texts return default.""" + monitor = _create_monitor() + + score = monitor._calculate_text_similarity("", "") + assert score == 0.6 + + def test_one_empty_text(self): + """Test one empty text returns default.""" + monitor = _create_monitor() + + score = monitor._calculate_text_similarity("hello", "") + assert score == 0.6 + + def test_different_texts(self): + """Test different texts return lower similarity.""" + monitor = _create_monitor() + + score = monitor._calculate_text_similarity("abc", "xyz") + assert 0.4 <= score <= 1.0 + + +@pytest.mark.unit +@pytest.mark.quality +class TestStyleStability: + """Test style stability calculation.""" + + @pytest.mark.asyncio + async def test_single_message_perfect_stability(self): + """Test single message returns perfect stability.""" + monitor = _create_monitor() + messages = _make_messages(["Hello!"]) + + score = await monitor._calculate_style_stability(messages) + assert score == 1.0 + + @pytest.mark.asyncio + async def test_identical_messages_high_stability(self): + """Test identical messages have high stability.""" + monitor = _create_monitor() + messages = _make_messages(["Hello!", "Hello!", "Hello!"]) + + score = await monitor._calculate_style_stability(messages) + assert score >= 0.8 + + @pytest.mark.asyncio + async def test_diverse_messages_lower_stability(self): + """Test diverse messages have lower stability.""" + monitor = _create_monitor() + messages = _make_messages([ + "Hi", + "This is a very long message with lots of words and punctuation! Really?", + "Ok", + ]) + + score = await monitor._calculate_style_stability(messages) + assert 0.0 <= score <= 1.0 + + +@pytest.mark.unit +@pytest.mark.quality +class TestVocabularyDiversity: + """Test vocabulary diversity calculation.""" + + @pytest.mark.asyncio + async def test_empty_messages(self): + """Test empty messages return zero diversity.""" + monitor = _create_monitor() + + score = await monitor._calculate_vocabulary_diversity([]) + assert score == 0.0 + + @pytest.mark.asyncio + async def test_single_word_messages(self): + """Test messages with same word have low diversity (actually 1.0).""" + monitor = _create_monitor() + messages = _make_messages(["hello", "hello", "hello"]) + + score = await monitor._calculate_vocabulary_diversity(messages) + # All same word: unique=1, total=3, ratio=0.33, *2=0.66 + assert 0.5 <= score <= 1.0 + + @pytest.mark.asyncio + async def test_unique_words_high_diversity(self): + """Test messages with all unique words have high diversity.""" + monitor = _create_monitor() + messages = _make_messages(["apple banana", "cherry date", "elderberry fig"]) + + score = await monitor._calculate_vocabulary_diversity(messages) + assert score >= 0.8 + + +@pytest.mark.unit +@pytest.mark.quality +class TestEmotionalBalance: + """Test emotional balance calculation (simple fallback).""" + + def test_neutral_messages(self): + """Test messages without emotional words return high balance.""" + monitor = _create_monitor() + messages = _make_messages(["今天天气不错", "我去了公园"]) + + score = monitor._simple_emotional_balance(messages) + assert score == 0.8 # No emotional words = neutral + + def test_positive_messages(self): + """Test messages with positive words.""" + monitor = _create_monitor() + messages = _make_messages(["好棒啊!", "真的很开心!喜欢!"]) + + score = monitor._simple_emotional_balance(messages) + assert 0.0 <= score <= 1.0 + + def test_negative_messages(self): + """Test messages with negative words.""" + monitor = _create_monitor() + messages = _make_messages(["不好", "真烦人,讨厌"]) + + score = monitor._simple_emotional_balance(messages) + assert 0.0 <= score <= 1.0 + + def test_balanced_messages(self): + """Test balanced positive and negative messages.""" + monitor = _create_monitor() + messages = _make_messages(["好开心", "不好"]) + + score = monitor._simple_emotional_balance(messages) + assert 0.0 <= score <= 1.0 + + +@pytest.mark.unit +@pytest.mark.quality +class TestCoherence: + """Test coherence calculation.""" + + @pytest.mark.asyncio + async def test_empty_persona(self): + """Test empty persona returns zero coherence.""" + monitor = _create_monitor() + + score = await monitor._calculate_coherence({"prompt": ""}) + assert score == 0.0 + + @pytest.mark.asyncio + async def test_single_sentence(self): + """Test single sentence returns high coherence.""" + monitor = _create_monitor() + + score = await monitor._calculate_coherence({"prompt": "我是一个友好的助手"}) + assert score == 0.8 + + @pytest.mark.asyncio + async def test_multiple_sentences(self): + """Test multiple sentences are evaluated.""" + monitor = _create_monitor() + prompt = "我是一个友好的助手。我喜欢帮助人。我会用中文交流。" + + score = await monitor._calculate_coherence({"prompt": prompt}) + assert 0.0 <= score <= 1.0 + + +@pytest.mark.unit +@pytest.mark.quality +class TestStyleDrift: + """Test style drift detection.""" + + def test_no_drift_identical_metrics(self): + """Test no drift when metrics are identical.""" + monitor = _create_monitor() + metrics = PersonaMetrics( + consistency_score=0.8, + style_stability=0.7, + vocabulary_diversity=0.6, + ) + + drift = monitor._calculate_style_drift(metrics, metrics) + assert drift == 0.0 + + def test_large_drift(self): + """Test large drift detection.""" + monitor = _create_monitor() + prev = PersonaMetrics( + consistency_score=0.9, + style_stability=0.8, + vocabulary_diversity=0.7, + ) + curr = PersonaMetrics( + consistency_score=0.3, + style_stability=0.2, + vocabulary_diversity=0.1, + ) + + drift = monitor._calculate_style_drift(prev, curr) + assert drift > 0.4 + + +@pytest.mark.unit +@pytest.mark.quality +class TestQualityAlerts: + """Test quality alert generation.""" + + @pytest.mark.asyncio + async def test_consistency_alert(self): + """Test alert is generated when consistency is below threshold.""" + monitor = _create_monitor(consistency_threshold=0.5) + metrics = PersonaMetrics(consistency_score=0.3) + + await monitor._check_quality_alerts(metrics) + + assert len(monitor.alerts_history) >= 1 + assert any(a.alert_type == "consistency" for a in monitor.alerts_history) + + @pytest.mark.asyncio + async def test_stability_alert(self): + """Test alert is generated when stability is below threshold.""" + monitor = _create_monitor(stability_threshold=0.4) + metrics = PersonaMetrics(style_stability=0.2) + + await monitor._check_quality_alerts(metrics) + + assert any(a.alert_type == "stability" for a in monitor.alerts_history) + + @pytest.mark.asyncio + async def test_no_alert_when_above_thresholds(self): + """Test no alerts when all metrics are above thresholds.""" + monitor = _create_monitor() + metrics = PersonaMetrics( + consistency_score=0.9, + style_stability=0.8, + vocabulary_diversity=0.7, + ) + + await monitor._check_quality_alerts(metrics) + assert len(monitor.alerts_history) == 0 + + @pytest.mark.asyncio + async def test_drift_alert_with_history(self): + """Test drift alert when historical metrics exist.""" + monitor = _create_monitor(drift_threshold=0.1) + # Add previous metrics + monitor.historical_metrics.append( + PersonaMetrics(consistency_score=0.9, style_stability=0.9, vocabulary_diversity=0.9) + ) + # Current metrics show significant change + current = PersonaMetrics( + consistency_score=0.3, + style_stability=0.3, + vocabulary_diversity=0.3, + ) + monitor.historical_metrics.append(current) + + await monitor._check_quality_alerts(current) + + assert any(a.alert_type == "drift" for a in monitor.alerts_history) + + +@pytest.mark.unit +@pytest.mark.quality +class TestShouldPauseLearning: + """Test learning pause decision logic.""" + + @pytest.mark.asyncio + async def test_no_history_no_pause(self): + """Test no pause with empty history.""" + monitor = _create_monitor() + + should_pause, reason = await monitor.should_pause_learning() + assert should_pause is False + + @pytest.mark.asyncio + async def test_pause_on_low_consistency(self): + """Test pause when consistency is critically low.""" + monitor = _create_monitor() + monitor.historical_metrics.append( + PersonaMetrics(consistency_score=0.3) + ) + + should_pause, reason = await monitor.should_pause_learning() + assert should_pause is True + assert "一致性" in reason + + @pytest.mark.asyncio + async def test_no_pause_above_threshold(self): + """Test no pause when metrics are acceptable.""" + monitor = _create_monitor() + monitor.historical_metrics.append( + PersonaMetrics(consistency_score=0.8) + ) + + should_pause, reason = await monitor.should_pause_learning() + assert should_pause is False + + +@pytest.mark.unit +@pytest.mark.quality +class TestQualityReport: + """Test quality report generation.""" + + @pytest.mark.asyncio + async def test_report_no_history(self): + """Test report with no historical data.""" + monitor = _create_monitor() + + report = await monitor.get_quality_report() + assert "error" in report + + @pytest.mark.asyncio + async def test_report_with_single_metric(self): + """Test report with single historical metric.""" + monitor = _create_monitor() + monitor.historical_metrics.append( + PersonaMetrics( + consistency_score=0.8, + style_stability=0.7, + vocabulary_diversity=0.6, + emotional_balance=0.5, + coherence_score=0.9, + ) + ) + + report = await monitor.get_quality_report() + + assert "current_metrics" in report + assert report["current_metrics"]["consistency_score"] == 0.8 + assert "trends" in report + assert "recommendations" in report + + @pytest.mark.asyncio + async def test_report_with_trends(self): + """Test report includes trend data when sufficient history exists.""" + monitor = _create_monitor() + monitor.historical_metrics.append( + PersonaMetrics(consistency_score=0.6, style_stability=0.5, vocabulary_diversity=0.4) + ) + monitor.historical_metrics.append( + PersonaMetrics(consistency_score=0.8, style_stability=0.7, vocabulary_diversity=0.6) + ) + + report = await monitor.get_quality_report() + + assert report["trends"]["consistency_trend"] == pytest.approx(0.2) + assert report["trends"]["stability_trend"] == pytest.approx(0.2) + + +@pytest.mark.unit +@pytest.mark.quality +class TestDynamicThresholdAdjustment: + """Test dynamic threshold adjustment based on history.""" + + @pytest.mark.asyncio + async def test_no_adjustment_insufficient_history(self): + """Test no adjustment with less than 5 historical entries.""" + monitor = _create_monitor(consistency_threshold=0.5) + + for _ in range(3): + monitor.historical_metrics.append(PersonaMetrics(consistency_score=0.9)) + + await monitor.adjust_thresholds_based_on_history() + + # Should remain unchanged + assert monitor.consistency_threshold == 0.5 + + @pytest.mark.asyncio + async def test_threshold_increases_on_good_performance(self): + """Test threshold increases when performance is consistently good.""" + monitor = _create_monitor(consistency_threshold=0.5) + + for _ in range(5): + monitor.historical_metrics.append( + PersonaMetrics(consistency_score=0.85, style_stability=0.75) + ) + + await monitor.adjust_thresholds_based_on_history() + + assert monitor.consistency_threshold == 0.55 # Increased by 0.05 + + +@pytest.mark.unit +@pytest.mark.quality +class TestHelperMethods: + """Test helper methods.""" + + def test_punctuation_ratio(self): + """Test punctuation ratio calculation.""" + monitor = _create_monitor() + + assert monitor._get_punctuation_ratio("") == 0.0 + assert monitor._get_punctuation_ratio("hello") == 0.0 + assert monitor._get_punctuation_ratio("你好,世界!") > 0.0 + + def test_count_emoji(self): + """Test emoji counting.""" + monitor = _create_monitor() + + assert monitor._count_emoji("hello") == 0 + # The emoji patterns defined in the source are empty strings, + # so this tests the current behavior + assert isinstance(monitor._count_emoji("hello 😀"), int) + + def test_recommendations_low_consistency(self): + """Test recommendations for low consistency.""" + monitor = _create_monitor() + metrics = PersonaMetrics(consistency_score=0.5) + + recs = monitor._generate_recommendations(metrics, []) + assert any("一致性" in r for r in recs) + + def test_recommendations_good_quality(self): + """Test recommendations for good quality.""" + monitor = _create_monitor() + metrics = PersonaMetrics(consistency_score=0.9, style_stability=0.8) + + recs = monitor._generate_recommendations(metrics, []) + assert any("良好" in r for r in recs) + + @pytest.mark.asyncio + async def test_stop(self): + """Test service stop.""" + monitor = _create_monitor() + + result = await monitor.stop() + assert result is True diff --git a/tests/unit/test_patterns.py b/tests/unit/test_patterns.py new file mode 100644 index 0000000..3e76501 --- /dev/null +++ b/tests/unit/test_patterns.py @@ -0,0 +1,446 @@ +""" +Unit tests for core design patterns module + +Tests the design pattern implementations: +- AsyncServiceBase lifecycle management +- LearningContextBuilder (builder pattern) +- StrategyFactory (factory + strategy patterns) +- ServiceRegistry (singleton + registry pattern) +- ProgressiveLearningStrategy execution +- BatchLearningStrategy execution +""" +import asyncio +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from datetime import datetime + +from core.interfaces import ( + ServiceLifecycle, + LearningStrategyType, + MessageData, +) +from core.patterns import ( + AsyncServiceBase, + LearningContext, + LearningContextBuilder, + ProgressiveLearningStrategy, + BatchLearningStrategy, + StrategyFactory, + ServiceRegistry, + SingletonABCMeta, +) + + +@pytest.mark.unit +@pytest.mark.core +class TestAsyncServiceBase: + """Test AsyncServiceBase lifecycle management.""" + + def _create_service(self, name: str = "test_service") -> AsyncServiceBase: + """Helper to create a concrete service instance.""" + return AsyncServiceBase(name) + + def test_initial_status_is_created(self): + """Test service starts in CREATED state.""" + service = self._create_service() + assert service.status == ServiceLifecycle.CREATED + + @pytest.mark.asyncio + async def test_start_transitions_to_running(self): + """Test starting service transitions to RUNNING.""" + service = self._create_service() + + result = await service.start() + + assert result is True + assert service.status == ServiceLifecycle.RUNNING + + @pytest.mark.asyncio + async def test_start_already_running(self): + """Test starting already running service returns True.""" + service = self._create_service() + await service.start() + + result = await service.start() + + assert result is True + assert service.status == ServiceLifecycle.RUNNING + + @pytest.mark.asyncio + async def test_stop_transitions_to_stopped(self): + """Test stopping service transitions to STOPPED.""" + service = self._create_service() + await service.start() + + result = await service.stop() + + assert result is True + assert service.status == ServiceLifecycle.STOPPED + + @pytest.mark.asyncio + async def test_stop_already_stopped(self): + """Test stopping already stopped service returns True.""" + service = self._create_service() + await service.start() + await service.stop() + + result = await service.stop() + + assert result is True + assert service.status == ServiceLifecycle.STOPPED + + @pytest.mark.asyncio + async def test_restart(self): + """Test restarting service.""" + service = self._create_service() + await service.start() + + result = await service.restart() + + assert result is True + assert service.status == ServiceLifecycle.RUNNING + + @pytest.mark.asyncio + async def test_is_running(self): + """Test is_running check.""" + service = self._create_service() + + assert await service.is_running() is False + + await service.start() + assert await service.is_running() is True + + await service.stop() + assert await service.is_running() is False + + @pytest.mark.asyncio + async def test_health_check(self): + """Test health check reflects running state.""" + service = self._create_service() + + assert await service.health_check() is False + + await service.start() + assert await service.health_check() is True + + @pytest.mark.asyncio + async def test_start_failure_transitions_to_error(self): + """Test service transitions to ERROR on start failure.""" + service = self._create_service() + service._do_start = AsyncMock(side_effect=RuntimeError("init failed")) + + result = await service.start() + + assert result is False + assert service.status == ServiceLifecycle.ERROR + + @pytest.mark.asyncio + async def test_stop_failure_transitions_to_error(self): + """Test service transitions to ERROR on stop failure.""" + service = self._create_service() + await service.start() + service._do_stop = AsyncMock(side_effect=RuntimeError("cleanup failed")) + + result = await service.stop() + + assert result is False + assert service.status == ServiceLifecycle.ERROR + + +@pytest.mark.unit +@pytest.mark.core +class TestLearningContextBuilder: + """Test LearningContextBuilder (builder pattern).""" + + def test_build_default_context(self): + """Test building context with default values.""" + context = LearningContextBuilder().build() + + assert isinstance(context, LearningContext) + assert context.messages == [] + assert context.strategy_type == LearningStrategyType.PROGRESSIVE + assert context.quality_threshold == 0.7 + assert context.max_iterations == 3 + assert context.metadata == {} + + def test_builder_chain(self): + """Test fluent builder chaining.""" + messages = [ + MessageData( + sender_id="u1", sender_name="Alice", + message="Hello", group_id="g1", + timestamp=1.0, platform="qq", + ) + ] + + context = ( + LearningContextBuilder() + .with_messages(messages) + .with_strategy(LearningStrategyType.BATCH) + .with_quality_threshold(0.9) + .with_max_iterations(5) + .with_metadata("source", "test") + .build() + ) + + assert len(context.messages) == 1 + assert context.strategy_type == LearningStrategyType.BATCH + assert context.quality_threshold == 0.9 + assert context.max_iterations == 5 + assert context.metadata["source"] == "test" + + +@pytest.mark.unit +@pytest.mark.core +class TestStrategyFactory: + """Test StrategyFactory (factory + strategy patterns).""" + + def test_create_progressive_strategy(self): + """Test creating progressive learning strategy.""" + config = {"batch_size": 25, "min_messages": 10} + strategy = StrategyFactory.create_strategy( + LearningStrategyType.PROGRESSIVE, config + ) + + assert isinstance(strategy, ProgressiveLearningStrategy) + assert strategy.config == config + + def test_create_batch_strategy(self): + """Test creating batch learning strategy.""" + config = {"batch_size": 100} + strategy = StrategyFactory.create_strategy( + LearningStrategyType.BATCH, config + ) + + assert isinstance(strategy, BatchLearningStrategy) + + def test_create_unsupported_strategy_raises(self): + """Test creating unsupported strategy raises ValueError.""" + with pytest.raises(ValueError, match="不支持的策略类型"): + StrategyFactory.create_strategy( + LearningStrategyType.REALTIME, {} + ) + + def test_register_custom_strategy(self): + """Test registering a custom strategy type.""" + + class CustomStrategy: + def __init__(self, config): + self.config = config + + StrategyFactory.register_strategy( + LearningStrategyType.REALTIME, CustomStrategy + ) + strategy = StrategyFactory.create_strategy( + LearningStrategyType.REALTIME, {"custom": True} + ) + + assert isinstance(strategy, CustomStrategy) + + # Cleanup: remove custom strategy to avoid test pollution + del StrategyFactory._strategies[LearningStrategyType.REALTIME] + + +@pytest.mark.unit +@pytest.mark.core +class TestProgressiveLearningStrategy: + """Test ProgressiveLearningStrategy execution logic.""" + + def _make_messages(self, count: int): + """Helper to create test messages.""" + return [ + MessageData( + sender_id=f"u{i}", sender_name=f"User{i}", + message=f"Message {i}", group_id="g1", + timestamp=float(i), platform="qq", + ) + for i in range(count) + ] + + @pytest.mark.asyncio + async def test_execute_learning_cycle_success(self): + """Test progressive learning cycle executes successfully.""" + strategy = ProgressiveLearningStrategy({"batch_size": 10}) + messages = self._make_messages(25) + + result = await strategy.execute_learning_cycle(messages) + + assert result.success is True + assert result.confidence > 0 + assert result.data["total_processed"] == 25 + assert result.data["batch_count"] == 3 + + @pytest.mark.asyncio + async def test_execute_learning_cycle_empty_messages(self): + """Test progressive learning cycle with empty messages.""" + strategy = ProgressiveLearningStrategy({"batch_size": 10}) + + result = await strategy.execute_learning_cycle([]) + + assert result.success is True + assert result.data["total_processed"] == 0 + + @pytest.mark.asyncio + async def test_should_learn_sufficient_messages(self): + """Test should_learn returns True when conditions are met.""" + strategy = ProgressiveLearningStrategy({ + "min_messages": 5, + "min_interval_hours": 0, + }) + context = { + "message_count": 10, + "last_learning_time": 0, + } + + result = await strategy.should_learn(context) + assert result is True + + @pytest.mark.asyncio + async def test_should_learn_insufficient_messages(self): + """Test should_learn returns False with insufficient messages.""" + strategy = ProgressiveLearningStrategy({ + "min_messages": 20, + "min_interval_hours": 0, + }) + context = { + "message_count": 5, + "last_learning_time": 0, + } + + result = await strategy.should_learn(context) + assert result is False + + +@pytest.mark.unit +@pytest.mark.core +class TestBatchLearningStrategy: + """Test BatchLearningStrategy execution logic.""" + + def _make_messages(self, count: int): + """Helper to create test messages.""" + return [ + MessageData( + sender_id=f"u{i}", sender_name=f"User{i}", + message=f"Message {i}", group_id="g1", + timestamp=float(i), platform="qq", + ) + for i in range(count) + ] + + @pytest.mark.asyncio + async def test_execute_learning_cycle_success(self): + """Test batch learning cycle executes successfully.""" + strategy = BatchLearningStrategy({"batch_size": 100}) + messages = self._make_messages(50) + + result = await strategy.execute_learning_cycle(messages) + + assert result.success is True + assert result.confidence > 0 + assert result.data["processed_count"] == 50 + + @pytest.mark.asyncio + async def test_should_learn_above_threshold(self): + """Test should_learn returns True when batch threshold is met.""" + strategy = BatchLearningStrategy({"batch_size": 20}) + context = {"message_count": 25} + + result = await strategy.should_learn(context) + assert result is True + + @pytest.mark.asyncio + async def test_should_learn_below_threshold(self): + """Test should_learn returns False below batch threshold.""" + strategy = BatchLearningStrategy({"batch_size": 100}) + context = {"message_count": 50} + + result = await strategy.should_learn(context) + assert result is False + + +@pytest.mark.unit +@pytest.mark.core +class TestServiceRegistry: + """Test ServiceRegistry (singleton + registry pattern).""" + + def _create_fresh_registry(self) -> ServiceRegistry: + """Create a fresh registry by clearing singleton cache.""" + # Clear singleton instance to avoid test pollution + SingletonABCMeta._instances.pop(ServiceRegistry, None) + return ServiceRegistry(service_stop_timeout=2) + + def test_register_service(self): + """Test registering a service.""" + registry = self._create_fresh_registry() + service = AsyncServiceBase("test_svc") + + registry.register_service("test_svc", service) + + assert registry.get_service("test_svc") is service + + def test_get_nonexistent_service(self): + """Test getting a nonexistent service returns None.""" + registry = self._create_fresh_registry() + + assert registry.get_service("nonexistent") is None + + def test_unregister_service(self): + """Test unregistering a service.""" + registry = self._create_fresh_registry() + service = AsyncServiceBase("test_svc") + registry.register_service("test_svc", service) + + result = registry.unregister_service("test_svc") + + assert result is True + assert registry.get_service("test_svc") is None + + def test_unregister_nonexistent_service(self): + """Test unregistering a nonexistent service returns False.""" + registry = self._create_fresh_registry() + + result = registry.unregister_service("nonexistent") + + assert result is False + + @pytest.mark.asyncio + async def test_start_all_services(self): + """Test starting all registered services.""" + registry = self._create_fresh_registry() + svc1 = AsyncServiceBase("svc1") + svc2 = AsyncServiceBase("svc2") + registry.register_service("svc1", svc1) + registry.register_service("svc2", svc2) + + result = await registry.start_all_services() + + assert result is True + assert svc1.status == ServiceLifecycle.RUNNING + assert svc2.status == ServiceLifecycle.RUNNING + + @pytest.mark.asyncio + async def test_stop_all_services(self): + """Test stopping all registered services.""" + registry = self._create_fresh_registry() + svc1 = AsyncServiceBase("svc1") + svc2 = AsyncServiceBase("svc2") + registry.register_service("svc1", svc1) + registry.register_service("svc2", svc2) + await registry.start_all_services() + + result = await registry.stop_all_services() + + assert result is True + assert svc1.status == ServiceLifecycle.STOPPED + assert svc2.status == ServiceLifecycle.STOPPED + + def test_get_service_status(self): + """Test getting status of all registered services.""" + registry = self._create_fresh_registry() + svc = AsyncServiceBase("svc1") + registry.register_service("svc1", svc) + + status = registry.get_service_status() + + assert "svc1" in status + assert status["svc1"] == "created" diff --git a/tests/unit/test_security_utils.py b/tests/unit/test_security_utils.py new file mode 100644 index 0000000..bd4b10e --- /dev/null +++ b/tests/unit/test_security_utils.py @@ -0,0 +1,306 @@ +""" +Unit tests for security utilities module + +Tests the security infrastructure: +- PasswordHasher: hash generation, salt handling, verification +- LoginAttemptTracker: attempt recording, lockout, rate limiting +- SecurityValidator: password strength, input sanitization, token validation +- Password migration: plaintext to hashed format +""" +import time +import pytest +from unittest.mock import patch + +from utils.security_utils import ( + PasswordHasher, + LoginAttemptTracker, + SecurityValidator, + migrate_password_to_hashed, + verify_password_with_migration, +) + + +@pytest.mark.unit +@pytest.mark.security +class TestPasswordHasher: + """Test PasswordHasher functionality.""" + + def test_hash_password_returns_tuple(self): + """Test hash_password returns (hash, salt) tuple.""" + hashed, salt = PasswordHasher.hash_password("test_password") + + assert isinstance(hashed, str) + assert isinstance(salt, str) + assert len(hashed) == 32 # MD5 hex digest length + assert len(salt) == 32 # 16 bytes = 32 hex chars + + def test_hash_password_with_custom_salt(self): + """Test hash_password with a provided salt.""" + hashed, salt = PasswordHasher.hash_password("test_password", salt="fixed_salt") + + assert salt == "fixed_salt" + assert len(hashed) == 32 + + def test_same_password_same_salt_same_hash(self): + """Test deterministic hashing with same password and salt.""" + h1, _ = PasswordHasher.hash_password("password", salt="salt123") + h2, _ = PasswordHasher.hash_password("password", salt="salt123") + + assert h1 == h2 + + def test_same_password_different_salt_different_hash(self): + """Test different salts produce different hashes.""" + h1, _ = PasswordHasher.hash_password("password", salt="salt_a") + h2, _ = PasswordHasher.hash_password("password", salt="salt_b") + + assert h1 != h2 + + def test_different_password_same_salt_different_hash(self): + """Test different passwords produce different hashes.""" + h1, _ = PasswordHasher.hash_password("password1", salt="same_salt") + h2, _ = PasswordHasher.hash_password("password2", salt="same_salt") + + assert h1 != h2 + + def test_verify_correct_password(self): + """Test password verification with correct password.""" + hashed, salt = PasswordHasher.hash_password("correct_password") + + result = PasswordHasher.verify_password("correct_password", hashed, salt) + assert result is True + + def test_verify_incorrect_password(self): + """Test password verification with incorrect password.""" + hashed, salt = PasswordHasher.hash_password("correct_password") + + result = PasswordHasher.verify_password("wrong_password", hashed, salt) + assert result is False + + +@pytest.mark.unit +@pytest.mark.security +class TestLoginAttemptTracker: + """Test LoginAttemptTracker functionality.""" + + def test_initial_state_not_locked(self): + """Test new IP is not locked.""" + tracker = LoginAttemptTracker(max_attempts=3) + + locked, remaining = tracker.is_locked("192.168.1.1") + assert locked is False + assert remaining == 0 + + def test_full_remaining_attempts_for_new_ip(self): + """Test new IP has full remaining attempts.""" + tracker = LoginAttemptTracker(max_attempts=5) + + remaining = tracker.get_remaining_attempts("192.168.1.1") + assert remaining == 5 + + def test_failed_attempt_decreases_remaining(self): + """Test failed attempt decreases remaining count.""" + tracker = LoginAttemptTracker(max_attempts=5) + + tracker.record_attempt("192.168.1.1", success=False) + + remaining = tracker.get_remaining_attempts("192.168.1.1") + assert remaining == 4 + + def test_lockout_after_max_attempts(self): + """Test IP is locked after max failed attempts.""" + tracker = LoginAttemptTracker(max_attempts=3, lockout_duration=300) + + for _ in range(3): + tracker.record_attempt("192.168.1.1", success=False) + + locked, remaining_seconds = tracker.is_locked("192.168.1.1") + assert locked is True + assert remaining_seconds > 0 + + def test_successful_login_clears_attempts(self): + """Test successful login clears failed attempt history.""" + tracker = LoginAttemptTracker(max_attempts=5) + + tracker.record_attempt("192.168.1.1", success=False) + tracker.record_attempt("192.168.1.1", success=False) + tracker.record_attempt("192.168.1.1", success=True) + + remaining = tracker.get_remaining_attempts("192.168.1.1") + assert remaining == 5 + + def test_different_ips_independent(self): + """Test tracking is independent per IP.""" + tracker = LoginAttemptTracker(max_attempts=3) + + tracker.record_attempt("192.168.1.1", success=False) + tracker.record_attempt("192.168.1.1", success=False) + + remaining_ip1 = tracker.get_remaining_attempts("192.168.1.1") + remaining_ip2 = tracker.get_remaining_attempts("192.168.1.2") + + assert remaining_ip1 == 1 + assert remaining_ip2 == 3 + + def test_clear_ip_record(self): + """Test clearing a specific IP record.""" + tracker = LoginAttemptTracker(max_attempts=3) + + tracker.record_attempt("192.168.1.1", success=False) + tracker.clear_ip_record("192.168.1.1") + + remaining = tracker.get_remaining_attempts("192.168.1.1") + assert remaining == 3 + + def test_clear_all_records(self): + """Test clearing all IP records.""" + tracker = LoginAttemptTracker(max_attempts=3) + + tracker.record_attempt("192.168.1.1", success=False) + tracker.record_attempt("192.168.1.2", success=False) + tracker.clear_all_records() + + assert tracker.get_remaining_attempts("192.168.1.1") == 3 + assert tracker.get_remaining_attempts("192.168.1.2") == 3 + + +@pytest.mark.unit +@pytest.mark.security +class TestSecurityValidator: + """Test SecurityValidator functionality.""" + + def test_strong_password(self): + """Test strong password validation.""" + result = SecurityValidator.validate_password_strength("Str0ng!P@ss") + + assert result['valid'] is True + assert result['strength'] == 'strong' + assert result['checks']['length'] is True + assert result['checks']['lowercase'] is True + assert result['checks']['uppercase'] is True + assert result['checks']['numbers'] is True + assert result['checks']['symbols'] is True + + def test_weak_password_short(self): + """Test weak password: too short.""" + result = SecurityValidator.validate_password_strength("abc") + + assert result['valid'] is False + assert result['strength'] == 'weak' + assert result['checks']['length'] is False + + def test_medium_password(self): + """Test medium strength password.""" + result = SecurityValidator.validate_password_strength("Password1") + + assert result['valid'] is True + assert result['strength'] in ('medium', 'strong') + + def test_extra_long_password_bonus(self): + """Test extra long password gets bonus score.""" + result = SecurityValidator.validate_password_strength("ThisIsAVeryLongPassword123!") + + assert result['score'] > 80 + + def test_sanitize_input_strips_whitespace(self): + """Test input sanitization strips whitespace.""" + result = SecurityValidator.sanitize_input(" hello world ") + assert result == "hello world" + + def test_sanitize_input_truncates(self): + """Test input sanitization truncates to max_length.""" + result = SecurityValidator.sanitize_input("a" * 300, max_length=10) + assert len(result) == 10 + + def test_sanitize_empty_input(self): + """Test sanitizing empty input.""" + assert SecurityValidator.sanitize_input("") == "" + assert SecurityValidator.sanitize_input(None) == "" + + def test_valid_session_token(self): + """Test valid session token validation.""" + token = "a" * 32 # 32-char hex string + assert SecurityValidator.is_valid_session_token(token) is True + + def test_invalid_session_token_too_short(self): + """Test invalid session token: too short.""" + assert SecurityValidator.is_valid_session_token("abc123") is False + + def test_invalid_session_token_non_hex(self): + """Test invalid session token: non-hex characters.""" + assert SecurityValidator.is_valid_session_token("z" * 32) is False + + def test_empty_session_token(self): + """Test empty session token is invalid.""" + assert SecurityValidator.is_valid_session_token("") is False + assert SecurityValidator.is_valid_session_token(None) is False + + +@pytest.mark.unit +@pytest.mark.security +class TestPasswordMigration: + """Test password migration from plaintext to hashed format.""" + + def test_migrate_plaintext_password(self): + """Test migrating plaintext password to hashed format.""" + old_config = {'password': 'my_password', 'must_change': False} + + new_config = migrate_password_to_hashed(old_config) + + assert 'password_hash' in new_config + assert 'salt' in new_config + assert new_config['version'] == 2 + assert new_config['migrated_from_plaintext'] is True + + def test_already_hashed_not_migrated(self): + """Test already hashed config is not re-migrated.""" + hashed_config = { + 'password_hash': 'existing_hash', + 'salt': 'existing_salt', + 'version': 2, + } + + result = migrate_password_to_hashed(hashed_config) + + assert result is hashed_config # Same object + + def test_verify_with_old_format(self): + """Test verification with old plaintext format triggers migration.""" + old_config = {'password': 'test_pass'} + + is_valid, new_config = verify_password_with_migration('test_pass', old_config) + + assert is_valid is True + assert 'password_hash' in new_config + + def test_verify_with_old_format_wrong_password(self): + """Test verification with wrong password in old format.""" + old_config = {'password': 'correct_pass'} + + is_valid, config = verify_password_with_migration('wrong_pass', old_config) + + assert is_valid is False + + def test_verify_with_new_format(self): + """Test verification with new hashed format.""" + hashed, salt = PasswordHasher.hash_password("secure_pwd") + new_config = {'password_hash': hashed, 'salt': salt} + + is_valid, config = verify_password_with_migration('secure_pwd', new_config) + + assert is_valid is True + + def test_verify_with_new_format_wrong_password(self): + """Test verification with wrong password in new format.""" + hashed, salt = PasswordHasher.hash_password("correct_pwd") + new_config = {'password_hash': hashed, 'salt': salt} + + is_valid, config = verify_password_with_migration('wrong_pwd', new_config) + + assert is_valid is False + + def test_verify_with_missing_hash_or_salt(self): + """Test verification fails gracefully when hash or salt is missing.""" + config = {'password_hash': '', 'salt': ''} + + is_valid, _ = verify_password_with_migration('any_pwd', config) + assert is_valid is False diff --git a/tests/unit/test_tiered_learning_trigger.py b/tests/unit/test_tiered_learning_trigger.py new file mode 100644 index 0000000..7fd600c --- /dev/null +++ b/tests/unit/test_tiered_learning_trigger.py @@ -0,0 +1,441 @@ +""" +Unit tests for TieredLearningTrigger + +Tests the two-tier learning trigger mechanism: +- Tier 1 registration and execution (per-message, concurrent) +- Tier 2 registration and gated execution (batch, cooldown/threshold) +- Error isolation between Tier 1 operations +- BatchTriggerPolicy configuration +- force_tier2 fast-path triggering +- Per-group state tracking and statistics +""" +import asyncio +import time +import pytest +from unittest.mock import AsyncMock, patch + +from core.interfaces import MessageData +from services.quality.tiered_learning_trigger import ( + TieredLearningTrigger, + BatchTriggerPolicy, + TriggerResult, + _GroupTriggerState, +) + + +def _make_message(text: str = "test message", group_id: str = "g1") -> MessageData: + """Helper to create a test MessageData instance.""" + return MessageData( + sender_id="user1", + sender_name="TestUser", + message=text, + group_id=group_id, + timestamp=time.time(), + platform="test", + ) + + +@pytest.mark.unit +@pytest.mark.quality +class TestTieredLearningTriggerRegistration: + """Test callback registration for both tiers.""" + + def test_register_tier1_success(self): + """Test registering a valid Tier 1 async callback.""" + trigger = TieredLearningTrigger() + + async def callback(msg, gid): + pass + + trigger.register_tier1("test_op", callback) + assert "test_op" in trigger._tier1_ops + + def test_register_tier1_none_callback_ignored(self): + """Test registering None callback is silently ignored.""" + trigger = TieredLearningTrigger() + trigger.register_tier1("test_op", None) + assert "test_op" not in trigger._tier1_ops + + def test_register_tier1_sync_callback_raises(self): + """Test registering a sync callback raises TypeError.""" + trigger = TieredLearningTrigger() + + def sync_callback(msg, gid): + pass + + with pytest.raises(TypeError, match="must be an async function"): + trigger.register_tier1("bad_op", sync_callback) + + def test_register_tier2_success(self): + """Test registering a valid Tier 2 async callback.""" + trigger = TieredLearningTrigger() + + async def callback(gid): + pass + + trigger.register_tier2("batch_op", callback) + assert "batch_op" in trigger._tier2_ops + + def test_register_tier2_with_custom_policy(self): + """Test registering Tier 2 with custom policy.""" + trigger = TieredLearningTrigger() + policy = BatchTriggerPolicy(message_threshold=50, cooldown_seconds=300.0) + + async def callback(gid): + pass + + trigger.register_tier2("batch_op", callback, policy=policy) + + stored_callback, stored_policy = trigger._tier2_ops["batch_op"] + assert stored_policy.message_threshold == 50 + assert stored_policy.cooldown_seconds == 300.0 + + def test_register_tier2_default_policy(self): + """Test registering Tier 2 uses default policy when none specified.""" + trigger = TieredLearningTrigger() + + async def callback(gid): + pass + + trigger.register_tier2("batch_op", callback) + + _, stored_policy = trigger._tier2_ops["batch_op"] + assert stored_policy.message_threshold == 15 + assert stored_policy.cooldown_seconds == 120.0 + + def test_register_tier2_sync_callback_raises(self): + """Test registering a sync Tier 2 callback raises TypeError.""" + trigger = TieredLearningTrigger() + + def sync_callback(gid): + pass + + with pytest.raises(TypeError, match="must be an async function"): + trigger.register_tier2("bad_op", sync_callback) + + +@pytest.mark.unit +@pytest.mark.quality +class TestTier1Execution: + """Test Tier 1 per-message execution.""" + + @pytest.mark.asyncio + async def test_tier1_executes_for_every_message(self): + """Test Tier 1 callbacks execute for every incoming message.""" + trigger = TieredLearningTrigger() + call_count = 0 + + async def tier1_callback(msg, gid): + nonlocal call_count + call_count += 1 + + trigger.register_tier1("counter", tier1_callback) + + for _ in range(5): + await trigger.process_message(_make_message(), "g1") + + assert call_count == 5 + + @pytest.mark.asyncio + async def test_tier1_concurrent_execution(self): + """Test multiple Tier 1 callbacks run concurrently.""" + trigger = TieredLearningTrigger() + execution_order = [] + + async def op_a(msg, gid): + execution_order.append("a") + + async def op_b(msg, gid): + execution_order.append("b") + + trigger.register_tier1("op_a", op_a) + trigger.register_tier1("op_b", op_b) + + await trigger.process_message(_make_message(), "g1") + + assert "a" in execution_order + assert "b" in execution_order + + @pytest.mark.asyncio + async def test_tier1_error_isolation(self): + """Test one Tier 1 failure does not affect others.""" + trigger = TieredLearningTrigger() + healthy_called = False + + async def failing_op(msg, gid): + raise RuntimeError("Tier 1 failure") + + async def healthy_op(msg, gid): + nonlocal healthy_called + healthy_called = True + + trigger.register_tier1("failing", failing_op) + trigger.register_tier1("healthy", healthy_op) + + result = await trigger.process_message(_make_message(), "g1") + + assert healthy_called is True + assert result.tier1_details["failing"] is False + assert result.tier1_details["healthy"] is True + assert result.tier1_ok is False + + @pytest.mark.asyncio + async def test_tier1_all_success(self): + """Test tier1_ok is True when all operations succeed.""" + trigger = TieredLearningTrigger() + + async def ok_op(msg, gid): + pass + + trigger.register_tier1("op1", ok_op) + trigger.register_tier1("op2", ok_op) + + result = await trigger.process_message(_make_message(), "g1") + + assert result.tier1_ok is True + assert all(result.tier1_details.values()) + + +@pytest.mark.unit +@pytest.mark.quality +class TestTier2Execution: + """Test Tier 2 batch execution with gating.""" + + @pytest.mark.asyncio + async def test_tier2_triggers_on_message_threshold(self): + """Test Tier 2 fires when message count reaches threshold.""" + trigger = TieredLearningTrigger() + tier2_called = False + + async def tier1_noop(msg, gid): + pass + + async def tier2_callback(gid): + nonlocal tier2_called + tier2_called = True + + trigger.register_tier1("noop", tier1_noop) + trigger.register_tier2( + "batch", tier2_callback, + policy=BatchTriggerPolicy(message_threshold=3, cooldown_seconds=9999), + ) + + for _ in range(3): + result = await trigger.process_message(_make_message(), "g1") + + assert tier2_called is True + + @pytest.mark.asyncio + async def test_tier2_does_not_trigger_below_threshold(self): + """Test Tier 2 does not fire below message threshold.""" + trigger = TieredLearningTrigger() + tier2_called = False + + async def tier1_noop(msg, gid): + pass + + async def tier2_callback(gid): + nonlocal tier2_called + tier2_called = True + + trigger.register_tier1("noop", tier1_noop) + trigger.register_tier2( + "batch", tier2_callback, + policy=BatchTriggerPolicy(message_threshold=100, cooldown_seconds=9999), + ) + + # Process only 2 messages + for _ in range(2): + await trigger.process_message(_make_message(), "g1") + + assert tier2_called is False + + @pytest.mark.asyncio + async def test_tier2_triggers_on_cooldown_expiry(self): + """Test Tier 2 fires when cooldown expires even below threshold.""" + trigger = TieredLearningTrigger() + tier2_called = False + + async def tier1_noop(msg, gid): + pass + + async def tier2_callback(gid): + nonlocal tier2_called + tier2_called = True + + trigger.register_tier1("noop", tier1_noop) + trigger.register_tier2( + "batch", tier2_callback, + policy=BatchTriggerPolicy(message_threshold=9999, cooldown_seconds=0.0), + ) + + # First message initializes state; cooldown=0 means always ready + # But _get_state initializes last_op_times to now, so the first + # process_message won't trigger. We need to manually adjust. + state = trigger._get_state("g1") + state.last_op_times["batch"] = 0.0 # Long ago + + result = await trigger.process_message(_make_message(), "g1") + + assert tier2_called is True + assert result.tier2_triggered is True + + @pytest.mark.asyncio + async def test_tier2_resets_counter_after_trigger(self): + """Test message counter resets after Tier 2 trigger.""" + trigger = TieredLearningTrigger() + + async def tier1_noop(msg, gid): + pass + + async def tier2_callback(gid): + pass + + trigger.register_tier1("noop", tier1_noop) + trigger.register_tier2( + "batch", tier2_callback, + policy=BatchTriggerPolicy(message_threshold=2, cooldown_seconds=9999), + ) + + await trigger.process_message(_make_message(), "g1") + await trigger.process_message(_make_message(), "g1") + + state = trigger._states["g1"] + assert state.message_count == 0 # Reset after trigger + + @pytest.mark.asyncio + async def test_tier2_error_handling(self): + """Test Tier 2 failure is captured in result.""" + trigger = TieredLearningTrigger() + + async def tier1_noop(msg, gid): + pass + + async def failing_tier2(gid): + raise RuntimeError("Batch failure") + + trigger.register_tier1("noop", tier1_noop) + trigger.register_tier2( + "batch", failing_tier2, + policy=BatchTriggerPolicy(message_threshold=1, cooldown_seconds=9999), + ) + + result = await trigger.process_message(_make_message(), "g1") + + assert result.tier2_triggered is True + assert result.tier2_details["batch"] is False + + +@pytest.mark.unit +@pytest.mark.quality +class TestForceTier2: + """Test force_tier2 fast-path triggering.""" + + @pytest.mark.asyncio + async def test_force_tier2_success(self): + """Test force triggering a registered Tier 2 operation.""" + trigger = TieredLearningTrigger() + called_with_group = None + + async def tier2_callback(gid): + nonlocal called_with_group + called_with_group = gid + + trigger.register_tier2("batch", tier2_callback) + + result = await trigger.force_tier2("batch", "g1") + + assert result is True + assert called_with_group == "g1" + + @pytest.mark.asyncio + async def test_force_tier2_unregistered_operation(self): + """Test forcing an unregistered operation returns False.""" + trigger = TieredLearningTrigger() + + result = await trigger.force_tier2("nonexistent", "g1") + assert result is False + + @pytest.mark.asyncio + async def test_force_tier2_failure(self): + """Test force_tier2 returns False on callback failure.""" + trigger = TieredLearningTrigger() + + async def failing_callback(gid): + raise RuntimeError("force failure") + + trigger.register_tier2("batch", failing_callback) + + result = await trigger.force_tier2("batch", "g1") + assert result is False + + +@pytest.mark.unit +@pytest.mark.quality +class TestGroupStats: + """Test per-group statistics.""" + + def test_stats_for_unknown_group(self): + """Test stats for unknown group returns inactive.""" + trigger = TieredLearningTrigger() + + stats = trigger.get_group_stats("unknown_group") + assert stats == {"active": False} + + @pytest.mark.asyncio + async def test_stats_after_processing(self): + """Test stats reflect processing state.""" + trigger = TieredLearningTrigger() + + async def tier1_noop(msg, gid): + pass + + trigger.register_tier1("noop", tier1_noop) + + await trigger.process_message(_make_message(), "g1") + await trigger.process_message(_make_message(), "g1") + + stats = trigger.get_group_stats("g1") + assert stats["active"] is True + assert stats["total_processed"] == 2 + assert stats["message_count"] == 2 # No Tier 2 to reset + + @pytest.mark.asyncio + async def test_stats_consecutive_tier1_errors(self): + """Test consecutive Tier 1 error tracking.""" + trigger = TieredLearningTrigger() + + async def failing_op(msg, gid): + raise RuntimeError("fail") + + trigger.register_tier1("failing", failing_op) + + await trigger.process_message(_make_message(), "g1") + await trigger.process_message(_make_message(), "g1") + + stats = trigger.get_group_stats("g1") + assert stats["consecutive_tier1_errors"] == 2 + + +@pytest.mark.unit +@pytest.mark.quality +class TestBatchTriggerPolicy: + """Test BatchTriggerPolicy dataclass.""" + + def test_default_values(self): + """Test default policy values.""" + policy = BatchTriggerPolicy() + assert policy.message_threshold == 15 + assert policy.cooldown_seconds == 120.0 + + def test_custom_values(self): + """Test custom policy values.""" + policy = BatchTriggerPolicy(message_threshold=50, cooldown_seconds=600.0) + assert policy.message_threshold == 50 + assert policy.cooldown_seconds == 600.0 + + def test_policy_is_frozen(self): + """Test policy dataclass is immutable.""" + policy = BatchTriggerPolicy() + with pytest.raises(AttributeError): + policy.message_threshold = 99 diff --git a/web_res/static/js/macos/apps/Dashboard.js b/web_res/static/js/macos/apps/Dashboard.js index f6a59e5..0e26727 100644 --- a/web_res/static/js/macos/apps/Dashboard.js +++ b/web_res/static/js/macos/apps/Dashboard.js @@ -555,14 +555,15 @@ window.AppDashboard = { llmModels.length) * 100 : 0; - // 响应速度 + // 响应速度(无 LLM 数据时不显示为 0,而是显示为中性值) var avgResponseTime = llmModels.length > 0 ? llmModels.reduce(function (s, m) { return s + (m.avg_response_time_ms || 0); }, 0) / llmModels.length - : 2000; - var responseSpeed = Math.max(0, 100 - avgResponseTime / 20); + : 0; + var responseSpeed = + llmModels.length > 0 ? Math.max(0, 100 - avgResponseTime / 20) : 50; // 系统稳定性 var sm = stats.system_metrics || {}; var systemStability = diff --git a/web_res/static/js/script.js b/web_res/static/js/script.js index 99d808b..81c7b25 100644 --- a/web_res/static/js/script.js +++ b/web_res/static/js/script.js @@ -938,8 +938,9 @@ function initializeSystemStatusRadar() { (sum, model) => sum + (model.avg_response_time_ms || 0), 0, ) / llmModels.length - : 2000; - const responseSpeed = Math.max(0, 100 - avgResponseTime / 20); // 2000ms = 0分,0ms = 100分 + : 0; + const responseSpeed = + llmModels.length > 0 ? Math.max(0, 100 - avgResponseTime / 20) : 50; // 系统稳定性 (基于CPU和内存使用率) const systemMetrics = stats.system_metrics || {}; diff --git a/webui/blueprints/persona_reviews.py b/webui/blueprints/persona_reviews.py index 7d07261..70f4570 100644 --- a/webui/blueprints/persona_reviews.py +++ b/webui/blueprints/persona_reviews.py @@ -37,6 +37,8 @@ async def review_persona_update(update_id: str): """审查人格更新内容 (批准/拒绝)""" try: data = await request.get_json() + if not data: + return jsonify({"error": "Request body is required"}), 400 action = data.get("action") comment = data.get("comment", "") modified_content = data.get("modified_content") @@ -84,7 +86,7 @@ async def get_reviewed_persona_updates(): async def revert_persona_update(update_id: str): """撤回人格更新审查""" try: - data = await request.get_json() + data = await request.get_json() or {} reason = data.get("reason", "撤回审查决定") container = get_container() @@ -128,6 +130,8 @@ async def batch_delete_persona_updates(): """批量删除人格更新审查记录""" try: data = await request.get_json() + if not data: + return jsonify({"error": "Request body is required"}), 400 update_ids = data.get('update_ids', []) if not update_ids or not isinstance(update_ids, list): @@ -153,6 +157,8 @@ async def batch_review_persona_updates(): """批量审查人格更新记录""" try: data = await request.get_json() + if not data: + return jsonify({"error": "Request body is required"}), 400 update_ids = data.get('update_ids', []) action = data.get('action') comment = data.get('comment', '') diff --git a/webui/manager.py b/webui/manager.py index f0bd910..5b78bca 100644 --- a/webui/manager.py +++ b/webui/manager.py @@ -1,6 +1,5 @@ """WebUI 服务器全生命周期管理 — 创建、启动、停止、服务注册""" import asyncio -import gc import sys from typing import Optional, Any, Dict, TYPE_CHECKING @@ -181,7 +180,6 @@ async def stop(self) -> None: try: logger.info(f"正在停止 Web 服务器 (端口: {_server_instance.port})...") await _server_instance.stop() - gc.collect() if sys.platform == "win32": logger.info("Windows 环境:等待端口资源释放...") diff --git a/webui/server.py b/webui/server.py index cd0b107..ccfc747 100644 --- a/webui/server.py +++ b/webui/server.py @@ -4,7 +4,6 @@ """ import os import sys -import gc import asyncio import socket import threading @@ -192,7 +191,6 @@ async def stop(self): Server._instance = None self._initialized = False - gc.collect() logger.info("[WebUI] 服务器已停止") except Exception as e: diff --git a/webui/services/learning_service.py b/webui/services/learning_service.py index ffbb697..7a53c48 100644 --- a/webui/services/learning_service.py +++ b/webui/services/learning_service.py @@ -40,9 +40,13 @@ async def get_style_learning_results(self) -> Dict[str, Any]: if self.db_manager: try: - # 优先使用ORM Repository获取统计数据 - if hasattr(self.db_manager, 'get_session'): - # 使用ORM方式获取统计 + # 优先使用 Facade 方法获取统计数据 + if hasattr(self.db_manager, 'get_style_learning_statistics'): + real_stats = await self.db_manager.get_style_learning_statistics() + if real_stats: + results_data['statistics'].update(real_stats) + elif hasattr(self.db_manager, 'get_session'): + # 降级到 Repository 方式 from ...repositories.learning_repository import StyleLearningReviewRepository async with self.db_manager.get_session() as session: @@ -52,11 +56,6 @@ async def get_style_learning_results(self) -> Dict[str, Any]: results_data['statistics'].update(real_stats) logger.debug(f"使用ORM获取风格学习统计: {real_stats}") - else: - # 降级到传统数据库方法 - real_stats = await self.db_manager.get_style_learning_statistics() - if real_stats: - results_data['statistics'].update(real_stats) # 获取进度数据(保持原有逻辑) real_progress = await self.db_manager.get_style_progress_data() diff --git a/webui/services/persona_review_service.py b/webui/services/persona_review_service.py index 8e6dcd7..89e7ee7 100644 --- a/webui/services/persona_review_service.py +++ b/webui/services/persona_review_service.py @@ -596,25 +596,53 @@ async def get_reviewed_persona_updates( # 从传统人格更新审查获取 if self.persona_updater: - traditional_updates = await self.persona_updater.get_reviewed_persona_updates(limit, offset, status_filter) - reviewed_updates.extend(traditional_updates) + try: + traditional_updates = await self.persona_updater.get_reviewed_persona_updates(limit, offset, status_filter) + if traditional_updates: + reviewed_updates.extend(traditional_updates) + except Exception as e: + logger.warning(f"获取传统已审查人格更新失败: {e}") # 从人格学习审查获取 if self.database_manager: - persona_learning_updates = await self.database_manager.get_reviewed_persona_learning_updates(limit, offset, status_filter) - reviewed_updates.extend(persona_learning_updates) + try: + persona_learning_updates = await self.database_manager.get_reviewed_persona_learning_updates(limit, offset, status_filter) + if persona_learning_updates: + # 为人格学习记录添加前缀 ID + for update in persona_learning_updates: + if update.get('id') is not None: + update['id'] = f"persona_learning_{update['id']}" + reviewed_updates.extend(persona_learning_updates) + except Exception as e: + logger.warning(f"获取已审查人格学习更新失败: {e}") # 从风格学习审查获取 if self.database_manager: - style_updates = await self.database_manager.get_reviewed_style_learning_updates(limit, offset, status_filter) - # 将风格审查转换为统一格式 - for update in style_updates: - if 'id' in update: - update['id'] = f"style_{update['id']}" - reviewed_updates.extend(style_updates) + try: + style_updates = await self.database_manager.get_reviewed_style_learning_updates(limit, offset, status_filter) + if style_updates: + for update in style_updates: + # 转换风格学习字段为前端统一格式 + update['id'] = f"style_{update['id']}" if update.get('id') is not None else None + update['update_type'] = update.get('type', UPDATE_TYPE_STYLE_LEARNING) + update['original_content'] = update.get('original_content', '') + update['new_content'] = update.get('few_shots_content', '') + update['proposed_content'] = update.get('few_shots_content', '') + update['confidence_score'] = update.get('confidence_score', 0.9) + update['reason'] = update.get('description', '') + update['review_source'] = 'style_learning' + reviewed_updates.extend(style_updates) + except Exception as e: + logger.warning(f"获取已审查风格学习更新失败: {e}") + + # 过滤掉无效记录(id 为 None 或空的条目) + reviewed_updates = [ + u for u in reviewed_updates + if u and u.get('id') is not None + ] # 按审查时间排序 - reviewed_updates.sort(key=lambda x: x.get('review_time', 0), reverse=True) + reviewed_updates.sort(key=lambda x: x.get('review_time') or 0, reverse=True) return { "success": True,