diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3c52cb5..b052273 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,50 @@
所有重要更改都将记录在此文件中。
+## [Next-2.0.6] - 2026-03-02
+
+### 新功能
+
+#### WebUI 监听地址可配置
+- 新增 `web_interface_host` 配置项,允许用户自定义 WebUI 监听地址(默认 `0.0.0.0`)
+
+### Bug 修复
+
+#### SQLite 并发访问错误
+- 修复 SQLite 连接池使用 `StaticPool`(共享单连接)导致 WebUI 并发请求事务状态污染的问题
+- 改为 `NullPool`,每个会话获取独立连接,消除 "Cannot operate on a closed database" 错误
+
+#### 插件卸载 CPU 100%
+- 移除 WebUI 关停流程中 `server.py` 和 `manager.py` 的两次 `gc.collect()` 调用
+- 每次 `gc.collect()` 遍历 ~200 个模块的对象图耗时 80+ 秒,导致卸载期间 CPU 满载
+
+#### 命令处理器空指针
+- 为全部 6 个管理命令(`learning_status`、`start_learning`、`stop_learning`、`force_learning`、`affection_status`、`set_mood`)添加空值守卫
+- 当 `bootstrap()` 失败导致 `_command_handlers` 为 `None` 时,返回友好提示而非抛出 `'NoneType' object has no attribute` 异常
+
+#### 人格审查系统
+- 修复撤回操作崩溃和已审查列表数据缺失问题
+- 修复风格学习审查记录在已审查历史中显示空内容、类型"未知"、置信度 0.0% 的问题,补全 `StyleLearningReview` 到前端统一格式的字段映射
+- WebUI 风格统计查询改用 Facade 而非直接 Repository 调用
+
+#### MySQL 8 连接
+- 禁用 MySQL 8 默认 SSL 要求,解决 `ssl.SSLError` 连接失败
+- 强化会话生命周期管理
+
+#### ORM 字段映射
+- 修正心理状态和情绪持久化的 ORM 字段映射
+- 使用防御性 `getattr` 处理 ORM-to-dataclass 组件映射中的缺失属性
+
+#### 其他修复
+- WebUI 使用全局默认人格代替随机 UMO
+- WebUI 响应速度指标无 LLM 数据时使用中性回退值
+- 黑话 meaning 字段 dict/list 类型序列化为 JSON 字符串后写入数据库
+- 批量学习路径中正确保存筛选后的消息到数据库
+- 防护 `background_tasks` 在关停序列中的访问安全
+
+### 测试
+- 新增核心模块单元测试,扩展覆盖率配置
+
## [Next-2.0.5] - 2026-02-24
### Bug 修复
diff --git a/README.md b/README.md
index 5464e80..6af5a80 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@
-[](https://github.com/NickCharlie/astrbot_plugin_self_learning) [](LICENSE) [](https://github.com/Soulter/AstrBot) [](https://www.python.org/)
+[](https://github.com/NickCharlie/astrbot_plugin_self_learning) [](LICENSE) [](https://github.com/Soulter/AstrBot) [](https://www.python.org/)
[核心功能](#-我们能做什么) · [快速开始](#-快速开始) · [管理界面](#-可视化管理界面) · [社区交流](#-社区交流) · [贡献指南](CONTRIBUTING.md)
@@ -229,10 +229,18 @@ http://localhost:7833
---
+---
+
**如果觉得有帮助,欢迎 Star 支持!**
+### 赞助支持
+
+如果这个项目对你有帮助,欢迎通过爱发电赞助支持开发者持续维护:
+
+

+
[回到顶部](#astrbot-自主学习插件)
diff --git a/README_EN.md b/README_EN.md
index 250354c..1a993ec 100644
--- a/README_EN.md
+++ b/README_EN.md
@@ -14,7 +14,7 @@
-[](https://github.com/NickCharlie/astrbot_plugin_self_learning) [](LICENSE) [](https://github.com/Soulter/AstrBot) [](https://www.python.org/)
+[](https://github.com/NickCharlie/astrbot_plugin_self_learning) [](LICENSE) [](https://github.com/Soulter/AstrBot) [](https://www.python.org/)
[Features](#what-we-can-do) · [Quick Start](#quick-start) · [Web UI](#visual-management-interface) · [Community](#community) · [Contributing](CONTRIBUTING.md)
diff --git a/core/database/engine.py b/core/database/engine.py
index 0208aec..a2a99f2 100644
--- a/core/database/engine.py
+++ b/core/database/engine.py
@@ -8,7 +8,7 @@
避免 "Task got Future attached to a different loop" 错误。
"""
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
-from sqlalchemy.pool import StaticPool
+from sqlalchemy.pool import NullPool
from sqlalchemy import types as sa_types
from astrbot.api import logger
from typing import Optional
@@ -90,11 +90,13 @@ def _create_sqlite_engine(self):
logger.info(f"[DatabaseEngine] 创建数据库目录: {db_dir}")
# SQLite 配置
- # StaticPool reuses a single connection, avoiding per-query overhead
+ # NullPool: 每个 session 独立创建/关闭连接,避免 StaticPool
+ # 单连接共享导致的并发事务状态污染和 "closed database" 错误。
+ # SQLite 建连成本极低(打开文件句柄),配合 WAL 模式可安全并发读。
engine = create_async_engine(
db_url,
echo=self.echo,
- poolclass=StaticPool,
+ poolclass=NullPool,
connect_args={
'check_same_thread': False,
'timeout': 30,
@@ -138,6 +140,7 @@ def _create_mysql_engine(self):
connect_args={
'connect_timeout': 10,
'charset': 'utf8mb4',
+ 'ssl': False,
}
)
diff --git a/core/framework_llm_adapter.py b/core/framework_llm_adapter.py
index d5ffe2a..c5f8df7 100644
--- a/core/framework_llm_adapter.py
+++ b/core/framework_llm_adapter.py
@@ -243,6 +243,10 @@ async def filter_chat_completion(
# 尝试延迟初始化
self._try_lazy_init()
+ # 确保 contexts 不为 None,避免 Provider 内部调用 len(None)
+ if contexts is None:
+ contexts = []
+
if not self.filter_provider:
logger.warning("筛选Provider未配置,尝试使用备选Provider或降级处理")
# 尝试使用其他可用的Provider作为备选
@@ -301,6 +305,10 @@ async def refine_chat_completion(
# 尝试延迟初始化
self._try_lazy_init()
+ # 确保 contexts 不为 None,避免 Provider 内部调用 len(None)
+ if contexts is None:
+ contexts = []
+
if not self.refine_provider:
logger.warning("提炼Provider未配置,尝试使用备选Provider或降级处理")
# 尝试使用其他可用的Provider作为备选
@@ -359,6 +367,10 @@ async def reinforce_chat_completion(
# 尝试延迟初始化
self._try_lazy_init()
+ # 确保 contexts 不为 None,避免 Provider 内部调用 len(None)
+ if contexts is None:
+ contexts = []
+
if not self.reinforce_provider:
logger.warning("强化Provider未配置,尝试使用备选Provider或降级处理")
# 尝试使用其他可用的Provider作为备选
diff --git a/image/afdian-NickMo.jpeg b/image/afdian-NickMo.jpeg
new file mode 100644
index 0000000..3152387
Binary files /dev/null and b/image/afdian-NickMo.jpeg differ
diff --git a/main.py b/main.py
index de11aa2..3fe3071 100644
--- a/main.py
+++ b/main.py
@@ -231,6 +231,9 @@ async def inject_diversity_to_llm_request(self, event: AstrMessageEvent, req=Non
@filter.permission_type(PermissionType.ADMIN)
async def learning_status_command(self, event: AstrMessageEvent):
"""查看学习状态"""
+ if not self._command_handlers:
+ yield event.plain_result("插件服务未就绪,请检查启动日志")
+ return
async for result in self._command_handlers.learning_status(event):
yield result
@@ -238,6 +241,9 @@ async def learning_status_command(self, event: AstrMessageEvent):
@filter.permission_type(PermissionType.ADMIN)
async def start_learning_command(self, event: AstrMessageEvent):
"""手动启动学习"""
+ if not self._command_handlers:
+ yield event.plain_result("插件服务未就绪,请检查启动日志")
+ return
async for result in self._command_handlers.start_learning(event):
yield result
@@ -245,6 +251,9 @@ async def start_learning_command(self, event: AstrMessageEvent):
@filter.permission_type(PermissionType.ADMIN)
async def stop_learning_command(self, event: AstrMessageEvent):
"""停止学习"""
+ if not self._command_handlers:
+ yield event.plain_result("插件服务未就绪,请检查启动日志")
+ return
async for result in self._command_handlers.stop_learning(event):
yield result
@@ -252,6 +261,9 @@ async def stop_learning_command(self, event: AstrMessageEvent):
@filter.permission_type(PermissionType.ADMIN)
async def force_learning_command(self, event: AstrMessageEvent):
"""强制执行一次学习周期"""
+ if not self._command_handlers:
+ yield event.plain_result("插件服务未就绪,请检查启动日志")
+ return
async for result in self._command_handlers.force_learning(event):
yield result
@@ -259,6 +271,9 @@ async def force_learning_command(self, event: AstrMessageEvent):
@filter.permission_type(PermissionType.ADMIN)
async def affection_status_command(self, event: AstrMessageEvent):
"""查看好感度状态"""
+ if not self._command_handlers:
+ yield event.plain_result("插件服务未就绪,请检查启动日志")
+ return
async for result in self._command_handlers.affection_status(event):
yield result
@@ -266,5 +281,8 @@ async def affection_status_command(self, event: AstrMessageEvent):
@filter.permission_type(PermissionType.ADMIN)
async def set_mood_command(self, event: AstrMessageEvent):
"""手动设置bot情绪"""
+ if not self._command_handlers:
+ yield event.plain_result("插件服务未就绪,请检查启动日志")
+ return
async for result in self._command_handlers.set_mood(event):
yield result
diff --git a/metadata.yaml b/metadata.yaml
index 5634dc3..9070d2d 100644
--- a/metadata.yaml
+++ b/metadata.yaml
@@ -2,7 +2,7 @@ name: "astrbot_plugin_self_learning"
author: "NickMo"
display_name: "self-learning"
description: "SELF LEARNING 自主学习插件 — 让 AI 聊天机器人自主学习对话风格、理解群组黑话、管理社交关系与好感度、自适应人格演化,像真人一样自然对话。(使用前必须手动备份人格数据)"
-version: "Next-2.0.5"
+version: "Next-2.0.6"
repo: "https://github.com/NickCharlie/astrbot_plugin_self_learning"
tags:
- "自学习"
diff --git a/persona_web_manager.py b/persona_web_manager.py
index 010ccd3..3930853 100644
--- a/persona_web_manager.py
+++ b/persona_web_manager.py
@@ -143,8 +143,9 @@ async def get_all_personas_for_web(self) -> List[Dict[str, Any]]:
async def get_default_persona_for_web(self) -> Dict[str, Any]:
"""获取默认人格,格式化为Web界面需要的格式
- 使用 group_id_to_unified_origin 映射中的 UMO 来获取当前活跃配置的人格,
- 而非始终返回 default 配置的人格。
+ 始终传入 None 调用 get_default_persona_v3 以获取 AstrBot 全局默认人格,
+ 避免在多配置文件场景下因随机选取 UMO 而导致每次返回不同人格。
+ 如需查看特定配置的人格,应通过 get_persona_for_group 并明确指定 group_id。
"""
fallback = {
"persona_id": "default",
@@ -157,14 +158,9 @@ async def get_default_persona_for_web(self) -> Dict[str, Any]:
return fallback
try:
- # 尝试从映射中获取一个 UMO,以加载当前活跃配置的人格
- umo = None
- if self.group_id_to_unified_origin:
- # 取任意一个 UMO(通常同一配置文件下的群组共享同一配置)
- umo = next(iter(self.group_id_to_unified_origin.values()), None)
-
+ # 获取全局默认人格,不依赖 group_id_to_unified_origin 映射
default_persona = await self._run_on_main_loop(
- self.persona_manager.get_default_persona_v3(umo)
+ self.persona_manager.get_default_persona_v3(None)
)
if default_persona:
diff --git a/pytest.ini b/pytest.ini
index f51256a..d4008a1 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -23,6 +23,12 @@ addopts =
# Strict markers (only registered markers allowed)
--strict-markers
# Coverage options
+ --cov=config
+ --cov=constants
+ --cov=exceptions
+ --cov=core
+ --cov=utils
+ --cov=services
--cov=webui
--cov-report=html
--cov-report=term-missing
@@ -40,7 +46,10 @@ markers =
auth: Authentication related tests
service: Service layer tests
blueprint: Blueprint/route tests
- security: Security-related tests
+ core: Core module tests
+ config: Configuration tests
+ utils: Utility module tests
+ quality: Quality monitoring tests
# Log output
log_cli = true
@@ -53,7 +62,14 @@ asyncio_mode = auto
# Coverage options
[coverage:run]
-source = webui
+source =
+ config
+ constants
+ exceptions
+ core
+ utils
+ services
+ webui
omit =
*/tests/*
*/test_*.py
diff --git a/repositories/bot_mood_repository.py b/repositories/bot_mood_repository.py
index b8f9d99..7048b1e 100644
--- a/repositories/bot_mood_repository.py
+++ b/repositories/bot_mood_repository.py
@@ -55,7 +55,14 @@ async def save(self, mood_data: Dict[str, Any]) -> Optional[BotMood]:
mood = BotMood(**mood_data)
self.session.add(mood)
await self.session.commit()
- await self.session.refresh(mood)
+ try:
+ await self.session.refresh(mood)
+ except Exception as refresh_err:
+ # 部分 async 驱动在 commit 后 refresh 可能失败,
+ # 此时 mood 对象已持久化且 ID 已赋值,可安全返回
+ logger.debug(
+ f"[BotMoodRepository] refresh after commit skipped: {refresh_err}"
+ )
return mood
except Exception as e:
await self.session.rollback()
diff --git a/repositories/psychological_repository.py b/repositories/psychological_repository.py
index ccc74f3..96b07ec 100644
--- a/repositories/psychological_repository.py
+++ b/repositories/psychological_repository.py
@@ -98,51 +98,57 @@ async def get_components(
获取状态的所有组件
Args:
- state_id: 状态 ID
+ state_id: 复合心理状态记录的主键 ID
Returns:
List[PsychologicalStateComponent]: 组件列表
"""
- return await self.find_many(state_id=state_id)
+ return await self.find_many(composite_state_id=state_id)
async def update_component(
self,
state_id: int,
component_name: str,
value: float,
- threshold: float = None
+ threshold: float = None,
+ group_id: str = "",
+ state_id_str: str = ""
) -> Optional[PsychologicalStateComponent]:
"""
更新组件值
Args:
- state_id: 状态 ID
- component_name: 组件名称
+ state_id: 复合心理状态记录的主键 ID
+ component_name: 组件类别名称(category)
value: 组件值
threshold: 阈值
+ group_id: 群组 ID (创建新组件时使用)
+ state_id_str: 状态标识符 (创建新组件时使用)
Returns:
Optional[PsychologicalStateComponent]: 组件对象
"""
component = await self.find_one(
- state_id=state_id,
- component_name=component_name
+ composite_state_id=state_id,
+ category=component_name
)
if component:
component.value = value
if threshold is not None:
component.threshold = threshold
- component.updated_at = int(time.time())
return await self.update(component)
else:
+ now = int(time.time())
return await self.create(
- state_id=state_id,
- component_name=component_name,
+ composite_state_id=state_id,
+ group_id=group_id,
+ state_id=state_id_str or f"{group_id}:{component_name}",
+ category=component_name,
+ state_type=component_name,
value=value,
- threshold=threshold or 0.5,
- created_at=int(time.time()),
- updated_at=int(time.time())
+ threshold=threshold or 0.3,
+ start_time=now,
)
@@ -158,7 +164,9 @@ async def add_history(
from_state: str,
to_state: str,
trigger_event: str = None,
- intensity_change: float = 0.0
+ intensity_change: float = 0.0,
+ group_id: str = "",
+ category: str = ""
) -> Optional[PsychologicalStateHistory]:
"""
添加历史记录
@@ -169,16 +177,21 @@ async def add_history(
to_state: 结束状态
trigger_event: 触发事件
intensity_change: 强度变化
+ group_id: 群组 ID
+ category: 状态类别
Returns:
Optional[PsychologicalStateHistory]: 历史记录
"""
return await self.create(
- state_id=state_id,
- from_state=from_state,
- to_state=to_state,
- trigger_event=trigger_event,
- intensity_change=intensity_change,
+ group_id=group_id,
+ state_id=str(state_id),
+ category=category or "unknown",
+ old_state_type=from_state,
+ new_state_type=to_state or "",
+ old_value=0.0,
+ new_value=intensity_change,
+ change_reason=trigger_event,
timestamp=int(time.time())
)
diff --git a/services/commands/handlers.py b/services/commands/handlers.py
index 102d52a..19b3a54 100644
--- a/services/commands/handlers.py
+++ b/services/commands/handlers.py
@@ -257,6 +257,10 @@ async def affection_status(self, event: Any) -> AsyncGenerator:
yield event.plain_result(CommandMessages.AFFECTION_DISABLED)
return
+ if not self._affection_manager:
+ yield event.plain_result("好感度管理器未初始化,请检查启动日志")
+ return
+
affection_status = await self._affection_manager.get_affection_status(group_id)
current_mood = None
@@ -324,6 +328,14 @@ async def set_mood(self, event: Any) -> AsyncGenerator:
yield event.plain_result(CommandMessages.AFFECTION_DISABLED)
return
+ if not self._temporary_persona_updater:
+ yield event.plain_result("临时人格更新器未初始化,无法设置情绪")
+ return
+
+ if not self._affection_manager:
+ yield event.plain_result("好感度管理器未初始化,无法设置情绪")
+ return
+
args = event.get_message_str().split()[1:]
if len(args) < 1:
yield event.plain_result(
diff --git a/services/core_learning/progressive_learning.py b/services/core_learning/progressive_learning.py
index 4851c82..951ac4c 100644
--- a/services/core_learning/progressive_learning.py
+++ b/services/core_learning/progressive_learning.py
@@ -250,7 +250,26 @@ async def _execute_learning_batch(self, group_id: str, relearn_mode: bool = Fals
logger.debug("没有通过筛选的消息")
await self._mark_messages_processed(unprocessed_messages)
return
-
+
+ # 2.5 将筛选后的消息写入 FilteredMessage 表(供 WebUI 统计)
+ saved_count = 0
+ for msg in filtered_messages:
+ try:
+ await self.message_collector.add_filtered_message({
+ "raw_message_id": msg.get("id"),
+ "message": msg.get("message", ""),
+ "sender_id": msg.get("sender_id", ""),
+ "group_id": msg.get("group_id", group_id),
+ "timestamp": msg.get("timestamp", int(time.time())),
+ "confidence": msg.get("relevance_score", 1.0),
+ "filter_reason": msg.get("filter_reason", "batch_learning"),
+ })
+ saved_count += 1
+ except Exception:
+ pass # best-effort, don't block learning
+ if saved_count:
+ logger.debug(f"已保存 {saved_count}/{len(filtered_messages)} 条筛选消息到 FilteredMessage 表")
+
# 3. 获取当前人格设置 (针对特定群组)
current_persona = await self._get_current_persona(group_id)
diff --git a/services/database/facades/_base.py b/services/database/facades/_base.py
index c877aa2..19aea50 100644
--- a/services/database/facades/_base.py
+++ b/services/database/facades/_base.py
@@ -26,13 +26,20 @@ def __init__(self, engine: DatabaseEngine, config: PluginConfig):
@asynccontextmanager
async def get_session(self):
- """获取异步数据库会话(上下文管理器)"""
+ """获取异步数据库会话(上下文管理器)
+
+ 自动处理会话的创建、提交和回滚。
+ 使用 ``async with session`` 确保事务完整性,
+ 不再在 finally 中重复调用 close 以避免连接状态异常。
+ """
+ if self.engine is None or self.engine.engine is None:
+ raise RuntimeError("数据库引擎未初始化或已关闭")
session = self.engine.get_session()
try:
async with session:
yield session
- finally:
- await session.close()
+ except Exception:
+ raise
@staticmethod
def _row_to_dict(obj: Any, fields: Optional[List[str]] = None) -> Dict[str, Any]:
diff --git a/services/database/facades/jargon_facade.py b/services/database/facades/jargon_facade.py
index 7ca2282..a7f3961 100644
--- a/services/database/facades/jargon_facade.py
+++ b/services/database/facades/jargon_facade.py
@@ -152,7 +152,13 @@ async def update_jargon(self, jargon_data: Dict[str, Any]) -> bool:
if 'raw_content' in jargon_data:
record.raw_content = jargon_data['raw_content']
if 'meaning' in jargon_data:
- record.meaning = jargon_data['meaning']
+ meaning_val = jargon_data['meaning']
+ if isinstance(meaning_val, dict):
+ record.meaning = json.dumps(meaning_val, ensure_ascii=False)
+ elif isinstance(meaning_val, list):
+ record.meaning = json.dumps(meaning_val, ensure_ascii=False)
+ else:
+ record.meaning = str(meaning_val) if meaning_val is not None else None
if 'is_jargon' in jargon_data:
record.is_jargon = jargon_data['is_jargon']
if 'count' in jargon_data:
@@ -743,8 +749,12 @@ async def get_jargon_groups(self) -> List[Dict]:
groups = []
for row in rows:
try:
+ chat_id = row.chat_id or ''
groups.append({
- 'chat_id': row.chat_id,
+ 'group_id': chat_id,
+ 'group_name': chat_id,
+ 'id': chat_id,
+ 'chat_id': chat_id,
'count': row.count or 0
})
except Exception as row_error:
diff --git a/services/database/sqlalchemy_database_manager.py b/services/database/sqlalchemy_database_manager.py
index 99ec3a6..586d30f 100644
--- a/services/database/sqlalchemy_database_manager.py
+++ b/services/database/sqlalchemy_database_manager.py
@@ -150,37 +150,53 @@ def _get_database_url(self) -> str:
return f"sqlite:///{db_path}"
async def _ensure_mysql_database_exists(self):
- """确保 MySQL 数据库存在"""
- try:
- import aiomysql
- host = getattr(self.config, 'mysql_host', 'localhost')
- port = getattr(self.config, 'mysql_port', 3306)
- user = getattr(self.config, 'mysql_user', 'root')
- password = getattr(self.config, 'mysql_password', '')
- database = getattr(self.config, 'mysql_database', 'astrbot_self_learning')
+ """确保 MySQL 数据库存在
+
+ 使用 aiomysql 直连 MySQL 服务器以检查/创建目标数据库。
+ 显式禁用 SSL 以避免 MySQL 8 默认 TLS 握手导致的
+ struct.unpack 解包异常。
+ """
+ import aiomysql
+ host = getattr(self.config, 'mysql_host', 'localhost')
+ port = getattr(self.config, 'mysql_port', 3306)
+ user = getattr(self.config, 'mysql_user', 'root')
+ password = getattr(self.config, 'mysql_password', '')
+ database = getattr(self.config, 'mysql_database', 'astrbot_self_learning')
- conn = await aiomysql.connect(
- host=host, port=port, user=user,
- password=password, charset='utf8mb4',
+ try:
+ conn = await asyncio.wait_for(
+ aiomysql.connect(
+ host=host, port=port, user=user,
+ password=password, charset='utf8mb4',
+ ssl=False, connect_timeout=10,
+ ),
+ timeout=15,
)
- try:
- async with conn.cursor() as cursor:
+ except asyncio.TimeoutError:
+ logger.error("[DomainRouter] 连接 MySQL 超时 (15s)")
+ raise
+ except Exception as e:
+ logger.error(f"[DomainRouter] 连接 MySQL 失败: {e}")
+ raise
+
+ try:
+ async with conn.cursor() as cursor:
+ await cursor.execute(
+ "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = %s",
+ (database,),
+ )
+ if not await cursor.fetchone():
+ logger.info(f"[DomainRouter] 数据库 {database} 不存在,正在创建...")
await cursor.execute(
- "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = %s",
- (database,),
+ f"CREATE DATABASE `{database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
)
- if not await cursor.fetchone():
- logger.info(f"[DomainRouter] 数据库 {database} 不存在,正在创建…")
- await cursor.execute(
- f"CREATE DATABASE `{database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
- )
- await conn.commit()
- logger.info(f"[DomainRouter] 数据库 {database} 创建成功")
- finally:
- conn.close()
+ await conn.commit()
+ logger.info(f"[DomainRouter] 数据库 {database} 创建成功")
except Exception as e:
logger.error(f"[DomainRouter] 确保 MySQL 数据库存在失败: {e}")
raise
+ finally:
+ conn.close()
# Infrastructure: session
@@ -206,8 +222,8 @@ async def get_session(self):
try:
async with session:
yield session
- finally:
- await session.close()
+ except Exception:
+ raise
# Domain delegates: AffectionFacade
diff --git a/services/jargon/jargon_miner.py b/services/jargon/jargon_miner.py
index 331a4b9..9f85733 100644
--- a/services/jargon/jargon_miner.py
+++ b/services/jargon/jargon_miner.py
@@ -116,7 +116,13 @@ async def infer_meaning(
logger.info(f"黑话 {content} 信息不足,等待下次推断")
return {'no_info': True}
- meaning1 = inference1.get('meaning', '').strip()
+ meaning1_raw = inference1.get('meaning', '')
+ if isinstance(meaning1_raw, dict):
+ meaning1 = json.dumps(meaning1_raw, ensure_ascii=False)
+ elif isinstance(meaning1_raw, list):
+ meaning1 = json.dumps(meaning1_raw, ensure_ascii=False)
+ else:
+ meaning1 = str(meaning1_raw).strip() if meaning1_raw else ''
if not meaning1:
return {'no_info': True}
@@ -153,9 +159,18 @@ async def infer_meaning(
is_similar = comparison.get('is_similar', False)
is_jargon = not is_similar
+ if is_jargon:
+ final_meaning = meaning1
+ else:
+ meaning2_raw = inference2.get('meaning', '')
+ if isinstance(meaning2_raw, (dict, list)):
+ final_meaning = json.dumps(meaning2_raw, ensure_ascii=False)
+ else:
+ final_meaning = str(meaning2_raw).strip() if meaning2_raw else ''
+
return {
'is_jargon': is_jargon,
- 'meaning': meaning1 if is_jargon else inference2.get('meaning', ''),
+ 'meaning': final_meaning,
'no_info': False
}
diff --git a/services/state/enhanced_psychological_state_manager.py b/services/state/enhanced_psychological_state_manager.py
index af4ffc2..71f4842 100644
--- a/services/state/enhanced_psychological_state_manager.py
+++ b/services/state/enhanced_psychological_state_manager.py
@@ -196,15 +196,26 @@ async def get_current_state(
# 获取组件
components = await component_repo.get_components(state.id)
- # 转换为 CompositePsychologicalState
+ # 转换 ORM 组件对象为 dataclass 实例
+ # 使用 getattr 进行防御性读取,防止 ORM 属性未加载时抛出
+ # AttributeError(如 lazy-load 失败或 schema 不一致)
state_components = []
for comp in components:
- state_components.append(PsychologicalStateComponent(
- dimension=comp.component_name,
- state_type=comp.component_name, # TODO: 需要解析类型
- value=comp.value,
- threshold=comp.threshold
- ))
+ try:
+ state_components.append(PsychologicalStateComponent(
+ category=getattr(comp, 'category', 'unknown'),
+ state_type=getattr(comp, 'state_type', 'unknown'),
+ value=float(getattr(comp, 'value', 0.5)),
+ threshold=float(getattr(comp, 'threshold', 0.3)),
+ description=getattr(comp, 'description', '') or "",
+ start_time=float(getattr(comp, 'start_time', 0))
+ if getattr(comp, 'start_time', None) else time.time()
+ ))
+ except Exception as comp_err:
+ self._logger.warning(
+ f"[增强型心理状态] 转换组件失败: {comp_err}, "
+ f"comp_type={type(comp).__name__}"
+ )
composite_state = CompositePsychologicalState(
group_id=group_id,
@@ -267,7 +278,9 @@ async def update_state(
await component_repo.update_component(
state.id,
dimension,
- new_value
+ new_value,
+ group_id=group_id,
+ state_id_str=f"{group_id}:{user_id}"
)
# 记录历史
@@ -276,7 +289,9 @@ async def update_state(
from_state=state.overall_state,
to_state=str(new_state_type),
trigger_event=trigger_event,
- intensity_change=0.0
+ intensity_change=0.0,
+ group_id=group_id,
+ category=dimension
)
# 清除缓存
diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/unit/test_cache_manager.py b/tests/unit/test_cache_manager.py
new file mode 100644
index 0000000..01b93f0
--- /dev/null
+++ b/tests/unit/test_cache_manager.py
@@ -0,0 +1,251 @@
+"""
+Unit tests for CacheManager
+
+Tests the unified cache management system:
+- Cache get/set/delete/clear operations
+- Named cache isolation (affection, memory, state, etc.)
+- Hit rate statistics tracking
+- Cache stats reporting
+- Global singleton management
+- Unknown cache name handling
+"""
+import pytest
+from unittest.mock import patch
+
+from utils.cache_manager import CacheManager, get_cache_manager, cached, async_cached
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestCacheManagerOperations:
+ """Test basic CacheManager CRUD operations."""
+
+ def test_set_and_get(self):
+ """Test setting and getting a cache value."""
+ mgr = CacheManager()
+ mgr.set("general", "key1", "value1")
+
+ result = mgr.get("general", "key1")
+ assert result == "value1"
+
+ def test_get_nonexistent_key(self):
+ """Test getting a nonexistent key returns None."""
+ mgr = CacheManager()
+
+ result = mgr.get("general", "nonexistent")
+ assert result is None
+
+ def test_delete_existing_key(self):
+ """Test deleting an existing cache entry."""
+ mgr = CacheManager()
+ mgr.set("general", "key1", "value1")
+
+ mgr.delete("general", "key1")
+
+ assert mgr.get("general", "key1") is None
+
+ def test_delete_nonexistent_key(self):
+ """Test deleting a nonexistent key does not raise."""
+ mgr = CacheManager()
+ mgr.delete("general", "nonexistent") # Should not raise
+
+ def test_clear_specific_cache(self):
+ """Test clearing a specific named cache."""
+ mgr = CacheManager()
+ mgr.set("affection", "k1", "v1")
+ mgr.set("affection", "k2", "v2")
+
+ mgr.clear("affection")
+
+ assert mgr.get("affection", "k1") is None
+ assert mgr.get("affection", "k2") is None
+
+ def test_clear_all_caches(self):
+ """Test clearing all caches at once."""
+ mgr = CacheManager()
+ mgr.set("affection", "k1", "v1")
+ mgr.set("memory", "k2", "v2")
+ mgr.set("general", "k3", "v3")
+
+ mgr.clear_all()
+
+ assert mgr.get("affection", "k1") is None
+ assert mgr.get("memory", "k2") is None
+ assert mgr.get("general", "k3") is None
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestCacheManagerIsolation:
+ """Test cache name isolation between different caches."""
+
+ def test_different_caches_are_isolated(self):
+ """Test same key in different caches are independent."""
+ mgr = CacheManager()
+ mgr.set("affection", "shared_key", "affection_value")
+ mgr.set("memory", "shared_key", "memory_value")
+
+ assert mgr.get("affection", "shared_key") == "affection_value"
+ assert mgr.get("memory", "shared_key") == "memory_value"
+
+ @pytest.mark.parametrize("cache_name", [
+ "affection", "memory", "state", "relation",
+ "context", "embedding_query",
+ "conversation", "summary", "general",
+ ])
+ def test_all_named_caches_accessible(self, cache_name):
+ """Test all named caches are accessible."""
+ mgr = CacheManager()
+ mgr.set(cache_name, "test_key", "test_value")
+
+ result = mgr.get(cache_name, "test_key")
+ assert result == "test_value"
+
+ def test_unknown_cache_name_returns_none(self):
+ """Test accessing an unknown cache name returns None."""
+ mgr = CacheManager()
+
+ result = mgr.get("unknown_cache", "key1")
+ assert result is None
+
+ def test_set_to_unknown_cache_does_not_raise(self):
+ """Test setting to an unknown cache does not raise."""
+ mgr = CacheManager()
+ mgr.set("unknown_cache", "key1", "value1") # Should not raise
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestCacheManagerStats:
+ """Test cache statistics and hit rate tracking."""
+
+ def test_hit_rate_empty(self):
+ """Test hit rates with no operations."""
+ mgr = CacheManager()
+
+ stats = mgr.get_hit_rates()
+ assert stats == {}
+
+ def test_hit_rate_tracking(self):
+ """Test hit/miss tracking across operations."""
+ mgr = CacheManager()
+ mgr.set("general", "key1", "value1")
+
+ # Hit
+ mgr.get("general", "key1")
+ # Miss
+ mgr.get("general", "nonexistent")
+
+ stats = mgr.get_hit_rates()
+ assert "general" in stats
+ assert stats["general"]["hits"] == 1
+ assert stats["general"]["misses"] == 1
+ assert stats["general"]["hit_rate"] == 0.5
+
+ def test_get_stats_for_cache(self):
+ """Test getting stats for a specific cache."""
+ mgr = CacheManager()
+ mgr.set("affection", "k1", "v1")
+
+ stats = mgr.get_stats("affection")
+
+ assert "size" in stats
+ assert "maxsize" in stats
+ assert stats["size"] == 1
+
+ def test_get_stats_unknown_cache(self):
+ """Test getting stats for unknown cache returns empty dict."""
+ mgr = CacheManager()
+
+ stats = mgr.get_stats("unknown")
+ assert stats == {}
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestCachedDecorator:
+ """Test the synchronous cached decorator."""
+
+ def test_cached_decorator_caches_result(self):
+ """Test cached decorator returns cached result on second call."""
+ mgr = CacheManager()
+ call_count = 0
+
+ @cached(cache_name="general", key_func=lambda x: f"key_{x}", manager=mgr)
+ def expensive_func(x):
+ nonlocal call_count
+ call_count += 1
+ return x * 2
+
+ result1 = expensive_func(5)
+ result2 = expensive_func(5)
+
+ assert result1 == 10
+ assert result2 == 10
+ assert call_count == 1 # Only called once
+
+ def test_cached_decorator_different_keys(self):
+ """Test cached decorator uses correct keys for different inputs."""
+ mgr = CacheManager()
+
+ @cached(cache_name="general", key_func=lambda x: f"key_{x}", manager=mgr)
+ def add_one(x):
+ return x + 1
+
+ assert add_one(1) == 2
+ assert add_one(2) == 3
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestAsyncCachedDecorator:
+ """Test the asynchronous cached decorator."""
+
+ @pytest.mark.asyncio
+ async def test_async_cached_decorator(self):
+ """Test async cached decorator caches result."""
+ mgr = CacheManager()
+ call_count = 0
+
+ @async_cached(
+ cache_name="general",
+ key_func=lambda x: f"async_key_{x}",
+ manager=mgr,
+ )
+ async def async_expensive_func(x):
+ nonlocal call_count
+ call_count += 1
+ return x * 3
+
+ result1 = await async_expensive_func(7)
+ result2 = await async_expensive_func(7)
+
+ assert result1 == 21
+ assert result2 == 21
+ assert call_count == 1
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestGlobalCacheManager:
+ """Test global singleton cache manager."""
+
+ def test_get_cache_manager_returns_instance(self):
+ """Test get_cache_manager returns a CacheManager instance."""
+ # Reset global to ensure clean state
+ import utils.cache_manager as module
+ module._global_cache_manager = None
+
+ mgr = get_cache_manager()
+
+ assert isinstance(mgr, CacheManager)
+
+ def test_get_cache_manager_returns_same_instance(self):
+ """Test get_cache_manager always returns the same singleton."""
+ import utils.cache_manager as module
+ module._global_cache_manager = None
+
+ mgr1 = get_cache_manager()
+ mgr2 = get_cache_manager()
+
+ assert mgr1 is mgr2
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
new file mode 100644
index 0000000..cdc9d4f
--- /dev/null
+++ b/tests/unit/test_config.py
@@ -0,0 +1,379 @@
+"""
+Unit tests for PluginConfig
+
+Tests the plugin configuration management including:
+- Default value initialization
+- Configuration creation from dict
+- Configuration validation
+- File persistence (save/load)
+- Boundary value verification
+"""
+import os
+import json
+import tempfile
+import pytest
+from unittest.mock import patch, MagicMock
+
+from config import PluginConfig
+
+
+@pytest.mark.unit
+@pytest.mark.config
+class TestPluginConfigDefaults:
+ """Test PluginConfig default value initialization."""
+
+ def test_create_default_instance(self):
+ """Test creating a default PluginConfig instance."""
+ config = PluginConfig()
+
+ assert config.enable_message_capture is True
+ assert config.enable_auto_learning is True
+ assert config.enable_realtime_learning is False
+ assert config.enable_web_interface is True
+ assert config.web_interface_port == 7833
+ assert config.web_interface_host == "0.0.0.0"
+
+ def test_create_default_classmethod(self):
+ """Test the create_default classmethod."""
+ config = PluginConfig.create_default()
+
+ assert isinstance(config, PluginConfig)
+ assert config.learning_interval_hours == 6
+ assert config.min_messages_for_learning == 50
+ assert config.max_messages_per_batch == 200
+
+ def test_default_learning_parameters(self):
+ """Test default learning parameter values."""
+ config = PluginConfig()
+
+ assert config.message_min_length == 5
+ assert config.message_max_length == 500
+ assert config.confidence_threshold == 0.7
+ assert config.relevance_threshold == 0.6
+ assert config.style_analysis_batch_size == 100
+ assert config.style_update_threshold == 0.6
+
+ def test_default_database_settings(self):
+ """Test default database configuration values."""
+ config = PluginConfig()
+
+ assert config.db_type == "sqlite"
+ assert config.mysql_host == "localhost"
+ assert config.mysql_port == 3306
+ assert config.postgresql_host == "localhost"
+ assert config.postgresql_port == 5432
+ assert config.max_connections == 10
+
+ def test_default_affection_settings(self):
+ """Test default affection system configuration."""
+ config = PluginConfig()
+
+ assert config.enable_affection_system is True
+ assert config.max_total_affection == 250
+ assert config.max_user_affection == 100
+ assert config.affection_decay_rate == 0.95
+
+ def test_default_provider_ids_none(self):
+ """Test provider IDs default to None."""
+ config = PluginConfig()
+
+ assert config.filter_provider_id is None
+ assert config.refine_provider_id is None
+ assert config.reinforce_provider_id is None
+ assert config.embedding_provider_id is None
+ assert config.rerank_provider_id is None
+
+ def test_sqlalchemy_always_true(self):
+ """Test that use_sqlalchemy is always True (hardcoded)."""
+ config = PluginConfig()
+ assert config.use_sqlalchemy is True
+
+
+@pytest.mark.unit
+@pytest.mark.config
+class TestPluginConfigFromDict:
+ """Test PluginConfig creation from configuration dict."""
+
+ def test_create_from_basic_config(self):
+ """Test creating config from a basic configuration dict."""
+ raw_config = {
+ 'Self_Learning_Basic': {
+ 'enable_message_capture': False,
+ 'enable_auto_learning': False,
+ 'web_interface_port': 8080,
+ }
+ }
+
+ config = PluginConfig.create_from_config(raw_config, data_dir="/tmp/test")
+
+ assert config.enable_message_capture is False
+ assert config.enable_auto_learning is False
+ assert config.web_interface_port == 8080
+ assert config.data_dir == "/tmp/test"
+
+ def test_create_from_config_with_model_settings(self):
+ """Test config creation with model configuration."""
+ raw_config = {
+ 'Model_Configuration': {
+ 'filter_provider_id': 'provider_1',
+ 'refine_provider_id': 'provider_2',
+ 'reinforce_provider_id': 'provider_3',
+ }
+ }
+
+ config = PluginConfig.create_from_config(raw_config, data_dir="/tmp/test")
+
+ assert config.filter_provider_id == 'provider_1'
+ assert config.refine_provider_id == 'provider_2'
+ assert config.reinforce_provider_id == 'provider_3'
+
+ def test_create_from_config_missing_data_dir(self):
+ """Test config creation with empty data_dir uses fallback."""
+ config = PluginConfig.create_from_config({}, data_dir="")
+
+ assert config.data_dir == "./data/self_learning_data"
+
+ def test_create_from_config_with_database_settings(self):
+ """Test config creation with database settings."""
+ raw_config = {
+ 'Database_Settings': {
+ 'db_type': 'mysql',
+ 'mysql_host': '192.168.1.100',
+ 'mysql_port': 3307,
+ 'mysql_user': 'admin',
+ 'mysql_password': 'secret',
+ 'mysql_database': 'test_db',
+ }
+ }
+
+ config = PluginConfig.create_from_config(raw_config, data_dir="/tmp/test")
+
+ assert config.db_type == 'mysql'
+ assert config.mysql_host == '192.168.1.100'
+ assert config.mysql_port == 3307
+ assert config.mysql_user == 'admin'
+ assert config.mysql_database == 'test_db'
+
+ def test_create_from_config_with_v2_settings(self):
+ """Test config creation with v2 architecture settings."""
+ raw_config = {
+ 'V2_Architecture_Settings': {
+ 'embedding_provider_id': 'embed_provider',
+ 'rerank_provider_id': 'rerank_provider',
+ 'knowledge_engine': 'lightrag',
+ 'memory_engine': 'mem0',
+ }
+ }
+
+ config = PluginConfig.create_from_config(raw_config, data_dir="/tmp/test")
+
+ assert config.embedding_provider_id == 'embed_provider'
+ assert config.rerank_provider_id == 'rerank_provider'
+ assert config.knowledge_engine == 'lightrag'
+ assert config.memory_engine == 'mem0'
+
+ def test_create_from_empty_config(self):
+ """Test config creation from empty dict uses all defaults."""
+ config = PluginConfig.create_from_config({}, data_dir="/tmp/test")
+
+ assert config.enable_message_capture is True
+ assert config.learning_interval_hours == 6
+ assert config.db_type == 'sqlite'
+
+ def test_extra_fields_ignored(self):
+ """Test that extra/unknown fields are ignored."""
+ config = PluginConfig(
+ unknown_field_1="value1",
+ unknown_field_2=42,
+ )
+ assert not hasattr(config, 'unknown_field_1')
+
+
+@pytest.mark.unit
+@pytest.mark.config
+class TestPluginConfigValidation:
+ """Test PluginConfig validation logic."""
+
+ def test_valid_config_no_errors(self):
+ """Test validation of a valid default config."""
+ config = PluginConfig(
+ filter_provider_id="provider_1",
+ refine_provider_id="provider_2",
+ )
+ errors = config.validate_config()
+
+ # Should have no blocking errors (may have warnings for reinforce)
+ blocking_errors = [e for e in errors if not e.startswith(" ")]
+ assert len(blocking_errors) == 0
+
+ def test_invalid_learning_interval(self):
+ """Test validation catches invalid learning interval."""
+ config = PluginConfig(learning_interval_hours=0)
+ errors = config.validate_config()
+
+ assert any("学习间隔必须大于0" in e for e in errors)
+
+ def test_invalid_min_messages(self):
+ """Test validation catches invalid min messages for learning."""
+ config = PluginConfig(min_messages_for_learning=0)
+ errors = config.validate_config()
+
+ assert any("最少学习消息数量必须大于0" in e for e in errors)
+
+ def test_invalid_max_batch_size(self):
+ """Test validation catches invalid max batch size."""
+ config = PluginConfig(max_messages_per_batch=-1)
+ errors = config.validate_config()
+
+ assert any("每批最大消息数量必须大于0" in e for e in errors)
+
+ def test_invalid_message_length_range(self):
+ """Test validation catches min_length >= max_length."""
+ config = PluginConfig(message_min_length=500, message_max_length=100)
+ errors = config.validate_config()
+
+ assert any("最小长度必须小于最大长度" in e for e in errors)
+
+ def test_invalid_confidence_threshold(self):
+ """Test validation catches confidence threshold out of range."""
+ config = PluginConfig(confidence_threshold=1.5)
+ errors = config.validate_config()
+
+ assert any("置信度阈值必须在0-1之间" in e for e in errors)
+
+ def test_invalid_style_threshold(self):
+ """Test validation catches style update threshold out of range."""
+ config = PluginConfig(style_update_threshold=-0.1)
+ errors = config.validate_config()
+
+ assert any("风格更新阈值必须在0-1之间" in e for e in errors)
+
+ def test_no_providers_configured(self):
+ """Test validation warns when no providers are configured."""
+ config = PluginConfig(
+ filter_provider_id=None,
+ refine_provider_id=None,
+ reinforce_provider_id=None,
+ )
+ errors = config.validate_config()
+
+ assert any("至少需要配置一个模型提供商ID" in e for e in errors)
+
+ def test_partial_providers_configured(self):
+ """Test validation with only some providers configured."""
+ config = PluginConfig(
+ filter_provider_id="provider_1",
+ refine_provider_id=None,
+ reinforce_provider_id=None,
+ )
+ errors = config.validate_config()
+
+ # Should have warnings but no blocking errors
+ blocking_errors = [e for e in errors if not e.startswith(" ")]
+ assert len(blocking_errors) == 0
+
+
+@pytest.mark.unit
+@pytest.mark.config
+class TestPluginConfigSerialization:
+ """Test PluginConfig serialization and deserialization."""
+
+ def test_to_dict(self):
+ """Test converting config to dict."""
+ config = PluginConfig(
+ enable_message_capture=False,
+ web_interface_port=9090,
+ )
+
+ config_dict = config.to_dict()
+
+ assert isinstance(config_dict, dict)
+ assert config_dict['enable_message_capture'] is False
+ assert config_dict['web_interface_port'] == 9090
+ assert 'learning_interval_hours' in config_dict
+
+ def test_save_to_file_success(self):
+ """Test saving config to file."""
+ config = PluginConfig()
+
+ with tempfile.NamedTemporaryFile(
+ mode='w', suffix='.json', delete=False
+ ) as f:
+ filepath = f.name
+
+ try:
+ result = config.save_to_file(filepath)
+
+ assert result is True
+ assert os.path.exists(filepath)
+
+ with open(filepath, 'r', encoding='utf-8') as f:
+ saved_data = json.load(f)
+ assert saved_data['enable_message_capture'] is True
+ finally:
+ os.unlink(filepath)
+
+ def test_load_from_file_success(self):
+ """Test loading config from existing file."""
+ config_data = {
+ 'enable_message_capture': False,
+ 'web_interface_port': 9999,
+ 'learning_interval_hours': 12,
+ }
+
+ with tempfile.NamedTemporaryFile(
+ mode='w', suffix='.json', delete=False
+ ) as f:
+ json.dump(config_data, f)
+ filepath = f.name
+
+ try:
+ loaded_config = PluginConfig.load_from_file(filepath)
+
+ assert loaded_config.enable_message_capture is False
+ assert loaded_config.web_interface_port == 9999
+ assert loaded_config.learning_interval_hours == 12
+ finally:
+ os.unlink(filepath)
+
+ def test_load_from_nonexistent_file(self):
+ """Test loading config from nonexistent file returns defaults."""
+ loaded_config = PluginConfig.load_from_file("/nonexistent/path.json")
+
+ assert loaded_config.enable_message_capture is True
+ assert loaded_config.learning_interval_hours == 6
+
+ def test_load_from_file_with_data_dir(self):
+ """Test loading config with explicit data_dir override."""
+ config_data = {'enable_message_capture': True}
+
+ with tempfile.NamedTemporaryFile(
+ mode='w', suffix='.json', delete=False
+ ) as f:
+ json.dump(config_data, f)
+ filepath = f.name
+
+ try:
+ loaded_config = PluginConfig.load_from_file(
+ filepath, data_dir="/custom/data/dir"
+ )
+
+ assert loaded_config.data_dir == "/custom/data/dir"
+ finally:
+ os.unlink(filepath)
+
+ def test_load_from_corrupt_file(self):
+ """Test loading config from corrupt file returns defaults."""
+ with tempfile.NamedTemporaryFile(
+ mode='w', suffix='.json', delete=False
+ ) as f:
+ f.write("this is not valid json {{{")
+ filepath = f.name
+
+ try:
+ loaded_config = PluginConfig.load_from_file(filepath)
+
+ # Should return default config
+ assert loaded_config.enable_message_capture is True
+ finally:
+ os.unlink(filepath)
diff --git a/tests/unit/test_constants.py b/tests/unit/test_constants.py
new file mode 100644
index 0000000..badadb2
--- /dev/null
+++ b/tests/unit/test_constants.py
@@ -0,0 +1,147 @@
+"""
+Unit tests for constants module
+
+Tests the update type normalization and review source resolution:
+- normalize_update_type exact and fuzzy matching
+- get_review_source_from_update_type classification
+- Legacy format backward compatibility
+"""
+import pytest
+
+from constants import (
+ UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING,
+ UPDATE_TYPE_STYLE_LEARNING,
+ UPDATE_TYPE_EXPRESSION_LEARNING,
+ UPDATE_TYPE_TRADITIONAL,
+ LEGACY_UPDATE_TYPE_MAPPING,
+ normalize_update_type,
+ get_review_source_from_update_type,
+)
+
+
+@pytest.mark.unit
+class TestNormalizeUpdateType:
+ """Test normalize_update_type function."""
+
+ def test_empty_input_returns_traditional(self):
+ """Test empty or None input returns traditional type."""
+ assert normalize_update_type("") == UPDATE_TYPE_TRADITIONAL
+ assert normalize_update_type(None) == UPDATE_TYPE_TRADITIONAL
+
+ def test_exact_match_progressive_learning(self):
+ """Test exact match for progressive_learning legacy key."""
+ result = normalize_update_type("progressive_learning")
+ assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING
+
+ def test_exact_match_style_learning(self):
+ """Test exact match for style_learning."""
+ result = normalize_update_type("style_learning")
+ assert result == UPDATE_TYPE_STYLE_LEARNING
+
+ def test_exact_match_expression_learning(self):
+ """Test exact match for expression_learning."""
+ result = normalize_update_type("expression_learning")
+ assert result == UPDATE_TYPE_EXPRESSION_LEARNING
+
+ def test_legacy_chinese_progressive_style(self):
+ """Test legacy Chinese format for progressive style analysis."""
+ result = normalize_update_type("渐进式学习-风格分析")
+ assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING
+
+ def test_legacy_chinese_progressive_persona(self):
+ """Test legacy Chinese format for progressive persona update."""
+ result = normalize_update_type("渐进式学习-人格更新")
+ assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING
+
+ def test_fuzzy_match_chinese_progressive(self):
+ """Test fuzzy match with Chinese progressive learning keyword."""
+ result = normalize_update_type("渐进式学习-新类型")
+ assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING
+
+ def test_fuzzy_match_english_progressive(self):
+ """Test fuzzy match with English progressive keyword."""
+ result = normalize_update_type("PROGRESSIVE_update")
+ assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING
+
+ def test_unknown_type_returns_traditional(self):
+ """Test unknown type returns traditional."""
+ result = normalize_update_type("some_unknown_type")
+ assert result == UPDATE_TYPE_TRADITIONAL
+
+ def test_already_normalized_value(self):
+ """Test passing an already normalized constant."""
+ result = normalize_update_type(UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING)
+ assert result == UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING
+
+
+@pytest.mark.unit
+class TestGetReviewSourceFromUpdateType:
+ """Test get_review_source_from_update_type function."""
+
+ def test_progressive_persona_learning_source(self):
+ """Test progressive persona learning maps to persona_learning."""
+ result = get_review_source_from_update_type(
+ UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING
+ )
+ assert result == 'persona_learning'
+
+ def test_style_learning_source(self):
+ """Test style learning maps to style_learning."""
+ result = get_review_source_from_update_type(UPDATE_TYPE_STYLE_LEARNING)
+ assert result == 'style_learning'
+
+ def test_expression_learning_source(self):
+ """Test expression learning maps to persona_learning."""
+ result = get_review_source_from_update_type(
+ UPDATE_TYPE_EXPRESSION_LEARNING
+ )
+ assert result == 'persona_learning'
+
+ def test_traditional_source(self):
+ """Test traditional update maps to traditional."""
+ result = get_review_source_from_update_type(UPDATE_TYPE_TRADITIONAL)
+ assert result == 'traditional'
+
+ def test_unknown_type_defaults_to_traditional(self):
+ """Test unknown update type defaults to traditional source."""
+ result = get_review_source_from_update_type("random_unknown_type")
+ assert result == 'traditional'
+
+ def test_legacy_format_normalization(self):
+ """Test legacy Chinese format is normalized before classification."""
+ result = get_review_source_from_update_type("渐进式学习-风格分析")
+ assert result == 'persona_learning'
+
+ def test_empty_string(self):
+ """Test empty string defaults to traditional."""
+ result = get_review_source_from_update_type("")
+ assert result == 'traditional'
+
+
+@pytest.mark.unit
+class TestLegacyMapping:
+ """Test legacy update type mapping completeness."""
+
+ def test_all_legacy_keys_mapped(self):
+ """Test all legacy keys exist in the mapping."""
+ expected_keys = {
+ "渐进式学习-风格分析",
+ "渐进式学习-人格更新",
+ "progressive_learning",
+ "style_learning",
+ "expression_learning",
+ }
+
+ assert set(LEGACY_UPDATE_TYPE_MAPPING.keys()) == expected_keys
+
+ def test_all_legacy_values_are_valid_constants(self):
+ """Test all legacy values map to valid update type constants."""
+ valid_types = {
+ UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING,
+ UPDATE_TYPE_STYLE_LEARNING,
+ UPDATE_TYPE_EXPRESSION_LEARNING,
+ UPDATE_TYPE_TRADITIONAL,
+ }
+
+ for value in LEGACY_UPDATE_TYPE_MAPPING.values():
+ assert value in valid_types, f"Invalid mapping value: {value}"
diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py
new file mode 100644
index 0000000..e0667a9
--- /dev/null
+++ b/tests/unit/test_exceptions.py
@@ -0,0 +1,110 @@
+"""
+Unit tests for custom exception hierarchy
+
+Tests the exception class inheritance chain and instantiation:
+- Base exception class
+- All derived exception types
+- Exception message propagation
+- isinstance checks for polymorphic handling
+"""
+import pytest
+
+from exceptions import (
+ SelfLearningError,
+ ConfigurationError,
+ MessageCollectionError,
+ StyleAnalysisError,
+ PersonaUpdateError,
+ ModelAccessError,
+ DataStorageError,
+ LearningSchedulerError,
+ LearningError,
+ ServiceError,
+ ResponseError,
+ BackupError,
+ ExpressionLearningError,
+ MemoryGraphError,
+ TimeDecayError,
+ MessageAnalysisError,
+ KnowledgeGraphError,
+)
+
+
+@pytest.mark.unit
+class TestExceptionHierarchy:
+ """Test exception class hierarchy and inheritance."""
+
+ def test_base_exception_is_exception(self):
+ """Test SelfLearningError is a proper Exception subclass."""
+ assert issubclass(SelfLearningError, Exception)
+
+ @pytest.mark.parametrize("exc_class", [
+ ConfigurationError,
+ MessageCollectionError,
+ StyleAnalysisError,
+ PersonaUpdateError,
+ ModelAccessError,
+ DataStorageError,
+ LearningSchedulerError,
+ LearningError,
+ ServiceError,
+ ResponseError,
+ BackupError,
+ ExpressionLearningError,
+ MemoryGraphError,
+ TimeDecayError,
+ MessageAnalysisError,
+ KnowledgeGraphError,
+ ])
+ def test_subclass_inherits_from_base(self, exc_class):
+ """Test all custom exceptions inherit from SelfLearningError."""
+ assert issubclass(exc_class, SelfLearningError)
+
+ @pytest.mark.parametrize("exc_class", [
+ ConfigurationError,
+ MessageCollectionError,
+ StyleAnalysisError,
+ PersonaUpdateError,
+ ModelAccessError,
+ DataStorageError,
+ LearningSchedulerError,
+ LearningError,
+ ServiceError,
+ ResponseError,
+ BackupError,
+ ExpressionLearningError,
+ MemoryGraphError,
+ TimeDecayError,
+ MessageAnalysisError,
+ KnowledgeGraphError,
+ ])
+ def test_exception_instantiation(self, exc_class):
+ """Test all exception classes can be instantiated with a message."""
+ msg = f"Test error message for {exc_class.__name__}"
+ exc = exc_class(msg)
+
+ assert str(exc) == msg
+
+ def test_catch_specific_exception(self):
+ """Test catching a specific exception type."""
+ with pytest.raises(ConfigurationError):
+ raise ConfigurationError("invalid config")
+
+ def test_catch_base_exception_for_derived(self):
+ """Test catching base exception catches derived types."""
+ with pytest.raises(SelfLearningError):
+ raise DataStorageError("storage failure")
+
+ def test_exception_with_no_message(self):
+ """Test exception can be raised without a message."""
+ exc = SelfLearningError()
+ assert str(exc) == ""
+
+ def test_exception_chain(self):
+ """Test exception chaining with __cause__."""
+ original = ValueError("original cause")
+ try:
+ raise ConfigurationError("config error") from original
+ except ConfigurationError as e:
+ assert e.__cause__ is original
+ assert str(e.__cause__) == "original cause"
diff --git a/tests/unit/test_guardrails_models.py b/tests/unit/test_guardrails_models.py
new file mode 100644
index 0000000..c72b012
--- /dev/null
+++ b/tests/unit/test_guardrails_models.py
@@ -0,0 +1,264 @@
+"""
+Unit tests for Guardrails Pydantic validation models
+
+Tests the Pydantic model definitions used for structured LLM output:
+- PsychologicalStateTransition validation
+- GoalAnalysisResult validation
+- ConversationIntentAnalysis defaults and validation
+- RelationChange and SocialRelationAnalysis validation
+- Field range constraints (ge, le, min_length, max_length)
+"""
+import pytest
+from pydantic import ValidationError
+
+from utils.guardrails_manager import (
+ PsychologicalStateTransition,
+ GoalAnalysisResult,
+ ConversationIntentAnalysis,
+ RelationChange,
+ SocialRelationAnalysis,
+)
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestPsychologicalStateTransition:
+ """Test PsychologicalStateTransition Pydantic model."""
+
+ def test_valid_creation(self):
+ """Test creating a valid state transition."""
+ state = PsychologicalStateTransition(
+ new_state="愉悦",
+ confidence=0.85,
+ reason="Positive conversation detected",
+ )
+
+ assert state.new_state == "愉悦"
+ assert state.confidence == 0.85
+ assert state.reason == "Positive conversation detected"
+
+ def test_default_values(self):
+ """Test default confidence and reason values."""
+ state = PsychologicalStateTransition(new_state="平静")
+
+ assert state.confidence == 0.8
+ assert state.reason == ""
+
+ def test_state_name_too_long(self):
+ """Test state name longer than 20 chars is rejected."""
+ with pytest.raises(ValidationError):
+ PsychologicalStateTransition(new_state="a" * 21)
+
+ def test_empty_state_name(self):
+ """Test empty state name is rejected."""
+ with pytest.raises(ValidationError):
+ PsychologicalStateTransition(new_state="")
+
+ def test_confidence_below_zero(self):
+ """Test confidence below 0.0 is rejected."""
+ with pytest.raises(ValidationError):
+ PsychologicalStateTransition(new_state="测试", confidence=-0.1)
+
+ def test_confidence_above_one(self):
+ """Test confidence above 1.0 is rejected."""
+ with pytest.raises(ValidationError):
+ PsychologicalStateTransition(new_state="测试", confidence=1.1)
+
+ def test_confidence_boundary_values(self):
+ """Test confidence at exact boundaries (0.0 and 1.0)."""
+ s_low = PsychologicalStateTransition(new_state="低", confidence=0.0)
+ s_high = PsychologicalStateTransition(new_state="高", confidence=1.0)
+
+ assert s_low.confidence == 0.0
+ assert s_high.confidence == 1.0
+
+ def test_state_name_whitespace_stripped(self):
+ """Test state name with whitespace is stripped."""
+ state = PsychologicalStateTransition(new_state=" 愉悦 ")
+ assert state.new_state == "愉悦"
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestGoalAnalysisResult:
+ """Test GoalAnalysisResult Pydantic model."""
+
+ def test_valid_creation(self):
+ """Test creating a valid goal analysis result."""
+ result = GoalAnalysisResult(
+ goal_type="emotional_support",
+ topic="工作压力",
+ confidence=0.85,
+ reasoning="User seems stressed",
+ )
+
+ assert result.goal_type == "emotional_support"
+ assert result.topic == "工作压力"
+ assert result.confidence == 0.85
+
+ def test_default_values(self):
+ """Test default values for optional fields."""
+ result = GoalAnalysisResult(
+ goal_type="casual_chat",
+ topic="日常",
+ )
+
+ assert result.confidence == 0.7
+ assert result.reasoning == ""
+
+ def test_goal_type_too_long(self):
+ """Test goal_type exceeding 50 chars is rejected."""
+ with pytest.raises(ValidationError):
+ GoalAnalysisResult(
+ goal_type="a" * 51,
+ topic="test",
+ )
+
+ def test_topic_too_long(self):
+ """Test topic exceeding 100 chars is rejected."""
+ with pytest.raises(ValidationError):
+ GoalAnalysisResult(
+ goal_type="test",
+ topic="a" * 101,
+ )
+
+ def test_empty_goal_type(self):
+ """Test empty goal_type is rejected."""
+ with pytest.raises(ValidationError):
+ GoalAnalysisResult(goal_type="", topic="test")
+
+ def test_empty_topic(self):
+ """Test empty topic is rejected."""
+ with pytest.raises(ValidationError):
+ GoalAnalysisResult(goal_type="test", topic="")
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestConversationIntentAnalysis:
+ """Test ConversationIntentAnalysis Pydantic model."""
+
+ def test_default_values(self):
+ """Test all default values are correctly set."""
+ intent = ConversationIntentAnalysis()
+
+ assert intent.goal_switch_needed is False
+ assert intent.new_goal_type is None
+ assert intent.new_topic is None
+ assert intent.topic_completed is False
+ assert intent.stage_completed is False
+ assert intent.stage_adjustment_needed is False
+ assert intent.suggested_stage is None
+ assert intent.completion_signals == 0
+ assert intent.user_engagement == 0.5
+ assert intent.reasoning == ""
+
+ def test_custom_values(self):
+ """Test setting custom values."""
+ intent = ConversationIntentAnalysis(
+ goal_switch_needed=True,
+ new_goal_type="knowledge_sharing",
+ user_engagement=0.9,
+ completion_signals=3,
+ )
+
+ assert intent.goal_switch_needed is True
+ assert intent.new_goal_type == "knowledge_sharing"
+ assert intent.user_engagement == 0.9
+ assert intent.completion_signals == 3
+
+ def test_engagement_below_zero(self):
+ """Test user_engagement below 0.0 is rejected."""
+ with pytest.raises(ValidationError):
+ ConversationIntentAnalysis(user_engagement=-0.1)
+
+ def test_engagement_above_one(self):
+ """Test user_engagement above 1.0 is rejected."""
+ with pytest.raises(ValidationError):
+ ConversationIntentAnalysis(user_engagement=1.1)
+
+ def test_negative_completion_signals(self):
+ """Test negative completion_signals is rejected."""
+ with pytest.raises(ValidationError):
+ ConversationIntentAnalysis(completion_signals=-1)
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestRelationChange:
+ """Test RelationChange Pydantic model."""
+
+ def test_valid_creation(self):
+ """Test creating a valid relation change."""
+ change = RelationChange(
+ relation_type="挚友",
+ value_delta=0.1,
+ reason="Shared positive experience",
+ )
+
+ assert change.relation_type == "挚友"
+ assert change.value_delta == 0.1
+
+ def test_relation_type_too_long(self):
+ """Test relation_type exceeding 30 chars is rejected."""
+ with pytest.raises(ValidationError):
+ RelationChange(
+ relation_type="a" * 31,
+ value_delta=0.1,
+ )
+
+ def test_value_delta_below_negative_one(self):
+ """Test value_delta below -1.0 is rejected."""
+ with pytest.raises(ValidationError):
+ RelationChange(relation_type="test", value_delta=-1.1)
+
+ def test_value_delta_above_one(self):
+ """Test value_delta above 1.0 is rejected."""
+ with pytest.raises(ValidationError):
+ RelationChange(relation_type="test", value_delta=1.1)
+
+ def test_boundary_values(self):
+ """Test boundary values for value_delta."""
+ low = RelationChange(relation_type="低", value_delta=-1.0)
+ high = RelationChange(relation_type="高", value_delta=1.0)
+
+ assert low.value_delta == -1.0
+ assert high.value_delta == 1.0
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestSocialRelationAnalysis:
+ """Test SocialRelationAnalysis Pydantic model."""
+
+ def test_valid_creation(self):
+ """Test creating a valid social relation analysis."""
+ analysis = SocialRelationAnalysis(
+ relations=[
+ RelationChange(relation_type="友情", value_delta=0.05),
+ RelationChange(relation_type="信任", value_delta=0.02),
+ ],
+ overall_sentiment="positive",
+ )
+
+ assert len(analysis.relations) == 2
+ assert analysis.overall_sentiment == "positive"
+
+ def test_empty_relations(self):
+ """Test empty relations list is valid."""
+ analysis = SocialRelationAnalysis(relations=[])
+ assert len(analysis.relations) == 0
+
+ def test_default_sentiment(self):
+ """Test default overall_sentiment is neutral."""
+ analysis = SocialRelationAnalysis(relations=[])
+ assert analysis.overall_sentiment == "neutral"
+
+ def test_max_five_relations(self):
+ """Test relations are capped at 5."""
+ relations = [
+ RelationChange(relation_type=f"type_{i}", value_delta=0.01)
+ for i in range(7)
+ ]
+ analysis = SocialRelationAnalysis(relations=relations)
+ assert len(analysis.relations) == 5
diff --git a/tests/unit/test_interfaces.py b/tests/unit/test_interfaces.py
new file mode 100644
index 0000000..0643489
--- /dev/null
+++ b/tests/unit/test_interfaces.py
@@ -0,0 +1,244 @@
+"""
+Unit tests for core interfaces module
+
+Tests the core data classes, enums, and interface definitions:
+- MessageData dataclass construction and defaults
+- AnalysisResult dataclass construction and defaults
+- PersonaUpdateRecord dataclass construction and defaults
+- ServiceLifecycle enum values
+- LearningStrategyType enum values
+- AnalysisType enum values
+"""
+import pytest
+from unittest.mock import MagicMock
+
+from core.interfaces import (
+ ServiceLifecycle,
+ MessageData,
+ AnalysisResult,
+ PersonaUpdateRecord,
+ LearningStrategyType,
+ AnalysisType,
+)
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestMessageData:
+ """Test MessageData dataclass."""
+
+ def test_required_fields(self):
+ """Test creating MessageData with all required fields."""
+ msg = MessageData(
+ sender_id="user_001",
+ sender_name="Alice",
+ message="Hello world",
+ group_id="group_001",
+ timestamp=1700000000.0,
+ platform="qq",
+ )
+
+ assert msg.sender_id == "user_001"
+ assert msg.sender_name == "Alice"
+ assert msg.message == "Hello world"
+ assert msg.group_id == "group_001"
+ assert msg.timestamp == 1700000000.0
+ assert msg.platform == "qq"
+
+ def test_optional_fields_default_none(self):
+ """Test optional fields default to None."""
+ msg = MessageData(
+ sender_id="user_001",
+ sender_name="Alice",
+ message="Hello",
+ group_id="group_001",
+ timestamp=1700000000.0,
+ platform="qq",
+ )
+
+ assert msg.message_id is None
+ assert msg.reply_to is None
+
+ def test_optional_fields_set_explicitly(self):
+ """Test optional fields can be set explicitly."""
+ msg = MessageData(
+ sender_id="user_001",
+ sender_name="Alice",
+ message="Hello",
+ group_id="group_001",
+ timestamp=1700000000.0,
+ platform="qq",
+ message_id="msg_123",
+ reply_to="msg_100",
+ )
+
+ assert msg.message_id == "msg_123"
+ assert msg.reply_to == "msg_100"
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestAnalysisResult:
+ """Test AnalysisResult dataclass."""
+
+ def test_required_fields(self):
+ """Test creating AnalysisResult with required fields."""
+ result = AnalysisResult(
+ success=True,
+ confidence=0.85,
+ data={"key": "value"},
+ )
+
+ assert result.success is True
+ assert result.confidence == 0.85
+ assert result.data == {"key": "value"}
+
+ def test_default_values(self):
+ """Test AnalysisResult default values."""
+ result = AnalysisResult(
+ success=True,
+ confidence=0.9,
+ data={},
+ )
+
+ assert result.timestamp == 0.0
+ assert result.error is None
+ assert result.consistency_score is None
+
+ def test_with_error(self):
+ """Test AnalysisResult with error information."""
+ result = AnalysisResult(
+ success=False,
+ confidence=0.0,
+ data={},
+ error="Analysis failed due to insufficient data",
+ )
+
+ assert result.success is False
+ assert result.error == "Analysis failed due to insufficient data"
+
+ def test_with_consistency_score(self):
+ """Test AnalysisResult with consistency score."""
+ result = AnalysisResult(
+ success=True,
+ confidence=0.8,
+ data={"metrics": [1, 2, 3]},
+ consistency_score=0.75,
+ )
+
+ assert result.consistency_score == 0.75
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestPersonaUpdateRecord:
+ """Test PersonaUpdateRecord dataclass."""
+
+ def test_required_fields(self):
+ """Test creating PersonaUpdateRecord with required fields."""
+ record = PersonaUpdateRecord(
+ timestamp=1700000000.0,
+ group_id="group_001",
+ update_type="prompt_update",
+ original_content="Original prompt",
+ new_content="New prompt",
+ reason="Style analysis update",
+ )
+
+ assert record.timestamp == 1700000000.0
+ assert record.group_id == "group_001"
+ assert record.update_type == "prompt_update"
+ assert record.original_content == "Original prompt"
+ assert record.new_content == "New prompt"
+ assert record.reason == "Style analysis update"
+
+ def test_default_values(self):
+ """Test PersonaUpdateRecord default values."""
+ record = PersonaUpdateRecord(
+ timestamp=0.0,
+ group_id="g1",
+ update_type="test",
+ original_content="",
+ new_content="",
+ reason="",
+ )
+
+ assert record.confidence_score == 0.5
+ assert record.id is None
+ assert record.status == "pending"
+ assert record.reviewer_comment is None
+ assert record.review_time is None
+
+ def test_approved_record(self):
+ """Test PersonaUpdateRecord with approved status."""
+ record = PersonaUpdateRecord(
+ timestamp=1700000000.0,
+ group_id="g1",
+ update_type="prompt_update",
+ original_content="old",
+ new_content="new",
+ reason="update",
+ id=42,
+ status="approved",
+ reviewer_comment="Looks good",
+ review_time=1700001000.0,
+ )
+
+ assert record.id == 42
+ assert record.status == "approved"
+ assert record.reviewer_comment == "Looks good"
+ assert record.review_time == 1700001000.0
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestServiceLifecycleEnum:
+ """Test ServiceLifecycle enum."""
+
+ def test_all_states_exist(self):
+ """Test all expected lifecycle states exist."""
+ assert ServiceLifecycle.CREATED.value == "created"
+ assert ServiceLifecycle.INITIALIZING.value == "initializing"
+ assert ServiceLifecycle.RUNNING.value == "running"
+ assert ServiceLifecycle.STOPPING.value == "stopping"
+ assert ServiceLifecycle.STOPPED.value == "stopped"
+ assert ServiceLifecycle.ERROR.value == "error"
+
+ def test_enum_count(self):
+ """Test the total number of lifecycle states."""
+ assert len(ServiceLifecycle) == 6
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestLearningStrategyTypeEnum:
+ """Test LearningStrategyType enum."""
+
+ def test_all_strategies_exist(self):
+ """Test all expected strategy types exist."""
+ assert LearningStrategyType.PROGRESSIVE.value == "progressive"
+ assert LearningStrategyType.BATCH.value == "batch"
+ assert LearningStrategyType.REALTIME.value == "realtime"
+ assert LearningStrategyType.HYBRID.value == "hybrid"
+
+ def test_enum_count(self):
+ """Test the total number of strategy types."""
+ assert len(LearningStrategyType) == 4
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestAnalysisTypeEnum:
+ """Test AnalysisType enum."""
+
+ def test_all_types_exist(self):
+ """Test all expected analysis types exist."""
+ assert AnalysisType.STYLE.value == "style"
+ assert AnalysisType.SENTIMENT.value == "sentiment"
+ assert AnalysisType.TOPIC.value == "topic"
+ assert AnalysisType.BEHAVIOR.value == "behavior"
+ assert AnalysisType.QUALITY.value == "quality"
+
+ def test_enum_count(self):
+ """Test the total number of analysis types."""
+ assert len(AnalysisType) == 5
diff --git a/tests/unit/test_json_utils.py b/tests/unit/test_json_utils.py
new file mode 100644
index 0000000..35b1dac
--- /dev/null
+++ b/tests/unit/test_json_utils.py
@@ -0,0 +1,391 @@
+"""
+Unit tests for JSON utilities module
+
+Tests LLM response parsing, markdown cleanup, and JSON validation:
+- remove_thinking_content for various LLM thinking tags
+- clean_markdown_blocks for code block removal
+- clean_control_characters for sanitization
+- extract_json_content for boundary detection
+- fix_common_json_errors for auto-repair
+- safe_parse_llm_json for end-to-end parsing
+- validate_json_structure for schema validation
+- detect_llm_provider for model name detection
+"""
+import pytest
+
+from utils.json_utils import (
+ remove_thinking_content,
+ extract_thinking_content,
+ clean_markdown_blocks,
+ clean_control_characters,
+ extract_json_content,
+ fix_common_json_errors,
+ clean_llm_json_response,
+ safe_parse_llm_json,
+ validate_json_structure,
+ detect_llm_provider,
+ LLMProvider,
+ _convert_single_quotes,
+)
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestRemoveThinkingContent:
+ """Test removal of LLM thinking tags."""
+
+ def test_empty_input(self):
+ """Test empty input returns as-is."""
+ assert remove_thinking_content("") == ""
+ assert remove_thinking_content(None) is None
+
+ def test_remove_thinking_tags(self):
+ """Test removal of tags."""
+ text = "Internal reasoningFinal answer"
+ result = remove_thinking_content(text)
+ assert "Internal reasoning" not in result
+ assert "Final answer" in result
+
+ def test_remove_thought_tags(self):
+ """Test removal of tags."""
+ text = "Analysis hereResult"
+ result = remove_thinking_content(text)
+ assert "Analysis here" not in result
+ assert "Result" in result
+
+ def test_remove_reasoning_tags(self):
+ """Test removal of tags."""
+ text = "Step 1, Step 2Output"
+ result = remove_thinking_content(text)
+ assert "Step 1" not in result
+ assert "Output" in result
+
+ def test_remove_think_tags(self):
+ """Test removal of tags."""
+ text = "Hmm let me thinkAnswer is 42"
+ result = remove_thinking_content(text)
+ assert "Hmm let me think" not in result
+ assert "Answer is 42" in result
+
+ def test_remove_chinese_thinking_tags(self):
+ """Test removal of Chinese thinking tags."""
+ text = "<思考>这是思考过程思考>最终结果"
+ result = remove_thinking_content(text)
+ assert "这是思考过程" not in result
+ assert "最终结果" in result
+
+ def test_multiline_thinking_content(self):
+ """Test removal of multiline thinking content."""
+ text = "\nLine 1\nLine 2\nLine 3\nFinal"
+ result = remove_thinking_content(text)
+ assert "Line 1" not in result
+ assert "Final" in result
+
+ def test_text_without_thinking_tags(self):
+ """Test text without thinking tags is unchanged."""
+ text = "Just a regular response without any tags"
+ result = remove_thinking_content(text)
+ assert result == text
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestExtractThinkingContent:
+ """Test extraction and separation of thinking content."""
+
+ def test_extract_thinking(self):
+ """Test extracting thinking content."""
+ text = "My thoughtsAnswer"
+ cleaned, thoughts = extract_thinking_content(text)
+
+ assert "Answer" in cleaned
+ assert len(thoughts) >= 1
+
+ def test_no_thinking_content(self):
+ """Test text without thinking content."""
+ text = "Plain text response"
+ cleaned, thoughts = extract_thinking_content(text)
+
+ assert cleaned == "Plain text response"
+ assert thoughts == []
+
+ def test_empty_input(self):
+ """Test empty input."""
+ cleaned, thoughts = extract_thinking_content("")
+ assert cleaned == ""
+ assert thoughts == []
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestCleanMarkdownBlocks:
+ """Test markdown code block cleaning."""
+
+ def test_clean_json_code_block(self):
+ """Test cleaning ```json code block."""
+ text = '```json\n{"key": "value"}\n```'
+ result = clean_markdown_blocks(text)
+ assert result == '{"key": "value"}'
+
+ def test_clean_plain_code_block(self):
+ """Test cleaning plain ``` code block."""
+ text = '```\n{"key": "value"}\n```'
+ result = clean_markdown_blocks(text)
+ assert result == '{"key": "value"}'
+
+ def test_no_code_blocks(self):
+ """Test text without code blocks is unchanged."""
+ text = '{"key": "value"}'
+ result = clean_markdown_blocks(text)
+ assert result == text
+
+ def test_empty_input(self):
+ """Test empty input returns as-is."""
+ assert clean_markdown_blocks("") == ""
+ assert clean_markdown_blocks(None) is None
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestCleanControlCharacters:
+ """Test control character cleaning."""
+
+ def test_remove_null_bytes(self):
+ """Test removal of null bytes."""
+ text = "hello\x00world"
+ result = clean_control_characters(text)
+ assert result == "helloworld"
+
+ def test_preserve_tabs_and_newlines(self):
+ """Test preservation of tabs and newlines."""
+ text = "hello\tworld\nfoo"
+ result = clean_control_characters(text)
+ assert result == text
+
+ def test_empty_input(self):
+ """Test empty input."""
+ assert clean_control_characters("") == ""
+ assert clean_control_characters(None) is None
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestExtractJsonContent:
+ """Test JSON content extraction from mixed text."""
+
+ def test_extract_json_object(self):
+ """Test extracting JSON object from text."""
+ text = 'Some text {"key": "value"} more text'
+ result = extract_json_content(text)
+ assert result == '{"key": "value"}'
+
+ def test_extract_json_array(self):
+ """Test extracting JSON array from text."""
+ text = 'Prefix [1, 2, 3] suffix'
+ result = extract_json_content(text)
+ assert result == '[1, 2, 3]'
+
+ def test_no_json_content(self):
+ """Test text without JSON returns original."""
+ text = "no json here"
+ result = extract_json_content(text)
+ assert result == text
+
+ def test_nested_json(self):
+ """Test extracting nested JSON object."""
+ text = '{"outer": {"inner": "value"}}'
+ result = extract_json_content(text)
+ assert result == text
+
+ def test_empty_input(self):
+ """Test empty input."""
+ assert extract_json_content("") == ""
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestFixCommonJsonErrors:
+ """Test JSON error auto-repair."""
+
+ def test_fix_trailing_comma_object(self):
+ """Test fixing trailing comma in object."""
+ text = '{"key": "value",}'
+ result = fix_common_json_errors(text)
+ assert result == '{"key": "value"}'
+
+ def test_fix_trailing_comma_array(self):
+ """Test fixing trailing comma in array."""
+ text = '[1, 2, 3,]'
+ result = fix_common_json_errors(text)
+ assert result == '[1, 2, 3]'
+
+ def test_fix_python_true_false(self):
+ """Test fixing Python True/False/None to JSON equivalents."""
+ text = '{"flag": True, "empty": None, "off": False}'
+ result = fix_common_json_errors(text)
+ assert ": true" in result
+ assert ": null" in result
+ assert ": false" in result
+
+ def test_fix_nan_value(self):
+ """Test fixing NaN to null."""
+ text = '{"score": nan}'
+ result = fix_common_json_errors(text)
+ assert ": null" in result
+
+ def test_empty_input(self):
+ """Test empty input."""
+ assert fix_common_json_errors("") == ""
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestSafeParseLlmJson:
+ """Test end-to-end safe JSON parsing."""
+
+ def test_parse_clean_json(self):
+ """Test parsing clean JSON."""
+ result = safe_parse_llm_json('{"key": "value"}')
+ assert result == {"key": "value"}
+
+ def test_parse_json_in_markdown(self):
+ """Test parsing JSON wrapped in markdown code block."""
+ text = '```json\n{"key": "value"}\n```'
+ result = safe_parse_llm_json(text)
+ assert result == {"key": "value"}
+
+ def test_parse_json_with_thinking_tags(self):
+ """Test parsing JSON with thinking tags."""
+ text = 'Analysis{"result": 42}'
+ result = safe_parse_llm_json(text)
+ assert result == {"result": 42}
+
+ def test_parse_json_with_trailing_comma(self):
+ """Test parsing JSON with trailing comma."""
+ text = '{"key": "value",}'
+ result = safe_parse_llm_json(text)
+ assert result == {"key": "value"}
+
+ def test_parse_invalid_json_returns_fallback(self):
+ """Test invalid JSON returns fallback result."""
+ result = safe_parse_llm_json("not json at all", fallback_result={"default": True})
+ assert result == {"default": True}
+
+ def test_parse_empty_input(self):
+ """Test empty input returns fallback."""
+ result = safe_parse_llm_json("", fallback_result=None)
+ assert result is None
+
+ def test_parse_json_array(self):
+ """Test parsing JSON array."""
+ result = safe_parse_llm_json('[1, 2, 3]')
+ assert result == [1, 2, 3]
+
+ def test_parse_with_single_quotes(self):
+ """Test parsing JSON with single quotes."""
+ text = "{'key': 'value'}"
+ result = safe_parse_llm_json(text)
+ assert result == {"key": "value"}
+
+ def test_parse_nested_json(self):
+ """Test parsing nested JSON structure."""
+ text = '{"outer": {"inner": [1, 2, 3]}, "flag": true}'
+ result = safe_parse_llm_json(text)
+ assert result["outer"]["inner"] == [1, 2, 3]
+ assert result["flag"] is True
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestValidateJsonStructure:
+ """Test JSON structure validation."""
+
+ def test_valid_with_required_fields(self):
+ """Test validation with all required fields present."""
+ data = {"name": "Alice", "age": 30}
+ valid, msg = validate_json_structure(
+ data, required_fields=["name", "age"]
+ )
+ assert valid is True
+ assert msg == ""
+
+ def test_missing_required_fields(self):
+ """Test validation with missing required fields."""
+ data = {"name": "Alice"}
+ valid, msg = validate_json_structure(
+ data, required_fields=["name", "age"]
+ )
+ assert valid is False
+ assert "age" in msg
+
+ def test_none_data(self):
+ """Test validation with None data."""
+ valid, msg = validate_json_structure(None)
+ assert valid is False
+
+ def test_type_check_success(self):
+ """Test validation with correct expected type."""
+ valid, msg = validate_json_structure(
+ {"key": "value"}, expected_type=dict
+ )
+ assert valid is True
+
+ def test_type_check_failure(self):
+ """Test validation with incorrect expected type."""
+ valid, msg = validate_json_structure(
+ [1, 2, 3], expected_type=dict
+ )
+ assert valid is False
+ assert "dict" in msg
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestDetectLlmProvider:
+ """Test LLM provider detection from model names."""
+
+ def test_detect_deepseek(self):
+ """Test detecting DeepSeek provider."""
+ assert detect_llm_provider("deepseek-chat") == LLMProvider.DEEPSEEK
+ assert detect_llm_provider("deepseek-reasoner") == LLMProvider.DEEPSEEK
+
+ def test_detect_anthropic(self):
+ """Test detecting Anthropic provider."""
+ assert detect_llm_provider("claude-3-opus") == LLMProvider.ANTHROPIC
+ assert detect_llm_provider("claude-3.5-sonnet") == LLMProvider.ANTHROPIC
+
+ def test_detect_openai(self):
+ """Test detecting OpenAI provider."""
+ assert detect_llm_provider("gpt-4") == LLMProvider.OPENAI
+ assert detect_llm_provider("gpt-4o-mini") == LLMProvider.OPENAI
+
+ def test_detect_unknown(self):
+ """Test detecting unknown provider."""
+ assert detect_llm_provider("some-custom-model") == LLMProvider.GENERIC
+
+ def test_detect_empty_input(self):
+ """Test detecting from empty input."""
+ assert detect_llm_provider("") == LLMProvider.GENERIC
+ assert detect_llm_provider(None) == LLMProvider.GENERIC
+
+
+@pytest.mark.unit
+@pytest.mark.utils
+class TestConvertSingleQuotes:
+ """Test single-to-double quote conversion."""
+
+ def test_basic_conversion(self):
+ """Test basic single quote to double quote conversion."""
+ result = _convert_single_quotes("{'key': 'value'}")
+ assert result == '{"key": "value"}'
+
+ def test_already_double_quotes(self):
+ """Test text with double quotes is unchanged."""
+ text = '{"key": "value"}'
+ result = _convert_single_quotes(text)
+ assert result == text
+
+ def test_empty_input(self):
+ """Test empty input."""
+ assert _convert_single_quotes("") == ""
+ assert _convert_single_quotes(None) is None
diff --git a/tests/unit/test_learning_quality_monitor.py b/tests/unit/test_learning_quality_monitor.py
new file mode 100644
index 0000000..2d55baa
--- /dev/null
+++ b/tests/unit/test_learning_quality_monitor.py
@@ -0,0 +1,586 @@
+"""
+Unit tests for LearningQualityMonitor
+
+Tests the learning quality monitoring service:
+- PersonaMetrics and LearningAlert dataclasses
+- Consistency calculation (text similarity fallback)
+- Style stability calculation
+- Vocabulary diversity calculation
+- Emotional balance calculation (simple fallback)
+- Coherence calculation
+- Quality alert generation
+- Style drift detection
+- Threshold dynamic adjustment
+- Pause learning decision
+- Quality report generation
+"""
+import pytest
+from unittest.mock import patch, MagicMock, AsyncMock
+from datetime import datetime, timedelta
+
+from services.quality.learning_quality_monitor import (
+ LearningQualityMonitor,
+ PersonaMetrics,
+ LearningAlert,
+)
+
+
+def _create_monitor(
+ consistency_threshold=0.5,
+ stability_threshold=0.4,
+ drift_threshold=0.4,
+) -> LearningQualityMonitor:
+ """Create a LearningQualityMonitor with mocked dependencies."""
+ config = MagicMock()
+ context = MagicMock()
+
+ monitor = LearningQualityMonitor(
+ config=config,
+ context=context,
+ llm_adapter=None,
+ prompts=None,
+ )
+ monitor.consistency_threshold = consistency_threshold
+ monitor.stability_threshold = stability_threshold
+ monitor.drift_threshold = drift_threshold
+
+ return monitor
+
+
+def _make_messages(texts):
+ """Helper to create message dicts from text list."""
+ return [{"message": text} for text in texts]
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestPersonaMetrics:
+ """Test PersonaMetrics dataclass."""
+
+ def test_default_values(self):
+ """Test default metric values."""
+ metrics = PersonaMetrics()
+
+ assert metrics.consistency_score == 0.0
+ assert metrics.style_stability == 0.0
+ assert metrics.vocabulary_diversity == 0.0
+ assert metrics.emotional_balance == 0.0
+ assert metrics.coherence_score == 0.0
+
+ def test_custom_values(self):
+ """Test custom metric values."""
+ metrics = PersonaMetrics(
+ consistency_score=0.85,
+ style_stability=0.9,
+ vocabulary_diversity=0.7,
+ emotional_balance=0.65,
+ coherence_score=0.8,
+ )
+
+ assert metrics.consistency_score == 0.85
+ assert metrics.style_stability == 0.9
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestLearningAlert:
+ """Test LearningAlert dataclass."""
+
+ def test_alert_creation(self):
+ """Test creating a learning alert."""
+ alert = LearningAlert(
+ alert_type="consistency",
+ severity="high",
+ message="Consistency dropped below threshold",
+ timestamp=datetime.now().isoformat(),
+ metrics={"consistency_score": 0.3},
+ suggestions=["Review persona changes"],
+ )
+
+ assert alert.alert_type == "consistency"
+ assert alert.severity == "high"
+ assert len(alert.suggestions) == 1
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestConsistencyCalculation:
+ """Test persona consistency score calculation."""
+
+ @pytest.mark.asyncio
+ async def test_both_empty_personas(self):
+ """Test consistency when both personas are empty."""
+ monitor = _create_monitor()
+
+ score = await monitor._calculate_consistency(
+ {"prompt": ""}, {"prompt": ""}
+ )
+ assert score == 0.7
+
+ @pytest.mark.asyncio
+ async def test_one_empty_persona(self):
+ """Test consistency when one persona is empty."""
+ monitor = _create_monitor()
+
+ score = await monitor._calculate_consistency(
+ {"prompt": "I am a helpful bot"}, {"prompt": ""}
+ )
+ assert score == 0.6
+
+ @pytest.mark.asyncio
+ async def test_identical_personas(self):
+ """Test consistency when personas are identical."""
+ monitor = _create_monitor()
+ prompt = "I am a friendly chatbot."
+
+ score = await monitor._calculate_consistency(
+ {"prompt": prompt}, {"prompt": prompt}
+ )
+ assert score == 0.95
+
+ @pytest.mark.asyncio
+ async def test_similar_personas_fallback(self):
+ """Test consistency using text similarity fallback (no LLM)."""
+ monitor = _create_monitor()
+
+ score = await monitor._calculate_consistency(
+ {"prompt": "I am a helpful assistant."},
+ {"prompt": "I am a helpful assistant. I like chatting."},
+ )
+ assert 0.4 <= score <= 1.0
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestTextSimilarity:
+ """Test text similarity fallback method."""
+
+ def test_identical_texts(self):
+ """Test identical texts return high similarity."""
+ monitor = _create_monitor()
+
+ score = monitor._calculate_text_similarity("hello world", "hello world")
+ assert score == 0.95
+
+ def test_empty_texts(self):
+ """Test empty texts return default."""
+ monitor = _create_monitor()
+
+ score = monitor._calculate_text_similarity("", "")
+ assert score == 0.6
+
+ def test_one_empty_text(self):
+ """Test one empty text returns default."""
+ monitor = _create_monitor()
+
+ score = monitor._calculate_text_similarity("hello", "")
+ assert score == 0.6
+
+ def test_different_texts(self):
+ """Test different texts return lower similarity."""
+ monitor = _create_monitor()
+
+ score = monitor._calculate_text_similarity("abc", "xyz")
+ assert 0.4 <= score <= 1.0
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestStyleStability:
+ """Test style stability calculation."""
+
+ @pytest.mark.asyncio
+ async def test_single_message_perfect_stability(self):
+ """Test single message returns perfect stability."""
+ monitor = _create_monitor()
+ messages = _make_messages(["Hello!"])
+
+ score = await monitor._calculate_style_stability(messages)
+ assert score == 1.0
+
+ @pytest.mark.asyncio
+ async def test_identical_messages_high_stability(self):
+ """Test identical messages have high stability."""
+ monitor = _create_monitor()
+ messages = _make_messages(["Hello!", "Hello!", "Hello!"])
+
+ score = await monitor._calculate_style_stability(messages)
+ assert score >= 0.8
+
+ @pytest.mark.asyncio
+ async def test_diverse_messages_lower_stability(self):
+ """Test diverse messages have lower stability."""
+ monitor = _create_monitor()
+ messages = _make_messages([
+ "Hi",
+ "This is a very long message with lots of words and punctuation! Really?",
+ "Ok",
+ ])
+
+ score = await monitor._calculate_style_stability(messages)
+ assert 0.0 <= score <= 1.0
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestVocabularyDiversity:
+ """Test vocabulary diversity calculation."""
+
+ @pytest.mark.asyncio
+ async def test_empty_messages(self):
+ """Test empty messages return zero diversity."""
+ monitor = _create_monitor()
+
+ score = await monitor._calculate_vocabulary_diversity([])
+ assert score == 0.0
+
+ @pytest.mark.asyncio
+ async def test_single_word_messages(self):
+ """Test messages with same word have low diversity (actually 1.0)."""
+ monitor = _create_monitor()
+ messages = _make_messages(["hello", "hello", "hello"])
+
+ score = await monitor._calculate_vocabulary_diversity(messages)
+ # All same word: unique=1, total=3, ratio=0.33, *2=0.66
+ assert 0.5 <= score <= 1.0
+
+ @pytest.mark.asyncio
+ async def test_unique_words_high_diversity(self):
+ """Test messages with all unique words have high diversity."""
+ monitor = _create_monitor()
+ messages = _make_messages(["apple banana", "cherry date", "elderberry fig"])
+
+ score = await monitor._calculate_vocabulary_diversity(messages)
+ assert score >= 0.8
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestEmotionalBalance:
+ """Test emotional balance calculation (simple fallback)."""
+
+ def test_neutral_messages(self):
+ """Test messages without emotional words return high balance."""
+ monitor = _create_monitor()
+ messages = _make_messages(["今天天气不错", "我去了公园"])
+
+ score = monitor._simple_emotional_balance(messages)
+ assert score == 0.8 # No emotional words = neutral
+
+ def test_positive_messages(self):
+ """Test messages with positive words."""
+ monitor = _create_monitor()
+ messages = _make_messages(["好棒啊!", "真的很开心!喜欢!"])
+
+ score = monitor._simple_emotional_balance(messages)
+ assert 0.0 <= score <= 1.0
+
+ def test_negative_messages(self):
+ """Test messages with negative words."""
+ monitor = _create_monitor()
+ messages = _make_messages(["不好", "真烦人,讨厌"])
+
+ score = monitor._simple_emotional_balance(messages)
+ assert 0.0 <= score <= 1.0
+
+ def test_balanced_messages(self):
+ """Test balanced positive and negative messages."""
+ monitor = _create_monitor()
+ messages = _make_messages(["好开心", "不好"])
+
+ score = monitor._simple_emotional_balance(messages)
+ assert 0.0 <= score <= 1.0
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestCoherence:
+ """Test coherence calculation."""
+
+ @pytest.mark.asyncio
+ async def test_empty_persona(self):
+ """Test empty persona returns zero coherence."""
+ monitor = _create_monitor()
+
+ score = await monitor._calculate_coherence({"prompt": ""})
+ assert score == 0.0
+
+ @pytest.mark.asyncio
+ async def test_single_sentence(self):
+ """Test single sentence returns high coherence."""
+ monitor = _create_monitor()
+
+ score = await monitor._calculate_coherence({"prompt": "我是一个友好的助手"})
+ assert score == 0.8
+
+ @pytest.mark.asyncio
+ async def test_multiple_sentences(self):
+ """Test multiple sentences are evaluated."""
+ monitor = _create_monitor()
+ prompt = "我是一个友好的助手。我喜欢帮助人。我会用中文交流。"
+
+ score = await monitor._calculate_coherence({"prompt": prompt})
+ assert 0.0 <= score <= 1.0
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestStyleDrift:
+ """Test style drift detection."""
+
+ def test_no_drift_identical_metrics(self):
+ """Test no drift when metrics are identical."""
+ monitor = _create_monitor()
+ metrics = PersonaMetrics(
+ consistency_score=0.8,
+ style_stability=0.7,
+ vocabulary_diversity=0.6,
+ )
+
+ drift = monitor._calculate_style_drift(metrics, metrics)
+ assert drift == 0.0
+
+ def test_large_drift(self):
+ """Test large drift detection."""
+ monitor = _create_monitor()
+ prev = PersonaMetrics(
+ consistency_score=0.9,
+ style_stability=0.8,
+ vocabulary_diversity=0.7,
+ )
+ curr = PersonaMetrics(
+ consistency_score=0.3,
+ style_stability=0.2,
+ vocabulary_diversity=0.1,
+ )
+
+ drift = monitor._calculate_style_drift(prev, curr)
+ assert drift > 0.4
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestQualityAlerts:
+ """Test quality alert generation."""
+
+ @pytest.mark.asyncio
+ async def test_consistency_alert(self):
+ """Test alert is generated when consistency is below threshold."""
+ monitor = _create_monitor(consistency_threshold=0.5)
+ metrics = PersonaMetrics(consistency_score=0.3)
+
+ await monitor._check_quality_alerts(metrics)
+
+ assert len(monitor.alerts_history) >= 1
+ assert any(a.alert_type == "consistency" for a in monitor.alerts_history)
+
+ @pytest.mark.asyncio
+ async def test_stability_alert(self):
+ """Test alert is generated when stability is below threshold."""
+ monitor = _create_monitor(stability_threshold=0.4)
+ metrics = PersonaMetrics(style_stability=0.2)
+
+ await monitor._check_quality_alerts(metrics)
+
+ assert any(a.alert_type == "stability" for a in monitor.alerts_history)
+
+ @pytest.mark.asyncio
+ async def test_no_alert_when_above_thresholds(self):
+ """Test no alerts when all metrics are above thresholds."""
+ monitor = _create_monitor()
+ metrics = PersonaMetrics(
+ consistency_score=0.9,
+ style_stability=0.8,
+ vocabulary_diversity=0.7,
+ )
+
+ await monitor._check_quality_alerts(metrics)
+ assert len(monitor.alerts_history) == 0
+
+ @pytest.mark.asyncio
+ async def test_drift_alert_with_history(self):
+ """Test drift alert when historical metrics exist."""
+ monitor = _create_monitor(drift_threshold=0.1)
+ # Add previous metrics
+ monitor.historical_metrics.append(
+ PersonaMetrics(consistency_score=0.9, style_stability=0.9, vocabulary_diversity=0.9)
+ )
+ # Current metrics show significant change
+ current = PersonaMetrics(
+ consistency_score=0.3,
+ style_stability=0.3,
+ vocabulary_diversity=0.3,
+ )
+ monitor.historical_metrics.append(current)
+
+ await monitor._check_quality_alerts(current)
+
+ assert any(a.alert_type == "drift" for a in monitor.alerts_history)
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestShouldPauseLearning:
+ """Test learning pause decision logic."""
+
+ @pytest.mark.asyncio
+ async def test_no_history_no_pause(self):
+ """Test no pause with empty history."""
+ monitor = _create_monitor()
+
+ should_pause, reason = await monitor.should_pause_learning()
+ assert should_pause is False
+
+ @pytest.mark.asyncio
+ async def test_pause_on_low_consistency(self):
+ """Test pause when consistency is critically low."""
+ monitor = _create_monitor()
+ monitor.historical_metrics.append(
+ PersonaMetrics(consistency_score=0.3)
+ )
+
+ should_pause, reason = await monitor.should_pause_learning()
+ assert should_pause is True
+ assert "一致性" in reason
+
+ @pytest.mark.asyncio
+ async def test_no_pause_above_threshold(self):
+ """Test no pause when metrics are acceptable."""
+ monitor = _create_monitor()
+ monitor.historical_metrics.append(
+ PersonaMetrics(consistency_score=0.8)
+ )
+
+ should_pause, reason = await monitor.should_pause_learning()
+ assert should_pause is False
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestQualityReport:
+ """Test quality report generation."""
+
+ @pytest.mark.asyncio
+ async def test_report_no_history(self):
+ """Test report with no historical data."""
+ monitor = _create_monitor()
+
+ report = await monitor.get_quality_report()
+ assert "error" in report
+
+ @pytest.mark.asyncio
+ async def test_report_with_single_metric(self):
+ """Test report with single historical metric."""
+ monitor = _create_monitor()
+ monitor.historical_metrics.append(
+ PersonaMetrics(
+ consistency_score=0.8,
+ style_stability=0.7,
+ vocabulary_diversity=0.6,
+ emotional_balance=0.5,
+ coherence_score=0.9,
+ )
+ )
+
+ report = await monitor.get_quality_report()
+
+ assert "current_metrics" in report
+ assert report["current_metrics"]["consistency_score"] == 0.8
+ assert "trends" in report
+ assert "recommendations" in report
+
+ @pytest.mark.asyncio
+ async def test_report_with_trends(self):
+ """Test report includes trend data when sufficient history exists."""
+ monitor = _create_monitor()
+ monitor.historical_metrics.append(
+ PersonaMetrics(consistency_score=0.6, style_stability=0.5, vocabulary_diversity=0.4)
+ )
+ monitor.historical_metrics.append(
+ PersonaMetrics(consistency_score=0.8, style_stability=0.7, vocabulary_diversity=0.6)
+ )
+
+ report = await monitor.get_quality_report()
+
+ assert report["trends"]["consistency_trend"] == pytest.approx(0.2)
+ assert report["trends"]["stability_trend"] == pytest.approx(0.2)
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestDynamicThresholdAdjustment:
+ """Test dynamic threshold adjustment based on history."""
+
+ @pytest.mark.asyncio
+ async def test_no_adjustment_insufficient_history(self):
+ """Test no adjustment with less than 5 historical entries."""
+ monitor = _create_monitor(consistency_threshold=0.5)
+
+ for _ in range(3):
+ monitor.historical_metrics.append(PersonaMetrics(consistency_score=0.9))
+
+ await monitor.adjust_thresholds_based_on_history()
+
+ # Should remain unchanged
+ assert monitor.consistency_threshold == 0.5
+
+ @pytest.mark.asyncio
+ async def test_threshold_increases_on_good_performance(self):
+ """Test threshold increases when performance is consistently good."""
+ monitor = _create_monitor(consistency_threshold=0.5)
+
+ for _ in range(5):
+ monitor.historical_metrics.append(
+ PersonaMetrics(consistency_score=0.85, style_stability=0.75)
+ )
+
+ await monitor.adjust_thresholds_based_on_history()
+
+ assert monitor.consistency_threshold == 0.55 # Increased by 0.05
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestHelperMethods:
+ """Test helper methods."""
+
+ def test_punctuation_ratio(self):
+ """Test punctuation ratio calculation."""
+ monitor = _create_monitor()
+
+ assert monitor._get_punctuation_ratio("") == 0.0
+ assert monitor._get_punctuation_ratio("hello") == 0.0
+ assert monitor._get_punctuation_ratio("你好,世界!") > 0.0
+
+ def test_count_emoji(self):
+ """Test emoji counting."""
+ monitor = _create_monitor()
+
+ assert monitor._count_emoji("hello") == 0
+ # The emoji patterns defined in the source are empty strings,
+ # so this tests the current behavior
+ assert isinstance(monitor._count_emoji("hello 😀"), int)
+
+ def test_recommendations_low_consistency(self):
+ """Test recommendations for low consistency."""
+ monitor = _create_monitor()
+ metrics = PersonaMetrics(consistency_score=0.5)
+
+ recs = monitor._generate_recommendations(metrics, [])
+ assert any("一致性" in r for r in recs)
+
+ def test_recommendations_good_quality(self):
+ """Test recommendations for good quality."""
+ monitor = _create_monitor()
+ metrics = PersonaMetrics(consistency_score=0.9, style_stability=0.8)
+
+ recs = monitor._generate_recommendations(metrics, [])
+ assert any("良好" in r for r in recs)
+
+ @pytest.mark.asyncio
+ async def test_stop(self):
+ """Test service stop."""
+ monitor = _create_monitor()
+
+ result = await monitor.stop()
+ assert result is True
diff --git a/tests/unit/test_patterns.py b/tests/unit/test_patterns.py
new file mode 100644
index 0000000..3e76501
--- /dev/null
+++ b/tests/unit/test_patterns.py
@@ -0,0 +1,446 @@
+"""
+Unit tests for core design patterns module
+
+Tests the design pattern implementations:
+- AsyncServiceBase lifecycle management
+- LearningContextBuilder (builder pattern)
+- StrategyFactory (factory + strategy patterns)
+- ServiceRegistry (singleton + registry pattern)
+- ProgressiveLearningStrategy execution
+- BatchLearningStrategy execution
+"""
+import asyncio
+import pytest
+from unittest.mock import patch, MagicMock, AsyncMock
+from datetime import datetime
+
+from core.interfaces import (
+ ServiceLifecycle,
+ LearningStrategyType,
+ MessageData,
+)
+from core.patterns import (
+ AsyncServiceBase,
+ LearningContext,
+ LearningContextBuilder,
+ ProgressiveLearningStrategy,
+ BatchLearningStrategy,
+ StrategyFactory,
+ ServiceRegistry,
+ SingletonABCMeta,
+)
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestAsyncServiceBase:
+ """Test AsyncServiceBase lifecycle management."""
+
+ def _create_service(self, name: str = "test_service") -> AsyncServiceBase:
+ """Helper to create a concrete service instance."""
+ return AsyncServiceBase(name)
+
+ def test_initial_status_is_created(self):
+ """Test service starts in CREATED state."""
+ service = self._create_service()
+ assert service.status == ServiceLifecycle.CREATED
+
+ @pytest.mark.asyncio
+ async def test_start_transitions_to_running(self):
+ """Test starting service transitions to RUNNING."""
+ service = self._create_service()
+
+ result = await service.start()
+
+ assert result is True
+ assert service.status == ServiceLifecycle.RUNNING
+
+ @pytest.mark.asyncio
+ async def test_start_already_running(self):
+ """Test starting already running service returns True."""
+ service = self._create_service()
+ await service.start()
+
+ result = await service.start()
+
+ assert result is True
+ assert service.status == ServiceLifecycle.RUNNING
+
+ @pytest.mark.asyncio
+ async def test_stop_transitions_to_stopped(self):
+ """Test stopping service transitions to STOPPED."""
+ service = self._create_service()
+ await service.start()
+
+ result = await service.stop()
+
+ assert result is True
+ assert service.status == ServiceLifecycle.STOPPED
+
+ @pytest.mark.asyncio
+ async def test_stop_already_stopped(self):
+ """Test stopping already stopped service returns True."""
+ service = self._create_service()
+ await service.start()
+ await service.stop()
+
+ result = await service.stop()
+
+ assert result is True
+ assert service.status == ServiceLifecycle.STOPPED
+
+ @pytest.mark.asyncio
+ async def test_restart(self):
+ """Test restarting service."""
+ service = self._create_service()
+ await service.start()
+
+ result = await service.restart()
+
+ assert result is True
+ assert service.status == ServiceLifecycle.RUNNING
+
+ @pytest.mark.asyncio
+ async def test_is_running(self):
+ """Test is_running check."""
+ service = self._create_service()
+
+ assert await service.is_running() is False
+
+ await service.start()
+ assert await service.is_running() is True
+
+ await service.stop()
+ assert await service.is_running() is False
+
+ @pytest.mark.asyncio
+ async def test_health_check(self):
+ """Test health check reflects running state."""
+ service = self._create_service()
+
+ assert await service.health_check() is False
+
+ await service.start()
+ assert await service.health_check() is True
+
+ @pytest.mark.asyncio
+ async def test_start_failure_transitions_to_error(self):
+ """Test service transitions to ERROR on start failure."""
+ service = self._create_service()
+ service._do_start = AsyncMock(side_effect=RuntimeError("init failed"))
+
+ result = await service.start()
+
+ assert result is False
+ assert service.status == ServiceLifecycle.ERROR
+
+ @pytest.mark.asyncio
+ async def test_stop_failure_transitions_to_error(self):
+ """Test service transitions to ERROR on stop failure."""
+ service = self._create_service()
+ await service.start()
+ service._do_stop = AsyncMock(side_effect=RuntimeError("cleanup failed"))
+
+ result = await service.stop()
+
+ assert result is False
+ assert service.status == ServiceLifecycle.ERROR
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestLearningContextBuilder:
+ """Test LearningContextBuilder (builder pattern)."""
+
+ def test_build_default_context(self):
+ """Test building context with default values."""
+ context = LearningContextBuilder().build()
+
+ assert isinstance(context, LearningContext)
+ assert context.messages == []
+ assert context.strategy_type == LearningStrategyType.PROGRESSIVE
+ assert context.quality_threshold == 0.7
+ assert context.max_iterations == 3
+ assert context.metadata == {}
+
+ def test_builder_chain(self):
+ """Test fluent builder chaining."""
+ messages = [
+ MessageData(
+ sender_id="u1", sender_name="Alice",
+ message="Hello", group_id="g1",
+ timestamp=1.0, platform="qq",
+ )
+ ]
+
+ context = (
+ LearningContextBuilder()
+ .with_messages(messages)
+ .with_strategy(LearningStrategyType.BATCH)
+ .with_quality_threshold(0.9)
+ .with_max_iterations(5)
+ .with_metadata("source", "test")
+ .build()
+ )
+
+ assert len(context.messages) == 1
+ assert context.strategy_type == LearningStrategyType.BATCH
+ assert context.quality_threshold == 0.9
+ assert context.max_iterations == 5
+ assert context.metadata["source"] == "test"
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestStrategyFactory:
+ """Test StrategyFactory (factory + strategy patterns)."""
+
+ def test_create_progressive_strategy(self):
+ """Test creating progressive learning strategy."""
+ config = {"batch_size": 25, "min_messages": 10}
+ strategy = StrategyFactory.create_strategy(
+ LearningStrategyType.PROGRESSIVE, config
+ )
+
+ assert isinstance(strategy, ProgressiveLearningStrategy)
+ assert strategy.config == config
+
+ def test_create_batch_strategy(self):
+ """Test creating batch learning strategy."""
+ config = {"batch_size": 100}
+ strategy = StrategyFactory.create_strategy(
+ LearningStrategyType.BATCH, config
+ )
+
+ assert isinstance(strategy, BatchLearningStrategy)
+
+ def test_create_unsupported_strategy_raises(self):
+ """Test creating unsupported strategy raises ValueError."""
+ with pytest.raises(ValueError, match="不支持的策略类型"):
+ StrategyFactory.create_strategy(
+ LearningStrategyType.REALTIME, {}
+ )
+
+ def test_register_custom_strategy(self):
+ """Test registering a custom strategy type."""
+
+ class CustomStrategy:
+ def __init__(self, config):
+ self.config = config
+
+ StrategyFactory.register_strategy(
+ LearningStrategyType.REALTIME, CustomStrategy
+ )
+ strategy = StrategyFactory.create_strategy(
+ LearningStrategyType.REALTIME, {"custom": True}
+ )
+
+ assert isinstance(strategy, CustomStrategy)
+
+ # Cleanup: remove custom strategy to avoid test pollution
+ del StrategyFactory._strategies[LearningStrategyType.REALTIME]
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestProgressiveLearningStrategy:
+ """Test ProgressiveLearningStrategy execution logic."""
+
+ def _make_messages(self, count: int):
+ """Helper to create test messages."""
+ return [
+ MessageData(
+ sender_id=f"u{i}", sender_name=f"User{i}",
+ message=f"Message {i}", group_id="g1",
+ timestamp=float(i), platform="qq",
+ )
+ for i in range(count)
+ ]
+
+ @pytest.mark.asyncio
+ async def test_execute_learning_cycle_success(self):
+ """Test progressive learning cycle executes successfully."""
+ strategy = ProgressiveLearningStrategy({"batch_size": 10})
+ messages = self._make_messages(25)
+
+ result = await strategy.execute_learning_cycle(messages)
+
+ assert result.success is True
+ assert result.confidence > 0
+ assert result.data["total_processed"] == 25
+ assert result.data["batch_count"] == 3
+
+ @pytest.mark.asyncio
+ async def test_execute_learning_cycle_empty_messages(self):
+ """Test progressive learning cycle with empty messages."""
+ strategy = ProgressiveLearningStrategy({"batch_size": 10})
+
+ result = await strategy.execute_learning_cycle([])
+
+ assert result.success is True
+ assert result.data["total_processed"] == 0
+
+ @pytest.mark.asyncio
+ async def test_should_learn_sufficient_messages(self):
+ """Test should_learn returns True when conditions are met."""
+ strategy = ProgressiveLearningStrategy({
+ "min_messages": 5,
+ "min_interval_hours": 0,
+ })
+ context = {
+ "message_count": 10,
+ "last_learning_time": 0,
+ }
+
+ result = await strategy.should_learn(context)
+ assert result is True
+
+ @pytest.mark.asyncio
+ async def test_should_learn_insufficient_messages(self):
+ """Test should_learn returns False with insufficient messages."""
+ strategy = ProgressiveLearningStrategy({
+ "min_messages": 20,
+ "min_interval_hours": 0,
+ })
+ context = {
+ "message_count": 5,
+ "last_learning_time": 0,
+ }
+
+ result = await strategy.should_learn(context)
+ assert result is False
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestBatchLearningStrategy:
+ """Test BatchLearningStrategy execution logic."""
+
+ def _make_messages(self, count: int):
+ """Helper to create test messages."""
+ return [
+ MessageData(
+ sender_id=f"u{i}", sender_name=f"User{i}",
+ message=f"Message {i}", group_id="g1",
+ timestamp=float(i), platform="qq",
+ )
+ for i in range(count)
+ ]
+
+ @pytest.mark.asyncio
+ async def test_execute_learning_cycle_success(self):
+ """Test batch learning cycle executes successfully."""
+ strategy = BatchLearningStrategy({"batch_size": 100})
+ messages = self._make_messages(50)
+
+ result = await strategy.execute_learning_cycle(messages)
+
+ assert result.success is True
+ assert result.confidence > 0
+ assert result.data["processed_count"] == 50
+
+ @pytest.mark.asyncio
+ async def test_should_learn_above_threshold(self):
+ """Test should_learn returns True when batch threshold is met."""
+ strategy = BatchLearningStrategy({"batch_size": 20})
+ context = {"message_count": 25}
+
+ result = await strategy.should_learn(context)
+ assert result is True
+
+ @pytest.mark.asyncio
+ async def test_should_learn_below_threshold(self):
+ """Test should_learn returns False below batch threshold."""
+ strategy = BatchLearningStrategy({"batch_size": 100})
+ context = {"message_count": 50}
+
+ result = await strategy.should_learn(context)
+ assert result is False
+
+
+@pytest.mark.unit
+@pytest.mark.core
+class TestServiceRegistry:
+ """Test ServiceRegistry (singleton + registry pattern)."""
+
+ def _create_fresh_registry(self) -> ServiceRegistry:
+ """Create a fresh registry by clearing singleton cache."""
+ # Clear singleton instance to avoid test pollution
+ SingletonABCMeta._instances.pop(ServiceRegistry, None)
+ return ServiceRegistry(service_stop_timeout=2)
+
+ def test_register_service(self):
+ """Test registering a service."""
+ registry = self._create_fresh_registry()
+ service = AsyncServiceBase("test_svc")
+
+ registry.register_service("test_svc", service)
+
+ assert registry.get_service("test_svc") is service
+
+ def test_get_nonexistent_service(self):
+ """Test getting a nonexistent service returns None."""
+ registry = self._create_fresh_registry()
+
+ assert registry.get_service("nonexistent") is None
+
+ def test_unregister_service(self):
+ """Test unregistering a service."""
+ registry = self._create_fresh_registry()
+ service = AsyncServiceBase("test_svc")
+ registry.register_service("test_svc", service)
+
+ result = registry.unregister_service("test_svc")
+
+ assert result is True
+ assert registry.get_service("test_svc") is None
+
+ def test_unregister_nonexistent_service(self):
+ """Test unregistering a nonexistent service returns False."""
+ registry = self._create_fresh_registry()
+
+ result = registry.unregister_service("nonexistent")
+
+ assert result is False
+
+ @pytest.mark.asyncio
+ async def test_start_all_services(self):
+ """Test starting all registered services."""
+ registry = self._create_fresh_registry()
+ svc1 = AsyncServiceBase("svc1")
+ svc2 = AsyncServiceBase("svc2")
+ registry.register_service("svc1", svc1)
+ registry.register_service("svc2", svc2)
+
+ result = await registry.start_all_services()
+
+ assert result is True
+ assert svc1.status == ServiceLifecycle.RUNNING
+ assert svc2.status == ServiceLifecycle.RUNNING
+
+ @pytest.mark.asyncio
+ async def test_stop_all_services(self):
+ """Test stopping all registered services."""
+ registry = self._create_fresh_registry()
+ svc1 = AsyncServiceBase("svc1")
+ svc2 = AsyncServiceBase("svc2")
+ registry.register_service("svc1", svc1)
+ registry.register_service("svc2", svc2)
+ await registry.start_all_services()
+
+ result = await registry.stop_all_services()
+
+ assert result is True
+ assert svc1.status == ServiceLifecycle.STOPPED
+ assert svc2.status == ServiceLifecycle.STOPPED
+
+ def test_get_service_status(self):
+ """Test getting status of all registered services."""
+ registry = self._create_fresh_registry()
+ svc = AsyncServiceBase("svc1")
+ registry.register_service("svc1", svc)
+
+ status = registry.get_service_status()
+
+ assert "svc1" in status
+ assert status["svc1"] == "created"
diff --git a/tests/unit/test_security_utils.py b/tests/unit/test_security_utils.py
new file mode 100644
index 0000000..bd4b10e
--- /dev/null
+++ b/tests/unit/test_security_utils.py
@@ -0,0 +1,306 @@
+"""
+Unit tests for security utilities module
+
+Tests the security infrastructure:
+- PasswordHasher: hash generation, salt handling, verification
+- LoginAttemptTracker: attempt recording, lockout, rate limiting
+- SecurityValidator: password strength, input sanitization, token validation
+- Password migration: plaintext to hashed format
+"""
+import time
+import pytest
+from unittest.mock import patch
+
+from utils.security_utils import (
+ PasswordHasher,
+ LoginAttemptTracker,
+ SecurityValidator,
+ migrate_password_to_hashed,
+ verify_password_with_migration,
+)
+
+
+@pytest.mark.unit
+@pytest.mark.security
+class TestPasswordHasher:
+ """Test PasswordHasher functionality."""
+
+ def test_hash_password_returns_tuple(self):
+ """Test hash_password returns (hash, salt) tuple."""
+ hashed, salt = PasswordHasher.hash_password("test_password")
+
+ assert isinstance(hashed, str)
+ assert isinstance(salt, str)
+ assert len(hashed) == 32 # MD5 hex digest length
+ assert len(salt) == 32 # 16 bytes = 32 hex chars
+
+ def test_hash_password_with_custom_salt(self):
+ """Test hash_password with a provided salt."""
+ hashed, salt = PasswordHasher.hash_password("test_password", salt="fixed_salt")
+
+ assert salt == "fixed_salt"
+ assert len(hashed) == 32
+
+ def test_same_password_same_salt_same_hash(self):
+ """Test deterministic hashing with same password and salt."""
+ h1, _ = PasswordHasher.hash_password("password", salt="salt123")
+ h2, _ = PasswordHasher.hash_password("password", salt="salt123")
+
+ assert h1 == h2
+
+ def test_same_password_different_salt_different_hash(self):
+ """Test different salts produce different hashes."""
+ h1, _ = PasswordHasher.hash_password("password", salt="salt_a")
+ h2, _ = PasswordHasher.hash_password("password", salt="salt_b")
+
+ assert h1 != h2
+
+ def test_different_password_same_salt_different_hash(self):
+ """Test different passwords produce different hashes."""
+ h1, _ = PasswordHasher.hash_password("password1", salt="same_salt")
+ h2, _ = PasswordHasher.hash_password("password2", salt="same_salt")
+
+ assert h1 != h2
+
+ def test_verify_correct_password(self):
+ """Test password verification with correct password."""
+ hashed, salt = PasswordHasher.hash_password("correct_password")
+
+ result = PasswordHasher.verify_password("correct_password", hashed, salt)
+ assert result is True
+
+ def test_verify_incorrect_password(self):
+ """Test password verification with incorrect password."""
+ hashed, salt = PasswordHasher.hash_password("correct_password")
+
+ result = PasswordHasher.verify_password("wrong_password", hashed, salt)
+ assert result is False
+
+
+@pytest.mark.unit
+@pytest.mark.security
+class TestLoginAttemptTracker:
+ """Test LoginAttemptTracker functionality."""
+
+ def test_initial_state_not_locked(self):
+ """Test new IP is not locked."""
+ tracker = LoginAttemptTracker(max_attempts=3)
+
+ locked, remaining = tracker.is_locked("192.168.1.1")
+ assert locked is False
+ assert remaining == 0
+
+ def test_full_remaining_attempts_for_new_ip(self):
+ """Test new IP has full remaining attempts."""
+ tracker = LoginAttemptTracker(max_attempts=5)
+
+ remaining = tracker.get_remaining_attempts("192.168.1.1")
+ assert remaining == 5
+
+ def test_failed_attempt_decreases_remaining(self):
+ """Test failed attempt decreases remaining count."""
+ tracker = LoginAttemptTracker(max_attempts=5)
+
+ tracker.record_attempt("192.168.1.1", success=False)
+
+ remaining = tracker.get_remaining_attempts("192.168.1.1")
+ assert remaining == 4
+
+ def test_lockout_after_max_attempts(self):
+ """Test IP is locked after max failed attempts."""
+ tracker = LoginAttemptTracker(max_attempts=3, lockout_duration=300)
+
+ for _ in range(3):
+ tracker.record_attempt("192.168.1.1", success=False)
+
+ locked, remaining_seconds = tracker.is_locked("192.168.1.1")
+ assert locked is True
+ assert remaining_seconds > 0
+
+ def test_successful_login_clears_attempts(self):
+ """Test successful login clears failed attempt history."""
+ tracker = LoginAttemptTracker(max_attempts=5)
+
+ tracker.record_attempt("192.168.1.1", success=False)
+ tracker.record_attempt("192.168.1.1", success=False)
+ tracker.record_attempt("192.168.1.1", success=True)
+
+ remaining = tracker.get_remaining_attempts("192.168.1.1")
+ assert remaining == 5
+
+ def test_different_ips_independent(self):
+ """Test tracking is independent per IP."""
+ tracker = LoginAttemptTracker(max_attempts=3)
+
+ tracker.record_attempt("192.168.1.1", success=False)
+ tracker.record_attempt("192.168.1.1", success=False)
+
+ remaining_ip1 = tracker.get_remaining_attempts("192.168.1.1")
+ remaining_ip2 = tracker.get_remaining_attempts("192.168.1.2")
+
+ assert remaining_ip1 == 1
+ assert remaining_ip2 == 3
+
+ def test_clear_ip_record(self):
+ """Test clearing a specific IP record."""
+ tracker = LoginAttemptTracker(max_attempts=3)
+
+ tracker.record_attempt("192.168.1.1", success=False)
+ tracker.clear_ip_record("192.168.1.1")
+
+ remaining = tracker.get_remaining_attempts("192.168.1.1")
+ assert remaining == 3
+
+ def test_clear_all_records(self):
+ """Test clearing all IP records."""
+ tracker = LoginAttemptTracker(max_attempts=3)
+
+ tracker.record_attempt("192.168.1.1", success=False)
+ tracker.record_attempt("192.168.1.2", success=False)
+ tracker.clear_all_records()
+
+ assert tracker.get_remaining_attempts("192.168.1.1") == 3
+ assert tracker.get_remaining_attempts("192.168.1.2") == 3
+
+
+@pytest.mark.unit
+@pytest.mark.security
+class TestSecurityValidator:
+ """Test SecurityValidator functionality."""
+
+ def test_strong_password(self):
+ """Test strong password validation."""
+ result = SecurityValidator.validate_password_strength("Str0ng!P@ss")
+
+ assert result['valid'] is True
+ assert result['strength'] == 'strong'
+ assert result['checks']['length'] is True
+ assert result['checks']['lowercase'] is True
+ assert result['checks']['uppercase'] is True
+ assert result['checks']['numbers'] is True
+ assert result['checks']['symbols'] is True
+
+ def test_weak_password_short(self):
+ """Test weak password: too short."""
+ result = SecurityValidator.validate_password_strength("abc")
+
+ assert result['valid'] is False
+ assert result['strength'] == 'weak'
+ assert result['checks']['length'] is False
+
+ def test_medium_password(self):
+ """Test medium strength password."""
+ result = SecurityValidator.validate_password_strength("Password1")
+
+ assert result['valid'] is True
+ assert result['strength'] in ('medium', 'strong')
+
+ def test_extra_long_password_bonus(self):
+ """Test extra long password gets bonus score."""
+ result = SecurityValidator.validate_password_strength("ThisIsAVeryLongPassword123!")
+
+ assert result['score'] > 80
+
+ def test_sanitize_input_strips_whitespace(self):
+ """Test input sanitization strips whitespace."""
+ result = SecurityValidator.sanitize_input(" hello world ")
+ assert result == "hello world"
+
+ def test_sanitize_input_truncates(self):
+ """Test input sanitization truncates to max_length."""
+ result = SecurityValidator.sanitize_input("a" * 300, max_length=10)
+ assert len(result) == 10
+
+ def test_sanitize_empty_input(self):
+ """Test sanitizing empty input."""
+ assert SecurityValidator.sanitize_input("") == ""
+ assert SecurityValidator.sanitize_input(None) == ""
+
+ def test_valid_session_token(self):
+ """Test valid session token validation."""
+ token = "a" * 32 # 32-char hex string
+ assert SecurityValidator.is_valid_session_token(token) is True
+
+ def test_invalid_session_token_too_short(self):
+ """Test invalid session token: too short."""
+ assert SecurityValidator.is_valid_session_token("abc123") is False
+
+ def test_invalid_session_token_non_hex(self):
+ """Test invalid session token: non-hex characters."""
+ assert SecurityValidator.is_valid_session_token("z" * 32) is False
+
+ def test_empty_session_token(self):
+ """Test empty session token is invalid."""
+ assert SecurityValidator.is_valid_session_token("") is False
+ assert SecurityValidator.is_valid_session_token(None) is False
+
+
+@pytest.mark.unit
+@pytest.mark.security
+class TestPasswordMigration:
+ """Test password migration from plaintext to hashed format."""
+
+ def test_migrate_plaintext_password(self):
+ """Test migrating plaintext password to hashed format."""
+ old_config = {'password': 'my_password', 'must_change': False}
+
+ new_config = migrate_password_to_hashed(old_config)
+
+ assert 'password_hash' in new_config
+ assert 'salt' in new_config
+ assert new_config['version'] == 2
+ assert new_config['migrated_from_plaintext'] is True
+
+ def test_already_hashed_not_migrated(self):
+ """Test already hashed config is not re-migrated."""
+ hashed_config = {
+ 'password_hash': 'existing_hash',
+ 'salt': 'existing_salt',
+ 'version': 2,
+ }
+
+ result = migrate_password_to_hashed(hashed_config)
+
+ assert result is hashed_config # Same object
+
+ def test_verify_with_old_format(self):
+ """Test verification with old plaintext format triggers migration."""
+ old_config = {'password': 'test_pass'}
+
+ is_valid, new_config = verify_password_with_migration('test_pass', old_config)
+
+ assert is_valid is True
+ assert 'password_hash' in new_config
+
+ def test_verify_with_old_format_wrong_password(self):
+ """Test verification with wrong password in old format."""
+ old_config = {'password': 'correct_pass'}
+
+ is_valid, config = verify_password_with_migration('wrong_pass', old_config)
+
+ assert is_valid is False
+
+ def test_verify_with_new_format(self):
+ """Test verification with new hashed format."""
+ hashed, salt = PasswordHasher.hash_password("secure_pwd")
+ new_config = {'password_hash': hashed, 'salt': salt}
+
+ is_valid, config = verify_password_with_migration('secure_pwd', new_config)
+
+ assert is_valid is True
+
+ def test_verify_with_new_format_wrong_password(self):
+ """Test verification with wrong password in new format."""
+ hashed, salt = PasswordHasher.hash_password("correct_pwd")
+ new_config = {'password_hash': hashed, 'salt': salt}
+
+ is_valid, config = verify_password_with_migration('wrong_pwd', new_config)
+
+ assert is_valid is False
+
+ def test_verify_with_missing_hash_or_salt(self):
+ """Test verification fails gracefully when hash or salt is missing."""
+ config = {'password_hash': '', 'salt': ''}
+
+ is_valid, _ = verify_password_with_migration('any_pwd', config)
+ assert is_valid is False
diff --git a/tests/unit/test_tiered_learning_trigger.py b/tests/unit/test_tiered_learning_trigger.py
new file mode 100644
index 0000000..7fd600c
--- /dev/null
+++ b/tests/unit/test_tiered_learning_trigger.py
@@ -0,0 +1,441 @@
+"""
+Unit tests for TieredLearningTrigger
+
+Tests the two-tier learning trigger mechanism:
+- Tier 1 registration and execution (per-message, concurrent)
+- Tier 2 registration and gated execution (batch, cooldown/threshold)
+- Error isolation between Tier 1 operations
+- BatchTriggerPolicy configuration
+- force_tier2 fast-path triggering
+- Per-group state tracking and statistics
+"""
+import asyncio
+import time
+import pytest
+from unittest.mock import AsyncMock, patch
+
+from core.interfaces import MessageData
+from services.quality.tiered_learning_trigger import (
+ TieredLearningTrigger,
+ BatchTriggerPolicy,
+ TriggerResult,
+ _GroupTriggerState,
+)
+
+
+def _make_message(text: str = "test message", group_id: str = "g1") -> MessageData:
+ """Helper to create a test MessageData instance."""
+ return MessageData(
+ sender_id="user1",
+ sender_name="TestUser",
+ message=text,
+ group_id=group_id,
+ timestamp=time.time(),
+ platform="test",
+ )
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestTieredLearningTriggerRegistration:
+ """Test callback registration for both tiers."""
+
+ def test_register_tier1_success(self):
+ """Test registering a valid Tier 1 async callback."""
+ trigger = TieredLearningTrigger()
+
+ async def callback(msg, gid):
+ pass
+
+ trigger.register_tier1("test_op", callback)
+ assert "test_op" in trigger._tier1_ops
+
+ def test_register_tier1_none_callback_ignored(self):
+ """Test registering None callback is silently ignored."""
+ trigger = TieredLearningTrigger()
+ trigger.register_tier1("test_op", None)
+ assert "test_op" not in trigger._tier1_ops
+
+ def test_register_tier1_sync_callback_raises(self):
+ """Test registering a sync callback raises TypeError."""
+ trigger = TieredLearningTrigger()
+
+ def sync_callback(msg, gid):
+ pass
+
+ with pytest.raises(TypeError, match="must be an async function"):
+ trigger.register_tier1("bad_op", sync_callback)
+
+ def test_register_tier2_success(self):
+ """Test registering a valid Tier 2 async callback."""
+ trigger = TieredLearningTrigger()
+
+ async def callback(gid):
+ pass
+
+ trigger.register_tier2("batch_op", callback)
+ assert "batch_op" in trigger._tier2_ops
+
+ def test_register_tier2_with_custom_policy(self):
+ """Test registering Tier 2 with custom policy."""
+ trigger = TieredLearningTrigger()
+ policy = BatchTriggerPolicy(message_threshold=50, cooldown_seconds=300.0)
+
+ async def callback(gid):
+ pass
+
+ trigger.register_tier2("batch_op", callback, policy=policy)
+
+ stored_callback, stored_policy = trigger._tier2_ops["batch_op"]
+ assert stored_policy.message_threshold == 50
+ assert stored_policy.cooldown_seconds == 300.0
+
+ def test_register_tier2_default_policy(self):
+ """Test registering Tier 2 uses default policy when none specified."""
+ trigger = TieredLearningTrigger()
+
+ async def callback(gid):
+ pass
+
+ trigger.register_tier2("batch_op", callback)
+
+ _, stored_policy = trigger._tier2_ops["batch_op"]
+ assert stored_policy.message_threshold == 15
+ assert stored_policy.cooldown_seconds == 120.0
+
+ def test_register_tier2_sync_callback_raises(self):
+ """Test registering a sync Tier 2 callback raises TypeError."""
+ trigger = TieredLearningTrigger()
+
+ def sync_callback(gid):
+ pass
+
+ with pytest.raises(TypeError, match="must be an async function"):
+ trigger.register_tier2("bad_op", sync_callback)
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestTier1Execution:
+ """Test Tier 1 per-message execution."""
+
+ @pytest.mark.asyncio
+ async def test_tier1_executes_for_every_message(self):
+ """Test Tier 1 callbacks execute for every incoming message."""
+ trigger = TieredLearningTrigger()
+ call_count = 0
+
+ async def tier1_callback(msg, gid):
+ nonlocal call_count
+ call_count += 1
+
+ trigger.register_tier1("counter", tier1_callback)
+
+ for _ in range(5):
+ await trigger.process_message(_make_message(), "g1")
+
+ assert call_count == 5
+
+ @pytest.mark.asyncio
+ async def test_tier1_concurrent_execution(self):
+ """Test multiple Tier 1 callbacks run concurrently."""
+ trigger = TieredLearningTrigger()
+ execution_order = []
+
+ async def op_a(msg, gid):
+ execution_order.append("a")
+
+ async def op_b(msg, gid):
+ execution_order.append("b")
+
+ trigger.register_tier1("op_a", op_a)
+ trigger.register_tier1("op_b", op_b)
+
+ await trigger.process_message(_make_message(), "g1")
+
+ assert "a" in execution_order
+ assert "b" in execution_order
+
+ @pytest.mark.asyncio
+ async def test_tier1_error_isolation(self):
+ """Test one Tier 1 failure does not affect others."""
+ trigger = TieredLearningTrigger()
+ healthy_called = False
+
+ async def failing_op(msg, gid):
+ raise RuntimeError("Tier 1 failure")
+
+ async def healthy_op(msg, gid):
+ nonlocal healthy_called
+ healthy_called = True
+
+ trigger.register_tier1("failing", failing_op)
+ trigger.register_tier1("healthy", healthy_op)
+
+ result = await trigger.process_message(_make_message(), "g1")
+
+ assert healthy_called is True
+ assert result.tier1_details["failing"] is False
+ assert result.tier1_details["healthy"] is True
+ assert result.tier1_ok is False
+
+ @pytest.mark.asyncio
+ async def test_tier1_all_success(self):
+ """Test tier1_ok is True when all operations succeed."""
+ trigger = TieredLearningTrigger()
+
+ async def ok_op(msg, gid):
+ pass
+
+ trigger.register_tier1("op1", ok_op)
+ trigger.register_tier1("op2", ok_op)
+
+ result = await trigger.process_message(_make_message(), "g1")
+
+ assert result.tier1_ok is True
+ assert all(result.tier1_details.values())
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestTier2Execution:
+ """Test Tier 2 batch execution with gating."""
+
+ @pytest.mark.asyncio
+ async def test_tier2_triggers_on_message_threshold(self):
+ """Test Tier 2 fires when message count reaches threshold."""
+ trigger = TieredLearningTrigger()
+ tier2_called = False
+
+ async def tier1_noop(msg, gid):
+ pass
+
+ async def tier2_callback(gid):
+ nonlocal tier2_called
+ tier2_called = True
+
+ trigger.register_tier1("noop", tier1_noop)
+ trigger.register_tier2(
+ "batch", tier2_callback,
+ policy=BatchTriggerPolicy(message_threshold=3, cooldown_seconds=9999),
+ )
+
+ for _ in range(3):
+ result = await trigger.process_message(_make_message(), "g1")
+
+ assert tier2_called is True
+
+ @pytest.mark.asyncio
+ async def test_tier2_does_not_trigger_below_threshold(self):
+ """Test Tier 2 does not fire below message threshold."""
+ trigger = TieredLearningTrigger()
+ tier2_called = False
+
+ async def tier1_noop(msg, gid):
+ pass
+
+ async def tier2_callback(gid):
+ nonlocal tier2_called
+ tier2_called = True
+
+ trigger.register_tier1("noop", tier1_noop)
+ trigger.register_tier2(
+ "batch", tier2_callback,
+ policy=BatchTriggerPolicy(message_threshold=100, cooldown_seconds=9999),
+ )
+
+ # Process only 2 messages
+ for _ in range(2):
+ await trigger.process_message(_make_message(), "g1")
+
+ assert tier2_called is False
+
+ @pytest.mark.asyncio
+ async def test_tier2_triggers_on_cooldown_expiry(self):
+ """Test Tier 2 fires when cooldown expires even below threshold."""
+ trigger = TieredLearningTrigger()
+ tier2_called = False
+
+ async def tier1_noop(msg, gid):
+ pass
+
+ async def tier2_callback(gid):
+ nonlocal tier2_called
+ tier2_called = True
+
+ trigger.register_tier1("noop", tier1_noop)
+ trigger.register_tier2(
+ "batch", tier2_callback,
+ policy=BatchTriggerPolicy(message_threshold=9999, cooldown_seconds=0.0),
+ )
+
+ # First message initializes state; cooldown=0 means always ready
+ # But _get_state initializes last_op_times to now, so the first
+ # process_message won't trigger. We need to manually adjust.
+ state = trigger._get_state("g1")
+ state.last_op_times["batch"] = 0.0 # Long ago
+
+ result = await trigger.process_message(_make_message(), "g1")
+
+ assert tier2_called is True
+ assert result.tier2_triggered is True
+
+ @pytest.mark.asyncio
+ async def test_tier2_resets_counter_after_trigger(self):
+ """Test message counter resets after Tier 2 trigger."""
+ trigger = TieredLearningTrigger()
+
+ async def tier1_noop(msg, gid):
+ pass
+
+ async def tier2_callback(gid):
+ pass
+
+ trigger.register_tier1("noop", tier1_noop)
+ trigger.register_tier2(
+ "batch", tier2_callback,
+ policy=BatchTriggerPolicy(message_threshold=2, cooldown_seconds=9999),
+ )
+
+ await trigger.process_message(_make_message(), "g1")
+ await trigger.process_message(_make_message(), "g1")
+
+ state = trigger._states["g1"]
+ assert state.message_count == 0 # Reset after trigger
+
+ @pytest.mark.asyncio
+ async def test_tier2_error_handling(self):
+ """Test Tier 2 failure is captured in result."""
+ trigger = TieredLearningTrigger()
+
+ async def tier1_noop(msg, gid):
+ pass
+
+ async def failing_tier2(gid):
+ raise RuntimeError("Batch failure")
+
+ trigger.register_tier1("noop", tier1_noop)
+ trigger.register_tier2(
+ "batch", failing_tier2,
+ policy=BatchTriggerPolicy(message_threshold=1, cooldown_seconds=9999),
+ )
+
+ result = await trigger.process_message(_make_message(), "g1")
+
+ assert result.tier2_triggered is True
+ assert result.tier2_details["batch"] is False
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestForceTier2:
+ """Test force_tier2 fast-path triggering."""
+
+ @pytest.mark.asyncio
+ async def test_force_tier2_success(self):
+ """Test force triggering a registered Tier 2 operation."""
+ trigger = TieredLearningTrigger()
+ called_with_group = None
+
+ async def tier2_callback(gid):
+ nonlocal called_with_group
+ called_with_group = gid
+
+ trigger.register_tier2("batch", tier2_callback)
+
+ result = await trigger.force_tier2("batch", "g1")
+
+ assert result is True
+ assert called_with_group == "g1"
+
+ @pytest.mark.asyncio
+ async def test_force_tier2_unregistered_operation(self):
+ """Test forcing an unregistered operation returns False."""
+ trigger = TieredLearningTrigger()
+
+ result = await trigger.force_tier2("nonexistent", "g1")
+ assert result is False
+
+ @pytest.mark.asyncio
+ async def test_force_tier2_failure(self):
+ """Test force_tier2 returns False on callback failure."""
+ trigger = TieredLearningTrigger()
+
+ async def failing_callback(gid):
+ raise RuntimeError("force failure")
+
+ trigger.register_tier2("batch", failing_callback)
+
+ result = await trigger.force_tier2("batch", "g1")
+ assert result is False
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestGroupStats:
+ """Test per-group statistics."""
+
+ def test_stats_for_unknown_group(self):
+ """Test stats for unknown group returns inactive."""
+ trigger = TieredLearningTrigger()
+
+ stats = trigger.get_group_stats("unknown_group")
+ assert stats == {"active": False}
+
+ @pytest.mark.asyncio
+ async def test_stats_after_processing(self):
+ """Test stats reflect processing state."""
+ trigger = TieredLearningTrigger()
+
+ async def tier1_noop(msg, gid):
+ pass
+
+ trigger.register_tier1("noop", tier1_noop)
+
+ await trigger.process_message(_make_message(), "g1")
+ await trigger.process_message(_make_message(), "g1")
+
+ stats = trigger.get_group_stats("g1")
+ assert stats["active"] is True
+ assert stats["total_processed"] == 2
+ assert stats["message_count"] == 2 # No Tier 2 to reset
+
+ @pytest.mark.asyncio
+ async def test_stats_consecutive_tier1_errors(self):
+ """Test consecutive Tier 1 error tracking."""
+ trigger = TieredLearningTrigger()
+
+ async def failing_op(msg, gid):
+ raise RuntimeError("fail")
+
+ trigger.register_tier1("failing", failing_op)
+
+ await trigger.process_message(_make_message(), "g1")
+ await trigger.process_message(_make_message(), "g1")
+
+ stats = trigger.get_group_stats("g1")
+ assert stats["consecutive_tier1_errors"] == 2
+
+
+@pytest.mark.unit
+@pytest.mark.quality
+class TestBatchTriggerPolicy:
+ """Test BatchTriggerPolicy dataclass."""
+
+ def test_default_values(self):
+ """Test default policy values."""
+ policy = BatchTriggerPolicy()
+ assert policy.message_threshold == 15
+ assert policy.cooldown_seconds == 120.0
+
+ def test_custom_values(self):
+ """Test custom policy values."""
+ policy = BatchTriggerPolicy(message_threshold=50, cooldown_seconds=600.0)
+ assert policy.message_threshold == 50
+ assert policy.cooldown_seconds == 600.0
+
+ def test_policy_is_frozen(self):
+ """Test policy dataclass is immutable."""
+ policy = BatchTriggerPolicy()
+ with pytest.raises(AttributeError):
+ policy.message_threshold = 99
diff --git a/web_res/static/js/macos/apps/Dashboard.js b/web_res/static/js/macos/apps/Dashboard.js
index f6a59e5..0e26727 100644
--- a/web_res/static/js/macos/apps/Dashboard.js
+++ b/web_res/static/js/macos/apps/Dashboard.js
@@ -555,14 +555,15 @@ window.AppDashboard = {
llmModels.length) *
100
: 0;
- // 响应速度
+ // 响应速度(无 LLM 数据时不显示为 0,而是显示为中性值)
var avgResponseTime =
llmModels.length > 0
? llmModels.reduce(function (s, m) {
return s + (m.avg_response_time_ms || 0);
}, 0) / llmModels.length
- : 2000;
- var responseSpeed = Math.max(0, 100 - avgResponseTime / 20);
+ : 0;
+ var responseSpeed =
+ llmModels.length > 0 ? Math.max(0, 100 - avgResponseTime / 20) : 50;
// 系统稳定性
var sm = stats.system_metrics || {};
var systemStability =
diff --git a/web_res/static/js/script.js b/web_res/static/js/script.js
index 99d808b..81c7b25 100644
--- a/web_res/static/js/script.js
+++ b/web_res/static/js/script.js
@@ -938,8 +938,9 @@ function initializeSystemStatusRadar() {
(sum, model) => sum + (model.avg_response_time_ms || 0),
0,
) / llmModels.length
- : 2000;
- const responseSpeed = Math.max(0, 100 - avgResponseTime / 20); // 2000ms = 0分,0ms = 100分
+ : 0;
+ const responseSpeed =
+ llmModels.length > 0 ? Math.max(0, 100 - avgResponseTime / 20) : 50;
// 系统稳定性 (基于CPU和内存使用率)
const systemMetrics = stats.system_metrics || {};
diff --git a/webui/blueprints/persona_reviews.py b/webui/blueprints/persona_reviews.py
index 7d07261..70f4570 100644
--- a/webui/blueprints/persona_reviews.py
+++ b/webui/blueprints/persona_reviews.py
@@ -37,6 +37,8 @@ async def review_persona_update(update_id: str):
"""审查人格更新内容 (批准/拒绝)"""
try:
data = await request.get_json()
+ if not data:
+ return jsonify({"error": "Request body is required"}), 400
action = data.get("action")
comment = data.get("comment", "")
modified_content = data.get("modified_content")
@@ -84,7 +86,7 @@ async def get_reviewed_persona_updates():
async def revert_persona_update(update_id: str):
"""撤回人格更新审查"""
try:
- data = await request.get_json()
+ data = await request.get_json() or {}
reason = data.get("reason", "撤回审查决定")
container = get_container()
@@ -128,6 +130,8 @@ async def batch_delete_persona_updates():
"""批量删除人格更新审查记录"""
try:
data = await request.get_json()
+ if not data:
+ return jsonify({"error": "Request body is required"}), 400
update_ids = data.get('update_ids', [])
if not update_ids or not isinstance(update_ids, list):
@@ -153,6 +157,8 @@ async def batch_review_persona_updates():
"""批量审查人格更新记录"""
try:
data = await request.get_json()
+ if not data:
+ return jsonify({"error": "Request body is required"}), 400
update_ids = data.get('update_ids', [])
action = data.get('action')
comment = data.get('comment', '')
diff --git a/webui/manager.py b/webui/manager.py
index f0bd910..5b78bca 100644
--- a/webui/manager.py
+++ b/webui/manager.py
@@ -1,6 +1,5 @@
"""WebUI 服务器全生命周期管理 — 创建、启动、停止、服务注册"""
import asyncio
-import gc
import sys
from typing import Optional, Any, Dict, TYPE_CHECKING
@@ -181,7 +180,6 @@ async def stop(self) -> None:
try:
logger.info(f"正在停止 Web 服务器 (端口: {_server_instance.port})...")
await _server_instance.stop()
- gc.collect()
if sys.platform == "win32":
logger.info("Windows 环境:等待端口资源释放...")
diff --git a/webui/server.py b/webui/server.py
index cd0b107..ccfc747 100644
--- a/webui/server.py
+++ b/webui/server.py
@@ -4,7 +4,6 @@
"""
import os
import sys
-import gc
import asyncio
import socket
import threading
@@ -192,7 +191,6 @@ async def stop(self):
Server._instance = None
self._initialized = False
- gc.collect()
logger.info("[WebUI] 服务器已停止")
except Exception as e:
diff --git a/webui/services/learning_service.py b/webui/services/learning_service.py
index ffbb697..7a53c48 100644
--- a/webui/services/learning_service.py
+++ b/webui/services/learning_service.py
@@ -40,9 +40,13 @@ async def get_style_learning_results(self) -> Dict[str, Any]:
if self.db_manager:
try:
- # 优先使用ORM Repository获取统计数据
- if hasattr(self.db_manager, 'get_session'):
- # 使用ORM方式获取统计
+ # 优先使用 Facade 方法获取统计数据
+ if hasattr(self.db_manager, 'get_style_learning_statistics'):
+ real_stats = await self.db_manager.get_style_learning_statistics()
+ if real_stats:
+ results_data['statistics'].update(real_stats)
+ elif hasattr(self.db_manager, 'get_session'):
+ # 降级到 Repository 方式
from ...repositories.learning_repository import StyleLearningReviewRepository
async with self.db_manager.get_session() as session:
@@ -52,11 +56,6 @@ async def get_style_learning_results(self) -> Dict[str, Any]:
results_data['statistics'].update(real_stats)
logger.debug(f"使用ORM获取风格学习统计: {real_stats}")
- else:
- # 降级到传统数据库方法
- real_stats = await self.db_manager.get_style_learning_statistics()
- if real_stats:
- results_data['statistics'].update(real_stats)
# 获取进度数据(保持原有逻辑)
real_progress = await self.db_manager.get_style_progress_data()
diff --git a/webui/services/persona_review_service.py b/webui/services/persona_review_service.py
index 8e6dcd7..89e7ee7 100644
--- a/webui/services/persona_review_service.py
+++ b/webui/services/persona_review_service.py
@@ -596,25 +596,53 @@ async def get_reviewed_persona_updates(
# 从传统人格更新审查获取
if self.persona_updater:
- traditional_updates = await self.persona_updater.get_reviewed_persona_updates(limit, offset, status_filter)
- reviewed_updates.extend(traditional_updates)
+ try:
+ traditional_updates = await self.persona_updater.get_reviewed_persona_updates(limit, offset, status_filter)
+ if traditional_updates:
+ reviewed_updates.extend(traditional_updates)
+ except Exception as e:
+ logger.warning(f"获取传统已审查人格更新失败: {e}")
# 从人格学习审查获取
if self.database_manager:
- persona_learning_updates = await self.database_manager.get_reviewed_persona_learning_updates(limit, offset, status_filter)
- reviewed_updates.extend(persona_learning_updates)
+ try:
+ persona_learning_updates = await self.database_manager.get_reviewed_persona_learning_updates(limit, offset, status_filter)
+ if persona_learning_updates:
+ # 为人格学习记录添加前缀 ID
+ for update in persona_learning_updates:
+ if update.get('id') is not None:
+ update['id'] = f"persona_learning_{update['id']}"
+ reviewed_updates.extend(persona_learning_updates)
+ except Exception as e:
+ logger.warning(f"获取已审查人格学习更新失败: {e}")
# 从风格学习审查获取
if self.database_manager:
- style_updates = await self.database_manager.get_reviewed_style_learning_updates(limit, offset, status_filter)
- # 将风格审查转换为统一格式
- for update in style_updates:
- if 'id' in update:
- update['id'] = f"style_{update['id']}"
- reviewed_updates.extend(style_updates)
+ try:
+ style_updates = await self.database_manager.get_reviewed_style_learning_updates(limit, offset, status_filter)
+ if style_updates:
+ for update in style_updates:
+ # 转换风格学习字段为前端统一格式
+ update['id'] = f"style_{update['id']}" if update.get('id') is not None else None
+ update['update_type'] = update.get('type', UPDATE_TYPE_STYLE_LEARNING)
+ update['original_content'] = update.get('original_content', '')
+ update['new_content'] = update.get('few_shots_content', '')
+ update['proposed_content'] = update.get('few_shots_content', '')
+ update['confidence_score'] = update.get('confidence_score', 0.9)
+ update['reason'] = update.get('description', '')
+ update['review_source'] = 'style_learning'
+ reviewed_updates.extend(style_updates)
+ except Exception as e:
+ logger.warning(f"获取已审查风格学习更新失败: {e}")
+
+ # 过滤掉无效记录(id 为 None 或空的条目)
+ reviewed_updates = [
+ u for u in reviewed_updates
+ if u and u.get('id') is not None
+ ]
# 按审查时间排序
- reviewed_updates.sort(key=lambda x: x.get('review_time', 0), reverse=True)
+ reviewed_updates.sort(key=lambda x: x.get('review_time') or 0, reverse=True)
return {
"success": True,