diff --git a/CHANGELOG.md b/CHANGELOG.md index 41df23a..e6bfa37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,111 @@ -# 🧧 新年快乐!Happy Lunar New Year! +# Changelog -> 祝所有用户和社区贡献者马年大吉、万事如意! +所有重要更改都将记录在此文件中。 ---- +## [Next-2.0.0] - 2026-02-21 + +### 🏗️ 架构重构 + +#### 全量 ORM 迁移(消除所有硬编码 SQL) +- 将 7 个服务文件中残留的硬编码 raw SQL 全部迁移至 SQLAlchemy ORM +- `expression_pattern_learner`:`_apply_time_decay`、`_limit_max_expressions`、`get_expression_patterns` 改用 `ExpressionPatternORM` 模型 +- `time_decay_manager`:完全重写,消除 f-string SQL 注入风险,用显式 ORM 模型处理器替代动态表名拼接,移除对不存在表的引用 +- `enhanced_social_relation_manager`:4 个方法改用 `UserSocialProfile`、`UserSocialRelationComponent`、`SocialRelationHistory` 模型 +- `intelligent_responder`:3 个方法改用 `FilteredMessage`、`RawMessage` 模型及 `func.count`/`func.avg` 聚合 +- `multidimensional_analyzer`:2 个 GROUP BY/HAVING 查询改用 ORM `select().group_by().having()` +- `affection_manager`:3 层级联查询改用 `RawMessage`、`FilteredMessage`、`LearningBatch` 模型 +- `dialog_analyzer`:`get_pending_style_reviews` 改用 `StyleLearningReview` 模型 +- `progressive_learning`、`message_facade`、`webui/learning` 蓝图同步迁移 + +#### 遗留数据库层清理(-7600 行) +- 删除 `services/database/database_manager.py`(6035 行硬编码 SQL 单体) +- 删除 `core/database/` 下 5 个遗留后端文件:`backend_interface.py`、`sqlite_backend.py`、`mysql_backend.py`、`postgresql_backend.py`、`factory.py`(共 1530 行) +- DomainRouter 移除 `_legacy_db` 回退、`get_db_connection()`/`get_connection()` shim、`__getattr__` 安全网 +- `core/database/__init__.py` 精简为仅导出 `DatabaseEngine` +- `services/database/__init__.py` 移除 `DatabaseManager` 导出 + +#### 未使用资源清理 +- 删除 `web_res/static/MacOS-Web-UI/` 源码目录(已迁移至 `static/js/macos/` 和 `static/css/macos/`) + +#### 服务层重组 +- 将 `services/` 下 51 个平铺文件重组为 14 个领域子包,提升内聚性和可维护性 +- 每个子包职责明确:`learning/`、`social/`、`jargon/`、`persona/`、`expression/`、`affection/`、`psychological/`、`reinforcement/`、`message/` 等 + +#### 主模块瘦身 +- 将 `main.py` 业务逻辑提取至独立生命周期模块(`initializer`、`event_handler`、`learning_scheduler` 等) +- 代码量从 2518 行精简至 207 行(减少 92%) + +#### 数据库单体拆分 +- 将 4308 行的 `SQLAlchemyDatabaseManager` 重写为约 800 行的薄路由层(DomainRouter) +- 引入 `BaseFacade` 基类和 11 个领域 Facade,实现关注点分离 +- 所有 62 个消费者方法显式路由到对应 Facade,消除隐式回退 + +#### 领域 Facade 清单 +| Facade | 职责 | 方法数 | +|--------|------|--------| +| `MessageFacade` | 消息存储、查询、统计 | 17 | +| `LearningFacade` | 学习记录、审查、批次、风格学习 | 29 | +| `JargonFacade` | 黑话 CRUD、搜索、统计、全局同步 | 14 | +| `SocialFacade` | 社交关系、用户画像、偏好 | 9 | +| `PersonaFacade` | 人格备份、恢复、更新历史 | 4 | +| `AffectionFacade` | 好感度、Bot 情绪状态 | 6 | +| `PsychologicalFacade` | 情绪画像 | 2 | +| `ExpressionFacade` | 表达模式、风格画像 | 8 | +| `ReinforcementFacade` | 强化学习、人格融合、策略优化 | 6 | +| `MetricsFacade` | 跨域统计聚合 | 3 | +| `AdminFacade` | 数据清理与导出 | 2 | + +#### Repository 层扩展 +- 新增 10 个类型化 Repository 类,总数从 29 增至 39 +- 新增:`RawMessageRepository`、`FilteredMessageRepository`、`BotMessageRepository`、`UserProfileRepository`、`UserPreferencesRepository`、`EmotionProfileRepository`、`StyleProfileRepository`、`BotMoodRepository`、`PersonaBackupRepository`、`KnowledgeGraphRepository` -# Changelog +### 🔧 重构 -所有重要更改都将记录在此文件中。 +#### PluginConfig 迁移 +- 从 `dataclass` 迁移至 pydantic `BaseModel` +- 采用 `ConfigDict(extra="ignore", populate_by_name=True)` 实现健壮验证和未知字段容忍 + +#### 服务缓存优化 +- 新增 `@cached_service` 装饰器,消除冗余服务实例化 +- 替换手工单例模式,减少样板代码 + +#### 数据库连接清理 +- 移除旧版 `DatabaseConnectionPool`,改用 SQLAlchemy 异步引擎内置连接池管理 +- 移除未使用的 `EventBus`、`EventType`、`EventManager` 等事件基础设施 + +### ⚡ 性能优化 + +#### LLM 缓存命中率提升 +- 上下文注入从 `system_prompt` 拼接改为 AstrBot 框架 `extra_user_content_parts` API +- 动态上下文(社交关系、黑话、多样性、V2 学习)作为额外内容块附加在用户消息之后,不再修改系统提示词 +- **system_prompt 保持稳定不变**,最大化 LLM API 前缀缓存(prefix caching)命中率,显著降低 token 消耗和响应延迟 +- 旧版 AstrBot 自动回退到 system_prompt 注入(附带缓存命中率下降警告) + +#### 上下文检索并行化 +- LLM Hook 的 4 个上下文提供者(社交、V2 学习、多样性、黑话)通过 `asyncio.gather` 并行执行 +- Hook 总延迟降低约 60-70%(从串行累加改为取最慢单项) +- 每个提供者独立计时,便于识别性能瓶颈 + +#### 服务实例化缓存 +- 29 个服务方法通过 `@cached_service` 装饰器缓存,避免重复创建服务实例 +- `ServiceFactory` 和 `ComponentFactory` 共享同一缓存字典,跨工厂复用 + +#### 数据处理流水线优化 +- 消息批量写入改为 `asyncio.gather` 并发插入 +- 渐进式学习中消息筛选与人格检索并行执行 +- 强化学习与风格分析并行执行 +- DomainRouter 显式方法路由消除 `__getattr__` 运行时属性查找开销 + +### 📊 统计 +- **净代码减少**:约 21,700 行(ORM 迁移 + 遗留层删除 + 未使用资源清理) +- **遗留 SQL 层**:6035 + 1530 = 7565 行硬编码 SQL 代码删除 +- **ORM 迁移**:7 个服务文件、约 800 行 raw SQL 替换为类型安全的 ORM 查询 +- **安全修复**:`time_decay_manager` f-string SQL 注入漏洞已消除 +- **新增文件**:11 个 Facade + 10 个 Repository + 1 个 BaseFacade = 22 个文件 +- **`SQLAlchemyDatabaseManager`**:4308 行 → ~777 行(减少 82%),零遗留回退 +- **变更文件**:51+ 个服务文件重组、`main.py` 重构、数据库层完全重写 + +--- ## [Next-1.2.9] - 2026-02-19 diff --git a/README.md b/README.md index fbe3a38..5464e80 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@
-[![Version](https://img.shields.io/badge/version-Next--1.2.8-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-GPLv3-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) +[![Version](https://img.shields.io/badge/version-Next--2.0.0-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-GPLv3-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) [核心功能](#-我们能做什么) · [快速开始](#-快速开始) · [管理界面](#-可视化管理界面) · [社区交流](#-社区交流) · [贡献指南](CONTRIBUTING.md) diff --git a/README_EN.md b/README_EN.md index 8055634..250354c 100644 --- a/README_EN.md +++ b/README_EN.md @@ -14,7 +14,7 @@
-[![Version](https://img.shields.io/badge/version-Next--1.2.8-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-GPLv3-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) +[![Version](https://img.shields.io/badge/version-Next--2.0.0-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-GPLv3-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) [Features](#what-we-can-do) · [Quick Start](#quick-start) · [Web UI](#visual-management-interface) · [Community](#community) · [Contributing](CONTRIBUTING.md) diff --git a/VIDEO_SCRIPT.md b/VIDEO_SCRIPT.md new file mode 100644 index 0000000..b4f30b3 --- /dev/null +++ b/VIDEO_SCRIPT.md @@ -0,0 +1,131 @@ +# Self-Learning 插件 Bilibili 视频脚本 + +> 预计时长:约 5 分钟 | 定位:功能讲解 + 使用教程 | 语气:口语化、轻松 + +--- + +## 一、开场 Hook(约 30 秒) + +各位好,先问大家一个问题—— + +你有没有觉得,现在的 AI 聊天机器人,虽然什么都能答,但说话总是……一股"AI味儿"? + +你跟它说"6",它回你"您的意思是数字六吗?"。你在群里发个梗,它给你来一段百科解释。 + +说白了,它不懂你们群的"语言",也不知道你们平时怎么聊天。 + +那有没有办法,让 AI 自己去"偷师",学会像真人一样说话? + +今天给大家介绍的这个插件,就是专门干这件事的。 + +--- + +## 二、一句话介绍(约 20 秒) + +这个插件叫 **Self-Learning**,是给 AstrBot 用的一个自主学习插件。 + +它能做的事情用一句话概括就是——**让你的 Bot 潜伏在群里,自动学习大家的说话方式,然后越聊越像真人。** + +不需要你手动喂数据,不需要你写什么 prompt 模板,它全自动完成。 + +--- + +## 三、核心功能讲解(约 3 分钟) + +接下来我挨个说说它到底能干什么。 + +### 1. 自动学群友说话(约 40 秒) + +首先,最核心的能力——**表达模式学习**。 + +插件开启之后,它会在后台默默收集群里的聊天消息。然后定时触发一次学习,用大模型分析这些对话:**在什么场景下,大家会用什么样的表达方式。** + +比如它可能会学到:表示赞同的时候,群友喜欢说"确实"而不是"我同意你的看法"。 + +这些学到的表达模式会被自动注入到 Bot 的回复里,这样 Bot 说话就不会那么正式、那么像机器了。 + +而且它有一个时间衰减机制,过时的表达会自动降权,新学到的会优先使用。所以 Bot 的说话风格会跟着群聊氛围一起"进化"。 + +### 2. 听得懂黑话(约 30 秒) + +第二个功能——**黑话挖掘**。 + +每个群都有自己的"黑话"对吧?比如某个群里"发财了"可能是表示"太好了"的意思,"下次一定"其实是在拒绝。 + +这些东西你不教,AI 是真不懂。 + +这个插件会自动检测群里的高频特殊用语,然后调用大模型根据上下文推断它的真实含义,保存下来。之后 Bot 在回复消息的时候,就能正确理解这些黑话了,不会再闹笑话。 + +### 3. 社交关系网络(约 30 秒) + +第三个——**社交关系分析**。 + +插件不光学说话,它还会"看人"。它会自动记录群里谁跟谁聊得多、谁 at 了谁、谁经常回复谁,把这些互动关系整理成一张社交网络图。 + +这有什么用呢?Bot 回复的时候,它知道这个群里谁跟谁关系好、谁是活跃分子、谁是边缘人。这样它聊天的时候就能更"懂事",不会在两个关系很好的人面前说不合时宜的话。 + +在管理后台里你还能看到一张可视化的关系图谱,节点越大说明这个人越活跃,连线越粗说明两个人互动越频繁,挺有意思的。 + +### 4. 好感度系统(约 30 秒) + +第四个——**好感度和情绪系统**。 + +插件会记录每个人跟 Bot 的互动。经常夸它、跟它友好聊天,好感度就会涨;反过来骂它,好感度就掉。 + +好感度会影响 Bot 的回复态度。对喜欢的人说话更热情、更主动,对不喜欢的人就冷淡一些。 + +Bot 自己还有一套情绪系统,每天会自动切换心情。开心的时候活泼一点,低落的时候话少一点。这样聊起来就更有"人味儿"了。 + +### 5. 人格审查——你说了算(约 30 秒) + +可能有人会担心:Bot 自己学习,万一学歪了怎么办? + +放心,插件有一个**人格审查机制**。Bot 学完之后生成的人格更新建议,不会直接应用,而是先提交到审查队列里。 + +你可以在管理后台看到它打算怎么改、改了哪些内容,觉得没问题就批准,觉得不对就驳回。**最终决定权始终在你手里。** + +### 6. 可视化管理后台(约 30 秒) + +可能有人会担心:Bot 自己学习,万一学歪了怎么办? + +放心,插件有一个**人格审查机制**。Bot 学完之后生成的人格更新建议,不会直接应用,而是先提交到审查队列里。 + +你可以在管理后台看到它打算怎么改、改了哪些内容,觉得没问题就批准,觉得不对就驳回。**最终决定权始终在你手里。** + +### 5. 可视化管理后台(约 30 秒) + +最后说一下**WebUI 管理界面**。 + +插件带了一个完整的网页后台,默认端口 7833,浏览器直接就能访问。 + +里面能看到消息收集的数据统计、学习进度、社交关系网络图、好感度排行榜,还有刚才说的人格审查和风格学习详情。 + +基本上 Bot 在做什么、学到了什么、效果怎么样,一目了然。不需要去翻日志,也不用敲命令,全部可视化搞定。 + +--- + +## 四、安装使用(约 30 秒) + +说了这么多,怎么用呢?非常简单。 + +第一步,确保你已经在跑 AstrBot,版本不低于 4.11.4。 + +第二步,在 AstrBot 的插件商店里搜索 **self-learning**,一键安装。 + +第三步,到 AstrBot 后台的插件配置里,把"启用消息抓取"和"启用自动学习"打开。 + +然后你就不用管了。它会自动开始收集消息、自动学习。过几个小时你再去看,Bot 说话就已经开始变了。 + +如果你想看详细数据,浏览器打开 `你的服务器地址:7833`,登录管理后台就行。默认密码在配置里,记得第一次登录之后改掉。 + +--- + +## 收尾(约 20 秒) + +总结一下,这个插件做的事情就是:**让你的 AI Bot 从一个"什么都会但不会说人话"的机器,变成一个能融入群聊、听得懂黑话、记得住谁对它好的"拟人化 Bot"。** + +感兴趣的话,GitHub 搜 **astrbot_plugin_self_learning** 就能找到,觉得有用欢迎给个 Star。 + +遇到问题或者想交流,可以加 QQ 群 **1021544792**。 + +好,这期就到这里,我们下期见。 diff --git a/_conf_schema.json b/_conf_schema.json index f140287..1e8089b 100644 --- a/_conf_schema.json +++ b/_conf_schema.json @@ -552,5 +552,42 @@ "default": 7 } } + }, + "V2_Architecture_Settings": { + "description": "v2架构升级配置", + "type": "object", + "hint": "高级功能配置:Embedding向量化、Reranker重排序、知识引擎和记忆引擎。需要先在AstrBot中配置对应类型的Provider", + "items": { + "embedding_provider_id": { + "description": "Embedding 提供商 ID", + "type": "string", + "hint": "填写Embedding提供商的完整ID(如 'openai/text-embedding-3-large')。需要先在AstrBot的Provider管理中创建Embedding类型的提供商,然后将其ID填写到此处。格式通常为 '来源名/模型名'", + "default": "" + }, + "rerank_provider_id": { + "description": "Reranker 提供商 ID", + "type": "string", + "hint": "填写Reranker提供商的完整ID(如 'openai/qwen3-rerank')。需要先在AstrBot的Provider管理中创建Reranker类型的提供商,然后将其ID填写到此处。格式通常为 '来源名/模型名'", + "default": "" + }, + "rerank_top_k": { + "description": "重排序保留结果数", + "type": "int", + "hint": "Reranker重排序后保留的Top-K结果数量,值越小越精准但可能遗漏", + "default": 5 + }, + "knowledge_engine": { + "description": "知识引擎", + "type": "string", + "hint": "知识存储引擎类型。legacy=现有NetworkX实现,lightrag=使用LightRAG进行向量+图谱混合检索(需配置Embedding提供商)", + "default": "legacy" + }, + "memory_engine": { + "description": "记忆引擎", + "type": "string", + "hint": "记忆管理引擎类型。legacy=现有实现,mem0=使用mem0进行自动记忆提取和检索(需配置Embedding提供商)", + "default": "legacy" + } + } } } \ No newline at end of file diff --git a/config.py b/config.py index a5b9f22..22d9e77 100644 --- a/config.py +++ b/config.py @@ -4,219 +4,226 @@ import os import json from typing import List, Optional -from dataclasses import dataclass, field, asdict + +from pydantic import BaseModel, Field, ConfigDict from astrbot.api import logger -@dataclass -class PluginConfig: +class PluginConfig(BaseModel): """插件配置类""" - + + model_config = ConfigDict(extra="ignore", populate_by_name=True) + # 基础开关 enable_message_capture: bool = True - enable_auto_learning: bool = True + enable_auto_learning: bool = True enable_realtime_learning: bool = False - enable_realtime_llm_filter: bool = False # 新增:控制实时LLM筛选 + enable_realtime_llm_filter: bool = False # 新增:控制实时LLM筛选 enable_web_interface: bool = True web_interface_port: int = 7833 # 新增 Web 界面端口配置 - + # MaiBot增强功能(默认启用) - enable_maibot_features: bool = True # 启用MaiBot增强功能 - enable_expression_patterns: bool = True # 启用表达模式学习 - enable_memory_graph: bool = True # 启用记忆图系统 - enable_knowledge_graph: bool = True # 启用知识图谱 - enable_time_decay: bool = True # 启用时间衰减机制 - + enable_maibot_features: bool = True # 启用MaiBot增强功能 + enable_expression_patterns: bool = True # 启用表达模式学习 + enable_memory_graph: bool = True # 启用记忆图系统 + enable_knowledge_graph: bool = True # 启用知识图谱 + enable_time_decay: bool = True # 启用时间衰减机制 + # QQ号设置 - target_qq_list: List[str] = field(default_factory=list) - target_blacklist: List[str] = field(default_factory=list) # 学习黑名单 - + target_qq_list: List[str] = Field(default_factory=list) + target_blacklist: List[str] = Field(default_factory=list) # 学习黑名单 + # LLM 提供商 ID(使用 AstrBot 框架的 Provider 系统) - filter_provider_id: Optional[str] = None # 筛选模型使用的提供商ID - refine_provider_id: Optional[str] = None # 提炼模型使用的提供商ID + filter_provider_id: Optional[str] = None # 筛选模型使用的提供商ID + refine_provider_id: Optional[str] = None # 提炼模型使用的提供商ID reinforce_provider_id: Optional[str] = None # 强化模型使用的提供商ID - + + # v2 Architecture: Embedding provider (framework-managed) + embedding_provider_id: Optional[str] = None + + # v2 Architecture: Reranker provider (framework-managed) + rerank_provider_id: Optional[str] = None + rerank_top_k: int = 5 + + # v2 Architecture: Knowledge engine + knowledge_engine: str = "legacy" # "lightrag" | "legacy" + + # v2 Architecture: Memory engine + memory_engine: str = "legacy" # "mem0" | "legacy" + # 当前人格设置 current_persona_name: str = "default" - + # 学习参数 - learning_interval_hours: int = 6 # 学习间隔(小时) - min_messages_for_learning: int = 50 # 最少消息数量才开始学习 - max_messages_per_batch: int = 200 # 每批处理的最大消息数量 - + learning_interval_hours: int = 6 # 学习间隔(小时) + min_messages_for_learning: int = 50 # 最少消息数量才开始学习 + max_messages_per_batch: int = 200 # 每批处理的最大消息数量 + # 筛选参数 - message_min_length: int = 5 # 消息最小长度 - message_max_length: int = 500 # 消息最大长度 - confidence_threshold: float = 0.7 # 筛选置信度阈值 - relevance_threshold: float = 0.6 # 相关性阈值 - + message_min_length: int = 5 # 消息最小长度 + message_max_length: int = 500 # 消息最大长度 + confidence_threshold: float = 0.7 # 筛选置信度阈值 + relevance_threshold: float = 0.6 # 相关性阈值 + # 风格分析参数 - style_analysis_batch_size: int = 100 # 风格分析批次大小 - style_update_threshold: float = 0.6 # 风格更新阈值 (降低阈值,从0.8改为0.6) - + style_analysis_batch_size: int = 100 # 风格分析批次大小 + style_update_threshold: float = 0.6 # 风格更新阈值 (降低阈值,从0.8改为0.6) + # 消息统计 - total_messages_collected: int = 0 # 收集到的消息总数 - + total_messages_collected: int = 0 # 收集到的消息总数 + # 机器学习设置 - enable_ml_analysis: bool = True # 启用ML分析 - max_ml_sample_size: int = 100 # ML样本最大数量 - ml_cache_timeout_hours: int = 1 # ML缓存超时 - + enable_ml_analysis: bool = True # 启用ML分析 + max_ml_sample_size: int = 100 # ML样本最大数量 + ml_cache_timeout_hours: int = 1 # ML缓存超时 + # 人格备份设置 - auto_backup_enabled: bool = True # 启用自动备份 - backup_interval_hours: int = 24 # 备份间隔 - max_backups_per_group: int = 10 # 每群最大备份数 - auto_apply_approved_persona: bool = False # 审查批准后自动应用到默认人格(危险功能,默认关闭) - + auto_backup_enabled: bool = True # 启用自动备份 + backup_interval_hours: int = 24 # 备份间隔 + max_backups_per_group: int = 10 # 每群最大备份数 + auto_apply_approved_persona: bool = False # 审查批准后自动应用到默认人格(危险功能,默认关闭) + # 高级设置 - debug_mode: bool = False # 调试模式 - save_raw_messages: bool = True # 保存原始消息 - auto_backup_interval_days: int = 7 # 自动备份间隔 - + debug_mode: bool = False # 调试模式 + save_raw_messages: bool = True # 保存原始消息 + auto_backup_interval_days: int = 7 # 自动备份间隔 + # PersonaUpdater配置 - persona_merge_strategy: str = "smart" # 人格合并策略: "replace", "append", "prepend", "smart" - max_mood_imitation_dialogs: int = 20 # 最大对话风格模仿数量 - enable_persona_evolution: bool = True # 启用人格演化跟踪 - persona_compatibility_threshold: float = 0.6 # 人格兼容性阈值 - + persona_merge_strategy: str = "smart" # 人格合并策略: "replace", "append", "prepend", "smart" + max_mood_imitation_dialogs: int = 20 # 最大对话风格模仿数量 + enable_persona_evolution: bool = True # 启用人格演化跟踪 + persona_compatibility_threshold: float = 0.6 # 人格兼容性阈值 + # 人格更新方式配置 - use_persona_manager_updates: bool = True # 使用PersonaManager进行增量更新(False=使用文件临时存储,True=使用PersonaManager) - auto_apply_persona_updates: bool = True # 自动应用人格更新(仅在use_persona_manager_updates=True时生效) - persona_update_backup_enabled: bool = True # 启用更新前备份 - + use_persona_manager_updates: bool = True # 使用PersonaManager进行增量更新(False=使用文件临时存储,True=使用PersonaManager) + auto_apply_persona_updates: bool = True # 自动应用人格更新(仅在use_persona_manager_updates=True时生效) + persona_update_backup_enabled: bool = True # 启用更新前备份 + # 好感度系统配置 - enable_affection_system: bool = True # 启用好感度系统 - max_total_affection: int = 250 # bot总好感度满分值 - max_user_affection: int = 100 # 单个用户最大好感度 - affection_decay_rate: float = 0.95 # 好感度衰减比例 - daily_mood_change: bool = True # 启用每日情绪变化 - mood_affect_affection: bool = True # 情绪影响好感度变化 - + enable_affection_system: bool = True # 启用好感度系统 + max_total_affection: int = 250 # bot总好感度满分值 + max_user_affection: int = 100 # 单个用户最大好感度 + affection_decay_rate: float = 0.95 # 好感度衰减比例 + daily_mood_change: bool = True # 启用每日情绪变化 + mood_affect_affection: bool = True # 情绪影响好感度变化 + # 情绪系统配置 - enable_daily_mood: bool = True # 启用每日情绪 + enable_daily_mood: bool = True # 启用每日情绪 enable_startup_random_mood: bool = True # 启用启动时随机情绪初始化 - mood_change_hour: int = 6 # 情绪更新时间(24小时制) - mood_persistence_hours: int = 24 # 情绪持续时间 - + mood_change_hour: int = 6 # 情绪更新时间(24小时制) + mood_persistence_hours: int = 24 # 情绪持续时间 + # 存储路径(内部配置,用户通常不需要修改) messages_db_path: Optional[str] = None learning_log_path: Optional[str] = None - + # 用户可配置的存储路径(放在最后,用户可以自定义) - data_dir: str = "./data/self_learning_data" # 插件数据存储目录 + data_dir: str = "./data/self_learning_data" # 插件数据存储目录 # API设置 - api_key: str = "" # 外部API访问密钥 - enable_api_auth: bool = False # 是否启用API密钥认证 + api_key: str = "" # 外部API访问密钥 + enable_api_auth: bool = False # 是否启用API密钥认证 # 数据库设置 - db_type: str = "sqlite" # 数据库类型: sqlite、mysql 或 postgresql + db_type: str = "sqlite" # 数据库类型: sqlite、mysql 或 postgresql # MySQL 配置 - mysql_host: str = "localhost" # MySQL主机地址 - mysql_port: int = 3306 # MySQL端口 - mysql_user: str = "root" # MySQL用户名 - mysql_password: str = "" # MySQL密码 - mysql_database: str = "astrbot_self_learning" # MySQL数据库名 + mysql_host: str = "localhost" # MySQL主机地址 + mysql_port: int = 3306 # MySQL端口 + mysql_user: str = "root" # MySQL用户名 + mysql_password: str = "" # MySQL密码 + mysql_database: str = "astrbot_self_learning" # MySQL数据库名 # PostgreSQL 配置 - postgresql_host: str = "localhost" # PostgreSQL主机地址 - postgresql_port: int = 5432 # PostgreSQL端口 - postgresql_user: str = "postgres" # PostgreSQL用户名 - postgresql_password: str = "" # PostgreSQL密码 - postgresql_database: str = "astrbot_self_learning" # PostgreSQL数据库名 - postgresql_schema: str = "public" # PostgreSQL Schema + postgresql_host: str = "localhost" # PostgreSQL主机地址 + postgresql_port: int = 5432 # PostgreSQL端口 + postgresql_user: str = "postgres" # PostgreSQL用户名 + postgresql_password: str = "" # PostgreSQL密码 + postgresql_database: str = "astrbot_self_learning" # PostgreSQL数据库名 + postgresql_schema: str = "public" # PostgreSQL Schema # 连接池配置 - max_connections: int = 10 # 数据库连接池最大连接数 - min_connections: int = 2 # 数据库连接池最小连接数 + max_connections: int = 10 # 数据库连接池最大连接数 + min_connections: int = 2 # 数据库连接池最小连接数 # 社交关系注入设置(与_conf_schema.json一致) - enable_social_context_injection: bool = True # 启用社交关系上下文注入到prompt - include_social_relations: bool = True # 注入用户社交关系网络信息 - include_affection_info: bool = True # 注入好感度信息 - include_mood_info: bool = True # 注入Bot情绪信息 - context_injection_position: str = "start" # 上下文注入位置: "start" 或 "end" + enable_social_context_injection: bool = True # 启用社交关系上下文注入到prompt + include_social_relations: bool = True # 注入用户社交关系网络信息 + include_affection_info: bool = True # 注入好感度信息 + include_mood_info: bool = True # 注入Bot情绪信息 + context_injection_position: str = "start" # 上下文注入位置: "start" 或 "end" # LLM Hook 注入位置设置(v1.1.1新增) # 控制注入内容添加到 req.system_prompt 还是 req.prompt # - "system_prompt": 注入到系统提示(推荐,不会被保存到对话历史) # - "prompt": 注入到用户消息(旧版行为,会导致对话历史膨胀) - llm_hook_injection_target: str = "system_prompt" # 可选值: "system_prompt" 或 "prompt" + llm_hook_injection_target: str = "system_prompt" # 可选值: "system_prompt" 或 "prompt" # 目标驱动对话配置 - enable_goal_driven_chat: bool = False # 启用目标驱动对话 - goal_session_timeout_hours: int = 24 # 会话超时时间(小时) - goal_auto_detect: bool = True # 自动检测对话目标 - goal_max_conversation_history: int = 40 # 最大对话历史(轮次*2) + enable_goal_driven_chat: bool = False # 启用目标驱动对话 + goal_session_timeout_hours: int = 24 # 会话超时时间(小时) + goal_auto_detect: bool = True # 自动检测对话目标 + goal_max_conversation_history: int = 40 # 最大对话历史(轮次*2) # 重构功能配置(新增) - # ⚠️ 强制使用 SQLAlchemy ORM:统一 SQLite 和 MySQL 的表结构定义 - use_sqlalchemy: bool = True # ✨ 硬编码为 True,确保所有数据库操作使用 ORM 模型 - use_enhanced_managers: bool = False # 使用增强型管理器(False=使用原始实现) - enable_memory_cleanup: bool = True # 启用记忆自动清理(每天凌晨3点) - memory_cleanup_days: int = 30 # 记忆保留天数(低于阈值的旧记忆会被清理) - memory_importance_threshold: float = 0.3 # 记忆重要性阈值(低于此值的会被清理) + # 强制使用 SQLAlchemy ORM:统一 SQLite 和 MySQL 的表结构定义 + use_sqlalchemy: bool = True # 硬编码为 True,确保所有数据库操作使用 ORM 模型 + enable_memory_cleanup: bool = True # 启用记忆自动清理(每天凌晨3点) + memory_cleanup_days: int = 30 # 记忆保留天数(低于阈值的旧记忆会被清理) + memory_importance_threshold: float = 0.3 # 记忆重要性阈值(低于此值的会被清理) # Repository数据访问层配置(新增) - default_review_limit: int = 50 # 默认审查记录查询数量 - default_pattern_limit: int = 10 # 默认表达模式查询数量 - default_memory_limit: int = 50 # 默认记忆查询数量 - default_affection_limit: int = 50 # 默认好感度记录查询数量 - default_social_limit: int = 50 # 默认社交记录查询数量 - default_psychological_limit: int = 20 # 默认心理状态记录查询数量 - max_interaction_batch_size: int = 100 # 最大交互批处理数量 - top_patterns_limit: int = 10 # 顶级模式查询数量 - recent_interactions_limit: int = 20 # 近期交互查询数量 - trend_analysis_days: int = 7 # 趋势分析天数 - - - def __post_init__(self): - """初始化后处理""" - # 这些路径的默认值和目录创建应在外部(如主插件类)处理 - pass + default_review_limit: int = 50 # 默认审查记录查询数量 + default_pattern_limit: int = 10 # 默认表达模式查询数量 + default_memory_limit: int = 50 # 默认记忆查询数量 + default_affection_limit: int = 50 # 默认好感度记录查询数量 + default_social_limit: int = 50 # 默认社交记录查询数量 + default_psychological_limit: int = 20 # 默认心理状态记录查询数量 + max_interaction_batch_size: int = 100 # 最大交互批处理数量 + top_patterns_limit: int = 10 # 顶级模式查询数量 + recent_interactions_limit: int = 20 # 近期交互查询数量 + trend_analysis_days: int = 7 # 趋势分析天数 @classmethod def create_from_config(cls, config: dict, data_dir: Optional[str] = None) -> 'PluginConfig': """从AstrBot配置创建插件配置""" - + # 确保 data_dir 不为空 if not data_dir: data_dir = "./data/self_learning_data" logger.warning(f"data_dir 为空,使用默认值: {data_dir}") - + # 从配置中提取各个配置组 # 根据 _conf_schema.json 的结构,配置项是直接在顶层,而不是嵌套在 'self_learning_settings' 下 basic_settings = config.get('Self_Learning_Basic', {}) target_settings = config.get('Target_Settings', {}) - model_config = config.get('Model_Configuration', {}) + model_configuration = config.get('Model_Configuration', {}) - # ✅ 添加调试日志:显示原始配置数据 - logger.info(f"🔍 [配置加载] Model_Configuration原始数据: {model_config}") - logger.info(f"🔍 [配置加载] filter_provider_id: {model_config.get('filter_provider_id', 'NOT_FOUND')}") - logger.info(f"🔍 [配置加载] refine_provider_id: {model_config.get('refine_provider_id', 'NOT_FOUND')}") - logger.info(f"🔍 [配置加载] reinforce_provider_id: {model_config.get('reinforce_provider_id', 'NOT_FOUND')}") + # 添加调试日志:显示原始配置数据 + logger.info(f" [配置加载] Model_Configuration原始数据: {model_configuration}") + logger.info(f" [配置加载] filter_provider_id: {model_configuration.get('filter_provider_id', 'NOT_FOUND')}") + logger.info(f" [配置加载] refine_provider_id: {model_configuration.get('refine_provider_id', 'NOT_FOUND')}") + logger.info(f" [配置加载] reinforce_provider_id: {model_configuration.get('reinforce_provider_id', 'NOT_FOUND')}") learning_params = config.get('Learning_Parameters', {}) filter_params = config.get('Filter_Parameters', {}) style_analysis = config.get('Style_Analysis', {}) advanced_settings = config.get('Advanced_Settings', {}) ml_settings = config.get('Machine_Learning_Settings', {}) - # 删除智能回复设置的获取 - # intelligent_reply_settings = config.get('Intelligent_Reply_Settings', {}) persona_backup_settings = config.get('Persona_Backup_Settings', {}) affection_settings = config.get('Affection_System_Settings', {}) mood_settings = config.get('Mood_System_Settings', {}) storage_settings = config.get('Storage_Settings', {}) api_settings = config.get('API_Settings', {}) - database_settings = config.get('Database_Settings', {}) # 新增:数据库设置 - social_context_settings = config.get('Social_Context_Settings', {}) # 新增:社交上下文设置 - repository_settings = config.get('Repository_Settings', {}) # 新增:Repository配置 - goal_driven_chat_settings = config.get('Goal_Driven_Chat_Settings', {}) # 新增:目标驱动对话设置 + database_settings = config.get('Database_Settings', {}) # 新增:数据库设置 + social_context_settings = config.get('Social_Context_Settings', {}) # 新增:社交上下文设置 + repository_settings = config.get('Repository_Settings', {}) # 新增:Repository配置 + goal_driven_chat_settings = config.get('Goal_Driven_Chat_Settings', {}) # 新增:目标驱动对话设置 + v2_settings = config.get('V2_Architecture_Settings', {}) # v2架构升级设置 - # ✅ 添加调试日志:显示目标驱动对话配置数据 - logger.info(f"🔍 [配置加载] Goal_Driven_Chat_Settings原始数据: {goal_driven_chat_settings}") - logger.info(f"🔍 [配置加载] enable_goal_driven_chat: {goal_driven_chat_settings.get('enable_goal_driven_chat', 'NOT_FOUND')}") + # 添加调试日志:显示目标驱动对话配置数据 + logger.info(f" [配置加载] Goal_Driven_Chat_Settings原始数据: {goal_driven_chat_settings}") + logger.info(f" [配置加载] enable_goal_driven_chat: {goal_driven_chat_settings.get('enable_goal_driven_chat', 'NOT_FOUND')}") return cls( enable_message_capture=basic_settings.get('enable_message_capture', True), @@ -224,45 +231,50 @@ def create_from_config(cls, config: dict, data_dir: Optional[str] = None) -> 'Pl enable_realtime_learning=basic_settings.get('enable_realtime_learning', False), enable_web_interface=basic_settings.get('enable_web_interface', True), web_interface_port=basic_settings.get('web_interface_port', 7833), # Web 界面端口配置 - + target_qq_list=target_settings.get('target_qq_list', []), target_blacklist=target_settings.get('target_blacklist', []), current_persona_name=target_settings.get('current_persona_name', 'default'), - - filter_provider_id=model_config.get('filter_provider_id', None), - refine_provider_id=model_config.get('refine_provider_id', None), - reinforce_provider_id=model_config.get('reinforce_provider_id', None), - + + filter_provider_id=model_configuration.get('filter_provider_id', None), + refine_provider_id=model_configuration.get('refine_provider_id', None), + reinforce_provider_id=model_configuration.get('reinforce_provider_id', None), + + # v2 Architecture + embedding_provider_id=v2_settings.get('embedding_provider_id', None), + rerank_provider_id=v2_settings.get('rerank_provider_id', None), + rerank_top_k=v2_settings.get('rerank_top_k', 5), + knowledge_engine=v2_settings.get('knowledge_engine', 'legacy'), + memory_engine=v2_settings.get('memory_engine', 'legacy'), + learning_interval_hours=learning_params.get('learning_interval_hours', 6), min_messages_for_learning=learning_params.get('min_messages_for_learning', 50), max_messages_per_batch=learning_params.get('max_messages_per_batch', 200), - + message_min_length=filter_params.get('message_min_length', 5), message_max_length=filter_params.get('message_max_length', 500), confidence_threshold=filter_params.get('confidence_threshold', 0.7), relevance_threshold=filter_params.get('relevance_threshold', 0.6), - + style_analysis_batch_size=style_analysis.get('style_analysis_batch_size', 100), style_update_threshold=style_analysis.get('style_update_threshold', 0.8), - + # 消息统计 (这个字段通常不是从外部配置加载,而是内部维护的,这里保留默认值) - total_messages_collected=0, - + total_messages_collected=0, + enable_ml_analysis=ml_settings.get('enable_ml_analysis', True), max_ml_sample_size=ml_settings.get('max_ml_sample_size', 100), ml_cache_timeout_hours=ml_settings.get('ml_cache_timeout_hours', 1), - - # 删除了智能回复相关配置 - + auto_backup_enabled=persona_backup_settings.get('auto_backup_enabled', True), backup_interval_hours=persona_backup_settings.get('backup_interval_hours', 24), max_backups_per_group=persona_backup_settings.get('max_backups_per_group', 10), auto_apply_approved_persona=advanced_settings.get('auto_apply_approved_persona', False), - + debug_mode=advanced_settings.get('debug_mode', False), save_raw_messages=advanced_settings.get('save_raw_messages', True), auto_backup_interval_days=advanced_settings.get('auto_backup_interval_days', 7), - + # 好感度系统配置 enable_affection_system=affection_settings.get('enable_affection_system', True), max_total_affection=affection_settings.get('max_total_affection', 250), @@ -270,13 +282,13 @@ def create_from_config(cls, config: dict, data_dir: Optional[str] = None) -> 'Pl affection_decay_rate=affection_settings.get('affection_decay_rate', 0.95), daily_mood_change=affection_settings.get('daily_mood_change', True), mood_affect_affection=affection_settings.get('mood_affect_affection', True), - + # 情绪系统配置 enable_daily_mood=mood_settings.get('enable_daily_mood', True), enable_startup_random_mood=mood_settings.get('enable_startup_random_mood', True), mood_change_hour=mood_settings.get('mood_change_hour', 6), mood_persistence_hours=mood_settings.get('mood_persistence_hours', 24), - + # PersonaUpdater配置 (这些可能不是直接从 _conf_schema.json 的顶层获取,而是从其他地方或默认值) persona_merge_strategy=config.get('persona_merge_strategy', 'smart'), max_mood_imitation_dialogs=config.get('max_mood_imitation_dialogs', 20), @@ -304,9 +316,8 @@ def create_from_config(cls, config: dict, data_dir: Optional[str] = None) -> 'Pl min_connections=database_settings.get('min_connections', 2), # 重构功能配置 - # ⚠️ 强制使用 SQLAlchemy ORM,忽略配置文件中的设置 - use_sqlalchemy=True, # 硬编码为 True - use_enhanced_managers=advanced_settings.get('use_enhanced_managers', False), + # 强制使用 SQLAlchemy ORM,忽略配置文件中的设置 + use_sqlalchemy=True, # 硬编码为 True enable_memory_cleanup=advanced_settings.get('enable_memory_cleanup', True), memory_cleanup_days=advanced_settings.get('memory_cleanup_days', 30), memory_importance_threshold=advanced_settings.get('memory_importance_threshold', 0.3), @@ -347,64 +358,62 @@ def create_default(cls) -> 'PluginConfig': def to_dict(self) -> dict: """转换为字典格式""" - # 使用 asdict 可以确保所有字段都被包含 - return asdict(self) + return self.model_dump() - def validate(self) -> List[str]: + def validate_config(self) -> List[str]: """验证配置有效性,返回错误信息列表""" errors = [] - + if self.learning_interval_hours <= 0: errors.append("学习间隔必须大于0小时") - + if self.min_messages_for_learning <= 0: errors.append("最少学习消息数量必须大于0") - + if self.max_messages_per_batch <= 0: errors.append("每批最大消息数量必须大于0") - + if self.message_min_length >= self.message_max_length: errors.append("消息最小长度必须小于最大长度") - + if not 0 <= self.confidence_threshold <= 1: errors.append("置信度阈值必须在0-1之间") - + if not 0 <= self.style_update_threshold <= 1: errors.append("风格更新阈值必须在0-1之间") - + # 提示性警告而非错误 provider_warnings = [] if not self.filter_provider_id: provider_warnings.append("未配置筛选模型提供商ID,将尝试自动配置或使用备选模型") - + if not self.refine_provider_id: provider_warnings.append("未配置提炼模型提供商ID,将尝试自动配置或使用备选模型") - + if not self.reinforce_provider_id: provider_warnings.append("未配置强化模型提供商ID,将尝试自动配置或使用备选模型") - + # 只有当没有配置任何Provider时才作为错误 if not self.filter_provider_id and not self.refine_provider_id and not self.reinforce_provider_id: errors.append("至少需要配置一个模型提供商ID,建议在AstrBot中配置Provider并在插件配置中指定") elif provider_warnings: # 将警告添加到错误列表用于信息展示(但不会阻止插件运行) - errors.extend([f"⚠️ {warning}" for warning in provider_warnings]) - + errors.extend([f" {warning}" for warning in provider_warnings]) + return errors - + def save_to_file(self, filepath: str) -> bool: """保存配置到文件""" try: - config_data = asdict(self) os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, 'w', encoding='utf-8') as f: - json.dump(config_data, f, indent=2, ensure_ascii=False) + f.write(self.model_dump_json(indent=2)) logger.info(f"配置已保存到: {filepath}") return True except Exception as e: logger.error(f"保存配置失败: {e}") return False - + @classmethod def load_from_file(cls, filepath: str, data_dir: Optional[str] = None) -> 'PluginConfig': """从文件加载配置""" @@ -412,13 +421,13 @@ def load_from_file(cls, filepath: str, data_dir: Optional[str] = None) -> 'Plugi if os.path.exists(filepath): with open(filepath, 'r', encoding='utf-8') as f: config_data = json.load(f) - + # 设置 data_dir if data_dir: config_data['data_dir'] = data_dir - - # 创建配置实例 - config = cls(**config_data) + + # 创建配置实例(extra="ignore" 会忽略未知字段) + config = cls.model_validate(config_data) logger.info(f"配置已从文件加载: {filepath}") return config else: diff --git a/constants.py b/constants.py index aa37a7a..1d5b05f 100644 --- a/constants.py +++ b/constants.py @@ -3,7 +3,7 @@ 避免字符串匹配混淆,使用明确的枚举常量 """ -# ============= 人格审查更新类型常量 ============= +# 人格审查更新类型常量 # 渐进式人格学习(从对话中学习的人格更新) UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING = "progressive_persona_learning" diff --git a/core/__init__.py b/core/__init__.py index a9c4474..f68d264 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -3,27 +3,23 @@ """ from .factory import ServiceFactory -from .patterns import EventBus, ServiceRegistry, AsyncServiceBase, LearningContext, LearningContextBuilder, StrategyFactory, ConfigurationManager, MetricsCollector +from .patterns import ServiceRegistry, AsyncServiceBase, LearningContext, LearningContextBuilder, StrategyFactory from .interfaces import ( - IMessageCollector, IMessageFilter, IStyleAnalyzer, ILearningStrategy, - IQualityMonitor, IPersonaManager, IPersonaUpdater, IPersonaBackupManager, - IDataStorage, IObserver, IEventPublisher, IServiceFactory, IAsyncService, - IMLAnalyzer, IIntelligentResponder, ServiceLifecycle, MessageData, - AnalysisResult, LearningStrategyType, AnalysisType, EventType, + IMessageCollector, IMessageFilter, IStyleAnalyzer, ILearningStrategy, + IQualityMonitor, IPersonaManager, IPersonaUpdater, IPersonaBackupManager, + IDataStorage, IServiceFactory, IAsyncService, + IMLAnalyzer, IIntelligentResponder, ServiceLifecycle, MessageData, + AnalysisResult, LearningStrategyType, AnalysisType, ServiceError, StyleAnalysisError, ConfigurationError, DataStorageError, PersonaUpdateError - ) __all__ = [ 'ServiceFactory', - 'EventBus', 'ServiceRegistry', 'AsyncServiceBase', 'LearningContext', 'LearningContextBuilder', 'StrategyFactory', - 'ConfigurationManager', - 'MetricsCollector', 'IMessageCollector', 'IMessageFilter', 'IStyleAnalyzer', @@ -33,8 +29,6 @@ 'IPersonaUpdater', 'IPersonaBackupManager', 'IDataStorage', - 'IObserver', - 'IEventPublisher', 'IServiceFactory', 'IAsyncService', 'IMLAnalyzer', @@ -44,9 +38,7 @@ 'AnalysisResult', 'LearningStrategyType', 'AnalysisType', - 'EventType', 'ServiceError', - # 'AnalysisError', 'ConfigurationError', 'DataStorageError', 'PersonaUpdateError' diff --git a/core/compatibility_extensions.py b/core/compatibility_extensions.py deleted file mode 100644 index 342b8cd..0000000 --- a/core/compatibility_extensions.py +++ /dev/null @@ -1,301 +0,0 @@ -""" -方法接口兼容性扩展 - 为新服务提供必要的接口方法 -""" -import json -import time -from typing import Dict, List, Optional, Any - - -class LLMClientExtension: - """LLM客户端扩展,提供统一的生成接口 - 已弃用,建议使用FrameworkLLMAdapter""" - - def __init__(self, llm_client, config, persona_manager=None, llm_adapter=None): - self.llm_client = llm_client - self.config = config - self.persona_manager = persona_manager - self.llm_adapter = llm_adapter # 新增适配器支持 - - async def generate_response(self, prompt: str, model_name: Optional[str] = None, - group_id: Optional[str] = None, **kwargs) -> str: - """生成响应的统一接口,自动包含当前人格信息""" - try: - # 获取当前人格信息 - system_prompt = None - if self.persona_manager and group_id: - try: - if hasattr(self.persona_manager, 'get_current_persona_description'): - persona_description = await self.persona_manager.get_current_persona_description(group_id) - else: - # 兼容性处理 - persona_ext = PersonaManagerExtension(self.persona_manager) - persona_description = await persona_ext.get_current_persona_description(group_id) - - if persona_description: - system_prompt = f"你的人格特征:{persona_description}\n\n请根据上述人格特征来回应用户。" - except Exception as e: - from astrbot.api import logger - logger.error(f"获取人格描述失败: {e}") - - # 优先使用新的适配器 - if self.llm_adapter and self.llm_adapter.has_filter_provider(): - response = await self.llm_adapter.filter_chat_completion( - prompt=prompt, - system_prompt=system_prompt - ) - else: - # 向后兼容:使用老式API配置 - api_url = getattr(self.config, 'filter_api_url', 'http://localhost:1234/v1/chat/completions') - api_key = getattr(self.config, 'filter_api_key', 'not-needed') - # 如果没有传入模型名称,使用默认值 - if not model_name: - model_name = 'gpt-4o' - - # 调用LLM - response = await self.llm_client.chat_completion( - api_url=api_url, - api_key=api_key, - model_name=model_name, - prompt=prompt, - system_prompt=system_prompt, - **kwargs - ) - - if response and hasattr(response, 'text'): - return response.text() - else: - return "抱歉,我暂时无法理解您的问题。" - - except Exception as e: - from astrbot.api import logger - logger.error(f"LLM响应生成失败: {e}") - return "抱歉,我暂时无法理解您的问题。" - - -class DatabaseManagerExtension: - """数据库管理器扩展,提供缺失的方法""" - - def __init__(self, db_manager): - self.db_manager = db_manager - - async def get_persona_update_history(self, group_id: str, days: int) -> List[Dict]: - """获取人格更新历史(基于真实数据库查询)""" - try: - # 使用数据库管理器的专门方法获取学习会话记录 - sessions = await self.db_manager.get_recent_learning_sessions(group_id, days) - - # 转换为人格更新历史格式 - history = [] - for session in sessions: - history.append({ - 'timestamp': session.get('start_time', time.time()), - 'group_id': group_id, - 'style_profile': { - 'quality_score': session.get('quality_score', 0.5), - 'messages_processed': session.get('messages_processed', 0), - 'success': session.get('success', False) - }, - 'update_type': 'learning_session', - 'backup_reason': f"学习会话 {session.get('session_id', 'unknown')}" - }) - - return history - - except Exception as e: - from astrbot.api import logger - logger.error(f"获取人格更新历史失败: {e}") - return [] - - async def get_learning_batch_history(self, group_id: str, days: int) -> List[Dict]: - """获取学习批次历史(基于真实数据库查询)""" - try: - # 从全局消息数据库查询学习批次记录 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - start_timestamp = time.time() - (days * 24 * 3600) - - await cursor.execute(''' - SELECT * FROM learning_batches - WHERE start_time >= ? AND group_id = ? - ORDER BY start_time DESC - LIMIT 30 - ''', (start_timestamp, group_id)) - - rows = await cursor.fetchall() - history = [] - - for row in rows: - history.append({ - 'start_time': row[2], # start_time column - 'end_time': row[3], # end_time column - 'group_id': row[1], # group_id column - 'quality_score': row[4] if row[4] else 0.5, # quality_score column - 'processed_messages': row[5] if row[5] else 0, # processed_messages column - 'processing_time': (row[3] - row[2]) if (row[3] and row[2]) else 0 # calculate from timestamps - }) - - return history - - except Exception as e: - from astrbot.api import logger - logger.error(f"获取学习批次历史失败: {e}") - # 如果表不存在或查询失败,返回空列表 - return [] - - async def get_messages_by_timerange(self, group_id: str, start_time, end_time) -> List[Dict]: - """根据时间范围获取消息(基于真实数据库查询)""" - try: - # 从全局消息数据库查询指定时间范围内的消息 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - start_timestamp = start_time.timestamp() - end_timestamp = end_time.timestamp() - - await cursor.execute(''' - SELECT sender_id, sender_name, message, group_id, platform, timestamp - FROM raw_messages - WHERE timestamp >= ? AND timestamp <= ? AND group_id = ? - ORDER BY timestamp ASC - LIMIT 1000 - ''', (start_timestamp, end_timestamp, group_id)) - - rows = await cursor.fetchall() - messages = [] - - for row in rows: - messages.append({ - 'timestamp': row[5], # timestamp column - 'group_id': row[3], # group_id column - 'sender_id': row[0], # sender_id column - 'sender_name': row[1], # sender_name column - 'message': row[2], # message column - 'platform': row[4] # platform column - }) - - return messages - - except Exception as e: - from astrbot.api import logger - logger.error(f"根据时间范围获取消息失败: {e}") - # 如果查询失败,返回空列表 - return [] - - async def get_social_relationships(self, group_id: str, days: int) -> List[Dict]: - """获取社交关系数据(基于真实数据库查询)""" - try: - # 使用数据库管理器的现有方法 - relationships = await self.db_manager.load_social_graph(group_id) - - # 过滤最近几天的关系 - start_timestamp = time.time() - (days * 24 * 3600) - filtered_relationships = [ - { - 'user1_id': rel['from_user'], - 'user2_id': rel['to_user'], - 'relationship_type': rel['relation_type'], - 'interaction_count': rel['frequency'], - 'strength': rel['strength'], - 'last_interaction': rel['last_interaction'] - } - for rel in relationships - if rel['last_interaction'] >= start_timestamp - ] - - return filtered_relationships - - except Exception as e: - from astrbot.api import logger - logger.error(f"获取社交关系失败: {e}") - return [] - - async def get_message_statistics(self) -> Dict[str, int]: - """获取消息统计(基于真实数据库查询)""" - try: - # 从全局消息数据库查询真实统计 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 查询原始消息总数 - await cursor.execute('SELECT COUNT(*) FROM raw_messages') - total_messages = (await cursor.fetchone())[0] - - # 查询筛选后消息数 - await cursor.execute('SELECT COUNT(*) FROM filtered_messages') - filtered_messages = (await cursor.fetchone())[0] - - # 查询已用于学习的消息数 - await cursor.execute('SELECT COUNT(*) FROM filtered_messages WHERE used_for_learning = 1') - processed_messages = (await cursor.fetchone())[0] - - return { - 'total_messages': total_messages, - 'filtered_messages': filtered_messages, - 'processed_messages': processed_messages - } - - except Exception as e: - from astrbot.api import logger - logger.error(f"获取消息统计失败: {e}") - return {'total_messages': 0, 'filtered_messages': 0, 'processed_messages': 0} - - -class PersonaManagerExtension: - """人格管理器扩展,提供缺失的方法""" - - def __init__(self, persona_manager): - self.persona_manager = persona_manager - - async def get_current_persona(self, group_id: str) -> Optional[Dict[str, Any]]: - """获取当前人格配置""" - try: - # 尝试调用原有方法 - if hasattr(self.persona_manager, 'get_current_persona'): - result = await self.persona_manager.get_current_persona(group_id) - if isinstance(result, dict): - return result - - # 返回默认人格配置 - return { - 'name': '默认人格', - 'description': '友好、智能的AI助手', - 'style_profile': { - 'creativity': 0.7, - 'formality': 0.5, - 'emotional_intensity': 0.6, - 'vocabulary_richness': 0.6, - 'empathy': 0.8 - }, - 'group_id': group_id, - 'last_updated': time.time() - } - - except Exception as e: - from astrbot.api import logger - logger.error(f"获取当前人格配置失败: {e}") - return None - - async def get_current_persona_description(self, group_id: str = None) -> str: - """获取当前人格描述""" - try: - if hasattr(self.persona_manager, 'get_current_persona_description'): - result = await self.persona_manager.get_current_persona_description(group_id) - if result: - return result - - # 返回默认描述 - return "我是一个友好、智能的AI助手,能够理解您的需求并提供有用的回答。" - - except Exception as e: - from astrbot.api import logger - logger.error(f"获取人格描述失败: {e}") - return "我是一个AI助手。" - - -def create_compatibility_extensions(config, llm_client, db_manager, persona_manager): - """创建兼容性扩展""" - return { - 'llm_client': LLMClientExtension(llm_client, config, persona_manager), - 'db_manager': DatabaseManagerExtension(db_manager), - 'persona_manager': PersonaManagerExtension(persona_manager) if persona_manager else None - } \ No newline at end of file diff --git a/core/database/__init__.py b/core/database/__init__.py index 8ebb998..4d2a5a1 100644 --- a/core/database/__init__.py +++ b/core/database/__init__.py @@ -1,19 +1,4 @@ -""" -数据库后端模块 - 支持 SQLite、MySQL 和 PostgreSQL -""" -from .backend_interface import IDatabaseBackend, DatabaseConfig, ConnectionPool, DatabaseType -from .sqlite_backend import SQLiteBackend -from .mysql_backend import MySQLBackend -from .postgresql_backend import PostgreSQLBackend -from .factory import DatabaseFactory +"""数据库引擎模块 - SQLAlchemy ORM""" +from .engine import DatabaseEngine -__all__ = [ - 'IDatabaseBackend', - 'DatabaseConfig', - 'ConnectionPool', - 'DatabaseType', - 'SQLiteBackend', - 'MySQLBackend', - 'PostgreSQLBackend', - 'DatabaseFactory' -] +__all__ = ['DatabaseEngine'] diff --git a/core/database/backend_interface.py b/core/database/backend_interface.py deleted file mode 100644 index dfb5b31..0000000 --- a/core/database/backend_interface.py +++ /dev/null @@ -1,263 +0,0 @@ -""" -数据库后端抽象接口 - 定义统一的数据库操作接口 -""" -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple -from dataclasses import dataclass -from enum import Enum -import asyncio - - -class DatabaseType(Enum): - """数据库类型枚举""" - SQLITE = "sqlite" - MYSQL = "mysql" - POSTGRESQL = "postgresql" - - -@dataclass -class DatabaseConfig: - """数据库配置""" - db_type: DatabaseType - - # SQLite 配置 - sqlite_path: Optional[str] = None - - # MySQL 配置 - mysql_host: Optional[str] = None - mysql_port: int = 3306 - mysql_user: Optional[str] = None - mysql_password: Optional[str] = None - mysql_database: Optional[str] = None - mysql_charset: str = "utf8mb4" - - # PostgreSQL 配置 - postgresql_host: Optional[str] = None - postgresql_port: int = 5432 - postgresql_user: Optional[str] = None - postgresql_password: Optional[str] = None - postgresql_database: Optional[str] = None - postgresql_schema: str = "public" - - # 连接池配置 - max_connections: int = 10 - min_connections: int = 2 - connection_timeout: int = 30 - - def validate(self) -> Tuple[bool, Optional[str]]: - """验证配置是否有效""" - if self.db_type == DatabaseType.SQLITE: - if not self.sqlite_path: - return False, "SQLite path is required" - elif self.db_type == DatabaseType.MYSQL: - if not all([self.mysql_host, self.mysql_user, self.mysql_database]): - return False, "MySQL host, user, and database are required" - elif self.db_type == DatabaseType.POSTGRESQL: - if not all([self.postgresql_host, self.postgresql_user, self.postgresql_database]): - return False, "PostgreSQL host, user, and database are required" - else: - return False, f"Unsupported database type: {self.db_type}" - - return True, None - - -class ConnectionPool(ABC): - """数据库连接池抽象基类""" - - @abstractmethod - async def initialize(self): - """初始化连接池""" - pass - - @abstractmethod - async def get_connection(self): - """获取数据库连接""" - pass - - @abstractmethod - async def return_connection(self, conn): - """归还数据库连接""" - pass - - @abstractmethod - async def close_all(self): - """关闭所有连接""" - pass - - -class IDatabaseBackend(ABC): - """数据库后端接口""" - - @abstractmethod - async def initialize(self) -> bool: - """初始化数据库连接""" - pass - - @abstractmethod - async def close(self) -> bool: - """关闭数据库连接""" - pass - - @abstractmethod - async def execute(self, sql: str, params: Optional[Tuple] = None) -> int: - """ - 执行SQL语句(INSERT, UPDATE, DELETE) - - Args: - sql: SQL语句 - params: SQL参数 - - Returns: - 影响的行数 - """ - pass - - @abstractmethod - async def execute_many(self, sql: str, params_list: List[Tuple]) -> int: - """ - 批量执行SQL语句 - - Args: - sql: SQL语句 - params_list: 参数列表 - - Returns: - 影响的总行数 - """ - pass - - @abstractmethod - async def fetch_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Tuple]: - """ - 查询单行数据 - - Args: - sql: SQL语句 - params: SQL参数 - - Returns: - 查询结果(单行)或 None - """ - pass - - @abstractmethod - async def fetch_all(self, sql: str, params: Optional[Tuple] = None) -> List[Tuple]: - """ - 查询所有数据 - - Args: - sql: SQL语句 - params: SQL参数 - - Returns: - 查询结果列表 - """ - pass - - @abstractmethod - async def begin_transaction(self): - """开始事务""" - pass - - @abstractmethod - async def commit(self): - """提交事务""" - pass - - @abstractmethod - async def rollback(self): - """回滚事务""" - pass - - @abstractmethod - async def create_table(self, table_name: str, schema: str) -> bool: - """ - 创建表 - - Args: - table_name: 表名 - schema: 表结构SQL(DDL) - - Returns: - 是否创建成功 - """ - pass - - @abstractmethod - async def table_exists(self, table_name: str) -> bool: - """ - 检查表是否存在 - - Args: - table_name: 表名 - - Returns: - 表是否存在 - """ - pass - - @abstractmethod - async def get_table_list(self) -> List[str]: - """ - 获取所有表名列表 - - Returns: - 表名列表 - """ - pass - - @abstractmethod - async def export_table_data(self, table_name: str) -> List[Dict[str, Any]]: - """ - 导出表数据 - - Args: - table_name: 表名 - - Returns: - 表数据列表(字典格式) - """ - pass - - @abstractmethod - async def import_table_data(self, table_name: str, data: List[Dict[str, Any]]) -> int: - """ - 导入表数据 - - Args: - table_name: 表名 - data: 数据列表(字典格式) - - Returns: - 导入的行数 - """ - pass - - @abstractmethod - def get_connection_context(self): - """ - 获取连接上下文管理器 - - Returns: - 异步上下文管理器 - """ - pass - - @property - @abstractmethod - def db_type(self) -> DatabaseType: - """获取数据库类型""" - pass - - @abstractmethod - def convert_ddl(self, sqlite_ddl: str) -> str: - """ - 转换DDL语句(SQLite -> 目标数据库) - - Args: - sqlite_ddl: SQLite DDL语句 - - Returns: - 转换后的DDL语句 - """ - pass diff --git a/core/database/engine.py b/core/database/engine.py index fa51abb..d912b4d 100644 --- a/core/database/engine.py +++ b/core/database/engine.py @@ -340,9 +340,7 @@ def _mask_password(url: str) -> str: return url -# ============================================================ # 便捷函数 -# ============================================================ def create_database_engine(database_url: str, echo: bool = False) -> DatabaseEngine: """ diff --git a/core/database/factory.py b/core/database/factory.py deleted file mode 100644 index c82cb46..0000000 --- a/core/database/factory.py +++ /dev/null @@ -1,93 +0,0 @@ -""" -数据库工厂 - 根据配置创建对应的数据库后端 -""" -from typing import Optional -from astrbot.api import logger - -from .backend_interface import IDatabaseBackend, DatabaseConfig, DatabaseType -from .sqlite_backend import SQLiteBackend -from .mysql_backend import MySQLBackend -from .postgresql_backend import PostgreSQLBackend - - -class DatabaseFactory: - """数据库工厂类""" - - @staticmethod - def create_backend(config: DatabaseConfig) -> Optional[IDatabaseBackend]: - """ - 根据配置创建数据库后端 - - Args: - config: 数据库配置 - - Returns: - 数据库后端实例,失败返回None - """ - try: - # 验证配置 - valid, error = config.validate() - if not valid: - logger.error(f"[DatabaseFactory] 配置验证失败: {error}") - return None - - # 根据类型创建后端 - if config.db_type == DatabaseType.SQLITE: - logger.info(f"[DatabaseFactory] 创建SQLite后端: {config.sqlite_path}") - return SQLiteBackend(config) - elif config.db_type == DatabaseType.MYSQL: - logger.info(f"[DatabaseFactory] 创建MySQL后端: {config.mysql_host}:{config.mysql_port}/{config.mysql_database}") - return MySQLBackend(config) - elif config.db_type == DatabaseType.POSTGRESQL: - logger.info(f"[DatabaseFactory] 创建PostgreSQL后端: {config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}") - return PostgreSQLBackend(config) - else: - logger.error(f"[DatabaseFactory] 不支持的数据库类型: {config.db_type}") - return None - - except Exception as e: - logger.error(f"[DatabaseFactory] 创建数据库后端失败: {e}", exc_info=True) - return None - - @staticmethod - def create_from_dict(config_dict: dict) -> Optional[IDatabaseBackend]: - """ - 从字典配置创建数据库后端 - - Args: - config_dict: 配置字典 - - Returns: - 数据库后端实例 - """ - try: - # 解析数据库类型 - db_type_str = config_dict.get('db_type', 'sqlite') - db_type = DatabaseType(db_type_str.lower()) - - # 创建配置对象 - config = DatabaseConfig( - db_type=db_type, - sqlite_path=config_dict.get('sqlite_path'), - mysql_host=config_dict.get('mysql_host'), - mysql_port=config_dict.get('mysql_port', 3306), - mysql_user=config_dict.get('mysql_user'), - mysql_password=config_dict.get('mysql_password'), - mysql_database=config_dict.get('mysql_database'), - mysql_charset=config_dict.get('mysql_charset', 'utf8mb4'), - postgresql_host=config_dict.get('postgresql_host'), - postgresql_port=config_dict.get('postgresql_port', 5432), - postgresql_user=config_dict.get('postgresql_user'), - postgresql_password=config_dict.get('postgresql_password'), - postgresql_database=config_dict.get('postgresql_database'), - postgresql_schema=config_dict.get('postgresql_schema', 'public'), - max_connections=config_dict.get('max_connections', 10), - min_connections=config_dict.get('min_connections', 2), - connection_timeout=config_dict.get('connection_timeout', 30) - ) - - return DatabaseFactory.create_backend(config) - - except Exception as e: - logger.error(f"[DatabaseFactory] 从字典创建后端失败: {e}", exc_info=True) - return None diff --git a/core/database/mysql_backend.py b/core/database/mysql_backend.py deleted file mode 100644 index 1af9be0..0000000 --- a/core/database/mysql_backend.py +++ /dev/null @@ -1,383 +0,0 @@ -""" -MySQL 数据库后端实现 -""" -import re -import asyncio -from typing import Any, Dict, List, Optional, Tuple, Callable, TypeVar -from contextlib import asynccontextmanager - -try: - import aiomysql - AIOMYSQL_AVAILABLE = True -except ImportError: - AIOMYSQL_AVAILABLE = False - aiomysql = None - -from astrbot.api import logger - -from .backend_interface import ( - IDatabaseBackend, - DatabaseType, - DatabaseConfig, - ConnectionPool -) - -T = TypeVar('T') - - -async def retry_on_mysql_error(func: Callable[..., T], max_retries: int = 3, initial_delay: float = 0.1) -> T: - """ - 对 MySQL 数据库操作进行重试,处理临时性错误 - - Args: - func: 要执行的异步函数 - max_retries: 最大重试次数 - initial_delay: 初始延迟时间(秒) - - Returns: - 函数执行结果 - """ - delay = initial_delay - last_error = None - - # MySQL 可重试的错误码 - RETRYABLE_ERRORS = { - 1205, # Lock wait timeout - 1213, # Deadlock - 2013, # Lost connection - 2006, # MySQL server has gone away - 2014, # Command Out of Sync - } - - for attempt in range(max_retries + 1): - try: - return await func() - except Exception as e: - error_msg = str(e) - - # 检查是否是可重试的错误 - is_retryable = False - if hasattr(e, 'args') and len(e.args) > 0: - error_code = e.args[0] if isinstance(e.args[0], int) else None - if error_code in RETRYABLE_ERRORS: - is_retryable = True - - # 也检查错误消息 - if any(keyword in error_msg.lower() for keyword in ['deadlock', 'lock wait', 'lost connection', 'gone away', 'command out of sync', 'out of sync', 'packet sequence number wrong']): - is_retryable = True - - if not is_retryable: - # 不是可重试的错误,直接抛出 - raise - - last_error = e - if attempt < max_retries: - logger.warning(f"[MySQL] 遇到临时错误,第 {attempt + 1}/{max_retries} 次重试(延迟 {delay:.2f}s): {error_msg}") - await asyncio.sleep(delay) - delay *= 2 # 指数退避 - else: - logger.error(f"[MySQL] 重试 {max_retries} 次后仍失败: {error_msg}") - - # 所有重试都失败 - raise last_error - - -class MySQLConnectionPool(ConnectionPool): - """MySQL连接池""" - - def __init__(self, config: DatabaseConfig): - self.config = config - self.pool: Optional[aiomysql.Pool] = None - self._is_closed = False # ✅ 添加关闭状态标记 - - async def initialize(self): - """初始化连接池""" - if not AIOMYSQL_AVAILABLE: - raise ImportError("aiomysql is not installed. Please install it: pip install aiomysql") - - self.pool = await aiomysql.create_pool( - host=self.config.mysql_host, - port=self.config.mysql_port, - user=self.config.mysql_user, - password=self.config.mysql_password, - db=self.config.mysql_database, - charset=self.config.mysql_charset, - minsize=self.config.min_connections, - maxsize=self.config.max_connections, - autocommit=False - ) - self._is_closed = False - logger.info(f"[MySQL] 连接池初始化成功: {self.config.mysql_host}:{self.config.mysql_port}/{self.config.mysql_database}") - - async def get_connection(self): - """获取数据库连接""" - # ✅ 添加状态检查,防止使用已关闭的连接池 - if self._is_closed or not self.pool: - logger.warning("[MySQL] 尝试从已关闭的连接池获取连接,跳过操作") - raise RuntimeError("连接池已关闭或未初始化,无法获取连接") - return await self.pool.acquire() - - async def return_connection(self, conn): - """归还数据库连接""" - if conn and self.pool and not self._is_closed: - self.pool.release(conn) - - async def close_all(self): - """关闭所有连接""" - if self.pool and not self._is_closed: - self._is_closed = True # ✅ 先设置关闭标记 - self.pool.close() - await self.pool.wait_closed() - logger.info("[MySQL] 连接池已关闭") - - -class MySQLBackend(IDatabaseBackend): - """MySQL数据库后端实现""" - - def __init__(self, config: DatabaseConfig): - self.config = config - self.connection_pool: Optional[MySQLConnectionPool] = None - self._current_transaction_conn: Optional[aiomysql.Connection] = None - - async def initialize(self) -> bool: - """初始化数据库连接""" - try: - if not AIOMYSQL_AVAILABLE: - logger.error("[MySQL] aiomysql未安装,请运行: pip install aiomysql") - return False - - valid, error = self.config.validate() - if not valid: - logger.error(f"[MySQL] 配置验证失败: {error}") - return False - - # 1. 构建MySQL连接URL用于迁移工具 - mysql_url = ( - f"mysql://{self.config.mysql_user}:{self.config.mysql_password}" - f"@{self.config.mysql_host}:{self.config.mysql_port}/{self.config.mysql_database}" - ) - - # 2. 初始化连接池 - self.connection_pool = MySQLConnectionPool(self.config) - await self.connection_pool.initialize() - - # 3. 验证并修复表结构 - try: - from ...utils.schema_validator import validate_and_fix_schema - schema_valid = await validate_and_fix_schema( - db_url=mysql_url, - db_type='mysql', - auto_fix=True - ) - if not schema_valid: - logger.warning("[MySQL] 表结构验证发现问题,已尝试修复") - except Exception as e: - logger.warning(f"[MySQL] 表结构验证失败: {e}") - - logger.info("[MySQL] 数据库初始化成功") - return True - except Exception as e: - logger.error(f"[MySQL] 初始化失败: {e}", exc_info=True) - return False - - async def close(self) -> bool: - """关闭数据库连接""" - try: - if self.connection_pool: - await self.connection_pool.close_all() - logger.info("[MySQL] 数据库连接已关闭") - return True - except Exception as e: - logger.error(f"[MySQL] 关闭失败: {e}", exc_info=True) - return False - - async def execute(self, sql: str, params: Optional[Tuple] = None) -> int: - """执行SQL语句(带重试机制)""" - async def _do_execute(): - async with self.get_connection_context() as conn: - async with conn.cursor() as cursor: - await cursor.execute(sql, params or ()) - await conn.commit() - return cursor.rowcount - return await retry_on_mysql_error(_do_execute, max_retries=3) - - async def execute_many(self, sql: str, params_list: List[Tuple]) -> int: - """批量执行SQL语句(带重试机制)""" - async def _do_execute_many(): - async with self.get_connection_context() as conn: - async with conn.cursor() as cursor: - await cursor.executemany(sql, params_list) - await conn.commit() - return cursor.rowcount - return await retry_on_mysql_error(_do_execute_many, max_retries=3) - - async def fetch_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Tuple]: - """查询单行数据(带重试机制)""" - async def _do_fetch_one(): - async with self.get_connection_context() as conn: - async with conn.cursor() as cursor: - await cursor.execute(sql, params or ()) - return await cursor.fetchone() - return await retry_on_mysql_error(_do_fetch_one, max_retries=2) - - async def fetch_all(self, sql: str, params: Optional[Tuple] = None) -> List[Tuple]: - """查询所有数据(带重试机制)""" - async def _do_fetch_all(): - async with self.get_connection_context() as conn: - async with conn.cursor() as cursor: - await cursor.execute(sql, params or ()) - return await cursor.fetchall() - return await retry_on_mysql_error(_do_fetch_all, max_retries=2) - - async def begin_transaction(self): - """开始事务""" - if self._current_transaction_conn is None: - self._current_transaction_conn = await self.connection_pool.get_connection() - await self._current_transaction_conn.begin() - - async def commit(self): - """提交事务""" - if self._current_transaction_conn: - await self._current_transaction_conn.commit() - await self.connection_pool.return_connection(self._current_transaction_conn) - self._current_transaction_conn = None - - async def rollback(self): - """回滚事务""" - if self._current_transaction_conn: - await self._current_transaction_conn.rollback() - await self.connection_pool.return_connection(self._current_transaction_conn) - self._current_transaction_conn = None - - async def create_table(self, table_name: str, schema: str) -> bool: - """创建表""" - try: - # 转换SQLite DDL到MySQL DDL - mysql_schema = self.convert_ddl(schema) - await self.execute(mysql_schema) - logger.info(f"[MySQL] 创建表成功: {table_name}") - return True - except Exception as e: - logger.error(f"[MySQL] 创建表失败 {table_name}: {e}") - return False - - async def table_exists(self, table_name: str) -> bool: - """检查表是否存在""" - sql = "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = %s AND table_name = %s" - result = await self.fetch_one(sql, (self.config.mysql_database, table_name)) - return result and result[0] > 0 - - async def get_table_list(self) -> List[str]: - """获取所有表名列表""" - sql = "SELECT table_name FROM information_schema.tables WHERE table_schema = %s ORDER BY table_name" - results = await self.fetch_all(sql, (self.config.mysql_database,)) - return [row[0] for row in results] - - async def export_table_data(self, table_name: str) -> List[Dict[str, Any]]: - """导出表数据""" - sql = f"SELECT * FROM {table_name}" - async with self.get_connection_context() as conn: - async with conn.cursor(aiomysql.DictCursor) as cursor: - await cursor.execute(sql) - rows = await cursor.fetchall() - return rows - - async def import_table_data(self, table_name: str, data: List[Dict[str, Any]], replace: bool = False) -> int: - """ - 导入表数据 - - Args: - table_name: 表名 - data: 数据列表 - replace: 是否使用 REPLACE INTO(解决主键冲突) - """ - if not data: - return 0 - - # 获取列名 - columns = list(data[0].keys()) - - # 转换时间戳格式(从 Unix 时间戳转为 DATETIME) - datetime_columns = {'created_at', 'updated_at', 'timestamp', 'review_time'} - - converted_data = [] - for row in data: - new_row = {} - for col, val in row.items(): - # 检查是否是需要转换的时间戳列 - if col in datetime_columns and isinstance(val, (int, float)) and val > 1000000000: - # Unix 时间戳 -> DATETIME 字符串 - from datetime import datetime - new_row[col] = datetime.fromtimestamp(val).strftime('%Y-%m-%d %H:%M:%S') - else: - new_row[col] = val - converted_data.append(new_row) - - placeholders = ','.join(['%s' for _ in columns]) - - # 使用 REPLACE INTO 或 INSERT INTO - insert_type = "REPLACE" if replace else "INSERT" - sql = f"{insert_type} INTO {table_name} ({','.join(columns)}) VALUES ({placeholders})" - - # 准备参数 - params_list = [tuple(row[col] for col in columns) for row in converted_data] - - return await self.execute_many(sql, params_list) - - @asynccontextmanager - async def get_connection_context(self): - """获取连接上下文管理器""" - # 如果在事务中,使用事务连接 - if self._current_transaction_conn: - yield self._current_transaction_conn - else: - # 否则从池中获取连接 - conn = await self.connection_pool.get_connection() - try: - yield conn - finally: - await self.connection_pool.return_connection(conn) - - @property - def db_type(self) -> DatabaseType: - """获取数据库类型""" - return DatabaseType.MYSQL - - def convert_ddl(self, sqlite_ddl: str) -> str: - """ - 转换SQLite DDL到MySQL DDL - - 主要转换: - 1. INTEGER PRIMARY KEY AUTOINCREMENT -> INT PRIMARY KEY AUTO_INCREMENT - 2. INTEGER -> INT - 3. REAL -> DOUBLE - 4. BOOLEAN -> TINYINT(1) - 5. TEXT -> TEXT/VARCHAR - 6. TIMESTAMP DEFAULT CURRENT_TIMESTAMP -> TIMESTAMP DEFAULT CURRENT_TIMESTAMP - 7. DATETIME DEFAULT CURRENT_TIMESTAMP -> DATETIME DEFAULT CURRENT_TIMESTAMP - """ - mysql_ddl = sqlite_ddl - - # 替换数据类型 - mysql_ddl = re.sub( - r'\bINTEGER PRIMARY KEY AUTOINCREMENT\b', - 'INT PRIMARY KEY AUTO_INCREMENT', - mysql_ddl, - flags=re.IGNORECASE - ) - mysql_ddl = re.sub(r'\bINTEGER\b', 'INT', mysql_ddl, flags=re.IGNORECASE) - mysql_ddl = re.sub(r'\bREAL\b', 'DOUBLE', mysql_ddl, flags=re.IGNORECASE) - mysql_ddl = re.sub(r'\bBOOLEAN\b', 'TINYINT(1)', mysql_ddl, flags=re.IGNORECASE) - - # 移除SQLite特有的PRAGMA - mysql_ddl = re.sub(r'PRAGMA\s+\w+\s*=\s*\w+;?', '', mysql_ddl, flags=re.IGNORECASE) - - # 替换IF NOT EXISTS (MySQL支持) - # 无需修改,MySQL也支持 - - # 添加ENGINE和CHARSET - if 'CREATE TABLE' in mysql_ddl.upper() and 'ENGINE=' not in mysql_ddl.upper(): - mysql_ddl = mysql_ddl.rstrip().rstrip(';') - mysql_ddl += ' ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci' - - return mysql_ddl diff --git a/core/database/postgresql_backend.py b/core/database/postgresql_backend.py deleted file mode 100644 index 875ee37..0000000 --- a/core/database/postgresql_backend.py +++ /dev/null @@ -1,445 +0,0 @@ -""" -PostgreSQL 数据库后端实现 -""" -import re -import asyncio -from typing import Any, Dict, List, Optional, Tuple, Callable, TypeVar -from contextlib import asynccontextmanager - -try: - import asyncpg - ASYNCPG_AVAILABLE = True -except ImportError: - ASYNCPG_AVAILABLE = False - asyncpg = None - -from astrbot.api import logger - -from .backend_interface import ( - IDatabaseBackend, - DatabaseType, - DatabaseConfig, - ConnectionPool -) - -T = TypeVar('T') - - -async def retry_on_postgres_error(func: Callable[..., T], max_retries: int = 3, initial_delay: float = 0.1) -> T: - """ - 对 PostgreSQL 数据库操作进行重试,处理临时性错误 - - Args: - func: 要执行的异步函数 - max_retries: 最大重试次数 - initial_delay: 初始延迟时间(秒) - - Returns: - 函数执行结果 - """ - delay = initial_delay - last_error = None - - # PostgreSQL 可重试的错误码 - RETRYABLE_SQLSTATES = { - '40001', # serialization_failure - '40P01', # deadlock_detected - '08003', # connection_does_not_exist - '08006', # connection_failure - '08000', # connection_exception - '57P03', # cannot_connect_now - } - - for attempt in range(max_retries + 1): - try: - return await func() - except Exception as e: - error_msg = str(e) - - # 检查是否是可重试的错误 - is_retryable = False - - # asyncpg 的异常有 sqlstate 属性 - if hasattr(e, 'sqlstate') and e.sqlstate in RETRYABLE_SQLSTATES: - is_retryable = True - - # 也检查错误消息 - if any(keyword in error_msg.lower() for keyword in ['deadlock', 'serialization', 'connection', 'timeout']): - is_retryable = True - - if not is_retryable: - # 不是可重试的错误,直接抛出 - raise - - last_error = e - if attempt < max_retries: - logger.warning(f"[PostgreSQL] 遇到临时错误,第 {attempt + 1}/{max_retries} 次重试(延迟 {delay:.2f}s): {error_msg}") - await asyncio.sleep(delay) - delay *= 2 # 指数退避 - else: - logger.error(f"[PostgreSQL] 重试 {max_retries} 次后仍失败: {error_msg}") - - # 所有重试都失败 - raise last_error - - -class PostgreSQLConnectionPool(ConnectionPool): - """PostgreSQL连接池""" - - def __init__(self, config: DatabaseConfig): - self.config = config - self.pool: Optional[asyncpg.Pool] = None - - async def initialize(self): - """初始化连接池""" - if not ASYNCPG_AVAILABLE: - raise ImportError("asyncpg is not installed. Please install it: pip install asyncpg") - - # 构建连接字符串或使用参数字典 - self.pool = await asyncpg.create_pool( - host=self.config.postgresql_host, - port=self.config.postgresql_port, - user=self.config.postgresql_user, - password=self.config.postgresql_password, - database=self.config.postgresql_database, - min_size=self.config.min_connections, - max_size=self.config.max_connections, - command_timeout=self.config.connection_timeout, - # PostgreSQL 特定设置 - server_settings={ - 'search_path': self.config.postgresql_schema, - } - ) - logger.info(f"[PostgreSQL] 连接池初始化成功: {self.config.postgresql_host}:{self.config.postgresql_port}/{self.config.postgresql_database}") - - async def get_connection(self): - """获取数据库连接""" - return await self.pool.acquire() - - async def return_connection(self, conn): - """归还数据库连接""" - if conn: - await self.pool.release(conn) - - async def close_all(self): - """关闭所有连接""" - if self.pool: - await self.pool.close() - logger.info("[PostgreSQL] 连接池已关闭") - - -class PostgreSQLBackend(IDatabaseBackend): - """PostgreSQL数据库后端实现""" - - def __init__(self, config: DatabaseConfig): - self.config = config - self.connection_pool: Optional[PostgreSQLConnectionPool] = None - self._current_transaction_conn: Optional[asyncpg.Connection] = None - - async def initialize(self) -> bool: - """初始化数据库连接""" - try: - if not ASYNCPG_AVAILABLE: - logger.error("[PostgreSQL] asyncpg未安装,请运行: pip install asyncpg") - return False - - valid, error = self.config.validate() - if not valid: - logger.error(f"[PostgreSQL] 配置验证失败: {error}") - return False - - self.connection_pool = PostgreSQLConnectionPool(self.config) - await self.connection_pool.initialize() - logger.info("[PostgreSQL] 数据库初始化成功") - return True - except Exception as e: - logger.error(f"[PostgreSQL] 初始化失败: {e}", exc_info=True) - return False - - async def close(self) -> bool: - """关闭数据库连接""" - try: - if self.connection_pool: - await self.connection_pool.close_all() - logger.info("[PostgreSQL] 数据库连接已关闭") - return True - except Exception as e: - logger.error(f"[PostgreSQL] 关闭失败: {e}", exc_info=True) - return False - - async def execute(self, sql: str, params: Optional[Tuple] = None) -> int: - """执行SQL语句(带重试机制)""" - async def _do_execute(): - async with self.get_connection_context() as conn: - # PostgreSQL 使用 $1, $2 而不是 ? - converted_sql = self._convert_placeholders(sql) - result = await conn.execute(converted_sql, *(params or ())) - # asyncpg 的 execute 返回状态字符串,如 "INSERT 0 1" - # 我们需要解析出影响的行数 - return self._parse_row_count(result) - return await retry_on_postgres_error(_do_execute, max_retries=3) - - async def execute_many(self, sql: str, params_list: List[Tuple]) -> int: - """批量执行SQL语句(带重试机制)""" - async def _do_execute_many(): - async with self.get_connection_context() as conn: - converted_sql = self._convert_placeholders(sql) - # asyncpg 使用 executemany - await conn.executemany(converted_sql, params_list) - # executemany 不返回行数,返回参数列表长度 - return len(params_list) - return await retry_on_postgres_error(_do_execute_many, max_retries=3) - - async def fetch_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Tuple]: - """查询单行数据(带重试机制)""" - async def _do_fetch_one(): - async with self.get_connection_context() as conn: - converted_sql = self._convert_placeholders(sql) - row = await conn.fetchrow(converted_sql, *(params or ())) - # asyncpg 返回 Record 对象,转为 tuple - return tuple(row) if row else None - return await retry_on_postgres_error(_do_fetch_one, max_retries=2) - - async def fetch_all(self, sql: str, params: Optional[Tuple] = None) -> List[Tuple]: - """查询所有数据(带重试机制)""" - async def _do_fetch_all(): - async with self.get_connection_context() as conn: - converted_sql = self._convert_placeholders(sql) - rows = await conn.fetch(converted_sql, *(params or ())) - # 转换为 tuple 列表 - return [tuple(row) for row in rows] - return await retry_on_postgres_error(_do_fetch_all, max_retries=2) - - async def begin_transaction(self): - """开始事务""" - if self._current_transaction_conn is None: - self._current_transaction_conn = await self.connection_pool.get_connection() - # asyncpg 使用 transaction() 上下文管理器,这里手动开始 - self._transaction = self._current_transaction_conn.transaction() - await self._transaction.start() - - async def commit(self): - """提交事务""" - if self._current_transaction_conn and hasattr(self, '_transaction'): - await self._transaction.commit() - await self.connection_pool.return_connection(self._current_transaction_conn) - self._current_transaction_conn = None - self._transaction = None - - async def rollback(self): - """回滚事务""" - if self._current_transaction_conn and hasattr(self, '_transaction'): - await self._transaction.rollback() - await self.connection_pool.return_connection(self._current_transaction_conn) - self._current_transaction_conn = None - self._transaction = None - - async def create_table(self, table_name: str, schema: str) -> bool: - """创建表""" - try: - # 转换SQLite DDL到PostgreSQL DDL - postgres_schema = self.convert_ddl(schema) - await self.execute(postgres_schema) - logger.info(f"[PostgreSQL] 创建表成功: {table_name}") - return True - except Exception as e: - logger.error(f"[PostgreSQL] 创建表失败 {table_name}: {e}") - return False - - async def table_exists(self, table_name: str) -> bool: - """检查表是否存在""" - sql = """ - SELECT COUNT(*) - FROM information_schema.tables - WHERE table_schema = $1 AND table_name = $2 - """ - result = await self.fetch_one(sql, (self.config.postgresql_schema, table_name)) - return result and result[0] > 0 - - async def get_table_list(self) -> List[str]: - """获取所有表名列表""" - sql = """ - SELECT table_name - FROM information_schema.tables - WHERE table_schema = $1 - ORDER BY table_name - """ - results = await self.fetch_all(sql, (self.config.postgresql_schema,)) - return [row[0] for row in results] - - async def export_table_data(self, table_name: str) -> List[Dict[str, Any]]: - """导出表数据""" - sql = f"SELECT * FROM {table_name}" - async with self.get_connection_context() as conn: - converted_sql = self._convert_placeholders(sql) - rows = await conn.fetch(converted_sql) - # asyncpg Record 可以直接转为 dict - return [dict(row) for row in rows] - - async def import_table_data(self, table_name: str, data: List[Dict[str, Any]], replace: bool = False) -> int: - """ - 导入表数据 - - Args: - table_name: 表名 - data: 数据列表 - replace: 是否使用 UPSERT(ON CONFLICT) - """ - if not data: - return 0 - - # 获取列名 - columns = list(data[0].keys()) - - # 转换时间戳格式(从 Unix 时间戳转为 TIMESTAMP) - datetime_columns = {'created_at', 'updated_at', 'timestamp', 'review_time'} - - converted_data = [] - for row in data: - new_row = {} - for col, val in row.items(): - # 检查是否是需要转换的时间戳列 - if col in datetime_columns and isinstance(val, (int, float)) and val > 1000000000: - # Unix 时间戳 -> TIMESTAMP - from datetime import datetime - new_row[col] = datetime.fromtimestamp(val) - else: - new_row[col] = val - converted_data.append(new_row) - - # PostgreSQL 使用 $1, $2, ... 占位符 - placeholders = ', '.join([f'${i+1}' for i in range(len(columns))]) - - if replace: - # PostgreSQL 使用 ON CONFLICT 实现 UPSERT - # 需要知道主键列名,这里假设第一个列是主键 - primary_key = columns[0] - update_cols = ', '.join([f"{col} = EXCLUDED.{col}" for col in columns[1:]]) - sql = f""" - INSERT INTO {table_name} ({','.join(columns)}) - VALUES ({placeholders}) - ON CONFLICT ({primary_key}) - DO UPDATE SET {update_cols} - """ - else: - sql = f"INSERT INTO {table_name} ({','.join(columns)}) VALUES ({placeholders})" - - # 准备参数 - params_list = [tuple(row[col] for col in columns) for row in converted_data] - - return await self.execute_many(sql, params_list) - - @asynccontextmanager - async def get_connection_context(self): - """获取连接上下文管理器""" - # 如果在事务中,使用事务连接 - if self._current_transaction_conn: - yield self._current_transaction_conn - else: - # 否则从池中获取连接 - conn = await self.connection_pool.get_connection() - try: - yield conn - finally: - await self.connection_pool.return_connection(conn) - - @property - def db_type(self) -> DatabaseType: - """获取数据库类型""" - return DatabaseType.POSTGRESQL - - def _convert_placeholders(self, sql: str) -> str: - """ - 将 ? 占位符转换为 PostgreSQL 的 $1, $2, ... 格式 - - 注意:这个简单实现不处理字符串中的 ?,实际使用中可能需要更复杂的解析 - """ - # 简单替换:按顺序替换所有 ? - counter = 1 - result = [] - in_string = False - escape_next = False - - for char in sql: - if escape_next: - result.append(char) - escape_next = False - continue - - if char == '\\': - escape_next = True - result.append(char) - continue - - if char in ("'", '"'): - in_string = not in_string - result.append(char) - continue - - if char == '?' and not in_string: - result.append(f'${counter}') - counter += 1 - else: - result.append(char) - - return ''.join(result) - - def _parse_row_count(self, status: str) -> int: - """ - 解析 PostgreSQL 返回的状态字符串,提取受影响的行数 - - 例如: "INSERT 0 1" -> 1, "UPDATE 3" -> 3, "DELETE 5" -> 5 - """ - try: - parts = status.split() - if len(parts) >= 2: - # 最后一个数字通常是行数 - return int(parts[-1]) - return 0 - except (ValueError, IndexError): - return 0 - - def convert_ddl(self, sqlite_ddl: str) -> str: - """ - 转换SQLite DDL到PostgreSQL DDL - - 主要转换: - 1. INTEGER PRIMARY KEY AUTOINCREMENT -> SERIAL PRIMARY KEY - 2. INTEGER -> INTEGER (PostgreSQL 也支持) - 3. REAL -> DOUBLE PRECISION - 4. BOOLEAN -> BOOLEAN (PostgreSQL 原生支持) - 5. TEXT -> TEXT (PostgreSQL 支持) - 6. DATETIME -> TIMESTAMP - 7. 移除 IF NOT EXISTS(PostgreSQL 9.1+ 支持,保留) - """ - postgres_ddl = sqlite_ddl - - # 替换 AUTOINCREMENT 为 SERIAL - postgres_ddl = re.sub( - r'\bINTEGER\s+PRIMARY\s+KEY\s+AUTOINCREMENT\b', - 'SERIAL PRIMARY KEY', - postgres_ddl, - flags=re.IGNORECASE - ) - - # 替换 REAL 为 DOUBLE PRECISION - postgres_ddl = re.sub(r'\bREAL\b', 'DOUBLE PRECISION', postgres_ddl, flags=re.IGNORECASE) - - # 替换 DATETIME 为 TIMESTAMP - postgres_ddl = re.sub(r'\bDATETIME\b', 'TIMESTAMP', postgres_ddl, flags=re.IGNORECASE) - - # 移除SQLite特有的PRAGMA - postgres_ddl = re.sub(r'PRAGMA\s+\w+\s*=\s*\w+;?', '', postgres_ddl, flags=re.IGNORECASE) - - # 替换 strftime('%s', 'now') 为 extract(epoch from now()) - postgres_ddl = re.sub( - r"strftime\s*\(\s*'%s'\s*,\s*'now'\s*\)", - "extract(epoch from now())", - postgres_ddl, - flags=re.IGNORECASE - ) - - # 替换 CURRENT_TIMESTAMP - # PostgreSQL 支持 CURRENT_TIMESTAMP,无需修改 - - return postgres_ddl diff --git a/core/database/sqlite_backend.py b/core/database/sqlite_backend.py deleted file mode 100644 index 75fe36a..0000000 --- a/core/database/sqlite_backend.py +++ /dev/null @@ -1,346 +0,0 @@ -""" -SQLite 数据库后端实现 -""" -import os -import asyncio -import aiosqlite -import sqlite3 -from typing import Any, Dict, List, Optional, Tuple, Callable, TypeVar -from contextlib import asynccontextmanager - -from astrbot.api import logger - -from .backend_interface import ( - IDatabaseBackend, - DatabaseType, - DatabaseConfig, - ConnectionPool -) - -T = TypeVar('T') - - -async def retry_on_lock(func: Callable[..., T], max_retries: int = 3, initial_delay: float = 0.1) -> T: - """ - 对数据库操作进行重试,处理 database is locked 错误 - - Args: - func: 要执行的异步函数 - max_retries: 最大重试次数 - initial_delay: 初始延迟时间(秒) - - Returns: - 函数执行结果 - """ - delay = initial_delay - last_error = None - - for attempt in range(max_retries + 1): - try: - return await func() - except (sqlite3.OperationalError, Exception) as e: - error_msg = str(e) - if 'database is locked' not in error_msg.lower(): - # 不是锁定错误,直接抛出 - raise - - last_error = e - if attempt < max_retries: - logger.warning(f"[SQLite] 数据库锁定,第 {attempt + 1}/{max_retries} 次重试(延迟 {delay:.2f}s)") - await asyncio.sleep(delay) - delay *= 2 # 指数退避 - else: - logger.error(f"[SQLite] 重试 {max_retries} 次后仍失败: {error_msg}") - - # 所有重试都失败 - raise last_error - - -class SQLiteConnectionPool(ConnectionPool): - """SQLite连接池""" - - def __init__(self, db_path: str, max_connections: int = 10, min_connections: int = 2): - self.db_path = db_path - self.max_connections = max_connections - self.min_connections = min_connections - self.pool: asyncio.Queue = asyncio.Queue(maxsize=max_connections) - self.active_connections = 0 - self.total_connections = 0 - self._lock = asyncio.Lock() - - async def initialize(self): - """初始化连接池""" - async with self._lock: - # 确保目录存在 - db_dir = os.path.dirname(self.db_path) - if db_dir: - os.makedirs(db_dir, exist_ok=True) - - # 创建最小数量的连接 - for _ in range(self.min_connections): - conn = await self._create_connection() - await self.pool.put(conn) - - async def _create_connection(self) -> aiosqlite.Connection: - """创建新的数据库连接""" - # 设置超时时间为30秒,避免database is locked错误 - conn = await aiosqlite.connect(self.db_path, timeout=30.0) - - # 设置连接参数 - await conn.execute('PRAGMA foreign_keys = ON') - await conn.execute('PRAGMA journal_mode = WAL') - await conn.execute('PRAGMA synchronous = NORMAL') - await conn.execute('PRAGMA cache_size = 10000') - await conn.execute('PRAGMA temp_store = memory') - await conn.execute('PRAGMA busy_timeout = 30000') # 设置忙等待超时为30秒(毫秒) - await conn.commit() - - self.total_connections += 1 - logger.debug(f"[SQLite] 创建新连接,总连接数: {self.total_connections}") - return conn - - async def get_connection(self) -> aiosqlite.Connection: - """获取数据库连接""" - try: - # 尝试从池中获取连接(非阻塞) - conn = self.pool.get_nowait() - self.active_connections += 1 - return conn - except asyncio.QueueEmpty: - # 池中无可用连接 - async with self._lock: - if self.total_connections < self.max_connections: - # 可以创建新连接 - conn = await self._create_connection() - self.active_connections += 1 - return conn - else: - # 达到最大连接数,等待连接归还 - logger.debug("[SQLite] 连接池已满,等待连接归还...") - conn = await self.pool.get() - self.active_connections += 1 - return conn - - async def return_connection(self, conn: aiosqlite.Connection): - """归还数据库连接""" - if conn: - try: - # 检查连接是否仍然有效 - await conn.execute('SELECT 1') - await self.pool.put(conn) - self.active_connections -= 1 - except Exception as e: - # 连接已损坏,关闭并减少计数 - logger.warning(f"[SQLite] 连接已损坏,关闭连接: {e}") - try: - await conn.close() - except: - pass - self.total_connections -= 1 - self.active_connections -= 1 - - async def close_all(self): - """关闭所有连接""" - logger.info("[SQLite] 开始关闭连接池...") - - # 关闭池中的所有连接 - while not self.pool.empty(): - try: - conn = self.pool.get_nowait() - await conn.close() - self.total_connections -= 1 - except asyncio.QueueEmpty: - break - except Exception as e: - logger.error(f"[SQLite] 关闭连接时出错: {e}") - - logger.info(f"[SQLite] 连接池已关闭,剩余连接数: {self.total_connections}") - - -class SQLiteBackend(IDatabaseBackend): - """SQLite数据库后端实现""" - - def __init__(self, config: DatabaseConfig): - self.config = config - self.connection_pool: Optional[SQLiteConnectionPool] = None - self._current_transaction_conn: Optional[aiosqlite.Connection] = None - - async def initialize(self) -> bool: - """初始化数据库连接""" - try: - valid, error = self.config.validate() - if not valid: - logger.error(f"[SQLite] 配置验证失败: {error}") - return False - - # 1. 初始化连接池 - self.connection_pool = SQLiteConnectionPool( - db_path=self.config.sqlite_path, - max_connections=self.config.max_connections, - min_connections=self.config.min_connections - ) - - await self.connection_pool.initialize() - - # 2. 验证并修复表结构 - try: - from ...utils.schema_validator import validate_and_fix_schema - schema_valid = await validate_and_fix_schema( - db_url=self.config.sqlite_path, - db_type='sqlite', - auto_fix=True - ) - if not schema_valid: - logger.warning("[SQLite] 表结构验证发现问题,已尝试修复") - except Exception as e: - logger.warning(f"[SQLite] 表结构验证失败: {e}") - - logger.info(f"[SQLite] 数据库初始化成功: {self.config.sqlite_path}") - return True - except Exception as e: - logger.error(f"[SQLite] 初始化失败: {e}", exc_info=True) - return False - - async def close(self) -> bool: - """关闭数据库连接""" - try: - if self.connection_pool: - await self.connection_pool.close_all() - logger.info("[SQLite] 数据库连接已关闭") - return True - except Exception as e: - logger.error(f"[SQLite] 关闭失败: {e}", exc_info=True) - return False - - async def execute(self, sql: str, params: Optional[Tuple] = None) -> int: - """执行SQL语句(带重试机制)""" - async def _do_execute(): - async with self.get_connection_context() as conn: - cursor = await conn.execute(sql, params or ()) - await conn.commit() - return cursor.rowcount - return await retry_on_lock(_do_execute, max_retries=5) - - async def execute_many(self, sql: str, params_list: List[Tuple]) -> int: - """批量执行SQL语句(带重试机制)""" - async def _do_execute_many(): - async with self.get_connection_context() as conn: - cursor = await conn.executemany(sql, params_list) - await conn.commit() - return cursor.rowcount - return await retry_on_lock(_do_execute_many, max_retries=5) - - async def fetch_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Tuple]: - """查询单行数据(带重试机制)""" - async def _do_fetch_one(): - async with self.get_connection_context() as conn: - cursor = await conn.execute(sql, params or ()) - return await cursor.fetchone() - return await retry_on_lock(_do_fetch_one, max_retries=3) - - async def fetch_all(self, sql: str, params: Optional[Tuple] = None) -> List[Tuple]: - """查询所有数据(带重试机制)""" - async def _do_fetch_all(): - async with self.get_connection_context() as conn: - cursor = await conn.execute(sql, params or ()) - return await cursor.fetchall() - return await retry_on_lock(_do_fetch_all, max_retries=3) - - async def begin_transaction(self): - """开始事务""" - if self._current_transaction_conn is None: - self._current_transaction_conn = await self.connection_pool.get_connection() - await self._current_transaction_conn.execute('BEGIN') - - async def commit(self): - """提交事务""" - if self._current_transaction_conn: - await self._current_transaction_conn.commit() - await self.connection_pool.return_connection(self._current_transaction_conn) - self._current_transaction_conn = None - - async def rollback(self): - """回滚事务""" - if self._current_transaction_conn: - await self._current_transaction_conn.rollback() - await self.connection_pool.return_connection(self._current_transaction_conn) - self._current_transaction_conn = None - - async def create_table(self, table_name: str, schema: str) -> bool: - """创建表""" - try: - await self.execute(schema) - logger.info(f"[SQLite] 创建表成功: {table_name}") - return True - except Exception as e: - logger.error(f"[SQLite] 创建表失败 {table_name}: {e}") - return False - - async def table_exists(self, table_name: str) -> bool: - """检查表是否存在""" - sql = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" - result = await self.fetch_one(sql, (table_name,)) - return result is not None - - async def get_table_list(self) -> List[str]: - """获取所有表名列表""" - sql = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" - results = await self.fetch_all(sql) - return [row[0] for row in results] - - async def export_table_data(self, table_name: str) -> List[Dict[str, Any]]: - """导出表数据""" - sql = f"SELECT * FROM {table_name}" - async with self.get_connection_context() as conn: - conn.row_factory = aiosqlite.Row - cursor = await conn.execute(sql) - rows = await cursor.fetchall() - return [dict(row) for row in rows] - - async def import_table_data(self, table_name: str, data: List[Dict[str, Any]], replace: bool = False) -> int: - """ - 导入表数据 - - Args: - table_name: 表名 - data: 数据列表 - replace: SQLite 不支持,忽略此参数 - """ - if not data: - return 0 - - # 获取列名 - columns = list(data[0].keys()) - placeholders = ','.join(['?' for _ in columns]) - - # SQLite 使用 INSERT OR REPLACE 代替 REPLACE INTO - insert_type = "INSERT OR REPLACE" if replace else "INSERT" - sql = f"{insert_type} INTO {table_name} ({','.join(columns)}) VALUES ({placeholders})" - - # 准备参数 - params_list = [tuple(row[col] for col in columns) for row in data] - - return await self.execute_many(sql, params_list) - - @asynccontextmanager - async def get_connection_context(self): - """获取连接上下文管理器""" - # 如果在事务中,使用事务连接 - if self._current_transaction_conn: - yield self._current_transaction_conn - else: - # 否则从池中获取连接 - conn = await self.connection_pool.get_connection() - try: - yield conn - finally: - await self.connection_pool.return_connection(conn) - - @property - def db_type(self) -> DatabaseType: - """获取数据库类型""" - return DatabaseType.SQLITE - - def convert_ddl(self, sqlite_ddl: str) -> str: - """SQLite DDL不需要转换""" - return sqlite_ddl diff --git a/core/factory.py b/core/factory.py index b83d7d2..b43f072 100644 --- a/core/factory.py +++ b/core/factory.py @@ -3,6 +3,7 @@ """ from typing import Dict, Any, Optional import asyncio +import functools import json # 导入json模块,因为MessageFilter中使用了 from astrbot.api.star import Context @@ -13,7 +14,7 @@ IQualityMonitor, IPersonaManager, IPersonaUpdater, IMLAnalyzer, IIntelligentResponder, IMessageRelationshipAnalyzer, LearningStrategyType ) -from .patterns import StrategyFactory, ServiceRegistry, EventBus +from .patterns import StrategyFactory, ServiceRegistry from .framework_llm_adapter import FrameworkLLMAdapter # 导入框架LLM适配器 # 使用单例模式导入配置和异常 @@ -23,6 +24,21 @@ from ..utils.json_utils import safe_parse_llm_json +def cached_service(key): + """Decorator that caches create_* return values in self._service_cache.""" + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if key in self._service_cache: + return self._service_cache[key] + result = func(self, *args, **kwargs) + if result is not None: + self._service_cache[key] = result + return result + return wrapper + return decorator + + class ServiceFactory(IServiceFactory): """主要服务工厂 - 创建和管理所有服务实例""" @@ -31,8 +47,7 @@ def __init__(self, config: PluginConfig, context: Context): self.context = context self._logger = logger self._registry = ServiceRegistry() - self._event_bus = EventBus() - + # 服务实例缓存 self._service_cache: Dict[str, Any] = {} @@ -50,25 +65,25 @@ def create_framework_llm_adapter(self) -> FrameworkLLMAdapter: # 检查是否成功配置了至少一个提供商 if self._framework_llm_adapter.providers_configured > 0: - self._logger.info(f"✅ 框架LLM适配器初始化成功,已配置 {self._framework_llm_adapter.providers_configured} 个提供商") + self._logger.info(f" 框架LLM适配器初始化成功,已配置 {self._framework_llm_adapter.providers_configured} 个提供商") else: - # ⚠️ 重要变更:Provider未配置时不抛出异常,允许延迟初始化 + # 重要变更:Provider未配置时不抛出异常,允许延迟初始化 self._logger.warning( - "⚠️ 框架LLM适配器初始化时未找到可用的Provider。\n" - " 原因可能是:\n" - " 1. AstrBot的Provider系统尚未完全初始化(插件加载时序问题)\n" - " 2. 配置文件中未指定filter_provider_id/refine_provider_id\n" - " 3. 指定的Provider ID不存在\n" - " 插件将继续加载,Provider会在实际使用时自动重试初始化。" + " 框架LLM适配器初始化时未找到可用的Provider。\n" + " 原因可能是:\n" + " 1. AstrBot的Provider系统尚未完全初始化(插件加载时序问题)\n" + " 2. 配置文件中未指定filter_provider_id/refine_provider_id\n" + " 3. 指定的Provider ID不存在\n" + " 插件将继续加载,Provider会在实际使用时自动重试初始化。" ) # 标记为需要延迟初始化 self._framework_llm_adapter._needs_lazy_init = True except Exception as e: self._logger.warning( - f"⚠️ 初始化LLM适配器时发生异常: {e}\n" - " 插件将继续加载,LLM功能会在实际调用时重试初始化。", - exc_info=self.config.debug_mode # 仅在debug模式显示完整堆栈 + f" 初始化LLM适配器时发生异常: {e}\n" + " 插件将继续加载,LLM功能会在实际调用时重试初始化。", + exc_info=self.config.debug_mode # 仅在debug模式显示完整堆栈 ) # 创建一个最小化的适配器实例,允许插件继续加载 self._framework_llm_adapter = FrameworkLLMAdapter(self.context) @@ -80,19 +95,14 @@ def get_prompts(self) -> Any: """获取 Prompt 静态数据""" return prompts + @cached_service("message_collector") def create_message_collector(self) -> IMessageCollector: """创建消息收集器""" - cache_key = "message_collector" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: # 单例模式动态导入避免循环依赖 - from ..services.message_collector import MessageCollectorService - + from ..services.core_learning import MessageCollectorService + service = MessageCollectorService(self.config, self.context, self.create_database_manager()) # 传递 DatabaseManager - self._service_cache[cache_key] = service self._registry.register_service("message_collector", service) self._logger.info("创建消息收集器成功") @@ -102,43 +112,37 @@ def create_message_collector(self) -> IMessageCollector: self._logger.error(f"导入消息收集器失败: {e}", exc_info=True) raise ServiceError(f"创建消息收集器失败: {str(e)}") + @cached_service("style_analyzer") def create_style_analyzer(self) -> IStyleAnalyzer: """创建风格分析器 - 优先使用MaiBot增强版本""" - cache_key = "style_analyzer" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: # 如果启用了MaiBot增强功能,使用MaiBot适配器 if getattr(self.config, 'enable_maibot_features', False): try: - from ..services.maibot_adapters import MaiBotStyleAnalyzer + from ..services.integration import MaiBotStyleAnalyzer service = MaiBotStyleAnalyzer( - self.config, + self.config, self.create_database_manager(), context=self.context, llm_adapter=self.create_framework_llm_adapter() ) - self._service_cache[cache_key] = service self._registry.register_service("style_analyzer", service) self._logger.info("创建MaiBot风格分析器成功") return service except ImportError as e: self._logger.warning(f"MaiBot适配器不可用,回退到默认实现: {e}") - + # 回退到默认实现 - from ..services.style_analyzer import StyleAnalyzerService - + from ..services.response import StyleAnalyzerService + # 传递 DatabaseManager 和框架适配器 service = StyleAnalyzerService( - self.config, - self.context, + self.config, + self.context, self.create_database_manager(), - llm_adapter=self.create_framework_llm_adapter(), # 使用框架适配器 - prompts=self.get_prompts() # 传递 prompts - ) - self._service_cache[cache_key] = service + llm_adapter=self.create_framework_llm_adapter(), # 使用框架适配器 + prompts=self.get_prompts() # 传递 prompts + ) self._registry.register_service("style_analyzer", service) self._logger.info("创建风格分析器成功") @@ -148,22 +152,17 @@ def create_style_analyzer(self) -> IStyleAnalyzer: self._logger.error(f"导入风格分析器失败: {e}", exc_info=True) raise ServiceError(f"创建风格分析器失败: {str(e)}") + @cached_service("message_relationship_analyzer") def create_message_relationship_analyzer(self): """创建消息关系分析器""" - cache_key = "message_relationship_analyzer" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.message_relationship_analyzer import MessageRelationshipAnalyzer - + from ..services.social import MessageRelationshipAnalyzer + service = MessageRelationshipAnalyzer( self.config, - self.context, + self.context, llm_adapter=self.create_framework_llm_adapter() ) - self._service_cache[cache_key] = service self._registry.register_service("message_relationship_analyzer", service) self._logger.info("创建消息关系分析器成功") @@ -179,7 +178,7 @@ def create_learning_strategy(self, strategy_type: str) -> ILearningStrategy: # 如果启用了MaiBot增强功能,使用MaiBot学习策略 if getattr(self.config, 'enable_maibot_features', False): try: - from ..services.maibot_adapters import MaiBotLearningStrategy + from ..services.integration import MaiBotLearningStrategy strategy = MaiBotLearningStrategy(self.config, self.create_database_manager()) self._logger.info("创建MaiBot学习策略成功") return strategy @@ -208,36 +207,30 @@ def create_learning_strategy(self, strategy_type: str) -> ILearningStrategy: self._logger.error(f"不支持的策略类型: {strategy_type}", exc_info=True) raise ServiceError(f"创建学习策略失败: {str(e)}") + @cached_service("quality_monitor") def create_quality_monitor(self) -> IQualityMonitor: """创建质量监控器 - 优先使用MaiBot增强版本""" - cache_key = "quality_monitor" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: # 如果启用了MaiBot增强功能,使用MaiBot质量监控器 if getattr(self.config, 'enable_maibot_features', False): try: - from ..services.maibot_adapters import MaiBotQualityMonitor + from ..services.integration import MaiBotQualityMonitor service = MaiBotQualityMonitor(self.config, self.create_database_manager()) - self._service_cache[cache_key] = service self._registry.register_service("quality_monitor", service) self._logger.info("创建MaiBot质量监控器成功") return service except ImportError as e: self._logger.warning(f"MaiBot质量监控器不可用,回退到默认实现: {e}") - + # 回退到默认实现 - from ..services.learning_quality_monitor import LearningQualityMonitor - + from ..services.quality import LearningQualityMonitor + service = LearningQualityMonitor( - self.config, - self.context, - llm_adapter=self.create_framework_llm_adapter(), # 使用框架适配器 - prompts=self.get_prompts() # 传递 prompts - ) - self._service_cache[cache_key] = service + self.config, + self.context, + llm_adapter=self.create_framework_llm_adapter(), # 使用框架适配器 + prompts=self.get_prompts() # 传递 prompts + ) self._registry.register_service("quality_monitor", service) self._logger.info("创建质量监控器成功") @@ -247,54 +240,41 @@ def create_quality_monitor(self) -> IQualityMonitor: self._logger.error(f"导入质量监控器失败: {e}", exc_info=True) raise ServiceError(f"创建质量监控器失败: {str(e)}") + @cached_service("database_manager") def create_database_manager(self): """创建数据库管理器 - 根据配置选择实现""" - cache_key = "database_manager" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - # 使用数据库工厂创建管理器(根据配置选择实现) - from ..services.database_factory import create_database_manager + from ..services.database import SQLAlchemyDatabaseManager - service = create_database_manager(self.config, self.context) - self._service_cache[cache_key] = service + service = SQLAlchemyDatabaseManager(self.config, self.context) self._registry.register_service("database_manager", service) - # 记录使用的实现类型 - impl_type = type(service).__name__ - self._logger.info(f"创建数据库管理器成功 (实现: {impl_type})") + self._logger.info(f"创建数据库管理器成功 (实现: SQLAlchemyDatabaseManager)") return service except ImportError as e: self._logger.error(f"导入数据库管理器失败: {e}", exc_info=True) raise ServiceError(f"创建数据库管理器失败: {str(e)}") + @cached_service("ml_analyzer") def create_ml_analyzer(self) -> IMLAnalyzer: """创建ML分析器""" - cache_key = "ml_analyzer" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.ml_analyzer import LightweightMLAnalyzer - + from ..services.analysis import LightweightMLAnalyzer + # 需要数据库管理器 db_manager = self.create_database_manager() - + # 获取临时人格更新器实例 temporary_persona_updater = self.create_temporary_persona_updater() service = LightweightMLAnalyzer( - self.config, - db_manager, - llm_adapter=self.create_framework_llm_adapter(), # 使用框架适配器 + self.config, + db_manager, + llm_adapter=self.create_framework_llm_adapter(), # 使用框架适配器 prompts=self.get_prompts(), # 传递 prompts temporary_persona_updater=temporary_persona_updater # 传递临时人格更新器 ) - self._service_cache[cache_key] = service self._logger.info("创建ML分析器成功") return service @@ -303,15 +283,11 @@ def create_ml_analyzer(self) -> IMLAnalyzer: self._logger.error(f"导入ML分析器失败: {e}", exc_info=True) raise ServiceError(f"创建ML分析器失败: {str(e)}") + @cached_service("intelligent_responder") def create_intelligent_responder(self) -> IIntelligentResponder: """创建智能回复器""" - cache_key = "intelligent_responder" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.intelligent_responder import IntelligentResponder + from ..services.response import IntelligentResponder # 需要数据库管理器 db_manager = self.create_database_manager() @@ -335,7 +311,6 @@ def create_intelligent_responder(self) -> IIntelligentResponder: diversity_manager=diversity_manager, # 传递多样性管理器 social_context_injector=social_context_injector # 传递社交上下文注入器 ) - self._service_cache[cache_key] = service self._logger.info("创建智能回复器成功") return service @@ -344,22 +319,17 @@ def create_intelligent_responder(self) -> IIntelligentResponder: self._logger.error(f"导入智能回复器失败: {e}", exc_info=True) raise ServiceError(f"创建智能回复器失败: {str(e)}") + @cached_service("persona_manager") def create_persona_manager(self) -> IPersonaManager: """创建人格管理器""" - cache_key = "persona_manager" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.persona_manager import PersonaManagerService # 导入 PersonaManagerService - + from ..services.persona import PersonaManagerService # 导入 PersonaManagerService + # 创建依赖的服务 persona_updater = self.create_persona_updater() persona_backup_manager = self.create_persona_backup_manager() - + service = PersonaManagerService(self.config, self.context, persona_updater, persona_backup_manager) - self._service_cache[cache_key] = service self._registry.register_service("persona_manager", service) # 注册服务 self._logger.info("创建人格管理器成功") @@ -369,18 +339,13 @@ def create_persona_manager(self) -> IPersonaManager: self._logger.error(f"导入人格管理器失败: {e}", exc_info=True) raise ServiceError(f"创建人格管理器失败: {str(e)}") + @cached_service("persona_manager_updater") def create_persona_manager_updater(self): """创建PersonaManager增量更新器""" - cache_key = "persona_manager_updater" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.persona_manager_updater import PersonaManagerUpdater - + from ..services.persona import PersonaManagerUpdater + service = PersonaManagerUpdater(self.config, self.context) - self._service_cache[cache_key] = service self._registry.register_service("persona_manager_updater", service) self._logger.info("创建PersonaManager更新器成功") @@ -390,33 +355,28 @@ def create_persona_manager_updater(self): self._logger.error(f"导入PersonaManager更新器失败: {e}", exc_info=True) raise ServiceError(f"创建PersonaManager更新器失败: {str(e)}") + @cached_service("multidimensional_analyzer") def create_multidimensional_analyzer(self): """创建多维度分析器""" - cache_key = "multidimensional_analyzer" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.multidimensional_analyzer import MultidimensionalAnalyzer - + from ..services.analysis import MultidimensionalAnalyzer + db_manager = self.create_database_manager() # 获取 DatabaseManager 实例 - + # 使用框架LLM适配器 llm_adapter = self.create_framework_llm_adapter() - + # 获取临时人格更新器实例 temporary_persona_updater = self.create_temporary_persona_updater() service = MultidimensionalAnalyzer( - self.config, - db_manager, + self.config, + db_manager, self.context, - llm_adapter=llm_adapter, # 传递框架适配器 + llm_adapter=llm_adapter, # 传递框架适配器 prompts=self.get_prompts(), # 传递 prompts temporary_persona_updater=temporary_persona_updater # 传递临时人格更新器 ) - self._service_cache[cache_key] = service self._logger.info("创建多维度分析器成功") return service @@ -425,21 +385,17 @@ def create_multidimensional_analyzer(self): self._logger.error(f"导入多维度分析器失败: {e}", exc_info=True) raise ServiceError(f"创建多维度分析器失败: {str(e)}") + @cached_service("progressive_learning") def create_progressive_learning(self): """创建渐进式学习服务""" - cache_key = "progressive_learning" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.progressive_learning import ProgressiveLearningService - + from ..services.core_learning import ProgressiveLearningService + # Directly pass the database manager db_manager = self.create_database_manager() - + service = ProgressiveLearningService( - self.config, + self.config, self.context, db_manager=db_manager, # 传递 db_manager 实例 message_collector=self.create_message_collector(), @@ -450,7 +406,6 @@ def create_progressive_learning(self): ml_analyzer=self.create_ml_analyzer(), # 传递 ml_analyzer 实例 prompts=self.get_prompts() # 传递 prompts ) - self._service_cache[cache_key] = service self._registry.register_service("progressive_learning", service) self._logger.info("创建渐进式学习服务成功") @@ -461,18 +416,13 @@ def create_progressive_learning(self): raise ServiceError(f"创建渐进式学习服务失败: {str(e)}") + @cached_service("persona_backup_manager") def create_persona_backup_manager(self): """创建人格备份管理器""" - cache_key = "persona_backup_manager" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.persona_backup_manager import PersonaBackupManager + from ..services.persona import PersonaBackupManager db_manager = self.create_database_manager() service = PersonaBackupManager(self.config, self.context, db_manager) - self._service_cache[cache_key] = service self._registry.register_service("persona_backup_manager", service) self._logger.info("创建人格备份管理器成功") return service @@ -480,21 +430,17 @@ def create_persona_backup_manager(self): self._logger.error(f"导入人格备份管理器失败: {e}", exc_info=True) raise ServiceError(f"创建人格备份管理器失败: {str(e)}") + @cached_service("temporary_persona_updater") def create_temporary_persona_updater(self): """创建临时人格更新器""" - cache_key = "temporary_persona_updater" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.temporary_persona_updater import TemporaryPersonaUpdater - + from ..services.persona import TemporaryPersonaUpdater + # 获取依赖的服务 persona_updater = self.create_persona_updater() backup_manager = self.create_persona_backup_manager() db_manager = self.create_database_manager() - + service = TemporaryPersonaUpdater( self.config, self.context, @@ -502,7 +448,6 @@ def create_temporary_persona_updater(self): backup_manager, db_manager ) - self._service_cache[cache_key] = service self._registry.register_service("temporary_persona_updater", service) self._logger.info("创建临时人格更新器成功") @@ -512,24 +457,19 @@ def create_temporary_persona_updater(self): self._logger.error(f"导入临时人格更新器失败: {e}", exc_info=True) raise ServiceError(f"创建临时人格更新器失败: {str(e)}") + @cached_service("persona_updater") def create_persona_updater(self) -> IPersonaUpdater: # 修改返回类型为 IPersonaUpdater """创建人格更新器""" - cache_key = "persona_updater" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.persona_updater import PersonaUpdater + from ..services.persona import PersonaUpdater backup_manager = self.create_persona_backup_manager() service = PersonaUpdater( - self.config, - self.context, - backup_manager, - None, # llm_client参数保持为可选 - self.create_database_manager() # 传递正确的db_manager + self.config, + self.context, + backup_manager, + None, # llm_client参数保持为可选 + self.create_database_manager() # 传递正确的db_manager ) - self._service_cache[cache_key] = service self._registry.register_service("persona_updater", service) self._logger.info("创建人格更新器成功") return service @@ -555,11 +495,7 @@ def get_persona_updater(self) -> Optional[IPersonaUpdater]: def get_service_registry(self) -> ServiceRegistry: """获取服务注册表""" return self._registry - - def get_event_bus(self) -> EventBus: - """获取事件总线""" - return self._event_bus - + async def initialize_all_services(self) -> bool: """初始化所有服务""" self._logger.info("开始初始化所有服务") @@ -567,7 +503,7 @@ async def initialize_all_services(self) -> bool: try: # 按依赖顺序创建服务 self.create_database_manager() - self.create_temporary_persona_updater() # 临时人格更新器需要优先创建 + self.create_temporary_persona_updater() # 临时人格更新器需要优先创建 self.create_message_collector() self.create_style_analyzer() self.create_quality_monitor() @@ -575,7 +511,7 @@ async def initialize_all_services(self) -> bool: # 创建响应多样性管理器(在intelligent_responder之前)- 使用工厂方法 try: - self.create_response_diversity_manager() # 使用ServiceFactory的方法 + self.create_response_diversity_manager() # 使用ServiceFactory的方法 except Exception as e: self._logger.warning(f"创建响应多样性管理器失败(继续使用默认行为): {e}") @@ -585,7 +521,7 @@ async def initialize_all_services(self) -> bool: except Exception as e: self._logger.warning(f"创建社交上下文注入器失败(继续使用默认行为): {e}") - self.create_intelligent_responder() # 重新启用智能回复器 + self.create_intelligent_responder() # 重新启用智能回复器 self.create_persona_manager() self.create_multidimensional_analyzer() self.create_progressive_learning() @@ -634,22 +570,17 @@ def clear_cache(self): self._service_cache.clear() self._logger.info("服务缓存已清理") + @cached_service("response_diversity_manager") def create_response_diversity_manager(self): """创建响应多样性管理器""" - cache_key = "response_diversity_manager" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.response_diversity_manager import ResponseDiversityManager + from ..services.response import ResponseDiversityManager service = ResponseDiversityManager( config=self.config, db_manager=self.create_database_manager() ) - self._service_cache[cache_key] = service self._registry.register_service("response_diversity_manager", service) self._logger.info("创建响应多样性管理器成功") @@ -742,7 +673,7 @@ class MessageFilter: def __init__(self, config: PluginConfig, context: Context, prompts: Any = None): self.config = config self.context = context - self.prompts = prompts # 保存 prompts + self.prompts = prompts # 保存 prompts self._logger = logger async def is_suitable_for_learning(self, message: str) -> bool: @@ -772,7 +703,7 @@ async def is_suitable_for_learning(self, message: str) -> bool: ) # 不再使用LLM进行筛选,返回默认结果 - return False # 默认认为不适合学习 + return False # 默认认为不适合学习 except Exception as e: self._logger.error(f"LLM 筛选消息失败: {e}", exc_info=True) return False # LLM 调用失败,认为不适合 @@ -815,7 +746,7 @@ async def _learning_loop(self): break except Exception as e: self._logger.error(f"学习循环异常: {e}", exc_info=True) - await asyncio.sleep(60) # 错误后等待1分钟再重试 + await asyncio.sleep(60) # 错误后等待1分钟再重试 class ComponentFactory: @@ -844,25 +775,20 @@ def create_learning_scheduler(self, plugin_instance): def create_persona_updater(self, context: Context, backup_manager): """创建人格更新器""" - from ..services.persona_updater import PersonaUpdater as ActualPersonaUpdater # 导入实际的 PersonaUpdater + from ..services.persona import PersonaUpdater as ActualPersonaUpdater # 导入实际的 PersonaUpdater prompts = self.service_factory.get_prompts() # 获取 prompts return ActualPersonaUpdater(self.config, context, backup_manager, None, prompts) + @cached_service("data_analytics") def create_data_analytics_service(self): """创建数据分析与可视化服务""" - cache_key = "data_analytics" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.data_analytics import DataAnalyticsService - + from ..services.analysis import DataAnalyticsService + service = DataAnalyticsService( self.config, self.service_factory.create_database_manager() ) - self._service_cache[cache_key] = service self._registry.register_service("data_analytics", service) self._logger.info("创建数据分析服务成功") @@ -872,23 +798,18 @@ def create_data_analytics_service(self): self._logger.error(f"导入数据分析服务失败: {e}", exc_info=True) raise ServiceError(f"创建数据分析服务失败: {str(e)}") + @cached_service("advanced_learning") def create_advanced_learning_service(self): """创建高级学习机制服务""" - cache_key = "advanced_learning" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.advanced_learning import AdvancedLearningService - + from ..services.core_learning import AdvancedLearningService + service = AdvancedLearningService( self.config, database_manager=self.service_factory.create_database_manager(), persona_manager=self.service_factory.create_persona_manager(), - llm_adapter=self.service_factory.create_framework_llm_adapter() # 使用框架适配器 + llm_adapter=self.service_factory.create_framework_llm_adapter() # 使用框架适配器 ) - self._service_cache[cache_key] = service self._registry.register_service("advanced_learning", service) self._logger.info("创建高级学习服务成功") @@ -898,22 +819,17 @@ def create_advanced_learning_service(self): self._logger.error(f"导入高级学习服务失败: {e}", exc_info=True) raise ServiceError(f"创建高级学习服务失败: {str(e)}") + @cached_service("enhanced_interaction") def create_enhanced_interaction_service(self): """创建增强交互服务""" - cache_key = "enhanced_interaction" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.enhanced_interaction import EnhancedInteractionService - + from ..services.state import EnhancedInteractionService + service = EnhancedInteractionService( self.config, database_manager=self.service_factory.create_database_manager(), - llm_adapter=self.service_factory.create_framework_llm_adapter() # 使用框架适配器 + llm_adapter=self.service_factory.create_framework_llm_adapter() # 使用框架适配器 ) - self._service_cache[cache_key] = service self._registry.register_service("enhanced_interaction", service) self._logger.info("创建增强交互服务成功") @@ -923,23 +839,18 @@ def create_enhanced_interaction_service(self): self._logger.error(f"导入增强交互服务失败: {e}", exc_info=True) raise ServiceError(f"创建增强交互服务失败: {str(e)}") + @cached_service("intelligence_enhancement") def create_intelligence_enhancement_service(self): """创建智能化提升服务""" - cache_key = "intelligence_enhancement" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.intelligence_enhancement import IntelligenceEnhancementService - + from ..services.analysis import IntelligenceEnhancementService + service = IntelligenceEnhancementService( self.config, database_manager=self.service_factory.create_database_manager(), persona_manager=self.service_factory.create_persona_manager(), - llm_adapter=self.service_factory.create_framework_llm_adapter() # 使用框架适配器 + llm_adapter=self.service_factory.create_framework_llm_adapter() # 使用框架适配器 ) - self._service_cache[cache_key] = service self._registry.register_service("intelligence_enhancement", service) self._logger.info("创建智能化提升服务成功") @@ -949,16 +860,12 @@ def create_intelligence_enhancement_service(self): self._logger.error(f"导入智能化提升服务失败: {e}", exc_info=True) raise ServiceError(f"创建智能化提升服务失败: {str(e)}") + @cached_service("affection_manager") def create_affection_manager_service(self): """创建好感度管理服务 - 根据配置选择实现""" - cache_key = "affection_manager" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: # 使用管理器工厂创建好感度管理器(根据配置选择实现) - from ..services.manager_factory import get_manager_factory + from ..services.database import get_manager_factory # 获取或创建管理器工厂 manager_factory = get_manager_factory(self.config) @@ -969,7 +876,6 @@ def create_affection_manager_service(self): llm_adapter=self.service_factory.create_framework_llm_adapter() ) - self._service_cache[cache_key] = service self._registry.register_service("affection_manager", service) # 记录使用的实现类型 @@ -981,15 +887,11 @@ def create_affection_manager_service(self): self._logger.error(f"导入好感度管理服务失败: {e}", exc_info=True) raise ServiceError(f"创建好感度管理服务失败: {str(e)}") + @cached_service("expression_pattern_learner") def create_expression_pattern_learner(self): """创建表达模式学习器""" - cache_key = "expression_pattern_learner" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.expression_pattern_learner import ExpressionPatternLearner + from ..services.analysis import ExpressionPatternLearner # 使用单例模式获取实例 service = ExpressionPatternLearner.get_instance( @@ -999,7 +901,6 @@ def create_expression_pattern_learner(self): llm_adapter=self.service_factory.create_framework_llm_adapter() ) - self._service_cache[cache_key] = service self._registry.register_service("expression_pattern_learner", service) self._logger.info("创建表达模式学习器成功") @@ -1009,16 +910,12 @@ def create_expression_pattern_learner(self): self._logger.error(f"导入表达模式学习器失败: {e}", exc_info=True) raise ServiceError(f"创建表达模式学习器失败: {str(e)}") + @cached_service("social_context_injector") def create_social_context_injector(self): """创建社交上下文注入器(整合了心理状态和行为指导功能)""" - cache_key = "social_context_injector" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.social_context_injector import SocialContextInjector - from ..services.manager_factory import ManagerFactory + from ..services.social import SocialContextInjector + from ..services.database import ManagerFactory db_manager = self.service_factory.create_database_manager() llm_adapter = self.service_factory.create_framework_llm_adapter() @@ -1038,33 +935,32 @@ def create_social_context_injector(self): try: # 创建心理状态管理器 psychological_state_manager = manager_factory.create_psychological_manager( - database_manager=db_manager, # ✅ 使用正确的参数名 database_manager + database_manager=db_manager, # 使用正确的参数名 database_manager llm_adapter=llm_adapter, - affection_manager=None # 避免循环依赖 + affection_manager=None # 避免循环依赖 ) # 创建社交关系管理器 social_relation_manager = manager_factory.create_social_relation_manager( - database_manager=db_manager, # ✅ 使用正确的参数名 database_manager + database_manager=db_manager, # 使用正确的参数名 database_manager llm_adapter=llm_adapter ) - self._logger.info("✅ 成功创建心理状态和社交关系管理器(整合到SocialContextInjector)") + self._logger.info(" 成功创建心理状态和社交关系管理器(整合到SocialContextInjector)") except Exception as e: self._logger.warning(f"创建心理状态/社交关系管理器失败: {e},将使用基础功能") service = SocialContextInjector( database_manager=db_manager, affection_manager=affection_manager, - mood_manager=affection_manager, # AffectionManager同时也管理情绪 - config=self.config, # ✅ 传递config以读取expression_patterns_hours配置 - psychological_state_manager=psychological_state_manager, # 新增:心理状态管理器 - social_relation_manager=social_relation_manager, # 新增:社交关系管理器(但使用原有实现) - llm_adapter=llm_adapter, # 新增:LLM适配器 - goal_manager=goal_manager # 新增:对话目标管理器 + mood_manager=affection_manager, # AffectionManager同时也管理情绪 + config=self.config, # 传递config以读取expression_patterns_hours配置 + psychological_state_manager=psychological_state_manager, # 新增:心理状态管理器 + social_relation_manager=social_relation_manager, # 新增:社交关系管理器(但使用原有实现) + llm_adapter=llm_adapter, # 新增:LLM适配器 + goal_manager=goal_manager # 新增:对话目标管理器 ) - self._service_cache[cache_key] = service self._registry.register_service("social_context_injector", service) if goal_manager: @@ -1077,15 +973,11 @@ def create_social_context_injector(self): self._logger.error(f"导入社交上下文注入器失败: {e}", exc_info=True) raise ServiceError(f"创建社交上下文注入器失败: {str(e)}") + @cached_service("conversation_goal_manager") def create_conversation_goal_manager(self): """创建对话目标管理器""" - cache_key = "conversation_goal_manager" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.conversation_goal_manager import ConversationGoalManager + from ..services.quality import ConversationGoalManager service = ConversationGoalManager( database_manager=self.service_factory.create_database_manager(), @@ -1093,7 +985,6 @@ def create_conversation_goal_manager(self): config=self.config ) - self._service_cache[cache_key] = service self._registry.register_service("conversation_goal_manager", service) self._logger.info("创建对话目标管理器成功") @@ -1103,16 +994,12 @@ def create_conversation_goal_manager(self): self._logger.error(f"导入对话目标管理器失败: {e}", exc_info=True) raise ServiceError(f"创建对话目标管理器失败: {str(e)}") + @cached_service("intelligent_chat_service") def create_intelligent_chat_service(self): """创建智能对话服务""" - cache_key = "intelligent_chat_service" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - try: - from ..services.intelligent_chat_service import IntelligentChatService - from ..services.manager_factory import ManagerFactory + from ..services.response import IntelligentChatService + from ..services.database import ManagerFactory # 创建必要的依赖 db_manager = self.service_factory.create_database_manager() @@ -1134,7 +1021,7 @@ def create_intelligent_chat_service(self): llm_adapter=llm_adapter, affection_manager=None ) - self._logger.info("✅ 为智能对话服务创建心理状态管理器成功") + self._logger.info(" 为智能对话服务创建心理状态管理器成功") except Exception as e: self._logger.warning(f"创建心理状态管理器失败: {e},智能对话服务将使用基础功能") @@ -1146,7 +1033,6 @@ def create_intelligent_chat_service(self): config=self.config ) - self._service_cache[cache_key] = service self._registry.register_service("intelligent_chat_service", service) self._logger.info("创建智能对话服务成功") @@ -1156,72 +1042,6 @@ def create_intelligent_chat_service(self): self._logger.error(f"导入智能对话服务失败: {e}", exc_info=True) raise ServiceError(f"创建智能对话服务失败: {str(e)}") - def create_psychological_social_context_injector(self): - """ - 创建心理社交上下文注入器 - - 该注入器整合了心理状态、社交关系、好感度等多维度信息, - 并使用LLM动态生成行为指导prompt - """ - cache_key = "psychological_social_context_injector" - - if cache_key in self._service_cache: - return self._service_cache[cache_key] - - try: - from ..services.psychological_social_context_injector import PsychologicalSocialContextInjector - from ..services.manager_factory import ManagerFactory - - # 获取必要的依赖 - db_manager = self.service_factory.create_database_manager() - llm_adapter = self.service_factory.create_framework_llm_adapter() - - # 使用ManagerFactory创建心理状态和社交关系管理器 - manager_factory = ManagerFactory(self.config) - - # 创建心理状态管理器(传递affection_manager=None避免循环依赖) - psychological_state_manager = manager_factory.create_psychological_manager( - database_manager=db_manager, # ✅ 使用正确的参数名 database_manager - llm_adapter=llm_adapter, - affection_manager=None - ) - - # 创建社交关系管理器 - social_relation_manager = manager_factory.create_social_relation_manager( - database_manager=db_manager, # ✅ 使用正确的参数名 database_manager - llm_adapter=llm_adapter - ) - - # 获取好感度管理器(如果已创建) - affection_manager = self._service_cache.get("affection_manager") - - # 获取响应多样性管理器(如果已创建) - diversity_manager = self._service_cache.get("response_diversity_manager") - - # 创建注入器实例 - service = PsychologicalSocialContextInjector( - database_manager=db_manager, - psychological_state_manager=psychological_state_manager, - social_relation_manager=social_relation_manager, - affection_manager=affection_manager, - diversity_manager=diversity_manager, - llm_adapter=llm_adapter, - config=self.config - ) - - # 缓存和注册 - self._service_cache[cache_key] = service - self._registry.register_service("psychological_social_context_injector", service) - - self._logger.info("✅ 创建心理社交上下文注入器成功") - return service - - except ImportError as e: - self._logger.error(f"❌ 导入心理社交上下文注入器失败: {e}", exc_info=True) - raise ServiceError(f"创建心理社交上下文注入器失败: {str(e)}") - except Exception as e: - self._logger.error(f"❌ 创建心理社交上下文注入器异常: {e}", exc_info=True) - raise ServiceError(f"创建心理社交上下文注入器失败: {str(e)}") # 全局工厂实例管理器 diff --git a/core/framework_llm_adapter.py b/core/framework_llm_adapter.py index 353a9cc..559ca34 100644 --- a/core/framework_llm_adapter.py +++ b/core/framework_llm_adapter.py @@ -18,9 +18,9 @@ def __init__(self, context): self.refine_provider: Optional[Provider] = None self.reinforce_provider: Optional[Provider] = None self.providers_configured = 0 - self._needs_lazy_init = False # 延迟初始化标记 - self._lazy_init_attempted = False # 避免重复尝试 - self._config = None # 保存配置用于延迟初始化 + self._needs_lazy_init = False # 延迟初始化标记 + self._lazy_init_attempted = False # 避免重复尝试 + self._config = None # 保存配置用于延迟初始化 # 添加调用统计 self.call_stats = { @@ -41,26 +41,26 @@ def initialize_providers(self, config): self.refine_provider = None self.reinforce_provider = None - # ✅ 添加配置调试日志 - logger.info(f"🔧 [LLM适配器] 开始初始化Provider,配置信息:") - logger.info(f" - filter_provider_id: {config.filter_provider_id}") - logger.info(f" - refine_provider_id: {config.refine_provider_id}") - logger.info(f" - reinforce_provider_id: {config.reinforce_provider_id}") + # 添加配置调试日志 + logger.info(f" [LLM适配器] 开始初始化Provider,配置信息:") + logger.info(f" - filter_provider_id: {config.filter_provider_id}") + logger.info(f" - refine_provider_id: {config.refine_provider_id}") + logger.info(f" - reinforce_provider_id: {config.reinforce_provider_id}") # 获取所有可用的Provider列表作为备选 available_providers = [] try: # 使用 get_all_providers() 方法获取所有 CHAT_COMPLETION 类型的 Provider all_providers = self.context.get_all_providers() - logger.info(f" - 发现 {len(all_providers)} 个 Provider") + logger.info(f" - 发现 {len(all_providers)} 个 Provider") for provider in all_providers: provider_meta = provider.meta() if provider_meta.provider_type == ProviderType.CHAT_COMPLETION: available_providers.append(provider) - logger.debug(f" ✅ Provider {provider_meta.id} 可用 (类型: {provider_meta.provider_type.value})") + logger.debug(f" Provider {provider_meta.id} 可用 (类型: {provider_meta.provider_type.value})") - logger.info(f"🔍 发现 {len(available_providers)} 个可用的 CHAT_COMPLETION 类型 Provider") + logger.info(f" 发现 {len(available_providers)} 个可用的 CHAT_COMPLETION 类型 Provider") except Exception as e: logger.warning(f"获取可用Provider列表失败: {e}") @@ -75,12 +75,12 @@ def initialize_providers(self, config): self._needs_lazy_init = True if has_configured_provider_ids: logger.warning( - "⏳ [LLM适配器] Provider 注册表尚未就绪(当前 0 个)," + " [LLM适配器] Provider 注册表尚未就绪(当前 0 个)," "跳过本次绑定并等待延迟重试。" ) else: logger.warning( - "⏳ [LLM适配器] 当前没有可用 Provider,且未配置 provider_id," + " [LLM适配器] 当前没有可用 Provider,且未配置 provider_id," "稍后将重试初始化。" ) return @@ -188,11 +188,11 @@ def initialize_providers(self, config): # 友好的配置状态提示 if self.providers_configured == 0: - logger.error("❌ 没有可用的AI模型Provider。请在AstrBot中配置至少一个CHAT_COMPLETION类型的Provider,并在插件配置中指定Provider ID。") + logger.error(" 没有可用的AI模型Provider。请在AstrBot中配置至少一个CHAT_COMPLETION类型的Provider,并在插件配置中指定Provider ID。") elif self.providers_configured < 3: - logger.info(f"ℹ️ 已配置 {self.providers_configured}/3 个AI模型Provider。部分高级功能可能使用简化算法。") + logger.info(f" 已配置 {self.providers_configured}/3 个AI模型Provider。部分高级功能可能使用简化算法。") else: - logger.info(f"✅ 已成功配置所有 {self.providers_configured} 个AI模型Provider!") + logger.info(f" 已成功配置所有 {self.providers_configured} 个AI模型Provider!") if self.providers_configured > 0: self._needs_lazy_init = False @@ -207,24 +207,24 @@ def initialize_providers(self, config): config_summary.append(f"强化: {self.reinforce_provider.meta().id}") if config_summary: - logger.info(f"📋 Provider配置摘要: {' | '.join(config_summary)}") + logger.info(f" Provider配置摘要: {' | '.join(config_summary)}") else: - logger.warning("⚠️ 所有Provider均未配置,插件功能将受限") + logger.warning(" 所有Provider均未配置,插件功能将受限") def _try_lazy_init(self): """尝试延迟初始化Provider(仅执行一次)""" if self._needs_lazy_init and not self._lazy_init_attempted and self._config: self._lazy_init_attempted = True - logger.info("🔄 [LLM适配器] 尝试延迟初始化Provider...") + logger.info(" [LLM适配器] 尝试延迟初始化Provider...") try: self.initialize_providers(self._config) if self.providers_configured > 0: self._needs_lazy_init = False - logger.info(f"✅ [LLM适配器] 延迟初始化成功,已配置 {self.providers_configured} 个Provider") + logger.info(f" [LLM适配器] 延迟初始化成功,已配置 {self.providers_configured} 个Provider") else: - logger.warning("⚠️ [LLM适配器] 延迟初始化仍未找到可用Provider") + logger.warning(" [LLM适配器] 延迟初始化仍未找到可用Provider") except Exception as e: - logger.warning(f"⚠️ [LLM适配器] 延迟初始化失败: {e}") + logger.warning(f" [LLM适配器] 延迟初始化失败: {e}") async def filter_chat_completion( self, diff --git a/core/interfaces.py b/core/interfaces.py index 4614683..de3d914 100644 --- a/core/interfaces.py +++ b/core/interfaces.py @@ -231,34 +231,6 @@ async def delete_data(self, key: str) -> bool: pass -class IObserver(ABC): - """观察者接口""" - - @abstractmethod - async def on_event(self, event_type: str, data: Dict[str, Any]) -> None: - """处理事件""" - pass - - -class IEventPublisher(ABC): - """事件发布器接口""" - - @abstractmethod - async def publish_event(self, event_type: str, data: Dict[str, Any]) -> None: - """发布事件""" - pass - - @abstractmethod - def subscribe(self, event_type: str, observer: IObserver) -> None: - """订阅事件""" - pass - - @abstractmethod - def unsubscribe(self, event_type: str, observer: IObserver) -> None: - """取消订阅""" - pass - - class IMessageRelationshipAnalyzer(ABC): """消息关系分析器接口""" @@ -480,16 +452,5 @@ class AnalysisType(Enum): QUALITY = "quality" -class EventType(Enum): - """事件类型""" - MESSAGE_COLLECTED = "message_collected" - MESSAGE_FILTERED = "message_filtered" - STYLE_ANALYZED = "style_analyzed" - PERSONA_UPDATED = "persona_updated" - LEARNING_COMPLETED = "learning_completed" - QUALITY_ISSUE_DETECTED = "quality_issue_detected" - SERVICE_STATUS_CHANGED = "service_status_changed" - - # 异常类型 (从 exceptions.py 导入,避免重复定义) from ..exceptions import SelfLearningError, ConfigurationError, DataStorageError, MessageCollectionError, StyleAnalysisError, PersonaUpdateError, ModelAccessError, LearningSchedulerError, ServiceError diff --git a/core/patterns.py b/core/patterns.py index 55f13f3..5226fa7 100644 --- a/core/patterns.py +++ b/core/patterns.py @@ -1,6 +1,5 @@ import abc -import asyncio from typing import Dict, List, Optional, Any, Type from dataclasses import dataclass, field from datetime import datetime @@ -8,8 +7,8 @@ from astrbot.api import logger # 导入 logger from .interfaces import ( - IObserver, IEventPublisher, IServiceFactory, ILearningStrategy, - IAsyncService, ServiceLifecycle, EventType, LearningStrategyType, + IServiceFactory, ILearningStrategy, + IAsyncService, ServiceLifecycle, LearningStrategyType, MessageData, AnalysisResult, IMessageCollector, IStyleAnalyzer, IQualityMonitor, IPersonaManager, ServiceError ) @@ -25,48 +24,6 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] -class EventBus(IEventPublisher, metaclass=SingletonABCMeta): - """事件总线 - 观察者模式实现""" - - def __init__(self): - self._observers: Dict[str, List[IObserver]] = {} - self._logger = logger - - def subscribe(self, event_type: str, observer: IObserver) -> None: - """订阅事件""" - if event_type not in self._observers: - self._observers[event_type] = [] - - if observer not in self._observers[event_type]: - self._observers[event_type].append(observer) - self._logger.debug(f"订阅事件 {event_type}: {observer.__class__.__name__}") - - def unsubscribe(self, event_type: str, observer: IObserver) -> None: - """取消订阅""" - if event_type in self._observers and observer in self._observers[event_type]: - self._observers[event_type].remove(observer) - self._logger.debug(f"取消订阅事件 {event_type}: {observer.__class__.__name__}") - - async def publish_event(self, event_type: str, data: Dict[str, Any]) -> None: - """发布事件""" - if event_type not in self._observers: - return - - self._logger.debug(f"发布事件 {event_type}, 观察者数量: {len(self._observers[event_type])}") - - # 并发通知所有观察者 - tasks = [] - for observer in self._observers[event_type]: - try: - task = asyncio.create_task(observer.on_event(event_type, data)) - tasks.append(task) - except Exception as e: - self._logger.error(f"通知观察者失败: {e}") - - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - class AsyncServiceBase(IAsyncService): """异步服务基类""" @@ -74,7 +31,6 @@ def __init__(self, name: str): self.name = name self._status = ServiceLifecycle.CREATED self._logger = logger - self._event_bus = EventBus() @property def status(self) -> ServiceLifecycle: @@ -85,17 +41,6 @@ async def _change_status(self, new_status: ServiceLifecycle): old_status = self._status self._status = new_status self._logger.info(f"服务状态变更: {old_status.value} -> {new_status.value}") - - # 发布状态变更事件 - await self._event_bus.publish_event( - EventType.SERVICE_STATUS_CHANGED.value, - { - 'service_name': self.name, - 'old_status': old_status.value, - 'new_status': new_status.value, - 'timestamp': datetime.now().isoformat() - } - ) async def start(self) -> bool: """启动服务""" @@ -438,84 +383,3 @@ def get_service_status(self) -> Dict[str, str]: } -class ConfigurationManager(metaclass=SingletonABCMeta): - """配置管理器 - 单例模式""" - - def __init__(self): - self._config: Dict[str, Any] = {} - self._observers: List[callable] = [] - self._logger = logger - - def update_config(self, key: str, value: Any): - """更新配置""" - old_value = self._config.get(key) - self._config[key] = value - - self._logger.info(f"配置更新: {key} = {value}") - - # 通知观察者 - for observer in self._observers: - try: - observer(key, old_value, value) - except Exception as e: - self._logger.error(f"通知配置观察者失败: {e}") - - def get_config(self, key: str, default: Any = None) -> Any: - """获取配置""" - return self._config.get(key, default) - - def add_observer(self, observer: callable): - """添加配置变更观察者""" - self._observers.append(observer) - - def remove_observer(self, observer: callable): - """移除配置变更观察者""" - if observer in self._observers: - self._observers.remove(observer) - - -class MetricsCollector(metaclass=SingletonABCMeta): - """指标收集器""" - - def __init__(self): - self._metrics: Dict[str, Any] = {} - self._logger = logger - - def record_metric(self, name: str, value: Any, tags: Dict[str, str] = None): - """记录指标""" - timestamp = datetime.now().timestamp() - - if name not in self._metrics: - self._metrics[name] = [] - - self._metrics[name].append({ - 'value': value, - 'timestamp': timestamp, - 'tags': tags or {} - }) - - # 保持最近1000条记录 - if len(self._metrics[name]) > 1000: - self._metrics[name] = self._metrics[name][-1000:] - - def get_metrics(self) -> Dict[str, Any]: - """获取所有指标""" - return self._metrics.copy() - - def get_metric_summary(self, name: str) -> Dict[str, Any]: - """获取指标摘要""" - if name not in self._metrics: - return {} - - values = [m['value'] for m in self._metrics[name] if isinstance(m['value'], (int, float))] - - if not values: - return {} - - return { - 'count': len(values), - 'min': min(values), - 'max': max(values), - 'avg': sum(values) / len(values), - 'latest': values[-1] if values else None - } diff --git a/core/plugin/__init__.py b/core/plugin/__init__.py new file mode 100644 index 0000000..a838bb4 --- /dev/null +++ b/core/plugin/__init__.py @@ -0,0 +1 @@ +"""Plugin orchestration layer — initialization, lifecycle, WebUI management.""" \ No newline at end of file diff --git a/core/plugin_lifecycle.py b/core/plugin_lifecycle.py new file mode 100644 index 0000000..77bb905 --- /dev/null +++ b/core/plugin_lifecycle.py @@ -0,0 +1,506 @@ +"""插件全生命周期编排 — 服务初始化 → 异步启动 → 有序关停""" +import os +import json +import asyncio +from typing import Any, Dict, TYPE_CHECKING + +from astrbot.api import logger + +from .factory import FactoryManager +from ..exceptions import SelfLearningError +from ..statics.messages import StatusMessages, LogMessages + +if TYPE_CHECKING: + pass # 避免循环导入 + + +class PluginLifecycle: + """插件全生命周期编排:初始化 → 启动 → 关停 + + 将 main.py 中的 _initialize_services / _setup_internal_components / + on_load / terminate 逻辑统一到一处。 + """ + + def __init__(self, plugin: Any): + """ + Args: + plugin: SelfLearningPlugin 实例(回引,用于设置属性) + """ + self._plugin = plugin + self._webui_manager = None # Phase 2 WebUIManager 延迟创建 + + # Phase 1: 同步初始化(__init__ 阶段调用) + + def bootstrap( + self, + plugin_config: Any, + context: Any, + group_id_to_unified_origin: Dict[str, str], + ) -> None: + """同步初始化:创建全部服务并注入到 plugin 实例上""" + p = self._plugin # 简写 + + try: + # ------ FactoryManager 初始化 ------ + p.factory_manager = FactoryManager() + p.factory_manager.initialize_factories(plugin_config, context) + p.service_factory = p.factory_manager.get_service_factory() + + # ------ ServiceFactory 创建核心服务 ------ + p.db_manager = p.service_factory.create_database_manager() + p.message_collector = p.service_factory.create_message_collector() + p.multidimensional_analyzer = p.service_factory.create_multidimensional_analyzer() + p.style_analyzer = p.service_factory.create_style_analyzer() + p.quality_monitor = p.service_factory.create_quality_monitor() + p.progressive_learning = p.service_factory.create_progressive_learning() + p.ml_analyzer = p.service_factory.create_ml_analyzer() + p.persona_manager = p.service_factory.create_persona_manager() + p.diversity_manager = p.service_factory.create_response_diversity_manager() + + # ------ ComponentFactory 创建高级服务 ------ + component_factory = p.factory_manager.get_component_factory() + p.data_analytics = component_factory.create_data_analytics_service() + p.advanced_learning = component_factory.create_advanced_learning_service() + p.enhanced_interaction = component_factory.create_enhanced_interaction_service() + p.intelligence_enhancement = component_factory.create_intelligence_enhancement_service() + p.affection_manager = component_factory.create_affection_manager_service() + + # ------ 条件创建:对话目标管理器 ------ + logger.info( + f"[初始化] enable_goal_driven_chat={plugin_config.enable_goal_driven_chat}" + ) + if plugin_config.enable_goal_driven_chat: + try: + p.conversation_goal_manager = ( + component_factory.create_conversation_goal_manager() + ) + logger.info("对话目标管理器已初始化") + except Exception as e: + logger.error(f"创建对话目标管理器失败: {e}", exc_info=True) + p.conversation_goal_manager = None + else: + p.conversation_goal_manager = None + logger.info("对话目标管理器未启用") + + # ------ 社交上下文注入器(必须在 intelligent_responder 之前)------ + p.social_context_injector = component_factory.create_social_context_injector() + + # ------ 黑话服务 ------ + from ..services.jargon import ( + JargonQueryService, + JargonMinerManager, + JargonStatisticalFilter, + ) + + p.jargon_query_service = JargonQueryService( + db_manager=p.db_manager, cache_ttl=60 + ) + logger.info("黑话查询服务已初始化(带60秒缓存)") + + p.jargon_miner_manager = JargonMinerManager( + llm_adapter=p.service_factory.create_framework_llm_adapter(), + db_manager=p.db_manager, + config=plugin_config, + ) + logger.info("黑话挖掘管理器已初始化") + + p.jargon_statistical_filter = JargonStatisticalFilter() + logger.info("黑话统计预筛器已初始化") + + # ------ V2 架构集成(条件创建)------ + p.v2_integration = None + logger.info( + f"[V2] Config check: knowledge_engine='{plugin_config.knowledge_engine}', " + f"memory_engine='{plugin_config.memory_engine}'" + ) + if ( + plugin_config.knowledge_engine != "legacy" + or plugin_config.memory_engine != "legacy" + ): + try: + from ..services.core_learning import V2LearningIntegration + + llm_adapter = p.service_factory.create_framework_llm_adapter() + p.v2_integration = V2LearningIntegration( + config=plugin_config, + llm_adapter=llm_adapter, + db_manager=p.db_manager, + context=context, + ) + logger.info( + f"V2LearningIntegration initialised " + f"(knowledge={plugin_config.knowledge_engine}, " + f"memory={plugin_config.memory_engine})" + ) + except Exception as exc: + logger.warning( + f"V2LearningIntegration init failed, v2 features disabled: {exc}" + ) + p.v2_integration = None + + # ------ 依赖后创建的服务 ------ + p.intelligent_responder = p.service_factory.create_intelligent_responder() + p.temporary_persona_updater = p.service_factory.create_temporary_persona_updater() + + # ------ group_id 映射表传递 ------ + p.temporary_persona_updater.group_id_to_unified_origin = ( + group_id_to_unified_origin + ) + if p.progressive_learning: + p.progressive_learning.group_id_to_unified_origin = ( + group_id_to_unified_origin + ) + if p.persona_manager: + p.persona_manager.group_id_to_unified_origin = ( + group_id_to_unified_origin + ) + logger.info("已将 group_id 映射表传递给服务组件") + + # ------ LLM 适配器(状态报告用)------ + p.llm_adapter = p.service_factory.create_framework_llm_adapter() + + # ------ 内部组件(QQ过滤/消息过滤/人格更新/调度器)------ + self._setup_internal_components(plugin_config, context, group_id_to_unified_origin) + + # ------ 提取的服务模块 ------ + from ..services.learning.dialog_analyzer import DialogAnalyzer + from ..services.learning.realtime_processor import RealtimeProcessor + from ..services.learning.group_orchestrator import GroupLearningOrchestrator + from ..services.hooks.llm_hook_handler import LLMHookHandler + + p._dialog_analyzer = DialogAnalyzer(p.factory_manager, p.db_manager) + p._realtime_processor = RealtimeProcessor( + plugin_config=plugin_config, + message_collector=p.message_collector, + multidimensional_analyzer=p.multidimensional_analyzer, + persona_manager=p.persona_manager, + temporary_persona_updater=p.temporary_persona_updater, + dialog_analyzer=p._dialog_analyzer, + learning_stats=p.learning_stats, + factory_manager=p.factory_manager, + db_manager=p.db_manager, + ) + p._group_orchestrator = GroupLearningOrchestrator( + plugin_config=plugin_config, + message_collector=p.message_collector, + progressive_learning=p.progressive_learning, + qq_filter=p.qq_filter, + db_manager=p.db_manager, + ) + p._hook_handler = LLMHookHandler( + plugin_config=plugin_config, + diversity_manager=getattr(p, "diversity_manager", None), + social_context_injector=getattr(p, "social_context_injector", None), + v2_integration=getattr(p, "v2_integration", None), + jargon_query_service=getattr(p, "jargon_query_service", None), + temporary_persona_updater=getattr(p, "temporary_persona_updater", None), + perf_tracker=p._perf_tracker, + group_id_to_unified_origin=group_id_to_unified_origin, + db_manager=getattr(p, "db_manager", None), + ) + + # ------ 消息处理流水线 ------ + from ..services.learning.message_pipeline import MessagePipeline + + p._pipeline = MessagePipeline( + plugin_config=plugin_config, + message_collector=p.message_collector, + enhanced_interaction=p.enhanced_interaction, + jargon_miner_manager=getattr(p, "jargon_miner_manager", None), + jargon_statistical_filter=getattr(p, "jargon_statistical_filter", None), + v2_integration=getattr(p, "v2_integration", None), + realtime_processor=p._realtime_processor, + group_orchestrator=p._group_orchestrator, + conversation_goal_manager=getattr(p, "conversation_goal_manager", None), + affection_manager=p.affection_manager, + db_manager=p.db_manager, + ) + + # ------ 命令处理器 ------ + from ..services.commands import PluginCommandHandlers, CommandFilter + + p._command_handlers = PluginCommandHandlers( + plugin_config=plugin_config, + service_factory=p.service_factory, + message_collector=p.message_collector, + persona_manager=p.persona_manager, + progressive_learning=p.progressive_learning, + affection_manager=p.affection_manager, + temporary_persona_updater=p.temporary_persona_updater, + db_manager=p.db_manager, + llm_adapter=p.llm_adapter, + ) + p._command_filter = CommandFilter() + + # ------ WebUI 管理器 ------ + from ..webui.manager import WebUIManager + + self._webui_manager = WebUIManager( + plugin_config=plugin_config, + context=context, + factory_manager=p.factory_manager, + perf_tracker=p._perf_tracker, + group_id_to_unified_origin=group_id_to_unified_origin, + ) + need_immediate_start = self._webui_manager.create_server() + if need_immediate_start: + asyncio.create_task(self._webui_manager.immediate_start(p.db_manager)) + + # ------ 自动学习启动(必须在 _group_orchestrator 创建之后)------ + if plugin_config.enable_auto_learning: + asyncio.create_task(p._group_orchestrator.delayed_auto_start_learning()) + + logger.info(StatusMessages.FACTORY_SERVICES_INIT_COMPLETE) + + except SelfLearningError as sle: + logger.error(StatusMessages.SERVICES_INIT_FAILED.format(error=sle)) + raise + except (TypeError, ValueError) as e: + logger.error( + StatusMessages.CONFIG_TYPE_ERROR.format(error=e), exc_info=True + ) + raise SelfLearningError( + StatusMessages.INIT_FAILED_GENERIC.format(error=str(e)) + ) from e + except Exception as e: + logger.error( + StatusMessages.UNKNOWN_INIT_ERROR.format(error=e), exc_info=True + ) + raise SelfLearningError( + StatusMessages.INIT_FAILED_GENERIC.format(error=str(e)) + ) from e + + def _setup_internal_components( + self, + plugin_config: Any, + context: Any, + group_id_to_unified_origin: Dict[str, str], + ) -> None: + """设置内部组件 — QQ 过滤 / 消息过滤 / 人格更新器 / 学习调度器""" + p = self._plugin + component_factory = p.factory_manager.get_component_factory() + p.component_factory = component_factory + + p.qq_filter = component_factory.create_qq_filter() + p.message_filter = component_factory.create_message_filter(context) + + persona_backup_manager_instance = p.service_factory.create_persona_backup_manager() + p.persona_updater = component_factory.create_persona_updater( + context, persona_backup_manager_instance + ) + + p.persona_updater.group_id_to_unified_origin = group_id_to_unified_origin + persona_backup_manager_instance.group_id_to_unified_origin = ( + group_id_to_unified_origin + ) + + p.learning_scheduler = component_factory.create_learning_scheduler(p) + p.background_tasks = set() + + asyncio.create_task(self._delayed_provider_reinitialization()) + + # Phase 2: 异步启动(on_load 阶段调用) + + async def on_load(self) -> None: + """异步启动:DB(带重试)+ 服务 + WebUI""" + p = self._plugin + plugin_config = p.plugin_config + + logger.info(StatusMessages.ON_LOAD_START) + + # ------ DB 启动(带重试)------ + db_started = False + max_retries = 3 + retry_delay = 2 + + for attempt in range(max_retries): + try: + logger.info(f"尝试启动数据库管理器 (第 {attempt + 1}/{max_retries} 次)") + db_started = await p.db_manager.start() + if db_started: + logger.info(StatusMessages.DB_MANAGER_STARTED) + break + else: + logger.warning( + f"数据库管理器启动返回 False (尝试 {attempt + 1}/{max_retries})" + ) + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay) + except Exception as e: + logger.error( + f"数据库启动异常 (尝试 {attempt + 1}/{max_retries}): {e}", + exc_info=True, + ) + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay) + + if not db_started: + logger.error( + StatusMessages.DB_MANAGER_START_FAILED.format(error="所有重试均失败") + ) + logger.warning("插件将在数据库功能受限的情况下继续运行") + + # ------ 好感度管理服务 ------ + if plugin_config.enable_affection_system: + try: + await p.affection_manager.start() + logger.info("好感度管理服务启动成功") + except Exception as e: + logger.error(f"好感度管理服务启动失败: {e}", exc_info=True) + + # ------ V2 学习集成 ------ + if getattr(p, "v2_integration", None): + try: + await p.v2_integration.start() + logger.info("V2LearningIntegration started successfully") + except Exception as e: + logger.error(f"V2LearningIntegration start failed: {e}", exc_info=True) + + # ------ WebUI ------ + if self._webui_manager: + await self._webui_manager.setup_and_start() + + logger.info(StatusMessages.PLUGIN_LOAD_COMPLETE) + + # Phase 3: 有序关停(terminate 阶段调用) + + async def shutdown(self) -> None: + """有序关停所有服务""" + p = self._plugin + try: + logger.info("开始插件清理工作...") + + # 1. 停止学习任务 + logger.info("停止所有学习任务...") + await p._group_orchestrator.cancel_all() + + # 2. 停止学习调度器 + if hasattr(p, "learning_scheduler"): + try: + await p.learning_scheduler.stop() + logger.info("学习调度器已停止") + except Exception as e: + logger.error(f"停止学习调度器失败: {e}") + + # 3. 取消后台任务 + logger.info("取消所有后台任务...") + for task in list(p.background_tasks): + try: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + logger.error( + LogMessages.BACKGROUND_TASK_CANCEL_ERROR.format(error=e) + ) + p.background_tasks.clear() + + # 4. 停止服务工厂 + logger.info("停止所有服务...") + if hasattr(p, "factory_manager"): + try: + await p.factory_manager.cleanup() + logger.info("服务工厂已清理") + except Exception as e: + logger.error(f"清理服务工厂失败: {e}") + + # 4.5 停止 V2 + if getattr(p, "v2_integration", None): + try: + await p.v2_integration.stop() + logger.info("V2LearningIntegration stopped") + except Exception as e: + logger.error(f"V2LearningIntegration stop failed: {e}") + + # 4.6 重置单例 + try: + from ..services.state import EnhancedMemoryGraphManager + + EnhancedMemoryGraphManager._instance = None + EnhancedMemoryGraphManager._initialized = False + logger.info("MemoryGraphManager 单例已重置") + except Exception: + pass + + # 5. 清理临时人格 + if hasattr(p, "temporary_persona_updater"): + try: + await p.temporary_persona_updater.cleanup_temp_personas() + logger.info("临时人格已清理") + except Exception as e: + logger.error(f"清理临时人格失败: {e}") + + # 6. 保存状态 + if hasattr(p, "message_collector"): + try: + await p.message_collector.save_state() + logger.info("消息收集器状态已保存") + except Exception as e: + logger.error(f"保存消息收集器状态失败: {e}") + + # 7. 停止 WebUI + if self._webui_manager: + await self._webui_manager.stop() + + # 8. 保存配置 + try: + config_path = os.path.join(p.plugin_config.data_dir, "config.json") + with open(config_path, "w", encoding="utf-8") as f: + json.dump(p.plugin_config.to_dict(), f, ensure_ascii=False, indent=2) + logger.info(LogMessages.PLUGIN_CONFIG_SAVED) + except Exception as e: + logger.error(f"保存配置失败: {e}") + + logger.info(LogMessages.PLUGIN_UNLOAD_SUCCESS) + + except Exception as e: + logger.error( + LogMessages.PLUGIN_UNLOAD_CLEANUP_FAILED.format(error=e), + exc_info=True, + ) + + # 辅助异步方法 + + async def _delayed_provider_reinitialization(self) -> None: + """延迟重新初始化提供商配置,解决重启后配置丢失问题""" + p = self._plugin + try: + await asyncio.sleep(10) + + if getattr(p, "llm_adapter", None): + p.llm_adapter.initialize_providers(p.plugin_config) + logger.info("延迟重新初始化提供商配置完成") + + if p.llm_adapter.providers_configured == 0: + logger.warning("重新初始化后仍然没有配置任何提供商,请检查配置") + await asyncio.sleep(30) + p.llm_adapter.initialize_providers(p.plugin_config) + logger.info("第二次尝试重新初始化提供商配置") + else: + logger.info( + f"成功配置了 {p.llm_adapter.providers_configured} 个提供商" + ) + except Exception as e: + logger.error(f"延迟重新初始化提供商配置失败: {e}") + + async def _delayed_start_learning(self, group_id: str) -> None: + """延迟启动学习服务""" + p = self._plugin + try: + await asyncio.sleep(3) + await p.service_factory.initialize_all_services() + await p.progressive_learning.start_learning(group_id) + logger.info( + StatusMessages.AUTO_LEARNING_SCHEDULER_STARTED.format( + group_id=group_id + ) + ) + except Exception as e: + logger.error( + StatusMessages.LEARNING_SERVICE_START_FAILED.format( + group_id=group_id, error=e + ) + ) diff --git a/main.py b/main.py index 356372f..69ba0d9 100644 --- a/main.py +++ b/main.py @@ -2,31 +2,23 @@ AstrBot 自学习插件 - 智能对话风格学习与人格优化 """ import os -import json # 导入 json 模块 import asyncio -import time -import re # 导入正则表达式模块 -from datetime import datetime -from typing import List, Dict, Optional, Any +from typing import Dict, Optional from dataclasses import dataclass from astrbot.api.event import AstrMessageEvent from astrbot.api.event import filter from astrbot.api.event.filter import PermissionType import astrbot.api.star as star -from astrbot.api.star import register, Context +from astrbot.api.star import Context from astrbot.api import logger, AstrBotConfig from astrbot.core.utils.astrbot_path import get_astrbot_data_path from .config import PluginConfig -from .core.factory import FactoryManager -from .core.interfaces import MessageData -from .exceptions import SelfLearningError -from .webui import Server, set_plugin_services # 导入 FastAPI 服务器相关 -from .statics.messages import StatusMessages, CommandMessages, LogMessages, FileNames, DefaultValues +from .core.plugin_lifecycle import PluginLifecycle +from .services.hooks.perf_tracker import PerfTracker +from .statics.messages import StatusMessages, FileNames -server_instance: Optional[Server] = None # 全局服务器实例 -_server_cleanup_lock = asyncio.Lock() # 服务器清理锁,防止并发清理 @dataclass class LearningStats: @@ -38,6 +30,7 @@ class LearningStats: last_learning_time: Optional[str] = None last_persona_update: Optional[str] = None + class SelfLearningPlugin(star.Star): """AstrBot 自学习插件 - 智能学习用户对话风格并优化人格设置""" @@ -45,2260 +38,170 @@ def __init__(self, context: Context, config: AstrBotConfig = None) -> None: super().__init__(context) self.context = context self.config = config or {} - - # 初始化插件配置 - # 获取插件数据目录,并传递给 PluginConfig + + # ------ 插件配置加载 ------ try: astrbot_data_path = get_astrbot_data_path() if astrbot_data_path is None: - # 回退到当前目录下的 data 目录 astrbot_data_path = os.path.join(os.path.dirname(__file__), "data") logger.warning("无法获取 AstrBot 数据路径,使用插件目录下的 data 目录") - # 检查用户是否在配置中自定义了数据存储路径 - # 从 Storage_Settings.data_dir 读取配置 storage_settings = self.config.get('Storage_Settings', {}) if self.config else {} user_data_dir = storage_settings.get('data_dir') if user_data_dir: - # 用户自定义了数据路径,使用用户指定的路径 logger.info(f"使用用户自定义数据路径 (从Storage_Settings.data_dir): {user_data_dir}") plugin_data_dir = user_data_dir - # 确保路径是绝对路径 if not os.path.isabs(plugin_data_dir): plugin_data_dir = os.path.abspath(plugin_data_dir) else: - # 使用 plugin_data 目录而不是 plugins 目录,这样数据不会在插件卸载时被删除 - # 根据 AstrBot 框架规范,插件持久化数据应存储在 data/plugin_data/{plugin_name}/ - plugin_data_dir = os.path.join(astrbot_data_path, "plugin_data", "astrbot_plugin_self_learning") + plugin_data_dir = os.path.join( + astrbot_data_path, "plugin_data", "astrbot_plugin_self_learning" + ) logger.info(f"使用默认数据路径: {plugin_data_dir}") logger.info(f"最终插件数据目录: {plugin_data_dir}") self.plugin_config = PluginConfig.create_from_config(self.config, data_dir=plugin_data_dir) - # ✅ 添加Provider配置加载日志 - logger.info(f"🔧 [插件初始化] Provider配置已加载:") - logger.info(f" - filter_provider_id: {self.plugin_config.filter_provider_id}") - logger.info(f" - refine_provider_id: {self.plugin_config.refine_provider_id}") - logger.info(f" - reinforce_provider_id: {self.plugin_config.reinforce_provider_id}") + logger.info(f"[插件初始化] Provider配置已加载:") + logger.info(f" - filter_provider_id: {self.plugin_config.filter_provider_id}") + logger.info(f" - refine_provider_id: {self.plugin_config.refine_provider_id}") + logger.info(f" - reinforce_provider_id: {self.plugin_config.reinforce_provider_id}") except Exception as e: logger.error(f"初始化插件配置失败: {e}") - # 使用最保险的默认配置 default_data_dir = os.path.join(os.path.dirname(__file__), "data") logger.warning(f"使用默认数据目录: {default_data_dir}") self.plugin_config = PluginConfig.create_from_config(self.config, data_dir=default_data_dir) - - # 确保数据目录存在 + os.makedirs(self.plugin_config.data_dir, exist_ok=True) - - # 初始化 messages_db_path 和 learning_log_path + if not self.plugin_config.messages_db_path: - self.plugin_config.messages_db_path = os.path.join(self.plugin_config.data_dir, FileNames.MESSAGES_DB_FILE) + self.plugin_config.messages_db_path = os.path.join( + self.plugin_config.data_dir, FileNames.MESSAGES_DB_FILE + ) if not self.plugin_config.learning_log_path: - self.plugin_config.learning_log_path = os.path.join(self.plugin_config.data_dir, FileNames.LEARNING_LOG_FILE) + self.plugin_config.learning_log_path = os.path.join( + self.plugin_config.data_dir, FileNames.LEARNING_LOG_FILE + ) - # 学习统计 + # ------ 运行时状态 ------ self.learning_stats = LearningStats() - - # 消息去重缓存 - 防止合并消息插件导致的重复处理 - self.message_dedup_cache = {} + self.message_dedup_cache: dict = {} self.max_cache_size = 1000 - - # ✅ group_id到unified_msg_origin的映射表 - 用于会话隔离 - # key: group_id, value: unified_msg_origin self.group_id_to_unified_origin: Dict[str, str] = {} - - # 设置增量更新回调 - 在服务初始化前设置,避免AttributeError self.update_system_prompt_callback = None + self._perf_tracker = PerfTracker(maxlen=200) - # 初始化服务层 - self._initialize_services() - - # 初始化 Web 服务器(但不启动,等待 on_load) - global server_instance - if self.plugin_config.enable_web_interface: - logger.info(f"Debug: 准备创建Server实例,端口: {self.plugin_config.web_interface_port}") - try: - # 检查是否已经有服务器实例在运行(处理插件重载场景) - if server_instance is not None: - logger.warning("检测到已存在的Web服务器实例,可能是插件重载") - # 检查服务器是否仍在运行 - if server_instance.server_task and not server_instance.server_task.done(): - logger.warning("旧的Web服务器仍在运行,将复用该实例") - logger.info(f"Web服务器地址: http://{server_instance.host}:{server_instance.port}") - else: - logger.info("旧的Web服务器已停止,创建新实例") - server_instance = None # 清除旧实例引用 + # ------ 委托生命周期编排 ------ + self._lifecycle = PluginLifecycle(self) + self._lifecycle.bootstrap( + self.plugin_config, self.context, self.group_id_to_unified_origin + ) - # 只有在没有运行中的服务器时才创建新实例 - if server_instance is None: - server_instance = Server(port=self.plugin_config.web_interface_port) - if server_instance: - logger.info(StatusMessages.WEB_INTERFACE_ENABLED.format(host=server_instance.host, port=server_instance.port)) - logger.info("Web服务器实例已创建,将在on_load中启动") - - # 立即尝试启动Web服务器而不等待on_load - logger.info("Debug: 尝试立即启动Web服务器") - asyncio.create_task(self._immediate_start_web_server()) - else: - logger.error(StatusMessages.WEB_INTERFACE_INIT_FAILED) - except Exception as e: - logger.error(f"创建Web服务器实例失败: {e}", exc_info=True) - else: - logger.info(StatusMessages.WEB_INTERFACE_DISABLED) - logger.info(StatusMessages.PLUGIN_INITIALIZED) - async def _immediate_start_web_server(self): - """立即启动Web服务器,不等待on_load""" - logger.info("Debug: _immediate_start_web_server 被调用") - - # 等待一小段时间让插件完全初始化 - await asyncio.sleep(1) - - global server_instance - if server_instance and self.plugin_config.enable_web_interface: - logger.info("Debug: 开始立即设置并启动Web服务器") - - # 启动数据库管理器 - try: - logger.info("Debug: 启动数据库管理器") - db_started = await self.db_manager.start() - if db_started: - logger.info("Debug: 数据库管理器启动成功") - else: - logger.error("❌ 数据库管理器启动失败,但没有抛出异常") - raise RuntimeError("数据库管理器启动失败") - except Exception as e: - logger.error(f"启动数据库管理器失败: {e}", exc_info=True) - raise # 重新抛出异常,停止插件启动 - - # 设置插件服务 - try: - logger.info("Debug: 开始设置插件服务") - - # 尝试获取AstrBot框架的PersonaManager - astrbot_persona_manager = None - try: - # 通过context的persona_manager属性获取框架的PersonaManager - if hasattr(self.context, 'persona_manager'): - astrbot_persona_manager = self.context.persona_manager - if astrbot_persona_manager: - logger.info(f"立即启动: 成功获取AstrBot框架PersonaManager: {type(astrbot_persona_manager)}") - # 检查PersonaManager是否已初始化 - if hasattr(astrbot_persona_manager, 'personas'): - logger.info(f"立即启动: PersonaManager已有personas属性,人格数量: {len(getattr(astrbot_persona_manager, 'personas', []))}") - else: - logger.info("立即启动: PersonaManager还没有personas属性,可能需要初始化") - else: - logger.warning("立即启动: Context中persona_manager为None") - else: - logger.warning("立即启动: Context中没有persona_manager属性") - - # 额外尝试:如果persona_manager为None,尝试延迟获取 - if not astrbot_persona_manager: - logger.info("立即启动: 尝试延迟获取PersonaManager...") - await asyncio.sleep(3) # 等待3秒,给AstrBot更多初始化时间 - if hasattr(self.context, 'persona_manager') and self.context.persona_manager: - astrbot_persona_manager = self.context.persona_manager - logger.info(f"立即启动: 延迟获取成功: {type(astrbot_persona_manager)}") - else: - logger.warning("立即启动: 延迟获取PersonaManager仍然失败,可能AstrBot还在初始化中") - - except Exception as pe: - logger.error(f"立即启动: 获取AstrBot框架PersonaManager失败: {pe}", exc_info=True) - - await set_plugin_services( - self.plugin_config, - self.factory_manager, - None, - astrbot_persona_manager, - self.group_id_to_unified_origin - ) - logger.info("Debug: 插件服务设置完成") - except Exception as e: - logger.error(f"设置插件服务失败: {e}", exc_info=True) - return - - # 启动Web服务器 - try: - logger.info("Debug: 调用 server_instance.start()") - await server_instance.start() - logger.info("🌐 Web服务器已成功启动!") - except Exception as e: - logger.error(f"Web服务器启动失败: {e}", exc_info=True) - logger.error("提示: 端口可能仍被占用。AstrBot将尝试继续运行,但WebUI不可用。") - # 将实例置空,防止后续错误调用 - server_instance = None - else: - logger.error("Debug: server_instance 为空或 web_interface 未启用") - - async def _start_web_server(self): - """启动Web服务器的异步方法""" - global server_instance - if server_instance: - logger.info(StatusMessages.WEB_SERVER_STARTING) - try: - await server_instance.start() - logger.info(StatusMessages.WEB_SERVER_STARTED) - - # 启动数据库管理器 - db_started = await self.db_manager.start() - if db_started: - logger.info(StatusMessages.DB_MANAGER_STARTED) - else: - logger.error("❌ 数据库管理器启动失败,但没有抛出异常") - raise RuntimeError("数据库管理器启动失败") - except Exception as e: - logger.error(StatusMessages.WEB_SERVER_START_FAILED.format(error=e), exc_info=True) - - def _initialize_services(self): - """初始化所有服务层组件 - 使用工厂模式""" - try: - # 初始化工厂管理器 - self.factory_manager = FactoryManager() - self.factory_manager.initialize_factories(self.plugin_config, self.context) - - # 获取服务工厂 - self.service_factory = self.factory_manager.get_service_factory() - - # 使用工厂创建核心服务 - self.db_manager = self.service_factory.create_database_manager() - self.message_collector = self.service_factory.create_message_collector() - self.multidimensional_analyzer = self.service_factory.create_multidimensional_analyzer() - self.style_analyzer = self.service_factory.create_style_analyzer() - self.quality_monitor = self.service_factory.create_quality_monitor() - self.progressive_learning = self.service_factory.create_progressive_learning() - self.ml_analyzer = self.service_factory.create_ml_analyzer() - self.persona_manager = self.service_factory.create_persona_manager() - - # ✅ 创建响应多样性管理器 - 用于防止LLM回复同质化 - self.diversity_manager = self.service_factory.create_response_diversity_manager() - - # 获取组件工厂并创建新的高级服务 - component_factory = self.factory_manager.get_component_factory() - self.data_analytics = component_factory.create_data_analytics_service() - self.advanced_learning = component_factory.create_advanced_learning_service() - self.enhanced_interaction = component_factory.create_enhanced_interaction_service() - self.intelligence_enhancement = component_factory.create_intelligence_enhancement_service() - self.affection_manager = component_factory.create_affection_manager_service() - - # ✅ 创建对话目标管理器 - 用于智能对话目标检测和管理 - # 必须在social_context_injector之前创建,这样才能被注入器引用 - logger.info(f"🔍 [初始化] 检查enable_goal_driven_chat配置: {self.plugin_config.enable_goal_driven_chat}") - if self.plugin_config.enable_goal_driven_chat: - try: - self.conversation_goal_manager = component_factory.create_conversation_goal_manager() - logger.info("✅ 对话目标管理器已初始化(目标驱动对话系统已启用)") - except Exception as e: - logger.error(f"❌ 创建对话目标管理器失败: {e}", exc_info=True) - self.conversation_goal_manager = None - else: - self.conversation_goal_manager = None - logger.info("⚠️ 对话目标管理器未启用(配置中 enable_goal_driven_chat=False)") - - # ✅ 创建社交上下文注入器(已整合心理状态、行为指导功能) - # 包含:表达模式学习、深度心理状态、社交关系、好感度、行为指导 - # 必须在intelligent_responder之前创建,这样才能被正确注入 - self.social_context_injector = component_factory.create_social_context_injector() - - # ✅ 创建黑话查询服务 - 用于在LLM请求时注入黑话理解 - from .services.jargon_query import JargonQueryService - self.jargon_query_service = JargonQueryService( - db_manager=self.db_manager, - cache_ttl=60 # 60秒缓存TTL - ) - logger.info("黑话查询服务已初始化(带60秒缓存)") - - # ✅ 创建黑话挖掘管理器 - 用于后台学习黑话 - from .services.jargon_miner import JargonMinerManager - self.jargon_miner_manager = JargonMinerManager( - llm_adapter=self.service_factory.create_framework_llm_adapter(), - db_manager=self.db_manager, - config=self.plugin_config - ) - logger.info("黑话挖掘管理器已初始化") - - # 在affection_manager和social_context_injector创建后再创建智能回复器 - self.intelligent_responder = self.service_factory.create_intelligent_responder() # 重新启用智能回复器 - - # 创建临时人格更新器 - self.temporary_persona_updater = self.service_factory.create_temporary_persona_updater() - - # ✅ 传递group_id到unified_origin映射表的引用 - if hasattr(self, 'group_id_to_unified_origin'): - self.temporary_persona_updater.group_id_to_unified_origin = self.group_id_to_unified_origin - if hasattr(self, 'progressive_learning') and self.progressive_learning: - self.progressive_learning.group_id_to_unified_origin = self.group_id_to_unified_origin - if hasattr(self, 'persona_manager') and self.persona_manager: - self.persona_manager.group_id_to_unified_origin = self.group_id_to_unified_origin - logger.info("已将group_id映射表传递给服务组件") - - # 创建并保存LLM适配器实例,用于状态报告 - self.llm_adapter = self.service_factory.create_framework_llm_adapter() - - # 初始化内部组件 - self._setup_internal_components() - - logger.info(StatusMessages.FACTORY_SERVICES_INIT_COMPLETE) - - except SelfLearningError as sle: - logger.error(StatusMessages.SERVICES_INIT_FAILED.format(error=sle)) - raise # Re-raise as this is an expected initialization failure - except (TypeError, ValueError) as e: # Catch common initialization errors - logger.error(StatusMessages.CONFIG_TYPE_ERROR.format(error=e), exc_info=True) - raise SelfLearningError(StatusMessages.INIT_FAILED_GENERIC.format(error=str(e))) from e - except Exception as e: # Catch any other unexpected errors - logger.error(StatusMessages.UNKNOWN_INIT_ERROR.format(error=e), exc_info=True) - raise SelfLearningError(StatusMessages.INIT_FAILED_GENERIC.format(error=str(e))) from e - - def _setup_internal_components(self): - """设置内部组件 - 使用工厂模式""" - # 获取组件工厂 - self.component_factory = self.factory_manager.get_component_factory() - - # QQ号过滤器 - self.qq_filter = self.component_factory.create_qq_filter() - - # 消息过滤器 - self.message_filter = self.component_factory.create_message_filter(self.context) - - # 人格更新器 - # PersonaUpdater 的创建现在需要 backup_manager,它是一个服务,也应该通过 ServiceFactory 获取 - persona_backup_manager_instance = self.service_factory.create_persona_backup_manager() - self.persona_updater = self.component_factory.create_persona_updater(self.context, persona_backup_manager_instance) - - # ✅ 传递group_id到unified_origin映射表(多配置文件支持) - if hasattr(self, 'group_id_to_unified_origin'): - self.persona_updater.group_id_to_unified_origin = self.group_id_to_unified_origin - persona_backup_manager_instance.group_id_to_unified_origin = self.group_id_to_unified_origin - - # 学习调度器 - self.learning_scheduler = self.component_factory.create_learning_scheduler(self) - - # 异步任务管理 - 增强后台任务管理 - self.background_tasks = set() - self.learning_tasks = {} # 按group_id管理学习任务 - - # 启动自动学习(如果启用) - if self.plugin_config.enable_auto_learning: - # 延迟启动,避免在初始化时启动大量任务 - asyncio.create_task(self._delayed_auto_start_learning()) - - # 添加延迟重新初始化提供商配置,解决重启后配置问题 - asyncio.create_task(self._delayed_provider_reinitialization()) + # 生命周期 async def on_load(self): - """插件加载时启动 Web 服务器和数据库管理器""" - global server_instance - logger.info(StatusMessages.ON_LOAD_START) - logger.info(f"Debug: enable_web_interface = {self.plugin_config.enable_web_interface}") - logger.info(f"Debug: server_instance = {server_instance}") - logger.info(f"Debug: web_interface_port = {self.plugin_config.web_interface_port}") - - # 启动数据库管理器,确保数据库表被创建 - db_started = False - max_retries = 3 - retry_delay = 2 # 秒 - - for attempt in range(max_retries): - try: - logger.info(f"尝试启动数据库管理器 (第 {attempt + 1}/{max_retries} 次)") - db_started = await self.db_manager.start() + """插件加载时启动 DB / 服务 / WebUI""" + await self._lifecycle.on_load() - if db_started: - logger.info(StatusMessages.DB_MANAGER_STARTED) - break - else: - logger.warning(f"数据库管理器启动返回False (尝试 {attempt + 1}/{max_retries})") - if attempt < max_retries - 1: - logger.info(f"等待 {retry_delay} 秒后重试...") - await asyncio.sleep(retry_delay) - - except Exception as e: - logger.error(f"数据库启动异常 (尝试 {attempt + 1}/{max_retries}): {e}", exc_info=True) - if attempt < max_retries - 1: - logger.info(f"等待 {retry_delay} 秒后重试...") - await asyncio.sleep(retry_delay) - - # 检查数据库是否成功启动 - if not db_started: - logger.error(StatusMessages.DB_MANAGER_START_FAILED.format(error="所有重试均失败")) - logger.warning("⚠️ 插件将在数据库功能受限的情况下继续运行") - - # 启动好感度管理服务(包含随机情绪初始化) - if self.plugin_config.enable_affection_system: - try: - await self.affection_manager.start() - logger.info("好感度管理服务启动成功") - except Exception as e: - logger.error(f"好感度管理服务启动失败: {e}", exc_info=True) - - # 设置Web服务器的插件服务实例和启动Web服务器 - logger.info(f"Debug: 进入Web服务器启动逻辑") - logger.info(f"Debug: enable_web_interface = {self.plugin_config.enable_web_interface}") - logger.info(f"Debug: server_instance is None = {server_instance is None}") - - if self.plugin_config.enable_web_interface and server_instance: - logger.info("Debug: 开始设置Web服务器插件服务") - # 设置插件服务 - try: - # 尝试获取AstrBot框架的PersonaManager - astrbot_persona_manager = None - try: - # 通过context的persona_manager属性获取框架的PersonaManager - if hasattr(self.context, 'persona_manager'): - astrbot_persona_manager = self.context.persona_manager - if astrbot_persona_manager: - logger.info(f"成功获取AstrBot框架PersonaManager: {type(astrbot_persona_manager)}") - # 检查PersonaManager是否已初始化 - if hasattr(astrbot_persona_manager, 'personas'): - logger.info(f"PersonaManager已有personas属性,人格数量: {len(getattr(astrbot_persona_manager, 'personas', []))}") - else: - logger.info("PersonaManager还没有personas属性,可能需要初始化") - else: - logger.warning("Context中persona_manager为None") - else: - logger.warning("Context中没有persona_manager属性") - - # 额外尝试:如果persona_manager为None,尝试延迟获取 - if not astrbot_persona_manager: - logger.info("尝试延迟获取PersonaManager...") - await asyncio.sleep(2) # 等待2秒 - if hasattr(self.context, 'persona_manager') and self.context.persona_manager: - astrbot_persona_manager = self.context.persona_manager - logger.info(f"延迟获取成功: {type(astrbot_persona_manager)}") - else: - logger.warning("延迟获取PersonaManager仍然失败") - - except Exception as pe: - logger.error(f"获取AstrBot框架PersonaManager失败: {pe}", exc_info=True) - - await set_plugin_services( - self.plugin_config, - self.factory_manager, - None, - astrbot_persona_manager, - self.group_id_to_unified_origin - ) - logger.info("Web服务器插件服务设置完成") - except Exception as e: - logger.error(f"设置Web服务器插件服务失败: {e}", exc_info=True) - return # 如果服务设置失败,就不要继续启动Web服务器 - - # 启动Web服务器 - logger.info(f"Debug: 准备启动Web服务器") - logger.info(StatusMessages.WEB_SERVER_PREPARE.format(host=server_instance.host, port=server_instance.port)) - try: - logger.info("Debug: 调用 server_instance.start()") - await server_instance.start() - logger.info(StatusMessages.WEB_SERVER_STARTED) - logger.info("Debug: Web服务器启动完成") - except Exception as e: - logger.error(StatusMessages.WEB_SERVER_START_FAILED.format(error=e), exc_info=True) - logger.error(f"Debug: Web服务器启动异常详���: {type(e).__name__}: {str(e)}") - import traceback - logger.error(f"Debug: 异常堆栈: {traceback.format_exc()}") - else: - logger.info("Debug: Web服务器启动条件不满足") - if not self.plugin_config.enable_web_interface: - logger.info(StatusMessages.WEB_INTERFACE_DISABLED_SKIP) - if not server_instance: - logger.error(StatusMessages.SERVER_INSTANCE_NULL) - logger.error(f"Debug: server_instance为空,无法启动Web服务器") - - logger.info(StatusMessages.PLUGIN_LOAD_COMPLETE) - - async def _delayed_start_learning(self, group_id: str): - """延迟启动学习服务""" - try: - await asyncio.sleep(3) # 等待初始化完成 - await self.service_factory.initialize_all_services() # 确保所有服务初始化完成 - # 启动针对特定 group_id 的渐进式学习 - await self.progressive_learning.start_learning(group_id) - logger.info(StatusMessages.AUTO_LEARNING_SCHEDULER_STARTED.format(group_id=group_id)) - except Exception as e: - logger.error(StatusMessages.LEARNING_SERVICE_START_FAILED.format(group_id=group_id, error=e)) - - async def _priority_update_incremental_content(self, group_id: str, sender_id: str, message_text: str, event: AstrMessageEvent): - """ - 优先更新增量内容 - 每收到一条消息都会立即调用 - 确保所有增量更新内容都能优先加入到system_prompt中 - """ - try: - logger.info(f"开始优先更新增量内容: group_id={group_id}, sender_id={sender_id[:8]}") - - # 1. 立即进行消息的多维度分析(实时分析) - if hasattr(self, 'multidimensional_analyzer') and self.multidimensional_analyzer: - try: - # 立即分析当前消息的上下文 - analysis_result = await self.multidimensional_analyzer.analyze_message_context( - event, message_text - ) - if analysis_result: - logger.info(f"实时多维度分析完成,包含 {len(analysis_result)} 个维度") - except Exception as e: - logger.error(f"实时多维度分析失败: {e}") - - # 2. 立即更新用户画像和社交关系 - if hasattr(self, 'affection_manager') and self.affection_manager: - try: - # 立即更新好感度和社交关系 - affection_result = await self.affection_manager.process_message_interaction( - group_id, sender_id, message_text - ) - if affection_result and affection_result.get('success'): - logger.debug(f"实时好感度更新完成: {affection_result}") - except Exception as e: - logger.error(f"实时好感度更新失败: {e}") - - # 3. 立即进行情绪和风格分析 - if hasattr(self, 'style_analyzer') and self.style_analyzer: - try: - # 获取最近的消息进行风格分析 - recent_messages_dict = await self.db_manager.get_recent_filtered_messages(group_id, limit=5) - # 添加当前消息 - current_message_dict = { - 'message': message_text, - 'sender_id': sender_id, - 'timestamp': time.time() - } - all_messages_dict = recent_messages_dict + [current_message_dict] - - # 转换字典数据为MessageData对象 - analysis_messages = [] - for msg_dict in all_messages_dict: - message_data = MessageData( - sender_id=msg_dict.get('sender_id', ''), - sender_name=msg_dict.get('sender_name', ''), - message=msg_dict.get('message', ''), - group_id=group_id, - timestamp=msg_dict.get('timestamp', time.time()), - platform=msg_dict.get('platform', 'default'), - message_id=msg_dict.get('message_id'), - reply_to=msg_dict.get('reply_to') - ) - analysis_messages.append(message_data) - - # 立即分析消息的风格 - style_result = await self.style_analyzer.analyze_conversation_style( - group_id, analysis_messages - ) - # ✅ 正确检查 AnalysisResult 的 success 属性 - if style_result and (style_result.success if hasattr(style_result, 'success') else True): - logger.debug(f"实时风格分析完成,置信度: {style_result.confidence if hasattr(style_result, 'confidence') else 'N/A'}") - except Exception as e: - logger.error(f"实时风格分析失败: {e}") - - # 4. 如果启用实时学习,立即进行深度分析 - if self.plugin_config.enable_realtime_learning: - try: - await self._process_message_realtime(group_id, message_text, sender_id) - logger.debug(f"实时学习处理完成: {group_id}") - except Exception as e: - logger.error(f"实时学习处理失败: {e}") - - logger.info(f"增量内容优先更新流程完成: {group_id}") - - except Exception as e: - logger.error(f"优先更新增量内容异常: {e}", exc_info=True) - - def _is_astrbot_command(self, event: AstrMessageEvent) -> bool: - """ - 判断用户输入是否为AstrBot命令(包括插件命令和其他命令) - - 融合了AstrBot框架的命令检测机制和插件特定的命令检测 - - 注意:唤醒词消息(is_at_or_wake_command)应该被收集用于学习, - 因为这些是最有价值的对话数据。只过滤明确的命令格式。 - - Args: - event: AstrBot消息事件 - - Returns: - bool: True表示是命令,False表示是普通消息 - """ - message_text = event.get_message_str() - if not message_text: - return False - - # 1. 检查是否为本插件的特定命令 - if self._is_plugin_command(message_text): - return True - - # 2. 检查是否为其他AstrBot命令(以命令前缀开头) - # 注意:不再使用 is_at_or_wake_command 来过滤,因为唤醒词消息应该被收集 - command_prefixes = ['/', '!', '#', '.'] # 常见命令前缀 - stripped_text = message_text.strip() - if stripped_text and stripped_text[0] in command_prefixes: - # 检查是否像命令格式(前缀+字母开头的命令名) - if len(stripped_text) > 1 and stripped_text[1].isalpha(): - return True + async def terminate(self): + """插件卸载时的清理工作""" + await self._lifecycle.shutdown() - return False - - def _is_plugin_command(self, message_text: str) -> bool: - """检查消息是否为本插件的命令""" - if not message_text: - return False - - # 定义所有插件命令(不包含前缀符号) - plugin_commands = [ - 'learning_status', - 'start_learning', - 'stop_learning', - 'force_learning', - 'affection_status', - 'set_mood' - ] - - # 去除首尾空白 - message_text = message_text.strip() - - # 方案1: 检查带前缀的命令 - # 创建命令的正则表达式模式 - 匹配: [任意单个字符][命令名][可选的空格和参数] - commands_pattern = '|'.join(re.escape(cmd) for cmd in plugin_commands) - pattern_with_prefix = rf'^.{{1}}({commands_pattern})(\s.*)?$' - - # 方案2: 检查不带前缀的命令(被AstrBot框架处理后的) - # 直接匹配命令名,可能带参数 - pattern_without_prefix = rf'^({commands_pattern})(\s.*)?$' - - # 使用正则表达式匹配,忽略大小写 - # 如果匹配任一模式,都认为是插件命令 - return bool(re.match(pattern_with_prefix, message_text, re.IGNORECASE)) or \ - bool(re.match(pattern_without_prefix, message_text, re.IGNORECASE)) + # 消息监听 @filter.event_message_type(filter.EventMessageType.ALL) async def on_message(self, event: AstrMessageEvent): """监听所有消息,收集用户对话数据(非阻塞优化版)""" - try: - # 检查数据库是否就绪(避免在 on_load 完成前处理消息) if not self.db_manager or not self.db_manager.engine: return - # 获取消息文本 message_text = event.get_message_str() if not message_text or len(message_text.strip()) == 0: return - group_id = event.get_group_id() or event.get_sender_id() # 使用群组ID或发送者ID作为会话ID + group_id = event.get_group_id() or event.get_sender_id() sender_id = event.get_sender_id() - # ⚡ 优化1: 好感度处理改为后台任务,不阻塞消息回复 - # 只对at消息和唤醒消息处理好感度(不包括插件命令) + # 好感度处理(后台,仅 at/唤醒消息) if event.is_at_or_wake_command and self.plugin_config.enable_affection_system: - asyncio.create_task(self._process_affection_background(group_id, sender_id, message_text)) + asyncio.create_task( + self._pipeline.process_affection(group_id, sender_id, message_text) + ) - # 检查是否启用消息抓取 - 用于学习数据收集 if not self.plugin_config.enable_message_capture: return - # 使用融合的命令检测机制 - 过滤所有AstrBot命令(仅用于学习数据收集,不影响好感度) - if self._is_astrbot_command(event): + # 命令过滤 + if self._command_filter.is_astrbot_command(event): logger.debug(f"检测到AstrBot命令,跳过学习数据收集: {message_text}") return - # QQ号过滤(仅用于学习数据收集) if not self.qq_filter.should_collect_message(sender_id, group_id): return - # ⚡ 优化2: 所有学习相关操作改为后台任务,完全不阻塞消息回复 - asyncio.create_task(self._process_learning_background( - group_id, sender_id, message_text, event - )) + # 后台学习流水线 + asyncio.create_task( + self._pipeline.process_learning(group_id, sender_id, message_text, event) + ) - # ⚡ 统计更新可以同步进行(非常快) self.learning_stats.total_messages_collected += 1 self.plugin_config.total_messages_collected = self.learning_stats.total_messages_collected except Exception as e: logger.error(StatusMessages.MESSAGE_COLLECTION_ERROR.format(error=e), exc_info=True) - async def _mine_jargon_background(self, group_id: str): - """ - 后台黑话挖掘 - 完全异步,不阻塞主流程 - - 工作流程: - 1. 检查是否应该触发挖掘(频率控制) - 2. 获取最近的消息 - 3. 使用JargonMiner进行黑话提取和推断 - 4. 保存到数据库 - """ - try: - if not hasattr(self, 'jargon_miner_manager'): - logger.debug("[黑话挖掘] JargonMinerManager未初始化,跳过") - return - - # 获取或创建该群组的黑话挖掘器 - jargon_miner = self.jargon_miner_manager.get_or_create_miner(group_id) - - # 获取最近的消息用于挖掘 - stats = await self.message_collector.get_statistics(group_id) - recent_message_count = stats.get('raw_messages', 0) - - # 检查是否应该触发学习(频率控制) - if not jargon_miner.should_trigger(recent_message_count): - logger.debug(f"[黑话挖掘] 群组 {group_id} 未达到触发条件") - return - - # 获取最近20-50条消息用于黑话挖掘 - recent_messages = await self.db_manager.get_recent_raw_messages( - group_id, limit=30 - ) - - if len(recent_messages) < 10: - logger.debug(f"[黑话挖掘] 群组 {group_id} 消息数量不足({len(recent_messages)}<10)") - return - - logger.info(f"🔍 [黑话挖掘] 开始分析群组 {group_id} 的 {len(recent_messages)} 条消息") - - # 将消息列表转换为聊天文本 - chat_messages = "\n".join([ - f"{msg.get('sender_id', 'unknown')}: {msg.get('message', '')}" - for msg in recent_messages - ]) - - # 执行黑话学习(包括候选提取、推断、保存) - await jargon_miner.run_once(chat_messages, len(recent_messages)) - - logger.debug(f"[黑话挖掘] 群组 {group_id} 学习完成") - - except Exception as e: - logger.error(f"❌ [黑话挖掘] 后台任务失败 (group={group_id}): {e}", exc_info=True) - - async def _process_affection_background(self, group_id: str, sender_id: str, message_text: str): - """后台处理好感度更新(非阻塞)""" - try: - affection_result = await self.affection_manager.process_message_interaction( - group_id, sender_id, message_text - ) - if affection_result.get('success'): - logger.debug(LogMessages.AFFECTION_PROCESSING_SUCCESS.format(result=affection_result)) - except Exception as e: - logger.error(LogMessages.AFFECTION_PROCESSING_FAILED.format(error=e)) - - async def _process_learning_background(self, group_id: str, sender_id: str, message_text: str, event: AstrMessageEvent): - """后台处理学习相关操作(非阻塞) - - ⚠️ 注意:此函数通过 asyncio.create_task() 在后台运行 - 为避免 'Future attached to different loop' 错误,数据库操作需要特殊处理 - """ - try: - # 1. ✅ 修复事件循环问题:将数据库写入操作包装在异常处理中 - # 对于 MySQL,可能会遇到事件循环绑定问题,捕获并记录而不是崩溃 - try: - await self.message_collector.collect_message({ - 'sender_id': sender_id, - 'sender_name': event.get_sender_name(), - 'message': message_text, - 'group_id': group_id, - 'timestamp': time.time(), - 'platform': event.get_platform_name() - }) - except RuntimeError as e: - if "attached to a different loop" in str(e): - # 这是已知的事件循环问题,记录警告但不中断流程 - logger.warning(f"消息收集遇到事件循环问题(已知MySQL限制),消息将被跳过: {str(e)[:100]}") - else: - raise # 其他 RuntimeError 继续抛出 - except Exception as e: - # 其他异常也记录但不中断 - logger.error(f"消息收集失败: {e}") - - - # 2. 处理增强交互(多轮对话管理) - try: - await self.enhanced_interaction.update_conversation_context( - group_id, sender_id, message_text - ) - except Exception as e: - logger.error(LogMessages.ENHANCED_INTERACTION_FAILED.format(error=e)) - - # 3. ✅ 黑话挖掘 - 每收集10条消息触发一次(完全后台执行) - stats = await self.message_collector.get_statistics(group_id) - raw_message_count = stats.get('raw_messages', 0) - if raw_message_count % 10 == 0 and raw_message_count >= 10: - asyncio.create_task(self._mine_jargon_background(group_id)) - - # 4. 如果启用实时学习,每条消息都学习(完全后台执行,不阻塞) - if self.plugin_config.enable_realtime_learning: - # ⚡ 使用 asyncio.create_task 确保完全后台执行 - asyncio.create_task(self._process_message_realtime_background(group_id, message_text, sender_id)) - - # 5. 智能启动学习任务(基于消息活动,添加频率限制) - await self._smart_start_learning_for_group(group_id) - - # 6. 对话目标管理(如果启用) - if self.plugin_config.enable_goal_driven_chat: - try: - if hasattr(self, 'conversation_goal_manager') and self.conversation_goal_manager: - # 创建或获取对话目标 - goal = await self.conversation_goal_manager.get_or_create_conversation_goal( - user_id=sender_id, - group_id=group_id, - user_message=message_text - ) - if goal: - goal_type = goal['final_goal'].get('type', 'unknown') - goal_name = goal['final_goal'].get('name', '未知目标') - topic = goal['final_goal'].get('topic', '未知话题') - current_stage = goal['current_stage'].get('task', '初始化') - logger.info(f"✅ [对话目标] 会话目标: {goal_name} (类型: {goal_type}), 话题: {topic}, 当前阶段: {current_stage}") - except Exception as e: - logger.error(f"对话目标处理失败: {e}", exc_info=True) - - except Exception as e: - logger.error(f"后台学习处理失败: {e}", exc_info=True) - - async def _smart_start_learning_for_group(self, group_id: str): - """智能启动群组学习任务 - 不阻塞主线程,添加频率限制""" - try: - # 检查该群组是否已有学习任务 - if group_id in self.learning_tasks: - return - - # 添加学习间隔检查:防止频繁启动学习 - current_time = time.time() - last_learning_key = f"last_learning_start_{group_id}" - last_learning_start = getattr(self, last_learning_key, 0) - learning_interval_seconds = self.plugin_config.learning_interval_hours * 3600 - - if current_time - last_learning_start < learning_interval_seconds: - time_remaining = learning_interval_seconds - (current_time - last_learning_start) - logger.debug(f"群组 {group_id} 学习间隔未到,剩余时间: {time_remaining/60:.1f}分钟") - return - - # 检查群组消息数量是否达到学习阈值 (确保类型转换) - stats = await self.message_collector.get_statistics(group_id) - - # 验证 stats 是否为字典 - if not isinstance(stats, dict): - logger.warning(f"get_statistics 返回了非字典类型: {type(stats)}, 值: {stats}, 跳过学习启动") - return - - # 安全获取并转换数值 - total_messages_raw = stats.get('total_messages', 0) - min_messages_raw = self.plugin_config.min_messages_for_learning - - # 类型转换带详细日志 - try: - if isinstance(total_messages_raw, str) and not total_messages_raw.replace('-', '').isdigit(): - logger.warning(f"total_messages 是非数字字符串: '{total_messages_raw}', 跳过学习启动") - return - total_messages = int(total_messages_raw) if total_messages_raw else 0 - except (ValueError, TypeError) as e: - logger.warning(f"total_messages 转换失败: 原始值={total_messages_raw}, 类型={type(total_messages_raw)}, 错误={e}") - return - - try: - if isinstance(min_messages_raw, str) and not min_messages_raw.replace('-', '').isdigit(): - logger.warning(f"min_messages_for_learning 是非数字字符串: '{min_messages_raw}', 使用默认值10") - min_messages = 10 - else: - min_messages = int(min_messages_raw) if min_messages_raw else 0 - except (ValueError, TypeError) as e: - logger.warning(f"min_messages 转换失败: 原始值={min_messages_raw}, 类型={type(min_messages_raw)}, 错误={e}, 使用默认值10") - min_messages = 10 - - if total_messages < min_messages: - logger.debug(f"群组 {group_id} 消息数量未达到学习阈值: {total_messages}/{min_messages}") - return - - # 记录学习启动时间 - setattr(self, last_learning_key, current_time) - - # 创建学习任务 - learning_task = asyncio.create_task(self._start_group_learning(group_id)) - - # 设置完成回调 - def on_learning_task_complete(task): - if group_id in self.learning_tasks: - del self.learning_tasks[group_id] - if task.exception(): - logger.error(f"群组 {group_id} 学习任务异常: {task.exception()}") - else: - logger.info(f"群组 {group_id} 学习任务完成") - - learning_task.add_done_callback(on_learning_task_complete) - self.learning_tasks[group_id] = learning_task - - logger.info(f"为群组 {group_id} 启动了智能学习任务") - - except Exception as e: - logger.error(f"智能启动学习失败: {e}") - - async def _start_group_learning(self, group_id: str): - """启动特定群组的学习任务""" - try: - success = await self.progressive_learning.start_learning(group_id) - if success: - logger.info(f"群组 {group_id} 学习任务启动成功") - else: - logger.warning(f"群组 {group_id} 学习任务启动失败") - except Exception as e: - logger.error(f"群组 {group_id} 学习任务启动异常: {e}") - - async def _delayed_provider_reinitialization(self): - """延迟重新初始化提供商配置,解决重启后配置丢失问题""" - try: - # 等待系统完全初始化 - await asyncio.sleep(10) - - # 重新初始化LLM适配器的提供商配置 - if hasattr(self, 'llm_adapter') and self.llm_adapter: - self.llm_adapter.initialize_providers(self.plugin_config) - logger.info("延迟重新初始化提供商配置完成") - - # 检查配置状态 - if self.llm_adapter.providers_configured == 0: - logger.warning("重新初始化后仍然没有配置任何提供商,请检查配置") - # 再次尝试,间隔更长时间 - await asyncio.sleep(30) - self.llm_adapter.initialize_providers(self.plugin_config) - logger.info("第二次尝试重新初始化提供商配置") - else: - logger.info(f"成功配置了 {self.llm_adapter.providers_configured} 个提供商") - - except Exception as e: - logger.error(f"延迟重新初始化提供商配置失败: {e}") - - async def _delayed_auto_start_learning(self): - """延迟自动启动学习 - 避免初始化时阻塞""" - try: - # 等待系统初始化完成 - await asyncio.sleep(30) - - # 获取活跃群组列表 - active_groups = await self._get_active_groups() - - for group_id in active_groups: - try: - await self._smart_start_learning_for_group(group_id) - # 避免同时启动过多任务 - await asyncio.sleep(5) - except Exception as e: - logger.error(f"延迟启动群组 {group_id} 学习失败: {e}") - - except Exception as e: - logger.error(f"延迟自动启动学习失败: {e}") - - async def _get_active_groups(self) -> List[str]: - """获取活跃群组列表(使用ORM)""" - try: - # 检查数据库管理器是否可用和已启动 - if not self.db_manager: - logger.warning("数据库管理器未初始化,无法获取活跃群组") - return [] - - # 对于 SQLAlchemy 数据库管理器,检查是否已启动 - if hasattr(self.db_manager, '_started') and not self.db_manager._started: - logger.warning("SQLAlchemy 数据库管理器未启动,无法获取活跃群组") - return [] - - # 根据白名单/黑名单配置构建群组过滤条件 - allowed_groups = self.qq_filter.get_allowed_group_ids() - blocked_groups = self.qq_filter.get_blocked_group_ids() - - if allowed_groups: - logger.info(f"应用群组白名单过滤,仅查询: {allowed_groups}") - if blocked_groups: - logger.info(f"应用群组黑名单过滤,排除: {blocked_groups}") - - # 使用 ORM 方式查询活跃群组 - async with self.db_manager.get_session() as session: - from sqlalchemy import select, func - from .models.orm import RawMessage - - def _apply_group_filter(stmt): - """对查询语句应用白名单/黑名单过滤""" - if allowed_groups: - stmt = stmt.where(RawMessage.group_id.in_(allowed_groups)) - if blocked_groups: - stmt = stmt.where(RawMessage.group_id.notin_(blocked_groups)) - return stmt - - # 首先尝试获取最近24小时内有消息的群组 - cutoff_time = int(time.time() - 86400) - - stmt = select( - RawMessage.group_id, - func.count(RawMessage.id).label('msg_count') - ).where( - RawMessage.timestamp > cutoff_time, - RawMessage.group_id.isnot(None), - RawMessage.group_id != '' - ) - stmt = _apply_group_filter(stmt) - stmt = stmt.group_by( - RawMessage.group_id - ).having( - func.count(RawMessage.id) >= self.plugin_config.min_messages_for_learning - ).order_by( - func.count(RawMessage.id).desc() - ).limit(10) - - result = await session.execute(stmt) - active_groups = [row.group_id for row in result if row.group_id] - - # 如果最近24小时没有活跃群组,扩大时间范围到7天 - if not active_groups: - logger.warning("最近24小时内没有活跃群组,扩大搜索范围到7天...") - cutoff_time = int(time.time() - (86400 * 7)) # 7天 - - stmt = select( - RawMessage.group_id, - func.count(RawMessage.id).label('msg_count') - ).where( - RawMessage.timestamp > cutoff_time, - RawMessage.group_id.isnot(None), - RawMessage.group_id != '' - ) - stmt = _apply_group_filter(stmt) - stmt = stmt.group_by( - RawMessage.group_id - ).having( - func.count(RawMessage.id) >= max(1, self.plugin_config.min_messages_for_learning // 2) - ).order_by( - func.count(RawMessage.id).desc() - ).limit(10) - - result = await session.execute(stmt) - active_groups = [row.group_id for row in result if row.group_id] - - # 如果还是没有,获取所有有消息的群组(无时间限制) - if not active_groups: - logger.warning("7天内也没有活跃群组,获取所有有消息记录的群组...") - - stmt = select( - RawMessage.group_id, - func.count(RawMessage.id).label('msg_count') - ).where( - RawMessage.group_id.isnot(None), - RawMessage.group_id != '' - ) - stmt = _apply_group_filter(stmt) - stmt = stmt.group_by( - RawMessage.group_id - ).order_by( - func.count(RawMessage.id).desc() - ).limit(10) - - result = await session.execute(stmt) - active_groups = [row.group_id for row in result if row.group_id] - - logger.info(f"发现 {len(active_groups)} 个活跃群组: {active_groups if active_groups else '无'}") - return active_groups - - except Exception as e: - logger.error(f"获取活跃群组失败: {e}") - return [] - - async def _process_message_realtime_background(self, group_id: str, message_text: str, sender_id: str): - """实时处理消息的后台包装方法 - 完全异步,不阻塞主流程""" - try: - await self._process_message_realtime(group_id, message_text, sender_id) - except Exception as e: - logger.error(f"实时学习后台处理失败 (group={group_id}): {e}", exc_info=True) - - async def _process_message_realtime(self, group_id: str, message_text: str, sender_id: str): - """实时处理消息 - 优化LLM调用频率,表达风格学习不经过消息筛选""" - try: - # 先进行基础过滤,避免不必要的LLM调用 - if len(message_text.strip()) < self.plugin_config.message_min_length: - return - - if len(message_text) > self.plugin_config.message_max_length: - return - - # 简单关键词过滤,避免明显无意义的消息 - if message_text.strip() in ['', '???', '。。。', '...', '嗯', '哦', '额']: - return - - # 【新增】表达风格学习 - 直接使用原始消息,无需筛选 - await self._process_expression_style_learning(group_id, message_text, sender_id) - - # 基于配置的批处理模式:不是每条消息都调用LLM - if not self.plugin_config.enable_realtime_llm_filter: - # 如果禁用实时LLM筛选,直接添加到筛选消息 - await self.message_collector.add_filtered_message({ - 'message': message_text, - 'sender_id': sender_id, - 'group_id': group_id, - 'timestamp': time.time(), - 'confidence': 0.6 # 无LLM筛选的置信度较低 - }) - self.learning_stats.filtered_messages += 1 - - # 确保配置中的统计也得到更新,用于WebUI显示 - if not hasattr(self.plugin_config, 'filtered_messages'): - self.plugin_config.filtered_messages = 0 - self.plugin_config.filtered_messages = self.learning_stats.filtered_messages - - # 如果启用LLM筛选,则获取当前人格描述并进行筛选 - current_persona_description = await self.persona_manager.get_current_persona_description(group_id) - - # 删除了智能回复相关处理 - # 原智能回复功能已移除 - - if await self.multidimensional_analyzer.filter_message_with_llm(message_text, current_persona_description): - await self.message_collector.add_filtered_message({ - 'message': message_text, - 'sender_id': sender_id, - 'group_id': group_id, - 'timestamp': time.time(), - 'confidence': 0.8 # 实时筛选置信度 - }) - self.learning_stats.filtered_messages += 1 - - # 确保配置中的统计也得到更新,用于WebUI显示 - if not hasattr(self.plugin_config, 'filtered_messages'): - self.plugin_config.filtered_messages = 0 - self.plugin_config.filtered_messages = self.learning_stats.filtered_messages - - except Exception as e: - logger.error(StatusMessages.REALTIME_PROCESSING_ERROR.format(error=e), exc_info=True) - - async def _process_expression_style_learning(self, group_id: str, message_text: str, sender_id: str): - """处理表达风格学习 - 直接学习,无需消息筛选""" - try: - # 检查是否有足够的消息进行学习 - stats = await self.message_collector.get_statistics(group_id) - raw_message_count = stats.get('raw_messages', 0) - - # 需要至少5条消息才开始表达风格学习 - if raw_message_count < 5: - logger.debug(f"群组 {group_id} 原始消息数量不足,当前:{raw_message_count},需要至少5条") - return - - logger.info(f"群组 {group_id} 开始表达风格学习,当前消息数:{raw_message_count}") - - # 获取最近的原始消息用于学习(不使用筛选后的消息) - recent_raw_messages = await self.db_manager.get_recent_raw_messages(group_id, limit=25) - - if not recent_raw_messages or len(recent_raw_messages) < 3: # 降低阈值 - logger.debug(f"群组 {group_id} 原始消息数量不足,数据库中只有 {len(recent_raw_messages) if recent_raw_messages else 0} 条") - return - - # 转换为 MessageData 格式,并应用正则表达式过滤 - from .core.interfaces import MessageData - import re - - message_data_list = [] - for msg in recent_raw_messages: - if msg.get('sender_id') != sender_id: # 不学习自己的消息 - message_content = msg.get('message', '') - - # 应用与webui.py相同的过滤逻辑 - # 1. 基础过滤:长度检查 - if len(message_content.strip()) < 5: - continue - if len(message_content) > 500: - continue - - # 2. 关键词过滤:无意义消息 - if message_content.strip() in ['', '???', '。。。', '...', '嗯', '哦', '额']: - continue - - # 3. @符号处理:提取@用户名后的消息内容 - processed_message = message_content - if '@' in message_content: - # 使用正则表达式匹配 @用户名 后的内容 - at_pattern = r'@[^\s]+\s+' - processed_message = re.sub(at_pattern, '', message_content).strip() - - # 如果处理后消息为空或过短,跳过 - if len(processed_message.strip()) < 5: - continue - - message_data = MessageData( - sender_id=msg.get('sender_id', ''), - sender_name=msg.get('sender_name', ''), - message=processed_message, # 使用处理后的消息内容 - group_id=group_id, - timestamp=msg.get('timestamp', time.time()), - platform=msg.get('platform', 'default'), - message_id=msg.get('id'), # 使用id而不是message_id - reply_to=None # raw_messages表中没有reply_to字段 - ) - message_data_list.append(message_data) - - if len(message_data_list) < 3: # 降低阈值 - logger.debug(f"群组 {group_id} 有效学习消息不足3条,跳过表达风格学习,当前:{len(message_data_list)}") - return - - logger.info(f"群组 {group_id} 准备进行表达风格学习,有效消息数:{len(message_data_list)}") - - # 调用表达模式学习器进行学习 - expression_learner = self.factory_manager.get_component_factory().create_expression_pattern_learner() - - if expression_learner: - learning_success = await expression_learner.trigger_learning_for_group(group_id, message_data_list) - - if learning_success: - logger.info(f"群组 {group_id} 表达风格学习成功") - - # 获取学习到的表达模式 - try: - learned_patterns = await expression_learner.get_expression_patterns(group_id, limit=5) - if learned_patterns: - # 动态临时加入prompt(不加入人格) - await self._apply_style_to_prompt_temporarily(group_id, learned_patterns) - - # 同时生成Few Shots对话格式并创建审查请求(用于正式加入人格) - few_shots_content = await self._generate_few_shots_dialog(group_id, message_data_list) - - if few_shots_content: - # 创建审查请求用于正式加入人格 - await self._create_style_learning_review_request( - group_id, learned_patterns, few_shots_content - ) - logger.info(f"群组 {group_id} 表达风格学习结果已临时应用到prompt,并已提交人格审查") - else: - logger.info(f"群组 {group_id} 表达风格学习结果已临时应用到prompt") - except Exception as e: - logger.error(f"处理表达风格学习结果失败: {e}") - - # 统计更新 - self.learning_stats.style_updates += 1 - - # 触发增量更新回调(动态临时更新prompt) - if self.update_system_prompt_callback: - await self.update_system_prompt_callback(group_id) - logger.info(f"群组 {group_id} 表达风格学习结果已应用到system_prompt") - else: - logger.debug(f"群组 {group_id} 表达风格学习未产生有效结果") - else: - logger.warning("表达模式学习器未正确初始化") - - except Exception as e: - logger.error(f"群组 {group_id} 表达风格学习处理失败: {e}") + # LLM Hook - async def _apply_style_to_prompt_temporarily(self, group_id: str, learned_patterns: List[Any]): - """临时将风格应用到prompt中(不修改人格文件)""" - try: - if not learned_patterns: - return - - # 构建风格描述 - style_descriptions = [] - for pattern in learned_patterns[:3]: # 只取前3个最重要的 - situation = pattern.situation if hasattr(pattern, 'situation') else pattern.get('situation', '') - expression = pattern.expression if hasattr(pattern, 'expression') else pattern.get('expression', '') - - if situation and expression: - style_descriptions.append(f"当{situation}时,可以使用\"{expression}\"这样的表达") - - if style_descriptions: - # 构建临时风格提示 - style_prompt = f""" -【临时表达风格特征】(基于最近学习) -在回复时可以参考以下表达方式: -{chr(10).join(f'• {desc}' for desc in style_descriptions)} - -注意:这些是临时学习的风格特征,应自然融入回复,不要刻意模仿。 -""" - - # 应用到临时prompt(通过临时人格更新器的动态更新功能) - success = await self.temporary_persona_updater.apply_temporary_style_update(group_id, style_prompt.strip()) - - if success: - logger.info(f"群组 {group_id} 表达风格已临时应用到prompt,包含 {len(style_descriptions)} 个风格特征") - else: - logger.warning(f"群组 {group_id} 表达风格临时应用失败") - - except Exception as e: - logger.error(f"临时应用风格到prompt失败: {e}") - - async def _generate_few_shots_dialog(self, group_id: str, message_data_list: List[Any]) -> str: - """生成Few Shots对话格式的内容 - 需要至少10条消息才调用LLM处理""" - try: - # 要求至少10条消息才进行Few Shots生成 - if len(message_data_list) < 10: - logger.debug(f"群组 {group_id} 消息数量不足10条(当前{len(message_data_list)}条),跳过Few Shots生成") - return "" - - # 筛选出有效的对话片段 - dialog_pairs = [] - - # 将消息按时间排序 - sorted_messages = sorted(message_data_list, key=lambda x: x.timestamp) - - # 使用LLM智能识别真实的对话关系 - for i in range(len(sorted_messages) - 1): - current_msg = sorted_messages[i] - next_msg = sorted_messages[i + 1] - - # 1. 确保是不同用户的消息(排除同一人连续发送) - if current_msg.sender_id == next_msg.sender_id: - continue - - # 2. 基础过滤:长度检查 - user_msg = current_msg.message.strip() - bot_response = next_msg.message.strip() - - if (len(user_msg) < 5 or len(bot_response) < 5 or - user_msg in ['?', '??', '...', '。。。'] or - bot_response in ['?', '??', '...', '。。。']): - continue - - # 3. 过滤重复内容(A重复B的话不算对话) - if user_msg == bot_response or user_msg in bot_response or bot_response in user_msg: - logger.debug(f"过滤重复内容: A='{user_msg[:30]}...' B='{bot_response[:30]}...'") - continue - - # 4. 调用专业的消息关系分析器判断两条消息是否构成真实对话关系 - if await self._is_valid_dialog_pair(current_msg, next_msg, group_id): - dialog_pairs.append({ - 'user': user_msg, - 'assistant': bot_response - }) - - # 选择最佳的对话片段(取前5个) - if len(dialog_pairs) >= 3: - selected_pairs = dialog_pairs[:5] - - # 生成Few Shots格式 - few_shots_lines = [ - "*Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:" - ] - - for pair in selected_pairs: - few_shots_lines.append(f"A: {pair['user']}") - few_shots_lines.append(f"B: {pair['assistant']}") - - logger.info(f"群组 {group_id} 生成了 {len(selected_pairs)} 组Few Shots对话") - return '\n'.join(few_shots_lines) - - logger.debug(f"群组 {group_id} 未找到足够的有效对话片段(需要至少3组,当前{len(dialog_pairs)}组)") - return "" - - except Exception as e: - logger.error(f"生成Few Shots对话失败: {e}") - return "" - - async def _is_valid_dialog_pair(self, msg1: Any, msg2: Any, group_id: str) -> bool: - """ - 使用专业的消息关系分析器判断两条消息是否构成真实的对话关系 - - Args: - msg1: 第一条消息(MessageData对象) - msg2: 第二条消息(MessageData对象) - group_id: 群组ID - - Returns: - bool: True表示构成对话关系,False表示不构成 - """ - try: - # 检查服务工厂是否已初始化 - if not self.factory_manager or not hasattr(self.factory_manager, '_service_factory') or not self.factory_manager._service_factory: - # 服务工厂未初始化,使用简单规则 - return msg1.message != msg2.message - - # 获取消息关系分析器 - relationship_analyzer = self.factory_manager.get_service_factory().create_message_relationship_analyzer() - - if not relationship_analyzer: - # 降级方案:简单规则 - return msg1.message != msg2.message - - # 构造分析器需要的消息格式 - msg1_dict = { - 'message_id': msg1.message_id or str(hash(f"{msg1.timestamp}{msg1.sender_id}")), - 'sender_id': msg1.sender_id, - 'message': msg1.message, - 'timestamp': msg1.timestamp - } - - msg2_dict = { - 'message_id': msg2.message_id or str(hash(f"{msg2.timestamp}{msg2.sender_id}")), - 'sender_id': msg2.sender_id, - 'message': msg2.message, - 'timestamp': msg2.timestamp - } - - # 调用专业分析器 - relationship = await relationship_analyzer._analyze_message_pair(msg1_dict, msg2_dict, group_id) - - # 判断结果 - if relationship: - # 关系类型为direct_reply或topic_continuation,且置信度>0.5,则认为是有效对话 - is_valid = ( - relationship.relationship_type in ['direct_reply', 'topic_continuation'] and - relationship.confidence > 0.5 - ) - - if is_valid: - logger.debug(f"识别对话关系: {relationship.relationship_type} (置信度: {relationship.confidence:.2f})") - - return is_valid - - return False - - except Exception as e: - logger.error(f"消息关系判断失败: {e}", exc_info=True) - # 出错时保守判断,返回False - return False - - async def _create_style_learning_review_request(self, group_id: str, learned_patterns: List[Any], few_shots_content: str): - """创建对话风格学习结果的审查请求 - 包含去重逻辑""" - try: - # 1. 检查是否有重复的待审查记录(避免重复提交) - existing_reviews = await self._get_pending_style_reviews(group_id) - - if existing_reviews: - # 检查内容是否相似 - for existing in existing_reviews: - existing_content = existing.get('few_shots_content', '') - # 如果Few Shots内容完全相同,跳过创建 - if existing_content == few_shots_content: - logger.info(f"群组 {group_id} 已存在相同的待审查风格学习记录,跳过重复创建") - return - - # 2. 构建审查内容 - review_data = { - 'type': 'style_learning', - 'group_id': group_id, - 'timestamp': time.time(), - 'learned_patterns': [pattern.to_dict() for pattern in learned_patterns], - 'few_shots_content': few_shots_content, - 'status': 'pending', # pending, approved, rejected - 'description': f'群组 {group_id} 的对话风格学习结果(包含 {len(learned_patterns)} 个表达模式)' - } - - # 3. 保存到数据库的审查表 - await self.db_manager.create_style_learning_review(review_data) - - logger.info(f"对话风格学习审查请求已创建: {group_id}") - - except Exception as e: - logger.error(f"创建对话风格学习审查请求失败: {e}") - - async def _get_pending_style_reviews(self, group_id: str) -> List[Dict[str, Any]]: - """获取指定群组的待审查风格学习记录""" - try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 查询该群组的pending状态的风格学习审查记录 - await cursor.execute(''' - SELECT id, group_id, few_shots_content, timestamp - FROM style_learning_reviews - WHERE group_id = ? AND status = 'pending' AND type = 'style_learning' - ORDER BY timestamp DESC - LIMIT 10 - ''', (group_id,)) - - rows = await cursor.fetchall() - - reviews = [] - for row in rows: - reviews.append({ - 'id': row[0], - 'group_id': row[1], - 'few_shots_content': row[2], - 'timestamp': row[3] - }) - - return reviews + @filter.on_llm_request() + async def inject_diversity_to_llm_request(self, event: AstrMessageEvent, req=None): + """LLM Hook — inject diversity, social context, V2, jargon into request.""" + await self._hook_handler.handle(event, req) - except Exception as e: - logger.error(f"获取待审查风格学习记录失败: {e}") - return [] + # 命令处理器(薄委托) @filter.command("learning_status") @filter.permission_type(PermissionType.ADMIN) async def learning_status_command(self, event: AstrMessageEvent): """查看学习状态""" - try: - group_id = event.get_group_id() or event.get_sender_id() # 获取当前会话ID - - # 获取收集统计 - collector_stats = await self.message_collector.get_statistics(group_id) # 传入 group_id - - # 确保 collector_stats 不为 None - if collector_stats is None: - collector_stats = { - 'total_messages': 0, - 'filtered_messages': 0, - 'raw_messages': 0, - 'unprocessed_messages': 0, - } - - # 获取当前人格设置 - current_persona_info = await self.persona_manager.get_current_persona(group_id) - current_persona_name = CommandMessages.STATUS_UNKNOWN - if current_persona_info and isinstance(current_persona_info, dict): - current_persona_name = current_persona_info.get('name', CommandMessages.STATUS_UNKNOWN) - - # 获取渐进式学习服务的状态 - learning_status = await self.progressive_learning.get_learning_status() - - # 确保 learning_status 不为 None - if learning_status is None: - learning_status = { - 'learning_active': False, - 'current_session': None, - 'total_sessions': 0, - } - - # 构建状态信息 - status_info = CommandMessages.STATUS_REPORT_HEADER.format(group_id=group_id) - - # 基础配置 - persona_update_mode = "PersonaManager模式" if self.plugin_config.use_persona_manager_updates else "传统文件模式" - status_info += CommandMessages.STATUS_BASIC_CONFIG.format( - message_capture=CommandMessages.STATUS_ENABLED if self.plugin_config.enable_message_capture else CommandMessages.STATUS_DISABLED, - auto_learning=CommandMessages.STATUS_ENABLED if self.plugin_config.enable_auto_learning else CommandMessages.STATUS_DISABLED, - realtime_learning=CommandMessages.STATUS_ENABLED if self.plugin_config.enable_realtime_learning else CommandMessages.STATUS_DISABLED, - web_interface=CommandMessages.STATUS_ENABLED if self.plugin_config.enable_web_interface else CommandMessages.STATUS_DISABLED - ) - - # 人格更新方式信息 - status_info += f"\n\n📊 人格更新配置:\n" - status_info += f"• 更新方式: {persona_update_mode}\n" - if self.plugin_config.use_persona_manager_updates: - # 检查PersonaManager可用性 - persona_manager_updater = self.service_factory.create_persona_manager_updater() - pm_status = "✅ 可用" if persona_manager_updater.is_available() else "❌ 不可用" - status_info += f"• PersonaManager状态: {pm_status}\n" - status_info += f"• 自动应用更新: {'启用' if self.plugin_config.auto_apply_persona_updates else '禁用'}\n" - status_info += f"• 更新前备份: {'启用' if self.plugin_config.persona_update_backup_enabled else '禁用'}\n" - - # 抓取设置 - status_info += CommandMessages.STATUS_CAPTURE_SETTINGS.format( - target_qq=self.plugin_config.target_qq_list if self.plugin_config.target_qq_list else CommandMessages.STATUS_ALL_USERS, - current_persona=current_persona_name - ) - - # Provider配置信息 - if hasattr(self, 'llm_adapter') and self.llm_adapter: - provider_info = self.llm_adapter.get_provider_info() - status_info += CommandMessages.STATUS_MODEL_CONFIG.format( - filter_model=provider_info.get('filter', '未配置'), - refine_model=provider_info.get('refine', '未配置') - ) - else: - status_info += CommandMessages.STATUS_MODEL_CONFIG.format( - filter_model='未配置框架Provider', - refine_model='未配置框架Provider' - ) - - # 学习统计 - 安全处理嵌套的None值 - current_session = learning_status.get('current_session') or {} - status_info += CommandMessages.STATUS_LEARNING_STATS.format( - total_messages=collector_stats.get('total_messages', 0), - filtered_messages=collector_stats.get('filtered_messages', 0), - style_updates=current_session.get('style_updates', 0), - last_learning_time=current_session.get('end_time', CommandMessages.STATUS_NEVER_EXECUTED) - ) - - # 存储统计 - status_info += CommandMessages.STATUS_STORAGE_STATS.format( - raw_messages=collector_stats.get('raw_messages', 0), - unprocessed_messages=collector_stats.get('unprocessed_messages', 0), - filtered_messages=collector_stats.get('filtered_messages', 0) - ) - - # 调度状态 - scheduler_status = CommandMessages.STATUS_RUNNING if learning_status.get('learning_active') else CommandMessages.STATUS_STOPPED - status_info += "\n\n" + CommandMessages.STATUS_SCHEDULER.format(status=scheduler_status) - - yield event.plain_result(status_info.strip()) - - except Exception as e: - logger.error(CommandMessages.ERROR_GET_LEARNING_STATUS.format(error=e), exc_info=True) - yield event.plain_result(CommandMessages.STATUS_QUERY_FAILED.format(error=str(e))) + async for result in self._command_handlers.learning_status(event): + yield result @filter.command("start_learning") @filter.permission_type(PermissionType.ADMIN) async def start_learning_command(self, event: AstrMessageEvent): """手动启动学习""" - try: - group_id = event.get_group_id() or event.get_sender_id() - - # 检查是否有足够的消息进行学习 - stats = await self.message_collector.get_statistics(group_id) - unprocessed_count = stats.get('unprocessed_messages', 0) - - if unprocessed_count < self.plugin_config.min_messages_for_learning: - yield event.plain_result(f"❌ 未处理消息数量不足({unprocessed_count}/{self.plugin_config.min_messages_for_learning}),无法开始学习") - return - - # 执行一次学习批次而不是启动持续循环 - yield event.plain_result(f"🔄 开始执行学习批次,处理 {unprocessed_count} 条未处理消息...") - - try: - await self.progressive_learning._execute_learning_batch(group_id) - yield event.plain_result(f"✅ 学习批次执行完成") - except Exception as batch_error: - yield event.plain_result(f"❌ 学习批次执行失败: {str(batch_error)}") - - except Exception as e: - logger.error(CommandMessages.ERROR_START_LEARNING.format(error=e), exc_info=True) - yield event.plain_result(CommandMessages.STARTUP_FAILED.format(error=str(e))) + async for result in self._command_handlers.start_learning(event): + yield result @filter.command("stop_learning") @filter.permission_type(PermissionType.ADMIN) async def stop_learning_command(self, event: AstrMessageEvent): """停止学习""" - try: - group_id = event.get_group_id() or event.get_sender_id() - - # ProgressiveLearningService 的 stop_learning 目前没有 group_id 参数 - # 如果需要停止特定 group_id 的学习,ProgressiveLearningService 需要修改 - # 暂时调用全局停止,或者假设 stop_learning 会停止当前活跃的会话 - await self.progressive_learning.stop_learning() - yield event.plain_result(CommandMessages.LEARNING_STOPPED.format(group_id=group_id)) - - except Exception as e: - logger.error(CommandMessages.ERROR_STOP_LEARNING.format(error=e), exc_info=True) - yield event.plain_result(CommandMessages.STOP_FAILED.format(error=str(e))) + async for result in self._command_handlers.stop_learning(event): + yield result @filter.command("force_learning") @filter.permission_type(PermissionType.ADMIN) async def force_learning_command(self, event: AstrMessageEvent): """强制执行一次学习周期""" - try: - group_id = event.get_group_id() or event.get_sender_id() - yield event.plain_result(CommandMessages.FORCE_LEARNING_START.format(group_id=group_id)) - - # 设置标志位防止无限循环 - self._force_learning_in_progress = getattr(self, '_force_learning_in_progress', set()) - if group_id in self._force_learning_in_progress: - yield event.plain_result(f"❌ 群组 {group_id} 的强制学习正在进行中,请等待完成") - return - - self._force_learning_in_progress.add(group_id) - - try: - # 直接调用 ProgressiveLearningService 的批处理方法 - await self.progressive_learning._execute_learning_batch(group_id) - yield event.plain_result(CommandMessages.FORCE_LEARNING_COMPLETE.format(group_id=group_id)) - finally: - # 无论成功失败都要清理标志位 - self._force_learning_in_progress.discard(group_id) - - except Exception as e: - logger.error(CommandMessages.ERROR_FORCE_LEARNING.format(error=e), exc_info=True) - yield event.plain_result(CommandMessages.ERROR_FORCE_LEARNING.format(error=str(e))) + async for result in self._command_handlers.force_learning(event): + yield result @filter.command("affection_status") @filter.permission_type(PermissionType.ADMIN) async def affection_status_command(self, event: AstrMessageEvent): """查看好感度状态""" - try: - group_id = event.get_group_id() or event.get_sender_id() - user_id = event.get_sender_id() - - if not self.plugin_config.enable_affection_system: - yield event.plain_result(CommandMessages.AFFECTION_DISABLED) - return - - # 获取好感度状态 - affection_status = await self.affection_manager.get_affection_status(group_id) - - # 确保当前群组有情绪状态(如果没有会自动创建随机情绪) - current_mood = None - if self.plugin_config.enable_startup_random_mood: - current_mood = await self.affection_manager.ensure_mood_for_group(group_id) - else: - current_mood = await self.affection_manager.get_current_mood(group_id) - - # 获取用户个人好感度 - user_affection = await self.db_manager.get_user_affection(group_id, user_id) - user_level = user_affection['affection_level'] if user_affection else 0 - - status_info = CommandMessages.AFFECTION_STATUS_HEADER.format(group_id=group_id) - status_info += "\n\n" + CommandMessages.AFFECTION_USER_LEVEL.format( - user_level=user_level, max_affection=self.plugin_config.max_user_affection - ) - status_info += "\n" + CommandMessages.AFFECTION_TOTAL_STATUS.format( - total_affection=affection_status['total_affection'], - max_total_affection=affection_status['max_total_affection'] - ) - status_info += "\n" + CommandMessages.AFFECTION_USER_COUNT.format(user_count=affection_status['user_count']) - status_info += "\n\n" + CommandMessages.AFFECTION_CURRENT_MOOD - - if current_mood: - mood_info = current_mood - status_info += "\n" + CommandMessages.AFFECTION_MOOD_TYPE.format(mood_type=mood_info.mood_type.value) - status_info += "\n" + CommandMessages.AFFECTION_MOOD_INTENSITY.format(intensity=mood_info.intensity) - status_info += "\n" + CommandMessages.AFFECTION_MOOD_DESCRIPTION.format(description=mood_info.description) - else: - status_info += "\n" + CommandMessages.AFFECTION_NO_MOOD - - if affection_status['top_users']: - status_info += "\n\n" + CommandMessages.AFFECTION_TOP_USERS - for i, user in enumerate(affection_status['top_users'][:3], 1): - status_info += "\n" + CommandMessages.AFFECTION_USER_RANK.format( - rank=i, user_id=user['user_id'], affection_level=user['affection_level'] - ) - - yield event.plain_result(status_info) - - except Exception as e: - logger.error(CommandMessages.ERROR_GET_AFFECTION_STATUS.format(error=e), exc_info=True) - yield event.plain_result(CommandMessages.ERROR_GET_AFFECTION_STATUS.format(error=str(e))) + async for result in self._command_handlers.affection_status(event): + yield result @filter.command("set_mood") @filter.permission_type(PermissionType.ADMIN) async def set_mood_command(self, event: AstrMessageEvent): - """手动设置bot情绪(通过增量人格更新)""" - try: - if not self.plugin_config.enable_affection_system: - yield event.plain_result(CommandMessages.AFFECTION_DISABLED) - return - - args = event.get_message_str().split()[1:] # 获取命令参数 - if len(args) < 1: - yield event.plain_result("使用方法:/set_mood \n可用情绪: happy, sad, excited, calm, angry, anxious, playful, serious, nostalgic, curious") - return - - group_id = event.get_group_id() or event.get_sender_id() - mood_type = args[0].lower() - - # 验证情绪类型 - valid_moods = { - 'happy': '心情很好,说话比较活泼开朗,容易表达正面情感', - 'sad': '心情有些低落,说话比较温和,需要更多的理解和安慰', - 'excited': '很兴奋,说话比较有活力,对很多事情都很感兴趣', - 'calm': '心情平静,说话比较稳重,给人安全感', - 'angry': '心情不太好,说话可能比较直接,不太有耐心', - 'anxious': '有些紧张不安,说话可能比较谨慎,需要更多确认', - 'playful': '心情很调皮,喜欢开玩笑,说话比较幽默风趣', - 'serious': '比较严肃认真,说话简洁直接,专注于重要的事情', - 'nostalgic': '有些怀旧情绪,说话带有回忆色彩,比较感性', - 'curious': '对很多事情都很好奇,喜欢提问和探索新事物' - } - - if mood_type not in valid_moods: - yield event.plain_result(f"❌ 无效的情绪类型。支持的情绪: {', '.join(valid_moods.keys())}") - return - - # 通过增量更新的方式设置情绪 - mood_description = valid_moods[mood_type] - - # 统一使用apply_mood_based_persona_update方法,它会同时处理文件和prompt更新 - persona_success = await self.temporary_persona_updater.apply_mood_based_persona_update( - group_id, mood_type, mood_description - ) - - # 同时在affection_manager中记录情绪状态(但不重复添加到prompt) - from .services.affection_manager import MoodType - try: - mood_enum = MoodType(mood_type) - # 只记录到affection_manager的数据库,不更新prompt(避免重复) - await self.affection_manager.db_manager.save_bot_mood( - group_id, mood_type, 0.7, mood_description, - self.plugin_config.mood_persistence_hours or 24 - ) - # 更新内存缓存 - from .services.affection_manager import BotMood - import time - mood_obj = BotMood( - mood_type=mood_enum, - intensity=0.7, - description=mood_description, - start_time=time.time(), - duration_hours=self.plugin_config.mood_persistence_hours or 24 - ) - self.affection_manager.current_moods[group_id] = mood_obj - affection_success = True - except Exception as e: - logger.warning(f"设置affection_manager情绪失败: {e}") - affection_success = False - - if persona_success: - status_msg = f"✅ 情绪状态已设置为: {mood_type}\n描述: {mood_description}" - if not affection_success: - status_msg += "\n⚠️ 注意:情绪状态可能无法在状态查询中正确显示" - yield event.plain_result(status_msg) - else: - yield event.plain_result(f"❌ 设置情绪状态失败") - - except Exception as e: - logger.error(CommandMessages.ERROR_SET_MOOD.format(error=e), exc_info=True) - yield event.plain_result(CommandMessages.ERROR_SET_MOOD.format(error=str(e))) - - @filter.on_llm_request() - async def inject_diversity_to_llm_request(self, event: AstrMessageEvent, req=None): - """在所有LLM请求前注入多样性增强prompt - 框架层面Hook (始终生效,不需要开启自动学习) - - 重要改进 (v1.1.1): - - 将注入内容添加到 req.system_prompt 而不是 req.prompt - - 解决对话历史膨胀问题:AstrBot 只保存 req.prompt 到对话历史,不保存 system_prompt - - 避免 token 超限:每次对话不再累积注入的人格设定、社交上下文、多样性提示 - - 注入内容包括: - 1. 社交上下文(表达模式学习、社交关系、好感度、深度心理状态、行为指导) - 2. 多样性增强(语言风格、回复模式、表达变化、历史Bot消息避重) - 3. 黑话理解(如果用户消息中包含黑话) - 4. 会话级增量更新(临时人格调整) - """ - try: - # 检查 req 参数是否存在 - if req is None: - logger.warning("[LLM Hook] req 参数为 None,跳过注入") - return - - # 如果diversity_manager不存在,跳过注入 - if not hasattr(self, 'diversity_manager') or not self.diversity_manager: - logger.debug("[LLM Hook] diversity_manager未初始化,跳过多样性注入") - return - - group_id = event.get_group_id() or event.get_sender_id() - user_id = event.get_sender_id() - - # ✅ 维护group_id到unified_msg_origin的映射 - if hasattr(event, 'unified_msg_origin') and event.unified_msg_origin: - self.group_id_to_unified_origin[group_id] = event.unified_msg_origin - logger.debug(f"[LLM Hook] 更新映射: {group_id} -> {event.unified_msg_origin}") - - # 检查是否有内容可注入 - if not req.prompt: - logger.debug("[LLM Hook] req.prompt为空,跳过多样性注入") - return - - original_prompt_length = len(req.prompt) - logger.info(f"✅ [LLM Hook] 开始注入多样性增强 (group: {group_id}, 原prompt长度: {original_prompt_length})") - - # 收集要注入的内容 - 所有增量内容都注入到 req.prompt(用户消息上下文) - prompt_injections = [] - - # ❌ 移除重复的人格注入 - 框架已经在 req.system_prompt 中注入了 persona["prompt"] - # 如果需要查看当前人格,可以通过 req.system_prompt 访问 - # session_persona_prompt = await self._get_active_persona_prompt(event) - logger.debug("[LLM Hook] 跳过基础人格注入(框架已处理),专注于增量内容") - - # ✅ 1. 注入社交上下文(已整合所有功能) - # SocialContextInjector 现在包含: - # - 表达模式学习(原有) - # - 社交关系(原有) - # - 好感度(原有) - # - 基础情绪(原有) - # - 深度心理状态(整合自 PsychologicalSocialContextInjector) - # - 行为模式指导(整合自 PsychologicalSocialContextInjector) - - if hasattr(self, 'social_context_injector') and self.social_context_injector: - try: - social_context = await self.social_context_injector.format_complete_context( - group_id=group_id, - user_id=user_id, - include_social_relations=self.plugin_config.include_social_relations, # 社交关系 - include_affection=self.plugin_config.include_affection_info, # 好感度 - include_mood=False, # 基础情绪(已被深度心理状态包含,避免重复) - include_expression_patterns=True, # ⭐ 表达模式学习结果 - include_psychological=True, # ⭐ 深度心理状态分析 - include_behavior_guidance=True, # ⭐ 行为模式指导 - include_conversation_goal=self.plugin_config.enable_goal_driven_chat, # ⭐ 对话目标上下文 - enable_protection=True - ) - if social_context: - prompt_injections.append(social_context) - logger.info(f"✅ [LLM Hook] 已准备完整社交上下文 (长度: {len(social_context)})") - else: - logger.debug(f"[LLM Hook] 群组 {group_id} 暂无社交上下文") - except Exception as e: - logger.warning(f"[LLM Hook] 注入社交上下文失败: {e}") - else: - logger.debug("[LLM Hook] social_context_injector未初始化,跳过社交上下文注入") - - # ✅ 2. 构建多样性增强内容 (不传入base_prompt,只生成注入内容) - 注入到 prompt - diversity_content = await self.diversity_manager.build_diversity_prompt_injection( - "", # 传空字符串,只生成注入内容 - group_id=group_id, # 传入group_id以获取历史消息 - inject_style=True, - inject_pattern=True, - inject_variation=True, - inject_history=True # 注入历史Bot消息,避免重复 - ) - - # 提取纯注入内容(去除空的base_prompt) - diversity_content = diversity_content.strip() - if diversity_content: - prompt_injections.append(diversity_content) - logger.info(f"✅ [LLM Hook] 已准备多样性增强内容 (长度: {len(diversity_content)})") - - # ✅ 3. 注入黑话理解(如果用户消息中包含黑话)- 注入到 prompt - if hasattr(self, 'jargon_query_service') and self.jargon_query_service: - try: - # 获取用户消息文本 - user_message = event.message_str if hasattr(event, 'message_str') else str(event.get_message()) - - # 检查消息中是否包含黑话,并获取解释 - jargon_explanation = await self.jargon_query_service.check_and_explain_jargon( - text=user_message, - chat_id=group_id - ) - - if jargon_explanation: - prompt_injections.append(jargon_explanation) - logger.info(f"✅ [LLM Hook] 已准备黑话理解内容 (长度: {len(jargon_explanation)})") - else: - logger.debug(f"[LLM Hook] 用户消息中未检测到已知黑话") - except Exception as e: - logger.warning(f"[LLM Hook] 注入黑话理解失败: {e}") - else: - logger.debug("[LLM Hook] jargon_query_service未初始化,跳过黑话注入") - - # ✅ 4. 注入会话级增量更新 (修复会话串流bug) - 注入到 prompt - if hasattr(self, 'temporary_persona_updater') and self.temporary_persona_updater: - try: - session_updates = self.temporary_persona_updater.session_updates.get(group_id, []) - if session_updates: - updates_text = '\n\n'.join(session_updates) - prompt_injections.append(updates_text) - logger.info(f"✅ [LLM Hook] 已准备会话级更新 (会话: {group_id}, 更新数: {len(session_updates)}, 长度: {len(updates_text)})") - else: - logger.debug(f"[LLM Hook] 会话 {group_id} 暂无增量更新") - except Exception as e: - logger.warning(f"[LLM Hook] 注入会话级更新失败: {e}") - else: - logger.debug("[LLM Hook] temporary_persona_updater未初始化,跳过会话级更新注入") - - # ✅ 5. 注入所有增量内容(根据配置选择注入位置) - # 关键改进 (v1.1.1):支持将注入内容添加到 system_prompt 或 prompt - # - system_prompt: 不会被 AstrBot 保存到对话历史,避免历史膨胀 (推荐) - # - prompt: 会被保存到对话历史,导致 token 累积和超限 (旧版行为) - if prompt_injections: - prompt_injection_text = '\n\n'.join(prompt_injections) - - # 根据配置决定注入位置 - injection_target = getattr(self.plugin_config, 'llm_hook_injection_target', 'system_prompt') - - if injection_target == 'system_prompt': - # 注入到 system_prompt(推荐,不会被保存到对话历史) - if not req.system_prompt: - req.system_prompt = "" - - original_length = len(req.system_prompt) - req.system_prompt += '\n\n' + prompt_injection_text - final_length = len(req.system_prompt) - injected_length = final_length - original_length - - logger.info(f"✅ [LLM Hook] System Prompt 注入完成 - 原长度: {original_length}, 新增: {injected_length}, 总长度: {final_length}") - logger.info(f"💡 [LLM Hook] 注入位置: system_prompt (不会被保存到对话历史)") - - else: - # 注入到 prompt(旧版行为,会导致对话历史膨胀) - original_length = len(req.prompt) - req.prompt += '\n\n' + prompt_injection_text - final_length = len(req.prompt) - injected_length = final_length - original_length - - logger.info(f"✅ [LLM Hook] Prompt 注入完成 - 原长度: {original_length}, 新增: {injected_length}, 总长度: {final_length}") - logger.warning(f"⚠️ [LLM Hook] 注入位置: prompt (会被保存到对话历史,可能导致token超限)") - - # 统计和日志 - current_language_style = self.diversity_manager.get_current_style() - current_response_pattern = self.diversity_manager.get_current_pattern() - - logger.info(f"✅ [LLM Hook] 当前语言风格: {current_language_style}, 回复模式: {current_response_pattern}") - logger.info(f"✅ [LLM Hook] 注入内容数量: {len(prompt_injections)}项") - logger.debug(f"✅ [LLM Hook] 注入内容预览: {prompt_injection_text[:200]}...") - else: - logger.debug("[LLM Hook] 没有可注入的增量内容") - - except Exception as e: - logger.error(f"❌ [LLM Hook] 框架层面注入多样性失败: {e}", exc_info=True) - - async def terminate(self): - """插件卸载时的清理工作 - 增强后台任务管理""" - try: - logger.info("开始插件清理工作...") - - # 1. 停止所有学习任务 - logger.info("停止所有学习任务...") - for group_id, task in list(self.learning_tasks.items()): - try: - # 先停止学习流程 - await self.progressive_learning.stop_learning() - - # 取消学习任务 - if not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - logger.info(f"群组 {group_id} 学习任务已停止") - except Exception as e: - logger.error(f"停止群组 {group_id} 学习任务失败: {e}") - - self.learning_tasks.clear() - - # 2. 停止学习调度器 - if hasattr(self, 'learning_scheduler'): - try: - await self.learning_scheduler.stop() - logger.info("学习调度器已停止") - except Exception as e: - logger.error(f"停止学习调度器失败: {e}") - - # 3. 取消所有后台任务 - logger.info("取消所有后台任务...") - for task in list(self.background_tasks): - try: - if not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(LogMessages.BACKGROUND_TASK_CANCEL_ERROR.format(error=e)) - - self.background_tasks.clear() - - # 4. 停止所有服务 - logger.info("停止所有服务...") - if hasattr(self, 'factory_manager'): - try: - await self.factory_manager.cleanup() - logger.info("服务工厂已清理") - except Exception as e: - logger.error(f"清理服务工厂失败: {e}") - - # 4.5 重置单例管理器,确保重启时重新初始化 - try: - from .services.memory_graph_manager import MemoryGraphManager - MemoryGraphManager._instance = None - MemoryGraphManager._initialized = False - logger.info("MemoryGraphManager 单例已重置") - except Exception: - pass - - # 5. 清理临时人格 - if hasattr(self, 'temporary_persona_updater'): - try: - await self.temporary_persona_updater.cleanup_temp_personas() - logger.info("临时人格已清理") - except Exception as e: - logger.error(f"清理临时人格失败: {e}") - - # 6. 保存最终状态 - if hasattr(self, 'message_collector'): - try: - await self.message_collector.save_state() - logger.info("消息收集器状态已保存") - except Exception as e: - logger.error(f"保存消息收集器状态失败: {e}") - - # 7. 停止 Web 服务器 (终极修正) - global server_instance, _server_cleanup_lock - async with _server_cleanup_lock: - if server_instance: - try: - logger.info(f"正在停止Web服务器 (端口: {server_instance.port})...") - - # [A] 停止服务 (跨线程通知退出) - await server_instance.stop() - - # [B] 关键新增:强制垃圾回收 - # 确保 Socket 句柄立即释放,而不是等待 Python 自动回收 - # 这对 Windows 这种 Socket 敏感的系统至关重要 - import gc - gc.collect() - - # [C] 平台差异化等待 - import sys - if sys.platform == 'win32': - logger.info("Windows环境:等待端口资源释放...") - # Windows 需要给内核一点时间把 TIME_WAIT 清理掉 - await asyncio.sleep(2.0) - - server_instance = None - logger.info("Web服务器实例已清理") - except Exception as e: - logger.error(f"停止Web服务器失败: {e}", exc_info=True) - server_instance = None - - # 8. 保存配置到文件 - try: - config_path = os.path.join(self.plugin_config.data_dir, 'config.json') - with open(config_path, 'w', encoding='utf-8') as f: - json.dump(self.plugin_config.to_dict(), f, ensure_ascii=False, indent=2) - logger.info(LogMessages.PLUGIN_CONFIG_SAVED) - except Exception as e: - logger.error(f"保存配置失败: {e}") - - logger.info(LogMessages.PLUGIN_UNLOAD_SUCCESS) - - except Exception as e: - logger.error(LogMessages.PLUGIN_UNLOAD_CLEANUP_FAILED.format(error=e), exc_info=True) - - async def _get_active_persona_prompt(self, event: AstrMessageEvent) -> Optional[str]: - """ - 获取当前会话配置的人格提示词 - - 优先读取 AstrBot 框架中的会话 -> 人格映射,回退到默认人格 - """ - try: - if not event or not hasattr(self, "context"): - return None - - conv_manager = getattr(self.context, "conversation_manager", None) - astr_persona_manager = getattr(self.context, "persona_manager", None) - if not conv_manager or not astr_persona_manager: - return None - - unified_origin = getattr(event, "unified_msg_origin", None) - if not unified_origin: - return None - - conv_id = await conv_manager.get_curr_conversation_id(unified_origin) - if not conv_id: - conv_id = await conv_manager.new_conversation(unified_origin) - - conv = await conv_manager.get_conversation( - unified_msg_origin=unified_origin, - conversation_id=conv_id, - create_if_not_exists=True, - ) - - persona_id = None - if conv: - conv_persona_id = getattr(conv, "persona_id", None) - if conv_persona_id and conv_persona_id != "[%None]": - persona_id = conv_persona_id - - persona_data = None - if persona_id: - persona_data = await astr_persona_manager.get_persona(persona_id) - else: - persona_data = await astr_persona_manager.get_default_persona_v3(umo=unified_origin) - - if not persona_data: - return None - - if isinstance(persona_data, dict): - return persona_data.get("system_prompt") or persona_data.get("prompt") - - return getattr(persona_data, "system_prompt", None) - - except Exception as exc: - logger.warning(f"获取会话人格失败: {exc}") - return None - - def _format_communication_style(self, communication_style: dict) -> str: - """ - 将沟通风格字典转换为可读描述 - - Args: - communication_style: 沟通风格字典 - - Returns: - str: 可读的描述文本 - """ - try: - if not communication_style or not isinstance(communication_style, dict): - return "" - - descriptions = [] - - # 解析各种沟通风格特征 - if 'formality' in communication_style: - formality = communication_style['formality'] - if formality > 0.7: - descriptions.append("正式礼貌") - elif formality < 0.3: - descriptions.append("随意轻松") - else: - descriptions.append("适中得体") - - if 'enthusiasm' in communication_style: - enthusiasm = communication_style['enthusiasm'] - if enthusiasm > 0.7: - descriptions.append("热情活跃") - elif enthusiasm < 0.3: - descriptions.append("冷静内敛") - - if 'directness' in communication_style: - directness = communication_style['directness'] - if directness > 0.7: - descriptions.append("直接坦率") - elif directness < 0.3: - descriptions.append("委婉含蓄") - - if 'humor_usage' in communication_style: - humor = communication_style['humor_usage'] - if humor > 0.6: - descriptions.append("幽默风趣") - - if 'emoji_usage' in communication_style: - emoji = communication_style['emoji_usage'] - if emoji > 0.6: - descriptions.append("表情丰富") - - return ",".join(descriptions) if descriptions else "普通交流风格" - - except Exception as e: - logger.debug(f"格式化沟通风格失败: {e}") - return "" - - def _format_emotional_tendency(self, emotional_tendency: dict) -> str: - """ - 将情感倾向字典转换为可读描述 - - Args: - emotional_tendency: 情感倾向字典 - - Returns: - str: 可读的描述文本 - """ - try: - if not emotional_tendency or not isinstance(emotional_tendency, dict): - return "" - - descriptions = [] - - # 解析情感倾向特征 - if 'positivity' in emotional_tendency: - positivity = emotional_tendency['positivity'] - if positivity > 0.7: - descriptions.append("积极乐观") - elif positivity < 0.3: - descriptions.append("情绪较低") - - if 'stability' in emotional_tendency: - stability = emotional_tendency['stability'] - if stability > 0.7: - descriptions.append("情绪稳定") - elif stability < 0.3: - descriptions.append("情绪波动") - - if 'empathy' in emotional_tendency: - empathy = emotional_tendency['empathy'] - if empathy > 0.6: - descriptions.append("善解人意") - - if 'expressiveness' in emotional_tendency: - expressiveness = emotional_tendency['expressiveness'] - if expressiveness > 0.6: - descriptions.append("表达丰富") - elif expressiveness < 0.3: - descriptions.append("表达内敛") - - if 'dominant_emotion' in emotional_tendency: - dominant = emotional_tendency['dominant_emotion'] - emotion_map = { - 'happy': '快乐', - 'calm': '平静', - 'excited': '兴奋', - 'serious': '严肃', - 'playful': '活泼', - 'thoughtful': '深思', - 'caring': '关怀' - } - if dominant in emotion_map: - descriptions.append(f"偏向{emotion_map[dominant]}") - - return ",".join(descriptions) if descriptions else "情感表达平和" - - except Exception as e: - logger.debug(f"格式化情感倾向失败: {e}") - return "" + """手动设置bot情绪""" + async for result in self._command_handlers.set_mood(event): + yield result diff --git a/metadata.yaml b/metadata.yaml index ddc4f6b..79bb374 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -2,7 +2,7 @@ name: "astrbot_plugin_self_learning" author: "NickMo" display_name: "self-learning" description: "SELF LEARNING 自主学习插件 — 让 AI 聊天机器人自主学习对话风格、理解群组黑话、管理社交关系与好感度、自适应人格演化,像真人一样自然对话。(使用前必须手动备份人格数据)" -version: "Next-1.2.9" +version: "Next-2.0.0" repo: "https://github.com/NickCharlie/astrbot_plugin_self_learning" tags: - "自学习" diff --git a/models/orm/__init__.py b/models/orm/__init__.py index 3020532..8de2bf5 100644 --- a/models/orm/__init__.py +++ b/models/orm/__init__.py @@ -19,13 +19,18 @@ PsychologicalStateHistory, PersonaDiversityScore, PersonaAttributeWeight, - PersonaEvolutionSnapshot + PersonaEvolutionSnapshot, + EmotionProfile, + BotMood, + PersonaBackup, ) from .social_relation import ( SocialRelation, UserSocialProfile, UserSocialRelationComponent, - SocialRelationHistory + SocialRelationHistory, + UserProfile, + UserPreferences, ) from .social_analysis import ( SocialRelationAnalysisResult, @@ -45,7 +50,10 @@ from .expression import ( ExpressionPattern, ExpressionGenerationResult, - AdaptiveResponseTemplate + AdaptiveResponseTemplate, + StyleProfile, + StyleLearningRecord, + LanguageStylePattern, ) from .performance import ( LearningPerformanceHistory @@ -76,35 +84,43 @@ KGRelation, KGParagraphHash ) +from .exemplar import ( + Exemplar +) __all__ = [ 'Base', - # 好感度系统 + # Affection 'UserAffection', 'AffectionInteraction', 'UserConversationHistory', 'UserDiversity', - # 记忆系统 + # Memory 'Memory', 'MemoryEmbedding', 'MemorySummary', - # 心理状态系统 + # Psychological 'CompositePsychologicalState', 'PsychologicalStateComponent', 'PsychologicalStateHistory', 'PersonaDiversityScore', 'PersonaAttributeWeight', 'PersonaEvolutionSnapshot', - # 社交关系系统 + 'EmotionProfile', + 'BotMood', + 'PersonaBackup', + # Social 'SocialRelation', 'UserSocialProfile', 'UserSocialRelationComponent', 'SocialRelationHistory', - # 社交分析 + 'UserProfile', + 'UserPreferences', + # Social analysis 'SocialRelationAnalysisResult', 'SocialNetworkNode', 'SocialNetworkEdge', - # 学习系统 + # Learning 'PersonaLearningReview', 'StyleLearningReview', 'StyleLearningPattern', @@ -113,13 +129,16 @@ 'LearningSession', 'LearningReinforcementFeedback', 'LearningOptimizationLog', - # 表达模式 + # Expression 'ExpressionPattern', 'ExpressionGenerationResult', 'AdaptiveResponseTemplate', - # 性能记录 + 'StyleProfile', + 'StyleLearningRecord', + 'LanguageStylePattern', + # Performance 'LearningPerformanceHistory', - # 消息系统 + # Message 'RawMessage', 'FilteredMessage', 'BotMessage', @@ -127,17 +146,19 @@ 'ConversationTopicClustering', 'ConversationQualityMetrics', 'ContextSimilarityCache', - # 黑话系统 + # Jargon 'Jargon', 'JargonUsageFrequency', - # 对话目标系统 + # Conversation goal 'ConversationGoal', - # 强化学习系统 + # Reinforcement learning 'ReinforcementLearningResult', 'PersonaFusionHistory', 'StrategyOptimizationResult', - # 知识图谱系统 + # Knowledge graph 'KGEntity', 'KGRelation', 'KGParagraphHash', + # Exemplar + 'Exemplar', ] diff --git a/models/orm/exemplar.py b/models/orm/exemplar.py new file mode 100644 index 0000000..087e37f --- /dev/null +++ b/models/orm/exemplar.py @@ -0,0 +1,60 @@ +""" +Exemplar ORM model. + +Stores high-quality message examples used for few-shot style imitation. +Each exemplar captures the original text along with its embedding vector +for similarity-based retrieval. +""" + +import time + +from sqlalchemy import ( + BigInteger, + Column, + Float, + Index, + Integer, + String, + Text, +) +from sqlalchemy.dialects.mysql import MEDIUMTEXT + +from .base import Base + +# MEDIUMTEXT on MySQL (16 MB), plain TEXT on SQLite (no size limit). +# Required for high-dimensional embedding vectors (e.g. 3072-dim ≈ 69 KB JSON). +_EmbeddingText = Text().with_variant(MEDIUMTEXT(), "mysql") + + +class Exemplar(Base): + """Few-shot style exemplar record. + + Attributes: + id: Auto-increment primary key. + content: The original message text serving as style example. + sender_id: ID of the message sender. + group_id: Chat group identifier. + embedding_json: Serialised embedding vector (JSON float array). + weight: Quality weight (adjusted by feedback, default 1.0). + dimensions: Embedding vector dimensionality (for validation). + created_at: Unix timestamp of record creation. + updated_at: Unix timestamp of last update. + """ + + __tablename__ = "exemplar" + + id = Column(Integer, primary_key=True, autoincrement=True) + content = Column(Text, nullable=False) + sender_id = Column(String(255), nullable=True) + group_id = Column(String(255), nullable=False) + embedding_json = Column(_EmbeddingText, nullable=True) + weight = Column(Float, default=1.0) + dimensions = Column(Integer, default=0) + created_at = Column(BigInteger, nullable=False, default=lambda: int(time.time())) + updated_at = Column(BigInteger, nullable=False, default=lambda: int(time.time())) + + __table_args__ = ( + Index("idx_exemplar_group_id", "group_id"), + Index("idx_exemplar_weight", "weight"), + Index("idx_exemplar_group_weight", "group_id", "weight"), + ) diff --git a/models/orm/expression.py b/models/orm/expression.py index 5011b2c..d296f4d 100644 --- a/models/orm/expression.py +++ b/models/orm/expression.py @@ -112,3 +112,59 @@ def to_dict(self): 'created_at': self.created_at.isoformat() if self.created_at else None } + +class StyleProfile(Base): + """Aggregate style profile for a persona or learning context.""" + __tablename__ = 'style_profiles' + + id = Column(Integer, primary_key=True, autoincrement=True) + profile_name = Column(String(255), nullable=False) + vocabulary_richness = Column(Float) + sentence_complexity = Column(Float) + emotional_expression = Column(Float) + interaction_tendency = Column(Float) + topic_diversity = Column(Float) + formality_level = Column(Float) + creativity_score = Column(Float) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('idx_style_profile_name', 'profile_name'), + ) + + +class StyleLearningRecord(Base): + """Record of a style learning session.""" + __tablename__ = 'style_learning_records' + + id = Column(Integer, primary_key=True, autoincrement=True) + style_type = Column(String(100), nullable=False) + learned_patterns = Column(Text) # JSON + confidence_score = Column(Float) + sample_count = Column(Integer) + last_updated = Column(Float) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('idx_style_record_type', 'style_type'), + ) + + +class LanguageStylePattern(Base): + """Reusable language style pattern with example phrases.""" + __tablename__ = 'language_style_patterns' + + id = Column(Integer, primary_key=True, autoincrement=True) + language_style = Column(String(255), nullable=False) + example_phrases = Column(Text) # JSON + usage_frequency = Column(Integer, default=0) + context_type = Column(String(100), default='general') + confidence_score = Column(Float) + last_updated = Column(Float) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('idx_lang_style', 'language_style'), + Index('idx_lang_context', 'context_type'), + ) + diff --git a/models/orm/learning.py b/models/orm/learning.py index 79da8f5..9633029 100644 --- a/models/orm/learning.py +++ b/models/orm/learning.py @@ -11,18 +11,18 @@ class PersonaLearningReview(Base): __tablename__ = 'persona_update_reviews' id = Column(Integer, primary_key=True, autoincrement=True) - timestamp = Column(Float, nullable=False) # 使用 REAL/Float 以匹配传统数据库 + timestamp = Column(Float, nullable=False) # 使用 REAL/Float 以匹配传统数据库 group_id = Column(String(255), nullable=False, index=True) - update_type = Column(String(255), nullable=False) # personality_trait, background_story, speaking_style, etc. + update_type = Column(String(255), nullable=False) # personality_trait, background_story, speaking_style, etc. original_content = Column(Text) new_content = Column(Text) - proposed_content = Column(Text) # 建议的新内容(兼容字段) - confidence_score = Column(Float) # 置信度得分 - reason = Column(Text) # 学习原因 - status = Column(String(50), default='pending', nullable=False) # pending/approved/rejected + proposed_content = Column(Text) # 建议的新内容(兼容字段) + confidence_score = Column(Float) # 置信度得分 + reason = Column(Text) # 学习原因 + status = Column(String(50), default='pending', nullable=False) # pending/approved/rejected reviewer_comment = Column(Text) - review_time = Column(Float) # 使用 REAL/Float 以匹配传统数据库 - metadata_ = Column('metadata', Text) # JSON格式的元数据,使用 metadata_ 避免与 SQLAlchemy 保留字冲突 + review_time = Column(Float) # 使用 REAL/Float 以匹配传统数据库 + metadata_ = Column('metadata', Text) # JSON格式的元数据,使用 metadata_ 避免与 SQLAlchemy 保留字冲突 __table_args__ = ( Index('idx_group_persona_review', 'group_id', 'status'), @@ -36,19 +36,19 @@ class StyleLearningReview(Base): __tablename__ = 'style_learning_reviews' id = Column(Integer, primary_key=True, autoincrement=True) - type = Column(String(100), nullable=False) # 学习类型 + type = Column(String(100), nullable=False) # 学习类型 group_id = Column(String(255), nullable=False, index=True) - timestamp = Column(Float, nullable=False) # 使用 REAL/Float 以匹配传统数据库 - learned_patterns = Column(Text) # JSON格式存储学习的模式 - few_shots_content = Column(Text) # Few-shot 示例内容 - status = Column(String(50), default='pending') # pending/approved/rejected - description = Column(Text) # 描述信息 - reviewer_comment = Column(Text) # 审查评论 - review_time = Column(Float) # 审查时间 - # ✅ 修改为 DateTime 类型以兼容 MySQL 的 DATETIME + timestamp = Column(Float, nullable=False) # 使用 REAL/Float 以匹配传统数据库 + learned_patterns = Column(Text) # JSON格式存储学习的模式 + few_shots_content = Column(Text) # Few-shot 示例内容 + status = Column(String(50), default='pending') # pending/approved/rejected + description = Column(Text) # 描述信息 + reviewer_comment = Column(Text) # 审查评论 + review_time = Column(Float) # 审查时间 + # 修改为 DateTime 类型以兼容 MySQL 的 DATETIME # SQLite 使用 TIMESTAMP,MySQL 使用 DATETIME,SQLAlchemy 的 DateTime 可以自动适配 - created_at = Column(DateTime) # 创建时间 - updated_at = Column(DateTime) # 更新时间 + created_at = Column(DateTime) # 创建时间 + updated_at = Column(DateTime) # 更新时间 __table_args__ = ( Index('idx_status', 'status'), @@ -65,9 +65,9 @@ class StyleLearningPattern(Base): group_id = Column(String(100), nullable=False, index=True) pattern_type = Column(String(50), nullable=False) pattern = Column(Text, nullable=False) - usage_count = Column(Integer, default=0) # 使用次数 - confidence = Column(Float, default=1.0) # 置信度 - last_used = Column(BigInteger) # 最后使用时间 + usage_count = Column(Integer, default=0) # 使用次数 + confidence = Column(Float, default=1.0) # 置信度 + last_used = Column(BigInteger) # 最后使用时间 created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) @@ -85,7 +85,7 @@ class InteractionRecord(Base): id = Column(Integer, primary_key=True, autoincrement=True) group_id = Column(String(100), nullable=False, index=True) user_id = Column(String(100), nullable=False, index=True) - interaction_type = Column(String(50), nullable=False) # message, reaction, mention, etc. + interaction_type = Column(String(50), nullable=False) # message, reaction, mention, etc. content_preview = Column(String(200)) timestamp = Column(BigInteger, nullable=False) @@ -148,7 +148,7 @@ class LearningSession(Base): id = Column(Integer, primary_key=True, autoincrement=True) session_id = Column(String(255), unique=True, nullable=False, index=True) group_id = Column(String(255), nullable=False, index=True) - batch_id = Column(String(255), nullable=True) # 外键到 learning_batches.batch_id + batch_id = Column(String(255), nullable=True) # 外键到 learning_batches.batch_id start_time = Column(Float, nullable=False) end_time = Column(Float, nullable=True) message_count = Column(Integer, default=0) @@ -184,10 +184,10 @@ class LearningReinforcementFeedback(Base): id = Column(Integer, primary_key=True, autoincrement=True) group_id = Column(String(255), nullable=False, index=True) - feedback_type = Column(String(100), nullable=False) # positive, negative, neutral - feedback_content = Column(Text, nullable=True) # 详细反馈内容 - effectiveness_score = Column(Float, nullable=True) # 反馈有效性评分 - applied_at = Column(Float, nullable=False) # 应用时间戳 + feedback_type = Column(String(100), nullable=False) # positive, negative, neutral + feedback_content = Column(Text, nullable=True) # 详细反馈内容 + effectiveness_score = Column(Float, nullable=True) # 反馈有效性评分 + applied_at = Column(Float, nullable=False) # 应用时间戳 created_at = Column(DateTime, default=func.now()) __table_args__ = ( @@ -215,12 +215,12 @@ class LearningOptimizationLog(Base): id = Column(Integer, primary_key=True, autoincrement=True) group_id = Column(String(255), nullable=False, index=True) - optimization_type = Column(String(100), nullable=False) # parameter_tuning, strategy_adjustment, etc. - parameters = Column(Text, nullable=True) # JSON格式的参数配置 - before_metrics = Column(Text, nullable=True) # JSON格式的优化前指标 - after_metrics = Column(Text, nullable=True) # JSON格式的优化后指标 - improvement_rate = Column(Float, nullable=True) # 改进率 - applied_at = Column(Float, nullable=False) # 应用时间戳 + optimization_type = Column(String(100), nullable=False) # parameter_tuning, strategy_adjustment, etc. + parameters = Column(Text, nullable=True) # JSON格式的参数配置 + before_metrics = Column(Text, nullable=True) # JSON格式的优化前指标 + after_metrics = Column(Text, nullable=True) # JSON格式的优化后指标 + improvement_rate = Column(Float, nullable=True) # 改进率 + applied_at = Column(Float, nullable=False) # 应用时间戳 created_at = Column(DateTime, default=func.now()) __table_args__ = ( diff --git a/models/orm/psychological.py b/models/orm/psychological.py index d9f0fb5..471d73c 100644 --- a/models/orm/psychological.py +++ b/models/orm/psychological.py @@ -14,8 +14,8 @@ class CompositePsychologicalState(Base): id = Column(Integer, primary_key=True, autoincrement=True) group_id = Column(String(255), nullable=False, index=True, unique=True) state_id = Column(String(255), nullable=False, unique=True) - triggering_events = Column(Text) # JSON 格式 - context = Column(Text) # JSON 格式 + triggering_events = Column(Text) # JSON 格式 + context = Column(Text) # JSON 格式 created_at = Column(BigInteger, nullable=False) last_updated = Column(BigInteger, nullable=False) @@ -32,7 +32,7 @@ class PsychologicalStateComponent(Base): __tablename__ = 'psychological_state_components' id = Column(Integer, primary_key=True, autoincrement=True) - composite_state_id = Column(Integer, ForeignKey('composite_psychological_states.id'), nullable=True) # ✅ 允许 NULL 兼容传统数据 + composite_state_id = Column(Integer, ForeignKey('composite_psychological_states.id'), nullable=True) # 允许 NULL 兼容传统数据 group_id = Column(String(255), nullable=False, index=True) state_id = Column(String(255), nullable=False, index=True) category = Column(String(50), nullable=False) @@ -80,9 +80,9 @@ class PersonaDiversityScore(Base): id = Column(Integer, primary_key=True, autoincrement=True) group_id = Column(String(255), nullable=False, index=True) persona_id = Column(String(255), nullable=False, index=True) - diversity_dimension = Column(String(100), nullable=False) # emotion, topic, style, etc. - score = Column(Float, nullable=False) # 多样性分数 0-1 - calculated_at = Column(Float, nullable=False) # 计算时间戳 + diversity_dimension = Column(String(100), nullable=False) # emotion, topic, style, etc. + score = Column(Float, nullable=False) # 多样性分数 0-1 + calculated_at = Column(Float, nullable=False) # 计算时间戳 created_at = Column(DateTime, default=func.now()) __table_args__ = ( @@ -112,10 +112,10 @@ class PersonaAttributeWeight(Base): id = Column(Integer, primary_key=True, autoincrement=True) group_id = Column(String(255), nullable=False, index=True) persona_id = Column(String(255), nullable=False, index=True) - attribute_name = Column(String(100), nullable=False) # 属性名称 - weight = Column(Float, nullable=False) # 权重值 0-1 - adjustment_reason = Column(Text, nullable=True) # 调整原因 - updated_at = Column(Float, nullable=False) # 更新时间戳 + attribute_name = Column(String(100), nullable=False) # 属性名称 + weight = Column(Float, nullable=False) # 权重值 0-1 + adjustment_reason = Column(Text, nullable=True) # 调整原因 + updated_at = Column(Float, nullable=False) # 更新时间戳 created_at = Column(DateTime, default=func.now()) __table_args__ = ( @@ -146,10 +146,10 @@ class PersonaEvolutionSnapshot(Base): id = Column(Integer, primary_key=True, autoincrement=True) group_id = Column(String(255), nullable=False, index=True) persona_id = Column(String(255), nullable=False, index=True) - snapshot_data = Column(Text, nullable=False) # JSON格式的完整人格状态 - version = Column(Integer, nullable=False) # 版本号 - snapshot_timestamp = Column(Float, nullable=False) # 快照时间戳 - trigger_event = Column(Text, nullable=True) # 触发事件描述 + snapshot_data = Column(Text, nullable=False) # JSON格式的完整人格状态 + version = Column(Integer, nullable=False) # 版本号 + snapshot_timestamp = Column(Float, nullable=False) # 快照时间戳 + trigger_event = Column(Text, nullable=True) # 触发事件描述 created_at = Column(DateTime, default=func.now()) __table_args__ = ( @@ -171,3 +171,60 @@ def to_dict(self): 'trigger_event': self.trigger_event, 'created_at': self.created_at.isoformat() if self.created_at else None } + + +class EmotionProfile(Base): + """Emotion profile per user per group.""" + __tablename__ = 'emotion_profiles' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(String(255), nullable=False, index=True) + group_id = Column(String(255), nullable=False, index=True) + dominant_emotions = Column(Text) # JSON + emotion_patterns = Column(Text) # JSON + empathy_level = Column(Float, default=0.5) + emotional_stability = Column(Float, default=0.5) + last_updated = Column(Float, nullable=False) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('idx_emotion_user_group', 'user_id', 'group_id', unique=True), + ) + + +class BotMood(Base): + """Bot mood state per group.""" + __tablename__ = 'bot_mood' + + id = Column(Integer, primary_key=True, autoincrement=True) + group_id = Column(String(255), nullable=False, index=True) + mood_type = Column(String(100), nullable=False) + mood_intensity = Column(Float, default=0.5) + mood_description = Column(Text) + start_time = Column(Float, nullable=False) + end_time = Column(Float) + is_active = Column(Integer, default=1) # Boolean as int for SQLite compat + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('idx_mood_group_active', 'group_id', 'is_active'), + ) + + +class PersonaBackup(Base): + """Persona configuration backup.""" + __tablename__ = 'persona_backups' + + id = Column(Integer, primary_key=True, autoincrement=True) + backup_name = Column(String(255), nullable=False) + timestamp = Column(Float, nullable=False) + reason = Column(Text) + persona_config = Column(Text) # JSON + original_persona = Column(Text) # JSON + imitation_dialogues = Column(Text) # JSON + backup_reason = Column(Text) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('idx_backup_timestamp', 'timestamp'), + ) diff --git a/models/orm/social_relation.py b/models/orm/social_relation.py index 381f3d5..fa16f49 100644 --- a/models/orm/social_relation.py +++ b/models/orm/social_relation.py @@ -1,8 +1,9 @@ """ 社交关系系统相关的 ORM 模型 """ -from sqlalchemy import Column, Integer, String, Text, Float, Index, BigInteger, ForeignKey +from sqlalchemy import Column, Integer, String, Text, Float, Index, BigInteger, ForeignKey, DateTime from sqlalchemy.orm import relationship +from sqlalchemy.sql import func from .base import Base @@ -111,3 +112,38 @@ class SocialRelationHistory(Base): Index('idx_social_history_from_to', 'from_user_id', 'to_user_id', 'group_id'), Index('idx_social_history_timestamp', 'timestamp'), ) + + +class UserProfile(Base): + """User profile with JSON-stored behavioral data.""" + __tablename__ = 'user_profiles' + + qq_id = Column(String(255), primary_key=True) + qq_name = Column(String(255)) + nicknames = Column(Text) # JSON + activity_pattern = Column(Text) # JSON + communication_style = Column(Text) # JSON + topic_preferences = Column(Text) # JSON + emotional_tendency = Column(Text) # JSON + last_active = Column(Float) + created_at = Column(DateTime, default=func.now()) + updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) + + +class UserPreferences(Base): + """User learning/interaction preferences per group.""" + __tablename__ = 'user_preferences' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(String(255), nullable=False, index=True) + group_id = Column(String(255), nullable=False, index=True) + favorite_topics = Column(Text) # JSON + interaction_style = Column(Text) # JSON + learning_preferences = Column(Text) # JSON + adaptive_rate = Column(Float, default=0.5) + updated_at = Column(Float, nullable=False) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('idx_pref_user_group', 'user_id', 'group_id', unique=True), + ) diff --git a/models/psychological_state.py b/models/psychological_state.py index 4182376..16e27ff 100644 --- a/models/psychological_state.py +++ b/models/psychological_state.py @@ -8,7 +8,7 @@ import time -# ==================== 情绪情感类心理状态 ==================== +# 情绪情感类心理状态 class EmotionPositiveType(Enum): """积极情绪类型""" @@ -133,7 +133,7 @@ class EmotionSpecialType(Enum): MIXED_FEELINGS = "百感交集" -# ==================== 认知类心理状态 ==================== +# 认知类心理状态 class AttentionState(Enum): """注意力状态""" @@ -243,7 +243,7 @@ class DecisionState(Enum): FOLLOWING_CROWD = "随波逐流" -# ==================== 意志与行为倾向类心理状态 ==================== +# 意志与行为倾向类心理状态 class WillStrengthState(Enum): """意志强度状态""" @@ -320,7 +320,7 @@ class GoalOrientationState(Enum): UTILITARIAN = "功利性" -# ==================== 自我认知与人格倾向类心理状态 ==================== +# 自我认知与人格倾向类心理状态 class SelfAcceptanceState(Enum): """自我接纳状态""" @@ -404,7 +404,7 @@ class PersonalityTendencyState(Enum): ADAPTABLE = "灵活应变" -# ==================== 社交互动类心理状态 ==================== +# 社交互动类心理状态 class SocialAttitudeState(Enum): """社交态度状态""" @@ -492,7 +492,7 @@ class InterpersonalRoleState(Enum): EQUAL_DIALOGUE = "平等对话" -# ==================== 适应与应激类心理状态 ==================== +# 适应与应激类心理状态 class EnvironmentalAdaptationState(Enum): """环境适应状态""" @@ -559,7 +559,7 @@ class BodyMindCoordinationState(Enum): PSYCHOSOMATIC = "心因性躯体症状" -# ==================== 其他维度心理状态 ==================== +# 其他维度心理状态 class EnergyState(Enum): """精力状态""" @@ -611,7 +611,7 @@ class TimePerceptionState(Enum): STEADY_PACE = "按部就班" -# ==================== 复合心理状态 ==================== +# 复合心理状态 @dataclass class PsychologicalStateComponent: @@ -720,7 +720,7 @@ def to_prompt_injection(self) -> str: return "\n".join(prompt_parts) -# ==================== 状态转换规则 ==================== +# 状态转换规则 @dataclass class StateTransitionRule: diff --git a/models/social_relation.py b/models/social_relation.py index e00007f..15b7beb 100644 --- a/models/social_relation.py +++ b/models/social_relation.py @@ -8,7 +8,7 @@ import time -# ==================== 核心联结基础类关系 ==================== +# 核心联结基础类关系 class BloodRelationType(Enum): """血缘关系类型""" @@ -155,7 +155,7 @@ class InterestRelationType(Enum): COMPANION = "搭子关系" -# ==================== 按亲密度与情感深度分类 ==================== +# 按亲密度与情感深度分类 class IntimacyLevel(Enum): """亲密度等级""" @@ -175,7 +175,7 @@ class IntimacyLevel(Enum): AVOIDANT = "回避型疏远" # 有矛盾、刻意保持距离 -# ==================== 按社会功能与互动场景分类 ==================== +# 按社会功能与互动场景分类 class FamilyRelationType(Enum): """家庭关系类型""" @@ -216,7 +216,7 @@ class PublicRelationType(Enum): STRANGER_INTERACTION = "陌生人互动" # 超市收银员、公交司机 -# ==================== 按法律与契约属性分类 ==================== +# 按法律与契约属性分类 class LegalRelationType(Enum): """法定关系类型""" @@ -239,7 +239,7 @@ class NonContractualRelationType(Enum): TEMPORARY_RELATION = "临时类" # 同车乘客、活动参与者 -# ==================== 按其他关键维度分类 ==================== +# 按其他关键维度分类 class RelationDuration(Enum): """关系存续时间""" @@ -271,7 +271,7 @@ class CrossDimensional(Enum): CONFLICT = "冲突型关系" # 仇人 -# ==================== 社交关系数值化数据模型 ==================== +# 社交关系数值化数据模型 @dataclass class SocialRelationComponent: @@ -378,7 +378,7 @@ def to_prompt_injection(self) -> str: return "\n".join(prompt_parts) -# ==================== 社交关系变化规则 ==================== +# 社交关系变化规则 @dataclass class RelationChangeRule: diff --git a/repositories/__init__.py b/repositories/__init__.py index d3f749e..4e7f8df 100644 --- a/repositories/__init__.py +++ b/repositories/__init__.py @@ -71,6 +71,32 @@ AdaptiveResponseTemplateRepository ) +# --- Phase 1 新增 Repository --- + +# 原始消息/筛选消息/Bot消息 +from .raw_message_repository import RawMessageRepository +from .filtered_message_repository import FilteredMessageRepository +from .bot_message_repository import BotMessageRepository + +# 用户画像/偏好 +from .user_profile_repository import UserProfileRepository +from .user_preferences_repository import UserPreferencesRepository + +# 情绪画像 / 风格画像 / Bot 情绪 +from .emotion_profile_repository import EmotionProfileRepository +from .style_profile_repository import StyleProfileRepository +from .bot_mood_repository import BotMoodRepository + +# 人格备份 +from .persona_backup_repository import PersonaBackupRepository + +# 知识图谱 +from .knowledge_graph_repository import ( + KnowledgeEntityRepository, + KnowledgeRelationRepository, + KnowledgeParagraphHashRepository +) + __all__ = [ # 基础 'BaseRepository', @@ -122,4 +148,28 @@ 'JargonUsageFrequencyRepository', 'ExpressionGenerationResultRepository', 'AdaptiveResponseTemplateRepository', + + # --- Phase 1 新增 (12个) --- + + # 消息三层 (3个) + 'RawMessageRepository', + 'FilteredMessageRepository', + 'BotMessageRepository', + + # 用户画像/偏好 (2个) + 'UserProfileRepository', + 'UserPreferencesRepository', + + # 情绪/风格/情绪状态 (3个) + 'EmotionProfileRepository', + 'StyleProfileRepository', + 'BotMoodRepository', + + # 人格备份 (1个) + 'PersonaBackupRepository', + + # 知识图谱 (3个) + 'KnowledgeEntityRepository', + 'KnowledgeRelationRepository', + 'KnowledgeParagraphHashRepository', ] diff --git a/repositories/bot_message_repository.py b/repositories/bot_message_repository.py new file mode 100644 index 0000000..4ad1f00 --- /dev/null +++ b/repositories/bot_message_repository.py @@ -0,0 +1,149 @@ +""" +Bot 消息 Repository — BotMessage 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, desc, func, delete +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.message import BotMessage + + +class BotMessageRepository(BaseRepository[BotMessage]): + """Bot 消息 Repository""" + + def __init__(self, session: AsyncSession): + super().__init__(session, BotMessage) + + async def save(self, message_data: Dict[str, Any]) -> Optional[BotMessage]: + """ + 保存一条 Bot 消息 + + Args: + message_data: 消息字段字典,包含 group_id, message, timestamp 等 + + Returns: + Optional[BotMessage]: 创建的记录 + """ + try: + now = int(time.time()) + return await self.create( + group_id=message_data.get('group_id', ''), + message=message_data.get('message', ''), + timestamp=message_data.get('timestamp', now), + created_at=now, + ) + except Exception as e: + logger.error(f"[BotMessageRepository] 保存 Bot 消息失败: {e}") + return None + + async def get_recent_responses( + self, + group_id: str, + limit: int = 50 + ) -> List[BotMessage]: + """ + 获取最近的 Bot 回复 + + Args: + group_id: 群组 ID + limit: 最大返回数量 + + Returns: + List[BotMessage]: Bot 消息列表(按时间倒序) + """ + try: + stmt = ( + select(BotMessage) + .where(BotMessage.group_id == group_id) + .order_by(desc(BotMessage.timestamp)) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[BotMessageRepository] 获取最近 Bot 回复失败: {e}") + return [] + + async def get_statistics(self, group_id: Optional[str] = None) -> Dict[str, Any]: + """ + 获取 Bot 消息统计信息 + + Args: + group_id: 群组 ID(为 None 时统计全部) + + Returns: + Dict: {"total": ..., "groups": ...} + """ + try: + # 总数 + total_stmt = select(func.count()).select_from(BotMessage) + if group_id: + total_stmt = total_stmt.where(BotMessage.group_id == group_id) + total_result = await self.session.execute(total_stmt) + total = total_result.scalar() or 0 + + # 按群组统计 + group_stmt = ( + select( + BotMessage.group_id, + func.count().label('count') + ) + .group_by(BotMessage.group_id) + .order_by(desc('count')) + ) + if group_id: + group_stmt = group_stmt.where(BotMessage.group_id == group_id) + + group_result = await self.session.execute(group_stmt) + groups = [ + {"group_id": row.group_id, "count": row.count} + for row in group_result.fetchall() + ] + + return {"total": total, "groups": groups} + except Exception as e: + logger.error(f"[BotMessageRepository] 获取统计信息失败: {e}") + return {"total": 0, "groups": []} + + async def count_all(self, group_id: Optional[str] = None) -> int: + """ + 统计 Bot 消息总数 + + Args: + group_id: 群组 ID(为 None 时统计全部) + + Returns: + int: 消息数量 + """ + try: + stmt = select(func.count()).select_from(BotMessage) + if group_id: + stmt = stmt.where(BotMessage.group_id == group_id) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[BotMessageRepository] 统计消息失败: {e}") + return 0 + + async def delete_by_group(self, group_id: str) -> int: + """ + 删除指定群组的所有 Bot 消息 + + Args: + group_id: 群组 ID + + Returns: + int: 删除的行数 + """ + try: + stmt = delete(BotMessage).where(BotMessage.group_id == group_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount + except Exception as e: + await self.session.rollback() + logger.error(f"[BotMessageRepository] 删除群组 Bot 消息失败: {e}") + return 0 diff --git a/repositories/bot_mood_repository.py b/repositories/bot_mood_repository.py new file mode 100644 index 0000000..6dddb07 --- /dev/null +++ b/repositories/bot_mood_repository.py @@ -0,0 +1,178 @@ +""" +Bot 情绪 Repository — BotMood 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, update, and_, desc, func +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.psychological import BotMood + + +class BotMoodRepository(BaseRepository[BotMood]): + """Bot 情绪 Repository + + BotMood 使用 (group_id, is_active) 索引来快速查找当前情绪。 + 设置新情绪时需先将旧情绪设为非活跃。 + """ + + def __init__(self, session: AsyncSession): + super().__init__(session, BotMood) + + async def save(self, mood_data: Dict[str, Any]) -> Optional[BotMood]: + """ + 保存新情绪(自动将同群组的旧情绪设为非活跃) + + Args: + mood_data: 情绪字段字典,必须包含 group_id, mood_type + + Returns: + Optional[BotMood]: 创建的记录 + """ + group_id = mood_data.get('group_id') + if not group_id: + logger.error("[BotMoodRepository] 保存情绪失败: 缺少 group_id") + return None + + try: + # 先将该群组的活跃情绪设为非活跃 + deactivate_stmt = ( + update(BotMood) + .where(and_( + BotMood.group_id == group_id, + BotMood.is_active == 1, + )) + .values(is_active=0, end_time=time.time()) + ) + await self.session.execute(deactivate_stmt) + + # 创建新的活跃情绪 + mood_data.setdefault('start_time', time.time()) + mood_data.setdefault('is_active', 1) + mood = BotMood(**mood_data) + self.session.add(mood) + await self.session.commit() + await self.session.refresh(mood) + return mood + except Exception as e: + await self.session.rollback() + logger.error(f"[BotMoodRepository] 保存情绪失败: {e}") + return None + + async def get_current(self, group_id: str) -> Optional[BotMood]: + """ + 获取当前活跃情绪 + + Args: + group_id: 群组 ID + + Returns: + Optional[BotMood]: 当前情绪对象 + """ + try: + stmt = ( + select(BotMood) + .where(and_( + BotMood.group_id == group_id, + BotMood.is_active == 1, + )) + .order_by(desc(BotMood.start_time)) + .limit(1) + ) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"[BotMoodRepository] 获取当前情绪失败: {e}") + return None + + async def get_history( + self, + group_id: str, + limit: int = 20 + ) -> List[BotMood]: + """ + 获取情绪历史 + + Args: + group_id: 群组 ID + limit: 最大返回数量 + + Returns: + List[BotMood]: 情绪历史列表(按时间倒序) + """ + try: + stmt = ( + select(BotMood) + .where(BotMood.group_id == group_id) + .order_by(desc(BotMood.start_time)) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[BotMoodRepository] 获取情绪历史失败: {e}") + return [] + + async def deactivate_all(self, group_id: str) -> int: + """ + 将指定群组的所有活跃情绪设为非活跃 + + Args: + group_id: 群组 ID + + Returns: + int: 更新的行数 + """ + try: + stmt = ( + update(BotMood) + .where(and_( + BotMood.group_id == group_id, + BotMood.is_active == 1, + )) + .values(is_active=0, end_time=time.time()) + ) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount + except Exception as e: + await self.session.rollback() + logger.error(f"[BotMoodRepository] 停用情绪失败: {e}") + return 0 + + async def get_mood_statistics(self, group_id: str) -> Dict[str, Any]: + """ + 获取情绪统计信息 + + Args: + group_id: 群组 ID + + Returns: + Dict: {"total": ..., "mood_distribution": {type: count, ...}} + """ + try: + total_stmt = select(func.count()).select_from(BotMood).where( + BotMood.group_id == group_id + ) + total_result = await self.session.execute(total_stmt) + total = total_result.scalar() or 0 + + dist_stmt = ( + select( + BotMood.mood_type, + func.count().label('count') + ) + .where(BotMood.group_id == group_id) + .group_by(BotMood.mood_type) + ) + dist_result = await self.session.execute(dist_stmt) + distribution = { + row.mood_type: row.count for row in dist_result.fetchall() + } + + return {"total": total, "mood_distribution": distribution} + except Exception as e: + logger.error(f"[BotMoodRepository] 获取情绪统计失败: {e}") + return {"total": 0, "mood_distribution": {}} diff --git a/repositories/emotion_profile_repository.py b/repositories/emotion_profile_repository.py new file mode 100644 index 0000000..05ad03d --- /dev/null +++ b/repositories/emotion_profile_repository.py @@ -0,0 +1,136 @@ +""" +情绪画像 Repository — EmotionProfile 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, and_, func +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.psychological import EmotionProfile + + +class EmotionProfileRepository(BaseRepository[EmotionProfile]): + """情绪画像 Repository + + EmotionProfile 以 (user_id, group_id) 唯一约束。 + """ + + def __init__(self, session: AsyncSession): + super().__init__(session, EmotionProfile) + + async def load(self, user_id: str, group_id: str) -> Optional[EmotionProfile]: + """ + 加载情绪画像 + + Args: + user_id: 用户 ID + group_id: 群组 ID + + Returns: + Optional[EmotionProfile]: 情绪画像对象 + """ + try: + stmt = select(EmotionProfile).where(and_( + EmotionProfile.user_id == user_id, + EmotionProfile.group_id == group_id, + )) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"[EmotionProfileRepository] 加载情绪画像失败: {e}") + return None + + async def save(self, profile_data: Dict[str, Any]) -> Optional[EmotionProfile]: + """ + 保存情绪画像(upsert:存在则更新,不存在则创建) + + Args: + profile_data: 画像字段字典,必须包含 user_id 和 group_id + + Returns: + Optional[EmotionProfile]: 保存后的记录 + """ + user_id = profile_data.get('user_id') + group_id = profile_data.get('group_id') + if not user_id or not group_id: + logger.error("[EmotionProfileRepository] 保存画像失败: 缺少 user_id 或 group_id") + return None + + try: + existing = await self.load(user_id, group_id) + if existing: + for key, value in profile_data.items(): + if key not in ('user_id', 'group_id', 'id') and hasattr(existing, key): + setattr(existing, key, value) + existing.last_updated = time.time() + await self.session.commit() + await self.session.refresh(existing) + return existing + else: + profile_data.setdefault('last_updated', time.time()) + profile = EmotionProfile(**profile_data) + self.session.add(profile) + await self.session.commit() + await self.session.refresh(profile) + return profile + except Exception as e: + await self.session.rollback() + logger.error(f"[EmotionProfileRepository] 保存情绪画像失败: {e}") + return None + + async def get_by_group(self, group_id: str) -> List[EmotionProfile]: + """ + 获取群组内所有情绪画像 + + Args: + group_id: 群组 ID + + Returns: + List[EmotionProfile]: 情绪画像列表 + """ + try: + stmt = select(EmotionProfile).where( + EmotionProfile.group_id == group_id + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[EmotionProfileRepository] 获取群组情绪画像失败: {e}") + return [] + + async def get_by_user(self, user_id: str) -> List[EmotionProfile]: + """ + 获取用户在所有群组的情绪画像 + + Args: + user_id: 用户 ID + + Returns: + List[EmotionProfile]: 情绪画像列表 + """ + try: + stmt = select(EmotionProfile).where( + EmotionProfile.user_id == user_id + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[EmotionProfileRepository] 获取用户情绪画像失败: {e}") + return [] + + async def count_all(self) -> int: + """ + 统计情绪画像总数 + + Returns: + int: 画像数量 + """ + try: + stmt = select(func.count()).select_from(EmotionProfile) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[EmotionProfileRepository] 统计画像失败: {e}") + return 0 diff --git a/repositories/filtered_message_repository.py b/repositories/filtered_message_repository.py new file mode 100644 index 0000000..7a195bf --- /dev/null +++ b/repositories/filtered_message_repository.py @@ -0,0 +1,221 @@ +""" +筛选后消息 Repository — FilteredMessage 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, update, and_, desc, func, delete +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.message import FilteredMessage + + +class FilteredMessageRepository(BaseRepository[FilteredMessage]): + """筛选后消息 Repository""" + + def __init__(self, session: AsyncSession): + super().__init__(session, FilteredMessage) + + async def add(self, message_data: Dict[str, Any]) -> Optional[FilteredMessage]: + """ + 添加一条筛选后的消息 + + Args: + message_data: 消息字段字典 + + Returns: + Optional[FilteredMessage]: 创建的记录 + """ + try: + now = int(time.time()) + return await self.create( + raw_message_id=message_data.get('raw_message_id'), + message=message_data.get('message', ''), + sender_id=message_data.get('sender_id', ''), + group_id=message_data.get('group_id', ''), + timestamp=message_data.get('timestamp', now), + confidence=message_data.get('confidence'), + quality_scores=message_data.get('quality_scores'), + filter_reason=message_data.get('filter_reason'), + created_at=now, + processed=False, + ) + except Exception as e: + logger.error(f"[FilteredMessageRepository] 添加筛选消息失败: {e}") + return None + + async def get_for_learning(self, limit: int = 200) -> List[FilteredMessage]: + """ + 获取待学习的筛选消息(未处理的) + + Args: + limit: 最大返回数量 + + Returns: + List[FilteredMessage]: 待学习消息列表(按时间升序) + """ + try: + stmt = ( + select(FilteredMessage) + .where(FilteredMessage.processed == False) # noqa: E712 + .order_by(FilteredMessage.timestamp.asc()) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[FilteredMessageRepository] 获取待学习消息失败: {e}") + return [] + + async def mark_processed(self, message_id: int) -> bool: + """ + 标记为已处理 + + Args: + message_id: 消息 ID + + Returns: + bool: 是否成功 + """ + try: + stmt = ( + update(FilteredMessage) + .where(FilteredMessage.id == message_id) + .values(processed=True) + ) + await self.session.execute(stmt) + await self.session.commit() + return True + except Exception as e: + await self.session.rollback() + logger.error(f"[FilteredMessageRepository] 标记已处理失败: {e}") + return False + + async def mark_batch_processed(self, message_ids: List[int]) -> int: + """ + 批量标记为已处理 + + Args: + message_ids: 消息 ID 列表 + + Returns: + int: 成功标记的数量 + """ + if not message_ids: + return 0 + try: + stmt = ( + update(FilteredMessage) + .where(FilteredMessage.id.in_(message_ids)) + .values(processed=True) + ) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount + except Exception as e: + await self.session.rollback() + logger.error(f"[FilteredMessageRepository] 批量标记已处理失败: {e}") + return 0 + + async def get_recent( + self, + group_id: Optional[str] = None, + limit: int = 50 + ) -> List[FilteredMessage]: + """ + 获取最近的筛选消息 + + Args: + group_id: 群组 ID(为 None 时不过滤) + limit: 最大返回数量 + + Returns: + List[FilteredMessage]: 消息列表(按时间倒序) + """ + try: + stmt = select(FilteredMessage) + if group_id: + stmt = stmt.where(FilteredMessage.group_id == group_id) + stmt = stmt.order_by(desc(FilteredMessage.timestamp)).limit(limit) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[FilteredMessageRepository] 获取最近筛选消息失败: {e}") + return [] + + async def count_all(self, group_id: Optional[str] = None) -> int: + """ + 统计消息总数 + + Args: + group_id: 群组 ID(为 None 时统计全部) + + Returns: + int: 消息数量 + """ + try: + stmt = select(func.count()).select_from(FilteredMessage) + if group_id: + stmt = stmt.where(FilteredMessage.group_id == group_id) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[FilteredMessageRepository] 统计消息失败: {e}") + return 0 + + async def delete_by_group(self, group_id: str) -> int: + """ + 删除指定群组的所有筛选消息 + + Args: + group_id: 群组 ID + + Returns: + int: 删除的行数 + """ + try: + stmt = delete(FilteredMessage).where(FilteredMessage.group_id == group_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount + except Exception as e: + await self.session.rollback() + logger.error(f"[FilteredMessageRepository] 删除群组筛选消息失败: {e}") + return 0 + + async def get_by_confidence_range( + self, + group_id: str, + min_confidence: float = 0.0, + max_confidence: float = 1.0, + limit: int = 100 + ) -> List[FilteredMessage]: + """ + 按置信度范围获取消息 + + Args: + group_id: 群组 ID + min_confidence: 最小置信度 + max_confidence: 最大置信度 + limit: 最大返回数量 + + Returns: + List[FilteredMessage]: 消息列表 + """ + try: + stmt = ( + select(FilteredMessage) + .where(and_( + FilteredMessage.group_id == group_id, + FilteredMessage.confidence >= min_confidence, + FilteredMessage.confidence <= max_confidence, + )) + .order_by(desc(FilteredMessage.confidence)) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[FilteredMessageRepository] 按置信度获取消息失败: {e}") + return [] diff --git a/repositories/knowledge_graph_repository.py b/repositories/knowledge_graph_repository.py new file mode 100644 index 0000000..b526386 --- /dev/null +++ b/repositories/knowledge_graph_repository.py @@ -0,0 +1,332 @@ +""" +知识图谱 Repository — KGEntity / KGRelation / KGParagraphHash 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, and_, or_, desc, func, update +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.knowledge_graph import KGEntity, KGRelation, KGParagraphHash + + +class KnowledgeEntityRepository(BaseRepository[KGEntity]): + """知识图谱实体 Repository""" + + def __init__(self, session: AsyncSession): + super().__init__(session, KGEntity) + + async def save_entity( + self, + name: str, + group_id: str, + entity_type: str = 'general' + ) -> Optional[KGEntity]: + """ + 保存实体(upsert:已存在则增加 appear_count) + + Args: + name: 实体名称 + group_id: 群组 ID + entity_type: 实体类型 + + Returns: + Optional[KGEntity]: 实体对象 + """ + try: + existing = await self._find_by_name_group(name, group_id) + if existing: + existing.appear_count = (existing.appear_count or 0) + 1 + existing.last_active_time = time.time() + if entity_type != 'general': + existing.entity_type = entity_type + await self.session.commit() + await self.session.refresh(existing) + return existing + else: + return await self.create( + name=name, + entity_type=entity_type, + appear_count=1, + last_active_time=time.time(), + group_id=group_id, + ) + except Exception as e: + await self.session.rollback() + logger.error(f"[KnowledgeEntityRepository] 保存实体失败: {e}") + return None + + async def _find_by_name_group( + self, + name: str, + group_id: str + ) -> Optional[KGEntity]: + """按名称和群组查找实体""" + try: + stmt = select(KGEntity).where(and_( + KGEntity.name == name, + KGEntity.group_id == group_id, + )) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception: + return None + + async def get_entities( + self, + group_id: str, + entity_type: Optional[str] = None, + limit: int = 100 + ) -> List[KGEntity]: + """ + 获取群组的实体列表 + + Args: + group_id: 群组 ID + entity_type: 实体类型过滤(可选) + limit: 最大返回数量 + + Returns: + List[KGEntity]: 实体列表(按出现次数倒序) + """ + try: + stmt = select(KGEntity).where(KGEntity.group_id == group_id) + if entity_type: + stmt = stmt.where(KGEntity.entity_type == entity_type) + stmt = stmt.order_by(desc(KGEntity.appear_count)).limit(limit) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[KnowledgeEntityRepository] 获取实体列表失败: {e}") + return [] + + async def get_entity_count(self, group_id: str) -> int: + """ + 统计群组的实体数量 + + Args: + group_id: 群组 ID + + Returns: + int: 实体数量 + """ + try: + stmt = select(func.count()).select_from(KGEntity).where( + KGEntity.group_id == group_id + ) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[KnowledgeEntityRepository] 统计实体失败: {e}") + return 0 + + async def search_entities( + self, + group_id: str, + keyword: str, + limit: int = 20 + ) -> List[KGEntity]: + """ + 搜索实体 + + Args: + group_id: 群组 ID + keyword: 搜索关键词 + limit: 最大返回数量 + + Returns: + List[KGEntity]: 匹配的实体列表 + """ + try: + stmt = ( + select(KGEntity) + .where(and_( + KGEntity.group_id == group_id, + KGEntity.name.contains(keyword), + )) + .order_by(desc(KGEntity.appear_count)) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[KnowledgeEntityRepository] 搜索实体失败: {e}") + return [] + + +class KnowledgeRelationRepository(BaseRepository[KGRelation]): + """知识图谱关系 Repository""" + + def __init__(self, session: AsyncSession): + super().__init__(session, KGRelation) + + async def save_relation( + self, + subject: str, + predicate: str, + object_: str, + group_id: str, + confidence: float = 1.0 + ) -> Optional[KGRelation]: + """ + 保存关系(upsert:已存在则更新 confidence) + + Args: + subject: 主体 + predicate: 谓词 + object_: 客体 + group_id: 群组 ID + confidence: 置信度 + + Returns: + Optional[KGRelation]: 关系对象 + """ + try: + existing = await self._find_relation(subject, predicate, object_, group_id) + if existing: + existing.confidence = confidence + await self.session.commit() + await self.session.refresh(existing) + return existing + else: + return await self.create( + subject=subject, + predicate=predicate, + object=object_, + confidence=confidence, + created_time=time.time(), + group_id=group_id, + ) + except Exception as e: + await self.session.rollback() + logger.error(f"[KnowledgeRelationRepository] 保存关系失败: {e}") + return None + + async def _find_relation( + self, + subject: str, + predicate: str, + object_: str, + group_id: str + ) -> Optional[KGRelation]: + """精确查找关系""" + try: + stmt = select(KGRelation).where(and_( + KGRelation.subject == subject, + KGRelation.predicate == predicate, + KGRelation.object == object_, + KGRelation.group_id == group_id, + )) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception: + return None + + async def get_relations_by_entity( + self, + entity_name: str, + group_id: str, + limit: int = 50 + ) -> List[KGRelation]: + """ + 获取与实体相关的所有关系(实体可以是主体或客体) + + Args: + entity_name: 实体名称 + group_id: 群组 ID + limit: 最大返回数量 + + Returns: + List[KGRelation]: 关系列表 + """ + try: + stmt = ( + select(KGRelation) + .where(and_( + KGRelation.group_id == group_id, + or_( + KGRelation.subject == entity_name, + KGRelation.object == entity_name, + ), + )) + .order_by(desc(KGRelation.confidence)) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[KnowledgeRelationRepository] 获取实体关系失败: {e}") + return [] + + async def get_relation_count(self, group_id: str) -> int: + """ + 统计群组的关系数量 + + Args: + group_id: 群组 ID + + Returns: + int: 关系数量 + """ + try: + stmt = select(func.count()).select_from(KGRelation).where( + KGRelation.group_id == group_id + ) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[KnowledgeRelationRepository] 统计关系失败: {e}") + return 0 + + +class KnowledgeParagraphHashRepository(BaseRepository[KGParagraphHash]): + """知识图谱段落 Hash Repository(去重用)""" + + def __init__(self, session: AsyncSession): + super().__init__(session, KGParagraphHash) + + async def save_hash(self, hash_value: str, group_id: str) -> Optional[KGParagraphHash]: + """ + 保存段落 hash + + Args: + hash_value: Hash 值 + group_id: 群组 ID + + Returns: + Optional[KGParagraphHash]: 记录对象 + """ + try: + return await self.create( + hash_value=hash_value, + group_id=group_id, + created_time=time.time(), + ) + except Exception as e: + # 唯一约束冲突表示已存在 + await self.session.rollback() + logger.debug(f"[KnowledgeParagraphHashRepository] 保存 hash 失败(可能已存在): {e}") + return None + + async def exists_hash(self, hash_value: str, group_id: str) -> bool: + """ + 检查段落 hash 是否已存在 + + Args: + hash_value: Hash 值 + group_id: 群组 ID + + Returns: + bool: 是否存在 + """ + try: + stmt = select(func.count()).select_from(KGParagraphHash).where(and_( + KGParagraphHash.hash_value == hash_value, + KGParagraphHash.group_id == group_id, + )) + result = await self.session.execute(stmt) + return (result.scalar() or 0) > 0 + except Exception as e: + logger.error(f"[KnowledgeParagraphHashRepository] 检查 hash 失败: {e}") + return False diff --git a/repositories/learning_repository.py b/repositories/learning_repository.py index 1850fc3..d3fd83d 100644 --- a/repositories/learning_repository.py +++ b/repositories/learning_repository.py @@ -311,7 +311,7 @@ async def get_statistics(self) -> Dict[str, Any]: # 3. 获取原始消息总数 (total_samples) # 从 style_learning_reviews 表获取累计的消息数量 # 注意:这个字段可能不存在,需要根据实际情况调整 - total_samples = total_patterns # 暂时用总模式数代替 + total_samples = total_patterns # 暂时用总模式数代替 # 4. 最后更新时间 (latest_update) # 使用 timestamp 而不是 updated_at,因为 timestamp 是数值类型 @@ -319,7 +319,7 @@ async def get_statistics(self) -> Dict[str, Any]: last_update_result = await self.session.execute(last_update_stmt) latest_timestamp = last_update_result.scalar() - # ✅ 转换 Unix 时间戳为可读格式 + # 转换 Unix 时间戳为可读格式 latest_update = None if latest_timestamp: latest_update = datetime.fromtimestamp(latest_timestamp).strftime('%Y-%m-%d %H:%M:%S') diff --git a/repositories/persona_backup_repository.py b/repositories/persona_backup_repository.py new file mode 100644 index 0000000..43e8a4e --- /dev/null +++ b/repositories/persona_backup_repository.py @@ -0,0 +1,181 @@ +""" +人格备份 Repository — PersonaBackup 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, desc, func, delete +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.psychological import PersonaBackup + + +class PersonaBackupRepository(BaseRepository[PersonaBackup]): + """人格备份 Repository""" + + def __init__(self, session: AsyncSession): + super().__init__(session, PersonaBackup) + + async def create_backup( + self, + backup_data: Dict[str, Any] + ) -> Optional[PersonaBackup]: + """ + 创建人格备份 + + Args: + backup_data: 备份字段字典,至少包含 backup_name + + Returns: + Optional[PersonaBackup]: 创建的记录 + """ + try: + backup_data.setdefault('timestamp', time.time()) + return await self.create(**backup_data) + except Exception as e: + logger.error(f"[PersonaBackupRepository] 创建备份失败: {e}") + return None + + async def list_backups( + self, + limit: int = 50, + offset: int = 0 + ) -> List[PersonaBackup]: + """ + 列出所有备份(按时间倒序) + + Args: + limit: 最大返回数量 + offset: 偏移量 + + Returns: + List[PersonaBackup]: 备份列表 + """ + try: + stmt = ( + select(PersonaBackup) + .order_by(desc(PersonaBackup.timestamp)) + .offset(offset) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[PersonaBackupRepository] 列出备份失败: {e}") + return [] + + async def get_backup(self, backup_id: int) -> Optional[PersonaBackup]: + """ + 获取指定备份 + + Args: + backup_id: 备份 ID + + Returns: + Optional[PersonaBackup]: 备份对象 + """ + return await self.get_by_id(backup_id) + + async def get_by_name(self, backup_name: str) -> Optional[PersonaBackup]: + """ + 按名称获取最近的备份 + + Args: + backup_name: 备份名称 + + Returns: + Optional[PersonaBackup]: 备份对象 + """ + try: + stmt = ( + select(PersonaBackup) + .where(PersonaBackup.backup_name == backup_name) + .order_by(desc(PersonaBackup.timestamp)) + .limit(1) + ) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"[PersonaBackupRepository] 按名称获取备份失败: {e}") + return None + + async def delete_backup(self, backup_id: int) -> bool: + """ + 删除指定备份 + + Args: + backup_id: 备份 ID + + Returns: + bool: 是否成功 + """ + return await self.delete_by_id(backup_id) + + async def count_backups(self) -> int: + """ + 统计备份总数 + + Returns: + int: 备份数量 + """ + try: + stmt = select(func.count()).select_from(PersonaBackup) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[PersonaBackupRepository] 统计备份失败: {e}") + return 0 + + async def delete_oldest(self, keep_count: int = 10) -> int: + """ + 删除最旧的备份,只保留最新的 N 条 + + Args: + keep_count: 保留数量 + + Returns: + int: 删除的行数 + """ + try: + # 获取需要保留的 ID + keep_stmt = ( + select(PersonaBackup.id) + .order_by(desc(PersonaBackup.timestamp)) + .limit(keep_count) + ) + keep_result = await self.session.execute(keep_stmt) + keep_ids = [row[0] for row in keep_result.fetchall()] + + if not keep_ids: + return 0 + + del_stmt = delete(PersonaBackup).where( + PersonaBackup.id.notin_(keep_ids) + ) + del_result = await self.session.execute(del_stmt) + await self.session.commit() + return del_result.rowcount + except Exception as e: + await self.session.rollback() + logger.error(f"[PersonaBackupRepository] 清理旧备份失败: {e}") + return 0 + + async def get_latest_backup(self) -> Optional[PersonaBackup]: + """ + 获取最新的备份 + + Returns: + Optional[PersonaBackup]: 最新的备份对象 + """ + try: + stmt = ( + select(PersonaBackup) + .order_by(desc(PersonaBackup.timestamp)) + .limit(1) + ) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"[PersonaBackupRepository] 获取最新备份失败: {e}") + return None diff --git a/repositories/raw_message_repository.py b/repositories/raw_message_repository.py new file mode 100644 index 0000000..4d627cc --- /dev/null +++ b/repositories/raw_message_repository.py @@ -0,0 +1,257 @@ +""" +原始消息 Repository — RawMessage 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, update, and_, desc, func, delete +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.message import RawMessage + + +class RawMessageRepository(BaseRepository[RawMessage]): + """原始消息 Repository""" + + def __init__(self, session: AsyncSession): + super().__init__(session, RawMessage) + + async def save(self, message_data: Dict[str, Any]) -> Optional[RawMessage]: + """ + 保存一条原始消息 + + Args: + message_data: 消息字段字典,至少包含 sender_id, message, timestamp + + Returns: + Optional[RawMessage]: 创建的记录 + """ + try: + now = int(time.time()) + return await self.create( + sender_id=message_data.get('sender_id', ''), + sender_name=message_data.get('sender_name', ''), + message=message_data.get('message', ''), + group_id=message_data.get('group_id', ''), + timestamp=message_data.get('timestamp', now), + platform=message_data.get('platform', ''), + message_id=message_data.get('message_id'), + reply_to=message_data.get('reply_to'), + created_at=now, + processed=False, + ) + except Exception as e: + logger.error(f"[RawMessageRepository] 保存原始消息失败: {e}") + return None + + async def get_unprocessed(self, limit: int = 100) -> List[RawMessage]: + """ + 获取未处理的消息 + + Args: + limit: 最大返回数量 + + Returns: + List[RawMessage]: 未处理消息列表(按时间升序) + """ + try: + stmt = ( + select(RawMessage) + .where(RawMessage.processed == False) # noqa: E712 + .order_by(RawMessage.timestamp.asc()) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[RawMessageRepository] 获取未处理消息失败: {e}") + return [] + + async def mark_processed(self, message_id: int) -> bool: + """ + 将消息标记为已处理 + + Args: + message_id: 消息 ID + + Returns: + bool: 是否成功 + """ + try: + stmt = ( + update(RawMessage) + .where(RawMessage.id == message_id) + .values(processed=True) + ) + await self.session.execute(stmt) + await self.session.commit() + return True + except Exception as e: + await self.session.rollback() + logger.error(f"[RawMessageRepository] 标记消息已处理失败: {e}") + return False + + async def mark_batch_processed(self, message_ids: List[int]) -> int: + """ + 批量标记消息为已处理 + + Args: + message_ids: 消息 ID 列表 + + Returns: + int: 成功标记的数量 + """ + if not message_ids: + return 0 + try: + stmt = ( + update(RawMessage) + .where(RawMessage.id.in_(message_ids)) + .values(processed=True) + ) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount + except Exception as e: + await self.session.rollback() + logger.error(f"[RawMessageRepository] 批量标记已处理失败: {e}") + return 0 + + async def get_recent( + self, + group_id: Optional[str] = None, + limit: int = 50 + ) -> List[RawMessage]: + """ + 获取最近的消息 + + Args: + group_id: 群组 ID(为 None 时不过滤) + limit: 最大返回数量 + + Returns: + List[RawMessage]: 消息列表(按时间倒序) + """ + try: + stmt = select(RawMessage) + if group_id: + stmt = stmt.where(RawMessage.group_id == group_id) + stmt = stmt.order_by(desc(RawMessage.timestamp)).limit(limit) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[RawMessageRepository] 获取最近消息失败: {e}") + return [] + + async def get_by_timerange( + self, + group_id: str, + start_ts: int, + end_ts: int, + limit: int = 500 + ) -> List[RawMessage]: + """ + 按时间范围获取消息 + + Args: + group_id: 群组 ID + start_ts: 开始时间戳 + end_ts: 结束时间戳 + limit: 最大返回数量 + + Returns: + List[RawMessage]: 消息列表 + """ + try: + stmt = ( + select(RawMessage) + .where(and_( + RawMessage.group_id == group_id, + RawMessage.timestamp >= start_ts, + RawMessage.timestamp <= end_ts, + )) + .order_by(RawMessage.timestamp.asc()) + .limit(limit) + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[RawMessageRepository] 按时间范围获取消息失败: {e}") + return [] + + async def count_all(self, group_id: Optional[str] = None) -> int: + """ + 统计消息总数 + + Args: + group_id: 群组 ID(为 None 时统计全部) + + Returns: + int: 消息数量 + """ + try: + stmt = select(func.count()).select_from(RawMessage) + if group_id: + stmt = stmt.where(RawMessage.group_id == group_id) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[RawMessageRepository] 统计消息失败: {e}") + return 0 + + async def delete_by_group(self, group_id: str) -> int: + """ + 删除指定群组的所有消息 + + Args: + group_id: 群组 ID + + Returns: + int: 删除的行数 + """ + try: + stmt = delete(RawMessage).where(RawMessage.group_id == group_id) + result = await self.session.execute(stmt) + await self.session.commit() + return result.rowcount + except Exception as e: + await self.session.rollback() + logger.error(f"[RawMessageRepository] 删除群组消息失败: {e}") + return 0 + + async def get_sender_statistics( + self, + group_id: str, + limit: int = 20 + ) -> List[Dict[str, Any]]: + """ + 获取发送者统计信息 + + Args: + group_id: 群组 ID + limit: 最大返回数量 + + Returns: + List[Dict]: [{"sender_id": ..., "count": ...}, ...] + """ + try: + stmt = ( + select( + RawMessage.sender_id, + RawMessage.sender_name, + func.count().label('count') + ) + .where(RawMessage.group_id == group_id) + .group_by(RawMessage.sender_id, RawMessage.sender_name) + .order_by(desc('count')) + .limit(limit) + ) + result = await self.session.execute(stmt) + return [ + {"sender_id": row.sender_id, "sender_name": row.sender_name or row.sender_id, "count": row.count} + for row in result.fetchall() + ] + except Exception as e: + logger.error(f"[RawMessageRepository] 获取发送者统计失败: {e}") + return [] diff --git a/repositories/reinforcement_repository.py b/repositories/reinforcement_repository.py index 4c09bef..8291296 100644 --- a/repositories/reinforcement_repository.py +++ b/repositories/reinforcement_repository.py @@ -54,11 +54,11 @@ async def save_reinforcement_result( self.session.add(result) await self.session.commit() - logger.info(f"✅ 保存强化学习结果成功 (group: {group_id})") + logger.info(f" 保存强化学习结果成功 (group: {group_id})") return True except Exception as e: - logger.error(f"❌ 保存强化学习结果失败: {e}", exc_info=True) + logger.error(f" 保存强化学习结果失败: {e}", exc_info=True) await self.session.rollback() return False @@ -91,7 +91,7 @@ async def get_recent_results( return [r.to_dict() for r in results] except Exception as e: - logger.error(f"❌ 获取强化学习结果失败: {e}", exc_info=True) + logger.error(f" 获取强化学习结果失败: {e}", exc_info=True) return [] @@ -129,11 +129,11 @@ async def save_fusion_result( self.session.add(fusion) await self.session.commit() - logger.info(f"✅ 保存人格融合结果成功 (group: {group_id})") + logger.info(f" 保存人格融合结果成功 (group: {group_id})") return True except Exception as e: - logger.error(f"❌ 保存人格融合结果失败: {e}", exc_info=True) + logger.error(f" 保存人格融合结果失败: {e}", exc_info=True) await self.session.rollback() return False @@ -166,7 +166,7 @@ async def get_fusion_history( return [h.to_dict() for h in histories] except Exception as e: - logger.error(f"❌ 获取人格融合历史失败: {e}", exc_info=True) + logger.error(f" 获取人格融合历史失败: {e}", exc_info=True) return [] @@ -203,11 +203,11 @@ async def save_optimization_result( self.session.add(result) await self.session.commit() - logger.info(f"✅ 保存策略优化结果成功 (group: {group_id})") + logger.info(f" 保存策略优化结果成功 (group: {group_id})") return True except Exception as e: - logger.error(f"❌ 保存策略优化结果失败: {e}", exc_info=True) + logger.error(f" 保存策略优化结果失败: {e}", exc_info=True) await self.session.rollback() return False @@ -240,5 +240,5 @@ async def get_recent_optimizations( return [r.to_dict() for r in results] except Exception as e: - logger.error(f"❌ 获取策略优化结果失败: {e}", exc_info=True) + logger.error(f" 获取策略优化结果失败: {e}", exc_info=True) return [] diff --git a/repositories/style_profile_repository.py b/repositories/style_profile_repository.py new file mode 100644 index 0000000..945c4d1 --- /dev/null +++ b/repositories/style_profile_repository.py @@ -0,0 +1,130 @@ +""" +风格画像 Repository — StyleProfile 表的数据访问 +""" +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, func +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.expression import StyleProfile + + +class StyleProfileRepository(BaseRepository[StyleProfile]): + """风格画像 Repository + + StyleProfile 以 profile_name 为逻辑键。 + """ + + def __init__(self, session: AsyncSession): + super().__init__(session, StyleProfile) + + async def load(self, profile_name: str) -> Optional[StyleProfile]: + """ + 加载风格画像 + + Args: + profile_name: 画像名称 + + Returns: + Optional[StyleProfile]: 风格画像对象 + """ + try: + stmt = select(StyleProfile).where( + StyleProfile.profile_name == profile_name + ) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"[StyleProfileRepository] 加载风格画像失败: {e}") + return None + + async def save(self, profile_data: Dict[str, Any]) -> Optional[StyleProfile]: + """ + 保存风格画像(upsert:存在则更新,不存在则创建) + + Args: + profile_data: 画像字段字典,必须包含 profile_name + + Returns: + Optional[StyleProfile]: 保存后的记录 + """ + profile_name = profile_data.get('profile_name') + if not profile_name: + logger.error("[StyleProfileRepository] 保存画像失败: 缺少 profile_name") + return None + + try: + existing = await self.load(profile_name) + if existing: + for key, value in profile_data.items(): + if key not in ('profile_name', 'id') and hasattr(existing, key): + setattr(existing, key, value) + await self.session.commit() + await self.session.refresh(existing) + return existing + else: + profile = StyleProfile(**profile_data) + self.session.add(profile) + await self.session.commit() + await self.session.refresh(profile) + return profile + except Exception as e: + await self.session.rollback() + logger.error(f"[StyleProfileRepository] 保存风格画像失败: {e}") + return None + + async def get_all_profiles(self, limit: int = 100) -> List[StyleProfile]: + """ + 获取所有风格画像 + + Args: + limit: 最大返回数量 + + Returns: + List[StyleProfile]: 画像列表 + """ + try: + stmt = select(StyleProfile).limit(limit) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[StyleProfileRepository] 获取所有画像失败: {e}") + return [] + + async def delete_profile(self, profile_name: str) -> bool: + """ + 删除风格画像 + + Args: + profile_name: 画像名称 + + Returns: + bool: 是否成功 + """ + try: + profile = await self.load(profile_name) + if profile: + await self.session.delete(profile) + await self.session.commit() + return True + return False + except Exception as e: + await self.session.rollback() + logger.error(f"[StyleProfileRepository] 删除画像失败: {e}") + return False + + async def count_all(self) -> int: + """ + 统计风格画像总数 + + Returns: + int: 画像数量 + """ + try: + stmt = select(func.count()).select_from(StyleProfile) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[StyleProfileRepository] 统计画像失败: {e}") + return 0 diff --git a/repositories/user_preferences_repository.py b/repositories/user_preferences_repository.py new file mode 100644 index 0000000..bc09657 --- /dev/null +++ b/repositories/user_preferences_repository.py @@ -0,0 +1,136 @@ +""" +用户偏好 Repository — UserPreferences 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, and_, func +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.social_relation import UserPreferences + + +class UserPreferencesRepository(BaseRepository[UserPreferences]): + """用户偏好 Repository + + UserPreferences 以 (user_id, group_id) 唯一约束。 + """ + + def __init__(self, session: AsyncSession): + super().__init__(session, UserPreferences) + + async def load(self, user_id: str, group_id: str) -> Optional[UserPreferences]: + """ + 加载用户偏好 + + Args: + user_id: 用户 ID + group_id: 群组 ID + + Returns: + Optional[UserPreferences]: 偏好对象 + """ + try: + stmt = select(UserPreferences).where(and_( + UserPreferences.user_id == user_id, + UserPreferences.group_id == group_id, + )) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"[UserPreferencesRepository] 加载偏好失败: {e}") + return None + + async def save(self, pref_data: Dict[str, Any]) -> Optional[UserPreferences]: + """ + 保存用户偏好(upsert:存在则更新,不存在则创建) + + Args: + pref_data: 偏好字段字典,必须包含 user_id 和 group_id + + Returns: + Optional[UserPreferences]: 保存后的记录 + """ + user_id = pref_data.get('user_id') + group_id = pref_data.get('group_id') + if not user_id or not group_id: + logger.error("[UserPreferencesRepository] 保存偏好失败: 缺少 user_id 或 group_id") + return None + + try: + existing = await self.load(user_id, group_id) + if existing: + for key, value in pref_data.items(): + if key not in ('user_id', 'group_id', 'id') and hasattr(existing, key): + setattr(existing, key, value) + existing.updated_at = time.time() + await self.session.commit() + await self.session.refresh(existing) + return existing + else: + pref_data.setdefault('updated_at', time.time()) + pref = UserPreferences(**pref_data) + self.session.add(pref) + await self.session.commit() + await self.session.refresh(pref) + return pref + except Exception as e: + await self.session.rollback() + logger.error(f"[UserPreferencesRepository] 保存偏好失败: {e}") + return None + + async def get_by_user(self, user_id: str) -> List[UserPreferences]: + """ + 获取用户在所有群组的偏好 + + Args: + user_id: 用户 ID + + Returns: + List[UserPreferences]: 偏好列表 + """ + try: + stmt = select(UserPreferences).where( + UserPreferences.user_id == user_id + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[UserPreferencesRepository] 获取用户偏好失败: {e}") + return [] + + async def get_by_group(self, group_id: str) -> List[UserPreferences]: + """ + 获取群组内所有用户的偏好 + + Args: + group_id: 群组 ID + + Returns: + List[UserPreferences]: 偏好列表 + """ + try: + stmt = select(UserPreferences).where( + UserPreferences.group_id == group_id + ) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[UserPreferencesRepository] 获取群组偏好失败: {e}") + return [] + + async def count_all(self) -> int: + """ + 统计偏好总数 + + Returns: + int: 偏好数量 + """ + try: + stmt = select(func.count()).select_from(UserPreferences) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[UserPreferencesRepository] 统计偏好失败: {e}") + return 0 diff --git a/repositories/user_profile_repository.py b/repositories/user_profile_repository.py new file mode 100644 index 0000000..ee3d7db --- /dev/null +++ b/repositories/user_profile_repository.py @@ -0,0 +1,131 @@ +""" +用户画像 Repository — UserProfile 表的数据访问 +""" +import time +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, func +from typing import List, Optional, Dict, Any + +from astrbot.api import logger +from .base_repository import BaseRepository +from ..models.orm.social_relation import UserProfile + + +class UserProfileRepository(BaseRepository[UserProfile]): + """用户画像 Repository + + UserProfile 以 qq_id 为主键(String),不使用自增 ID。 + """ + + def __init__(self, session: AsyncSession): + super().__init__(session, UserProfile) + + async def load(self, qq_id: str) -> Optional[UserProfile]: + """ + 加载用户画像 + + Args: + qq_id: 用户 QQ 号 + + Returns: + Optional[UserProfile]: 用户画像对象 + """ + try: + stmt = select(UserProfile).where(UserProfile.qq_id == qq_id) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"[UserProfileRepository] 加载用户画像失败: {e}") + return None + + async def save(self, profile_data: Dict[str, Any]) -> Optional[UserProfile]: + """ + 保存用户画像(upsert:存在则更新,不存在则创建) + + Args: + profile_data: 画像字段字典,必须包含 qq_id + + Returns: + Optional[UserProfile]: 保存后的记录 + """ + qq_id = profile_data.get('qq_id') + if not qq_id: + logger.error("[UserProfileRepository] 保存画像失败: 缺少 qq_id") + return None + + try: + existing = await self.load(qq_id) + if existing: + # 更新已有记录 + for key, value in profile_data.items(): + if key != 'qq_id' and hasattr(existing, key): + setattr(existing, key, value) + await self.session.commit() + await self.session.refresh(existing) + return existing + else: + # 创建新记录 + profile = UserProfile(**profile_data) + self.session.add(profile) + await self.session.commit() + await self.session.refresh(profile) + return profile + except Exception as e: + await self.session.rollback() + logger.error(f"[UserProfileRepository] 保存用户画像失败: {e}") + return None + + async def get_all_profiles(self, limit: int = 100) -> List[UserProfile]: + """ + 获取所有用户画像 + + Args: + limit: 最大返回数量 + + Returns: + List[UserProfile]: 画像列表 + """ + try: + stmt = select(UserProfile).limit(limit) + result = await self.session.execute(stmt) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"[UserProfileRepository] 获取所有画像失败: {e}") + return [] + + async def count_all(self) -> int: + """ + 统计用户画像总数 + + Returns: + int: 画像数量 + """ + try: + stmt = select(func.count()).select_from(UserProfile) + result = await self.session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + logger.error(f"[UserProfileRepository] 统计画像失败: {e}") + return 0 + + async def delete_profile(self, qq_id: str) -> bool: + """ + 删除用户画像 + + Args: + qq_id: 用户 QQ 号 + + Returns: + bool: 是否成功 + """ + try: + profile = await self.load(qq_id) + if profile: + await self.session.delete(profile) + await self.session.commit() + return True + return False + except Exception as e: + await self.session.rollback() + logger.error(f"[UserProfileRepository] 删除画像失败: {e}") + return False diff --git a/requirements.txt b/requirements.txt index dac3da1..bd64b0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,11 @@ seaborn==0.13.2 wordcloud==1.9.4 aiomysql guardrails-ai +pydantic>=2.0.0 sqlalchemy[asyncio]>=2.0.0 cachetools>=5.3.0 apscheduler>=3.10.0 asyncpg>=0.29.0 +lightrag-hku>=1.4.0 +mem0ai>=1.0.0 +qdrant-client>=1.7.0 diff --git a/scripts/MYSQL_SETUP.md b/scripts/MYSQL_SETUP.md deleted file mode 100644 index 6ef36e1..0000000 --- a/scripts/MYSQL_SETUP.md +++ /dev/null @@ -1,177 +0,0 @@ -# MySQL 数据库表结构初始化指南 - -## 问题说明 - -由于已废弃自动迁移功能,MySQL 数据库表需要手动创建。本文档提供了从 ORM 模型生成的完整建表 SQL 脚本。 - -## 表结构来源 - -所有表结构统一由 SQLAlchemy ORM 模型定义,位于: - -- `models/orm/message.py` - 消息相关表 -- `models/orm/psychological.py` - 心理状态表 -- `models/orm/social_relation.py` - 社交关系表 -- `models/orm/affection.py` - 好感度表 -- `models/orm/memory.py` - 记忆表 -- `models/orm/learning.py` - 学习记录表 -- `models/orm/expression.py` - 表达模式表 -- `models/orm/jargon.py` - 黑话表 -- `models/orm/social_analysis.py` - 社交分析表 -- `models/orm/performance.py` - 性能记录表 - -## 初始化步骤 - -### 方法 1: 执行完整建表脚本(推荐) - -```bash -# 1. 执行 ORM 模型表(27个表) -mysql -h 47.121.138.217 -P 13307 -u root -p < scripts/mysql_schema.sql - -# 2. 执行传统表(23个表) -mysql -h 47.121.138.217 -P 13307 -u root -p < scripts/mysql_schema_additional.sql -``` - -**说明**: -- `mysql_schema.sql` 包含从 ORM 模型生成的 27 个核心表 -- `mysql_schema_additional.sql` 包含尚未迁移到 ORM 的 23 个传统表 - -### 方法 2: 通过 MySQL 客户端导入 - -```bash -# 登录 MySQL -mysql -h 47.121.138.217 -P 13307 -u root -p - -# 执行脚本 -mysql> source /path/to/scripts/mysql_schema.sql; -``` - -### 方法 3: 重新生成 SQL 脚本 - -如果修改了 ORM 模型,需要重新生成 SQL: - -```bash -# 运行生成脚本 -python scripts/generate_mysql_schema.py - -# 执行新生成的 SQL -mysql -h 47.121.138.217 -P 13307 -u root -p < scripts/mysql_schema.sql -``` - -## 包含的表(共 27 个) - -### 消息系统 (3) -- `raw_messages` - 原始消息 -- `filtered_messages` - 筛选后消息 -- `bot_messages` - Bot 消息 - -### 好感度系统 (4) -- `user_affections` - 用户好感度 -- `affection_interactions` - 好感度交互记录 -- `user_conversation_history` - 对话历史 -- `user_diversity` - 用户多样性 - -### 记忆系统 (3) -- `memories` - 记忆 -- `memory_embeddings` - 记忆向量 -- `memory_summaries` - 记忆摘要 - -### 心理状态系统 (3) -- `composite_psychological_states` - 复合心理状态 -- `psychological_state_components` - 心理状态组件 -- `psychological_state_history` - 心理状态历史 - -### 社交关系系统 (6) -- `social_relations` - 社交关系 -- `user_social_profiles` - 用户社交档案 -- `user_social_relation_components` - 用户社交关系组件 -- `social_relation_history` - 社交关系历史 -- `social_relation_analysis_results` - 社交关系分析结果 -- `social_network_nodes` - 社交网络节点 -- `social_network_edges` - 社交网络边 - -### 学习系统 (4) -- `persona_update_reviews` - 人格更新审查 -- `style_learning_reviews` - 风格学习审查 -- `style_learning_patterns` - 风格学习模式 -- `interaction_records` - 交互记录 - -### 其他系统 (4) -- `expression_patterns` - 表达模式 -- `jargon` - 黑话 -- `learning_performance_history` - 学习性能历史 - -## 验证安装 - -执行 SQL 后,验证表是否创建成功: - -```sql --- 查看所有表 -SHOW TABLES; - --- 应该看到 27 个表 - --- 检查某个表的结构 -DESC raw_messages; -DESC composite_psychological_states; -``` - -## 注意事项 - -1. **字符集**: 所有表使用 `utf8mb4` 字符集,支持完整的 Unicode 字符(包括 emoji) -2. **引擎**: 所有表使用 `InnoDB` 引擎,支持事务和外键 -3. **索引**: SQL 脚本包含所有必要的索引,无需手动添加 -4. **外键**: 部分表有外键约束,删除表时需注意顺序 - -## 故障排除 - -### 问题 1: 表已存在 - -如果表已存在,SQL 脚本会先执行 `DROP TABLE IF EXISTS`,自动删除旧表。 - -**警告**: 这会删除所有数据!如需保留数据,请先备份: - -```bash -mysqldump -h 47.121.138.217 -P 13307 -u root -p astrbot_self_learning > backup.sql -``` - -### 问题 2: 权限不足 - -确保 MySQL 用户有足够权限: - -```sql -GRANT ALL PRIVILEGES ON astrbot_self_learning.* TO 'root'@'%'; -FLUSH PRIVILEGES; -``` - -### 问题 3: 连接失败 - -检查配置文件 `_conf_schema.json` 中的 MySQL 连接参数: - -```json -{ - "mysql_host": "47.121.138.217", - "mysql_port": 13307, - "mysql_user": "root", - "mysql_password": "your_password", - "mysql_database": "astrbot_self_learning" -} -``` - -## 更新表结构 - -如果未来修改了 ORM 模型(添加/删除字段),需要: - -1. 重新生成 SQL 脚本: - ```bash - python scripts/generate_mysql_schema.py - ``` - -2. **手动迁移数据**(如果需要保留数据): - - 导出旧数据 - - 执行新的 SQL 脚本 - - 导入数据(可能需要调整) - -3. 或者删除重建(**会丢失所有数据**): - ```bash - mysql -h 47.121.138.217 -P 13307 -u root -p < scripts/mysql_schema.sql - ``` diff --git a/scripts/check_refactoring_status.py b/scripts/check_refactoring_status.py deleted file mode 100644 index 27a5afc..0000000 --- a/scripts/check_refactoring_status.py +++ /dev/null @@ -1,135 +0,0 @@ -#!/usr/bin/env python3 -""" -验证重构功能启用状态 -""" -import json -import os - -def check_refactoring_status(): - """检查重构功能启用状态""" - - print("=" * 70) - print("🔍 检查重构功能启用状态") - print("=" * 70) - print() - - # 检查配置 schema - schema_path = "_conf_schema.json" - if os.path.exists(schema_path): - with open(schema_path, 'r', encoding='utf-8') as f: - schema = json.load(f) - - print("📋 配置 Schema 检查:") - print() - - # 检查 Database_Settings - db_settings = schema.get('Database_Settings', {}).get('items', {}) - use_sqlalchemy = db_settings.get('use_sqlalchemy', {}) - if use_sqlalchemy: - default_value = use_sqlalchemy.get('default', False) - print(f" ✅ use_sqlalchemy: 已添加 (默认值: {default_value})") - print(f" 描述: {use_sqlalchemy.get('description')}") - print(f" 提示: {use_sqlalchemy.get('hint')}") - else: - print(" ❌ use_sqlalchemy: 未找到") - - print() - - # 检查 Advanced_Settings - adv_settings = schema.get('Advanced_Settings', {}).get('items', {}) - - use_enhanced = adv_settings.get('use_enhanced_managers', {}) - if use_enhanced: - default_value = use_enhanced.get('default', False) - print(f" ✅ use_enhanced_managers: 已添加 (默认值: {default_value})") - print(f" 描述: {use_enhanced.get('description')}") - else: - print(" ❌ use_enhanced_managers: 未找到") - - print() - - enable_cleanup = adv_settings.get('enable_memory_cleanup', {}) - if enable_cleanup: - print(f" ✅ enable_memory_cleanup: 已添加 (默认值: {enable_cleanup.get('default')})") - else: - print(" ❌ enable_memory_cleanup: 未找到") - - cleanup_days = adv_settings.get('memory_cleanup_days', {}) - if cleanup_days: - print(f" ✅ memory_cleanup_days: 已添加 (默认值: {cleanup_days.get('default')})") - else: - print(" ❌ memory_cleanup_days: 未找到") - - threshold = adv_settings.get('memory_importance_threshold', {}) - if threshold: - print(f" ✅ memory_importance_threshold: 已添加 (默认值: {threshold.get('default')})") - else: - print(" ❌ memory_importance_threshold: 未找到") - else: - print("❌ 配置文件不存在: _conf_schema.json") - - print() - print("=" * 70) - print("📊 总结") - print("=" * 70) - print() - - # 检查默认值 - all_enabled = all([ - use_sqlalchemy.get('default') == True, - use_enhanced.get('default') == True, - enable_cleanup.get('default') == True - ]) - - if all_enabled: - print("✅ 所有重构功能默认启用!") - print() - print("下次启动插件时将自动使用:") - print(" • SQLAlchemy 数据库管理器") - print(" • 增强型好感度管理器") - print(" • 增强型记忆图管理器") - print(" • 增强型心理状态管理器") - print(" • 统一缓存管理") - print(" • APScheduler 任务调度") - print(" • 自动数据库迁移") - print() - print("🎉 无需手动配置,直接重启 AstrBot 即可!") - else: - print("⚠️ 部分功能未默认启用") - print() - print("当前默认值:") - print(f" • use_sqlalchemy: {use_sqlalchemy.get('default', False)}") - print(f" • use_enhanced_managers: {use_enhanced.get('default', False)}") - print(f" • enable_memory_cleanup: {enable_cleanup.get('default', False)}") - print() - print("如需启用,请在 AstrBot 配置文件中设置为 true") - - print() - print("=" * 70) - - # 检查迁移标记 - migration_marker = "./data/self_learning_data/.migration_completed" - if os.path.exists(migration_marker): - print() - print("📌 数据库迁移状态:") - print(f" ✅ 已完成迁移") - print(f" 标记文件: {migration_marker}") - try: - with open(migration_marker, 'r', encoding='utf-8') as f: - migration_info = json.load(f) - print(f" 迁移时间: {migration_info.get('timestamp')}") - print(f" 迁移表数: {migration_info.get('tables_migrated', 0)}") - print(f" 总行数: {migration_info.get('total_rows_migrated', 0)}") - except: - pass - else: - print() - print("📌 数据库迁移状态:") - print(" ⏳ 尚未迁移(首次启动时会自动执行)") - - print() - print("=" * 70) - - -if __name__ == "__main__": - check_refactoring_status() diff --git a/scripts/generate_mysql_schema.py b/scripts/generate_mysql_schema.py deleted file mode 100644 index 7291667..0000000 --- a/scripts/generate_mysql_schema.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 -""" -从 ORM 模型生成 MySQL 建表 SQL 脚本 - -使用方法: - python scripts/generate_mysql_schema.py - -生成的 SQL 文件位于: scripts/mysql_schema.sql -可以直接在 MySQL 中执行此文件创建所有表 -""" -import sys -import os - -# 添加项目根目录到 Python 路径 -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.insert(0, project_root) - -from sqlalchemy import create_engine -from sqlalchemy.schema import CreateTable -from models.orm import Base - - -def generate_mysql_schema(output_file: str = "scripts/mysql_schema.sql"): - """ - 生成 MySQL 建表 SQL 脚本 - - Args: - output_file: 输出文件路径 - """ - # 创建一个临时的 MySQL engine(不需要真实连接) - engine = create_engine( - "mysql+pymysql://user:pass@localhost/dummy", - strategy='mock', - executor=lambda sql, *_: None - ) - - # 生成建表语句 - sql_statements = [] - - # 添加数据库创建语句 - sql_statements.append("-- =====================================================") - sql_statements.append("-- AstrBot Self Learning Plugin - MySQL Schema") - sql_statements.append("-- 从 SQLAlchemy ORM 模型自动生成") - sql_statements.append("-- =====================================================") - sql_statements.append("") - sql_statements.append("-- 创建数据库(如果不存在)") - sql_statements.append("CREATE DATABASE IF NOT EXISTS astrbot_self_learning DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;") - sql_statements.append("USE astrbot_self_learning;") - sql_statements.append("") - - # 按表名排序,确保依赖关系正确 - tables = sorted(Base.metadata.tables.values(), key=lambda t: t.name) - - for table in tables: - # 生成 CREATE TABLE 语句 - create_table_sql = str(CreateTable(table).compile(engine)) - - # 替换引擎为 InnoDB - if "ENGINE=" not in create_table_sql: - create_table_sql = create_table_sql.rstrip() + " ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci" - - sql_statements.append(f"-- 表: {table.name}") - sql_statements.append(f"DROP TABLE IF EXISTS `{table.name}`;") - sql_statements.append(create_table_sql + ";") - sql_statements.append("") - - # 写入文件 - output_path = os.path.join(project_root, output_file) - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - with open(output_path, 'w', encoding='utf-8') as f: - f.write('\n'.join(sql_statements)) - - print(f"✅ MySQL 建表脚本已生成: {output_path}") - print(f"📋 包含 {len(tables)} 个表") - print("\n表列表:") - for table in tables: - print(f" - {table.name}") - print(f"\n使用方法:") - print(f" mysql -h 47.121.138.217 -P 13307 -u root -p < {output_file}") - - -if __name__ == "__main__": - generate_mysql_schema() diff --git a/scripts/migrate_database.py b/scripts/migrate_database.py deleted file mode 100644 index 606f6d4..0000000 --- a/scripts/migrate_database.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -""" -数据库迁移命令行工具 -""" -import asyncio -import sys -import os - -# 添加项目路径 -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from utils.migration_tool import migrate_database - - -async def main(): - print("=" * 70) - print(" AstrBot 自学习插件 - 数据库自动迁移工具") - print("=" * 70) - print() - - # 检查命令行参数 - if len(sys.argv) < 2: - print("📖 用法:") - print(f" python {sys.argv[0]} ") - print() - print("📝 示例:") - print(f" # SQLite") - print(f" python {sys.argv[0]} ./data/database.db") - print() - print(f" # MySQL") - print(f" python {sys.argv[0]} mysql+aiomysql://user:password@localhost/dbname") - print() - sys.exit(1) - - db_path = sys.argv[1] - - # 处理 SQLite 路径 - if not db_path.startswith('mysql') and not db_path.startswith('sqlite'): - # 相对路径 - if not os.path.isabs(db_path): - db_path = os.path.abspath(db_path) - db_url = f"sqlite:///{db_path}" - else: - db_url = db_path - - print(f"🔗 数据库: {db_url}") - print() - - # 确认 - confirm = input("⚠️ 确认开始迁移? 这将创建新表并复制数据 (y/N): ") - if confirm.lower() != 'y': - print("❌ 已取消") - sys.exit(0) - - print() - print("=" * 70) - - # 执行迁移 - try: - await migrate_database(db_url, backup=True) - print() - print("=" * 70) - print("🎉 迁移完成!") - print("=" * 70) - print() - print("📋 后续步骤:") - print(" 1. 检查迁移日志,确认数据完整性") - print(" 2. 测试应用功能是否正常") - print(" 3. 如果一切正常,可以删除旧表备份") - print() - - except Exception as e: - print() - print("=" * 70) - print(f"❌ 迁移失败: {e}") - print("=" * 70) - print() - print("🔧 故障排查:") - print(" 1. 检查数据库连接是否正常") - print(" 2. 确认数据库用户有足够权限") - print(" 3. 查看完整错误日志") - print() - sys.exit(1) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/scripts/mysql_schema.sql b/scripts/mysql_schema.sql deleted file mode 100644 index 26daac6..0000000 --- a/scripts/mysql_schema.sql +++ /dev/null @@ -1,437 +0,0 @@ --- ===================================================== --- AstrBot Self Learning Plugin - MySQL Schema --- 从 SQLAlchemy ORM 模型自动生成 --- ===================================================== - --- 创建数据库(如果不存在) -CREATE DATABASE IF NOT EXISTS astrbot_self_learning DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; -USE astrbot_self_learning; - --- 表: affection_interactions -DROP TABLE IF EXISTS `affection_interactions`; - -CREATE TABLE affection_interactions ( - id INTEGER NOT NULL AUTO_INCREMENT, - user_affection_id INTEGER NOT NULL, - interaction_type VARCHAR(50) NOT NULL, - affection_delta INTEGER NOT NULL, - message_content TEXT, - timestamp BIGINT NOT NULL, - PRIMARY KEY (id), - FOREIGN KEY(user_affection_id) REFERENCES user_affections (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: bot_messages -DROP TABLE IF EXISTS `bot_messages`; - -CREATE TABLE bot_messages ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - message TEXT NOT NULL, - timestamp BIGINT NOT NULL, - created_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: composite_psychological_states -DROP TABLE IF EXISTS `composite_psychological_states`; - -CREATE TABLE composite_psychological_states ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - state_id VARCHAR(255) NOT NULL, - triggering_events TEXT, - context TEXT, - created_at BIGINT NOT NULL, - last_updated BIGINT NOT NULL, - PRIMARY KEY (id), - UNIQUE (state_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: expression_patterns -DROP TABLE IF EXISTS `expression_patterns`; - -CREATE TABLE expression_patterns ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - situation TEXT NOT NULL, - expression TEXT NOT NULL, - weight FLOAT NOT NULL, - last_active_time FLOAT NOT NULL, - create_time FLOAT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: filtered_messages -DROP TABLE IF EXISTS `filtered_messages`; - -CREATE TABLE filtered_messages ( - id INTEGER NOT NULL AUTO_INCREMENT, - raw_message_id INTEGER, - message TEXT NOT NULL, - sender_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255), - timestamp BIGINT NOT NULL, - confidence FLOAT, - quality_scores TEXT, - filter_reason TEXT, - created_at BIGINT NOT NULL, - processed BOOL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: interaction_records -DROP TABLE IF EXISTS `interaction_records`; - -CREATE TABLE interaction_records ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(100) NOT NULL, - user_id VARCHAR(100) NOT NULL, - interaction_type VARCHAR(50) NOT NULL, - content_preview VARCHAR(200), - timestamp BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: jargon -DROP TABLE IF EXISTS `jargon`; - -CREATE TABLE jargon ( - id INTEGER NOT NULL AUTO_INCREMENT, - content TEXT NOT NULL, - raw_content TEXT, - meaning TEXT, - is_jargon BOOL, - count INTEGER, - last_inference_count INTEGER, - is_complete BOOL, - is_global BOOL, - chat_id VARCHAR(255) NOT NULL, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: learning_performance_history -DROP TABLE IF EXISTS `learning_performance_history`; - -CREATE TABLE learning_performance_history ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - session_id VARCHAR(255), - timestamp BIGINT NOT NULL, - quality_score FLOAT, - learning_time FLOAT, - success BOOL, - successful_pattern TEXT, - failed_pattern TEXT, - created_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: memories -DROP TABLE IF EXISTS `memories`; - -CREATE TABLE memories ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - content TEXT NOT NULL, - importance INTEGER NOT NULL, - memory_type VARCHAR(50), - created_at BIGINT NOT NULL, - last_accessed BIGINT NOT NULL, - access_count INTEGER NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: memory_embeddings -DROP TABLE IF EXISTS `memory_embeddings`; - -CREATE TABLE memory_embeddings ( - id INTEGER NOT NULL AUTO_INCREMENT, - memory_id INTEGER NOT NULL, - embedding_model VARCHAR(100) NOT NULL, - embedding_data TEXT NOT NULL, - created_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: memory_summaries -DROP TABLE IF EXISTS `memory_summaries`; - -CREATE TABLE memory_summaries ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - summary_type VARCHAR(50) NOT NULL, - summary_content TEXT NOT NULL, - memory_count INTEGER, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: persona_update_reviews -DROP TABLE IF EXISTS `persona_update_reviews`; - -CREATE TABLE persona_update_reviews ( - id INTEGER NOT NULL AUTO_INCREMENT, - timestamp FLOAT NOT NULL, - group_id VARCHAR(255) NOT NULL, - update_type VARCHAR(255) NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, - confidence_score FLOAT, - reason TEXT, - status VARCHAR(50) NOT NULL, - reviewer_comment TEXT, - review_time FLOAT, - metadata TEXT, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: psychological_state_components -DROP TABLE IF EXISTS `psychological_state_components`; - -CREATE TABLE psychological_state_components ( - id INTEGER NOT NULL AUTO_INCREMENT, - composite_state_id INTEGER, - group_id VARCHAR(255) NOT NULL, - state_id VARCHAR(255) NOT NULL, - category VARCHAR(50) NOT NULL, - state_type VARCHAR(100) NOT NULL, - value FLOAT NOT NULL, - threshold FLOAT NOT NULL, - description TEXT, - start_time BIGINT NOT NULL, - PRIMARY KEY (id), - FOREIGN KEY(composite_state_id) REFERENCES composite_psychological_states (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: psychological_state_history -DROP TABLE IF EXISTS `psychological_state_history`; - -CREATE TABLE psychological_state_history ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - state_id VARCHAR(255) NOT NULL, - category VARCHAR(50) NOT NULL, - old_state_type VARCHAR(100), - new_state_type VARCHAR(100) NOT NULL, - old_value FLOAT, - new_value FLOAT NOT NULL, - change_reason TEXT, - timestamp BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: raw_messages -DROP TABLE IF EXISTS `raw_messages`; - -CREATE TABLE raw_messages ( - id INTEGER NOT NULL AUTO_INCREMENT, - sender_id VARCHAR(255) NOT NULL, - sender_name VARCHAR(255), - message TEXT NOT NULL, - group_id VARCHAR(255), - timestamp BIGINT NOT NULL, - platform VARCHAR(100), - message_id VARCHAR(255), - reply_to VARCHAR(255), - created_at BIGINT NOT NULL, - processed BOOL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: social_network_edges -DROP TABLE IF EXISTS `social_network_edges`; - -CREATE TABLE social_network_edges ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - from_user_id VARCHAR(255) NOT NULL, - to_user_id VARCHAR(255) NOT NULL, - edge_type VARCHAR(50) NOT NULL, - weight FLOAT, - properties TEXT, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: social_network_nodes -DROP TABLE IF EXISTS `social_network_nodes`; - -CREATE TABLE social_network_nodes ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - node_type VARCHAR(50), - display_name VARCHAR(255), - properties TEXT, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: social_relation_analysis_results -DROP TABLE IF EXISTS `social_relation_analysis_results`; - -CREATE TABLE social_relation_analysis_results ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - analysis_type VARCHAR(50) NOT NULL, - result_data TEXT NOT NULL, - created_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: social_relation_history -DROP TABLE IF EXISTS `social_relation_history`; - -CREATE TABLE social_relation_history ( - id INTEGER NOT NULL AUTO_INCREMENT, - from_user_id VARCHAR(255) NOT NULL, - to_user_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255) NOT NULL, - relation_type VARCHAR(100) NOT NULL, - old_value FLOAT, - new_value FLOAT NOT NULL, - change_reason TEXT, - timestamp BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: social_relations -DROP TABLE IF EXISTS `social_relations`; - -CREATE TABLE social_relations ( - id INTEGER NOT NULL AUTO_INCREMENT, - user_id VARCHAR(255), - from_user VARCHAR(255), - to_user VARCHAR(255), - group_id VARCHAR(255), - relation_type VARCHAR(100), - affection_score FLOAT, - interaction_count INTEGER, - strength FLOAT, - frequency INTEGER, - last_interaction FLOAT, - metadata TEXT, - created_at BIGINT, - updated_at BIGINT, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: style_learning_patterns -DROP TABLE IF EXISTS `style_learning_patterns`; - -CREATE TABLE style_learning_patterns ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(100) NOT NULL, - pattern_type VARCHAR(50) NOT NULL, - pattern TEXT NOT NULL, - usage_count INTEGER, - confidence FLOAT, - last_used BIGINT, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: style_learning_reviews -DROP TABLE IF EXISTS `style_learning_reviews`; - -CREATE TABLE style_learning_reviews ( - id INTEGER NOT NULL AUTO_INCREMENT, - type VARCHAR(100) NOT NULL, - group_id VARCHAR(255) NOT NULL, - timestamp FLOAT NOT NULL, - learned_patterns TEXT, - few_shots_content TEXT, - status VARCHAR(50), - description TEXT, - reviewer_comment TEXT, - review_time FLOAT, - created_at DATETIME, - updated_at DATETIME, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: user_affections -DROP TABLE IF EXISTS `user_affections`; - -CREATE TABLE user_affections ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - affection_level INTEGER NOT NULL, - max_affection INTEGER NOT NULL, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: user_conversation_history -DROP TABLE IF EXISTS `user_conversation_history`; - -CREATE TABLE user_conversation_history ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - `role` VARCHAR(20) NOT NULL, - content TEXT NOT NULL, - timestamp BIGINT NOT NULL, - turn_index INTEGER NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: user_diversity -DROP TABLE IF EXISTS `user_diversity`; - -CREATE TABLE user_diversity ( - id INTEGER NOT NULL AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - response_hash VARCHAR(64) NOT NULL, - response_preview VARCHAR(200), - timestamp BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: user_social_profiles -DROP TABLE IF EXISTS `user_social_profiles`; - -CREATE TABLE user_social_profiles ( - id INTEGER NOT NULL AUTO_INCREMENT, - user_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255) NOT NULL, - total_relations INTEGER NOT NULL, - significant_relations INTEGER NOT NULL, - dominant_relation_type VARCHAR(100), - created_at BIGINT NOT NULL, - last_updated BIGINT NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 表: user_social_relation_components -DROP TABLE IF EXISTS `user_social_relation_components`; - -CREATE TABLE user_social_relation_components ( - id INTEGER NOT NULL AUTO_INCREMENT, - profile_id INTEGER NOT NULL, - from_user_id VARCHAR(255) NOT NULL, - to_user_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255) NOT NULL, - relation_type VARCHAR(100) NOT NULL, - value FLOAT NOT NULL, - frequency INTEGER NOT NULL, - last_interaction BIGINT NOT NULL, - description TEXT, - tags TEXT, - created_at BIGINT NOT NULL, - PRIMARY KEY (id), - FOREIGN KEY(profile_id) REFERENCES user_social_profiles (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; diff --git a/scripts/mysql_schema_additional.sql b/scripts/mysql_schema_additional.sql deleted file mode 100644 index 2860706..0000000 --- a/scripts/mysql_schema_additional.sql +++ /dev/null @@ -1,289 +0,0 @@ --- ===================================================== --- 传统表(未迁移到 ORM 的表) --- ===================================================== - --- 选择数据库 -USE bot_db_migrated; - --- =================================================== --- 学习批次表(如果已存在则确保结构正确) --- =================================================== --- 先创建表(如果不存在) -CREATE TABLE IF NOT EXISTS learning_batches ( - id INT PRIMARY KEY AUTO_INCREMENT, - batch_id VARCHAR(255) UNIQUE, - batch_name VARCHAR(255) NOT NULL, - group_id VARCHAR(255) NOT NULL, - start_time DOUBLE NOT NULL, - end_time DOUBLE, - quality_score DOUBLE, - processed_messages INT DEFAULT 0, - message_count INT DEFAULT 0, - filtered_count INT DEFAULT 0, - success BOOLEAN DEFAULT 1, - error_message TEXT, - status VARCHAR(50) DEFAULT 'pending', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id), - INDEX idx_batch_id (batch_id), - INDEX idx_batch_name (batch_name) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 确保 batch_name 列存在(用于向后兼容) --- 如果表已存在但缺少该列,则添加 -SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = DATABASE() - AND TABLE_NAME = 'learning_batches' - AND COLUMN_NAME = 'batch_name'); - -SET @alter_sql = IF(@column_exists = 0, - 'ALTER TABLE learning_batches ADD COLUMN batch_name VARCHAR(255) NOT NULL AFTER batch_id', - 'SELECT "Column batch_name already exists"'); - -PREPARE stmt FROM @alter_sql; -EXECUTE stmt; -DEALLOCATE PREPARE stmt; - --- =================================================== --- 其他传统表 --- =================================================== - --- 强化学习结果表 -CREATE TABLE IF NOT EXISTS reinforcement_learning_results ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - replay_analysis TEXT, - optimization_strategy TEXT, - reinforcement_feedback TEXT, - next_action TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 策略优化结果表 -CREATE TABLE IF NOT EXISTS strategy_optimization_results ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - exploration_type VARCHAR(100), - effectiveness_score DOUBLE, - detailed_metrics TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 人格融合历史表 -CREATE TABLE IF NOT EXISTS persona_fusion_history ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - base_persona_hash BIGINT, - incremental_hash BIGINT, - fusion_result TEXT, - compatibility_score DOUBLE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 学习会话表 -CREATE TABLE IF NOT EXISTS learning_sessions ( - id INT PRIMARY KEY AUTO_INCREMENT, - session_id VARCHAR(255) UNIQUE NOT NULL, - group_id VARCHAR(255) NOT NULL, - batch_id VARCHAR(255), - start_time DOUBLE NOT NULL, - end_time DOUBLE, - metrics TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id), - INDEX idx_session (session_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 人格备份表 -CREATE TABLE IF NOT EXISTS persona_backups ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - backup_time DOUBLE NOT NULL, - persona_content TEXT NOT NULL, - persona_hash BIGINT, - backup_reason VARCHAR(255), - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 人格更新记录表 -CREATE TABLE IF NOT EXISTS persona_update_records ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - update_time DOUBLE NOT NULL, - old_persona_hash BIGINT, - new_persona_hash BIGINT, - update_type VARCHAR(50), - update_content TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- Bot 心情表 -CREATE TABLE IF NOT EXISTS bot_mood ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - mood_type VARCHAR(50) NOT NULL, - intensity DOUBLE DEFAULT 0.5, - trigger_event TEXT, - timestamp DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 对话上下文表 -CREATE TABLE IF NOT EXISTS conversation_contexts ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - context_data TEXT, - last_update DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 情感模式表 -CREATE TABLE IF NOT EXISTS emotion_patterns ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - pattern_type VARCHAR(100), - pattern_data TEXT, - confidence DOUBLE DEFAULT 0.5, - last_updated DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 情感档案表 -CREATE TABLE IF NOT EXISTS emotion_profiles ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - emotion_data TEXT, - last_updated DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group_user (group_id, user_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 知识实体表 -CREATE TABLE IF NOT EXISTS knowledge_entities ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - entity_type VARCHAR(100), - entity_name VARCHAR(255), - entity_data TEXT, - confidence DOUBLE DEFAULT 0.5, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 语言风格模式表 -CREATE TABLE IF NOT EXISTS language_style_patterns ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - pattern_type VARCHAR(100), - pattern_content TEXT, - frequency INT DEFAULT 0, - last_used DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 风格档案表 -CREATE TABLE IF NOT EXISTS style_profiles ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - style_data TEXT, - last_updated DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 主题偏好表 -CREATE TABLE IF NOT EXISTS topic_preferences ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - topic VARCHAR(255), - preference_score DOUBLE DEFAULT 0.5, - last_updated DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 主题摘要表 -CREATE TABLE IF NOT EXISTS topic_summaries ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - topic VARCHAR(255), - summary_content TEXT, - last_updated DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 用户偏好表 -CREATE TABLE IF NOT EXISTS user_preferences ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - preference_data TEXT, - last_updated DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group_user (group_id, user_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 用户档案表 -CREATE TABLE IF NOT EXISTS user_profiles ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - profile_data TEXT, - last_updated DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group_user (group_id, user_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- LLM 调用统计表 -CREATE TABLE IF NOT EXISTS llm_call_statistics ( - id INT PRIMARY KEY AUTO_INCREMENT, - call_type VARCHAR(50), - model_name VARCHAR(100), - tokens_used INT, - response_time DOUBLE, - success BOOLEAN DEFAULT TRUE, - error_message TEXT, - timestamp DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_call_type (call_type), - INDEX idx_timestamp (timestamp) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 风格学习记录表 -CREATE TABLE IF NOT EXISTS style_learning_records ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - learning_type VARCHAR(100), - learning_content TEXT, - effectiveness DOUBLE DEFAULT 0.5, - timestamp DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - --- 好感度历史表 -CREATE TABLE IF NOT EXISTS affection_history ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - old_affection INT, - new_affection INT, - change_reason VARCHAR(255), - timestamp DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group_user (group_id, user_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; diff --git a/scripts/quick_test.sh b/scripts/quick_test.sh deleted file mode 100755 index 9f17055..0000000 --- a/scripts/quick_test.sh +++ /dev/null @@ -1,84 +0,0 @@ -#!/bin/bash -# 快速测试脚本 - 检查代码质量和运行测试 - -set -e - -echo "╔════════════════════════════════════════════════════════════════╗" -echo "║ Astrbot Self-Learning Plugin - 测试工具 ║" -echo "╚════════════════════════════════════════════════════════════════╝" -echo "" - -# 检查是否安装了测试工具 -check_tool() { - if ! command -v $1 &> /dev/null; then - echo "⚠️ $1 未安装,跳过..." - return 1 - fi - return 0 -} - -# 1. Python 语法检查 -echo "🔍 [1/6] Python 语法检查..." -python -m py_compile *.py 2>/dev/null && echo "✅ 语法检查通过" || echo "❌ 语法错误" -echo "" - -# 2. 代码风格检查 -echo "🎨 [2/6] 代码风格检查 (flake8)..." -if check_tool flake8; then - flake8 --max-line-length=120 --exclude=venv,__pycache__,.git,web_res --count --statistics . || true -else - echo "💡 安装: pip install flake8" -fi -echo "" - -# 3. 代码复杂度分析 -echo "📊 [3/6] 代码复杂度分析 (radon)..." -if check_tool radon; then - echo "圈复杂度 (推荐 < 10):" - radon cc . -a -s --exclude="venv,__pycache__,web_res" | head -20 - echo "" - echo "可维护性指数 (推荐 > 20):" - radon mi . -s --exclude="venv,__pycache__,web_res" | head -10 -else - echo "💡 安装: pip install radon" -fi -echo "" - -# 4. 安全检查 -echo "🔒 [4/6] 安全漏洞扫描 (bandit)..." -if check_tool bandit; then - bandit -r . -ll -f json -o bandit_report.json 2>/dev/null && \ - echo "✅ 安全检查完成,报告: bandit_report.json" || \ - echo "⚠️ 发现潜在安全问题,查看: bandit_report.json" -else - echo "💡 安装: pip install bandit" -fi -echo "" - -# 5. 运行现有测试 -echo "🧪 [5/6] 运行 API 测试..." -if [ -f "test_api_simple.py" ]; then - echo "运行简化测试..." - timeout 10 python test_api_simple.py 2>&1 | head -20 || echo "⚠️ 测试需要 WebUI 运行" -else - echo "ℹ️ 未找到测试文件" -fi -echo "" - -# 6. 文件统计 -echo "📈 [6/6] 项目统计..." -echo "Python 文件数:" -find . -name "*.py" -not -path "./venv/*" -not -path "./__pycache__/*" | wc -l -echo "总代码行数:" -find . -name "*.py" -not -path "./venv/*" -not -path "./__pycache__/*" -exec wc -l {} + | tail -1 -echo "" - -echo "╔════════════════════════════════════════════════════════════════╗" -echo "║ 测试完成 ║" -echo "╚════════════════════════════════════════════════════════════════╝" -echo "" -echo "💡 建议的下一步:" -echo " 1. 查看 bandit_report.json 处理安全问题" -echo " 2. 运行 'flake8 .' 修复代码风格问题" -echo " 3. 创建单元测试 (参考 docs/TESTING_GUIDE.md)" -echo "" diff --git a/scripts/webui_refactor_analyzer.py b/scripts/webui_refactor_analyzer.py deleted file mode 100644 index 616ec98..0000000 --- a/scripts/webui_refactor_analyzer.py +++ /dev/null @@ -1,165 +0,0 @@ -#!/usr/bin/env python3 -""" -WebUI 自动重构工具 -分析原 webui.py 并生成重构后的蓝图代码 -""" -import re -import os -from typing import List, Dict, Tuple - - -class WebUIRefactorTool: - """WebUI 重构工具""" - - def __init__(self, source_file: str = "webui.py"): - self.source_file = source_file - self.routes = [] - self.functions = [] - - def analyze_routes(self) -> Dict[str, List[Tuple[str, str, List[str]]]]: - """ - 分析路由并按功能分组 - - Returns: - Dict[分组名, List[(路由路径, 函数名, HTTP方法)]] - """ - route_groups = { - 'auth': [], # 认证相关 - 'config': [], # 配置管理 - 'personas': [], # 人格管理 - 'learning': [], # 学习功能 - 'metrics': [], # 指标分析 - 'social': [], # 社交关系 - 'jargon': [], # 黑话管理 - 'bug_report': [], # Bug报告 - 'chat': [], # 聊天历史 - 'other': [] # 其他 - } - - with open(self.source_file, 'r', encoding='utf-8') as f: - content = f.read() - - # 查找所有路由定义 - route_pattern = r'@app\.route\([\'"]([^\'"]+)[\'"]\s*(?:,\s*methods=\[(.*?)\])?\s*\)\s*async def (\w+)' - - for match in re.finditer(route_pattern, content): - path = match.group(1) - methods_str = match.group(2) or "'GET'" - func_name = match.group(3) - methods = [m.strip('\'" ') for m in methods_str.split(',')] - - # 根据路径和函数名分组 - if any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['login', 'logout', 'password', 'auth']): - route_groups['auth'].append((path, func_name, methods)) - elif any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['persona', 'personality']): - route_groups['personas'].append((path, func_name, methods)) - elif any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['learning', 'style']): - route_groups['learning'].append((path, func_name, methods)) - elif any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['metrics', 'analytics']): - route_groups['metrics'].append((path, func_name, methods)) - elif any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['social', 'relation']): - route_groups['social'].append((path, func_name, methods)) - elif any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['jargon', '黑话']): - route_groups['jargon'].append((path, func_name, methods)) - elif any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['bug', 'report']): - route_groups['bug_report'].append((path, func_name, methods)) - elif any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['chat', 'message', 'history']): - route_groups['chat'].append((path, func_name, methods)) - elif any(keyword in path.lower() or keyword in func_name.lower() - for keyword in ['config', 'setting']): - route_groups['config'].append((path, func_name, methods)) - else: - route_groups['other'].append((path, func_name, methods)) - - return route_groups - - def print_analysis(self): - """打印分析结果""" - route_groups = self.analyze_routes() - - print("=" * 70) - print("WebUI 路由分析结果") - print("=" * 70) - print() - - total_routes = 0 - for group_name, routes in route_groups.items(): - if routes: - print(f"📦 {group_name.upper()} ({len(routes)} 个路由)") - print("-" * 70) - for path, func_name, methods in routes: - methods_str = ', '.join(methods) - print(f" {methods_str:15} {path:40} -> {func_name}") - print() - total_routes += len(routes) - - print("=" * 70) - print(f"总计: {total_routes} 个路由") - print("=" * 70) - - def generate_blueprint_template(self, group_name: str, routes: List[Tuple[str, str, List[str]]]) -> str: - """生成蓝图模板代码""" - template = f'''""" -{group_name.capitalize()} 相关路由 -""" -from quart import Blueprint, render_template, request, jsonify, session - -from ..dependencies import get_container -from ..services.{group_name}_service import {group_name.capitalize()}Service -from ..middleware.auth import require_auth -from ..utils.response import success_response, error_response - -{group_name}_bp = Blueprint('{group_name}', __name__, url_prefix='/api/{group_name}') - - -''' - - for path, func_name, methods in routes: - # 提取路由参数 - params = re.findall(r'<(\w+)(?::(\w+))?>', path) - param_str = ', '.join([p[1] if p[1] else p[0] for p in params]) if params else '' - - methods_str = ', '.join([f'"{m}"' for m in methods]) - - template += f'''@{group_name}_bp.route('{path}', methods=[{methods_str}]) -@require_auth -async def {func_name}({param_str}): - """TODO: 实现 {func_name}""" - try: - service = {group_name.capitalize()}Service(get_container()) - # TODO: 实现业务逻辑 - return success_response("TODO") - except Exception as e: - return error_response(f"操作失败: {{str(e)}}", 500) - - -''' - - return template - - -def main(): - """主函数""" - tool = WebUIRefactorTool() - tool.print_analysis() - - print() - print("💡 建议的重构步骤:") - print("1. 创建上述每个分组的 blueprint 文件") - print("2. 为每个 blueprint 创建对应的 service 文件") - print("3. 从 webui.py 提取对应的业务逻辑到 service") - print("4. 逐个测试每个 blueprint") - print("5. 全部迁移完成后删除 webui.py") - print() - - -if __name__ == "__main__": - main() diff --git a/services/analysis/__init__.py b/services/analysis/__init__.py new file mode 100644 index 0000000..4a257b9 --- /dev/null +++ b/services/analysis/__init__.py @@ -0,0 +1,17 @@ +"""Data analysis, ML, and intelligence services.""" + +from .multidimensional_analyzer import MultidimensionalAnalyzer +from .ml_analyzer import LightweightMLAnalyzer +from .intelligence_enhancement import IntelligenceEnhancementService +from .data_analytics import DataAnalyticsService +from .expression_pattern_learner import ExpressionPatternLearner +from .intelligence_metrics import IntelligenceMetricsService + +__all__ = [ + "MultidimensionalAnalyzer", + "LightweightMLAnalyzer", + "IntelligenceEnhancementService", + "DataAnalyticsService", + "ExpressionPatternLearner", + "IntelligenceMetricsService", +] diff --git a/services/data_analytics.py b/services/analysis/data_analytics.py similarity index 96% rename from services/data_analytics.py rename to services/analysis/data_analytics.py index 141c6b0..68ae03f 100644 --- a/services/data_analytics.py +++ b/services/analysis/data_analytics.py @@ -21,18 +21,16 @@ from astrbot.api import logger -from ..config import PluginConfig +from ...config import PluginConfig -from ..core.patterns import AsyncServiceBase +from ...core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage - -from ..core.compatibility_extensions import create_compatibility_extensions +from ...core.interfaces import IDataStorage class DataAnalyticsService(AsyncServiceBase): """数据分析与可视化服务""" - + def __init__(self, config: PluginConfig, database_manager: IDataStorage): super().__init__("data_analytics") self.config = config @@ -40,10 +38,6 @@ def __init__(self, config: PluginConfig, database_manager: IDataStorage): self.analytics_cache = {} self.cache_timeout = 300 # 5分钟缓存 - # 创建兼容性扩展 - extensions = create_compatibility_extensions(config, None, database_manager, None) - self.db_ext = extensions['db_manager'] - async def _do_start(self) -> bool: """启动分析服务""" try: @@ -66,7 +60,7 @@ async def generate_learning_trajectory_chart(self, group_id: str, days: int = 30 try: # 获取人格更新历史数据 - persona_updates = await self.db_ext.get_persona_update_history(group_id, days) + persona_updates = await self.db_manager.get_persona_update_history(group_id, days) if not persona_updates: return {"chart": None, "message": "暂无人格更新数据"} @@ -161,7 +155,7 @@ async def generate_learning_quality_curve(self, group_id: str, days: int = 30) - try: # 获取学习批次数据 - learning_batches = await self.db_ext.get_learning_batch_history(group_id, days) + learning_batches = await self.db_manager.get_learning_batch_history(group_id, days) if not learning_batches: return {"chart": None, "message": "暂无学习批次数据"} @@ -246,7 +240,7 @@ async def generate_user_activity_heatmap(self, group_id: str, days: int = 7) -> try: # 获取用户消息数据 - messages = await self.db_ext.get_messages_by_timerange( + messages = await self.db_manager.get_messages_by_timerange( group_id, datetime.now() - timedelta(days=days), datetime.now() @@ -314,7 +308,7 @@ async def generate_topic_trend_analysis(self, group_id: str, days: int = 30) -> try: # 获取消息数据 - messages = await self.db_ext.get_messages_by_timerange( + messages = await self.db_manager.get_messages_by_timerange( group_id, datetime.now() - timedelta(days=days), datetime.now() @@ -412,7 +406,7 @@ async def generate_social_network_graph(self, group_id: str, days: int = 30) -> try: # 获取社交关系数据 - relationships = await self.db_ext.get_social_relationships(group_id, days) + relationships = await self.db_manager.get_social_relationships(group_id, days) if not relationships: return {"chart": None, "message": "暂无社交关系数据"} @@ -513,7 +507,7 @@ async def analyze_user_behavior_patterns(self, group_id: str, days: int = 30) -> try: # 获取用户消息数据 - messages = await self.db_ext.get_messages_by_timerange( + messages = await self.db_manager.get_messages_by_timerange( group_id, datetime.now() - timedelta(days=days), datetime.now() diff --git a/services/expression_pattern_learner.py b/services/analysis/expression_pattern_learner.py similarity index 71% rename from services/expression_pattern_learner.py rename to services/analysis/expression_pattern_learner.py index 6f19eb9..e2e8052 100644 --- a/services/expression_pattern_learner.py +++ b/services/analysis/expression_pattern_learner.py @@ -5,30 +5,29 @@ import time import json import random -import sqlite3 from typing import Dict, List, Optional, Tuple, Any from datetime import datetime from dataclasses import dataclass, asdict from astrbot.api import logger -from ..core.interfaces import MessageData, ServiceLifecycle -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..config import PluginConfig -from ..exceptions import ExpressionLearningError, ModelAccessError -from ..utils.json_utils import safe_parse_llm_json -from .database_manager import DatabaseManager +from ...core.interfaces import MessageData, ServiceLifecycle +from ...core.framework_llm_adapter import FrameworkLLMAdapter +from ...config import PluginConfig +from ...exceptions import ExpressionLearningError, ModelAccessError +from ...utils.json_utils import safe_parse_llm_json +from ..database import DatabaseManager @dataclass class ExpressionPattern: """表达模式数据结构""" - situation: str # 场景描述,如"对某件事表示十分惊叹" - expression: str # 表达方式,如"我嘞个xxxx" - weight: float # 权重(使用频率) - last_active_time: float # 最后活跃时间 - create_time: float # 创建时间 - group_id: str # 所属群组ID + situation: str # 场景描述,如"对某件事表示十分惊叹" + expression: str # 表达方式,如"我嘞个xxxx" + weight: float # 权重(使用频率) + last_active_time: float # 最后活跃时间 + create_time: float # 创建时间 + group_id: str # 所属群组ID def to_dict(self) -> Dict[str, Any]: return asdict(self) @@ -46,11 +45,11 @@ class ExpressionPatternLearner: """ # MaiBot的配置参数 - MAX_EXPRESSION_COUNT = 300 # 最大表达式数量 - DECAY_DAYS = 15 # 15天衰减周期 - DECAY_MIN = 0.01 # 最小衰减值 - MIN_MESSAGES_FOR_LEARNING = 25 # 触发学习所需的最少消息数 - MIN_LEARNING_INTERVAL = 300 # 最短学习时间间隔(秒) + MAX_EXPRESSION_COUNT = 300 # 最大表达式数量 + DECAY_DAYS = 15 # 15天衰减周期 + DECAY_MIN = 0.01 # 最小衰减值 + MIN_MESSAGES_FOR_LEARNING = 25 # 触发学习所需的最少消息数 + MIN_LEARNING_INTERVAL = 300 # 最短学习时间间隔(秒) _instance = None _initialized = False @@ -105,57 +104,8 @@ def get_instance(cls, config: PluginConfig = None, db_manager: DatabaseManager = return cls._instance async def _init_expression_patterns_table(self): - """初始化表达模式数据库表(异步)""" - if self._table_initialized: - return - - try: - # 检查是否是 SQLAlchemy 版本 - if hasattr(self.db_manager, 'get_session'): - # SQLAlchemy 版本 - 使用 async session - async with self.db_manager.get_session() as session: - # 使用 SQLAlchemy 原生 SQL - from sqlalchemy import text - await session.execute(text(''' - CREATE TABLE IF NOT EXISTS expression_patterns ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - situation TEXT NOT NULL, - expression TEXT NOT NULL, - weight REAL NOT NULL DEFAULT 1.0, - last_active_time REAL NOT NULL, - create_time REAL NOT NULL, - group_id TEXT NOT NULL, - UNIQUE(situation, expression, group_id) - ) - ''')) - await session.commit() - logger.info("表达模式数据库表初始化完成 (SQLAlchemy)") - elif hasattr(self.db_manager, 'get_db_connection'): - # 传统 DatabaseManager - 使用 get_db_connection 上下文管理器 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS expression_patterns ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - situation TEXT NOT NULL, - expression TEXT NOT NULL, - weight REAL NOT NULL DEFAULT 1.0, - last_active_time REAL NOT NULL, - create_time REAL NOT NULL, - group_id TEXT NOT NULL, - UNIQUE(situation, expression, group_id) - ) - ''') - await conn.commit() - await cursor.close() - logger.info("表达模式数据库表初始化完成 (传统)") - else: - raise ExpressionLearningError("不支持的数据库管理器类型") - - self._table_initialized = True - except Exception as e: - logger.error(f"初始化表达模式数据库表失败: {e}") - raise ExpressionLearningError(f"数据库初始化失败: {e}") + """表达模式表由 ORM (models/orm/expression.py) 在引擎启动时自动创建""" + self._table_initialized = True async def start(self) -> bool: """启动服务""" @@ -255,7 +205,7 @@ async def learn_expression_patterns(self, messages: List[MessageData], group_id: 请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格 1. 只考虑文字,不要考虑表情包和图片 -2. 不要涉及具体的人名,但是可以涉及具体名词 +2. 不要涉及具体的人名,但是可以涉及具体名词 3. 思考有没有特殊的梗,一并总结成语言风格 4. 例子仅供参考,请严格根据群聊内容总结!!! @@ -279,8 +229,8 @@ async def learn_expression_patterns(self, messages: List[MessageData], group_id: try: response = await self.llm_adapter.generate_response( prompt, - temperature=0.3, # 使用MaiBot的temperature设置 - model_type="refine" # 使用精炼模型 + temperature=0.3, # 使用MaiBot的temperature设置 + model_type="refine" # 使用精炼模型 ) # 检查response是否有效 @@ -349,7 +299,7 @@ def _generate_fallback_expression_patterns(self, messages: List[MessageData]) -> patterns = [] # 分析消息特征 - for msg in messages[:10]: # 只分析前10条消息 + for msg in messages[:10]: # 只分析前10条消息 # 兼容处理MessageData对象和字典类型 if hasattr(msg, 'message'): # 如果是MessageData对象 @@ -400,7 +350,7 @@ def _generate_fallback_expression_patterns(self, messages: List[MessageData]) -> } # 检测表情符号 - elif any(emoji in content for emoji in ['😊', '😄', '😢', '😂', '🤔', '👍', '❤️']): + elif any(emoji in content for emoji in ['', '', '', '', '', '', '']): pattern_data = { "situation": "表达情感状态", "expression": content[:10] + ('...' if len(content) > 10 else ''), @@ -533,83 +483,91 @@ def _parse_expression_response(self, response: str, group_id: str) -> List[Expre return patterns async def _save_expression_patterns(self, patterns: List[ExpressionPattern], group_id: str): - """保存表达模式到数据库(异步版本)""" + """保存表达模式到数据库(ORM 版本)""" try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() + from sqlalchemy import select + from ...models.orm.expression import ExpressionPattern as ExpressionPatternORM + async with self.db_manager.get_session() as session: for pattern in patterns: # 查找是否已存在相似模式 - await cursor.execute( - 'SELECT id, weight FROM expression_patterns WHERE situation = ? AND expression = ? AND group_id = ?', - (pattern.situation, pattern.expression, group_id) + stmt = select(ExpressionPatternORM).where( + ExpressionPatternORM.situation == pattern.situation, + ExpressionPatternORM.expression == pattern.expression, + ExpressionPatternORM.group_id == group_id, ) - existing = await cursor.fetchone() + result = await session.execute(stmt) + existing = result.scalar_one_or_none() if existing: - # 更新现有模式,权重增加,50%概率替换内容(参考MaiBot) - new_weight = existing[1] + 1.0 + # 更新现有模式,权重增加,50%概率替换内容 + existing.weight += 1.0 + existing.last_active_time = pattern.last_active_time if random.random() < 0.5: - await cursor.execute( - 'UPDATE expression_patterns SET weight = ?, last_active_time = ?, situation = ?, expression = ? WHERE id = ?', - (new_weight, pattern.last_active_time, pattern.situation, pattern.expression, existing[0]) - ) - else: - await cursor.execute( - 'UPDATE expression_patterns SET weight = ?, last_active_time = ? WHERE id = ?', - (new_weight, pattern.last_active_time, existing[0]) - ) + existing.situation = pattern.situation + existing.expression = pattern.expression else: # 插入新模式 - await cursor.execute( - 'INSERT INTO expression_patterns (situation, expression, weight, last_active_time, create_time, group_id) VALUES (?, ?, ?, ?, ?, ?)', - (pattern.situation, pattern.expression, pattern.weight, pattern.last_active_time, pattern.create_time, pattern.group_id) + new_record = ExpressionPatternORM( + situation=pattern.situation, + expression=pattern.expression, + weight=pattern.weight, + last_active_time=pattern.last_active_time, + create_time=pattern.create_time, + group_id=pattern.group_id, ) + session.add(new_record) - await conn.commit() - logger.info(f"✅ 保存了 {len(patterns)} 个表达模式到数据库(群组: {group_id})") + await session.commit() + logger.info(f" 保存了 {len(patterns)} 个表达模式到数据库(群组: {group_id})") except Exception as e: logger.error(f"保存表达模式失败: {e}", exc_info=True) raise ExpressionLearningError(f"保存表达模式失败: {e}") async def _apply_time_decay(self, group_id: str): - """ - 应用时间衰减 - 完全参考MaiBot的衰减机制(异步版本) - """ + """应用时间衰减 - 完全参考MaiBot的衰减机制(ORM 版本)""" try: + from sqlalchemy import select, delete + from ...models.orm.expression import ExpressionPattern as ExpressionPatternORM + current_time = time.time() updated_count = 0 deleted_count = 0 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - + async with self.db_manager.get_session() as session: # 获取所有该群组的表达模式 - await cursor.execute( - 'SELECT id, weight, last_active_time FROM expression_patterns WHERE group_id = ?', - (group_id,) + stmt = select(ExpressionPatternORM).where( + ExpressionPatternORM.group_id == group_id ) - patterns = await cursor.fetchall() + result = await session.execute(stmt) + patterns = result.scalars().all() - for pattern_id, weight, last_active_time in patterns: + ids_to_delete = [] + for pattern in patterns: # 计算时间差(天) - time_diff_days = (current_time - last_active_time) / (24 * 3600) + time_diff_days = (current_time - pattern.last_active_time) / (24 * 3600) # 计算衰减值 decay_value = self._calculate_decay_factor(time_diff_days) - new_weight = max(self.DECAY_MIN, weight - decay_value) + new_weight = max(self.DECAY_MIN, pattern.weight - decay_value) if new_weight <= self.DECAY_MIN: - # 删除权重过低的模式 - await cursor.execute('DELETE FROM expression_patterns WHERE id = ?', (pattern_id,)) + ids_to_delete.append(pattern.id) deleted_count += 1 else: - # 更新权重 - await cursor.execute('UPDATE expression_patterns SET weight = ? WHERE id = ?', (new_weight, pattern_id)) + pattern.weight = new_weight updated_count += 1 - await conn.commit() + # 批量删除权重过低的模式 + if ids_to_delete: + await session.execute( + delete(ExpressionPatternORM).where( + ExpressionPatternORM.id.in_(ids_to_delete) + ) + ) + + await session.commit() if updated_count > 0 or deleted_count > 0: logger.info(f"群组 {group_id} 时间衰减完成:更新了 {updated_count} 个,删除了 {deleted_count} 个表达模式") @@ -625,10 +583,10 @@ def _calculate_decay_factor(self, time_diff_days: float) -> float: 使用二次函数进行曲线插值 """ if time_diff_days <= 0: - return 0.0 # 刚激活的表达式不衰减 + return 0.0 # 刚激活的表达式不衰减 if time_diff_days >= self.DECAY_DAYS: - return 0.01 # 长时间未活跃的表达式大幅衰减 + return 0.01 # 长时间未活跃的表达式大幅衰减 # 使用二次函数插值:在0-15天之间从0衰减到0.01 a = 0.01 / (self.DECAY_DAYS ** 2) @@ -637,68 +595,70 @@ def _calculate_decay_factor(self, time_diff_days: float) -> float: return min(0.01, decay) async def _limit_max_expressions(self, group_id: str): - """限制最大表达模式数量(异步版本)""" + """限制最大表达模式数量(ORM 版本)""" try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() + from sqlalchemy import select, func, delete, asc + from ...models.orm.expression import ExpressionPattern as ExpressionPatternORM + async with self.db_manager.get_session() as session: # 统计当前数量 - await cursor.execute('SELECT COUNT(*) FROM expression_patterns WHERE group_id = ?', (group_id,)) - row = await cursor.fetchone() - count = row[0] if row else 0 + count_stmt = select(func.count()).select_from(ExpressionPatternORM).where( + ExpressionPatternORM.group_id == group_id + ) + count = (await session.execute(count_stmt)).scalar() or 0 if count > self.MAX_EXPRESSION_COUNT: - # 删除权重最小的多余模式 - # MySQL 不支持 DELETE ... WHERE id IN (SELECT ... LIMIT) - # 改用 JOIN 方式 excess_count = count - self.MAX_EXPRESSION_COUNT - # 先查询要删除的 ID - await cursor.execute( - 'SELECT id FROM expression_patterns WHERE group_id = ? ORDER BY weight ASC LIMIT ?', - (group_id, excess_count) + # 查询权重最小的 ID + ids_stmt = ( + select(ExpressionPatternORM.id) + .where(ExpressionPatternORM.group_id == group_id) + .order_by(asc(ExpressionPatternORM.weight)) + .limit(excess_count) ) - rows = await cursor.fetchall() - ids_to_delete = [row[0] for row in rows] + result = await session.execute(ids_stmt) + ids_to_delete = [row[0] for row in result.fetchall()] if ids_to_delete: - # 批量删除 - placeholders = ','.join(['?' for _ in ids_to_delete]) - await cursor.execute( - f'DELETE FROM expression_patterns WHERE id IN ({placeholders})', - tuple(ids_to_delete) + await session.execute( + delete(ExpressionPatternORM).where( + ExpressionPatternORM.id.in_(ids_to_delete) + ) ) - await conn.commit() + await session.commit() logger.info(f"群组 {group_id} 删除了 {len(ids_to_delete)} 个权重最小的表达模式") except Exception as e: logger.error(f"限制表达模式数量失败: {e}", exc_info=True) async def get_expression_patterns(self, group_id: str, limit: int = 10) -> List[ExpressionPattern]: - """获取群组的表达模式(异步版本)""" + """获取群组的表达模式(ORM 版本)""" try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute( - 'SELECT situation, expression, weight, last_active_time, create_time, group_id FROM expression_patterns WHERE group_id = ? ORDER BY weight DESC LIMIT ?', - (group_id, limit) + from sqlalchemy import select, desc + from ...models.orm.expression import ExpressionPattern as ExpressionPatternORM + + async with self.db_manager.get_session() as session: + stmt = ( + select(ExpressionPatternORM) + .where(ExpressionPatternORM.group_id == group_id) + .order_by(desc(ExpressionPatternORM.weight)) + .limit(limit) ) - - rows = await cursor.fetchall() - patterns = [] - for row in rows: - pattern = ExpressionPattern( - situation=row[0], - expression=row[1], - weight=row[2], - last_active_time=row[3], - create_time=row[4], - group_id=row[5] + result = await session.execute(stmt) + rows = result.scalars().all() + + return [ + ExpressionPattern( + situation=row.situation, + expression=row.expression, + weight=row.weight, + last_active_time=row.last_active_time, + create_time=row.create_time, + group_id=row.group_id, ) - patterns.append(pattern) - - return patterns + for row in rows + ] except Exception as e: logger.error(f"获取表达模式失败: {e}", exc_info=True) diff --git a/services/intelligence_enhancement.py b/services/analysis/intelligence_enhancement.py similarity index 99% rename from services/intelligence_enhancement.py rename to services/analysis/intelligence_enhancement.py index 037dc3e..b4d5f02 100644 --- a/services/intelligence_enhancement.py +++ b/services/analysis/intelligence_enhancement.py @@ -16,11 +16,11 @@ from astrbot.api import logger -from ..config import PluginConfig -from ..core.patterns import AsyncServiceBase -from ..utils.json_utils import safe_parse_llm_json -from ..core.interfaces import IDataStorage, IPersonaManager, ServiceLifecycle -from ..core.framework_llm_adapter import FrameworkLLMAdapter +from ...config import PluginConfig +from ...core.patterns import AsyncServiceBase +from ...utils.json_utils import safe_parse_llm_json +from ...core.interfaces import IDataStorage, IPersonaManager, ServiceLifecycle +from ...core.framework_llm_adapter import FrameworkLLMAdapter @dataclass @@ -837,9 +837,9 @@ def _find_related_entities(self, entity: KnowledgeEntity) -> List[KnowledgeEntit for neighbor_id in neighbors[:3]: # 限制数量 if neighbor_id in self.knowledge_entities: related.append(self.knowledge_entities[neighbor_id]) - except: + except (KeyError, AttributeError): pass - + return related def _filter_recommendations_by_rate(self, recommendations: List[PersonalizedRecommendation], diff --git a/services/intelligence_metrics.py b/services/analysis/intelligence_metrics.py similarity index 99% rename from services/intelligence_metrics.py rename to services/analysis/intelligence_metrics.py index 328b583..7dfc14b 100644 --- a/services/intelligence_metrics.py +++ b/services/analysis/intelligence_metrics.py @@ -9,8 +9,8 @@ from datetime import datetime, timedelta from astrbot.api import logger -from ..config import PluginConfig -from ..utils.json_utils import safe_parse_llm_json +from ...config import PluginConfig +from ...utils.json_utils import safe_parse_llm_json @dataclass diff --git a/services/ml_analyzer.py b/services/analysis/ml_analyzer.py similarity index 96% rename from services/ml_analyzer.py rename to services/analysis/ml_analyzer.py index 2e95453..c4ca42e 100644 --- a/services/ml_analyzer.py +++ b/services/analysis/ml_analyzer.py @@ -22,15 +22,15 @@ from astrbot.api import logger -from ..config import PluginConfig +from ...config import PluginConfig -from ..exceptions import StyleAnalysisError +from ...exceptions import StyleAnalysisError -from ..core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 +from ...core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 -from .database_manager import DatabaseManager # 确保 DatabaseManager 被正确导入 +from ..database import DatabaseManager # 确保 DatabaseManager 被正确导入 -from ..utils.json_utils import safe_parse_llm_json, clean_llm_json_response +from ...utils.json_utils import safe_parse_llm_json, clean_llm_json_response class LightweightMLAnalyzer: @@ -41,15 +41,15 @@ def __init__(self, config: PluginConfig, db_manager: DatabaseManager, prompts: Any = None, temporary_persona_updater = None): # 使用框架适配器替代LLMClient self.config = config self.db_manager = db_manager - self.llm_adapter = llm_adapter # 使用框架适配器 + self.llm_adapter = llm_adapter # 使用框架适配器 self.prompts = prompts # 保存 prompts self.temporary_persona_updater = temporary_persona_updater # 保存临时人格更新器引用 # 设置分析限制以节省资源 - self.max_sample_size = 100 # 最大样本数量 - self.max_features = 50 # 最大特征数量 - self.analysis_cache = {} # 分析结果缓存 - self.cache_timeout = 3600 # 缓存1小时 + self.max_sample_size = 100 # 最大样本数量 + self.max_features = 50 # 最大特征数量 + self.analysis_cache = {} # 分析结果缓存 + self.cache_timeout = 3600 # 缓存1小时 if not SKLEARN_AVAILABLE: logger.warning("scikit-learn未安装,将使用基础统计分析") @@ -125,7 +125,7 @@ async def reinforcement_memory_replay(self, group_id: str, new_messages: List[Di try: reinforcement_result = safe_parse_llm_json(clean_response) - # ✅ 检查解析结果是否为None + # 检查解析结果是否为None if not reinforcement_result: logger.warning("强化学习记忆重放解析结果为空") return {} @@ -240,7 +240,7 @@ async def reinforcement_strategy_optimization(self, group_id: str) -> Dict[str, """ 强化学习策略优化:基于历史表现数据动态调整学习策略 """ - if (not self.llm_adapter or not self.llm_adapter.has_reinforce_provider()) and self.llm_adapter.providers_configured < 3: + if (not self.llm_adapter or not self.llm_adapter.has_reinforce_provider()) and self.llm_adapter.providers_configured < 3: logger.warning("强化模型未配置,跳过策略优化功能") return {} @@ -343,7 +343,7 @@ async def replay_memory(self, group_id: str, new_messages: List[Dict[str, Any]], 记忆重放:将历史数据与新数据混合,并交给提炼模型进行处理。 这模拟了LLM的"增量微调"过程,通过重新暴露历史数据来巩固学习。 """ - if (not self.llm_adapter or not self.llm_adapter.has_refine_provider()) and self.llm_adapter.providers_configured < 2: + if (not self.llm_adapter or not self.llm_adapter.has_refine_provider()) and self.llm_adapter.providers_configured < 2: logger.warning("提炼模型未配置,跳过记忆重放功能") return [] @@ -645,10 +645,10 @@ async def _get_user_messages(self, group_id: str, user_id: str, limit: int) -> L """获取用户消息(限制数量)""" try: from sqlalchemy import select, desc, and_ - from ..models.orm import RawMessage + from ...models.orm import RawMessage async with self.db_manager.get_session() as session: - cutoff_time = time.time() - 86400 * 7 # 最近7天 + cutoff_time = time.time() - 86400 * 7 # 最近7天 stmt = ( select(RawMessage) .where(and_( @@ -710,7 +710,7 @@ def _analyze_message_frequency(self, messages: List[Dict[str, Any]]) -> Dict[str for i in range(1, len(sorted_messages)): interval = sorted_messages[i]['timestamp'] - sorted_messages[i-1]['timestamp'] - intervals.append(interval / 60) # 转换为分钟 + intervals.append(interval / 60) # 转换为分钟 if not intervals: return {} @@ -718,7 +718,7 @@ def _analyze_message_frequency(self, messages: List[Dict[str, Any]]) -> Dict[str return { 'avg_interval_minutes': np.mean(intervals), 'interval_std': np.std(intervals), - 'burst_tendency': len([x for x in intervals if x < 5]) / len(intervals) # 5分钟内连续消息比例 + 'burst_tendency': len([x for x in intervals if x < 5]) / len(intervals) # 5分钟内连续消息比例 } async def _analyze_interaction_patterns(self, group_id: str, user_id: str, messages: List[Dict[str, Any]]) -> Dict[str, Any]: @@ -758,8 +758,8 @@ def _analyze_topic_clusters(self, messages: List[Dict[str, Any]]) -> Dict[str, A # TF-IDF向量化(限制特征数量) vectorizer = TfidfVectorizer( max_features=min(self.max_features, len(texts) * 2), - stop_words=None, # 不使用停用词以节省内存 - ngram_range=(1, 1) # 只使用单词 + stop_words=None, # 不使用停用词以节省内存 + ngram_range=(1, 1) # 只使用单词 ) tfidf_matrix = vectorizer.fit_transform(texts) @@ -775,7 +775,7 @@ def _analyze_topic_clusters(self, messages: List[Dict[str, Any]]) -> Dict[str, A # 分析聚类结果 clusters = defaultdict(list) for i, label in enumerate(cluster_labels): - clusters[int(label)].append(texts[i][:50]) # 限制文本长度 + clusters[int(label)].append(texts[i][:50]) # 限制文本长度 # 提取关键词 feature_names = vectorizer.get_feature_names_out() @@ -783,7 +783,7 @@ def _analyze_topic_clusters(self, messages: List[Dict[str, Any]]) -> Dict[str, A for i in range(n_clusters): center = kmeans.cluster_centers_[i] - top_indices = center.argsort()[-5:][::-1] # 前5个关键词 + top_indices = center.argsort()[-5:][::-1] # 前5个关键词 cluster_keywords[i] = [feature_names[idx] for idx in top_indices] return { @@ -833,10 +833,10 @@ async def _get_recent_group_messages(self, group_id: str, limit: int) -> List[Di """获取群聊最近消息""" try: from sqlalchemy import select, desc, and_ - from ..models.orm import RawMessage + from ...models.orm import RawMessage async with self.db_manager.get_session() as session: - cutoff_time = time.time() - 3600 * 6 # 最近6小时 + cutoff_time = time.time() - 3600 * 6 # 最近6小时 stmt = ( select(RawMessage) .where(and_( @@ -900,8 +900,8 @@ def _simple_sentiment_analysis(self, messages: List[Dict[str, Any]]) -> Dict[str # 确保消息列表已经过滤掉None值 filtered_messages = [msg for msg in messages if msg is not None] - positive_keywords = ['哈哈', '好的', '谢谢', '赞', '棒', '开心', '高兴', '😊', '👍', '❤️'] - negative_keywords = ['不行', '差', '烦', '无聊', '生气', '😢', '😡', '💔'] + positive_keywords = ['哈哈', '好的', '谢谢', '赞', '棒', '开心', '高兴', '', '', ''] + negative_keywords = ['不行', '差', '烦', '无聊', '生气', '', '', ''] positive_count = 0 negative_count = 0 @@ -1014,10 +1014,10 @@ async def _get_most_active_users(self, group_id: str, limit: int) -> List[Dict[s """获取最活跃用户""" try: from sqlalchemy import select, desc, func, and_ - from ..models.orm import RawMessage + from ...models.orm import RawMessage async with self.db_manager.get_session() as session: - cutoff_time = time.time() - 86400 # 最近24小时 + cutoff_time = time.time() - 86400 # 最近24小时 stmt = ( select( RawMessage.sender_id, diff --git a/services/multidimensional_analyzer.py b/services/analysis/multidimensional_analyzer.py similarity index 95% rename from services/multidimensional_analyzer.py rename to services/analysis/multidimensional_analyzer.py index 76f9bed..89a704e 100644 --- a/services/multidimensional_analyzer.py +++ b/services/analysis/multidimensional_analyzer.py @@ -14,15 +14,15 @@ from astrbot.api import logger from astrbot.api.event import AstrMessageEvent -from ..config import PluginConfig +from ...config import PluginConfig -from ..exceptions import StyleAnalysisError +from ...exceptions import StyleAnalysisError -from ..core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 +from ...core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 -from .database_manager import DatabaseManager +from ..database import DatabaseManager -from ..utils.json_utils import safe_parse_llm_json +from ...utils.json_utils import safe_parse_llm_json @dataclass @@ -36,7 +36,7 @@ class UserProfile: social_connections: List[str] = None topic_preferences: Dict[str, float] = None emotional_tendency: Dict[str, float] = None - last_active: float = None # 添加缺失的字段 + last_active: float = None # 添加缺失的字段 def __post_init__(self): if self.nicknames is None: @@ -60,16 +60,16 @@ class SocialRelation: """社交关系""" from_user: str to_user: str - relation_type: str # mention, reply, frequent_interaction - strength: float # 关系强度 0-1 - frequency: int # 交互频次 + relation_type: str # mention, reply, frequent_interaction + strength: float # 关系强度 0-1 + frequency: int # 交互频次 last_interaction: str @dataclass class ContextualPattern: """情境模式""" - context_type: str # time_based, topic_based, social_based + context_type: str # time_based, topic_based, social_based pattern_name: str triggers: List[str] characteristics: Dict[str, Any] @@ -94,13 +94,13 @@ def __init__(self, config: PluginConfig, db_manager: DatabaseManager, context=No # 友好的配置状态提示 if self.llm_adapter: if not self.llm_adapter.has_filter_provider(): - logger.info("💡 筛选模型未配置,将使用简化算法进行消息筛选") + logger.info(" 筛选模型未配置,将使用简化算法进行消息筛选") if not self.llm_adapter.has_refine_provider(): - logger.info("💡 提炼模型未配置,将使用简化算法进行深度分析") + logger.info(" 提炼模型未配置,将使用简化算法进行深度分析") if not self.llm_adapter.has_reinforce_provider(): - logger.info("💡 强化模型未配置,将跳过强化学习功能") + logger.info(" 强化模型未配置,将跳过强化学习功能") else: - logger.info("💡 框架LLM适配器未配置,将使用简化算法进行分析") + logger.info(" 框架LLM适配器未配置,将使用简化算法进行分析") # 用户画像存储 self.user_profiles: Dict[str, UserProfile] = {} @@ -109,7 +109,7 @@ def __init__(self, config: PluginConfig, db_manager: DatabaseManager, context=No self.social_graph: Dict[str, List[SocialRelation]] = defaultdict(list) # 昵称映射表 - self.nickname_mapping: Dict[str, str] = {} # nickname -> qq_id + self.nickname_mapping: Dict[str, str] = {} # nickname -> qq_id # 情境模式库 self.contextual_patterns: List[ContextualPattern] = [] @@ -149,7 +149,7 @@ async def start(self): # 初始化分析缓存 self._analysis_cache = {} - self._cache_timeout = 3600 # 1小时缓存 + self._cache_timeout = 3600 # 1小时缓存 # 启动定期清理任务 self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) @@ -163,24 +163,29 @@ async def start(self): async def _load_user_profiles_from_db(self): """从数据库加载用户画像""" try: - # 获取所有活跃群组 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 查询最近活跃的用户 - await cursor.execute(''' - SELECT group_id, sender_id, MAX(sender_name) as sender_name, COUNT(*) as msg_count - FROM raw_messages - WHERE timestamp > ? - GROUP BY group_id, sender_id - HAVING msg_count >= 5 - ORDER BY msg_count DESC - LIMIT 500 - ''', (time.time() - 7 * 24 * 3600,)) # 最近7天 - - users = await cursor.fetchall() - - for group_id, sender_id, sender_name, msg_count in users: + cutoff = time.time() - 7 * 24 * 3600 # 最近7天 + + async with self.db_manager.get_session() as session: + from sqlalchemy import select, func + from ...models.orm.message import RawMessage + + stmt = ( + select( + RawMessage.group_id, + RawMessage.sender_id, + func.max(RawMessage.sender_name).label('sender_name'), + func.count().label('msg_count'), + ) + .where(RawMessage.timestamp > cutoff) + .group_by(RawMessage.group_id, RawMessage.sender_id) + .having(func.count() >= 5) + .order_by(func.count().desc()) + .limit(500) + ) + result = await session.execute(stmt) + rows = result.fetchall() + + for group_id, sender_id, sender_name, msg_count in rows: if group_id and sender_id: user_key = f"{group_id}:{sender_id}" self.user_profiles[user_key] = { @@ -193,11 +198,9 @@ async def _load_user_profiles_from_db(self): 'last_activity': time.time(), 'created_at': time.time() } - - await cursor.close() - + logger.info(f"从数据库加载了 {len(self.user_profiles)} 个用户画像") - + except Exception as e: logger.error(f"从数据库加载用户画像失败: {e}") @@ -206,27 +209,32 @@ async def _load_social_relations_from_db(self): try: # 初始化社交图谱 self.social_graph = {} - - # 分析用户间的交互关系 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 查询用户在同一群组中的交互 - await cursor.execute(''' - SELECT group_id, sender_id, COUNT(*) as interaction_count - FROM raw_messages - WHERE timestamp > ? AND group_id IS NOT NULL - GROUP BY group_id, sender_id - HAVING interaction_count >= 3 - ''', (time.time() - 7 * 24 * 3600,)) - - interactions = await cursor.fetchall() - + + cutoff = time.time() - 7 * 24 * 3600 # 最近7天 + + async with self.db_manager.get_session() as session: + from sqlalchemy import select, func + from ...models.orm.message import RawMessage + + stmt = ( + select( + RawMessage.group_id, + RawMessage.sender_id, + func.count().label('interaction_count'), + ) + .where(RawMessage.timestamp > cutoff) + .where(RawMessage.group_id.isnot(None)) + .group_by(RawMessage.group_id, RawMessage.sender_id) + .having(func.count() >= 3) + ) + result = await session.execute(stmt) + rows = result.fetchall() + # 构建基础社交关系 - for group_id, sender_id, count in interactions: + for group_id, sender_id, count in rows: if sender_id not in self.social_graph: self.social_graph[sender_id] = [] - + # 为简化,暂时记录用户在各群组的活跃度 relation_info = { 'target_user': group_id, @@ -235,11 +243,9 @@ async def _load_social_relations_from_db(self): 'last_interaction': time.time() } self.social_graph[sender_id].append(relation_info) - - await cursor.close() - + logger.info(f"构建了 {len(self.social_graph)} 个用户的社交关系") - + except Exception as e: logger.error(f"加载社交关系失败: {e}") @@ -247,7 +253,7 @@ async def _periodic_cleanup(self): """定期清理过期缓存和数据""" try: while True: - await asyncio.sleep(3600) # 每小时执行一次 + await asyncio.sleep(3600) # 每小时执行一次 current_time = time.time() @@ -264,7 +270,7 @@ async def _periodic_cleanup(self): logger.debug(f"清理了 {len(expired_keys)} 个过期的分析缓存") # 清理过期的用户活动记录 - cutoff_time = current_time - 30 * 24 * 3600 # 30天前 + cutoff_time = current_time - 30 * 24 * 3600 # 30天前 expired_users = [ k for k, v in self.user_profiles.items() if v.get('last_activity', 0) < cutoff_time @@ -478,7 +484,7 @@ async def analyze_message_context(self, event: AstrMessageEvent, message_text: s sender_id = event.get_sender_id() sender_name = event.get_sender_name() - group_id = event.get_group_id() or event.get_sender_id() # 私聊时使用 sender_id 作为会话 ID + group_id = event.get_group_id() or event.get_sender_id() # 私聊时使用 sender_id 作为会话 ID # 预先清理user_profiles中的任何问题数据 self._clean_user_profiles() @@ -626,7 +632,7 @@ async def analyze_message_batch(self, # 分析沟通风格(添加限制) style_context = {} - if self._batch_analysis_count[hour_key] <= 50: # 限制风格分析的调用次数 + if self._batch_analysis_count[hour_key] <= 50: # 限制风格分析的调用次数 style_context = await self._analyze_communication_style(message_text) else: # 使用简化的风格分析 @@ -878,7 +884,7 @@ async def _analyze_social_context(self, event: AstrMessageEvent, message_text: s """分析社交关系上下文""" try: sender_id = event.get_sender_id() - group_id = event.get_group_id() or event.get_sender_id() # 私聊时使用 sender_id 作为会话 ID + group_id = event.get_group_id() or event.get_sender_id() # 私聊时使用 sender_id 作为会话 ID social_context = { 'mentions': [], @@ -913,7 +919,7 @@ async def _analyze_social_context(self, event: AstrMessageEvent, message_text: s else: logger.debug(f"[社交关系] 消息事件不支持get_reply_info或没有回复信息") - # === 新增:基于时间窗口的对话关系分析(去除@限制) === + # 新增:基于时间窗口的对话关系分析(去除@限制) await self._analyze_conversation_interactions(sender_id, group_id, message_text) # 计算与群内成员的交互强度 @@ -952,7 +958,7 @@ async def _analyze_emotional_context(self, message_text: str) -> Dict[str, float cache_key = f"emotion_cache_{hash(message_text)}" if hasattr(self, '_analysis_cache') and cache_key in self._analysis_cache: cached_result = self._analysis_cache[cache_key] - if time.time() - cached_result.get('timestamp', 0) < 300: # 5分钟缓存 + if time.time() - cached_result.get('timestamp', 0) < 300: # 5分钟缓存 logger.debug(f"使用缓存的情感分析结果") return cached_result.get('result', self._simple_emotional_analysis(message_text)) @@ -1006,11 +1012,11 @@ async def _analyze_emotional_context(self, message_text: str) -> Dict[str, float def _simple_emotional_analysis(self, message_text: str) -> Dict[str, float]: """简化的情感分析(备用)""" emotions = { - '积极': ['开心', '高兴', '兴奋', '满意', '喜欢', '爱', '好棒', '太好了', '哈哈', '😄', '😊', '👍'], - '消极': ['难过', '生气', '失望', '无聊', '烦', '讨厌', '糟糕', '不好', '😭', '😢', '😡'], + '积极': ['开心', '高兴', '兴奋', '满意', '喜欢', '爱', '好棒', '太好了', '哈哈', '', '', ''], + '消极': ['难过', '生气', '失望', '无聊', '烦', '讨厌', '糟糕', '不好', '', '', ''], '中性': ['知道', '明白', '可以', '好的', '嗯', '哦', '这样', '然后'], - '疑问': ['吗', '呢', '?', '什么', '怎么', '为什么', '哪里', '🤔'], - '惊讶': ['哇', '天哪', '真的', '不会吧', '太', '竟然', '居然', '😱', '😯'] + '疑问': ['吗', '呢', '?', '什么', '怎么', '为什么', '哪里', ''], + '惊讶': ['哇', '天哪', '真的', '不会吧', '太', '竟然', '居然', '', ''] } emotion_scores = {} @@ -1049,7 +1055,7 @@ async def _analyze_communication_style(self, message_text: str) -> Dict[str, flo cache_key = f"style_cache_{hash(message_text)}" if hasattr(self, '_analysis_cache') and cache_key in self._analysis_cache: cached_result = self._analysis_cache[cache_key] - if time.time() - cached_result.get('timestamp', 0) < 600: # 10分钟缓存 + if time.time() - cached_result.get('timestamp', 0) < 600: # 10分钟缓存 logger.debug(f"使用缓存的风格分析结果") return cached_result.get('result', {}) @@ -1195,7 +1201,7 @@ async def _analyze_conversation_interactions(self, sender_id: str, group_id: str # 获取最近5分钟内的消息 recent_messages = await self.db_manager.get_messages_by_group_and_timerange( group_id=group_id, - start_time=time.time() - 300, # 5分钟 + start_time=time.time() - 300, # 5分钟 limit=20 ) @@ -1204,7 +1210,7 @@ async def _analyze_conversation_interactions(self, sender_id: str, group_id: str # 找到当前用户之前的最近一条其他人的消息 previous_sender = None - for msg in reversed(recent_messages): # 按时间倒序 + for msg in reversed(recent_messages): # 按时间倒序 if msg['sender_id'] != sender_id and msg['sender_id'] != 'bot': previous_sender = msg['sender_id'] previous_message = msg['message'] @@ -1480,7 +1486,7 @@ async def _calculate_enthusiasm_level(self, text: str) -> float: def _simple_enthusiasm_level(self, text: str) -> float: """简化的热情程度计算(备用)""" - enthusiasm_indicators = ['!', '!', '哈哈', '太好了', '棒', '赞', '😄', '😊', '🎉', '厉害', 'awesome'] + enthusiasm_indicators = ['!', '!', '哈哈', '太好了', '棒', '赞', '', '', '', '厉害', 'awesome'] count = sum(text.count(indicator) for indicator in enthusiasm_indicators) return min(count / max(len(text), 1) * 20, 1.0) @@ -1678,10 +1684,10 @@ def _simple_personality_analysis(self, profile) -> Dict[str, float]: return { "openness": min(openness, 1.0), - "conscientiousness": 0.6, # 默认值 + "conscientiousness": 0.6, # 默认值 "extraversion": extraversion, - "agreeableness": 0.7, # 默认值 - "neuroticism": 0.3 # 默认值 + "agreeableness": 0.7, # 默认值 + "neuroticism": 0.3 # 默认值 } async def _analyze_social_behavior(self, qq_id: str) -> Dict[str, Any]: diff --git a/services/commands/__init__.py b/services/commands/__init__.py new file mode 100644 index 0000000..ad22416 --- /dev/null +++ b/services/commands/__init__.py @@ -0,0 +1,9 @@ +"""命令处理器 — 命令检测过滤 + 业务逻辑实现""" + +from .command_filter import CommandFilter +from .handlers import PluginCommandHandlers + +__all__ = [ + "CommandFilter", + "PluginCommandHandlers", +] diff --git a/services/commands/command_filter.py b/services/commands/command_filter.py new file mode 100644 index 0000000..e2f53ac --- /dev/null +++ b/services/commands/command_filter.py @@ -0,0 +1,54 @@ +"""AstrBot 命令检测过滤器 — 区分命令消息与普通消息""" +import re +from typing import Any + + +class CommandFilter: + """判断消息是否为 AstrBot 命令或本插件命令""" + + PLUGIN_COMMANDS = [ + "learning_status", + "start_learning", + "stop_learning", + "force_learning", + "affection_status", + "set_mood", + ] + + def is_astrbot_command(self, event: Any) -> bool: + """判断用户输入是否为 AstrBot 命令(包括插件命令和其他命令) + + 注意:唤醒词消息(is_at_or_wake_command)应该被收集用于学习, + 因为这些是最有价值的对话数据。只过滤明确的命令格式。 + """ + message_text = event.get_message_str() + if not message_text: + return False + + if self.is_plugin_command(message_text): + return True + + command_prefixes = ["/", "!", "#", "."] + stripped_text = message_text.strip() + if stripped_text and stripped_text[0] in command_prefixes: + if len(stripped_text) > 1 and stripped_text[1].isalpha(): + return True + + return False + + def is_plugin_command(self, message_text: str) -> bool: + """检查消息是否为本插件的命令""" + if not message_text: + return False + + message_text = message_text.strip() + + commands_pattern = "|".join(re.escape(cmd) for cmd in self.PLUGIN_COMMANDS) + pattern_with_prefix = rf"^.{{1}}({commands_pattern})(\s.*)?$" + pattern_without_prefix = rf"^({commands_pattern})(\s.*)?$" + + return bool( + re.match(pattern_with_prefix, message_text, re.IGNORECASE) + ) or bool( + re.match(pattern_without_prefix, message_text, re.IGNORECASE) + ) diff --git a/services/commands/handlers.py b/services/commands/handlers.py new file mode 100644 index 0000000..cb70917 --- /dev/null +++ b/services/commands/handlers.py @@ -0,0 +1,405 @@ +"""插件命令业务逻辑实现 — 6 个 admin 命令的处理体""" +import time +from typing import Any, AsyncGenerator + +from astrbot.api import logger + +from ...statics.messages import CommandMessages, LogMessages + + +class PluginCommandHandlers: + """6 个 @filter.command 命令的业务逻辑(从 main.py 提取)""" + + def __init__( + self, + plugin_config: Any, + service_factory: Any, + message_collector: Any, + persona_manager: Any, + progressive_learning: Any, + affection_manager: Any, + temporary_persona_updater: Any, + db_manager: Any, + llm_adapter: Any, + ): + self._config = plugin_config + self._service_factory = service_factory + self._message_collector = message_collector + self._persona_manager = persona_manager + self._progressive_learning = progressive_learning + self._affection_manager = affection_manager + self._temporary_persona_updater = temporary_persona_updater + self._db_manager = db_manager + self._llm_adapter = llm_adapter + self._force_learning_in_progress: set = set() + + # learning_status + + async def learning_status(self, event: Any) -> AsyncGenerator: + """查看学习状态""" + try: + group_id = event.get_group_id() or event.get_sender_id() + + collector_stats = await self._message_collector.get_statistics(group_id) + if collector_stats is None: + collector_stats = { + "total_messages": 0, + "filtered_messages": 0, + "raw_messages": 0, + "unprocessed_messages": 0, + } + + current_persona_info = await self._persona_manager.get_current_persona(group_id) + current_persona_name = CommandMessages.STATUS_UNKNOWN + if current_persona_info and isinstance(current_persona_info, dict): + current_persona_name = current_persona_info.get("name", CommandMessages.STATUS_UNKNOWN) + + learning_status = await self._progressive_learning.get_learning_status() + if learning_status is None: + learning_status = { + "learning_active": False, + "current_session": None, + "total_sessions": 0, + } + + status_info = CommandMessages.STATUS_REPORT_HEADER.format(group_id=group_id) + + persona_update_mode = ( + "PersonaManager模式" + if self._config.use_persona_manager_updates + else "传统文件模式" + ) + status_info += CommandMessages.STATUS_BASIC_CONFIG.format( + message_capture=( + CommandMessages.STATUS_ENABLED + if self._config.enable_message_capture + else CommandMessages.STATUS_DISABLED + ), + auto_learning=( + CommandMessages.STATUS_ENABLED + if self._config.enable_auto_learning + else CommandMessages.STATUS_DISABLED + ), + realtime_learning=( + CommandMessages.STATUS_ENABLED + if self._config.enable_realtime_learning + else CommandMessages.STATUS_DISABLED + ), + web_interface=( + CommandMessages.STATUS_ENABLED + if self._config.enable_web_interface + else CommandMessages.STATUS_DISABLED + ), + ) + + status_info += f"\n\n 人格更新配置:\n" + status_info += f"• 更新方式: {persona_update_mode}\n" + if self._config.use_persona_manager_updates: + persona_manager_updater = self._service_factory.create_persona_manager_updater() + pm_status = " 可用" if persona_manager_updater.is_available() else " 不可用" + status_info += f"• PersonaManager状态: {pm_status}\n" + status_info += f"• 自动应用更新: {'启用' if self._config.auto_apply_persona_updates else '禁用'}\n" + status_info += f"• 更新前备份: {'启用' if self._config.persona_update_backup_enabled else '禁用'}\n" + + status_info += CommandMessages.STATUS_CAPTURE_SETTINGS.format( + target_qq=( + self._config.target_qq_list + if self._config.target_qq_list + else CommandMessages.STATUS_ALL_USERS + ), + current_persona=current_persona_name, + ) + + if self._llm_adapter: + provider_info = self._llm_adapter.get_provider_info() + status_info += CommandMessages.STATUS_MODEL_CONFIG.format( + filter_model=provider_info.get("filter", "未配置"), + refine_model=provider_info.get("refine", "未配置"), + ) + else: + status_info += CommandMessages.STATUS_MODEL_CONFIG.format( + filter_model="未配置框架Provider", + refine_model="未配置框架Provider", + ) + + current_session = learning_status.get("current_session") or {} + status_info += CommandMessages.STATUS_LEARNING_STATS.format( + total_messages=collector_stats.get("total_messages", 0), + filtered_messages=collector_stats.get("filtered_messages", 0), + style_updates=current_session.get("style_updates", 0), + last_learning_time=current_session.get( + "end_time", CommandMessages.STATUS_NEVER_EXECUTED + ), + ) + + status_info += CommandMessages.STATUS_STORAGE_STATS.format( + raw_messages=collector_stats.get("raw_messages", 0), + unprocessed_messages=collector_stats.get("unprocessed_messages", 0), + filtered_messages=collector_stats.get("filtered_messages", 0), + ) + + scheduler_status = ( + CommandMessages.STATUS_RUNNING + if learning_status.get("learning_active") + else CommandMessages.STATUS_STOPPED + ) + status_info += "\n\n" + CommandMessages.STATUS_SCHEDULER.format( + status=scheduler_status + ) + + yield event.plain_result(status_info.strip()) + + except Exception as e: + logger.error( + CommandMessages.ERROR_GET_LEARNING_STATUS.format(error=e), + exc_info=True, + ) + yield event.plain_result( + CommandMessages.STATUS_QUERY_FAILED.format(error=str(e)) + ) + + # start_learning + + async def start_learning(self, event: Any) -> AsyncGenerator: + """手动启动学习""" + try: + group_id = event.get_group_id() or event.get_sender_id() + + stats = await self._message_collector.get_statistics(group_id) + unprocessed_count = stats.get("unprocessed_messages", 0) + + if unprocessed_count < self._config.min_messages_for_learning: + yield event.plain_result( + f" 未处理消息数量不足" + f"({unprocessed_count}/{self._config.min_messages_for_learning})," + f"无法开始学习" + ) + return + + yield event.plain_result( + f" 开始执行学习批次,处理 {unprocessed_count} 条未处理消息..." + ) + + try: + await self._progressive_learning._execute_learning_batch(group_id) + yield event.plain_result(" 学习批次执行完成") + except Exception as batch_error: + yield event.plain_result(f" 学习批次执行失败: {str(batch_error)}") + + except Exception as e: + logger.error( + CommandMessages.ERROR_START_LEARNING.format(error=e), exc_info=True + ) + yield event.plain_result( + CommandMessages.STARTUP_FAILED.format(error=str(e)) + ) + + # stop_learning + + async def stop_learning(self, event: Any) -> AsyncGenerator: + """停止学习""" + try: + group_id = event.get_group_id() or event.get_sender_id() + await self._progressive_learning.stop_learning() + yield event.plain_result( + CommandMessages.LEARNING_STOPPED.format(group_id=group_id) + ) + except Exception as e: + logger.error( + CommandMessages.ERROR_STOP_LEARNING.format(error=e), exc_info=True + ) + yield event.plain_result( + CommandMessages.STOP_FAILED.format(error=str(e)) + ) + + # force_learning + + async def force_learning(self, event: Any) -> AsyncGenerator: + """强制执行一次学习周期""" + try: + group_id = event.get_group_id() or event.get_sender_id() + yield event.plain_result( + CommandMessages.FORCE_LEARNING_START.format(group_id=group_id) + ) + + if group_id in self._force_learning_in_progress: + yield event.plain_result( + f" 群组 {group_id} 的强制学习正在进行中,请等待完成" + ) + return + + self._force_learning_in_progress.add(group_id) + try: + await self._progressive_learning._execute_learning_batch(group_id) + yield event.plain_result( + CommandMessages.FORCE_LEARNING_COMPLETE.format(group_id=group_id) + ) + finally: + self._force_learning_in_progress.discard(group_id) + + except Exception as e: + logger.error( + CommandMessages.ERROR_FORCE_LEARNING.format(error=e), exc_info=True + ) + yield event.plain_result( + CommandMessages.ERROR_FORCE_LEARNING.format(error=str(e)) + ) + + # affection_status + + async def affection_status(self, event: Any) -> AsyncGenerator: + """查看好感度状态""" + try: + group_id = event.get_group_id() or event.get_sender_id() + user_id = event.get_sender_id() + + if not self._config.enable_affection_system: + yield event.plain_result(CommandMessages.AFFECTION_DISABLED) + return + + affection_status = await self._affection_manager.get_affection_status(group_id) + + current_mood = None + if self._config.enable_startup_random_mood: + current_mood = await self._affection_manager.ensure_mood_for_group(group_id) + else: + current_mood = await self._affection_manager.get_current_mood(group_id) + + user_affection = await self._db_manager.get_user_affection(group_id, user_id) + user_level = user_affection["affection_level"] if user_affection else 0 + + status_info = CommandMessages.AFFECTION_STATUS_HEADER.format(group_id=group_id) + status_info += "\n\n" + CommandMessages.AFFECTION_USER_LEVEL.format( + user_level=user_level, max_affection=self._config.max_user_affection + ) + status_info += "\n" + CommandMessages.AFFECTION_TOTAL_STATUS.format( + total_affection=affection_status["total_affection"], + max_total_affection=affection_status["max_total_affection"], + ) + status_info += "\n" + CommandMessages.AFFECTION_USER_COUNT.format( + user_count=affection_status["user_count"] + ) + status_info += "\n\n" + CommandMessages.AFFECTION_CURRENT_MOOD + + if current_mood: + mood_info = current_mood + status_info += "\n" + CommandMessages.AFFECTION_MOOD_TYPE.format( + mood_type=mood_info.mood_type.value + ) + status_info += "\n" + CommandMessages.AFFECTION_MOOD_INTENSITY.format( + intensity=mood_info.intensity + ) + status_info += "\n" + CommandMessages.AFFECTION_MOOD_DESCRIPTION.format( + description=mood_info.description + ) + else: + status_info += "\n" + CommandMessages.AFFECTION_NO_MOOD + + if affection_status["top_users"]: + status_info += "\n\n" + CommandMessages.AFFECTION_TOP_USERS + for i, user in enumerate(affection_status["top_users"][:3], 1): + status_info += "\n" + CommandMessages.AFFECTION_USER_RANK.format( + rank=i, + user_id=user["user_id"], + affection_level=user["affection_level"], + ) + + yield event.plain_result(status_info) + + except Exception as e: + logger.error( + CommandMessages.ERROR_GET_AFFECTION_STATUS.format(error=e), + exc_info=True, + ) + yield event.plain_result( + CommandMessages.ERROR_GET_AFFECTION_STATUS.format(error=str(e)) + ) + + # set_mood + + async def set_mood(self, event: Any) -> AsyncGenerator: + """手动设置 bot 情绪(通过增量人格更新)""" + try: + if not self._config.enable_affection_system: + yield event.plain_result(CommandMessages.AFFECTION_DISABLED) + return + + args = event.get_message_str().split()[1:] + if len(args) < 1: + yield event.plain_result( + "使用方法:/set_mood \n" + "可用情绪: happy, sad, excited, calm, angry, " + "anxious, playful, serious, nostalgic, curious" + ) + return + + group_id = event.get_group_id() or event.get_sender_id() + mood_type = args[0].lower() + + valid_moods = { + "happy": "心情很好,说话比较活泼开朗,容易表达正面情感", + "sad": "心情有些低落,说话比较温和,需要更多的理解和安慰", + "excited": "很兴奋,说话比较有活力,对很多事情都很感兴趣", + "calm": "心情平静,说话比较稳重,给人安全感", + "angry": "心情不太好,说话可能比较直接,不太有耐心", + "anxious": "有些紧张不安,说话可能比较谨慎,需要更多确认", + "playful": "心情很调皮,喜欢开玩笑,说话比较幽默风趣", + "serious": "比较严肃认真,说话简洁直接,专注于重要的事情", + "nostalgic": "有些怀旧情绪,说话带有回忆色彩,比较感性", + "curious": "对很多事情都很好奇,喜欢提问和探索新事物", + } + + if mood_type not in valid_moods: + yield event.plain_result( + f" 无效的情绪类型。支持的情绪: {', '.join(valid_moods.keys())}" + ) + return + + mood_description = valid_moods[mood_type] + + persona_success = ( + await self._temporary_persona_updater.apply_mood_based_persona_update( + group_id, mood_type, mood_description + ) + ) + + # 同时在 affection_manager 中记录情绪状态 + from ...services.state import MoodType, BotMood + + affection_success = False + try: + mood_enum = MoodType(mood_type) + await self._affection_manager.db_manager.save_bot_mood( + group_id, + mood_type, + 0.7, + mood_description, + self._config.mood_persistence_hours or 24, + ) + mood_obj = BotMood( + mood_type=mood_enum, + intensity=0.7, + description=mood_description, + start_time=time.time(), + duration_hours=self._config.mood_persistence_hours or 24, + ) + self._affection_manager.current_moods[group_id] = mood_obj + affection_success = True + except Exception as e: + logger.warning(f"设置 affection_manager 情绪失败: {e}") + + if persona_success: + status_msg = f" 情绪状态已设置为: {mood_type}\n描述: {mood_description}" + if not affection_success: + status_msg += "\n 注意:情绪状态可能无法在状态查询中正确显示" + yield event.plain_result(status_msg) + else: + yield event.plain_result(" 设置情绪状态失败") + + except Exception as e: + logger.error( + CommandMessages.ERROR_SET_MOOD.format(error=e), exc_info=True + ) + yield event.plain_result( + CommandMessages.ERROR_SET_MOOD.format(error=str(e)) + ) diff --git a/services/core_learning/__init__.py b/services/core_learning/__init__.py new file mode 100644 index 0000000..5aec69b --- /dev/null +++ b/services/core_learning/__init__.py @@ -0,0 +1,14 @@ +"""Core learning engines -- progressive, advanced, V2, message collection.""" + +from .progressive_learning import ProgressiveLearningService, LearningSession +from .advanced_learning import AdvancedLearningService +from .v2_learning_integration import V2LearningIntegration +from .message_collector import MessageCollectorService + +__all__ = [ + "ProgressiveLearningService", + "LearningSession", + "AdvancedLearningService", + "V2LearningIntegration", + "MessageCollectorService", +] diff --git a/services/advanced_learning.py b/services/core_learning/advanced_learning.py similarity index 98% rename from services/advanced_learning.py rename to services/core_learning/advanced_learning.py index 00df7ec..5de745f 100644 --- a/services/advanced_learning.py +++ b/services/core_learning/advanced_learning.py @@ -14,10 +14,10 @@ from astrbot.api import logger -from ..config import PluginConfig -from ..core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage, IPersonaManager, ServiceLifecycle -from ..core.framework_llm_adapter import FrameworkLLMAdapter +from ...config import PluginConfig +from ...core.patterns import AsyncServiceBase +from ...core.interfaces import IDataStorage, IPersonaManager, ServiceLifecycle +from ...core.framework_llm_adapter import FrameworkLLMAdapter @dataclass diff --git a/services/message_collector.py b/services/core_learning/message_collector.py similarity index 94% rename from services/message_collector.py rename to services/core_learning/message_collector.py index e793d3f..121b13a 100644 --- a/services/message_collector.py +++ b/services/core_learning/message_collector.py @@ -11,15 +11,15 @@ # 简化的单例模式导入 try: - from ..config import PluginConfig - from ..exceptions import MessageCollectionError, DataStorageError - from ..core.interfaces import MessageData + from ...config import PluginConfig + from ...exceptions import MessageCollectionError, DataStorageError + from ...core.interfaces import MessageData except ImportError: - from ..config import PluginConfig - from ..exceptions import MessageCollectionError, DataStorageError - from ..core.interfaces import MessageData + from ...config import PluginConfig + from ...exceptions import MessageCollectionError, DataStorageError + from ...core.interfaces import MessageData -from .database_manager import DatabaseManager +from ..database import DatabaseManager class MessageCollectorService: @@ -34,7 +34,7 @@ def __init__(self, config: PluginConfig, context: Context, database_manager: Dat self._message_cache = [] self._cache_size_limit = 100 self._last_flush_time = time.time() - self._flush_interval = 30 # 30秒强制刷新一次 + self._flush_interval = 30 # 30秒强制刷新一次 logger.info("消息收集服务初始化完成") @@ -63,7 +63,7 @@ async def collect_message(self, message_data: Dict[str, Any]) -> bool: ) await self.database_manager.save_raw_message(message_obj) - logger.info(f"✅ 消息已保存: group={message_data.get('group_id')}, sender={message_data.get('sender_name')}, msg_preview={message_data.get('message', '')[:30]}...") + logger.info(f" 消息已保存: group={message_data.get('group_id')}, sender={message_data.get('sender_name')}, msg_preview={message_data.get('message', '')[:30]}...") return True diff --git a/services/progressive_learning.py b/services/core_learning/progressive_learning.py similarity index 74% rename from services/progressive_learning.py rename to services/core_learning/progressive_learning.py index 9c31a3a..0a9b507 100644 --- a/services/progressive_learning.py +++ b/services/core_learning/progressive_learning.py @@ -11,13 +11,13 @@ from astrbot.api import logger from astrbot.api.star import Context -from ..config import PluginConfig -from ..constants import UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING -from ..exceptions import LearningError +from ...config import PluginConfig +from ...constants import UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING +from ...exceptions import LearningError -from ..utils.json_utils import safe_parse_llm_json, clean_llm_json_response +from ...utils.json_utils import safe_parse_llm_json, clean_llm_json_response -from .database_manager import DatabaseManager +from ..database import DatabaseManager @dataclass @@ -56,21 +56,21 @@ def __init__(self, config: PluginConfig, context: Context, self.quality_monitor = quality_monitor self.persona_manager = persona_manager # 注入 persona_manager self.ml_analyzer = ml_analyzer # 注入 ml_analyzer - self.prompts = prompts # 保存 prompts 实例 + self.prompts = prompts # 保存 prompts 实例 # 学习状态 - 使用字典管理每个群组的学习状态 - self.learning_active = {} # 改为字典,按群组ID管理 + self.learning_active = {} # 改为字典,按群组ID管理 # 增量更新回调函数,降低耦合性 self.update_system_prompt_callback = None - self.current_session: Optional[LearningSession] = None + self._group_sessions: Dict[str, LearningSession] = {} self.learning_sessions: List[LearningSession] = [] # 历史学习会话,可以从数据库加载 - self.learning_lock = asyncio.Lock() # 添加异步锁防止竞态条件 + self.learning_lock = asyncio.Lock() # 添加异步锁防止竞态条件 # 学习控制参数 self.batch_size = config.max_messages_per_batch - self.learning_interval = config.learning_interval_hours * 3600 # 转换为秒 + self.learning_interval = config.learning_interval_hours * 3600 # 转换为秒 self.quality_threshold = config.style_update_threshold logger.info("渐进式学习服务初始化完成") @@ -104,24 +104,24 @@ async def start(self): async def start_learning(self, group_id: str) -> bool: """启动学习流程 - 优化为后台任务执行""" - async with self.learning_lock: # 使用锁防止竞态条件 + async with self.learning_lock: # 使用锁防止竞态条件 try: # 检查该群组是否已经在学习 if self.learning_active.get(group_id, False): logger.info(f"群组 {group_id} 学习已在进行中,跳过启动") - return True # 返回True表示学习状态正常 + return True # 返回True表示学习状态正常 # 设置该群组为学习状态 self.learning_active[group_id] = True # 创建新的学习会话 session_id = f"session_{group_id}_{int(time.time())}" - self.current_session = LearningSession( + self._group_sessions[group_id] = LearningSession( session_id=session_id, start_time=datetime.now().isoformat() ) # 保存新的学习会话到数据库 - await self.db_manager.save_learning_session_record(group_id, self.current_session.__dict__) + await self.db_manager.save_learning_session_record(group_id, self._group_sessions[group_id].__dict__) logger.info(f"开始学习会话: {session_id} for group {group_id}") @@ -159,15 +159,22 @@ async def stop_learning(self, group_id: str = None): self.learning_active[gid] = False logger.info("停止所有群组的学习任务") - if self.current_session: - self.current_session.end_time = datetime.now().isoformat() - self.current_session.success = True # 假设正常停止即成功 - # 保存更新后的学习会话到数据库 - target_group_id = group_id or "global_learning" # 使用指定的群组ID或默认值 - await self.db_manager.save_learning_session_record(target_group_id, self.current_session.__dict__) - self.learning_sessions.append(self.current_session) # 仍然添加到内存列表 - logger.info(f"学习会话结束: {self.current_session.session_id}") - self.current_session = None + if group_id: + session = self._group_sessions.pop(group_id, None) + if session: + session.end_time = datetime.now().isoformat() + session.success = True + await self.db_manager.save_learning_session_record(group_id, session.__dict__) + self.learning_sessions.append(session) + logger.info(f"学习会话结束: {session.session_id}") + else: + for gid, session in list(self._group_sessions.items()): + session.end_time = datetime.now().isoformat() + session.success = True + await self.db_manager.save_learning_session_record(gid, session.__dict__) + self.learning_sessions.append(session) + logger.info(f"学习会话结束: {session.session_id}") + self._group_sessions.clear() async def _learning_loop_safe(self, group_id: str): """安全的学习循环 - 在后台线程执行,包含完整错误处理""" @@ -192,12 +199,13 @@ async def _learning_loop_safe(self, group_id: str): break except Exception as e: logger.error(f"群组 {group_id} 学习循环异常: {e}", exc_info=True) - await asyncio.sleep(60) # 异常时等待1分钟 + await asyncio.sleep(60) # 异常时等待1分钟 finally: # 确保清理资源 - if self.current_session: - self.current_session.end_time = datetime.now().isoformat() - await self.db_manager.save_learning_session_record(group_id, self.current_session.__dict__) + session = self._group_sessions.pop(group_id, None) + if session: + session.end_time = datetime.now().isoformat() + await self.db_manager.save_learning_session_record(group_id, session.__dict__) logger.info(f"学习循环结束 for group {group_id}") async def _execute_learning_batch(self, group_id: str, relearn_mode: bool = False): @@ -212,12 +220,12 @@ async def _execute_learning_batch(self, group_id: str, relearn_mode: bool = Fals # 1. 获取消息(根据模式决定是否忽略"已处理"标记) if relearn_mode: - # ✅ 重新学习模式:获取所有历史消息,忽略已处理标记 - logger.info(f"🔄 重新学习模式:获取群组 {group_id} 的所有历史消息(忽略已处理标记)") + # 重新学习模式:获取所有历史消息,忽略已处理标记 + logger.info(f" 重新学习模式:获取群组 {group_id} 的所有历史消息(忽略已处理标记)") # 使用 get_recent_raw_messages 获取所有历史消息(不考虑已处理标记) unprocessed_messages = await self.db_manager.get_recent_raw_messages( group_id=group_id, - limit=self.batch_size * 10 # 重新学习时获取更多消息 + limit=self.batch_size * 10 # 重新学习时获取更多消息 ) logger.info(f"获取到 {len(unprocessed_messages) if unprocessed_messages else 0} 条历史消息用于重新学习") else: @@ -288,7 +296,7 @@ async def _execute_learning_batch(self, group_id: str, relearn_mode: bool = Fals updated_persona = await self._generate_updated_persona_with_refinement(group_id, current_persona, style_analysis) # 7. 【新增】强化学习增量微调 - ml_tuning_info = None # 用于记录强化学习调优信息 + ml_tuning_info = None # 用于记录强化学习调优信息 if self.config.enable_ml_analysis and updated_persona: try: tuning_result = await self.ml_analyzer.reinforcement_incremental_tuning( @@ -299,7 +307,7 @@ async def _execute_learning_batch(self, group_id: str, relearn_mode: bool = Fals # 使用强化学习优化后的人格 final_persona = tuning_result.get('updated_persona') - # ✅ 检查 updated_persona 类型,确保是字典才调用 update + # 检查 updated_persona 类型,确保是字典才调用 update if not isinstance(updated_persona, dict): logger.error(f"updated_persona 类型不正确,预期为 dict 但得到 {type(updated_persona)},跳过强化学习调优") elif not isinstance(final_persona, dict): @@ -345,35 +353,35 @@ async def _execute_learning_batch(self, group_id: str, relearn_mode: bool = Fals # 9. 应用学习更新(对话风格学习不判断质量直接应用,人格学习加入审查) # 注意:对话风格(表达模式)学习总是成功,人格学习在_apply_learning_updates中会加入审查 - # ✅ 传递 relearn_mode 和 ml_tuning_info 参数 + # 传递 relearn_mode 和 ml_tuning_info 参数 await self._apply_learning_updates(group_id, style_analysis, filtered_messages, current_persona, updated_persona, quality_metrics, relearn_mode=relearn_mode, ml_tuning_info=ml_tuning_info) logger.info(f"学习更新已应用(对话风格学习已完成,人格学习已加入审查),质量得分: {quality_metrics.consistency_score:.3f} for group {group_id}") - success = True # 对话风格学习总是成功 + success = True # 对话风格学习总是成功 # 10. 【新增】保存学习性能记录 - # ✅ 正确处理 AnalysisResult 对象进行序列化 + # 正确处理 AnalysisResult 对象进行序列化 style_analysis_for_db = style_analysis.data if hasattr(style_analysis, 'data') else style_analysis await self.db_manager.save_learning_performance_record(group_id, { - 'session_id': self.current_session.session_id if self.current_session else '', + 'session_id': self._group_sessions[group_id].session_id if group_id in self._group_sessions else '', 'timestamp': time.time(), 'quality_score': quality_metrics.consistency_score, 'learning_time': (datetime.now() - batch_start_time).total_seconds(), 'success': success, 'successful_pattern': json.dumps(style_analysis_for_db, default=self._json_serializer), - 'failed_pattern': '' # 对话风格学习总是成功,不记录失败 + 'failed_pattern': '' # 对话风格学习总是成功,不记录失败 }) # 11. 标记消息为已处理 await self._mark_messages_processed(unprocessed_messages) # 12. 更新学习会话统计并持久化 - if self.current_session: - self.current_session.messages_processed += len(unprocessed_messages) - self.current_session.filtered_messages += len(filtered_messages) - self.current_session.quality_score = quality_metrics.consistency_score - self.current_session.success = success - # 每次批次结束都保存当前会话状态 - await self.db_manager.save_learning_session_record(group_id, self.current_session.__dict__) + group_session = self._group_sessions.get(group_id) + if group_session: + group_session.messages_processed += len(unprocessed_messages) + group_session.filtered_messages += len(filtered_messages) + group_session.quality_score = quality_metrics.consistency_score + group_session.success = success + await self.db_manager.save_learning_session_record(group_id, group_session.__dict__) # 13. 【新增】学习成功后更新增量内容到system_prompt if success: @@ -387,8 +395,8 @@ async def _execute_learning_batch(self, group_id: str, relearn_mode: bool = Fals except Exception as e: logger.error(f"定时增量内容更新失败: {e}") - # 14. 【新增】定期执行策略优化 - if success and self.current_session and self.current_session.messages_processed % 500 == 0: + # 14. 定期执行策略优化 + if success and group_session and group_session.messages_processed % 500 == 0: try: await self.ml_analyzer.reinforcement_strategy_optimization(group_id) logger.info("执行了策略优化检查") @@ -475,7 +483,7 @@ async def _execute_learning_batch_background(self, group_id: str): group_id, current_persona, updated_persona ) if tuning_result and tuning_result.get('updated_persona'): - # ✅ 检查 updated_persona 类型,确保是字典才调用 update + # 检查 updated_persona 类型,确保是字典才调用 update if isinstance(updated_persona, dict): updated_persona.update(tuning_result.get('updated_persona')) logger.info(f"应用强化学习优化,预期改进: {tuning_result.get('performance_prediction', {}).get('expected_improvement', 0)}") @@ -485,7 +493,7 @@ async def _execute_learning_batch_background(self, group_id: str): # 7. 质量评估和应用更新 await self._finalize_learning_batch( group_id, current_persona, updated_persona, filtered_messages, - unprocessed_messages, batch_start_time, style_analysis # ✅ 传递 style_analysis + unprocessed_messages, batch_start_time, style_analysis # 传递 style_analysis ) except Exception as e: @@ -506,7 +514,7 @@ async def _execute_reinforcement_learning_background(self, group_id: str, filter async def _execute_style_analysis_background(self, group_id: str, filtered_messages): """在后台执行风格分析""" - from ..core.interfaces import AnalysisResult + from ...core.interfaces import AnalysisResult try: return await self.style_analyzer.analyze_conversation_style(group_id, filtered_messages) except Exception as e: @@ -546,81 +554,66 @@ async def _finalize_learning_batch(self, group_id: str, current_persona, updated ) # 应用学习更新(对话风格学习不判断质量直接应用,人格学习加入审查) - # ✅ 传递 style_analysis 用于保存对话风格学习记录 - # ✅ 如果 style_analysis 为 None,创建一个空的 AnalysisResult - from ..core.interfaces import AnalysisResult + # 传递 style_analysis 用于保存对话风格学习记录 + # 如果 style_analysis 为 None,创建一个空的 AnalysisResult + from ...core.interfaces import AnalysisResult if style_analysis is None: style_analysis = AnalysisResult(success=True, confidence=0.5, data={}) await self._apply_learning_updates(group_id, style_analysis, filtered_messages, current_persona, updated_persona, quality_metrics, relearn_mode=False, ml_tuning_info=None) logger.info(f"学习更新已应用(对话风格学习已完成,人格学习已加入审查),质量得分: {quality_metrics.consistency_score:.3f} for group {group_id}") - success = True # 对话风格学习总是成功 + success = True # 对话风格学习总是成功 - # 【新增】记录学习批次到数据库,供webui查询使用 - # ✅ 增强错误处理,如果表不存在则跳过记录 + # 记录学习批次到数据库(使用 ORM) try: batch_name = f"batch_{group_id}_{int(time.time())}" start_time = batch_start_time.timestamp() end_time = time.time() - # 连接到全局消息数据库记录学习批次 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT INTO learning_batches - (group_id, batch_name, start_time, end_time, quality_score, processed_messages, - message_count, filtered_count, success, error_message) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - group_id, - batch_name, - start_time, - end_time, - quality_metrics.consistency_score, - len(unprocessed_messages), - len(unprocessed_messages), - len(filtered_messages), - success, - None # 对话风格学习总是成功,不记录错误 - )) - await conn.commit() - logger.debug(f"学习批次记录已保存: {batch_name}") - except Exception as e: - error_str = str(e) - if "no such table" in error_str.lower() or "doesn't exist" in error_str.lower() or "unknown column" in error_str.lower(): - logger.debug(f"学习批次表不存在或结构过旧,跳过保存(这不影响学习功能): {e}") - else: - logger.error(f"保存学习批次记录失败: {e}") - finally: - await cursor.close() + async with self.db_manager.get_session() as session: + from ...models.orm.learning import LearningBatch + batch_record = LearningBatch( + batch_id=batch_name, + batch_name=batch_name, + group_id=group_id, + start_time=start_time, + end_time=end_time, + quality_score=quality_metrics.consistency_score, + processed_messages=len(unprocessed_messages), + message_count=len(unprocessed_messages), + filtered_count=len(filtered_messages), + success=success, + ) + session.add(batch_record) + await session.commit() + logger.debug(f"学习批次记录已保存: {batch_name}") except Exception as e: - logger.debug(f"无法记录学习批次(这不影响学习功能): {e}") - + logger.debug(f"无法记录学习批次(不影响学习功能): {e}") + # 保存学习性能记录 await self.db_manager.save_learning_performance_record(group_id, { - 'session_id': self.current_session.session_id if self.current_session else '', + 'session_id': self._group_sessions[group_id].session_id if group_id in self._group_sessions else '', 'timestamp': time.time(), 'quality_score': quality_metrics.consistency_score, 'learning_time': end_time - start_time, 'success': success, 'successful_pattern': json.dumps({}), - 'failed_pattern': '' # 对话风格学习总是成功,不记录失败 + 'failed_pattern': '' # 对话风格学习总是成功,不记录失败 }) # 标记消息为已处理 await self._mark_messages_processed(unprocessed_messages) # 更新会话统计 - if self.current_session: - self.current_session.messages_processed += len(unprocessed_messages) - self.current_session.filtered_messages += len(filtered_messages) - self.current_session.quality_score = quality_metrics.consistency_score - self.current_session.success = success - await self.db_manager.save_learning_session_record(group_id, self.current_session.__dict__) - + bg_session = self._group_sessions.get(group_id) + if bg_session: + bg_session.messages_processed += len(unprocessed_messages) + bg_session.filtered_messages += len(filtered_messages) + bg_session.quality_score = quality_metrics.consistency_score + bg_session.success = success + await self.db_manager.save_learning_session_record(group_id, bg_session.__dict__) + # 定期执行策略优化 - 不阻塞主流程 - if success and self.current_session and self.current_session.messages_processed % 500 == 0: + if success and bg_session and bg_session.messages_processed % 500 == 0: asyncio.create_task(self._execute_strategy_optimization_background(group_id)) batch_duration = end_time - start_time @@ -641,7 +634,7 @@ async def _generate_updated_persona_with_refinement(self, group_id: str, current """使用提炼模型生成更新后的人格""" try: # 正确处理AnalysisResult对象和字典类型 - from ..core.interfaces import AnalysisResult + from ...core.interfaces import AnalysisResult if isinstance(style_analysis, AnalysisResult): # 如果是AnalysisResult对象,提取data属性 @@ -723,76 +716,15 @@ def _json_serializer(self, obj): logger.warning(f"JSON序列化对象时出现错误: {e}, 对象类型: {type(obj)}, 转换为字符串") return str(obj) - # async def _execute_learning_batch(self): - # """执行一个学习批次""" - # try: - # batch_start_time = datetime.now() - - # # 1. 获取未处理的消息 - # unprocessed_messages = await self.message_collector.get_unprocessed_messages( - # limit=self.batch_size - # ) - - # if not unprocessed_messages: - # logger.debug("没有未处理的消息,跳过此批次") - # return - - # logger.info(f"开始处理 {len(unprocessed_messages)} 条消息") - - # # 2. 使用多维度分析器筛选消息 - # filtered_messages = await self._filter_messages_with_context(unprocessed_messages) - - # if not filtered_messages: - # logger.debug("没有通过筛选的消息") - # await self._mark_messages_processed(unprocessed_messages) - # return - - # # 3. 使用风格分析器深度分析 - # style_analysis = await self.style_analyzer.analyze_conversation_style(filtered_messages) - - # # 4. 获取当前人格设置 - # current_persona = await self._get_current_persona() - - # # 5. 质量监控评估 - # quality_metrics = await self.quality_monitor.evaluate_learning_batch( - # current_persona, - # await self._generate_updated_persona(current_persona, style_analysis), - # filtered_messages - # ) - - # # 6. 根据质量评估决定是否应用更新 - # if quality_metrics.consistency_score >= self.quality_threshold: - # await self._apply_learning_updates(style_analysis, filtered_messages) - # logger.info(f"学习更新已应用,质量得分: {quality_metrics.consistency_score:.3f}") - # else: - # logger.warning(f"学习质量不达标,跳过更新,得分: {quality_metrics.consistency_score:.3f}") - - # # 7. 标记消息为已处理 - # await self._mark_messages_processed(unprocessed_messages) - - # # 8. 更新学习会话统计 - # if self.current_session: - # self.current_session.messages_processed += len(unprocessed_messages) - # self.current_session.filtered_messages += len(filtered_messages) - # self.current_session.quality_score = quality_metrics.consistency_score - - # # 记录批次耗时 - # batch_duration = (datetime.now() - batch_start_time).total_seconds() - # logger.info(f"学习批次完成,耗时: {batch_duration:.2f}秒") - - # except Exception as e: - # logger.error(f"学习批次执行失败: {e}") - # raise LearningError(f"学习批次执行失败: {str(e)}") - async def _filter_messages_with_context(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """对话风格学习不需要筛选,直接返回所有消息""" - # ✅ 对话风格学习不需要LLM筛选,直接学习所有原始消息 + # 对话风格学习不需要LLM筛选,直接学习所有原始消息 logger.info(f"对话风格学习模式:直接学习 {len(messages)} 条原始消息(跳过LLM筛选)") # 为每条消息添加默认的相关性评分 for message in messages: - message['relevance_score'] = 1.0 # 默认完全相关 + message['relevance_score'] = 1.0 # 默认完全相关 message['filter_reason'] = 'style_learning_no_filter' return messages @@ -850,7 +782,7 @@ async def _generate_updated_persona(self, group_id: str, current_persona: Dict[s learning_content = [] # 正确处理AnalysisResult对象和字典类型 - from ..core.interfaces import AnalysisResult + from ...core.interfaces import AnalysisResult if isinstance(style_analysis, AnalysisResult): # 如果是AnalysisResult对象,提取data属性 @@ -867,7 +799,7 @@ async def _generate_updated_persona(self, group_id: str, current_persona: Dict[s analysis_data = {} logger.warning(f"style_analysis类型不正确: {type(style_analysis)}, 使用空字典") - # ✅ 修复:从实际的 style_analysis 结构中提取内容 + # 修复:从实际的 style_analysis 结构中提取内容 # 优先提取 enhanced_prompt 和 learning_insights(如果有) if 'enhanced_prompt' in analysis_data: learning_content.append(analysis_data['enhanced_prompt']) @@ -879,7 +811,7 @@ async def _generate_updated_persona(self, group_id: str, current_persona: Dict[s learning_content.append(insights) logger.debug("找到 learning_insights 字段") - # ✅ 新增:从 style_analysis 字段提取内容(StyleAnalyzer返回的结构) + # 新增:从 style_analysis 字段提取内容(StyleAnalyzer返回的结构) if not learning_content and 'style_analysis' in analysis_data: style_report = analysis_data['style_analysis'] if isinstance(style_report, dict): @@ -914,7 +846,7 @@ async def _generate_updated_persona(self, group_id: str, current_persona: Dict[s learning_content.append("【对话风格学习结果】\n" + "\n".join(extracted_parts)) logger.debug(f"从 style_analysis 提取了 {len(extracted_parts)} 个风格特征") - # ✅ 新增:如果还是没有内容,从 style_profile 提取 + # 新增:如果还是没有内容,从 style_profile 提取 if not learning_content and 'style_profile' in analysis_data: style_profile = analysis_data['style_profile'] if isinstance(style_profile, dict): @@ -936,7 +868,7 @@ async def _generate_updated_persona(self, group_id: str, current_persona: Dict[s learning_content.append("【风格量化指标】\n" + "\n".join(profile_parts)) logger.debug(f"从 style_profile 提取了 {len(profile_parts)} 个量化指标") - # ✅ 新增:如果还是没有内容,尝试提取任何有用的信息 + # 新增:如果还是没有内容,尝试提取任何有用的信息 if not learning_content: # 尝试从顶层提取任何看起来有用的字段 useful_fields = ['summary', 'description', 'analysis', 'insights', 'findings'] @@ -956,10 +888,10 @@ async def _generate_updated_persona(self, group_id: str, current_persona: Dict[s updated_persona['prompt'] = original_prompt + new_content updated_persona['last_updated'] = timestamp - logger.info(f"✅ 成功追加 {len(learning_content)} 项学习内容到人格 for group {group_id}") + logger.info(f" 成功追加 {len(learning_content)} 项学习内容到人格 for group {group_id}") return updated_persona else: - logger.warning(f"⚠️ style_analysis中没有可提取的学习内容 for group {group_id}, 数据结构: {list(analysis_data.keys())}") + logger.warning(f" style_analysis中没有可提取的学习内容 for group {group_id}, 数据结构: {list(analysis_data.keys())}") # 即使没有学习内容,也返回一个副本以确保有updated_persona用于对比 return dict(default_persona) @@ -998,7 +930,7 @@ async def _apply_learning_updates(self, group_id: str, style_analysis: Dict[str, # 2. 更新人格prompt(通过 PersonaManagerService) logger.info(f"应用人格更新 for group {group_id}") - # ✅ 正确处理 AnalysisResult 对象 + # 正确处理 AnalysisResult 对象 if hasattr(style_analysis, 'success'): # 这是一个 AnalysisResult 对象 if not style_analysis.success: @@ -1023,7 +955,7 @@ async def _apply_learning_updates(self, group_id: str, style_analysis: Dict[str, logger.error(f"通过 PersonaManagerService 更新人格失败 for group {group_id}") # 2. 创建人格学习审查记录(新增) - # ✅ 重新学习模式:即使内容相同也创建审查记录(作为重新确认) + # 重新学习模式:即使内容相同也创建审查记录(作为重新确认) # 正常模式:只在内容不同时创建审查记录 should_create_review = False if relearn_mode: @@ -1033,17 +965,17 @@ async def _apply_learning_updates(self, group_id: str, style_analysis: Dict[str, # 检查是否有实质性变化 has_changes = updated_persona.get('prompt', '') != current_persona.get('prompt', '') if has_changes: - logger.info(f"🔄 重新学习模式:检测到人格变化,创建审查记录(group: {group_id})") + logger.info(f" 重新学习模式:检测到人格变化,创建审查记录(group: {group_id})") else: - logger.info(f"🔄 重新学习模式:未检测到人格变化,但仍创建审查记录供审核(group: {group_id})") + logger.info(f" 重新学习模式:未检测到人格变化,但仍创建审查记录供审核(group: {group_id})") else: - logger.warning(f"⚠️ 重新学习模式:无法创建审查记录 - updated_persona={bool(updated_persona)}, current_persona={bool(current_persona)}") + logger.warning(f" 重新学习模式:无法创建审查记录 - updated_persona={bool(updated_persona)}, current_persona={bool(current_persona)}") elif updated_persona and current_persona and updated_persona.get('prompt') != current_persona.get('prompt'): # 正常模式:只在内容不同时创建 should_create_review = True - logger.info(f"✅ 正常模式:检测到人格变化,创建审查记录(group: {group_id})") + logger.info(f" 正常模式:检测到人格变化,创建审查记录(group: {group_id})") else: - logger.debug(f"🔹 正常模式:人格未变化,跳过审查记录 - updated={bool(updated_persona)}, current={bool(current_persona)}, same_prompt={updated_persona.get('prompt') == current_persona.get('prompt') if updated_persona and current_persona else 'N/A'}") + logger.debug(f" 正常模式:人格未变化,跳过审查记录 - updated={bool(updated_persona)}, current={bool(current_persona)}, same_prompt={updated_persona.get('prompt') == current_persona.get('prompt') if updated_persona and current_persona else 'N/A'}") if should_create_review: try: @@ -1051,32 +983,32 @@ async def _apply_learning_updates(self, group_id: str, style_analysis: Dict[str, original_prompt = current_persona.get('prompt', '') new_prompt = updated_persona.get('prompt', '') - # ✅ 计算新增内容(用于单独标记) + # 计算新增内容(用于单独标记) if len(new_prompt) > len(original_prompt): incremental_content = new_prompt[len(original_prompt):].strip() else: incremental_content = new_prompt - # ✅ 准备元数据(包含高亮信息) + # 准备元数据(包含高亮信息) metadata = { "progressive_learning": True, "message_count": len(messages), "style_analysis_fields": list(style_analysis.data.keys()) if (hasattr(style_analysis, "data") and isinstance(style_analysis.data, dict)) else (list(style_analysis.keys()) if isinstance(style_analysis, dict) else []), "original_prompt_length": len(original_prompt), "new_prompt_length": len(new_prompt), - "incremental_content": incremental_content, # ✅ 单独记录增量内容,用于高亮 - "incremental_start_pos": len(original_prompt), # ✅ 标记新增内容的起始位置 - "relearn_mode": relearn_mode # ✅ 标记是否���重新学习模式 + "incremental_content": incremental_content, # 单独记录增量内容,用于高亮 + "incremental_start_pos": len(original_prompt), # 标记新增内容的起始位置 + "relearn_mode": relearn_mode # 标记是否���重新学习模式 } - # ✅ 添加强化学习调优信息到元数据 + # 添加强化学习调优信息到元数据 if ml_tuning_info: metadata['ml_tuning'] = ml_tuning_info # 获取质量得分 confidence_score = quality_metrics.consistency_score if quality_metrics and hasattr(quality_metrics, 'consistency_score') else 0.5 - # ✅ 构建 raw_analysis 说明(包含强化学习信息) + # 构建 raw_analysis 说明(包含强化学习信息) raw_analysis_parts = [f"基于{len(messages)}条消息的风格分析"] if relearn_mode: raw_analysis_parts.append("(重新学习)") @@ -1087,19 +1019,19 @@ async def _apply_learning_updates(self, group_id: str, style_analysis: Dict[str, raw_analysis_parts.append(f"已应用强化学习优化,预期改进: {ml_tuning_info['expected_improvement']:.2%}") raw_analysis = ";".join(raw_analysis_parts) - # ✅ 创建审查记录 - proposed_content 是完整的新人格(原人格 + 更新内容) + # 创建审查记录 - proposed_content 是完整的新人格(原人格 + 更新内容) review_id = await self.db_manager.add_persona_learning_review( group_id=group_id, - proposed_content=new_prompt, # ✅ 修改:proposed_content 是完整新人格 + proposed_content=new_prompt, # 修改:proposed_content 是完整新人格 learning_source=UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING, confidence_score=confidence_score, raw_analysis=raw_analysis, metadata=metadata, - original_content=original_prompt, # ✅ 原人格完整文本 - new_content=new_prompt # ✅ 新人格完整文本(与proposed_content相同,保持一致性) + original_content=original_prompt, # 原人格完整文本 + new_content=new_prompt # 新人格完整文本(与proposed_content相同,保持一致性) ) - logger.info(f"✅ 已创建人格学习审查记录 (ID: {review_id}),置信度: {confidence_score:.3f}") + logger.info(f" 已创建人格学习审查记录 (ID: {review_id}),置信度: {confidence_score:.3f}") except Exception as review_error: logger.error(f"创建人格学习审查记录失败: {review_error}", exc_info=True) @@ -1107,8 +1039,8 @@ async def _apply_learning_updates(self, group_id: str, style_analysis: Dict[str, logger.debug(f"人格未变化或缺少必要参数,跳过审查记录创建") # 3. 记录学习更新 - if self.current_session: - self.current_session.style_updates += 1 + if group_id in self._group_sessions: + self._group_sessions[group_id].style_updates += 1 except Exception as e: logger.error(f"应用学习更新失败 for group {group_id}: {e}") @@ -1126,18 +1058,17 @@ async def get_learning_status(self, group_id: str = None) -> Dict[str, Any]: return { 'learning_active': self.learning_active.get(group_id, False), 'group_id': group_id, - 'current_session': self.current_session.__dict__ if self.current_session else None, + 'current_session': self._group_sessions[group_id].__dict__ if group_id in self._group_sessions else None, 'total_sessions': len(self.learning_sessions), 'statistics': await self.message_collector.get_statistics(), 'quality_report': await self.quality_monitor.get_quality_report(), 'last_update': datetime.now().isoformat() } else: - # 获取所有群组的状态 return { 'learning_active_groups': {gid: active for gid, active in self.learning_active.items()}, 'active_groups_count': sum(1 for active in self.learning_active.values() if active), - 'current_session': self.current_session.__dict__ if self.current_session else None, + 'group_sessions': {gid: s.__dict__ for gid, s in self._group_sessions.items()}, 'total_sessions': len(self.learning_sessions), 'statistics': await self.message_collector.get_statistics(), 'quality_report': await self.quality_monitor.get_quality_report(), @@ -1182,129 +1113,13 @@ async def get_learning_insights(self) -> Dict[str, Any]: async def stop(self): """停止服务""" try: - await self.stop_learning() # 停止所有群组的学习 + await self.stop_learning() # 停止所有群组的学习 logger.info("渐进式学习服务已停止") return True except Exception as e: logger.error(f"停止渐进式学习服务失败: {e}") return False - async def _create_persona_review_for_low_quality(self, group_id: str, current_persona: str, - updated_persona: str, quality_metrics, filtered_messages): - """为质量不达标的学习结果创建审查记录""" - try: - from ..core.interfaces import PersonaUpdateRecord - import time - - # 将字典类型的人格数据转换为字符串 - if isinstance(current_persona, dict): - current_persona_str = json.dumps(current_persona, ensure_ascii=False, indent=2) - else: - current_persona_str = str(current_persona) if current_persona else "" - - if isinstance(updated_persona, dict): - updated_persona_str = json.dumps(updated_persona, ensure_ascii=False, indent=2) - else: - updated_persona_str = str(updated_persona) if updated_persona else "" - - # 计算变化内容摘要 - current_length = len(current_persona_str) - updated_length = len(updated_persona_str) - - # 构建详细的审查说明 - reason = f"""学习质量评估结果 (得分: {quality_metrics.consistency_score:.3f} < 阈值: {self.quality_threshold}) - -质量分析详情: -- 一致性得分: {quality_metrics.consistency_score:.3f} -- 处理消息数: {len(filtered_messages)} -- 原人格长度: {current_length} 字符 -- 新人格长度: {updated_length} 字符 - -系统建议: 由于学习质量不达标,建议手动审查内容质量后决定是否应用。 -可能的问题包括:内容冗余、逻辑不连贯、与现有人格风格差异过大等。 - -请仔细检查新人格内容是否合理,决定是否应用此次学习结果。""" - - # 保存完整内容,不进行截断(移除之前的500字符限制) - original_content_full = current_persona_str - new_content_full = updated_persona_str - - # 创建审查记录 - review_record = PersonaUpdateRecord( - timestamp=time.time(), - group_id=group_id, - update_type="persona_learning_review", - original_content=original_content_full, - new_content=new_content_full, - reason=reason, - confidence_score=quality_metrics.consistency_score, # 使用实际的质量得分 - status='pending' - ) - - # 直接保存到数据库 - 不依赖persona_updater - try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 确保审查表存在 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp REAL NOT NULL, - group_id TEXT NOT NULL, - update_type TEXT NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, - confidence_score REAL, - reason TEXT, - status TEXT NOT NULL DEFAULT 'pending', - reviewer_comment TEXT, - review_time REAL - ) - ''') - - # 为旧表添加缺失的列(如果不存在) - try: - await cursor.execute('ALTER TABLE persona_update_reviews ADD COLUMN proposed_content TEXT') - except: - pass # 列已存在 - try: - await cursor.execute('ALTER TABLE persona_update_reviews ADD COLUMN confidence_score REAL') - except: - pass # 列已存在 - - # 插入审查记录 - await cursor.execute(''' - INSERT INTO persona_update_reviews - (timestamp, group_id, update_type, original_content, new_content, proposed_content, confidence_score, reason, status) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - review_record.timestamp, - review_record.group_id, - review_record.update_type, - review_record.original_content, - review_record.new_content, - review_record.new_content, # proposed_content使用相同内容 - review_record.confidence_score, - review_record.reason, - review_record.status - )) - - await conn.commit() - record_id = cursor.lastrowid - await cursor.close() - logger.info(f"质量不达标的人格学习审查记录已创建,ID: {record_id}") - return True - - except Exception as db_error: - logger.error(f"保存审查记录到数据库失败: {db_error}") - return False - - except Exception as e: - logger.error(f"创建质量不达标审查记录失败: {e}") - return False - async def _save_style_learning_record(self, group_id: str, style_analysis: Dict[str, Any], messages: List[Dict[str, Any]], quality_metrics=None): """ @@ -1317,7 +1132,7 @@ async def _save_style_learning_record(self, group_id: str, style_analysis: Dict[ quality_metrics: 质量指标 """ try: - # ✅ 处理 AnalysisResult 对象,提取其 data 属性 + # 处理 AnalysisResult 对象,提取其 data 属性 if style_analysis and hasattr(style_analysis, 'data'): style_analysis_dict = style_analysis.data elif isinstance(style_analysis, dict): @@ -1325,7 +1140,7 @@ async def _save_style_learning_record(self, group_id: str, style_analysis: Dict[ else: style_analysis_dict = {} - # ✅ 即使没有 style_analysis,也应该基于消息创建学习记录 + # 即使没有 style_analysis,也应该基于消息创建学习记录 if not style_analysis_dict and not messages: logger.debug(f"群组 {group_id} 没有风格分析结果且没有消息,跳过风格学习记录保存") return @@ -1343,13 +1158,13 @@ async def _save_style_learning_record(self, group_id: str, style_analysis: Dict[ # 如果没有 enhanced_prompt,从 expression_patterns 构建 few_shots_content = self._build_few_shots_from_patterns(expression_patterns) - # ✅ 如果没有 few_shots_content,从消息中构建简单的学习内容 + # 如果没有 few_shots_content,从消息中构建简单的学习内容 if not few_shots_content and messages: few_shots_content = f"基于 {len(messages)} 条对话消息的风格学习" # 3. 构建学习模式列表 learned_patterns = [] - for pattern in expression_patterns[:10]: # 取前10个模式 + for pattern in expression_patterns[:10]: # 取前10个模式 learned_patterns.append({ 'situation': pattern.get('situation', ''), 'expression': pattern.get('expression', ''), @@ -1368,7 +1183,7 @@ async def _save_style_learning_record(self, group_id: str, style_analysis: Dict[ # 6. 保存风格学习记录(使用 ORM) try: async with self.db_manager.get_session() as session: - from ..models.orm.learning import StyleLearningReview + from ...models.orm.learning import StyleLearningReview from datetime import datetime current_timestamp = time.time() @@ -1379,19 +1194,19 @@ async def _save_style_learning_record(self, group_id: str, style_analysis: Dict[ timestamp=current_timestamp, learned_patterns=json.dumps(learned_patterns, ensure_ascii=False), few_shots_content=few_shots_content, - status='approved', # 直接批准,不需要审查 + status='approved', # 直接批准,不需要审查 description=description, reviewer_comment='自动批准', review_time=current_timestamp, - created_at=datetime.fromtimestamp(current_timestamp), # ✅ 转换为datetime对象 - updated_at=datetime.fromtimestamp(current_timestamp) # ✅ 转换为datetime对象 + created_at=datetime.fromtimestamp(current_timestamp), # 转换为datetime对象 + updated_at=datetime.fromtimestamp(current_timestamp) # 转换为datetime对象 ) session.add(review) await session.commit() await session.refresh(review) - logger.info(f"✅ 对话风格学习记录已保存 (ID: {review.id}),处理 {message_count} 条消息,提取 {pattern_count} 个模式") + logger.info(f" 对话风格学习记录已保存 (ID: {review.id}),处理 {message_count} 条消息,提取 {pattern_count} 个模式") except Exception as e: logger.error(f"保存对话风格学习记录失败: {e}", exc_info=True) @@ -1403,7 +1218,7 @@ def _build_few_shots_from_patterns(self, patterns: List[Dict[str, Any]]) -> str: """从表达模式构建 few-shots 内容""" few_shots = "*Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n" - for i, pattern in enumerate(patterns[:5], 1): # 只取前5个 + for i, pattern in enumerate(patterns[:5], 1): # 只取前5个 situation = pattern.get('situation', '') expression = pattern.get('expression', '') if situation and expression: @@ -1425,7 +1240,7 @@ async def _save_expression_patterns(self, group_id: str, patterns: List[Dict[str # 使用 ORM 保存表达模式 async with self.db_manager.get_session() as session: - from ..models.orm.expression import ExpressionPattern + from ...models.orm.expression import ExpressionPattern import time current_time = time.time() @@ -1443,14 +1258,14 @@ async def _save_expression_patterns(self, group_id: str, patterns: List[Dict[str situation=situation, expression=expression, weight=float(pattern.get('weight', 1.0)), - last_active_time=current_time, # ✅ 使用last_active_time而不是confidence + last_active_time=current_time, # 使用last_active_time而不是confidence create_time=current_time ) session.add(expr_pattern) await session.commit() - logger.info(f"✅ 已保存 {len(patterns)} 个表达模式到数据库 (群组: {group_id})") + logger.info(f" 已保存 {len(patterns)} 个表达模式到数据库 (群组: {group_id})") except Exception as e: logger.error(f"保存表达模式失败: {e}", exc_info=True) diff --git a/services/core_learning/v2_learning_integration.py b/services/core_learning/v2_learning_integration.py new file mode 100644 index 0000000..51c9f9d --- /dev/null +++ b/services/core_learning/v2_learning_integration.py @@ -0,0 +1,552 @@ +""" +V2 learning integration layer. + +Wires together the v2-architecture modules and provides a unified +interface for the ``MaiBotEnhancedLearningManager`` to delegate to. +When v2 features are enabled in ``PluginConfig`` the learning manager +instantiates this class and calls its ``process_message`` and +``get_enhanced_context`` methods alongside (or instead of) the legacy +code paths. + +Modules orchestrated: + * ``TieredLearningTrigger`` — per-message / batch operation scheduling + * ``LightRAGKnowledgeManager`` — knowledge graph (replaces legacy) + * ``Mem0MemoryManager`` — memory management (replaces legacy) + * ``ExemplarLibrary`` — few-shot style exemplar retrieval + * ``SocialGraphAnalyzer`` — community detection / influence ranking + * ``JargonStatisticalFilter`` — statistical jargon pre-filter + * ``IRerankProvider`` — cross-source context reranking + +Design notes: + - All module construction is guarded by the relevant config flags so + that unused modules are never instantiated. + - ``start()`` / ``stop()`` manage the full lifecycle of every active + v2 module. + - Each module that can fail during construction logs a warning and + falls back gracefully (the integration layer keeps working with + the remaining modules). + - Thread-safe for single-event-loop asyncio usage. +""" + +from typing import Any, Dict, List, Optional, Tuple + +from astrbot.api import logger + +from ...config import PluginConfig +from ...core.interfaces import MessageData +from ..quality import ( + BatchTriggerPolicy, + TieredLearningTrigger, + TriggerResult, +) + + +class V2LearningIntegration: + """Facade that initialises, wires, and exposes v2 learning modules. + + Usage:: + + v2 = V2LearningIntegration(config, llm_adapter, db_manager, context) + await v2.start() + result = await v2.process_message(message, group_id) + context = await v2.get_enhanced_context("query", group_id) + await v2.stop() + """ + + def __init__( + self, + config: PluginConfig, + llm_adapter: Optional[Any] = None, + db_manager: Optional[Any] = None, + context: Optional[Any] = None, + ) -> None: + self._config = config + self._llm = llm_adapter + self._db = db_manager + self._context = context + + # --- Resolve framework providers via factories --------------- + self._embedding_provider = self._create_embedding_provider() + self._rerank_provider = self._create_rerank_provider() + + # --- Instantiate v2 modules ---------------------------------- + self._knowledge_manager = self._create_knowledge_manager() + self._memory_manager = self._create_memory_manager() + self._exemplar_library = self._create_exemplar_library() + self._social_analyzer = self._create_social_analyzer() + self._jargon_filter = self._create_jargon_filter() + + # --- Tiered trigger ------------------------------------------ + self._trigger = TieredLearningTrigger() + self._register_trigger_operations() + + logger.info( + "[V2Integration] Initialised — " + f"knowledge={self._config.knowledge_engine}, " + f"memory={self._config.memory_engine}, " + f"embedding={'yes' if self._embedding_provider else 'no'}, " + f"reranker={'yes' if self._rerank_provider else 'no'}" + ) + + # Lifecycle + + async def start(self) -> None: + """Start all active v2 modules that expose a ``start`` method.""" + modules: List[Tuple[str, Any]] = [ + ("knowledge_manager", self._knowledge_manager), + ("memory_manager", self._memory_manager), + ("exemplar_library", self._exemplar_library), + ("social_analyzer", self._social_analyzer), + ("jargon_filter", self._jargon_filter), + ] + for name, module in modules: + if module and hasattr(module, "start"): + try: + await module.start() + except Exception as exc: + logger.warning( + f"[V2Integration] {name} start failed: {exc}" + ) + logger.info("[V2Integration] All modules started") + + async def stop(self) -> None: + """Stop all active v2 modules and release resources.""" + modules: List[Tuple[str, Any]] = [ + ("knowledge_manager", self._knowledge_manager), + ("memory_manager", self._memory_manager), + ("exemplar_library", self._exemplar_library), + ("social_analyzer", self._social_analyzer), + ("jargon_filter", self._jargon_filter), + ] + for name, module in modules: + if module and hasattr(module, "stop"): + try: + await module.stop() + except Exception as exc: + logger.warning( + f"[V2Integration] {name} stop failed: {exc}" + ) + + if self._rerank_provider and hasattr(self._rerank_provider, "close"): + try: + await self._rerank_provider.close() + except Exception as exc: + logger.warning(f"[V2Integration] Reranker close failed: {exc}") + + logger.info("[V2Integration] All modules stopped") + + # Public API + + async def process_message( + self, message: MessageData, group_id: str + ) -> TriggerResult: + """Process an incoming message through the tiered trigger. + + Tier 1 operations run concurrently on every message. Tier 2 + operations fire when their policies are satisfied. + """ + return await self._trigger.process_message(message, group_id) + + async def get_enhanced_context( + self, + query: str, + group_id: str, + top_k: int = 5, + ) -> Dict[str, Any]: + """Retrieve v2 enhanced context for response generation. + + Returns a dict with optional keys: + * ``knowledge_context`` (str): Retrieved knowledge graph context. + * ``related_memories`` (List[str]): Semantically related memories. + * ``few_shot_examples`` (List[str]): Style exemplar texts + (not reranked; returned as-is). + * ``graph_stats`` (dict): Social graph summary statistics. + + When a reranker is available, knowledge and memory candidates are + reranked by relevance and only the top-k are returned. Few-shot + exemplars and graph stats are returned unmodified. + + All retrieval tasks run concurrently via ``asyncio.gather`` to + minimise total latency. + """ + import asyncio + + context: Dict[str, Any] = {} + + # --- Build concurrent retrieval tasks --- + + async def _fetch_knowledge() -> None: + if not self._knowledge_manager: + return + try: + if hasattr(self._knowledge_manager, "query_knowledge"): + ctx = await self._knowledge_manager.query_knowledge( + query, group_id + ) + elif hasattr( + self._knowledge_manager, + "answer_question_with_knowledge_graph", + ): + ctx = ( + await self._knowledge_manager + .answer_question_with_knowledge_graph(query, group_id) + ) + else: + ctx = "" + if ctx: + context["knowledge_context"] = ctx + except Exception as exc: + logger.debug( + f"[V2Integration] Knowledge retrieval failed: {exc}" + ) + + async def _fetch_memories() -> None: + if not self._memory_manager: + return + try: + memories = await self._memory_manager.get_related_memories( + query, group_id + ) + if memories: + context["related_memories"] = memories + except Exception as exc: + logger.debug( + f"[V2Integration] Memory retrieval failed: {exc}" + ) + + async def _fetch_exemplars() -> None: + if not self._exemplar_library: + return + try: + examples = await self._exemplar_library.get_few_shot_examples( + query, group_id, k=top_k + ) + if examples: + context["few_shot_examples"] = examples + except Exception as exc: + logger.debug( + f"[V2Integration] Exemplar retrieval failed: {exc}" + ) + + async def _fetch_graph_stats() -> None: + if not self._social_analyzer: + return + try: + stats = await self._social_analyzer.get_graph_statistics( + group_id + ) + if stats and stats.get("node_count", 0) > 0: + context["graph_stats"] = stats + except Exception as exc: + logger.debug( + f"[V2Integration] Social graph stats failed: {exc}" + ) + + # --- Run all retrievals concurrently --- + await asyncio.gather( + _fetch_knowledge(), + _fetch_memories(), + _fetch_exemplars(), + _fetch_graph_stats(), + ) + + # --- Reranking (optional, knowledge + memory only) --- + if self._rerank_provider and context: + context = await self._rerank_context(query, context, top_k) + + return context + + def get_trigger_stats(self, group_id: str) -> Dict[str, Any]: + """Return tiered trigger statistics for a group.""" + return self._trigger.get_group_stats(group_id) + + # Module factories + + def _create_embedding_provider(self) -> Optional[Any]: + """Resolve embedding provider from the framework.""" + try: + from ..embedding.factory import EmbeddingProviderFactory + return EmbeddingProviderFactory.create(self._config, self._context) + except Exception as exc: + logger.debug( + f"[V2Integration] Embedding provider unavailable: {exc}" + ) + return None + + def _create_rerank_provider(self) -> Optional[Any]: + """Resolve reranker provider from the framework.""" + try: + from ..reranker.factory import RerankProviderFactory + return RerankProviderFactory.create(self._config, self._context) + except Exception as exc: + logger.debug(f"[V2Integration] Reranker unavailable: {exc}") + return None + + def _create_knowledge_manager(self) -> Optional[Any]: + """Create knowledge manager based on configured engine.""" + if self._config.knowledge_engine == "lightrag": + try: + from ..integration import LightRAGKnowledgeManager + return LightRAGKnowledgeManager( + self._config, self._llm, self._embedding_provider + ) + except ImportError: + logger.warning( + "[V2Integration] lightrag-hku not installed, " + "falling back to legacy knowledge engine" + ) + except Exception as exc: + logger.warning( + f"[V2Integration] LightRAG init failed: {exc}" + ) + logger.debug( + "[V2Integration] LightRAG traceback:", exc_info=True + ) + return None + + def _create_memory_manager(self) -> Optional[Any]: + """Create memory manager based on configured engine.""" + if self._config.memory_engine == "mem0": + try: + from ..integration import Mem0MemoryManager + return Mem0MemoryManager( + self._config, self._llm, self._embedding_provider + ) + except ImportError: + logger.warning( + "[V2Integration] mem0ai not installed, " + "falling back to legacy memory engine" + ) + except Exception as exc: + logger.warning( + f"[V2Integration] Mem0 init failed: {exc}" + ) + logger.debug( + "[V2Integration] Mem0 traceback:", exc_info=True + ) + return None + + def _create_exemplar_library(self) -> Optional[Any]: + """Create exemplar library if DB and embedding are available.""" + if not self._db: + return None + try: + from ..integration import ExemplarLibrary + return ExemplarLibrary(self._db, self._embedding_provider) + except Exception as exc: + logger.debug( + f"[V2Integration] ExemplarLibrary init failed: {exc}" + ) + return None + + def _create_social_analyzer(self) -> Optional[Any]: + """Create social graph analyzer.""" + try: + from ..social import SocialGraphAnalyzer + return SocialGraphAnalyzer(self._llm, self._db) + except Exception as exc: + logger.debug( + f"[V2Integration] SocialGraphAnalyzer init failed: {exc}" + ) + return None + + def _create_jargon_filter(self) -> Optional[Any]: + """Create jargon statistical filter.""" + try: + from ..jargon import JargonStatisticalFilter + return JargonStatisticalFilter() + except Exception as exc: + logger.debug( + f"[V2Integration] JargonStatisticalFilter init failed: {exc}" + ) + return None + + # Trigger wiring + + def _register_trigger_operations(self) -> None: + """Register all available modules with the tiered trigger.""" + + # ---- Tier 1: per-message lightweight operations ---- + + if self._jargon_filter: + jf = self._jargon_filter + + async def _jargon_update( + message: MessageData, group_id: str + ) -> None: + jf.update_from_message(message, group_id) + + self._trigger.register_tier1("jargon_stats", _jargon_update) + + if self._memory_manager: + self._trigger.register_tier1( + "memory", self._memory_manager.add_memory_from_message + ) + + if self._knowledge_manager: + # Resolve the correct ingestion method name. + if hasattr( + self._knowledge_manager, + "process_message_for_knowledge_graph", + ): + method_name = "process_message_for_knowledge_graph" + elif hasattr( + self._knowledge_manager, "process_message_for_knowledge" + ): + method_name = "process_message_for_knowledge" + else: + method_name = None + logger.warning( + "[V2Integration] Knowledge manager has no recognised " + "ingestion method; knowledge tier-1 op skipped" + ) + + if method_name: + self._trigger.register_tier1( + "knowledge", + getattr(self._knowledge_manager, method_name), + ) + + if self._exemplar_library: + lib = self._exemplar_library + + async def _exemplar_add( + message: MessageData, group_id: str + ) -> None: + await lib.add_exemplar( + message.message, group_id, message.sender_id + ) + + self._trigger.register_tier1("exemplar", _exemplar_add) + + # ---- Tier 2: batch operations (LLM-heavy) ---- + + if self._jargon_filter: + jf2 = self._jargon_filter + llm = self._llm + db = self._db + + async def _jargon_batch(group_id: str) -> None: + candidates = jf2.get_jargon_candidates(group_id, top_k=20) + if not candidates or not llm: + return + for candidate in candidates[:10]: + try: + meaning = await llm.generate_response( + f"Explain the slang/jargon term " + f"'{candidate['term']}' in the context of an " + f"online chat group. Return a concise definition.", + model_type="filter", + ) + if ( + meaning + and db + and hasattr(db, "save_or_update_jargon") + ): + await db.save_or_update_jargon( + candidate["term"], meaning, group_id + ) + except Exception as exc: + logger.debug( + f"[V2Integration] Jargon inference failed " + f"for '{candidate['term']}': {exc}" + ) + + self._trigger.register_tier2( + "jargon", + _jargon_batch, + BatchTriggerPolicy( + message_threshold=20, cooldown_seconds=180 + ), + ) + + if self._social_analyzer: + sa = self._social_analyzer + + async def _social_batch(group_id: str) -> None: + # Execute independently so one failure does not skip the other. + try: + await sa.detect_communities(group_id) + except Exception as exc: + logger.debug( + f"[V2Integration] detect_communities failed: {exc}" + ) + try: + await sa.get_influence_ranking(group_id) + except Exception as exc: + logger.debug( + f"[V2Integration] get_influence_ranking failed: {exc}" + ) + + self._trigger.register_tier2( + "social", + _social_batch, + BatchTriggerPolicy( + message_threshold=50, cooldown_seconds=600 + ), + ) + + # Reranking + + async def _rerank_context( + self, + query: str, + context: Dict[str, Any], + top_k: int, + ) -> Dict[str, Any]: + """Rerank knowledge and memory candidates by relevance. + + Few-shot exemplars and graph stats are returned unmodified. + """ + try: + documents: List[str] = [] + sources: List[str] = [] + + if "knowledge_context" in context: + documents.append(context["knowledge_context"]) + sources.append("knowledge") + + for mem in context.get("related_memories", []): + documents.append(mem) + sources.append("memory") + + if not documents: + return context + + results = await self._rerank_provider.rerank( + query, documents, top_n=top_k + ) + + # Rebuild context with reranked order. + reranked_memories: List[str] = [] + reranked_knowledge = "" + for r in results: + if r.index >= len(documents): + logger.debug( + f"[V2Integration] Reranker returned out-of-range " + f"index {r.index} (len={len(documents)}); skipping" + ) + continue + src = sources[r.index] + doc = documents[r.index] + if src == "knowledge": + reranked_knowledge = doc + elif src == "memory": + reranked_memories.append(doc) + + if reranked_knowledge: + context["knowledge_context"] = reranked_knowledge + elif "knowledge_context" in context: + del context["knowledge_context"] + + if reranked_memories: + context["related_memories"] = reranked_memories + elif "related_memories" in context: + del context["related_memories"] + + except Exception as exc: + logger.debug( + f"[V2Integration] Reranking failed, using unranked: {exc}" + ) + + return context diff --git a/services/data_export_formatter.py b/services/data_export_formatter.py deleted file mode 100644 index d8e24d6..0000000 --- a/services/data_export_formatter.py +++ /dev/null @@ -1,705 +0,0 @@ -""" -数据导出格式化服务 -用于将插件内部数据转换为标准JSON格式,供外部系统(如liyn-web)使用 - -设计原则: -1. 通用性:支持多种数据类型的导出(情绪、好感度、学习数据等) -2. 扩展性:便于未来添加新的数据类型 -3. 统一格式:所有导出数据遵循统一的响应结构 -4. 安全性:数据过滤和权限控制 -""" -import time -from typing import Dict, List, Optional, Any, Callable -from dataclasses import dataclass, asdict -from datetime import datetime -from enum import Enum - -from astrbot.api import logger - -from ..config import PluginConfig -from ..core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage -from .affection_manager import AffectionManager, MoodType, InteractionType - - -class DataExportType(Enum): - """数据导出类型枚举""" - EMOTION = "emotion" # 情绪数据 - AFFECTION = "affection" # 好感度数据 - LEARNING_STATS = "learning_stats" # 学习统计数据 - STYLE_PATTERNS = "style_patterns" # 风格模式数据 - SOCIAL_RELATIONS = "social_relations" # 社交关系数据 - MESSAGE_STATS = "message_stats" # 消息统计数据 - COMPREHENSIVE = "comprehensive" # 综合数据(包含所有) - - -@dataclass -class EmotionData: - """情绪数据结构""" - group_id: str - mood_type: str # happy, sad, excited, calm, angry, anxious, playful, serious, nostalgic, curious - mood_intensity: float # 0.0 - 1.0 - mood_description: str - start_time: float - end_time: Optional[float] - is_active: bool - created_at: str - - -@dataclass -class UserAffectionData: - """用户好感度数据结构""" - user_id: str - group_id: str - affection_level: int # 0-100 - last_interaction: float - interaction_count: int - last_updated: float - created_at: str - - -@dataclass -class GroupAffectionSummary: - """群组好感度汇总数据""" - group_id: str - total_affection: int - max_total_affection: int # 250 - user_count: int - avg_affection: float - top_users: List[Dict[str, Any]] # 前5名用户 - last_updated: float - - -@dataclass -class StandardResponse: - """标准响应数据结构 - 所有导出数据都遵循此格式""" - success: bool - timestamp: float - data_type: str # 数据类型:emotion, affection, learning_stats等 - group_id: Optional[str] # 群组ID(如果适用) - user_id: Optional[str] # 用户ID(如果适用) - data: Optional[Dict[str, Any]] # 实际数据内容 - metadata: Optional[Dict[str, Any]] # 元数据(统计信息等) - message: Optional[str] - error: Optional[str] - - -class DataExportFormatter(AsyncServiceBase): - """通用数据导出格式化服务 - - 职责: - 1. 统一数据导出接口 - 2. 支持多种数据类型的格式化 - 3. 提供数据过滤和权限控制 - 4. 便于未来扩展新的数据类型 - """ - - def __init__( - self, - config: PluginConfig, - database_manager: IDataStorage, - affection_manager: Optional[AffectionManager] = None - ): - super().__init__("data_export_formatter") - self.config = config - self.db_manager = database_manager - self.affection_manager = affection_manager - - # 数据导出处理器注册表(使用策略模式) - self._exporters: Dict[DataExportType, Callable] = {} - - async def _do_start(self) -> bool: - """启动服务并注册数据导出处理器""" - # 注册内置数据导出处理器 - self._register_builtin_exporters() - - self._logger.info("通用数据导出格式化服务启动成功") - return True - - async def _do_stop(self) -> bool: - """停止服务""" - return True - - def _register_builtin_exporters(self): - """注册内置的数据导出处理器""" - self._exporters[DataExportType.EMOTION] = self._export_emotion_data - self._exporters[DataExportType.AFFECTION] = self._export_affection_data - self._exporters[DataExportType.LEARNING_STATS] = self._export_learning_stats - self._exporters[DataExportType.STYLE_PATTERNS] = self._export_style_patterns - self._exporters[DataExportType.SOCIAL_RELATIONS] = self._export_social_relations - self._exporters[DataExportType.MESSAGE_STATS] = self._export_message_stats - self._exporters[DataExportType.COMPREHENSIVE] = self._export_comprehensive_data - - def register_custom_exporter( - self, - export_type: str, - exporter_func: Callable - ): - """ - 注册自定义数据导出处理器(用于扩展) - - Args: - export_type: 自定义的导出类型名称 - exporter_func: 导出处理函数,签名应为 async def func(group_id, **kwargs) -> Dict - """ - try: - # 创建动态枚举值(如果不存在) - custom_type = f"custom_{export_type}" - self._exporters[custom_type] = exporter_func - self._logger.info(f"注册自定义导出处理器: {export_type}") - except Exception as e: - self._logger.error(f"注册自定义导出处理器失败: {e}") - - async def export_data( - self, - data_type: str, - group_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ) -> StandardResponse: - """ - 通用数据导出接口 - - Args: - data_type: 数据类型(emotion, affection, learning_stats等) - group_id: 群组ID(可选) - user_id: 用户ID(可选) - **kwargs: 其他参数,传递给具体的导出处理器 - - Returns: - StandardResponse: 标准响应格式的数据 - """ - try: - # 查找对应的导出处理器 - exporter = None - - # 尝试匹配枚举类型 - for export_enum in DataExportType: - if export_enum.value == data_type: - exporter = self._exporters.get(export_enum) - break - - # 尝试匹配自定义类型 - if not exporter: - custom_key = f"custom_{data_type}" - exporter = self._exporters.get(custom_key) - - if not exporter: - return StandardResponse( - success=False, - timestamp=time.time(), - data_type=data_type, - group_id=group_id, - user_id=user_id, - data=None, - metadata=None, - message=None, - error=f"不支持的数据类型: {data_type}" - ) - - # 调用导出处理器 - result_data = await exporter( - group_id=group_id, - user_id=user_id, - **kwargs - ) - - return StandardResponse( - success=True, - timestamp=time.time(), - data_type=data_type, - group_id=group_id, - user_id=user_id, - data=result_data.get('data'), - metadata=result_data.get('metadata'), - message="数据导出成功", - error=None - ) - - except Exception as e: - self._logger.error(f"导出数据失败 (type={data_type}, group={group_id}): {e}", exc_info=True) - return StandardResponse( - success=False, - timestamp=time.time(), - data_type=data_type, - group_id=group_id, - user_id=user_id, - data=None, - metadata=None, - message=None, - error=f"数据导出失败: {str(e)}" - ) - - # ==================== 内置导出处理器 ==================== - - def _format_timestamp(self, timestamp: float) -> str: - """格式化时间戳为ISO 8601格式""" - return datetime.fromtimestamp(timestamp).isoformat() - - async def _export_emotion_data( - self, - group_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ) -> Dict[str, Any]: - """导出情绪数据""" - if not self.affection_manager: - return {"data": None, "metadata": {"error": "好感度管理器未初始化"}} - - if not group_id: - return {"data": None, "metadata": {"error": "需要提供群组ID"}} - - emotion_data = await self.get_current_emotion(group_id) - - return { - "data": asdict(emotion_data) if emotion_data else None, - "metadata": { - "has_active_emotion": emotion_data is not None if emotion_data else False - } - } - - async def _export_affection_data( - self, - group_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ) -> Dict[str, Any]: - """导出好感度数据""" - if not group_id: - return {"data": None, "metadata": {"error": "需要提供群组ID"}} - - limit = kwargs.get('limit', 100) - - # 如果指定了用户ID,只返回该用户的数据 - if user_id: - user_affection = await self.db_manager.get_user_affection(group_id, user_id) - return { - "data": { - "user_affection": user_affection, - "interaction_history": await self._get_user_interaction_history(group_id, user_id, limit=10) - }, - "metadata": {"query_type": "single_user"} - } - - # 否则返回所有用户的数据 - affection_list = await self.get_user_affections(group_id, limit) - group_summary = await self.get_group_affection_summary(group_id) - - return { - "data": { - "user_affections": [asdict(a) for a in affection_list], - "group_summary": asdict(group_summary) if group_summary else None - }, - "metadata": { - "total_users": len(affection_list), - "query_type": "group_level" - } - } - - async def _export_learning_stats( - self, - group_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ) -> Dict[str, Any]: - """导出学习统计数据""" - try: - stats = {} - - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 获取消息统计 - if group_id: - await cursor.execute(''' - SELECT COUNT(*) as total, - COUNT(DISTINCT sender_id) as unique_users - FROM raw_messages - WHERE group_id = ? - ''', (group_id,)) - else: - await cursor.execute(''' - SELECT COUNT(*) as total, - COUNT(DISTINCT sender_id) as unique_users - FROM raw_messages - ''') - - row = await cursor.fetchone() - stats['total_messages'] = row[0] if row else 0 - stats['unique_users'] = row[1] if row else 0 - - # 获取学习会话统计 - await cursor.execute(''' - SELECT COUNT(*) as session_count - FROM learning_sessions - ''' + (' WHERE group_id = ?' if group_id else ''), (group_id,) if group_id else ()) - - row = await cursor.fetchone() - stats['learning_sessions'] = row[0] if row else 0 - - await cursor.close() - - return {"data": stats, "metadata": {"data_source": "database"}} - - except Exception as e: - self._logger.error(f"导出学习统计失败: {e}") - return {"data": None, "metadata": {"error": str(e)}} - - async def _export_style_patterns( - self, - group_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ) -> Dict[str, Any]: - """导出风格模式数据""" - # 这里可以根据实际需求扩展 - return {"data": {"message": "风格模式导出功能待实现"}, "metadata": {}} - - async def _export_social_relations( - self, - group_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ) -> Dict[str, Any]: - """导出社交关系数据""" - # 这里可以根据实际需求扩展 - return {"data": {"message": "社交关系导出功能待实现"}, "metadata": {}} - - async def _export_message_stats( - self, - group_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ) -> Dict[str, Any]: - """导出消息统计数据""" - # 这里可以根据实际需求扩展 - return {"data": {"message": "消息统计导出功能待实现"}, "metadata": {}} - - async def _export_comprehensive_data( - self, - group_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ) -> Dict[str, Any]: - """导出综合数据(包含所有类型)""" - comprehensive = { - "emotion": await self._export_emotion_data(group_id, user_id, **kwargs), - "affection": await self._export_affection_data(group_id, user_id, **kwargs), - "learning_stats": await self._export_learning_stats(group_id, user_id, **kwargs) - } - - return { - "data": comprehensive, - "metadata": { - "included_types": ["emotion", "affection", "learning_stats"], - "comprehensive_export": True - } - } - - # ==================== 辅助方法(保持原有实现)==================== - """获取当前群组的情绪状态""" - try: - current_mood = await self.affection_manager.get_current_mood(group_id) - - if not current_mood or not current_mood.is_active(): - self._logger.debug(f"群组 {group_id} 没有活跃的情绪状态") - return None - - return EmotionData( - group_id=group_id, - mood_type=current_mood.mood_type.value, - mood_intensity=current_mood.intensity, - mood_description=current_mood.description, - start_time=current_mood.start_time, - end_time=current_mood.start_time + current_mood.duration_hours * 3600, - is_active=current_mood.is_active(), - created_at=self._format_timestamp(current_mood.start_time) - ) - - except Exception as e: - self._logger.error(f"获取群组 {group_id} 情绪状态失败: {e}") - return None - - async def get_user_affections(self, group_id: str, limit: int = 100) -> List[UserAffectionData]: - """获取群组内用户好感度列表""" - try: - affection_list = [] - - # 从数据库获取用户好感度数据 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - SELECT - user_id, - group_id, - affection_level, - last_interaction, - interaction_count, - last_updated, - created_at - FROM user_affection - WHERE group_id = ? - ORDER BY affection_level DESC - LIMIT ? - ''', (group_id, limit)) - - rows = await cursor.fetchall() - - for row in rows: - affection_list.append(UserAffectionData( - user_id=row[0], - group_id=row[1], - affection_level=row[2], - last_interaction=row[3], - interaction_count=row[4], - last_updated=row[5], - created_at=row[6] - )) - - await cursor.close() - - return affection_list - - except Exception as e: - self._logger.error(f"获取群组 {group_id} 用户好感度列表失败: {e}") - return [] - - async def get_group_affection_summary(self, group_id: str) -> Optional[GroupAffectionSummary]: - """获取群组好感度汇总信息""" - try: - # 使用affection_manager获取汇总数据 - affection_status = await self.affection_manager.get_affection_status(group_id) - - if not affection_status: - return None - - return GroupAffectionSummary( - group_id=group_id, - total_affection=affection_status['total_affection'], - max_total_affection=affection_status['max_total_affection'], - user_count=affection_status['user_count'], - avg_affection=affection_status['avg_affection'], - top_users=affection_status['top_users'][:5], # 前5名 - last_updated=time.time() - ) - - except Exception as e: - self._logger.error(f"获取群组 {group_id} 好感度汇总失败: {e}") - return None - - async def format_emotion_affection_data( - self, - group_id: str, - include_emotion: bool = True, - include_affection: bool = True, - include_summary: bool = True - ) -> EmotionAffectionResponse: - """ - 格式化情绪和好感度数据为标准JSON响应 - - Args: - group_id: 群组ID - include_emotion: 是否包含情绪数据 - include_affection: 是否包含用户好感度数据 - include_summary: 是否包含群组汇总数据 - - Returns: - EmotionAffectionResponse: 标准化响应数据 - """ - try: - current_emotion = None - user_affections = [] - group_summary = None - - # 获取当前情绪 - if include_emotion: - emotion_data = await self.get_current_emotion(group_id) - if emotion_data: - current_emotion = asdict(emotion_data) - - # 获取用户好感度列表 - if include_affection: - affection_list = await self.get_user_affections(group_id) - user_affections = [asdict(affection) for affection in affection_list] - - # 获取群组汇总 - if include_summary: - summary_data = await self.get_group_affection_summary(group_id) - if summary_data: - group_summary = asdict(summary_data) - - return EmotionAffectionResponse( - success=True, - timestamp=time.time(), - group_id=group_id, - current_emotion=current_emotion, - user_affections=user_affections, - group_summary=group_summary, - message="数据获取成功", - error=None - ) - - except Exception as e: - self._logger.error(f"格式化群组 {group_id} 数据失败: {e}", exc_info=True) - return EmotionAffectionResponse( - success=False, - timestamp=time.time(), - group_id=group_id, - current_emotion=None, - user_affections=[], - group_summary=None, - message=None, - error=f"数据获取失败: {str(e)}" - ) - - async def get_all_groups_emotion_affection(self) -> Dict[str, Any]: - """获取所有活跃群组的情绪和好感度数据""" - try: - # 获取所有活跃群组 - active_groups = await self._get_active_groups() - - groups_data = [] - for group_id in active_groups: - group_data = await self.format_emotion_affection_data( - group_id, - include_emotion=True, - include_affection=True, - include_summary=True - ) - groups_data.append(asdict(group_data)) - - return { - "success": True, - "timestamp": time.time(), - "total_groups": len(groups_data), - "groups": groups_data, - "message": "所有群组数据获取成功", - "error": None - } - - except Exception as e: - self._logger.error(f"获取所有群组数据失败: {e}", exc_info=True) - return { - "success": False, - "timestamp": time.time(), - "total_groups": 0, - "groups": [], - "message": None, - "error": f"数据获取失败: {str(e)}" - } - - async def _get_active_groups(self) -> List[str]: - """获取所有活跃群组ID列表""" - try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 获取最近7天内有消息的群组 - cutoff_time = time.time() - (86400 * 7) - await cursor.execute(''' - SELECT DISTINCT group_id - FROM raw_messages - WHERE timestamp > ? AND group_id IS NOT NULL AND group_id != '' - ORDER BY timestamp DESC - ''', (cutoff_time,)) - - rows = await cursor.fetchall() - await cursor.close() - - return [row[0] for row in rows] - - except Exception as e: - self._logger.error(f"获取活跃群组列表失败: {e}") - return [] - - async def get_user_emotion_affection( - self, - group_id: str, - user_id: str - ) -> Dict[str, Any]: - """获取指定用户在指定群组的情绪和好感度数据""" - try: - # 获取群组情绪 - emotion_data = await self.get_current_emotion(group_id) - - # 获取用户好感度 - user_affection = await self.db_manager.get_user_affection(group_id, user_id) - - # 获取用户最近的交互历史 - interaction_history = await self._get_user_interaction_history( - group_id, user_id, limit=10 - ) - - return { - "success": True, - "timestamp": time.time(), - "group_id": group_id, - "user_id": user_id, - "current_emotion": asdict(emotion_data) if emotion_data else None, - "user_affection": user_affection, - "interaction_history": interaction_history, - "message": "用户数据获取成功", - "error": None - } - - except Exception as e: - self._logger.error(f"获取用户 {user_id} 在群组 {group_id} 的数据失败: {e}") - return { - "success": False, - "timestamp": time.time(), - "group_id": group_id, - "user_id": user_id, - "current_emotion": None, - "user_affection": None, - "interaction_history": [], - "message": None, - "error": f"数据获取失败: {str(e)}" - } - - async def _get_user_interaction_history( - self, - group_id: str, - user_id: str, - limit: int = 10 - ) -> List[Dict[str, Any]]: - """获取用户交互历史记录""" - try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - SELECT - change_amount, - previous_level, - new_level, - change_reason, - bot_mood, - timestamp, - created_at - FROM affection_history - WHERE group_id = ? AND user_id = ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, user_id, limit)) - - rows = await cursor.fetchall() - await cursor.close() - - history = [] - for row in rows: - history.append({ - "change_amount": row[0], - "previous_level": row[1], - "new_level": row[2], - "change_reason": row[3], - "bot_mood": row[4], - "timestamp": row[5], - "created_at": row[6] - }) - - return history - - except Exception as e: - self._logger.error(f"获取用户 {user_id} 交互历史失败: {e}") - return [] diff --git a/services/database/__init__.py b/services/database/__init__.py new file mode 100644 index 0000000..75c509c --- /dev/null +++ b/services/database/__init__.py @@ -0,0 +1,14 @@ +"""Database access layer -- managers and factory.""" + +from .sqlalchemy_database_manager import SQLAlchemyDatabaseManager +from .manager_factory import ManagerFactory, get_manager_factory + +# 向后兼容别名:大量服务文件以 DatabaseManager 作为类型引用 +DatabaseManager = SQLAlchemyDatabaseManager + +__all__ = [ + "SQLAlchemyDatabaseManager", + "DatabaseManager", + "ManagerFactory", + "get_manager_factory", +] diff --git a/services/database/facades/__init__.py b/services/database/facades/__init__.py new file mode 100644 index 0000000..26e9240 --- /dev/null +++ b/services/database/facades/__init__.py @@ -0,0 +1,29 @@ +"""Domain Facade modules for decoupled data access.""" + +from ._base import BaseFacade +from .affection_facade import AffectionFacade +from .admin_facade import AdminFacade +from .expression_facade import ExpressionFacade +from .jargon_facade import JargonFacade +from .learning_facade import LearningFacade +from .message_facade import MessageFacade +from .metrics_facade import MetricsFacade +from .persona_facade import PersonaFacade +from .psychological_facade import PsychologicalFacade +from .reinforcement_facade import ReinforcementFacade +from .social_facade import SocialFacade + +__all__ = [ + "BaseFacade", + "AffectionFacade", + "AdminFacade", + "ExpressionFacade", + "JargonFacade", + "LearningFacade", + "MessageFacade", + "MetricsFacade", + "PersonaFacade", + "PsychologicalFacade", + "ReinforcementFacade", + "SocialFacade", +] diff --git a/services/database/facades/_base.py b/services/database/facades/_base.py new file mode 100644 index 0000000..c877aa2 --- /dev/null +++ b/services/database/facades/_base.py @@ -0,0 +1,91 @@ +""" +Facade 基类 — 提供会话管理和通用工具方法 +""" +import time +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +from astrbot.api import logger + +from ....config import PluginConfig +from ....core.database.engine import DatabaseEngine + + +class BaseFacade: + """领域 Facade 基类 + + 所有领域 Facade 继承此类,获得统一的会话管理能力。 + Facade 方法返回 Dict/List[Dict],不向消费者暴露 ORM 对象。 + """ + + def __init__(self, engine: DatabaseEngine, config: PluginConfig): + self.engine = engine + self.config = config + self._logger = logger + + @asynccontextmanager + async def get_session(self): + """获取异步数据库会话(上下文管理器)""" + session = self.engine.get_session() + try: + async with session: + yield session + finally: + await session.close() + + @staticmethod + def _row_to_dict(obj: Any, fields: Optional[List[str]] = None) -> Dict[str, Any]: + """将 ORM 对象转换为字典 + + Args: + obj: ORM 模型实例 + fields: 需要提取的字段列表。为 None 时使用 to_dict() 或 __table__.columns。 + + Returns: + Dict 表示的记录数据 + """ + if obj is None: + return {} + if hasattr(obj, 'to_dict'): + return obj.to_dict() + if fields: + return {f: getattr(obj, f, None) for f in fields} + # 回退:从 SQLAlchemy column 列表提取 + if hasattr(obj, '__table__'): + return {c.name: getattr(obj, c.name, None) for c in obj.__table__.columns} + return {} + + @staticmethod + def _to_float_ts( + value: Union[None, int, float, str, datetime], + default: Optional[float] = None, + ) -> Optional[float]: + """将各类时间表示统一转换为 float 时间戳 + + 支持 float/int 直通、ISO 8601 字符串、datetime 对象。 + 调用方传入 default=time.time() 可在 value 为 None 时使用当前时间。 + + Args: + value: 原始时间值 + default: value 为 None 时的回退值 + + Returns: + UNIX 时间戳 (float),或 None + """ + if value is None: + return default + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, datetime): + return value.timestamp() + if isinstance(value, str): + try: + return datetime.fromisoformat(value).timestamp() + except (ValueError, TypeError): + pass + try: + return float(value) + except (ValueError, TypeError): + pass + return default diff --git a/services/database/facades/admin_facade.py b/services/database/facades/admin_facade.py new file mode 100644 index 0000000..276912f --- /dev/null +++ b/services/database/facades/admin_facade.py @@ -0,0 +1,88 @@ +""" +管理操作 Facade — 批量清理、导出等管理功能的业务入口 +""" +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade + + +class AdminFacade(BaseFacade): + """管理操作 Facade""" + + async def clear_all_messages_data(self) -> bool: + """清除所有消息与学习数据(批量删除多个表)""" + try: + async with self.get_session() as session: + from sqlalchemy import delete as sa_delete + from ....models.orm.message import RawMessage, FilteredMessage + from ....models.orm.learning import LearningBatch + from ....models.orm.reinforcement import ( + ReinforcementLearningResult, PersonaFusionHistory, + StrategyOptimizationResult + ) + from ....models.orm.performance import LearningPerformanceHistory + + tables = [ + FilteredMessage, RawMessage, LearningBatch, + ReinforcementLearningResult, PersonaFusionHistory, + StrategyOptimizationResult, LearningPerformanceHistory, + ] + for table in tables: + try: + await session.execute(sa_delete(table)) + except Exception as table_err: + self._logger.warning( + f"[AdminFacade] 清除 {table.__tablename__} 失败: {table_err}" + ) + + await session.commit() + self._logger.info("[AdminFacade] 所有消息与学习数据已清除") + return True + except Exception as e: + self._logger.error(f"[AdminFacade] 清除数据失败: {e}") + return False + + async def export_messages_learning_data( + self, group_id: str = None + ) -> Dict[str, Any]: + """导出原始消息和筛选消息""" + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.message import RawMessage, FilteredMessage + + raw_stmt = select(RawMessage) + filtered_stmt = select(FilteredMessage) + if group_id: + raw_stmt = raw_stmt.where(RawMessage.group_id == group_id) + filtered_stmt = filtered_stmt.where(FilteredMessage.group_id == group_id) + + raw_result = await session.execute(raw_stmt) + raw_msgs = raw_result.scalars().all() + + filtered_result = await session.execute(filtered_stmt) + filtered_msgs = filtered_result.scalars().all() + + return { + 'raw_messages': [ + { + 'id': m.id, 'sender_id': m.sender_id, + 'message': m.message, 'group_id': m.group_id, + 'timestamp': m.timestamp, + } + for m in raw_msgs + ], + 'filtered_messages': [ + { + 'id': m.id, 'message': m.message, + 'group_id': m.group_id, 'confidence': m.confidence, + 'timestamp': m.timestamp, + } + for m in filtered_msgs + ], + } + except Exception as e: + self._logger.error(f"[AdminFacade] 导出数据失败: {e}") + return {'raw_messages': [], 'filtered_messages': []} diff --git a/services/database/facades/affection_facade.py b/services/database/facades/affection_facade.py new file mode 100644 index 0000000..332a70f --- /dev/null +++ b/services/database/facades/affection_facade.py @@ -0,0 +1,134 @@ +""" +好感度 Facade — 好感度与情绪状态的业务入口 +""" +import time +import json +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade +from ....repositories.affection_repository import AffectionRepository +from ....repositories.bot_mood_repository import BotMoodRepository + + +class AffectionFacade(BaseFacade): + """好感度与 Bot 情绪管理 Facade""" + + async def get_user_affection( + self, group_id: str, user_id: str + ) -> Optional[Dict[str, Any]]: + """获取用户好感度""" + try: + async with self.get_session() as session: + repo = AffectionRepository(session) + affection = await repo.get_by_group_and_user(group_id, user_id) + if affection: + return { + 'group_id': affection.group_id, + 'user_id': affection.user_id, + 'affection_level': affection.affection_level, + 'max_affection': affection.max_affection, + 'created_at': affection.created_at, + 'updated_at': affection.updated_at, + } + return None + except Exception as e: + self._logger.error(f"[AffectionFacade] 获取好感度失败: {e}") + return None + + async def update_user_affection( + self, + group_id: str, + user_id: str, + new_level: int, + change_reason: str = "", + bot_mood: str = "" + ) -> bool: + """更新用户好感度""" + try: + async with self.get_session() as session: + repo = AffectionRepository(session) + current = await repo.get_by_group_and_user(group_id, user_id) + previous_level = current.affection_level if current else 0 + affection_delta = new_level - previous_level + affection = await repo.update_level( + group_id, user_id, affection_delta, max_affection=100 + ) + return affection is not None + except Exception as e: + self._logger.error(f"[AffectionFacade] 更新好感度失败: {e}") + return False + + async def get_all_user_affections(self, group_id: str) -> List[Dict[str, Any]]: + """获取群组所有用户好感度""" + try: + async with self.get_session() as session: + repo = AffectionRepository(session) + affections = await repo.find_many(group_id=group_id) + return [ + { + 'group_id': a.group_id, + 'user_id': a.user_id, + 'affection_level': a.affection_level, + 'max_affection': a.max_affection, + 'created_at': a.created_at, + 'updated_at': a.updated_at, + } + for a in affections + ] + except Exception as e: + self._logger.error(f"[AffectionFacade] 获取所有好感度失败: {e}") + return [] + + async def get_total_affection(self, group_id: str) -> int: + """获取群组总好感度""" + try: + async with self.get_session() as session: + repo = AffectionRepository(session) + return await repo.get_total_affection(group_id) + except Exception as e: + self._logger.error(f"[AffectionFacade] 获取总好感度失败: {e}") + return 0 + + async def save_bot_mood( + self, + group_id: str, + mood_type: str, + mood_intensity: float, + mood_description: str, + duration_hours: int = 24 + ) -> bool: + """保存 Bot 情绪状态""" + try: + async with self.get_session() as session: + repo = BotMoodRepository(session) + mood = await repo.save({ + 'group_id': group_id, + 'mood_type': mood_type, + 'mood_intensity': mood_intensity, + 'mood_description': mood_description, + 'start_time': time.time(), + }) + return mood is not None + except Exception as e: + self._logger.error(f"[AffectionFacade] 保存情绪状态失败: {e}") + return False + + async def get_current_bot_mood(self, group_id: str) -> Optional[Dict[str, Any]]: + """获取当前活跃情绪""" + try: + async with self.get_session() as session: + repo = BotMoodRepository(session) + mood = await repo.get_current(group_id) + if not mood: + return None + return { + 'mood_type': mood.mood_type, + 'mood_intensity': mood.mood_intensity, + 'mood_description': mood.mood_description, + 'start_time': mood.start_time, + } + except Exception as e: + self._logger.error(f"[AffectionFacade] 获取当前情绪失败: {e}") + return None diff --git a/services/database/facades/expression_facade.py b/services/database/facades/expression_facade.py new file mode 100644 index 0000000..1d73cbd --- /dev/null +++ b/services/database/facades/expression_facade.py @@ -0,0 +1,208 @@ +""" +表达风格 Facade — 表达模式、风格画像、语言模式的业务入口 +""" +import time +import json +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade +from ....repositories.style_profile_repository import StyleProfileRepository + + +class ExpressionFacade(BaseFacade): + """表达风格管理 Facade""" + + async def get_all_expression_patterns(self) -> Dict[str, List[Dict[str, Any]]]: + """获取所有群组的表达模式""" + try: + async with self.get_session() as session: + from ....repositories.expression_repository import ExpressionPatternRepository + + repo = ExpressionPatternRepository(session) + all_patterns = await repo.get_all(limit=1000) + + grouped: Dict[str, List[Dict[str, Any]]] = {} + for p in all_patterns: + gid = p.group_id or 'global' + if gid not in grouped: + grouped[gid] = [] + grouped[gid].append(self._row_to_dict(p)) + return grouped + except Exception as e: + self._logger.error(f"[ExpressionFacade] 获取所有表达模式失败: {e}") + return {} + + async def get_expression_patterns_statistics(self) -> Dict[str, Any]: + """获取表达模式统计""" + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.expression import ExpressionPattern + + total_stmt = select(func.count()).select_from(ExpressionPattern) + total_result = await session.execute(total_stmt) + total = total_result.scalar() or 0 + + groups_stmt = select(func.count(func.distinct(ExpressionPattern.group_id))) + groups_result = await session.execute(groups_stmt) + groups = groups_result.scalar() or 0 + + return {'total_patterns': total, 'groups_with_patterns': groups} + except Exception as e: + self._logger.error(f"[ExpressionFacade] 获取统计失败: {e}") + return {'total_patterns': 0, 'groups_with_patterns': 0} + + async def get_group_expression_patterns( + self, group_id: str, limit: int = None + ) -> List[Dict[str, Any]]: + """获取指定群组的表达模式""" + try: + async with self.get_session() as session: + from ....repositories.expression_repository import ExpressionPatternRepository + + repo = ExpressionPatternRepository(session) + patterns = await repo.find_many( + group_id=group_id, limit=limit or 100 + ) + return [self._row_to_dict(p) for p in patterns] + except Exception as e: + self._logger.error(f"[ExpressionFacade] 获取群组表达模式失败: {e}") + return [] + + async def get_recent_week_expression_patterns( + self, group_id: str = None, limit: int = 50 + ) -> List[Dict[str, Any]]: + """获取最近一周的表达模式""" + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.expression import ExpressionPattern + + cutoff = time.time() - (7 * 24 * 3600) + stmt = select(ExpressionPattern).where( + ExpressionPattern.last_adapted_at >= cutoff + ) + if group_id: + stmt = stmt.where(ExpressionPattern.group_id == group_id) + stmt = stmt.order_by(desc(ExpressionPattern.usage_count)).limit(limit) + + result = await session.execute(stmt) + return [self._row_to_dict(p) for p in result.scalars().all()] + except Exception as e: + self._logger.error(f"[ExpressionFacade] 获取近期表达模式失败: {e}") + return [] + + # ---- 风格画像 ---- + + async def load_style_profile(self, profile_name: str) -> Optional[Dict[str, Any]]: + """加载风格画像""" + try: + async with self.get_session() as session: + repo = StyleProfileRepository(session) + sp = await repo.load(profile_name) + if not sp: + return None + return { + 'profile_name': sp.profile_name, + 'vocabulary_richness': sp.vocabulary_richness, + 'sentence_complexity': sp.sentence_complexity, + 'emotional_expression': sp.emotional_expression, + 'interaction_tendency': sp.interaction_tendency, + 'topic_diversity': sp.topic_diversity, + 'formality_level': sp.formality_level, + 'creativity_score': sp.creativity_score, + } + except Exception as e: + self._logger.error(f"[ExpressionFacade] 加载风格画像失败: {e}") + return None + + async def save_style_profile( + self, profile_name: str, profile_data: Dict[str, Any] + ) -> bool: + """保存风格画像(upsert)""" + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.expression import StyleProfile + + stmt = select(StyleProfile).where(StyleProfile.profile_name == profile_name) + result = await session.execute(stmt) + sp = result.scalar_one_or_none() + if sp: + for key in ('vocabulary_richness', 'sentence_complexity', 'emotional_expression', + 'interaction_tendency', 'topic_diversity', 'formality_level', 'creativity_score'): + if key in profile_data: + setattr(sp, key, profile_data[key]) + else: + sp = StyleProfile(profile_name=profile_name, **{ + k: profile_data.get(k) + for k in ('vocabulary_richness', 'sentence_complexity', 'emotional_expression', + 'interaction_tendency', 'topic_diversity', 'formality_level', 'creativity_score') + }) + session.add(sp) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[ExpressionFacade] 保存风格画像失败: {e}") + return False + + # ---- 风格学习记录 ---- + + async def save_style_learning_record(self, record_data: Dict[str, Any]) -> bool: + """保存风格学习记录""" + try: + async with self.get_session() as session: + from ....models.orm.expression import StyleLearningRecord + + rec = StyleLearningRecord( + style_type=record_data.get('style_type', 'unknown'), + learned_patterns=json.dumps(record_data.get('learned_patterns', []), ensure_ascii=False), + confidence_score=record_data.get('confidence_score', 0.0), + sample_count=record_data.get('sample_count', 0), + last_updated=time.time(), + ) + session.add(rec) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[ExpressionFacade] 保存风格学习记录失败: {e}") + return False + + async def save_language_style_pattern( + self, language_style: str, pattern_data: Dict[str, Any] + ) -> bool: + """保存语言风格模式(upsert)""" + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.expression import LanguageStylePattern + + stmt = select(LanguageStylePattern).where( + LanguageStylePattern.language_style == language_style + ) + result = await session.execute(stmt) + pat = result.scalar_one_or_none() + now = time.time() + if pat: + pat.example_phrases = json.dumps(pattern_data.get('example_phrases', []), ensure_ascii=False) + pat.usage_frequency = (pat.usage_frequency or 0) + 1 + pat.context_type = pattern_data.get('context_type', 'general') + pat.confidence_score = pattern_data.get('confidence_score') + pat.last_updated = now + else: + pat = LanguageStylePattern( + language_style=language_style, + example_phrases=json.dumps(pattern_data.get('example_phrases', []), ensure_ascii=False), + usage_frequency=1, + context_type=pattern_data.get('context_type', 'general'), + confidence_score=pattern_data.get('confidence_score'), + last_updated=now, + ) + session.add(pat) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[ExpressionFacade] 保存语言风格模式失败: {e}") + return False diff --git a/services/database/facades/jargon_facade.py b/services/database/facades/jargon_facade.py new file mode 100644 index 0000000..6e95027 --- /dev/null +++ b/services/database/facades/jargon_facade.py @@ -0,0 +1,771 @@ +""" +黑话 Facade — 黑话(Jargon)域的业务入口 + +封装所有黑话相关的数据库操作,对外仅暴露 Dict / List[Dict] 等纯数据结构。 +""" +import time +import json +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade + + +class JargonFacade(BaseFacade): + """黑话管理 Facade""" + + # 1. get_jargon + async def get_jargon(self, chat_id: str, content: str) -> Optional[Dict[str, Any]]: + """查询指定黑话(按 chat_id + content 唯一定位) + + Args: + chat_id: 群组ID + content: 黑话内容 + + Returns: + 黑话字典或 None + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, and_ + from ....models.orm.jargon import Jargon + + stmt = select(Jargon).where(and_( + Jargon.chat_id == chat_id, + Jargon.content == content + )) + result = await session.execute(stmt) + record = result.scalars().first() + + if not record: + return None + + return record.to_dict() + + except Exception as e: + self._logger.error(f"[JargonFacade] 查询黑话失败: {e}", exc_info=True) + return None + + # 2. insert_jargon + async def insert_jargon(self, jargon_data: Dict[str, Any]) -> Optional[int]: + """插入新的黑话记录 + + Args: + jargon_data: 黑话数据字典 + + Returns: + 新记录 ID 或 None + """ + try: + async with self.get_session() as session: + from ....models.orm.jargon import Jargon + + now_ts = int(time.time()) + + # 处理 created_at / updated_at — 统一转为 int 时间戳 + created_at = jargon_data.get('created_at') + updated_at = jargon_data.get('updated_at') + if created_at and not isinstance(created_at, (int, float)): + created_at = now_ts + elif created_at: + created_at = int(created_at) + else: + created_at = now_ts + + if updated_at and not isinstance(updated_at, (int, float)): + updated_at = now_ts + elif updated_at: + updated_at = int(updated_at) + else: + updated_at = now_ts + + record = Jargon( + content=jargon_data.get('content', ''), + raw_content=jargon_data.get('raw_content', '[]'), + meaning=jargon_data.get('meaning'), + is_jargon=jargon_data.get('is_jargon'), + count=jargon_data.get('count', 1), + last_inference_count=jargon_data.get('last_inference_count', 0), + is_complete=jargon_data.get('is_complete', False), + is_global=jargon_data.get('is_global', False), + chat_id=jargon_data.get('chat_id', ''), + created_at=created_at, + updated_at=updated_at + ) + + session.add(record) + await session.commit() + await session.refresh(record) + + self._logger.info( + f"[JargonFacade] 插入黑话成功: id={record.id}, content={record.content}" + ) + return record.id + + except Exception as e: + self._logger.error(f"[JargonFacade] 插入黑话失败: {e}", exc_info=True) + return None + + # 3. update_jargon + async def update_jargon(self, jargon_data: Dict[str, Any]) -> bool: + """更新现有黑话记录 + + Args: + jargon_data: 包含 id 和待更新字段的字典 + + Returns: + 是否更新成功 + """ + jargon_id = jargon_data.get('id') + if not jargon_id: + self._logger.error("[JargonFacade] 更新黑话失败: 缺少 id") + return False + + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.jargon import Jargon + + stmt = select(Jargon).where(Jargon.id == jargon_id) + result = await session.execute(stmt) + record = result.scalars().first() + + if not record: + self._logger.warning(f"[JargonFacade] 更新黑话失败: 未找到 id={jargon_id}") + return False + + # 更新字段 + if 'content' in jargon_data: + record.content = jargon_data['content'] + if 'raw_content' in jargon_data: + record.raw_content = jargon_data['raw_content'] + if 'meaning' in jargon_data: + record.meaning = jargon_data['meaning'] + if 'is_jargon' in jargon_data: + record.is_jargon = jargon_data['is_jargon'] + if 'count' in jargon_data: + record.count = jargon_data['count'] + if 'last_inference_count' in jargon_data: + record.last_inference_count = jargon_data['last_inference_count'] + if 'is_complete' in jargon_data: + record.is_complete = jargon_data['is_complete'] + if 'is_global' in jargon_data: + record.is_global = jargon_data['is_global'] + + # updated_at 统一为 int 时间戳 + updated_at = jargon_data.get('updated_at') + if updated_at and not isinstance(updated_at, (int, float)): + record.updated_at = int(time.time()) + elif updated_at: + record.updated_at = int(updated_at) + else: + record.updated_at = int(time.time()) + + await session.commit() + self._logger.debug(f"[JargonFacade] 更新黑话成功: id={jargon_id}") + return True + + except Exception as e: + self._logger.error(f"[JargonFacade] 更新黑话失败: {e}", exc_info=True) + return False + + # 4. get_jargon_statistics + async def get_jargon_statistics(self, group_id: str = None) -> Dict[str, Any]: + """获取黑话学习统计信息 + + Args: + group_id: 群组ID(可选,None 表示全局统计) + + Returns: + 统计数据字典,包含 total_candidates, confirmed_jargon, + completed_inference, total_occurrences, average_count, active_groups + """ + default_stats = { + 'total_candidates': 0, + 'confirmed_jargon': 0, + 'completed_inference': 0, + 'total_occurrences': 0, + 'average_count': 0.0, + 'active_groups': 0, + } + try: + async with self.get_session() as session: + from sqlalchemy import select, func, case + from ....models.orm.jargon import Jargon + + columns = [ + func.count().label('total'), + func.count(case((Jargon.is_jargon == True, 1))).label('confirmed'), + func.count(case((Jargon.is_complete == True, 1))).label('completed'), + func.coalesce(func.sum(Jargon.count), 0).label('total_occurrences'), + func.coalesce(func.avg(Jargon.count), 0).label('avg_count'), + ] + + if not group_id: + columns.append( + func.count(func.distinct(Jargon.chat_id)).label('active_groups') + ) + + stmt = select(*columns) + if group_id: + stmt = stmt.where(Jargon.chat_id == group_id) + + result = await session.execute(stmt) + row = result.fetchone() + + if not row: + return default_stats + + stats = { + 'total_candidates': int(row.total) if row.total else 0, + 'confirmed_jargon': int(row.confirmed) if row.confirmed else 0, + 'completed_inference': int(row.completed) if row.completed else 0, + 'total_occurrences': int(row.total_occurrences) if row.total_occurrences else 0, + 'average_count': round(float(row.avg_count), 1) if row.avg_count else 0.0, + } + + if not group_id: + stats['active_groups'] = int(row.active_groups) if row.active_groups else 0 + else: + stats['active_groups'] = 1 if stats['total_candidates'] > 0 else 0 + + return stats + + except Exception as e: + self._logger.error(f"[JargonFacade] 获取黑话统计失败: {e}", exc_info=True) + return default_stats + + # 5. get_recent_jargon_list + async def get_recent_jargon_list( + self, + group_id: str = None, + chat_id: str = None, + limit: int = 10, + offset: int = 0, + only_confirmed: bool = None + ) -> List[Dict]: + """获取最近的黑话列表 + + Args: + group_id: 群组ID(可选,None 表示获取所有群组) + chat_id: 聊天ID(可选,兼容参数) + limit: 返回数量限制 + offset: 偏移量(用于分页) + only_confirmed: 是否只返回已确认的黑话 + + Returns: + 黑话列表 + """ + # chat_id 是 group_id 的别名(向后兼容) + if group_id is None and chat_id is not None: + group_id = chat_id + + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.jargon import Jargon + + # 构建查询 + stmt = select(Jargon) + + # 如果指定了 group_id,则只查询该群组 + if group_id is not None: + stmt = stmt.where(Jargon.chat_id == group_id) + + # 按确认状态过滤(None=全部, True=已确认, False=未确认) + if only_confirmed is True: + stmt = stmt.where(Jargon.is_jargon == True) + elif only_confirmed is False: + stmt = stmt.where( + (Jargon.is_jargon == False) | (Jargon.is_jargon == None) + ) + + # 按更新时间倒序排列,分页 + stmt = stmt.order_by(Jargon.updated_at.desc()) + if offset > 0: + stmt = stmt.offset(offset) + stmt = stmt.limit(limit) + + result = await session.execute(stmt) + jargon_records = result.scalars().all() + + self._logger.debug( + f"[JargonFacade] 查询最近黑话列表: group_id={group_id}, " + f"数量={len(jargon_records)}" + ) + + jargon_list = [] + for record in jargon_records: + try: + jargon_list.append({ + 'id': record.id, + 'content': record.content, + 'raw_content': record.raw_content, + 'meaning': record.meaning, + 'is_jargon': record.is_jargon, + 'count': record.count or 0, + 'last_inference_count': record.last_inference_count or 0, + 'is_complete': record.is_complete, + 'chat_id': record.chat_id, + 'updated_at': record.updated_at, + 'is_global': record.is_global or False + }) + except Exception as row_error: + self._logger.warning(f"处理黑话记录行时出错,跳过: {row_error}") + continue + + return jargon_list + + except Exception as e: + self._logger.error(f"[JargonFacade] 获取最近黑话列表失败: {e}", exc_info=True) + return [] + + # 6. get_jargon_count + async def get_jargon_count( + self, + chat_id: Optional[str] = None, + only_confirmed: Optional[bool] = None, + ) -> int: + """获取黑话记录总数(用于分页) + + Args: + chat_id: 群组ID(可选,None 表示所有群组) + only_confirmed: None=全部, True=已确认, False=未确认 + + Returns: + 记录总数 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.jargon import Jargon + + stmt = select(func.count(Jargon.id)) + + if chat_id is not None: + stmt = stmt.where(Jargon.chat_id == chat_id) + + if only_confirmed is True: + stmt = stmt.where(Jargon.is_jargon == True) + elif only_confirmed is False: + stmt = stmt.where( + (Jargon.is_jargon == False) | (Jargon.is_jargon == None) + ) + + result = await session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + self._logger.error(f"[JargonFacade] 获取黑话总数失败: {e}", exc_info=True) + return 0 + + # 7. search_jargon + async def search_jargon( + self, + keyword: str, + chat_id: Optional[str] = None, + confirmed_only: bool = True, + limit: int = 10 + ) -> List[Dict]: + """搜索黑话(LIKE 匹配) + + Args: + keyword: 搜索关键词 + chat_id: 群组ID(有值搜本群,无值搜全局已确认黑话) + confirmed_only: 是否仅返回已确认的黑话(默认 True) + limit: 返回数量限制 + + Returns: + 匹配的黑话列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, and_ + from ....models.orm.jargon import Jargon + + conditions = [ + Jargon.content.ilike(f'%{keyword}%'), + ] + if confirmed_only: + conditions.append(Jargon.is_jargon == True) + if chat_id: + conditions.append(Jargon.chat_id == chat_id) + elif confirmed_only: + # 无群组限制 + 仅已确认 → 限定全局黑话 + conditions.append(Jargon.is_global == True) + + stmt = ( + select(Jargon) + .where(and_(*conditions)) + .order_by(Jargon.count.desc(), Jargon.updated_at.desc()) + .limit(limit) + ) + result = await session.execute(stmt) + records = result.scalars().all() + + return [ + { + 'id': r.id, + 'content': r.content, + 'raw_content': r.raw_content, + 'meaning': r.meaning, + 'is_jargon': r.is_jargon, + 'count': r.count or 0, + 'is_complete': r.is_complete, + 'is_global': r.is_global or False, + 'chat_id': r.chat_id, + 'updated_at': r.updated_at, + } + for r in records + ] + except Exception as e: + self._logger.error(f"[JargonFacade] 搜索黑话失败: {e}", exc_info=True) + return [] + + # 8. get_jargon_by_id + async def get_jargon_by_id(self, jargon_id: int) -> Optional[Dict]: + """根据 ID 获取黑话记录 + + Args: + jargon_id: 黑话记录 ID + + Returns: + 黑话字典或 None + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.jargon import Jargon + + stmt = select(Jargon).where(Jargon.id == jargon_id) + result = await session.execute(stmt) + record = result.scalars().first() + + if not record: + return None + + return record.to_dict() + + except Exception as e: + self._logger.error( + f"[JargonFacade] 获取黑话记录失败 (id={jargon_id}): {e}", exc_info=True + ) + return None + + # 9. delete_jargon_by_id + async def delete_jargon_by_id(self, jargon_id: int) -> bool: + """根据 ID 删除黑话记录 + + Args: + jargon_id: 黑话记录 ID + + Returns: + 是否删除成功 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.jargon import Jargon + + stmt = select(Jargon).where(Jargon.id == jargon_id) + result = await session.execute(stmt) + record = result.scalars().first() + + if not record: + return False + + await session.delete(record) + await session.commit() + self._logger.debug(f"[JargonFacade] 删除黑话记录成功, ID: {jargon_id}") + return True + except Exception as e: + self._logger.error( + f"[JargonFacade] 删除黑话失败 (id={jargon_id}): {e}", exc_info=True + ) + return False + + # 10. set_jargon_global + async def set_jargon_global(self, jargon_id: int, is_global: bool) -> bool: + """设置黑话的全局共享状态 + + Args: + jargon_id: 黑话记录 ID + is_global: 是否全局共享 + + Returns: + 是否更新成功 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.jargon import Jargon + + stmt = select(Jargon).where(Jargon.id == jargon_id) + result = await session.execute(stmt) + record = result.scalars().first() + + if not record: + return False + + record.is_global = is_global + record.updated_at = int(time.time()) + await session.commit() + self._logger.info( + f"[JargonFacade] 黑话全局状态已更新: ID={jargon_id}, is_global={is_global}" + ) + return True + except Exception as e: + self._logger.error( + f"[JargonFacade] 更新黑话全局状态失败 (id={jargon_id}): {e}", exc_info=True + ) + return False + + # 11. sync_global_jargon_to_group + async def sync_global_jargon_to_group(self, target_chat_id: str) -> int: + """将全局黑话同步到指定群组 + + 对全局黑话逐条检查目标群组是否已存在相同内容,不存在则插入。 + + Args: + target_chat_id: 目标群组 ID + + Returns: + 成功同步的数量 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, and_ + from ....models.orm.jargon import Jargon + + # 获取非目标群组的全局黑话 + stmt = select(Jargon).where(and_( + Jargon.is_jargon == True, + Jargon.is_global == True, + Jargon.chat_id != target_chat_id + )) + result = await session.execute(stmt) + global_jargons = result.scalars().all() + + synced_count = 0 + now_ts = int(time.time()) + + for gj in global_jargons: + # 检查目标群组是否已存在 + check_stmt = select(Jargon).where(and_( + Jargon.chat_id == target_chat_id, + Jargon.content == gj.content + )) + check_result = await session.execute(check_stmt) + if check_result.scalars().first(): + continue + + new_jargon = Jargon( + content=gj.content, + raw_content='[]', + meaning=gj.meaning, + is_jargon=True, + count=1, + last_inference_count=0, + is_complete=False, + is_global=False, + chat_id=target_chat_id, + created_at=now_ts, + updated_at=now_ts, + ) + session.add(new_jargon) + synced_count += 1 + + await session.commit() + self._logger.info( + f"[JargonFacade] 同步全局黑话到群组 {target_chat_id}: 同步 {synced_count} 条" + ) + return synced_count + except Exception as e: + self._logger.error(f"[JargonFacade] 同步全局黑话失败: {e}", exc_info=True) + return 0 + + # 12. save_or_update_jargon + async def save_or_update_jargon( + self, + chat_id: str, + content: str, + jargon_data: Dict[str, Any] + ) -> Optional[int]: + """保存或更新黑话记录(Upsert) + + 按 chat_id + content 检查是否已存在: + - 存在 → 用 jargon_data 中的字段更新 + - 不存在 → 插入新记录 + + Args: + chat_id: 群组 ID + content: 黑话内容 + jargon_data: 黑话数据字典 + + Returns: + 记录 ID 或 None + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, and_ + from ....models.orm.jargon import Jargon + + stmt = select(Jargon).where(and_( + Jargon.chat_id == chat_id, + Jargon.content == content, + )) + result = await session.execute(stmt) + record = result.scalars().first() + + now_ts = int(time.time()) + + if record: + # 更新已有记录 + if 'meaning' in jargon_data: + record.meaning = jargon_data['meaning'] + if 'raw_content' in jargon_data: + record.raw_content = jargon_data['raw_content'] + if 'is_jargon' in jargon_data: + record.is_jargon = jargon_data['is_jargon'] + if 'count' in jargon_data: + record.count = jargon_data['count'] + if 'last_inference_count' in jargon_data: + record.last_inference_count = jargon_data['last_inference_count'] + if 'is_complete' in jargon_data: + record.is_complete = jargon_data['is_complete'] + if 'is_global' in jargon_data: + record.is_global = jargon_data['is_global'] + record.updated_at = now_ts + + await session.commit() + self._logger.debug( + f"[JargonFacade] 更新黑话: content='{content}', chat_id={chat_id}, " + f"id={record.id}" + ) + return record.id + else: + # 插入新记录 + new_record = Jargon( + content=content, + raw_content=jargon_data.get('raw_content', '[]'), + meaning=jargon_data.get('meaning'), + is_jargon=jargon_data.get('is_jargon', True), + count=jargon_data.get('count', 1), + last_inference_count=jargon_data.get('last_inference_count', 0), + is_complete=jargon_data.get('is_complete', False), + is_global=jargon_data.get('is_global', False), + chat_id=chat_id, + created_at=now_ts, + updated_at=now_ts, + ) + session.add(new_record) + await session.commit() + await session.refresh(new_record) + self._logger.debug( + f"[JargonFacade] 插入黑话: content='{content}', chat_id={chat_id}, " + f"id={new_record.id}" + ) + return new_record.id + + except Exception as e: + self._logger.error( + f"[JargonFacade] 保存/更新黑话失败 (content='{content}'): {e}", + exc_info=True, + ) + return None + + # 13. get_global_jargon_list + async def get_global_jargon_list(self, limit: int = 50) -> List[Dict]: + """获取全局共享的黑话列表 + + Args: + limit: 返回数量限制 + + Returns: + 全局黑话列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.jargon import Jargon + + stmt = select(Jargon).where( + Jargon.is_jargon == True, + Jargon.is_global == True + ).order_by( + Jargon.count.desc(), + Jargon.updated_at.desc() + ).limit(limit) + + result = await session.execute(stmt) + jargon_list = result.scalars().all() + + self._logger.debug( + f"[JargonFacade] 查询全局黑话列表: 数量={len(jargon_list)}" + ) + + return [ + { + 'id': jargon.id, + 'content': jargon.content, + 'raw_content': jargon.raw_content, + 'meaning': jargon.meaning, + 'is_jargon': jargon.is_jargon, + 'count': jargon.count, + 'last_inference_count': jargon.last_inference_count, + 'is_complete': jargon.is_complete, + 'is_global': jargon.is_global, + 'chat_id': jargon.chat_id, + 'updated_at': jargon.updated_at + } + for jargon in jargon_list + ] + + except Exception as e: + self._logger.error(f"[JargonFacade] 获取全局黑话列表失败: {e}", exc_info=True) + return [] + + # 14. get_jargon_groups + async def get_jargon_groups(self) -> List[Dict]: + """获取包含黑话的群组列表 + + Returns: + 群组列表 [{chat_id, count}, ...] + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.jargon import Jargon + + stmt = select( + Jargon.chat_id, + func.count(Jargon.id).label('count') + ).group_by( + Jargon.chat_id + ).order_by( + func.count(Jargon.id).desc() + ) + + result = await session.execute(stmt) + rows = result.all() + + self._logger.debug(f"[JargonFacade] 查询黑话群组列表: 数量={len(rows)}") + + groups = [] + for row in rows: + try: + groups.append({ + 'chat_id': row.chat_id, + 'count': row.count or 0 + }) + except Exception as row_error: + self._logger.warning( + f"处理黑话群组数据行失败: {row_error}, 行数据: {row}" + ) + continue + + return groups + + except Exception as e: + self._logger.error(f"[JargonFacade] 获取黑话群组列表失败: {e}", exc_info=True) + return [] diff --git a/services/database/facades/learning_facade.py b/services/database/facades/learning_facade.py new file mode 100644 index 0000000..2c8565b --- /dev/null +++ b/services/database/facades/learning_facade.py @@ -0,0 +1,1101 @@ +""" +学习 Facade — 人格学习审核、风格学习审核、学习批次/会话、统计的业务入口 +""" +import time +import json +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade + + +class LearningFacade(BaseFacade): + """学习管理 Facade — 包装所有学习相关的数据库方法""" + + # Persona Learning Review methods + + async def add_persona_learning_review(self, review_data: Dict[str, Any]) -> int: + """创建人格学习审核记录 + + Args: + review_data: 审核数据字典 + + Returns: + 新记录的 id,失败返回 0 + """ + try: + async with self.get_session() as session: + from ....models.orm.learning import PersonaLearningReview + + metadata = review_data.get('metadata', {}) + record = PersonaLearningReview( + timestamp=self._to_float_ts( + review_data.get('timestamp'), default=time.time() + ), + group_id=review_data.get('group_id', ''), + update_type=review_data.get('update_type', ''), + original_content=review_data.get('original_content', ''), + new_content=review_data.get('new_content', ''), + proposed_content=review_data.get('proposed_content', ''), + confidence_score=review_data.get('confidence_score', 0.0), + reason=review_data.get('reason', ''), + status='pending', + metadata_=json.dumps(metadata, ensure_ascii=False) if metadata else None, + ) + session.add(record) + await session.commit() + await session.refresh(record) + return record.id + except Exception as e: + self._logger.error(f"[LearningFacade] 添加人格学习审核记录失败: {e}") + return 0 + + async def get_pending_persona_update_records(self) -> List[Dict[str, Any]]: + """获取所有待审核的人格更新记录 + + Returns: + 待审核记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import PersonaLearningReview + + stmt = ( + select(PersonaLearningReview) + .where(PersonaLearningReview.status == 'pending') + .order_by(desc(PersonaLearningReview.timestamp)) + ) + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'id': r.id, + 'timestamp': r.timestamp, + 'group_id': r.group_id, + 'update_type': r.update_type, + 'original_content': r.original_content, + 'new_content': r.new_content, + 'proposed_content': r.proposed_content, + 'confidence_score': r.confidence_score, + 'reason': r.reason, + 'status': r.status, + 'reviewer_comment': r.reviewer_comment, + 'review_time': r.review_time, + 'metadata': json.loads(r.metadata_) if r.metadata_ else {}, + } + for r in rows + ] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取待审核人格更新记录失败: {e}") + return [] + + async def save_persona_update_record(self, record: Dict[str, Any]) -> int: + """保存人格更新记录(add_persona_learning_review 的别名) + + Args: + record: 记录数据字典 + + Returns: + 新记录的 id,失败返回 0 + """ + return await self.add_persona_learning_review(record) + + async def update_persona_update_record_status( + self, record_id: int, new_status: str, reviewer_comment: str = '' + ) -> bool: + """更新人格更新记录的状态 + + Args: + record_id: 记录 ID + new_status: 新状态 (approved/rejected) + reviewer_comment: 审核评论 + + Returns: + 是否更新成功 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.learning import PersonaLearningReview + + stmt = select(PersonaLearningReview).where( + PersonaLearningReview.id == record_id + ) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return False + + record.status = new_status + record.reviewer_comment = reviewer_comment + record.review_time = time.time() + await session.commit() + return True + except Exception as e: + self._logger.error(f"[LearningFacade] 更新人格更新记录状态失败: {e}") + return False + + async def delete_persona_update_record(self, record_id: int) -> bool: + """删除人格更新记录 + + Args: + record_id: 记录 ID + + Returns: + 是否删除成功 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, delete as sa_delete + from ....models.orm.learning import PersonaLearningReview + + stmt = select(PersonaLearningReview).where( + PersonaLearningReview.id == record_id + ) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return False + + await session.execute( + sa_delete(PersonaLearningReview).where( + PersonaLearningReview.id == record_id + ) + ) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[LearningFacade] 删除人格更新记录失败: {e}") + return False + + async def get_persona_update_record_by_id( + self, record_id: int + ) -> Optional[Dict[str, Any]]: + """根据 ID 获取人格更新记录 + + Args: + record_id: 记录 ID + + Returns: + 记录字典或 None + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.learning import PersonaLearningReview + + stmt = select(PersonaLearningReview).where( + PersonaLearningReview.id == record_id + ) + result = await session.execute(stmt) + r = result.scalar_one_or_none() + if not r: + return None + return { + 'id': r.id, + 'timestamp': r.timestamp, + 'group_id': r.group_id, + 'update_type': r.update_type, + 'original_content': r.original_content, + 'new_content': r.new_content, + 'proposed_content': r.proposed_content, + 'confidence_score': r.confidence_score, + 'reason': r.reason, + 'status': r.status, + 'reviewer_comment': r.reviewer_comment, + 'review_time': r.review_time, + 'metadata': json.loads(r.metadata_) if r.metadata_ else {}, + } + except Exception as e: + self._logger.error(f"[LearningFacade] 获取人格更新记录失败: {e}") + return None + + async def get_reviewed_persona_update_records( + self, limit: int = 50, offset: int = 0, status_filter: str = None + ) -> List[Dict[str, Any]]: + """获取已审核的人格更新记录 + + Args: + limit: 返回数量限制 + offset: 偏移量 + status_filter: 状态过滤 + + Returns: + 已审核记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import PersonaLearningReview + + if status_filter: + stmt = ( + select(PersonaLearningReview) + .where(PersonaLearningReview.status == status_filter) + .order_by(desc(PersonaLearningReview.review_time)) + .offset(offset) + .limit(limit) + ) + else: + stmt = ( + select(PersonaLearningReview) + .where(PersonaLearningReview.status.in_(['approved', 'rejected'])) + .order_by(desc(PersonaLearningReview.review_time)) + .offset(offset) + .limit(limit) + ) + + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'id': r.id, + 'timestamp': r.timestamp, + 'group_id': r.group_id, + 'update_type': r.update_type, + 'original_content': r.original_content, + 'new_content': r.new_content, + 'proposed_content': r.proposed_content, + 'confidence_score': r.confidence_score, + 'reason': r.reason, + 'status': r.status, + 'reviewer_comment': r.reviewer_comment, + 'review_time': r.review_time, + 'metadata': json.loads(r.metadata_) if r.metadata_ else {}, + } + for r in rows + ] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取已审核人格更新记录失败: {e}") + return [] + + async def get_pending_persona_learning_reviews( + self, limit: int = None, offset: int = 0 + ) -> List[Dict[str, Any]]: + """获取待审核的人格学习审核记录 + + Args: + limit: 可选的返回数量限制 + offset: 分页偏移量 + + Returns: + 待审核记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import PersonaLearningReview + + stmt = ( + select(PersonaLearningReview) + .where(PersonaLearningReview.status == 'pending') + .order_by(desc(PersonaLearningReview.timestamp)) + ) + if offset > 0: + stmt = stmt.offset(offset) + if limit is not None: + stmt = stmt.limit(limit) + + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'id': r.id, + 'timestamp': r.timestamp, + 'group_id': r.group_id, + 'update_type': r.update_type, + 'original_content': r.original_content, + 'new_content': r.new_content, + 'proposed_content': r.proposed_content, + 'confidence_score': r.confidence_score, + 'reason': r.reason, + 'status': r.status, + 'reviewer_comment': r.reviewer_comment, + 'review_time': r.review_time, + 'metadata': json.loads(r.metadata_) if r.metadata_ else {}, + } + for r in rows + ] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取待审核人格学习审核记录失败: {e}") + return [] + + async def get_reviewed_persona_learning_updates( + self, limit=50, offset=0, status_filter=None + ) -> List[Dict]: + """获取已审核的人格学习更新记录 + + Args: + limit: 返回数量限制 + offset: 偏移量 + status_filter: 状态过滤 + + Returns: + 已审核记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import PersonaLearningReview + + if status_filter: + stmt = ( + select(PersonaLearningReview) + .where(PersonaLearningReview.status == status_filter) + .order_by(desc(PersonaLearningReview.review_time)) + .offset(offset) + .limit(limit) + ) + else: + stmt = ( + select(PersonaLearningReview) + .where(PersonaLearningReview.status.in_(['approved', 'rejected'])) + .order_by(desc(PersonaLearningReview.review_time)) + .offset(offset) + .limit(limit) + ) + + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'id': r.id, + 'timestamp': r.timestamp, + 'group_id': r.group_id, + 'update_type': r.update_type, + 'original_content': r.original_content, + 'new_content': r.new_content, + 'proposed_content': r.proposed_content, + 'confidence_score': r.confidence_score, + 'reason': r.reason, + 'status': r.status, + 'reviewer_comment': r.reviewer_comment, + 'review_time': r.review_time, + } + for r in rows + ] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取已审核人格学习更新记录失败: {e}") + return [] + + async def delete_persona_learning_review_by_id(self, review_id: int) -> bool: + """根据 ID 删除人格学习审核记录 + + Args: + review_id: 审核记录 ID + + Returns: + 是否删除成功 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, delete as sa_delete + from ....models.orm.learning import PersonaLearningReview + + stmt = select(PersonaLearningReview).where( + PersonaLearningReview.id == review_id + ) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return False + + await session.execute( + sa_delete(PersonaLearningReview).where( + PersonaLearningReview.id == review_id + ) + ) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[LearningFacade] 删除人格学习审核记录失败: {e}") + return False + + async def get_persona_learning_review_by_id( + self, review_id: int + ) -> Optional[Dict]: + """根据 ID 获取人格学习审核记录(get_persona_update_record_by_id 的别名) + + Args: + review_id: 审核记录 ID + + Returns: + 记录字典或 None + """ + return await self.get_persona_update_record_by_id(review_id) + + async def update_persona_learning_review_status( + self, review_id, new_status, reviewer_comment='', + modified_content=None, + ) -> bool: + """更新人格学习审核记录状态 + + Args: + review_id: 审核记录 ID + new_status: 新状态 + reviewer_comment: 审核评论 + modified_content: 用户修改后的内容(可选) + + Returns: + 是否更新成功 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.learning import PersonaLearningReview + + stmt = select(PersonaLearningReview).where( + PersonaLearningReview.id == review_id + ) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return False + + record.status = new_status + record.reviewer_comment = reviewer_comment + record.review_time = time.time() + + if modified_content: + record.proposed_content = modified_content + record.new_content = modified_content + + await session.commit() + return True + except Exception as e: + self._logger.error(f"[LearningFacade] 更新人格学习审核记录状态失败: {e}") + return False + + # Style Learning Review methods + + async def create_style_learning_review( + self, review_data: Dict[str, Any] + ) -> int: + """创建风格学习审核记录 + + Args: + review_data: 审核数据字典 + + Returns: + 新记录的 id,失败返回 0 + """ + try: + async with self.get_session() as session: + from ....models.orm.learning import StyleLearningReview + + learned_patterns = review_data.get('learned_patterns', []) + record = StyleLearningReview( + type=review_data.get('type', ''), + group_id=review_data.get('group_id', ''), + timestamp=self._to_float_ts( + review_data.get('timestamp'), default=time.time() + ), + learned_patterns=json.dumps(learned_patterns, ensure_ascii=False) + if isinstance(learned_patterns, (list, dict)) + else learned_patterns, + few_shots_content=review_data.get('few_shots_content', ''), + status='pending', + description=review_data.get('description', ''), + ) + session.add(record) + await session.commit() + await session.refresh(record) + return record.id + except Exception as e: + self._logger.error(f"[LearningFacade] 创建风格学习审核记录失败: {e}") + return 0 + + async def get_pending_style_reviews(self, limit=None, offset=0) -> List[Dict]: + """获取待审核的风格学习记录 + + Args: + limit: 可选的返回数量限制 + offset: 分页偏移量 + + Returns: + 待审核记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import StyleLearningReview + + stmt = ( + select(StyleLearningReview) + .where(StyleLearningReview.status == 'pending') + .order_by(desc(StyleLearningReview.timestamp)) + ) + if offset > 0: + stmt = stmt.offset(offset) + if limit is not None: + stmt = stmt.limit(limit) + + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'id': r.id, + 'type': r.type, + 'group_id': r.group_id, + 'timestamp': r.timestamp, + 'learned_patterns': json.loads(r.learned_patterns) + if r.learned_patterns + else [], + 'few_shots_content': r.few_shots_content, + 'status': r.status, + 'description': r.description, + 'reviewer_comment': r.reviewer_comment, + 'review_time': r.review_time, + 'created_at': r.created_at, + } + for r in rows + ] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取待审核风格学习记录失败: {e}") + return [] + + async def get_approved_few_shots( + self, group_id: str, limit: int = 3 + ) -> List[str]: + """获取指定群组已审批的 few-shot 对话内容 + + Args: + group_id: 群组 ID + limit: 返回条数上限 + + Returns: + few_shots_content 文本列表,按时间倒序 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import StyleLearningReview + + stmt = ( + select(StyleLearningReview.few_shots_content) + .where( + StyleLearningReview.status == 'approved', + StyleLearningReview.group_id == group_id, + StyleLearningReview.few_shots_content.isnot(None), + StyleLearningReview.few_shots_content != '', + ) + .order_by(desc(StyleLearningReview.timestamp)) + .limit(limit) + ) + result = await session.execute(stmt) + return [row[0] for row in result.fetchall()] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取已审批 few-shots 失败: {e}") + return [] + + async def get_reviewed_style_learning_updates( + self, limit=50, offset=0, status_filter=None + ) -> List[Dict]: + """获取已审核的风格学习更新记录 + + Args: + limit: 返回数量限制 + offset: 偏移量 + status_filter: 状态过滤 + + Returns: + 已审核记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import StyleLearningReview + + if status_filter: + stmt = ( + select(StyleLearningReview) + .where(StyleLearningReview.status == status_filter) + .order_by(desc(StyleLearningReview.review_time)) + .offset(offset) + .limit(limit) + ) + else: + stmt = ( + select(StyleLearningReview) + .where(StyleLearningReview.status.in_(['approved', 'rejected'])) + .order_by(desc(StyleLearningReview.review_time)) + .offset(offset) + .limit(limit) + ) + + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'id': r.id, + 'type': r.type, + 'group_id': r.group_id, + 'timestamp': r.timestamp, + 'learned_patterns': json.loads(r.learned_patterns) + if r.learned_patterns + else [], + 'few_shots_content': r.few_shots_content, + 'status': r.status, + 'description': r.description, + 'reviewer_comment': r.reviewer_comment, + 'review_time': r.review_time, + } + for r in rows + ] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取已审核风格学习更新记录失败: {e}") + return [] + + async def update_style_review_status( + self, review_id, new_status, reviewer_comment='' + ) -> bool: + """更新风格学习审核记录状态 + + Args: + review_id: 审核记录 ID + new_status: 新状态 (approved/rejected) + reviewer_comment: 审核评论 + + Returns: + 是否更新成功 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.learning import StyleLearningReview + + stmt = select(StyleLearningReview).where( + StyleLearningReview.id == review_id + ) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return False + + record.status = new_status + record.reviewer_comment = reviewer_comment + record.review_time = time.time() + await session.commit() + return True + except Exception as e: + self._logger.error(f"[LearningFacade] 更新风格学习审核记录状态失败: {e}") + return False + + async def delete_style_review_by_id(self, review_id: int) -> bool: + """根据 ID 删除风格学习审核记录 + + Args: + review_id: 审核记录 ID + + Returns: + 是否删除成功 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, delete as sa_delete + from ....models.orm.learning import StyleLearningReview + + stmt = select(StyleLearningReview).where( + StyleLearningReview.id == review_id + ) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return False + + await session.execute( + sa_delete(StyleLearningReview).where( + StyleLearningReview.id == review_id + ) + ) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[LearningFacade] 删除风格学习审核记录失败: {e}") + return False + + # Learning Batch/Session methods + + async def get_learning_batch_history( + self, group_id=None, limit=20 + ) -> List[Dict]: + """获取学习批次历史 + + Args: + group_id: 可选的群组 ID 过滤 + limit: 返回数量限制 + + Returns: + 学习批次记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import LearningBatch + + stmt = ( + select(LearningBatch) + .order_by(desc(LearningBatch.start_time)) + .limit(limit) + ) + if group_id: + stmt = stmt.where(LearningBatch.group_id == group_id) + + result = await session.execute(stmt) + rows = result.scalars().all() + return [self._row_to_dict(r) for r in rows] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取学习批次历史失败: {e}") + return [] + + async def get_recent_learning_batches(self, limit=5) -> List[Dict]: + """获取最近的学习批次 + + Args: + limit: 返回数量限制 + + Returns: + 学习批次记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import LearningBatch + + stmt = ( + select(LearningBatch) + .order_by(desc(LearningBatch.start_time)) + .limit(limit) + ) + result = await session.execute(stmt) + rows = result.scalars().all() + return [self._row_to_dict(r) for r in rows] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取最近学习批次失败: {e}") + return [] + + async def get_learning_sessions(self, group_id, limit=5) -> List[Dict]: + """获取指定群组的学习会话 + + Args: + group_id: 群组 ID + limit: 返回数量限制 + + Returns: + 学习会话记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import LearningSession + + stmt = ( + select(LearningSession) + .where(LearningSession.group_id == group_id) + .order_by(desc(LearningSession.start_time)) + .limit(limit) + ) + result = await session.execute(stmt) + rows = result.scalars().all() + return [self._row_to_dict(r) for r in rows] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取学习会话失败: {e}") + return [] + + async def get_recent_learning_sessions(self, days=7) -> List[Dict]: + """获取最近 N 天的学习会话 + + Args: + days: 天数 + + Returns: + 学习会话记录列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import LearningSession + + cutoff = time.time() - (days * 24 * 3600) + stmt = ( + select(LearningSession) + .where(LearningSession.start_time > cutoff) + .order_by(desc(LearningSession.start_time)) + ) + result = await session.execute(stmt) + rows = result.scalars().all() + return [self._row_to_dict(r) for r in rows] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取最近学习会话失败: {e}") + return [] + + async def save_learning_session_record( + self, group_id, session_data + ) -> bool: + """保存学习会话记录 + + Args: + group_id: 群组 ID + session_data: 会话数据字典 + + Returns: + 是否保存成功 + """ + try: + async with self.get_session() as session: + from ....models.orm.learning import LearningSession + + record = LearningSession( + session_id=session_data.get('session_id', ''), + group_id=group_id, + batch_id=session_data.get('batch_id'), + start_time=self._to_float_ts( + session_data.get('start_time'), default=time.time() + ), + end_time=self._to_float_ts(session_data.get('end_time')), + message_count=session_data.get('message_count', 0), + learning_quality=session_data.get('learning_quality'), + status=session_data.get('status', 'active'), + ) + session.add(record) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[LearningFacade] 保存学习会话记录失败: {e}") + return False + + async def save_learning_performance_record( + self, group_id, performance_data + ) -> bool: + """保存学习性能记录 + + Args: + group_id: 群组 ID + performance_data: 性能数据字典 + + Returns: + 是否保存成功 + """ + try: + async with self.get_session() as session: + from ....models.orm.performance import LearningPerformanceHistory + + metadata = performance_data.get('metadata', {}) + record = LearningPerformanceHistory( + group_id=group_id, + session_id=performance_data.get('session_id', ''), + timestamp=int( + self._to_float_ts( + performance_data.get('timestamp'), + default=time.time(), + ) + ), + quality_score=performance_data.get('quality_score'), + learning_time=performance_data.get('learning_time'), + success=performance_data.get('success', True), + successful_pattern=json.dumps( + performance_data.get('successful_pattern', []), + ensure_ascii=False, + ) + if isinstance(performance_data.get('successful_pattern'), (list, dict)) + else performance_data.get('successful_pattern'), + failed_pattern=json.dumps( + performance_data.get('failed_pattern', []), + ensure_ascii=False, + ) + if isinstance(performance_data.get('failed_pattern'), (list, dict)) + else performance_data.get('failed_pattern'), + created_at=int(time.time()), + ) + session.add(record) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[LearningFacade] 保存学习性能记录失败: {e}") + return False + + # Statistics methods + + async def count_pending_persona_updates(self) -> int: + """统计待审核的人格更新记录数 + + Returns: + 待审核记录数量 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.learning import PersonaLearningReview + + stmt = ( + select(func.count()) + .select_from(PersonaLearningReview) + .where(PersonaLearningReview.status == 'pending') + ) + result = await session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + self._logger.error(f"[LearningFacade] 统计待审核人格更新数量失败: {e}") + return 0 + + async def count_style_learning_patterns(self) -> int: + """统计风格学习模式总数 + + Returns: + 风格学习模式数量 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.learning import StyleLearningPattern + + stmt = select(func.count()).select_from(StyleLearningPattern) + result = await session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + self._logger.error(f"[LearningFacade] 统计风格学习模式数量失败: {e}") + return 0 + + async def count_refined_messages(self) -> int: + """统计筛选后消息总数 + + Returns: + 筛选后消息数量 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.message import FilteredMessage + + stmt = select(func.count()).select_from(FilteredMessage) + result = await session.execute(stmt) + return result.scalar() or 0 + except Exception as e: + self._logger.error(f"[LearningFacade] 统计筛选后消息数量失败: {e}") + return 0 + + async def get_style_learning_statistics(self) -> Dict[str, Any]: + """获取风格学习统计信息 + + Returns: + 包含 total_reviews, pending_reviews, approved_reviews 的字典 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.learning import StyleLearningReview + + total_stmt = select(func.count()).select_from(StyleLearningReview) + total_result = await session.execute(total_stmt) + total = total_result.scalar() or 0 + + pending_stmt = ( + select(func.count()) + .select_from(StyleLearningReview) + .where(StyleLearningReview.status == 'pending') + ) + pending_result = await session.execute(pending_stmt) + pending = pending_result.scalar() or 0 + + approved_stmt = ( + select(func.count()) + .select_from(StyleLearningReview) + .where(StyleLearningReview.status == 'approved') + ) + approved_result = await session.execute(approved_stmt) + approved = approved_result.scalar() or 0 + + return { + 'total_reviews': total, + 'pending_reviews': pending, + 'approved_reviews': approved, + } + except Exception as e: + self._logger.error(f"[LearningFacade] 获取风格学习统计失败: {e}") + return { + 'total_reviews': 0, + 'pending_reviews': 0, + 'approved_reviews': 0, + } + + async def get_style_progress_data( + self, group_id=None + ) -> List[Dict]: + """获取风格学习进度数据(从 learning_batches 表查询) + + Args: + group_id: 可选的群组 ID 过滤 + + Returns: + 学习批次进度列表 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import LearningBatch + + stmt = ( + select(LearningBatch) + .where( + LearningBatch.quality_score.isnot(None), + LearningBatch.processed_messages > 0, + ) + .order_by(desc(LearningBatch.start_time)) + .limit(30) + ) + if group_id: + stmt = stmt.where(LearningBatch.group_id == group_id) + + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'group_id': r.group_id, + 'timestamp': r.start_time or 0, + 'quality_score': r.quality_score or 0, + 'success': bool(r.success), + 'processed_messages': r.processed_messages or 0, + 'filtered_count': r.filtered_count or 0, + 'message_count': r.message_count or 0, + 'batch_name': r.batch_name or '', + } + for r in rows + ] + except Exception as e: + self._logger.error(f"[LearningFacade] 获取风格学习进度数据失败: {e}") + return [] + + async def get_learning_patterns_data( + self, group_id=None + ) -> Dict[str, Any]: + """获取学习模式分布数据 + + 按 pattern_type 分组统计 StyleLearningPattern 记录。 + + Args: + group_id: 可选的群组 ID 过滤 + + Returns: + 按模式类型分组的计数字典 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.learning import StyleLearningPattern + + stmt = select( + StyleLearningPattern.pattern_type, + func.count().label('count'), + ).group_by(StyleLearningPattern.pattern_type) + if group_id: + stmt = stmt.where(StyleLearningPattern.group_id == group_id) + + result = await session.execute(stmt) + rows = result.all() + pattern_counts = {row[0]: row[1] for row in rows} + return pattern_counts + except Exception as e: + self._logger.error(f"[LearningFacade] 获取学习模式分布数据失败: {e}") + return {} diff --git a/services/database/facades/message_facade.py b/services/database/facades/message_facade.py new file mode 100644 index 0000000..878f6e8 --- /dev/null +++ b/services/database/facades/message_facade.py @@ -0,0 +1,419 @@ +""" +消息 Facade — 原始消息、筛选消息、Bot消息的业务入口 +""" +import time +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade +from ....repositories.raw_message_repository import RawMessageRepository +from ....repositories.filtered_message_repository import FilteredMessageRepository +from ....repositories.bot_message_repository import BotMessageRepository + + +class MessageFacade(BaseFacade): + """消息管理 Facade""" + + # ---- 原始消息 ---- + + async def save_raw_message(self, message_data) -> int: + """保存原始消息 + + Args: + message_data: 消息数据(对象或字典) + + Returns: + int: 消息 ID(失败返回 0) + """ + try: + async with self.get_session() as session: + from ....models.orm.message import RawMessage + + if hasattr(message_data, '__dict__'): + data = message_data.__dict__ + else: + data = message_data + + raw_msg = RawMessage( + sender_id=str(data.get('sender_id', '')), + sender_name=data.get('sender_name', ''), + message=data.get('message', ''), + group_id=data.get('group_id', ''), + timestamp=int(data.get('timestamp', time.time())), + platform=data.get('platform', ''), + message_id=data.get('message_id'), + reply_to=data.get('reply_to'), + created_at=int(time.time()), + processed=False, + ) + session.add(raw_msg) + await session.commit() + await session.refresh(raw_msg) + return raw_msg.id + except Exception as e: + self._logger.error(f"[MessageFacade] 保存原始消息失败: {e}") + return 0 + + async def get_recent_raw_messages( + self, group_id: str, limit: int = 200 + ) -> List[Dict[str, Any]]: + """获取最近的原始消息""" + try: + async with self.get_session() as session: + repo = RawMessageRepository(session) + messages = await repo.get_recent(group_id=group_id, limit=limit) + return [ + { + 'id': msg.id, 'sender_id': msg.sender_id, + 'sender_name': msg.sender_name, 'message': msg.message, + 'group_id': msg.group_id, 'timestamp': msg.timestamp, + 'platform': msg.platform, 'message_id': msg.message_id, + 'reply_to': msg.reply_to, 'created_at': msg.created_at, + 'processed': msg.processed, + } + for msg in messages + ] + except Exception as e: + self._logger.error(f"[MessageFacade] 获取最近原始消息失败: {e}") + raise RuntimeError(f"无法获取群组 {group_id} 的最近原始消息: {e}") from e + + async def get_unprocessed_messages( + self, limit: Optional[int] = None + ) -> List[Dict[str, Any]]: + """获取未处理的原始消息""" + try: + async with self.get_session() as session: + repo = RawMessageRepository(session) + messages = await repo.get_unprocessed(limit=limit or 100) + return [ + { + 'id': msg.id, 'sender_id': msg.sender_id, + 'sender_name': msg.sender_name, 'message': msg.message, + 'group_id': msg.group_id, 'platform': msg.platform, + 'timestamp': msg.timestamp, + } + for msg in messages + ] + except Exception as e: + self._logger.error(f"[MessageFacade] 获取未处理消息失败: {e}") + raise RuntimeError(f"获取未处理消息失败: {e}") from e + + async def mark_messages_processed(self, message_ids: List[int]) -> bool: + """批量标记消息为已处理""" + if not message_ids: + return True + try: + async with self.get_session() as session: + repo = RawMessageRepository(session) + count = await repo.mark_batch_processed(message_ids) + return count > 0 + except Exception as e: + self._logger.error(f"[MessageFacade] 标记已处理失败: {e}") + return False + + async def get_messages_by_timerange( + self, group_id: str, start_time: int, end_time: int, limit: int = 500 + ) -> List[Dict[str, Any]]: + """按时间范围获取消息""" + return await self.get_messages_by_group_and_timerange( + group_id, start_time, end_time, limit + ) + + async def get_messages_by_group_and_timerange( + self, group_id: str, start_time: int, end_time: int, limit: int = 500 + ) -> List[Dict[str, Any]]: + """按群组和时间范围获取消息""" + try: + async with self.get_session() as session: + repo = RawMessageRepository(session) + messages = await repo.get_by_timerange(group_id, start_time, end_time, limit) + return [ + { + 'id': msg.id, 'sender_id': msg.sender_id, + 'sender_name': msg.sender_name, 'message': msg.message, + 'group_id': msg.group_id, 'timestamp': msg.timestamp, + } + for msg in messages + ] + except Exception as e: + self._logger.error(f"[MessageFacade] 按时间范围获取消息失败: {e}") + return [] + + async def get_messages_for_replay( + self, group_id: str, days: int = 30, limit: int = 100 + ) -> List[Dict[str, Any]]: + """获取用于记忆重放的消息""" + try: + async with self.get_session() as session: + from sqlalchemy import select, desc, and_ + from ....models.orm.message import RawMessage + + cutoff_time = time.time() - (days * 24 * 3600) + stmt = ( + select(RawMessage) + .where(and_( + RawMessage.group_id == group_id, + RawMessage.timestamp > cutoff_time, + RawMessage.processed == True, # noqa: E712 + )) + .order_by(desc(RawMessage.timestamp)) + .limit(limit) + ) + result = await session.execute(stmt) + return [ + { + 'message_id': msg.id, 'message': msg.message, + 'sender_id': msg.sender_id, 'group_id': msg.group_id, + 'timestamp': msg.timestamp, + } + for msg in result.scalars().all() + ] + except Exception as e: + self._logger.error(f"[MessageFacade] 获取记忆重放消息失败: {e}") + return [] + + # ---- 筛选消息 ---- + + async def get_recent_filtered_messages( + self, group_id: str, limit: int = 20 + ) -> List[Dict[str, Any]]: + """获取最近的筛选消息""" + try: + async with self.get_session() as session: + repo = FilteredMessageRepository(session) + messages = await repo.get_recent(group_id=group_id, limit=limit) + return [ + { + 'id': msg.id, 'raw_message_id': msg.raw_message_id, + 'message': msg.message, 'sender_id': msg.sender_id, + 'group_id': msg.group_id, 'timestamp': msg.timestamp, + 'confidence': msg.confidence, 'quality_scores': msg.quality_scores, + 'filter_reason': msg.filter_reason, 'created_at': msg.created_at, + 'processed': msg.processed, + } + for msg in messages + ] + except Exception as e: + self._logger.error(f"[MessageFacade] 获取筛选消息失败: {e}") + raise RuntimeError(f"无法获取群组 {group_id} 的最近筛选消息: {e}") from e + + async def get_filtered_messages_for_learning( + self, limit: int = 20 + ) -> List[Dict[str, Any]]: + """获取待学习的筛选消息""" + try: + async with self.get_session() as session: + repo = FilteredMessageRepository(session) + messages = await repo.get_for_learning(limit=limit) + return [ + { + 'id': msg.id, 'message': msg.message, + 'sender_id': msg.sender_id, 'group_id': msg.group_id, + 'timestamp': msg.timestamp, 'confidence': msg.confidence, + } + for msg in messages + ] + except Exception as e: + self._logger.error(f"[MessageFacade] 获取待学习筛选消息失败: {e}") + return [] + + async def add_filtered_message(self, filtered_data: Dict[str, Any]) -> int: + """添加筛选后的消息""" + try: + async with self.get_session() as session: + repo = FilteredMessageRepository(session) + msg = await repo.add(filtered_data) + return msg.id if msg else 0 + except Exception as e: + self._logger.error(f"[MessageFacade] 添加筛选消息失败: {e}") + return 0 + + # ---- Bot 消息 ---- + + async def save_bot_message( + self, group_id: str, message: str, timestamp: int = None + ) -> bool: + """保存 Bot 消息""" + try: + async with self.get_session() as session: + repo = BotMessageRepository(session) + result = await repo.save({ + 'group_id': group_id, + 'message': message, + 'timestamp': timestamp or int(time.time()), + }) + return result is not None + except Exception as e: + self._logger.error(f"[MessageFacade] 保存 Bot 消息失败: {e}") + return False + + async def get_recent_bot_responses( + self, group_id: str, limit: int = 10 + ) -> List[str]: + """获取最近的 Bot 回复(仅文本)""" + try: + async with self.get_session() as session: + repo = BotMessageRepository(session) + messages = await repo.get_recent_responses(group_id, limit) + return [msg.message for msg in messages] + except Exception as e: + self._logger.error(f"[MessageFacade] 获取 Bot 回复失败: {e}") + return [] + + # ---- 统计 ---- + + async def get_message_statistics( + self, group_id: str = None + ) -> Dict[str, Any]: + """获取消息统计信息""" + if not group_id: + return await self.get_messages_statistics() + + try: + async with self.get_session() as session: + from sqlalchemy import select, func, and_ + from ....models.orm.message import RawMessage, FilteredMessage + + total_stmt = select(func.count()).select_from(RawMessage).where( + RawMessage.group_id == group_id + ) + total = (await session.execute(total_stmt)).scalar() or 0 + + unprocessed_stmt = select(func.count()).select_from(RawMessage).where( + and_(RawMessage.group_id == group_id, RawMessage.processed == False) # noqa: E712 + ) + unprocessed = (await session.execute(unprocessed_stmt)).scalar() or 0 + + filtered_stmt = select(func.count()).select_from(FilteredMessage).where( + FilteredMessage.group_id == group_id + ) + filtered = (await session.execute(filtered_stmt)).scalar() or 0 + + return { + 'total_messages': total, + 'unprocessed_messages': unprocessed, + 'filtered_messages': filtered, + 'raw_messages': total, + 'group_id': group_id, + } + except Exception as e: + self._logger.error(f"[MessageFacade] 获取消息统计失败: {e}") + return { + 'total_messages': 0, 'unprocessed_messages': 0, + 'filtered_messages': 0, 'raw_messages': 0, 'group_id': group_id, + } + + async def get_messages_statistics(self) -> Dict[str, Any]: + """获取全局消息统计""" + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.message import RawMessage, FilteredMessage, BotMessage + + raw_count = (await session.execute( + select(func.count()).select_from(RawMessage) + )).scalar() or 0 + filtered_count = (await session.execute( + select(func.count()).select_from(FilteredMessage) + )).scalar() or 0 + bot_count = (await session.execute( + select(func.count()).select_from(BotMessage) + )).scalar() or 0 + + return { + 'total_messages': raw_count, + 'raw_messages': raw_count, + 'filtered_messages': filtered_count, + 'bot_messages': bot_count, + } + except Exception as e: + self._logger.error(f"[MessageFacade] 获取全局统计失败: {e}") + return { + 'total_messages': 0, 'raw_messages': 0, + 'filtered_messages': 0, 'bot_messages': 0, + } + + async def get_group_messages_statistics( + self, group_id: str + ) -> Dict[str, Any]: + """获取群组消息统计""" + return await self.get_message_statistics(group_id) + + async def get_group_user_statistics( + self, group_id: str + ) -> Dict[str, Dict[str, Any]]: + """获取群组用户消息统计""" + try: + async with self.get_session() as session: + repo = RawMessageRepository(session) + stats = await repo.get_sender_statistics(group_id, limit=50) + return { + s['sender_id']: { + 'message_count': s['count'], + 'sender_name': s.get('sender_name', s['sender_id']), + } + for s in stats + } + except Exception as e: + self._logger.error(f"[MessageFacade] 获取用户统计失败: {e}") + return {} + + async def get_groups_for_social_analysis(self) -> List[Dict[str, Any]]: + """获取有消息记录的群组列表(用于社交分析) + + 返回每个群组的消息数、成员数和社交关系数,供 SocialService 消费。 + 使用两个独立查询避免 LEFT JOIN 子查询的兼容性问题。 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select, func, distinct + from ....models.orm.message import RawMessage + + # 查询 1: 从 RawMessage 获取群组列表、消息数、成员数 + msg_stmt = ( + select( + RawMessage.group_id, + func.count().label('message_count'), + func.count(distinct(RawMessage.sender_id)).label('member_count'), + ) + .group_by(RawMessage.group_id) + .order_by(func.count().desc()) + ) + msg_result = await session.execute(msg_stmt) + groups = [] + for row in msg_result.fetchall(): + groups.append({ + 'group_id': row.group_id, + 'message_count': row.message_count, + 'member_count': row.member_count, + 'relation_count': 0, + }) + + # 查询 2: 从 SocialRelation 获取每个群组的关系数(可选) + if groups: + try: + from ....models.orm.social_relation import SocialRelation + rel_stmt = ( + select( + SocialRelation.group_id, + func.count().label('relation_count'), + ) + .group_by(SocialRelation.group_id) + ) + rel_result = await session.execute(rel_stmt) + rel_map = { + row.group_id: row.relation_count + for row in rel_result.fetchall() + } + for g in groups: + g['relation_count'] = rel_map.get(g['group_id'], 0) + except Exception as rel_err: + self._logger.debug( + f"[MessageFacade] 获取社交关系计数失败(不影响群组列表): {rel_err}" + ) + + return groups + except Exception as e: + self._logger.error(f"[MessageFacade] 获取分析群组列表失败: {e}") + return [] diff --git a/services/database/facades/metrics_facade.py b/services/database/facades/metrics_facade.py new file mode 100644 index 0000000..497262f --- /dev/null +++ b/services/database/facades/metrics_facade.py @@ -0,0 +1,148 @@ +""" +指标聚合 Facade — 跨域统计指标的业务入口 +""" +import time +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade + + +class MetricsFacade(BaseFacade): + """跨域指标聚合 Facade""" + + async def get_group_statistics(self, group_id: str = None) -> Dict[str, Any]: + """获取群组综合统计数据""" + try: + async with self.get_session() as session: + from sqlalchemy import select, func, and_ + from ....models.orm.message import RawMessage, FilteredMessage + from ....models.orm.learning import PersonaLearningReview, StyleLearningReview + + # 原始消息数 + raw_stmt = select(func.count()).select_from(RawMessage) + if group_id: + raw_stmt = raw_stmt.where(RawMessage.group_id == group_id) + raw_count = (await session.execute(raw_stmt)).scalar() or 0 + + # 筛选消息数 + filtered_stmt = select(func.count()).select_from(FilteredMessage) + if group_id: + filtered_stmt = filtered_stmt.where(FilteredMessage.group_id == group_id) + filtered_count = (await session.execute(filtered_stmt)).scalar() or 0 + + # 人格学习审核数 + persona_stmt = select(func.count()).select_from(PersonaLearningReview) + if group_id: + persona_stmt = persona_stmt.where(PersonaLearningReview.group_id == group_id) + persona_count = (await session.execute(persona_stmt)).scalar() or 0 + + # 风格学习审核数 + style_stmt = select(func.count()).select_from(StyleLearningReview) + if group_id: + style_stmt = style_stmt.where(StyleLearningReview.group_id == group_id) + style_count = (await session.execute(style_stmt)).scalar() or 0 + + return { + 'raw_messages': raw_count, + 'filtered_messages': filtered_count, + 'persona_reviews': persona_count, + 'style_reviews': style_count, + 'group_id': group_id, + } + except Exception as e: + self._logger.error(f"[MetricsFacade] 获取群组统计失败: {e}") + return { + 'raw_messages': 0, 'filtered_messages': 0, + 'persona_reviews': 0, 'style_reviews': 0, + 'group_id': group_id, + } + + async def get_detailed_metrics(self, group_id: str = None) -> Dict[str, Any]: + """获取详细指标""" + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.message import RawMessage, FilteredMessage, BotMessage + from ....models.orm.learning import ( + PersonaLearningReview, StyleLearningReview, + LearningBatch, StyleLearningPattern + ) + + async def _count(model, group_col=None): + stmt = select(func.count()).select_from(model) + if group_id and group_col is not None: + stmt = stmt.where(group_col == group_id) + return (await session.execute(stmt)).scalar() or 0 + + raw = await _count(RawMessage, RawMessage.group_id) + filtered = await _count(FilteredMessage, FilteredMessage.group_id) + bot = await _count(BotMessage, BotMessage.group_id) + persona_reviews = await _count(PersonaLearningReview, PersonaLearningReview.group_id) + style_reviews = await _count(StyleLearningReview, StyleLearningReview.group_id) + batches = await _count(LearningBatch, LearningBatch.group_id) + patterns = await _count(StyleLearningPattern, StyleLearningPattern.group_id) + + return { + 'messages': { + 'raw': raw, 'filtered': filtered, 'bot': bot, + }, + 'learning': { + 'persona_reviews': persona_reviews, + 'style_reviews': style_reviews, + 'batches': batches, + 'style_patterns': patterns, + }, + 'group_id': group_id, + } + except Exception as e: + self._logger.error(f"[MetricsFacade] 获取详细指标失败: {e}") + return { + 'messages': {'raw': 0, 'filtered': 0, 'bot': 0}, + 'learning': { + 'persona_reviews': 0, 'style_reviews': 0, + 'batches': 0, 'style_patterns': 0, + }, + 'group_id': group_id, + } + + async def get_trends_data(self) -> Dict[str, Any]: + """获取趋势数据""" + try: + async with self.get_session() as session: + from sqlalchemy import select, func + from ....models.orm.message import RawMessage + from ....models.orm.learning import LearningBatch + + # 过去7天每天的消息数 + cutoff = int(time.time()) - (7 * 24 * 3600) + msg_stmt = ( + select(RawMessage) + .where(RawMessage.timestamp >= cutoff) + .order_by(RawMessage.timestamp) + ) + msg_result = await session.execute(msg_stmt) + messages = msg_result.scalars().all() + + daily: Dict[str, int] = {} + for m in messages: + day = time.strftime('%Y-%m-%d', time.localtime(m.timestamp)) + daily[day] = daily.get(day, 0) + 1 + + # 最近的学习批次 + batch_stmt = ( + select(LearningBatch) + .order_by(LearningBatch.start_time.desc()) + .limit(10) + ) + batch_result = await session.execute(batch_stmt) + batches = [self._row_to_dict(b) for b in batch_result.scalars().all()] + + return { + 'daily_messages': daily, + 'recent_batches': batches, + } + except Exception as e: + self._logger.error(f"[MetricsFacade] 获取趋势数据失败: {e}") + return {'daily_messages': {}, 'recent_batches': []} diff --git a/services/database/facades/persona_facade.py b/services/database/facades/persona_facade.py new file mode 100644 index 0000000..478ceac --- /dev/null +++ b/services/database/facades/persona_facade.py @@ -0,0 +1,112 @@ +""" +人格备份 Facade — 人格配置备份与恢复的业务入口 +""" +import time +import json +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade +from ....repositories.persona_backup_repository import PersonaBackupRepository + + +class PersonaFacade(BaseFacade): + """人格备份管理 Facade""" + + async def backup_persona(self, backup_data: Dict[str, Any]) -> bool: + """创建人格备份""" + try: + async with self.get_session() as session: + from ....models.orm.psychological import PersonaBackup + + backup = PersonaBackup( + backup_name=backup_data.get('backup_name', f'backup_{int(time.time())}'), + timestamp=time.time(), + reason=backup_data.get('reason', ''), + persona_config=json.dumps(backup_data.get('persona_config', {}), ensure_ascii=False), + original_persona=json.dumps(backup_data.get('original_persona', {}), ensure_ascii=False), + imitation_dialogues=json.dumps(backup_data.get('imitation_dialogues', []), ensure_ascii=False), + backup_reason=backup_data.get('backup_reason', ''), + ) + session.add(backup) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[PersonaFacade] 备份人格失败: {e}") + return False + + async def get_persona_backups(self, limit: int = 10) -> List[Dict[str, Any]]: + """获取人格备份列表""" + try: + async with self.get_session() as session: + repo = PersonaBackupRepository(session) + backups = await repo.list_backups(limit=limit) + return [ + { + 'id': b.id, + 'backup_name': b.backup_name, + 'timestamp': b.timestamp, + 'reason': b.reason, + 'persona_config': json.loads(b.persona_config) if b.persona_config else {}, + 'original_persona': json.loads(b.original_persona) if b.original_persona else {}, + 'imitation_dialogues': json.loads(b.imitation_dialogues) if b.imitation_dialogues else [], + 'backup_reason': b.backup_reason, + } + for b in backups + ] + except Exception as e: + self._logger.error(f"[PersonaFacade] 获取备份列表失败: {e}") + return [] + + async def restore_persona_backup(self, backup_id: int) -> Optional[Dict[str, Any]]: + """恢复指定备份""" + try: + async with self.get_session() as session: + repo = PersonaBackupRepository(session) + backup = await repo.get_backup(backup_id) + if not backup: + return None + return { + 'id': backup.id, + 'backup_name': backup.backup_name, + 'timestamp': backup.timestamp, + 'persona_config': json.loads(backup.persona_config) if backup.persona_config else {}, + 'original_persona': json.loads(backup.original_persona) if backup.original_persona else {}, + 'imitation_dialogues': json.loads(backup.imitation_dialogues) if backup.imitation_dialogues else [], + } + except Exception as e: + self._logger.error(f"[PersonaFacade] 恢复备份失败: {e}") + return None + + async def get_persona_update_history( + self, group_id: str = None, limit: int = 50 + ) -> List[Dict[str, Any]]: + """获取人格更新历史""" + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.learning import PersonaLearningReview + + stmt = select(PersonaLearningReview).order_by( + desc(PersonaLearningReview.timestamp) + ).limit(limit) + if group_id: + stmt = stmt.where(PersonaLearningReview.group_id == group_id) + + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'id': r.id, + 'timestamp': r.timestamp, + 'group_id': r.group_id, + 'update_type': r.update_type, + 'status': r.status, + 'confidence_score': r.confidence_score, + } + for r in rows + ] + except Exception as e: + self._logger.error(f"[PersonaFacade] 获取更新历史失败: {e}") + return [] diff --git a/services/database/facades/psychological_facade.py b/services/database/facades/psychological_facade.py new file mode 100644 index 0000000..d2e4464 --- /dev/null +++ b/services/database/facades/psychological_facade.py @@ -0,0 +1,75 @@ +""" +心理状态 Facade — 情绪画像与心理分析的业务入口 +""" +import time +import json +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade +from ....repositories.emotion_profile_repository import EmotionProfileRepository + + +class PsychologicalFacade(BaseFacade): + """心理状态管理 Facade""" + + async def load_emotion_profile( + self, user_id: str, group_id: str + ) -> Optional[Dict[str, Any]]: + """加载情绪画像""" + try: + async with self.get_session() as session: + repo = EmotionProfileRepository(session) + ep = await repo.load(user_id, group_id) + if not ep: + return None + return { + 'user_id': ep.user_id, + 'group_id': ep.group_id, + 'dominant_emotions': json.loads(ep.dominant_emotions) if ep.dominant_emotions else {}, + 'emotion_patterns': json.loads(ep.emotion_patterns) if ep.emotion_patterns else {}, + 'empathy_level': ep.empathy_level, + 'emotional_stability': ep.emotional_stability, + 'last_updated': ep.last_updated, + } + except Exception as e: + self._logger.error(f"[PsychologicalFacade] 加载情绪画像失败: {e}") + return None + + async def save_emotion_profile( + self, user_id: str, group_id: str, profile: Dict[str, Any] + ) -> bool: + """保存情绪画像(upsert)""" + try: + async with self.get_session() as session: + from sqlalchemy import select, and_ + from ....models.orm.psychological import EmotionProfile + + stmt = select(EmotionProfile).where( + and_(EmotionProfile.user_id == user_id, EmotionProfile.group_id == group_id) + ) + result = await session.execute(stmt) + ep = result.scalar_one_or_none() + now = time.time() + if ep: + ep.dominant_emotions = json.dumps(profile.get('dominant_emotions', {}), ensure_ascii=False) + ep.emotion_patterns = json.dumps(profile.get('emotion_patterns', {}), ensure_ascii=False) + ep.empathy_level = profile.get('empathy_level', 0.5) + ep.emotional_stability = profile.get('emotional_stability', 0.5) + ep.last_updated = now + else: + ep = EmotionProfile( + user_id=user_id, group_id=group_id, + dominant_emotions=json.dumps(profile.get('dominant_emotions', {}), ensure_ascii=False), + emotion_patterns=json.dumps(profile.get('emotion_patterns', {}), ensure_ascii=False), + empathy_level=profile.get('empathy_level', 0.5), + emotional_stability=profile.get('emotional_stability', 0.5), + last_updated=now, + ) + session.add(ep) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[PsychologicalFacade] 保存情绪画像失败: {e}") + return False diff --git a/services/database/facades/reinforcement_facade.py b/services/database/facades/reinforcement_facade.py new file mode 100644 index 0000000..75a0753 --- /dev/null +++ b/services/database/facades/reinforcement_facade.py @@ -0,0 +1,128 @@ +""" +强化学习 Facade — 强化学习、人格融合、策略优化的业务入口 +""" +import time +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade +from ....repositories.reinforcement_repository import ( + ReinforcementLearningRepository, + PersonaFusionRepository, + StrategyOptimizationRepository, +) + + +class ReinforcementFacade(BaseFacade): + """强化学习与策略优化 Facade""" + + async def get_learning_history_for_reinforcement( + self, group_id: str, limit: int = 50 + ) -> List[Dict[str, Any]]: + """获取用于强化学习的历史数据""" + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.performance import LearningPerformanceHistory + + stmt = ( + select(LearningPerformanceHistory) + .where(LearningPerformanceHistory.group_id == group_id) + .order_by(desc(LearningPerformanceHistory.timestamp)) + .limit(limit) + ) + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'timestamp': row.timestamp, + 'quality_score': row.quality_score or 0.0, + 'success': bool(row.success), + 'successful_pattern': row.successful_pattern or '', + 'failed_pattern': row.failed_pattern or '' + } + for row in rows + ] + except Exception as e: + self._logger.error(f"[ReinforcementFacade] 获取强化学习历史失败: {e}") + return [] + + async def save_reinforcement_learning_result( + self, group_id: str, result_data: Dict[str, Any] + ) -> bool: + """保存强化学习结果""" + try: + async with self.get_session() as session: + repo = ReinforcementLearningRepository(session) + return await repo.save_reinforcement_result(group_id, result_data) + except Exception as e: + self._logger.error(f"[ReinforcementFacade] 保存强化学习结果失败: {e}") + return False + + async def get_persona_fusion_history( + self, group_id: str, limit: int = 10 + ) -> List[Dict[str, Any]]: + """获取人格融合历史""" + try: + async with self.get_session() as session: + repo = PersonaFusionRepository(session) + return await repo.get_fusion_history(group_id, limit) + except Exception as e: + self._logger.error(f"[ReinforcementFacade] 获取人格融合历史失败: {e}") + return [] + + async def save_persona_fusion_result( + self, group_id: str, fusion_data: Dict[str, Any] + ) -> bool: + """保存人格融合结果""" + try: + async with self.get_session() as session: + repo = PersonaFusionRepository(session) + return await repo.save_fusion_result(group_id, fusion_data) + except Exception as e: + self._logger.error(f"[ReinforcementFacade] 保存人格融合结果失败: {e}") + return False + + async def get_learning_performance_history( + self, group_id: str, limit: int = 30 + ) -> List[Dict[str, Any]]: + """获取学习性能历史""" + try: + async with self.get_session() as session: + from sqlalchemy import select, desc + from ....models.orm.performance import LearningPerformanceHistory + + stmt = ( + select(LearningPerformanceHistory) + .where(LearningPerformanceHistory.group_id == group_id) + .order_by(desc(LearningPerformanceHistory.timestamp)) + .limit(limit) + ) + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + 'session_id': row.session_id, + 'timestamp': row.timestamp, + 'quality_score': row.quality_score or 0.0, + 'learning_time': row.learning_time or 0.0, + 'success': bool(row.success) + } + for row in rows + ] + except Exception as e: + self._logger.error(f"[ReinforcementFacade] 获取学习性能历史失败: {e}") + return [] + + async def save_strategy_optimization_result( + self, group_id: str, optimization_data: Dict[str, Any] + ) -> bool: + """保存策略优化结果""" + try: + async with self.get_session() as session: + repo = StrategyOptimizationRepository(session) + return await repo.save_optimization_result(group_id, optimization_data) + except Exception as e: + self._logger.error(f"[ReinforcementFacade] 保存策略优化结果失败: {e}") + return False diff --git a/services/database/facades/social_facade.py b/services/database/facades/social_facade.py new file mode 100644 index 0000000..b63c9f0 --- /dev/null +++ b/services/database/facades/social_facade.py @@ -0,0 +1,243 @@ +""" +社交关系 Facade — 用户画像、偏好、社交关系网络的业务入口 +""" +import time +import json +from typing import Dict, List, Optional, Any + +from astrbot.api import logger + +from ._base import BaseFacade +from ....repositories.user_profile_repository import UserProfileRepository +from ....repositories.user_preferences_repository import UserPreferencesRepository + + +class SocialFacade(BaseFacade): + """社交关系管理 Facade""" + + # ---- 用户画像 ---- + + async def load_user_profile(self, qq_id: str) -> Optional[Dict[str, Any]]: + """加载用户画像""" + try: + async with self.get_session() as session: + from ....models.orm.social_relation import UserProfile + profile = await session.get(UserProfile, qq_id) + if not profile: + return None + return { + 'qq_id': profile.qq_id, + 'qq_name': profile.qq_name, + 'nicknames': json.loads(profile.nicknames) if profile.nicknames else [], + 'activity_pattern': json.loads(profile.activity_pattern) if profile.activity_pattern else {}, + 'communication_style': json.loads(profile.communication_style) if profile.communication_style else {}, + 'topic_preferences': json.loads(profile.topic_preferences) if profile.topic_preferences else {}, + 'emotional_tendency': json.loads(profile.emotional_tendency) if profile.emotional_tendency else {}, + 'last_active': profile.last_active, + } + except Exception as e: + self._logger.error(f"[SocialFacade] 加载用户画像失败: {e}") + return None + + async def save_user_profile(self, qq_id: str, profile_data: Dict[str, Any]) -> bool: + """保存用户画像(upsert)""" + try: + async with self.get_session() as session: + from ....models.orm.social_relation import UserProfile + profile = await session.get(UserProfile, qq_id) + if profile: + profile.qq_name = profile_data.get('qq_name', profile.qq_name) + profile.nicknames = json.dumps(profile_data.get('nicknames', []), ensure_ascii=False) + profile.activity_pattern = json.dumps(profile_data.get('activity_pattern', {}), ensure_ascii=False) + profile.communication_style = json.dumps(profile_data.get('communication_style', {}), ensure_ascii=False) + profile.topic_preferences = json.dumps(profile_data.get('topic_preferences', {}), ensure_ascii=False) + profile.emotional_tendency = json.dumps(profile_data.get('emotional_tendency', {}), ensure_ascii=False) + profile.last_active = profile_data.get('last_active', time.time()) + else: + profile = UserProfile( + qq_id=qq_id, + qq_name=profile_data.get('qq_name', ''), + nicknames=json.dumps(profile_data.get('nicknames', []), ensure_ascii=False), + activity_pattern=json.dumps(profile_data.get('activity_pattern', {}), ensure_ascii=False), + communication_style=json.dumps(profile_data.get('communication_style', {}), ensure_ascii=False), + topic_preferences=json.dumps(profile_data.get('topic_preferences', {}), ensure_ascii=False), + emotional_tendency=json.dumps(profile_data.get('emotional_tendency', {}), ensure_ascii=False), + last_active=profile_data.get('last_active', time.time()), + ) + session.add(profile) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[SocialFacade] 保存用户画像失败: {e}") + return False + + # ---- 用户偏好 ---- + + async def load_user_preferences( + self, user_id: str, group_id: str + ) -> Optional[Dict[str, Any]]: + """加载用户偏好""" + try: + async with self.get_session() as session: + from sqlalchemy import select, and_ + from ....models.orm.social_relation import UserPreferences + stmt = select(UserPreferences).where( + and_(UserPreferences.user_id == user_id, UserPreferences.group_id == group_id) + ) + result = await session.execute(stmt) + pref = result.scalar_one_or_none() + if not pref: + return None + return { + 'user_id': pref.user_id, + 'group_id': pref.group_id, + 'favorite_topics': json.loads(pref.favorite_topics) if pref.favorite_topics else [], + 'interaction_style': json.loads(pref.interaction_style) if pref.interaction_style else {}, + 'learning_preferences': json.loads(pref.learning_preferences) if pref.learning_preferences else {}, + 'adaptive_rate': pref.adaptive_rate, + } + except Exception as e: + self._logger.error(f"[SocialFacade] 加载用户偏好失败: {e}") + return None + + async def save_user_preferences( + self, user_id: str, group_id: str, prefs: Dict[str, Any] + ) -> bool: + """保存用户偏好(upsert)""" + try: + async with self.get_session() as session: + from sqlalchemy import select, and_ + from ....models.orm.social_relation import UserPreferences + stmt = select(UserPreferences).where( + and_(UserPreferences.user_id == user_id, UserPreferences.group_id == group_id) + ) + result = await session.execute(stmt) + pref = result.scalar_one_or_none() + now = time.time() + if pref: + pref.favorite_topics = json.dumps(prefs.get('favorite_topics', []), ensure_ascii=False) + pref.interaction_style = json.dumps(prefs.get('interaction_style', {}), ensure_ascii=False) + pref.learning_preferences = json.dumps(prefs.get('learning_preferences', {}), ensure_ascii=False) + pref.adaptive_rate = prefs.get('adaptive_rate', 0.5) + pref.updated_at = now + else: + pref = UserPreferences( + user_id=user_id, group_id=group_id, + favorite_topics=json.dumps(prefs.get('favorite_topics', []), ensure_ascii=False), + interaction_style=json.dumps(prefs.get('interaction_style', {}), ensure_ascii=False), + learning_preferences=json.dumps(prefs.get('learning_preferences', {}), ensure_ascii=False), + adaptive_rate=prefs.get('adaptive_rate', 0.5), + updated_at=now, + ) + session.add(pref) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[SocialFacade] 保存用户偏好失败: {e}") + return False + + # ---- 社交关系 ---- + + async def get_social_relations_by_group(self, group_id: str) -> List[Dict[str, Any]]: + """获取群组的社交关系列表 + + 返回格式兼容 SocialService/SocialRelationAnalyzer 期望的 + from_user/to_user 键名。 + """ + try: + async with self.get_session() as session: + from sqlalchemy import select + from ....models.orm.social_relation import UserSocialRelationComponent + + stmt = select(UserSocialRelationComponent).where( + UserSocialRelationComponent.group_id == group_id + ) + result = await session.execute(stmt) + components = result.scalars().all() + return [ + { + 'from_user': c.from_user_id, + 'to_user': c.to_user_id, + 'relation_type': c.relation_type, + 'strength': c.value, + 'frequency': c.frequency, + 'last_interaction': c.last_interaction, + 'description': c.description, + } + for c in components + ] + except Exception as e: + self._logger.error(f"[SocialFacade] 获取社交关系失败: {e}") + return [] + + async def get_social_relationships(self, group_id: str) -> List[Dict[str, Any]]: + """获取社交关系(别名)""" + return await self.get_social_relations_by_group(group_id) + + async def load_social_graph(self, group_id: str) -> List[Dict[str, Any]]: + """加载社交关系图(别名)""" + return await self.get_social_relations_by_group(group_id) + + async def save_social_relation( + self, group_id: str, relation_data: Dict[str, Any] + ) -> bool: + """保存社交关系 + + 接受 SocialRelationAnalyzer 传入的 from_user/to_user 格式, + 映射到 ORM 模型的 from_user_id/to_user_id 列。 + """ + try: + async with self.get_session() as session: + from ....models.orm.social_relation import UserSocialRelationComponent + import time as _time + + now = int(_time.time()) + component = UserSocialRelationComponent( + profile_id=0, # 无关联 profile 时使用占位值 + from_user_id=relation_data.get('from_user', relation_data.get('from_user_id', '')), + to_user_id=relation_data.get('to_user', relation_data.get('to_user_id', '')), + group_id=group_id, + relation_type=relation_data.get('relation_type', 'interaction'), + value=relation_data.get('strength', 0.5), + frequency=relation_data.get('frequency', 1), + last_interaction=relation_data.get('last_interaction', now) if isinstance( + relation_data.get('last_interaction'), (int, float) + ) else now, + description=relation_data.get('relation_name', ''), + created_at=now, + ) + session.add(component) + await session.commit() + return True + except Exception as e: + self._logger.error(f"[SocialFacade] 保存社交关系失败: {e}") + return False + + async def get_user_social_relations( + self, group_id: str, user_id: str + ) -> Dict[str, Any]: + """获取用户的社交关系""" + try: + async with self.get_session() as session: + from ....repositories.social_repository import SocialRelationComponentRepository + repo = SocialRelationComponentRepository(session) + from sqlalchemy import select, or_ + from ....models.orm.social_relation import UserSocialRelationComponent + + stmt = select(UserSocialRelationComponent).where( + UserSocialRelationComponent.group_id == group_id, + or_( + UserSocialRelationComponent.from_user_id == user_id, + UserSocialRelationComponent.to_user_id == user_id, + ), + ) + result = await session.execute(stmt) + relations = result.scalars().all() + return { + 'user_id': user_id, + 'group_id': group_id, + 'relations': [self._row_to_dict(r) for r in relations], + } + except Exception as e: + self._logger.error(f"[SocialFacade] 获取用户社交关系失败: {e}") + return {'user_id': user_id, 'group_id': group_id, 'relations': []} diff --git a/services/manager_factory.py b/services/database/manager_factory.py similarity index 56% rename from services/manager_factory.py rename to services/database/manager_factory.py index 21229c7..7c77117 100644 --- a/services/manager_factory.py +++ b/services/database/manager_factory.py @@ -5,9 +5,9 @@ from typing import Optional, Union from astrbot.api import logger -from ..config import PluginConfig -from ..core.interfaces import IDataStorage -from ..core.framework_llm_adapter import FrameworkLLMAdapter +from ...config import PluginConfig +from ...core.interfaces import IDataStorage +from ...core.framework_llm_adapter import FrameworkLLMAdapter class ManagerFactory: @@ -38,20 +38,9 @@ def __init__(self, config: PluginConfig): config: 插件配置 """ self.config = config + logger.info("[ManagerFactory] initialized") - # 检查是否启用增强型管理器 - self.use_enhanced = getattr(config, 'use_enhanced_managers', False) - self.use_sqlalchemy = getattr(config, 'use_sqlalchemy', False) - - logger.info( - f"[管理器工厂] 初始化完成 " - f"(SQLAlchemy={self.use_sqlalchemy}, " - f"增强型管理器={self.use_enhanced})" - ) - - # ============================================================ # 数据库管理器 - # ============================================================ def create_database_manager(self, context=None): """ @@ -61,20 +50,13 @@ def create_database_manager(self, context=None): context: 上下文对象 Returns: - 数据库管理器实例(原始或增强型) + SQLAlchemy 数据库管理器实例 """ - if self.use_sqlalchemy: - from ..services.database_factory import create_database_manager - logger.info("📦 [工厂] 创建 SQLAlchemy 数据库管理器") - return create_database_manager(self.config, context) - else: - from ..services.database_manager import DatabaseManager - logger.info("📦 [工厂] 创建传统数据库管理器") - return DatabaseManager(self.config, context) - - # ============================================================ + from .sqlalchemy_database_manager import SQLAlchemyDatabaseManager + logger.info("[ManagerFactory] Creating SQLAlchemy database manager") + return SQLAlchemyDatabaseManager(self.config, context) + # 好感度管理器 - # ============================================================ def create_affection_manager( self, @@ -89,20 +71,13 @@ def create_affection_manager( llm_adapter: LLM 适配器 Returns: - 好感度管理器实例(原始或增强型) + 好感度管理器实例 """ - if self.use_enhanced: - from ..services.enhanced_affection_manager import EnhancedAffectionManager - logger.info("📦 [工厂] 创建增强型好感度管理器") - return EnhancedAffectionManager(self.config, database_manager, llm_adapter) - else: - from ..services.affection_manager import AffectionManager - logger.info("📦 [工厂] 创建传统好感度管理器") - return AffectionManager(self.config, database_manager, llm_adapter) - - # ============================================================ + from ..state import AffectionManager + logger.info("[ManagerFactory] Creating affection manager") + return AffectionManager(self.config, database_manager, llm_adapter) + # 记忆管理器 - # ============================================================ def create_memory_manager( self, @@ -121,28 +96,16 @@ def create_memory_manager( Returns: 记忆管理器实例(原始或增强型) """ - if self.use_enhanced: - from ..services.enhanced_memory_graph_manager import EnhancedMemoryGraphManager - logger.info("📦 [工厂] 创建增强型记忆图管理器") - return EnhancedMemoryGraphManager.get_instance( - self.config, - database_manager, - llm_adapter, - decay_manager - ) - else: - from ..services.memory_graph_manager import MemoryGraphManager - logger.info("📦 [工厂] 创建传统记忆图管理器") - return MemoryGraphManager.get_instance( - self.config, - database_manager, - llm_adapter, - decay_manager - ) - - # ============================================================ + from ..state import EnhancedMemoryGraphManager + logger.info("[ManagerFactory] Creating memory graph manager") + return EnhancedMemoryGraphManager.get_instance( + self.config, + database_manager, + llm_adapter, + decay_manager + ) + # 心理状态管理器 - # ============================================================ def create_psychological_manager( self, @@ -161,28 +124,16 @@ def create_psychological_manager( Returns: 心理状态管理器实例(原始或增强型) """ - if self.use_enhanced: - from ..services.enhanced_psychological_state_manager import EnhancedPsychologicalStateManager - logger.info("📦 [工厂] 创建增强型心理状态管理器") - return EnhancedPsychologicalStateManager( - self.config, - database_manager, - llm_adapter, - affection_manager - ) - else: - from ..services.psychological_state_manager import PsychologicalStateManager - logger.info("📦 [工厂] 创建传统心理状态管理器") - return PsychologicalStateManager( - self.config, - database_manager, - llm_adapter, - affection_manager - ) - - # ============================================================ + from ..state import EnhancedPsychologicalStateManager + logger.info("[ManagerFactory] Creating psychological state manager") + return EnhancedPsychologicalStateManager( + self.config, + database_manager, + llm_adapter, + affection_manager + ) + # 社交关系管理器 - # ============================================================ def create_social_relation_manager( self, @@ -203,8 +154,8 @@ def create_social_relation_manager( """ # 注意: 原始的社交关系管理器已经叫 EnhancedSocialRelationManager # 所以这里不需要区分 - from ..services.enhanced_social_relation_manager import EnhancedSocialRelationManager - logger.info("📦 [工厂] 创建社交关系管理器") + from ..social import EnhancedSocialRelationManager + logger.info(" [工厂] 创建社交关系管理器") return EnhancedSocialRelationManager( self.config, database_manager, @@ -212,9 +163,7 @@ def create_social_relation_manager( psychological_manager ) - # ============================================================ # 其他管理器(可根据需要扩展) - # ============================================================ def create_diversity_manager( self, @@ -222,8 +171,8 @@ def create_diversity_manager( llm_adapter: Optional[FrameworkLLMAdapter] = None ): """创建响应多样性管理器""" - from ..services.response_diversity_manager import ResponseDiversityManager - logger.info("📦 [工厂] 创建响应多样性管理器") + from ..response import ResponseDiversityManager + logger.info(" [工厂] 创建响应多样性管理器") return ResponseDiversityManager(self.config, database_manager, llm_adapter) def create_time_decay_manager( @@ -231,13 +180,11 @@ def create_time_decay_manager( database_manager: IDataStorage ): """创建时间衰减管理器""" - from ..services.time_decay_manager import TimeDecayManager - logger.info("📦 [工厂] 创建时间衰减管理器") + from ..state import TimeDecayManager + logger.info(" [工厂] 创建时间衰减管理器") return TimeDecayManager(self.config, database_manager) - # ============================================================ # 批量创建 - # ============================================================ def create_all_managers(self, context=None) -> dict: """ @@ -250,7 +197,7 @@ def create_all_managers(self, context=None) -> dict: dict: 包含所有管理器的字典 """ logger.info("=" * 70) - logger.info("🏭 [管理器工厂] 开始创建所有管理器...") + logger.info(" [管理器工厂] 开始创建所有管理器...") logger.info("=" * 70) managers = {} @@ -259,7 +206,7 @@ def create_all_managers(self, context=None) -> dict: managers['database'] = self.create_database_manager(context) # 2. LLM 适配器(从主插件获取) - managers['llm_adapter'] = None # 需要外部传入 + managers['llm_adapter'] = None # 需要外部传入 # 3. 时间衰减管理器 managers['time_decay'] = self.create_time_decay_manager(managers['database']) @@ -298,14 +245,12 @@ def create_all_managers(self, context=None) -> dict: ) logger.info("=" * 70) - logger.info(f"✅ [管理器工厂] 成功创建 {len(managers)} 个管理器") + logger.info(f" [管理器工厂] 成功创建 {len(managers)} 个管理器") logger.info("=" * 70) return managers - # ============================================================ # 工具方法 - # ============================================================ def get_configuration_info(self) -> dict: """ @@ -315,8 +260,6 @@ def get_configuration_info(self) -> dict: dict: 配置信息 """ return { - 'use_sqlalchemy': self.use_sqlalchemy, - 'use_enhanced_managers': self.use_enhanced, 'enable_affection_system': self.config.enable_affection_system, 'enable_memory_graph': self.config.enable_memory_graph, 'enable_maibot_features': self.config.enable_maibot_features, @@ -327,19 +270,17 @@ def print_configuration(self): info = self.get_configuration_info() logger.info("=" * 70) - logger.info("📋 [管理器工厂] 当前配置:") + logger.info(" [管理器工厂] 当前配置:") logger.info("=" * 70) for key, value in info.items(): - status = "✅ 启用" if value else "❌ 禁用" - logger.info(f" {key}: {status}") + status = " 启用" if value else " 禁用" + logger.info(f" {key}: {status}") logger.info("=" * 70) -# ============================================================ # 全局工厂实例 -# ============================================================ _global_factory = None diff --git a/services/database/sqlalchemy_database_manager.py b/services/database/sqlalchemy_database_manager.py new file mode 100644 index 0000000..9fa6b3b --- /dev/null +++ b/services/database/sqlalchemy_database_manager.py @@ -0,0 +1,736 @@ +""" +DomainRouter — 薄路由层,将所有数据库方法委托给领域 Facade + +前身为 4308 行的单体 SQLAlchemyDatabaseManager,现已拆分为 +11 个领域 Facade,本文件仅保留生命周期管理、会话/连接基础设施 +以及方法路由。 +""" +import os +import asyncio +from typing import Dict, List, Optional, Any +from contextlib import asynccontextmanager + +from astrbot.api import logger + +from ...config import PluginConfig +from ...core.database.engine import DatabaseEngine + + +class SQLAlchemyDatabaseManager: + """DomainRouter — 薄路由层,委托给 11 个领域 Facade。 + + 对外接口(方法签名、返回类型)与旧版完全一致,消费者无需任何改动。 + """ + + # Lifecycle + + def __init__(self, config: PluginConfig, context=None): + self.config = config + self.context = context + self.engine: Optional[DatabaseEngine] = None + self._started = False + self._starting = False + self._start_lock = asyncio.Lock() + + # Facades(在 start() 中初始化) + self._affection = None + self._message = None + self._learning = None + self._jargon = None + self._persona = None + self._social = None + self._expression = None + self._psychological = None + self._reinforcement = None + self._metrics = None + self._admin = None + + async def start(self) -> bool: + """启动数据库管理器(带并发保护)""" + async with self._start_lock: + if self._started: + logger.debug("[DomainRouter] 已启动,跳过") + return True + + if self._starting: + logger.warning("[DomainRouter] 正在启动中,等待…") + for _ in range(50): + await asyncio.sleep(0.1) + if self._started: + return True + logger.error("[DomainRouter] 启动超时") + return False + + try: + self._starting = True + logger.info("[DomainRouter] 开始启动…") + + db_url = self._get_database_url() + + if hasattr(self.config, 'db_type') and self.config.db_type.lower() == 'mysql': + await self._ensure_mysql_database_exists() + + self.engine = DatabaseEngine(db_url, echo=False) + logger.info("[DomainRouter] 数据库引擎已创建") + + await self.engine.create_tables() + + if await self.engine.health_check(): + self._init_facades() + self._started = True + self._starting = False + logger.info("[DomainRouter] 数据库启动成功") + return True + + self._started = False + self._starting = False + logger.error("[DomainRouter] 数据库健康检查失败") + return False + + except Exception as e: + self._started = False + self._starting = False + logger.error(f"[DomainRouter] 启动失败: {e}", exc_info=True) + return False + + async def stop(self) -> bool: + """停止数据库管理器""" + if not self._started: + return True + try: + if self.engine: + await self.engine.close() + self._started = False + logger.info("[DomainRouter] 数据库已停止") + return True + except Exception as e: + logger.error(f"[DomainRouter] 停止失败: {e}") + return False + + # Facade initialization + + def _init_facades(self): + """初始化所有领域 Facade""" + from .facades import ( + AffectionFacade, MessageFacade, LearningFacade, + JargonFacade, PersonaFacade, SocialFacade, + ExpressionFacade, PsychologicalFacade, ReinforcementFacade, + MetricsFacade, AdminFacade, + ) + self._affection = AffectionFacade(self.engine, self.config) + self._message = MessageFacade(self.engine, self.config) + self._learning = LearningFacade(self.engine, self.config) + self._jargon = JargonFacade(self.engine, self.config) + self._persona = PersonaFacade(self.engine, self.config) + self._social = SocialFacade(self.engine, self.config) + self._expression = ExpressionFacade(self.engine, self.config) + self._psychological = PsychologicalFacade(self.engine, self.config) + self._reinforcement = ReinforcementFacade(self.engine, self.config) + self._metrics = MetricsFacade(self.engine, self.config) + self._admin = AdminFacade(self.engine, self.config) + logger.info("[DomainRouter] 11 个领域 Facade 已初始化") + + # Infrastructure: database URL + + def _get_database_url(self) -> str: + """获取数据库连接 URL""" + if hasattr(self.config, 'db_type') and self.config.db_type.lower() == 'mysql': + host = getattr(self.config, 'mysql_host', 'localhost') + port = getattr(self.config, 'mysql_port', 3306) + user = getattr(self.config, 'mysql_user', 'root') + password = getattr(self.config, 'mysql_password', '') + database = getattr(self.config, 'mysql_database', 'astrbot_self_learning') + return f"mysql+aiomysql://{user}:{password}@{host}:{port}/{database}" + + db_path = getattr(self.config, 'messages_db_path', None) + if not db_path: + db_path = os.path.join(self.config.data_dir, 'messages.db') + if not os.path.isabs(db_path): + db_path = os.path.abspath(db_path) + return f"sqlite:///{db_path}" + + async def _ensure_mysql_database_exists(self): + """确保 MySQL 数据库存在""" + try: + import aiomysql + host = getattr(self.config, 'mysql_host', 'localhost') + port = getattr(self.config, 'mysql_port', 3306) + user = getattr(self.config, 'mysql_user', 'root') + password = getattr(self.config, 'mysql_password', '') + database = getattr(self.config, 'mysql_database', 'astrbot_self_learning') + + conn = await aiomysql.connect( + host=host, port=port, user=user, + password=password, charset='utf8mb4', + ) + try: + async with conn.cursor() as cursor: + await cursor.execute( + "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = %s", + (database,), + ) + if not await cursor.fetchone(): + logger.info(f"[DomainRouter] 数据库 {database} 不存在,正在创建…") + await cursor.execute( + f"CREATE DATABASE `{database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" + ) + await conn.commit() + logger.info(f"[DomainRouter] 数据库 {database} 创建成功") + finally: + conn.close() + except Exception as e: + logger.error(f"[DomainRouter] 确保 MySQL 数据库存在失败: {e}") + raise + + # Infrastructure: session + + @asynccontextmanager + async def get_session(self): + """获取 ORM 会话(async context manager)""" + if not self.engine: + if self._starting: + logger.debug("[DomainRouter] 等待 engine 创建…") + for _ in range(30): + await asyncio.sleep(0.1) + if self.engine: + break + if not self.engine: + raise RuntimeError("数据库管理器启动超时,engine未创建") + else: + raise RuntimeError("数据库管理器未启动,engine不存在") + + if not self._started: + logger.debug("[DomainRouter] get_session: _started=False 但 engine 存在,继续执行") + + session = self.engine.get_session() + try: + async with session: + yield session + finally: + await session.close() + + # Domain delegates: AffectionFacade + + async def get_user_affection(self, group_id: str, user_id: str) -> Optional[Dict[str, Any]]: + return await self._affection.get_user_affection(group_id, user_id) + + async def update_user_affection( + self, group_id: str, user_id: str, new_level: int, + change_reason: str = "", bot_mood: str = "", + ) -> bool: + return await self._affection.update_user_affection( + group_id, user_id, new_level, change_reason, bot_mood, + ) + + async def get_all_user_affections(self, group_id: str) -> List[Dict[str, Any]]: + return await self._affection.get_all_user_affections(group_id) + + async def get_total_affection(self, group_id: str) -> int: + return await self._affection.get_total_affection(group_id) + + async def save_bot_mood( + self, group_id: str, mood_type: str, mood_intensity: float, + mood_description: str, duration_hours: int = 24, + ) -> bool: + return await self._affection.save_bot_mood( + group_id, mood_type, mood_intensity, mood_description, duration_hours, + ) + + async def get_current_bot_mood(self, group_id: str) -> Optional[Dict[str, Any]]: + return await self._affection.get_current_bot_mood(group_id) + + # Domain delegates: MessageFacade + + async def save_raw_message(self, message_data) -> int: + return await self._message.save_raw_message(message_data) + + async def get_recent_raw_messages( + self, group_id: str, limit: int = 200, + ) -> List[Dict[str, Any]]: + return await self._message.get_recent_raw_messages(group_id, limit) + + async def get_unprocessed_messages( + self, limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: + return await self._message.get_unprocessed_messages(limit) + + async def mark_messages_processed(self, message_ids: List[int]) -> bool: + return await self._message.mark_messages_processed(message_ids) + + async def get_messages_by_timerange( + self, group_id: str, start_time: int, end_time: int, limit: int = 500, + ) -> List[Dict[str, Any]]: + return await self._message.get_messages_by_timerange( + group_id, start_time, end_time, limit, + ) + + async def get_messages_by_group_and_timerange( + self, group_id: str, start_time: int, end_time: int, limit: int = 500, + ) -> List[Dict[str, Any]]: + return await self._message.get_messages_by_group_and_timerange( + group_id, start_time, end_time, limit, + ) + + async def get_messages_for_replay( + self, group_id: str, days: int = 30, limit: int = 100, + ) -> List[Dict[str, Any]]: + return await self._message.get_messages_for_replay(group_id, days, limit) + + async def get_recent_filtered_messages( + self, group_id: str, limit: int = 20, + ) -> List[Dict[str, Any]]: + return await self._message.get_recent_filtered_messages(group_id, limit) + + async def get_filtered_messages_for_learning( + self, limit: int = 20, + ) -> List[Dict[str, Any]]: + return await self._message.get_filtered_messages_for_learning(limit) + + async def add_filtered_message(self, filtered_data: Dict[str, Any]) -> int: + return await self._message.add_filtered_message(filtered_data) + + async def save_bot_message( + self, group_id: str, message: str, timestamp: int = None, + ) -> bool: + return await self._message.save_bot_message(group_id, message, timestamp) + + async def get_recent_bot_responses( + self, group_id: str, limit: int = 10, + ) -> List[str]: + return await self._message.get_recent_bot_responses(group_id, limit) + + async def get_message_statistics( + self, group_id: str = None, + ) -> Dict[str, Any]: + return await self._message.get_message_statistics(group_id) + + async def get_messages_statistics(self) -> Dict[str, Any]: + return await self._message.get_messages_statistics() + + async def get_group_messages_statistics(self, group_id: str) -> Dict[str, Any]: + return await self._message.get_group_messages_statistics(group_id) + + async def get_group_user_statistics( + self, group_id: str, + ) -> Dict[str, Dict[str, Any]]: + return await self._message.get_group_user_statistics(group_id) + + async def get_groups_for_social_analysis(self) -> List[Dict[str, Any]]: + return await self._message.get_groups_for_social_analysis() + + # Domain delegates: LearningFacade + + async def add_persona_learning_review( + self, + review_data: Dict[str, Any] = None, + *, + group_id: str = None, + proposed_content: str = None, + learning_source: str = '', + confidence_score: float = 0.5, + raw_analysis: str = '', + metadata: Dict[str, Any] = None, + original_content: str = '', + new_content: str = '', + ) -> int: + """兼容新旧两种调用方式:单 dict 或关键字参数。""" + if review_data is None: + review_data = { + 'group_id': group_id or '', + 'proposed_content': proposed_content or '', + 'update_type': learning_source, + 'confidence_score': confidence_score, + 'reason': raw_analysis, + 'metadata': metadata or {}, + 'original_content': original_content, + 'new_content': new_content, + } + return await self._learning.add_persona_learning_review(review_data) + + async def get_pending_persona_update_records(self) -> List[Dict[str, Any]]: + return await self._learning.get_pending_persona_update_records() + + async def save_persona_update_record(self, record_data: Dict[str, Any]) -> int: + return await self._learning.save_persona_update_record(record_data) + + async def delete_persona_update_record(self, record_id: int) -> bool: + return await self._learning.delete_persona_update_record(record_id) + + async def get_persona_update_record_by_id( + self, record_id: int, + ) -> Optional[Dict[str, Any]]: + return await self._learning.get_persona_update_record_by_id(record_id) + + async def get_reviewed_persona_update_records( + self, limit: int = 50, offset: int = 0, status_filter: str = None, + ) -> List[Dict[str, Any]]: + return await self._learning.get_reviewed_persona_update_records( + limit=limit, offset=offset, status_filter=status_filter, + ) + + async def update_persona_update_record_status( + self, record_id: int, status: str, comment: str = None, + ) -> bool: + return await self._learning.update_persona_update_record_status( + record_id, status, comment, + ) + + async def create_style_learning_review( + self, review_data: Dict[str, Any], + ) -> int: + return await self._learning.create_style_learning_review(review_data) + + async def get_pending_style_reviews( + self, limit: int = 50, offset: int = 0, + ) -> List[Dict[str, Any]]: + return await self._learning.get_pending_style_reviews(limit, offset) + + async def get_reviewed_style_learning_updates( + self, limit: int = 50, offset: int = 0, status_filter: str = None, + ) -> List[Dict[str, Any]]: + return await self._learning.get_reviewed_style_learning_updates( + limit=limit, offset=offset, status_filter=status_filter, + ) + + async def update_style_review_status( + self, review_id: int, status: str, reviewer_comment: str = '', + ) -> bool: + return await self._learning.update_style_review_status( + review_id, status, reviewer_comment, + ) + + async def delete_style_review_by_id(self, review_id: int) -> bool: + return await self._learning.delete_style_review_by_id(review_id) + + async def get_approved_few_shots( + self, group_id: str, limit: int = 3, + ) -> List[str]: + return await self._learning.get_approved_few_shots(group_id, limit) + + async def get_pending_persona_learning_reviews( + self, limit: int = 50, offset: int = 0, + ) -> List[Dict[str, Any]]: + return await self._learning.get_pending_persona_learning_reviews(limit, offset) + + async def get_reviewed_persona_learning_updates( + self, limit: int = 50, offset: int = 0, status_filter: str = None, + ) -> List[Dict[str, Any]]: + return await self._learning.get_reviewed_persona_learning_updates( + limit=limit, offset=offset, status_filter=status_filter, + ) + + async def delete_persona_learning_review_by_id(self, review_id: int) -> bool: + return await self._learning.delete_persona_learning_review_by_id(review_id) + + async def get_persona_learning_review_by_id( + self, review_id: int, + ) -> Optional[Dict[str, Any]]: + return await self._learning.get_persona_learning_review_by_id(review_id) + + async def update_persona_learning_review_status( + self, review_id: int, status: str, comment: str = None, + modified_content: str = None, + ) -> bool: + return await self._learning.update_persona_learning_review_status( + review_id, status, comment, modified_content, + ) + + async def get_learning_batch_history( + self, group_id: str = None, limit: int = 50, + ) -> List[Dict[str, Any]]: + return await self._learning.get_learning_batch_history(group_id, limit) + + async def get_recent_learning_batches( + self, limit: int = 10, + ) -> List[Dict[str, Any]]: + return await self._learning.get_recent_learning_batches(limit) + + async def get_learning_sessions(self, group_id: str) -> List[Dict[str, Any]]: + return await self._learning.get_learning_sessions(group_id) + + async def get_recent_learning_sessions( + self, days: int = 7, + ) -> List[Dict[str, Any]]: + return await self._learning.get_recent_learning_sessions(days) + + async def save_learning_session_record( + self, group_id: str, session_data: Dict[str, Any], + ) -> bool: + return await self._learning.save_learning_session_record(group_id, session_data) + + async def save_learning_performance_record( + self, group_id: str, performance_data: Dict[str, Any], + ) -> bool: + return await self._learning.save_learning_performance_record( + group_id, performance_data, + ) + + async def count_pending_persona_updates(self) -> int: + return await self._learning.count_pending_persona_updates() + + async def count_style_learning_patterns(self) -> int: + return await self._learning.count_style_learning_patterns() + + async def count_refined_messages(self) -> int: + return await self._learning.count_refined_messages() + + async def get_style_learning_statistics(self) -> Dict[str, Any]: + return await self._learning.get_style_learning_statistics() + + async def get_style_progress_data( + self, group_id: str = None, + ) -> List[Dict[str, Any]]: + return await self._learning.get_style_progress_data(group_id) + + async def get_learning_patterns_data( + self, group_id: str = None, + ) -> Dict[str, Any]: + return await self._learning.get_learning_patterns_data(group_id) + + # Domain delegates: JargonFacade + + async def get_jargon(self, chat_id: str, content: str) -> Optional[Dict[str, Any]]: + return await self._jargon.get_jargon(chat_id, content) + + async def insert_jargon(self, jargon_data: Dict[str, Any]) -> Optional[int]: + return await self._jargon.insert_jargon(jargon_data) + + async def update_jargon(self, jargon_data: Dict[str, Any]) -> bool: + return await self._jargon.update_jargon(jargon_data) + + async def get_jargon_statistics(self, group_id: str = None) -> Dict[str, Any]: + return await self._jargon.get_jargon_statistics(group_id) + + async def get_recent_jargon_list( + self, group_id: str = None, chat_id: str = None, + limit: int = 50, offset: int = 0, only_confirmed: bool = False, + ) -> List[Dict[str, Any]]: + return await self._jargon.get_recent_jargon_list( + group_id, chat_id, limit, offset, only_confirmed, + ) + + async def get_jargon_count( + self, chat_id: str = None, only_confirmed: bool = False, + ) -> int: + return await self._jargon.get_jargon_count(chat_id, only_confirmed) + + async def search_jargon( + self, keyword: str, chat_id: str = None, + confirmed_only: bool = False, limit: int = 50, + ) -> List[Dict[str, Any]]: + return await self._jargon.search_jargon( + keyword=keyword, chat_id=chat_id, + confirmed_only=confirmed_only, limit=limit, + ) + + async def get_jargon_by_id(self, jargon_id: int) -> Optional[Dict[str, Any]]: + return await self._jargon.get_jargon_by_id(jargon_id) + + async def delete_jargon_by_id(self, jargon_id: int) -> bool: + return await self._jargon.delete_jargon_by_id(jargon_id) + + async def set_jargon_global(self, jargon_id: int, is_global: bool) -> bool: + return await self._jargon.set_jargon_global(jargon_id, is_global) + + async def sync_global_jargon_to_group(self, target_chat_id: str) -> int: + return await self._jargon.sync_global_jargon_to_group(target_chat_id) + + async def save_or_update_jargon( + self, chat_id: str, content: str, jargon_data: Dict[str, Any], + ) -> Optional[int]: + return await self._jargon.save_or_update_jargon( + chat_id, content, jargon_data, + ) + + async def get_global_jargon_list( + self, limit: int = 100, + ) -> List[Dict[str, Any]]: + return await self._jargon.get_global_jargon_list(limit) + + async def get_jargon_groups(self) -> List[Dict[str, Any]]: + return await self._jargon.get_jargon_groups() + + # Domain delegates: PersonaFacade + + async def backup_persona(self, group_id: str, backup_data: Dict[str, Any]) -> bool: + backup_data.setdefault('group_id', group_id) + return await self._persona.backup_persona(backup_data) + + async def get_persona_backups(self, limit: int = 10) -> List[Dict[str, Any]]: + return await self._persona.get_persona_backups(limit) + + async def restore_persona_backup( + self, backup_id: int, + ) -> Optional[Dict[str, Any]]: + return await self._persona.restore_persona_backup(backup_id) + + async def get_persona_update_history( + self, group_id: str = None, limit: int = 50, + ) -> List[Dict[str, Any]]: + return await self._persona.get_persona_update_history(group_id, limit) + + # Domain delegates: SocialFacade + + async def load_user_profile(self, qq_id: str) -> Optional[Dict[str, Any]]: + return await self._social.load_user_profile(qq_id) + + async def save_user_profile( + self, qq_id: str, profile_data: Dict[str, Any], + ) -> bool: + return await self._social.save_user_profile(qq_id, profile_data) + + async def load_user_preferences( + self, user_id: str, group_id: str, + ) -> Optional[Dict[str, Any]]: + return await self._social.load_user_preferences(user_id, group_id) + + async def save_user_preferences( + self, user_id: str, group_id: str, prefs: Dict[str, Any], + ) -> bool: + return await self._social.save_user_preferences(user_id, group_id, prefs) + + async def get_social_relations_by_group( + self, group_id: str, + ) -> List[Dict[str, Any]]: + return await self._social.get_social_relations_by_group(group_id) + + async def get_social_relationships( + self, group_id: str, + ) -> List[Dict[str, Any]]: + return await self._social.get_social_relationships(group_id) + + async def load_social_graph(self, group_id: str) -> List[Dict[str, Any]]: + return await self._social.load_social_graph(group_id) + + async def save_social_relation( + self, group_id: str, relation_data: Dict[str, Any], + ) -> bool: + return await self._social.save_social_relation(group_id, relation_data) + + async def get_user_social_relations( + self, group_id: str, user_id: str, + ) -> Dict[str, Any]: + return await self._social.get_user_social_relations(group_id, user_id) + + # Domain delegates: ExpressionFacade + + async def get_all_expression_patterns(self) -> Dict[str, List[Dict[str, Any]]]: + return await self._expression.get_all_expression_patterns() + + async def get_expression_patterns_statistics(self) -> Dict[str, Any]: + return await self._expression.get_expression_patterns_statistics() + + async def get_group_expression_patterns( + self, group_id: str, limit: int = None, + ) -> List[Dict[str, Any]]: + return await self._expression.get_group_expression_patterns(group_id, limit) + + async def get_recent_week_expression_patterns( + self, group_id: str = None, limit: int = 50, + ) -> List[Dict[str, Any]]: + return await self._expression.get_recent_week_expression_patterns( + group_id, limit, + ) + + async def load_style_profile( + self, profile_name: str, + ) -> Optional[Dict[str, Any]]: + return await self._expression.load_style_profile(profile_name) + + async def save_style_profile( + self, profile_name: str, profile_data: Dict[str, Any], + ) -> bool: + return await self._expression.save_style_profile(profile_name, profile_data) + + async def save_style_learning_record( + self, record_data: Dict[str, Any], + ) -> bool: + return await self._expression.save_style_learning_record(record_data) + + async def save_language_style_pattern( + self, language_style: str, pattern_data: Dict[str, Any], + ) -> bool: + return await self._expression.save_language_style_pattern( + language_style, pattern_data, + ) + + # Domain delegates: PsychologicalFacade + + async def load_emotion_profile( + self, user_id: str, group_id: str, + ) -> Optional[Dict[str, Any]]: + return await self._psychological.load_emotion_profile(user_id, group_id) + + async def save_emotion_profile( + self, user_id: str, group_id: str, profile: Dict[str, Any], + ) -> bool: + return await self._psychological.save_emotion_profile( + user_id, group_id, profile, + ) + + # Domain delegates: ReinforcementFacade + + async def get_learning_history_for_reinforcement( + self, group_id: str, limit: int = 50, + ) -> List[Dict[str, Any]]: + return await self._reinforcement.get_learning_history_for_reinforcement( + group_id, limit, + ) + + async def save_reinforcement_learning_result( + self, group_id: str, result_data: Dict[str, Any], + ) -> bool: + return await self._reinforcement.save_reinforcement_learning_result( + group_id, result_data, + ) + + async def get_persona_fusion_history( + self, group_id: str, limit: int = 10, + ) -> List[Dict[str, Any]]: + return await self._reinforcement.get_persona_fusion_history(group_id, limit) + + async def save_persona_fusion_result( + self, group_id: str, fusion_data: Dict[str, Any], + ) -> bool: + return await self._reinforcement.save_persona_fusion_result( + group_id, fusion_data, + ) + + async def get_learning_performance_history( + self, group_id: str, limit: int = 30, + ) -> List[Dict[str, Any]]: + return await self._reinforcement.get_learning_performance_history( + group_id, limit, + ) + + async def save_strategy_optimization_result( + self, group_id: str, optimization_data: Dict[str, Any], + ) -> bool: + return await self._reinforcement.save_strategy_optimization_result( + group_id, optimization_data, + ) + + # Domain delegates: MetricsFacade + + async def get_group_statistics( + self, group_id: str = None, + ) -> Dict[str, Any]: + return await self._metrics.get_group_statistics(group_id) + + async def get_detailed_metrics( + self, group_id: str = None, + ) -> Dict[str, Any]: + return await self._metrics.get_detailed_metrics(group_id) + + async def get_trends_data(self) -> Dict[str, Any]: + return await self._metrics.get_trends_data() + + # Domain delegates: AdminFacade + + async def clear_all_messages_data(self) -> bool: + return await self._admin.clear_all_messages_data() + + async def export_messages_learning_data( + self, group_id: str = None, + ) -> Dict[str, Any]: + return await self._admin.export_messages_learning_data(group_id) diff --git a/services/database_factory.py b/services/database_factory.py deleted file mode 100644 index d525eb3..0000000 --- a/services/database_factory.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -数据库管理器工厂 -默认使用 SQLAlchemy ORM 数据库管理器(支持自动迁移) -""" -from astrbot.api import logger - -from ..config import PluginConfig -from .sqlalchemy_database_manager import SQLAlchemyDatabaseManager - - -def create_database_manager( - config: PluginConfig, - context=None -) -> SQLAlchemyDatabaseManager: - """ - 创建数据库管理器 - - 默认使用 SQLAlchemy 版本(带自动数据库迁移功能) - - Args: - config: 插件配置 - context: 上下文(可选) - - Returns: - SQLAlchemy 数据库管理器实例 - """ - logger.info("📦 [数据库] 使用 SQLAlchemy 版本的数据库管理器(支持自动迁移)") - return SQLAlchemyDatabaseManager(config, context) - - -__all__ = [ - 'SQLAlchemyDatabaseManager', - 'create_database_manager', -] diff --git a/services/database_manager.py b/services/database_manager.py deleted file mode 100644 index b97b813..0000000 --- a/services/database_manager.py +++ /dev/null @@ -1,8132 +0,0 @@ -""" -数据库管理器 - 管理分群数据库和数据持久化 即将弃用 -""" -import os -import json -import aiosqlite -import time -import asyncio -from typing import Dict, List, Optional, Any, Callable -from datetime import datetime - -from astrbot.api import logger - -from ..config import PluginConfig -from ..constants import UPDATE_TYPE_EXPRESSION_LEARNING -from ..exceptions import DataStorageError - -from ..core.patterns import AsyncServiceBase - -# 导入数据库后端 -from ..core.database import ( - DatabaseFactory, - DatabaseConfig, - DatabaseType, - IDatabaseBackend -) - -# ✨ 导入ORM支持 -from ..core.database.engine import DatabaseEngine -from ..repositories.reinforcement_repository import ( - ReinforcementLearningRepository, - PersonaFusionRepository, - StrategyOptimizationRepository -) -from ..repositories.learning_repository import ( - LearningBatchRepository, - LearningSessionRepository, - StyleLearningReviewRepository, - PersonaLearningReviewRepository -) -from ..repositories.message_repository import ( - ConversationContextRepository, - ConversationTopicClusteringRepository, - ConversationQualityMetricsRepository, - ContextSimilarityCacheRepository -) -from ..repositories.jargon_repository import ( - JargonRepository -) - - -class DatabaseConnectionPool: - """数据库连接池""" - - def __init__(self, db_path: str, max_connections: int = 10, min_connections: int = 2): - self.db_path = db_path - self.max_connections = max_connections - self.min_connections = min_connections - self.pool: asyncio.Queue = asyncio.Queue(maxsize=max_connections) - self.active_connections = 0 - self.total_connections = 0 - self._lock = asyncio.Lock() - self._logger = logger - - async def initialize(self): - """初始化连接池""" - async with self._lock: - # 创建最小数量的连接 - for _ in range(self.min_connections): - conn = await self._create_connection() - await self.pool.put(conn) - - async def _create_connection(self) -> aiosqlite.Connection: - """创建新的数据库连接""" - # 确保目录存在 - db_dir = os.path.dirname(self.db_path) - os.makedirs(db_dir, exist_ok=True) - - # 检查数据库文件权限 - if os.path.exists(self.db_path): - try: - import stat - os.chmod(self.db_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP) - except OSError as e: - self._logger.warning(f"无法修改数据库文件权限: {e}") - - conn = await aiosqlite.connect(self.db_path) - - # 设置连接参数 - await conn.execute('PRAGMA foreign_keys = ON') - await conn.execute('PRAGMA journal_mode = WAL') - await conn.execute('PRAGMA synchronous = NORMAL') - await conn.execute('PRAGMA cache_size = 10000') - await conn.execute('PRAGMA temp_store = memory') - await conn.commit() - - self.total_connections += 1 - self._logger.debug(f"创建新数据库连接,总连接数: {self.total_connections}") - return conn - - async def get_connection(self) -> aiosqlite.Connection: - """获取数据库连接""" - try: - # 尝试从池中获取连接(非阻塞) - conn = self.pool.get_nowait() - self.active_connections += 1 - return conn - except asyncio.QueueEmpty: - # 池中无可用连接 - async with self._lock: - if self.total_connections < self.max_connections: - # 可以创建新连接 - conn = await self._create_connection() - self.active_connections += 1 - return conn - else: - # 达到最大连接数,等待连接归还 - self._logger.debug("连接池已满,等待连接归还...") - conn = await self.pool.get() - self.active_connections += 1 - return conn - - async def return_connection(self, conn: aiosqlite.Connection): - """归还数据库连接""" - if conn: - try: - # 检查连接是否仍然有效 - await conn.execute('SELECT 1') - await self.pool.put(conn) - self.active_connections -= 1 - except Exception as e: - # 连接已损坏,关闭并减少计数 - self._logger.warning(f"连接已损坏,关闭连接: {e}") - try: - await conn.close() - except: - pass - self.total_connections -= 1 - self.active_connections -= 1 - - async def close_all(self): - """关闭所有连接""" - self._logger.info("开始关闭数据库连接池...") - - # 关闭池中的所有连接 - while not self.pool.empty(): - try: - conn = self.pool.get_nowait() - await conn.close() - self.total_connections -= 1 - except asyncio.QueueEmpty: - break - except Exception as e: - self._logger.error(f"关闭连接时出错: {e}") - - self._logger.info(f"数据库连接池已关闭,剩余连接数: {self.total_connections}") - - async def __aenter__(self): - """异步上下文管理器入口""" - return await self.get_connection() - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器退出""" - # 注意:这里不能直接归还连接,因为我们不知道连接对象 - # 实际使用时需要在调用方手动归还 - pass - - -class DatabaseManager(AsyncServiceBase): - """数据库管理器 - 使用连接池管理数据库连接,支持SQLite和MySQL""" - - def __init__(self, config: PluginConfig, context=None, skip_table_init: bool = False): - super().__init__("database_manager") - self.config = config - self.context = context - self.group_db_connections: Dict[str, aiosqlite.Connection] = {} - self.skip_table_init = skip_table_init # ✨ 新增:跳过表初始化标志 - - # 安全地构建路径 - if not config.data_dir: - raise ValueError("config.data_dir 不能为空") - - self.group_data_dir = os.path.join(config.data_dir, "group_databases") - self.messages_db_path = config.messages_db_path - - # 新增: 数据库后端(支持SQLite和MySQL) - self.db_backend: Optional[IDatabaseBackend] = None - - # ✨ 新增: DatabaseEngine for ORM支持 - self.db_engine: Optional[DatabaseEngine] = None - - # 初始化连接池(保留旧的SQLite连接池,用于group数据库) - self.connection_pool = DatabaseConnectionPool( - db_path=self.messages_db_path, - max_connections=config.max_connections, - min_connections=config.min_connections - ) - - # 确保数据目录存在 - os.makedirs(self.group_data_dir, exist_ok=True) - - self._logger.info(f"数据库管理器初始化完成 (类型: {config.db_type}, 跳过表初始化: {skip_table_init})") - - async def _do_start(self) -> bool: - """启动服务时初始化连接池和数据库""" - try: - self._logger.info(f"🚀 [DatabaseManager] 开始启动 (db_type={self.config.db_type}, skip_table_init={self.skip_table_init})") - - # 1. 创建数据库后端(无论 skip_table_init 是否为 True 都需要初始化后端) - # skip_table_init 只影响表的创建,不影响后端连接的初始化 - self._logger.info(f"📡 [DatabaseManager] 正在初始化 {self.config.db_type} 数据库后端...") - backend_success = await self._initialize_database_backend() - - # 2. 如果数据库后端初始化失败,直接报错,不回退 - if not backend_success or not self.db_backend: - error_msg = f"❌ {self.config.db_type} 数据库后端初始化失败" - self._logger.error(error_msg) - raise RuntimeError(error_msg) - - self._logger.info(f"✅ [DatabaseManager] {self.config.db_type} 后端初始化成功") - - # 3. 初始化旧的连接池(仅用于group数据库,暂时保留) - await self.connection_pool.initialize() - self._logger.info("✅ [DatabaseManager] 数据库连接池初始化成功") - - # 4. 初始化数据库表结构(如果表不存在则自动创建) - # 如果 skip_table_init=True(由 ORM 管理表),则跳过表创建 - if not self.skip_table_init: - await self._init_messages_database() - self._logger.info("✅ [DatabaseManager] 全局消息数据库初始化成功") - else: - self._logger.info("⏭️ [DatabaseManager] 跳过传统数据库表创建(由 SQLAlchemy ORM 管理)") - - self._logger.info(f"🎉 [DatabaseManager] 数据库管理器启动完成 (使用后端: {self.config.db_type})") - return True - except Exception as e: - self._logger.error(f"❌ [DatabaseManager] 启动数据库管理器失败: {e}", exc_info=True) - return False - - async def _initialize_database_backend(self) -> bool: - """初始化数据库后端""" - try: - # 构建数据库配置 - db_type = DatabaseType(self.config.db_type.lower()) - - if db_type == DatabaseType.SQLITE: - db_config = DatabaseConfig( - db_type=DatabaseType.SQLITE, - sqlite_path=self.messages_db_path, - max_connections=self.config.max_connections, - min_connections=self.config.min_connections - ) - elif db_type == DatabaseType.MYSQL: - db_config = DatabaseConfig( - db_type=DatabaseType.MYSQL, - mysql_host=self.config.mysql_host, - mysql_port=self.config.mysql_port, - mysql_user=self.config.mysql_user, - mysql_password=self.config.mysql_password, - mysql_database=self.config.mysql_database, - max_connections=self.config.max_connections, - min_connections=self.config.min_connections - ) - elif db_type == DatabaseType.POSTGRESQL: - db_config = DatabaseConfig( - db_type=DatabaseType.POSTGRESQL, - postgresql_host=self.config.postgresql_host, - postgresql_port=self.config.postgresql_port, - postgresql_user=self.config.postgresql_user, - postgresql_password=self.config.postgresql_password, - postgresql_database=self.config.postgresql_database, - postgresql_schema=self.config.postgresql_schema, - max_connections=self.config.max_connections, - min_connections=self.config.min_connections - ) - else: - raise ValueError(f"不支持的数据库类型: {self.config.db_type}") - - # 使用工厂创建后端 - self.db_backend = DatabaseFactory.create_backend(db_config) - if not self.db_backend: - raise Exception("创建数据库后端失败") - - # 初始化后端 - success = await self.db_backend.initialize() - if not success: - raise Exception("数据库后端初始化失败") - - self._logger.info(f"数据库后端初始化成功: {self.config.db_type}") - return True - - except Exception as e: - self._logger.error(f"初始化数据库后端失败: {e}", exc_info=True) - return False - - async def _do_stop(self) -> bool: - """停止服务时关闭所有数据库连接""" - try: - # 关闭数据库后端 - if self.db_backend: - await self.db_backend.close() - - # 关闭旧的连接池 - await self.close_all_connections() - await self.connection_pool.close_all() - - self._logger.info("所有数据库连接已关闭") - return True - except Exception as e: - self._logger.error(f"关闭数据库管理器失败: {e}", exc_info=True) - return False - - def get_db_connection(self): - """ - 获取数据库连接的上下文管理器 - 根据配置的数据库类型,自动选择SQLite、MySQL或PostgreSQL后端 - """ - db_type = self.config.db_type.lower() - - # 🔍 调试日志:输出数据库类型和后端状态 - self._logger.debug(f"[get_db_connection] 配置的数据库类型: {db_type}") - self._logger.debug(f"[get_db_connection] db_backend 状态: {self.db_backend is not None}") - - # 如果使用MySQL或PostgreSQL且db_backend可用,使用通用后端连接管理器 - if db_type in ('mysql', 'postgresql') and self.db_backend: - self._logger.debug(f"[get_db_connection] ✅ 使用 {db_type.upper()} 后端") - return self._get_backend_connection_manager() - else: - # 使用旧的SQLite连接池 - self._logger.warning(f"[get_db_connection] ⚠️ 回退到 SQLite 连接池 (db_type={db_type}, backend_exists={self.db_backend is not None})") - return self._get_sqlite_connection_manager() - - def _get_sqlite_connection_manager(self): - """获取SQLite连接管理器""" - class SQLiteConnectionManager: - def __init__(self, pool: DatabaseConnectionPool): - self.pool = pool - self.connection = None - - async def __aenter__(self): - self.connection = await self.pool.get_connection() - return self.connection - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.connection: - await self.pool.return_connection(self.connection) - - return SQLiteConnectionManager(self.connection_pool) - - def _get_backend_connection_manager(self): - """获取MySQL/PostgreSQL连接管理器 - 适配aiosqlite接口""" - db_backend = self.db_backend - - class BackendConnectionAdapter: - """数据库后端连接适配器 - 模拟aiosqlite接口""" - def __init__(self, backend): - self.backend = backend - self._cursor = None - - async def cursor(self): - """返回游标适配器""" - return BackendCursorAdapter(self.backend) - - async def commit(self): - """提交事务 - 后端在execute中已自动提交""" - pass - - async def rollback(self): - """回滚事务""" - await self.backend.rollback() - - async def execute(self, sql, params=None): - """执行SQL""" - return await self.backend.execute(sql, params) - - async def executemany(self, sql, params_list): - """批量执行SQL""" - return await self.backend.execute_many(sql, params_list) - - async def fetchone(self): - """获取单行""" - return await self._cursor.fetchone() if self._cursor else None - - async def fetchall(self): - """获取所有行""" - return await self._cursor.fetchall() if self._cursor else [] - - class BackendCursorAdapter: - """数据库后端游标适配器""" - def __init__(self, backend): - self.backend = backend - self._last_result = None - self.lastrowid = None - self.rowcount = 0 - - async def execute(self, sql, params=None): - """执行SQL并存储结果""" - import re - - # 检测是SELECT查询还是其他操作 - sql_upper = sql.strip().upper() - - # 获取数据库类型 - db_type = self.backend.db_type - is_mysql = (db_type == DatabaseType.MYSQL) - is_postgresql = (db_type == DatabaseType.POSTGRESQL) - - # 对于 CREATE TABLE 和 ALTER TABLE,需要特殊处理 - if sql_upper.startswith('CREATE TABLE') or sql_upper.startswith('ALTER TABLE'): - # 使用后端的 convert_ddl 进行转换 - converted_sql = self.backend.convert_ddl(sql) - await self.backend.execute(converted_sql, None) - self._last_result = [] - self.rowcount = 0 - return self - - # 转换参数占位符 - if is_mysql: - # ✅ MySQL: 转换 INSERT OR REPLACE 为 REPLACE INTO - converted_sql = sql.replace('INSERT OR REPLACE', 'REPLACE') - # 转换参数占位符 ? -> %s - converted_sql = converted_sql.replace('?', '%s') - elif is_postgresql: - # PostgreSQL 使用 $1, $2, ... - # 调用后端的占位符转换方法 - converted_sql = self.backend._convert_placeholders(sql) if hasattr(self.backend, '_convert_placeholders') else sql - else: - converted_sql = sql - - # 确保 params 是 tuple 类型 - if params is not None and not isinstance(params, tuple): - if isinstance(params, list): - params = tuple(params) - else: - params = (params,) - - # 处理 sqlite_master 查询 - if 'SQLITE_MASTER' in sql_upper: - table_match = re.search(r"NAME\s*=\s*['\"]?(\w+)['\"]?", sql_upper) - if table_match: - table_name = table_match.group(1).lower() - if is_mysql: - check_sql = """ - SELECT TABLE_NAME as name - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA = DATABASE() AND LOWER(TABLE_NAME) = %s - """ - self._last_result = await self.backend.fetch_all(check_sql, (table_name,)) - elif is_postgresql: - check_sql = """ - SELECT table_name as name - FROM information_schema.tables - WHERE table_schema = $1 AND LOWER(table_name) = $2 - """ - schema = getattr(self.backend.config, 'postgresql_schema', 'public') - self._last_result = await self.backend.fetch_all(check_sql, (schema, table_name)) - self.rowcount = len(self._last_result) if self._last_result else 0 - return self - else: - self._last_result = [] - self.rowcount = 0 - return self - - # 处理 PRAGMA table_info 查询 - if sql_upper.startswith('PRAGMA'): - pragma_match = re.search(r'PRAGMA\s+TABLE_INFO\s*\(\s*(\w+)\s*\)', sql_upper) - if pragma_match: - table_name = pragma_match.group(1) - try: - if is_mysql: - describe_sql = f"DESCRIBE {table_name}" - mysql_result = await self.backend.fetch_all(describe_sql, None) - self._last_result = [] - for idx, row in enumerate(mysql_result or []): - field_name = row[0] - field_type = row[1] - is_nullable = 0 if row[2] == 'NO' else 1 - default_value = row[4] - is_pk = 1 if row[3] == 'PRI' else 0 - self._last_result.append((idx, field_name, field_type, 1 - is_nullable, default_value, is_pk)) - elif is_postgresql: - # PostgreSQL 使用 information_schema.columns - schema = getattr(self.backend.config, 'postgresql_schema', 'public') - pg_sql = """ - SELECT - ordinal_position - 1 as cid, - column_name as name, - data_type as type, - CASE WHEN is_nullable = 'NO' THEN 1 ELSE 0 END as notnull, - column_default as dflt_value, - 0 as pk - FROM information_schema.columns - WHERE table_schema = $1 AND table_name = $2 - ORDER BY ordinal_position - """ - self._last_result = await self.backend.fetch_all(pg_sql, (schema, table_name)) - self.rowcount = len(self._last_result) - except Exception: - self._last_result = [] - self.rowcount = 0 - return self - else: - self._last_result = [] - self.rowcount = 0 - return self - - if sql_upper.startswith('SELECT'): - self._last_result = await self.backend.fetch_all(converted_sql, params) - self.rowcount = len(self._last_result) if self._last_result else 0 - else: - # INSERT/UPDATE/DELETE - self.rowcount = await self.backend.execute(converted_sql, params) - # 尝试获取lastrowid(对于INSERT操作) - if sql_upper.startswith('INSERT'): - try: - if is_mysql: - result = await self.backend.fetch_one("SELECT LAST_INSERT_ID()") - elif is_postgresql: - result = await self.backend.fetch_one("SELECT lastval()") - else: - result = None - self.lastrowid = result[0] if result else None - except Exception: - self.lastrowid = None - return self - - async def executemany(self, sql, params_list): - """批量执行SQL""" - db_type = self.backend.db_type - if db_type == DatabaseType.MYSQL: - converted_sql = sql.replace('?', '%s') - elif db_type == DatabaseType.POSTGRESQL: - converted_sql = self.backend._convert_placeholders(sql) if hasattr(self.backend, '_convert_placeholders') else sql - else: - converted_sql = sql - self.rowcount = await self.backend.execute_many(converted_sql, params_list) - return self - - async def fetchone(self): - """获取单行结果""" - if self._last_result and len(self._last_result) > 0: - return self._last_result[0] - return None - - async def fetchall(self): - """获取所有结果""" - return self._last_result if self._last_result else [] - - def __aiter__(self): - """支持异步迭代""" - self._iter_index = 0 - return self - - async def __anext__(self): - """异步迭代""" - if not self._last_result or self._iter_index >= len(self._last_result): - raise StopAsyncIteration - result = self._last_result[self._iter_index] - self._iter_index += 1 - return result - - async def close(self): - """关闭游标(后端使用连接池,无需实际关闭)""" - self._last_result = None - self.lastrowid = None - self.rowcount = 0 - - class BackendConnectionManager: - def __init__(self, backend): - self.backend = backend - self.adapter = None - - async def __aenter__(self): - self.adapter = BackendConnectionAdapter(self.backend) - return self.adapter - - async def __aexit__(self, exc_type, exc_val, exc_tb): - # 后端使用连接池,无需手动关闭 - pass - - return BackendConnectionManager(db_backend) - - def get_connection(self): - """ - 获取数据库连接的同步接口,用于兼容旧代码 - 注意:这是一个同步方法,用于兼容使用 'with' 语句的代码 - """ - class SyncConnectionWrapper: - def __init__(self, db_manager): - self.db_manager = db_manager - self.connection = None - - def __enter__(self): - # 同步获取连接,这需要在异步上下文中使用 - import sqlite3 - # 直接创建同步连接到同一个数据库文件 - self.connection = sqlite3.connect(self.db_manager.messages_db_path) - return self.connection - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.connection: - self.connection.close() - - return SyncConnectionWrapper(self) - - async def close_all_connections(self): - """关闭所有数据库连接""" - try: - # 关闭所有群组数据库连接 - for group_id, conn in list(self.group_db_connections.items()): - try: - await conn.close() - self._logger.info(f"群组 {group_id} 数据库连接已关闭") - except Exception as e: - self._logger.error(f"关闭群组 {group_id} 数据库连接失败: {e}") - - self.group_db_connections.clear() - self._logger.info("所有群组数据库连接已关闭") - - except Exception as e: - self._logger.error(f"关闭数据库连接过程中发生错误: {e}") - raise - - async def _retry_on_connection_error(self, func, *args, **kwargs): - """在连接错误时重试的通用方法(保留兼容性)""" - try: - return await func(*args, **kwargs) - except Exception as e: - if "no active connection" in str(e).lower(): - self._logger.warning(f"检测到连接问题: {e},尝试重新执行...") - try: - # 连接池会自动处理连接问题,直接重试 - return await func(*args, **kwargs) - except Exception as retry_error: - self._logger.error(f"重试也失败: {retry_error}") - raise retry_error - else: - raise e - - async def _init_messages_database(self): - """ - 初始化全局消息数据库(根据数据库类型选择后端) - - ⚠️ 已废弃:所有表结构由 SQLAlchemy ORM 统一管理 - 此方法保留仅用于向后兼容,不再创建表 - """ - self._logger.info("⏭️ [传统数据库管理器] 表创建已由 SQLAlchemy ORM 接管,跳过传统表初始化") - # 如果使用MySQL后端,使用db_backend初始化表 - # if self.db_backend and self.config.db_type.lower() == 'mysql': - # await self._init_messages_database_mysql() - # self._logger.info("MySQL数据库表初始化完成。") - # else: - # # 使用旧的SQLite连接池 - # async with self.get_db_connection() as conn: - # await self._init_messages_database_tables(conn) - # self._logger.info("全局消息数据库连接池初始化完成并表已初始化。") - - async def _init_messages_database_mysql(self): - """ - 使用MySQL后端初始化数据库表 - - ⚠️ 已废弃:所有表结构由 SQLAlchemy ORM 统一管理 - 此方法保留仅用于参考,不再使用 - """ - self._logger.warning("⚠️ [传统数据库管理器] _init_messages_database_mysql 已废弃,请使用 SQLAlchemy ORM") - return - - # 以下代码已禁用,保留仅供参考 - """ - try: - # 创建原始消息表 - self._logger.info("尝试创建 raw_messages 表 (MySQL)...") - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS raw_messages ( - id INT PRIMARY KEY AUTO_INCREMENT, - sender_id VARCHAR(255) NOT NULL, - sender_name VARCHAR(255), - message TEXT NOT NULL, - group_id VARCHAR(255), - platform VARCHAR(50), - timestamp DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - processed TINYINT(1) DEFAULT 0, - INDEX idx_timestamp (timestamp), - INDEX idx_sender (sender_id), - INDEX idx_processed (processed), - INDEX idx_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - self._logger.info("raw_messages 表创建/检查完成。") - - # 创建Bot消息表 - self._logger.info("尝试创建 bot_messages 表 (MySQL)...") - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS bot_messages ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255), - message TEXT NOT NULL, - response_to_message_id INT, - context_type VARCHAR(100), - temperature DOUBLE, - language_style VARCHAR(100), - response_pattern VARCHAR(255), - timestamp DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - self._logger.info("bot_messages 表创建/检查完成。") - - # 创建筛选后消息表 - self._logger.info("尝试创建 filtered_messages 表 (MySQL)...") - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS filtered_messages ( - id INT PRIMARY KEY AUTO_INCREMENT, - raw_message_id INT, - message TEXT NOT NULL, - sender_id VARCHAR(255), - group_id VARCHAR(255), - confidence DOUBLE, - filter_reason TEXT, - timestamp DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - used_for_learning TINYINT(1) DEFAULT 0, - quality_scores TEXT, - refined TINYINT(1) DEFAULT 0, - INDEX idx_confidence (confidence), - INDEX idx_used (used_for_learning), - INDEX idx_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - self._logger.info("filtered_messages 表创建/检查完成。") - - # 创建学习批次表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS learning_batches ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - start_time DOUBLE NOT NULL, - end_time DOUBLE, - quality_score DOUBLE DEFAULT 0.5, - processed_messages INT DEFAULT 0, - batch_name VARCHAR(255) UNIQUE, - message_count INT, - filtered_count INT, - success TINYINT(1), - error_message TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建人格更新记录表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_records ( - id INT PRIMARY KEY AUTO_INCREMENT, - timestamp DOUBLE NOT NULL, - group_id VARCHAR(255) NOT NULL, - update_type VARCHAR(100) NOT NULL, - original_content TEXT, - new_content TEXT NOT NULL, - reason TEXT, - status VARCHAR(50) DEFAULT 'pending', - reviewer_comment TEXT, - review_time DOUBLE, - INDEX idx_status (status), - INDEX idx_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建强化学习结果表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS reinforcement_learning_results ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - replay_analysis TEXT, - optimization_strategy TEXT, - reinforcement_feedback TEXT, - next_action TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建策略优化结果表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS strategy_optimization_results ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - exploration_type VARCHAR(100), - effectiveness_score DOUBLE, - new_strategy TEXT, - rollback_reason TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建学习性能历史表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS learning_performance_history ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - session_id VARCHAR(255), - timestamp DOUBLE NOT NULL, - quality_score DOUBLE, - learning_time DOUBLE, - success TINYINT(1), - successful_pattern TEXT, - failed_pattern TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id), - INDEX idx_session (session_id), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建LLM调用统计表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS llm_call_statistics ( - id INT PRIMARY KEY AUTO_INCREMENT, - call_type VARCHAR(100) NOT NULL, - provider VARCHAR(100), - model VARCHAR(100), - input_tokens INT DEFAULT 0, - output_tokens INT DEFAULT 0, - total_tokens INT DEFAULT 0, - latency_ms DOUBLE, - success TINYINT(1) DEFAULT 1, - error_message TEXT, - group_id VARCHAR(255), - timestamp DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_call_type (call_type), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建黑话表(与 SQLite 版本结构一致) - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS jargon ( - id INT PRIMARY KEY AUTO_INCREMENT, - content TEXT NOT NULL, - raw_content TEXT, - meaning TEXT, - is_jargon TINYINT(1), - count INT DEFAULT 1, - last_inference_count INT DEFAULT 0, - is_complete TINYINT(1) DEFAULT 0, - is_global TINYINT(1) DEFAULT 0, - chat_id VARCHAR(255) NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - UNIQUE KEY uk_chat_content (chat_id, content(255)), - INDEX idx_content (content(255)), - INDEX idx_chat_id (chat_id), - INDEX idx_is_jargon (is_jargon) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建社交关系表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS social_relations ( - id INT PRIMARY KEY AUTO_INCREMENT, - user_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255) NOT NULL, - relation_type VARCHAR(100), - affection_score DOUBLE DEFAULT 0, - interaction_count INT DEFAULT 0, - last_interaction DOUBLE, - metadata TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - UNIQUE KEY uk_user_group (user_id, group_id), - INDEX idx_group (group_id), - INDEX idx_affection (affection_score) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建表达模式表(与 expression_pattern_learner.py 中的 SQLite 结构一致) - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS expression_patterns ( - id INT PRIMARY KEY AUTO_INCREMENT, - situation TEXT NOT NULL, - expression TEXT NOT NULL, - weight DOUBLE NOT NULL DEFAULT 1.0, - last_active_time DOUBLE NOT NULL, - create_time DOUBLE NOT NULL, - group_id VARCHAR(255) NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - UNIQUE KEY uk_situation_expression_group (situation(255), expression(255), group_id), - INDEX idx_group (group_id), - INDEX idx_weight (weight) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建语言风格模式表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS language_style_patterns ( - id INT PRIMARY KEY AUTO_INCREMENT, - style_name VARCHAR(100) NOT NULL, - style_description TEXT, - examples TEXT, - frequency INT DEFAULT 1, - source_group VARCHAR(255), - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - INDEX idx_style (style_name) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建话题摘要表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS topic_summaries ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - topic VARCHAR(255) NOT NULL, - summary TEXT, - message_count INT DEFAULT 0, - start_time DOUBLE, - end_time DOUBLE, - keywords TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id), - INDEX idx_topic (topic) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建风格学习记录表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS style_learning_records ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - style_type VARCHAR(100), - learned_content TEXT, - confidence DOUBLE DEFAULT 0.5, - source_messages TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group (group_id), - INDEX idx_style (style_type) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建风格学习审核表(与 SQLite 版本的 _ensure_style_review_table_exists 结构一致) - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS style_learning_reviews ( - id INT PRIMARY KEY AUTO_INCREMENT, - type VARCHAR(100) NOT NULL, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - learned_patterns TEXT, - few_shots_content TEXT, - status VARCHAR(50) DEFAULT 'pending', - description TEXT, - reviewer_comment TEXT, - review_time DOUBLE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - INDEX idx_status (status), - INDEX idx_group (group_id), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建人格融合历史表 - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS persona_fusion_history ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - base_persona_hash BIGINT, - incremental_hash BIGINT, - fusion_result TEXT, - compatibility_score DOUBLE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_group_id (group_id), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # 创建人格更新审核表(与 SQLite 版本结构一致) - await self.db_backend.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INT PRIMARY KEY AUTO_INCREMENT, - timestamp DOUBLE NOT NULL, - group_id VARCHAR(255) NOT NULL, - update_type VARCHAR(100) NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, - confidence_score DOUBLE, - reason TEXT, - status VARCHAR(50) NOT NULL DEFAULT 'pending', - reviewer_comment TEXT, - review_time DOUBLE, - metadata TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_status (status), - INDEX idx_group (group_id), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - self._logger.info("所有MySQL表创建完成") - - except Exception as e: - self._logger.error(f"MySQL表初始化失败: {e}", exc_info=True) - raise - """ - - async def _init_messages_database_tables(self, conn: aiosqlite.Connection): - """ - 初始化全局消息SQLite数据库的表结构 - - ⚠️ 已废弃:所有表结构由 SQLAlchemy ORM 统一管理 - 此方法保留仅用于向后兼容,不再创建表 - """ - self._logger.warning("⚠️ [传统数据库管理器] _init_messages_database_tables 已废弃,请使用 SQLAlchemy ORM") - return - - # 以下代码已禁用,保留仅供参考 - """ - cursor = await conn.cursor() - - try: - # 设置数据库为WAL模式,提高并发性能并避免锁定问题 - await cursor.execute('PRAGMA journal_mode=WAL') - await cursor.execute('PRAGMA synchronous=NORMAL') - await cursor.execute('PRAGMA cache_size=10000') - await cursor.execute('PRAGMA temp_store=memory') - - # 创建原始消息表 - self._logger.info("尝试创建 raw_messages 表...") - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS raw_messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - sender_id TEXT NOT NULL, - sender_name TEXT, - message TEXT NOT NULL, - group_id TEXT, - platform TEXT, - timestamp REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - processed BOOLEAN DEFAULT FALSE - ) - ''') - self._logger.info("raw_messages 表创建/检查完成。") - await conn.commit() # 强制提交,确保表结构写入磁盘 - - # 创建Bot消息表 (用于存储Bot发送的消息,供多样性管理器使用) - self._logger.info("尝试创建 bot_messages 表...") - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS bot_messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - user_id TEXT, - message TEXT NOT NULL, - response_to_message_id INTEGER, - context_type TEXT, - temperature REAL, - language_style TEXT, - response_pattern TEXT, - timestamp REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (response_to_message_id) REFERENCES raw_messages (id) - ) - ''') - self._logger.info("bot_messages 表创建/检查完成。") - await conn.commit() - - # 创建筛选后消息表 - self._logger.info("尝试创建 filtered_messages 表...") - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS filtered_messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - raw_message_id INTEGER, - message TEXT NOT NULL, - sender_id TEXT, - group_id TEXT, - confidence REAL, - filter_reason TEXT, - timestamp REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - used_for_learning BOOLEAN DEFAULT FALSE, - quality_scores TEXT, -- 新增字段,存储JSON字符串 - FOREIGN KEY (raw_message_id) REFERENCES raw_messages (id) - ) - ''') - self._logger.info("filtered_messages 表创建/检查完成。") - - # 检查并添加 quality_scores 列(如果不存在) - await cursor.execute("PRAGMA table_info(filtered_messages)") - columns = [col[1] for col in await cursor.fetchall()] - if 'quality_scores' not in columns: - await cursor.execute("ALTER TABLE filtered_messages ADD COLUMN quality_scores TEXT") - await conn.commit() # 立即提交,确保列添加成功 - logger.info("已为 filtered_messages 表添加 quality_scores 列。") - - # 检查并添加 group_id 列(如果不存在) - # 重新获取列信息,因为前面可能添加了新列 - await cursor.execute("PRAGMA table_info(filtered_messages)") - columns = [col[1] for col in await cursor.fetchall()] - if 'group_id' not in columns: - await cursor.execute("ALTER TABLE filtered_messages ADD COLUMN group_id TEXT") - await conn.commit() - logger.info("已为 filtered_messages 表添加 group_id 列。") - - # 检查并添加 refined 列(如果不存在) - await cursor.execute("PRAGMA table_info(filtered_messages)") - columns = [col[1] for col in await cursor.fetchall()] - if 'refined' not in columns: - await cursor.execute("ALTER TABLE filtered_messages ADD COLUMN refined BOOLEAN DEFAULT 0") - await conn.commit() - logger.info("已为 filtered_messages 表添加 refined 列。") - - # 检查并添加 used_for_learning 列(如果不存在) - await cursor.execute("PRAGMA table_info(filtered_messages)") - columns = [col[1] for col in await cursor.fetchall()] - if 'used_for_learning' not in columns: - await cursor.execute("ALTER TABLE filtered_messages ADD COLUMN used_for_learning BOOLEAN DEFAULT 0") - await conn.commit() - logger.info("已为 filtered_messages 表添加 used_for_learning 列。") - - # 创建学习批次表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS learning_batches ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - start_time REAL NOT NULL, - end_time REAL, - quality_score REAL DEFAULT 0.5, - processed_messages INTEGER DEFAULT 0, - batch_name TEXT UNIQUE, - message_count INTEGER, - filtered_count INTEGER, - success BOOLEAN, - error_message TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 创建人格更新记录表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_records ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp REAL NOT NULL, - group_id TEXT NOT NULL, - update_type TEXT NOT NULL, - original_content TEXT, - new_content TEXT NOT NULL, - reason TEXT, - status TEXT DEFAULT 'pending', - reviewer_comment TEXT, - review_time REAL - ) - ''') - - # 创建索引(带错误处理,避免列不存在导致失败) - indices = [ - ('idx_raw_messages_timestamp', 'raw_messages', 'timestamp'), - ('idx_raw_messages_sender', 'raw_messages', 'sender_id'), - ('idx_raw_messages_processed', 'raw_messages', 'processed'), - ('idx_filtered_messages_confidence', 'filtered_messages', 'confidence'), - ('idx_filtered_messages_used', 'filtered_messages', 'used_for_learning'), - ('idx_persona_update_records_status', 'persona_update_records', 'status'), - ('idx_persona_update_records_group_id', 'persona_update_records', 'group_id'), - ] - - for index_name, table_name, column_name in indices: - try: - await cursor.execute(f'CREATE INDEX IF NOT EXISTS {index_name} ON {table_name}({column_name})') - except Exception as e: - logger.debug(f"创建索引 {index_name} 失败(可能列不存在): {e}") - - # 新增强化学习相关表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS reinforcement_learning_results ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - timestamp REAL NOT NULL, - replay_analysis TEXT, - optimization_strategy TEXT, - reinforcement_feedback TEXT, - next_action TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_fusion_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - timestamp REAL NOT NULL, - base_persona_hash INTEGER, - incremental_hash INTEGER, - fusion_result TEXT, - compatibility_score REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS strategy_optimization_results ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - timestamp REAL NOT NULL, - original_strategy TEXT, - optimization_result TEXT, - expected_improvement TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS learning_performance_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - session_id TEXT, - timestamp REAL NOT NULL, - quality_score REAL, - learning_time REAL, - success BOOLEAN, - successful_pattern TEXT, - failed_pattern TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 为强化学习表创建索引 - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_reinforcement_learning_group ON reinforcement_learning_results(group_id)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_persona_fusion_group ON persona_fusion_history(group_id)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_strategy_optimization_group ON strategy_optimization_results(group_id)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_learning_performance_group ON learning_performance_history(group_id)') - - # 创建LLM调用统计表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS llm_call_statistics ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - provider_type TEXT NOT NULL, -- filter, refine, reinforce - model_name TEXT, - total_calls INTEGER DEFAULT 0, - success_calls INTEGER DEFAULT 0, - failed_calls INTEGER DEFAULT 0, - total_response_time_ms INTEGER DEFAULT 0, - avg_response_time_ms REAL DEFAULT 0, - success_rate REAL DEFAULT 0, - last_call_time REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - UNIQUE(provider_type, model_name) - ) - ''') - - # 风格学习记录表 (从群组数据库移至消息数据库) - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS style_learning_records ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - style_type TEXT NOT NULL, - learned_patterns TEXT, -- JSON格式存储学习到的模式 - confidence_score REAL, - sample_count INTEGER, - learning_time REAL NOT NULL, - last_updated REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 语言风格模式表 (从群组数据库移至消息数据库) - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS language_style_patterns ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - language_style TEXT NOT NULL, - example_phrases TEXT, -- JSON格式存储示例短语 - usage_frequency INTEGER DEFAULT 0, - context_type TEXT DEFAULT 'general', - confidence_score REAL, - last_updated REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 为新表创建索引 - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_style_learning_group ON style_learning_records(group_id)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_style_learning_time ON style_learning_records(learning_time)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_language_style_group ON language_style_patterns(group_id)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_language_style_frequency ON language_style_patterns(usage_frequency)') - - # 创建话题总结表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS topic_summaries ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - topic TEXT NOT NULL, - summary TEXT, - participants TEXT, -- JSON格式存储参与者列表 - message_count INTEGER DEFAULT 0, - start_timestamp REAL, - end_timestamp REAL, - generated_at REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 为话题总结表创建索引 - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_topic_summaries_group ON topic_summaries(group_id)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_topic_summaries_time ON topic_summaries(generated_at)') - - # 创建黑话学习表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS jargon ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - content TEXT NOT NULL, - raw_content TEXT DEFAULT '[]', - meaning TEXT, - is_jargon BOOLEAN, - count INTEGER DEFAULT 1, - last_inference_count INTEGER DEFAULT 0, - is_complete BOOLEAN DEFAULT 0, - is_global BOOLEAN DEFAULT 0, - chat_id TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - UNIQUE(chat_id, content) - ) - ''') - - # 为黑话表创建索引 - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_jargon_content ON jargon(content)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_jargon_chat_id ON jargon(chat_id)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_jargon_is_jargon ON jargon(is_jargon)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_jargon_count ON jargon(count)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_jargon_updated_at ON jargon(updated_at)') - - await conn.commit() - logger.info("全局消息数据库初始化完成") - - except aiosqlite.Error as e: - logger.error(f"全局消息数据库初始化失败: {e}", exc_info=True) - # 尝试删除可能损坏的数据库文件,以便下次启动时重新创建 - if os.path.exists(self.messages_db_path): - self._logger.warning(f"数据库初始化失败,尝试删除损坏的数据库文件: {self.messages_db_path}") - try: - os.remove(self.messages_db_path) - except OSError as ose: - self._logger.error(f"删除数据库文件失败: {ose}") - raise DataStorageError(f"全局消息数据库初始化失败: {str(e)}") - """ - - def get_group_db_path(self, group_id: str) -> str: - """获取群数据库文件路径""" - if not group_id: - raise ValueError("group_id 不能为空") - if not self.group_data_dir: - raise ValueError("group_data_dir 未初始化") - return os.path.join(self.group_data_dir, f"{group_id}_ID.db") - - async def get_group_connection(self, group_id: str) -> aiosqlite.Connection: - """获取群数据库连接""" - if group_id not in self.group_db_connections: - db_path = self.get_group_db_path(group_id) - - # 确保数据库目录存在 - db_dir = os.path.dirname(db_path) - os.makedirs(db_dir, exist_ok=True) - - # 检查数据库文件权限 - if os.path.exists(db_path): - try: - # 尝试修改文件权限为可写 - import stat - os.chmod(db_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP) - except OSError as e: - logger.warning(f"无法修改群数据库文件权限: {e}") - - conn = await aiosqlite.connect(db_path) - - # 设置连接参数,确保数据库可写 - await conn.execute('PRAGMA foreign_keys = ON') - await conn.execute('PRAGMA journal_mode = WAL') - await conn.execute('PRAGMA synchronous = NORMAL') - await conn.commit() - - await self._init_group_database(conn) - self.group_db_connections[group_id] = conn - logger.info(f"已创建群 {group_id} 的数据库连接") - - return self.group_db_connections[group_id] - - async def _init_group_database(self, conn: aiosqlite.Connection): - """初始化群数据库表结构""" - cursor = await conn.cursor() - - try: - # 设置数据库为WAL模式,提高并发性能并避免锁定问题 - await cursor.execute('PRAGMA journal_mode=WAL') - await cursor.execute('PRAGMA synchronous=NORMAL') - await cursor.execute('PRAGMA cache_size=10000') - await cursor.execute('PRAGMA temp_store=memory') - - # 原始消息表 (群数据库中不再存储原始消息,由全局消息数据库统一管理) - # 筛选消息表 (群数据库中不再存储筛选消息,由全局消息数据库统一管理) - - # 用户画像表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS user_profiles ( - qq_id TEXT PRIMARY KEY, - qq_name TEXT, - nicknames TEXT, -- JSON格式存储 - activity_pattern TEXT, -- JSON格式存储活动模式 - communication_style TEXT, -- JSON格式存储沟通风格 - topic_preferences TEXT, -- JSON格式存储话题偏好 - emotional_tendency TEXT, -- JSON格式存储情感倾向 - last_active REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 社交关系表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS social_relations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - from_user TEXT NOT NULL, - to_user TEXT NOT NULL, - relation_type TEXT NOT NULL, -- mention, reply, frequent_interaction - strength REAL NOT NULL, - frequency INTEGER NOT NULL, - last_interaction REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - UNIQUE(from_user, to_user, relation_type) - ) - ''') - - # 风格档案表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS style_profiles ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - profile_name TEXT NOT NULL, - vocabulary_richness REAL, - sentence_complexity REAL, - emotional_expression REAL, - interaction_tendency REAL, - topic_diversity REAL, - formality_level REAL, - creativity_score REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 人格备份表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_backups ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - backup_name TEXT NOT NULL, - timestamp REAL NOT NULL, - reason TEXT, - persona_config TEXT, -- JSON格式存储人格配置 - original_persona TEXT, -- JSON格式存储 - imitation_dialogues TEXT, -- JSON格式存储模仿对话 - backup_reason TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 风格学习记录表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS style_learning_records ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - style_type TEXT NOT NULL, - learned_patterns TEXT, -- JSON格式存储学习到的模式 - confidence_score REAL, - sample_count INTEGER, - last_updated REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 情感表达模式表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS emotion_patterns ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - emotional_pattern TEXT NOT NULL, - confidence_score REAL, - frequency INTEGER DEFAULT 0, - context_type TEXT, - last_updated REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 语言风格模式表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS language_style_patterns ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - language_style TEXT NOT NULL, - example_phrases TEXT, -- JSON格式存储示例短语 - usage_frequency INTEGER DEFAULT 0, - context_type TEXT DEFAULT 'general', - confidence_score REAL, - last_updated REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 主题偏好表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS topic_preferences ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - topic_category TEXT NOT NULL, - interest_level REAL, - response_style TEXT, - sample_count INTEGER DEFAULT 0, - confidence_score REAL, - last_updated REAL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 人格更新审查表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - update_type TEXT NOT NULL, -- style_update, persona_update, learning_update - original_content TEXT, -- 原始人格内容 - proposed_content TEXT, -- 建议的新内容 - confidence_score REAL, - reason TEXT, -- 更新原因 - sample_messages TEXT, -- JSON格式存储触发更新的示例消息 - review_status TEXT DEFAULT 'pending', -- pending, approved, rejected - reviewer_comment TEXT, - created_at REAL, - reviewed_at REAL, - auto_score REAL, -- 自动评分 - manual_override BOOLEAN DEFAULT FALSE - ) - ''') - - # 学习批次表 (如果不存在) - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS learning_batches ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - batch_name TEXT, - start_time REAL, - end_time REAL, - processed_messages INTEGER DEFAULT 0, - success BOOLEAN DEFAULT FALSE, - error_message TEXT, - learning_type TEXT, -- style_learning, persona_update, etc. - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 学习会话表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS learning_sessions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id TEXT UNIQUE NOT NULL, - start_time REAL NOT NULL, - end_time REAL, - messages_processed INTEGER DEFAULT 0, - filtered_messages INTEGER DEFAULT 0, - style_updates INTEGER DEFAULT 0, - quality_score REAL DEFAULT 0.0, - success BOOLEAN DEFAULT FALSE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 创建索引 - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_social_relations_from_user ON social_relations(from_user)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_social_relations_to_user ON social_relations(to_user)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_user_profiles_active ON user_profiles(last_active)') - await cursor.execute('CREATE INDEX IF NOT EXISTS idx_style_profiles_name ON style_profiles(profile_name)') - - # 创建好感度表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS user_affection ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - group_id TEXT NOT NULL, - affection_level INTEGER DEFAULT 0, - last_interaction REAL NOT NULL, - last_updated REAL NOT NULL, - interaction_count INTEGER DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - UNIQUE(user_id, group_id) - ) - ''') - - # 创建bot情绪表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS bot_mood ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - mood_type TEXT NOT NULL, - mood_intensity REAL DEFAULT 0.5, - mood_description TEXT, - start_time REAL NOT NULL, - end_time REAL, - is_active BOOLEAN DEFAULT TRUE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # 创建好感度变化记录表 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS affection_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - group_id TEXT NOT NULL, - change_amount INTEGER NOT NULL, - previous_level INTEGER NOT NULL, - new_level INTEGER NOT NULL, - change_reason TEXT, - bot_mood TEXT, - timestamp REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - await conn.commit() - logger.debug("群数据库表结构初始化完成") - - except aiosqlite.Error as e: - logger.error(f"初始化群数据库失败: {e}", exc_info=True) - raise DataStorageError(f"初始化群数据库失败: {str(e)}") - - async def save_style_profile(self, group_id: str, profile_data: Dict[str, Any]): - """保存风格档案到数据库""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT OR REPLACE INTO style_profiles - (profile_name, vocabulary_richness, sentence_complexity, emotional_expression, - interaction_tendency, topic_diversity, formality_level, creativity_score) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - profile_data['profile_name'], - profile_data.get('vocabulary_richness'), - profile_data.get('sentence_complexity'), - profile_data.get('emotional_expression'), - profile_data.get('interaction_tendency'), - profile_data.get('topic_diversity'), - profile_data.get('formality_level'), - profile_data.get('creativity_score') - )) - await conn.commit() - logger.debug(f"风格档案 '{profile_data['profile_name']}' 已保存到群 {group_id} 数据库。") - except aiosqlite.Error as e: - logger.error(f"保存风格档案失败: {e}", exc_info=True) - raise DataStorageError(f"保存风格档案失败: {str(e)}") - - async def load_style_profile(self, group_id: str, profile_name: str) -> Optional[Dict[str, Any]]: - """从数据库加载风格档案""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT profile_name, vocabulary_richness, sentence_complexity, emotional_expression, - interaction_tendency, topic_diversity, formality_level, creativity_score - FROM style_profiles WHERE profile_name = ? - ''', (profile_name,)) - row = await cursor.fetchone() - if not row: - return None - return { - 'profile_name': row[0], - 'vocabulary_richness': row[1], - 'sentence_complexity': row[2], - 'emotional_expression': row[3], - 'interaction_tendency': row[4], - 'topic_diversity': row[5], - 'formality_level': row[6], - 'creativity_score': row[7] - } - except aiosqlite.Error as e: - logger.error(f"加载风格档案失败: {e}", exc_info=True) - return None - - async def save_user_profile(self, group_id: str, profile_data: Dict[str, Any]): - """保存用户画像到数据库""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT OR REPLACE INTO user_profiles - (qq_id, qq_name, nicknames, activity_pattern, communication_style, - topic_preferences, emotional_tendency, last_active, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - profile_data['qq_id'], - profile_data.get('qq_name', ''), - json.dumps(profile_data.get('nicknames', []), ensure_ascii=False), - json.dumps(profile_data.get('activity_pattern', {}), ensure_ascii=False), - json.dumps(profile_data.get('communication_style', {}), ensure_ascii=False), - json.dumps(profile_data.get('topic_preferences', {}), ensure_ascii=False), - json.dumps(profile_data.get('emotional_tendency', {}), ensure_ascii=False), - profile_data.get('last_active', time.time()), # 使用profile中的值或当前时间 - datetime.now().isoformat() - )) - - await conn.commit() - - except aiosqlite.Error as e: - logger.error(f"保存用户画像失败: {e}", exc_info=True) - raise DataStorageError(f"保存用户画像失败: {str(e)}") - - async def load_user_profile(self, group_id: str, qq_id: str) -> Optional[Dict[str, Any]]: - """从数据库加载用户画像""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT qq_id, qq_name, nicknames, activity_pattern, communication_style, - topic_preferences, emotional_tendency, last_active - FROM user_profiles WHERE qq_id = ? - ''', (qq_id,)) - - row = await cursor.fetchone() - if not row: - return None - - return { - 'qq_id': row[0], - 'qq_name': row[1], - 'nicknames': json.loads(row[2]) if row[2] else [], - 'activity_pattern': json.loads(row[3]) if row[3] else {}, - 'communication_style': json.loads(row[4]) if row[4] else {}, - 'topic_preferences': json.loads(row[5]) if row[5] else {}, - 'emotional_tendency': json.loads(row[6]) if row[6] else {}, - 'last_active': row[7] - } - - except aiosqlite.Error as e: - logger.error(f"加载用户画像失败: {e}", exc_info=True) - return None - - async def save_social_relation(self, group_id: str, relation_data: Dict[str, Any]): - """保存社交关系到数据库""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT OR REPLACE INTO social_relations - (from_user, to_user, relation_type, strength, frequency, last_interaction, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?) - ''', ( - relation_data['from_user'], - relation_data['to_user'], - relation_data['relation_type'], - relation_data['strength'], - relation_data['frequency'], - relation_data['last_interaction'], - datetime.now().isoformat() - )) - - await conn.commit() - - except aiosqlite.Error as e: - logger.error(f"保存社交关系失败: {e}", exc_info=True) - raise DataStorageError(f"保存社交关系失败: {str(e)}") - - async def get_social_relations_by_group(self, group_id: str) -> List[Dict[str, Any]]: - """获取指定群组的社交关系""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - # 添加 WHERE 子句来过滤特定群组的关系 - # 社交关系中的 from_user 和 to_user 格式为 "group_id:user_id" - await cursor.execute(''' - SELECT from_user, to_user, relation_type, strength, frequency, last_interaction - FROM social_relations - WHERE (from_user LIKE ? OR to_user LIKE ?) - ORDER BY frequency DESC, strength DESC - ''', (f'{group_id}:%', f'{group_id}:%')) - - rows = await cursor.fetchall() - relations = [] - - for row in rows: - try: - # 添加行数据验证 - if len(row) < 6: - self._logger.warning(f"社交关系数据行不完整 (期望6个字段,实际{len(row)}个),跳过: {row}") - continue - - relations.append({ - 'from_user': row[0], - 'to_user': row[1], - 'relation_type': row[2], - 'strength': float(row[3]) if row[3] else 0.0, - 'frequency': int(row[4]) if row[4] else 0, - 'last_interaction': row[5] - }) - except Exception as row_error: - self._logger.warning(f"处理社交关系数据行时出错,跳过: {row_error}, row: {row}") - - self._logger.info(f"群组 {group_id} 加载了 {len(relations)} 条社交关系") - return relations - - except aiosqlite.Error as e: - logger.error(f"获取社交关系失败: {e}", exc_info=True) - return [] - - async def get_user_social_relations(self, group_id: str, user_id: str) -> Dict[str, Any]: - """ - 获取指定用户在群组中的社交关系 - - Args: - group_id: 群组ID - user_id: 用户ID - - Returns: - 包含用户社交关系的字典,包括: - - outgoing: 该用户发起的关系列表 - - incoming: 指向该用户的关系列表 - - total_relations: 总关系数 - """ - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - user_key = f"{group_id}:{user_id}" - - # 获取该用户发起的关系(outgoing) - await cursor.execute(''' - SELECT from_user, to_user, relation_type, strength, frequency, last_interaction - FROM social_relations - WHERE from_user = ? OR from_user = ? - ORDER BY frequency DESC, strength DESC - LIMIT 10 - ''', (user_key, user_id)) - - outgoing_rows = await cursor.fetchall() - outgoing_relations = [] - - for row in outgoing_rows: - outgoing_relations.append({ - 'from_user': row[0], - 'to_user': row[1], - 'relation_type': row[2], - 'strength': row[3], - 'frequency': row[4], - 'last_interaction': row[5] - }) - - # 获取指向该用户的关系(incoming) - await cursor.execute(''' - SELECT from_user, to_user, relation_type, strength, frequency, last_interaction - FROM social_relations - WHERE to_user = ? OR to_user = ? - ORDER BY frequency DESC, strength DESC - LIMIT 10 - ''', (user_key, user_id)) - - incoming_rows = await cursor.fetchall() - incoming_relations = [] - - for row in incoming_rows: - incoming_relations.append({ - 'from_user': row[0], - 'to_user': row[1], - 'relation_type': row[2], - 'strength': row[3], - 'frequency': row[4], - 'last_interaction': row[5] - }) - - return { - 'user_id': user_id, - 'group_id': group_id, - 'outgoing': outgoing_relations, - 'incoming': incoming_relations, - 'total_relations': len(outgoing_relations) + len(incoming_relations) - } - - except aiosqlite.Error as e: - logger.error(f"获取用户社交关系失败: {e}", exc_info=True) - return { - 'user_id': user_id, - 'group_id': group_id, - 'outgoing': [], - 'incoming': [], - 'total_relations': 0 - } - - - async def save_raw_message(self, message_data) -> int: - """ - 将原始消息保存到全局消息数据库。 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 检查message_data是否为字典或对象 - if hasattr(message_data, 'sender_id'): - # 如果是对象,直接访问属性 - await cursor.execute(''' - INSERT INTO raw_messages (sender_id, sender_name, message, group_id, platform, timestamp) - VALUES (?, ?, ?, ?, ?, ?) - ''', ( - message_data.sender_id, - message_data.sender_name, - message_data.message, - message_data.group_id, - message_data.platform, - message_data.timestamp - )) - else: - # 如果是字典,使用字典访问 - await cursor.execute(''' - INSERT INTO raw_messages (sender_id, sender_name, message, group_id, platform, timestamp) - VALUES (?, ?, ?, ?, ?, ?) - ''', ( - message_data.get('sender_id'), - message_data.get('sender_name'), - message_data.get('message'), - message_data.get('group_id'), - message_data.get('platform'), - message_data.get('timestamp') - )) - - message_id = cursor.lastrowid - await conn.commit() - logger.info(f"💾 数据库写入成功: ID={message_id}, timestamp={message_data.timestamp if hasattr(message_data, 'timestamp') else message_data.get('timestamp')}") - return message_id - - except aiosqlite.Error as e: - logger.error(f"保存原始消息失败: {e}", exc_info=True) - raise DataStorageError(f"保存原始消息失败: {str(e)}") - finally: - await cursor.close() - - async def get_unprocessed_messages(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: - """ - 获取未处理的原始消息 - - Args: - limit: 限制返回的消息数量 - - Returns: - 未处理的消息列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - if limit: - await cursor.execute(''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp - FROM raw_messages - WHERE processed = FALSE - ORDER BY timestamp ASC - LIMIT ? - ''', (limit,)) - else: - await cursor.execute(''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp - FROM raw_messages - WHERE processed = FALSE - ORDER BY timestamp ASC - ''') - - messages = [] - for row in await cursor.fetchall(): - messages.append({ - 'id': row[0], - 'sender_id': row[1], - 'sender_name': row[2], - 'message': row[3], - 'group_id': row[4], - 'platform': row[5], - 'timestamp': row[6] - }) - - logger.debug(f"获取到 {len(messages)} 条未处理消息") - return messages - - except aiosqlite.Error as e: - logger.error(f"获取未处理消息失败: {e}", exc_info=True) - raise DataStorageError(f"获取未处理消息失败: {str(e)}") - finally: - await cursor.close() - - async def mark_messages_processed(self, message_ids: List[int]) -> bool: - """ - 标记消息为已处理 - - Args: - message_ids: 消息ID列表 - - Returns: - 是否成功标记 - """ - if not message_ids: - return True - - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 批量更新消息状态 - placeholders = ','.join(['?' for _ in message_ids]) - await cursor.execute(f''' - UPDATE raw_messages - SET processed = TRUE - WHERE id IN ({placeholders}) - ''', message_ids) - - await conn.commit() - logger.debug(f"已标记 {len(message_ids)} 条消息为已处理") - return True - - except aiosqlite.Error as e: - logger.error(f"标记消息处理状态失败: {e}", exc_info=True) - raise DataStorageError(f"标记消息处理状态失败: {str(e)}") - finally: - await cursor.close() - - async def add_filtered_message(self, filtered_data: Dict[str, Any]) -> int: - """ - 添加筛选后的消息 - - Args: - filtered_data: 筛选后的消息数据 - - Returns: - 筛选消息的ID - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - current_time = int(time.time()) - await cursor.execute(''' - INSERT INTO filtered_messages - (raw_message_id, message, sender_id, confidence, filter_reason, timestamp, quality_scores, group_id, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - filtered_data.get('raw_message_id'), - filtered_data.get('message'), - filtered_data.get('sender_id'), - filtered_data.get('confidence', 0.8), - filtered_data.get('filter_reason', ''), - filtered_data.get('timestamp') or current_time, - json.dumps(filtered_data.get('quality_scores', {}), ensure_ascii=False), - filtered_data.get('group_id'), - current_time - )) - - filtered_id = cursor.lastrowid - await conn.commit() - logger.debug(f"筛选消息已保存,ID: {filtered_id}") - return filtered_id - - except aiosqlite.Error as e: - logger.error(f"添加筛选消息失败: {e}", exc_info=True) - raise DataStorageError(f"添加筛选消息失败: {str(e)}") - finally: - await cursor.close() - - async def get_filtered_messages_for_learning(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: - """ - 获取用于学习的筛选消息 - - Args: - limit: 限制返回的消息数量 - - Returns: - 筛选消息列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - if limit: - await cursor.execute(''' - SELECT id, message, sender_id, confidence, quality_scores, timestamp, group_id - FROM filtered_messages - WHERE used_for_learning = FALSE - ORDER BY timestamp DESC - LIMIT ? - ''', (limit,)) - else: - await cursor.execute(''' - SELECT id, message, sender_id, confidence, quality_scores, timestamp, group_id - FROM filtered_messages - WHERE used_for_learning = FALSE - ORDER BY timestamp DESC - ''') - - messages = [] - for row in await cursor.fetchall(): - try: - # 添加行数据验证 - if len(row) < 7: - self._logger.warning(f"筛选消息行数据不完整 (期望7个字段,实际{len(row)}个),跳过: {row}") - continue - - quality_scores = {} - try: - if row[4]: # quality_scores - quality_scores = json.loads(row[4]) - except (json.JSONDecodeError, TypeError): - pass - - messages.append({ - 'id': row[0], - 'message': row[1], - 'sender_id': row[2], - 'confidence': float(row[3]) if row[3] else 0.0, - 'quality_scores': quality_scores, - 'timestamp': float(row[5]) if row[5] else 0, - 'group_id': row[6] - }) - except Exception as row_error: - self._logger.warning(f"处理筛选消息行时出错,跳过: {row_error}, row: {row if len(row) < 20 else 'too long'}") - - return messages - - except aiosqlite.Error as e: - logger.error(f"获取学习消息失败: {e}", exc_info=True) - raise DataStorageError(f"获取学习消息失败: {str(e)}") - finally: - await cursor.close() - - async def get_recent_filtered_messages(self, group_id: str, limit: int = 5) -> List[Dict[str, Any]]: - """ - 获取指定群组最近的筛选消息 - - Args: - group_id: 群组ID - limit: 消息数量限制 - - Returns: - 筛选消息列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT id, message, sender_id, confidence, quality_scores, timestamp - FROM filtered_messages - WHERE group_id = ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, limit)) - - messages = [] - for row in await cursor.fetchall(): - quality_scores = {} - try: - if row[4]: - quality_scores = json.loads(row[4]) - except json.JSONDecodeError: - pass - - messages.append({ - 'id': row[0], - 'message': row[1], - 'sender_id': row[2], - 'confidence': row[3], - 'quality_scores': quality_scores, - 'timestamp': row[5] - }) - - return messages - - except aiosqlite.Error as e: - logger.error(f"获取最近筛选消息失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def get_recent_raw_messages(self, group_id: str, limit: int = 25) -> List[Dict[str, Any]]: - """ - 获取指定群组最近的原始消息,用于表达风格学习 - - Args: - group_id: 群组ID - limit: 消息数量限制 - - Returns: - 原始消息列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp - FROM raw_messages - WHERE group_id = ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, limit)) - - messages = [] - for row in await cursor.fetchall(): - messages.append({ - 'id': row[0], - 'sender_id': row[1], - 'sender_name': row[2], - 'message': row[3], - 'group_id': row[4], - 'platform': row[5], - 'timestamp': row[6] - }) - - return messages - - except aiosqlite.Error as e: - logger.error(f"获取最近原始消息失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def get_messages_statistics(self) -> Dict[str, Any]: - """ - 获取消息统计信息 - - Returns: - 统计信息字典 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 获取原始消息统计 - await cursor.execute('SELECT COUNT(*) FROM raw_messages') - result = await cursor.fetchone() - if not result or len(result) == 0: - total_messages = 0 - else: - total_messages = int(result[0]) if result[0] and str(result[0]).isdigit() else 0 - - await cursor.execute('SELECT COUNT(*) FROM raw_messages WHERE processed = FALSE') - result = await cursor.fetchone() - unprocessed_messages = int(result[0]) if result and result[0] and str(result[0]).replace('-', '').isdigit() else 0 - - # 获取筛选消息统计 - await cursor.execute('SELECT COUNT(*) FROM filtered_messages') - result = await cursor.fetchone() - filtered_messages = int(result[0]) if result and result[0] and str(result[0]).replace('-', '').isdigit() else 0 - - await cursor.execute('SELECT COUNT(*) FROM filtered_messages WHERE used_for_learning = FALSE') - result = await cursor.fetchone() - unused_filtered_messages = int(result[0]) if result and result[0] and str(result[0]).replace('-', '').isdigit() else 0 - - stats = { - 'total_messages': total_messages, - 'unprocessed_messages': unprocessed_messages, - 'filtered_messages': filtered_messages, - 'unused_filtered_messages': unused_filtered_messages, - 'raw_messages': total_messages # 兼容旧接口 - } - - # 验证返回的统计数据没有表名 - for key, value in stats.items(): - if isinstance(value, str) and not value.replace('-', '').isdigit(): - self._logger.error(f"get_messages_statistics 返回了非数字字符串: {key}={value},设置为0") - stats[key] = 0 - - return stats - - except aiosqlite.Error as e: - self._logger.error(f"获取消息统计失败: {e}", exc_info=True) - return { - 'total_messages': 0, - 'unprocessed_messages': 0, - 'filtered_messages': 0, - 'unused_filtered_messages': 0, - 'raw_messages': 0 - } - finally: - await cursor.close() - - async def get_pending_style_reviews(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取待审查的风格学习记录""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.get_pending_style_reviews_orm(limit) - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 确保表存在 - await self._ensure_style_review_table_exists(cursor) - - await cursor.execute(''' - SELECT id, type, group_id, timestamp, learned_patterns, few_shots_content, - status, description, created_at - FROM style_learning_reviews - WHERE status = 'pending' - ORDER BY timestamp DESC - LIMIT ? - ''', (limit,)) - - reviews = [] - for row in await cursor.fetchall(): - learned_patterns = [] - try: - if row[4]: # learned_patterns - learned_patterns = json.loads(row[4]) - except json.JSONDecodeError: - pass - - reviews.append({ - 'id': row[0], - 'type': row[1], - 'group_id': row[2], - 'timestamp': row[3], - 'learned_patterns': learned_patterns, - 'few_shots_content': row[5], - 'status': row[6], - 'description': row[7], - 'created_at': row[8] - }) - - return reviews - - except Exception as e: - self._logger.error(f"获取待审查风格学习记录失败: {e}") - return [] - finally: - await cursor.close() - - async def get_reviewed_style_learning_updates(self, limit: int = 50, offset: int = 0, status_filter: str = None) -> List[Dict[str, Any]]: - """获取已审查的风格学习记录""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.get_reviewed_style_learning_updates_orm(limit, offset, status_filter) - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 确保表存在 - await self._ensure_style_review_table_exists(cursor) - - # 构建查询条件 - where_clause = "WHERE status != 'pending'" - params = [] - - if status_filter: - where_clause += " AND status = ?" - params.append(status_filter) - - params.extend([limit, offset]) - - await cursor.execute(f''' - SELECT id, type, group_id, timestamp, learned_patterns, few_shots_content, - status, description, created_at, updated_at - FROM style_learning_reviews - {where_clause} - ORDER BY updated_at DESC - LIMIT ? OFFSET ? - ''', params) - - reviews = [] - for row in await cursor.fetchall(): - learned_patterns = [] - try: - if row[4]: # learned_patterns - learned_patterns = json.loads(row[4]) - except json.JSONDecodeError: - pass - - reviews.append({ - 'id': row[0], - 'type': row[1], - 'group_id': row[2], - 'timestamp': row[3], - 'learned_patterns': learned_patterns, - 'few_shots_content': row[5], - 'status': row[6], - 'description': row[7], - 'created_at': row[8], - 'review_time': row[9] if len(row) > 9 else None - }) - - return reviews - - except Exception as e: - self._logger.error(f"获取已审查风格学习记录失败: {e}") - return [] - finally: - await cursor.close() - - async def get_detailed_metrics(self) -> Dict[str, Any]: - """获取详细监控数据""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - detailed_data = { - 'api_metrics': { - 'hours': list(range(24)), - 'response_times': [100 + i * 10 for i in range(24)] - }, - 'database_metrics': { - 'table_stats': {} - }, - 'system_metrics': { - 'memory_percent': 45.2, - 'cpu_percent': 23.1, - 'disk_percent': 67.8 - }, - 'connection_pool_stats': { - 'total_connections': self.connection_pool.total_connections, - 'active_connections': self.connection_pool.active_connections, - 'max_connections': self.connection_pool.max_connections, - 'pool_usage': round(self.connection_pool.active_connections / self.connection_pool.max_connections * 100, 1) if self.connection_pool.max_connections > 0 else 0 - } - } - - # 获取数据库表统计 - try: - tables = ['raw_messages', 'filtered_messages', 'expression_patterns'] - for table in tables: - try: - await cursor.execute(f'SELECT COUNT(*) FROM {table}') - count = (await cursor.fetchone())[0] - detailed_data['database_metrics']['table_stats'][table] = {'count': count} - except: - detailed_data['database_metrics']['table_stats'][table] = {'count': 0} - - except Exception as e: - self._logger.warning(f"获取数据库表统计失败: {e}") - - return detailed_data - - except Exception as e: - self._logger.error(f"获取详细监控数据失败: {e}") - return { - 'api_metrics': {'hours': [], 'response_times': []}, - 'database_metrics': {'table_stats': {}}, - 'system_metrics': {'memory_percent': 0, 'cpu_percent': 0, 'disk_percent': 0}, - 'connection_pool_stats': {'total_connections': 0, 'active_connections': 0, 'max_connections': 0, 'pool_usage': 0} - } - finally: - await cursor.close() - - async def get_message_statistics(self, group_id: str = None) -> Dict[str, Any]: - """获取消息统计信息,兼容 webui.py 的调用""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - if group_id: - # 获取特定群组的统计 - await cursor.execute('SELECT COUNT(*) FROM raw_messages WHERE group_id = ?', (group_id,)) - total_messages = (await cursor.fetchone())[0] - - await cursor.execute('SELECT COUNT(*) FROM raw_messages WHERE group_id = ? AND processed = FALSE', (group_id,)) - unprocessed_messages = (await cursor.fetchone())[0] - - await cursor.execute('SELECT COUNT(*) FROM filtered_messages WHERE group_id = ?', (group_id,)) - filtered_messages = (await cursor.fetchone())[0] - - await cursor.execute('SELECT COUNT(*) FROM filtered_messages WHERE group_id = ? AND used_for_learning = FALSE', (group_id,)) - unused_filtered_messages = (await cursor.fetchone())[0] - else: - # 获取全局统计 - return await self.get_messages_statistics() - - return { - 'total_messages': total_messages, - 'unprocessed_messages': unprocessed_messages, - 'filtered_messages': filtered_messages, - 'unused_filtered_messages': unused_filtered_messages, - 'raw_messages': total_messages, - 'group_id': group_id - } - - except aiosqlite.Error as e: - self._logger.error(f"获取消息统计失败: {e}", exc_info=True) - return { - 'total_messages': 0, - 'unprocessed_messages': 0, - 'filtered_messages': 0, - 'unused_filtered_messages': 0, - 'raw_messages': 0, - 'group_id': group_id - } - finally: - await cursor.close() - - async def get_recent_learning_batches(self, limit: int = 10) -> List[Dict[str, Any]]: - """获取最近的学习批次记录""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 确保表存在 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS learning_batches ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - batch_name TEXT NOT NULL, - start_time REAL NOT NULL, - end_time REAL, - quality_score REAL, - processed_messages INTEGER DEFAULT 0, - message_count INTEGER DEFAULT 0, - filtered_count INTEGER DEFAULT 0, - success BOOLEAN DEFAULT FALSE, - error_message TEXT - ) - ''') - - await cursor.execute(''' - SELECT group_id, batch_name, start_time, end_time, quality_score, - processed_messages, message_count, filtered_count, success, error_message - FROM learning_batches - ORDER BY start_time DESC - LIMIT ? - ''', (limit,)) - - batches = [] - for row in await cursor.fetchall(): - batches.append({ - 'group_id': row[0], - 'batch_name': row[1], - 'start_time': row[2], - 'end_time': row[3], - 'quality_score': row[4] or 0, - 'processed_messages': row[5] or 0, - 'message_count': row[6] or 0, - 'filtered_count': row[7] or 0, - 'success': bool(row[8]), - 'error_message': row[9] - }) - - return batches - - except Exception as e: - self._logger.error(f"获取最近学习批次失败: {e}") - return [] - finally: - await cursor.close() - - async def get_style_progress_data(self) -> List[Dict[str, Any]]: - """获取风格进度数据""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 首先检查表是否存在 - await cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='learning_batches'") - if not await cursor.fetchone(): - self._logger.info("learning_batches 表不存在,返回空列表") - return [] - - # 从学习批次中获取进度数据,包含消息数量信息 - # ✅ 只显示有实际消息的记录(过滤旧的空数据) - await cursor.execute(''' - SELECT group_id, start_time, quality_score, success, - processed_messages, filtered_count, batch_name - FROM learning_batches - WHERE quality_score IS NOT NULL - AND processed_messages > 0 - ORDER BY start_time DESC - LIMIT 30 - ''') - - progress_data = [] - rows = await cursor.fetchall() - - self._logger.debug(f"get_style_progress_data 获取到 {len(rows)} 行数据") - if rows and len(rows) > 0: - self._logger.debug(f"第一行数据: {rows[0]}, 列数: {len(rows[0])}") - - for row in rows: - try: - # 添加行数据验证(现在有7个字段) - if len(row) < 4: - self._logger.warning(f"学习批次进度数据行不完整 (期望至少4个字段,实际{len(row)}个),跳过: {row}") - continue - - progress_item = { - 'group_id': row[0], - 'timestamp': float(row[1]) if row[1] else 0, - 'quality_score': float(row[2]) if row[2] else 0, - 'success': bool(row[3]) - } - - # 添加消息数量信息(如果存在) - if len(row) > 4: - progress_item['processed_messages'] = int(row[4]) if row[4] else 0 - if len(row) > 5: - progress_item['filtered_count'] = int(row[5]) if row[5] else 0 - if len(row) > 6: - progress_item['batch_name'] = row[6] if row[6] else '未命名' - - progress_data.append(progress_item) - except Exception as row_error: - self._logger.warning(f"处理学习批次进度数据行时出错,跳过: {row_error}, row: {row}") - - return progress_data - - except Exception as e: - self._logger.warning(f"从learning_batches表获取进度数据失败: {e}") - return [] - finally: - await cursor.close() - - async def get_style_learning_statistics(self) -> Dict[str, Any]: - """获取风格学习统计数据""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - stats = { - 'unique_styles': 0, - 'avg_confidence': 0, - 'total_samples': 0, - 'latest_update': None - } - - # 从表达模式表获取统计 - try: - await cursor.execute('SELECT COUNT(*) FROM expression_patterns') - stats['total_samples'] = (await cursor.fetchone())[0] or 0 - - await cursor.execute('SELECT AVG(weight), MAX(create_time) FROM expression_patterns') - row = await cursor.fetchone() - if row[0]: - stats['avg_confidence'] = round((row[0] or 0) * 100, 1) - - if row[1]: - stats['latest_update'] = datetime.fromtimestamp(row[1]).strftime('%Y-%m-%d %H:%M') - - # 计算独特风格数量(基于群组) - await cursor.execute('SELECT COUNT(DISTINCT group_id) FROM expression_patterns') - stats['unique_styles'] = (await cursor.fetchone())[0] or 0 - - except Exception as e: - self._logger.warning(f"从expression_patterns表获取统计失败: {e}") - - return stats - - except Exception as e: - self._logger.error(f"获取风格学习统计失败: {e}") - return { - 'unique_styles': 0, - 'avg_confidence': 0, - 'total_samples': 0, - 'latest_update': None - } - finally: - await cursor.close() - - async def get_group_messages_statistics(self, group_id: str) -> Dict[str, Any]: - """ - 获取指定群组的消息统计信息 - - Args: - group_id: 群组ID - - Returns: - 统计信息字典 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 获取原始消息统计 - await cursor.execute('SELECT COUNT(*) FROM raw_messages WHERE group_id = ?', (group_id,)) - result = await cursor.fetchone() - total_messages = int(result[0]) if result and result[0] and str(result[0]).replace('-', '').isdigit() else 0 - - await cursor.execute('SELECT COUNT(*) FROM raw_messages WHERE group_id = ? AND processed = FALSE', (group_id,)) - result = await cursor.fetchone() - unprocessed_messages = int(result[0]) if result and result[0] and str(result[0]).replace('-', '').isdigit() else 0 - - # 获取筛选消息统计 - await cursor.execute('SELECT COUNT(*) FROM filtered_messages WHERE group_id = ?', (group_id,)) - result = await cursor.fetchone() - filtered_messages = int(result[0]) if result and result[0] and str(result[0]).replace('-', '').isdigit() else 0 - - await cursor.execute('SELECT COUNT(*) FROM filtered_messages WHERE group_id = ? AND used_for_learning = FALSE', (group_id,)) - result = await cursor.fetchone() - unused_filtered_messages = int(result[0]) if result and result[0] and str(result[0]).replace('-', '').isdigit() else 0 - - stats = { - 'total_messages': total_messages, - 'unprocessed_messages': unprocessed_messages, - 'filtered_messages': filtered_messages, - 'unused_filtered_messages': unused_filtered_messages, - 'raw_messages': total_messages # 兼容旧接口 - } - - # 验证返回的统计数据没有表名 - for key, value in stats.items(): - if isinstance(value, str) and not value.replace('-', '').isdigit(): - self._logger.error(f"get_group_messages_statistics 返回了非数字字符串: {key}={value},设置为0") - stats[key] = 0 - - return stats - - except aiosqlite.Error as e: - logger.error(f"获取群组消息统计失败: {e}", exc_info=True) - return { - 'total_messages': 0, - 'unprocessed_messages': 0, - 'filtered_messages': 0, - 'unused_filtered_messages': 0, - 'raw_messages': 0 - } - finally: - await cursor.close() - - async def load_social_graph(self, group_id: str) -> List[Dict[str, Any]]: - """加载完整社交图谱""" - self._logger.debug(f"[数据库] 开始加载群组 {group_id} 的社交图谱") - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT from_user, to_user, relation_type, strength, frequency, last_interaction - FROM social_relations ORDER BY strength DESC - ''') - - relations = [] - for row in await cursor.fetchall(): - relations.append({ - 'from_user': row[0], - 'to_user': row[1], - 'relation_type': row[2], - 'strength': row[3], - 'frequency': row[4], - 'last_interaction': row[5] - }) - - self._logger.info(f"[数据库] 成功加载群组 {group_id} 的社交图谱: {len(relations)} 条关系记录") - if len(relations) == 0: - self._logger.warning(f"[数据库] 警告: 群组 {group_id} 的social_relations表中没有数据!") - else: - # 输出前3条示例 - self._logger.debug(f"[数据库] 社交关系示例: {relations[:3]}") - - return relations - - except aiosqlite.Error as e: - self._logger.error(f"[数据库] 加载社交图谱失败 (群组: {group_id}): {e}", exc_info=True) - return [] - - async def get_messages_for_replay(self, group_id: str, days: int, limit: int) -> List[Dict[str, Any]]: - """ - 从全局消息数据库获取指定群组在过去一段时间内的原始消息,用于记忆重放。 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - start_timestamp = time.time() - (days * 86400) # 转换为秒 - - await cursor.execute(''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp - FROM raw_messages - WHERE group_id = ? AND timestamp > ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, start_timestamp, limit)) - - messages = [] - for row in await cursor.fetchall(): - messages.append({ - 'id': row[0], - 'sender_id': row[1], - 'sender_name': row[2], - 'message': row[3], - 'group_id': row[4], - 'platform': row[5], - 'timestamp': row[6] - }) - - return messages - - except aiosqlite.Error as e: - self._logger.error(f"获取记忆重放消息失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def backup_persona(self, group_id: str, backup_data: Dict[str, Any]) -> int: - """备份人格数据""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - # 获取当前时间戳 - current_timestamp = time.time() - - await cursor.execute(''' - INSERT INTO persona_backups (backup_name, timestamp, original_persona, imitation_dialogues, backup_reason) - VALUES (?, ?, ?, ?, ?) - ''', ( - backup_data['backup_name'], - current_timestamp, - json.dumps(backup_data['original_persona'], ensure_ascii=False), - json.dumps(backup_data.get('imitation_dialogues', []), ensure_ascii=False), - backup_data.get('backup_reason', 'Auto backup before update') - )) - - backup_id = cursor.lastrowid - await conn.commit() - - logger.info(f"人格数据已备份,备份ID: {backup_id}") - return backup_id - - except aiosqlite.Error as e: - logger.error(f"备份人格数据失败: {e}", exc_info=True) - raise DataStorageError(f"备份人格数据失败: {str(e)}") - - async def get_persona_backups(self, group_id: str, limit: int = 10) -> List[Dict[str, Any]]: - """获取最近的人格备份""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT id, backup_name, created_at FROM persona_backups - ORDER BY created_at DESC LIMIT ? - ''', (limit,)) - - backups = [] - for row in await cursor.fetchall(): - backups.append({ - 'id': row[0], - 'backup_name': row[1], - 'created_at': row[2] - }) - - return backups - - except aiosqlite.Error as e: - logger.error(f"获取人格备份失败: {e}", exc_info=True) - return [] - - async def restore_persona(self, group_id: str, backup_id: int) -> Optional[Dict[str, Any]]: - """从备份恢复人格数据""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT backup_name, original_persona, imitation_dialogues, backup_reason - FROM persona_backups WHERE id = ? - ''', (backup_id,)) - - row = await cursor.fetchone() - if not row: - return None - - return { - 'backup_name': row[0], - 'original_persona': json.loads(row[1]), - 'imitation_dialogues': json.loads(row[2]), - 'backup_reason': row[3] - } - - except aiosqlite.Error as e: - logger.error(f"恢复人格数据失败: {e}", exc_info=True) - return None - - async def save_persona_update_record(self, record: Dict[str, Any]) -> int: - """保存人格更新记录到数据库""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT INTO persona_update_records (timestamp, group_id, update_type, original_content, new_content, reason, status) - VALUES (?, ?, ?, ?, ?, ?, ?) - ''', ( - record.get('timestamp', time.time()), - record.get('group_id'), - record.get('update_type'), - record.get('original_content'), - record.get('new_content'), - record.get('reason'), - record.get('status', 'pending') - )) - - record_id = cursor.lastrowid - await conn.commit() - logger.debug(f"人格更新记录已保存,ID: {record_id}") - return record_id - - except aiosqlite.Error as e: - logger.error(f"保存人格更新记录失败: {e}", exc_info=True) - raise DataStorageError(f"保存人格更新记录失败: {str(e)}") - finally: - await cursor.close() - - async def get_pending_persona_update_records(self) -> List[Dict[str, Any]]: - """获取所有待审查的人格更新记录""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 首先检查表是否存在以及包含什么数据 - await cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='persona_update_records'") - if not await cursor.fetchone(): - self._logger.info("persona_update_records 表不存在") - return [] - - # 检查表中总共有多少记录 - await cursor.execute('SELECT COUNT(*) FROM persona_update_records') - total_count = (await cursor.fetchone())[0] - self._logger.info(f"persona_update_records 表中总共有 {total_count} 条记录") - - # 检查各种状态的记录数量 - await cursor.execute('SELECT status, COUNT(*) FROM persona_update_records GROUP BY status') - status_counts = await cursor.fetchall() - self._logger.info(f"各状态记录数量: {dict(status_counts)}") - - # 优先查询pending状态的记录 - await cursor.execute(''' - SELECT id, timestamp, group_id, update_type, original_content, new_content, reason, status, reviewer_comment, review_time - FROM persona_update_records - WHERE status = 'pending' - ORDER BY timestamp DESC - ''') - - records = [] - pending_rows = await cursor.fetchall() - self._logger.info(f"找到 {len(pending_rows)} 条pending状态的记录") - - for row in pending_rows: - records.append({ - 'id': row[0], - 'timestamp': row[1], - 'group_id': row[2], - 'update_type': row[3], - 'original_content': row[4], - 'new_content': row[5], - 'reason': row[6], - 'status': row[7], - 'reviewer_comment': row[8], - 'review_time': row[9] - }) - - # 如果没有pending状态的记录,尝试查询所有记录(可能status字段为空或其他值) - if not records and total_count > 0: - self._logger.info("没有pending状态记录,查询所有记录...") - await cursor.execute(''' - SELECT id, timestamp, group_id, update_type, original_content, new_content, reason, - COALESCE(status, 'pending') as status, reviewer_comment, review_time - FROM persona_update_records - WHERE status IS NULL OR status = '' OR status = 'pending' - ORDER BY timestamp DESC - LIMIT 50 - ''') - - all_rows = await cursor.fetchall() - self._logger.info(f"找到 {len(all_rows)} 条可能的待审查记录") - - for row in all_rows: - records.append({ - 'id': row[0], - 'timestamp': row[1], - 'group_id': row[2], - 'update_type': row[3], - 'original_content': row[4], - 'new_content': row[5], - 'reason': row[6], - 'status': 'pending', # 强制设置为pending - 'reviewer_comment': row[8], - 'review_time': row[9] - }) - - return records - - except aiosqlite.Error as e: - logger.error(f"获取待审查人格更新记录失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def update_persona_update_record_status(self, record_id: int, status: str, reviewer_comment: Optional[str] = None) -> bool: - """更新人格更新记录的状态""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - review_time = time.time() - await cursor.execute(''' - UPDATE persona_update_records - SET status = ?, reviewer_comment = ?, review_time = ? - WHERE id = ? - ''', (status, reviewer_comment, review_time, record_id)) - - await conn.commit() - logger.debug(f"人格更新记录 {record_id} 状态已更新为 {status}") - return cursor.rowcount > 0 - - except aiosqlite.Error as e: - logger.error(f"更新人格更新记录状态失败: {e}", exc_info=True) - raise DataStorageError(f"更新人格更新记录状态失败: {str(e)}") - finally: - await cursor.close() - - async def delete_persona_update_record(self, record_id: int) -> bool: - """删除人格更新记录""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - DELETE FROM persona_update_records - WHERE id = ? - ''', (record_id,)) - - await conn.commit() - logger.debug(f"人格更新记录 {record_id} 已删除") - return cursor.rowcount > 0 - - except aiosqlite.Error as e: - logger.error(f"删除人格更新记录失败: {e}", exc_info=True) - raise DataStorageError(f"删除人格更新记录失败: {str(e)}") - finally: - await cursor.close() - - async def get_persona_update_record_by_id(self, record_id: int) -> Optional[Dict[str, Any]]: - """根据ID获取人格更新记录""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT id, timestamp, group_id, update_type, original_content, new_content, reason, status, reviewer_comment, review_time - FROM persona_update_records - WHERE id = ? - ''', (record_id,)) - - row = await cursor.fetchone() - if row: - return { - 'id': row[0], - 'timestamp': row[1], - 'group_id': row[2], - 'update_type': row[3], - 'original_content': row[4], - 'new_content': row[5], - 'reason': row[6], - 'status': row[7], - 'reviewer_comment': row[8], - 'review_time': row[9] - } - return None - - except aiosqlite.Error as e: - logger.error(f"获取人格更新记录失败: {e}", exc_info=True) - return None - finally: - await cursor.close() - - # ========== 高级功能数据库操作方法 ========== - - async def save_emotion_profile(self, group_id: str, user_id: str, profile_data: Dict[str, Any]) -> bool: - """保存情感档案""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - # 检查是否已存在表,如果不存在则创建 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS emotion_profiles ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - group_id TEXT NOT NULL, - dominant_emotions TEXT, -- JSON格式 - emotion_patterns TEXT, -- JSON格式 - empathy_level REAL DEFAULT 0.5, - emotional_stability REAL DEFAULT 0.5, - last_updated REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - UNIQUE(user_id, group_id) - ) - ''') - - await cursor.execute(''' - INSERT OR REPLACE INTO emotion_profiles - (user_id, group_id, dominant_emotions, emotion_patterns, empathy_level, emotional_stability, last_updated) - VALUES (?, ?, ?, ?, ?, ?, ?) - ''', ( - user_id, - group_id, - json.dumps(profile_data.get('dominant_emotions', {}), ensure_ascii=False), - json.dumps(profile_data.get('emotion_patterns', {}), ensure_ascii=False), - profile_data.get('empathy_level', 0.5), - profile_data.get('emotional_stability', 0.5), - profile_data.get('last_updated', time.time()) - )) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"保存情感档案失败: {e}") - return False - - async def load_emotion_profile(self, group_id: str, user_id: str) -> Optional[Dict[str, Any]]: - """加载情感档案""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT dominant_emotions, emotion_patterns, empathy_level, emotional_stability, last_updated - FROM emotion_profiles WHERE user_id = ? AND group_id = ? - ''', (user_id, group_id)) - - row = await cursor.fetchone() - if not row: - return None - - return { - 'user_id': user_id, - 'group_id': group_id, - 'dominant_emotions': json.loads(row[0]) if row[0] else {}, - 'emotion_patterns': json.loads(row[1]) if row[1] else {}, - 'empathy_level': row[2], - 'emotional_stability': row[3], - 'last_updated': row[4] - } - - except Exception as e: - self._logger.error(f"加载情感档案失败: {e}") - return None - - async def save_knowledge_entity(self, group_id: str, entity_data: Dict[str, Any]) -> bool: - """保存知识实体""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - # 检查是否已存在表,如果不存在则创建 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS knowledge_entities ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - entity_id TEXT UNIQUE NOT NULL, - name TEXT NOT NULL, - entity_type TEXT NOT NULL, - attributes TEXT, -- JSON格式 - relationships TEXT, -- JSON格式 - confidence REAL DEFAULT 0.5, - source_messages TEXT, -- JSON格式 - last_mentioned REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - await cursor.execute(''' - INSERT OR REPLACE INTO knowledge_entities - (entity_id, name, entity_type, attributes, relationships, confidence, source_messages, last_mentioned) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - entity_data.get('entity_id'), - entity_data.get('name', ''), - entity_data.get('entity_type', 'unknown'), - json.dumps(entity_data.get('attributes', {}), ensure_ascii=False), - json.dumps(entity_data.get('relationships', []), ensure_ascii=False), - entity_data.get('confidence', 0.5), - json.dumps(entity_data.get('source_messages', []), ensure_ascii=False), - entity_data.get('last_mentioned', time.time()) - )) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"保存知识实体失败: {e}") - return False - - async def get_knowledge_entities(self, group_id: str, limit: int = 100) -> List[Dict[str, Any]]: - """获取知识实体列表""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT entity_id, name, entity_type, attributes, relationships, confidence, source_messages, last_mentioned - FROM knowledge_entities - ORDER BY last_mentioned DESC - LIMIT ? - ''', (limit,)) - - entities = [] - for row in await cursor.fetchall(): - entities.append({ - 'entity_id': row[0], - 'name': row[1], - 'entity_type': row[2], - 'attributes': json.loads(row[3]) if row[3] else {}, - 'relationships': json.loads(row[4]) if row[4] else [], - 'confidence': row[5], - 'source_messages': json.loads(row[6]) if row[6] else [], - 'last_mentioned': row[7] - }) - - return entities - - except Exception as e: - self._logger.error(f"获取知识实体失败: {e}") - return [] - - # 新增强化学习相关方法 - async def save_reinforcement_learning_result(self, group_id: str, result_data: Dict[str, Any]) -> bool: - """保存强化学习结果""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT INTO reinforcement_learning_results - (group_id, timestamp, replay_analysis, optimization_strategy, reinforcement_feedback, next_action) - VALUES (?, ?, ?, ?, ?, ?) - ''', ( - group_id, - result_data.get('timestamp', time.time()), - json.dumps(result_data.get('replay_analysis', {}), ensure_ascii=False), - json.dumps(result_data.get('optimization_strategy', {}), ensure_ascii=False), - json.dumps(result_data.get('reinforcement_feedback', {}), ensure_ascii=False), - result_data.get('next_action', '') - )) - - await conn.commit() - return True - - except Exception as e: - logger.error(f"保存强化学习结果失败: {e}") - return False - finally: - await cursor.close() - - async def get_learning_history_for_reinforcement(self, group_id: str, limit: int = 50) -> List[Dict[str, Any]]: - """获取用于强化学习的历史数据""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT timestamp, quality_score, success, successful_pattern, failed_pattern - FROM learning_performance_history - WHERE group_id = ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, limit)) - - history = [] - for row in await cursor.fetchall(): - history.append({ - 'timestamp': row[0], - 'quality_score': row[1], - 'success': bool(row[2]), - 'successful_pattern': row[3] or '', - 'failed_pattern': row[4] or '' - }) - - return history - - except Exception as e: - logger.error(f"获取强化学习历史数据失败: {e}") - return [] - finally: - await cursor.close() - - async def save_persona_fusion_result(self, group_id: str, fusion_data: Dict[str, Any]) -> bool: - """保存人格融合结果""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT INTO persona_fusion_history - (group_id, timestamp, base_persona_hash, incremental_hash, fusion_result, compatibility_score) - VALUES (?, ?, ?, ?, ?, ?) - ''', ( - group_id, - fusion_data.get('timestamp', time.time()), - fusion_data.get('base_persona_hash'), - fusion_data.get('incremental_hash'), - json.dumps(fusion_data.get('fusion_result', {}), ensure_ascii=False), - fusion_data.get('compatibility_score', 0.0) - )) - - await conn.commit() - return True - - except Exception as e: - logger.error(f"保存人格融合结果失败: {e}") - return False - finally: - await cursor.close() - - async def get_persona_fusion_history(self, group_id: str, limit: int = 10) -> List[Dict[str, Any]]: - """获取人格融合历史""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT timestamp, base_persona_hash, incremental_hash, fusion_result, compatibility_score - FROM persona_fusion_history - WHERE group_id = ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, limit)) - - history = [] - for row in await cursor.fetchall(): - fusion_result = {} - try: - fusion_result = json.loads(row[3]) if row[3] else {} - except json.JSONDecodeError: - logger.warning(f"解析融合结果JSON失败: {row[3]}") - - history.append({ - 'timestamp': row[0], - 'base_persona_hash': row[1], - 'incremental_hash': row[2], - 'fusion_result': fusion_result, - 'compatibility_score': row[4] - }) - - return history - - except Exception as e: - logger.error(f"获取人格融合历史失败: {e}") - return [] - finally: - await cursor.close() - - async def save_strategy_optimization_result(self, group_id: str, optimization_data: Dict[str, Any]) -> bool: - """保存策略优化结果""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT INTO strategy_optimization_results - (group_id, timestamp, original_strategy, optimization_result, expected_improvement) - VALUES (?, ?, ?, ?, ?) - ''', ( - group_id, - optimization_data.get('timestamp', time.time()), - json.dumps(optimization_data.get('original_strategy', {}), ensure_ascii=False), - json.dumps(optimization_data.get('optimization_result', {}), ensure_ascii=False), - json.dumps(optimization_data.get('expected_improvement', {}), ensure_ascii=False) - )) - - await conn.commit() - return True - - except Exception as e: - logger.error(f"保存策略优化结果失败: {e}") - return False - finally: - await cursor.close() - - async def get_learning_performance_history(self, group_id: str, limit: int = 30) -> List[Dict[str, Any]]: - """获取学习性能历史数据""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT session_id, timestamp, quality_score, learning_time, success - FROM learning_performance_history - WHERE group_id = ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, limit)) - - history = [] - for row in await cursor.fetchall(): - history.append({ - 'session_id': row[0], - 'timestamp': row[1], - 'quality_score': row[2] or 0.0, - 'learning_time': row[3] or 0.0, - 'success': bool(row[4]) - }) - - return history - - except Exception as e: - logger.error(f"获取学习性能历史失败: {e}") - return [] - finally: - await cursor.close() - - async def save_learning_performance_record(self, group_id: str, performance_data: Dict[str, Any]) -> bool: - """保存学习性能记录""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT INTO learning_performance_history - (group_id, session_id, timestamp, quality_score, learning_time, success, successful_pattern, failed_pattern) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - group_id, - performance_data.get('session_id', ''), - performance_data.get('timestamp', time.time()), - performance_data.get('quality_score', 0.0), - performance_data.get('learning_time', 0.0), - performance_data.get('success', False), - performance_data.get('successful_pattern', ''), - performance_data.get('failed_pattern', '') - )) - - await conn.commit() - return True - - except Exception as e: - logger.error(f"保存学习性能记录失败: {e}") - return False - finally: - await cursor.close() - - async def get_messages_for_replay(self, group_id: str, days: int = 30, limit: int = 100) -> List[Dict[str, Any]]: - """获取用于记忆重放的消息""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 获取指定天数内的消息 - cutoff_time = time.time() - (days * 24 * 3600) - - await cursor.execute(''' - SELECT id, message, sender_id, group_id, timestamp - FROM raw_messages - WHERE group_id = ? AND timestamp > ? AND processed = TRUE - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, cutoff_time, limit)) - - messages = [] - for row in await cursor.fetchall(): - messages.append({ - 'message_id': row[0], - 'message': row[1], - 'sender_id': row[2], - 'group_id': row[3], - 'timestamp': row[4] - }) - - return messages - - except Exception as e: - logger.error(f"获取记忆重放消息失败: {e}") - return [] - finally: - await cursor.close() - - async def save_user_preferences(self, group_id: str, user_id: str, preferences: Dict[str, Any]) -> bool: - """保存用户偏好设置""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - # 检查是否已存在表,如果不存在则创建 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS user_preferences ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - group_id TEXT NOT NULL, - favorite_topics TEXT, -- JSON格式 - interaction_style TEXT, -- JSON格式 - learning_preferences TEXT, -- JSON格式 - adaptive_rate REAL DEFAULT 0.5, - updated_at REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - UNIQUE(user_id, group_id) - ) - ''') - - await cursor.execute(''' - INSERT OR REPLACE INTO user_preferences - (user_id, group_id, favorite_topics, interaction_style, learning_preferences, adaptive_rate, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?) - ''', ( - user_id, - group_id, - json.dumps(preferences.get('favorite_topics', []), ensure_ascii=False), - json.dumps(preferences.get('interaction_style', {}), ensure_ascii=False), - json.dumps(preferences.get('learning_preferences', {}), ensure_ascii=False), - preferences.get('adaptive_rate', 0.5), - time.time() - )) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"保存用户偏好失败: {e}") - return False - - async def load_user_preferences(self, group_id: str, user_id: str) -> Optional[Dict[str, Any]]: - """加载用户偏好设置""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT favorite_topics, interaction_style, learning_preferences, adaptive_rate, updated_at - FROM user_preferences WHERE user_id = ? AND group_id = ? - ''', (user_id, group_id)) - - row = await cursor.fetchone() - if not row: - return None - - return { - 'favorite_topics': json.loads(row[0]) if row[0] else [], - 'interaction_style': json.loads(row[1]) if row[1] else {}, - 'learning_preferences': json.loads(row[2]) if row[2] else {}, - 'adaptive_rate': row[3], - 'updated_at': row[4] - } - - except Exception as e: - self._logger.error(f"加载用户偏好失败: {e}") - return None - - async def save_conversation_context(self, group_id: str, context_data: Dict[str, Any]) -> bool: - """保存对话上下文""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - # 检查是否已存在表,如果不存在则创建 - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS conversation_contexts ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - context_id TEXT UNIQUE NOT NULL, - participants TEXT, -- JSON格式存储参与者列表 - current_topic TEXT, - emotion_state TEXT, -- JSON格式存储情感状态 - context_messages TEXT, -- JSON格式存储上下文消息 - start_time REAL NOT NULL, - last_updated REAL NOT NULL, - is_active BOOLEAN DEFAULT TRUE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - - await cursor.execute(''' - INSERT OR REPLACE INTO conversation_contexts - (group_id, context_id, participants, current_topic, emotion_state, context_messages, start_time, last_updated, is_active) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - group_id, - context_data.get('context_id'), - json.dumps(list(context_data.get('participants', set())), ensure_ascii=False), - context_data.get('current_topic'), - json.dumps(context_data.get('emotion_state', {}), ensure_ascii=False), - json.dumps(context_data.get('messages', []), ensure_ascii=False), - context_data.get('start_time', time.time()), - time.time(), - context_data.get('is_active', True) - )) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"保存对话上下文失败: {e}") - return False - - async def get_active_conversation_contexts(self, group_id: str) -> List[Dict[str, Any]]: - """获取活跃的对话上下文""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT context_id, participants, current_topic, emotion_state, context_messages, start_time, last_updated - FROM conversation_contexts - WHERE group_id = ? AND is_active = TRUE - ORDER BY last_updated DESC - ''', (group_id,)) - - contexts = [] - for row in await cursor.fetchall(): - contexts.append({ - 'context_id': row[0], - 'participants': set(json.loads(row[1])) if row[1] else set(), - 'current_topic': row[2], - 'emotion_state': json.loads(row[3]) if row[3] else {}, - 'messages': json.loads(row[4]) if row[4] else [], - 'start_time': row[5], - 'last_updated': row[6] - }) - - return contexts - - except Exception as e: - self._logger.error(f"获取对话上下文失败: {e}") - return [] - - async def save_learning_session_record(self, group_id: str, session_data: Dict[str, Any]) -> bool: - """保存学习会话记录""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT OR REPLACE INTO learning_sessions - (session_id, start_time, end_time, messages_processed, filtered_messages, - style_updates, quality_score, success) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - session_data.get('session_id'), - session_data.get('start_time'), - session_data.get('end_time'), - session_data.get('messages_processed', 0), - session_data.get('filtered_messages', 0), - session_data.get('style_updates', 0), - session_data.get('quality_score', 0.0), - session_data.get('success', False) - )) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"保存学习会话记录失败: {e}") - return False - - async def get_recent_learning_sessions(self, group_id: str, days: int = 7) -> List[Dict[str, Any]]: - """获取最近的学习会话记录""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - start_time = time.time() - (days * 24 * 3600) - - await cursor.execute(''' - SELECT session_id, start_time, end_time, messages_processed, filtered_messages, - style_updates, quality_score, success - FROM learning_sessions - WHERE start_time >= ? - ORDER BY start_time DESC - ''', (start_time,)) - - sessions = [] - for row in await cursor.fetchall(): - sessions.append({ - 'session_id': row[0], - 'start_time': row[1], - 'end_time': row[2], - 'messages_processed': row[3], - 'filtered_messages': row[4], - 'style_updates': row[5], - 'quality_score': row[6], - 'success': row[7] - }) - - return sessions - - except Exception as e: - self._logger.error(f"获取学习会话记录失败: {e}") - return [] - - # ========== 好感度系统数据库操作方法 ========== - - async def get_user_affection(self, group_id: str, user_id: str) -> Optional[Dict[str, Any]]: - """获取用户好感度""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT affection_level, last_interaction, last_updated, interaction_count - FROM user_affection WHERE user_id = ? AND group_id = ? - ''', (user_id, group_id)) - - row = await cursor.fetchone() - if not row: - return None - - return { - 'user_id': user_id, - 'group_id': group_id, - 'affection_level': row[0], - 'last_interaction': row[1], - 'last_updated': row[2], - 'interaction_count': row[3] - } - - except Exception as e: - self._logger.error(f"获取用户好感度失败: {e}") - return None - - async def update_user_affection(self, group_id: str, user_id: str, - new_level: int, change_reason: str = "", - bot_mood: str = "") -> bool: - """更新用户好感度""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - current_time = time.time() - - # 获取当前好感度 - current_affection = await self.get_user_affection(group_id, user_id) - previous_level = current_affection['affection_level'] if current_affection else 0 - interaction_count = current_affection['interaction_count'] if current_affection else 0 - - # 更新或插入好感度记录 - await cursor.execute(''' - INSERT OR REPLACE INTO user_affection - (user_id, group_id, affection_level, last_interaction, last_updated, interaction_count) - VALUES (?, ?, ?, ?, ?, ?) - ''', (user_id, group_id, new_level, current_time, current_time, interaction_count + 1)) - - # 记录好感度变化历史 - change_amount = new_level - previous_level - if change_amount != 0: - await cursor.execute(''' - INSERT INTO affection_history - (user_id, group_id, change_amount, previous_level, new_level, - change_reason, bot_mood, timestamp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', (user_id, group_id, change_amount, previous_level, new_level, - change_reason, bot_mood, current_time)) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"更新用户好感度失败: {e}") - return False - - async def get_all_user_affections(self, group_id: str) -> List[Dict[str, Any]]: - """获取群内所有用户好感度""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT user_id, affection_level, last_interaction, last_updated, interaction_count - FROM user_affection - WHERE group_id = ? - ORDER BY affection_level DESC - ''', (group_id,)) - - affections = [] - for row in await cursor.fetchall(): - affections.append({ - 'user_id': row[0], - 'group_id': group_id, - 'affection_level': row[1], - 'last_interaction': row[2], - 'last_updated': row[3], - 'interaction_count': row[4] - }) - - return affections - - except Exception as e: - self._logger.error(f"获取所有用户好感度失败: {e}") - return [] - - async def get_total_affection(self, group_id: str) -> int: - """获取群内总好感度""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT SUM(affection_level) FROM user_affection WHERE group_id = ? - ''', (group_id,)) - - result = await cursor.fetchone() - return result[0] if result[0] is not None else 0 - - except Exception as e: - self._logger.error(f"获取总好感度失败: {e}") - return 0 - - async def save_bot_mood(self, group_id: str, mood_type: str, mood_intensity: float, - mood_description: str, duration_hours: int = 24) -> bool: - """保存bot情绪状态""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - current_time = time.time() - end_time = current_time + (duration_hours * 3600) - - # 将之前的情绪设为非活跃状态 - await cursor.execute(''' - UPDATE bot_mood SET is_active = FALSE, end_time = ? WHERE group_id = ? AND is_active = TRUE - ''', (current_time, group_id)) - - # 插入新的情绪状态 - await cursor.execute(''' - INSERT INTO bot_mood - (group_id, mood_type, mood_intensity, mood_description, start_time, end_time, is_active) - VALUES (?, ?, ?, ?, ?, ?, TRUE) - ''', (group_id, mood_type, mood_intensity, mood_description, current_time, end_time)) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"保存bot情绪失败: {e}") - return False - - async def get_current_bot_mood(self, group_id: str) -> Optional[Dict[str, Any]]: - """获取当前bot情绪""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - current_time = time.time() - - await cursor.execute(''' - SELECT mood_type, mood_intensity, mood_description, start_time, end_time - FROM bot_mood - WHERE group_id = ? AND is_active = TRUE AND start_time <= ? AND (end_time IS NULL OR end_time > ?) - ORDER BY start_time DESC - LIMIT 1 - ''', (group_id, current_time, current_time)) - - row = await cursor.fetchone() - if not row: - return None - - return { - 'mood_type': row[0], - 'mood_intensity': row[1], - 'mood_description': row[2], - 'start_time': row[3], - 'end_time': row[4] - } - - except Exception as e: - self._logger.error(f"获取当前bot情绪失败: {e}") - return None - - async def get_affection_history(self, group_id: str, user_id: str = None, - days: int = 7) -> List[Dict[str, Any]]: - """获取好感度变化历史""" - conn = await self.get_group_connection(group_id) - cursor = await conn.cursor() - - try: - start_time = time.time() - (days * 24 * 3600) - - if user_id: - await cursor.execute(''' - SELECT user_id, change_amount, previous_level, new_level, - change_reason, bot_mood, timestamp - FROM affection_history - WHERE group_id = ? AND user_id = ? AND timestamp >= ? - ORDER BY timestamp DESC - ''', (group_id, user_id, start_time)) - else: - await cursor.execute(''' - SELECT user_id, change_amount, previous_level, new_level, - change_reason, bot_mood, timestamp - FROM affection_history - WHERE group_id = ? AND timestamp >= ? - ORDER BY timestamp DESC - ''', (group_id, start_time)) - - history = [] - for row in await cursor.fetchall(): - history.append({ - 'user_id': row[0], - 'change_amount': row[1], - 'previous_level': row[2], - 'new_level': row[3], - 'change_reason': row[4], - 'bot_mood': row[5], - 'timestamp': row[6] - }) - - return history - - except Exception as e: - self._logger.error(f"获取好感度历史失败: {e}") - return [] - - async def record_llm_call_statistics(self, provider_type: str, model_name: str, - success: bool, response_time_ms: int) -> bool: - """记录LLM调用统计数据""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - current_time = time.time() - - # 查询当前统计数据 - await cursor.execute(''' - SELECT total_calls, success_calls, failed_calls, total_response_time_ms - FROM llm_call_statistics - WHERE provider_type = ? AND model_name = ? - ''', (provider_type, model_name)) - - row = await cursor.fetchone() - if row: - # 更新现有记录 - total_calls = row[0] + 1 - success_calls = row[1] + (1 if success else 0) - failed_calls = row[2] + (0 if success else 1) - total_response_time = row[3] + response_time_ms - avg_response_time = total_response_time / total_calls - success_rate = success_calls / total_calls - - await cursor.execute(''' - UPDATE llm_call_statistics - SET total_calls = ?, success_calls = ?, failed_calls = ?, - total_response_time_ms = ?, avg_response_time_ms = ?, - success_rate = ?, last_call_time = ?, updated_at = CURRENT_TIMESTAMP - WHERE provider_type = ? AND model_name = ? - ''', (total_calls, success_calls, failed_calls, total_response_time, - avg_response_time, success_rate, current_time, provider_type, model_name)) - else: - # 插入新记录 - success_calls = 1 if success else 0 - failed_calls = 0 if success else 1 - success_rate = 1.0 if success else 0.0 - - await cursor.execute(''' - INSERT INTO llm_call_statistics - (provider_type, model_name, total_calls, success_calls, failed_calls, - total_response_time_ms, avg_response_time_ms, success_rate, last_call_time) - VALUES (?, ?, 1, ?, ?, ?, ?, ?, ?) - ''', (provider_type, model_name, success_calls, failed_calls, - response_time_ms, response_time_ms, success_rate, current_time)) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"记录LLM调用统计失败: {e}") - return False - finally: - await cursor.close() - - async def get_llm_call_statistics(self) -> Dict[str, Any]: - """获取LLM调用统计数据""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - SELECT provider_type, model_name, total_calls, success_calls, failed_calls, - avg_response_time_ms, success_rate, last_call_time - FROM llm_call_statistics - ORDER BY provider_type, total_calls DESC - ''') - - statistics = {} - total_calls = 0 - - for row in await cursor.fetchall(): - provider_type = row[0] - model_name = row[1] or f"{provider_type}_model" - - stats = { - "total_calls": row[2], - "success_calls": row[3], - "failed_calls": row[4], - "avg_response_time_ms": row[5] or 0, - "success_rate": row[6] or 0, - "last_call_time": row[7] - } - - statistics[f"{provider_type}_{model_name}"] = stats - total_calls += row[2] - - # 如果没有统计数据,返回默认结构 - if not statistics: - statistics = { - "filter_provider": {"total_calls": 0, "avg_response_time_ms": 0, "success_rate": 0, "error_count": 0}, - "refine_provider": {"total_calls": 0, "avg_response_time_ms": 0, "success_rate": 0, "error_count": 0}, - "reinforce_provider": {"total_calls": 0, "avg_response_time_ms": 0, "success_rate": 0, "error_count": 0} - } - - return { - "statistics": statistics, - "total_calls": total_calls - } - - except Exception as e: - self._logger.error(f"获取LLM调用统计失败: {e}") - return { - "statistics": { - "filter_provider": {"total_calls": 0, "avg_response_time_ms": 0, "success_rate": 0, "error_count": 0}, - "refine_provider": {"total_calls": 0, "avg_response_time_ms": 0, "success_rate": 0, "error_count": 0}, - "reinforce_provider": {"total_calls": 0, "avg_response_time_ms": 0, "success_rate": 0, "error_count": 0} - }, - "total_calls": 0 - } - finally: - await cursor.close() - - async def export_messages_learning_data(self) -> Dict[str, Any]: - """导出消息学习数据""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 导出原始消息 - await cursor.execute(''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp, processed - FROM raw_messages ORDER BY timestamp DESC - ''') - raw_messages = [] - for row in await cursor.fetchall(): - raw_messages.append({ - 'id': row[0], - 'sender_id': row[1], - 'sender_name': row[2], - 'message': row[3], - 'group_id': row[4], - 'platform': row[5], - 'timestamp': row[6], - 'processed': bool(row[7]) - }) - - # 导出筛选消息 - await cursor.execute(''' - SELECT id, raw_message_id, message, sender_id, group_id, confidence, - filter_reason, timestamp, used_for_learning, quality_scores - FROM filtered_messages ORDER BY timestamp DESC - ''') - filtered_messages = [] - for row in await cursor.fetchall(): - quality_scores = {} - try: - if row[9]: # quality_scores - quality_scores = json.loads(row[9]) - except (json.JSONDecodeError, TypeError): - pass - - filtered_messages.append({ - 'id': row[0], - 'raw_message_id': row[1], - 'message': row[2], - 'sender_id': row[3], - 'group_id': row[4], - 'confidence': row[5], - 'filter_reason': row[6], - 'timestamp': row[7], - 'used_for_learning': bool(row[8]), - 'quality_scores': quality_scores - }) - - # 导出学习批次记录 - await cursor.execute(''' - SELECT id, group_id, start_time, end_time, quality_score, - processed_messages, batch_name, message_count, - filtered_count, success, error_message - FROM learning_batches ORDER BY start_time DESC - ''') - learning_batches = [] - for row in await cursor.fetchall(): - learning_batches.append({ - 'id': row[0], - 'group_id': row[1], - 'start_time': row[2], - 'end_time': row[3], - 'quality_score': row[4], - 'processed_messages': row[5], - 'batch_name': row[6], - 'message_count': row[7], - 'filtered_count': row[8], - 'success': bool(row[9]), - 'error_message': row[10] - }) - - # 导出人格更新记录 - await cursor.execute(''' - SELECT id, timestamp, group_id, update_type, original_content, - new_content, reason, status, reviewer_comment, review_time - FROM persona_update_records ORDER BY timestamp DESC - ''') - persona_update_records = [] - for row in await cursor.fetchall(): - persona_update_records.append({ - 'id': row[0], - 'timestamp': row[1], - 'group_id': row[2], - 'update_type': row[3], - 'original_content': row[4], - 'new_content': row[5], - 'reason': row[6], - 'status': row[7], - 'reviewer_comment': row[8], - 'review_time': row[9] - }) - - # 获取统计信息 - statistics = await self.get_messages_statistics() - - export_data = { - 'export_timestamp': time.time(), - 'export_date': datetime.now().isoformat(), - 'statistics': statistics, - 'raw_messages': raw_messages, - 'filtered_messages': filtered_messages, - 'learning_batches': learning_batches, - 'persona_update_records': persona_update_records - } - - self._logger.info(f"成功导出学习数据: {len(raw_messages)} 条原始消息, {len(filtered_messages)} 条筛选消息") - return export_data - - except Exception as e: - self._logger.error(f"导出消息学习数据失败: {e}", exc_info=True) - raise DataStorageError(f"导出消息学习数据失败: {str(e)}") - finally: - await cursor.close() - - async def clear_all_messages_data(self): - """清空所有消息数据""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 清空所有表的数据 - tables_to_clear = [ - 'raw_messages', - 'filtered_messages', - 'learning_batches', - 'persona_update_records', - 'reinforcement_learning_results', - 'persona_fusion_history', - 'strategy_optimization_results', - 'learning_performance_history' - ] - - for table in tables_to_clear: - await cursor.execute(f'DELETE FROM {table}') - self._logger.debug(f"已清空表: {table}") - - await conn.commit() - self._logger.info("所有消息数据已清空") - - except Exception as e: - self._logger.error(f"清空所有消息数据失败: {e}", exc_info=True) - raise DataStorageError(f"清空所有消息数据失败: {str(e)}") - finally: - await cursor.close() - - async def get_learning_patterns_data(self) -> Dict[str, Any]: - """获取学习模式数据""" - try: - # 首先尝试获取表达模式数据(来自expression_patterns表) - expression_patterns = await self.get_expression_patterns_for_webui() - - # 获取其他学习数据 - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 检查是否有原始消息数据 - await cursor.execute('SELECT COUNT(*) FROM raw_messages') - raw_data_count = (await cursor.fetchone())[0] - - # 检查是否有筛选消息数据 - await cursor.execute('SELECT COUNT(*) FROM filtered_messages') - filtered_data_count = (await cursor.fetchone())[0] - - # 如果有表达模式数据,使用它;否则使用默认提示 - if expression_patterns: - emotion_patterns = [] - for pattern in expression_patterns[:10]: # 显示前10个 - situation = pattern.get('situation', '场景描述').strip() - expression = pattern.get('expression', '表达方式').strip() - weight = pattern.get('weight', 0) - - # 确保不显示空的或无意义的数据 - if situation and expression and situation != '未知' and expression != '未知': - pattern_name = f"情感表达-{situation[:10]}" # 截取前10个字符作为模式名 - emotion_patterns.append({ - 'pattern': pattern_name, - 'confidence': round(weight * 20, 2), # 将权重转换为置信度百分比 - 'frequency': max(1, int(weight)) # 确保频率至少为1 - }) - - # 如果没有有效的表达模式,添加一个说明 - if not emotion_patterns: - emotion_patterns.append({ - 'pattern': '正在学习表达模式', - 'confidence': 30.0, - 'frequency': 1 - }) - else: - # 如果没有表达模式,但有原始数据,显示学习中状态 - if raw_data_count > 0: - emotion_patterns = [{ - 'pattern': '正在学习表达模式,请稍候...', - 'confidence': 50.0, - 'frequency': raw_data_count - }] - else: - emotion_patterns = [{ - 'pattern': '暂无对话数据,请先进行对话', - 'confidence': 0.0, - 'frequency': 0 - }] - - # 语言风格分析(基于原始消息长度分布) - await cursor.execute(''' - SELECT - CASE - WHEN LENGTH(message) < 10 THEN '简短表达' - WHEN LENGTH(message) < 30 THEN '适中表达' - WHEN LENGTH(message) < 100 THEN '详细表达' - ELSE '长篇表达' - END as style_type, - COUNT(*) as count - FROM raw_messages - WHERE message IS NOT NULL AND LENGTH(TRIM(message)) > 0 - GROUP BY style_type - ''') - - language_patterns = [] - for row in await cursor.fetchall(): - language_patterns.append({ - 'style': row[0], # 改为style字段以匹配前端 - 'type': row[0], # 保留type用于兼容性 - 'count': row[1], - 'frequency': row[1], # 添加frequency字段用于前端显示 - 'context': 'general', - 'environment': 'general' - }) - - # 如果没有语言模式数据 - if not language_patterns: - language_patterns = [{ - 'style': '暂无语言风格数据', - 'type': '暂无语言风格数据', - 'count': 0, - 'frequency': 0, - 'context': 'general', - 'environment': 'general' - }] - - # 话题偏好分析(基于群组活跃度和智能主题识别) - topic_preferences = [] - - # 获取各个群组的消息数据进行主题分析 - await cursor.execute(''' - SELECT - group_id, - COUNT(*) as message_count, - AVG(LENGTH(message)) as avg_length - FROM raw_messages - WHERE group_id IS NOT NULL AND LENGTH(TRIM(message)) > 3 - GROUP BY group_id - HAVING COUNT(*) > 10 - ORDER BY message_count DESC - LIMIT 8 - ''') - - group_data = await cursor.fetchall() - - # 先收集所有group_data,避免嵌套查询 - for row in group_data: - try: - # 添加行数据验证 - if len(row) < 3: - self._logger.warning(f"群组话题数据行不完整 (期望3个字段,实际{len(row)}个),跳过: {row}") - continue - - group_id = row[0] - message_count = int(row[1]) if row[1] else 0 - avg_length = float(row[2]) if row[2] else 0 - - # 创建新的cursor来执行嵌套查询(避免cursor状态冲突) - async with self.get_db_connection() as nested_conn: - nested_cursor = await nested_conn.cursor() - - # 获取该群组的代表性消息进行主题分析 - await nested_cursor.execute(''' - SELECT message - FROM raw_messages - WHERE group_id = ? AND LENGTH(TRIM(message)) > 5 AND LENGTH(TRIM(message)) < 200 - ORDER BY LENGTH(message) DESC, timestamp DESC - LIMIT 20 - ''', (group_id,)) - - messages = await nested_cursor.fetchall() - await nested_cursor.close() - - if not messages: - continue - - # 智能主题识别 - topic_analysis = self._analyze_topic_from_messages([msg[0] for msg in messages]) - topic_name = topic_analysis['topic'] - conversation_style = topic_analysis['style'] - - # 根据消息长度和数量推断兴趣度 - interest_level = min(100, max(10, (message_count * avg_length) / 50)) - - topic_preferences.append({ - 'topic': topic_name, - 'style': conversation_style, - 'interest_level': round(interest_level, 1) - }) - except Exception as row_error: - self._logger.warning(f"处理群组话题数据行时出错,跳过: {row_error}, row: {row if 'row' in locals() and len(str(row)) < 100 else 'row too long'}") - continue - - # 去重:确保每个话题只出现一次,保留兴趣度最高的 - seen_topics = {} - for pref in topic_preferences: - try: - topic = pref['topic'] - # 确保 interest_level 是数字类型 - current_interest = float(pref.get('interest_level', 0)) - pref['interest_level'] = current_interest - - if topic not in seen_topics: - seen_topics[topic] = pref - else: - existing_interest = float(seen_topics[topic].get('interest_level', 0)) - if current_interest > existing_interest: - seen_topics[topic] = pref - except (ValueError, TypeError, KeyError) as e: - self._logger.warning(f"处理话题偏好时出错,跳过: {e}, pref: {pref}") - - topic_preferences = list(seen_topics.values()) - - # 如果没有话题偏好数据 - if not topic_preferences: - topic_preferences = [{ - 'topic': '暂无话题数据', - 'style': '等待中', - 'interest_level': 0.0 - }] - - return { - 'emotion_patterns': emotion_patterns, - 'language_patterns': language_patterns, - 'topic_preferences': topic_preferences - } - - except Exception as e: - self._logger.error(f"获取学习模式数据失败: {e}") - return { - 'emotion_patterns': [ - {'pattern': '数据获取失败,请检查系统状态', 'confidence': 0, 'frequency': 0} - ], - 'language_patterns': [ - {'type': '数据获取失败', 'count': 0, 'environment': 'general'} - ], - 'topic_preferences': [ - {'topic': '数据获取失败', 'style': 'normal', 'interest_level': 0} - ] - } - finally: - if 'cursor' in locals(): - await cursor.close() - - async def get_expression_patterns_for_webui(self, limit: int = 20) -> List[Dict[str, Any]]: - """获取表达模式数据用于WebUI显示""" - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 检查表是否存在 - await cursor.execute(''' - SELECT name FROM sqlite_master - WHERE type='table' AND name='expression_patterns' - ''') - - table_exists = await cursor.fetchone() - if not table_exists: - self._logger.debug("expression_patterns表不存在") - return [] - - # 获取表达模式数据 - await cursor.execute(''' - SELECT situation, expression, weight, last_active_time, group_id - FROM expression_patterns - ORDER BY weight DESC, last_active_time DESC - LIMIT ? - ''', (limit,)) - - patterns = [] - for row in await cursor.fetchall(): - try: - # 添加行数据验证 - if len(row) < 5: - self._logger.warning(f"表达模式行数据不完整 (期望5个字段,实际{len(row)}个),跳过: {row}") - continue - - patterns.append({ - 'situation': row[0], - 'expression': row[1], - 'weight': float(row[2]) if row[2] else 0.0, - 'last_active_time': row[3], - 'group_id': row[4] - }) - except Exception as row_error: - self._logger.warning(f"处理表达模式行时出错,跳过: {row_error}, row: {row}") - continue - - return patterns - - except Exception as e: - self._logger.error(f"获取表达模式失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def create_style_learning_review(self, review_data: Dict[str, Any]) -> int: - """创建对话风格学习审查记录""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 确保审查表存在 - await self._ensure_style_review_table_exists(cursor) - - # 插入审查记录 - await cursor.execute(''' - INSERT INTO style_learning_reviews - (type, group_id, timestamp, learned_patterns, few_shots_content, status, description) - VALUES (?, ?, ?, ?, ?, ?, ?) - ''', ( - review_data['type'], - review_data['group_id'], - review_data['timestamp'], - json.dumps(review_data['learned_patterns'], ensure_ascii=False), - review_data['few_shots_content'], - review_data['status'], - review_data['description'] - )) - - review_id = cursor.lastrowid - await conn.commit() - - self._logger.info(f"创建风格学习审查记录成功,ID: {review_id}") - return review_id - - except Exception as e: - self._logger.error(f"创建风格学习审查记录失败: {e}") - raise DataStorageError(f"创建风格学习审查记录失败: {str(e)}") - - async def _ensure_style_review_table_exists(self, cursor): - """确保风格学习审查表存在""" - # 根据数据库类型选择不同的 DDL - if self.config.db_type.lower() == 'mysql': - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS style_learning_reviews ( - id INT PRIMARY KEY AUTO_INCREMENT, - type VARCHAR(100) NOT NULL, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - learned_patterns TEXT, - few_shots_content TEXT, - status VARCHAR(50) DEFAULT 'pending', - description TEXT, - reviewer_comment TEXT, - review_time DOUBLE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - INDEX idx_status (status), - INDEX idx_group (group_id), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - - # ✅ 数据库迁移:添加缺失的字段(如果表已存在但缺少这些字段) - try: - # 检查并添加 reviewer_comment 字段 - await cursor.execute(''' - SELECT COUNT(*) - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = DATABASE() - AND TABLE_NAME = 'style_learning_reviews' - AND COLUMN_NAME = 'reviewer_comment' - ''') - if (await cursor.fetchone())[0] == 0: - await cursor.execute('ALTER TABLE style_learning_reviews ADD COLUMN reviewer_comment TEXT') - self._logger.info("✅ 迁移:已添加 reviewer_comment 字段到 style_learning_reviews 表") - - # 检查并添加 review_time 字段 - await cursor.execute(''' - SELECT COUNT(*) - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = DATABASE() - AND TABLE_NAME = 'style_learning_reviews' - AND COLUMN_NAME = 'review_time' - ''') - if (await cursor.fetchone())[0] == 0: - await cursor.execute('ALTER TABLE style_learning_reviews ADD COLUMN review_time DOUBLE') - self._logger.info("✅ 迁移:已添加 review_time 字段到 style_learning_reviews 表") - except Exception as migration_error: - self._logger.warning(f"数据库迁移检查失败(可能是非 MySQL 数据库): {migration_error}") - else: - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS style_learning_reviews ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - type TEXT NOT NULL, - group_id TEXT NOT NULL, - timestamp REAL NOT NULL, - learned_patterns TEXT, - few_shots_content TEXT, - status TEXT DEFAULT 'pending', - description TEXT, - reviewer_comment TEXT, - review_time REAL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - ''') - - # ✅ SQLite 数据库迁移:添加缺失的字段 - try: - # 检查表结构 - await cursor.execute("PRAGMA table_info(style_learning_reviews)") - columns = {row[1] for row in await cursor.fetchall()} - - # 添加 reviewer_comment 字段(如果不存在) - if 'reviewer_comment' not in columns: - await cursor.execute('ALTER TABLE style_learning_reviews ADD COLUMN reviewer_comment TEXT') - self._logger.info("✅ 迁移:已添加 reviewer_comment 字段到 style_learning_reviews 表 (SQLite)") - - # 添加 review_time 字段(如果不存在) - if 'review_time' not in columns: - await cursor.execute('ALTER TABLE style_learning_reviews ADD COLUMN review_time REAL') - self._logger.info("✅ 迁移:已添加 review_time 字段到 style_learning_reviews 表 (SQLite)") - except Exception as migration_error: - self._logger.warning(f"SQLite 数据库迁移失败: {migration_error}") - - # 注意:get_pending_style_reviews 方法已在上面定义(约1456行),这里删除重复定义 - # 第一个版本是正确的,第二个版本有async with缩进bug - - async def get_pending_persona_learning_reviews(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取待审查的人格学习记录(质量不达标的学习结果)""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.get_pending_persona_learning_reviews_orm(limit) - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 确保表存在(根据数据库类型使用不同的DDL) - if self.config.db_type.lower() == 'mysql': - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INT PRIMARY KEY AUTO_INCREMENT, - timestamp DOUBLE NOT NULL, - group_id VARCHAR(255) NOT NULL, - update_type VARCHAR(100) NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, - confidence_score DOUBLE, - reason TEXT, - status VARCHAR(50) NOT NULL DEFAULT 'pending', - reviewer_comment TEXT, - review_time DOUBLE, - INDEX idx_status (status), - INDEX idx_group_id (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - else: - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp REAL NOT NULL, - group_id TEXT NOT NULL, - update_type TEXT NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, -- 建议的新内容(兼容字段) - confidence_score REAL, -- 置信度得分 - reason TEXT, - status TEXT NOT NULL DEFAULT 'pending', - reviewer_comment TEXT, - review_time REAL - ) - ''') - - # 尝试添加metadata列(如果表已存在但没有此列) - try: - await cursor.execute('ALTER TABLE persona_update_reviews ADD COLUMN metadata TEXT') - except: - pass # 列已存在 - - await cursor.execute(''' - SELECT id, timestamp, group_id, update_type, original_content, - new_content, proposed_content, confidence_score, reason, status, - reviewer_comment, review_time, metadata - FROM persona_update_reviews - WHERE status = 'pending' - ORDER BY timestamp DESC - LIMIT ? - ''', (limit,)) - - reviews = [] - import json - for row in await cursor.fetchall(): - # 确保有proposed_content字段,如果为空则使用new_content - proposed_content = row[6] if row[6] else row[5] # proposed_content或new_content - confidence_score = row[7] if row[7] is not None else 0.5 # 使用数据库中的置信度 - - # 解析metadata JSON - metadata = {} - if row[12]: # metadata字段 - try: - metadata = json.loads(row[12]) - except: - metadata = {} - - reviews.append({ - 'id': row[0], - 'timestamp': row[1], - 'group_id': row[2], - 'update_type': row[3], - 'original_content': row[4], - 'new_content': row[5], - 'proposed_content': proposed_content, - 'confidence_score': confidence_score, - 'reason': row[8], - 'status': row[9], - 'reviewer_comment': row[10], - 'review_time': row[11], - 'metadata': metadata # 添加metadata字段 - }) - - return reviews - - except Exception as e: - self._logger.error(f"获取待审查人格学习记录失败: {e}") - return [] - - async def update_persona_learning_review_status(self, review_id: int, status: str, comment: str = None, modified_content: str = None) -> bool: - """更新人格学习审查状态(使用 ORM,支持跨事件循环)""" - try: - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,无法更新人格学习审查状态") - return False - - from ..models.orm.learning import PersonaLearningReview - - async with self.db_engine.get_session() as session: - review = await session.get(PersonaLearningReview, review_id) - if not review: - self._logger.warning(f"未找到人格学习审查记录,ID: {review_id}") - return False - - review.status = status - review.reviewer_comment = comment - review.review_time = time.time() - - if modified_content: - review.proposed_content = modified_content - review.new_content = modified_content - - await session.commit() - self._logger.info(f"人格学习审查状态已更新,ID: {review_id}, 状态: {status}") - return True - - except Exception as e: - self._logger.error(f"更新人格学习审查状态失败: {e}") - return False - - async def delete_persona_learning_review_by_id(self, review_id: int) -> bool: - """删除指定ID的人格学习审查记录""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.delete_persona_learning_review_by_id_orm(review_id) - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 根据数据库类型使用不同的占位符 - placeholder = '%s' if self.config.db_type.lower() == 'mysql' else '?' - - # 删除审查记录 - await cursor.execute(f''' - DELETE FROM persona_update_reviews WHERE id = {placeholder} - ''', (review_id,)) - - await conn.commit() - deleted_count = cursor.rowcount - - if deleted_count > 0: - self._logger.info(f"成功删除人格学习审查记录,ID: {review_id}") - return True - else: - self._logger.warning(f"未找到要删除的人格学习审查记录,ID: {review_id}") - return False - - except Exception as e: - self._logger.error(f"删除人格学习审查记录失败: {e}") - return False - - async def delete_all_persona_learning_reviews(self, group_id: Optional[str] = None) -> int: - """ - 批量删除人格学习审查记录 - - Args: - group_id: 群组ID(可选),如果指定则只删除该群组的记录,否则删除所有记录 - - Returns: - int: 删除的记录数量 - """ - try: - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - from ..models.orm.learning import PersonaLearningReview - from sqlalchemy import delete as sa_delete - - async with self.db_engine.get_session() as session: - if group_id: - stmt = sa_delete(PersonaLearningReview).where(PersonaLearningReview.group_id == group_id) - self._logger.info(f"删除群组 {group_id} 的所有人格学习审查记录") - else: - stmt = sa_delete(PersonaLearningReview) - self._logger.info("删除所有人格学习审查记录") - - result = await session.execute(stmt) - await session.commit() - deleted_count = result.rowcount - self._logger.info(f"成功删除 {deleted_count} 条人格学习审查记录") - return deleted_count - - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 根据数据库类型使用不同的占位符 - placeholder = '%s' if self.config.db_type.lower() == 'mysql' else '?' - - if group_id: - # 删除指定群组的审查记录 - await cursor.execute(f''' - DELETE FROM persona_update_reviews WHERE group_id = {placeholder} - ''', (group_id,)) - self._logger.info(f"删除群组 {group_id} 的所有人格学习审查记录") - else: - # 删除所有审查记录 - await cursor.execute(''' - DELETE FROM persona_update_reviews - ''') - self._logger.info("删除所有人格学习审查记录") - - await conn.commit() - deleted_count = cursor.rowcount - - self._logger.info(f"✅ 成功删除 {deleted_count} 条人格学习审查记录") - return deleted_count - - except Exception as e: - self._logger.error(f"批量删除人格学习审查记录失败: {e}") - return 0 - - async def get_persona_learning_review_by_id(self, review_id: int) -> Optional[Dict[str, Any]]: - """获取指定ID的人格学习审查记录详情""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.get_persona_learning_review_by_id_orm(review_id) - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - SELECT id, group_id, original_content, new_content, proposed_content, - confidence_score, reason, status, reviewer_comment, review_time, timestamp - FROM persona_update_reviews - WHERE id = ? - ''', (review_id,)) - - row = await cursor.fetchone() - if row: - return { - 'id': row[0], - 'group_id': row[1], - 'original_content': row[2], - 'new_content': row[3], - 'proposed_content': row[4] if row[4] else row[3], # proposed_content或new_content - 'confidence_score': row[5] if row[5] is not None else 0.5, - 'reason': row[6], - 'status': row[7], - 'reviewer_comment': row[8], - 'review_time': row[9], - 'timestamp': row[10] - } - return None - - except Exception as e: - self._logger.error(f"获取人格学习审查记录失败: {e}") - return None - - async def save_style_learning_record(self, record_data: Dict[str, Any]) -> bool: - """保存风格学习记录到数据库""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - INSERT INTO style_learning_records - (style_type, learned_patterns, confidence_score, sample_count, group_id, learning_time) - VALUES (?, ?, ?, ?, ?, ?) - ''', ( - record_data.get('style_type'), - record_data.get('learned_patterns'), - record_data.get('confidence_score'), - record_data.get('sample_count'), - record_data.get('group_id'), - record_data.get('learning_time') - )) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"保存风格学习记录失败: {e}") - return False - - async def save_language_style_pattern(self, pattern_data: Dict[str, Any]) -> bool: - """保存语言风格模式到数据库""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 先检查是否已存在相同的语言风格 - await cursor.execute(''' - SELECT id FROM language_style_patterns - WHERE language_style = ? AND group_id = ? - ''', (pattern_data.get('language_style'), pattern_data.get('group_id'))) - - existing = await cursor.fetchone() - - if existing: - # 更新现有记录 - await cursor.execute(''' - UPDATE language_style_patterns - SET example_phrases = ?, usage_frequency = ?, context_type = ?, last_updated = ? - WHERE id = ? - ''', ( - pattern_data.get('example_phrases'), - pattern_data.get('usage_frequency'), - pattern_data.get('context_type'), - pattern_data.get('last_updated'), - existing[0] - )) - else: - # 插入新记录 - await cursor.execute(''' - INSERT INTO language_style_patterns - (language_style, example_phrases, usage_frequency, context_type, group_id, last_updated) - VALUES (?, ?, ?, ?, ?, ?) - ''', ( - pattern_data.get('language_style'), - pattern_data.get('example_phrases'), - pattern_data.get('usage_frequency'), - pattern_data.get('context_type'), - pattern_data.get('group_id'), - pattern_data.get('last_updated') - )) - - await conn.commit() - return True - - except Exception as e: - self._logger.error(f"保存语言风格模式失败: {e}") - return False - - async def get_reviewed_persona_learning_updates(self, limit: int = 50, offset: int = 0, status_filter: str = None) -> List[Dict[str, Any]]: - """获取已审查的人格学习更新记录""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.get_reviewed_persona_learning_updates_orm(limit, offset, status_filter) - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 构建查询条件 - where_clause = "WHERE status != 'pending'" - params = [] - - if status_filter: - where_clause += " AND status = ?" - params.append(status_filter) - - # 首先检查表是否存在并获取表结构 - await cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='persona_update_reviews'") - table_exists = await cursor.fetchone() - - if not table_exists: - self._logger.info("persona_update_reviews表不存在,返回空列表") - return [] - - # 检查表结构,确定正确的字段名 - await cursor.execute("PRAGMA table_info(persona_update_reviews)") - columns = await cursor.fetchall() - column_names = [col[1] for col in columns] - - # 根据实际的列名构建查询 - if 'proposed_content' in column_names: - content_field = 'proposed_content' - elif 'new_content' in column_names: - content_field = 'new_content' - else: - # 如果两个字段都不存在,使用原始内容 - content_field = 'original_content' - - # 检查是否有metadata列 - has_metadata = 'metadata' in column_names - - # 使用实际存在的字段进行查询,并处理NULL值 - metadata_field = ', metadata' if has_metadata else '' - await cursor.execute(f''' - SELECT id, group_id, original_content, {content_field}, reason, - status, reviewer_comment, review_time, timestamp{metadata_field} - FROM persona_update_reviews - {where_clause} - ORDER BY COALESCE(review_time, timestamp) DESC - LIMIT ? OFFSET ? - ''', params + [limit, offset]) - - rows = await cursor.fetchall() - updates = [] - - import json - for row in rows: - # 解析metadata(如果存在) - metadata = {} - if has_metadata and len(row) > 9 and row[9]: - try: - metadata = json.loads(row[9]) - except: - metadata = {} - - updates.append({ - 'id': f"persona_learning_{row[0]}", - 'group_id': row[1] or 'default', - 'original_content': row[2] or '', - 'proposed_content': row[3] or '', # 使用实际存在的字段 - 'reason': row[4] or '人格学习更新', - 'confidence_score': metadata.get('confidence_score', 0.8), # 从metadata获取或使用默认值 - 'status': row[5], - 'reviewer_comment': row[6] or '', - 'review_time': row[7] if row[7] else 0, - 'timestamp': row[8] if row[8] else 0, - 'update_type': 'persona_learning_review', - # 添加metadata中的关键字段 - 'features_content': metadata.get('features_content', ''), - 'llm_response': metadata.get('llm_response', ''), - 'total_raw_messages': metadata.get('total_raw_messages', 0), - 'messages_analyzed': metadata.get('messages_analyzed', 0), - 'metadata': metadata - }) - - return updates - - except Exception as e: - self._logger.error(f"获取已审查人格学习记录失败: {e}") - # 如果是表或列不存在的错误,返回空列表 - if "no such table" in str(e).lower() or "no such column" in str(e).lower(): - self._logger.info("人格学习审查表或字段不存在,返回空列表") - return [] - return [] - - async def get_reviewed_style_learning_updates(self, limit: int = 50, offset: int = 0, status_filter: str = None) -> List[Dict[str, Any]]: - """获取已审查的风格学习更新记录""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.get_reviewed_style_learning_updates_orm(limit, offset, status_filter) - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 构建查询条件 - where_clause = "WHERE status != 'pending'" - params = [] - - if status_filter: - where_clause += " AND status = ?" - params.append(status_filter) - - # 使用正确的字段名,没有review_time字段,使用updated_at,并处理NULL值 - await cursor.execute(f''' - SELECT id, type, group_id, timestamp, learned_patterns, status, updated_at, description - FROM style_learning_reviews - {where_clause} - ORDER BY COALESCE(updated_at, timestamp) DESC - LIMIT ? OFFSET ? - ''', params + [limit, offset]) - - rows = await cursor.fetchall() - updates = [] - - for row in rows: - # 添加行数据验证 - try: - if len(row) < 8: - self._logger.warning(f"风格学习记录行数据不完整,跳过: {row}") - continue - - # 尝试解析learned_patterns以获取更多信息 - try: - learned_patterns = json.loads(row[4]) if row[4] else {} - reason = learned_patterns.get('reason', '风格学习更新') - original_content = learned_patterns.get('original_content', '原始风格特征') - proposed_content = learned_patterns.get('proposed_content', row[4]) # 使用完整的learned_patterns作为proposed_content - confidence_score = learned_patterns.get('confidence_score', 0.8) - except (json.JSONDecodeError, AttributeError): - reason = row[7] if len(row) > 7 and row[7] else '风格学习更新' # 使用description字段 - original_content = '原始风格特征' - proposed_content = row[4] if len(row) > 4 and row[4] else '无内容' - confidence_score = 0.8 - - updates.append({ - 'id': row[0], - 'group_id': row[2], - 'original_content': original_content, - 'proposed_content': proposed_content, - 'reason': reason, - 'confidence_score': confidence_score, - 'status': row[5], - 'reviewer_comment': '', # 风格审查没有备注字段 - 'review_time': row[6] if len(row) > 6 else None, # 使用updated_at字段 - 'timestamp': row[3], - 'update_type': f'style_learning_{row[1]}' - }) - except Exception as row_error: - self._logger.warning(f"处理风格学习记录行时出错,跳过: {row_error}, row: {row if len(row) < 20 else 'too long'}") - - return updates - - except Exception as e: - self._logger.error(f"获取已审查风格学习记录失败: {e}") - # 如果表不存在,返回空列表 - if "no such table" in str(e).lower(): - self._logger.info("风格学习审查表不存在,返回空列表") - return [] - return [] - - async def get_reviewed_persona_update_records(self, limit: int = 50, offset: int = 0, status_filter: str = None) -> List[Dict[str, Any]]: - """获取已审查的传统人格更新记录""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 构建查询条件 - where_clause = "WHERE status != 'pending'" - params = [] - - if status_filter: - where_clause += " AND status = ?" - params.append(status_filter) - - query = f''' - SELECT id, timestamp, group_id, update_type, original_content, new_content, - reason, status, reviewer_comment, review_time - FROM persona_update_records - {where_clause} - ORDER BY COALESCE(review_time, timestamp) DESC - LIMIT ? OFFSET ? - ''' - - self._logger.debug(f"执行人格更新记录查询: params={params + [limit, offset]}") - await cursor.execute(query, params + [limit, offset]) - - rows = await cursor.fetchall() - records = [] - - for row in rows: - # 添加行数据验证 - try: - if len(row) < 10: - self._logger.warning(f"人格更新记录行数据不完整 (期望10个字段,实际{len(row)}个),跳过: {row}") - continue - - records.append({ - 'id': row[0], - 'timestamp': row[1], - 'group_id': row[2], - 'update_type': row[3], - 'original_content': row[4], - 'new_content': row[5], - 'reason': row[6], - 'status': row[7], - 'reviewer_comment': row[8] if row[8] else '', - 'review_time': row[9] - }) - except Exception as row_error: - self._logger.warning(f"处理人格更新记录行时出错,跳过: {row_error}, row: {row if len(row) < 20 else 'too long'}") - - return records - - except Exception as e: - self._logger.error(f"获取已审查传统人格更新记录失败: {e}") - return [] - - async def update_style_review_status(self, review_id: int, status: str, group_id: str = None) -> bool: - """更新风格学习审查状态""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.update_style_review_status_orm(review_id, status, group_id) - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - UPDATE style_learning_reviews - SET status = ?, updated_at = ? - WHERE id = ? - ''', (status, time.time(), review_id)) - - await conn.commit() - - if cursor.rowcount > 0: - self._logger.info(f"更新风格学习审查状态成功: ID={review_id}, 状态={status}") - return True - else: - self._logger.warning(f"更新风格学习审查状态失败: 未找到ID={review_id}的记录") - return False - - except Exception as e: - self._logger.error(f"更新风格学习审查状态失败: {e}") - return False - - async def delete_style_review_by_id(self, review_id: int) -> bool: - """删除指定ID的风格学习审查记录""" - # 优先使用 ORM(支持跨事件循环) - if self.db_engine: - return await self.delete_style_review_by_id_orm(review_id) - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 删除审查记录 - await cursor.execute(''' - DELETE FROM style_learning_reviews WHERE id = ? - ''', (review_id,)) - - await conn.commit() - deleted_count = cursor.rowcount - - await cursor.close() - - if deleted_count > 0: - self._logger.info(f"成功删除风格学习审查记录,ID: {review_id}") - return True - else: - self._logger.warning(f"未找到要删除的风格学习审查记录,ID: {review_id}") - return False - - except Exception as e: - self._logger.error(f"删除风格学习审查记录失败: {e}") - return False - - async def get_detailed_metrics(self) -> Dict[str, Any]: - """获取详细性能监控数据""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # API指标(基于学习批次的执行时间) - # ✅ 修复:使用数据库无关的时间格式化方式 - if self.config.db_type == 'sqlite': # ✅ 修正:self.db_type → self.config.db_type - # SQLite语法 - await cursor.execute(''' - SELECT - strftime('%H', datetime(start_time, 'unixepoch')) as hour, - AVG((CASE WHEN end_time IS NOT NULL THEN end_time - start_time ELSE 0 END)) as avg_response_time - FROM learning_batches - WHERE start_time > ? AND end_time IS NOT NULL - GROUP BY hour - ORDER BY hour - ''', (time.time() - 86400,)) - else: - # MySQL语法 - await cursor.execute(''' - SELECT - HOUR(FROM_UNIXTIME(start_time)) as hour, - AVG((CASE WHEN end_time IS NOT NULL THEN end_time - start_time ELSE 0 END)) as avg_response_time - FROM learning_batches - WHERE start_time > %s AND end_time IS NOT NULL - GROUP BY hour - ORDER BY hour - ''', (time.time() - 86400,)) - - api_hours = [] - api_response_times = [] - for row in await cursor.fetchall(): - api_hours.append(f"{row[0]}:00") - api_response_times.append(round(row[1] * 1000, 2)) # 转换为毫秒 - - # 数据库表统计 - tables_to_check = ['raw_messages', 'filtered_messages', 'learning_batches', 'persona_update_records'] - table_stats = {} - - for table in tables_to_check: - try: - await cursor.execute(f'SELECT COUNT(*) FROM {table}') - count = await cursor.fetchone() - table_stats[table] = count[0] if count else 0 - except Exception as table_error: - self._logger.debug(f"无法获取表 {table} 统计: {table_error}") - table_stats[table] = 0 - - # 系统指标 - import psutil - try: - memory = psutil.virtual_memory() - # 在Windows上使用主驱动器 - disk_path = 'C:\\' if os.name == 'nt' else '/' - disk = psutil.disk_usage(disk_path) - - system_metrics = { - 'memory_percent': memory.percent, - 'cpu_percent': psutil.cpu_percent(), - 'disk_percent': round(disk.used / disk.total * 100, 2) - } - except Exception as system_error: - self._logger.warning(f"获取系统指标失败: {system_error}") - system_metrics = { - 'memory_percent': 0, - 'cpu_percent': 0, - 'disk_percent': 0 - } - - return { - 'api_metrics': { - 'hours': api_hours, - 'response_times': api_response_times - }, - 'database_metrics': { - 'table_stats': table_stats - }, - 'system_metrics': system_metrics - } - - except Exception as e: - self._logger.error(f"获取详细监控数据失败: {e}") - return { - 'api_metrics': { - 'hours': [], - 'response_times': [] - }, - 'database_metrics': { - 'table_stats': {} - }, - 'system_metrics': { - 'memory_percent': 0, - 'cpu_percent': 0, - 'disk_percent': 0 - } - } - - async def get_trends_data(self) -> Dict[str, Any]: - """获取指标趋势数据""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 计算7天和30天前的时间戳 - now = time.time() - week_ago = now - (7 * 24 * 3600) - month_ago = now - (30 * 24 * 3600) - - # 消息增长趋势 - await cursor.execute(''' - SELECT - COUNT(CASE WHEN timestamp > ? THEN 1 END) as week_count, - COUNT(CASE WHEN timestamp > ? THEN 1 END) as month_count, - COUNT(*) as total_count - FROM raw_messages - ''', (week_ago, month_ago)) - - message_stats = await cursor.fetchone() - if message_stats and len(message_stats) >= 3: - week_messages = int(message_stats[0]) if message_stats[0] else 0 - month_messages = int(message_stats[1]) if message_stats[1] else 0 - total_messages = int(message_stats[2]) if message_stats[2] else 0 - - # 计算增长率 - if month_messages > week_messages: - message_growth = ((week_messages * 4 - (month_messages - week_messages)) / (month_messages - week_messages) * 100) if (month_messages - week_messages) > 0 else 0 - else: - message_growth = 0 - elif message_stats: - self._logger.warning(f"消息统计数据行不完整 (期望3个字段,实际{len(message_stats)}个): {message_stats}") - message_growth = 0 - week_messages = 0 - month_messages = 0 - total_messages = 0 - else: - message_growth = 0 - week_messages = 0 - month_messages = 0 - total_messages = 0 - - # 筛选消息增长趋势 - await cursor.execute(''' - SELECT - COUNT(CASE WHEN timestamp > ? THEN 1 END) as week_filtered, - COUNT(CASE WHEN timestamp > ? THEN 1 END) as month_filtered - FROM filtered_messages - ''', (week_ago, month_ago)) - - filtered_stats = await cursor.fetchone() - if filtered_stats and len(filtered_stats) >= 2: - week_filtered = int(filtered_stats[0]) if filtered_stats[0] else 0 - month_filtered = int(filtered_stats[1]) if filtered_stats[1] else 0 - - # 计算增长率 - if month_filtered > week_filtered: - filtered_growth = ((week_filtered * 4 - (month_filtered - week_filtered)) / (month_filtered - week_filtered) * 100) if (month_filtered - week_filtered) > 0 else 0 - else: - filtered_growth = 0 - elif filtered_stats: - self._logger.warning(f"筛选消息统计数据行不完整 (期望2个字段,实际{len(filtered_stats)}个): {filtered_stats}") - week_filtered = 0 - month_filtered = 0 - filtered_growth = 0 - else: - week_filtered = 0 - month_filtered = 0 - filtered_growth = 0 - - # LLM调用增长(基于学习批次) - await cursor.execute(''' - SELECT - COUNT(CASE WHEN start_time > ? THEN 1 END) as week_sessions, - COUNT(CASE WHEN start_time > ? THEN 1 END) as month_sessions - FROM learning_batches - ''', (week_ago, month_ago)) - - session_stats = await cursor.fetchone() - if session_stats and len(session_stats) >= 2: - week_sessions = int(session_stats[0]) if session_stats[0] else 0 - month_sessions = int(session_stats[1]) if session_stats[1] else 0 - - # 计算增长率 - if month_sessions > week_sessions: - sessions_growth = ((week_sessions * 4 - (month_sessions - week_sessions)) / (month_sessions - week_sessions) * 100) if (month_sessions - week_sessions) > 0 else 0 - else: - sessions_growth = 0 - elif session_stats: - self._logger.warning(f"学习批次统计数据行不完整 (期望2个字段,实际{len(session_stats)}个): {session_stats}") - week_sessions = 0 - month_sessions = 0 - sessions_growth = 0 - else: - week_sessions = 0 - month_sessions = 0 - sessions_growth = 0 - - return { - 'message_growth': round(message_growth, 1), - 'filtered_growth': round(filtered_growth, 1), - 'llm_growth': round(sessions_growth, 1), - 'sessions_growth': round(sessions_growth, 1) - } - - except Exception as e: - self._logger.error(f"获取趋势数据失败: {e}") - return { - 'message_growth': 0, - 'filtered_growth': 0, - 'llm_growth': 0, - 'sessions_growth': 0 - } - - def _analyze_topic_from_messages(self, messages: List[str]) -> Dict[str, str]: - """ - 基于消息内容智能分析群聊主题 - - Args: - messages: 消息列表 - - Returns: - 包含topic和style的字典 - """ - try: - if not messages: - return {'topic': '空群聊', 'style': 'unknown'} - - # 合并所有消息文本 - all_text = ' '.join(messages).lower() - - # 定义主题关键词库 - topic_keywords = { - '技术讨论': ['代码', '编程', 'python', 'java', 'javascript', 'bug', '算法', '开发', '前端', '后端', 'api', '数据库', 'sql', 'git', '项目', '需求', '测试', '部署'], - '游戏娱乐': ['游戏', '玩家', '攻略', '装备', '副本', '公会', 'pvp', '角色', '技能', '等级', '经验', '任务', '活动', '充值', '抽卡', '开黑', '上分'], - '学习交流': ['学习', '作业', '考试', '复习', '笔记', '课程', '老师', '同学', '知识', '问题', '答案', '教程', '资料', '书籍', '论文', '研究'], - '工作协作': ['工作', '会议', '项目', '任务', '进度', '汇报', '客户', '合作', '团队', '领导', '同事', '业务', '方案', '文档', '流程', '审批'], - '生活日常': ['吃饭', '睡觉', '天气', '心情', '家人', '朋友', '购物', '电影', '音乐', '旅游', '美食', '健康', '运动', '休息', '周末'], - '兴趣爱好': ['摄影', '绘画', '音乐', '电影', '书籍', '旅行', '美食', '运动', '健身', '瑜伽', '跑步', '骑行', '爬山', '游泳', '篮球'], - '商务合作': ['合作', '商务', '业务', '客户', '项目', '方案', '报价', '合同', '付款', '发票', '产品', '服务', '市场', '销售', '推广'], - '技术支持': ['问题', '故障', '错误', '修复', '解决', '帮助', '支持', '教程', '指导', '操作', '配置', '安装', '更新', '维护', '优化'], - '闲聊灌水': ['哈哈', '嘿嘿', '😂', '😄', '笑死', '有趣', '无聊', '随便', '聊天', '扯淡', '吐槽', '搞笑', '段子', '表情', '发呆'], - '通知公告': ['通知', '公告', '重要', '注意', '提醒', '截止', '时间', '安排', '活动', '报名', '参加', '会议', '培训', '讲座', '活动'] - } - - # 分析主题匹配度 - topic_scores = {} - for topic, keywords in topic_keywords.items(): - score = 0 - for keyword in keywords: - score += all_text.count(keyword) - topic_scores[topic] = score - - # 获取得分最高的主题 - best_topic = max(topic_scores.items(), key=lambda x: x[1]) - - if best_topic[1] == 0: # 没有匹配到任何关键词 - return {'topic': '综合聊天', 'style': '日常对话'} - - # 根据主题确定对话风格 - style_mapping = { - '技术讨论': '技术交流', - '游戏娱乐': '轻松娱乐', - '学习交流': '学术讨论', - '工作协作': '工作协调', - '生活日常': '日常闲聊', - '兴趣爱好': '兴趣分享', - '商务合作': '商务沟通', - '技术支持': '技术答疑', - '闲聊灌水': '轻松聊天', - '通知公告': '信息通知' - } - - topic = best_topic[0] - style = style_mapping.get(topic, '日常对话') - - return { - 'topic': topic, - 'style': style - } - - except Exception as e: - self._logger.error(f"主题分析失败: {e}") - return {'topic': '未知主题', 'style': '日常对话'} - - async def get_recent_learning_batches(self, limit: int = 10) -> List[Dict[str, Any]]: - """获取最近的学习批次记录""" - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - SELECT id, group_id, start_time, end_time, quality_score, - processed_messages, batch_name, message_count, - filtered_count, success, error_message - FROM learning_batches - ORDER BY start_time DESC - LIMIT ? - ''', (limit,)) - - batches = [] - for row in await cursor.fetchall(): - try: - # 添加行数据验证 - if len(row) < 11: - self._logger.warning(f"学习批次记录行数据不完整 (期望11个字段,实际{len(row)}个),跳过: {row}") - continue - - batches.append({ - 'id': int(row[0]) if row[0] else 0, - 'group_id': row[1], - 'start_time': float(row[2]) if row[2] else 0, - 'end_time': float(row[3]) if row[3] else 0, - 'quality_score': float(row[4]) if row[4] else 0, - 'processed_messages': int(row[5]) if row[5] else 0, - 'batch_name': row[6], - 'message_count': int(row[7]) if row[7] else 0, - 'filtered_count': int(row[8]) if row[8] else 0, - 'success': bool(row[9]) if row[9] is not None else False, - 'error_message': row[10] - }) - except Exception as row_error: - self._logger.warning(f"处理学习批次记录行时出错,跳过: {row_error}, row: {row if len(str(row)) < 100 else 'row too long'}") - continue - - return batches - - except Exception as e: - self._logger.error(f"获取学习批次记录失败: {e}") - return [] - - async def add_persona_learning_review( - self, - group_id: str, - proposed_content: str, - learning_source: str = UPDATE_TYPE_EXPRESSION_LEARNING, # ✅ 使用常量作为默认值 - confidence_score: float = 0.5, - raw_analysis: str = "", - metadata: Dict[str, Any] = None, - original_content: str = "", # ✅ 新增:原人格完整文本 - new_content: str = "" # ✅ 新增:新人格完整文本(原人格+增量) - ) -> int: - """添加人格学习审查记录 - - Args: - group_id: 群组ID - proposed_content: 建议的增量人格内容 - learning_source: 学习来源 - confidence_score: 置信度分数 - raw_analysis: 原始分析结果 - metadata: 元数据(包含features_content, llm_response, sample counts等) - original_content: 原人格完整文本(用于前端显示对比) - new_content: 新人格完整文本(原人格+增量,用于前端高亮显示) - - Returns: - 插入记录的ID - """ - try: - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - # 确保表存在并添加metadata列 - # 根据数据库类型使用不同的DDL - if self.config.db_type.lower() == 'mysql': - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INT PRIMARY KEY AUTO_INCREMENT, - timestamp DOUBLE NOT NULL, - group_id VARCHAR(255) NOT NULL, - update_type VARCHAR(100) NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, - confidence_score DOUBLE, - reason TEXT, - status VARCHAR(50) NOT NULL DEFAULT 'pending', - reviewer_comment TEXT, - review_time DOUBLE, - metadata JSON, - INDEX idx_group_id (group_id), - INDEX idx_status (status), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - ''') - else: - await cursor.execute(''' - CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp REAL NOT NULL, - group_id TEXT NOT NULL, - update_type TEXT NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, - confidence_score REAL, - reason TEXT, - status TEXT NOT NULL DEFAULT 'pending', - reviewer_comment TEXT, - review_time REAL, - metadata TEXT - ) - ''') - - # 尝试添加metadata列(如果表已存在但没有此列) - try: - await cursor.execute('ALTER TABLE persona_update_reviews ADD COLUMN metadata TEXT') - except: - pass # 列已存在 - - # 准备元数据JSON - import json - metadata_json = json.dumps(metadata if metadata else {}, ensure_ascii=False) - - # ✅ 修复:使用传入的 original_content 和 new_content - # 如果 new_content 为空,则使用 proposed_content(向后兼容) - final_new_content = new_content if new_content else proposed_content - - # 根据数据库类型使用不同的占位符 - placeholder = '%s' if self.config.db_type.lower() == 'mysql' else '?' - - # 插入记录 - placeholders = ', '.join([placeholder] * 10) - await cursor.execute(f''' - INSERT INTO persona_update_reviews - (timestamp, group_id, update_type, original_content, new_content, - proposed_content, confidence_score, reason, status, metadata) - VALUES ({placeholders}) - ''', ( - time.time(), - group_id, - learning_source, # update_type就是learning_source - original_content, # ✅ 使用传入的原人格文本 - final_new_content, # ✅ 使用完整的新人格文本 - proposed_content, # proposed_content保持为增量部分 - confidence_score, - raw_analysis, # reason字段存储raw_analysis - 'pending', - metadata_json - )) - - await conn.commit() - record_id = cursor.lastrowid - - self._logger.info(f"添加人格学习审查记录成功,ID: {record_id}, 群组: {group_id}") - return record_id - - except Exception as e: - self._logger.error(f"添加人格学习审查记录失败: {e}") - raise - - async def get_messages_by_group_and_timerange( - self, - group_id: str, - start_time: float = None, - end_time: float = None, - limit: int = 100 - ) -> List[Dict[str, Any]]: - """ - 获取指定群组在指定时间范围内的聊天记录 - - Args: - group_id: 群组ID - start_time: 开始时间戳(秒),None表示不限制 - end_time: 结束时间戳(秒),None表示不限制 - limit: 返回消息数量限制 - - Returns: - 消息记录列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - query = ''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp, processed - FROM raw_messages - WHERE group_id = ? - ''' - params = [group_id] - - if start_time is not None: - query += ' AND timestamp >= ?' - params.append(start_time) - - if end_time is not None: - query += ' AND timestamp <= ?' - params.append(end_time) - - query += ' ORDER BY timestamp DESC LIMIT ?' - params.append(limit) - - await cursor.execute(query, params) - - messages = [] - for row in await cursor.fetchall(): - messages.append({ - 'id': row[0], - 'sender_id': row[1], - 'sender_name': row[2], - 'content': row[3], # 外部API使用 'content' 字段名 - 'group_id': row[4], - 'platform': row[5], - 'timestamp': row[6], - 'processed': row[7] - }) - - self._logger.info(f"📖 API查询结果: group={group_id}, 返回{len(messages)}条消息, 最新timestamp={messages[0]['timestamp'] if messages else 'N/A'}") - return messages - - except aiosqlite.Error as e: - self._logger.error(f"获取时间范围消息失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def get_new_messages_since( - self, - group_id: str, - last_message_id: int = None, - last_timestamp: float = None - ) -> List[Dict[str, Any]]: - """ - 获取指定群组的增量消息(自上次获取后的新消息) - - Args: - group_id: 群组ID - last_message_id: 上次获取的最后一条消息ID - last_timestamp: 上次获取的最后一条消息时间戳 - - Returns: - 新消息列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 优先使用message_id,如果没有则使用timestamp - if last_message_id is not None: - query = ''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp, processed - FROM raw_messages - WHERE group_id = ? AND id > ? - ORDER BY timestamp ASC - ''' - params = (group_id, last_message_id) - elif last_timestamp is not None: - query = ''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp, processed - FROM raw_messages - WHERE group_id = ? AND timestamp > ? - ORDER BY timestamp ASC - ''' - params = (group_id, last_timestamp) - else: - # 如果两个参数都没有,返回最近的消息 - query = ''' - SELECT id, sender_id, sender_name, message, group_id, platform, timestamp, processed - FROM raw_messages - WHERE group_id = ? - ORDER BY timestamp DESC - LIMIT 20 - ''' - params = (group_id,) - - await cursor.execute(query, params) - - messages = [] - for row in await cursor.fetchall(): - messages.append({ - 'id': row[0], - 'sender_id': row[1], - 'sender_name': row[2], - 'content': row[3], # 外部API使用 'content' 字段名 - 'group_id': row[4], - 'platform': row[5], - 'timestamp': row[6], - 'processed': row[7] - }) - - return messages - - except aiosqlite.Error as e: - self._logger.error(f"获取增量消息失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def get_current_topic_summary(self, group_id: str, recent_messages_count: int = 20) -> Dict[str, Any]: - """ - 获取指定群组当前的聊天话题总结 - - 优先从数据库中读取最近的话题总结,如果没有或过期(超过30分钟),则分析最近消息生成新的总结 - - Args: - group_id: 群组ID - recent_messages_count: 分析的最近消息数量 - - Returns: - 话题总结信息 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 首先尝试从数据库获取最近30分钟内的话题总结 - thirty_minutes_ago = time.time() - 1800 - await cursor.execute(''' - SELECT topic, summary, participants, message_count, - start_timestamp, end_timestamp, generated_at - FROM topic_summaries - WHERE group_id = ? AND generated_at > ? - ORDER BY generated_at DESC - LIMIT 1 - ''', (group_id, thirty_minutes_ago)) - - cached_summary = await cursor.fetchone() - - if cached_summary: - # 返回缓存的话题总结 - import json - participants = json.loads(cached_summary[2]) if cached_summary[2] else [] - - return { - 'group_id': group_id, - 'topic': cached_summary[0], - 'summary': cached_summary[1], - 'participants': participants, - 'message_count': cached_summary[3], - 'start_timestamp': cached_summary[4], - 'latest_timestamp': cached_summary[5], - 'generated_at': cached_summary[6], - 'from_cache': True - } - - # 如果没有缓存,获取最近的消息生成新总结 - await cursor.execute(''' - SELECT message, sender_name, timestamp - FROM raw_messages - WHERE group_id = ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, recent_messages_count)) - - messages = [] - latest_timestamp = None - earliest_timestamp = None - for row in await cursor.fetchall(): - messages.append({ - 'message': row[0], - 'sender_name': row[1], - 'timestamp': row[2] - }) - if latest_timestamp is None or row[2] > latest_timestamp: - latest_timestamp = row[2] - if earliest_timestamp is None or row[2] < earliest_timestamp: - earliest_timestamp = row[2] - - if not messages: - return { - 'group_id': group_id, - 'topic': '暂无聊天记录', - 'participants': [], - 'message_count': 0, - 'latest_timestamp': 0, - 'summary': '群组暂无聊天活动', - 'from_cache': False - } - - # 统计参与者 - participants = list(set([msg['sender_name'] for msg in messages])) - - # 使用已有的话题分析方法 - messages_text = [msg['message'] for msg in messages] - topic_analysis = self._analyze_topic_from_messages(messages_text) - - topic_result = { - 'group_id': group_id, - 'topic': topic_analysis['topic'], - 'summary': f"最近{len(messages)}条消息讨论了{topic_analysis['topic']},对话风格为{topic_analysis['style']}", - 'participants': participants, - 'message_count': len(messages), - 'start_timestamp': earliest_timestamp, - 'latest_timestamp': latest_timestamp, - 'generated_at': time.time(), - 'recent_messages': messages[:5], # 返回最近5条消息内容供参考 - 'from_cache': False - } - - # 保存到数据库以供后续查询 - # 不等待保存完成,避免阻塞API响应 - asyncio.create_task(self._save_topic_summary(group_id, topic_result)) - - return topic_result - - except aiosqlite.Error as e: - self._logger.error(f"获取话题总结失败: {e}", exc_info=True) - return { - 'group_id': group_id, - 'topic': '获取失败', - 'participants': [], - 'message_count': 0, - 'latest_timestamp': 0, - 'summary': f'获取话题失败: {str(e)}', - 'from_cache': False - } - finally: - await cursor.close() - - async def _save_topic_summary(self, group_id: str, topic_data: Dict[str, Any]): - """ - 保存话题总结到数据库 - - Args: - group_id: 群组ID - topic_data: 话题数据 - """ - try: - import json - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - INSERT INTO topic_summaries - (group_id, topic, summary, participants, message_count, - start_timestamp, end_timestamp, generated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - group_id, - topic_data.get('topic', ''), - topic_data.get('summary', ''), - json.dumps(topic_data.get('participants', []), ensure_ascii=False), - topic_data.get('message_count', 0), - topic_data.get('start_timestamp'), - topic_data.get('latest_timestamp'), - topic_data.get('generated_at', time.time()) - )) - - await conn.commit() - await cursor.close() - - self._logger.debug(f"已保存群组 {group_id} 的话题总结") - - except Exception as e: - self._logger.error(f"保存话题总结失败: {e}", exc_info=True) - - def _extract_simple_keywords(self, messages: List[str], max_keywords: int = 10) -> List[str]: - """ - 简单的关键词提取(后续可以用LLM优化) - - Args: - messages: 消息列表 - max_keywords: 最大关键词数量 - - Returns: - 关键词列表 - """ - # 合并所有消息 - text = ' '.join(messages) - - # 简单的词频统计(这里可以用jieba等工具优化) - import re - # 移除特殊字符,保留中文、英文、数字 - words = re.findall(r'[\u4e00-\u9fa5]+|[a-zA-Z]+', text) - - # 统计词频 - word_freq = {} - for word in words: - if len(word) >= 2: # 只统计长度>=2的词 - word_freq[word] = word_freq.get(word, 0) + 1 - - # 按频率排序 - sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True) - - return [word for word, freq in sorted_words[:max_keywords]] - - async def get_all_expression_patterns(self, group_id: str) -> List[Dict[str, Any]]: - """ - 获取指定群组的所有表达模式 - - Args: - group_id: 群组ID - - Returns: - 表达模式列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT context, expression, quality_score, last_used_timestamp - FROM expression_patterns - WHERE group_id = ? - ORDER BY quality_score DESC, last_used_timestamp DESC - ''', (group_id,)) - - patterns = [] - for row in await cursor.fetchall(): - patterns.append({ - 'context': row[0], - 'expression': row[1], - 'quality_score': row[2], - 'last_used_timestamp': row[3] - }) - - return patterns - - except aiosqlite.Error as e: - self._logger.error(f"获取表达模式失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def get_all_expression_patterns_by_group(self) -> Dict[str, List[Dict[str, Any]]]: - """ - 获取所有群组的表达模式(按群组分组) - - Returns: - Dict[str, List[Dict[str, Any]]]: 群组ID -> 表达模式列表的映射 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT id, situation, expression, weight, last_active_time, create_time, group_id - FROM expression_patterns - ORDER BY group_id, last_active_time DESC - ''') - - patterns_by_group = {} - for row in await cursor.fetchall(): - group_id = row[6] - if group_id not in patterns_by_group: - patterns_by_group[group_id] = [] - - patterns_by_group[group_id].append({ - 'id': row[0], - 'situation': row[1], - 'expression': row[2], - 'weight': row[3], - 'last_active_time': row[4], - 'created_time': row[5], - 'group_id': group_id, - 'style_type': 'general' - }) - - return patterns_by_group - - except Exception as e: - self._logger.error(f"获取所有表达模式失败: {e}", exc_info=True) - return {} - finally: - await cursor.close() - - async def get_recent_week_expression_patterns(self, group_id: str = None, limit: int = 20, hours: int = 168) -> List[Dict[str, Any]]: - """ - 获取最近指定小时内学习到的表达模式(按质量分数和时间排序) - - Args: - group_id: 群组ID,如果为None则获取全局所有群组的表达模式 - limit: 获取数量限制 - hours: 时间范围(小时),默认168小时(一周) - - Returns: - 表达模式列表,包含场景(situation)和表达(expression) - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 计算时间阈值 - time_threshold = time.time() - (hours * 3600) - - # 根据group_id是否为None决定查询条件 - if group_id is None: - # 全局查询:从所有群组获取表达模式 - await cursor.execute(''' - SELECT situation, expression, weight, last_active_time, create_time, group_id - FROM expression_patterns - WHERE last_active_time > ? - ORDER BY weight DESC, last_active_time DESC - LIMIT ? - ''', (time_threshold, limit)) - else: - # 单群组查询:只获取指定群组的表达模式 - await cursor.execute(''' - SELECT situation, expression, weight, last_active_time, create_time, group_id - FROM expression_patterns - WHERE group_id = ? AND last_active_time > ? - ORDER BY weight DESC, last_active_time DESC - LIMIT ? - ''', (group_id, time_threshold, limit)) - - patterns = [] - for row in await cursor.fetchall(): - patterns.append({ - 'situation': row[0], # 场景描述 - 'expression': row[1], # 表达方式 - 'weight': row[2], # 权重 - 'last_active_time': row[3], # 最后活跃时间 - 'create_time': row[4], # 创建时间 - 'group_id': row[5] if len(row) > 5 else group_id # 群组ID(全局查询时有用) - }) - - return patterns - - except aiosqlite.Error as e: - self._logger.error(f"获取最近一周表达模式失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def get_recent_bot_responses(self, group_id: str, limit: int = 10) -> List[str]: - """ - 获取Bot最近的回复内容(用于同质化分析)- 从bot_messages表读取 - - Args: - group_id: 群组ID - limit: 获取数量 - - Returns: - 回复内容列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 从bot_messages表读取Bot的回复 - await cursor.execute(''' - SELECT message - FROM bot_messages - WHERE group_id = ? - ORDER BY timestamp DESC - LIMIT ? - ''', (group_id, limit)) - - responses = [] - for row in await cursor.fetchall(): - responses.append(row[0]) - - return responses - - except aiosqlite.Error as e: - self._logger.error(f"获取Bot最近回复失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def save_bot_message( - self, - group_id: str, - user_id: str, - message: str, - response_to_message_id: Optional[int] = None, - context_type: str = "normal", - temperature: float = 0.7, - language_style: Optional[str] = None, - response_pattern: Optional[str] = None - ) -> bool: - """ - 保存Bot发送的消息到数据库 - - Args: - group_id: 群组ID - user_id: 回复的用户ID - message: Bot的回复内容 - response_to_message_id: 回复的消息ID (来自raw_messages表) - context_type: 上下文类型 (normal/creative/precise等) - temperature: 使用的temperature参数 - language_style: 使用的语言风格 - response_pattern: 使用的回复模式 - - Returns: - bool: 是否成功保存 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT INTO bot_messages - (group_id, user_id, message, response_to_message_id, context_type, - temperature, language_style, response_pattern, timestamp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - group_id, - user_id, - message, - response_to_message_id, - context_type, - temperature, - language_style, - response_pattern, - time.time() - )) - - await conn.commit() - self._logger.debug(f"✅ Bot消息已保存: group={group_id}, msg_preview={message[:50]}...") - return True - - except aiosqlite.Error as e: - self._logger.error(f"保存Bot消息失败: {e}", exc_info=True) - return False - finally: - await cursor.close() - - async def get_bot_message_statistics(self, group_id: str, time_range_hours: int = 24) -> Dict[str, Any]: - """ - 获取Bot消息统计信息 (用于多样性分析) - - Args: - group_id: 群组ID - time_range_hours: 统计时间范围(小时) - - Returns: - 统计信息字典 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - cutoff_time = time.time() - (time_range_hours * 3600) - - # 统计消息总数 - await cursor.execute(''' - SELECT COUNT(*) as total, - AVG(temperature) as avg_temp, - COUNT(DISTINCT language_style) as unique_styles, - COUNT(DISTINCT response_pattern) as unique_patterns - FROM bot_messages - WHERE group_id = ? AND timestamp > ? - ''', (group_id, cutoff_time)) - - row = await cursor.fetchone() - - # 获取最常用的风格和模式 - await cursor.execute(''' - SELECT language_style, COUNT(*) as count - FROM bot_messages - WHERE group_id = ? AND timestamp > ? AND language_style IS NOT NULL - GROUP BY language_style - ORDER BY count DESC - LIMIT 5 - ''', (group_id, cutoff_time)) - - top_styles = [{'style': row[0], 'count': row[1]} for row in await cursor.fetchall()] - - await cursor.execute(''' - SELECT response_pattern, COUNT(*) as count - FROM bot_messages - WHERE group_id = ? AND timestamp > ? AND response_pattern IS NOT NULL - GROUP BY response_pattern - ORDER BY count DESC - LIMIT 5 - ''', (group_id, cutoff_time)) - - top_patterns = [{'pattern': row[0], 'count': row[1]} for row in await cursor.fetchall()] - - return { - 'total_messages': row[0] if row else 0, - 'average_temperature': round(row[1], 2) if row and row[1] else 0.7, - 'unique_styles_count': row[2] if row else 0, - 'unique_patterns_count': row[3] if row else 0, - 'top_styles': top_styles, - 'top_patterns': top_patterns, - 'time_range_hours': time_range_hours - } - - except aiosqlite.Error as e: - self._logger.error(f"获取Bot消息统计失败: {e}", exc_info=True) - return {} - finally: - await cursor.close() - - # ========== 黑话学习系统数据库操作方法 ========== - - async def get_jargon(self, chat_id: str, content: str) -> Optional[Dict[str, Any]]: - """ - 查询指定黑话 - - Args: - chat_id: 群组ID - content: 黑话词条 - - Returns: - 黑话记录字典或None - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT id, content, raw_content, meaning, is_jargon, count, - last_inference_count, is_complete, is_global, chat_id, - created_at, updated_at - FROM jargon - WHERE chat_id = ? AND content = ? - ''', (chat_id, content)) - - row = await cursor.fetchone() - if not row: - return None - - return { - 'id': row[0], - 'content': row[1], - 'raw_content': row[2], - 'meaning': row[3], - 'is_jargon': bool(row[4]) if row[4] is not None else None, - 'count': row[5], - 'last_inference_count': row[6], - 'is_complete': bool(row[7]), - 'is_global': bool(row[8]), - 'chat_id': row[9], - 'created_at': row[10], - 'updated_at': row[11] - } - - except aiosqlite.Error as e: - logger.error(f"查询黑话失败: {e}", exc_info=True) - return None - finally: - await cursor.close() - - async def insert_jargon(self, jargon: Dict[str, Any]) -> int: - """ - 插入新的黑话记录 - - Args: - jargon: 黑话数据字典 - - Returns: - 插入记录的ID - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - INSERT INTO jargon - (content, raw_content, meaning, is_jargon, count, last_inference_count, - is_complete, is_global, chat_id, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - jargon.get('content'), - jargon.get('raw_content', '[]'), - jargon.get('meaning'), - jargon.get('is_jargon'), - jargon.get('count', 1), - jargon.get('last_inference_count', 0), - jargon.get('is_complete', False), - jargon.get('is_global', False), - jargon.get('chat_id'), - jargon.get('created_at'), - jargon.get('updated_at') - )) - - jargon_id = cursor.lastrowid - await conn.commit() - logger.debug(f"插入黑话记录成功, ID: {jargon_id}") - return jargon_id - - except aiosqlite.Error as e: - logger.error(f"插入黑话失败: {e}", exc_info=True) - raise - finally: - await cursor.close() - - async def update_jargon(self, jargon: Dict[str, Any]) -> bool: - """ - 更新现有黑话记录 - - Args: - jargon: 黑话数据字典(必须包含id) - - Returns: - 是否成功更新 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - UPDATE jargon - SET content = ?, raw_content = ?, meaning = ?, is_jargon = ?, - count = ?, last_inference_count = ?, is_complete = ?, - is_global = ?, updated_at = ? - WHERE id = ? - ''', ( - jargon.get('content'), - jargon.get('raw_content'), - jargon.get('meaning'), - jargon.get('is_jargon'), - jargon.get('count'), - jargon.get('last_inference_count'), - jargon.get('is_complete'), - jargon.get('is_global'), - jargon.get('updated_at'), - jargon.get('id') - )) - - await conn.commit() - logger.debug(f"更新黑话记录成功, ID: {jargon.get('id')}") - return cursor.rowcount > 0 - - except aiosqlite.Error as e: - logger.error(f"更新黑话失败: {e}", exc_info=True) - return False - finally: - await cursor.close() - - async def search_jargon( - self, - keyword: str, - chat_id: Optional[str] = None, - limit: int = 10 - ) -> List[Dict[str, Any]]: - """ - 搜索黑话(用于LLM工具调用) - - Args: - keyword: 搜索关键词 - chat_id: 群组ID (None表示搜索全局黑话) - limit: 返回结果数量限制 - - Returns: - 黑话记录列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 根据数据库类型选择占位符 - placeholder = '%s' if self.config.db_type.lower() == 'mysql' else '?' - - if chat_id: - # 搜索指定群组的黑话 - query = f''' - SELECT id, content, meaning, is_jargon, count, is_complete - FROM jargon - WHERE chat_id = {placeholder} AND content LIKE {placeholder} AND is_jargon = 1 - ORDER BY count DESC, updated_at DESC - LIMIT {placeholder} - ''' - await cursor.execute(query, (chat_id, f'%{keyword}%', limit)) - else: - # 搜索全局黑话 - query = f''' - SELECT id, content, meaning, is_jargon, count, is_complete - FROM jargon - WHERE content LIKE {placeholder} AND is_jargon = 1 AND is_global = 1 - ORDER BY count DESC, updated_at DESC - LIMIT {placeholder} - ''' - await cursor.execute(query, (f'%{keyword}%', limit)) - - results = [] - for row in await cursor.fetchall(): - results.append({ - 'id': row[0], - 'content': row[1], - 'meaning': row[2], - 'is_jargon': bool(row[3]), - 'count': row[4], - 'is_complete': bool(row[5]) - }) - - return results - - except Exception as e: - logger.error(f"搜索黑话失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def get_jargon_statistics(self, chat_id: Optional[str] = None) -> Dict[str, Any]: - """ - 获取黑话学习统计信息 - - Args: - chat_id: 群组ID (None表示获取全局统计) - - Returns: - 统计信息字典 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 根据数据库类型选择占位符 - placeholder = '%s' if self.config.db_type.lower() == 'mysql' else '?' - - if chat_id: - # 群组统计 - query = f''' - SELECT - COUNT(*) as total, - COUNT(CASE WHEN is_jargon = 1 THEN 1 END) as confirmed_jargon, - COUNT(CASE WHEN is_complete = 1 THEN 1 END) as completed, - SUM(count) as total_occurrences, - AVG(count) as avg_count - FROM jargon - WHERE chat_id = {placeholder} - ''' - await cursor.execute(query, (chat_id,)) - else: - # 全局统计 - await cursor.execute(''' - SELECT - COUNT(*) as total, - COUNT(CASE WHEN is_jargon = 1 THEN 1 END) as confirmed_jargon, - COUNT(CASE WHEN is_complete = 1 THEN 1 END) as completed, - SUM(count) as total_occurrences, - AVG(count) as avg_count, - COUNT(DISTINCT chat_id) as active_groups - FROM jargon - ''') - - row = await cursor.fetchone() - - # 添加行数据验证 - if not row or len(row) < 5: - self._logger.warning(f"黑话统计数据行不完整 (期望至少5个字段,实际{len(row) if row else 0}个),返回默认值") - return { - 'total_candidates': 0, - 'confirmed_jargon': 0, - 'completed_inference': 0, - 'total_occurrences': 0, - 'average_count': 0, - 'active_groups': 0 - } - - stats = { - 'total_candidates': int(row[0]) if row[0] else 0, - 'confirmed_jargon': int(row[1]) if row[1] else 0, - 'completed_inference': int(row[2]) if row[2] else 0, - 'total_occurrences': int(row[3]) if row[3] else 0, - 'average_count': round(float(row[4]), 1) if row[4] else 0 - } - - if not chat_id and len(row) > 5: - stats['active_groups'] = int(row[5]) if row[5] else 0 - - return stats - - except Exception as e: - logger.error(f"获取黑话统计失败: {e}", exc_info=True) - return { - 'total_candidates': 0, - 'confirmed_jargon': 0, - 'completed_inference': 0, - 'total_occurrences': 0, - 'average_count': 0 - } - finally: - await cursor.close() - - async def get_recent_jargon_list( - self, - chat_id: Optional[str] = None, - limit: int = 20, - only_confirmed: bool = True - ) -> List[Dict[str, Any]]: - """ - 获取最近学习到的黑话列表 - - Args: - chat_id: 群组ID (None表示获取所有) - limit: 返回数量限制 - only_confirmed: 是否只返回已确认的黑话 - - Returns: - 黑话列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 根据数据库类型选择占位符 - placeholder = '%s' if self.config.db_type.lower() == 'mysql' else '?' - - query = ''' - SELECT id, content, meaning, is_jargon, count, - last_inference_count, is_complete, chat_id, updated_at, is_global - FROM jargon - WHERE 1=1 - ''' - params = [] - - if chat_id: - query += f' AND chat_id = {placeholder}' - params.append(chat_id) - - if only_confirmed: - query += ' AND is_jargon = 1' - - query += f' ORDER BY updated_at DESC LIMIT {placeholder}' - params.append(limit) - - await cursor.execute(query, tuple(params)) - - jargon_list = [] - for row in await cursor.fetchall(): - try: - # 添加行数据验证 - if len(row) < 10: - self._logger.warning(f"黑话记录行数据不完整 (期望10个字段,实际{len(row)}个),跳过: {row}") - continue - - jargon_list.append({ - 'id': row[0], - 'content': row[1], - 'meaning': row[2], - 'is_jargon': bool(row[3]) if row[3] is not None else None, - 'count': int(row[4]) if row[4] else 0, - 'last_inference_count': int(row[5]) if row[5] else 0, - 'is_complete': bool(row[6]), - 'chat_id': row[7], - 'updated_at': row[8], - 'is_global': bool(row[9]) if row[9] is not None else False - }) - except Exception as row_error: - self._logger.warning(f"处理黑话记录行时出错,跳过: {row_error}, row: {row}") - continue - - return jargon_list - - except Exception as e: - logger.error(f"获取黑话列表失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def get_jargon_by_id(self, jargon_id: int) -> Optional[Dict[str, Any]]: - """ - 根据ID获取黑话记录 - - Args: - jargon_id: 黑话记录ID - - Returns: - 黑话记录或None - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 根据数据库类型选择占位符 - placeholder = '%s' if self.config.db_type.lower() == 'mysql' else '?' - - query = f''' - SELECT id, content, meaning, is_jargon, count, - last_inference_count, is_complete, chat_id, updated_at, is_global - FROM jargon - WHERE id = {placeholder} - ''' - await cursor.execute(query, (jargon_id,)) - row = await cursor.fetchone() - - if row: - return { - 'id': row[0], - 'content': row[1], - 'meaning': row[2], - 'is_jargon': bool(row[3]) if row[3] is not None else None, - 'count': row[4], - 'last_inference_count': row[5], - 'is_complete': bool(row[6]), - 'chat_id': row[7], - 'updated_at': row[8], - 'is_global': bool(row[9]) if row[9] is not None else False - } - return None - - except Exception as e: - logger.error(f"获取黑话记录失败: {e}", exc_info=True) - return None - finally: - await cursor.close() - - async def delete_jargon_by_id(self, jargon_id: int) -> bool: - """ - 根据ID删除黑话记录 - - Args: - jargon_id: 黑话记录ID - - Returns: - 是否成功删除 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 根据数据库类型选择占位符 - placeholder = '%s' if self.config.db_type.lower() == 'mysql' else '?' - - query = f'DELETE FROM jargon WHERE id = {placeholder}' - await cursor.execute(query, (jargon_id,)) - await conn.commit() - deleted = cursor.rowcount > 0 - if deleted: - logger.debug(f"删除黑话记录成功, ID: {jargon_id}") - return deleted - - except Exception as e: - logger.error(f"删除黑话失败: {e}", exc_info=True) - return False - finally: - await cursor.close() - - async def get_global_jargon_list(self, limit: int = 50) -> List[Dict[str, Any]]: - """ - 获取全局共享的黑话列表 - - Args: - limit: 返回数量限制 - - Returns: - 全局黑话列表 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - SELECT id, content, meaning, is_jargon, count, - last_inference_count, is_complete, is_global, chat_id, updated_at - FROM jargon - WHERE is_jargon = 1 AND is_global = 1 - ORDER BY count DESC, updated_at DESC - LIMIT ? - ''', (limit,)) - - jargon_list = [] - for row in await cursor.fetchall(): - jargon_list.append({ - 'id': row[0], - 'content': row[1], - 'meaning': row[2], - 'is_jargon': bool(row[3]), - 'count': row[4], - 'last_inference_count': row[5], - 'is_complete': bool(row[6]), - 'is_global': bool(row[7]), - 'chat_id': row[8], - 'updated_at': row[9] - }) - - return jargon_list - - except aiosqlite.Error as e: - logger.error(f"获取全局黑话列表失败: {e}", exc_info=True) - return [] - finally: - await cursor.close() - - async def set_jargon_global(self, jargon_id: int, is_global: bool) -> bool: - """ - 设置黑话的全局共享状态 - - Args: - jargon_id: 黑话记录ID - is_global: 是否全局共享 - - Returns: - 是否成功更新 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - await cursor.execute(''' - UPDATE jargon - SET is_global = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (is_global, jargon_id)) - - await conn.commit() - updated = cursor.rowcount > 0 - if updated: - logger.info(f"黑话全局状态已更新: ID={jargon_id}, is_global={is_global}") - return updated - - except aiosqlite.Error as e: - logger.error(f"更新黑话全局状态失败: {e}", exc_info=True) - return False - finally: - await cursor.close() - - async def sync_global_jargon_to_group(self, target_chat_id: str) -> Dict[str, Any]: - """ - 将全局黑话同步到指定群组 - - Args: - target_chat_id: 目标群组ID - - Returns: - 同步结果统计 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - # 获取全局黑话列表 - await cursor.execute(''' - SELECT content, meaning, count - FROM jargon - WHERE is_jargon = 1 AND is_global = 1 AND chat_id != ? - ''', (target_chat_id,)) - - global_jargon = await cursor.fetchall() - - synced_count = 0 - skipped_count = 0 - - for content, meaning, count in global_jargon: - # 检查目标群组是否已存在该黑话 - await cursor.execute(''' - SELECT id FROM jargon - WHERE chat_id = ? AND content = ? - ''', (target_chat_id, content)) - - existing = await cursor.fetchone() - - if existing: - # 已存在,跳过 - skipped_count += 1 - else: - # 不存在,同步到目标群组 - await cursor.execute(''' - INSERT INTO jargon - (content, raw_content, meaning, is_jargon, count, last_inference_count, - is_complete, is_global, chat_id, created_at, updated_at) - VALUES (?, '[]', ?, 1, 1, 0, 0, 0, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (content, meaning, target_chat_id)) - synced_count += 1 - - await conn.commit() - - logger.info(f"同步全局黑话到群组 {target_chat_id}: 同步 {synced_count} 条, 跳过 {skipped_count} 条") - - return { - 'success': True, - 'synced_count': synced_count, - 'skipped_count': skipped_count, - 'total_global': len(global_jargon) - } - - except aiosqlite.Error as e: - logger.error(f"同步全局黑话失败: {e}", exc_info=True) - return { - 'success': False, - 'error': str(e), - 'synced_count': 0, - 'skipped_count': 0 - } - finally: - await cursor.close() - - async def batch_set_jargon_global(self, jargon_ids: List[int], is_global: bool) -> Dict[str, Any]: - """ - 批量设置黑话的全局共享状态 - - Args: - jargon_ids: 黑话记录ID列表 - is_global: 是否全局共享 - - Returns: - 操作结果统计 - """ - async with self.get_db_connection() as conn: - cursor = await conn.cursor() - - try: - success_count = 0 - failed_count = 0 - - for jid in jargon_ids: - try: - await cursor.execute(''' - UPDATE jargon - SET is_global = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? AND is_jargon = 1 - ''', (is_global, jid)) - if cursor.rowcount > 0: - success_count += 1 - else: - failed_count += 1 - except Exception: - failed_count += 1 - - await conn.commit() - - logger.info(f"批量更新黑话全局状态: 成功 {success_count}, 失败 {failed_count}") - - return { - 'success': True, - 'success_count': success_count, - 'failed_count': failed_count - } - - except aiosqlite.Error as e: - logger.error(f"批量更新黑话全局状态失败: {e}", exc_info=True) - return { - 'success': False, - 'error': str(e), - 'success_count': 0, - 'failed_count': len(jargon_ids) - } - finally: - await cursor.close() - - # ======================================================================== - # ORM Repository 方法(新) - # ======================================================================== - - async def save_learning_batch( - self, - batch_id: str, - batch_name: str, - group_id: str, - start_time: float, - end_time: Optional[float] = None, - quality_score: Optional[float] = None, - processed_messages: int = 0, - message_count: int = 0, - filtered_count: int = 0, - success: bool = True, - error_message: Optional[str] = None, - status: str = 'pending' - ) -> bool: - """ - 保存学习批次(使用 ORM) - - Args: - batch_id: 批次 ID - batch_name: 批次名称 - group_id: 群组 ID - start_time: 开始时间 - end_time: 结束时间 - quality_score: 质量分数 - processed_messages: 已处理消息数 - message_count: 总消息数 - filtered_count: 过滤掉的消息数 - success: 是否成功 - error_message: 错误信息 - status: 状态 - - Returns: - bool: 是否保存成功 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,无法保存学习批次") - return False - - try: - async with self.db_engine.get_session() as session: - repo = LearningBatchRepository(session) - batch = await repo.save_learning_batch( - batch_id=batch_id, - batch_name=batch_name, - group_id=group_id, - start_time=start_time, - end_time=end_time, - quality_score=quality_score, - processed_messages=processed_messages, - message_count=message_count, - filtered_count=filtered_count, - success=success, - error_message=error_message, - status=status - ) - await session.commit() - return batch is not None - - except Exception as e: - self._logger.error(f"保存学习批次失败: {e}", exc_info=True) - return False - - async def get_learning_batches( - self, - group_id: str, - limit: int = 50, - offset: int = 0, - status_filter: Optional[str] = None - ) -> List[Dict[str, Any]]: - """ - 获取学习批次列表(使用 ORM) - - Args: - group_id: 群组 ID - limit: 最大返回数量 - offset: 偏移量 - status_filter: 状态过滤 - - Returns: - List[Dict]: 批次列表 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - repo = LearningBatchRepository(session) - batches = await repo.get_learning_batches( - group_id=group_id, - limit=limit, - offset=offset, - status_filter=status_filter - ) - return [batch.to_dict() for batch in batches] - - except Exception as e: - self._logger.error(f"获取学习批次列表失败: {e}", exc_info=True) - return [] - - async def get_learning_batch_by_id(self, batch_id: str) -> Optional[Dict[str, Any]]: - """ - 根据 batch_id 获取学习批次(使用 ORM) - - Args: - batch_id: 批次 ID - - Returns: - Optional[Dict]: 批次记录 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回 None") - return None - - try: - async with self.db_engine.get_session() as session: - repo = LearningBatchRepository(session) - batch = await repo.get_learning_batch_by_id(batch_id) - return batch.to_dict() if batch else None - - except Exception as e: - self._logger.error(f"获取学习批次失败: {e}", exc_info=True) - return None - - async def save_learning_session( - self, - session_id: str, - group_id: str, - batch_id: Optional[str] = None, - start_time: Optional[float] = None, - end_time: Optional[float] = None, - metrics: Optional[str] = None - ) -> bool: - """ - 保存学习会话(使用 ORM) - - Args: - session_id: 会话 ID - group_id: 群组 ID - batch_id: 批次 ID - start_time: 开始时间 - end_time: 结束时间 - metrics: 指标数据(JSON字符串) - - Returns: - bool: 是否保存成功 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,无法保存学习会话") - return False - - try: - async with self.db_engine.get_session() as session: - repo = LearningSessionRepository(session) - learning_session = await repo.save_learning_session( - session_id=session_id, - group_id=group_id, - batch_id=batch_id, - start_time=start_time, - end_time=end_time, - metrics=metrics - ) - await session.commit() - return learning_session is not None - - except Exception as e: - self._logger.error(f"保存学习会话失败: {e}", exc_info=True) - return False - - async def get_learning_sessions( - self, - group_id: str, - batch_id: Optional[str] = None, - limit: int = 50, - offset: int = 0 - ) -> List[Dict[str, Any]]: - """ - 获取学习会话列表(使用 ORM) - - Args: - group_id: 群组 ID - batch_id: 批次 ID(可选) - limit: 最大返回数量 - offset: 偏移量 - - Returns: - List[Dict]: 会话列表 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - repo = LearningSessionRepository(session) - sessions = await repo.get_learning_sessions( - group_id=group_id, - batch_id=batch_id, - limit=limit, - offset=offset - ) - return [sess.to_dict() for sess in sessions] - - except Exception as e: - self._logger.error(f"获取学习会话列表失败: {e}", exc_info=True) - return [] - - # ==================== 对话与上下文系统 ORM 方法 ==================== - - async def save_conversation_context( - self, - group_id: str, - user_id: str, - context_window: str, - topic: Optional[str] = None, - sentiment: Optional[str] = None, - context_embedding: Optional[bytes] = None, - last_updated: Optional[float] = None - ) -> bool: - """ - 保存对话上下文(使用 ORM) - - Args: - group_id: 群组 ID - user_id: 用户 ID - context_window: 上下文窗口(JSON字符串) - topic: 当前话题 - sentiment: 情感倾向 - context_embedding: 上下文向量嵌入 - last_updated: 最后更新时间戳 - - Returns: - bool: 是否成功 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,无法保存对话上下文") - return False - - try: - async with self.db_engine.get_session() as session: - repo = ConversationContextRepository(session) - context = await repo.save_context( - group_id=group_id, - user_id=user_id, - context_window=context_window, - topic=topic, - sentiment=sentiment, - context_embedding=context_embedding, - last_updated=last_updated - ) - await session.commit() - return context is not None - - except Exception as e: - self._logger.error(f"保存对话上下文失败: {e}", exc_info=True) - return False - - async def get_latest_conversation_context( - self, - group_id: str, - user_id: str - ) -> Optional[Dict[str, Any]]: - """ - 获取最新的对话上下文(使用 ORM) - - Args: - group_id: 群组 ID - user_id: 用户 ID - - Returns: - Optional[Dict]: 上下文记录 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回 None") - return None - - try: - async with self.db_engine.get_session() as session: - repo = ConversationContextRepository(session) - context = await repo.get_latest_context( - group_id=group_id, - user_id=user_id - ) - return context.to_dict() if context else None - - except Exception as e: - self._logger.error(f"获取最新对话上下文失败: {e}", exc_info=True) - return None - - async def save_topic_cluster( - self, - group_id: str, - cluster_id: str, - topic_keywords: str, - message_count: int = 0, - representative_messages: Optional[str] = None, - cluster_center: Optional[bytes] = None - ) -> bool: - """ - 保存主题聚类(使用 ORM) - - Args: - group_id: 群组 ID - cluster_id: 聚类 ID - topic_keywords: 主题关键词(JSON字符串) - message_count: 消息数量 - representative_messages: 代表性消息(JSON字符串) - cluster_center: 聚类中心向量 - - Returns: - bool: 是否成功 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,无法保存主题聚类") - return False - - try: - async with self.db_engine.get_session() as session: - repo = ConversationTopicClusteringRepository(session) - cluster = await repo.save_cluster( - group_id=group_id, - cluster_id=cluster_id, - topic_keywords=topic_keywords, - message_count=message_count, - representative_messages=representative_messages, - cluster_center=cluster_center - ) - await session.commit() - return cluster is not None - - except Exception as e: - self._logger.error(f"保存主题聚类失败: {e}", exc_info=True) - return False - - async def get_all_topic_clusters( - self, - group_id: str, - order_by_message_count: bool = True, - limit: int = 100 - ) -> List[Dict[str, Any]]: - """ - 获取所有主题聚类(使用 ORM) - - Args: - group_id: 群组 ID - order_by_message_count: 是否按消息数量排序 - limit: 最大返回数量 - - Returns: - List[Dict]: 聚类列表 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - repo = ConversationTopicClusteringRepository(session) - clusters = await repo.get_all_clusters( - group_id=group_id, - order_by_message_count=order_by_message_count, - limit=limit - ) - return [cluster.to_dict() for cluster in clusters] - - except Exception as e: - self._logger.error(f"获取主题聚类列表失败: {e}", exc_info=True) - return [] - - async def save_quality_metrics( - self, - group_id: str, - message_id: int, - coherence_score: Optional[float] = None, - relevance_score: Optional[float] = None, - engagement_score: Optional[float] = None, - sentiment_alignment: Optional[float] = None, - calculated_at: Optional[float] = None - ) -> bool: - """ - 保存对话质量指标(使用 ORM) - - Args: - group_id: 群组 ID - message_id: 消息 ID - coherence_score: 连贯性分数 - relevance_score: 相关性分数 - engagement_score: 互动度分数 - sentiment_alignment: 情感一致性分数 - calculated_at: 计算时间戳 - - Returns: - bool: 是否成功 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,无法保存质量指标") - return False - - try: - async with self.db_engine.get_session() as session: - repo = ConversationQualityMetricsRepository(session) - metrics = await repo.save_quality_metrics( - group_id=group_id, - message_id=message_id, - coherence_score=coherence_score, - relevance_score=relevance_score, - engagement_score=engagement_score, - sentiment_alignment=sentiment_alignment, - calculated_at=calculated_at - ) - await session.commit() - return metrics is not None - - except Exception as e: - self._logger.error(f"保存质量指标失败: {e}", exc_info=True) - return False - - async def get_average_quality_scores( - self, - group_id: str, - start_time: Optional[float] = None, - end_time: Optional[float] = None - ) -> Dict[str, float]: - """ - 获取平均质量分数(使用 ORM) - - Args: - group_id: 群组 ID - start_time: 开始时间戳(可选) - end_time: 结束时间戳(可选) - - Returns: - Dict[str, float]: 各指标的平均分数 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回默认值") - return { - "avg_coherence_score": 0.0, - "avg_relevance_score": 0.0, - "avg_engagement_score": 0.0, - "avg_sentiment_alignment": 0.0 - } - - try: - async with self.db_engine.get_session() as session: - repo = ConversationQualityMetricsRepository(session) - return await repo.get_average_scores( - group_id=group_id, - start_time=start_time, - end_time=end_time - ) - - except Exception as e: - self._logger.error(f"获取平均质量分数失败: {e}", exc_info=True) - return { - "avg_coherence_score": 0.0, - "avg_relevance_score": 0.0, - "avg_engagement_score": 0.0, - "avg_sentiment_alignment": 0.0 - } - - async def save_context_similarity( - self, - context_hash_1: str, - context_hash_2: str, - similarity_score: float, - calculation_method: Optional[str] = None, - cached_at: Optional[float] = None - ) -> bool: - """ - 保存上下文相似度缓存(使用 ORM) - - Args: - context_hash_1: 上下文1的哈希值 - context_hash_2: 上下文2的哈希值 - similarity_score: 相似度分数 - calculation_method: 计算方法 - cached_at: 缓存时间戳 - - Returns: - bool: 是否成功 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,无法保存相似度缓存") - return False - - try: - async with self.db_engine.get_session() as session: - repo = ContextSimilarityCacheRepository(session) - cache = await repo.save_similarity( - context_hash_1=context_hash_1, - context_hash_2=context_hash_2, - similarity_score=similarity_score, - calculation_method=calculation_method, - cached_at=cached_at - ) - await session.commit() - return cache is not None - - except Exception as e: - self._logger.error(f"保存相似度缓存失败: {e}", exc_info=True) - return False - - async def get_context_similarity( - self, - context_hash_1: str, - context_hash_2: str - ) -> Optional[float]: - """ - 获取上下文相似度(使用 ORM,支持双向查找) - - Args: - context_hash_1: 上下文1的哈希值 - context_hash_2: 上下文2的哈希值 - - Returns: - Optional[float]: 相似度分数 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回 None") - return None - - try: - async with self.db_engine.get_session() as session: - repo = ContextSimilarityCacheRepository(session) - cache = await repo.get_similarity( - context_hash_1=context_hash_1, - context_hash_2=context_hash_2 - ) - return cache.similarity_score if cache else None - - except Exception as e: - self._logger.error(f"获取相似度缓存失败: {e}", exc_info=True) - return None - - # ==================== 黑话系统 ORM 方法 ==================== - - async def get_recent_jargon_list_orm( - self, - chat_id: Optional[str] = None, - limit: int = 20, - only_confirmed: bool = True - ) -> List[Dict[str, Any]]: - """ - 获取最近学习到的黑话列表(使用 ORM) - - Args: - chat_id: 群组ID (None表示获取所有) - limit: 返回数量限制 - only_confirmed: 是否只返回已确认的黑话 - - Returns: - List[Dict]: 黑话列表 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - repo = JargonRepository(session) - jargons = await repo.get_recent_jargon_list( - chat_id=chat_id, - limit=limit, - only_confirmed=only_confirmed - ) - return [jargon.to_dict() for jargon in jargons] - - except Exception as e: - self._logger.error(f"获取黑话列表失败(ORM): {e}", exc_info=True) - return [] - - async def get_jargon_statistics_orm( - self, - chat_id: Optional[str] = None - ) -> Dict[str, Any]: - """ - 获取黑话学习统计信息(使用 ORM) - - Args: - chat_id: 群组ID (None表示获取全局统计) - - Returns: - Dict[str, Any]: 统计信息字典 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回默认值") - return { - 'total_candidates': 0, - 'confirmed_jargon': 0, - 'completed_inference': 0, - 'total_occurrences': 0, - 'average_count': 0.0, - 'active_groups': 0 - } - - try: - async with self.db_engine.get_session() as session: - repo = JargonRepository(session) - return await repo.get_jargon_statistics(chat_id=chat_id) - - except Exception as e: - self._logger.error(f"获取黑话统计失败(ORM): {e}", exc_info=True) - return { - 'total_candidates': 0, - 'confirmed_jargon': 0, - 'completed_inference': 0, - 'total_occurrences': 0, - 'average_count': 0.0, - 'active_groups': 0 - } - - async def get_jargon_by_id_orm( - self, - jargon_id: int - ) -> Optional[Dict[str, Any]]: - """ - 根据ID获取黑话记录(使用 ORM) - - Args: - jargon_id: 黑话记录ID - - Returns: - Optional[Dict]: 黑话记录或None - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回 None") - return None - - try: - async with self.db_engine.get_session() as session: - repo = JargonRepository(session) - jargon = await repo.get_by_id(jargon_id) - return jargon.to_dict() if jargon else None - - except Exception as e: - self._logger.error(f"根据ID获取黑话失败(ORM): {e}", exc_info=True) - return None - - async def update_jargon_status_orm( - self, - jargon_id: int, - is_jargon: Optional[bool] = None, - is_complete: Optional[bool] = None, - meaning: Optional[str] = None - ) -> bool: - """ - 更新黑话状态(使用 ORM) - - Args: - jargon_id: 黑话ID - is_jargon: 是否为黑话 - is_complete: 是否完成推理 - meaning: 含义 - - Returns: - bool: 是否成功 - """ - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,无法更新黑话状态") - return False - - try: - async with self.db_engine.get_session() as session: - repo = JargonRepository(session) - success = await repo.update_jargon_status( - jargon_id=jargon_id, - is_jargon=is_jargon, - is_complete=is_complete, - meaning=meaning - ) - await session.commit() - return success - - except Exception as e: - self._logger.error(f"更新黑话状态失败(ORM): {e}", exc_info=True) - return False - - # ==================== 学习系统 ORM 方法 ==================== - - async def get_pending_style_reviews_orm(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取待审查的风格学习记录(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - repo = StyleLearningReviewRepository(session) - reviews = await repo.get_by_status(status='pending', limit=limit) - - # 转换为字典格式,保持与传统方法相同的格式 - result = [] - for review in reviews: - review_dict = review.to_dict() - - # 解析 learned_patterns JSON 字符串 - learned_patterns = [] - try: - if review_dict.get('learned_patterns'): - import json - learned_patterns = json.loads(review_dict['learned_patterns']) - except json.JSONDecodeError: - pass - - result.append({ - 'id': review_dict['id'], - 'type': review_dict['type'], - 'group_id': review_dict['group_id'], - 'timestamp': review_dict['timestamp'], - 'learned_patterns': learned_patterns, - 'few_shots_content': review_dict['few_shots_content'], - 'status': review_dict['status'], - 'description': review_dict['description'], - 'created_at': review_dict['created_at'] - }) - - return result - - except Exception as e: - self._logger.error(f"获取待审查风格学习记录失败(ORM): {e}", exc_info=True) - return [] - - async def update_style_review_status_orm( - self, - review_id: int, - status: str, - group_id: str = None - ) -> bool: - """更新风格学习审查状态(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化") - return False - - try: - async with self.db_engine.get_session() as session: - repo = StyleLearningReviewRepository(session) - - import time - success = await repo.update( - review_id, - status=status, - updated_at=time.time() - ) - - await session.commit() - - if success: - self._logger.info(f"更新风格学习审查状态成功(ORM): ID={review_id}, 状态={status}") - else: - self._logger.warning(f"更新风格学习审查状态失败(ORM): 未找到ID={review_id}的记录") - - return success - - except Exception as e: - self._logger.error(f"更新风格学习审查状态失败(ORM): {e}", exc_info=True) - return False - - async def get_style_progress_data_orm(self) -> List[Dict[str, Any]]: - """获取风格进度数据(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - repo = LearningBatchRepository(session) - - # 获取最近30条有质量分数和消息的学习批次 - from sqlalchemy import select, and_ - from ..models.orm import LearningBatch - - stmt = select(LearningBatch).where( - and_( - LearningBatch.quality_score.isnot(None), - LearningBatch.processed_messages > 0 - ) - ).order_by(LearningBatch.start_time.desc()).limit(30) - - result = await session.execute(stmt) - batches = list(result.scalars().all()) - - self._logger.debug(f"get_style_progress_data_orm 获取到 {len(batches)} 行数据") - - progress_data = [] - for batch in batches: - try: - progress_item = { - 'group_id': batch.group_id, - 'timestamp': float(batch.start_time) if batch.start_time else 0, - 'quality_score': float(batch.quality_score) if batch.quality_score else 0, - 'success': bool(batch.success) - } - - # 添加消息数量信息 - if batch.processed_messages is not None: - progress_item['processed_messages'] = int(batch.processed_messages) - if batch.filtered_count is not None: - progress_item['filtered_count'] = int(batch.filtered_count) - if batch.batch_name: - progress_item['batch_name'] = batch.batch_name - else: - progress_item['batch_name'] = '未命名' - - progress_data.append(progress_item) - - except Exception as row_error: - self._logger.warning(f"处理学习批次进度数据行时出错(ORM),跳过: {row_error}") - - return progress_data - - except Exception as e: - self._logger.error(f"从learning_batches表获取进度数据失败(ORM): {e}", exc_info=True) - return [] - - # ==================== 人格学习审查系统 ORM 方法 ==================== - - async def get_pending_persona_learning_reviews_orm(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取待审查的人格学习记录(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - repo = PersonaLearningReviewRepository(session) - reviews = await repo.get_pending_reviews(limit=limit) - - # 转换为字典格式,保持与传统方法相同的格式 - result = [] - for review in reviews: - review_dict = review.to_dict() - - # 解析 metadata JSON 字符串 - metadata = {} - try: - if review_dict.get('metadata'): - import json - metadata = json.loads(review_dict['metadata']) - except json.JSONDecodeError: - pass - - # 确保有proposed_content字段,如果为空则使用new_content - proposed_content = review_dict.get('proposed_content') or review_dict.get('new_content') - confidence_score = review_dict.get('confidence_score') if review_dict.get('confidence_score') is not None else 0.5 - - result.append({ - 'id': review_dict['id'], - 'timestamp': review_dict['timestamp'], - 'group_id': review_dict['group_id'], - 'update_type': review_dict['update_type'], - 'original_content': review_dict['original_content'], - 'new_content': review_dict['new_content'], - 'proposed_content': proposed_content, - 'confidence_score': confidence_score, - 'reason': review_dict['reason'], - 'status': review_dict['status'], - 'reviewer_comment': review_dict['reviewer_comment'], - 'review_time': review_dict['review_time'], - 'metadata': metadata - }) - - return result - - except Exception as e: - self._logger.error(f"获取待审查人格学习记录失败(ORM): {e}", exc_info=True) - return [] - - async def update_persona_learning_review_status_orm( - self, - review_id: int, - status: str, - comment: str = None, - modified_content: str = None - ) -> bool: - """更新人格学习审查状态(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化") - return False - - try: - async with self.db_engine.get_session() as session: - repo = PersonaLearningReviewRepository(session) - - import time - update_data = { - 'status': status, - 'review_time': time.time() - } - - if comment: - update_data['reviewer_comment'] = comment - - # 如果有修改后的内容,也要更新proposed_content和new_content字段 - if modified_content: - update_data['proposed_content'] = modified_content - update_data['new_content'] = modified_content - - success = await repo.update(review_id, **update_data) - await session.commit() - - if success: - self._logger.info(f"更新人格学习审查状态成功(ORM): ID={review_id}, 状态={status}") - else: - self._logger.warning(f"更新人格学习审查状态失败(ORM): 未找到ID={review_id}的记录") - - return success - - except Exception as e: - self._logger.error(f"更新人格学习审查状态失败(ORM): {e}", exc_info=True) - return False - - async def get_persona_learning_review_by_id_orm(self, review_id: int) -> Optional[Dict[str, Any]]: - """根据ID获取人格学习审查记录(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化") - return None - - try: - async with self.db_engine.get_session() as session: - repo = PersonaLearningReviewRepository(session) - review = await repo.get_by_id(review_id) - - if not review: - return None - - review_dict = review.to_dict() - - # 解析 metadata JSON 字符串 - metadata = {} - try: - if review_dict.get('metadata'): - import json - metadata = json.loads(review_dict['metadata']) - except json.JSONDecodeError: - pass - - # 确保有proposed_content字段 - proposed_content = review_dict.get('proposed_content') or review_dict.get('new_content') - confidence_score = review_dict.get('confidence_score') if review_dict.get('confidence_score') is not None else 0.5 - - return { - 'id': review_dict['id'], - 'timestamp': review_dict['timestamp'], - 'group_id': review_dict['group_id'], - 'update_type': review_dict['update_type'], - 'original_content': review_dict['original_content'], - 'new_content': review_dict['new_content'], - 'proposed_content': proposed_content, - 'confidence_score': confidence_score, - 'reason': review_dict['reason'], - 'status': review_dict['status'], - 'reviewer_comment': review_dict['reviewer_comment'], - 'review_time': review_dict['review_time'], - 'metadata': metadata - } - - except Exception as e: - self._logger.error(f"根据ID获取人格学习审查记录失败(ORM): {e}", exc_info=True) - return None - - async def delete_persona_learning_review_by_id_orm(self, review_id: int) -> bool: - """删除指定ID的人格学习审查记录(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化") - return False - - try: - async with self.db_engine.get_session() as session: - repo = PersonaLearningReviewRepository(session) - success = await repo.delete(review_id) - await session.commit() - - if success: - self._logger.info(f"删除人格学习审查记录成功(ORM): ID={review_id}") - else: - self._logger.warning(f"删除人格学习审查记录失败(ORM): 未找到ID={review_id}的记录") - - return success - - except Exception as e: - self._logger.error(f"删除人格学习审查记录失败(ORM): {e}", exc_info=True) - return False - - async def get_reviewed_persona_learning_updates_orm( - self, - limit: int = 50, - offset: int = 0, - status_filter: str = None - ) -> List[Dict[str, Any]]: - """获取已审查的人格学习更新记录(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - repo = PersonaLearningReviewRepository(session) - reviews = await repo.get_reviewed_updates( - limit=limit, - offset=offset, - status_filter=status_filter - ) - - # 转换为字典格式 - result = [] - for review in reviews: - review_dict = review.to_dict() - - # 解析 metadata JSON 字符串 - metadata = {} - try: - if review_dict.get('metadata'): - import json - metadata = json.loads(review_dict['metadata']) - except json.JSONDecodeError: - pass - - # 确保有proposed_content字段 - proposed_content = review_dict.get('proposed_content') or review_dict.get('new_content') - confidence_score = review_dict.get('confidence_score') if review_dict.get('confidence_score') is not None else 0.5 - - result.append({ - 'id': review_dict['id'], - 'timestamp': review_dict['timestamp'], - 'group_id': review_dict['group_id'], - 'update_type': review_dict['update_type'], - 'original_content': review_dict['original_content'], - 'new_content': review_dict['new_content'], - 'proposed_content': proposed_content, - 'confidence_score': confidence_score, - 'reason': review_dict['reason'], - 'status': review_dict['status'], - 'reviewer_comment': review_dict['reviewer_comment'], - 'review_time': review_dict['review_time'], - 'metadata': metadata - }) - - return result - - except Exception as e: - self._logger.error(f"获取已审查人格学习更新记录失败(ORM): {e}", exc_info=True) - return [] - - async def get_reviewed_style_learning_updates_orm( - self, - limit: int = 50, - offset: int = 0, - status_filter: str = None - ) -> List[Dict[str, Any]]: - """获取已审查的风格学习更新记录(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - from sqlalchemy import select, or_, func as sql_func, case - from ..models.orm import StyleLearningReview - - # 构建查询 - stmt = select(StyleLearningReview) - - # 状态过滤 - if status_filter: - stmt = stmt.where(StyleLearningReview.status == status_filter) - else: - stmt = stmt.where( - or_( - StyleLearningReview.status == 'approved', - StyleLearningReview.status == 'rejected' - ) - ) - - # 排序:使用updated_at,如果为NULL则使用timestamp - stmt = stmt.order_by( - sql_func.coalesce(StyleLearningReview.updated_at, StyleLearningReview.timestamp).desc() - ).offset(offset).limit(limit) - - result = await session.execute(stmt) - reviews = list(result.scalars().all()) - - # 转换为字典格式 - updates = [] - for review in reviews: - review_dict = review.to_dict() - - # 尝试解析learned_patterns以获取更多信息 - try: - import json - learned_patterns = json.loads(review_dict['learned_patterns']) if review_dict.get('learned_patterns') else {} - reason = learned_patterns.get('reason', '风格学习更新') - original_content = learned_patterns.get('original_content', '原始风格特征') - proposed_content = learned_patterns.get('proposed_content', review_dict.get('learned_patterns', '')) - confidence_score = learned_patterns.get('confidence_score', 0.8) - except (json.JSONDecodeError, AttributeError): - reason = review_dict.get('description', '风格学习更新') - original_content = '原始风格特征' - proposed_content = review_dict.get('learned_patterns', '无内容') - confidence_score = 0.8 - - updates.append({ - 'id': review_dict['id'], - 'group_id': review_dict['group_id'], - 'original_content': original_content, - 'proposed_content': proposed_content, - 'confidence_score': confidence_score, - 'reason': reason, - 'update_type': review_dict.get('type', 'style'), - 'timestamp': review_dict['timestamp'], - 'status': review_dict['status'], - 'reviewer_comment': None, - 'review_time': review_dict.get('updated_at', review_dict['timestamp']) - }) - - return updates - - except Exception as e: - self._logger.error(f"获取已审查风格学习更新记录失败(ORM): {e}", exc_info=True) - return [] - - async def delete_style_review_by_id_orm(self, review_id: int) -> bool: - """删除指定ID的风格学习审查记录(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化") - return False - - try: - async with self.db_engine.get_session() as session: - repo = StyleLearningReviewRepository(session) - success = await repo.delete(review_id) - await session.commit() - - if success: - self._logger.info(f"成功删除风格学习审查记录(ORM),ID: {review_id}") - else: - self._logger.warning(f"未找到要删除的风格学习审查记录(ORM),ID: {review_id}") - - return success - - except Exception as e: - self._logger.error(f"删除风格学习审查记录失败(ORM): {e}", exc_info=True) - return False - - async def search_jargon_orm( - self, - keyword: str, - chat_id: Optional[str] = None, - limit: int = 10 - ) -> List[Dict[str, Any]]: - """搜索黑话(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化,返回空列表") - return [] - - try: - async with self.db_engine.get_session() as session: - from sqlalchemy import select, and_, or_, desc - from ..models.orm import Jargon - - # 构建查询 - stmt = select(Jargon) - - if chat_id: - # 搜索指定群组的黑话 - stmt = stmt.where( - and_( - Jargon.chat_id == chat_id, - Jargon.content.like(f'%{keyword}%'), - Jargon.is_jargon == True - ) - ) - else: - # 搜索全局黑话 - stmt = stmt.where( - and_( - Jargon.content.like(f'%{keyword}%'), - Jargon.is_jargon == True, - Jargon.is_global == True - ) - ) - - stmt = stmt.order_by( - desc(Jargon.count), - desc(Jargon.updated_at) - ).limit(limit) - - result = await session.execute(stmt) - jargons = list(result.scalars().all()) - - # 转换为字典格式 - results = [] - for jargon in jargons: - results.append({ - 'id': jargon.id, - 'content': jargon.content, - 'meaning': jargon.meaning, - 'is_jargon': bool(jargon.is_jargon), - 'count': jargon.count, - 'is_complete': bool(jargon.is_complete) - }) - - return results - - except Exception as e: - self._logger.error(f"搜索黑话失败(ORM): {e}", exc_info=True) - return [] - - async def delete_jargon_by_id_orm(self, jargon_id: int) -> bool: - """根据ID删除黑话记录(使用 ORM)""" - if not self.db_engine: - self._logger.warning("DatabaseEngine 未初始化") - return False - - try: - async with self.db_engine.get_session() as session: - repo = JargonRepository(session) - success = await repo.delete(jargon_id) - await session.commit() - - if success: - self._logger.debug(f"删除黑话记录成功(ORM), ID: {jargon_id}") - - return success - - except Exception as e: - self._logger.error(f"删除黑话记录失败(ORM): {e}", exc_info=True) - return False - - diff --git a/services/embedding/__init__.py b/services/embedding/__init__.py new file mode 100644 index 0000000..5c455ba --- /dev/null +++ b/services/embedding/__init__.py @@ -0,0 +1,29 @@ +""" +Embedding provider abstraction layer. + +Provides a plugin-level ``IEmbeddingProvider`` interface that delegates to +AstrBot framework's ``EmbeddingProvider`` via a thin adapter. The factory +resolves providers by their framework-configured ``provider_id``. + +Public API:: + + from services.embedding import ( + IEmbeddingProvider, + EmbeddingResult, + EmbeddingProviderError, + EmbeddingProviderFactory, + FrameworkEmbeddingAdapter, + ) +""" + +from .base import EmbeddingProviderError, EmbeddingResult, IEmbeddingProvider +from .factory import EmbeddingProviderFactory +from .framework_adapter import FrameworkEmbeddingAdapter + +__all__ = [ + "IEmbeddingProvider", + "EmbeddingResult", + "EmbeddingProviderError", + "EmbeddingProviderFactory", + "FrameworkEmbeddingAdapter", +] diff --git a/services/embedding/base.py b/services/embedding/base.py new file mode 100644 index 0000000..0232620 --- /dev/null +++ b/services/embedding/base.py @@ -0,0 +1,86 @@ +""" +Embedding provider interface and value objects. + +Defines the abstract contract that all embedding providers must implement. +Aligned with AstrBot framework's ``EmbeddingProvider`` method signatures +to ensure seamless integration while keeping plugin-level decoupling. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List + + +@dataclass(frozen=True) +class EmbeddingResult: + """Immutable result from an embedding operation. + + Attributes: + embeddings: List of embedding vectors, one per input text. + model: The model identifier used for this embedding. + dimensions: Dimensionality of each embedding vector. + usage: Provider-specific usage metadata (e.g. token counts). + """ + + embeddings: List[List[float]] + model: str + dimensions: int + usage: Dict[str, Any] = field(default_factory=dict) + + +class IEmbeddingProvider(ABC): + """Abstract embedding provider interface. + + Method signatures are deliberately aligned with AstrBot framework's + ``EmbeddingProvider`` base class (``get_embedding``, ``get_embeddings``, + ``get_dim``) so that framework adapters can delegate with zero + transformation. + """ + + @abstractmethod + async def get_embedding(self, text: str) -> List[float]: + """Generate an embedding vector for a single text. + + Args: + text: The string to embed. + + Returns: + A single embedding vector. + + Raises: + EmbeddingProviderError: On provider communication failure. + """ + + @abstractmethod + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a batch of texts. + + Args: + texts: Non-empty list of strings to embed. + + Returns: + One embedding vector per input text, in the same order. + + Raises: + ValueError: If *texts* is empty. + EmbeddingProviderError: On provider communication failure. + """ + + @abstractmethod + def get_dim(self) -> int: + """Return the embedding dimensionality for the current model.""" + + @abstractmethod + def get_model_name(self) -> str: + """Return the model identifier string.""" + + async def close(self) -> None: + """Release any resources held by the provider. + + Default implementation is a no-op. Subclasses that manage + HTTP sessions or other resources should override this method. + """ + + +class EmbeddingProviderError(Exception): + """Raised when an embedding provider encounters an unrecoverable error.""" diff --git a/services/embedding/factory.py b/services/embedding/factory.py new file mode 100644 index 0000000..c0a3f6e --- /dev/null +++ b/services/embedding/factory.py @@ -0,0 +1,98 @@ +""" +Embedding provider factory. + +Creates the appropriate ``IEmbeddingProvider`` implementation by looking up +the AstrBot framework's provider registry using a configured ``provider_id``. + +This follows the same pattern as the plugin's ``FrameworkLLMAdapter``: +``context.get_provider_by_id(provider_id)`` → framework provider instance → +wrapped in a thin adapter. +""" + +from typing import Optional + +from astrbot.api import logger +from astrbot.core.provider.provider import EmbeddingProvider + +from .base import IEmbeddingProvider +from .framework_adapter import FrameworkEmbeddingAdapter + + +class EmbeddingProviderFactory: + """Factory for creating embedding provider instances. + + Usage:: + + provider = EmbeddingProviderFactory.create(config, context) + if provider: + vec = await provider.get_embedding("hello") + """ + + @staticmethod + def create(config, context) -> Optional[IEmbeddingProvider]: + """Create an embedding provider from plugin configuration. + + Args: + config: ``PluginConfig`` instance. Expected field: + - ``embedding_provider_id``: AstrBot provider ID string. + context: AstrBot plugin context (provides ``get_provider_by_id``). + + Returns: + An ``IEmbeddingProvider`` instance, or ``None`` if embedding is + not configured. + """ + provider_id = getattr(config, "embedding_provider_id", None) + + if not provider_id: + logger.debug( + "[EmbeddingFactory] No embedding_provider_id configured, " + "embedding features disabled" + ) + return None + + if context is None: + logger.warning( + "[EmbeddingFactory] AstrBot context is None, " + "cannot resolve embedding provider" + ) + return None + + return EmbeddingProviderFactory._resolve_framework_provider( + provider_id, context + ) + + @staticmethod + def _resolve_framework_provider( + provider_id: str, context + ) -> Optional[IEmbeddingProvider]: + """Resolve the framework provider by ID and wrap in adapter.""" + try: + provider = context.get_provider_by_id(provider_id) + except Exception as exc: + logger.warning( + f"[EmbeddingFactory] Failed to look up provider " + f"'{provider_id}': {exc}" + ) + return None + + if provider is None: + logger.warning( + f"[EmbeddingFactory] Provider '{provider_id}' not found " + f"in framework registry" + ) + return None + + if not isinstance(provider, EmbeddingProvider): + logger.warning( + f"[EmbeddingFactory] Provider '{provider_id}' is " + f"{type(provider).__name__}, expected EmbeddingProvider" + ) + return None + + adapter = FrameworkEmbeddingAdapter(provider) + logger.info( + f"[EmbeddingFactory] Resolved embedding provider: " + f"id={provider_id}, model={adapter.get_model_name()}, " + f"dim={adapter.get_dim()}" + ) + return adapter diff --git a/services/embedding/framework_adapter.py b/services/embedding/framework_adapter.py new file mode 100644 index 0000000..05e4642 --- /dev/null +++ b/services/embedding/framework_adapter.py @@ -0,0 +1,104 @@ +""" +Framework embedding adapter. + +Thin adapter that wraps AstrBot's ``EmbeddingProvider`` instance behind the +plugin's ``IEmbeddingProvider`` interface. All heavy lifting (HTTP calls, +batching, retries, connection pooling) is delegated to the framework provider. + +Usage:: + + from astrbot.core.provider.provider import EmbeddingProvider + + framework_provider: EmbeddingProvider = context.get_provider_by_id(pid) + adapter = FrameworkEmbeddingAdapter(framework_provider) + vec = await adapter.get_embedding("hello world") +""" + +from typing import List + +from astrbot.api import logger +from astrbot.core.provider.provider import EmbeddingProvider + +from .base import IEmbeddingProvider, EmbeddingProviderError + + +class FrameworkEmbeddingAdapter(IEmbeddingProvider): + """Adapter bridging AstrBot ``EmbeddingProvider`` → plugin ``IEmbeddingProvider``. + + This class owns no HTTP resources; it simply delegates to the framework + provider instance which manages its own lifecycle. + + Args: + provider: A fully-initialised AstrBot ``EmbeddingProvider`` instance. + """ + + def __init__(self, provider: EmbeddingProvider) -> None: + if provider is None: + raise ValueError("provider must not be None") + self._provider = provider + + # IEmbeddingProvider implementation + + async def get_embedding(self, text: str) -> List[float]: + try: + return await self._provider.get_embedding(text) + except Exception as exc: + raise EmbeddingProviderError( + f"Framework embedding call failed: {exc}" + ) from exc + + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: + if not texts: + raise ValueError("texts must be a non-empty list") + try: + return await self._provider.get_embeddings(texts) + except Exception as exc: + raise EmbeddingProviderError( + f"Framework batch embedding call failed: {exc}" + ) from exc + + def get_dim(self) -> int: + return self._provider.get_dim() + + def get_model_name(self) -> str: + return self._provider.get_model() + + async def close(self) -> None: + # Framework manages its own provider lifecycle; nothing to release. + pass + + # Extended helpers (delegated to framework) + + async def get_embeddings_batch( + self, + texts: List[str], + batch_size: int = 16, + tasks_limit: int = 3, + max_retries: int = 3, + progress_callback=None, + ) -> List[List[float]]: + """Batch embedding with framework-level retry and progress tracking. + + Delegates to ``EmbeddingProvider.get_embeddings_batch`` which + implements semaphore-controlled concurrency and exponential backoff. + """ + try: + return await self._provider.get_embeddings_batch( + texts, + batch_size=batch_size, + tasks_limit=tasks_limit, + max_retries=max_retries, + progress_callback=progress_callback, + ) + except Exception as exc: + raise EmbeddingProviderError( + f"Framework batch embedding failed: {exc}" + ) from exc + + @property + def provider_id(self) -> str: + """Return the framework provider's unique identifier.""" + try: + return self._provider.meta().id + except (ValueError, KeyError): + return "" diff --git a/services/enhanced_affection_manager.py b/services/enhanced_affection_manager.py deleted file mode 100644 index f91200d..0000000 --- a/services/enhanced_affection_manager.py +++ /dev/null @@ -1,411 +0,0 @@ -""" -增强型好感度管理服务 -使用 CacheManager 和 Repository 模式,与现有接口兼容 -""" -import asyncio -import random -import time -from typing import Dict, List, Optional, Any -from datetime import datetime, timedelta -from enum import Enum - -from astrbot.api import logger - -from ..config import PluginConfig -from ..core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage -from ..utils.cache_manager import get_cache_manager, async_cached -from ..utils.task_scheduler import get_task_scheduler - -# 导入 Repository -from ..repositories import ( - AffectionRepository, - InteractionRepository, - ConversationHistoryRepository, - DiversityRepository -) - -# 导入原有的枚举和数据类 -from .affection_manager import ( - MoodType, - InteractionType, - BotMood, - UserAffection as OriginalUserAffection -) - - -class EnhancedAffectionManager(AsyncServiceBase): - """ - 增强型好感度管理服务 - - 改进: - 1. 使用 CacheManager 替代手动字典缓存 - 2. 使用 Repository 访问数据库 - 3. 使用 TaskScheduler 管理定时任务 - 4. 保持与原有接口的兼容性 - - 用法: - # 在配置中启用 - config.use_enhanced_managers = True - - # 创建管理器 - affection_mgr = EnhancedAffectionManager(config, db_manager, llm_adapter) - await affection_mgr.start() - """ - - def __init__( - self, - config: PluginConfig, - database_manager: IDataStorage, - llm_adapter=None - ): - super().__init__("enhanced_affection_manager") - self.config = config - self.db_manager = database_manager - self.llm_adapter = llm_adapter - - # 使用统一的缓存管理器 - self.cache = get_cache_manager() - - # 使用统一的任务调度器 - self.scheduler = get_task_scheduler() - - # 预定义的情绪描述模板(保持原有逻辑) - self.mood_descriptions = self._init_mood_descriptions() - - # 好感度变化规则(保持原有逻辑) - self.affection_rules = self._init_affection_rules() - - self._logger.info("[增强型好感度] 初始化完成(使用缓存管理器)") - - async def _do_start(self) -> bool: - """启动好感度管理服务""" - try: - # 启动任务调度器 - await self.scheduler.start() - - # 为所有活跃群组设置初始随机情绪(如果启用) - if self.config.enable_startup_random_mood: - await self._initialize_random_moods_for_active_groups() - - # 启动每日情绪更新任务(使用调度器) - if self.config.enable_daily_mood: - self.scheduler.add_cron_job( - self._daily_mood_update_task, - job_id='affection_daily_mood', - hour=0, # 每天凌晨0点 - minute=0 - ) - - self._logger.info("✅ [增强型好感度] 启动成功") - return True - - except Exception as e: - self._logger.error(f"❌ [增强型好感度] 启动失败: {e}") - return False - - async def _do_stop(self) -> bool: - """停止好感度管理服务""" - try: - # 移除定时任务 - self.scheduler.remove_job('affection_daily_mood') - - # 清除缓存 - self.cache.clear('affection') - - self._logger.info("✅ [增强型好感度] 已停止") - return True - - except Exception as e: - self._logger.error(f"❌ [增强型好感度] 停止失败: {e}") - return False - - # ============================================================ - # 使用缓存装饰器的方法 - # ============================================================ - - @async_cached( - cache_name='affection', - key_func=lambda self, group_id, user_id: f"affection:{group_id}:{user_id}" - ) - async def get_user_affection( - self, - group_id: str, - user_id: str - ) -> Optional[OriginalUserAffection]: - """ - 获取用户好感度(带缓存) - - Args: - group_id: 群组 ID - user_id: 用户 ID - - Returns: - Optional[UserAffection]: 好感度对象 - """ - try: - # 从数据库获取 - affection_data = await self.db_manager.get_user_affection( - group_id, - user_id - ) - - if affection_data: - return OriginalUserAffection( - user_id=user_id, - group_id=group_id, - affection_level=affection_data['affection_level'], - last_interaction=affection_data.get('updated_at', time.time()), - interaction_count=affection_data.get('interaction_count', 0) - ) - return None - - except Exception as e: - self._logger.error(f"[增强型好感度] 获取好感度失败: {e}") - return None - - async def update_user_affection( - self, - group_id: str, - user_id: str, - affection_delta: int, - interaction_type: str = None - ) -> bool: - """ - 更新用户好感度(自动清除缓存) - - Args: - group_id: 群组 ID - user_id: 用户 ID - affection_delta: 好感度变化量 - interaction_type: 交互类型 - - Returns: - bool: 是否更新成功 - """ - try: - # 更新数据库 - success = await self.db_manager.update_user_affection( - group_id, - user_id, - affection_delta - ) - - if success: - # 清除缓存 - cache_key = f"affection:{group_id}:{user_id}" - self.cache.delete('affection', cache_key) - - self._logger.debug( - f"[增强型好感度] 更新成功: {group_id}:{user_id} " - f"变化={affection_delta}, 已清除缓存" - ) - - return success - - except Exception as e: - self._logger.error(f"[增强型好感度] 更新好感度失败: {e}") - return False - - @async_cached( - cache_name='affection', - key_func=lambda self, group_id: f"mood:{group_id}" - ) - async def get_current_mood(self, group_id: str) -> Optional[BotMood]: - """ - 获取当前情绪(带缓存) - - Args: - group_id: 群组 ID - - Returns: - Optional[BotMood]: 情绪对象 - """ - try: - # 从数据库加载 - mood_data = await self.db_manager.get_current_bot_mood(group_id) - - if mood_data: - mood = BotMood( - mood_type=MoodType(mood_data['mood_type']), - intensity=mood_data['mood_intensity'], - description=mood_data['mood_description'], - start_time=mood_data['created_at'], - duration_hours=mood_data.get('duration_hours', 24) - ) - - # 检查是否过期 - if mood.is_active(): - return mood - else: - # 过期则清除缓存 - cache_key = f"mood:{group_id}" - self.cache.delete('affection', cache_key) - - return None - - except Exception as e: - self._logger.error(f"[增强型好感度] 获取情绪失败: {e}") - return None - - async def set_daily_mood( - self, - group_id: str, - mood_type: MoodType = None, - intensity: float = None - ) -> BotMood: - """ - 设置每日情绪(自动清除缓存) - - Args: - group_id: 群组 ID - mood_type: 情绪类型(None 则随机) - intensity: 情绪强度(None 则随机) - - Returns: - BotMood: 新的情绪对象 - """ - try: - # 随机选择情绪 - if mood_type is None: - mood_type = random.choice(list(MoodType)) - - if intensity is None: - intensity = random.uniform(0.5, 1.0) - - # 获取情绪描述 - description = self._get_mood_description(mood_type, intensity) - - # 保存到数据库 - await self.db_manager.save_bot_mood( - group_id, - mood_type.value, - intensity, - description, - duration_hours=24 - ) - - # 创建情绪对象 - mood = BotMood( - mood_type=mood_type, - intensity=intensity, - description=description, - start_time=time.time(), - duration_hours=24 - ) - - # 清除缓存 - cache_key = f"mood:{group_id}" - self.cache.delete('affection', cache_key) - - self._logger.info( - f"[增强型好感度] 设置每日情绪: {group_id} -> " - f"{mood_type.value} ({intensity:.2f})" - ) - - return mood - - except Exception as e: - self._logger.error(f"[增强型好感度] 设置情绪失败: {e}") - return None - - # ============================================================ - # 任务调度方法 - # ============================================================ - - async def _daily_mood_update_task(self): - """每日情绪更新任务(由调度器调用)""" - try: - self._logger.info("[增强型好感度] 执行每日情绪更新...") - - # 获取所有活跃群组 - # TODO: 需要从数据库获取活跃群组列表 - # 暂时使用示例实现 - active_groups = [] # await self.db_manager.get_active_groups() - - for group_id in active_groups: - await self.set_daily_mood(group_id) - - self._logger.info( - f"[增强型好感度] 每日情绪更新完成," - f"共更新 {len(active_groups)} 个群组" - ) - - except Exception as e: - self._logger.error(f"[增强型好感度] 每日情绪更新失败: {e}") - - # ============================================================ - # 辅助方法(保持原有逻辑) - # ============================================================ - - def _init_mood_descriptions(self) -> Dict[MoodType, List[str]]: - """初始化情绪描述模板""" - return { - MoodType.HAPPY: [ - "今天心情特别好~", - "感觉一切都很美好呢", - "今天充满了正能量!" - ], - MoodType.SAD: [ - "今天有点不开心...", - "心情有些低落", - "感觉有点难过" - ], - MoodType.EXCITED: [ - "今天超级兴奋!", - "感觉浑身充满了活力!", - "好激动啊!" - ], - # ... 其他情绪 - } - - def _init_affection_rules(self) -> Dict[str, int]: - """初始化好感度变化规则""" - return { - InteractionType.CHAT.value: 1, - InteractionType.COMPLIMENT.value: 5, - InteractionType.FLIRT.value: 3, - InteractionType.COMFORT.value: 4, - InteractionType.HELP.value: 3, - InteractionType.THANKS.value: 2, - InteractionType.CARE.value: 4, - InteractionType.GIFT.value: 10, - InteractionType.INSULT.value: -10, - InteractionType.HARASSMENT.value: -15, - InteractionType.ABUSE.value: -20, - # ... 其他规则 - } - - def _get_mood_description( - self, - mood_type: MoodType, - intensity: float - ) -> str: - """获取情绪描述""" - descriptions = self.mood_descriptions.get(mood_type, ["心情一般"]) - return random.choice(descriptions) - - async def _initialize_random_moods_for_active_groups(self): - """为活跃群组初始化随机情绪""" - try: - # TODO: 从数据库获取活跃群组 - # active_groups = await self.db_manager.get_active_groups() - # for group_id in active_groups: - # await self.set_daily_mood(group_id) - pass - - except Exception as e: - self._logger.error(f"[增强型好感度] 初始化随机情绪失败: {e}") - - # ============================================================ - # 缓存统计方法 - # ============================================================ - - def get_cache_stats(self) -> dict: - """获取缓存统计信息""" - return self.cache.get_stats('affection') - - def clear_cache(self): - """清除所有缓存""" - self.cache.clear('affection') - self._logger.info("[增强型好感度] 已清除所有缓存") diff --git a/services/hooks/__init__.py b/services/hooks/__init__.py new file mode 100644 index 0000000..44b9960 --- /dev/null +++ b/services/hooks/__init__.py @@ -0,0 +1 @@ +"""LLM hook processing — context providers and hook handler.""" \ No newline at end of file diff --git a/services/hooks/llm_hook_handler.py b/services/hooks/llm_hook_handler.py new file mode 100644 index 0000000..5f0de3b --- /dev/null +++ b/services/hooks/llm_hook_handler.py @@ -0,0 +1,375 @@ +"""LLM Hook handler — parallel context retrieval, prompt injection, performance tracking. + +Orchestrates all context providers (social, V2, diversity, jargon, few-shot, session updates) +in parallel, merges results, and injects them into the LLM request via +``extra_user_content_parts`` to preserve system_prompt prefix caching. +""" + +import asyncio +import time +from typing import Any, Dict, List, Optional + +from astrbot.api import logger +from astrbot.api.event import AstrMessageEvent +from astrbot.core.agent.message import TextPart + +from .perf_tracker import PerfTracker + + +class LLMHookHandler: + """Orchestrate LLM Hook context injection. + + Runs all context providers in parallel via ``asyncio.gather``, merges + results in priority order, and records timing data. + + Args: + plugin_config: Plugin configuration object. + diversity_manager: Diversity prompt builder service. + social_context_injector: Social context injector service. + v2_integration: V2 learning integration service. + jargon_query_service: Jargon query service. + temporary_persona_updater: Session-level persona updater. + perf_tracker: ``PerfTracker`` for recording timing samples. + group_id_to_unified_origin: Shared mapping from group_id to UMO. + db_manager: Database manager for approved few-shot retrieval. + """ + + def __init__( + self, + plugin_config: Any, + diversity_manager: Any, + social_context_injector: Any, + v2_integration: Any, + jargon_query_service: Any, + temporary_persona_updater: Any, + perf_tracker: PerfTracker, + group_id_to_unified_origin: Dict[str, str], + db_manager: Any = None, + ) -> None: + self._config = plugin_config + self._diversity_manager = diversity_manager + self._social_context_injector = social_context_injector + self._v2_integration = v2_integration + self._jargon_query_service = jargon_query_service + self._temporary_persona_updater = temporary_persona_updater + self._perf_tracker = perf_tracker + self._group_id_to_unified_origin = group_id_to_unified_origin + self._db_manager = db_manager + + # Public API + + async def handle(self, event: AstrMessageEvent, req: Any) -> None: + """Process an LLM request hook — inject context into *req*.""" + hook_start = time.time() + social_ms = v2_ms = diversity_ms = jargon_ms = few_shots_ms = 0.0 + + try: + if req is None: + logger.warning("[LLM Hook] req 参数为 None,跳过注入") + return + + if not self._diversity_manager: + logger.debug("[LLM Hook] diversity_manager未初始化,跳过多样性注入") + return + + group_id = event.get_group_id() or event.get_sender_id() + user_id = event.get_sender_id() + + # Maintain group_id → unified_msg_origin mapping + if hasattr(event, "unified_msg_origin") and event.unified_msg_origin: + self._group_id_to_unified_origin[group_id] = event.unified_msg_origin + logger.debug(f"[LLM Hook] 更新映射: {group_id} -> {event.unified_msg_origin}") + + if not req.prompt: + logger.debug("[LLM Hook] req.prompt为空,跳过多样性注入") + return + + original_prompt_length = len(req.prompt) + logger.info( + f"[LLM Hook] 开始注入多样性增强 " + f"(group: {group_id}, 原prompt长度: {original_prompt_length})" + ) + + prompt_injections: List[str] = [] + logger.debug("[LLM Hook] 跳过基础人格注入(框架已处理),专注于增量内容") + + # Parallel context retrieval + social_result: Optional[str] = None + v2_result: Optional[Dict[str, Any]] = None + diversity_result: Optional[str] = None + jargon_result: Optional[str] = None + few_shots_result: Optional[str] = None + + async def _timed_social() -> None: + nonlocal social_result, social_ms + t0 = time.time() + social_result = await self._fetch_social(group_id, user_id) + social_ms = (time.time() - t0) * 1000 + + async def _timed_v2() -> None: + nonlocal v2_result, v2_ms + t0 = time.time() + v2_result = await self._fetch_v2(req.prompt, group_id) + v2_ms = (time.time() - t0) * 1000 + + async def _timed_diversity() -> None: + nonlocal diversity_result, diversity_ms + t0 = time.time() + diversity_result = await self._fetch_diversity(group_id) + diversity_ms = (time.time() - t0) * 1000 + + async def _timed_jargon() -> None: + nonlocal jargon_result, jargon_ms + t0 = time.time() + jargon_result = await self._fetch_jargon(event, group_id) + jargon_ms = (time.time() - t0) * 1000 + + async def _timed_few_shots() -> None: + nonlocal few_shots_result, few_shots_ms + t0 = time.time() + few_shots_result = await self._fetch_few_shots(group_id) + few_shots_ms = (time.time() - t0) * 1000 + + await asyncio.gather( + _timed_social(), + _timed_v2(), + _timed_diversity(), + _timed_jargon(), + _timed_few_shots(), + ) + + # Merge results in priority order + self._collect_social(social_result, group_id, prompt_injections) + self._collect_v2(v2_result, v2_ms, prompt_injections) + self._collect_diversity(diversity_result, prompt_injections) + self._collect_jargon(jargon_result, prompt_injections) + self._collect_few_shots(few_shots_result, prompt_injections) + self._collect_session_updates(group_id, prompt_injections) + + # Inject into request + if prompt_injections: + self._inject(req, prompt_injections, hook_start) + else: + logger.debug("[LLM Hook] 没有可注入的增量内容") + + # Record perf data + total_ms = (time.time() - hook_start) * 1000 + self._perf_tracker.record( + { + "ts": time.time(), + "total_ms": round(total_ms, 1), + "social_ctx_ms": round(social_ms, 1), + "v2_ctx_ms": round(v2_ms, 1), + "diversity_ms": round(diversity_ms, 1), + "jargon_ms": round(jargon_ms, 1), + "few_shots_ms": round(few_shots_ms, 1), + "group_id": group_id, + } + ) + + except Exception as e: + logger.error(f"[LLM Hook] 框架层面注入多样性失败: {e}", exc_info=True) + + # Context fetchers + + async def _fetch_social( + self, group_id: str, user_id: str + ) -> Optional[str]: + if not self._social_context_injector: + logger.debug("[LLM Hook] social_context_injector未初始化,跳过社交上下文注入") + return None + try: + return await self._social_context_injector.format_complete_context( + group_id=group_id, + user_id=user_id, + include_social_relations=self._config.include_social_relations, + include_affection=self._config.include_affection_info, + include_mood=False, + include_expression_patterns=True, + include_psychological=True, + include_behavior_guidance=True, + include_conversation_goal=self._config.enable_goal_driven_chat, + enable_protection=True, + ) + except Exception as e: + logger.warning(f"[LLM Hook] 注入社交上下文失败: {e}") + return None + + async def _fetch_v2( + self, prompt: str, group_id: str + ) -> Optional[Dict[str, Any]]: + if not self._v2_integration: + return None + try: + return await self._v2_integration.get_enhanced_context(prompt, group_id) + except Exception as e: + logger.debug(f"[LLM Hook] V2 context retrieval failed: {e}") + return None + + async def _fetch_diversity(self, group_id: str) -> Optional[str]: + try: + content = await self._diversity_manager.build_diversity_prompt_injection( + "", + group_id=group_id, + inject_style=True, + inject_pattern=True, + inject_variation=True, + inject_history=True, + ) + return content.strip() if content else None + except Exception as e: + logger.warning(f"[LLM Hook] 多样性增强失败: {e}") + return None + + async def _fetch_jargon( + self, event: AstrMessageEvent, group_id: str + ) -> Optional[str]: + if not self._jargon_query_service: + logger.debug("[LLM Hook] jargon_query_service未初始化,跳过黑话注入") + return None + try: + user_message = ( + event.message_str + if hasattr(event, "message_str") + else str(event.get_message()) + ) + return await self._jargon_query_service.check_and_explain_jargon( + text=user_message, chat_id=group_id + ) + except Exception as e: + logger.warning(f"[LLM Hook] 注入黑话理解失败: {e}") + return None + + async def _fetch_few_shots(self, group_id: str) -> Optional[str]: + """Fetch approved few-shot dialogue content for the given group.""" + if not self._db_manager: + return None + try: + contents = await self._db_manager.get_approved_few_shots(group_id, limit=3) + if contents: + return contents[0] + except Exception as e: + logger.warning(f"[LLM Hook] Failed to fetch approved few-shots: {e}") + return None + + # Result collectors + + @staticmethod + def _collect_social( + result: Optional[str], group_id: str, out: List[str] + ) -> None: + if result: + out.append(result) + logger.info(f"[LLM Hook] 已准备完整社交上下文 (长度: {len(result)})") + else: + logger.debug(f"[LLM Hook] 群组 {group_id} 暂无社交上下文") + + @staticmethod + def _collect_v2( + result: Optional[Dict[str, Any]], ms: float, out: List[str] + ) -> None: + if not result: + return + v2_parts: List[str] = [] + if result.get("knowledge_context"): + v2_parts.append(f"[Related Knowledge]\n{result['knowledge_context']}") + if result.get("related_memories"): + memories_text = "\n".join(result["related_memories"][:5]) + v2_parts.append(f"[Related Memories]\n{memories_text}") + if result.get("few_shot_examples"): + examples_text = "\n".join(result["few_shot_examples"][:3]) + v2_parts.append(f"[Style Examples]\n{examples_text}") + if v2_parts: + out.append("\n\n".join(v2_parts)) + logger.info(f"[LLM Hook] V2 context injected ({len(v2_parts)} sections, {ms:.0f}ms)") + else: + logger.debug(f"[LLM Hook] V2 context empty ({ms:.0f}ms)") + + @staticmethod + def _collect_diversity(result: Optional[str], out: List[str]) -> None: + if result: + out.append(result) + logger.info(f"[LLM Hook] 已准备多样性增强内容 (长度: {len(result)})") + + @staticmethod + def _collect_jargon(result: Optional[str], out: List[str]) -> None: + if result: + out.append(result) + logger.info(f"[LLM Hook] 已准备黑话理解内容 (长度: {len(result)})") + else: + logger.debug("[LLM Hook] 用户消息中未检测到已知黑话") + + @staticmethod + def _collect_few_shots(result: Optional[str], out: List[str]) -> None: + if result: + out.append(f"[Few-Shot Dialogue Examples]\n{result}") + logger.info(f"[LLM Hook] Few-shot dialogue injected (len={len(result)})") + else: + logger.debug("[LLM Hook] No approved few-shot dialogues available") + + def _collect_session_updates( + self, group_id: str, out: List[str] + ) -> None: + if not self._temporary_persona_updater: + logger.debug("[LLM Hook] temporary_persona_updater未初始化,跳过会话级更新注入") + return + try: + session_updates = self._temporary_persona_updater.session_updates.get( + group_id, [] + ) + if session_updates: + updates_text = "\n\n".join(session_updates) + out.append(updates_text) + logger.info( + f"[LLM Hook] 已准备会话级更新 " + f"(会话: {group_id}, 更新数: {len(session_updates)}, " + f"长度: {len(updates_text)})" + ) + else: + logger.debug(f"[LLM Hook] 会话 {group_id} 暂无增量更新") + except Exception as e: + logger.warning(f"[LLM Hook] 注入会话级更新失败: {e}") + + # Injection + + def _inject( + self, req: Any, injections: List[str], hook_start: float + ) -> None: + injection_text = "\n\n".join(injections) + + # Use AstrBot's extra_user_content_parts API to inject context. + # This keeps system_prompt stable for LLM API prefix caching, + # while appending dynamic context as extra content blocks after + # the user message. + if hasattr(req, "extra_user_content_parts"): + req.extra_user_content_parts.append( + TextPart(text=f"\n{injection_text}\n") + ) + logger.info( + f"[LLM Hook] extra_user_content_parts 注入完成 - " + f"新增: {len(injection_text)} chars" + ) + else: + # Fallback for older AstrBot versions without extra_user_content_parts + if not req.system_prompt: + req.system_prompt = "" + req.system_prompt += "\n\n" + injection_text + logger.info( + f"[LLM Hook] system_prompt fallback 注入完成 - " + f"新增: {len(injection_text)} chars" + ) + logger.warning( + "[LLM Hook] 当前 AstrBot 版本不支持 extra_user_content_parts," + "回退到 system_prompt 注入(会影响缓存命中率)" + ) + + current_style = self._diversity_manager.get_current_style() + current_pattern = self._diversity_manager.get_current_pattern() + logger.info( + f"[LLM Hook] 当前语言风格: {current_style}, 回复模式: {current_pattern}" + ) + logger.info( + f"[LLM Hook] 注入内容数量: {len(injections)}项, " + f"耗时: {time.time() - hook_start:.3f}s" + ) + logger.debug(f"[LLM Hook] 注入内容预览: {injection_text[:200]}...") diff --git a/services/hooks/perf_tracker.py b/services/hooks/perf_tracker.py new file mode 100644 index 0000000..4c23fd5 --- /dev/null +++ b/services/hooks/perf_tracker.py @@ -0,0 +1,69 @@ +"""Ring-buffer performance tracker for LLM hook timing. + +Collects per-request timing samples and maintains rolling-average +statistics. Designed to be referenced by the WebUI ServiceContainer +as ``perf_collector``. +""" + +import time +from collections import deque +from typing import Any, Dict, List + + +class PerfTracker: + """Collects LLM hook timing data in a fixed-size ring buffer. + + Usage:: + + tracker = PerfTracker(maxlen=200) + tracker.record({"total_ms": 123, "social_ctx_ms": 45, ...}) + data = tracker.get_perf_data(recent_limit=50) + """ + + _TIMING_KEYS = ( + "total_ms", + "social_ctx_ms", + "v2_ctx_ms", + "diversity_ms", + "jargon_ms", + ) + + def __init__(self, maxlen: int = 200) -> None: + self._samples: deque = deque(maxlen=maxlen) + self._stats: Dict[str, Any] = { + "total_requests": 0, + "avg_total_ms": 0, + "avg_social_ctx_ms": 0, + "avg_v2_ctx_ms": 0, + "avg_diversity_ms": 0, + "avg_jargon_ms": 0, + "max_total_ms": 0, + "last_updated": 0, + } + + def record(self, sample: Dict[str, Any]) -> None: + """Append a timing sample and update rolling statistics.""" + self._samples.append(sample) + self._update_stats(sample) + + def get_perf_data(self, recent_limit: int = 50) -> Dict[str, Any]: + """Return aggregated stats plus the most recent samples.""" + samples: List[Dict[str, Any]] = list(self._samples)[-recent_limit:] + stats = { + k: round(v, 1) if isinstance(v, float) else v + for k, v in self._stats.items() + } + stats["recent_samples"] = samples + return stats + + def _update_stats(self, sample: Dict[str, Any]) -> None: + """Update rolling averages using Welford's online algorithm.""" + s = self._stats + n = s["total_requests"] + 1 + for key in self._TIMING_KEYS: + avg_key = f"avg_{key}" + s[avg_key] = s[avg_key] + (sample.get(key, 0) - s[avg_key]) / n + if sample.get("total_ms", 0) > s["max_total_ms"]: + s["max_total_ms"] = sample["total_ms"] + s["total_requests"] = n + s["last_updated"] = time.time() diff --git a/services/integration/__init__.py b/services/integration/__init__.py new file mode 100644 index 0000000..b182372 --- /dev/null +++ b/services/integration/__init__.py @@ -0,0 +1,23 @@ +"""External integrations -- MaiBot, knowledge graphs, memory engines.""" + +from .maibot_integration_factory import MaiBotIntegrationFactory +from .maibot_adapters import MaiBotStyleAnalyzer, MaiBotLearningStrategy, MaiBotQualityMonitor +from .maibot_enhanced_learning_manager import MaiBotEnhancedLearningManager +from .exemplar_library import ExemplarLibrary +from .knowledge_graph_manager import KnowledgeGraphManager +from .lightrag_knowledge_manager import LightRAGKnowledgeManager +from .mem0_memory_manager import Mem0MemoryManager +from .training_data_exporter import TrainingDataExporter + +__all__ = [ + "MaiBotIntegrationFactory", + "MaiBotStyleAnalyzer", + "MaiBotLearningStrategy", + "MaiBotQualityMonitor", + "MaiBotEnhancedLearningManager", + "ExemplarLibrary", + "KnowledgeGraphManager", + "LightRAGKnowledgeManager", + "Mem0MemoryManager", + "TrainingDataExporter", +] diff --git a/services/integration/exemplar_library.py b/services/integration/exemplar_library.py new file mode 100644 index 0000000..ccae482 --- /dev/null +++ b/services/integration/exemplar_library.py @@ -0,0 +1,363 @@ +""" +Few-shot exemplar library. + +Stores high-quality message examples and retrieves them via cosine +similarity for few-shot style imitation in LLM prompts. + +When an ``IEmbeddingProvider`` is available, exemplars are embedded and +similarity search uses vector cosine distance. Without an embedding +provider the library degrades to recency-weighted random sampling. + +Design notes: + - Embedding vectors stored as JSON text columns for DB portability. + - Cosine similarity computed in Python (numpy) during retrieval. + - Weight field supports feedback-driven quality adjustment. + - Thread-safe for single-event-loop asyncio usage. +""" + +import json +import time +from typing import Any, Dict, List, Optional + +from astrbot.api import logger +from sqlalchemy import case, delete, desc, select, update +from sqlalchemy.sql import func + +from ...models.orm.exemplar import Exemplar + + +# Minimum content length to accept as an exemplar. +_MIN_CONTENT_LENGTH = 10 + +# Maximum exemplars stored per group (FIFO eviction of lowest-weight). +_MAX_EXEMPLARS_PER_GROUP = 500 + +# Default number of few-shot examples to retrieve. +_DEFAULT_TOP_K = 5 + + +class ExemplarLibrary: + """Few-shot style exemplar library. + + Usage:: + + library = ExemplarLibrary(db_manager, embedding_provider) + await library.add_exemplar("nice message", group_id, sender_id) + examples = await library.get_few_shot_examples("query", group_id) + """ + + _schema_migrated = False # class-level flag: run migration once per process + + def __init__(self, db_manager, embedding_provider=None) -> None: + """Initialise the exemplar library. + + Args: + db_manager: SQLAlchemy database manager with ``get_session()``. + embedding_provider: Optional ``IEmbeddingProvider`` for vector + similarity search. When ``None``, falls back to + weight-based random sampling. + """ + self._db = db_manager + self._embedding = embedding_provider + + # Public API + + async def add_exemplar( + self, + content: str, + group_id: str, + sender_id: Optional[str] = None, + ) -> Optional[int]: + """Store a high-quality message as a style exemplar. + + Args: + content: The original message text. + group_id: Chat group identifier. + sender_id: Message sender identifier (optional). + + Returns: + The record ID if saved, or ``None`` if rejected. + """ + # One-time schema migration for existing MySQL tables (TEXT → MEDIUMTEXT). + if not ExemplarLibrary._schema_migrated: + await self._migrate_embedding_column() + ExemplarLibrary._schema_migrated = True + + if not content or len(content.strip()) < _MIN_CONTENT_LENGTH: + return None + + content = content.strip() + now = int(time.time()) + + # Compute embedding if provider is available. + embedding_json = None + dimensions = 0 + if self._embedding: + try: + vec = await self._embedding.get_embedding(content) + embedding_json = json.dumps(vec) + dimensions = len(vec) + except Exception as exc: + logger.debug( + f"[ExemplarLibrary] Embedding failed for exemplar, " + f"storing without vector: {exc}" + ) + + try: + async with self._db.get_session() as session: + record = Exemplar( + content=content, + sender_id=sender_id, + group_id=group_id, + embedding_json=embedding_json, + weight=1.0, + dimensions=dimensions, + created_at=now, + updated_at=now, + ) + session.add(record) + await session.flush() + record_id = record.id + await session.commit() + + # Evict excess exemplars if over capacity. + await self._evict_excess(session, group_id) + + return record_id + + except Exception as exc: + logger.warning(f"[ExemplarLibrary] Failed to save exemplar: {exc}") + return None + + async def get_few_shot_examples( + self, + query: str, + group_id: str, + k: int = _DEFAULT_TOP_K, + ) -> List[str]: + """Retrieve the top-K most relevant style exemplars. + + When an embedding provider is available, uses cosine similarity + between the query embedding and stored exemplar vectors. + Falls back to weight-ordered sampling otherwise. + + Args: + query: The current query or context string. + group_id: Chat group to search within. + k: Number of exemplars to return. + + Returns: + List of exemplar content strings, most relevant first. + """ + if self._embedding: + try: + return await self._similarity_search(query, group_id, k) + except Exception as exc: + logger.debug( + f"[ExemplarLibrary] Similarity search failed, " + f"falling back to weight-based: {exc}" + ) + + return await self._weight_based_search(group_id, k) + + async def adjust_weight( + self, exemplar_id: int, delta: float + ) -> bool: + """Adjust an exemplar's quality weight. + + Args: + exemplar_id: Record ID. + delta: Weight adjustment (positive or negative). + + Returns: + ``True`` if the update succeeded. + """ + try: + async with self._db.get_session() as session: + stmt = ( + update(Exemplar) + .where(Exemplar.id == exemplar_id) + .values( + weight=func.max(0.0, Exemplar.weight + delta), + updated_at=int(time.time()), + ) + ) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 + except Exception as exc: + logger.warning( + f"[ExemplarLibrary] Weight adjustment failed: {exc}" + ) + return False + + async def get_group_stats(self, group_id: str) -> Dict[str, Any]: + """Return summary statistics for a group's exemplar collection.""" + try: + async with self._db.get_session() as session: + stmt = select( + func.count(Exemplar.id), + func.avg(Exemplar.weight), + func.sum( + case( + (Exemplar.embedding_json.isnot(None), 1), + else_=0, + ) + ), + ).where(Exemplar.group_id == group_id) + result = await session.execute(stmt) + row = result.one_or_none() + + if row: + return { + "total_exemplars": row[0] or 0, + "avg_weight": round(float(row[1] or 0), 3), + "with_embeddings": row[2] or 0, + } + except Exception as exc: + logger.debug(f"[ExemplarLibrary] Stats query failed: {exc}") + + return {"total_exemplars": 0, "avg_weight": 0.0, "with_embeddings": 0} + + async def delete_exemplar(self, exemplar_id: int) -> bool: + """Delete a specific exemplar by ID.""" + try: + async with self._db.get_session() as session: + stmt = delete(Exemplar).where(Exemplar.id == exemplar_id) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 + except Exception as exc: + logger.warning(f"[ExemplarLibrary] Delete failed: {exc}") + return False + + # Internal helpers + + async def _migrate_embedding_column(self) -> None: + """Upgrade ``embedding_json`` from TEXT to MEDIUMTEXT on MySQL. + + TEXT has a 65 KB limit which is too small for high-dimensional + embeddings (e.g. 3072-dim ≈ 69 KB JSON). This runs once per + process and is a no-op on SQLite (syntax error caught silently). + """ + try: + from sqlalchemy import text + async with self._db.get_session() as session: + await session.execute( + text("ALTER TABLE exemplar MODIFY COLUMN embedding_json MEDIUMTEXT") + ) + await session.commit() + logger.info( + "[ExemplarLibrary] Migrated embedding_json column to MEDIUMTEXT" + ) + except Exception as exc: + # SQLite doesn't support MODIFY COLUMN, or column already migrated. + logger.debug( + f"[ExemplarLibrary] embedding_json migration skipped: {exc}" + ) + + async def _similarity_search( + self, query: str, group_id: str, k: int + ) -> List[str]: + """Vector cosine similarity search.""" + query_vec = await self._embedding.get_embedding(query) + + async with self._db.get_session() as session: + stmt = ( + select(Exemplar.content, Exemplar.embedding_json, Exemplar.weight) + .where( + Exemplar.group_id == group_id, + Exemplar.embedding_json.isnot(None), + ) + .order_by(desc(Exemplar.weight)) + .limit(_MAX_EXEMPLARS_PER_GROUP) + ) + result = await session.execute(stmt) + rows = result.all() + + if not rows: + return await self._weight_based_search(group_id, k) + + scored = [] + for content, emb_json, weight in rows: + try: + stored_vec = json.loads(emb_json) + sim = self._cosine_similarity(query_vec, stored_vec) + # Blend similarity with weight for final score. + score = sim * 0.8 + (weight or 1.0) * 0.2 + scored.append((content, score)) + except (json.JSONDecodeError, TypeError): + continue + + scored.sort(key=lambda x: x[1], reverse=True) + return [content for content, _ in scored[:k]] + + async def _weight_based_search( + self, group_id: str, k: int + ) -> List[str]: + """Fallback: return highest-weight exemplars.""" + try: + async with self._db.get_session() as session: + stmt = ( + select(Exemplar.content) + .where(Exemplar.group_id == group_id) + .order_by(desc(Exemplar.weight), desc(Exemplar.created_at)) + .limit(k) + ) + result = await session.execute(stmt) + return [row[0] for row in result.all()] + except Exception as exc: + logger.debug(f"[ExemplarLibrary] Weight search failed: {exc}") + return [] + + async def _evict_excess(self, session, group_id: str) -> None: + """Remove lowest-weight exemplars when over capacity.""" + try: + count_stmt = select(func.count(Exemplar.id)).where( + Exemplar.group_id == group_id + ) + result = await session.execute(count_stmt) + total = result.scalar() or 0 + + if total <= _MAX_EXEMPLARS_PER_GROUP: + return + + excess = total - _MAX_EXEMPLARS_PER_GROUP + # Find IDs of lowest-weight records. + ids_stmt = ( + select(Exemplar.id) + .where(Exemplar.group_id == group_id) + .order_by(Exemplar.weight, Exemplar.created_at) + .limit(excess) + ) + result = await session.execute(ids_stmt) + ids_to_delete = [row[0] for row in result.all()] + + if ids_to_delete: + del_stmt = delete(Exemplar).where(Exemplar.id.in_(ids_to_delete)) + await session.execute(del_stmt) + await session.commit() + logger.debug( + f"[ExemplarLibrary] Evicted {len(ids_to_delete)} " + f"excess exemplars from group {group_id}" + ) + except Exception as exc: + logger.debug(f"[ExemplarLibrary] Eviction failed: {exc}") + + @staticmethod + def _cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float: + """Compute cosine similarity between two vectors. + + Uses pure Python to avoid hard numpy dependency. + """ + if len(vec_a) != len(vec_b) or not vec_a: + return 0.0 + + dot = sum(a * b for a, b in zip(vec_a, vec_b)) + norm_a = sum(a * a for a in vec_a) ** 0.5 + norm_b = sum(b * b for b in vec_b) ** 0.5 + + if norm_a == 0.0 or norm_b == 0.0: + return 0.0 + + return dot / (norm_a * norm_b) diff --git a/services/knowledge_graph_manager.py b/services/integration/knowledge_graph_manager.py similarity index 97% rename from services/knowledge_graph_manager.py rename to services/integration/knowledge_graph_manager.py index 3d344ea..10118c2 100644 --- a/services/knowledge_graph_manager.py +++ b/services/integration/knowledge_graph_manager.py @@ -13,12 +13,12 @@ from astrbot.api import logger -from ..core.interfaces import MessageData, ServiceLifecycle -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..config import PluginConfig -from ..exceptions import KnowledgeGraphError, ModelAccessError -from ..utils.json_utils import safe_parse_llm_json -from ..models.orm.knowledge_graph import KGEntity, KGRelation, KGParagraphHash +from ...core.interfaces import MessageData, ServiceLifecycle +from ...core.framework_llm_adapter import FrameworkLLMAdapter +from ...config import PluginConfig +from ...exceptions import KnowledgeGraphError, ModelAccessError +from ...utils.json_utils import safe_parse_llm_json +from ...models.orm.knowledge_graph import KGEntity, KGRelation, KGParagraphHash class KnowledgeGraphManager: @@ -200,7 +200,7 @@ async def extract_entities_from_text(self, text: str) -> List[str]: 提取的实体列表 """ try: - from ..statics.prompts import ENTITY_EXTRACTION_PROMPT + from ...statics.prompts import ENTITY_EXTRACTION_PROMPT prompt = ENTITY_EXTRACTION_PROMPT.format(text=text) @@ -234,7 +234,7 @@ async def extract_relations_from_text(self, text: str, entities: List[str]) -> L 关系三元组列表 [(subject, predicate, object), ...] """ try: - from ..statics.prompts import RDF_TRIPLE_EXTRACTION_PROMPT + from ...statics.prompts import RDF_TRIPLE_EXTRACTION_PROMPT entities_str = json.dumps(entities, ensure_ascii=False) prompt = RDF_TRIPLE_EXTRACTION_PROMPT.format( @@ -540,7 +540,7 @@ async def answer_question_with_knowledge_graph(self, question: str, group_id: st knowledge_text = "\n".join(knowledge_context) # 使用LLM生成回答 - from ..statics.prompts import KNOWLEDGE_GRAPH_QA_PROMPT + from ...statics.prompts import KNOWLEDGE_GRAPH_QA_PROMPT prompt = KNOWLEDGE_GRAPH_QA_PROMPT.format( question=question, diff --git a/services/integration/lightrag_knowledge_manager.py b/services/integration/lightrag_knowledge_manager.py new file mode 100644 index 0000000..238f9e4 --- /dev/null +++ b/services/integration/lightrag_knowledge_manager.py @@ -0,0 +1,390 @@ +""" +LightRAG-based knowledge manager. + +Replaces the legacy ``KnowledgeGraphManager`` by using the LightRAG library +for entity/relation extraction, vector-indexed graph storage, and hybrid +retrieval. When ``knowledge_engine`` is set to ``"lightrag"`` in the plugin +config, this module is activated instead of the SQL-based implementation. + +Design notes: + - One ``LightRAG`` instance per group (data isolation via working_dir). + - LLM and embedding calls are bridged to the existing framework adapters + so that no additional API keys are required. + - Query uses ``only_need_context=True`` to return raw context without + an internal LLM QA step, reducing latency to pure retrieval time. + - Graceful import guard: if ``lightrag`` is not installed the class + raises a clear ``ImportError`` at construction time rather than at + module import, so the rest of the plugin can still load under the + ``"legacy"`` engine setting. + - All public methods mirror the ``KnowledgeGraphManager`` interface to + allow transparent config-based switching. +""" + +import asyncio +import os +import time +from typing import Any, Dict, List, Optional + +from astrbot.api import logger + +from ...config import PluginConfig +from ...core.interfaces import MessageData, ServiceLifecycle +from ..embedding.base import IEmbeddingProvider + +# Lazy import guard -- LightRAG is an optional dependency. +_LIGHTRAG_AVAILABLE = False +try: + from lightrag import LightRAG, QueryParam + from lightrag.utils import EmbeddingFunc + + _LIGHTRAG_AVAILABLE = True +except ImportError: + LightRAG = None # type: ignore[assignment,misc] + QueryParam = None # type: ignore[assignment,misc] + EmbeddingFunc = None # type: ignore[assignment,misc] + + +class LightRAGKnowledgeManager: + """Knowledge manager backed by the LightRAG library. + + Public interface intentionally mirrors ``KnowledgeGraphManager`` so that + the learning manager can swap implementations via configuration: + + * ``process_message_for_knowledge_graph(message, group_id)`` + * ``query_knowledge(query, group_id)`` + * ``answer_question_with_knowledge_graph(question, group_id)`` + * ``query_knowledge_graph(query, group_id, limit)`` + * ``get_knowledge_graph_statistics(group_id)`` + * ``start()`` / ``stop()`` + + Usage:: + + manager = LightRAGKnowledgeManager(config, llm_adapter, embedding) + await manager.start() + await manager.process_message_for_knowledge_graph(msg, "group1") + context = await manager.query_knowledge("topic", "group1") + await manager.stop() + """ + + def __init__( + self, + config: PluginConfig, + llm_adapter, + embedding_provider: Optional[IEmbeddingProvider] = None, + ) -> None: + if not _LIGHTRAG_AVAILABLE: + raise ImportError( + "lightrag-hku is required for the LightRAG knowledge engine. " + "Install via: pip install lightrag-hku" + ) + + self._config = config + self._llm = llm_adapter + self._embedding = embedding_provider + self._status = ServiceLifecycle.CREATED + + # Per-group LightRAG instances (lazy-initialised). + self._instances: Dict[str, LightRAG] = {} + + # Per-group initialisation locks to prevent concurrent creation. + self._init_locks: Dict[str, asyncio.Lock] = {} + + # Base directory for all LightRAG data. + self._base_dir = os.path.join(config.data_dir, "lightrag") + + # Track processed message counts per group for statistics. + self._processed_counts: Dict[str, int] = {} + + # Lifecycle + + async def start(self) -> bool: + """Start the knowledge manager service.""" + self._status = ServiceLifecycle.RUNNING + logger.info("[LightRAG] Knowledge manager started") + return True + + async def stop(self) -> bool: + """Stop the service and release all LightRAG storage handles.""" + self._status = ServiceLifecycle.STOPPING + + # Snapshot to avoid RuntimeError from dict mutation during iteration. + instances_snapshot = list(self._instances.items()) + self._instances.clear() + + for group_id, rag in instances_snapshot: + try: + await rag.finalize_storages() + logger.debug( + f"[LightRAG] Finalized storages for group {group_id}" + ) + except Exception as exc: + logger.warning( + f"[LightRAG] Error finalizing group {group_id}: {exc}" + ) + + self._init_locks.clear() + self._status = ServiceLifecycle.STOPPED + logger.info("[LightRAG] Knowledge manager stopped") + return True + + # Public API + + async def process_message_for_knowledge_graph( + self, message: MessageData, group_id: str + ) -> None: + """Extract entities/relations from a message and insert into the graph. + + This is the primary entry point, matching the legacy + ``KnowledgeGraphManager.process_message_for_knowledge_graph`` name + for drop-in compatibility. + """ + if not message.message or len(message.message.strip()) < 10: + return + + text = f"[{message.sender_name}]: {message.message}" + try: + rag = await self._get_rag(group_id) + await rag.ainsert(text) + self._processed_counts[group_id] = ( + self._processed_counts.get(group_id, 0) + 1 + ) + except Exception as exc: + logger.warning( + f"[LightRAG] Insert failed for group {group_id}: {exc}" + ) + + async def process_message_for_knowledge( + self, message: MessageData, group_id: str + ) -> None: + """Short alias for ``process_message_for_knowledge_graph``.""" + await self.process_message_for_knowledge_graph(message, group_id) + + async def query_knowledge( + self, + query: str, + group_id: str, + mode: str = "hybrid", + top_k: int = 10, + ) -> str: + """Retrieve knowledge context for a query without LLM QA. + + Args: + query: The user query or topic. + group_id: Chat group to search within. + mode: LightRAG query mode (``naive``, ``local``, ``global``, + ``hybrid``, ``mix``). + top_k: Number of top items to retrieve. + + Returns: + Retrieved context string. Empty string if nothing relevant. + """ + try: + rag = await self._get_rag(group_id) + result = await rag.aquery( + query, + param=QueryParam( + mode=mode, + only_need_context=True, + top_k=top_k, + ), + ) + if isinstance(result, dict): + # When only_need_context=True, LightRAG may return a dict + # with context sections. Flatten to a single string. + parts = [] + for key in ("entities", "relationships", "chunks"): + if key in result and result[key]: + parts.append(str(result[key])) + return "\n\n".join(parts) if parts else "" + return str(result) if result else "" + except Exception as exc: + logger.warning( + f"[LightRAG] Query failed for group {group_id}: {exc}" + ) + return "" + + async def answer_question_with_knowledge_graph( + self, + question: str, + group_id: str, + ) -> str: + """Return retrieved context for the given question. + + Behavioural difference from the legacy ``KnowledgeGraphManager``: + this method returns an empty string when no relevant context exists, + rather than a fallback natural-language reply like "I don't know". + The raw context is intended for inclusion in the main generation + prompt, saving an LLM round-trip. Callers must handle the + empty-string case. + """ + return await self.query_knowledge(question, group_id) + + async def query_knowledge_graph( + self, + query: str, + group_id: str, + limit: int = 10, + ) -> List[Dict[str, Any]]: + """Legacy-compatible structured query. + + Returns a list of result dicts with ``text`` and ``source`` keys. + """ + context = await self.query_knowledge(query, group_id, top_k=limit) + if not context: + return [] + # Wrap the flat text into the expected list-of-dicts format. + return [{"text": context, "source": "lightrag", "relevance": 1.0}] + + async def get_knowledge_graph_statistics( + self, group_id: str + ) -> Dict[str, Any]: + """Return summary statistics for a group's knowledge graph.""" + stats: Dict[str, Any] = { + "engine": "lightrag", + "entity_count": 0, + "relation_count": 0, + "processed_messages": self._processed_counts.get(group_id, 0), + } + + if group_id not in self._instances: + return stats + + # Read basic metrics from the working directory if available. + working_dir = os.path.join(self._base_dir, group_id) + graph_file = os.path.join( + working_dir, "graph_chunk_entity_relation.graphml" + ) + if not os.path.isfile(graph_file): + return stats + + try: + import networkx as nx + except ImportError: + logger.warning( + "[LightRAG] networkx is not installed; " + "entity/relation counts unavailable" + ) + return stats + + try: + graph = nx.read_graphml(graph_file) + stats["entity_count"] = graph.number_of_nodes() + stats["relation_count"] = graph.number_of_edges() + except Exception as exc: + logger.warning(f"[LightRAG] Could not read graph stats: {exc}") + + return stats + + # Internal helpers + + async def _get_rag(self, group_id: str) -> LightRAG: + """Return the LightRAG instance for *group_id*, creating if needed. + + Uses a per-group asyncio lock to prevent concurrent initialisation + of the same group (TOCTOU race). + """ + if group_id in self._instances: + return self._instances[group_id] + + # Retrieve or create the lock (dict key assignment is atomic in + # CPython's GIL, so no race on the lock creation itself). + if group_id not in self._init_locks: + self._init_locks[group_id] = asyncio.Lock() + + async with self._init_locks[group_id]: + # Re-check after acquiring the lock. + if group_id in self._instances: + return self._instances[group_id] + + working_dir = os.path.join(self._base_dir, group_id) + os.makedirs(working_dir, exist_ok=True) + + rag_kwargs: Dict[str, Any] = { + "working_dir": working_dir, + "llm_model_func": self._make_llm_func(), + "chunk_token_size": 1200, + "chunk_overlap_token_size": 100, + "entity_extract_max_gleaning": 1, + } + + # Attach embedding function if a provider is available. + if self._embedding: + rag_kwargs["embedding_func"] = EmbeddingFunc( + embedding_dim=self._embedding.get_dim(), + max_token_size=8192, + func=self._make_embedding_func(), + ) + + rag = LightRAG(**rag_kwargs) + await rag.initialize_storages() + if hasattr(rag, "initialize_pipeline_status"): + await rag.initialize_pipeline_status() + + self._instances[group_id] = rag + logger.info( + f"[LightRAG] Initialised instance for group {group_id}" + ) + return rag + + def _make_llm_func(self): + """Build an async callable matching LightRAG's LLM function signature. + + LightRAG expects:: + + async def func( + prompt: str, + system_prompt: str | None = None, + history_messages: list = [], + keyword_extraction: bool = False, + **kwargs, + ) -> str + + Note: ``history_messages`` is accepted but not forwarded because + the current ``FrameworkLLMAdapter`` does not support multi-turn + context. A debug log is emitted when history is discarded. + """ + llm = self._llm + + async def _llm_bridge( + prompt: str, + system_prompt: Optional[str] = None, + history_messages: Optional[list] = None, + keyword_extraction: bool = False, + **kwargs, + ) -> str: + if history_messages is None: + history_messages = [] + + full_prompt = prompt + if system_prompt: + full_prompt = f"{system_prompt}\n\n{prompt}" + + if history_messages: + logger.debug( + "[LightRAG] LLM bridge received %d history messages; " + "the current adapter does not forward conversation " + "history.", + len(history_messages), + ) + + result = await llm.generate_response( + full_prompt, + model_type="filter", + ) + return result or "" + + return _llm_bridge + + def _make_embedding_func(self): + """Build an async callable matching LightRAG's embedding function. + + LightRAG expects:: + + async def func(texts: list[str]) -> list[list[float]] + """ + embedding = self._embedding + + async def _embedding_bridge(texts: list) -> list: + return await embedding.get_embeddings(texts) + + return _embedding_bridge diff --git a/services/maibot_adapters.py b/services/integration/maibot_adapters.py similarity index 98% rename from services/maibot_adapters.py rename to services/integration/maibot_adapters.py index dc77c14..df420dc 100644 --- a/services/maibot_adapters.py +++ b/services/integration/maibot_adapters.py @@ -8,16 +8,16 @@ from astrbot.api import logger -from ..core.interfaces import ( +from ...core.interfaces import ( IStyleAnalyzer, ILearningStrategy, IQualityMonitor, MessageData, AnalysisResult, ServiceLifecycle ) -from ..config import PluginConfig -from .database_manager import DatabaseManager -from .expression_pattern_learner import ExpressionPatternLearner -from .memory_graph_manager import MemoryGraphManager +from ...config import PluginConfig +from ..database import DatabaseManager +from ..analysis import ExpressionPatternLearner +from ..state.enhanced_memory_graph_manager import MemoryGraphManager from .knowledge_graph_manager import KnowledgeGraphManager -from .time_decay_manager import TimeDecayManager +from ..state import TimeDecayManager class MaiBotStyleAnalyzer(IStyleAnalyzer): @@ -471,7 +471,7 @@ async def evaluate_learning_batch(self, except Exception as e: logger.error(f"学习批次质量评估失败: {e}") - from ..core.interfaces import AnalysisResult + from ...core.interfaces import AnalysisResult return AnalysisResult( success=False, confidence=0.0, diff --git a/services/maibot_enhanced_learning_manager.py b/services/integration/maibot_enhanced_learning_manager.py similarity index 76% rename from services/maibot_enhanced_learning_manager.py rename to services/integration/maibot_enhanced_learning_manager.py index 378feeb..e6cddae 100644 --- a/services/maibot_enhanced_learning_manager.py +++ b/services/integration/maibot_enhanced_learning_manager.py @@ -9,15 +9,15 @@ from astrbot.api import logger -from ..core.interfaces import MessageData, ServiceLifecycle -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..config import PluginConfig -from ..exceptions import SelfLearningError -from .database_manager import DatabaseManager -from .expression_pattern_learner import ExpressionPatternLearner -from .memory_graph_manager import MemoryGraphManager +from ...core.interfaces import MessageData, ServiceLifecycle +from ...core.framework_llm_adapter import FrameworkLLMAdapter +from ...config import PluginConfig +from ...exceptions import SelfLearningError +from ..database import DatabaseManager +from ..analysis import ExpressionPatternLearner +from ..state.enhanced_memory_graph_manager import MemoryGraphManager from .knowledge_graph_manager import KnowledgeGraphManager -from .time_decay_manager import TimeDecayManager +from ..state import TimeDecayManager class MaiBotEnhancedLearningManager: @@ -36,8 +36,8 @@ def __new__(cls, *args, **kwargs): return cls._instance def __init__(self, config: PluginConfig = None, db_manager: DatabaseManager = None, context=None): - # 防止重复初始化 - if self._initialized: + # Allow re-init when first created without config (e.g. via get_instance()) + if self._initialized and self.config is not None: return self.config = config @@ -73,7 +73,21 @@ def __init__(self, config: PluginConfig = None, db_manager: DatabaseManager = No self.MIN_MESSAGES_FOR_LEARNING = 25 # 触发学习的最小消息数 self.LEARNING_COOLDOWN = 300 # 学习冷却时间(秒) self.BATCH_LEARNING_SIZE = 50 # 批量学习大小 - + + # V2 integration (conditional on engine config) + self.v2_integration = None + if config and (config.knowledge_engine != "legacy" or config.memory_engine != "legacy"): + try: + from ..core_learning import V2LearningIntegration + self.v2_integration = V2LearningIntegration( + config=config, + llm_adapter=self.llm_adapter, + db_manager=db_manager, + context=context, + ) + except Exception as exc: + logger.warning(f"V2LearningIntegration init failed, using legacy only: {exc}") + self._initialized = True @classmethod @@ -100,7 +114,11 @@ async def start(self) -> bool: if self.time_decay_manager: await self.time_decay_manager.start() - + + # V2 integration + if self.v2_integration: + await self.v2_integration.start() + # 启动定期维护任务 asyncio.create_task(self._periodic_maintenance()) @@ -128,7 +146,11 @@ async def stop(self) -> bool: if self.time_decay_manager: await self.time_decay_manager.stop() - + + # V2 integration + if self.v2_integration: + await self.v2_integration.stop() + logger.info("MaiBotEnhancedLearningManager及所有子服务已停止") return True @@ -211,59 +233,72 @@ async def process_message(self, message: MessageData, group_id: str) -> Dict[str results = { 'expression_learning': False, 'memory_update': False, - 'knowledge_update': False + 'knowledge_update': False, + 'v2_learning': False } - + # 添加到消息缓冲区 if group_id not in self.message_buffers: self.message_buffers[group_id] = [] - + self.message_buffers[group_id].append(message) - + # 限制缓冲区大小 if len(self.message_buffers[group_id]) > self.BATCH_LEARNING_SIZE: self.message_buffers[group_id] = self.message_buffers[group_id][-self.BATCH_LEARNING_SIZE:] - + state = self._get_group_learning_state(group_id) state['message_count_since_last_learning'] += 1 state['total_messages_processed'] += 1 - - # 异步处理各个学习任务 - tasks = [] - - # 1. 表达模式学习(批量触发) + + # 构建异步任务列表 (result_key, coroutine) + named_tasks = [] + + # V2 handles memory, knowledge, jargon, social, exemplar + if self.v2_integration: + named_tasks.append(('v2_learning', self._trigger_v2_processing(message, group_id))) + + # Expression learning always via legacy (no v2 replacement) if self.expression_learner and self._should_trigger_expression_learning(group_id, self.message_buffers[group_id]): - tasks.append(self._trigger_expression_learning(group_id)) - - # 2. 记忆图更新(实时) - if self.memory_graph_manager and self._should_trigger_memory_update(group_id): - tasks.append(self._trigger_memory_update(message, group_id)) - - # 3. 知识图谱更新(准实时) - if self.knowledge_graph_manager and self._should_trigger_knowledge_update(group_id): - tasks.append(self._trigger_knowledge_update(message, group_id)) - + named_tasks.append(('expression_learning', self._trigger_expression_learning(group_id))) + + # Legacy memory only when v2 doesn't handle it + if not (self.v2_integration and self.config.memory_engine != "legacy"): + if self.memory_graph_manager and self._should_trigger_memory_update(group_id): + named_tasks.append(('memory_update', self._trigger_memory_update(message, group_id))) + + # Legacy knowledge only when v2 doesn't handle it + if not (self.v2_integration and self.config.knowledge_engine != "legacy"): + if self.knowledge_graph_manager and self._should_trigger_knowledge_update(group_id): + named_tasks.append(('knowledge_update', self._trigger_knowledge_update(message, group_id))) + # 并发执行所有任务 - if tasks: - task_results = await asyncio.gather(*tasks, return_exceptions=True) - - for i, result in enumerate(task_results): + if named_tasks: + keys = [k for k, _ in named_tasks] + coros = [c for _, c in named_tasks] + task_results = await asyncio.gather(*coros, return_exceptions=True) + + for key, result in zip(keys, task_results): if isinstance(result, Exception): - logger.error(f"学习任务 {i} 执行失败: {result}") + logger.error(f"学习任务 '{key}' 执行失败: {result}") elif isinstance(result, bool): - if i == 0: # 表达学习 - results['expression_learning'] = result - elif i == 1: # 记忆更新 - results['memory_update'] = result - elif i == 2: # 知识更新 - results['knowledge_update'] = result - + results[key] = result + return results except Exception as e: logger.error(f"处理消息失败: {e}") return {} + async def _trigger_v2_processing(self, message: MessageData, group_id: str) -> bool: + """Trigger V2 tiered learning pipeline.""" + try: + await self.v2_integration.process_message(message, group_id) + return True + except Exception as exc: + logger.error(f"V2 processing failed: {exc}") + return False + async def _trigger_expression_learning(self, group_id: str) -> bool: """触发表达模式学习""" try: @@ -433,23 +468,41 @@ async def get_enhanced_context_for_response(self, query: str, group_id: str) -> 'related_memories': [], 'knowledge_graph_context': '' } - - # 1. 获取表达模式 + + # 1. Expression patterns — always legacy if self.expression_learner: patterns_text = await self.expression_learner.format_expression_patterns_for_prompt(group_id) context['expression_patterns'] = patterns_text - - # 2. 获取相关记忆 - if self.memory_graph_manager: - memories = await self.memory_graph_manager.get_related_memories(query, group_id) - context['related_memories'] = memories - - # 3. 获取知识图谱上下文 - if self.knowledge_graph_manager: - kg_answer = await self.knowledge_graph_manager.answer_question_with_knowledge_graph(query, group_id) - if kg_answer != "我不知道": - context['knowledge_graph_context'] = kg_answer - + + # 2. V2 context (knowledge, memory, few-shot, social graph) + v2_context_ok = False + if self.v2_integration: + try: + v2_ctx = await self.v2_integration.get_enhanced_context(query, group_id) + v2_context_ok = True + if 'knowledge_context' in v2_ctx: + context['knowledge_graph_context'] = v2_ctx['knowledge_context'] + if 'related_memories' in v2_ctx: + context['related_memories'] = v2_ctx['related_memories'] + if 'few_shot_examples' in v2_ctx: + context['few_shot_examples'] = v2_ctx['few_shot_examples'] + if 'graph_stats' in v2_ctx: + context['graph_stats'] = v2_ctx['graph_stats'] + except Exception as exc: + logger.warning(f"V2 context retrieval failed, falling through to legacy: {exc}") + + # 3. Legacy fallbacks (when v2 not active, not handling this engine, or v2 failed) + if not (self.v2_integration and v2_context_ok and self.config.memory_engine != "legacy"): + if self.memory_graph_manager: + memories = await self.memory_graph_manager.get_related_memories(query, group_id) + context['related_memories'] = memories + + if not (self.v2_integration and v2_context_ok and self.config.knowledge_engine != "legacy"): + if self.knowledge_graph_manager: + kg_answer = await self.knowledge_graph_manager.answer_question_with_knowledge_graph(query, group_id) + if kg_answer and kg_answer != "我不知道": + context['knowledge_graph_context'] = kg_answer + return context except Exception as e: diff --git a/services/maibot_integration_factory.py b/services/integration/maibot_integration_factory.py similarity index 88% rename from services/maibot_integration_factory.py rename to services/integration/maibot_integration_factory.py index 9260b9c..3b5d623 100644 --- a/services/maibot_integration_factory.py +++ b/services/integration/maibot_integration_factory.py @@ -5,13 +5,13 @@ from typing import Optional, Dict, Any, List from astrbot.api import logger -from ..core.interfaces import MessageData -from ..config import PluginConfig -from .database_manager import DatabaseManager +from ...core.interfaces import MessageData +from ...config import PluginConfig +from ..database import DatabaseManager from .maibot_enhanced_learning_manager import MaiBotEnhancedLearningManager -from .expression_pattern_learner import ExpressionPatternLearner +from ..analysis import ExpressionPatternLearner from .knowledge_graph_manager import KnowledgeGraphManager -from .time_decay_manager import TimeDecayManager +from ..state import TimeDecayManager class MaiBotIntegrationFactory: @@ -30,7 +30,7 @@ def __new__(cls, *args, **kwargs): return cls._instance def __init__(self, config: PluginConfig = None, db_manager: DatabaseManager = None, context=None, llm_adapter=None): - if self._initialized: + if self._initialized and self.config is not None: return self.config = config @@ -41,7 +41,7 @@ def __init__(self, config: PluginConfig = None, db_manager: DatabaseManager = No # 初始化子管理器(如果还没有初始化) if config and db_manager: - self.enhanced_manager.__init__(config, db_manager) + self.enhanced_manager.__init__(config, db_manager, context) # 确保子管理器也被正确初始化,传递所有必要参数 ExpressionPatternLearner.get_instance( @@ -51,24 +51,12 @@ def __init__(self, config: PluginConfig = None, db_manager: DatabaseManager = No llm_adapter=llm_adapter ) - # 使用管理器工厂创建记忆管理器(根据配置选择实现) - use_enhanced = getattr(config, 'use_enhanced_managers', False) - if use_enhanced: - logger.info("📦 [MaiBot工厂] 使用增强型记忆管理器") - from .manager_factory import get_manager_factory - manager_factory = get_manager_factory(config) - self.memory_manager = manager_factory.create_memory_manager( - db_manager, - llm_adapter, - self.enhanced_manager.time_decay_manager - ) - else: - logger.info("📦 [MaiBot工厂] 使用原始记忆管理器") - from .memory_graph_manager import MemoryGraphManager - self.memory_manager = MemoryGraphManager.get_instance() - self.memory_manager.__init__(config, db_manager, - self.enhanced_manager.llm_adapter, - self.enhanced_manager.time_decay_manager) + # 创建记忆管理器 + from ..state.enhanced_memory_graph_manager import EnhancedMemoryGraphManager + self.memory_manager = EnhancedMemoryGraphManager.get_instance( + config, db_manager, llm_adapter, + self.enhanced_manager.time_decay_manager + ) KnowledgeGraphManager.get_instance().__init__(config, db_manager, self.enhanced_manager.llm_adapter) @@ -220,7 +208,7 @@ async def get_related_memories(self, query: str, group_id: str, limit: int = 5) return await self.memory_manager.get_related_memories(query, group_id, limit) else: # 降级方案 - from .memory_graph_manager import MemoryGraphManager + from ..state.enhanced_memory_graph_manager import MemoryGraphManager memory_manager = MemoryGraphManager.get_instance() return await memory_manager.get_related_memories(query, group_id, limit) except Exception as e: @@ -280,7 +268,7 @@ async def get_all_statistics(self, group_id: str) -> Dict[str, Any]: stats['memory_graph'] = await self.memory_manager.get_memory_graph_statistics(group_id) else: # 降级方案 - from .memory_graph_manager import MemoryGraphManager + from ..state.enhanced_memory_graph_manager import MemoryGraphManager memory_manager = MemoryGraphManager.get_instance() stats['memory_graph'] = await memory_manager.get_memory_graph_statistics(group_id) diff --git a/services/integration/mem0_memory_manager.py b/services/integration/mem0_memory_manager.py new file mode 100644 index 0000000..01d8959 --- /dev/null +++ b/services/integration/mem0_memory_manager.py @@ -0,0 +1,349 @@ +""" +mem0-based memory manager. + +Replaces the legacy ``MemoryGraphManager`` by using the mem0 library for +automatic memory extraction, semantic vector search, and contradiction +detection. When ``memory_engine`` is set to ``"mem0"`` in the plugin +config, this module is activated instead of the NetworkX-based +implementation. + +Design notes: + - Uses mem0's built-in LLM fact extraction to distil memories from + chat messages, replacing manual ``jieba`` concept extraction. + - Semantic vector retrieval via Qdrant (local embedded mode, no + external server required). + - Group isolation achieved by using ``agent_id=group_id`` as the + mem0 scoping parameter. + - LLM and embedding credentials are extracted from the AstrBot + framework providers at initialisation time so users only configure + providers once. + - Blocking mem0 calls are offloaded to a thread pool via + ``asyncio.to_thread`` to keep the event loop responsive. + - Graceful import guard: if ``mem0ai`` is not installed the class + raises a clear ``ImportError`` at construction time. +""" + +import asyncio +import os +from typing import Any, Dict, List, Optional + +from astrbot.api import logger + +from ...config import PluginConfig +from ...core.interfaces import MessageData, ServiceLifecycle + +# Lazy import guard -- mem0ai is an optional dependency. +_MEM0_AVAILABLE = False +try: + from mem0 import Memory as Mem0Memory + + _MEM0_AVAILABLE = True +except ImportError: + Mem0Memory = None # type: ignore[assignment,misc] + + +class Mem0MemoryManager: + """Memory manager backed by the mem0 library. + + Public interface mirrors ``MemoryGraphManager`` for transparent + config-based switching: + + * ``add_memory_from_message(message, group_id)`` + * ``get_related_memories(query, group_id, limit)`` + * ``get_memory_graph_statistics(group_id)`` + * ``save_memory_graph(group_id)`` -- no-op (mem0 auto-persists) + * ``load_memory_graph(group_id)`` -- no-op (mem0 auto-loads) + * ``start()`` / ``stop()`` + + Usage:: + + manager = Mem0MemoryManager(config, llm_adapter, embedding_provider) + await manager.start() + await manager.add_memory_from_message(msg, "group1") + memories = await manager.get_related_memories("topic", "group1") + await manager.stop() + """ + + def __init__( + self, + config: PluginConfig, + llm_adapter, + embedding_provider=None, + ) -> None: + if not _MEM0_AVAILABLE: + raise ImportError( + "mem0ai is required for the mem0 memory engine. " + "Install via: pip install mem0ai" + ) + + self._config = config + self._llm_adapter = llm_adapter + self._embedding_provider = embedding_provider + self._status = ServiceLifecycle.CREATED + self._memory: Optional[Mem0Memory] = None + + # Provide a dict-like attribute so callers iterating over + # memory_graphs (as with the legacy manager) get an empty dict + # instead of an AttributeError. + self.memory_graphs: Dict[str, Any] = {} + + # Lifecycle + + async def start(self) -> bool: + """Initialise the mem0 Memory instance.""" + try: + mem0_config = self._build_config() + self._memory = await asyncio.to_thread( + Mem0Memory.from_config, mem0_config + ) + self._status = ServiceLifecycle.RUNNING + logger.info("[Mem0] Memory manager started") + return True + except Exception as exc: + logger.error(f"[Mem0] Failed to start: {exc}") + self._status = ServiceLifecycle.ERROR + return False + + async def stop(self) -> bool: + """Release the mem0 instance.""" + self._status = ServiceLifecycle.STOPPING + self._memory = None + self._status = ServiceLifecycle.STOPPED + logger.info("[Mem0] Memory manager stopped") + return True + + # Public API + + async def add_memory_from_message( + self, message: MessageData, group_id: str + ) -> None: + """Extract and store memories from an incoming message. + + mem0 automatically distils facts from the text via its LLM + pipeline, handling deduplication and contradiction resolution. + """ + if not self._memory: + return + + text = self._extract_text(message) + if not text: + return + + try: + await asyncio.to_thread( + self._memory.add, + text, + user_id=message.sender_id, + agent_id=group_id, + metadata={"sender_name": message.sender_name}, + ) + except Exception as exc: + logger.debug(f"[Mem0] add_memory failed: {exc}") + + async def get_related_memories( + self, + query: str, + group_id: str, + limit: int = 5, + ) -> List[str]: + """Retrieve semantically related memories for a group. + + Returns: + List of memory text strings, most relevant first. + """ + if not self._memory: + return [] + + try: + results = await asyncio.to_thread( + self._memory.search, + query, + agent_id=group_id, + limit=limit, + ) + # mem0 v1.1 format: {"results": [{"memory": str, ...}, ...]} + entries = results.get("results", []) if isinstance(results, dict) else results + return [ + entry["memory"] + for entry in entries + if isinstance(entry, dict) and entry.get("memory") + ] + except Exception as exc: + logger.debug(f"[Mem0] search failed: {exc}") + return [] + + async def get_memory_graph_statistics( + self, group_id: str + ) -> Dict[str, Any]: + """Return summary statistics for a group's memory store.""" + stats: Dict[str, Any] = { + "engine": "mem0", + "total_memories": 0, + } + + if not self._memory: + return stats + + try: + all_memories = await asyncio.to_thread( + self._memory.get_all, + agent_id=group_id, + ) + entries = ( + all_memories.get("results", []) + if isinstance(all_memories, dict) + else all_memories + ) + stats["total_memories"] = len(entries) if entries else 0 + except Exception as exc: + logger.debug(f"[Mem0] get_all failed: {exc}") + + return stats + + async def save_memory_graph(self, group_id: str) -> None: + """No-op: mem0 auto-persists to Qdrant.""" + + async def load_memory_graph(self, group_id: str) -> None: + """No-op: mem0 auto-loads from Qdrant.""" + + def get_memory_graph(self, group_id: str) -> None: + """Compatibility stub. Returns ``None`` since mem0 does not + expose an in-memory graph object.""" + return None + + # Internal helpers + + @staticmethod + def _extract_text(message: MessageData) -> str: + """Build a text representation from a MessageData instance.""" + text = getattr(message, "message", "") or "" + text = text.strip() + if len(text) < 5: + return "" + sender = getattr(message, "sender_name", "Unknown") + return f"[{sender}]: {text}" + + def _build_config(self) -> dict: + """Build the mem0 configuration dict. + + Attempts to extract LLM and embedding API credentials from the + AstrBot framework providers. Falls back to env variables if + extraction fails (mem0 reads ``OPENAI_API_KEY`` by default). + """ + config: Dict[str, Any] = {"version": "v1.1"} + + # -- LLM config -- + llm_cfg = self._extract_llm_credentials() + if llm_cfg: + config["llm"] = llm_cfg + + # -- Embedding config -- + emb_cfg = self._extract_embedding_credentials() + if emb_cfg: + config["embedder"] = emb_cfg + + # -- Vector store (local Qdrant, no external server) -- + qdrant_path = os.path.join(self._config.data_dir, "mem0_qdrant") + os.makedirs(qdrant_path, exist_ok=True) + + embedding_dims = 1536 # default for text-embedding-3-small + if self._embedding_provider: + try: + embedding_dims = self._embedding_provider.get_dim() + except Exception: + pass + + config["vector_store"] = { + "provider": "qdrant", + "config": { + "collection_name": "self_learning_memories", + "path": qdrant_path, + "on_disk": True, + "embedding_model_dims": embedding_dims, + }, + } + + return config + + def _extract_llm_credentials(self) -> Optional[Dict[str, Any]]: + """Try to extract LLM API credentials from the framework adapter.""" + try: + provider = ( + self._llm_adapter.filter_provider + or self._llm_adapter.refine_provider + or self._llm_adapter.reinforce_provider + ) + if not provider: + return None + + pc = getattr(provider, "provider_config", {}) + api_key = None + if hasattr(provider, "get_current_key"): + api_key = provider.get_current_key() + if not api_key: + keys = pc.get("key", []) + api_key = keys[0] if keys else None + + base_url = pc.get("api_base") or None + model = provider.get_model() if hasattr(provider, "get_model") else None + + if not api_key: + return None + + llm_config: Dict[str, Any] = { + "model": model or "gpt-4o-mini", + "temperature": 0.1, + "max_tokens": 1500, + "api_key": api_key, + } + if base_url: + llm_config["openai_base_url"] = base_url + + return {"provider": "openai", "config": llm_config} + + except Exception as exc: + logger.debug( + f"[Mem0] Could not extract LLM credentials, " + f"using mem0 defaults: {exc}" + ) + return None + + def _extract_embedding_credentials(self) -> Optional[Dict[str, Any]]: + """Try to extract embedding API credentials from the framework.""" + try: + emb = self._embedding_provider + if not emb: + return None + + # Unwrap the FrameworkEmbeddingAdapter to reach the underlying + # AstrBot EmbeddingProvider which holds provider_config. + underlying = getattr(emb, "_provider", None) + if not underlying: + return None + + pc = getattr(underlying, "provider_config", {}) + api_key = pc.get("embedding_api_key") or None + base_url = pc.get("embedding_api_base") or None + model = underlying.get_model() if hasattr(underlying, "get_model") else None + + if not api_key: + return None + + emb_config: Dict[str, Any] = { + "model": model or "text-embedding-3-small", + "api_key": api_key, + } + if base_url: + emb_config["openai_base_url"] = base_url + + dim = emb.get_dim() if hasattr(emb, "get_dim") else 1536 + emb_config["embedding_dims"] = dim + + return {"provider": "openai", "config": emb_config} + + except Exception as exc: + logger.debug( + f"[Mem0] Could not extract embedding credentials, " + f"using mem0 defaults: {exc}" + ) + return None diff --git a/services/integration/training_data_exporter.py b/services/integration/training_data_exporter.py new file mode 100644 index 0000000..41d07f6 --- /dev/null +++ b/services/integration/training_data_exporter.py @@ -0,0 +1,662 @@ +""" +训练数据导出服务 +将对话数据导出为标准的大模型微调格式 (JSONL) + +设计原则: +1. 数据聚合: 关联用户消息和Bot回复,构建完整对话对 +2. 格式标准化: 转换为OpenAI/Claude微调训练格式 +3. 质量筛选: 可选的质量过滤机制 +4. 批量导出: 支持按时间范围、群组、质量阈值等条件导出 +""" +import json +import time +from typing import Dict, List, Optional, Any, Tuple +from datetime import datetime, timedelta +from pathlib import Path + +from astrbot.api import logger +from sqlalchemy import select, and_, or_, func +from sqlalchemy.ext.asyncio import AsyncSession + +from ...core.patterns import AsyncServiceBase +from ...models.orm.message import RawMessage, BotMessage, FilteredMessage +from ...repositories.base_repository import BaseRepository + + +class ConversationPair: + """对话对数据结构""" + + def __init__( + self, + user_message: str, + bot_response: str, + user_id: str, + group_id: str, + user_timestamp: int, + bot_timestamp: int, + quality_score: Optional[float] = None, + metadata: Optional[Dict] = None + ): + self.user_message = user_message + self.bot_response = bot_response + self.user_id = user_id + self.group_id = group_id + self.user_timestamp = user_timestamp + self.bot_timestamp = bot_timestamp + self.quality_score = quality_score + self.metadata = metadata or {} + + def to_training_format( + self, + system_prompt: Optional[str] = None, + include_metadata: bool = False + ) -> Dict[str, Any]: + """ + 转换为训练格式 + + Args: + system_prompt: 系统提示词 (可选) + include_metadata: 是否包含元数据 + + Returns: + 标准训练格式的字典 + """ + messages = [] + + # 添加system角色 (如果提供) + if system_prompt: + messages.append({ + "role": "system", + "content": system_prompt + }) + + # 添加用户消息 + messages.append({ + "role": "user", + "content": self.user_message + }) + + # 添加助手回复 + messages.append({ + "role": "assistant", + "content": self.bot_response + }) + + result = {"messages": messages} + + # 可选: 添加元数据 (用于分析,不用于训练) + if include_metadata: + result["metadata"] = { + "user_id": self.user_id, + "group_id": self.group_id, + "user_timestamp": self.user_timestamp, + "bot_timestamp": self.bot_timestamp, + "quality_score": self.quality_score, + **self.metadata + } + + return result + + +class TrainingDataExporter(AsyncServiceBase): + """ + 训练数据导出服务 + + 功能: + 1. 从数据库中提取对话对 (用户消息 + Bot回复) + 2. 按时间顺序关联消息 + 3. 可选的质量筛选 + 4. 导出为JSONL格式 + 5. 支持从远程数据库导出 + """ + + def __init__(self, database_manager, is_remote: bool = False): + """ + 初始化训练数据导出器 + + Args: + database_manager: SQLAlchemyDatabaseManager实例 + is_remote: 是否为远程数据库连接 + """ + super().__init__("training_data_exporter") + self.db_manager = database_manager + self.is_remote = is_remote + + # 配置参数 + self.max_time_gap_seconds = 300 # 用户消息和Bot回复的最大时间差 (5分钟) + self.min_message_length = 2 # 最小消息长度 + self.max_message_length = 2000 # 最大消息长度 + + @classmethod + async def create_from_remote_db( + cls, + database_url: str, + echo: bool = False + ) -> 'TrainingDataExporter': + """ + 从远程数据库创建导出器 (工厂方法) + + Args: + database_url: 远程数据库连接URL + - MySQL: "mysql+aiomysql://user:pass@host:port/dbname" + - PostgreSQL: "postgresql+asyncpg://user:pass@host:port/dbname" + echo: 是否打印SQL语句 (调试用) + + Returns: + TrainingDataExporter实例 + + Examples: + # MySQL云端数据库 + exporter = await TrainingDataExporter.create_from_remote_db( + "mysql+aiomysql://user:password@云端IP:3306/database" + ) + await exporter.start() + + # PostgreSQL云端数据库 + exporter = await TrainingDataExporter.create_from_remote_db( + "postgresql+asyncpg://user:password@云端IP:5432/database" + ) + """ + from ...core.database.engine import DatabaseEngine + from ..database import SQLAlchemyDatabaseManager + + # 创建远程数据库引擎 + logger.info(f"连接远程数据库: {cls._mask_database_url(database_url)}") + engine = DatabaseEngine(database_url, echo=echo) + + # 创建数据库管理器 + # 注意: 这里使用临时配置,因为远程数据库不需要完整的PluginConfig + class RemoteDBConfig: + """远程数据库临时配置""" + def __init__(self, db_url): + self.database_url = db_url + self.enable_auto_migration = False # 远程数据库不自动迁移 + + config = RemoteDBConfig(database_url) + db_manager = SQLAlchemyDatabaseManager.__new__(SQLAlchemyDatabaseManager) + db_manager.config = config + db_manager.engine = engine + db_manager._logger = logger + + # 创建导出器 + exporter = cls(db_manager, is_remote=True) + logger.info(" 远程数据库连接成功") + + return exporter + + @staticmethod + def _mask_database_url(url: str) -> str: + """隐藏数据库URL中的密码""" + if '@' in url: + parts = url.split('@') + if ':' in parts[0]: + prefix = parts[0].rsplit(':', 1)[0] + return f"{prefix}:****@{parts[1]}" + return url + + async def _do_start(self) -> bool: + """启动服务""" + self._logger.info("训练数据导出服务启动成功") + return True + + async def _do_stop(self) -> bool: + """停止服务""" + return True + + async def extract_conversation_pairs( + self, + group_id: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + min_quality_score: Optional[float] = None, + limit: Optional[int] = None + ) -> List[ConversationPair]: + """ + 提取对话对 + + Args: + group_id: 群组ID (可选,不指定则提取所有群组) + start_time: 开始时间戳 (毫秒,可选) + end_time: 结束时间戳 (毫秒,可选) + min_quality_score: 最小质量分数 (可选,0-1) + limit: 最大返回数量 (可选) + + Returns: + 对话对列表 + """ + try: + async with self.db_manager.get_session() as session: + # 1. 查询用户消息 + user_messages = await self._fetch_user_messages( + session, group_id, start_time, end_time, min_quality_score + ) + + if not user_messages: + self._logger.info("未找到符合条件的用户消息") + return [] + + self._logger.info(f"查询到 {len(user_messages)} 条用户消息") + + # 2. 查询Bot回复 + bot_responses = await self._fetch_bot_responses( + session, group_id, start_time, end_time + ) + + if not bot_responses: + self._logger.info("未找到符合条件的Bot回复") + return [] + + self._logger.info(f"查询到 {len(bot_responses)} 条Bot回复") + + # 3. 关联消息对 + conversation_pairs = self._match_message_pairs( + user_messages, bot_responses + ) + + self._logger.info(f"成功匹配 {len(conversation_pairs)} 个对话对") + + # 4. 应用限制 + if limit and len(conversation_pairs) > limit: + conversation_pairs = conversation_pairs[:limit] + + return conversation_pairs + + except Exception as e: + self._logger.error(f"提取对话对失败: {e}", exc_info=True) + return [] + + async def _fetch_user_messages( + self, + session: AsyncSession, + group_id: Optional[str], + start_time: Optional[int], + end_time: Optional[int], + min_quality_score: Optional[float] + ) -> List[Tuple]: + """ + 查询用户消息 + + Returns: + (message_id, sender_id, group_id, message, timestamp, quality_score) + """ + # 如果需要质量筛选,使用filtered_messages表 + if min_quality_score is not None: + stmt = select( + FilteredMessage.id, + FilteredMessage.sender_id, + FilteredMessage.group_id, + FilteredMessage.message, + FilteredMessage.timestamp, + FilteredMessage.confidence + ).where( + and_( + FilteredMessage.confidence >= min_quality_score, + func.length(FilteredMessage.message) >= self.min_message_length, + func.length(FilteredMessage.message) <= self.max_message_length + ) + ) + else: + # 否则使用raw_messages表 + stmt = select( + RawMessage.id, + RawMessage.sender_id, + RawMessage.group_id, + RawMessage.message, + RawMessage.timestamp + ).where( + and_( + func.length(RawMessage.message) >= self.min_message_length, + func.length(RawMessage.message) <= self.max_message_length + ) + ) + + # 添加过滤条件 + conditions = [] + + if group_id: + if min_quality_score is not None: + conditions.append(FilteredMessage.group_id == group_id) + else: + conditions.append(RawMessage.group_id == group_id) + + if start_time: + if min_quality_score is not None: + conditions.append(FilteredMessage.timestamp >= start_time) + else: + conditions.append(RawMessage.timestamp >= start_time) + + if end_time: + if min_quality_score is not None: + conditions.append(FilteredMessage.timestamp <= end_time) + else: + conditions.append(RawMessage.timestamp <= end_time) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + # 按时间排序 + if min_quality_score is not None: + stmt = stmt.order_by(FilteredMessage.timestamp) + else: + stmt = stmt.order_by(RawMessage.timestamp) + + result = await session.execute(stmt) + rows = result.fetchall() + + # 如果使用raw_messages,添加None作为quality_score + if min_quality_score is None: + rows = [(*row, None) for row in rows] + + return rows + + async def _fetch_bot_responses( + self, + session: AsyncSession, + group_id: Optional[str], + start_time: Optional[int], + end_time: Optional[int] + ) -> List[Tuple]: + """ + 查询Bot回复 + + Returns: + (message_id, group_id, message, timestamp) + """ + stmt = select( + BotMessage.id, + BotMessage.group_id, + BotMessage.message, + BotMessage.timestamp + ).where( + and_( + func.length(BotMessage.message) >= self.min_message_length, + func.length(BotMessage.message) <= self.max_message_length + ) + ) + + # 添加过滤条件 + conditions = [] + + if group_id: + conditions.append(BotMessage.group_id == group_id) + + if start_time: + conditions.append(BotMessage.timestamp >= start_time) + + if end_time: + conditions.append(BotMessage.timestamp <= end_time) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + # 按时间排序 + stmt = stmt.order_by(BotMessage.timestamp) + + result = await session.execute(stmt) + return result.fetchall() + + def _match_message_pairs( + self, + user_messages: List[Tuple], + bot_responses: List[Tuple] + ) -> List[ConversationPair]: + """ + 关联用户消息和Bot回复 + + 匹配策略: + 1. 相同群组 + 2. Bot回复时间在用户消息之后 + 3. 时间差在max_time_gap_seconds内 + 4. 选择时间差最小的Bot回复 + + Args: + user_messages: (id, sender_id, group_id, message, timestamp, quality_score) + bot_responses: (id, group_id, message, timestamp) + + Returns: + 对话对列表 + """ + pairs = [] + used_bot_indices = set() + + # 将Bot回复按群组分组,提高匹配效率 + bot_by_group = {} + for idx, (bot_id, group_id, message, timestamp) in enumerate(bot_responses): + if group_id not in bot_by_group: + bot_by_group[group_id] = [] + bot_by_group[group_id].append((idx, bot_id, message, timestamp)) + + # 遍历用户消息,寻找匹配的Bot回复 + for user_id, sender_id, group_id, user_msg, user_ts, quality_score in user_messages: + if group_id not in bot_by_group: + continue + + # 查找该群组内,时间在用户消息之后的Bot回复 + best_match = None + min_time_gap = float('inf') + best_idx = None + + for idx, bot_id, bot_msg, bot_ts in bot_by_group[group_id]: + # 跳过已使用的Bot回复 + if idx in used_bot_indices: + continue + + # Bot回复必须在用户消息之后 + if bot_ts < user_ts: + continue + + # 计算时间差 (毫秒转秒) + time_gap = (bot_ts - user_ts) / 1000 + + # 时间差必须在允许范围内 + if time_gap > self.max_time_gap_seconds: + break # bot_responses已按时间排序,后续的都不符合 + + # 选择时间差最小的 + if time_gap < min_time_gap: + min_time_gap = time_gap + best_match = (bot_id, bot_msg, bot_ts) + best_idx = idx + + # 找到匹配 + if best_match: + bot_id, bot_msg, bot_ts = best_match + used_bot_indices.add(best_idx) + + pair = ConversationPair( + user_message=user_msg, + bot_response=bot_msg, + user_id=sender_id, + group_id=group_id, + user_timestamp=user_ts, + bot_timestamp=bot_ts, + quality_score=quality_score, + metadata={ + "time_gap_seconds": min_time_gap + } + ) + pairs.append(pair) + + return pairs + + async def export_to_jsonl( + self, + output_path: str, + group_id: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + min_quality_score: Optional[float] = None, + limit: Optional[int] = None, + system_prompt: Optional[str] = None, + include_metadata: bool = False + ) -> Dict[str, Any]: + """ + 导出训练数据为JSONL文件 + + Args: + output_path: 输出文件路径 + group_id: 群组ID (可选) + start_time: 开始时间戳 (毫秒,可选) + end_time: 结束时间戳 (毫秒,可选) + min_quality_score: 最小质量分数 (可选,0-1) + limit: 最大导出数量 (可选) + system_prompt: 系统提示词 (可选) + include_metadata: 是否包含元数据 (可选) + + Returns: + 导出结果统计 + """ + try: + start_export_time = time.time() + + # 1. 提取对话对 + self._logger.info(f"开始提取对话对... (group={group_id}, limit={limit})") + pairs = await self.extract_conversation_pairs( + group_id=group_id, + start_time=start_time, + end_time=end_time, + min_quality_score=min_quality_score, + limit=limit + ) + + if not pairs: + return { + "success": False, + "message": "未找到符合条件的对话对", + "total_pairs": 0, + "output_path": None + } + + # 2. 创建输出目录 + output_file = Path(output_path) + output_file.parent.mkdir(parents=True, exist_ok=True) + + # 3. 写入JSONL文件 + with open(output_file, 'w', encoding='utf-8') as f: + for pair in pairs: + training_data = pair.to_training_format( + system_prompt=system_prompt, + include_metadata=include_metadata + ) + f.write(json.dumps(training_data, ensure_ascii=False) + '\n') + + export_duration = time.time() - start_export_time + + self._logger.info( + f" 导出完成: {len(pairs)} 个对话对, " + f"耗时 {export_duration:.2f}s, " + f"文件: {output_path}" + ) + + return { + "success": True, + "message": "导出成功", + "total_pairs": len(pairs), + "output_path": str(output_file.absolute()), + "duration_seconds": export_duration, + "filters": { + "group_id": group_id, + "start_time": start_time, + "end_time": end_time, + "min_quality_score": min_quality_score, + "limit": limit + } + } + + except Exception as e: + self._logger.error(f"导出训练数据失败: {e}", exc_info=True) + return { + "success": False, + "message": f"导出失败: {str(e)}", + "total_pairs": 0, + "output_path": None + } + + async def export_by_date_range( + self, + output_dir: str, + days_ago: int = 7, + **export_kwargs + ) -> Dict[str, Any]: + """ + 按日期范围导出 (便捷方法) + + Args: + output_dir: 输出目录 + days_ago: 最近N天 (默认7天) + **export_kwargs: 其他导出参数 + + Returns: + 导出结果 + """ + end_time = int(time.time() * 1000) # 当前时间 (毫秒) + start_time = end_time - (days_ago * 24 * 60 * 60 * 1000) # N天前 + + # 生成文件名 + timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") + output_filename = f"training_data_{days_ago}days_{timestamp_str}.jsonl" + output_path = str(Path(output_dir) / output_filename) + + return await self.export_to_jsonl( + output_path=output_path, + start_time=start_time, + end_time=end_time, + **export_kwargs + ) + + async def get_export_statistics( + self, + group_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + 获取可导出数据的统计信息 + + Args: + group_id: 群组ID (可选) + + Returns: + 统计信息 + """ + try: + async with self.db_manager.get_session() as session: + # 统计用户消息数 + user_stmt = select(func.count(RawMessage.id)) + if group_id: + user_stmt = user_stmt.where(RawMessage.group_id == group_id) + + user_result = await session.execute(user_stmt) + total_user_messages = user_result.scalar() + + # 统计Bot回复数 + bot_stmt = select(func.count(BotMessage.id)) + if group_id: + bot_stmt = bot_stmt.where(BotMessage.group_id == group_id) + + bot_result = await session.execute(bot_stmt) + total_bot_messages = bot_result.scalar() + + # 统计高质量消息数 + filtered_stmt = select(func.count(FilteredMessage.id)) + if group_id: + filtered_stmt = filtered_stmt.where(FilteredMessage.group_id == group_id) + + filtered_result = await session.execute(filtered_stmt) + total_filtered_messages = filtered_result.scalar() + + return { + "total_user_messages": total_user_messages, + "total_bot_messages": total_bot_messages, + "total_filtered_messages": total_filtered_messages, + "estimated_max_pairs": min(total_user_messages, total_bot_messages), + "group_id": group_id + } + + except Exception as e: + self._logger.error(f"获取统计信息失败: {e}", exc_info=True) + return { + "total_user_messages": 0, + "total_bot_messages": 0, + "total_filtered_messages": 0, + "estimated_max_pairs": 0, + "error": str(e) + } \ No newline at end of file diff --git a/services/jargon/__init__.py b/services/jargon/__init__.py new file mode 100644 index 0000000..23bd214 --- /dev/null +++ b/services/jargon/__init__.py @@ -0,0 +1,12 @@ +"""Jargon detection, mining, and query services.""" + +from .jargon_miner import JargonMiner, JargonMinerManager +from .jargon_query import JargonQueryService +from .jargon_statistical_filter import JargonStatisticalFilter + +__all__ = [ + "JargonMiner", + "JargonMinerManager", + "JargonQueryService", + "JargonStatisticalFilter", +] diff --git a/services/jargon_miner.py b/services/jargon/jargon_miner.py similarity index 92% rename from services/jargon_miner.py rename to services/jargon/jargon_miner.py index e79a162..4d3ecce 100644 --- a/services/jargon_miner.py +++ b/services/jargon/jargon_miner.py @@ -11,10 +11,10 @@ from astrbot.api import logger -from ..models.jargon import Jargon -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..core.patterns import AsyncServiceBase -from ..utils.json_utils import safe_parse_llm_json +from ...models.jargon import Jargon +from ...core.framework_llm_adapter import FrameworkLLMAdapter +from ...core.patterns import AsyncServiceBase +from ...utils.json_utils import safe_parse_llm_json class JargonInferenceEngine: @@ -444,15 +444,42 @@ async def infer_and_update(self, jargon: Jargon): except Exception as e: logger.error(f"推断黑话失败: {e}") - async def run_once(self, chat_messages: str, message_count: int): - """执行一次黑话学习""" + async def run_once( + self, + chat_messages: str, + message_count: int, + statistical_candidates: Optional[List[Dict[str, Any]]] = None, + ): + """Execute a single jargon learning iteration. + + Args: + chat_messages: Formatted chat text for LLM extraction. + message_count: Number of recent messages. + statistical_candidates: Pre-filtered candidates from + ``JargonStatisticalFilter``. When provided, LLM-based + candidate extraction is skipped, saving one LLM call. + """ try: if not self.should_trigger(message_count): return - # 1. 提取候选黑话 - candidates = await self.extract_candidates(chat_messages) + # 1. Get candidates — prefer statistical pre-filter over LLM. + if statistical_candidates: + candidates = [ + { + "content": c["term"], + "raw_content": c.get("context_examples", []), + } + for c in statistical_candidates + if c.get("term") + ] + logger.info( + f"[{self.chat_id}] Using {len(candidates)} statistical " + f"candidates (LLM extraction skipped)" + ) + else: + candidates = await self.extract_candidates(chat_messages) if not candidates: return diff --git a/services/jargon_query.py b/services/jargon/jargon_query.py similarity index 96% rename from services/jargon_query.py rename to services/jargon/jargon_query.py index 3f824a3..7bf33a5 100644 --- a/services/jargon_query.py +++ b/services/jargon/jargon_query.py @@ -21,7 +21,7 @@ def __init__(self, db_manager, cache_ttl: int = 60): """ self.db = db_manager - # ⚡ 使用 cachetools.TTLCache - 自动过期管理 + # 使用 cachetools.TTLCache - 自动过期管理 self._cache = TTLCache(maxsize=500, ttl=cache_ttl) logger.info(f"[黑话查询] 使用 TTLCache (maxsize=500, ttl={cache_ttl}s)") @@ -68,7 +68,7 @@ async def query_jargon( if include_global and len(results) < limit: global_results = await self.db.search_jargon( keyword=keyword, - chat_id=None, # 搜索全局黑话 + chat_id=None, # 搜索全局黑话 limit=limit - len(results) ) # 去重 @@ -111,7 +111,7 @@ async def get_jargon_context( Returns: 格式化的黑话列表文本 """ - # ⚡ 先检查缓存 + # 先检查缓存 cache_key = f"jargon_context_{chat_id}_{limit}" cached = self._get_from_cache(cache_key) if cached is not None: @@ -136,7 +136,7 @@ async def get_jargon_context( result = "\n".join(lines) - # ⚡ 缓存结果 + # 缓存结果 self._set_to_cache(cache_key, result) return result @@ -160,7 +160,7 @@ async def check_and_explain_jargon( 如果找到黑话则返回解释文本,否则返回None """ try: - # ⚡ 先从缓存获取该群组的黑话列表 + # 先从缓存获取该群组的黑话列表 cache_key = f"jargon_list_{chat_id}" jargon_list = self._get_from_cache(cache_key) @@ -171,7 +171,7 @@ async def check_and_explain_jargon( limit=100, only_confirmed=True ) - # ⚡ 缓存黑话列表 + # 缓存黑话列表 self._set_to_cache(cache_key, jargon_list) if not jargon_list: diff --git a/services/jargon/jargon_statistical_filter.py b/services/jargon/jargon_statistical_filter.py new file mode 100644 index 0000000..755f403 --- /dev/null +++ b/services/jargon/jargon_statistical_filter.py @@ -0,0 +1,290 @@ +""" +Jargon statistical pre-filter. + +Maintains per-group term frequency tables and applies three statistical +signals (cross-group IDF, burst frequency, user concentration) to identify +jargon candidates *before* any LLM call. This reduces LLM cost by 70-80% +by only forwarding high-confidence candidates to the inference engine. + +Design notes: + - All state is held in memory (dict-of-dicts) for O(1) update per message. + - Tokenisation uses ``jieba`` (already a project dependency). + - The filter is stateless across restarts — rebuilt implicitly from the + message stream. A future enhancement could persist snapshots to DB. + - Thread-safe for single-event-loop asyncio usage (no concurrent writes). +""" + +import math +import time +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set + +from astrbot.api import logger + + +# Minimum term length (characters) to consider as a candidate. +_MIN_TERM_LENGTH = 2 + +# Minimum frequency in a group before a term is considered. +_MIN_FREQUENCY = 3 + +# Maximum number of context examples to retain per term. +_MAX_CONTEXT_EXAMPLES = 10 + +# Score component weights. +_WEIGHT_IDF = 0.4 +_WEIGHT_BURST = 0.3 +_WEIGHT_CONCENTRATION = 0.3 + + +class JargonStatisticalFilter: + """Zero-cost statistical pre-filter for jargon candidate detection. + + Call ``update_from_message`` on every incoming message (< 1 ms cost). + Call ``get_jargon_candidates`` when batch analysis triggers to retrieve + high-confidence candidates ranked by a composite statistical score. + + Usage:: + + jfilter = JargonStatisticalFilter() + + # Per-message (zero LLM cost): + jfilter.update_from_message(text, group_id, sender_id) + + # Batch trigger: + candidates = jfilter.get_jargon_candidates(group_id, top_k=20) + """ + + def __init__(self) -> None: + # group_id → {term → count} + self._group_term_freq: Dict[str, Dict[str, int]] = defaultdict( + lambda: defaultdict(int) + ) + + # term → total count across all groups + self._global_term_freq: Dict[str, int] = defaultdict(int) + + # group_id → {term → {sender_id → count}} + self._user_term_freq: Dict[str, Dict[str, Dict[str, int]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(int)) + ) + + # group_id → {term → first_seen_timestamp} + self._term_first_seen: Dict[str, Dict[str, float]] = defaultdict(dict) + + # group_id → {term → [context_examples]} + self._term_contexts: Dict[str, Dict[str, List[str]]] = defaultdict( + lambda: defaultdict(list) + ) + + # Set of groups that have been updated since last candidate pull. + self._dirty_groups: Set[str] = set() + + # jieba instance (lazy-loaded). + self._jieba_loaded = False + + # Public API + + def update_from_message( + self, + content: str, + group_id: str, + sender_id: str, + ) -> None: + """Update term frequency tables from a single message. + + This method is designed to be called on every incoming message. + Typical wall-clock cost is < 1 ms (dominated by jieba tokenisation). + + Args: + content: The raw message text. + group_id: Chat group identifier. + sender_id: Message sender identifier. + """ + if not content or not group_id: + return + + tokens = self._tokenize(content) + if not tokens: + return + + now = time.time() + group_freq = self._group_term_freq[group_id] + user_freq = self._user_term_freq[group_id] + first_seen = self._term_first_seen[group_id] + contexts = self._term_contexts[group_id] + + for token in tokens: + group_freq[token] += 1 + self._global_term_freq[token] += 1 + user_freq[token][sender_id] += 1 + + if token not in first_seen: + first_seen[token] = now + + # Store limited context examples. + ctx_list = contexts[token] + if len(ctx_list) < _MAX_CONTEXT_EXAMPLES: + ctx_list.append(content) + + self._dirty_groups.add(group_id) + + def get_jargon_candidates( + self, + group_id: str, + top_k: int = 20, + exclude_terms: Optional[Set[str]] = None, + ) -> List[Dict[str, Any]]: + """Retrieve top-K jargon candidates ranked by composite score. + + The composite score combines three signals: + 1. **Cross-group IDF** (weight 0.4): Terms frequent within the + group but rare across other groups. + 2. **Burst frequency** (weight 0.3): Terms that appeared recently + and gained frequency rapidly. + 3. **User concentration** (weight 0.3): Terms used by only a few + users (insider language). + + Args: + group_id: The group to analyse. + top_k: Maximum candidates to return. + exclude_terms: Set of terms to skip (e.g. already-confirmed + jargon in the database). + + Returns: + List of candidate dicts sorted by score descending, each with + keys: ``term``, ``score``, ``frequency``, ``idf``, + ``burst_score``, ``unique_users``, ``context_examples``. + """ + group_freq = self._group_term_freq.get(group_id) + if not group_freq: + return [] + + exclude = exclude_terms or set() + num_groups = max(len(self._group_term_freq), 1) + candidates: List[Dict[str, Any]] = [] + + for term, freq in group_freq.items(): + if freq < _MIN_FREQUENCY: + continue + if term in exclude: + continue + + # Signal 1: Cross-group IDF. + groups_containing = sum( + 1 for gf in self._group_term_freq.values() if term in gf + ) + idf = math.log(num_groups / max(groups_containing, 1)) + + # Signal 2: Burst frequency (frequency / age_days). + burst_score = self._calc_burst_score(term, group_id) + + # Signal 3: User concentration (1 / unique_users). + unique_users = len( + self._user_term_freq.get(group_id, {}).get(term, {}) + ) + concentration = 1.0 / max(unique_users, 1) + + # Composite score. + score = ( + idf * _WEIGHT_IDF + + burst_score * _WEIGHT_BURST + + concentration * _WEIGHT_CONCENTRATION + ) + + candidates.append({ + "term": term, + "score": round(score, 4), + "frequency": freq, + "idf": round(idf, 4), + "burst_score": round(burst_score, 4), + "unique_users": unique_users, + "context_examples": self._term_contexts.get( + group_id, {} + ).get(term, [])[:5], + }) + + candidates.sort(key=lambda x: x["score"], reverse=True) + return candidates[:top_k] + + def get_group_stats(self, group_id: str) -> Dict[str, Any]: + """Return summary statistics for a group's term table. + + Useful for monitoring and dashboard display. + """ + group_freq = self._group_term_freq.get(group_id, {}) + return { + "total_unique_terms": len(group_freq), + "total_occurrences": sum(group_freq.values()), + "terms_above_threshold": sum( + 1 for f in group_freq.values() if f >= _MIN_FREQUENCY + ), + } + + def reset_group(self, group_id: str) -> None: + """Clear all statistical data for a specific group.""" + self._group_term_freq.pop(group_id, None) + self._user_term_freq.pop(group_id, None) + self._term_first_seen.pop(group_id, None) + self._term_contexts.pop(group_id, None) + self._dirty_groups.discard(group_id) + logger.debug(f"[JargonFilter] Reset statistics for group {group_id}") + + # Internal helpers + + def _tokenize(self, text: str) -> List[str]: + """Segment text into tokens using jieba. + + Returns tokens with length >= _MIN_TERM_LENGTH, excluding + common stopwords and punctuation. + """ + self._ensure_jieba() + import jieba + + tokens = [] + for word in jieba.cut(text): + word = word.strip() + if len(word) >= _MIN_TERM_LENGTH and not self._is_stopword(word): + tokens.append(word) + return tokens + + def _ensure_jieba(self) -> None: + """Lazily initialise jieba to avoid import-time cost.""" + if not self._jieba_loaded: + try: + import jieba + jieba.setLogLevel(20) # Suppress jieba's verbose logging. + self._jieba_loaded = True + except ImportError: + logger.warning( + "[JargonFilter] jieba is not installed. " + "Install via: pip install jieba" + ) + + def _calc_burst_score(self, term: str, group_id: str) -> float: + """Calculate burst frequency: freq / age_in_days. + + A high value means the term gained popularity quickly. + """ + first_seen = self._term_first_seen.get(group_id, {}).get(term, 0) + if first_seen == 0: + return 0.0 + age_days = max((time.time() - first_seen) / 86400.0, 1.0) + freq = self._group_term_freq.get(group_id, {}).get(term, 0) + return freq / age_days + + @staticmethod + def _is_stopword(word: str) -> bool: + """Quick check for common Chinese stopwords and punctuation.""" + _STOPWORDS = frozenset({ + "的", "了", "在", "是", "我", "有", "和", "就", + "不", "人", "都", "一", "个", "上", "也", "很", + "到", "说", "要", "去", "你", "会", "着", "没", + "看", "好", "自", "这", "他", "她", "它", "们", + "吗", "吧", "呢", "啊", "哦", "嗯", "呀", "哈", + "那", "么", "什", "呢", "啦", "噢", "嘛", "哇", + "来", "对", "把", "让", "被", "给", "从", "还", + "比", "得", "过", "可", "能", "为", "以", "而", + "但", "或", "如", "与", "等", "及", "其", "之", + }) + return word in _STOPWORDS diff --git a/services/learning/__init__.py b/services/learning/__init__.py new file mode 100644 index 0000000..69633f9 --- /dev/null +++ b/services/learning/__init__.py @@ -0,0 +1,5 @@ +"""Learning services — dialog analysis, realtime processing, group orchestration, message pipeline.""" + +from .message_pipeline import MessagePipeline + +__all__ = ["MessagePipeline"] diff --git a/services/learning/dialog_analyzer.py b/services/learning/dialog_analyzer.py new file mode 100644 index 0000000..33f7a0d --- /dev/null +++ b/services/learning/dialog_analyzer.py @@ -0,0 +1,247 @@ +"""Few-shot dialog generation, dialog-pair validation, and style review management. + +Extracted from main.py to encapsulate dialog analysis logic used during +expression-style learning. +""" + +import time +from typing import Any, Dict, List, Optional + +from astrbot.api import logger + + +class DialogAnalyzer: + """Generates few-shot dialog examples and manages style-learning reviews. + + Dependencies are injected via constructor to keep this class testable + and decoupled from the plugin instance. + + Args: + factory_manager: ``FactoryManager`` for obtaining service/component factories. + db_manager: Database manager with ``create_style_learning_review`` + and ``get_db_connection`` support. + """ + + def __init__(self, factory_manager: Any, db_manager: Any) -> None: + self._factory_manager = factory_manager + self._db_manager = db_manager + + # Few-shot dialog generation + + async def generate_few_shots_dialog( + self, group_id: str, message_data_list: List[Any] + ) -> str: + """Generate few-shot dialog content from collected messages. + + Requires at least 10 messages and 3 valid dialog pairs to produce + output. Returns an empty string when the threshold is not met. + """ + try: + if len(message_data_list) < 10: + logger.debug( + f"群组 {group_id} 消息数量不足10条" + f"(当前{len(message_data_list)}条),跳过Few Shots生成" + ) + return "" + + dialog_pairs: List[Dict[str, str]] = [] + sorted_messages = sorted(message_data_list, key=lambda x: x.timestamp) + + for i in range(len(sorted_messages) - 1): + current_msg = sorted_messages[i] + next_msg = sorted_messages[i + 1] + + # Skip consecutive messages from the same sender + if current_msg.sender_id == next_msg.sender_id: + continue + + user_msg = current_msg.message.strip() + bot_response = next_msg.message.strip() + + # Basic length / trivial-content filter + if ( + len(user_msg) < 5 + or len(bot_response) < 5 + or user_msg in ("?", "??", "...", "。。。") + or bot_response in ("?", "??", "...", "。。。") + ): + continue + + # Filter duplicate / contained content + if ( + user_msg == bot_response + or user_msg in bot_response + or bot_response in user_msg + ): + logger.debug( + f"过滤重复内容: A='{user_msg[:30]}...' B='{bot_response[:30]}...'" + ) + continue + + if await self.is_valid_dialog_pair(current_msg, next_msg, group_id): + dialog_pairs.append({"user": user_msg, "assistant": bot_response}) + + if len(dialog_pairs) >= 3: + selected_pairs = dialog_pairs[:5] + few_shots_lines = [ + "*Here are few shots of dialogs, you need to imitate " + "the tone of 'B' in the following dialogs to respond:" + ] + for pair in selected_pairs: + few_shots_lines.append(f"A: {pair['user']}") + few_shots_lines.append(f"B: {pair['assistant']}") + + logger.info( + f"群组 {group_id} 生成了 {len(selected_pairs)} 组Few Shots对话" + ) + return "\n".join(few_shots_lines) + + logger.debug( + f"群组 {group_id} 未找到足够的有效对话片段" + f"(需要至少3组,当前{len(dialog_pairs)}组)" + ) + return "" + + except Exception as e: + logger.error(f"生成Few Shots对话失败: {e}") + return "" + + # Dialog-pair validation + + async def is_valid_dialog_pair( + self, msg1: Any, msg2: Any, group_id: str + ) -> bool: + """Determine whether two messages form a genuine dialog pair. + + Uses the professional ``MessageRelationshipAnalyzer`` when available, + falling back to a simple inequality check otherwise. + """ + try: + if ( + not self._factory_manager + or not hasattr(self._factory_manager, "_service_factory") + or not self._factory_manager._service_factory + ): + return msg1.message != msg2.message + + relationship_analyzer = ( + self._factory_manager.get_service_factory() + .create_message_relationship_analyzer() + ) + if not relationship_analyzer: + return msg1.message != msg2.message + + msg1_dict = { + "message_id": msg1.message_id + or str(hash(f"{msg1.timestamp}{msg1.sender_id}")), + "sender_id": msg1.sender_id, + "message": msg1.message, + "timestamp": msg1.timestamp, + } + msg2_dict = { + "message_id": msg2.message_id + or str(hash(f"{msg2.timestamp}{msg2.sender_id}")), + "sender_id": msg2.sender_id, + "message": msg2.message, + "timestamp": msg2.timestamp, + } + + relationship = await relationship_analyzer._analyze_message_pair( + msg1_dict, msg2_dict, group_id + ) + + if relationship: + is_valid = ( + relationship.relationship_type + in ("direct_reply", "topic_continuation") + and relationship.confidence > 0.5 + ) + if is_valid: + logger.debug( + f"识别对话关系: {relationship.relationship_type} " + f"(置信度: {relationship.confidence:.2f})" + ) + return is_valid + + return False + + except Exception as e: + logger.error(f"消息关系判断失败: {e}", exc_info=True) + return False + + # Style-learning review management + + async def create_style_learning_review_request( + self, + group_id: str, + learned_patterns: List[Any], + few_shots_content: str, + ) -> None: + """Create a review request for learned dialog-style patterns. + + Skips creation when an identical pending review already exists + (de-duplication). + """ + try: + existing_reviews = await self.get_pending_style_reviews(group_id) + if existing_reviews: + for existing in existing_reviews: + if existing.get("few_shots_content", "") == few_shots_content: + logger.info( + f"群组 {group_id} 已存在相同的待审查风格学习记录,跳过重复创建" + ) + return + + review_data = { + "type": "style_learning", + "group_id": group_id, + "timestamp": time.time(), + "learned_patterns": [p.to_dict() for p in learned_patterns], + "few_shots_content": few_shots_content, + "status": "pending", + "description": ( + f"群组 {group_id} 的对话风格学习结果" + f"(包含 {len(learned_patterns)} 个表达模式)" + ), + } + + await self._db_manager.create_style_learning_review(review_data) + logger.info(f"对话风格学习审查请求已创建: {group_id}") + + except Exception as e: + logger.error(f"创建对话风格学习审查请求失败: {e}") + + async def get_pending_style_reviews( + self, group_id: str + ) -> List[Dict[str, Any]]: + """Retrieve pending style-learning review records for a group.""" + try: + async with self._db_manager.get_session() as session: + from sqlalchemy import select, desc + from ...models.orm.learning import StyleLearningReview + + stmt = ( + select(StyleLearningReview) + .where( + StyleLearningReview.group_id == group_id, + StyleLearningReview.status == 'pending', + StyleLearningReview.type == 'style_learning', + ) + .order_by(desc(StyleLearningReview.timestamp)) + .limit(10) + ) + result = await session.execute(stmt) + reviews = result.scalars().all() + return [ + { + "id": r.id, + "group_id": r.group_id, + "few_shots_content": r.few_shots_content, + "timestamp": r.timestamp, + } + for r in reviews + ] + + except Exception as e: + logger.error(f"获取待审查风格学习记录失败: {e}") + return [] diff --git a/services/learning/group_orchestrator.py b/services/learning/group_orchestrator.py new file mode 100644 index 0000000..7dcac1c --- /dev/null +++ b/services/learning/group_orchestrator.py @@ -0,0 +1,271 @@ +"""Group learning orchestration — smart-start, auto-start, active group discovery. + +Manages per-group learning tasks, throttling, and automatic scheduling. +""" + +import asyncio +import time +from typing import Any, Dict, List + +from astrbot.api import logger + + +class GroupLearningOrchestrator: + """Orchestrate learning tasks across chat groups. + + Owns the ``learning_tasks`` mapping and provides methods to smart-start + learning, discover active groups, and clean up on shutdown. + + Args: + plugin_config: Plugin configuration object. + message_collector: Message collector service. + progressive_learning: Progressive learning service. + service_factory: Service factory from ``FactoryManager``. + qq_filter: QQ group filter with whitelist/blacklist support. + db_manager: Database manager for ORM queries. + """ + + def __init__( + self, + plugin_config: Any, + message_collector: Any, + progressive_learning: Any, + qq_filter: Any, + db_manager: Any, + ) -> None: + self._config = plugin_config + self._message_collector = message_collector + self._progressive_learning = progressive_learning + self._qq_filter = qq_filter + self._db_manager = db_manager + + self.learning_tasks: Dict[str, asyncio.Task] = {} + + # Per-group last-start timestamps (keyed by group_id) + self._last_learning_start: Dict[str, float] = {} + + # Public API + + async def smart_start_learning_for_group(self, group_id: str) -> None: + """Smart-start a learning task for *group_id* with frequency throttling.""" + try: + if group_id in self.learning_tasks: + return + + current_time = time.time() + last_start = self._last_learning_start.get(group_id, 0) + interval_seconds = self._config.learning_interval_hours * 3600 + + if current_time - last_start < interval_seconds: + remaining = interval_seconds - (current_time - last_start) + logger.debug( + f"群组 {group_id} 学习间隔未到,剩余时间: {remaining / 60:.1f}分钟" + ) + return + + stats = await self._message_collector.get_statistics(group_id) + if not isinstance(stats, dict): + logger.warning( + f"get_statistics 返回了非字典类型: {type(stats)}, " + f"值: {stats}, 跳过学习启动" + ) + return + + total_messages = self._safe_int( + stats.get("total_messages", 0), "total_messages" + ) + if total_messages is None: + return + + min_messages = self._safe_int( + self._config.min_messages_for_learning, + "min_messages_for_learning", + default=10, + ) + + if total_messages < min_messages: + logger.debug( + f"群组 {group_id} 消息数量未达到学习阈值: " + f"{total_messages}/{min_messages}" + ) + return + + self._last_learning_start[group_id] = current_time + + learning_task = asyncio.create_task( + self._start_group_learning(group_id) + ) + + def _on_complete(task: asyncio.Task) -> None: + self.learning_tasks.pop(group_id, None) + if task.exception(): + logger.error( + f"群组 {group_id} 学习任务异常: {task.exception()}" + ) + else: + logger.info(f"群组 {group_id} 学习任务完成") + + learning_task.add_done_callback(_on_complete) + self.learning_tasks[group_id] = learning_task + logger.info(f"为群组 {group_id} 启动了智能学习任务") + + except Exception as e: + logger.error(f"智能启动学习失败: {e}") + + async def delayed_auto_start_learning(self) -> None: + """Auto-start learning for active groups after a startup delay.""" + try: + await asyncio.sleep(30) + active_groups = await self.get_active_groups() + + for group_id in active_groups: + try: + await self.smart_start_learning_for_group(group_id) + await asyncio.sleep(5) + except Exception as e: + logger.error(f"延迟启动群组 {group_id} 学习失败: {e}") + + except Exception as e: + logger.error(f"延迟自动启动学习失败: {e}") + + async def get_active_groups(self) -> List[str]: + """Discover active groups using ORM queries with whitelist/blacklist.""" + try: + if not self._db_manager: + logger.warning("数据库管理器未初始化,无法获取活跃群组") + return [] + + if hasattr(self._db_manager, "_started") and not self._db_manager._started: + logger.warning("SQLAlchemy 数据库管理器未启动,无法获取活跃群组") + return [] + + allowed_groups = self._qq_filter.get_allowed_group_ids() + blocked_groups = self._qq_filter.get_blocked_group_ids() + + if allowed_groups: + logger.info(f"应用群组白名单过滤,仅查询: {allowed_groups}") + if blocked_groups: + logger.info(f"应用群组黑名单过滤,排除: {blocked_groups}") + + async with self._db_manager.get_session() as session: + from sqlalchemy import select, func + from ...models.orm import RawMessage + + def _apply_filter(stmt): + if allowed_groups: + stmt = stmt.where(RawMessage.group_id.in_(allowed_groups)) + if blocked_groups: + stmt = stmt.where(RawMessage.group_id.notin_(blocked_groups)) + return stmt + + # Progressively widen the search window: 24h → 7d → all-time + for label, cutoff in ( + ("24小时", int(time.time() - 86400)), + ("7天", int(time.time() - 86400 * 7)), + ("全部", None), + ): + base = select( + RawMessage.group_id, + func.count(RawMessage.id).label("msg_count"), + ).where( + RawMessage.group_id.isnot(None), + RawMessage.group_id != "", + ) + + if cutoff is not None: + base = base.where(RawMessage.timestamp > cutoff) + + base = _apply_filter(base) + + min_msgs = self._config.min_messages_for_learning + if label == "7天": + min_msgs = max(1, min_msgs // 2) + elif label == "全部": + min_msgs = 1 + + stmt = ( + base.group_by(RawMessage.group_id) + .having(func.count(RawMessage.id) >= min_msgs) + .order_by(func.count(RawMessage.id).desc()) + .limit(10) + ) + + result = await session.execute(stmt) + active_groups = [ + row.group_id for row in result if row.group_id + ] + + if active_groups: + logger.info( + f"在{label}范围内发现 {len(active_groups)} 个活跃群组: " + f"{active_groups}" + ) + return active_groups + + if cutoff is not None: + logger.warning( + f"最近{label}内没有活跃群组,扩大搜索范围..." + ) + + logger.info("未发现任何活跃群组") + return [] + + except Exception as e: + logger.error(f"获取活跃群组失败: {e}") + return [] + + async def cancel_all(self) -> None: + """Cancel all running learning tasks (called during shutdown).""" + for group_id, task in list(self.learning_tasks.items()): + try: + await self._progressive_learning.stop_learning() + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + logger.info(f"群组 {group_id} 学习任务已停止") + except Exception as e: + logger.error(f"停止群组 {group_id} 学习任务失败: {e}") + self.learning_tasks.clear() + + # Internal helpers + + async def _start_group_learning(self, group_id: str) -> None: + """Start the progressive learning session for a single group.""" + try: + success = await self._progressive_learning.start_learning(group_id) + if success: + logger.info(f"群组 {group_id} 学习任务启动成功") + else: + logger.warning(f"群组 {group_id} 学习任务启动失败") + except Exception as e: + logger.error(f"群组 {group_id} 学习任务启动异常: {e}") + + @staticmethod + def _safe_int( + value: Any, name: str, *, default: int | None = None + ) -> int | None: + """Safely convert *value* to ``int`` with detailed logging.""" + try: + if isinstance(value, str) and not value.replace("-", "").isdigit(): + if default is not None: + logger.warning( + f"{name} 是非数字字符串: '{value}', 使用默认值{default}" + ) + return default + logger.warning(f"{name} 是非数字字符串: '{value}', 跳过") + return None + return int(value) if value else 0 + except (ValueError, TypeError) as e: + if default is not None: + logger.warning( + f"{name} 转换失败: 原始值={value}, 错误={e}, " + f"使用默认值{default}" + ) + return default + logger.warning( + f"{name} 转换失败: 原始值={value}, 类型={type(value)}, 错误={e}" + ) + return None diff --git a/services/learning/message_pipeline.py b/services/learning/message_pipeline.py new file mode 100644 index 0000000..4b2582b --- /dev/null +++ b/services/learning/message_pipeline.py @@ -0,0 +1,245 @@ +"""消息处理流水线 — 协调后台学习、黑话挖掘、好感度更新""" +import asyncio +import time +from typing import Any, Optional + +from astrbot.api import logger + +from ...core.interfaces import MessageData +from ...statics.messages import LogMessages + + +class MessagePipeline: + """消息处理流水线 — 每条消息的后台处理编排""" + + def __init__( + self, + plugin_config: Any, + message_collector: Any, + enhanced_interaction: Any, + jargon_miner_manager: Optional[Any], + jargon_statistical_filter: Optional[Any], + v2_integration: Optional[Any], + realtime_processor: Any, + group_orchestrator: Any, + conversation_goal_manager: Optional[Any], + affection_manager: Any, + db_manager: Any, + ): + self._config = plugin_config + self._message_collector = message_collector + self._enhanced_interaction = enhanced_interaction + self._jargon_miner_manager = jargon_miner_manager + self._jargon_statistical_filter = jargon_statistical_filter + self._v2_integration = v2_integration + self._realtime_processor = realtime_processor + self._group_orchestrator = group_orchestrator + self._conversation_goal_manager = conversation_goal_manager + self._affection_manager = affection_manager + self._db_manager = db_manager + + # 后台学习流水线(6 步) + + async def process_learning( + self, + group_id: str, + sender_id: str, + message_text: str, + event: Any, + ) -> None: + """后台处理学习相关操作(非阻塞) + + 通过 asyncio.create_task() 在后台运行。 + 为避免 'Future attached to different loop' 错误,数据库操作包装在异常处理中。 + """ + try: + # 1. 消息收集 + try: + await self._message_collector.collect_message( + { + "sender_id": sender_id, + "sender_name": event.get_sender_name(), + "message": message_text, + "group_id": group_id, + "timestamp": time.time(), + "platform": event.get_platform_name(), + } + ) + except RuntimeError as e: + if "attached to a different loop" in str(e): + logger.warning( + f"消息收集遇到事件循环问题(已知 MySQL 限制)," + f"消息将被跳过: {str(e)[:100]}" + ) + else: + raise + except Exception as e: + logger.error(f"消息收集失败: {e}") + + # 2. 增强交互(多轮对话管理) + try: + await self._enhanced_interaction.update_conversation_context( + group_id, sender_id, message_text + ) + except Exception as e: + logger.error(LogMessages.ENHANCED_INTERACTION_FAILED.format(error=e)) + + # 2.5 黑话统计预筛(<1ms, 零 LLM 成本) + if self._jargon_statistical_filter: + try: + self._jargon_statistical_filter.update_from_message( + message_text, group_id, sender_id + ) + except Exception: + pass # best-effort + + # 3. 黑话挖掘 — 每收集 10 条消息触发一次 + stats = await self._message_collector.get_statistics(group_id) + raw_message_count = stats.get("raw_messages", 0) + if raw_message_count % 10 == 0 and raw_message_count >= 10: + asyncio.create_task(self.mine_jargon(group_id)) + + # 3.5 V2 per-message processing + if self._v2_integration: + try: + msg_data = MessageData( + message=message_text, + sender_id=sender_id, + sender_name=event.get_sender_name() or sender_id, + group_id=group_id, + timestamp=time.time(), + platform=event.get_platform_name() or "unknown", + ) + await self._v2_integration.process_message(msg_data, group_id) + except Exception as e: + logger.debug(f"V2 message processing failed: {e}") + + # 4. 实时学习 + if self._config.enable_realtime_learning: + asyncio.create_task( + self._realtime_processor.process_realtime_background( + group_id, message_text, sender_id + ) + ) + + # 5. 智能启动学习任务 + await self._group_orchestrator.smart_start_learning_for_group(group_id) + + # 6. 对话目标管理 + if self._config.enable_goal_driven_chat: + try: + if self._conversation_goal_manager: + goal = await self._conversation_goal_manager.get_or_create_conversation_goal( + user_id=sender_id, + group_id=group_id, + user_message=message_text, + ) + if goal: + goal_type = goal["final_goal"].get("type", "unknown") + goal_name = goal["final_goal"].get("name", "未知目标") + topic = goal["final_goal"].get("topic", "未知话题") + current_stage = goal["current_stage"].get("task", "初始化") + logger.info( + f" [对话目标] 会话目标: {goal_name} " + f"(类型: {goal_type}), 话题: {topic}, " + f"当前阶段: {current_stage}" + ) + except Exception as e: + logger.error(f"对话目标处理失败: {e}", exc_info=True) + + except Exception as e: + logger.error(f"后台学习处理失败: {e}", exc_info=True) + + # 黑话挖掘 + + async def mine_jargon(self, group_id: str) -> None: + """后台黑话挖掘 — 完全异步、非阻塞 + + 1. 检查触发条件(频率控制) + 2. 获取统计候选词(零 LLM 成本) + 3. 无统计候选时回退到 LLM 提取 + 4. 保存/更新到数据库并在阈值处触发推理 + """ + try: + if not self._jargon_miner_manager: + logger.debug("[JargonMining] JargonMinerManager not initialised, skip") + return + + jargon_miner = self._jargon_miner_manager.get_or_create_miner(group_id) + + stats = await self._message_collector.get_statistics(group_id) + recent_message_count = stats.get("raw_messages", 0) + + if not jargon_miner.should_trigger(recent_message_count): + logger.debug( + f"[JargonMining] Group {group_id} trigger conditions not met" + ) + return + + recent_messages = await self._db_manager.get_recent_raw_messages( + group_id, limit=30 + ) + + if len(recent_messages) < 10: + logger.debug( + f"[JargonMining] Group {group_id} insufficient messages " + f"({len(recent_messages)}<10)" + ) + return + + logger.info( + f"[JargonMining] Analysing {len(recent_messages)} messages " + f"from group {group_id}" + ) + + chat_messages = "\n".join( + [ + f"{msg.get('sender_id', 'unknown')}: {msg.get('message', '')}" + for msg in recent_messages + ] + ) + + statistical_candidates = None + if self._jargon_statistical_filter: + statistical_candidates = ( + self._jargon_statistical_filter.get_jargon_candidates( + group_id, top_k=20 + ) + ) + if not statistical_candidates: + statistical_candidates = None + + await jargon_miner.run_once( + chat_messages, + len(recent_messages), + statistical_candidates=statistical_candidates, + ) + + logger.debug(f"[JargonMining] Group {group_id} learning complete") + + except Exception as e: + logger.error( + f"[JargonMining] Background task failed (group={group_id}): {e}", + exc_info=True, + ) + + # 好感度处理 + + async def process_affection( + self, group_id: str, sender_id: str, message_text: str + ) -> None: + """后台处理好感度更新(非阻塞)""" + try: + affection_result = ( + await self._affection_manager.process_message_interaction( + group_id, sender_id, message_text + ) + ) + if affection_result.get("success"): + logger.debug( + LogMessages.AFFECTION_PROCESSING_SUCCESS.format( + result=affection_result + ) + ) + except Exception as e: + logger.error(LogMessages.AFFECTION_PROCESSING_FAILED.format(error=e)) diff --git a/services/learning/realtime_processor.py b/services/learning/realtime_processor.py new file mode 100644 index 0000000..481b53b --- /dev/null +++ b/services/learning/realtime_processor.py @@ -0,0 +1,338 @@ +"""Realtime message processing — expression-style learning and message filtering. + +Handles the per-message processing pipeline that runs in the background +after each incoming message. +""" + +import re +import time +from typing import Any, Callable, Coroutine, Dict, List, Optional + +from astrbot.api import logger + +from ...core.interfaces import MessageData +from ...statics.messages import StatusMessages +from .dialog_analyzer import DialogAnalyzer + + +class RealtimeProcessor: + """Process incoming messages for realtime learning and filtering. + + Orchestrates expression-style learning, message LLM filtering, and + temporary persona updates. + + Args: + plugin_config: Plugin configuration object. + message_collector: Message collector service. + multidimensional_analyzer: Analyzer for LLM-based message filtering. + persona_manager: Persona manager for current persona retrieval. + temporary_persona_updater: Service for temporary style prompt updates. + dialog_analyzer: ``DialogAnalyzer`` for few-shot generation. + learning_stats: Shared ``LearningStats`` dataclass instance. + factory_manager: ``FactoryManager`` for component creation. + db_manager: Database manager for raw message retrieval. + """ + + def __init__( + self, + plugin_config: Any, + message_collector: Any, + multidimensional_analyzer: Any, + persona_manager: Any, + temporary_persona_updater: Any, + dialog_analyzer: DialogAnalyzer, + learning_stats: Any, + factory_manager: Any, + db_manager: Any, + ) -> None: + self._config = plugin_config + self._message_collector = message_collector + self._multidimensional_analyzer = multidimensional_analyzer + self._persona_manager = persona_manager + self._temporary_persona_updater = temporary_persona_updater + self._dialog_analyzer = dialog_analyzer + self._learning_stats = learning_stats + self._factory_manager = factory_manager + self._db_manager = db_manager + + # Callback set by the plugin to trigger incremental prompt updates + self.update_system_prompt_callback: Optional[ + Callable[[str], Coroutine[Any, Any, None]] + ] = None + + # Public API + + async def process_realtime_background( + self, group_id: str, message_text: str, sender_id: str + ) -> None: + """Background wrapper — fully async, never blocks the main flow.""" + try: + await self.process_message_realtime(group_id, message_text, sender_id) + except Exception as e: + logger.error( + f"实时学习后台处理失败 (group={group_id}): {e}", exc_info=True + ) + + async def process_message_realtime( + self, group_id: str, message_text: str, sender_id: str + ) -> None: + """Process a single message in realtime — filter + expression learning.""" + try: + # Basic guards + if len(message_text.strip()) < self._config.message_min_length: + return + if len(message_text) > self._config.message_max_length: + return + if message_text.strip() in ("", "???", "。。。", "...", "嗯", "哦", "额"): + return + + # Expression-style learning (bypasses filtering) + await self._process_expression_style_learning( + group_id, message_text, sender_id + ) + + # Batch mode: skip LLM filtering if disabled + if not self._config.enable_realtime_llm_filter: + await self._message_collector.add_filtered_message( + { + "message": message_text, + "sender_id": sender_id, + "group_id": group_id, + "timestamp": time.time(), + "confidence": 0.6, + } + ) + self._learning_stats.filtered_messages += 1 + if not hasattr(self._config, "filtered_messages"): + self._config.filtered_messages = 0 + self._config.filtered_messages = ( + self._learning_stats.filtered_messages + ) + + # LLM-based filtering + current_persona_description = ( + await self._persona_manager.get_current_persona_description(group_id) + ) + + if await self._multidimensional_analyzer.filter_message_with_llm( + message_text, current_persona_description + ): + await self._message_collector.add_filtered_message( + { + "message": message_text, + "sender_id": sender_id, + "group_id": group_id, + "timestamp": time.time(), + "confidence": 0.8, + } + ) + self._learning_stats.filtered_messages += 1 + if not hasattr(self._config, "filtered_messages"): + self._config.filtered_messages = 0 + self._config.filtered_messages = ( + self._learning_stats.filtered_messages + ) + + except Exception as e: + logger.error( + StatusMessages.REALTIME_PROCESSING_ERROR.format(error=e), + exc_info=True, + ) + + # Expression-style learning + + async def _process_expression_style_learning( + self, group_id: str, message_text: str, sender_id: str + ) -> None: + """Learn expression styles directly from raw messages.""" + try: + stats = await self._message_collector.get_statistics(group_id) + raw_message_count = stats.get("raw_messages", 0) + + if raw_message_count < 5: + logger.debug( + f"群组 {group_id} 原始消息数量不足,当前:{raw_message_count},需要至少5条" + ) + return + + logger.info( + f"群组 {group_id} 开始表达风格学习,当前消息数:{raw_message_count}" + ) + + recent_raw_messages = await self._db_manager.get_recent_raw_messages( + group_id, limit=25 + ) + if not recent_raw_messages or len(recent_raw_messages) < 3: + logger.debug( + f"群组 {group_id} 原始消息数量不足,数据库中只有 " + f"{len(recent_raw_messages) if recent_raw_messages else 0} 条" + ) + return + + message_data_list = self._build_message_data_list( + recent_raw_messages, group_id, sender_id + ) + + if len(message_data_list) < 3: + logger.debug( + f"群组 {group_id} 有效学习消息不足3条,跳过表达风格学习," + f"当前:{len(message_data_list)}" + ) + return + + logger.info( + f"群组 {group_id} 准备进行表达风格学习," + f"有效消息数:{len(message_data_list)}" + ) + + expression_learner = ( + self._factory_manager.get_component_factory() + .create_expression_pattern_learner() + ) + if not expression_learner: + logger.warning("表达模式学习器未正确初始化") + return + + learning_success = await expression_learner.trigger_learning_for_group( + group_id, message_data_list + ) + if not learning_success: + logger.debug(f"群组 {group_id} 表达风格学习未产生有效结果") + return + + logger.info(f"群组 {group_id} 表达风格学习成功") + + try: + learned_patterns = await expression_learner.get_expression_patterns( + group_id, limit=5 + ) + if learned_patterns: + await self._apply_style_to_prompt_temporarily( + group_id, learned_patterns + ) + few_shots_content = ( + await self._dialog_analyzer.generate_few_shots_dialog( + group_id, message_data_list + ) + ) + if few_shots_content: + await self._dialog_analyzer.create_style_learning_review_request( + group_id, learned_patterns, few_shots_content + ) + logger.info( + f"群组 {group_id} 表达风格学习结果已临时应用到prompt," + "并已提交人格审查" + ) + else: + logger.info( + f"群组 {group_id} 表达风格学习结果已临时应用到prompt" + ) + except Exception as e: + logger.error(f"处理表达风格学习结果失败: {e}") + + self._learning_stats.style_updates += 1 + + if self.update_system_prompt_callback: + await self.update_system_prompt_callback(group_id) + logger.info( + f"群组 {group_id} 表达风格学习结果已应用到system_prompt" + ) + + except Exception as e: + logger.error(f"群组 {group_id} 表达风格学习处理失败: {e}") + + # Temporary style application + + async def _apply_style_to_prompt_temporarily( + self, group_id: str, learned_patterns: List[Any] + ) -> None: + """Apply learned style patterns to the prompt temporarily.""" + try: + if not learned_patterns: + return + + style_descriptions: List[str] = [] + for pattern in learned_patterns[:3]: + situation = ( + pattern.situation + if hasattr(pattern, "situation") + else pattern.get("situation", "") + ) + expression = ( + pattern.expression + if hasattr(pattern, "expression") + else pattern.get("expression", "") + ) + if situation and expression: + style_descriptions.append( + f'当{situation}时,可以使用"{expression}"这样的表达' + ) + + if not style_descriptions: + return + + bullet_list = "\n".join(f"• {desc}" for desc in style_descriptions) + style_prompt = ( + "【临时表达风格特征】(基于最近学习)\n" + "在回复时可以参考以下表达方式:\n" + f"{bullet_list}\n\n" + "注意:这些是临时学习的风格特征,应自然融入回复,不要刻意模仿。" + ) + + success = await self._temporary_persona_updater.apply_temporary_style_update( + group_id, style_prompt + ) + + if success: + logger.info( + f"群组 {group_id} 表达风格已临时应用到prompt," + f"包含 {len(style_descriptions)} 个风格特征" + ) + else: + logger.warning(f"群组 {group_id} 表达风格临时应用失败") + + except Exception as e: + logger.error(f"临时应用风格到prompt失败: {e}") + + # Helpers + + @staticmethod + def _build_message_data_list( + recent_raw_messages: List[Dict[str, Any]], + group_id: str, + sender_id: str, + ) -> List[MessageData]: + """Convert raw DB messages to filtered ``MessageData`` objects.""" + at_pattern = re.compile(r"@[^\s]+\s+") + result: List[MessageData] = [] + + for msg in recent_raw_messages: + if msg.get("sender_id") == sender_id: + continue + + content = msg.get("message", "") + if len(content.strip()) < 5 or len(content) > 500: + continue + if content.strip() in ("", "???", "。。。", "...", "嗯", "哦", "额"): + continue + + processed = content + if "@" in content: + processed = at_pattern.sub("", content).strip() + if len(processed.strip()) < 5: + continue + + result.append( + MessageData( + sender_id=msg.get("sender_id", ""), + sender_name=msg.get("sender_name", ""), + message=processed, + group_id=group_id, + timestamp=msg.get("timestamp", time.time()), + platform=msg.get("platform", "default"), + message_id=msg.get("id"), + reply_to=None, + ) + ) + + return result diff --git a/services/memory_graph_manager.py b/services/memory_graph_manager.py deleted file mode 100644 index 055a76b..0000000 --- a/services/memory_graph_manager.py +++ /dev/null @@ -1,661 +0,0 @@ -""" -记忆图管理器 - 基于MaiBot的记忆图系统设计 -使用NetworkX图结构实现概念关联和智能记忆融合 -""" -import time -import json -import math -import random -from typing import Dict, List, Optional, Tuple, Any, Set -from datetime import datetime -from dataclasses import dataclass, asdict -from collections import Counter - -import networkx as nx - -from astrbot.api import logger - -from ..core.interfaces import MessageData, ServiceLifecycle -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..config import PluginConfig -from ..exceptions import MemoryGraphError, ModelAccessError -from ..utils.json_utils import safe_parse_llm_json -from .database_manager import DatabaseManager -from .time_decay_manager import TimeDecayManager - - -@dataclass -class MemoryNode: - """记忆节点""" - concept: str - memory_items: str - weight: float - created_time: float - last_modified: float - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'MemoryNode': - return cls(**data) - - -@dataclass -class MemoryEdge: - """记忆边""" - concept1: str - concept2: str - strength: float - created_time: float - last_modified: float - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'MemoryEdge': - return cls(**data) - - -class MemoryGraph: - """ - 记忆图 - 完全基于MaiBot的MemoryGraph设计 - 使用NetworkX实现概念关联和记忆管理 - """ - - def __init__(self): - self.G = nx.Graph() # 使用NetworkX的图结构 - - def connect_concepts(self, concept1: str, concept2: str): - """ - 连接两个概念 - 参考MaiBot的connect_dot方法 - - Args: - concept1: 概念1 - concept2: 概念2 - """ - # 避免自连接 - if concept1 == concept2: - return - - current_time = time.time() - - # 如果边已存在,增加strength - if self.G.has_edge(concept1, concept2): - self.G[concept1][concept2]["strength"] = self.G[concept1][concept2].get("strength", 1) + 1 - # 更新最后修改时间 - self.G[concept1][concept2]["last_modified"] = current_time - else: - # 如果是新边,初始化strength为1 - self.G.add_edge( - concept1, - concept2, - strength=1, - created_time=current_time, - last_modified=current_time, - ) - - async def add_memory_node(self, concept: str, memory: str, llm_adapter: Optional[FrameworkLLMAdapter] = None): - """ - 添加记忆节点 - 参考MaiBot的add_dot方法 - 支持LLM智能记忆融合 - - Args: - concept: 概念名称 - memory: 记忆内容 - llm_adapter: LLM适配器,用于记忆融合 - """ - current_time = time.time() - - if concept in self.G: - if "memory_items" in self.G.nodes[concept]: - # 获取现有的记忆项 - existing_memory = self.G.nodes[concept]["memory_items"] - - # 如果现有记忆不为空,则使用LLM整合新旧记忆 - if existing_memory and llm_adapter: - try: - integrated_memory = await self._integrate_memories_with_llm( - existing_memory, str(memory), llm_adapter - ) - self.G.nodes[concept]["memory_items"] = integrated_memory - # 整合成功,增加权重 - current_weight = self.G.nodes[concept].get("weight", 0.0) - self.G.nodes[concept]["weight"] = current_weight + 1.0 - logger.debug(f"节点 {concept} 记忆整合成功,权重增加到 {current_weight + 1.0}") - logger.info(f"节点 {concept} 记忆内容已更新:{integrated_memory}") - except Exception as e: - logger.error(f"LLM整合记忆失败: {e}") - # 降级到简单连接 - new_memory_str = f"{existing_memory} | {memory}" - self.G.nodes[concept]["memory_items"] = new_memory_str - logger.info(f"节点 {concept} 记忆内容已简单拼接并更新:{new_memory_str}") - else: - new_memory_str = str(memory) - self.G.nodes[concept]["memory_items"] = new_memory_str - logger.info(f"节点 {concept} 记忆内容已直接更新:{new_memory_str}") - else: - self.G.nodes[concept]["memory_items"] = str(memory) - # 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time - if "created_time" not in self.G.nodes[concept]: - self.G.nodes[concept]["created_time"] = current_time - logger.info(f"节点 {concept} 创建新记忆:{str(memory)}") - # 更新最后修改时间 - self.G.nodes[concept]["last_modified"] = current_time - else: - # 如果是新节点,创建新的记忆字符串 - self.G.add_node( - concept, - memory_items=str(memory), - weight=1.0, # 新节点初始权重为1.0 - created_time=current_time, - last_modified=current_time, - ) - logger.info(f"新节点 {concept} 已添加,记忆内容已写入:{str(memory)}") - - async def _integrate_memories_with_llm(self, old_memory: str, new_memory: str, llm_adapter: FrameworkLLMAdapter) -> str: - """ - 使用LLM智能整合记忆 - 参考MaiBot的_integrate_memories_with_llm方法 - - Args: - old_memory: 旧记忆 - new_memory: 新记忆 - llm_adapter: LLM适配器 - - Returns: - 整合后的记忆 - """ - from ..statics.prompts import MEMORY_INTEGRATION_PROMPT - - prompt = MEMORY_INTEGRATION_PROMPT.format( - old_memory=old_memory, - new_memory=new_memory - ) - - response = await llm_adapter.generate_response( - prompt, - temperature=0.3, - model_type="refine" - ) - - return response.strip() - - def get_memory_node(self, concept: str) -> Optional[Tuple[str, Dict[str, Any]]]: - """ - 获取记忆节点 - 参考MaiBot的get_dot方法 - - Args: - concept: 概念名称 - - Returns: - (概念名称, 节点数据) 或 None - """ - return (concept, self.G.nodes[concept]) if concept in self.G else None - - def get_related_concepts(self, topic: str, depth: int = 1) -> Tuple[List[str], List[str]]: - """ - 获取相关概念 - 参考MaiBot的get_related_item方法 - - Args: - topic: 主题概念 - depth: 搜索深度 - - Returns: - (第一层相关概念, 第二层相关概念) - """ - if topic not in self.G: - return [], [] - - first_layer_items = [] - second_layer_items = [] - - # 获取相邻节点 - neighbors = list(self.G.neighbors(topic)) - - # 获取当前节点的记忆项 - node_data = self.get_memory_node(topic) - if node_data: - _, data = node_data - if "memory_items" in data: - # 将主题概念的记忆内容加入第一层 - first_layer_items.append(data["memory_items"]) - - # 获取相邻节点的记忆项 - for neighbor in neighbors: - neighbor_data = self.get_memory_node(neighbor) - if neighbor_data: - _, data = neighbor_data - if "memory_items" in data: - first_layer_items.append(data["memory_items"]) - - # 如果需要深度搜索,获取邻居的邻居 - if depth > 1: - second_neighbors = list(self.G.neighbors(neighbor)) - for second_neighbor in second_neighbors: - if second_neighbor != topic and second_neighbor not in neighbors: - second_data = self.get_memory_node(second_neighbor) - if second_data: - _, second_node_data = second_data - if "memory_items" in second_node_data: - second_layer_items.append(second_node_data["memory_items"]) - - return first_layer_items, second_layer_items - - def calculate_information_content(self, text: str) -> float: - """ - 计算文本的信息量(熵) - 参考MaiBot的calculate_information_content方法 - - Args: - text: 文本内容 - - Returns: - 信息熵值 - """ - char_count = Counter(text) - total_chars = len(text) - if total_chars == 0: - return 0 - - entropy = 0 - for count in char_count.values(): - probability = count / total_chars - entropy -= probability * math.log2(probability) - - return entropy - - def get_graph_statistics(self) -> Dict[str, Any]: - """获取图的统计信息""" - return { - "nodes_count": self.G.number_of_nodes(), - "edges_count": self.G.number_of_edges(), - "density": nx.density(self.G), - "connected_components": nx.number_connected_components(self.G), - "average_clustering": nx.average_clustering(self.G) if self.G.number_of_nodes() > 0 else 0, - "average_shortest_path": nx.average_shortest_path_length(self.G) if nx.is_connected(self.G) else 0 - } - - -class MemoryGraphManager: - """ - 记忆图管理器 - 负责记忆图的持久化和管理 - 采用单例模式确保全局唯一实例 - """ - - _instance = None - _initialized = False - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self, config: PluginConfig = None, db_manager: DatabaseManager = None, - llm_adapter: FrameworkLLMAdapter = None, decay_manager: TimeDecayManager = None): - # 防止重复初始化 - if self._initialized: - return - - self.config = config - self.db_manager = db_manager - self.llm_adapter = llm_adapter - self.decay_manager = decay_manager - self._status = ServiceLifecycle.CREATED - - # 为每个群组维护独立的记忆图 - self.memory_graphs: Dict[str, MemoryGraph] = {} - - # 初始化数据库表 - if self.db_manager: - self._init_memory_graph_tables() - - self._initialized = True - - @classmethod - def get_instance(cls, config: PluginConfig = None, db_manager = None, - llm_adapter = None, decay_manager = None) -> 'MemoryGraphManager': - """获取单例实例""" - if cls._instance is None: - cls._instance = cls(config, db_manager, llm_adapter, decay_manager) - else: - # 已有实例但字段为 None 时补充注入 - if llm_adapter is not None and cls._instance.llm_adapter is None: - cls._instance.llm_adapter = llm_adapter - if config is not None and cls._instance.config is None: - cls._instance.config = config - if db_manager is not None and cls._instance.db_manager is None: - cls._instance.db_manager = db_manager - if decay_manager is not None and cls._instance.decay_manager is None: - cls._instance.decay_manager = decay_manager - return cls._instance - - def _init_memory_graph_tables(self): - """初始化记忆图数据库表""" - try: - with self.db_manager.get_connection() as conn: - # 记忆节点表 - conn.execute(''' - CREATE TABLE IF NOT EXISTS memory_nodes ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - concept TEXT NOT NULL, - memory_items TEXT NOT NULL, - weight REAL NOT NULL DEFAULT 1.0, - created_time REAL NOT NULL, - last_modified REAL NOT NULL, - group_id TEXT NOT NULL, - UNIQUE(concept, group_id) - ) - ''') - - # 记忆边表 - conn.execute(''' - CREATE TABLE IF NOT EXISTS memory_edges ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - concept1 TEXT NOT NULL, - concept2 TEXT NOT NULL, - strength REAL NOT NULL DEFAULT 1.0, - created_time REAL NOT NULL, - last_modified REAL NOT NULL, - group_id TEXT NOT NULL, - UNIQUE(concept1, concept2, group_id) - ) - ''') - - conn.commit() - logger.info("记忆图数据库表初始化完成") - except Exception as e: - logger.error(f"初始化记忆图数据库表失败: {e}") - raise MemoryGraphError(f"数据库初始化失败: {e}") - - async def start(self) -> bool: - """启动服务""" - self._status = ServiceLifecycle.RUNNING - logger.info("MemoryGraphManager服务已启动") - return True - - async def stop(self) -> bool: - """停止服务""" - # 保存所有记忆图 - for group_id in self.memory_graphs: - await self.save_memory_graph(group_id) - - self._status = ServiceLifecycle.STOPPED - logger.info("MemoryGraphManager服务已停止") - return True - - def get_memory_graph(self, group_id: str) -> MemoryGraph: - """获取或创建群组的记忆图""" - if group_id not in self.memory_graphs: - self.memory_graphs[group_id] = MemoryGraph() - # 异步加载记忆图数据 - asyncio.create_task(self.load_memory_graph(group_id)) - - return self.memory_graphs[group_id] - - async def load_memory_graph(self, group_id: str): - """从数据库加载记忆图""" - try: - if not self.db_manager: - logger.debug(f"db_manager 为空,无法加载群组 {group_id} 记忆图") - return - - memory_graph = self.memory_graphs.get(group_id, MemoryGraph()) - - with self.db_manager.get_connection() as conn: - # 加载节点 - cursor = conn.execute( - 'SELECT concept, memory_items, weight, created_time, last_modified FROM memory_nodes WHERE group_id = ?', - (group_id,) - ) - - for concept, memory_items, weight, created_time, last_modified in cursor.fetchall(): - memory_graph.G.add_node( - concept, - memory_items=memory_items, - weight=weight, - created_time=created_time, - last_modified=last_modified - ) - - # 加载边 - cursor = conn.execute( - 'SELECT concept1, concept2, strength, created_time, last_modified FROM memory_edges WHERE group_id = ?', - (group_id,) - ) - - for concept1, concept2, strength, created_time, last_modified in cursor.fetchall(): - memory_graph.G.add_edge( - concept1, - concept2, - strength=strength, - created_time=created_time, - last_modified=last_modified - ) - - self.memory_graphs[group_id] = memory_graph - logger.info(f"群组 {group_id} 记忆图加载完成,节点数: {memory_graph.G.number_of_nodes()},边数: {memory_graph.G.number_of_edges()}") - - except Exception as e: - logger.error(f"加载群组 {group_id} 记忆图失败: {e}") - - async def save_memory_graph(self, group_id: str): - """保存记忆图到数据库""" - try: - if group_id not in self.memory_graphs: - return - - if not self.db_manager: - logger.debug(f"db_manager 为空,无法保存群组 {group_id} 记忆图") - return - - memory_graph = self.memory_graphs[group_id] - - with self.db_manager.get_connection() as conn: - # 清除旧数据 - conn.execute('DELETE FROM memory_nodes WHERE group_id = ?', (group_id,)) - conn.execute('DELETE FROM memory_edges WHERE group_id = ?', (group_id,)) - - # 保存节点 - for node, data in memory_graph.G.nodes(data=True): - conn.execute( - 'INSERT INTO memory_nodes (concept, memory_items, weight, created_time, last_modified, group_id) VALUES (?, ?, ?, ?, ?, ?)', - ( - node, - data.get('memory_items', ''), - data.get('weight', 1.0), - data.get('created_time', time.time()), - data.get('last_modified', time.time()), - group_id - ) - ) - - # 保存边 - for u, v, data in memory_graph.G.edges(data=True): - conn.execute( - 'INSERT INTO memory_edges (concept1, concept2, strength, created_time, last_modified, group_id) VALUES (?, ?, ?, ?, ?, ?)', - ( - u, v, - data.get('strength', 1.0), - data.get('created_time', time.time()), - data.get('last_modified', time.time()), - group_id - ) - ) - - conn.commit() - logger.debug(f"群组 {group_id} 记忆图保存完成") - - except Exception as e: - logger.error(f"保存群组 {group_id} 记忆图失败: {e}") - - async def add_memory_from_message(self, message: MessageData, group_id: str): - """ - 从消息中添加记忆 - - Args: - message: 消息数据 - group_id: 群组ID - """ - try: - memory_graph = self.get_memory_graph(group_id) - - # 提取概念和记忆内容 - concepts = await self._extract_concepts_from_message(message) - - for concept in concepts: - # 获取消息文本(兼容 dict 和 MessageData) - msg_text = message.get('message', '') if isinstance(message, dict) else getattr(message, 'message', '') - # 添加记忆节点 - await memory_graph.add_memory_node( - concept=concept, - memory=msg_text, - llm_adapter=self.llm_adapter - ) - - # 建立概念间的连接 - for other_concept in concepts: - if concept != other_concept: - memory_graph.connect_concepts(concept, other_concept) - - # 定期保存 - if random.random() < 0.1: # 10% 概率保存 - await self.save_memory_graph(group_id) - - except Exception as e: - logger.error(f"从消息添加记忆失败: {e}") - - async def _extract_concepts_from_message(self, message: MessageData) -> List[str]: - """ - 从消息中提取概念 - - Args: - message: 消息数据 - - Returns: - 提取的概念列表 - """ - try: - from ..statics.prompts import ENTITY_EXTRACTION_PROMPT - - if not self.llm_adapter: - logger.debug("llm_adapter 未初始化,跳过概念提取") - return [] - - # 兼容 dict 和 MessageData 对象 - if isinstance(message, dict): - text = message.get('message', '') or message.get('content', '') - else: - text = getattr(message, 'message', '') or getattr(message, 'content', '') - - if not text: - return [] - - prompt = ENTITY_EXTRACTION_PROMPT.format(text=text) - - # 二次检查:防止并发场景下 llm_adapter 被重置 - adapter = self.llm_adapter - if not adapter: - return [] - - response = await adapter.generate_response( - prompt, - temperature=0.1, - model_type="filter" # 使用过滤模型进行快速提取 - ) - - # 解析JSON响应 - concepts = safe_parse_llm_json(response) - - if isinstance(concepts, list): - return [str(concept).strip() for concept in concepts if concept] - else: - return [] - - except Exception as e: - logger.error(f"提取概念失败: {e}") - return [] - - async def get_related_memories(self, query: str, group_id: str, limit: int = 5) -> List[str]: - """ - 获取与查询相关的记忆 - - Args: - query: 查询内容 - group_id: 群组ID - limit: 返回数量限制 - - Returns: - 相关记忆列表 - """ - try: - memory_graph = self.get_memory_graph(group_id) - - # 提取查询中的概念 - query_concepts = await self._extract_concepts_from_text(query) - - related_memories = [] - - for concept in query_concepts: - if concept in memory_graph.G: - # 获取相关概念 - first_layer, second_layer = memory_graph.get_related_concepts(concept, depth=2) - related_memories.extend(first_layer) - related_memories.extend(second_layer) - - # 去重并限制数量 - unique_memories = list(dict.fromkeys(related_memories)) - return unique_memories[:limit] - - except Exception as e: - logger.error(f"获取相关记忆失败: {e}") - return [] - - async def _extract_concepts_from_text(self, text: str) -> List[str]: - """从文本中提取概念""" - # 简化版本的概念提取,可以后续优化 - import jieba - - # 使用jieba分词提取关键词 - words = jieba.lcut(text) - - # 过滤停用词和短词 - stopwords = {'的', '是', '在', '了', '和', '有', '我', '你', '他', '她', '它', '这', '那', '一个', '不', '没有'} - concepts = [word for word in words if len(word) > 1 and word not in stopwords] - - return concepts[:5] # 返回前5个概念 - - async def get_memory_graph_statistics(self, group_id: str) -> Dict[str, Any]: - """获取记忆图统计信息""" - try: - memory_graph = self.get_memory_graph(group_id) - stats = memory_graph.get_graph_statistics() - - # 添加更多统计信息 - with self.db_manager.get_connection() as conn: - cursor = conn.execute( - 'SELECT COUNT(*) FROM memory_nodes WHERE group_id = ?', - (group_id,) - ) - db_nodes_count = cursor.fetchone()[0] - - cursor = conn.execute( - 'SELECT COUNT(*) FROM memory_edges WHERE group_id = ?', - (group_id,) - ) - db_edges_count = cursor.fetchone()[0] - - stats.update({ - 'db_nodes_count': db_nodes_count, - 'db_edges_count': db_edges_count, - 'group_id': group_id - }) - - return stats - - except Exception as e: - logger.error(f"获取记忆图统计信息失败: {e}") - return {} - - -# 导入asyncio -import asyncio \ No newline at end of file diff --git a/services/performance_optimizer.py b/services/performance_optimizer.py deleted file mode 100644 index b74d30c..0000000 --- a/services/performance_optimizer.py +++ /dev/null @@ -1,511 +0,0 @@ -""" -并行化和异步优化服务 - 应用MaiBot的高性能架构 - -关键技术: -1. asyncio.gather 并行信息收集 (串行8s+ → 并行3.2s) -2. LLM判定缓存 (30秒TTL) -3. 非阻塞异步学习任务 -4. 上下文哈希缓存 -""" -import asyncio -import hashlib -import time -from typing import Dict, Any, Callable, Optional, List, Tuple -from functools import wraps -from astrbot.api import logger - - -class LLMResultCache: - """ - LLM判定结果缓存 - - MaiBot的关键优化: 30秒TTL缓存避免重复LLM调用 - 缓存命中率可达60%+, 节省大量时间和API调用 - """ - - def __init__(self, ttl: int = 30, max_size: int = 1000): - """ - 初始化LLM缓存 - - Args: - ttl: 缓存有效期(秒), 默认30秒 - max_size: 最大缓存条目数 - """ - self.cache: Dict[str, Tuple[Any, float]] = {} - self.ttl = ttl - self.max_size = max_size - self.hits = 0 - self.misses = 0 - - def _make_key(self, action_name: str, context: str) -> str: - """ - 生成缓存键 - - Args: - action_name: 操作名称 - context: 上下文内容 - - Returns: - 缓存键 - """ - # 使用上下文的MD5哈希作为键的一部分 - context_hash = hashlib.md5(context.encode()).hexdigest()[:8] - return f"{action_name}_{context_hash}" - - async def get_or_compute( - self, - action_name: str, - context: str, - compute_fn: Callable - ) -> Any: - """ - 获取缓存值或计算新值 - - Args: - action_name: 操作名称 - context: 上下文内容 - compute_fn: 计算函数(异步) - - Returns: - 缓存或计算的结果 - """ - key = self._make_key(action_name, context) - - # 检查缓存 - if key in self.cache: - result, timestamp = self.cache[key] - if time.time() - timestamp < self.ttl: - self.hits += 1 - logger.debug(f"缓存命中: {action_name}") - return result - - # 计算新值 - self.misses += 1 - result = await compute_fn() - self.cache[key] = (result, time.time()) - - # 清理过期缓存 - self._cleanup() - - return result - - def get(self, action_name: str, context: str) -> Optional[Any]: - """ - 仅获取缓存值(不计算) - - Args: - action_name: 操作名称 - context: 上下文内容 - - Returns: - 缓存值或None - """ - key = self._make_key(action_name, context) - if key in self.cache: - result, timestamp = self.cache[key] - if time.time() - timestamp < self.ttl: - return result - return None - - def set(self, action_name: str, context: str, value: Any): - """ - 设置缓存值 - - Args: - action_name: 操作名称 - context: 上下文内容 - value: 要缓存的值 - """ - key = self._make_key(action_name, context) - self.cache[key] = (value, time.time()) - self._cleanup() - - def _cleanup(self): - """清理过期缓存""" - now = time.time() - expired_keys = [ - k for k, (_, ts) in self.cache.items() - if now - ts > self.ttl - ] - for k in expired_keys: - del self.cache[k] - - # 如果仍然超过最大大小,删除最旧的条目 - if len(self.cache) > self.max_size: - sorted_items = sorted( - self.cache.items(), - key=lambda x: x[1][1] # 按时间戳排序 - ) - # 删除最旧的20% - to_remove = int(len(self.cache) * 0.2) - for key, _ in sorted_items[:to_remove]: - del self.cache[key] - - def get_stats(self) -> Dict[str, Any]: - """获取缓存统计""" - total = self.hits + self.misses - hit_rate = self.hits / total if total > 0 else 0 - return { - 'hits': self.hits, - 'misses': self.misses, - 'total': total, - 'hit_rate': f"{hit_rate:.1%}", - 'cache_size': len(self.cache) - } - - def clear(self): - """清空缓存""" - self.cache.clear() - self.hits = 0 - self.misses = 0 - - -class ParallelTaskExecutor: - """ - 并行任务执行器 - - MaiBot的关键优化: 使用asyncio.gather并行执行多个独立任务 - 总耗时从串行的8秒+降低到并行的3-4秒 - """ - - def __init__(self, timeout: float = 30.0): - """ - 初始化并行执行器 - - Args: - timeout: 单个任务的超时时间(秒) - """ - self.timeout = timeout - - async def execute_parallel( - self, - tasks: Dict[str, Callable], - return_exceptions: bool = True - ) -> Dict[str, Any]: - """ - 并行执行多个任务 - - Args: - tasks: 任务字典 {任务名: 异步函数} - return_exceptions: 是否返回异常而不是抛出 - - Returns: - 结果字典 {任务名: 结果} - """ - start_time = time.time() - - # 创建任务协程列表 - task_names = list(tasks.keys()) - task_coroutines = [ - asyncio.wait_for(task(), timeout=self.timeout) - for task in tasks.values() - ] - - # 并行执行 - results_list = await asyncio.gather( - *task_coroutines, - return_exceptions=return_exceptions - ) - - # 组装结果 - results = {} - for name, result in zip(task_names, results_list): - if isinstance(result, Exception): - logger.warning(f"任务 {name} 执行失败: {result}") - results[name] = None - else: - results[name] = result - - elapsed = time.time() - start_time - logger.debug(f"并行执行 {len(tasks)} 个任务完成, 耗时: {elapsed:.2f}秒") - - return results - - async def execute_with_priority( - self, - high_priority_tasks: Dict[str, Callable], - low_priority_tasks: Dict[str, Callable] - ) -> Tuple[Dict[str, Any], asyncio.Task]: - """ - 执行带优先级的任务 - - 高优先级任务立即执行并等待结果 - 低优先级任务在后台执行,不阻塞 - - Args: - high_priority_tasks: 高优先级任务字典 - low_priority_tasks: 低优先级任务字典 - - Returns: - (高优先级结果, 低优先级任务的Task对象) - """ - # 立即执行高优先级任务 - high_results = await self.execute_parallel(high_priority_tasks) - - # 低优先级任务在后台执行 - async def run_low_priority(): - return await self.execute_parallel(low_priority_tasks) - - low_priority_task = asyncio.create_task(run_low_priority()) - - return high_results, low_priority_task - - -class AsyncLearningScheduler: - """ - 异步学习任务调度器 - - MaiBot的关键优化: 学习任务不阻塞主回复流程 - 使用asyncio.create_task在后台执行学习 - """ - - def __init__(self, max_concurrent: int = 5): - """ - 初始化学习调度器 - - Args: - max_concurrent: 最大并发学习任务数 - """ - self.max_concurrent = max_concurrent - self.running_tasks: List[asyncio.Task] = [] - self.pending_tasks: List[Callable] = [] - self._lock = asyncio.Lock() - - async def schedule_learning( - self, - learning_fn: Callable, - task_name: str = "learning" - ) -> Optional[asyncio.Task]: - """ - 调度一个学习任务(非阻塞) - - Args: - learning_fn: 学习函数(异步) - task_name: 任务名称 - - Returns: - 创建的Task对象或None(如果超过并发限制) - """ - async with self._lock: - # 清理已完成的任务 - self.running_tasks = [ - t for t in self.running_tasks - if not t.done() - ] - - # 检查是否可以启动新任务 - if len(self.running_tasks) >= self.max_concurrent: - logger.debug(f"学习任务队列已满,延迟执行: {task_name}") - self.pending_tasks.append(learning_fn) - return None - - # 创建后台任务 - async def wrapped_task(): - try: - await learning_fn() - logger.debug(f"学习任务完成: {task_name}") - except Exception as e: - logger.error(f"学习任务失败 {task_name}: {e}") - finally: - # 尝试执行待处理的任务 - await self._try_execute_pending() - - task = asyncio.create_task(wrapped_task()) - self.running_tasks.append(task) - logger.debug(f"学习任务已调度: {task_name}") - - return task - - async def _try_execute_pending(self): - """尝试执行待处理的任务""" - async with self._lock: - # 清理已完成的任务 - self.running_tasks = [ - t for t in self.running_tasks - if not t.done() - ] - - # 如果有空位且有待处理任务 - while ( - len(self.running_tasks) < self.max_concurrent - and self.pending_tasks - ): - pending_fn = self.pending_tasks.pop(0) - - async def wrapped(): - try: - await pending_fn() - except Exception as e: - logger.error(f"待处理学习任务失败: {e}") - - task = asyncio.create_task(wrapped()) - self.running_tasks.append(task) - - async def wait_all(self, timeout: float = 60.0) -> bool: - """ - 等待所有学习任务完成 - - Args: - timeout: 超时时间(秒) - - Returns: - 是否全部完成 - """ - if not self.running_tasks: - return True - - try: - await asyncio.wait_for( - asyncio.gather(*self.running_tasks, return_exceptions=True), - timeout=timeout - ) - return True - except asyncio.TimeoutError: - logger.warning("等待学习任务超时") - return False - - def get_status(self) -> Dict[str, Any]: - """获取调度器状态""" - return { - 'running_count': len([t for t in self.running_tasks if not t.done()]), - 'pending_count': len(self.pending_tasks), - 'max_concurrent': self.max_concurrent - } - - -class PerformanceOptimizer: - """ - 性能优化器 - 整合所有优化功能 - - 提供: - 1. 并行信息收集 - 2. LLM结果缓存 - 3. 异步学习调度 - """ - - def __init__(self, cache_ttl: int = 30): - """初始化性能优化器""" - self.cache = LLMResultCache(ttl=cache_ttl) - self.executor = ParallelTaskExecutor() - self.scheduler = AsyncLearningScheduler() - - async def collect_reply_context( - self, - tasks: Dict[str, Callable] - ) -> Dict[str, Any]: - """ - 并行收集回复所需的上下文信息 - - 这是MaiBot高速回复的核心: 将原本串行的8秒+操作 - 通过并行执行降低到3-4秒 - - Args: - tasks: 上下文收集任务字典 - - Returns: - 收集到的上下文信息 - """ - return await self.executor.execute_parallel(tasks) - - async def cached_llm_call( - self, - action: str, - context: str, - llm_fn: Callable - ) -> Any: - """ - 带缓存的LLM调用 - - Args: - action: 操作名称 - context: 上下文(用于生成缓存键) - llm_fn: LLM调用函数 - - Returns: - LLM调用结果 - """ - return await self.cache.get_or_compute(action, context, llm_fn) - - async def schedule_background_learning( - self, - learning_fn: Callable, - name: str = "learning" - ): - """ - 调度后台学习任务(非阻塞) - - Args: - learning_fn: 学习函数 - name: 任务名称 - """ - await self.scheduler.schedule_learning(learning_fn, name) - - def get_performance_stats(self) -> Dict[str, Any]: - """获取性能统计""" - return { - 'cache': self.cache.get_stats(), - 'scheduler': self.scheduler.get_status() - } - - -# 全局性能优化器实例 -_performance_optimizer: Optional[PerformanceOptimizer] = None - - -def get_performance_optimizer() -> PerformanceOptimizer: - """获取全局性能优化器实例""" - global _performance_optimizer - if _performance_optimizer is None: - _performance_optimizer = PerformanceOptimizer() - return _performance_optimizer - - -# 装饰器: 自动缓存LLM调用结果 -def cached_llm_result(action_name: str, context_key: str = None): - """ - 装饰器: 自动缓存LLM调用结果 - - Args: - action_name: 操作名称 - context_key: 用于缓存键的参数名(默认使用第一个参数) - """ - def decorator(fn): - @wraps(fn) - async def wrapper(*args, **kwargs): - optimizer = get_performance_optimizer() - - # 获取上下文 - if context_key and context_key in kwargs: - context = str(kwargs[context_key]) - elif args: - context = str(args[0]) - else: - context = "" - - return await optimizer.cached_llm_call( - action_name, - context[:100], # 只使用前100字符 - lambda: fn(*args, **kwargs) - ) - return wrapper - return decorator - - -# 装饰器: 非阻塞后台执行 -def background_task(name: str = "background"): - """ - 装饰器: 将任务转为后台非阻塞执行 - - Args: - name: 任务名称 - """ - def decorator(fn): - @wraps(fn) - async def wrapper(*args, **kwargs): - optimizer = get_performance_optimizer() - await optimizer.schedule_background_learning( - lambda: fn(*args, **kwargs), - name - ) - return wrapper - return decorator diff --git a/services/persona/__init__.py b/services/persona/__init__.py new file mode 100644 index 0000000..4352083 --- /dev/null +++ b/services/persona/__init__.py @@ -0,0 +1,15 @@ +"""Persona management -- create, update, backup, temporary personas.""" + +from .persona_manager import PersonaManagerService +from .persona_manager_updater import PersonaManagerUpdater +from .persona_updater import PersonaUpdater +from .persona_backup_manager import PersonaBackupManager +from .temporary_persona_updater import TemporaryPersonaUpdater + +__all__ = [ + "PersonaManagerService", + "PersonaManagerUpdater", + "PersonaUpdater", + "PersonaBackupManager", + "TemporaryPersonaUpdater", +] diff --git a/services/persona_backup_manager.py b/services/persona/persona_backup_manager.py similarity index 99% rename from services/persona_backup_manager.py rename to services/persona/persona_backup_manager.py index 5e5c4b8..5720348 100644 --- a/services/persona_backup_manager.py +++ b/services/persona/persona_backup_manager.py @@ -9,9 +9,9 @@ from astrbot.api import logger from astrbot.api.star import Context -from ..config import PluginConfig -from ..exceptions import BackupError -from .database_manager import DatabaseManager +from ...config import PluginConfig +from ...exceptions import BackupError +from ..database import DatabaseManager class PersonaBackupManager: diff --git a/services/persona_manager.py b/services/persona/persona_manager.py similarity index 96% rename from services/persona_manager.py rename to services/persona/persona_manager.py index 251b7c0..d98e0ec 100644 --- a/services/persona_manager.py +++ b/services/persona/persona_manager.py @@ -2,11 +2,11 @@ from typing import Dict, Any, Optional, List from astrbot.api.star import Context -from ..config import PluginConfig +from ...config import PluginConfig -from ..core.interfaces import IPersonaManager, IPersonaUpdater, IPersonaBackupManager, ServiceLifecycle, MessageData +from ...core.interfaces import IPersonaManager, IPersonaUpdater, IPersonaBackupManager, ServiceLifecycle, MessageData -from ..exceptions import SelfLearningError # 导入 SelfLearningError +from ...exceptions import SelfLearningError # 导入 SelfLearningError class PersonaManagerService(IPersonaManager): """ diff --git a/services/persona_manager_updater.py b/services/persona/persona_manager_updater.py similarity index 98% rename from services/persona_manager_updater.py rename to services/persona/persona_manager_updater.py index 39a805d..4dfc4ff 100644 --- a/services/persona_manager_updater.py +++ b/services/persona/persona_manager_updater.py @@ -10,9 +10,9 @@ from astrbot.api import logger from astrbot.api.star import Context -from ..core.interfaces import IPersonaManagerUpdater -from ..config import PluginConfig -from ..exceptions import SelfLearningError +from ...core.interfaces import IPersonaManagerUpdater +from ...config import PluginConfig +from ...exceptions import SelfLearningError class PersonaManagerUpdater(IPersonaManagerUpdater): @@ -154,7 +154,7 @@ async def get_or_create_group_persona(self, group_id: str, base_persona_id: str if existing_persona: logger.info(f"使用现有群组persona: {persona_id}") return persona_id - except: + except Exception: # persona不存在,清理映射 del self.group_persona_mapping[group_id] @@ -170,7 +170,7 @@ async def get_or_create_group_persona(self, group_id: str, base_persona_id: str # 获取基础persona try: base_persona = await self.persona_manager.get_persona(base_persona_id) - except: + except Exception: # 如果指定的基础persona不存在,使用默认 base_persona = await self.persona_manager.get_default_persona_v3(self._resolve_umo(group_id)) diff --git a/services/persona_updater.py b/services/persona/persona_updater.py similarity index 95% rename from services/persona_updater.py rename to services/persona/persona_updater.py index e01bfdb..431ebef 100644 --- a/services/persona_updater.py +++ b/services/persona/persona_updater.py @@ -9,18 +9,18 @@ from astrbot.api.star import Context from astrbot.core.db.po import Personality -from ..config import PluginConfig +from ...config import PluginConfig -from ..core.interfaces import IPersonaUpdater, IPersonaBackupManager, MessageData, AnalysisResult, PersonaUpdateRecord # 导入 PersonaUpdateRecord +from ...core.interfaces import IPersonaUpdater, IPersonaBackupManager, MessageData, AnalysisResult, PersonaUpdateRecord # 导入 PersonaUpdateRecord from .persona_manager_updater import PersonaManagerUpdater -from ..exceptions import PersonaUpdateError, SelfLearningError # 导入 PersonaUpdateError -from .database_manager import DatabaseManager # 导入 DatabaseManager +from ...exceptions import PersonaUpdateError, SelfLearningError # 导入 PersonaUpdateError +from ..database import DatabaseManager # 导入 DatabaseManager # MaiBot功能模块导入 - 结合MaiBot的学习功能 -from .expression_pattern_learner import ExpressionPatternLearner -from .memory_graph_manager import MemoryGraphManager -from .knowledge_graph_manager import KnowledgeGraphManager +from ..analysis import ExpressionPatternLearner +from ..state.enhanced_memory_graph_manager import MemoryGraphManager +from ..integration import KnowledgeGraphManager class PersonaUpdater(IPersonaUpdater): @@ -42,7 +42,7 @@ def __init__(self, config: PluginConfig, context: Context, backup_manager: IPers # 初始化MaiBot组件 - 结合MaiBot功能 # 创建FrameworkLLMAdapter for expression learner - from ..core.framework_llm_adapter import FrameworkLLMAdapter + from ...core.framework_llm_adapter import FrameworkLLMAdapter expression_llm_adapter = FrameworkLLMAdapter(context) expression_llm_adapter.initialize_providers(config) @@ -80,7 +80,7 @@ async def update_persona_with_style(self, group_id: str, style_analysis: Dict[st persona_name = current_persona.get('name', 'unknown') if isinstance(current_persona, dict) else current_persona['name'] self._logger.info(f"当前人格: {persona_name} for group {group_id}") - # ===== 创建备份(如果启用) ===== + # 创建备份(如果启用) backup_id = None if self.config.persona_update_backup_enabled: try: @@ -93,7 +93,7 @@ async def update_persona_with_style(self, group_id: str, style_analysis: Dict[st self._logger.error(f"创建备份失败: {backup_error}") # 不阻止更新继续进行 - # ===== 保存更新前的人格状态用于对比 ===== + # 保存更新前的人格状态用于对比 def clone_persona_data(persona_data: Any) -> Dict[str, Any]: """临时克隆人格数据用于对比""" try: @@ -153,7 +153,7 @@ def clone_persona_data(persona_data: Any) -> Dict[str, Any]: if 'style_attributes' in style_analysis: # 从 style_analysis 中获取 style_attributes await self._apply_style_attributes(current_persona, style_analysis['style_attributes']) - # ===== 生成并输出格式化的更新报告 ===== + # 生成并输出格式化的更新报告 after_persona = clone_persona_data(current_persona) update_details = { 'new_features_count': len(style_analysis.get('style_features', [])), @@ -304,13 +304,13 @@ async def _create_approved_persona_backup(self, update_id: int, modified_content if not approved_prompt: self._logger.error(f"✗ 更新记录 {update_id} 中没有新内容(new_content),且未提供modified_content") - self._logger.error(f" update_record keys: {list(update_record.keys())}") + self._logger.error(f" update_record keys: {list(update_record.keys())}") return False self._logger.info(f"开始创建批准更新人格: {approved_persona_id}") - self._logger.info(f" 原人格prompt长度: {len(original_prompt)} 字符") - self._logger.info(f" 新人格prompt长度: {len(approved_prompt)} 字符") - self._logger.debug(f" 新人格prompt前100字: {approved_prompt[:100]}...") + self._logger.info(f" 原人格prompt长度: {len(original_prompt)} 字符") + self._logger.info(f" 新人格prompt长度: {len(approved_prompt)} 字符") + self._logger.debug(f" 新人格prompt前100字: {approved_prompt[:100]}...") self._logger.info(f"调用 PersonaManager.create_persona()...") approved_persona = await persona_manager.create_persona( @@ -335,17 +335,17 @@ async def _create_approved_persona_backup(self, update_id: int, modified_content return True else: self._logger.error(f"✗ 验证失败: 批准更新人格创建后无法找到") - self._logger.error(f" 尝试列出所有人格...") + self._logger.error(f" 尝试列出所有人格...") try: all_personas = await persona_manager.get_all_personas() - self._logger.error(f" 当前所有人格: {[p.get('name', 'unknown') for p in all_personas] if all_personas else '无法获取'}") + self._logger.error(f" 当前所有人格: {[p.get('name', 'unknown') for p in all_personas] if all_personas else '无法获取'}") except Exception as list_error: - self._logger.error(f" 列出人格失败: {list_error}") + self._logger.error(f" 列出人格失败: {list_error}") return False else: self._logger.error(f"✗ 创建批准更新人格失败: {approved_persona_id}") - self._logger.error(f" PersonaManager.create_persona() 返回了 None 或 False") - self._logger.error(f" 参数检查: persona_id='{approved_persona_id}', system_prompt长度={len(approved_prompt)}") + self._logger.error(f" PersonaManager.create_persona() 返回了 None 或 False") + self._logger.error(f" 参数检查: persona_id='{approved_persona_id}', system_prompt长度={len(approved_prompt)}") return False else: self._logger.error("PersonaManager不可用,无法创建备份") @@ -370,7 +370,7 @@ async def get_reviewed_persona_updates(self, limit: int = 50, offset: int = 0, s 'original_content': record.get('original_content', ''), 'proposed_content': record.get('new_content', ''), 'reason': record.get('reason', '传统人格更新'), - 'confidence_score': 0.9, # 传统更新默认较高置信度 + 'confidence_score': 0.9, # 传统更新默认较高置信度 'status': record.get('status'), 'reviewer_comment': record.get('reviewer_comment'), 'review_time': record.get('review_time'), @@ -466,7 +466,7 @@ def _merge_prompts(self, original: str, enhancement: str) -> str: return f"{original}\n\n{enhancement}" elif self.config.persona_merge_strategy == "prepend": return f"{enhancement}\n\n{original}" - else: # smart merge + else: # smart merge return self._smart_merge_prompts(original, enhancement) def _smart_merge_prompts(self, original: str, enhancement: str) -> str: @@ -483,9 +483,9 @@ def _smart_merge_prompts(self, original: str, enhancement: str) -> str: overlap_ratio = len(words_original.intersection(words_enhancement)) / max(len(words_original), 1) - if overlap_ratio > 0.7: # 高重叠,选择较长的 + if overlap_ratio > 0.7: # 高重叠,选择较长的 return enhancement if len(enhancement) > len(original) else original - else: # 低重叠,合并 + else: # 低重叠,合并 return f"{original}\n\n补充风格特征:{enhancement}" async def _update_mood_imitation_dialogs(self, persona: Personality, filtered_messages: List[Dict[str, Any]]): @@ -495,7 +495,7 @@ async def _update_mood_imitation_dialogs(self, persona: Personality, filtered_me # 从过滤后的消息中提取高质量对话特征(不是原始对话) new_features = [] - for msg in filtered_messages[-10:]: # 取最近10条 + for msg in filtered_messages[-10:]: # 取最近10条 message_text = msg.get('message', '').strip() if message_text and len(message_text) > self.config.message_min_length: if self._is_authentic_message(message_text) and message_text not in current_dialogs: @@ -534,7 +534,7 @@ def _is_authentic_message(self, text: str) -> bool: r'.*:\s*你最近.*', r'开场对话列表', r'情绪模拟对话列表', - r'风格特征:.*', # 避免重复嵌套 + r'风格特征:.*', # 避免重复嵌套 ] import re @@ -661,7 +661,7 @@ async def analyze_persona_compatibility(self, target_style: Dict[str, Any]) -> A target_attributes = target_style.get('style_attributes', {}) # 简单的兼容性评分 - compatibility_score = 0.8 # 基础分数 + compatibility_score = 0.8 # 基础分数 # 检查风格冲突 conflicts = [] @@ -787,7 +787,7 @@ async def format_persona_update_report(self, group_id: str, before_persona: Dict 格式化的人格更新报告 """ try: - from ..statics.messages import CommandMessages + from ...statics.messages import CommandMessages # 生成变化摘要 change_summary = await self._generate_change_summary(before_persona, after_persona, update_details) @@ -809,7 +809,7 @@ async def format_persona_update_report(self, group_id: str, before_persona: Dict except Exception as e: self._logger.error(f"格式化人格更新报告失败: {e}") - from ..statics.messages import CommandMessages + from ...statics.messages import CommandMessages return CommandMessages.PERSONA_UPDATE_FAILED.format(error=str(e)) def _format_persona_content(self, persona_data: Dict[str, Any]) -> str: @@ -836,7 +836,7 @@ async def _generate_change_summary(self, before_persona: Dict[str, Any], update_details: Dict[str, Any]) -> str: """生成变化摘要""" try: - from ..statics.messages import CommandMessages + from ...statics.messages import CommandMessages # 计算prompt长度变化 before_prompt = self._get_persona_prompt(before_persona) @@ -942,7 +942,7 @@ async def stop(self): self._logger.error(f"停止人格更新服务失败: {e}") return False - # ===== 人格格式化输出功能 ===== + # 人格格式化输出功能 async def format_current_persona_display(self, group_id: str) -> str: """ @@ -955,12 +955,12 @@ async def format_current_persona_display(self, group_id: str) -> str: 格式化的当前人格信息 """ try: - from ..statics.messages import CommandMessages + from ...statics.messages import CommandMessages # 获取当前人格信息 current_persona = await self.get_current_persona(group_id) if not current_persona: - return "❌ 无法获取当前人格信息" + return " 无法获取当前人格信息" # 获取人格统计信息 stats = await self._get_persona_statistics(group_id) @@ -1001,7 +1001,7 @@ async def format_current_persona_display(self, group_id: str) -> str: except Exception as e: self._logger.error(f"格式化当前人格显示失败: {e}") - return f"❌ 获取人格信息失败: {str(e)}" + return f" 获取人格信息失败: {str(e)}" def _get_persona_name(self, persona_data: Any) -> str: """获取人格名称""" @@ -1073,7 +1073,7 @@ async def _get_learned_style_features(self, group_id: str) -> str: features.append(line) if features: - return '\n'.join(features[-10:]) # 显示最近10个特征 + return '\n'.join(features[-10:]) # 显示最近10个特征 return "暂无学习到的风格特征" @@ -1081,7 +1081,7 @@ async def _get_learned_style_features(self, group_id: str) -> str: self._logger.error(f"获取学习到的风格特征失败: {e}") return "获取风格特征失败" - # ===== 辅助方法 ===== + # 辅助方法 async def _clone_persona_data(self, persona_data: Any) -> Dict[str, Any]: """克隆人格数据用于对比""" diff --git a/services/temporary_persona_updater.py b/services/persona/temporary_persona_updater.py similarity index 97% rename from services/temporary_persona_updater.py rename to services/persona/temporary_persona_updater.py index be085c5..6c7b792 100644 --- a/services/temporary_persona_updater.py +++ b/services/persona/temporary_persona_updater.py @@ -11,18 +11,18 @@ from astrbot.api import logger from astrbot.api.star import Context -from ..config import PluginConfig +from ...config import PluginConfig -from ..core.interfaces import IPersonaUpdater, IPersonaBackupManager +from ...core.interfaces import IPersonaUpdater, IPersonaBackupManager -from ..services.database_manager import DatabaseManager -from ..services.persona_manager_updater import PersonaManagerUpdater +from ..database import DatabaseManager +from .persona_manager_updater import PersonaManagerUpdater -from ..statics.temp_persona_messages import TemporaryPersonaMessages +from ...statics.temp_persona_messages import TemporaryPersonaMessages -from ..statics.prompts import MULTIDIMENSIONAL_ANALYZER_FILTER_MESSAGE_PROMPT +from ...statics.prompts import MULTIDIMENSIONAL_ANALYZER_FILTER_MESSAGE_PROMPT -from ..exceptions import SelfLearningError +from ...exceptions import SelfLearningError class TemporaryPersonaUpdater: @@ -49,8 +49,8 @@ def __init__(self, self.db_manager = db_manager # 临时人格存储 - self.active_temp_personas: Dict[str, Dict] = {} # group_id -> temp_persona_info - self.expiry_tasks: Dict[str, asyncio.Task] = {} # group_id -> expiry_task + self.active_temp_personas: Dict[str, Dict] = {} # group_id -> temp_persona_info + self.expiry_tasks: Dict[str, asyncio.Task] = {} # group_id -> expiry_task # 备份目录设置 self.backup_base_dir = os.path.join(config.data_dir, "persona_backups") @@ -262,7 +262,7 @@ async def _create_enhanced_persona(self, # 添加对话示例 if example_dialogs: - dialog_examples = ['\n\".join([f\"- {dialog}' for dialog in example_dialogs[:5]] # 限制数量 + dialog_examples = ['\n\".join([f\"- {dialog}' for dialog in example_dialogs[:5]] # 限制数量 enhanced_prompt += f'{dialog_examples}' enhanced_persona.update({ @@ -270,7 +270,7 @@ async def _create_enhanced_persona(self, 'prompt': enhanced_prompt, 'mood_imitation_dialogs': ( original_persona.get('mood_imitation_dialogs', []) + example_dialogs - )[-20:], # 保留最新20条 + )[-20:], # 保留最新20条 'temp_features': new_features, 'temp_created_at': datetime.now().isoformat() }) @@ -292,7 +292,7 @@ async def _apply_persona_to_system(self, group_id: str, persona: Dict[str, Any]) """ 将人格应用到系统中 - 使用会话级存储而不是修改全局provider - ✅ 修复: 不再修改全局provider.curr_personality,避免会话串流 + 修复: 不再修改全局provider.curr_personality,避免会话串流 改为存储到self.session_updates[group_id]中,由LLM Hook注入 """ try: @@ -311,7 +311,7 @@ async def _apply_persona_to_system(self, group_id: str, persona: Dict[str, Any]) incremental_update = enhanced_prompt[update_start:] logger.info(f"提取到增量更新内容: {incremental_update[:100]}...") - # ✅ 存储到会话级映射,不修改全局provider + # 存储到会话级映射,不修改全局provider if group_id not in self.session_updates: self.session_updates[group_id] = [] @@ -608,7 +608,7 @@ async def _apply_incremental_updates(self, current_persona: Dict[str, Any], upda current_prompt = updated_persona.get('prompt', '') # 去除重复的更新内容 - unique_updates = list(dict.fromkeys(updates)) # 保持顺序的去重 + unique_updates = list(dict.fromkeys(updates)) # 保持顺序的去重 logger.info(f"原始更新数量: {len(updates)}, 去重后: {len(unique_updates)}") # 构建增量更新文本 @@ -868,7 +868,7 @@ async def apply_expression_style_learning(self, group_id: str, expression_patter # 构建表达风格描述 style_descriptions = [] - for pattern in expression_patterns[:5]: # 只取前5个最重要的 + for pattern in expression_patterns[:5]: # 只取前5个最重要的 situation = pattern.get('situation', '').strip() expression = pattern.get('expression', '').strip() weight = pattern.get('weight', 1.0) @@ -1275,12 +1275,12 @@ async def _validate_dialog_authenticity(self, dialogs: List[str]) -> List[str]: # 定义虚假对话的特征模式 fake_patterns = [ - r'A:\s*你最近干.*呢.*\?', # "A: 你最近干啥呢?"模式 - r'B:\s*', # "B: "开头的模式 - r'用户\d+:\s*', # "用户01: "模式 - r'.*:\s*你最近.*', # 任何包含"你最近"的对话格式 - r'开场对话列表', # 示例文本 - r'情绪模拟对话列表', # 示例文本 + r'A:\s*你最近干.*呢.*\?', # "A: 你最近干啥呢?"模式 + r'B:\s*', # "B: "开头的模式 + r'用户\d+:\s*', # "用户01: "模式 + r'.*:\s*你最近.*', # 任何包含"你最近"的对话格式 + r'开场对话列表', # 示例文本 + r'情绪模拟对话列表', # 示例文本 ] import re @@ -1292,7 +1292,7 @@ async def _validate_dialog_authenticity(self, dialogs: List[str]) -> List[str]: is_fake = True break - if not is_fake and len(dialog.strip()) > 3: # 只保留有效的真实对话 + if not is_fake and len(dialog.strip()) > 3: # 只保留有效的真实对话 validated_dialogs.append(dialog) logger.info(f"对话验证完成: 原始{len(dialogs)}条,验证后{len(validated_dialogs)}条") diff --git a/services/persona_optimization.py b/services/persona_optimization.py deleted file mode 100644 index 5e3fcc3..0000000 --- a/services/persona_optimization.py +++ /dev/null @@ -1,397 +0,0 @@ -""" -人格优化服务 - 应用MaiBot的极简人格和反完美主义技术 - -关键技术: -1. 极简人格核心 - 将人格描述简化为1句话核心 -2. 动态表达规则 - 每次随机选择8条学习到的表达习惯 -3. 反完美主义提示 - 添加让回复更自然的指令 -4. 情绪随机切换 - 20%概率切换情绪状态 -""" -import random -from typing import List, Dict, Any, Optional, Tuple -from datetime import datetime -from astrbot.api import logger - - -class PersonaPromptOptimizer: - """ - 人格Prompt优化器 - 应用MaiBot的五层人格架构 - - 五层架构: - 1. 极简核心 (1句话) - 2. 动态表达规则 (8条随机选择) - 3. 实时情绪状态 - 4. 行为规划思考 (可选) - 5. 反完美主义指令 (关键!) - """ - - # 反完美主义提示模板 - ANTI_PERFECTIONISM_PROMPTS = [ - "回复时请注意: 平淡一些,不要太有条理,可以有个性。像真人一样自然对话,不必每次都很完美。", - "说话自然一点,不需要面面俱到,有时候简短随意的回复更真实。", - "回复时可以有自己的小情绪和个性,不必总是很专业很完美。", - "不用每次都回复得很周全,有时候随口说说也挺好的。", - ] - - # 默认情绪状态列表 - DEFAULT_MOODS = [ - '平静', '开心', '好奇', '慵懒', '兴奋', - '困倦', '活泼', '沉思', '放松', '调皮' - ] - - def __init__(self, db_manager=None): - """ - 初始化人格优化器 - - Args: - db_manager: 数据库管理器 (用于获取学习到的表达规则) - """ - self.db = db_manager - self.current_mood = random.choice(self.DEFAULT_MOODS) - self.last_mood_change = datetime.now() - - async def build_optimized_persona_prompt( - self, - base_persona_core: str, - group_id: Optional[str] = None, - include_mood: bool = True, - include_anti_perfectionism: bool = True, - expression_rules_count: int = 8 - ) -> str: - """ - 构建优化后的人格Prompt - - Args: - base_persona_core: 基础人格核心描述 (应该是1句话的简短描述) - group_id: 群组ID (用于获取群组特定的表达规则) - include_mood: 是否包含情绪状态 - include_anti_perfectionism: 是否包含反完美主义提示 - expression_rules_count: 要包含的表达规则数量 - - Returns: - 优化后的完整人格Prompt - """ - prompt_parts = [] - - # 第1层: 极简人格核心 - core = self._simplify_persona_core(base_persona_core) - prompt_parts.append(f"你是{core}") - - # 第2层: 动态表达规则 - if self.db and group_id: - expressions = await self._get_random_expression_rules( - group_id, expression_rules_count - ) - if expressions: - prompt_parts.append("\n你学到的表达习惯:") - for expr in expressions: - prompt_parts.append(f"- {expr}") - - # 第3层: 实时情绪状态 - if include_mood: - # 20%概率切换情绪 - self._maybe_switch_mood() - prompt_parts.append(f"\n当前情绪状态: {self.current_mood}") - - # 第4层: 行为规划 (可选,根据需要添加) - # 这一层通常在具体对话时动态生成 - - # 第5层: 反完美主义指令 (关键!) - if include_anti_perfectionism: - anti_perfect = random.choice(self.ANTI_PERFECTIONISM_PROMPTS) - prompt_parts.append(f"\n{anti_perfect}") - - return "\n".join(prompt_parts) - - def _simplify_persona_core(self, persona_description: str) -> str: - """ - 简化人格描述为1句话核心 - - MaiBot的关键洞察: 过度详细的人格描述会约束太多,缺乏灵活性 - 极简核心反而能让LLM发挥更自然 - - Args: - persona_description: 原始人格描述 - - Returns: - 简化后的1句话核心 - """ - if not persona_description: - return "友好的AI助手" - - # 如果已经很短,直接返回 - if len(persona_description) <= 50: - return persona_description - - # 尝试提取第一句话作为核心 - sentences = persona_description.replace('\n', '。').split('。') - if sentences: - first_sentence = sentences[0].strip() - if first_sentence and len(first_sentence) >= 5: - return first_sentence - - # 如果无法提取,截取前50个字符 - return persona_description[:50] + "..." - - async def _get_random_expression_rules( - self, - group_id: str, - count: int = 8 - ) -> List[str]: - """ - 获取随机的表达规则 - - MaiBot的关键洞察: 每次随机选择不同的表达规则,保持新鲜感 - - Args: - group_id: 群组ID - count: 要获取的规则数量 - - Returns: - 表达规则列表 - """ - try: - if not self.db: - return self._get_default_expression_rules(count) - - # 从数据库获取学习到的表达规则 - # 这里需要调用表达学习服务的方法 - # 暂时使用默认规则 - all_rules = await self._fetch_learned_expressions(group_id) - - if not all_rules: - return self._get_default_expression_rules(count) - - # 随机选择指定数量的规则 - if len(all_rules) <= count: - return all_rules - - return random.sample(all_rules, count) - - except Exception as e: - logger.error(f"获取表达规则失败: {e}") - return self._get_default_expression_rules(count) - - async def _fetch_learned_expressions(self, group_id: str) -> List[str]: - """ - 从数据库获取学习到的表达规则 - - Args: - group_id: 群组ID - - Returns: - 表达规则列表 - """ - # TODO: 集成表达学习模块后,从数据库读取 - # 暂时返回空列表,使用默认规则 - return [] - - def _get_default_expression_rules(self, count: int = 8) -> List[str]: - """ - 获取默认的表达规则 - - Args: - count: 规则数量 - - Returns: - 默认表达规则列表 - """ - default_rules = [ - "可以使用口语化的表达方式", - "适当使用语气词让对话更自然", - "回复不必太长,简洁有力也很好", - "可以表达自己的看法和小情绪", - "不必每次都正式严肃", - "有时候可以用问句来互动", - "可以适当使用网络用语", - "回复时可以有自己的风格", - "不用总是解释得很详细", - "偶尔可以调皮一下", - ] - return random.sample(default_rules, min(count, len(default_rules))) - - def _maybe_switch_mood(self, probability: float = 0.2): - """ - 概率性切换情绪状态 - - MaiBot的关键洞察: 20%概率随机切换情绪,保持对话的自然变化 - - Args: - probability: 切换概率 (默认20%) - """ - if random.random() < probability: - old_mood = self.current_mood - self.current_mood = random.choice(self.DEFAULT_MOODS) - if self.current_mood != old_mood: - self.last_mood_change = datetime.now() - logger.debug(f"情绪切换: {old_mood} -> {self.current_mood}") - - def get_current_mood(self) -> str: - """获取当前情绪状态""" - return self.current_mood - - def set_mood(self, mood: str): - """手动设置情绪状态""" - self.current_mood = mood - self.last_mood_change = datetime.now() - - @staticmethod - def enhance_reply_with_naturalness(reply: str) -> str: - """ - 增强回复的自然感 - - 应用反完美主义原则,让回复更像真人 - - Args: - reply: 原始回复 - - Returns: - 增强后的回复 - """ - # 如果回复太长太完美,适当简化 - if len(reply) > 300: - # 考虑只保留前几句话 - sentences = reply.split('。') - if len(sentences) > 5: - # 保留前3-4句,然后随机决定是否保留更多 - keep_count = random.randint(3, 5) - reply = '。'.join(sentences[:keep_count]) + '。' - - # 随机决定是否去掉结尾的客套话 - politeness_endings = [ - '如果你还有什么问题', - '希望这能帮到你', - '如果需要更多帮助', - '欢迎随时问我', - ] - for ending in politeness_endings: - if ending in reply and random.random() < 0.5: - reply = reply.split(ending)[0].strip() - - return reply - - -class PersonaOptimizationService: - """ - 人格优化服务 - 整合所有人格优化功能 - - 提供: - 1. 优化人格Prompt构建 - 2. 回复自然感增强 - 3. 情绪状态管理 - 4. 提示词保护 (元指令包装 + 后处理过滤 + 双重检查) - """ - - def __init__(self, db_manager=None, enable_prompt_protection: bool = True): - """ - 初始化人格优化服务 - - Args: - db_manager: 数据库管理器 - enable_prompt_protection: 是否启用提示词保护 - """ - self.optimizer = PersonaPromptOptimizer(db_manager) - self.enable_prompt_protection = enable_prompt_protection - self._protection_service = None - - def _get_protection_service(self): - """延迟加载提示词保护服务""" - if self._protection_service is None and self.enable_prompt_protection: - from .prompt_sanitizer import PromptProtectionService - self._protection_service = PromptProtectionService() - return self._protection_service - - async def get_optimized_persona( - self, - base_persona: str, - group_id: str = None - ) -> str: - """ - 获取优化后的人格Prompt - - Args: - base_persona: 基础人格描述 - group_id: 群组ID - - Returns: - 优化后的人格Prompt - """ - return await self.optimizer.build_optimized_persona_prompt( - base_persona_core=base_persona, - group_id=group_id, - include_mood=True, - include_anti_perfectionism=True - ) - - def enhance_reply(self, reply: str) -> str: - """ - 增强回复的自然感 - - Args: - reply: 原始回复 - - Returns: - 增强后的回复 - """ - return PersonaPromptOptimizer.enhance_reply_with_naturalness(reply) - - def get_current_mood(self) -> str: - """获取当前情绪""" - return self.optimizer.get_current_mood() - - def wrap_diversity_prompts(self, prompts: List[str]) -> str: - """ - 使用元指令包装多样性提示词 - - Args: - prompts: 多样性提示词列表 - - Returns: - 包装后的提示词 - """ - protection = self._get_protection_service() - if protection: - return protection.wrap_prompts(prompts) - return "\n".join(prompts) - - def sanitize_response(self, response: str) -> Tuple[str, Dict[str, Any]]: - """ - 消毒LLM回复 - 移除泄露的提示词 - - Args: - response: LLM原始回复 - - Returns: - (消毒后的回复, 处理报告) - """ - protection = self._get_protection_service() - if protection: - return protection.sanitize_response(response) - return response, {'sanitized': False} - - def process_with_protection( - self, - diversity_prompts: List[str], - llm_response: str - ) -> Tuple[str, str, Dict[str, Any]]: - """ - 完整的保护流程处理 - - Args: - diversity_prompts: 多样性注入提示词 - llm_response: LLM回复 - - Returns: - (包装后的提示词, 消毒后的回复, 处理报告) - """ - protection = self._get_protection_service() - if protection: - return protection.process_llm_interaction(diversity_prompts, llm_response) - return "\n".join(diversity_prompts), llm_response, {'protected': False} - - def get_protection_stats(self) -> Optional[Dict[str, Any]]: - """获取提示词保护统计信息""" - protection = self._get_protection_service() - if protection: - return protection.get_stats() - return None diff --git a/services/psychological_social_context_injector.py b/services/psychological_social_context_injector.py deleted file mode 100644 index fa67709..0000000 --- a/services/psychological_social_context_injector.py +++ /dev/null @@ -1,736 +0,0 @@ -""" -心理状态与社交关系上下文注入器 -将bot的心理状态和用户的社交关系信息整合注入到LLM prompt中 -支持提示词保护,避免注入内容泄露 -""" -import asyncio -from typing import Dict, Any, List, Optional, Tuple - -from astrbot.api import logger - - -class PsychologicalSocialContextInjector: - """ - 心理状态与社交关系上下文注入器 - - 核心功能: - 1. 整合心理状态管理器和社交关系管理器的数据 - 2. 生成结构化的上下文注入内容 - 3. 应用提示词保护机制 - 4. 使用统一缓存管理器优化性能 - 5. 生成指导bot行为模式的详细提示词 - """ - - def __init__( - self, - database_manager, - psychological_state_manager=None, - social_relation_manager=None, - affection_manager=None, - diversity_manager=None, - llm_adapter=None, - config=None - ): - self.db_manager = database_manager - self.psych_manager = psychological_state_manager - self.social_manager = social_relation_manager - self.affection_manager = affection_manager - self.diversity_manager = diversity_manager - self.llm_adapter = llm_adapter - self.config = config - - # 提示词保护服务(延迟加载) - self._prompt_protection = None - self._enable_protection = True - - # 使用统一缓存管理器 - from ..utils.cache_manager import get_cache_manager - self._cache_manager = get_cache_manager() - - # 为心理社交上下文创建专用缓存(如果不存在) - if not hasattr(self._cache_manager, 'psych_social_cache'): - from cachetools import TTLCache - self._cache_manager.psych_social_cache = TTLCache(maxsize=1000, ttl=300) # 5分钟TTL - # 注册到缓存管理器的映射表 - if hasattr(self._cache_manager, '_get_cache'): - # 动态添加到cache_map - logger.info("✅ [心理社交上下文] 已创建专用缓存 (maxsize=1000, ttl=300s)") - - # 后台任务管理 - 用于异步更新缓存 - self._background_tasks: set = set() - self._llm_generation_lock: Dict[str, asyncio.Lock] = {} # 防止重复LLM调用 - - def _get_prompt_protection(self): - """延迟加载提示词保护服务""" - if self._prompt_protection is None and self._enable_protection: - try: - from .prompt_sanitizer import PromptProtectionService - self._prompt_protection = PromptProtectionService(wrapper_template_index=2) - logger.info("心理社交上下文注入器: 提示词保护服务已加载") - except Exception as e: - logger.warning(f"加载提示词保护服务失败: {e}") - self._enable_protection = False - return self._prompt_protection - - def _get_from_cache(self, key: str) -> Optional[Any]: - """ - 从统一缓存管理器获取数据 - - Args: - key: 缓存键 - - Returns: - 缓存值或None - """ - return self._cache_manager.psych_social_cache.get(key) - - def _set_to_cache(self, key: str, data: Any): - """设置缓存到统一缓存管理器""" - self._cache_manager.psych_social_cache[key] = data - - async def build_complete_context( - self, - group_id: str, - user_id: str, - include_psychological: bool = True, - include_social_relation: bool = True, - include_affection: bool = True, - include_diversity: bool = True, - enable_protection: bool = True - ) -> str: - """ - 构建完整的上下文注入内容 - - Args: - group_id: 群组ID - user_id: 用户ID - include_psychological: 是否包含心理状态 - include_social_relation: 是否包含社交关系 - include_affection: 是否包含好感度 - include_diversity: 是否包含多样性指导 - enable_protection: 是否启用提示词保护 - - Returns: - 完整的上下文注入字符串 - """ - try: - context_parts = [] - - # 1. Bot的心理状态 - if include_psychological and self.psych_manager: - psych_context = await self._build_psychological_context(group_id) - if psych_context: - context_parts.append(psych_context) - logger.debug(f"✅ [心理社交上下文] 已准备心理状态 (群组: {group_id})") - - # 2. 用户的社交关系 - if include_social_relation and self.social_manager: - social_context = await self._build_social_relation_context( - user_id, group_id - ) - if social_context: - context_parts.append(social_context) - logger.debug(f"✅ [心理社交上下文] 已准备社交关系 (用户: {user_id[:8]}...)") - - # 3. 好感度信息 - if include_affection and self.affection_manager: - affection_context = await self._build_affection_context( - user_id, group_id - ) - if affection_context: - context_parts.append(affection_context) - logger.debug(f"✅ [心理社交上下文] 已准备好感度信息") - - # 4. 行为模式指导(基于心理状态和社交关系联动) - if include_psychological or include_social_relation: - behavior_guidance = await self._build_behavior_guidance( - group_id, user_id - ) - if behavior_guidance: - context_parts.append(behavior_guidance) - logger.debug(f"✅ [心理社交上下文] 已准备行为模式指导") - - # 5. 多样性指导(可选) - if include_diversity and self.diversity_manager: - diversity_context = await self._build_diversity_context(group_id) - if diversity_context: - context_parts.append(diversity_context) - logger.debug(f"✅ [心理社交上下文] 已准备多样性指导") - - if not context_parts: - return "" - - # 组合所有上下文 - raw_context = "\n\n".join(context_parts) - - # 应用提示词保护 - if enable_protection and self._enable_protection: - protection = self._get_prompt_protection() - if protection: - protected_context = protection.wrap_prompt(raw_context, register_for_filter=True) - logger.info( - f"✅ [心理社交上下文] 已保护包装 - " - f"原长度: {len(raw_context)}, 新长度: {len(protected_context)}" - ) - return protected_context - else: - logger.warning("⚠️ [心理社交上下文] 提示词保护服务不可用,使用原始文本") - - return raw_context - - except Exception as e: - logger.error(f"构建完整上下文失败: {e}", exc_info=True) - return "" - - async def _build_psychological_context(self, group_id: str) -> str: - """构建心理状态上下文""" - try: - cache_key = f"psych_context_{group_id}" - cached = self._get_from_cache(cache_key) - if cached: - return cached - - # 从心理状态管理器获取当前状态 - state_prompt = await self.psych_manager.get_state_prompt_injection(group_id) - - if state_prompt: - self._set_to_cache(cache_key, state_prompt) - return state_prompt - - return "" - - except Exception as e: - logger.error(f"构建心理状态上下文失败: {e}", exc_info=True) - return "" - - async def _build_social_relation_context( - self, - user_id: str, - group_id: str - ) -> str: - """构建社交关系上下文""" - try: - cache_key = f"social_context_{user_id}_{group_id}" - cached = self._get_from_cache(cache_key) - if cached: - return cached - - # 从社交关系管理器获取关系描述 - relation_prompt = await self.social_manager.get_relation_prompt_injection( - user_id, "bot", group_id - ) - - if relation_prompt: - self._set_to_cache(cache_key, relation_prompt) - return relation_prompt - - return "" - - except Exception as e: - logger.error(f"构建社交关系上下文失败: {e}", exc_info=True) - return "" - - async def _build_affection_context( - self, - user_id: str, - group_id: str - ) -> str: - """构建好感度上下文""" - try: - cache_key = f"affection_context_{user_id}_{group_id}" - cached = self._get_from_cache(cache_key) - if cached: - return cached - - # 从好感度管理器获取信息 - affection_data = await self.db_manager.get_user_affection(group_id, user_id) - - if not affection_data: - return "" - - level = affection_data.get('affection_level', 0) - max_level = affection_data.get('max_affection', 100) - - # 生成描述 - if level >= 80: - desc = "非常喜欢这个用户,关系非常亲密" - elif level >= 60: - desc = "比较喜欢这个用户,关系较好" - elif level >= 40: - desc = "对这个用户有一定好感" - elif level >= 20: - desc = "对这个用户略有好感" - elif level >= 0: - desc = "与这个用户初次见面,关系一般" - elif level >= -20: - desc = "对这个用户略有反感" - elif level >= -40: - desc = "比较不喜欢这个用户" - else: - desc = "非常讨厌这个用户" - - context = f"【对该用户的好感度】\n好感度: {level}/{max_level} ({desc})" - - self._set_to_cache(cache_key, context) - return context - - except Exception as e: - logger.error(f"构建好感度上下文失败: {e}", exc_info=True) - return "" - - async def _build_behavior_guidance( - self, - group_id: str, - user_id: str - ) -> str: - """ - 构建行为模式指导(基于心理状态和社交关系的联动分析) - - 这是核心功能:根据当前的心理状态和社交关系, - 使用LLM提炼模型生成对bot行为有强烈指导性但不死板的提示词 - - ⚡ 非阻塞设计: - - 优先返回缓存数据(5分钟TTL) - - 如果缓存不存在,返回空字符串,并在后台异步生成 - - 后台生成完成后更新缓存,下次调用时可用 - """ - try: - cache_key = f"behavior_guidance_{group_id}_{user_id}" - - # 1. 优先返回缓存(TTLCache自动管理过期,5分钟TTL) - cached = self._get_from_cache(cache_key) - if cached: - logger.debug(f"💾 [行为指导] 使用缓存 (group: {group_id[:8]}...)") - return cached - - # 2. 缓存未命中 - 检查是否已有后台生成任务在运行 - if cache_key not in self._llm_generation_lock: - self._llm_generation_lock[cache_key] = asyncio.Lock() - - # 尝试获取锁(非阻塞) - if self._llm_generation_lock[cache_key].locked(): - # 已有任务在生成,直接返回空字符串,不阻塞 - logger.debug(f"⏳ [行为指导] 生成任务进行中,返回空字符串 (group: {group_id[:8]}...)") - return "" - - # 3. 获取锁后,启动后台生成任务(不等待) - async with self._llm_generation_lock[cache_key]: - # 双重检查:再次查询缓存(可能其他协程已经生成了) - cached = self._get_from_cache(cache_key) - if cached: - return cached - - # 启动后台生成任务 - task = asyncio.create_task(self._background_generate_guidance( - cache_key, group_id, user_id - )) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - - # 立即返回空字符串,不阻塞主流程 - logger.debug(f"🚀 [行为指导] 已启动后台生成任务 (group: {group_id[:8]}...)") - return "" - - except Exception as e: - logger.error(f"构建行为模式指导失败: {e}", exc_info=True) - return "" - - async def _background_generate_guidance( - self, - cache_key: str, - group_id: str, - user_id: str - ): - """ - 后台生成行为指导(异步任务,不阻塞主流程) - - Args: - cache_key: 缓存键 - group_id: 群组ID - user_id: 用户ID - """ - try: - logger.debug(f"🔄 [后台任务] 开始生成行为指导 (group: {group_id[:8]}...)") - - # 获取心理状态 - psych_state = None - if self.psych_manager: - psych_state = await self.psych_manager.get_or_create_state(group_id) - - # 获取社交关系 - social_profile = None - if self.social_manager: - social_profile = await self.social_manager.get_or_create_profile( - user_id, group_id - ) - - # 获取好感度 - affection_level = 0 - if self.affection_manager: - try: - affection_data = await self.db_manager.get_user_affection(group_id, user_id) - if affection_data: - affection_level = affection_data.get('affection_level', 0) - except: - pass - - # 使用LLM提炼模型生成行为指导 - guidance = await self._generate_guidance_by_llm( - psych_state, social_profile, affection_level, group_id, user_id - ) - - if guidance: - # 缓存生成的指导(5分钟TTL) - self._set_to_cache(cache_key, guidance) - logger.info(f"✅ [后台任务] 行为指导生成完成并已缓存 (group: {group_id[:8]}...)") - else: - logger.warning(f"⚠️ [后台任务] LLM生成失败,未缓存 (group: {group_id[:8]}...)") - - except Exception as e: - logger.error(f"❌ [后台任务] 生成行为指导失败: {e}", exc_info=True) - - async def _generate_guidance_by_llm( - self, - psych_state, - social_profile, - affection_level: int, - group_id: str, - user_id: str - ) -> str: - """ - 使用LLM提炼模型生成行为指导prompt - - Args: - psych_state: 复合心理状态对象 - social_profile: 社交关系profile对象 - affection_level: 好感度等级 - group_id: 群组ID - user_id: 用户ID - - Returns: - LLM生成的行为指导prompt字符串 - """ - try: - # 检查LLM适配器是否可用 - if not self.llm_adapter or not hasattr(self.llm_adapter, 'has_refine_provider') or not self.llm_adapter.has_refine_provider(): - logger.warning("⚠️ [行为指导生成] LLM提炼模型不可用,无法生成指导") - return "" - - # 构建心理状态描述 - psych_desc = "" - active_components = [] - if psych_state: - active_components = psych_state.get_active_components() - if active_components: - psych_parts = [] - for component in active_components[:5]: # 取前5个最显著的状态 - category = component.category - state_name = component.state_type.value if hasattr( - component.state_type, 'value') else str(component.state_type) - intensity = component.value - psych_parts.append(f"- {category}: {state_name} (强度: {intensity:.2f})") - psych_desc = "\n".join(psych_parts) - - # 构建社交关系描述 - social_desc = "" - if social_profile: - significant_relations = social_profile.get_significant_relations() - if significant_relations: - social_parts = [] - for rel in significant_relations[:3]: # 取前3个最显著的关系 - rel_name = rel.relation_type.value if hasattr( - rel.relation_type, 'value') else str(rel.relation_type) - social_parts.append(f"- {rel_name} (强度: {rel.value:.2f})") - social_desc = "\n".join(social_parts) - - # 构建好感度描述 - if affection_level >= 80: - affection_desc = f"非常喜欢 ({affection_level}/100)" - elif affection_level >= 60: - affection_desc = f"比较喜欢 ({affection_level}/100)" - elif affection_level >= 40: - affection_desc = f"有一定好感 ({affection_level}/100)" - elif affection_level >= 20: - affection_desc = f"略有好感 ({affection_level}/100)" - elif affection_level >= 0: - affection_desc = f"初次见面 ({affection_level}/100)" - elif affection_level >= -20: - affection_desc = f"略有反感 ({affection_level}/100)" - elif affection_level >= -40: - affection_desc = f"比较不喜欢 ({affection_level}/100)" - else: - affection_desc = f"非常讨厌 ({affection_level}/100)" - - # 构建LLM prompt - prompt = self._build_llm_guidance_prompt( - psych_desc, social_desc, affection_desc - ) - - # 调用LLM生成 - logger.debug(f"📤 [行为指导] 调用LLM提炼模型生成指导 (group: {group_id[:8]}...)") - - response = await self.llm_adapter.refine_chat_completion( - prompt=prompt, - temperature=0.7 # 适度的创造性 - ) - - if response: - # 包装为标准格式 - guidance = f"【行为模式指导】\n{response.strip()}" - logger.info(f"✅ [行为指导] LLM生成成功 (长度: {len(guidance)})") - return guidance - else: - logger.warning("⚠️ [行为指导] LLM返回空响应") - return "" - - except Exception as e: - logger.error(f"❌ [行为指导] LLM生成失败: {e}", exc_info=True) - return "" - - def _build_llm_guidance_prompt( - self, - psych_desc: str, - social_desc: str, - affection_desc: str - ) -> str: - """ - 构建发送给LLM提炼模型的prompt - - Args: - psych_desc: 心理状态描述 - social_desc: 社交关系描述 - affection_desc: 好感度描述 - - Returns: - 完整的prompt字符串 - """ - prompt = f"""你是一个AI对话行为分析专家。根据以下Bot当前的心理状态、社交关系和好感度信息,生成一段简洁但有效的行为指导prompt。 - -【Bot当前心理状态】 -{psych_desc if psych_desc else "无明显心理状态"} - -【与该用户的社交关系】 -{social_desc if social_desc else "初次接触,关系陌生"} - -【对该用户的好感度】 -{affection_desc} - ---- - -请生成一段行为指导,用于指导Bot在对话中的语气、态度和行为模式。要求: - -1. **简洁明确**:用2-4个要点说明,每个要点1-2句话 -2. **自然灵活**:避免死板的规则,给Bot留有发挥空间 -3. **情感真实**:根据心理状态和好感度,体现真实的情感倾向 -4. **关系适配**:根据社交关系强度调整亲密度和距离感 - -输出格式示例: -- 语气风格:[描述语气特点,如"轻松活泼"、"平和沉稳"等] -- 态度倾向:[描述对用户的态度,如"友好热情"、"略显冷淡"等] -- 回复特点:[描述回复风格,如"简短直接"、"详细耐心"等] -- 特别提示:[可选的其他建议] - -请直接输出行为指导内容,不要包含额外的解释或标题。""" - - return prompt - - def _generate_emotion_guidance(self, emotion: str, intensity: float) -> str: - """根据情绪生成行为指导""" - intensity_desc = "强烈" if intensity > 0.7 else "明显" if intensity > 0.4 else "轻微" - - # 根据不同情绪生成不同指导 - if "愉悦" in emotion or "快乐" in emotion or "兴奋" in emotion: - return ( - f"- 当前情绪: {intensity_desc}{emotion}\n" - f" 建议: 语气活泼开朗,可以多用积极的表达,适度使用感叹号和表情符号\n" - f" 语言风格: 轻松愉快,对话中可以展现出喜悦的情感" - ) - elif "悲伤" in emotion or "难过" in emotion or "低落" in emotion: - return ( - f"- 当前情绪: {intensity_desc}{emotion}\n" - f" 建议: 语气相对平和或低沉,回复可以简短一些,避免过于活泼\n" - f" 语言风格: 略显沉默,可能需要他人的理解和安慰" - ) - elif "愤怒" in emotion or "烦躁" in emotion: - return ( - f"- 当前情绪: {intensity_desc}{emotion}\n" - f" 建议: 语气可能比较直接或强硬,回复简洁,不太有耐心\n" - f" 语言风格: 直接了当,可能略显不耐烦" - ) - elif "平静" in emotion or "平和" in emotion: - return ( - f"- 当前情绪: {intensity_desc}{emotion}\n" - f" 建议: 语气平稳温和,回复理性客观\n" - f" 语言风格: 平和稳重,给人安全感" - ) - else: - return f"- 当前情绪: {intensity_desc}{emotion}" - - def _generate_cognitive_guidance(self, cognitive: str, intensity: float) -> str: - """根据认知状态生成行为指导""" - if "专注" in cognitive or "集中" in cognitive: - return ( - f"- 当前认知状态: {cognitive}\n" - f" 建议: 回复详细深入,逻辑清晰,可以展开讨论\n" - f" 思维模式: 思路清晰,能够深入分析问题" - ) - elif "涣散" in cognitive or "分心" in cognitive: - return ( - f"- 当前认知状态: {cognitive}\n" - f" 建议: 回复可能相对简短,思路可能跳跃\n" - f" 思维模式: 注意力不太集中,可能需要对方提醒" - ) - else: - return f"- 当前认知状态: {cognitive}" - - def _generate_social_guidance(self, social: str, intensity: float) -> str: - """根据社交状态生成行为指导""" - if "友善" in social or "热情" in social: - return ( - f"- 当前社交状态: {social}\n" - f" 建议: 主动积极回应,语气友好热情,可以主动提问或关心对方\n" - f" 社交模式: 开放友好,乐于交流" - ) - elif "冷漠" in social or "疏离" in social: - return ( - f"- 当前社交状态: {social}\n" - f" 建议: 回复相对冷淡,不太主动,保持一定距离感\n" - f" 社交模式: 不太愿意深入交流,保持距离" - ) - else: - return f"- 当前社交状态: {social}" - - def _generate_energy_guidance(self, energy: str, intensity: float) -> str: - """根据精力状态生成行为指导""" - if "充沛" in energy or "活力" in energy: - return ( - f"- 当前精力状态: {energy}\n" - f" 建议: 回复可以更长更详细,充满活力\n" - f" 精神状态: 精力旺盛,可以处理复杂话题" - ) - elif "疲惫" in energy or "困倦" in energy or "瞌睡" in energy: - return ( - f"- 当前精力状态: {energy}\n" - f" 建议: 回复简短一些,可能反应稍慢,语气略显疲惫\n" - f" 精神状态: 比较累,需要休息" - ) - else: - return f"- 当前精力状态: {energy}" - - def _generate_relation_guidance( - self, - relation_type: str, - relation_value: float, - affection_level: int - ) -> str: - """根据社交关系生成行为指导""" - strength = "非常强" if relation_value > 0.7 else "较强" if relation_value > 0.4 else "一般" - - guidance = f"- 与该用户的关系: {relation_type} (强度: {strength})\n" - - # 根据关系类型调整语气和态度 - if "挚友" in relation_type or "知己" in relation_type or "闺蜜" in relation_type: - guidance += ( - " 建议: 语气亲密自然,可以开玩笑,展现真实性格\n" - " 态度: 放松随意,无需过分客套,像对待老朋友一样" - ) - elif "恋人" in relation_type or "情侣" in relation_type: - guidance += ( - " 建议: 语气温柔体贴,关心对方,可以适度撒娇或甜蜜\n" - " 态度: 亲密关爱,重视对方的感受" - ) - elif "同事" in relation_type or "同学" in relation_type: - guidance += ( - " 建议: 语气友好但保持适当专业性\n" - " 态度: 友善合作,但不过分亲密" - ) - elif "陌生" in relation_type or relation_value < 0.2: - guidance += ( - " 建议: 语气礼貌客气,保持一定距离\n" - " 态度: 谨慎友好,慢慢建立信任" - ) - else: - guidance += ( - " 建议: 根据具体情况自然应对\n" - " 态度: 友好适度" - ) - - # 结合好感度调整 - if affection_level >= 70: - guidance += "\n 特别提示: 好感度很高,可以更加亲近和真实" - elif affection_level <= -20: - guidance += "\n 特别提示: 好感度较低,需要谨慎应对,避免冲突" - - return guidance - - async def _build_diversity_context(self, group_id: str) -> str: - """构建多样性指导上下文""" - try: - if not self.diversity_manager: - return "" - - # 获取多样性管理器的当前设置 - current_style = self.diversity_manager.get_current_style() - current_pattern = self.diversity_manager.get_current_pattern() - - if not current_style and not current_pattern: - return "" - - context_parts = ["【回复多样性指导】"] - - if current_style: - context_parts.append(f"当前语言风格: {current_style}") - - if current_pattern: - context_parts.append(f"推荐回复模式: {current_pattern}") - - context_parts.append( - "注意: 这些是参考建议,请自然运用,不必严格遵守" - ) - - return "\n".join(context_parts) - - except Exception as e: - logger.error(f"构建多样性上下文失败: {e}") - return "" - - async def inject_to_system_prompt( - self, - original_system_prompt: str, - group_id: str, - user_id: str, - position: str = "end" - ) -> str: - """ - 将完整上下文注入到system prompt - - Args: - original_system_prompt: 原始system prompt - group_id: 群组ID - user_id: 用户ID - position: 注入位置 ('start' 或 'end') - - Returns: - 注入后的system prompt - """ - try: - context = await self.build_complete_context( - group_id, user_id, - include_psychological=True, - include_social_relation=True, - include_affection=True, - include_diversity=False, # 多样性指导通常单独处理 - enable_protection=True - ) - - if not context: - return original_system_prompt - - if position == "start": - return f"{context}\n\n{original_system_prompt}" - else: - return f"{original_system_prompt}\n\n{context}" - - except Exception as e: - logger.error(f"注入上下文到system prompt失败: {e}", exc_info=True) - return original_system_prompt diff --git a/services/psychological_state_manager.py b/services/psychological_state_manager.py deleted file mode 100644 index 6516e99..0000000 --- a/services/psychological_state_manager.py +++ /dev/null @@ -1,867 +0,0 @@ -""" -心理状态管理器 - 管理bot的复合心理状态 -支持多维度心理状态(情绪、认知、意志等)的动态管理和状态转换 -""" -import asyncio -import random -import time -import uuid -import json -from typing import Dict, List, Optional, Any, Tuple -from datetime import datetime, timedelta - -from astrbot.api import logger - -from ..config import PluginConfig -from ..core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage -from ..core.framework_llm_adapter import FrameworkLLMAdapter - -from ..models.psychological_state import ( - EmotionPositiveType, EmotionNegativeType, EmotionNeutralType, - AttentionState, ThinkingState, MemoryState, - WillStrengthState, ActionTendencyState, GoalOrientationState, - SelfAcceptanceState, PersonalityTendencyState, - SocialAttitudeState, SocialBehaviorState, - EnergyState, InterestMotivationState, - PsychologicalStateComponent, CompositePsychologicalState -) -from ..utils.guardrails_manager import get_guardrails_manager - - -class PsychologicalStateManager(AsyncServiceBase): - """ - 心理状态管理器 - 管理bot的复合心理状态 - - 核心功能: - 1. 维护多维度心理状态(情绪、认知、意志、社交等) - 2. 根据时间、事件、好感度变化等因素动态调整状态 - 3. 当某个状态数值降到阈值以下时,使用LLM智能分析并切换状态 - 4. 生成心理状态的prompt注入内容,指导bot的行为模式 - """ - - def __init__(self, config: PluginConfig, database_manager: IDataStorage, - llm_adapter: Optional[FrameworkLLMAdapter] = None, - affection_manager=None): - super().__init__("psychological_state_manager") - self.config = config - self.db_manager = database_manager - self.llm_adapter = llm_adapter - self.affection_manager = affection_manager - - # 当前活跃的心理状态缓存 {group_id: CompositePsychologicalState} - self.current_states: Dict[str, CompositePsychologicalState] = {} - - # 状态自然衰减速率配置 - self.decay_rates = { - "情绪": 0.02, # 情绪衰减较快 - "认知": 0.01, # 认知状态较稳定 - "意志": 0.015, - "自我认知": 0.005, # 自我认知最稳定 - "社交": 0.015, - "精力": 0.03, # 精力衰减最快 - "兴趣": 0.01 - } - - # 时间段对心理状态的影响规则 - self.time_based_rules = self._init_time_based_rules() - - async def _do_start(self) -> bool: - """启动心理状态管理服务""" - try: - # 加载所有群组的当前心理状态 - await self._load_all_states() - - # 启动状态自动衰减任务 - asyncio.create_task(self._auto_decay_task()) - - # 启动时间驱动的状态变化任务 - asyncio.create_task(self._time_driven_state_change_task()) - - self._logger.info("心理状态管理服务启动成功") - return True - except Exception as e: - self._logger.error(f"心理状态管理服务启动失败: {e}", exc_info=True) - return False - - async def _do_stop(self) -> bool: - """停止心理状态管理服务""" - try: - # 保存所有当前状态到数据库 - await self._save_all_states() - self._logger.info("心理状态管理服务已停止") - return True - except Exception as e: - self._logger.error(f"停止心理状态管理服务失败: {e}") - return False - - def _init_time_based_rules(self) -> List[Dict[str, Any]]: - """初始化基于时间的状态变化规则""" - return [ - { - "time_range": (0, 5), # 凌晨0-5点 - "states": [ - ("精力", EnergyState.SLEEPY, 0.7, "凌晨时分非常困倦"), - ("认知", AttentionState.SCATTERED, 0.6, "注意力涣散"), - ("情绪", EmotionNeutralType.CALM, 0.5, "夜深人静心情平静") - ], - "description": "深夜时分,困倦且注意力不集中" - }, - { - "time_range": (6, 8), # 早上6-8点 - "states": [ - ("精力", EnergyState.DROWSY, 0.6, "刚起床还有些困"), - ("情绪", EmotionPositiveType.JOYFUL, 0.4, "新的一天轻松愉悦"), - ("认知", AttentionState.SCATTERED, 0.5, "注意力还没完全集中") - ], - "description": "清晨刚起床,有些困但心情还不错" - }, - { - "time_range": (9, 11), # 上午9-11点 - "states": [ - ("精力", EnergyState.VIGOROUS, 0.7, "精力充沛"), - ("认知", AttentionState.FOCUSED, 0.7, "注意力集中"), - ("情绪", EmotionPositiveType.MOTIVATED, 0.6, "充满干劲") - ], - "description": "上午精力旺盛,状态最佳" - }, - { - "time_range": (12, 13), # 中午12-13点 - "states": [ - ("精力", EnergyState.DROWSY, 0.5, "午饭后有些困"), - ("情绪", EmotionPositiveType.SATISFIED, 0.6, "吃饱了感到满足") - ], - "description": "午饭后有些困倦" - }, - { - "time_range": (14, 17), # 下午14-17点 - "states": [ - ("精力", EnergyState.VIGOROUS, 0.6, "精力恢复"), - ("认知", AttentionState.FOCUSED, 0.6, "注意力不错"), - ("意志", ActionTendencyState.PROACTIVE, 0.5, "比较主动") - ], - "description": "下午精力恢复,工作状态良好" - }, - { - "time_range": (18, 21), # 傍晚18-21点 - "states": [ - ("精力", EnergyState.TIRED, 0.5, "开始感到疲惫"), - ("情绪", EmotionPositiveType.RELAXED, 0.6, "工作结束轻松下来"), - ("社交", SocialAttitudeState.FRIENDLY, 0.6, "友善放松") - ], - "description": "傍晚放松时光,友善但有些疲惫" - }, - { - "time_range": (22, 23), # 晚上22-23点 - "states": [ - ("精力", EnergyState.FATIGUED_ENERGY, 0.6, "比较疲劳"), - ("情绪", EmotionNeutralType.PEACEFUL, 0.5, "平和宁静"), - ("认知", AttentionState.SCATTERED, 0.5, "注意力开始涣散") - ], - "description": "深夜渐晚,疲劳且平和" - }, - ] - - async def get_or_create_state(self, group_id: str) -> CompositePsychologicalState: - """获取或创建群组的心理状态""" - try: - # 先从缓存获取 - if group_id in self.current_states: - return self.current_states[group_id] - - # 从数据库加载 - loaded_state = await self._load_state_from_db(group_id) - if loaded_state: - self.current_states[group_id] = loaded_state - return loaded_state - - # 创建新状态 - new_state = await self._create_initial_state(group_id) - self.current_states[group_id] = new_state - await self._save_state_to_db(new_state) - return new_state - - except Exception as e: - self._logger.error(f"获取或创建心理状态失败: {e}", exc_info=True) - # 返回一个空的状态对象,避免程序崩溃 - return CompositePsychologicalState(group_id=group_id, state_id=str(uuid.uuid4())) - - async def _create_initial_state(self, group_id: str) -> CompositePsychologicalState: - """ - 创建初始心理状态(基于当前时间 + 随机积极状态) - - 初始化时会生成相对随机但较为积极的心理状态,包括: - - 随机的积极情绪状态(轻度到中度) - - 随机的认知状态(注意力/思维等) - - 随机的精力状态 - - 随机的社交状态 - 每个状态的强度也是随机的,但保持在合理范围内 - """ - state_id = str(uuid.uuid4()) - state = CompositePsychologicalState( - group_id=group_id, - state_id=state_id - ) - - # 根据当前时间设置基础状态(保持原有逻辑) - current_hour = datetime.now().hour - time_based_applied = False - for rule in self.time_based_rules: - start, end = rule["time_range"] - if start <= current_hour < end: - for category, state_type, value, description in rule["states"]: - component = PsychologicalStateComponent( - category=category, - state_type=state_type, - value=value, - description=description - ) - state.add_component(component) - state.triggering_events.append(f"初始化: {rule['description']}") - self._logger.info(f"群组 {group_id} 基础心理状态: {rule['description']}") - time_based_applied = True - break - - # 添加随机的积极心理状态(增强初始状态的多样性) - # 1. 随机积极情绪 (40%-70%强度) - positive_emotions = [ - EmotionPositiveType.JOYFUL, - EmotionPositiveType.HAPPY, - EmotionPositiveType.SATISFIED, - EmotionPositiveType.RELAXED, - EmotionPositiveType.COMFORTABLE, - EmotionPositiveType.PLEASANT, - EmotionPositiveType.CHEERFUL - ] - selected_emotion = random.choice(positive_emotions) - emotion_intensity = random.uniform(0.4, 0.7) # 中等强度的积极情绪 - state.add_component(PsychologicalStateComponent( - category="情绪", - state_type=selected_emotion, - value=emotion_intensity, - description=f"初始化时的随机积极情绪" - )) - - # 2. 随机认知状态 (30%-60%强度) - attention_states = [ - AttentionState.FOCUSED, - AttentionState.CONCENTRATED, - AttentionState.ATTENTIVE - ] - selected_attention = random.choice(attention_states) - attention_intensity = random.uniform(0.3, 0.6) - state.add_component(PsychologicalStateComponent( - category="认知", - state_type=selected_attention, - value=attention_intensity, - description=f"初始化时的认知状态" - )) - - # 3. 随机社交状态 (40%-65%强度) - social_states = [ - SocialAttitudeState.FRIENDLY, - SocialAttitudeState.CORDIAL, - SocialAttitudeState.WARM, - SocialAttitudeState.TOLERANT - ] - selected_social = random.choice(social_states) - social_intensity = random.uniform(0.4, 0.65) - state.add_component(PsychologicalStateComponent( - category="社交", - state_type=selected_social, - value=social_intensity, - description=f"初始化时的社交态度" - )) - - # 4. 随机精力状态 (35%-65%强度) - # 根据时间调整精力状态范围 - if 9 <= current_hour < 17: # 白天精力更高 - energy_range = (0.5, 0.75) - energy_states = [EnergyState.VIGOROUS, EnergyState.ENERGETIC_FULL] - elif 22 <= current_hour or current_hour < 6: # 深夜和凌晨精力较低 - energy_range = (0.25, 0.45) - energy_states = [EnergyState.TIRED, EnergyState.DROWSY] - else: # 其他时间中等 - energy_range = (0.35, 0.65) - energy_states = [EnergyState.VIGOROUS, EnergyState.TIRED, EnergyState.DROWSY] - - selected_energy = random.choice(energy_states) - energy_intensity = random.uniform(*energy_range) - state.add_component(PsychologicalStateComponent( - category="精力", - state_type=selected_energy, - value=energy_intensity, - description=f"初始化时的精力状态" - )) - - state.triggering_events.append(f"随机积极状态初始化完成") - self._logger.info( - f"✅ 群组 {group_id} 已初始化随机积极心理状态 - " - f"情绪:{selected_emotion.value}({emotion_intensity:.2f}), " - f"认知:{selected_attention.value}({attention_intensity:.2f}), " - f"社交:{selected_social.value}({social_intensity:.2f}), " - f"精力:{selected_energy.value}({energy_intensity:.2f})" - ) - - return state - - async def update_state_by_event( - self, - group_id: str, - event_type: str, - event_context: Dict[str, Any] - ) -> CompositePsychologicalState: - """ - 根据事件更新心理状态 - - Args: - group_id: 群组ID - event_type: 事件类型 (如: "user_compliment", "user_insult", "affection_change"等) - event_context: 事件上下文信息 - """ - try: - state = await self.get_or_create_state(group_id) - - # 根据事件类型应用不同的状态变化规则 - if event_type == "user_compliment": - await self._handle_positive_interaction(state, event_context) - elif event_type == "user_insult": - await self._handle_negative_interaction(state, event_context) - elif event_type == "affection_high": - await self._handle_high_affection_event(state, event_context) - elif event_type == "time_change": - await self._handle_time_change(state, event_context) - else: - self._logger.warning(f"未知的事件类型: {event_type}") - - # 检查是否有状态组件需要转换 - await self._check_and_transition_states(state, event_context) - - # 保存更新后的状态 - await self._save_state_to_db(state) - - return state - - except Exception as e: - self._logger.error(f"根据事件更新心理状态失败: {e}", exc_info=True) - return await self.get_or_create_state(group_id) - - async def _handle_positive_interaction( - self, - state: CompositePsychologicalState, - context: Dict[str, Any] - ): - """处理积极交互事件""" - # 提升情绪状态 - state.update_component_value("情绪", +0.1) - - # 提升社交状态 - state.update_component_value("社交", +0.05) - - state.triggering_events.append(f"积极交互: {context.get('description', '未知')}") - - async def _handle_negative_interaction( - self, - state: CompositePsychologicalState, - context: Dict[str, Any] - ): - """处理消极交互事件""" - # 降低情绪状态 - state.update_component_value("情绪", -0.15) - - # 影响社交状态 - state.update_component_value("社交", -0.1) - - # 降低精力 - state.update_component_value("精力", -0.05) - - state.triggering_events.append(f"消极交互: {context.get('description', '未知')}") - - async def _handle_high_affection_event( - self, - state: CompositePsychologicalState, - context: Dict[str, Any] - ): - """处理高好感度事件""" - # 提升情绪 - state.update_component_value("情绪", +0.08) - - # 提升社交友好度 - state.update_component_value("社交", +0.08) - - state.triggering_events.append(f"高好感度: {context.get('user_id', '未知用户')}") - - async def _handle_time_change( - self, - state: CompositePsychologicalState, - context: Dict[str, Any] - ): - """处理时间变化事件""" - current_hour = context.get("hour", datetime.now().hour) - - for rule in self.time_based_rules: - start, end = rule["time_range"] - if start <= current_hour < end: - # 根据时间段调整状态 - for category, state_type, value, description in rule["states"]: - # 查找是否已有该类别的状态 - existing = None - for comp in state.components: - if comp.category == category: - existing = comp - break - - if existing: - # 缓慢过渡到目标状态 - target_value = value - delta = (target_value - existing.value) * 0.3 # 30%的过渡 - existing.update_value(delta) - else: - # 添加新状态 - component = PsychologicalStateComponent( - category=category, - state_type=state_type, - value=value, - description=description - ) - state.add_component(component) - - break - - async def _check_and_transition_states( - self, - state: CompositePsychologicalState, - event_context: Dict[str, Any] - ): - """检查并转换需要改变的状态""" - transitioning = state.get_transitioning_components() - - if not transitioning: - return - - self._logger.info(f"检测到 {len(transitioning)} 个需要转换的心理状态组件") - - for component in transitioning: - try: - # 使用LLM分析应该转换到什么状态 - new_state_type = await self._analyze_state_transition( - state, component, event_context - ) - - if new_state_type: - # 记录状态变化历史 - await self._record_state_history( - state.group_id, - state.state_id, - component.category, - component.state_type, - new_state_type, - component.value, - 0.5, # 新状态初始值 - "自动分析转换" - ) - - # 更新状态 - component.state_type = new_state_type - component.value = 0.5 # 重置为中等强度 - component.start_time = time.time() - - self._logger.info( - f"状态转换: {component.category} " - f"从 {component.state_type} 转换到 {new_state_type}" - ) - - except Exception as e: - self._logger.error(f"状态转换失败: {e}", exc_info=True) - - async def _analyze_state_transition( - self, - state: CompositePsychologicalState, - component: PsychologicalStateComponent, - context: Dict[str, Any] - ) -> Optional[Any]: - """使用LLM分析应该转换到什么状态""" - if not self.llm_adapter or not self.llm_adapter.has_refine_provider(): - self._logger.warning("LLM适配器不可用,无法进行智能状态分析") - return self._fallback_state_transition(component) - - try: - # 构建分析prompt - prompt = self._build_transition_analysis_prompt(state, component, context) - - # 调用LLM分析 - response = await self.llm_adapter.refine_chat_completion( - prompt=prompt, - temperature=0.3 - ) - - if response: - # 解析LLM返回的状态类型 - new_state = self._parse_transition_response(response, component.category) - return new_state - - except Exception as e: - self._logger.error(f"LLM状态分析失败: {e}") - - return self._fallback_state_transition(component) - - def _build_transition_analysis_prompt( - self, - state: CompositePsychologicalState, - component: PsychologicalStateComponent, - context: Dict[str, Any] - ) -> str: - """构建状态转换分析的prompt""" - # 获取当前所有活跃状态的描述 - active_states_desc = "\n".join([ - f"- {c.category}: {c.state_type.value if hasattr(c.state_type, 'value') else str(c.state_type)} (强度: {c.value:.2f})" - for c in state.get_active_components() - ]) - - # 获取最近的触发事件 - recent_events = "\n".join([f"- {event}" for event in state.triggering_events[-5:]]) - - # 获取好感度信息(如果有) - affection_info = "" - if "user_id" in context and self.affection_manager: - try: - affection_data = self.affection_manager.db_manager.get_user_affection( - state.group_id, context["user_id"] - ) - if affection_data: - affection_info = f"\n对该用户的好感度: {affection_data.get('affection_level', 0)}" - except: - pass - - category = component.category - current_state = component.state_type.value if hasattr(component.state_type, 'value') else str(component.state_type) - current_value = component.value - - prompt = f""" -你是一个心理状态分析专家。Bot当前的心理状态组件"{category}: {current_state}"的数值已降至{current_value:.2f},低于阈值,需要转换到新的状态。 - -【当前完整心理状态】 -{active_states_desc} - -【最近触发事件】 -{recent_events} -{affection_info} - -【时间信息】 -当前时间: {datetime.now().strftime('%H:%M')} -星期: {datetime.now().strftime('%A')} - -请根据以上信息,分析Bot的{category}状态应该转换到什么新状态。 - -可选的{category}状态类型(仅供参考): -{self._get_category_state_options(category)} - -请只返回一个具体的状态名称(中文),不要返回其他内容。 -例如: "疲惫" 或 "轻松" 或 "专注" -""" - return prompt - - def _get_category_state_options(self, category: str) -> str: - """获取某个类别的可选状态列表""" - options_map = { - "情绪": "愉悦、快乐、兴奋、满足、悲伤、难过、愤怒、焦虑、平静、放松", - "认知": "专注、集中、涣散、分心、清晰思维、混乱思维、敏锐感知", - "意志": "坚定、坚持、软弱、放弃、主动、被动", - "精力": "精力充沛、活力满满、疲惫、疲劳、困倦、瞌睡", - "社交": "友善、热情、冷漠、疏离、主动社交、被动社交", - "兴趣": "兴趣浓厚、好奇心强、兴趣索然、缺乏动力" - } - return options_map.get(category, "根据上下文自行判断合适的状态") - - def _parse_transition_response(self, response: str, category: str) -> Optional[Any]: - """解析LLM返回的状态转换结果 - 使用 JSON 清洗工具""" - # 使用 JSON 清洗工具解析状态名称 - state_name = LLMJSONParser.parse_state_analysis(response) - - if not state_name: - self._logger.warning(f"无法解析LLM返回的状态: {response}") - return None - - # 尝试匹配到具体的枚举类型 - category_enum_map = { - "情绪": [EmotionPositiveType, EmotionNegativeType, EmotionNeutralType], - "认知": [AttentionState, ThinkingState, MemoryState], - "意志": [WillStrengthState, ActionTendencyState, GoalOrientationState], - "社交": [SocialAttitudeState, SocialBehaviorState], - "精力": [EnergyState], - "兴趣": [InterestMotivationState] - } - - enums_to_check = category_enum_map.get(category, []) - - for enum_class in enums_to_check: - for enum_val in enum_class: - if enum_val.value in state_name: - self._logger.debug(f"✅ 成功匹配状态: {state_name} -> {enum_val.value}") - return enum_val - - self._logger.warning(f"无法匹配到枚举类型: {state_name} (类别: {category})") - return None - - def _fallback_state_transition(self, component: PsychologicalStateComponent) -> Optional[Any]: - """备用的状态转换逻辑(随机选择)""" - category = component.category - - category_enum_map = { - "情绪": [EmotionPositiveType, EmotionNegativeType, EmotionNeutralType], - "认知": [AttentionState, ThinkingState], - "意志": [WillStrengthState, ActionTendencyState], - "社交": [SocialAttitudeState, SocialBehaviorState], - "精力": [EnergyState], - "兴趣": [InterestMotivationState] - } - - enums = category_enum_map.get(category, []) - if enums: - enum_class = random.choice(enums) - return random.choice(list(enum_class)) - - return None - - async def _auto_decay_task(self): - """自动衰减任务 - 定期降低所有状态的数值""" - while True: - try: - await asyncio.sleep(1800) # 每30分钟执行一次 - - for group_id, state in self.current_states.items(): - for component in state.components: - decay_rate = self.decay_rates.get(component.category, 0.01) - component.update_value(-decay_rate) - - # 检查是否有状态需要转换 - await self._check_and_transition_states(state, {"trigger": "auto_decay"}) - - self._logger.debug("心理状态自动衰减完成") - - except Exception as e: - self._logger.error(f"自动衰减任务失败: {e}", exc_info=True) - await asyncio.sleep(1800) - - async def _time_driven_state_change_task(self): - """时间驱动的状态变化任务""" - last_hour = datetime.now().hour - - while True: - try: - await asyncio.sleep(300) # 每5分钟检查一次 - - current_hour = datetime.now().hour - if current_hour != last_hour: - # 小时变化,触发时间驱动的状态变化 - for group_id, state in self.current_states.items(): - await self.update_state_by_event( - group_id, - "time_change", - {"hour": current_hour} - ) - - last_hour = current_hour - self._logger.info(f"时间驱动状态变化完成 (当前: {current_hour}点)") - - except Exception as e: - self._logger.error(f"时间驱动状态变化任务失败: {e}", exc_info=True) - await asyncio.sleep(300) - - async def get_state_prompt_injection(self, group_id: str) -> str: - """获取用于prompt注入的心理状态描述""" - try: - state = await self.get_or_create_state(group_id) - return state.to_prompt_injection() - except Exception as e: - self._logger.error(f"生成状态prompt注入失败: {e}") - return "" - - # ==================== 数据库操作 ==================== - - async def _load_state_from_db(self, group_id: str) -> Optional[CompositePsychologicalState]: - """从数据库加载心理状态""" - try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 查询复合状态元数据 - await cursor.execute(''' - SELECT state_id, triggering_events, context, created_at, last_updated - FROM composite_psychological_states - WHERE group_id = ? - ORDER BY last_updated DESC - LIMIT 1 - ''', (group_id,)) - - row = await cursor.fetchone() - if not row: - return None - - state_id, events_json, context_json, created_at, last_updated = row - - state = CompositePsychologicalState( - group_id=group_id, - state_id=state_id, - triggering_events=json.loads(events_json) if events_json else [], - context=json.loads(context_json) if context_json else {}, - created_at=created_at, - last_updated=last_updated - ) - - # 查询所有组件 - await cursor.execute(''' - SELECT category, state_type, value, threshold, description, start_time - FROM psychological_state_components - WHERE group_id = ? AND state_id = ? - ''', (group_id, state_id)) - - for row in await cursor.fetchall(): - category, state_type_str, value, threshold, description, start_time = row - - # 重建枚举类型(简化处理) - component = PsychologicalStateComponent( - category=category, - state_type=state_type_str, # 暂时用字符串,实际应该恢复枚举 - value=value, - threshold=threshold, - description=description, - start_time=start_time - ) - state.components.append(component) - - await cursor.close() - return state - - except Exception as e: - self._logger.error(f"从数据库加载心理状态失败: {e}", exc_info=True) - return None - - async def _save_state_to_db(self, state: CompositePsychologicalState): - """保存心理状态到数据库""" - try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # ✅ 使用数据库无关的语法:DELETE + INSERT 替代 INSERT OR REPLACE - # 先删除旧记录 - await cursor.execute(''' - DELETE FROM composite_psychological_states - WHERE group_id = ? AND state_id = ? - ''', (state.group_id, state.state_id)) - - # 再插入新记录 - await cursor.execute(''' - INSERT INTO composite_psychological_states - (group_id, state_id, triggering_events, context, created_at, last_updated) - VALUES (?, ?, ?, ?, ?, ?) - ''', ( - state.group_id, - state.state_id, - json.dumps(state.triggering_events, ensure_ascii=False), - json.dumps(state.context, ensure_ascii=False), - state.created_at, - time.time() - )) - - # 删除旧的组件 - await cursor.execute(''' - DELETE FROM psychological_state_components - WHERE group_id = ? AND state_id = ? - ''', (state.group_id, state.state_id)) - - # 保存所有组件 - for component in state.components: - state_type_str = component.state_type.value if hasattr(component.state_type, 'value') else str(component.state_type) - - await cursor.execute(''' - INSERT INTO psychological_state_components - (group_id, state_id, category, state_type, value, threshold, description, start_time) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - state.group_id, - state.state_id, - component.category, - state_type_str, - component.value, - component.threshold, - component.description, - component.start_time - )) - - await conn.commit() - await cursor.close() - - except Exception as e: - self._logger.error(f"保存心理状态到数据库失败: {e}", exc_info=True) - - async def _record_state_history( - self, - group_id: str, - state_id: str, - category: str, - old_state_type: Any, - new_state_type: Any, - old_value: float, - new_value: float, - reason: str - ): - """记录状态变化历史""" - try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - old_str = old_state_type.value if hasattr(old_state_type, 'value') else str(old_state_type) - new_str = new_state_type.value if hasattr(new_state_type, 'value') else str(new_state_type) - - await cursor.execute(''' - INSERT INTO psychological_state_history - (group_id, state_id, category, old_state_type, new_state_type, - old_value, new_value, change_reason, timestamp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - group_id, state_id, category, old_str, new_str, - old_value, new_value, reason, time.time() - )) - - await conn.commit() - await cursor.close() - - except Exception as e: - self._logger.error(f"记录状态历史失败: {e}") - - async def _load_all_states(self): - """加载所有群组的当前状态""" - try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - SELECT DISTINCT group_id FROM composite_psychological_states - WHERE last_updated > ? - ''', (time.time() - 86400 * 7,)) # 最近7天 - - rows = await cursor.fetchall() - await cursor.close() - - for row in rows: - group_id = row[0] - state = await self._load_state_from_db(group_id) - if state: - self.current_states[group_id] = state - - self._logger.info(f"已加载 {len(self.current_states)} 个群组的心理状态") - - except Exception as e: - self._logger.error(f"加载所有状态失败: {e}", exc_info=True) - - async def _save_all_states(self): - """保存所有当前状态""" - try: - for state in self.current_states.values(): - await self._save_state_to_db(state) - - self._logger.info(f"已保存 {len(self.current_states)} 个群组的心理状态") - - except Exception as e: - self._logger.error(f"保存所有状态失败: {e}", exc_info=True) diff --git a/services/quality/__init__.py b/services/quality/__init__.py new file mode 100644 index 0000000..305e090 --- /dev/null +++ b/services/quality/__init__.py @@ -0,0 +1,17 @@ +"""Learning quality control -- goal management, monitoring, triggers.""" + +from .conversation_goal_manager import ConversationGoalManager +from .learning_quality_monitor import LearningQualityMonitor +from .tiered_learning_trigger import ( + TieredLearningTrigger, + BatchTriggerPolicy, + TriggerResult, +) + +__all__ = [ + "ConversationGoalManager", + "LearningQualityMonitor", + "TieredLearningTrigger", + "BatchTriggerPolicy", + "TriggerResult", +] diff --git a/services/conversation_goal_manager.py b/services/quality/conversation_goal_manager.py similarity index 94% rename from services/conversation_goal_manager.py rename to services/quality/conversation_goal_manager.py index 38acd58..ff22a05 100644 --- a/services/conversation_goal_manager.py +++ b/services/quality/conversation_goal_manager.py @@ -7,7 +7,7 @@ import hashlib from astrbot.api import logger -from ..repositories.conversation_goal_repository import ConversationGoalRepository +from ...repositories.conversation_goal_repository import ConversationGoalRepository class ConversationGoalManager: @@ -15,7 +15,7 @@ class ConversationGoalManager: # 预定义目标模板 (30+种类型,实际会动态调整) GOAL_TEMPLATES = { - # ===== 情感支持类 ===== + # 情感支持类 "comfort": { "name": "安慰用户", "base_stages": ["初步共情", "弱化负面情绪", "给出轻量安慰"], @@ -41,7 +41,7 @@ class ConversationGoalManager: "min_rounds": 3 }, - # ===== 信息交流类 ===== + # 信息交流类 "qa": { "name": "解答疑问", "base_stages": ["理解问题", "提供答案", "确认满意度"], @@ -73,7 +73,7 @@ class ConversationGoalManager: "min_rounds": 4 }, - # ===== 娱乐互动类 ===== + # 娱乐互动类 "casual_chat": { "name": "闲聊互动", "base_stages": ["回应话题", "自然互动"], @@ -111,7 +111,7 @@ class ConversationGoalManager: "min_rounds": 4 }, - # ===== 社交互动类 ===== + # 社交互动类 "greeting": { "name": "问候寒暄", "base_stages": ["回应问候", "关心近况", "自然过渡"], @@ -143,7 +143,7 @@ class ConversationGoalManager: "min_rounds": 4 }, - # ===== 建议指导类 ===== + # 建议指导类 "advise": { "name": "提供建议", "base_stages": ["理解需求", "分析情况", "给出建议", "补充说明"], @@ -169,7 +169,7 @@ class ConversationGoalManager: "min_rounds": 4 }, - # ===== 情绪调节类 ===== + # 情绪调节类 "calm_down": { "name": "情绪安抚", "base_stages": ["承认情绪", "理解原因", "引导冷静", "转移注意"], @@ -189,7 +189,7 @@ class ConversationGoalManager: "min_rounds": 3 }, - # ===== 兴趣分享类 ===== + # 兴趣分享类 "recommend": { "name": "推荐分享", "base_stages": ["了解偏好", "推荐内容", "说明亮点", "引发兴趣"], @@ -209,7 +209,7 @@ class ConversationGoalManager: "min_rounds": 4 }, - # ===== 特殊场景类 ===== + # 特殊场景类 "debate": { "name": "友好辩论", "base_stages": ["阐述观点", "论证立场", "反驳质疑", "求同存异"], @@ -229,7 +229,7 @@ class ConversationGoalManager: "min_rounds": 4 }, - # ===== 冲突场景类 ===== + # 冲突场景类 "argument": { "name": "激烈争论", "base_stages": ["理解立场", "冷静回应", "寻找共识", "缓和气氛"], @@ -279,11 +279,11 @@ def __init__(self, database_manager, llm_adapter, config): self.session_timeout_hours = 24 # 初始化提示词保护服务 - from ..services.prompt_sanitizer import PromptProtectionService + from ..response import PromptProtectionService self.prompt_protection = PromptProtectionService(wrapper_template_index=0) # 初始化Guardrails管理器用于JSON验证 - from ..utils.guardrails_manager import get_guardrails_manager, GoalAnalysisResult, ConversationIntentAnalysis + from ...utils.guardrails_manager import get_guardrails_manager, GoalAnalysisResult, ConversationIntentAnalysis self.guardrails = get_guardrails_manager() self.GoalAnalysisResult = GoalAnalysisResult self.ConversationIntentAnalysis = ConversationIntentAnalysis @@ -485,32 +485,32 @@ async def _analyze_initial_goal(self, user_message: str) -> Dict: # 使用提示词保护包装 protected_prompt = self.prompt_protection.wrap_prompt(prompt, register_for_filter=True) - # ✅ Debug日志: 输出发送给LLM的prompt - logger.debug(f"🔍 [对话目标-分析初始目标] LLM Prompt:\n{prompt}") + # Debug日志: 输出发送给LLM的prompt + logger.debug(f" [对话目标-分析初始目标] LLM Prompt:\n{prompt}") - # ✅ 使用提炼模型(refine)进行目标分析 + # 使用提炼模型(refine)进行目标分析 response = await self.llm.refine_chat_completion( prompt=protected_prompt, temperature=0.3, max_tokens=200 ) - logger.debug(f"🔍 [对话目标-分析初始目标] LLM Response: {response}") + logger.debug(f" [对话目标-分析初始目标] LLM Response: {response}") # 消毒响应 try: sanitized_response, report = self.prompt_protection.sanitize_response(response) - logger.debug(f"🔍 [对话目标-分析初始目标] 消毒后响应: {sanitized_response}") + logger.debug(f" [对话目标-分析初始目标] 消毒后响应: {sanitized_response}") except Exception as sanitize_error: logger.error(f"消毒响应失败: {sanitize_error}", exc_info=True) - sanitized_response = response # 使用原始响应 + sanitized_response = response # 使用原始响应 - # ✅ 使用Guardrails Pydantic模型验证和解析JSON + # 使用Guardrails Pydantic模型验证和解析JSON try: # 直接解析已有的响应文本 parsed_result = self.guardrails.parse_json_direct( sanitized_response, - model_class=self.GoalAnalysisResult # 使用正确的模型引用 + model_class=self.GoalAnalysisResult # 使用正确的模型引用 ) if parsed_result: @@ -521,7 +521,7 @@ async def _analyze_initial_goal(self, user_message: str) -> Dict: "confidence": parsed_result.confidence, "reasoning": parsed_result.reasoning } - logger.debug(f"✅ [对话目标] Pydantic验证成功: goal_type={result['goal_type']}") + logger.debug(f" [对话目标] Pydantic验证成功: goal_type={result['goal_type']}") else: result = None @@ -594,25 +594,25 @@ async def _plan_dynamic_stages( # 使用提示词保护包装 protected_prompt = self.prompt_protection.wrap_prompt(prompt, register_for_filter=True) - # ✅ Debug日志: 输出发送给LLM的prompt - logger.debug(f"🔍 [对话目标-动态规划阶段] LLM Prompt:\n{prompt}") + # Debug日志: 输出发送给LLM的prompt + logger.debug(f" [对话目标-动态规划阶段] LLM Prompt:\n{prompt}") - # ✅ 使用提炼模型(refine)进行阶段规划 + # 使用提炼模型(refine)进行阶段规划 response = await self.llm.refine_chat_completion( prompt=protected_prompt, temperature=0.5, max_tokens=150 ) - logger.debug(f"🔍 [对话目标-动态规划阶段] LLM Response: {response}") + logger.debug(f" [对话目标-动态规划阶段] LLM Response: {response}") # 消毒响应 try: sanitized_response, report = self.prompt_protection.sanitize_response(response) - logger.debug(f"🔍 [对话目标-动态规划阶段] 消毒后响应: {sanitized_response}") + logger.debug(f" [对话目标-动态规划阶段] 消毒后响应: {sanitized_response}") except Exception as sanitize_error: logger.error(f"消毒响应失败: {sanitize_error}", exc_info=True) - sanitized_response = response # 使用原始响应 + sanitized_response = response # 使用原始响应 # 使用guardrails验证和清理JSON try: @@ -798,8 +798,8 @@ async def _analyze_conversation_intent( # 使用提示词保护包装 protected_prompt = self.prompt_protection.wrap_prompt(prompt, register_for_filter=True) - # ✅ Debug日志: 输出发送给LLM的prompt - logger.debug(f"🔍 [对话目标-意图分析] LLM Prompt:\n{prompt}") + # Debug日志: 输出发送给LLM的prompt + logger.debug(f" [对话目标-意图分析] LLM Prompt:\n{prompt}") response = await self.llm.refine_chat_completion( prompt=protected_prompt, @@ -807,22 +807,22 @@ async def _analyze_conversation_intent( max_tokens=300 ) - logger.debug(f"🔍 [对话目标-意图分析] LLM Response: {response}") + logger.debug(f" [对话目标-意图分析] LLM Response: {response}") # 消毒响应 try: sanitized_response, report = self.prompt_protection.sanitize_response(response) - logger.debug(f"🔍 [对话目标-意图分析] 消毒后响应: {sanitized_response}") + logger.debug(f" [对话目标-意图分析] 消毒后响应: {sanitized_response}") except Exception as sanitize_error: logger.error(f"消毒响应失败: {sanitize_error}", exc_info=True) - sanitized_response = response # 使用原始响应 + sanitized_response = response # 使用原始响应 - # ✅ 使用Guardrails Pydantic模型验证和解析JSON + # 使用Guardrails Pydantic模型验证和解析JSON try: # 直接解析已有的响应文本 parsed_result = self.guardrails.parse_json_direct( sanitized_response, - model_class=self.ConversationIntentAnalysis # 使用正确的模型引用 + model_class=self.ConversationIntentAnalysis # 使用正确的模型引用 ) if parsed_result: @@ -839,7 +839,7 @@ async def _analyze_conversation_intent( "user_engagement": parsed_result.user_engagement, "reasoning": parsed_result.reasoning } - logger.debug(f"✅ [对话目标] 意图分析Pydantic验证成功") + logger.debug(f" [对话目标] 意图分析Pydantic验证成功") else: analysis = None diff --git a/services/learning_quality_monitor.py b/services/quality/learning_quality_monitor.py similarity index 92% rename from services/learning_quality_monitor.py rename to services/quality/learning_quality_monitor.py index c4b12ed..da1a709 100644 --- a/services/learning_quality_monitor.py +++ b/services/quality/learning_quality_monitor.py @@ -1,6 +1,7 @@ """ 学习质量监控服务 - 监控学习效果,防止人格崩坏 """ +import asyncio import json import time import re # 移动到文件顶部 @@ -11,30 +12,30 @@ from astrbot.api import logger from astrbot.api.star import Context -from ..core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 +from ...core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 -from ..config import PluginConfig +from ...config import PluginConfig -from ..exceptions import StyleAnalysisError +from ...exceptions import StyleAnalysisError -from ..utils.json_utils import safe_parse_llm_json +from ...utils.json_utils import safe_parse_llm_json @dataclass class PersonaMetrics: """人格指标""" - consistency_score: float = 0.0 # 一致性得分 - style_stability: float = 0.0 # 风格稳定性 - vocabulary_diversity: float = 0.0 # 词汇多样性 - emotional_balance: float = 0.0 # 情感平衡性 - coherence_score: float = 0.0 # 逻辑连贯性 + consistency_score: float = 0.0 # 一致性得分 + style_stability: float = 0.0 # 风格稳定性 + vocabulary_diversity: float = 0.0 # 词汇多样性 + emotional_balance: float = 0.0 # 情感平衡性 + coherence_score: float = 0.0 # 逻辑连贯性 @dataclass class LearningAlert: """学习警报""" alert_type: str - severity: str # low, medium, high, critical + severity: str # low, medium, high, critical message: str timestamp: str metrics: Dict[str, float] @@ -55,9 +56,9 @@ def __init__(self, config: PluginConfig, context: Context, self.llm_adapter = llm_adapter # 监控阈值 - 调整为更合理的值 - self.consistency_threshold = 0.5 # 一致性阈值 (从0.7降低到0.5) - self.stability_threshold = 0.4 # 稳定性阈值 (从0.6降低到0.4) - self.drift_threshold = 0.4 # 风格偏移阈值 (从0.3提高到0.4) + self.consistency_threshold = 0.5 # 一致性阈值 (从0.7降低到0.5) + self.stability_threshold = 0.4 # 稳定性阈值 (从0.6降低到0.4) + self.drift_threshold = 0.4 # 风格偏移阈值 (从0.3提高到0.4) # 历史指标存储 self.historical_metrics: List[PersonaMetrics] = [] @@ -71,25 +72,19 @@ async def evaluate_learning_batch(self, learning_messages: List[Dict[str, Any]]) -> PersonaMetrics: """评估学习批次质量""" try: - # 计算各项指标 - consistency_score = await self._calculate_consistency( - original_persona, updated_persona - ) - - style_stability = await self._calculate_style_stability( - learning_messages - ) - - vocabulary_diversity = await self._calculate_vocabulary_diversity( - learning_messages - ) - - emotional_balance = await self._calculate_emotional_balance( - learning_messages - ) - - coherence_score = await self._calculate_coherence( - updated_persona + # 并行计算各项指标 + ( + consistency_score, + style_stability, + vocabulary_diversity, + emotional_balance, + coherence_score, + ) = await asyncio.gather( + self._calculate_consistency(original_persona, updated_persona), + self._calculate_style_stability(learning_messages), + self._calculate_vocabulary_diversity(learning_messages), + self._calculate_emotional_balance(learning_messages), + self._calculate_coherence(updated_persona), ) metrics = PersonaMetrics( @@ -126,10 +121,10 @@ async def _calculate_consistency(self, # 增强的空值检查和默认值处理 if not original_prompt and not updated_prompt: logger.debug("原始和更新人格都为空,返回中等一致性") - return 0.7 # 提高默认值,因为两者都空可以认为是一致的 + return 0.7 # 提高默认值,因为两者都空可以认为是一致的 elif not original_prompt or not updated_prompt: logger.debug("其中一个人格为空,返回较低一致性") - return 0.6 # 提高默认值,避免因数据问题导致的低分 + return 0.6 # 提高默认值,避免因数据问题导致的低分 # 如果两个prompt完全相同,直接返回高一致性 if original_prompt.strip() == updated_prompt.strip(): @@ -159,7 +154,7 @@ async def _calculate_consistency(self, r'一致性[::]\s*([0-9]*\.?[0-9]+)', r'得分[::]\s*([0-9]*\.?[0-9]+)', r'分数[::]\s*([0-9]*\.?[0-9]+)', - r'([0-9]*\.?[0-9]+)', # 任何数字 + r'([0-9]*\.?[0-9]+)', # 任何数字 ] for pattern in score_patterns: @@ -171,7 +166,7 @@ async def _calculate_consistency(self, if score > 1.0: score = score / 100.0 # 确保分数在合理范围内 - consistency_score = max(0.1, min(score, 1.0)) # 最低0.1,避免0.0 + consistency_score = max(0.1, min(score, 1.0)) # 最低0.1,避免0.0 logger.debug(f"解析得到一致性得分: {consistency_score}") return consistency_score except ValueError: @@ -190,16 +185,16 @@ async def _calculate_consistency(self, return 0.4 else: logger.debug("无法解析一致性评估,返回中等默认值") - return 0.6 # 提高默认值 + return 0.6 # 提高默认值 except (ValueError, IndexError) as e: logger.warning(f"解析一致性得分失败: {e}, 响应: {consistency_text}") - return 0.6 # 提高默认值 + return 0.6 # 提高默认值 else: logger.warning("LLM一致性评估无响应") - return 0.6 # 提高默认值 + return 0.6 # 提高默认值 except Exception as e: logger.error(f"框架适配器计算人格一致性失败: {e}") - return 0.6 # 提高默认值 + return 0.6 # 提高默认值 else: logger.warning("没有可用的Filter Provider,使用简单文本相似度计算") # 简单的文本相似度计算作为后备方案 @@ -207,7 +202,7 @@ async def _calculate_consistency(self, except Exception as e: logger.error(f"计算人格一致性失败: {e}") - return 0.6 # 提高默认值,避免阻塞学习 + return 0.6 # 提高默认值,避免阻塞学习 def _calculate_text_similarity(self, text1: str, text2: str) -> float: """计算文本相似度作为后备方案""" @@ -343,7 +338,7 @@ async def _calculate_emotional_balance(self, messages: List[Dict[str, Any]]) -> # 计算情感平衡性:积极情感减去消极情感,再调整到0-1范围 positive_score = emotional_scores.get("积极", 0.5) negative_score = emotional_scores.get("消极", 0.5) - balance_score = (positive_score - negative_score + 1.0) / 2.0 # 转换到0-1范围 + balance_score = (positive_score - negative_score + 1.0) / 2.0 # 转换到0-1范围 return max(0.0, min(balance_score, 1.0)) else: return self._simple_emotional_balance(messages) @@ -369,7 +364,7 @@ def _simple_emotional_balance(self, messages: List[Dict[str, Any]]) -> float: total_emotional = pos_count + neg_count if total_emotional == 0: - return 0.8 # 中性情感 + return 0.8 # 中性情感 # 计算平衡性(越接近0.5越平衡) pos_ratio = pos_count / total_emotional @@ -479,7 +474,7 @@ def _get_punctuation_ratio(self, text: str) -> float: def _count_emoji(self, text: str) -> int: """统计表情符号数量""" # 简单的表情符号检测 - emoji_patterns = ['😀', '😂', '😊', '🤔', '👍', '❤️', '🎉'] + emoji_patterns = ['', '', '', '', '', '', ''] count = 0 for emoji in emoji_patterns: count += text.count(emoji) diff --git a/services/quality/tiered_learning_trigger.py b/services/quality/tiered_learning_trigger.py new file mode 100644 index 0000000..af3a39c --- /dev/null +++ b/services/quality/tiered_learning_trigger.py @@ -0,0 +1,338 @@ +""" +Tiered learning trigger mechanism. + +Replaces the legacy fixed-threshold trigger system with a two-tier +architecture that separates lightweight per-message operations (Tier 1) +from LLM-heavy batch operations (Tier 2). + +Tier 1 (per message, < 5 ms each): + * Statistical jargon filter update + * Memory ingestion (mem0 / legacy) + * Knowledge graph ingestion (LightRAG / legacy) + * Exemplar candidate screening + +Tier 2 (batch, LLM-gated, cooldown-protected): + * Jargon meaning inference on top statistical candidates + * Social sentiment batch analysis + * Expression pattern learning + +Design notes: + - Each Tier 1 operation is executed with individual error isolation + so one failure cannot block the others. + - Tier 2 triggers are gated by *configurable* message-count + thresholds **and** wall-clock cooldowns; either condition can be + satisfied independently to handle both high-traffic and low-traffic + groups. + - An optional event-driven fast-path lets Tier 2 fire early when the + statistical filter detects a strong new-term signal. + - All state is per-group; no cross-group interference. + - Thread-safe for single-event-loop asyncio usage. +""" + +import asyncio +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple + +from astrbot.api import logger + +from ...core.interfaces import MessageData + + +# Type aliases + +# Internal alias: once registered, a callback is always a real callable. +_AsyncCallable = Callable[..., Coroutine[Any, Any, Any]] + +# Public-facing alias: accepts None from callers to allow conditional wiring. +_OptionalAsyncCallback = Optional[_AsyncCallable] + + +# Per-group trigger state + +@dataclass +class _GroupTriggerState: + """Mutable per-group state tracked by the trigger.""" + + # Counters + message_count: int = 0 + total_processed: int = 0 + + # Per-operation last-execution timestamps (keyed by operation name). + last_op_times: Dict[str, float] = field(default_factory=dict) + + # Accumulated interactions for social sentiment batch. + pending_interactions: List[Dict[str, str]] = field(default_factory=list) + + # Consecutive Tier 1 failure count for observability. + consecutive_tier1_errors: int = 0 + + +# Tier 2 trigger policy + +@dataclass(frozen=True) +class BatchTriggerPolicy: + """Configurable policy for gating Tier 2 batch operations. + + A Tier 2 operation is triggered when **either** the message-count + threshold **or** the maximum time interval is reached, whichever + comes first. This ensures both high-traffic groups (hit count + quickly) and low-traffic groups (hit time limit) get timely + processing. + """ + + message_threshold: int = 15 + cooldown_seconds: float = 120.0 + + +# Result container + +@dataclass +class TriggerResult: + """Outcome of a ``process_message`` invocation.""" + + tier1_ok: bool = True + tier1_details: Dict[str, bool] = field(default_factory=dict) + tier2_triggered: bool = False + tier2_details: Dict[str, bool] = field(default_factory=dict) + + +# Main class + +class TieredLearningTrigger: + """Orchestrates tiered learning operations for incoming messages. + + Usage:: + + trigger = TieredLearningTrigger() + trigger.register_tier1("memory", memory_manager.add_memory_from_message) + trigger.register_tier2("jargon", jargon_batch_callback, policy) + result = await trigger.process_message(message, group_id) + """ + + def __init__(self) -> None: + # Per-group mutable state. + self._states: Dict[str, _GroupTriggerState] = {} + + # Registered operations. + # Tier 1: name -> async callable(message, group_id) + self._tier1_ops: Dict[str, _AsyncCallable] = {} + # Tier 2: name -> (async callable(group_id), policy) + self._tier2_ops: Dict[str, Tuple[_AsyncCallable, BatchTriggerPolicy]] = {} + + # Registration + + def register_tier1( + self, + name: str, + callback: _OptionalAsyncCallback, + ) -> None: + """Register a per-message Tier 1 operation. + + The callback signature must be:: + + async def callback(message: MessageData, group_id: str) -> None + + Callbacks are executed concurrently for every incoming message. + Errors in one callback do not affect others. + """ + if callback is None: + return + if not asyncio.iscoroutinefunction(callback): + raise TypeError( + f"Tier 1 callback '{name}' must be an async function, " + f"got {type(callback)!r}" + ) + self._tier1_ops[name] = callback + logger.debug(f"[TieredTrigger] Registered Tier 1 op: {name}") + + def register_tier2( + self, + name: str, + callback: _OptionalAsyncCallback, + policy: Optional[BatchTriggerPolicy] = None, + ) -> None: + """Register a batch Tier 2 operation. + + The callback signature must be:: + + async def callback(group_id: str) -> None + + The operation fires when the group's message count exceeds + ``policy.message_threshold`` **or** ``policy.cooldown_seconds`` + have elapsed since the last execution, whichever comes first. + """ + if callback is None: + return + if not asyncio.iscoroutinefunction(callback): + raise TypeError( + f"Tier 2 callback '{name}' must be an async function, " + f"got {type(callback)!r}" + ) + self._tier2_ops[name] = ( + callback, + policy or BatchTriggerPolicy(), + ) + logger.debug(f"[TieredTrigger] Registered Tier 2 op: {name}") + + # Main entry point + + async def process_message( + self, + message: MessageData, + group_id: str, + ) -> TriggerResult: + """Process an incoming message through all registered tiers. + + Returns a :class:`TriggerResult` summarising what was executed. + """ + state = self._get_state(group_id) + result = TriggerResult() + + # ---- Tier 1: always execute (concurrent, error-isolated) ---- + result.tier1_details = await self._execute_tier1( + message, group_id, state + ) + # tier1_ok is True only when at least one op ran and all succeeded. + result.tier1_ok = ( + bool(result.tier1_details) + and all(result.tier1_details.values()) + ) + + # Update counters. + state.message_count += 1 + state.total_processed += 1 + + # ---- Tier 2: check each registered batch operation ---- + # Each operation has its own counter/cooldown gate. When any + # operation fires, the shared message counter resets so that + # all Tier 2 ops start their count window fresh. The time-based + # fallback ensures low-traffic groups still trigger eventually. + now = time.time() + for name, (callback, policy) in self._tier2_ops.items(): + last_time = state.last_op_times.get(name, 0.0) + count_ok = state.message_count >= policy.message_threshold + time_ok = (now - last_time) >= policy.cooldown_seconds + + if count_ok or time_ok: + ok = await self._execute_tier2_op( + name, callback, group_id, state + ) + result.tier2_details[name] = ok + result.tier2_triggered = True + + if result.tier2_triggered: + state.message_count = 0 + + return result + + # Event-driven fast-path + + async def force_tier2( + self, + name: str, + group_id: str, + ) -> bool: + """Force-trigger a specific Tier 2 operation outside the normal + schedule (e.g. when the statistical filter detects a strong + new-term signal). + + Returns ``True`` if the operation executed successfully. + """ + if name not in self._tier2_ops: + return False + + state = self._get_state(group_id) + callback, _ = self._tier2_ops[name] + return await self._execute_tier2_op(name, callback, group_id, state) + + # Inspection / statistics + + def get_group_stats(self, group_id: str) -> Dict[str, Any]: + """Return trigger statistics for a group.""" + state = self._states.get(group_id) + if not state: + return {"active": False} + + return { + "active": True, + "message_count": state.message_count, + "total_processed": state.total_processed, + "last_op_times": dict(state.last_op_times), + "pending_interactions": len(state.pending_interactions), + "consecutive_tier1_errors": state.consecutive_tier1_errors, + } + + # Internals + + def _get_state(self, group_id: str) -> _GroupTriggerState: + if group_id not in self._states: + # Initialise last_op_times to "now" so that Tier 2 operations + # do not fire on the very first message of a new group. + state = _GroupTriggerState() + now = time.time() + for name in self._tier2_ops: + state.last_op_times[name] = now + self._states[group_id] = state + return self._states[group_id] + + async def _execute_tier1( + self, + message: MessageData, + group_id: str, + state: _GroupTriggerState, + ) -> Dict[str, bool]: + """Run all Tier 1 operations concurrently with error isolation.""" + if not self._tier1_ops: + return {} + + names = list(self._tier1_ops.keys()) + callbacks = list(self._tier1_ops.values()) + + async def _safe_run(op_name: str, cb: _AsyncCallable) -> bool: + try: + await cb(message, group_id) + return True + except Exception as exc: + logger.debug( + f"[TieredTrigger] Tier 1 op '{op_name}' failed: {exc}" + ) + return False + + results = await asyncio.gather( + *(_safe_run(n, c) for n, c in zip(names, callbacks)), + return_exceptions=False, + ) + + details = dict(zip(names, results)) + + # Track consecutive failures for observability. + if not all(results): + state.consecutive_tier1_errors += 1 + else: + state.consecutive_tier1_errors = 0 + + return details + + async def _execute_tier2_op( + self, + name: str, + callback: _AsyncCallable, + group_id: str, + state: _GroupTriggerState, + ) -> bool: + """Execute a single Tier 2 operation with error handling.""" + try: + await callback(group_id) + state.last_op_times[name] = time.time() + logger.debug( + f"[TieredTrigger] Tier 2 op '{name}' completed for " + f"group {group_id}" + ) + return True + except Exception as exc: + logger.warning( + f"[TieredTrigger] Tier 2 op '{name}' failed for " + f"group {group_id}: {exc}" + ) + return False diff --git a/services/reranker/__init__.py b/services/reranker/__init__.py new file mode 100644 index 0000000..c84edb5 --- /dev/null +++ b/services/reranker/__init__.py @@ -0,0 +1,28 @@ +""" +Reranker provider abstraction layer. + +Provides a plugin-level ``IRerankProvider`` interface that delegates to +AstrBot framework's ``RerankProvider`` via a thin adapter. + +Public API:: + + from services.reranker import ( + IRerankProvider, + RerankResult, + RerankProviderError, + RerankProviderFactory, + FrameworkRerankAdapter, + ) +""" + +from .base import IRerankProvider, RerankProviderError, RerankResult +from .factory import RerankProviderFactory +from .framework_adapter import FrameworkRerankAdapter + +__all__ = [ + "IRerankProvider", + "RerankResult", + "RerankProviderError", + "RerankProviderFactory", + "FrameworkRerankAdapter", +] diff --git a/services/reranker/base.py b/services/reranker/base.py new file mode 100644 index 0000000..ef86053 --- /dev/null +++ b/services/reranker/base.py @@ -0,0 +1,67 @@ +""" +Reranker provider interface and value objects. + +Defines the abstract contract for document reranking, aligned with +AstrBot framework's ``RerankProvider`` interface. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass(frozen=True) +class RerankResult: + """Single reranking result. + + Attributes: + index: Original index in the candidate document list. + relevance_score: Relevance score assigned by the reranker. + """ + + index: int + relevance_score: float + + +class IRerankProvider(ABC): + """Abstract reranker provider interface. + + Method signatures are aligned with AstrBot framework's + ``RerankProvider`` to allow zero-transformation delegation. + """ + + @abstractmethod + async def rerank( + self, + query: str, + documents: List[str], + top_n: Optional[int] = None, + ) -> List[RerankResult]: + """Rerank documents by relevance to the query. + + Args: + query: The query string. + documents: List of candidate document texts. + top_n: Maximum number of results to return. + If ``None``, returns all documents ranked. + + Returns: + Sorted list of ``RerankResult`` (highest relevance first). + + Raises: + RerankProviderError: On provider communication failure. + """ + + @abstractmethod + def get_model_name(self) -> str: + """Return the model identifier string.""" + + async def close(self) -> None: + """Release any resources held by the provider. + + Default implementation is a no-op. + """ + + +class RerankProviderError(Exception): + """Raised when a reranker provider encounters an unrecoverable error.""" diff --git a/services/reranker/factory.py b/services/reranker/factory.py new file mode 100644 index 0000000..956ff80 --- /dev/null +++ b/services/reranker/factory.py @@ -0,0 +1,82 @@ +""" +Reranker provider factory. + +Creates ``IRerankProvider`` instances by resolving AstrBot framework +providers via ``context.get_provider_by_id(provider_id)``. +""" + +from typing import Optional + +from astrbot.api import logger +from astrbot.core.provider.provider import RerankProvider as FrameworkRerankProvider + +from .base import IRerankProvider +from .framework_adapter import FrameworkRerankAdapter + + +class RerankProviderFactory: + """Factory for creating reranker provider instances. + + Usage:: + + reranker = RerankProviderFactory.create(config, context) + if reranker: + results = await reranker.rerank("query", ["doc1", "doc2"]) + """ + + @staticmethod + def create(config, context) -> Optional[IRerankProvider]: + """Create a reranker provider from plugin configuration. + + Args: + config: ``PluginConfig`` instance with ``rerank_provider_id``. + context: AstrBot plugin context. + + Returns: + An ``IRerankProvider`` instance, or ``None`` if not configured. + """ + provider_id = getattr(config, "rerank_provider_id", None) + + if not provider_id: + logger.debug( + "[RerankFactory] No rerank_provider_id configured, " + "reranking disabled" + ) + return None + + if context is None: + logger.warning( + "[RerankFactory] AstrBot context is None, " + "cannot resolve reranker provider" + ) + return None + + try: + provider = context.get_provider_by_id(provider_id) + except Exception as exc: + logger.warning( + f"[RerankFactory] Failed to look up provider " + f"'{provider_id}': {exc}" + ) + return None + + if provider is None: + logger.warning( + f"[RerankFactory] Provider '{provider_id}' not found " + f"in framework registry" + ) + return None + + if not isinstance(provider, FrameworkRerankProvider): + logger.warning( + f"[RerankFactory] Provider '{provider_id}' is " + f"{type(provider).__name__}, expected RerankProvider" + ) + return None + + adapter = FrameworkRerankAdapter(provider) + logger.info( + f"[RerankFactory] Resolved reranker provider: " + f"id={provider_id}, model={adapter.get_model_name()}" + ) + return adapter diff --git a/services/reranker/framework_adapter.py b/services/reranker/framework_adapter.py new file mode 100644 index 0000000..8079129 --- /dev/null +++ b/services/reranker/framework_adapter.py @@ -0,0 +1,65 @@ +""" +Framework reranker adapter. + +Thin adapter wrapping AstrBot's ``RerankProvider`` behind the plugin's +``IRerankProvider`` interface. Translates framework ``RerankResult`` +to the plugin's own dataclass to avoid tight coupling. +""" + +from typing import List, Optional + +from astrbot.api import logger +from astrbot.core.provider.provider import RerankProvider as FrameworkRerankProvider +from astrbot.core.provider.entities import RerankResult as FrameworkRerankResult + +from .base import IRerankProvider, RerankResult, RerankProviderError + + +class FrameworkRerankAdapter(IRerankProvider): + """Adapter bridging AstrBot ``RerankProvider`` → plugin ``IRerankProvider``. + + Args: + provider: A fully-initialised AstrBot ``RerankProvider`` instance. + """ + + def __init__(self, provider: FrameworkRerankProvider) -> None: + if provider is None: + raise ValueError("provider must not be None") + self._provider = provider + + async def rerank( + self, + query: str, + documents: List[str], + top_n: Optional[int] = None, + ) -> List[RerankResult]: + try: + framework_results: List[FrameworkRerankResult] = ( + await self._provider.rerank(query, documents, top_n) + ) + return [ + RerankResult( + index=r.index, + relevance_score=r.relevance_score, + ) + for r in framework_results + ] + except Exception as exc: + raise RerankProviderError( + f"Framework rerank call failed: {exc}" + ) from exc + + def get_model_name(self) -> str: + return self._provider.get_model() + + async def close(self) -> None: + # Framework manages its own provider lifecycle; nothing to release. + pass + + @property + def provider_id(self) -> str: + """Return the framework provider's unique identifier.""" + try: + return self._provider.meta().id + except (ValueError, KeyError): + return "" diff --git a/services/response/__init__.py b/services/response/__init__.py new file mode 100644 index 0000000..1f00dc7 --- /dev/null +++ b/services/response/__init__.py @@ -0,0 +1,15 @@ +"""Response generation, diversity, and quality control.""" + +from .prompt_sanitizer import PromptProtectionService +from .intelligent_chat_service import IntelligentChatService +from .response_diversity_manager import ResponseDiversityManager +from .style_analyzer import StyleAnalyzerService +from .intelligent_responder import IntelligentResponder + +__all__ = [ + "PromptProtectionService", + "IntelligentChatService", + "ResponseDiversityManager", + "StyleAnalyzerService", + "IntelligentResponder", +] diff --git a/services/intelligent_chat_service.py b/services/response/intelligent_chat_service.py similarity index 100% rename from services/intelligent_chat_service.py rename to services/response/intelligent_chat_service.py diff --git a/services/intelligent_responder.py b/services/response/intelligent_responder.py similarity index 85% rename from services/intelligent_responder.py rename to services/response/intelligent_responder.py index 41faad2..7c3446d 100644 --- a/services/intelligent_responder.py +++ b/services/response/intelligent_responder.py @@ -12,11 +12,11 @@ from astrbot.api.event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType -from ..core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 +from ...core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 -from ..config import PluginConfig +from ...config import PluginConfig -from ..exceptions import ResponseError +from ...exceptions import ResponseError class IntelligentResponder: @@ -29,8 +29,8 @@ class IntelligentResponder: RECENT_MESSAGES_LIMIT = 5 PROMPT_MESSAGE_LENGTH_LIMIT = 50 PROMPT_RESPONSE_WORD_LIMIT = 100 - DAILY_RESPONSE_STATS_PERIOD_SECONDS = 86400 # 24小时 - GROUP_ATMOSPHERE_PERIOD_SECONDS = 3600 # 1小时 + DAILY_RESPONSE_STATS_PERIOD_SECONDS = 86400 # 24小时 + GROUP_ATMOSPHERE_PERIOD_SECONDS = 3600 # 1小时 GROUP_ACTIVITY_HIGH_THRESHOLD = 10 def __init__(self, config: PluginConfig, context: Context, db_manager, @@ -41,16 +41,16 @@ def __init__(self, config: PluginConfig, context: Context, db_manager, self.context = context self.db_manager = db_manager self.prompts = prompts - self.affection_manager = affection_manager # 添加好感度管理器 - self.diversity_manager = diversity_manager # 添加多样性管理器 - self.social_context_injector = social_context_injector # 添加社交上下文注入器 + self.affection_manager = affection_manager # 添加好感度管理器 + self.diversity_manager = diversity_manager # 添加多样性管理器 + self.social_context_injector = social_context_injector # 添加社交上下文注入器 # 使用框架适配器 self.llm_adapter = llm_adapter # 设置默认回复策略 - 不依赖配置文件 - self.enable_intelligent_reply = True # 默认启用智能回复 - self.context_window_size = 5 # 默认上下文窗口大小 + self.enable_intelligent_reply = True # 默认启用智能回复 + self.context_window_size = 5 # 默认上下文窗口大小 logger.info("智能回复器初始化完成 - 使用默认配置") @@ -116,7 +116,7 @@ async def generate_intelligent_response_text(self, event: AstrMessageEvent) -> O """生成自学习可能需要用到的的智能回复文本(修改版 - 增量更新在SYSTEM_PROMPT中)""" try: sender_id = event.get_sender_id() - group_id = event.get_group_id() or event.get_sender_id() # 私聊时使用 sender_id 作为会话 ID + group_id = event.get_group_id() or event.get_sender_id() # 私聊时使用 sender_id 作为会话 ID message_text = event.get_message_str() # 收集上下文信息 @@ -134,11 +134,11 @@ async def generate_intelligent_response_text(self, event: AstrMessageEvent) -> O logger.info(f"开始注入多样性增强到system_prompt (当前长度: {len(enhanced_system_prompt)})") enhanced_system_prompt = await self.diversity_manager.build_diversity_prompt_injection( enhanced_system_prompt, - group_id=group_id, # ✅ 传入group_id以获取历史消息 + group_id=group_id, # 传入group_id以获取历史消息 inject_style=True, inject_pattern=True, inject_variation=True, - inject_history=True # ✅ 注入历史Bot消息,避免重复 + inject_history=True # 注入历史Bot消息,避免重复 ) logger.info(f"多样性注入后system_prompt长度: {len(enhanced_system_prompt)}") @@ -156,7 +156,7 @@ async def generate_intelligent_response_text(self, event: AstrMessageEvent) -> O randomize=True ) else: - temperature = 0.7 # 默认值 + temperature = 0.7 # 默认值 # 调用框架的默认LLM provider = self.context.get_using_provider() @@ -167,7 +167,7 @@ async def generate_intelligent_response_text(self, event: AstrMessageEvent) -> O # 使用框架适配器 if self.llm_adapter and self.llm_adapter.has_refine_provider(): try: - # ✅ 将enhanced_system_prompt合并到prompt参数中,而不是使用system_prompt参数 + # 将enhanced_system_prompt合并到prompt参数中,而不是使用system_prompt参数 # 这样可以确保所有Provider都能看到完整的增强内容 combined_prompt = f"{enhanced_system_prompt}\n\n【当前用户消息】\n{message_text}" @@ -177,16 +177,16 @@ async def generate_intelligent_response_text(self, event: AstrMessageEvent) -> O logger.debug(f"多样性增强部分长度: {len(enhanced_system_prompt)}, 用户消息长度: {len(message_text)}") response = await self.llm_adapter.refine_chat_completion( - prompt=combined_prompt, # 包含增强系统提示词 + 用户消息 - system_prompt=None, # 不使用system_prompt参数,避免Provider兼容性问题 - temperature=temperature, # 动态temperature + prompt=combined_prompt, # 包含增强系统提示词 + 用户消息 + system_prompt=None, # 不使用system_prompt参数,避免Provider兼容性问题 + temperature=temperature, # 动态temperature max_tokens=self.PROMPT_RESPONSE_WORD_LIMIT ) if response: response_text = response.strip() - # ✅ 提示词保护:消毒LLM回复,移除泄露的提示词 + # 提示词保护:消毒LLM回复,移除泄露的提示词 if self.diversity_manager: try: sanitized_response, sanitize_report = self.diversity_manager.sanitize_llm_response(response_text) @@ -197,13 +197,13 @@ async def generate_intelligent_response_text(self, event: AstrMessageEvent) -> O except Exception as sanitize_error: logger.warning(f"回复消毒失败(不影响回复): {sanitize_error}") - # ✅ 保存Bot消息到数据库 (用于多样性分析和避免同质化) + # 保存Bot消息到数据库 (用于多样性分析和避免同质化) try: await self.db_manager.save_bot_message( group_id=group_id, user_id=sender_id, message=response_text, - response_to_message_id=None, # TODO: 可以关联原始消息ID + response_to_message_id=None, # TODO: 可以关联原始消息ID context_type='normal', temperature=temperature, language_style=current_language_style, @@ -233,7 +233,7 @@ async def generate_intelligent_response(self, event: AstrMessageEvent) -> Option """生成智能回复参数,用于传递给框架的request_llm""" try: sender_id = event.get_sender_id() - group_id = event.get_group_id() or event.get_sender_id() # 私聊时使用 sender_id 作为会话 ID + group_id = event.get_group_id() or event.get_sender_id() # 私聊时使用 sender_id 作为会话 ID message_text = event.get_message_str() logger.info(f"[生成智能回复] 开始处理: group_id={group_id}, sender_id={sender_id}, message_len={len(message_text)}") @@ -258,11 +258,11 @@ async def generate_intelligent_response(self, event: AstrMessageEvent) -> Option # 参数验证 if not enhanced_prompt or len(enhanced_prompt) == 0: - logger.error(f"[生成智能回复] ❌ 增强提示词为空!") + logger.error(f"[生成智能回复] 增强提示词为空!") return None if not curr_cid: - logger.error(f"[生成智能回复] ❌ 会话ID为空!") + logger.error(f"[生成智能回复] 会话ID为空!") return None # 返回request_llm所需的参数 @@ -272,7 +272,7 @@ async def generate_intelligent_response(self, event: AstrMessageEvent) -> Option 'conversation': conversation } - logger.info(f"[生成智能回复] ✅ 智能回复参数生成成功: prompt_len={len(enhanced_prompt)}, conversation_len={len(conversation)}, session_id={curr_cid}") + logger.info(f"[生成智能回复] 智能回复参数生成成功: prompt_len={len(enhanced_prompt)}, conversation_len={len(conversation)}, session_id={curr_cid}") return result except Exception as e: @@ -282,8 +282,8 @@ async def generate_intelligent_response(self, event: AstrMessageEvent) -> Option async def _collect_context_info(self, group_id: str, sender_id: str, message: str) -> Dict[str, Any]: """收集上下文信息""" context_info = { - 'group_id': group_id, # 添加group_id字段 - 'sender_id': sender_id, # 添加sender_id字段 + 'group_id': group_id, # 添加group_id字段 + 'sender_id': sender_id, # 添加sender_id字段 'sender_profile': None, 'user_affection': None, 'social_relations': [], @@ -304,7 +304,7 @@ async def _collect_context_info(self, group_id: str, sender_id: str, message: st context_info['social_relations'] = [ rel for rel in all_relations if rel['from_user'] == sender_id or rel['to_user'] == sender_id - ][:5] # 限制前5个最强关系 + ][:5] # 限制前5个最强关系 # 获取最近的筛选消息 context_info['recent_messages'] = await self.db_manager.get_recent_filtered_messages(group_id, 5) @@ -329,7 +329,7 @@ async def _build_enhanced_system_prompt(self, context_info: Dict[str, Any]) -> s """ try: # 1. 获取基础人格设定(原有的SYSTEM_PROMPT) - base_system_prompt = "你是一个友好、智能的助手。" # 默认 + base_system_prompt = "你是一个友好、智能的助手。" # 默认 try: persona = await self.context.persona_manager.get_default_persona_v3() @@ -391,9 +391,9 @@ async def _build_enhanced_system_prompt(self, context_info: Dict[str, Any]) -> s include_social_relations=getattr(self.config, 'include_social_relations', True), include_affection=getattr(self.config, 'include_affection_info', True), include_mood=getattr(self.config, 'include_mood_info', True), - include_expression_patterns=True # ✅ 启用表达模式注入 + include_expression_patterns=True # 启用表达模式注入 ) - logger.debug("✅ 社交上下文(含表达模式)已成功注入到系统提示词") + logger.debug(" 社交上下文(含表达模式)已成功注入到系统提示词") except Exception as e: logger.warning(f"社交上下文注入失败: {e}", exc_info=True) @@ -509,7 +509,7 @@ async def _build_context_enhancement(self, context_info: Dict[str, Any]) -> str: # 3. 社交关系图谱(增强版) if context_info.get('social_relations'): relations_details = [] - for rel in context_info['social_relations'][:5]: # 显示前5个关系 + for rel in context_info['social_relations'][:5]: # 显示前5个关系 strength_desc = "强" if rel['strength'] > 0.7 else "中" if rel['strength'] > 0.4 else "弱" relations_details.append( f"- 与{rel.get('to_user', '未知用户')}的关系强度: {rel['strength']:.2f}({strength_desc}), " @@ -536,7 +536,7 @@ async def _build_context_enhancement(self, context_info: Dict[str, Any]) -> str: # 5. 最近对话上下文(更详细) if context_info.get('recent_messages'): recent_context = [] - for i, msg in enumerate(context_info['recent_messages'][-5:], 1): # 最近5条 + for i, msg in enumerate(context_info['recent_messages'][-5:], 1): # 最近5条 quality_score = msg.get('quality_scores', {}) msg_quality = "高质量" if isinstance(quality_score, dict) and quality_score.get('overall', 0) > 0.7 else "普通" recent_context.append( @@ -588,7 +588,7 @@ async def _build_enhanced_prompt(self, context_info: Dict[str, Any], message: st prompt_parts.append("你正在参与一个真实的群聊对话,需要基于以下详细上下文信息进行自然、智能的回复:") # 2. 当前人格状态 - 获取完整的人格信息(包含增量更新) - current_persona = "你是一个友好、智能的助手。" # 默认人格 + current_persona = "你是一个友好、智能的助手。" # 默认人格 persona_updates_info = "" try: @@ -607,7 +607,7 @@ async def _build_enhanced_prompt(self, context_info: Dict[str, Any], message: st update_pattern = r'【增量更新[^】]*】[^【]*' updates = re.findall(update_pattern, current_persona) if updates: - persona_updates_info = f"\n\n【当前活跃的人格增量更新】:\n" + "\n".join(updates[-3:]) # 取最近3个更新 + persona_updates_info = f"\n\n【当前活跃的人格增量更新】:\n" + "\n".join(updates[-3:]) # 取最近3个更新 logger.debug(f"获取到当前人格设定长度: {len(current_persona)} 字符") @@ -664,28 +664,26 @@ async def _get_conversation_context(self, group_id: str, sender_id: str) -> List async def _record_response(self, group_id: str, sender_id: str, original_message: str, response: str): """记录回复信息用于学习""" try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - + async with self.db_manager.get_session() as session: + from sqlalchemy import select + from ...models.orm.message import FilteredMessage + + now = int(time.time()) # 简化实现:filtered_messages 表用于记录所有经过筛选的消息,包括BOT的回复。 # 实际应用中,可能需要为BOT回复创建单独的表以区分。 - await cursor.execute(''' - INSERT OR IGNORE INTO filtered_messages - (message, sender_id, group_id, confidence, filter_reason, timestamp, used_for_learning) - VALUES (?, ?, ?, ?, ?, ?, ?) - ''', ( - f"BOT回复: {response}", - "bot", - group_id, # 添加 group_id 字段 - 1.0, # 假设BOT回复的置信度为1.0 - f"回复{sender_id}: {original_message[:self.PROMPT_MESSAGE_LENGTH_LIMIT]}", # 使用常量 - time.time(), - False # BOT回复不用于学习,避免循环学习 - )) - - await conn.commit() - await cursor.close() - + filtered_msg = FilteredMessage( + message=f"BOT回复: {response}", + sender_id="bot", + group_id=group_id, + confidence=1.0, # 假设BOT回复的置信度为1.0 + filter_reason=f"回复{sender_id}: {original_message[:self.PROMPT_MESSAGE_LENGTH_LIMIT]}", # 使用常量 + timestamp=now, + created_at=now, + processed=False, # BOT回复不用于学习,避免循环学习 + ) + session.add(filtered_msg) + await session.commit() + except Exception as e: logger.error(f"记录回复失败: {e}") @@ -706,10 +704,10 @@ async def send_intelligent_response(self, event: AstrMessageEvent): try: response_params = await self.generate_intelligent_response(event) except ResponseError as re: - logger.error(f"[智能回复] ❌ 生成回复参数时发生ResponseError: {re}") + logger.error(f"[智能回复] 生成回复参数时发生ResponseError: {re}") return None except Exception as gen_error: - logger.error(f"[智能回复] ❌ 生成回复参数时发生未知错误: {gen_error}", exc_info=True) + logger.error(f"[智能回复] 生成回复参数时发生未知错误: {gen_error}", exc_info=True) return None if response_params: @@ -718,15 +716,15 @@ async def send_intelligent_response(self, event: AstrMessageEvent): # 验证关键参数 if not response_params.get('prompt'): - logger.error(f"[智能回复] ❌ prompt参数为空,无法发送回复") + logger.error(f"[智能回复] prompt参数为空,无法发送回复") return None if not response_params.get('session_id'): - logger.error(f"[智能回复] ❌ session_id参数为空,无法发送回复") + logger.error(f"[智能回复] session_id参数为空,无法发送回复") return None - logger.info(f"[智能回复] ✅ 参数验证通过,准备返回给main.py") - return response_params # 返回request_llm参数 + logger.info(f"[智能回复] 参数验证通过,准备返回给main.py") + return response_params # 返回request_llm参数 else: logger.warning(f"[智能回复] generate_intelligent_response 返回None") return None @@ -738,24 +736,27 @@ async def send_intelligent_response(self, event: AstrMessageEvent): async def get_response_statistics(self, group_id: str) -> Dict[str, Any]: """获取回复统计""" try: - conn = await self.db_manager.get_group_connection(group_id) - cursor = await conn.cursor() - - # 统计BOT回复次数 - await cursor.execute(''' - SELECT COUNT(*) - FROM filtered_messages - WHERE sender_id = 'bot' AND timestamp > ? - ''', (time.time() - self.DAILY_RESPONSE_STATS_PERIOD_SECONDS,)) # 最近24小时 - - row = await cursor.fetchone() - daily_responses = row[0] if row else 0 - - return { - 'daily_responses': daily_responses, - 'intelligent_reply_enabled': self.enable_intelligent_reply - } - + async with self.db_manager.get_session() as session: + from sqlalchemy import select, func + from ...models.orm.message import FilteredMessage + + # 统计BOT回复次数 + cutoff = time.time() - self.DAILY_RESPONSE_STATS_PERIOD_SECONDS + stmt = ( + select(func.count()) + .select_from(FilteredMessage) + .where( + FilteredMessage.sender_id == 'bot', + FilteredMessage.timestamp > cutoff, + ) + ) + daily_responses = (await session.execute(stmt)).scalar() or 0 + + return { + 'daily_responses': daily_responses, + 'intelligent_reply_enabled': self.enable_intelligent_reply + } + except Exception as e: logger.error(f"获取回复统计失败: {e}") return {} @@ -763,31 +764,29 @@ async def get_response_statistics(self, group_id: str) -> Dict[str, Any]: async def _analyze_group_atmosphere(self, group_id: str) -> Dict[str, Any]: """分析群氛围""" try: - # 从全局消息数据库获取连接 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - + async with self.db_manager.get_session() as session: + from sqlalchemy import select, func + from ...models.orm.message import RawMessage + # 分析最近消息的情感倾向 - await cursor.execute(''' - SELECT COUNT(*) as total_messages, - AVG(LENGTH(message)) as avg_length - FROM raw_messages - WHERE timestamp > ? - ''', (time.time() - self.GROUP_ATMOSPHERE_PERIOD_SECONDS,)) # 最近1小时 - - row = await cursor.fetchone() - - await cursor.close() - - total_messages = row[0] if row else 0 - avg_length = row[1] if row else 0.0 - + cutoff = time.time() - self.GROUP_ATMOSPHERE_PERIOD_SECONDS + stmt = select( + func.count().label('total_messages'), + func.avg(func.length(RawMessage.message)).label('avg_length'), + ).select_from(RawMessage).where( + RawMessage.timestamp > cutoff, + ) + row = (await session.execute(stmt)).one() + + total_messages = row.total_messages or 0 + avg_length = row.avg_length or 0.0 + return { 'activity_level': 'high' if total_messages > self.GROUP_ACTIVITY_HIGH_THRESHOLD else 'low', 'avg_message_length': avg_length, 'total_recent_messages': total_messages } - + except Exception as e: logger.error(f"分析群氛围失败: {e}") return {'activity_level': 'unknown'} diff --git a/services/prompt_sanitizer.py b/services/response/prompt_sanitizer.py similarity index 100% rename from services/prompt_sanitizer.py rename to services/response/prompt_sanitizer.py diff --git a/services/response_diversity_manager.py b/services/response/response_diversity_manager.py similarity index 96% rename from services/response_diversity_manager.py rename to services/response/response_diversity_manager.py index 1ab4f4a..f7c6048 100644 --- a/services/response_diversity_manager.py +++ b/services/response/response_diversity_manager.py @@ -27,10 +27,10 @@ def __init__(self, config, db_manager): # Temperature动态范围配置 self.temperature_ranges = { - 'creative': (0.8, 1.2), # 创意型回复 - 'normal': (0.6, 0.9), # 正常对话 - 'precise': (0.3, 0.6), # 精确分析 - 'stable': (0.2, 0.4) # 稳定输出 + 'creative': (0.8, 1.2), # 创意型回复 + 'normal': (0.6, 0.9), # 正常对话 + 'precise': (0.3, 0.6), # 精确分析 + 'stable': (0.2, 0.4) # 稳定输出 } # 语言风格池(定期轮换) @@ -51,7 +51,7 @@ def __init__(self, config, db_manager): # 提示词保护服务(延迟加载) self._prompt_protection = None - self._enable_protection = True # 默认启用保护 + self._enable_protection = True # 默认启用保护 # 当前使用的风格和模式 (用于保存到数据库) self.current_language_style = None @@ -144,7 +144,7 @@ def get_dynamic_temperature(self, context_type: str = 'normal', randomize: bool except Exception as e: logger.error(f"获取动态Temperature失败: {e}") - return 0.7 # 默认值 + return 0.7 # 默认值 def get_random_language_style(self, avoid_recent: bool = True) -> str: """ @@ -241,12 +241,12 @@ async def build_diversity_prompt_injection(self, base_prompt: str, if inject_style: style = self.get_random_language_style() - self.current_language_style = style # ✅ 保存当前风格 + self.current_language_style = style # 保存当前风格 raw_prompts.append(f"当前语言风格:{style}") if inject_pattern: pattern = self.get_random_response_pattern() - self.current_response_pattern = pattern # ✅ 保存当前模式 + self.current_response_pattern = pattern # 保存当前模式 raw_prompts.append(f"推荐回复模式:{pattern}") if inject_variation: @@ -278,7 +278,7 @@ async def build_diversity_prompt_injection(self, base_prompt: str, history_text += "- 如果观点相似,也要用不同的表达方式,建议用一定的合理的倒装句、省略句等" raw_prompts.append(history_text) - logger.info(f"✅ 已注入 {len(recent_responses)} 条历史Bot消息到多样性提示") + logger.info(f" 已注入 {len(recent_responses)} 条历史Bot消息到多样性提示") else: logger.debug(f"群组 {group_id} 暂无历史Bot消息") except Exception as e: @@ -302,7 +302,7 @@ async def build_diversity_prompt_injection(self, base_prompt: str, # 使用元指令包装器包装所有多样性提示词 wrapped = protection.wrap_prompts(raw_prompts) enhanced_prompt = base_prompt + "\n\n" + wrapped - logger.info(f"✅ 多样性Prompt已保护包装 - 原长度: {len(base_prompt)}, 新长度: {len(enhanced_prompt)}") + logger.info(f" 多样性Prompt已保护包装 - 原长度: {len(base_prompt)}, 新长度: {len(enhanced_prompt)}") else: # 保护服务不可用,使用原始拼接 enhanced_prompt = base_prompt + "\n\n" + "\n\n".join([f"【{i+1}】\n{p}" for i, p in enumerate(raw_prompts)]) @@ -387,7 +387,7 @@ def get_sampling_parameters(self, diversity_level: str = 'medium') -> Dict[str, 'frequency_penalty': 0.8, 'presence_penalty': 0.6 } - else: # medium + else: # medium params = { 'temperature': 0.7, 'top_p': 0.9, diff --git a/services/style_analyzer.py b/services/response/style_analyzer.py similarity index 96% rename from services/style_analyzer.py rename to services/response/style_analyzer.py index 8f21530..0188505 100644 --- a/services/style_analyzer.py +++ b/services/response/style_analyzer.py @@ -1,6 +1,7 @@ """ 风格分析服务 - 使用强模型深度分析对话风格并提炼特征 """ +import asyncio import json import time from typing import Dict, List, Optional, Any @@ -10,16 +11,16 @@ from astrbot.api import logger from astrbot.api.star import Context -from ..core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 -from ..core.interfaces import AnalysisResult # 导入 AnalysisResult +from ...core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 +from ...core.interfaces import AnalysisResult # 导入 AnalysisResult -from ..config import PluginConfig +from ...config import PluginConfig -from ..exceptions import StyleAnalysisError, ModelAccessError +from ...exceptions import StyleAnalysisError, ModelAccessError -from .database_manager import DatabaseManager +from ..database import DatabaseManager -from ..utils.json_utils import safe_parse_llm_json +from ...utils.json_utils import safe_parse_llm_json @dataclass @@ -125,11 +126,11 @@ async def analyze_conversation_style(self, group_id: str, messages: List[Dict[st message_texts = [msg.get('message', '') for msg in messages] combined_text = '\n'.join(message_texts[:50]) # 限制长度避免token超限 - # 生成风格分析报告 - style_analysis = await self._generate_style_analysis(combined_text) - - # 提取数值化特征 - style_profile = await self._extract_style_profile(combined_text) + # 并行生成风格分析报告和提取数值化特征 + style_analysis, style_profile = await asyncio.gather( + self._generate_style_analysis(combined_text), + self._extract_style_profile(combined_text), + ) # 检测风格变化 style_evolution = None diff --git a/services/social/__init__.py b/services/social/__init__.py new file mode 100644 index 0000000..6579b4a --- /dev/null +++ b/services/social/__init__.py @@ -0,0 +1,15 @@ +"""Social relationship analysis and context injection.""" + +from .social_context_injector import SocialContextInjector +from .enhanced_social_relation_manager import EnhancedSocialRelationManager +from .social_relation_analyzer import SocialRelationAnalyzer +from .social_graph_analyzer import SocialGraphAnalyzer +from .message_relationship_analyzer import MessageRelationshipAnalyzer + +__all__ = [ + "SocialContextInjector", + "EnhancedSocialRelationManager", + "SocialRelationAnalyzer", + "SocialGraphAnalyzer", + "MessageRelationshipAnalyzer", +] diff --git a/services/enhanced_social_relation_manager.py b/services/social/enhanced_social_relation_manager.py similarity index 74% rename from services/enhanced_social_relation_manager.py rename to services/social/enhanced_social_relation_manager.py index adea126..40027d3 100644 --- a/services/enhanced_social_relation_manager.py +++ b/services/social/enhanced_social_relation_manager.py @@ -10,19 +10,18 @@ from astrbot.api import logger -from ..config import PluginConfig -from ..core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage -from ..core.framework_llm_adapter import FrameworkLLMAdapter +from ...config import PluginConfig +from ...core.patterns import AsyncServiceBase +from ...core.interfaces import IDataStorage +from ...core.framework_llm_adapter import FrameworkLLMAdapter -from ..models.social_relation import ( +from ...models.social_relation import ( BloodRelationType, GeographicalRelationType, CareerRelationType, EmotionalRelationType, InterestRelationType, IntimacyLevel, RelationDuration, PowerStructure, SocialRelationComponent, UserSocialProfile, RelationChangeRule, RelationInfluenceOnPsychology ) -from ..utils.json_cleaner import LLMJSONParser class EnhancedSocialRelationManager(AsyncServiceBase): @@ -83,7 +82,7 @@ def _init_relation_difficulty(self) -> Dict[str, float]: "同村村民": 0.45, "同乡": 0.50, "同校": 0.55, - "同车乘客": 0.05, # 临时关系,易变 + "同车乘客": 0.05, # 临时关系,易变 # 业缘关系 - 中等到较高难度 "上下级": 0.65, @@ -123,7 +122,7 @@ def _init_relation_difficulty(self) -> Dict[str, float]: "借贷关系": 0.60, "生意伙伴": 0.55, "雇主雇员": 0.50, - "搭子关系": 0.15, # 临时功能关系,易变 + "搭子关系": 0.15, # 临时功能关系,易变 # 亲密度等级相关 "核心亲密": 0.90, @@ -194,7 +193,7 @@ def _init_relation_psych_influence(self) -> List[RelationInfluenceOnPsychology]: relation_value_threshold=0.6, interaction_type="compliment", psychological_impact={ - "情绪": 0.15, # 挚友的称赞让情绪大幅提升 + "情绪": 0.15, # 挚友的称赞让情绪大幅提升 "社交": 0.10, "精力": 0.05 }, @@ -206,7 +205,7 @@ def _init_relation_psych_influence(self) -> List[RelationInfluenceOnPsychology]: relation_value_threshold=0.6, interaction_type="insult", psychological_impact={ - "情绪": -0.25, # 挚友的侮辱伤害更深 + "情绪": -0.25, # 挚友的侮辱伤害更深 "社交": -0.15, "意志": -0.10 }, @@ -242,7 +241,7 @@ def _init_relation_psych_influence(self) -> List[RelationInfluenceOnPsychology]: relation_value_threshold=0.7, interaction_type="compliment", psychological_impact={ - "情绪": 0.20, # 恋人的赞美影响最大 + "情绪": 0.20, # 恋人的赞美影响最大 "社交": 0.12, "精力": 0.08, "兴趣": 0.05 @@ -255,7 +254,7 @@ def _init_relation_psych_influence(self) -> List[RelationInfluenceOnPsychology]: relation_value_threshold=0.7, interaction_type="insult", psychological_impact={ - "情绪": -0.30, # 恋人的伤害最深 + "情绪": -0.30, # 恋人的伤害最深 "社交": -0.20, "意志": -0.15, "精力": -0.10 @@ -365,7 +364,7 @@ async def update_relation( if not relation: relation = SocialRelationComponent( relation_type=relation_type_str, - value=0.5, # 初始中等强度 + value=0.5, # 初始中等强度 description=f"与 {to_user_id} 的{relation_type_str}关系" ) profile.add_relation(relation) @@ -583,65 +582,62 @@ async def get_relation_prompt_injection( self._logger.error(f"生成关系prompt注入失败: {e}") return "" - # ==================== 数据库操作 ==================== + # 数据库操作 async def _load_profile_from_db( self, user_id: str, group_id: str ) -> Optional[UserSocialProfile]: - """从数据库加载用户社交档案""" + """从数据库加载用户社交档案(ORM 版本)""" try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 加载档案统计 - await cursor.execute(''' - SELECT total_relations, significant_relations, dominant_relation_type, - created_at, last_updated - FROM user_social_profiles - WHERE user_id = ? AND group_id = ? - ''', (user_id, group_id)) - - row = await cursor.fetchone() - if not row: - return None + from sqlalchemy import select + from ...models.orm.social_relation import ( + UserSocialProfile as UserSocialProfileORM, + UserSocialRelationComponent as UserSocialRelationComponentORM, + ) + + async with self.db_manager.get_session() as session: + # 加载档案(带 eager-loaded relation_components) + stmt = select(UserSocialProfileORM).where( + UserSocialProfileORM.user_id == user_id, + UserSocialProfileORM.group_id == group_id, + ) + result = await session.execute(stmt) + profile_orm = result.scalar_one_or_none() - total, significant, dominant, created, updated = row + if not profile_orm: + return None profile = UserSocialProfile( user_id=user_id, group_id=group_id, - total_relations=total, - significant_relations=significant, - dominant_relation_type=dominant, - created_at=created, - last_updated=updated + total_relations=profile_orm.total_relations, + significant_relations=profile_orm.significant_relations, + dominant_relation_type=profile_orm.dominant_relation_type, + created_at=profile_orm.created_at, + last_updated=profile_orm.last_updated, ) # 加载所有关系组件 - await cursor.execute(''' - SELECT relation_type, value, frequency, last_interaction, - description, tags, created_at - FROM user_social_relation_components - WHERE from_user_id = ? AND group_id = ? - ''', (user_id, group_id)) - - for row in await cursor.fetchall(): - rel_type, value, freq, last_int, desc, tags_json, created = row + comp_stmt = select(UserSocialRelationComponentORM).where( + UserSocialRelationComponentORM.from_user_id == user_id, + UserSocialRelationComponentORM.group_id == group_id, + ) + comp_result = await session.execute(comp_stmt) + for comp in comp_result.scalars().all(): component = SocialRelationComponent( - relation_type=rel_type, - value=value, - frequency=freq, - last_interaction=last_int, - description=desc, - tags=json.loads(tags_json) if tags_json else [], - created_at=created + relation_type=comp.relation_type, + value=comp.value, + frequency=comp.frequency, + last_interaction=comp.last_interaction, + description=comp.description, + tags=json.loads(comp.tags) if comp.tags else [], + created_at=comp.created_at, ) profile.relations.append(component) - await cursor.close() return profile except Exception as e: @@ -649,56 +645,74 @@ async def _load_profile_from_db( return None async def _save_profile_to_db(self, profile: UserSocialProfile): - """保存用户社交档案到数据库""" + """保存用户社交档案到数据库(ORM 版本)""" try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # ✅ 使用数据库无关的语法:DELETE + INSERT 替代 INSERT OR REPLACE - # 先删除旧记录 - await cursor.execute(''' - DELETE FROM user_social_profiles - WHERE user_id = ? AND group_id = ? - ''', (profile.user_id, profile.group_id)) - - # 再插入新记录 - await cursor.execute(''' - INSERT INTO user_social_profiles - (user_id, group_id, total_relations, significant_relations, - dominant_relation_type, created_at, last_updated) - VALUES (?, ?, ?, ?, ?, ?, ?) - ''', ( - profile.user_id, profile.group_id, profile.total_relations, - profile.significant_relations, profile.dominant_relation_type, - profile.created_at, time.time() - )) + from sqlalchemy import select, delete + from ...models.orm.social_relation import ( + UserSocialProfile as UserSocialProfileORM, + UserSocialRelationComponent as UserSocialRelationComponentORM, + ) + + async with self.db_manager.get_session() as session: + # 查找现有档案 + stmt = select(UserSocialProfileORM).where( + UserSocialProfileORM.user_id == profile.user_id, + UserSocialProfileORM.group_id == profile.group_id, + ) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + # 更新现有档案 + existing.total_relations = profile.total_relations + existing.significant_relations = profile.significant_relations + existing.dominant_relation_type = profile.dominant_relation_type + existing.last_updated = int(time.time()) + profile_id = existing.id + else: + # 创建新档案 + new_profile = UserSocialProfileORM( + user_id=profile.user_id, + group_id=profile.group_id, + total_relations=profile.total_relations, + significant_relations=profile.significant_relations, + dominant_relation_type=profile.dominant_relation_type, + created_at=profile.created_at or int(time.time()), + last_updated=int(time.time()), + ) + session.add(new_profile) + await session.flush() + profile_id = new_profile.id + + # 删除旧的关系组件 + await session.execute( + delete(UserSocialRelationComponentORM).where( + UserSocialRelationComponentORM.from_user_id == profile.user_id, + UserSocialRelationComponentORM.group_id == profile.group_id, + ) + ) # 保存所有关系组件 for relation in profile.relations: rel_type_str = relation.relation_type.value if hasattr( relation.relation_type, 'value') else str(relation.relation_type) - # ✅ 先删除旧关系记录 - await cursor.execute(''' - DELETE FROM user_social_relation_components - WHERE from_user_id = ? AND to_user_id = ? AND group_id = ? AND relation_type = ? - ''', (profile.user_id, "bot", profile.group_id, rel_type_str)) - - # 再插入新关系记录 - await cursor.execute(''' - INSERT INTO user_social_relation_components - (from_user_id, to_user_id, group_id, relation_type, value, - frequency, last_interaction, description, tags, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - profile.user_id, "bot", profile.group_id, rel_type_str, - relation.value, relation.frequency, relation.last_interaction, - relation.description, json.dumps(relation.tags, ensure_ascii=False), - relation.created_at - )) - - await conn.commit() - await cursor.close() + comp = UserSocialRelationComponentORM( + profile_id=profile_id, + from_user_id=profile.user_id, + to_user_id="bot", + group_id=profile.group_id, + relation_type=rel_type_str, + value=relation.value, + frequency=relation.frequency, + last_interaction=relation.last_interaction or int(time.time()), + description=relation.description, + tags=json.dumps(relation.tags, ensure_ascii=False) if relation.tags else None, + created_at=relation.created_at or int(time.time()), + ) + session.add(comp) + + await session.commit() except Exception as e: self._logger.error(f"保存社交档案到数据库失败: {e}", exc_info=True) @@ -713,50 +727,49 @@ async def _record_relation_history( new_value: float, reason: str ): - """记录关系变化历史""" + """记录关系变化历史(ORM 版本)""" try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - await cursor.execute(''' - INSERT INTO social_relation_history - (from_user_id, to_user_id, group_id, relation_type, - old_value, new_value, change_reason, timestamp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - from_user_id, to_user_id, group_id, relation_type, - old_value, new_value, reason, time.time() - )) - - await conn.commit() - await cursor.close() + from ...models.orm.social_relation import SocialRelationHistory + + async with self.db_manager.get_session() as session: + record = SocialRelationHistory( + from_user_id=from_user_id, + to_user_id=to_user_id, + group_id=group_id, + relation_type=relation_type, + old_value=old_value, + new_value=new_value, + change_reason=reason, + timestamp=int(time.time()), + ) + session.add(record) + await session.commit() except Exception as e: self._logger.error(f"记录关系历史失败: {e}") async def _load_active_profiles(self): - """加载活跃用户的社交档案""" + """加载活跃用户的社交档案(ORM 版本)""" try: - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - - # 获取最近7天有互动的用户 - await cursor.execute(''' - SELECT DISTINCT user_id, group_id - FROM user_social_profiles - WHERE last_updated > ? - LIMIT 100 - ''', (time.time() - 86400 * 7,)) - - rows = await cursor.fetchall() - await cursor.close() - - for user_id, group_id in rows: - profile = await self._load_profile_from_db(user_id, group_id) - if profile: - self.user_profiles[(user_id, group_id)] = profile - - self._logger.info(f"已加载 {len(self.user_profiles)} 个用户的社交档案") + from sqlalchemy import select + from ...models.orm.social_relation import UserSocialProfile as UserSocialProfileORM + + async with self.db_manager.get_session() as session: + cutoff = int(time.time()) - 86400 * 7 + stmt = ( + select(UserSocialProfileORM.user_id, UserSocialProfileORM.group_id) + .where(UserSocialProfileORM.last_updated > cutoff) + .limit(100) + ) + result = await session.execute(stmt) + rows = result.fetchall() + + for user_id, group_id in rows: + profile = await self._load_profile_from_db(user_id, group_id) + if profile: + self.user_profiles[(user_id, group_id)] = profile + + self._logger.info(f"已加载 {len(self.user_profiles)} 个用户的社交档案") except Exception as e: self._logger.error(f"加载活跃档案失败: {e}", exc_info=True) diff --git a/services/message_relationship_analyzer.py b/services/social/message_relationship_analyzer.py similarity index 99% rename from services/message_relationship_analyzer.py rename to services/social/message_relationship_analyzer.py index 4b7f923..5d05e53 100644 --- a/services/message_relationship_analyzer.py +++ b/services/social/message_relationship_analyzer.py @@ -11,10 +11,10 @@ from astrbot.api import logger from astrbot.api.star import Context -from ..config import PluginConfig -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..exceptions import MessageAnalysisError -from ..utils.json_utils import safe_parse_llm_json +from ...config import PluginConfig +from ...core.framework_llm_adapter import FrameworkLLMAdapter +from ...exceptions import MessageAnalysisError +from ...utils.json_utils import safe_parse_llm_json @dataclass diff --git a/services/social_context_injector.py b/services/social/social_context_injector.py similarity index 63% rename from services/social_context_injector.py rename to services/social/social_context_injector.py index 483623b..f275609 100644 --- a/services/social_context_injector.py +++ b/services/social/social_context_injector.py @@ -1,7 +1,13 @@ """ 社交上下文注入器 - 将用户社交关系、好感度、Bot情绪信息注入到LLM prompt中 支持缓存机制以避免频繁查询数据库 + +整合了原 PsychologicalSocialContextInjector 的行为指导生成功能: +- 深度心理状态分析 +- LLM驱动的行为模式指导(非阻塞后台生成) +- 好感度/社交关系联动分析 """ +import asyncio import time from typing import Dict, Any, List, Optional, Tuple from cachetools import TTLCache @@ -26,7 +32,7 @@ def __init__( self.database_manager = database_manager self.affection_manager = affection_manager self.mood_manager = mood_manager - self.config = config # 添加config参数以读取配置 + self.config = config # 添加config参数以读取配置 # 新增:心理状态和社交关系管理器(整合自 PsychologicalSocialContextInjector) self.psych_manager = psychological_state_manager @@ -40,16 +46,20 @@ def __init__( self._prompt_protection = None self._enable_protection = True - # ⚡ 缓存机制 - 使用cachetools的TTLCache + # 缓存机制 - 使用cachetools的TTLCache # maxsize=1000: 最多缓存1000个条目 # ttl=60: 缓存有效期60秒(1分钟) self._cache = TTLCache(maxsize=1000, ttl=60) + # 行为指导后台生成 (整合自 PsychologicalSocialContextInjector) + self._background_tasks: set = set() + self._llm_generation_lock: Dict[str, asyncio.Lock] = {} + def _get_prompt_protection(self): """延迟加载提示词保护服务""" if self._prompt_protection is None and self._enable_protection: try: - from .prompt_sanitizer import PromptProtectionService + from ..response import PromptProtectionService self._prompt_protection = PromptProtectionService(wrapper_template_index=0) logger.info("社交上下文注入器: 提示词保护服务已加载") except Exception as e: @@ -105,64 +115,64 @@ async def format_complete_context( psych_context = await self._build_psychological_context(group_id) if psych_context: context_parts.append(psych_context) - logger.info(f"✅ [社交上下文] 已准备深度心理状态 (群组: {group_id}, 长度: {len(psych_context)})") + logger.info(f" [社交上下文] 已准备深度心理状态 (群组: {group_id}, 长度: {len(psych_context)})") else: - logger.info(f"⚠️ [社交上下文] 群组 {group_id} 暂无活跃的心理状态") + logger.info(f" [社交上下文] 群组 {group_id} 暂无活跃的心理状态") # 2. Bot当前情绪信息(基础版,可与心理状态共存) if include_mood and self.mood_manager: mood_text = await self._format_mood_context(group_id) if mood_text: context_parts.append(mood_text) - logger.debug(f"✅ [社交上下文] 已准备情绪信息 (群组: {group_id})") + logger.debug(f" [社交上下文] 已准备情绪信息 (群组: {group_id})") # 3. 对该用户的好感度信息 if include_affection and self.affection_manager: affection_text = await self._format_affection_context(group_id, user_id) if affection_text: context_parts.append(affection_text) - logger.debug(f"✅ [社交上下文] 已准备好感度信息 (群组: {group_id}, 用户: {user_id[:8]}...)") + logger.debug(f" [社交上下文] 已准备好感度信息 (群组: {group_id}, 用户: {user_id[:8]}...)") # 4. 用户社交关系信息(使用 SocialContextInjector 原有实现) if include_social_relations: social_text = await self.format_social_context(group_id, user_id) if social_text: context_parts.append(social_text) - logger.debug(f"✅ [社交上下文] 已准备社交关系 (群组: {group_id}, 用户: {user_id[:8]}...)") + logger.debug(f" [社交上下文] 已准备社交关系 (群组: {group_id}, 用户: {user_id[:8]}...)") # 5. 最近学到的表达模式(风格特征)- SocialContextInjector 独有 # 注意:表达模式内部已经应用了保护,这里获取的是保护后的文本 if include_expression_patterns: expression_text = await self._format_expression_patterns_context( group_id, - enable_protection=enable_protection # 传递保护参数 + enable_protection=enable_protection # 传递保护参数 ) if expression_text: context_parts.append(expression_text) - logger.info(f"✅ [社交上下文] 已准备表达模式 (群组: {group_id}, 长度: {len(expression_text)})") + logger.info(f" [社交上下文] 已准备表达模式 (群组: {group_id}, 长度: {len(expression_text)})") else: - logger.info(f"⚠️ [社交上下文] 群组 {group_id} 暂无表达模式学习记录") + logger.info(f" [社交上下文] 群组 {group_id} 暂无表达模式学习记录") # 6. 行为模式指导(整合自 PsychologicalSocialContextInjector) if include_behavior_guidance and (include_psychological or include_social_relations): behavior_guidance = await self._build_behavior_guidance(group_id, user_id) if behavior_guidance: context_parts.append(behavior_guidance) - logger.info(f"✅ [社交上下文] 已准备行为模式指导 (长度: {len(behavior_guidance)})") + logger.info(f" [社交上下文] 已准备行为模式指导 (长度: {len(behavior_guidance)})") else: - logger.debug(f"⚠️ [社交上下文] 未生成行为模式指导") + logger.debug(f" [社交上下文] 未生成行为模式指导") # 7. 对话目标上下文(新增) if include_conversation_goal and self.goal_manager: - logger.info(f"🔍 [社交上下文] 尝试获取对话目标上下文 (user={user_id[:8]}..., group={group_id})") + logger.info(f" [社交上下文] 尝试获取对话目标上下文 (user={user_id[:8]}..., group={group_id})") goal_context = await self._format_conversation_goal_context(group_id, user_id) if goal_context: context_parts.append(goal_context) - logger.info(f"✅ [社交上下文] 已准备对话目标 (长度: {len(goal_context)})") + logger.info(f" [社交上下文] 已准备对话目标 (长度: {len(goal_context)})") else: - logger.info(f"ℹ️ [社交上下文] 未找到活跃对话目标 (user={user_id[:8]}..., group={group_id})") + logger.info(f" [社交上下文] 未找到活跃对话目标 (user={user_id[:8]}..., group={group_id})") elif include_conversation_goal and not self.goal_manager: - logger.warning(f"⚠️ [社交上下文] 对话目标功能已启用但goal_manager未初始化") + logger.warning(f" [社交上下文] 对话目标功能已启用但goal_manager未初始化") if not context_parts: return None @@ -190,10 +200,10 @@ async def format_complete_context( protection = self._get_prompt_protection() if protection: protected_other = protection.wrap_prompt(raw_other_context, register_for_filter=True) - logger.info(f"✅ [社交上下文] 已对情绪/好感度/社交关系应用提示词保护") + logger.info(f" [社交上下文] 已对情绪/好感度/社交关系应用提示词保护") else: protected_other = raw_other_context - logger.warning(f"⚠️ [社交上下文] 提示词保护服务不可用,使用原始文本") + logger.warning(f" [社交上下文] 提示词保护服务不可用,使用原始文本") else: protected_other = raw_other_context else: @@ -211,12 +221,12 @@ async def format_complete_context( full_context = "\n\n".join(final_parts) - # 🔍 输出最终上下文的组成部分用于调试 - logger.info(f"📋 [社交上下文] 最终上下文包含 {len(final_parts)} 个部分") + # 输出最终上下文的组成部分用于调试 + logger.info(f" [社交上下文] 最终上下文包含 {len(final_parts)} 个部分") if "对话目标" in full_context or "【当前对话目标状态】" in full_context: - logger.info(f"✅ [社交上下文] 对话目标上下文已成功包含在最终输出中") + logger.info(f" [社交上下文] 对话目标上下文已成功包含在最终输出中") else: - logger.info(f"ℹ️ [社交上下文] 对话目标上下文未包含在最终输出中") + logger.info(f" [社交上下文] 对话目标上下文未包含在最终输出中") return full_context @@ -230,7 +240,7 @@ async def _format_mood_context(self, group_id: str) -> Optional[str]: if not self.mood_manager: return None - # ⚡ 尝试从缓存获取 + # 尝试从缓存获取 cache_key = f"mood_{group_id}" cached = self._get_from_cache(cache_key) if cached is not None: @@ -295,7 +305,7 @@ def _normalize_mood(record: Any) -> Tuple[Optional[str], Optional[float], str]: connector = " - " if mood_label else "" mood_text += f"{connector}{mood_description}" - # ⚡ 缓存结果 + # 缓存结果 self._set_to_cache(cache_key, mood_text) return mood_text @@ -309,7 +319,7 @@ async def _format_affection_context(self, group_id: str, user_id: str) -> Option if not self.affection_manager: return None - # ⚡ 尝试从缓存获取 + # 尝试从缓存获取 cache_key = f"affection_{group_id}_{user_id}" cached = self._get_from_cache(cache_key) if cached is not None: @@ -353,7 +363,7 @@ async def _format_affection_context(self, group_id: str, user_id: str) -> Option if affection_rank and affection_rank != '未知': affection_text += f"\n好感度排名: {affection_rank}" - # ⚡ 缓存结果 + # 缓存结果 self._set_to_cache(cache_key, affection_text) return affection_text @@ -380,7 +390,7 @@ async def _format_expression_patterns_context( 格式化的表达模式文本(已保护包装) """ try: - # ⚡ 尝试从缓存获取 + # 尝试从缓存获取 cache_key = f"expression_patterns_{group_id}" cached = self._get_from_cache(cache_key) if cached is not None: @@ -391,7 +401,7 @@ async def _format_expression_patterns_context( if self.config and hasattr(self.config, 'expression_patterns_hours'): hours = getattr(self.config, 'expression_patterns_hours', 24) - # 1️⃣ 优先获取当前群组的表达模式 + # 优先获取当前群组的表达模式 patterns = await self.database_manager.get_recent_week_expression_patterns( group_id, limit=10, @@ -400,20 +410,20 @@ async def _format_expression_patterns_context( source_desc = f"群组 {group_id}" - # 2️⃣ 如果当前群组没有表达模式,且启用了全局回退,则获取全局表达模式 + # 如果当前群组没有表达模式,且启用了全局回退,则获取全局表达模式 if not patterns and enable_global_fallback: - logger.info(f"⚠️ [表达模式] 群组 {group_id} 无表达模式,尝试使用全局表达模式") + logger.info(f" [表达模式] 群组 {group_id} 无表达模式,尝试使用全局表达模式") patterns = await self.database_manager.get_recent_week_expression_patterns( - group_id=None, # None = 全局查询 + group_id=None, # None = 全局查询 limit=10, hours=hours ) source_desc = "全局所有群组" if not patterns: - # ⚡ 缓存空结果(避免频繁查询空数据) + # 缓存空结果(避免频繁查询空数据) self._set_to_cache(cache_key, None) - logger.info(f"⚠️ [表达模式] {source_desc} 均无表达模式学习记录") + logger.info(f" [表达模式] {source_desc} 均无表达模式学习记录") return None # 构建原始表达模式文本 @@ -421,7 +431,7 @@ async def _format_expression_patterns_context( raw_pattern_text = f"最近{time_desc}学到的表达风格特征(来源: {source_desc}):\n" raw_pattern_text += f"以下是最近{time_desc}学习到的表达模式,参考这些风格进行回复:\n" - for i, pattern in enumerate(patterns[:10], 1): # 最多显示10个 + for i, pattern in enumerate(patterns[:10], 1): # 最多显示10个 situation = pattern.get('situation', '未知场景') expression = pattern.get('expression', '未知表达') @@ -435,15 +445,15 @@ async def _format_expression_patterns_context( protection = self._get_prompt_protection() if protection: protected_text = protection.wrap_prompt(raw_pattern_text, register_for_filter=True) - logger.info(f"✅ [表达模式] 已应用提示词保护 (来源: {source_desc}, 模式数: {len(patterns)})") - # ⚡ 缓存保护后的结果 + logger.info(f" [表达模式] 已应用提示词保护 (来源: {source_desc}, 模式数: {len(patterns)})") + # 缓存保护后的结果 self._set_to_cache(cache_key, protected_text) return protected_text else: - logger.warning(f"⚠️ [表达模式] 提示词保护服务不可用,使用原始文本") + logger.warning(f" [表达模式] 提示词保护服务不可用,使用原始文本") - # ⚡ 缓存原始结果 - logger.info(f"✅ [表达模式] 已准备表达模式(未保护)(来源: {source_desc}, 模式数: {len(patterns)})") + # 缓存原始结果 + logger.info(f" [表达模式] 已准备表达模式(未保护)(来源: {source_desc}, 模式数: {len(patterns)})") self._set_to_cache(cache_key, raw_pattern_text) return raw_pattern_text @@ -463,7 +473,7 @@ async def format_social_context(self, group_id: str, user_id: str) -> Optional[s 格式化的社交关系文本,如果没有关系则返回None """ try: - # ⚡ 先从缓存获取 + # 先从缓存获取 cache_key = f"social_relations_{group_id}_{user_id}" cached = self._get_from_cache(cache_key) if cached is not None: @@ -473,7 +483,7 @@ async def format_social_context(self, group_id: str, user_id: str) -> Optional[s relations_data = await self.database_manager.get_user_social_relations(group_id, user_id) if relations_data['total_relations'] == 0: - # ⚡ 缓存空结果 + # 缓存空结果 self._set_to_cache(cache_key, None) return None @@ -484,32 +494,32 @@ async def format_social_context(self, group_id: str, user_id: str) -> Optional[s # 格式化发出的关系 if relations_data['outgoing']: context_lines.append(f"该用户的互动对象(按频率排序):") - for i, relation in enumerate(relations_data['outgoing'][:5], 1): # 只显示前5个 + for i, relation in enumerate(relations_data['outgoing'][:5], 1): # 只显示前5个 target = self._extract_user_id(relation['to_user']) relation_type = self._format_relation_type(relation['relation_type']) strength = relation['strength'] frequency = relation['frequency'] context_lines.append( - f" {i}. 与 {target} - {relation_type},强度: {strength:.1f},互动{frequency}次" + f" {i}. 与 {target} - {relation_type},强度: {strength:.1f},互动{frequency}次" ) # 格式化接收的关系 if relations_data['incoming']: context_lines.append(f"与该用户互动的成员(按频率排序):") - for i, relation in enumerate(relations_data['incoming'][:5], 1): # 只显示前5个 + for i, relation in enumerate(relations_data['incoming'][:5], 1): # 只显示前5个 source = self._extract_user_id(relation['from_user']) relation_type = self._format_relation_type(relation['relation_type']) strength = relation['strength'] frequency = relation['frequency'] context_lines.append( - f" {i}. {source} - {relation_type},强度: {strength:.1f},互动{frequency}次" + f" {i}. {source} - {relation_type},强度: {strength:.1f},互动{frequency}次" ) context_text = "\n".join(context_lines) - # ⚡ 缓存结果 + # 缓存结果 self._set_to_cache(cache_key, context_text) return context_text @@ -578,90 +588,248 @@ async def inject_context_to_prompt( if injection_position == "start": return f"{context}\n\n{original_prompt}" - else: # end + else: # end return f"{original_prompt}\n\n{context}" except Exception as e: logger.error(f"注入上下文失败: {e}", exc_info=True) return original_prompt - # ========== 整合自 PsychologicalSocialContextInjector 的方法 ========== + # 行为指导生成 (整合自 PsychologicalSocialContextInjector) - async def _build_psychological_context(self, group_id: str) -> str: - """构建深度心理状态上下文(整合自 PsychologicalSocialContextInjector)""" + async def _build_behavior_guidance(self, group_id: str, user_id: str) -> str: + """ + 构建行为模式指导(基于心理状态和社交关系的联动分析) + + 使用LLM提炼模型生成对bot行为有强烈指导性但不死板的提示词。 + + 非阻塞设计: + - 优先返回缓存数据(TTLCache自动管理过期) + - 如果缓存不存在,返回空字符串,并在后台异步生成 + - 后台生成完成后更新缓存,下次调用时可用 + """ try: - if not self.psych_manager: - return "" + cache_key = f"behavior_guidance_{group_id}_{user_id}" - cache_key = f"psych_context_{group_id}" + # 1. 优先返回缓存 cached = self._get_from_cache(cache_key) if cached: + logger.debug(f"[behavior_guidance] cache hit (group: {group_id[:8]}...)") return cached - # 从心理状态管理器获取当前状态 - state_prompt = await self.psych_manager.get_state_prompt_injection(group_id) + # 2. 缓存未命中 - 检查是否已有后台生成任务在运行 + if cache_key not in self._llm_generation_lock: + self._llm_generation_lock[cache_key] = asyncio.Lock() - if state_prompt: - self._set_to_cache(cache_key, state_prompt) - return state_prompt + if self._llm_generation_lock[cache_key].locked(): + logger.debug(f"[behavior_guidance] generation in progress, skip (group: {group_id[:8]}...)") + return "" - return "" + # 3. 获取锁后,启动后台生成任务(不等待) + async with self._llm_generation_lock[cache_key]: + # 双重检查 + cached = self._get_from_cache(cache_key) + if cached: + return cached + + task = asyncio.create_task(self._background_generate_guidance( + cache_key, group_id, user_id + )) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + logger.debug(f"[behavior_guidance] bg task started (group: {group_id[:8]}...)") + return "" except Exception as e: - logger.error(f"构建深度心理状态上下文失败: {e}", exc_info=True) + logger.error(f"[behavior_guidance] build failed: {e}", exc_info=True) return "" - async def _build_behavior_guidance(self, group_id: str, user_id: str) -> str: - """ - 构建行为模式指导(复用 PsychologicalSocialContextInjector 的完整实现) - - 基于心理状态和社交关系生成行为指导 - 通过内部调用 PsychologicalSocialContextInjector 来实现完整功能 - """ + async def _background_generate_guidance( + self, + cache_key: str, + group_id: str, + user_id: str + ): + """后台生成行为指导(异步任务,不阻塞主流程)""" try: - # 延迟导入,避免循环依赖 - if not hasattr(self, '_psych_social_injector'): - from .psychological_social_context_injector import PsychologicalSocialContextInjector - - # 创建 PsychologicalSocialContextInjector 实例(复用现有管理器) - self._psych_social_injector = PsychologicalSocialContextInjector( - database_manager=self.database_manager, - psychological_state_manager=self.psych_manager, - social_relation_manager=self.social_manager, - affection_manager=self.affection_manager, - diversity_manager=None, # 不需要多样性管理器 - llm_adapter=self.llm_adapter, - config=self.config + # 获取心理状态 + psych_state = None + if self.psych_manager and hasattr(self.psych_manager, 'get_or_create_state'): + psych_state = await self.psych_manager.get_or_create_state(group_id) + + # 获取社交关系 + social_profile = None + if self.social_manager and hasattr(self.social_manager, 'get_or_create_profile'): + social_profile = await self.social_manager.get_or_create_profile( + user_id, group_id ) - logger.debug("✅ [SocialContextInjector] 已创建内部 PsychologicalSocialContextInjector") - # 调用 PsychologicalSocialContextInjector 的行为指导方法 - if hasattr(self._psych_social_injector, '_build_behavior_guidance'): - guidance = await self._psych_social_injector._build_behavior_guidance(group_id, user_id) - return guidance + # 获取好感度 + affection_level = 0 + if self.affection_manager: + try: + affection_data = await self.database_manager.get_user_affection(group_id, user_id) + if affection_data: + affection_level = affection_data.get('affection_level', 0) + except Exception: + pass + + # 使用LLM提炼模型生成行为指导 + guidance = await self._generate_guidance_by_llm( + psych_state, social_profile, affection_level, group_id, user_id + ) + + if guidance: + self._set_to_cache(cache_key, guidance) + logger.info(f"[behavior_guidance] bg generation done and cached (group: {group_id[:8]}...)") else: - logger.warning("⚠️ PsychologicalSocialContextInjector 没有 _build_behavior_guidance 方法") + logger.debug(f"[behavior_guidance] LLM returned empty (group: {group_id[:8]}...)") + + except Exception as e: + logger.error(f"[behavior_guidance] bg generation failed: {e}", exc_info=True) + + async def _generate_guidance_by_llm( + self, + psych_state, + social_profile, + affection_level: int, + group_id: str, + user_id: str + ) -> str: + """使用LLM提炼模型生成行为指导prompt""" + try: + if not self.llm_adapter: return "" + if not hasattr(self.llm_adapter, 'has_refine_provider') or not self.llm_adapter.has_refine_provider(): + return "" + + # 构建心理状态描述 + psych_desc = "" + if psych_state and hasattr(psych_state, 'get_active_components'): + active_components = psych_state.get_active_components() + if active_components: + psych_parts = [] + for component in active_components[:5]: + category = component.category + state_name = ( + component.state_type.value + if hasattr(component.state_type, 'value') + else str(component.state_type) + ) + intensity = component.value + psych_parts.append(f"- {category}: {state_name} (intensity: {intensity:.2f})") + psych_desc = "\n".join(psych_parts) + + # 构建社交关系描述 + social_desc = "" + if social_profile and hasattr(social_profile, 'get_significant_relations'): + significant_relations = social_profile.get_significant_relations() + if significant_relations: + social_parts = [] + for rel in significant_relations[:3]: + rel_name = ( + rel.relation_type.value + if hasattr(rel.relation_type, 'value') + else str(rel.relation_type) + ) + social_parts.append(f"- {rel_name} (strength: {rel.value:.2f})") + social_desc = "\n".join(social_parts) + + # 构建好感度描述 + if affection_level >= 80: + affection_desc = f"very fond ({affection_level}/100)" + elif affection_level >= 60: + affection_desc = f"fairly fond ({affection_level}/100)" + elif affection_level >= 40: + affection_desc = f"some affection ({affection_level}/100)" + elif affection_level >= 20: + affection_desc = f"slight affection ({affection_level}/100)" + elif affection_level >= 0: + affection_desc = f"first meeting ({affection_level}/100)" + elif affection_level >= -20: + affection_desc = f"slight dislike ({affection_level}/100)" + elif affection_level >= -40: + affection_desc = f"fairly disliked ({affection_level}/100)" + else: + affection_desc = f"strongly disliked ({affection_level}/100)" + + # 构建LLM prompt + prompt = self._build_llm_guidance_prompt(psych_desc, social_desc, affection_desc) + + response = await self.llm_adapter.refine_chat_completion( + prompt=prompt, + temperature=0.7 + ) + + if response: + return response.strip() + + return "" except Exception as e: - logger.error(f"构建行为模式指导失败: {e}", exc_info=True) + logger.error(f"[behavior_guidance] LLM generation failed: {e}", exc_info=True) return "" - async def _format_conversation_goal_context(self, group_id: str, user_id: str) -> Optional[str]: - """ - 格式化对话目标上下文(带缓存) + @staticmethod + def _build_llm_guidance_prompt( + psych_desc: str, + social_desc: str, + affection_desc: str + ) -> str: + """构建发送给LLM提炼模型的行为指导生成prompt""" + return ( + "You are an AI conversation behavior analyst. " + "Based on the following Bot's current psychological state, social relations, " + "and affection level, generate a concise but effective behavior guidance prompt.\n\n" + f"[Bot Current Psychological State]\n" + f"{psych_desc if psych_desc else 'No notable psychological state'}\n\n" + f"[Social Relationship with User]\n" + f"{social_desc if social_desc else 'First contact, stranger relationship'}\n\n" + f"[Affection Level for User]\n" + f"{affection_desc}\n\n" + "---\n\n" + "Please generate behavior guidance with 2-4 bullet points:\n" + "1. Tone & style: describe the tone (e.g. relaxed, calm, direct)\n" + "2. Attitude: describe attitude towards the user (e.g. friendly, slightly cold)\n" + "3. Reply style: describe reply characteristics (e.g. brief, detailed, patient)\n" + "4. Special note: any other relevant suggestion (optional)\n\n" + "Output the guidance directly, no extra explanation or title." + ) + + # 心理状态上下文 - Args: - group_id: 群组ID - user_id: 用户ID + async def _build_psychological_context(self, group_id: str) -> str: + """构建深度心理状态上下文""" + try: + if not self.psych_manager: + return "" - Returns: - 格式化的对话目标文本,如果没有活跃目标则返回None - """ + cache_key = f"psych_context_{group_id}" + cached = self._get_from_cache(cache_key) + if cached: + return cached + + state_prompt = await self.psych_manager.get_state_prompt_injection(group_id) + + if state_prompt: + self._set_to_cache(cache_key, state_prompt) + return state_prompt + + return "" + + except Exception as e: + logger.error(f"[psych_context] build failed: {e}", exc_info=True) + return "" + + # 对话目标上下文 + + async def _format_conversation_goal_context(self, group_id: str, user_id: str) -> Optional[str]: + """格式化对话目标上下文(带缓存)""" try: if not self.goal_manager: return None - # ⚡ 尝试从缓存获取 + # 尝试从缓存获取 cache_key = f"conv_goal_{group_id}_{user_id}" cached = self._get_from_cache(cache_key) if cached is not None: @@ -670,9 +838,9 @@ async def _format_conversation_goal_context(self, group_id: str, user_id: str) - # 获取当前对话目标 goal = await self.goal_manager.get_conversation_goal(user_id, group_id) if not goal: - # ⚡ 缓存空结果 + # 缓存空结果 self._set_to_cache(cache_key, None) - logger.debug(f"⚠️ [对话目标上下文] 群组 {group_id} 用户 {user_id[:8]}... 暂无活跃对话目标") + logger.debug(f" [对话目标上下文] 群组 {group_id} 用户 {user_id[:8]}... 暂无活跃对话目标") return None # 提取关键信息 @@ -693,7 +861,7 @@ async def _format_conversation_goal_context(self, group_id: str, user_id: str) - user_engagement = metrics.get('user_engagement', 0.5) progress = metrics.get('goal_progress', 0.0) - logger.info(f"✅ [对话目标上下文] 检测到活跃目标 - 类型: {goal_type}, 名称: {goal_name}, 进度: {progress:.0%}, 阶段: {current_task}") + logger.info(f" [对话目标上下文] 检测到活跃目标 - 类型: {goal_type}, 名称: {goal_name}, 进度: {progress:.0%}, 阶段: {current_task}") # 格式化上下文文本 context_lines = [] @@ -713,29 +881,29 @@ async def _format_conversation_goal_context(self, group_id: str, user_id: str) - context_lines.append("") context_lines.append("【回复指令】") if task_index < len(planned_stages): - context_lines.append(f"✅ 请根据以上对话目标信息,结合用户的最新消息,围绕当前阶段性目标「{current_task}」组织你的回复内容。") - context_lines.append(f"✅ 你的回复应该自然地推进对话朝着「{goal_name}」的方向发展,同时保持对话的连贯性和真实性。") - context_lines.append(f"✅ 注意:不要机械地提及'目标'或'阶段'等元信息,而是通过对话内容本身体现当前阶段的意图。") + context_lines.append(f" 请根据以上对话目标信息,结合用户的最新消息,围绕当前阶段性目标「{current_task}」组织你的回复内容。") + context_lines.append(f" 你的回复应该自然地推进对话朝着「{goal_name}」的方向发展,同时保持对话的连贯性和真实性。") + context_lines.append(f" 注意:不要机械地提及'目标'或'阶段'等元信息,而是通过对话内容本身体现当前阶段的意图。") # 根据进度和参与度调整提示 if progress < 0.3: - context_lines.append(f"💡 对话刚开始,重点是{current_task},建立良好的互动基础。") + context_lines.append(f" 对话刚开始,重点是{current_task},建立良好的互动基础。") elif progress < 0.7: - context_lines.append(f"💡 对话进行中,继续围绕{current_task}深入交流,适时引导话题发展。") + context_lines.append(f" 对话进行中,继续围绕{current_task}深入交流,适时引导话题发展。") else: - context_lines.append(f"💡 对话接近完成,注意把握{current_task}的收尾,为下一阶段做准备。") + context_lines.append(f" 对话接近完成,注意把握{current_task}的收尾,为下一阶段做准备。") if user_engagement < 0.4: - context_lines.append(f"⚠️ 用户参与度较低({user_engagement:.0%}),尝试提出开放性问题或话题,激发用户兴趣。") + context_lines.append(f" 用户参与度较低({user_engagement:.0%}),尝试提出开放性问题或话题,激发用户兴趣。") elif user_engagement > 0.7: - context_lines.append(f"✨ 用户参与度很高({user_engagement:.0%}),保持当前互动风格,深化对话内容。") + context_lines.append(f" 用户参与度很高({user_engagement:.0%}),保持当前互动风格,深化对话内容。") else: - context_lines.append(f"✅ 对话目标「{goal_name}」的所有规划阶段已完成,请自然地结束本话题或引导新话题。") - context_lines.append(f"✅ 注意:避免生硬地结束对话,保持自然流畅的互动。") + context_lines.append(f" 对话目标「{goal_name}」的所有规划阶段已完成,请自然地结束本话题或引导新话题。") + context_lines.append(f" 注意:避免生硬地结束对话,保持自然流畅的互动。") context_text = "\n".join(context_lines) - # ⚡ 缓存结果 + # 缓存结果 self._set_to_cache(cache_key, context_text) return context_text diff --git a/services/social/social_graph_analyzer.py b/services/social/social_graph_analyzer.py new file mode 100644 index 0000000..7875f82 --- /dev/null +++ b/services/social/social_graph_analyzer.py @@ -0,0 +1,297 @@ +""" +Social graph analyzer. + +Adds graph-level analytics on top of the existing +``EnhancedSocialRelationManager``: + +* **Sentiment polarity**: LLM-based batch sentiment labelling for + interaction pairs (positive/negative/neutral). +* **Community detection**: Louvain algorithm via ``networkx`` to + identify tightly-knit subgroups within a chat group. +* **Influence ranking**: PageRank to surface the most influential + members of a group. + +All heavy computation is done via ``networkx`` (already a project +dependency). Sentiment labelling uses the framework LLM adapter +(remote API, no local model). + +Design notes: + - Builds an in-memory ``nx.DiGraph`` from the ORM + ``UserSocialRelationComponent`` rows. + - Community detection results are cached per group to avoid + recomputing on every request. + - Thread-safe for single-event-loop asyncio usage. +""" + +import time +from typing import Any, Dict, List, Optional, Set, Tuple + +import networkx as nx +from pydantic import BaseModel, Field, field_validator + +from astrbot.api import logger + +from ...core.framework_llm_adapter import FrameworkLLMAdapter + + +# Pydantic models for guardrails-ai structured output validation. + +class _SentimentItem(BaseModel): + """Schema for a single sentiment-labelled interaction pair.""" + + from_user: str = Field(alias="from", description="Source user identifier.") + to_user: str = Field(alias="to", description="Target user identifier.") + sentiment: float = Field( + ge=-1.0, le=1.0, + description="Sentiment polarity from -1.0 (hostile) to +1.0 (friendly).", + ) + label: str = Field( + description="Categorical label: positive, negative, or neutral.", + ) + + model_config = {"populate_by_name": True} + + @field_validator("label") + @classmethod + def normalise_label(cls, v: str) -> str: + v = v.strip().lower() + if v not in ("positive", "negative", "neutral"): + return "neutral" + return v + + +# LLM prompt for batch sentiment labelling of interaction pairs. +_SENTIMENT_BATCH_PROMPT = """Below are interaction summaries between users in a chat group. +For each pair, determine the sentiment polarity of the interaction. + +Interactions: +{interactions} + +Output a JSON array where each element has the format: +{{"from": "", "to": "", "sentiment": , "label": "positive|negative|neutral"}} + +Rules: +- sentiment ranges from -1.0 (hostile) to +1.0 (warm/friendly) +- "neutral" means roughly 0, "positive" means > 0.3, "negative" means < -0.3 +- Only output the JSON array, no extra text.""" + + +class SocialGraphAnalyzer: + """Graph-level social analytics for chat groups. + + Usage:: + + analyzer = SocialGraphAnalyzer(llm_adapter, db_manager) + communities = await analyzer.detect_communities(group_id) + ranking = await analyzer.get_influence_ranking(group_id) + sentiments = await analyzer.analyze_interaction_sentiment( + interactions, group_id + ) + """ + + def __init__( + self, + llm_adapter: Optional[FrameworkLLMAdapter] = None, + db_manager=None, + ) -> None: + self._llm = llm_adapter + self._db = db_manager + + # Per-group community cache: group_id -> (timestamp, communities). + self._community_cache: Dict[str, Tuple[float, List[Set[str]]]] = {} + self._cache_ttl = 600 # 10 minutes + + # Public API + + async def build_social_graph(self, group_id: str) -> nx.DiGraph: + """Build a directed graph from stored social relation components. + + Nodes are user IDs; edges carry ``weight`` (relation value) and + ``relation_type`` attributes. + """ + graph = nx.DiGraph() + + if not self._db or not hasattr(self._db, "get_session"): + return graph + + try: + from ...models.orm.social_relation import UserSocialRelationComponent + from sqlalchemy import select + + async with self._db.get_session() as session: + stmt = select(UserSocialRelationComponent).where( + UserSocialRelationComponent.group_id == group_id + ) + result = await session.execute(stmt) + rows = result.scalars().all() + + for row in rows: + graph.add_edge( + row.from_user_id, + row.to_user_id, + weight=row.value, + relation_type=row.relation_type, + frequency=row.frequency, + ) + + except Exception as exc: + logger.debug(f"[SocialGraph] Failed to build graph: {exc}") + + return graph + + async def detect_communities( + self, group_id: str, resolution: float = 1.0 + ) -> List[Set[str]]: + """Detect communities within a group using the Louvain algorithm. + + Args: + group_id: Chat group to analyse. + resolution: Louvain resolution parameter (higher = smaller + communities). + + Returns: + List of sets, each set containing user IDs that form a + community. + """ + # Check cache. + cached = self._community_cache.get(group_id) + if cached: + ts, communities = cached + if time.time() - ts < self._cache_ttl: + return communities + + graph = await self.build_social_graph(group_id) + if graph.number_of_nodes() < 2: + return [] + + # Louvain requires an undirected graph. + undirected = graph.to_undirected() + try: + communities = list( + nx.community.louvain_communities( + undirected, resolution=resolution, seed=42 + ) + ) + except Exception as exc: + logger.debug(f"[SocialGraph] Community detection failed: {exc}") + communities = [] + + self._community_cache[group_id] = (time.time(), communities) + return communities + + async def get_influence_ranking( + self, group_id: str, top_k: int = 10 + ) -> List[Dict[str, Any]]: + """Rank group members by influence using PageRank. + + Returns: + Sorted list of dicts with ``user_id``, ``pagerank``, + ``degree`` keys. Most influential first. + """ + graph = await self.build_social_graph(group_id) + if graph.number_of_nodes() == 0: + return [] + + try: + pr = nx.pagerank(graph, weight="weight") + except Exception: + pr = {n: 0.0 for n in graph.nodes} + + degree = dict(graph.degree()) + + ranking = [ + { + "user_id": uid, + "pagerank": round(score, 6), + "degree": degree.get(uid, 0), + } + for uid, score in pr.items() + ] + ranking.sort(key=lambda x: x["pagerank"], reverse=True) + return ranking[:top_k] + + async def analyze_interaction_sentiment( + self, + interactions: List[Dict[str, str]], + group_id: str, + ) -> List[Dict[str, Any]]: + """Batch-label sentiment polarity for interaction pairs via LLM. + + Args: + interactions: List of dicts with ``from``, ``to``, and + ``summary`` keys describing each interaction. + group_id: Chat group context. + + Returns: + List of dicts with ``from``, ``to``, ``sentiment`` (float), + and ``label`` keys. + """ + if not self._llm or not interactions: + return [] + + # Format interactions for the prompt. + lines = [] + for i, item in enumerate(interactions[:20], 1): + lines.append( + f"{i}. {item.get('from', '?')} -> {item.get('to', '?')}: " + f"{item.get('summary', 'general interaction')}" + ) + + prompt = _SENTIMENT_BATCH_PROMPT.format( + interactions="\n".join(lines) + ) + + try: + response = await self._llm.generate_response( + prompt, model_type="filter" + ) + if not response: + return [] + + # Validate LLM output via guardrails-ai: parse the raw JSON + # array, then validate each element against the Pydantic schema. + from ...utils.guardrails_manager import get_guardrails_manager + gm = get_guardrails_manager() + parsed = gm.validate_and_clean_json(response, expected_type="array") + if not isinstance(parsed, list): + return [] + + results: List[Dict[str, Any]] = [] + for raw_item in parsed: + if not isinstance(raw_item, dict): + continue + try: + validated = _SentimentItem.model_validate(raw_item) + results.append({ + "from": validated.from_user, + "to": validated.to_user, + "sentiment": validated.sentiment, + "label": validated.label, + }) + except Exception: + # Skip malformed items rather than failing the batch. + continue + return results + + except Exception as exc: + logger.debug(f"[SocialGraph] Sentiment analysis failed: {exc}") + return [] + + async def get_graph_statistics( + self, group_id: str + ) -> Dict[str, Any]: + """Return summary statistics for a group's social graph.""" + graph = await self.build_social_graph(group_id) + stats: Dict[str, Any] = { + "node_count": graph.number_of_nodes(), + "edge_count": graph.number_of_edges(), + "density": 0.0, + "communities": 0, + } + + if graph.number_of_nodes() > 1: + stats["density"] = round(nx.density(graph), 4) + communities = await self.detect_communities(group_id) + stats["communities"] = len(communities) + + return stats diff --git a/services/social_relation_analyzer.py b/services/social/social_relation_analyzer.py similarity index 93% rename from services/social_relation_analyzer.py rename to services/social/social_relation_analyzer.py index 6302f47..7c3bf7c 100644 --- a/services/social_relation_analyzer.py +++ b/services/social/social_relation_analyzer.py @@ -11,22 +11,22 @@ from astrbot.api import logger -from ..config import PluginConfig -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..exceptions import MessageAnalysisError -from ..utils.json_utils import safe_parse_llm_json +from ...config import PluginConfig +from ...core.framework_llm_adapter import FrameworkLLMAdapter +from ...exceptions import MessageAnalysisError +from ...utils.json_utils import safe_parse_llm_json @dataclass class SocialRelation: """社交关系数据结构""" - from_user: str # 发起方用户ID - to_user: str # 接收方用户ID - relation_type: str # 关系类型: 'frequent_interaction', 'mention', 'reply', 'topic_discussion' - strength: float # 关系强度 0.0-1.0 - frequency: int # 互动频率(消息数量) - last_interaction: str # 最后互动时间 - relation_name: str # 关系名称(中文描述) + from_user: str # 发起方用户ID + to_user: str # 接收方用户ID + relation_type: str # 关系类型: 'frequent_interaction', 'mention', 'reply', 'topic_discussion' + strength: float # 关系强度 0.0-1.0 + frequency: int # 互动频率(消息数量) + last_interaction: str # 最后互动时间 + relation_name: str # 关系名称(中文描述) class SocialRelationAnalyzer: @@ -152,7 +152,7 @@ async def analyze_group_social_relations( async def _get_group_messages(self, group_id: str, limit: int) -> List[Dict[str, Any]]: """获取群组消息记录(使用 ORM 方法,支持跨线程调用)""" try: - # ✅ 使用 ORM 方法获取消息(支持跨线程调用) + # 使用 ORM 方法获取消息(支持跨线程调用) raw_messages = await self.db_manager.get_recent_raw_messages(group_id, limit=limit) # 过滤掉 bot 消息并转换格式 @@ -204,7 +204,7 @@ async def _analyze_relations_with_llm( response = await self.llm_adapter.generate_response( prompt=prompt, temperature=0.7, - model_type="filter" # 使用filter模型进行分析 + model_type="filter" # 使用filter模型进行分析 ) if not response: @@ -280,8 +280,8 @@ def _build_analysis_prompt( "to_user": "用户ID", "relation_type": "关系类型(英文key)", "relation_name": "关系名称(中文)", - "strength": 0.85, // 关系强度 0.0-1.0 - "frequency": 12, // 互动次数 + "strength": 0.85, // 关系强度 0.0-1.0 + "frequency": 12, // 互动次数 "evidence": "识别依据:例如'频繁使用亲密称呼'、'讨论私密话题'、'快速回复'等" }} ] @@ -430,8 +430,8 @@ async def get_user_relations( all_relations = await self.db_manager.get_social_relations_by_group(group_id) # 筛选与该用户相关的关系 - outgoing = [] # 该用户发起的关系 - incoming = [] # 指向该用户的关系 + outgoing = [] # 该用户发起的关系 + incoming = [] # 指向该用户的关系 for rel in all_relations: if rel['from_user'] == user_id: @@ -441,8 +441,8 @@ async def get_user_relations( return { 'user_id': user_id, - 'outgoing_relations': outgoing, # 我关注的人 - 'incoming_relations': incoming, # 关注我的人 + 'outgoing_relations': outgoing, # 我关注的人 + 'incoming_relations': incoming, # 关注我的人 'total_relations': len(outgoing) + len(incoming) } diff --git a/services/sqlalchemy_database_manager.py b/services/sqlalchemy_database_manager.py deleted file mode 100644 index 1e39049..0000000 --- a/services/sqlalchemy_database_manager.py +++ /dev/null @@ -1,3518 +0,0 @@ -""" -增强型数据库管理器 - 使用 SQLAlchemy 和 Repository 模式 -与现有 DatabaseManager 接口兼容,可通过配置切换 -""" -import time -import json -import asyncio - -from typing import Dict, List, Optional, Any -from contextlib import asynccontextmanager - -from astrbot.api import logger - -from ..config import PluginConfig -from ..core.database.engine import DatabaseEngine -from ..repositories import ( - # 好感度系统 - AffectionRepository, - InteractionRepository, - ConversationHistoryRepository, - DiversityRepository, - # 记忆系统 - MemoryRepository, - MemoryEmbeddingRepository, - MemorySummaryRepository, - # 心理状态系统 - PsychologicalStateRepository, - PsychologicalComponentRepository, - PsychologicalHistoryRepository, - # 社交关系系统 - SocialProfileRepository, - SocialRelationComponentRepository, - SocialRelationHistoryRepository, -) -from ..repositories.reinforcement_repository import ( - ReinforcementLearningRepository, - PersonaFusionRepository, - StrategyOptimizationRepository, -) - - -class SQLAlchemyDatabaseManager: - """ - 基于 SQLAlchemy 的增强型数据库管理器 - - 特性: - 1. 使用 SQLAlchemy ORM 和 Repository 模式 - 2. 与现有 DatabaseManager 接口兼容 - 3. 支持 SQLite 和 MySQL - 4. 更好的类型安全和错误处理 - 5. 统一的数据访问层 - - 用法: - # 在配置中启用 - config.use_sqlalchemy = True - - # 创建管理器 - db_manager = SQLAlchemyDatabaseManager(config) - await db_manager.start() - - # 使用Repository - async with db_manager.get_session() as session: - affection_repo = AffectionRepository(session) - affection = await affection_repo.get_by_group_and_user(group_id, user_id) - """ - - def __init__(self, config: PluginConfig, context=None): - """ - 初始化数据库管理器 - - Args: - config: 插件配置 - context: 上下文(可选) - """ - self.config = config - self.context = context - self.engine: Optional[DatabaseEngine] = None - self._started = False - self._starting = False - self._start_lock = asyncio.Lock() - - # 创建传统 DatabaseManager 实例用于委托未实现的方法 - from .database_manager import DatabaseManager - self._legacy_db: Optional[DatabaseManager] = None - try: - # ✨ 传入 skip_table_init=True,让传统数据库管理器跳过表初始化 - # 因为 SQLAlchemy ORM 会通过 create_tables() 自动创建和迁移所有表 - self._legacy_db = DatabaseManager(config, context, skip_table_init=True) - logger.info("[SQLAlchemyDBManager] 初始化完成(包含传统数据库管理器后备,跳过表初始化)") - except Exception as e: - logger.warning(f"[SQLAlchemyDBManager] 初始化传统数据库管理器失败: {e},部分功能可能不可用") - logger.info("[SQLAlchemyDBManager] 初始化完成") - - @property - def db_backend(self): - """ - 提供 db_backend 属性用于向后兼容 - - 返回传统数据库管理器的 db_backend - """ - if self._legacy_db: - return self._legacy_db.db_backend - return None - - async def start(self) -> bool: - """ - 启动数据库管理器(带并发保护) - - Returns: - bool: 是否启动成功 - """ - # 使用锁防止并发启动 - async with self._start_lock: - if self._started: - logger.debug("[SQLAlchemyDBManager] 已经启动,跳过") - return True - - if self._starting: - logger.warning("[SQLAlchemyDBManager] 正在启动中,等待完成...") - # 等待启动完成 - for _ in range(50): # 最多等待5秒 - await asyncio.sleep(0.1) - if self._started: - return True - logger.error("[SQLAlchemyDBManager] 启动超时") - return False - - try: - self._starting = True - logger.info("[SQLAlchemyDBManager] 开始启动数据库管理器...") - - # 启动传统数据库管理器(用于委托未实现的方法) - if self._legacy_db: - legacy_started = await self._legacy_db.start() - if not legacy_started: - logger.warning("[SQLAlchemyDBManager] 传统数据库管理器启动失败,部分功能可能不可用") - - # 获取数据库 URL - db_url = self._get_database_url() - - # 如果是 MySQL,先确保数据库存在 - if hasattr(self.config, 'db_type') and self.config.db_type.lower() == 'mysql': - await self._ensure_mysql_database_exists() - - # 创建数据库引擎 - self.engine = DatabaseEngine(db_url, echo=False) - - logger.info("[SQLAlchemyDBManager] 数据库引擎已创建") - # 创建表结构(如果不存在) - await self.engine.create_tables() - - # 健康检查 - if await self.engine.health_check(): - logger.info("✅ [SQLAlchemyDBManager] 数据库启动成功") - self._started = True - self._starting = False - return True - else: - self._started = False - self._starting = False - logger.error("❌ [SQLAlchemyDBManager] 数据库健康检查失败") - return False - - except Exception as e: - self._started = False - self._starting = False - logger.error(f"❌ [SQLAlchemyDBManager] 启动失败: {e}", exc_info=True) - return False - - async def stop(self) -> bool: - """ - 停止数据库管理器 - - Returns: - bool: 是否停止成功 - """ - if not self._started: - return True - - try: - # ⚠️ 不停止传统数据库管理器,因为 Web UI 路由可能随时需要它 - # 传统数据库会在插件卸载时由 AstrBot 框架自动清理 - # if self._legacy_db: - # await self._legacy_db.stop() - - logger.debug("[SQLAlchemyDBManager] 保持传统数据库运行(用于 Web UI 兼容)") - - # 停止 SQLAlchemy 引擎 - if self.engine: - await self.engine.close() - - self._started = False - logger.info("✅ [SQLAlchemyDBManager] 数据库已停止(传统数据库保持运行)") - return True - - except Exception as e: - logger.error(f"❌ [SQLAlchemyDBManager] 停止失败: {e}") - return False - - def _get_database_url(self) -> str: - """ - 获取数据库连接 URL - - Returns: - str: 数据库 URL - """ - import os - - # 检查数据库类型 - if hasattr(self.config, 'db_type') and self.config.db_type.lower() == 'mysql': - # MySQL 数据库 - host = getattr(self.config, 'mysql_host', 'localhost') - port = getattr(self.config, 'mysql_port', 3306) - user = getattr(self.config, 'mysql_user', 'root') - password = getattr(self.config, 'mysql_password', '') - database = getattr(self.config, 'mysql_database', 'astrbot_self_learning') - - return f"mysql+aiomysql://{user}:{password}@{host}:{port}/{database}" - else: - # SQLite 数据库(默认) - db_path = getattr(self.config, 'messages_db_path', None) - - if not db_path: - # 使用默认路径 - db_path = os.path.join(self.config.data_dir, 'messages.db') - - # 确保路径是绝对路径 - if not os.path.isabs(db_path): - db_path = os.path.abspath(db_path) - - return f"sqlite:///{db_path}" - - async def _ensure_mysql_database_exists(self): - """ - 确保 MySQL 数据库存在,如果不存在则创建 - """ - try: - import aiomysql - - host = getattr(self.config, 'mysql_host', 'localhost') - port = getattr(self.config, 'mysql_port', 3306) - user = getattr(self.config, 'mysql_user', 'root') - password = getattr(self.config, 'mysql_password', '') - database = getattr(self.config, 'mysql_database', 'astrbot_self_learning') - - # 先连接到 MySQL 服务器(不指定数据库) - conn = await aiomysql.connect( - host=host, - port=port, - user=user, - password=password, - charset='utf8mb4' - ) - - try: - async with conn.cursor() as cursor: - # 检查数据库是否存在 - await cursor.execute( - "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = %s", - (database,) - ) - result = await cursor.fetchone() - - if not result: - # 数据库不存在,创建它 - logger.info(f"[SQLAlchemyDBManager] 数据库 {database} 不存在,正在创建...") - await cursor.execute( - f"CREATE DATABASE `{database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" - ) - await conn.commit() - logger.info(f"✅ [SQLAlchemyDBManager] 数据库 {database} 创建成功") - else: - logger.debug(f"[SQLAlchemyDBManager] 数据库 {database} 已存在") - - finally: - conn.close() - - except Exception as e: - logger.error(f"❌ [SQLAlchemyDBManager] 确保 MySQL 数据库存在失败: {e}") - raise - - @asynccontextmanager - async def get_session(self): - """ - 获取数据库会话(上下文管理器) - - 改进: 更宽松的状态检查,检查 engine 是否可用而不是严格依赖 _started 标志 - 这样可以避免在并发场景下的状态不一致问题 - - 用法: - async with db_manager.get_session() as session: - repo = AffectionRepository(session) - result = await repo.get_by_id(1) - """ - # ✅ 改进:检查 engine 是否存在,而不是仅依赖 _started 标志 - # 这样可以处理启动过程中的并发访问 - if not self.engine: - # 如果正在启动,等待一小段时间 - if self._starting: - logger.debug("[SQLAlchemyDBManager] 数据库正在启动中,等待engine创建...") - for _ in range(30): # 最多等待3秒 - await asyncio.sleep(0.1) - if self.engine: - break - - if not self.engine: - raise RuntimeError("数据库管理器启动超时,engine未创建") - else: - raise RuntimeError("数据库管理器未启动,engine不存在") - - # DatabaseEngine.get_session() 自动适配当前 event loop, - # 跨线程调用时会创建独立引擎,无需手动处理 - if not self._started: - logger.debug("[SQLAlchemyDBManager] get_session: _started=False 但 engine 存在,继续执行") - - session = self.engine.get_session() - try: - async with session: - yield session - finally: - await session.close() - - # ============================================================ - # 兼容现有 DatabaseManager 接口的方法 - # 这些方法使用 Repository 实现,但保持与旧接口相同 - # ============================================================ - - async def get_user_affection( - self, - group_id: str, - user_id: str - ) -> Optional[Dict[str, Any]]: - """ - 获取用户好感度(兼容接口) - - Args: - group_id: 群组 ID - user_id: 用户 ID - - Returns: - Optional[Dict]: 好感度数据 - """ - try: - async with self.get_session() as session: - repo = AffectionRepository(session) - affection = await repo.get_by_group_and_user(group_id, user_id) - - if affection: - return { - 'group_id': affection.group_id, - 'user_id': affection.user_id, - 'affection_level': affection.affection_level, - 'max_affection': affection.max_affection, - 'created_at': affection.created_at, - 'updated_at': affection.updated_at, - } - return None - - except Exception as e: - logger.error(f"[SQLAlchemyDBManager] 获取好感度失败: {e}") - return None - - async def update_user_affection( - self, - group_id: str, - user_id: str, - new_level: int, - change_reason: str = "", - bot_mood: str = "" - ) -> bool: - """ - 更新用户好感度(兼容接口) - - Args: - group_id: 群组 ID - user_id: 用户 ID - new_level: 新的好感度等级 - change_reason: 变化原因 - bot_mood: 机器人情绪状态 - - Returns: - bool: 是否更新成功 - """ - try: - async with self.get_session() as session: - repo = AffectionRepository(session) - - # 获取当前好感度以计算delta - current = await repo.get_by_group_and_user(group_id, user_id) - previous_level = current.affection_level if current else 0 - affection_delta = new_level - previous_level - - # 使用 Repository 的 update_level 方法 - affection = await repo.update_level( - group_id, - user_id, - affection_delta, - max_affection=100 # 默认最大值 - ) - - # TODO: 如果需要记录 change_reason 和 bot_mood,需要扩展 Repository - # 当前版本忽略这些参数,保持向后兼容 - - return affection is not None - - except Exception as e: - logger.error(f"[SQLAlchemyDBManager] 更新好感度失败: {e}") - return False - - async def get_all_user_affections( - self, - group_id: str - ) -> List[Dict[str, Any]]: - """ - 获取群组所有用户好感度(兼容接口) - - Args: - group_id: 群组 ID - - Returns: - List[Dict]: 好感度列表 - """ - try: - async with self.get_session() as session: - repo = AffectionRepository(session) - affections = await repo.find_many(group_id=group_id) - - return [ - { - 'group_id': a.group_id, - 'user_id': a.user_id, - 'affection_level': a.affection_level, - 'max_affection': a.max_affection, - 'created_at': a.created_at, - 'updated_at': a.updated_at, - } - for a in affections - ] - - except Exception as e: - logger.error(f"[SQLAlchemyDBManager] 获取所有好感度失败: {e}") - return [] - - async def get_total_affection(self, group_id: str) -> int: - """ - 获取群组总好感度(兼容接口) - - Args: - group_id: 群组 ID - - Returns: - int: 总好感度 - """ - try: - async with self.get_session() as session: - repo = AffectionRepository(session) - return await repo.get_total_affection(group_id) - - except Exception as e: - logger.error(f"[SQLAlchemyDBManager] 获取总好感度失败: {e}") - return 0 - - async def save_bot_mood( - self, - group_id: str, - mood_type: str, - mood_intensity: float, - mood_description: str, - duration_hours: int = 24 - ) -> bool: - """ - 保存bot情绪状态(兼容接口) - - 注意: 这个方法暂时保持原有实现,因为情绪系统 - 还没有对应的ORM模型。后续可以添加BotMood模型。 - - Args: - group_id: 群组 ID - mood_type: 情绪类型 - mood_intensity: 情绪强度 - mood_description: 情绪描述 - duration_hours: 持续时间(小时) - - Returns: - bool: 是否保存成功 - """ - # TODO: 等待 BotMood ORM 模型创建后实现 - logger.debug(f"[SQLAlchemyDBManager] save_bot_mood 暂未实现,使用原有实现") - return True - - # ============================================================ - # Repository 访问方法(新增) - # 直接返回 Repository 实例,供高级用法使用 - # ============================================================ - - def get_affection_repo(self, session) -> AffectionRepository: - """获取好感度 Repository""" - return AffectionRepository(session) - - def get_interaction_repo(self, session) -> InteractionRepository: - """获取互动记录 Repository""" - return InteractionRepository(session) - - def get_conversation_repo(self, session) -> ConversationHistoryRepository: - """获取对话历史 Repository""" - return ConversationHistoryRepository(session) - - def get_diversity_repo(self, session) -> DiversityRepository: - """获取多样性 Repository""" - return DiversityRepository(session) - - def get_memory_repo(self, session) -> MemoryRepository: - """获取记忆 Repository""" - return MemoryRepository(session) - - def get_psychological_repo(self, session) -> PsychologicalStateRepository: - """获取心理状态 Repository""" - return PsychologicalStateRepository(session) - - def get_social_profile_repo(self, session) -> SocialProfileRepository: - """获取社交档案 Repository""" - return SocialProfileRepository(session) - - # ============================================================ - # 工具方法 - # ============================================================ - - def is_started(self) -> bool: - """检查是否已启动""" - return self._started - - async def health_check(self) -> bool: - """健康检查""" - if not self.engine: - return False - return await self.engine.health_check() - - def get_engine_info(self) -> dict: - """获取引擎信息""" - if not self.engine: - return {} - return self.engine.get_engine_info() - - # ============================================================ - # 兼容性方法 - 优先使用现代 Repository 实现,失败时降级 - # ============================================================ - - async def get_user_social_relations(self, group_id: str, user_id: str) -> Dict[str, Any]: - """ - 获取用户社交关系 - - 优先使用 SQLAlchemy Repository 实现,失败时降级到传统实现 - """ - try: - # 尝试使用 Repository 实现 - async with self.get_session() as session: - from sqlalchemy import select, and_, or_ - from ..models.orm import UserSocialRelationComponent - - # 构建用户标识(支持两种格式) - user_keys = [user_id, f"{group_id}:{user_id}"] - - # 查询用户发起的关系 - stmt_outgoing = select(UserSocialRelationComponent).where( - and_( - UserSocialRelationComponent.group_id == group_id, - or_(*[UserSocialRelationComponent.from_user_id == key for key in user_keys]) # ✅ 修正字段名 - ) - ).order_by( - UserSocialRelationComponent.frequency.desc(), - UserSocialRelationComponent.value.desc() # ✅ 修正字段名 strength → value - ).limit(self.config.default_social_limit) - - result = await session.execute(stmt_outgoing) - outgoing_relations = result.scalars().all() - - # 查询指向用户的关系 - stmt_incoming = select(UserSocialRelationComponent).where( - and_( - UserSocialRelationComponent.group_id == group_id, - or_(*[UserSocialRelationComponent.to_user_id == key for key in user_keys]) # ✅ 修正字段名 - ) - ).order_by( - UserSocialRelationComponent.frequency.desc(), - UserSocialRelationComponent.value.desc() # ✅ 修正字段名 strength → value - ).limit(self.config.default_social_limit) - - result = await session.execute(stmt_incoming) - incoming_relations = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 使用 Repository 查询社交关系: {user_id} in {group_id}") - - return { - 'user_id': user_id, - 'group_id': group_id, - 'outgoing': [ - { - 'from_user': r.from_user_id, # ✅ 修正字段名 - 'to_user': r.to_user_id, # ✅ 修正字段名 - 'relation_type': r.relation_type, - 'strength': r.value, # ✅ 修正字段名 strength → value - 'frequency': r.frequency, - 'last_interaction': r.last_interaction # ✅ 修正字段名 - } - for r in outgoing_relations - ], - 'incoming': [ - { - 'from_user': r.from_user_id, # ✅ 修正字段名 - 'to_user': r.to_user_id, # ✅ 修正字段名 - 'relation_type': r.relation_type, - 'strength': r.value, # ✅ 修正字段名 strength → value - 'frequency': r.frequency, - 'last_interaction': r.last_interaction # ✅ 修正字段名 - } - for r in incoming_relations - ], - 'total_relations': len(outgoing_relations) + len(incoming_relations) - } - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 查询社交关系失败: {e}") - raise RuntimeError(f"无法获取用户社交关系: {e}") from e - - async def get_reviewed_persona_learning_updates( - self, - limit: int = 50, - offset: int = 0, - status_filter: str = None - ) -> List[Dict[str, Any]]: - """ - 获取已审查的人格学习更新 - - 优先使用 SQLAlchemy Repository 实现,失败时降级到传统实现 - """ - try: - async with self.get_session() as session: - from ..repositories.learning_repository import PersonaLearningReviewRepository - - repo = PersonaLearningReviewRepository(session) - reviews = await repo.get_reviewed_updates(limit, offset, status_filter) - - logger.debug(f"[SQLAlchemy] 使用 Repository 查询已审查人格更新: {len(reviews)} 条") - - return [ - { - 'id': review.id, - 'group_id': review.group_id, - 'timestamp': review.timestamp, - 'update_type': review.update_type, - 'original_content': review.original_content, - 'new_content': review.new_content, - 'reason': review.reason, - 'confidence': review.confidence_score, - 'status': review.status, - 'reviewer_comment': review.reviewer_comment, - 'review_time': review.review_time - } - for review in reviews - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 查询已审查人格更新失败: {e}") - raise RuntimeError(f"无法获取已审查人格更新: {e}") from e - - async def get_trends_data(self) -> Dict[str, Any]: - """ - 获取趋势数据 - - 使用 SQLAlchemy Repository 实现,支持跨线程调用(NullPool),基于现有数据计算趋势 - """ - try: - # 尝试使用 Repository 计算趋势 - async with self.get_session() as session: - from sqlalchemy import select, func, cast, Date - from ..models.orm import UserAffection, InteractionRecord - from datetime import datetime, timedelta - - # 计算趋势的天数范围(使用配置中的 trend_analysis_days) - days_ago = int((datetime.now() - timedelta(days=self.config.trend_analysis_days)).timestamp()) - - # 根据数据库类型选择日期转换函数 - is_mysql = self.config.db_type.lower() == 'mysql' - - if is_mysql: - # MySQL: 使用 FROM_UNIXTIME 和 DATE - date_func_affection = func.date(func.from_unixtime(UserAffection.updated_at)) - date_func_interaction = func.date(func.from_unixtime(InteractionRecord.timestamp)) - else: - # SQLite: 使用 datetime(timestamp, 'unixepoch') 和 date() - date_func_affection = func.date(UserAffection.updated_at, 'unixepoch') - date_func_interaction = func.date(InteractionRecord.timestamp, 'unixepoch') - - # 好感度趋势(按天统计) - affection_stmt = select( - date_func_affection.label('date'), - func.avg(UserAffection.affection_level).label('avg_affection'), - func.count(UserAffection.id).label('count') - ).where( - UserAffection.updated_at >= days_ago - ).group_by( - date_func_affection - ).order_by('date') - - affection_result = await session.execute(affection_stmt) - affection_trend = [ - { - 'date': str(row.date), - 'avg_affection': float(row.avg_affection) if row.avg_affection else 0.0, - 'count': row.count - } - for row in affection_result - ] - - # 互动趋势(按天统计) - interaction_stmt = select( - date_func_interaction.label('date'), - func.count(InteractionRecord.id).label('count') - ).where( - InteractionRecord.timestamp >= days_ago - ).group_by( - date_func_interaction - ).order_by('date') - - interaction_result = await session.execute(interaction_stmt) - interaction_trend = [ - { - 'date': str(row.date), - 'count': row.count - } - for row in interaction_result - ] - - logger.debug("[SQLAlchemy] 使用 Repository 计算趋势数据") - - return { - "affection_trend": affection_trend, - "interaction_trend": interaction_trend, - "learning_trend": [] # 学习趋势需要学习记录表 - } - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 计算趋势数据失败: {e}") - raise RuntimeError(f"无法获取趋势数据: {e}") from e - - async def get_style_learning_statistics(self) -> Dict[str, Any]: - """ - 获取风格学习统计 - - 使用 SQLAlchemy Repository 实现,支持跨线程调用(NullPool) - """ - try: - async with self.get_session() as session: - from ..repositories.learning_repository import StyleLearningReviewRepository - - repo = StyleLearningReviewRepository(session) - statistics = await repo.get_statistics() - - logger.debug("[SQLAlchemy] 使用 Repository 计算风格学习统计") - - return statistics - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取风格学习统计失败: {e}") - raise RuntimeError(f"无法获取风格学习统计: {e}") from e - - async def get_pending_persona_learning_reviews(self, limit: int = None) -> List[Dict[str, Any]]: - """ - 获取待审查的人格学习更新 - - 使用 SQLAlchemy Repository 实现,支持跨线程调用(NullPool) - - Args: - limit: 最大返回数量(None则使用配置中的default_review_limit) - """ - if limit is None: - limit = self.config.default_review_limit - - try: - async with self.get_session() as session: - from ..repositories.learning_repository import PersonaLearningReviewRepository - - repo = PersonaLearningReviewRepository(session) - reviews = await repo.get_pending_reviews(limit) - - logger.debug(f"[SQLAlchemy] 使用 Repository 查询待审查人格更新: {len(reviews)} 条") - - # 解析 metadata JSON 字符串 - import json - result = [] - for review in reviews: - # 解析 metadata 字段(如果是字符串) - metadata = review.metadata_ - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) if metadata else {} - except json.JSONDecodeError: - metadata = {} - elif metadata is None: - metadata = {} - - result.append({ - 'id': review.id, - 'group_id': review.group_id, - 'timestamp': review.timestamp, - 'update_type': review.update_type, - 'original_content': review.original_content, - 'new_content': review.new_content, - 'proposed_content': review.proposed_content, - 'confidence_score': review.confidence_score, - 'reason': review.reason, - 'status': review.status, - 'reviewer_comment': review.reviewer_comment, - 'review_time': review.review_time, - 'metadata': metadata # 已解析为字典 - }) - - return result - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 查询待审查人格更新失败: {e}") - raise RuntimeError(f"无法获取待审查人格更新: {e}") from e - - async def get_pending_style_reviews(self, limit: int = None) -> List[Dict[str, Any]]: - """ - 获取待审查的风格学习更新 - - 使用 SQLAlchemy Repository 实现,支持跨线程调用(NullPool) - - Args: - limit: 最大返回数量(None则使用配置中的default_review_limit) - """ - if limit is None: - limit = self.config.default_review_limit - - try: - async with self.get_session() as session: - from ..repositories.learning_repository import StyleLearningReviewRepository - - repo = StyleLearningReviewRepository(session) - reviews = await repo.get_pending_reviews(limit) - - logger.debug(f"[SQLAlchemy] 使用 Repository 查询待审查风格更新: {len(reviews)} 条") - - return [ - { - 'id': review.id, - 'type': review.type, # 使用 type 而不是 pattern_type - 'group_id': review.group_id, - 'timestamp': review.timestamp, - 'learned_patterns': review.learned_patterns, # JSON格式 - 'few_shots_content': review.few_shots_content, - 'status': review.status, - 'description': review.description, - 'created_at': review.created_at - } - for review in reviews - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 查询待审查风格更新失败: {e}") - raise RuntimeError(f"无法获取待审查风格更新: {e}") from e - - async def get_reviewed_style_learning_updates( - self, - limit: int = None, - offset: int = 0, - status_filter: str = None - ) -> List[Dict[str, Any]]: - """ - 获取已审查的风格学习更新 - - 使用 SQLAlchemy Repository 实现,支持跨线程调用(NullPool) - - Args: - limit: 最大返回数量(None则使用配置中的default_review_limit) - offset: 偏移量 - status_filter: 状态过滤('approved', 'rejected', None表示全部) - - Returns: - List[Dict]: 已审查的风格学习记录列表 - """ - if limit is None: - limit = self.config.default_review_limit - - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm.learning import StyleLearningReview - - # 构建查询 - stmt = select(StyleLearningReview) - - # 状态过滤 - if status_filter: - stmt = stmt.where(StyleLearningReview.status == status_filter) - else: - # 只查询非 pending 状态的记录 - stmt = stmt.where(StyleLearningReview.status != 'pending') - - # 按时间倒序排列 - stmt = stmt.order_by(StyleLearningReview.review_time.desc()) - - # 分页 - stmt = stmt.offset(offset).limit(limit) - - result = await session.execute(stmt) - reviews = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询已审查风格更新: {len(reviews)} 条 (状态={status_filter})") - - return [ - { - 'id': review.id, - 'type': review.type, - 'group_id': review.group_id, - 'timestamp': review.timestamp, - 'learned_patterns': review.learned_patterns, - 'few_shots_content': review.few_shots_content, - 'status': review.status, - 'description': review.description, - 'reviewer_comment': review.reviewer_comment, - 'review_time': review.review_time, - 'created_at': review.created_at - } - for review in reviews - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 查询已审查风格更新失败: {e}") - raise RuntimeError(f"无法获取已审查风格更新: {e}") from e - - async def update_style_review_status( - self, - review_id: int, - status: str, - reviewer_comment: str = None - ) -> bool: - """ - 更新风格审查状态 - - 优先使用 SQLAlchemy Repository 实现,失败时降级到传统实现 - """ - try: - async with self.get_session() as session: - from ..repositories.learning_repository import StyleLearningReviewRepository - - repo = StyleLearningReviewRepository(session) - success = await repo.update_review_status(review_id, status, reviewer_comment) - - if success: - logger.debug(f"[SQLAlchemy] 使用 Repository 更新风格审查状态: {review_id} -> {status}") - - return success - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 更新风格审查状态失败: {e}") - raise RuntimeError(f"无法更新风格审查状态: {e}") from e - - async def delete_persona_learning_review_by_id(self, review_id: int) -> bool: - """ - 删除人格学习审查记录 - - 优先使用 SQLAlchemy Repository 实现,失败时降级到传统实现 - """ - try: - async with self.get_session() as session: - from ..repositories.learning_repository import PersonaLearningReviewRepository - - repo = PersonaLearningReviewRepository(session) - success = await repo.delete_by_id(review_id) - - if success: - logger.debug(f"[SQLAlchemy] 使用 Repository 删除人格学习审查: {review_id}") - - return success - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 删除人格学习审查失败: {e}") - raise RuntimeError(f"无法删除人格学习审查: {e}") from e - - async def add_persona_learning_review( - self, - group_id: str, - proposed_content: str, - learning_source: str = "expression_learning", - confidence_score: float = 0.5, - raw_analysis: str = "", - metadata: Dict[str, Any] = None, - original_content: str = "", - new_content: str = "" - ) -> int: - """ - 添加人格学习审查记录 - - 使用 SQLAlchemy ORM 实现,支持跨线程调用(NullPool) - - Args: - group_id: 群组ID - proposed_content: 建议的增量人格内容 - learning_source: 学习来源 - confidence_score: 置信度分数 - raw_analysis: 原始分析结果 - metadata: 元数据 - original_content: 原人格完整文本 - new_content: 新人格完整文本 - - Returns: - int: 插入记录的ID - """ - try: - async with self.get_session() as session: - from ..models.orm.learning import PersonaLearningReview - import time - import json - - # 创建记录 - review = PersonaLearningReview( - group_id=group_id, - timestamp=time.time(), # ✅ 使用 Float 类型(与 ORM 模型定义一致) - update_type=learning_source, - original_content=original_content, - new_content=new_content, - proposed_content=proposed_content, - confidence_score=confidence_score, - reason=raw_analysis, - status='pending', - reviewer_comment=None, - review_time=None, - metadata_=json.dumps(metadata) if metadata else None, - # ❌ 移除 created_at - PersonaLearningReview 模型没有此字段 - ) - - session.add(review) - await session.commit() - await session.refresh(review) - - logger.debug(f"[SQLAlchemy] 已添加人格学习审查记录: ID={review.id}, group={group_id}") - return review.id - - except Exception as e: - logger.error(f"[SQLAlchemy] 添加人格学习审查记录失败: {e}", exc_info=True) - raise RuntimeError(f"无法添加人格学习审查记录: {e}") from e - - async def get_messages_statistics(self) -> Dict[str, Any]: - """ - 获取消息统计信息 - - 使用 SQLAlchemy ORM 实现,支持跨线程调用(NullPool) - 统计 raw_messages 和 filtered_messages 表的数据 - - Returns: - Dict[str, Any]: 统计信息 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..models.orm import RawMessage, FilteredMessage - - # 统计原始消息数量 - total_stmt = select(func.count()).select_from(RawMessage) - total_result = await session.execute(total_stmt) - total_messages = total_result.scalar() or 0 - - # 统计筛选后消息数量 - filtered_stmt = select(func.count()).select_from(FilteredMessage) - filtered_result = await session.execute(filtered_stmt) - filtered_messages = filtered_result.scalar() or 0 - - # 计算筛选率 - filter_rate = (filtered_messages / total_messages * 100) if total_messages > 0 else 0.0 - - return { - "total_messages": total_messages, - "filtered_messages": filtered_messages, - "filter_rate": round(filter_rate, 2) - } - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取消息统计失败: {e}") - raise RuntimeError(f"无法获取消息统计: {e}") from e - - # ============================================================ - # 强化学习 / 人格融合 / 策略优化 / 性能记录(ORM 实现) - # ============================================================ - - async def get_learning_history_for_reinforcement(self, group_id: str, limit: int = 50) -> List[Dict[str, Any]]: - """获取用于强化学习的历史数据(ORM)""" - try: - async with self.get_session() as session: - from sqlalchemy import select, desc - from ..models.orm.performance import LearningPerformanceHistory - - stmt = ( - select(LearningPerformanceHistory) - .where(LearningPerformanceHistory.group_id == group_id) - .order_by(desc(LearningPerformanceHistory.timestamp)) - .limit(limit) - ) - result = await session.execute(stmt) - rows = result.scalars().all() - - return [ - { - 'timestamp': row.timestamp, - 'quality_score': row.quality_score or 0.0, - 'success': bool(row.success), - 'successful_pattern': row.successful_pattern or '', - 'failed_pattern': row.failed_pattern or '' - } - for row in rows - ] - except Exception as e: - logger.error(f"[SQLAlchemy] 获取强化学习历史数据失败: {e}") - return [] - - async def save_reinforcement_learning_result(self, group_id: str, result_data: Dict[str, Any]) -> bool: - """保存强化学习结果(ORM)""" - try: - async with self.get_session() as session: - repo = ReinforcementLearningRepository(session) - return await repo.save_reinforcement_result(group_id, result_data) - except Exception as e: - logger.error(f"[SQLAlchemy] 保存强化学习结果失败: {e}") - return False - - async def get_persona_fusion_history(self, group_id: str, limit: int = 10) -> List[Dict[str, Any]]: - """获取人格融合历史(ORM)""" - try: - async with self.get_session() as session: - repo = PersonaFusionRepository(session) - return await repo.get_fusion_history(group_id, limit) - except Exception as e: - logger.error(f"[SQLAlchemy] 获取人格融合历史失败: {e}") - return [] - - async def save_persona_fusion_result(self, group_id: str, fusion_data: Dict[str, Any]) -> bool: - """保存人格融合结果(ORM)""" - try: - async with self.get_session() as session: - repo = PersonaFusionRepository(session) - return await repo.save_fusion_result(group_id, fusion_data) - except Exception as e: - logger.error(f"[SQLAlchemy] 保存人格融合结果失败: {e}") - return False - - async def get_learning_performance_history(self, group_id: str, limit: int = 30) -> List[Dict[str, Any]]: - """获取学习性能历史数据(ORM)""" - try: - async with self.get_session() as session: - from sqlalchemy import select, desc - from ..models.orm.performance import LearningPerformanceHistory - - stmt = ( - select(LearningPerformanceHistory) - .where(LearningPerformanceHistory.group_id == group_id) - .order_by(desc(LearningPerformanceHistory.timestamp)) - .limit(limit) - ) - result = await session.execute(stmt) - rows = result.scalars().all() - - return [ - { - 'session_id': row.session_id, - 'timestamp': row.timestamp, - 'quality_score': row.quality_score or 0.0, - 'learning_time': row.learning_time or 0.0, - 'success': bool(row.success) - } - for row in rows - ] - except Exception as e: - logger.error(f"[SQLAlchemy] 获取学习性能历史失败: {e}") - return [] - - async def save_strategy_optimization_result(self, group_id: str, optimization_data: Dict[str, Any]) -> bool: - """保存策略优化结果(ORM)""" - try: - async with self.get_session() as session: - repo = StrategyOptimizationRepository(session) - return await repo.save_optimization_result(group_id, optimization_data) - except Exception as e: - logger.error(f"[SQLAlchemy] 保存策略优化结果失败: {e}") - return False - - async def get_messages_for_replay(self, group_id: str, days: int = 30, limit: int = 100) -> List[Dict[str, Any]]: - """获取用于记忆重放的消息(ORM)""" - try: - async with self.get_session() as session: - from sqlalchemy import select, desc, and_ - from ..models.orm import RawMessage - - cutoff_time = time.time() - (days * 24 * 3600) - - stmt = ( - select(RawMessage) - .where(and_( - RawMessage.group_id == group_id, - RawMessage.timestamp > cutoff_time, - RawMessage.processed == True - )) - .order_by(desc(RawMessage.timestamp)) - .limit(limit) - ) - result = await session.execute(stmt) - messages = result.scalars().all() - - return [ - { - 'message_id': msg.id, - 'message': msg.message, - 'sender_id': msg.sender_id, - 'group_id': msg.group_id, - 'timestamp': msg.timestamp - } - for msg in messages - ] - except Exception as e: - logger.error(f"[SQLAlchemy] 获取记忆重放消息失败: {e}") - return [] - - async def get_message_statistics(self, group_id: str = None) -> Dict[str, Any]: - """获取消息统计信息(ORM,兼容 webui.py 的调用)""" - if not group_id: - return await self.get_messages_statistics() - - try: - async with self.get_session() as session: - from sqlalchemy import select, func, and_ - from ..models.orm import RawMessage, FilteredMessage - - # 总消息数 - total_stmt = select(func.count()).select_from(RawMessage).where( - RawMessage.group_id == group_id - ) - total_result = await session.execute(total_stmt) - total_messages = total_result.scalar() or 0 - - # 未处理消息数 - unprocessed_stmt = select(func.count()).select_from(RawMessage).where(and_( - RawMessage.group_id == group_id, - RawMessage.processed == False - )) - unprocessed_result = await session.execute(unprocessed_stmt) - unprocessed_messages = unprocessed_result.scalar() or 0 - - # 筛选消息数 - filtered_stmt = select(func.count()).select_from(FilteredMessage).where( - FilteredMessage.group_id == group_id - ) - filtered_result = await session.execute(filtered_stmt) - filtered_messages = filtered_result.scalar() or 0 - - return { - 'total_messages': total_messages, - 'unprocessed_messages': unprocessed_messages, - 'filtered_messages': filtered_messages, - 'raw_messages': total_messages, - 'group_id': group_id - } - except Exception as e: - logger.error(f"[SQLAlchemy] 获取消息统计失败: {e}") - return { - 'total_messages': 0, - 'unprocessed_messages': 0, - 'filtered_messages': 0, - 'raw_messages': 0, - 'group_id': group_id - } - - async def get_all_expression_patterns(self) -> Dict[str, List[Dict[str, Any]]]: - """ - 获取所有群组的表达模式 - - 使用 SQLAlchemy Repository 实现,支持跨线程调用 - - Returns: - Dict[str, List[Dict[str, Any]]]: 群组ID -> 表达模式列表的映射 - """ - try: - # 直接使用 ORM,引擎已配置支持多线程 - # SQLite: check_same_thread=False - # MySQL: NullPool 每次都创建新连接 - async with self.get_session() as session: - from ..repositories.expression_repository import ExpressionPatternRepository - - repo = ExpressionPatternRepository(session) - patterns_by_group = await repo.get_all_patterns() - - logger.debug(f"[SQLAlchemy] 使用 Repository 获取所有表达模式: {len(patterns_by_group)} 个群组") - - # 转换为 WebUI 所需的字典格式 - result = {} - for group_id, patterns in patterns_by_group.items(): - result[group_id] = [ - { - 'situation': pattern.situation, - 'expression': pattern.expression, - 'weight': pattern.weight, - 'last_active_time': pattern.last_active_time, - 'created_time': pattern.create_time, - 'group_id': pattern.group_id, - 'style_type': 'general' # 兼容字段 - } - for pattern in patterns - ] - - return result - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 获取表达模式失败: {e}") - raise RuntimeError(f"无法获取表达模式: {e}") from e - - async def get_expression_patterns_statistics(self) -> Dict[str, Any]: - """ - 获取表达模式统计信息 - - 优先使用 SQLAlchemy Repository 实现,失败时降级到传统实现 - - Returns: - Dict[str, Any]: 统计信息 - """ - try: - async with self.get_session() as session: - from ..repositories.expression_repository import ExpressionPatternRepository - - repo = ExpressionPatternRepository(session) - stats = await repo.get_statistics() - - logger.debug(f"[SQLAlchemy] 使用 Repository 获取表达模式统计: {stats}") - - return stats - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 获取表达模式统计失败: {e}") - raise RuntimeError(f"无法获取表达模式统计: {e}") from e - - async def get_group_expression_patterns(self, group_id: str, limit: int = None) -> List[Dict[str, Any]]: - """ - 获取指定群组的表达模式 - - 优先使用 SQLAlchemy Repository 实现,失败时降级到传统实现 - - Args: - group_id: 群组ID - limit: 最大返回数量(None则使用配置中的default_pattern_limit) - - Returns: - List[Dict[str, Any]]: 表达模式列表(按权重降序) - """ - if limit is None: - limit = self.config.default_pattern_limit - - try: - async with self.get_session() as session: - from ..repositories.expression_repository import ExpressionPatternRepository - - repo = ExpressionPatternRepository(session) - patterns = await repo.get_patterns_by_group(group_id, limit) - - logger.debug(f"[SQLAlchemy] 使用 Repository 获取群组 {group_id} 的表达模式: {len(patterns)} 条") - - return [ - { - 'situation': pattern.situation, - 'expression': pattern.expression, - 'weight': pattern.weight, - 'last_active_time': pattern.last_active_time, - 'created_time': pattern.create_time, - 'group_id': pattern.group_id, - 'style_type': 'general' # 兼容字段 - } - for pattern in patterns - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] Repository 获取群组表达模式失败: {e}") - raise RuntimeError(f"无法获取群组表达模式: {e}") from e - - # ======================================== - # 社交关系系统方法(使用新ORM表) - # ======================================== - - async def get_social_relations_by_group(self, group_id: str) -> List[Dict[str, Any]]: - """ - 获取指定群组的社交关系(使用新ORM表) - - Args: - group_id: 群组ID - - Returns: - List[Dict[str, Any]]: 社交关系列表 - """ - try: - async with self.get_session() as session: - # 使用新的 user_social_relation_components 表 - from sqlalchemy import select - from ..models.orm.social_relation import UserSocialRelationComponent - - # 查询该群组的所有社交关系组件 - stmt = select(UserSocialRelationComponent).where( - UserSocialRelationComponent.group_id == group_id - ).order_by( - UserSocialRelationComponent.frequency.desc(), - UserSocialRelationComponent.value.desc() - ) - - result = await session.execute(stmt) - components = result.scalars().all() - - # 转换为旧格式的字典(保持向后兼容) - relations = [] - for comp in components: - relations.append({ - 'from_user': f"{comp.group_id}:{comp.from_user_id}", # 兼容旧格式 - 'to_user': f"{comp.group_id}:{comp.to_user_id}", - 'relation_type': comp.relation_type, - 'strength': float(comp.value), # value 对应 strength - 'frequency': int(comp.frequency), - 'last_interaction': comp.last_interaction - }) - - logger.info(f"[SQLAlchemy] 群组 {group_id} 加载了 {len(relations)} 条社交关系") - return relations - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取社交关系失败: {e}", exc_info=True) - return [] - - async def load_social_graph(self, group_id: str) -> List[Dict[str, Any]]: - """ - 加载社交图谱(使用新ORM表) - - Args: - group_id: 群组ID - - Returns: - List[Dict[str, Any]]: 社交关系列表 - """ - # load_social_graph 与 get_social_relations_by_group 功能相同 - return await self.get_social_relations_by_group(group_id) - - async def get_user_social_relations(self, group_id: str, user_id: str) -> Dict[str, Any]: - """ - 获取指定用户在群组中的社交关系(使用新ORM表) - - Args: - group_id: 群组ID - user_id: 用户ID - - Returns: - Dict: 包含用户社交关系的字典 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, or_ - from ..models.orm.social_relation import UserSocialRelationComponent - - # 查询该用户发起或接收的所有关系 - stmt = select(UserSocialRelationComponent).where( - UserSocialRelationComponent.group_id == group_id - ).where( - or_( - UserSocialRelationComponent.from_user_id == user_id, - UserSocialRelationComponent.to_user_id == user_id - ) - ).order_by( - UserSocialRelationComponent.frequency.desc(), - UserSocialRelationComponent.value.desc() - ).limit(10) - - result = await session.execute(stmt) - components = result.scalars().all() - - # 分类为发起关系和接收关系 - outgoing_relations = [] - incoming_relations = [] - - for comp in components: - relation_dict = { - 'from_user': f"{comp.group_id}:{comp.from_user_id}", - 'to_user': f"{comp.group_id}:{comp.to_user_id}", - 'relation_type': comp.relation_type, - 'strength': float(comp.value), - 'frequency': int(comp.frequency), - 'last_interaction': comp.last_interaction - } - - if comp.from_user_id == user_id: - outgoing_relations.append(relation_dict) - else: - incoming_relations.append(relation_dict) - - return { - 'outgoing': outgoing_relations, - 'incoming': incoming_relations, - 'total_relations': len(components) - } - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取用户社交关系失败: {e}", exc_info=True) - return {'outgoing': [], 'incoming': [], 'total_relations': 0} - - async def save_social_relation(self, group_id: str, relation_data: Dict[str, Any]): - """ - 保存社交关系(使用新ORM表) - - Args: - group_id: 群组ID - relation_data: 关系数据 - """ - try: - async with self.get_session() as session: - from ..models.orm.social_relation import UserSocialRelationComponent, UserSocialProfile - from sqlalchemy import select - import time - from datetime import datetime - - # 解析 from_user 和 to_user(兼容旧格式 "group_id:user_id") - from_user = relation_data.get('from_user', '') - to_user = relation_data.get('to_user', '') - - # 提取用户ID(如果包含 group_id:) - from_user_id = from_user.split(':')[-1] if ':' in from_user else from_user - to_user_id = to_user.split(':')[-1] if ':' in to_user else to_user - - # 处理 last_interaction 时间戳(支持 ISO 格式字符串和数值) - last_interaction_raw = relation_data.get('last_interaction', time.time()) - if isinstance(last_interaction_raw, str): - # ISO 格式字符串 -> Unix 时间戳 - try: - dt = datetime.fromisoformat(last_interaction_raw.replace('Z', '+00:00')) - last_interaction = int(dt.timestamp()) - except (ValueError, AttributeError): - last_interaction = int(time.time()) - elif isinstance(last_interaction_raw, (int, float)): - last_interaction = int(last_interaction_raw) - else: - last_interaction = int(time.time()) - - # 获取或创建 from_user 的社交档案 - stmt = select(UserSocialProfile).where( - UserSocialProfile.user_id == from_user_id, - UserSocialProfile.group_id == group_id - ) - result = await session.execute(stmt) - profile = result.scalars().first() - - if not profile: - # 创建新的用户社交档案 - profile = UserSocialProfile( - user_id=from_user_id, - group_id=group_id, - total_relations=0, - significant_relations=0, - created_at=int(time.time()), - last_updated=int(time.time()) - ) - session.add(profile) - await session.flush() # 确保获得 profile.id - - # 创建新的社交关系组件 - component = UserSocialRelationComponent( - profile_id=profile.id, - from_user_id=from_user_id, - to_user_id=to_user_id, - group_id=group_id, - relation_type=relation_data.get('relation_type', 'unknown'), - value=float(relation_data.get('strength', 0.0)), - frequency=int(relation_data.get('frequency', 0)), - last_interaction=last_interaction, - created_at=int(time.time()) - ) - - session.add(component) - - # 更新用户档案统计信息 - profile.total_relations += 1 - profile.last_updated = int(time.time()) - - await session.commit() - - logger.debug(f"[SQLAlchemy] 已保存社交关系: {from_user_id} -> {to_user_id}") - - except Exception as e: - logger.error(f"[SQLAlchemy] 保存社交关系失败: {e}", exc_info=True) - - # ======================================== - # 其他必要方法 - # ======================================== - - def get_db_connection(self): - """ - 获取数据库连接(上下文管理器) - - 用于向后兼容传统代码 - 返回一个模拟传统数据库连接的适配器 - - Returns: - AsyncContextManager: 异步上下文管理器 - """ - @asynccontextmanager - async def _connection_context(): - # 检查数据库管理器是否已启动 - if not self._started or not self.engine: - raise RuntimeError( - "[SQLAlchemy] 数据库引擎未初始化。请确保已调用 start() 方法。" - f"状态: _started={self._started}, engine={'已创建' if self.engine else '未创建'}" - ) - - # 创建一个兼容传统接口的连接适配器 - class SQLAlchemyConnectionAdapter: - """SQLAlchemy 连接适配器 - 模拟传统数据库连接接口""" - def __init__(self, session_factory): - self.session_factory = session_factory - self._session = None - - async def cursor(self): - """返回游标适配器""" - if not self._session: - self._session = self.session_factory() - return SQLAlchemyCursorAdapter(self._session) - - async def commit(self): - """提交事务""" - if self._session: - await self._session.commit() - - async def rollback(self): - """回滚事务""" - if self._session: - await self._session.rollback() - - async def close(self): - """关闭会话""" - if self._session: - await self._session.close() - - class SQLAlchemyCursorAdapter: - """SQLAlchemy 游标适配器""" - def __init__(self, session): - self.session = session - self._result = None - self.lastrowid = None - self.rowcount = 0 - - async def execute(self, sql, params=None): - """执行 SQL 语句""" - from sqlalchemy import text - from sqlalchemy import inspect - - # 检测并转换 SQLite 专用查询 - sql_converted = self._convert_sqlite_queries(sql) - - # 转换参数格式(? → :1, :2...) - if params: - # 将 ? 占位符转换为命名参数 - param_dict = {} - if isinstance(params, (list, tuple)): - for i, param in enumerate(params): - param_name = f"param_{i}" - sql_converted = sql_converted.replace('?', f":{param_name}", 1) - param_dict[param_name] = param - self._result = await self.session.execute(text(sql_converted), param_dict) - else: - self._result = await self.session.execute(text(sql_converted), params) - else: - self._result = await self.session.execute(text(sql_converted)) - - self.rowcount = self._result.rowcount if hasattr(self._result, 'rowcount') else 0 - return self - - def _convert_sqlite_queries(self, sql: str) -> str: - """ - 转换 SQLite 专用查询为数据库无关查询 - - Args: - sql: 原始 SQL 查询 - - Returns: - str: 转换后的 SQL 查询 - """ - import re - - # 检测数据库类型 - dialect_name = self.session.bind.dialect.name if self.session.bind else 'sqlite' - - # 如果是 SQLite,不需要转换 - if dialect_name == 'sqlite': - return sql - - # MySQL: 转换 sqlite_master 查询 - if 'sqlite_master' in sql.lower(): - if dialect_name == 'mysql': - # 提取表名检查模式 - # 匹配: SELECT name FROM sqlite_master WHERE type='table' AND name='表名' - pattern = r"SELECT\s+name\s+FROM\s+sqlite_master\s+WHERE\s+type\s*=\s*['\"]table['\"]\s+AND\s+name\s*=\s*['\"](\w+)['\"]" - match = re.search(pattern, sql, re.IGNORECASE) - - if match: - table_name = match.group(1) - # MySQL: 查询 INFORMATION_SCHEMA - converted = f""" - SELECT TABLE_NAME as name - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA = DATABASE() - AND TABLE_NAME = '{table_name}' - """ - logger.debug(f"[SQLAlchemy] 转换 SQLite 查询为 MySQL 查询: {table_name}") - return converted.strip() - - # 匹配: SELECT name FROM sqlite_master WHERE type='table' - pattern2 = r"SELECT\s+name\s+FROM\s+sqlite_master\s+WHERE\s+type\s*=\s*['\"]table['\"]" - if re.search(pattern2, sql, re.IGNORECASE): - # 列出所有表 - converted = """ - SELECT TABLE_NAME as name - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA = DATABASE() - """ - logger.debug("[SQLAlchemy] 转换 SQLite 查询为 MySQL 查询: 列出所有表") - return converted.strip() - - return sql - - async def fetchone(self): - """获取一行""" - if self._result: - return self._result.fetchone() - return None - - async def fetchall(self): - """获取所有行""" - if self._result: - return self._result.fetchall() - return [] - - async def close(self): - """关闭游标""" - if self._result: - self._result.close() - - # 创建并返回连接适配器 - adapter = SQLAlchemyConnectionAdapter(self.engine.get_session) - try: - yield adapter - finally: - await adapter.close() - - return _connection_context() - - async def get_group_connection(self, group_id: str): - """ - 获取群组数据库连接(用于向后兼容) - - 注意:此方法已废弃,新代码应使用 get_session() - 为了向后兼容,返回 get_db_connection() 的结果 - - Args: - group_id: 群组ID - - Returns: - Connection: 数据库连接适配器 - """ - # 返回通用连接(不区分群组) - return self.get_db_connection() - - async def mark_messages_processed(self, message_ids: List[int]): - """ - 标记消息为已处理 - - 注意:UserConversationHistory ORM 模型暂无 processed 字段 - 此方法暂时不执行实际操作,仅记录日志 - - Args: - message_ids: 消息ID列表 - """ - if not message_ids: - return - - try: - # TODO: 为 UserConversationHistory 添加 processed 字段后实现 - logger.debug(f"[SQLAlchemy] mark_messages_processed 调用(暂不实现): {len(message_ids)} 条消息") - - except Exception as e: - logger.error(f"[SQLAlchemy] 标记消息处理状态失败: {e}", exc_info=True) - - async def save_learning_performance_record(self, group_id: str, performance_data: Dict[str, Any]) -> bool: - """ - 保存学习性能记录 - - Args: - group_id: 群组ID - performance_data: 性能记录数据 - - Returns: - bool: 是否保存成功 - """ - try: - async with self.get_session() as session: - from ..models.orm import LearningPerformanceHistory - import time - - # 创建学习性能记录 - def _ser(v): - if isinstance(v, (dict, list)): - return json.dumps(v, ensure_ascii=False) - return v - - record = LearningPerformanceHistory( - group_id=group_id, - session_id=performance_data.get('session_id', ''), - timestamp=int(performance_data.get('timestamp', time.time())), - quality_score=float(performance_data.get('quality_score', 0.0)), - learning_time=float(performance_data.get('learning_time', 0.0)), - success=bool(performance_data.get('success', False)), - successful_pattern=_ser(performance_data.get('successful_pattern', '')), - failed_pattern=_ser(performance_data.get('failed_pattern', '')), - created_at=int(time.time()) - ) - - session.add(record) - await session.commit() - - logger.debug(f"[SQLAlchemy] 已保存学习性能记录: {group_id}") - return True - - except Exception as e: - logger.error(f"[SQLAlchemy] 保存学习性能记录失败: {e}", exc_info=True) - return False - - async def get_group_messages_statistics(self, group_id: str) -> Dict[str, Any]: - """ - 获取群组消息统计 - - 使用 SQLAlchemy ORM 实现,支持跨线程调用(NullPool) - 使用 RawMessage 表进行统计 - - Args: - group_id: 群组ID - - Returns: - Dict: 消息统计数据 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..models.orm import RawMessage - - # 统计总消息数 - total_stmt = select(func.count()).select_from(RawMessage).where( - RawMessage.group_id == group_id - ) - total_result = await session.execute(total_stmt) - total_messages = total_result.scalar() or 0 - - # 统计已处理消息数 - processed_stmt = select(func.count()).select_from(RawMessage).where( - RawMessage.group_id == group_id, - RawMessage.processed == True - ) - processed_result = await session.execute(processed_stmt) - processed_messages = processed_result.scalar() or 0 - - # 计算未处理消息数 - unprocessed_messages = total_messages - processed_messages - - return { - 'total_messages': total_messages, - 'unprocessed_messages': unprocessed_messages, - 'processed_messages': processed_messages - } - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取群组消息统计失败: {e}", exc_info=True) - raise RuntimeError(f"无法获取群组 {group_id} 的消息统计: {e}") from e - - # ==================== 黑话 CRUD (ORM) ==================== - - async def get_jargon(self, chat_id: str, content: str) -> Optional[Dict[str, Any]]: - """查询指定黑话(ORM)""" - try: - async with self.get_session() as session: - from sqlalchemy import select, and_ - from ..models.orm.jargon import Jargon - - stmt = select(Jargon).where(and_( - Jargon.chat_id == chat_id, - Jargon.content == content - )) - result = await session.execute(stmt) - record = result.scalars().first() - - if not record: - return None - - return record.to_dict() - - except Exception as e: - logger.error(f"[SQLAlchemy] 查询黑话失败: {e}", exc_info=True) - return None - - async def insert_jargon(self, jargon_data: Dict[str, Any]) -> Optional[int]: - """插入新的黑话记录(ORM)""" - try: - async with self.get_session() as session: - from ..models.orm.jargon import Jargon - - now_ts = int(time.time()) - - # 处理 created_at / updated_at - 统一转为 int 时间戳 - created_at = jargon_data.get('created_at') - updated_at = jargon_data.get('updated_at') - if created_at and not isinstance(created_at, (int, float)): - created_at = now_ts - elif created_at: - created_at = int(created_at) - else: - created_at = now_ts - - if updated_at and not isinstance(updated_at, (int, float)): - updated_at = now_ts - elif updated_at: - updated_at = int(updated_at) - else: - updated_at = now_ts - - record = Jargon( - content=jargon_data.get('content', ''), - raw_content=jargon_data.get('raw_content', '[]'), - meaning=jargon_data.get('meaning'), - is_jargon=jargon_data.get('is_jargon'), - count=jargon_data.get('count', 1), - last_inference_count=jargon_data.get('last_inference_count', 0), - is_complete=jargon_data.get('is_complete', False), - is_global=jargon_data.get('is_global', False), - chat_id=jargon_data.get('chat_id', ''), - created_at=created_at, - updated_at=updated_at - ) - - session.add(record) - await session.commit() - await session.refresh(record) - - logger.info(f"[SQLAlchemy] 插入黑话成功: id={record.id}, content={record.content}") - return record.id - - except Exception as e: - logger.error(f"[SQLAlchemy] 插入黑话失败: {e}", exc_info=True) - return None - - async def update_jargon(self, jargon_data: Dict[str, Any]) -> bool: - """更新现有黑话记录(ORM)""" - jargon_id = jargon_data.get('id') - if not jargon_id: - logger.error("[SQLAlchemy] 更新黑话失败: 缺少 id") - return False - - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm.jargon import Jargon - - stmt = select(Jargon).where(Jargon.id == jargon_id) - result = await session.execute(stmt) - record = result.scalars().first() - - if not record: - logger.warning(f"[SQLAlchemy] 更新黑话失败: 未找到 id={jargon_id}") - return False - - # 更新字段 - if 'content' in jargon_data: - record.content = jargon_data['content'] - if 'raw_content' in jargon_data: - record.raw_content = jargon_data['raw_content'] - if 'meaning' in jargon_data: - record.meaning = jargon_data['meaning'] - if 'is_jargon' in jargon_data: - record.is_jargon = jargon_data['is_jargon'] - if 'count' in jargon_data: - record.count = jargon_data['count'] - if 'last_inference_count' in jargon_data: - record.last_inference_count = jargon_data['last_inference_count'] - if 'is_complete' in jargon_data: - record.is_complete = jargon_data['is_complete'] - if 'is_global' in jargon_data: - record.is_global = jargon_data['is_global'] - - # updated_at 统一为 int 时间戳 - updated_at = jargon_data.get('updated_at') - if updated_at and not isinstance(updated_at, (int, float)): - record.updated_at = int(time.time()) - elif updated_at: - record.updated_at = int(updated_at) - else: - record.updated_at = int(time.time()) - - await session.commit() - logger.debug(f"[SQLAlchemy] 更新黑话成功: id={jargon_id}") - return True - - except Exception as e: - logger.error(f"[SQLAlchemy] 更新黑话失败: {e}", exc_info=True) - return False - - async def get_jargon_statistics(self, group_id: str = None) -> Dict[str, Any]: - """获取黑话学习统计信息(ORM 版本) - - Args: - group_id: 群组ID(可选,None 表示全局统计) - - Returns: - 统计数据字典,包含 total_candidates, confirmed_jargon, - completed_inference, total_occurrences, average_count, active_groups - """ - default_stats = { - 'total_candidates': 0, - 'confirmed_jargon': 0, - 'completed_inference': 0, - 'total_occurrences': 0, - 'average_count': 0.0, - 'active_groups': 0, - } - try: - async with self.get_session() as session: - from sqlalchemy import select, func, case - from ..models.orm.jargon import Jargon - - columns = [ - func.count().label('total'), - func.count(case((Jargon.is_jargon == True, 1))).label('confirmed'), - func.count(case((Jargon.is_complete == True, 1))).label('completed'), - func.coalesce(func.sum(Jargon.count), 0).label('total_occurrences'), - func.coalesce(func.avg(Jargon.count), 0).label('avg_count'), - ] - - if not group_id: - columns.append( - func.count(func.distinct(Jargon.chat_id)).label('active_groups') - ) - - stmt = select(*columns) - if group_id: - stmt = stmt.where(Jargon.chat_id == group_id) - - result = await session.execute(stmt) - row = result.fetchone() - - if not row: - return default_stats - - stats = { - 'total_candidates': int(row.total) if row.total else 0, - 'confirmed_jargon': int(row.confirmed) if row.confirmed else 0, - 'completed_inference': int(row.completed) if row.completed else 0, - 'total_occurrences': int(row.total_occurrences) if row.total_occurrences else 0, - 'average_count': round(float(row.avg_count), 1) if row.avg_count else 0.0, - } - - if not group_id: - stats['active_groups'] = int(row.active_groups) if row.active_groups else 0 - else: - stats['active_groups'] = 1 if stats['total_candidates'] > 0 else 0 - - return stats - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取黑话统计失败: {e}", exc_info=True) - return default_stats - - async def get_recent_jargon_list( - self, - group_id: str = None, - chat_id: str = None, - limit: int = 10, - only_confirmed: bool = None - ) -> List[Dict[str, Any]]: - """ - 获取最近的黑话列表 - - Args: - group_id: 群组ID(可选,None 表示获取所有群组) - chat_id: 聊天ID(可选,兼容参数) - limit: 返回数量限制 - only_confirmed: 是否只返回已确认的黑话 - - Returns: - List[Dict]: 黑话列表,包含 content, meaning 等字段 - """ - # chat_id 是 group_id 的别名(向后兼容) - if group_id is None and chat_id is not None: - group_id = chat_id - - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import Jargon - - # 构建查询 - stmt = select(Jargon) - - # 如果指定了 group_id,则只查询该群组 - if group_id is not None: - stmt = stmt.where(Jargon.chat_id == group_id) - - # 如果只返回已确认的黑话 - if only_confirmed: - stmt = stmt.where(Jargon.is_jargon == True) - - # 按更新时间倒序排列,限制数量 - stmt = stmt.order_by(Jargon.updated_at.desc()).limit(limit) - - result = await session.execute(stmt) - jargon_records = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询最近黑话列表: group_id={group_id}, 数量={len(jargon_records)}") - - jargon_list = [] - for record in jargon_records: - try: - jargon_list.append({ - 'id': record.id, - 'content': record.content, - 'raw_content': record.raw_content, - 'meaning': record.meaning, - 'is_jargon': record.is_jargon, - 'count': record.count or 0, - 'last_inference_count': record.last_inference_count or 0, - 'is_complete': record.is_complete, - 'chat_id': record.chat_id, - 'updated_at': record.updated_at, - 'is_global': record.is_global or False - }) - except Exception as row_error: - logger.warning(f"处理黑话记录行时出错,跳过: {row_error}") - continue - - return jargon_list - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取最近黑话列表失败: {e}", exc_info=True) - return [] - - async def search_jargon( - self, - keyword: str, - chat_id: Optional[str] = None, - limit: int = 10 - ) -> List[Dict[str, Any]]: - """搜索黑话(LIKE 匹配,ORM 版本) - - Args: - keyword: 搜索关键词 - chat_id: 群组ID(有值搜本群已确认黑话,无值搜全局已确认黑话) - limit: 返回数量限制 - - Returns: - 匹配的黑话列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, and_ - from ..models.orm.jargon import Jargon - - conditions = [ - Jargon.content.ilike(f'%{keyword}%'), - Jargon.is_jargon == True, - ] - if chat_id: - conditions.append(Jargon.chat_id == chat_id) - else: - conditions.append(Jargon.is_global == True) - - stmt = ( - select(Jargon) - .where(and_(*conditions)) - .order_by(Jargon.count.desc(), Jargon.updated_at.desc()) - .limit(limit) - ) - result = await session.execute(stmt) - records = result.scalars().all() - - return [ - { - 'id': r.id, - 'content': r.content, - 'raw_content': r.raw_content, - 'meaning': r.meaning, - 'is_jargon': r.is_jargon, - 'count': r.count or 0, - 'is_complete': r.is_complete, - 'is_global': r.is_global or False, - 'chat_id': r.chat_id, - 'updated_at': r.updated_at, - } - for r in records - ] - except Exception as e: - logger.error(f"[SQLAlchemy] 搜索黑话失败: {e}", exc_info=True) - return [] - - async def get_jargon_by_id(self, jargon_id: int) -> Optional[Dict[str, Any]]: - """根据ID获取黑话记录(ORM 版本) - - Args: - jargon_id: 黑话记录ID - - Returns: - 黑话字典或 None - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm.jargon import Jargon - - stmt = select(Jargon).where(Jargon.id == jargon_id) - result = await session.execute(stmt) - record = result.scalars().first() - - if not record: - return None - - return { - 'id': record.id, - 'content': record.content, - 'raw_content': record.raw_content, - 'meaning': record.meaning, - 'is_jargon': bool(record.is_jargon) if record.is_jargon is not None else None, - 'count': record.count or 0, - 'last_inference_count': record.last_inference_count or 0, - 'is_complete': bool(record.is_complete), - 'is_global': bool(record.is_global) if record.is_global is not None else False, - 'chat_id': record.chat_id, - 'updated_at': record.updated_at, - } - except Exception as e: - logger.error(f"[SQLAlchemy] 获取黑话记录失败 (id={jargon_id}): {e}", exc_info=True) - return None - - async def delete_jargon_by_id(self, jargon_id: int) -> bool: - """根据ID删除黑话记录(ORM 版本) - - Args: - jargon_id: 黑话记录ID - - Returns: - 是否删除成功 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm.jargon import Jargon - - stmt = select(Jargon).where(Jargon.id == jargon_id) - result = await session.execute(stmt) - record = result.scalars().first() - - if not record: - return False - - await session.delete(record) - await session.commit() - logger.debug(f"[SQLAlchemy] 删除黑话记录成功, ID: {jargon_id}") - return True - except Exception as e: - logger.error(f"[SQLAlchemy] 删除黑话失败 (id={jargon_id}): {e}", exc_info=True) - return False - - async def set_jargon_global(self, jargon_id: int, is_global: bool) -> bool: - """设置黑话的全局共享状态(ORM 版本) - - Args: - jargon_id: 黑话记录ID - is_global: 是否全局共享 - - Returns: - 是否更新成功 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm.jargon import Jargon - - stmt = select(Jargon).where(Jargon.id == jargon_id) - result = await session.execute(stmt) - record = result.scalars().first() - - if not record: - return False - - record.is_global = is_global - record.updated_at = int(time.time()) - await session.commit() - logger.info(f"[SQLAlchemy] 黑话全局状态已更新: ID={jargon_id}, is_global={is_global}") - return True - except Exception as e: - logger.error(f"[SQLAlchemy] 更新黑话全局状态失败 (id={jargon_id}): {e}", exc_info=True) - return False - - async def sync_global_jargon_to_group(self, target_chat_id: str) -> int: - """将全局黑话同步到指定群组(ORM 版本) - - 对全局黑话逐条检查目标群组是否已存在相同内容,不存在则插入。 - - Args: - target_chat_id: 目标群组ID - - Returns: - 成功同步的数量 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, and_ - from ..models.orm.jargon import Jargon - - # 获取非目标群组的全局黑话 - stmt = select(Jargon).where(and_( - Jargon.is_jargon == True, - Jargon.is_global == True, - Jargon.chat_id != target_chat_id - )) - result = await session.execute(stmt) - global_jargons = result.scalars().all() - - synced_count = 0 - now_ts = int(time.time()) - - for gj in global_jargons: - # 检查目标群组是否已存在 - check_stmt = select(Jargon).where(and_( - Jargon.chat_id == target_chat_id, - Jargon.content == gj.content - )) - check_result = await session.execute(check_stmt) - if check_result.scalars().first(): - continue - - new_jargon = Jargon( - content=gj.content, - raw_content='[]', - meaning=gj.meaning, - is_jargon=True, - count=1, - last_inference_count=0, - is_complete=False, - is_global=False, - chat_id=target_chat_id, - created_at=now_ts, - updated_at=now_ts, - ) - session.add(new_jargon) - synced_count += 1 - - await session.commit() - logger.info(f"[SQLAlchemy] 同步全局黑话到群组 {target_chat_id}: 同步 {synced_count} 条") - return synced_count - except Exception as e: - logger.error(f"[SQLAlchemy] 同步全局黑话失败: {e}", exc_info=True) - return 0 - - async def get_learning_patterns_data(self, group_id: str = None) -> Dict[str, Any]: - """ - 获取学习模式数据 - - Args: - group_id: 群组ID(可选) - - Returns: - Dict: 学习模式数据 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..repositories.learning_repository import PersonaLearningReviewRepository, StyleLearningReviewRepository - - persona_repo = PersonaLearningReviewRepository(session) - style_repo = StyleLearningReviewRepository(session) - - # 获取人格学习统计 - persona_stats = await persona_repo.get_statistics() - - # 获取风格学习统计 - style_stats = await style_repo.get_statistics() - - return { - 'persona_learning': persona_stats, - 'style_learning': style_stats, - 'group_id': group_id - } - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取学习模式数据失败: {e}", exc_info=True) - return {'persona_learning': {}, 'style_learning': {}, 'group_id': group_id} - - async def save_learning_session_record(self, group_id: str, session_data: Dict[str, Any]) -> bool: - """ - 保存学习会话记录 - - Args: - group_id: 群组ID - session_data: 会话数据 - - Returns: - bool: 是否保存成功 - """ - try: - # 此方法在新架构中可能不需要,暂时只记录日志 - logger.debug(f"[SQLAlchemy] 学习会话记录(暂不实现): group={group_id}, data={session_data}") - return True - - except Exception as e: - logger.error(f"[SQLAlchemy] 保存学习会话记录失败: {e}", exc_info=True) - return False - - async def get_detailed_metrics(self, group_id: str = None) -> Dict[str, Any]: - """ - 获取详细指标数据 - - Args: - group_id: 群组ID(可选) - - Returns: - Dict: 详细指标数据 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..models.orm import UserAffection, UserConversationHistory, ExpressionPattern - - metrics = {} - - # 好感度指标 - if group_id: - affection_stmt = select( - func.count(UserAffection.id).label('count'), - func.avg(UserAffection.affection_level).label('avg_level') - ).where(UserAffection.group_id == group_id) - else: - affection_stmt = select( - func.count(UserAffection.id).label('count'), - func.avg(UserAffection.affection_level).label('avg_level') - ) - - affection_result = await session.execute(affection_stmt) - affection_row = affection_result.first() - - metrics['affection'] = { - 'total_users': affection_row.count if affection_row else 0, - 'avg_level': float(affection_row.avg_level) if affection_row and affection_row.avg_level else 0.0 - } - - # 对话历史指标 - if group_id: - conv_stmt = select(func.count(UserConversationHistory.id)).where( - UserConversationHistory.group_id == group_id - ) - else: - conv_stmt = select(func.count(UserConversationHistory.id)) - - conv_result = await session.execute(conv_stmt) - conv_count = conv_result.scalar() or 0 - - metrics['conversations'] = { - 'total_count': conv_count - } - - # 表达模式指标 - if group_id: - expr_stmt = select(func.count(ExpressionPattern.id)).where( - ExpressionPattern.group_id == group_id - ) - else: - expr_stmt = select(func.count(ExpressionPattern.id)) - - expr_result = await session.execute(expr_stmt) - expr_count = expr_result.scalar() or 0 - - metrics['expressions'] = { - 'total_patterns': expr_count - } - - return metrics - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取详细指标失败: {e}", exc_info=True) - return {'affection': {}, 'conversations': {}, 'expressions': {}} - - async def get_style_progress_data(self, group_id: str = None) -> List[Dict[str, Any]]: - """ - 获取风格进度数据(从 learning_batches 表) - - Args: - group_id: 群组ID(可选) - - Returns: - List[Dict]: 风格进度数据列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, desc - from ..models.orm.learning import LearningBatch - - query = select(LearningBatch).where( - LearningBatch.quality_score.isnot(None), - LearningBatch.processed_messages > 0 - ).order_by(desc(LearningBatch.start_time)).limit(30) - - if group_id: - query = query.where(LearningBatch.group_id == group_id) - - result = await session.execute(query) - batches = result.scalars().all() - - progress_data = [] - for batch in batches: - progress_data.append({ - 'group_id': batch.group_id, - 'timestamp': batch.start_time or 0, - 'quality_score': batch.quality_score or 0, - 'success': batch.success if batch.success is not None else True, - 'processed_messages': batch.processed_messages or 0, - 'filtered_count': batch.filtered_count or 0, - 'batch_name': batch.batch_name or '', - 'message_count': batch.message_count or 0 - }) - - logger.debug(f"[SQLAlchemy] get_style_progress_data 获取到 {len(progress_data)} 行数据") - return progress_data - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取风格进度数据失败: {e}", exc_info=True) - return [] - - async def save_raw_message(self, message_data) -> int: - """ - 保存原始消息(纯 ORM 实现) - - Args: - message_data: 消息数据(对象或字典) - - Returns: - int: 消息ID - """ - try: - async with self.get_session() as session: - from ..models.orm import RawMessage - import time - - # 兼容对象和字典两种输入 - if hasattr(message_data, '__dict__'): - data = message_data.__dict__ - else: - data = message_data - - # 创建原始消息记录 - raw_msg = RawMessage( - sender_id=str(data.get('sender_id', '')), - sender_name=data.get('sender_name', ''), - message=data.get('message', ''), - group_id=data.get('group_id', ''), - timestamp=int(data.get('timestamp', time.time())), - platform=data.get('platform', ''), - message_id=data.get('message_id'), - reply_to=data.get('reply_to'), - created_at=int(time.time()), - processed=False - ) - - session.add(raw_msg) - await session.commit() - await session.refresh(raw_msg) - - logger.debug(f"[SQLAlchemy] 已保存原始消息: ID={raw_msg.id}, group={data.get('group_id')}") - return raw_msg.id - - except Exception as e: - logger.error(f"[SQLAlchemy] 保存���始消息失败: {e}", exc_info=True) - return 0 - - async def get_recent_raw_messages(self, group_id: str, limit: int = 200) -> List[Dict[str, Any]]: - """ - 获取最近的原始消息 - - 使用 SQLAlchemy ORM 实现,支持跨线程调用(NullPool) - - Args: - group_id: 群组ID - limit: 最大返回数量 - - Returns: - List[Dict]: 原始消息列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import RawMessage - - # 构建查询:按时间倒序 - stmt = select(RawMessage).where( - RawMessage.group_id == group_id - ).order_by( - RawMessage.timestamp.desc() - ).limit(limit) - - result = await session.execute(stmt) - messages = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询最近原始消息: 群组={group_id}, 数量={len(messages)}") - - return [ - { - 'id': msg.id, - 'sender_id': msg.sender_id, - 'sender_name': msg.sender_name, - 'message': msg.message, - 'group_id': msg.group_id, - 'timestamp': msg.timestamp, - 'platform': msg.platform, - 'message_id': msg.message_id, - 'reply_to': msg.reply_to, - 'created_at': msg.created_at, - 'processed': msg.processed - } - for msg in messages - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 查询最近原始消息失败: {e}") - raise RuntimeError(f"无法获取群组 {group_id} 的最近原始消息: {e}") from e - - async def get_recent_filtered_messages(self, group_id: str, limit: int = 20) -> List[Dict[str, Any]]: - """ - 获取最近的筛选后消息 - - 使用 SQLAlchemy ORM 实现,支持跨线程调用(NullPool) - - Args: - group_id: 群组ID - limit: 最大返回数量 - - Returns: - List[Dict]: 筛选后消息列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import FilteredMessage - - # 构建查询:按时间倒序 - stmt = select(FilteredMessage).where( - FilteredMessage.group_id == group_id - ).order_by( - FilteredMessage.timestamp.desc() - ).limit(limit) - - result = await session.execute(stmt) - messages = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询最近筛选消息: 群组={group_id}, 数量={len(messages)}") - - return [ - { - 'id': msg.id, - 'raw_message_id': msg.raw_message_id, - 'message': msg.message, - 'sender_id': msg.sender_id, - 'group_id': msg.group_id, - 'timestamp': msg.timestamp, - 'confidence': msg.confidence, - 'quality_scores': msg.quality_scores, - 'filter_reason': msg.filter_reason, - 'created_at': msg.created_at, - 'processed': msg.processed - } - for msg in messages - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 查询最近筛选消息失败: {e}") - raise RuntimeError(f"无法获取群组 {group_id} 的最近筛选消息: {e}") from e - - async def get_unprocessed_messages(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: - """ - 获取未处理的原始消息(ORM 版本 - 支持跨线程调用) - - Args: - limit: 限制返回的消息数量 - - Returns: - 未处理的消息列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import RawMessage - - # 构建查询 - stmt = select(RawMessage).where( - RawMessage.processed == False - ).order_by( - RawMessage.timestamp.asc() - ) - - # 添加限制 - if limit: - stmt = stmt.limit(limit) - - # 执行查询 - result = await session.execute(stmt) - raw_messages = result.scalars().all() - - # 转换为字典格式 - messages = [] - for msg in raw_messages: - messages.append({ - 'id': msg.id, - 'sender_id': msg.sender_id, - 'sender_name': msg.sender_name, - 'message': msg.message, - 'group_id': msg.group_id, - 'platform': msg.platform, - 'timestamp': msg.timestamp - }) - - logger.debug(f"[SQLAlchemy] 获取到 {len(messages)} 条未处理消息") - return messages - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取未处理消息失败: {e}", exc_info=True) - raise RuntimeError(f"获取未处理消息失败: {str(e)}") from e - - async def mark_messages_processed(self, message_ids: List[int]) -> bool: - """ - 标记消息为已处理(ORM 版本 - 支持跨线程调用) - - Args: - message_ids: 消息ID列表 - - Returns: - 是否成功标记 - """ - if not message_ids: - return True - - try: - async with self.get_session() as session: - from sqlalchemy import update - from ..models.orm import RawMessage - - # 批量更新消息状态 - stmt = update(RawMessage).where( - RawMessage.id.in_(message_ids) - ).values( - processed=True - ) - - result = await session.execute(stmt) - await session.commit() - - updated_count = result.rowcount - logger.debug(f"[SQLAlchemy] 已标记 {updated_count} 条消息为已处理") - return True - - except Exception as e: - logger.error(f"[SQLAlchemy] 标记消息处理状态失败: {e}", exc_info=True) - raise RuntimeError(f"标记消息处理状态失败: {str(e)}") from e - - async def get_filtered_messages_for_learning(self, limit: int = 20) -> List[Dict[str, Any]]: - """ - 获取用于学习的筛选后消息 - - 使用 SQLAlchemy ORM 实现,支持跨线程调用(NullPool) - - Args: - limit: 最大返回数量 - - Returns: - List[Dict]: 筛选后消息列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import FilteredMessage - - # 构建查询:获取未处理的高质量消息 - stmt = select(FilteredMessage).where( - FilteredMessage.processed == False - ).order_by( - FilteredMessage.timestamp.desc() - ).limit(limit) - - result = await session.execute(stmt) - messages = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询用于学习的筛选消息: 数量={len(messages)}") - - return [ - { - 'id': msg.id, - 'raw_message_id': msg.raw_message_id, - 'message': msg.message, - 'sender_id': msg.sender_id, - 'group_id': msg.group_id, - 'timestamp': msg.timestamp, - 'confidence': msg.confidence, - 'quality_scores': msg.quality_scores, - 'filter_reason': msg.filter_reason, - 'created_at': msg.created_at, - 'processed': msg.processed - } - for msg in messages - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 查询用于学习的筛选消息失败: {e}") - raise RuntimeError(f"无法获取用于学习的筛选消息: {e}") from e - - async def get_recent_learning_batches(self, limit: int = 5) -> List[Dict[str, Any]]: - """ - 获取最近的学习批次 - - 使用 SQLAlchemy ORM 实现,支持跨线程调用(NullPool) - - Args: - limit: 最大返回数量 - - Returns: - List[Dict]: 学习批次列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import LearningPerformanceHistory - - # 构建查询:按时间倒序 - stmt = select(LearningPerformanceHistory).order_by( - LearningPerformanceHistory.timestamp.desc() - ).limit(limit) - - result = await session.execute(stmt) - batches = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询最近学习批次: 数量={len(batches)}") - - return [ - { - 'id': batch.id, - 'group_id': batch.group_id, - 'session_id': batch.session_id, - 'timestamp': batch.timestamp, - 'quality_score': batch.quality_score, - 'learning_time': batch.learning_time, - 'success': batch.success, - 'successful_pattern': batch.successful_pattern, - 'failed_pattern': batch.failed_pattern, - 'created_at': batch.created_at - } - for batch in batches - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 查询最近学习批次失败: {e}") - raise RuntimeError(f"无法获取最近学习批次: {e}") from e - - async def get_learning_sessions(self, group_id: str, limit: int = 5) -> List[Dict[str, Any]]: - """ - 获取学习会话 - - 使用 SQLAlchemy ORM 实现,支持跨线程调用(NullPool) - - Args: - group_id: 群组ID - limit: 最大返回数量 - - Returns: - List[Dict]: 学习会话列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import LearningPerformanceHistory - - # 构建查询:按时间倒序,过滤群组 - stmt = select(LearningPerformanceHistory).where( - LearningPerformanceHistory.group_id == group_id - ).order_by( - LearningPerformanceHistory.timestamp.desc() - ).limit(limit) - - result = await session.execute(stmt) - sessions = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询学习会话: 群组={group_id}, 数量={len(sessions)}") - - return [ - { - 'id': session.id, - 'group_id': session.group_id, - 'session_id': session.session_id, - 'timestamp': session.timestamp, - 'quality_score': session.quality_score, - 'learning_time': session.learning_time, - 'success': session.success, - 'successful_pattern': session.successful_pattern, - 'failed_pattern': session.failed_pattern, - 'created_at': session.created_at - } - for session in sessions - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 查询学习会话失败: {e}") - raise RuntimeError(f"无法获取群组 {group_id} 的学习会话: {e}") from e - - async def get_pending_persona_update_records(self) -> List[Dict[str, Any]]: - """ - 获取待审核的人格更新记录(ORM 版本) - - Returns: - 待审核记录列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import PersonaLearningReview - - stmt = select(PersonaLearningReview).where( - PersonaLearningReview.status == 'pending' - ).order_by( - PersonaLearningReview.timestamp.desc() - ) - - result = await session.execute(stmt) - records = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询待审核人格更新记录: 数量={len(records)}") - - return [ - { - 'id': record.id, - 'timestamp': record.timestamp, - 'group_id': record.group_id, - 'update_type': record.update_type, - 'original_content': record.original_content, - 'new_content': record.new_content, - 'reason': record.reason, - 'status': record.status, - 'reviewer_comment': record.reviewer_comment, - 'review_time': record.review_time - } - for record in records - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取待审核人格更新记录失败: {e}") - raise RuntimeError(f"无法获取待审核人格更新记录: {e}") from e - - async def save_persona_update_record(self, record: Dict[str, Any]) -> int: - """ - 保存人格更新记录(ORM 版本) - - Args: - record: 人格更新记录字典 - - Returns: - int: 新记录 ID - """ - try: - async with self.get_session() as session: - from ..models.orm import PersonaLearningReview - - orm_record = PersonaLearningReview( - timestamp=record.get('timestamp', time.time()), - group_id=record.get('group_id', 'default'), - update_type=record.get('update_type', 'prompt_update'), - original_content=record.get('original_content', ''), - new_content=record.get('new_content', ''), - proposed_content=record.get('new_content', ''), - confidence_score=record.get('confidence_score'), - reason=record.get('reason', ''), - status=record.get('status', 'pending'), - reviewer_comment=record.get('reviewer_comment'), - review_time=record.get('review_time') - ) - - session.add(orm_record) - await session.flush() - record_id = orm_record.id - await session.commit() - - logger.debug(f"[SQLAlchemy] 已保存人格更新记录: id={record_id}") - return record_id - - except Exception as e: - logger.error(f"[SQLAlchemy] 保存人格更新记录失败: {e}") - raise RuntimeError(f"无法保存人格更新记录: {e}") from e - - async def update_persona_update_record_status( - self, - record_id: int, - status: str, - reviewer_comment: Optional[str] = None - ) -> bool: - """ - 更新人格更新记录状态(ORM 版本) - - Args: - record_id: 记录 ID - status: 新状态 - reviewer_comment: 审核备注 - - Returns: - bool: 是否更新成功 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import PersonaLearningReview - - stmt = select(PersonaLearningReview).where( - PersonaLearningReview.id == record_id - ) - result = await session.execute(stmt) - record = result.scalar_one_or_none() - - if not record: - logger.warning(f"[SQLAlchemy] 未找到人格更新记录: id={record_id}") - return False - - record.status = status - record.reviewer_comment = reviewer_comment - record.review_time = time.time() - - await session.commit() - logger.debug(f"[SQLAlchemy] 已更新人格记录状态: id={record_id}, status={status}") - return True - - except Exception as e: - logger.error(f"[SQLAlchemy] 更新人格更新记录状态失败: {e}") - raise RuntimeError(f"无法更新人格更新记录状态: {e}") from e - - async def delete_persona_update_record(self, record_id: int) -> bool: - """ - 删除人格更新记录(ORM 版本) - - Args: - record_id: 记录 ID - - Returns: - bool: 是否删除成功 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import PersonaLearningReview - - stmt = select(PersonaLearningReview).where( - PersonaLearningReview.id == record_id - ) - result = await session.execute(stmt) - record = result.scalar_one_or_none() - - if not record: - logger.warning(f"[SQLAlchemy] 删除失败,记录不存在: id={record_id}") - return False - - await session.delete(record) - await session.commit() - logger.debug(f"[SQLAlchemy] 已删除人格更新记录: id={record_id}") - return True - - except Exception as e: - logger.error(f"[SQLAlchemy] 删除人格更新记录失败: {e}") - raise RuntimeError(f"无法删除人格更新记录: {e}") from e - - async def get_persona_update_record_by_id(self, record_id: int) -> Optional[Dict[str, Any]]: - """ - 根据 ID 获取人格更新记录(ORM 版本) - - Args: - record_id: 记录 ID - - Returns: - Optional[Dict]: 记录字典,不存在时返回 None - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import PersonaLearningReview - - stmt = select(PersonaLearningReview).where( - PersonaLearningReview.id == record_id - ) - result = await session.execute(stmt) - record = result.scalar_one_or_none() - - if not record: - return None - - return { - 'id': record.id, - 'timestamp': record.timestamp, - 'group_id': record.group_id, - 'update_type': record.update_type, - 'original_content': record.original_content, - 'new_content': record.new_content, - 'reason': record.reason, - 'status': record.status, - 'reviewer_comment': record.reviewer_comment, - 'review_time': record.review_time - } - - except Exception as e: - logger.error(f"[SQLAlchemy] 根据ID获取人格更新记录失败: {e}") - raise RuntimeError(f"无法获取人格更新记录: {e}") from e - - async def get_reviewed_persona_update_records( - self, - limit: int = 50, - offset: int = 0, - status_filter: Optional[str] = None - ) -> List[Dict[str, Any]]: - """ - 获取已审核的人格更新记录(ORM 版本) - - Args: - limit: 返回数量限制 - offset: 偏移量 - status_filter: 筛选状态 ('approved' 或 'rejected'),None 表示返回所有已审核记录 - - Returns: - 已审核记录列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, or_ - from ..models.orm import PersonaLearningReview - - # 构建查询 - if status_filter: - # 筛选特定状态 - stmt = select(PersonaLearningReview).where( - PersonaLearningReview.status == status_filter - ) - else: - # 返回所有已审核记录(approved 或 rejected) - stmt = select(PersonaLearningReview).where( - or_( - PersonaLearningReview.status == 'approved', - PersonaLearningReview.status == 'rejected' - ) - ) - - stmt = stmt.order_by( - PersonaLearningReview.review_time.desc() - ).limit(limit).offset(offset) - - result = await session.execute(stmt) - records = result.scalars().all() - - logger.debug( - f"[SQLAlchemy] 查询已审核人格更新记录: 状态={status_filter}, 数量={len(records)}" - ) - - return [ - { - 'id': record.id, - 'timestamp': record.timestamp, - 'group_id': record.group_id, - 'update_type': record.update_type, - 'original_content': record.original_content, - 'new_content': record.new_content, - 'reason': record.reason, - 'status': record.status, - 'reviewer_comment': record.reviewer_comment, - 'review_time': record.review_time - } - for record in records - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取已审核人格更新记录失败: {e}") - raise RuntimeError(f"无法获取已审核人格更新记录: {e}") from e - - async def get_global_jargon_list(self, limit: int = 50) -> List[Dict[str, Any]]: - """ - 获取全局共享的黑话列表(ORM 版本) - - Args: - limit: 返回数量限制 - - Returns: - 全局黑话列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select - from ..models.orm import Jargon - - stmt = select(Jargon).where( - Jargon.is_jargon == True, - Jargon.is_global == True - ).order_by( - Jargon.count.desc(), - Jargon.updated_at.desc() - ).limit(limit) - - result = await session.execute(stmt) - jargon_list = result.scalars().all() - - logger.debug(f"[SQLAlchemy] 查询全局黑话列表: 数量={len(jargon_list)}") - - return [ - { - 'id': jargon.id, - 'content': jargon.content, - 'raw_content': jargon.raw_content, - 'meaning': jargon.meaning, - 'is_jargon': jargon.is_jargon, - 'count': jargon.count, - 'last_inference_count': jargon.last_inference_count, - 'is_complete': jargon.is_complete, - 'is_global': jargon.is_global, - 'chat_id': jargon.chat_id, - 'updated_at': jargon.updated_at - } - for jargon in jargon_list - ] - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取全局黑话列表失败: {e}") - raise RuntimeError(f"无法获取全局黑话列表: {e}") from e - - async def get_groups_for_social_analysis(self) -> List[Dict[str, Any]]: - """ - 获取可用于社交关系分析的群组列表(ORM 版本) - - 返回包含消息数、成员数、社交关系数的群组列表 - 仅返回消息数 >= 10 的群组 - - Returns: - 群组统计列表 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..models.orm import RawMessage, SocialRelation - - # 使用 LEFT JOIN 一次性获取群组的消息数、成员数和社交关系数 - # 注意:这里需要处理 MySQL 和 SQLite 的字段差异 - stmt = select( - RawMessage.group_id, - func.count(func.distinct(RawMessage.id)).label('message_count'), - func.count(func.distinct(RawMessage.sender_id)).label('member_count'), - func.count(func.distinct(SocialRelation.id)).label('relation_count') - ).select_from(RawMessage).outerjoin( - SocialRelation, - RawMessage.group_id == SocialRelation.group_id - ).where( - RawMessage.group_id.isnot(None), - RawMessage.group_id != '' - ).group_by( - RawMessage.group_id - ).having( - func.count(func.distinct(RawMessage.id)) >= 10 - ).order_by( - func.count(func.distinct(RawMessage.id)).desc() - ) - - result = await session.execute(stmt) - rows = result.all() - - logger.debug(f"[SQLAlchemy] 查询社交分析群组列表: 数量={len(rows)}") - - groups = [] - for row in rows: - try: - groups.append({ - 'group_id': row.group_id, - 'message_count': row.message_count, - 'member_count': row.member_count, - 'relation_count': row.relation_count - }) - except Exception as e: - logger.warning(f"处理群组数据行失败: {e}, 行数据: {row}") - continue - - return groups - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取社交分析群组列表失败: {e}") - raise RuntimeError(f"无法获取社交分析群组列表: {e}") from e - - async def get_jargon_groups(self) -> List[Dict[str, Any]]: - """ - 获取包含黑话的群组列表(ORM 版本) - - Returns: - 包含黑话的群组列表,包括群组ID、黑话数量、已完成黑话数、全局黑话数 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func, case - from ..models.orm import Jargon - - # 统计每个群组的黑话情况 - stmt = select( - Jargon.chat_id.label('group_id'), - func.count(Jargon.id).label('total_jargon'), - func.sum(case((Jargon.is_complete == True, 1), else_=0)).label('complete_jargon'), - func.sum(case((Jargon.is_global == True, 1), else_=0)).label('global_jargon') - ).where( - Jargon.is_jargon == True - ).group_by( - Jargon.chat_id - ).order_by( - func.count(Jargon.id).desc() - ) - - result = await session.execute(stmt) - rows = result.all() - - logger.debug(f"[SQLAlchemy] 查询黑话群组列表: 数量={len(rows)}") - - groups = [] - for row in rows: - try: - groups.append({ - 'group_id': row.group_id, - 'total_jargon': row.total_jargon or 0, - 'complete_jargon': row.complete_jargon or 0, - 'global_jargon': row.global_jargon or 0 - }) - except Exception as e: - logger.warning(f"处理黑话群组数据行失败: {e}, 行数据: {row}") - continue - - return groups - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取黑话群组列表失败: {e}") - raise RuntimeError(f"无法获取黑话群组列表: {e}") from e - - async def get_group_user_statistics(self, group_id: str) -> Dict[str, Dict[str, Any]]: - """ - 获取群组用户消息统计(ORM 版本) - - Args: - group_id: 群组ID - - Returns: - 字典,key 为 user_id,value 包含 sender_name 和 message_count - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..models.orm import RawMessage - - # 统计每个用户在该群组的消息总数 - stmt = select( - RawMessage.sender_id, - func.max(RawMessage.sender_name).label('sender_name'), - func.count(RawMessage.id).label('message_count') - ).where( - RawMessage.group_id == group_id, - RawMessage.sender_id != 'bot' - ).group_by( - RawMessage.sender_id - ) - - result = await session.execute(stmt) - rows = result.all() - - logger.debug(f"[SQLAlchemy] 查询群组用户统计: group_id={group_id}, 用户数={len(rows)}") - - user_stats = {} - for row in rows: - try: - sender_id = row.sender_id - if sender_id: - user_stats[sender_id] = { - 'sender_name': row.sender_name or sender_id, - 'message_count': row.message_count or 0 - } - except Exception as row_error: - logger.warning(f"处理用户统计数据行失败: {row_error}, row: {row}") - continue - - return user_stats - - except Exception as e: - logger.error(f"[SQLAlchemy] 获取群组用户统计失败: {e}") - raise RuntimeError(f"无法获取群组 {group_id} 的用户统计: {e}") from e - - async def count_refined_messages(self) -> int: - """ - 统计提炼内容数量(ORM 版本 - 支持跨线程调用) - - Returns: - 提炼消息的数量 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..models.orm import FilteredMessage - - # 统计 refined = True 的消息数量 - stmt = select(func.count(FilteredMessage.id)).where( - FilteredMessage.processed == True # refined 字段在某些版本中是 processed - ) - - result = await session.execute(stmt) - count = result.scalar() or 0 - - logger.debug(f"[SQLAlchemy] 统计提炼消息数量: {count}") - return count - - except Exception as e: - logger.error(f"[SQLAlchemy] 统计提炼消息数量失败: {e}") - return 0 - - async def count_style_learning_patterns(self) -> int: - """ - 统计风格学习模式数量(ORM 版本 - 支持跨线程调用) - - Returns: - 风格学习模式的数量 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..models.orm import StyleLearningPattern - - # 统计所有风格学习模式 - stmt = select(func.count(StyleLearningPattern.id)) - - result = await session.execute(stmt) - count = result.scalar() or 0 - - logger.debug(f"[SQLAlchemy] 统计风格学习模式数量: {count}") - return count - - except Exception as e: - logger.error(f"[SQLAlchemy] 统计风格学习模式数量失败: {e}") - return 0 - - async def count_pending_persona_updates(self) -> int: - """ - 统计待审查的人格更新数量(ORM 版本 - 支持跨线程调用) - - Returns: - 待审查人格更新的数量 - """ - try: - async with self.get_session() as session: - from sqlalchemy import select, func - from ..models.orm import PersonaLearningReview - - # 统计 status = 'pending' 的记录 - stmt = select(func.count(PersonaLearningReview.id)).where( - PersonaLearningReview.status == 'pending' - ) - - result = await session.execute(stmt) - count = result.scalar() or 0 - - logger.debug(f"[SQLAlchemy] 统计待审查人格更新数量: {count}") - return count - - except Exception as e: - logger.error(f"[SQLAlchemy] 统计待审查人格更新数量失败: {e}") - return 0 - - def get_db_connection(self): - """ - 获取数据库连接(兼容性方法) - - ⚠️ 向后兼容策略: - - 如果有传统数据库管理器,返回其连接(支持 cursor() 方法) - - 否则返回 SQLAlchemy 会话工厂(不支持 cursor()) - - Returns: - 传统数据库连接或 AsyncSession 工厂 - """ - if self._legacy_db: - logger.debug("[SQLAlchemy] get_db_connection() 被调用,返回传统数据库连接(兼容 cursor())") - return self._legacy_db.get_db_connection() - else: - logger.debug("[SQLAlchemy] get_db_connection() 被调用,返回 SQLAlchemy 会话工厂") - return self.get_session() - - def __getattr__(self, name): - """ - 魔法方法:自动降级未实现的方法到传统数据库管理器 - - 当访问 SQLAlchemyDatabaseManager 中不存在的属性/方法时: - 1. 检查传统数据库管理器是否可用 - 2. 如果可用,返回传统管理器的对应方法 - 3. 如果不可用,抛出 AttributeError - """ - # 避免无限递归 - if name in ('_legacy_db', '_started', 'config', 'context', 'engine'): - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - # 如果传统数据库管理器可用,尝试从它获取属性 - if self._legacy_db and hasattr(self._legacy_db, name): - attr = getattr(self._legacy_db, name) - logger.debug(f"[SQLAlchemy] 方法 '{name}' 未实现 ORM 版本,降级到传统数据库管理器") - return attr - - # 如果传统数据库管理器也没有这个属性,抛出 AttributeError - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}', " - f"and legacy database manager is {'not available' if not self._legacy_db else 'missing this attribute'}" - ) diff --git a/services/state/__init__.py b/services/state/__init__.py new file mode 100644 index 0000000..99f5794 --- /dev/null +++ b/services/state/__init__.py @@ -0,0 +1,17 @@ +"""Runtime state management -- psychological, interaction, memory, affection.""" + +from .enhanced_psychological_state_manager import EnhancedPsychologicalStateManager +from .enhanced_interaction import EnhancedInteractionService +from .enhanced_memory_graph_manager import EnhancedMemoryGraphManager +from .time_decay_manager import TimeDecayManager +from .affection_manager import AffectionManager, MoodType, BotMood + +__all__ = [ + "EnhancedPsychologicalStateManager", + "EnhancedInteractionService", + "EnhancedMemoryGraphManager", + "TimeDecayManager", + "AffectionManager", + "MoodType", + "BotMood", +] diff --git a/services/affection_manager.py b/services/state/affection_manager.py similarity index 95% rename from services/affection_manager.py rename to services/state/affection_manager.py index dc3b642..b26faf4 100644 --- a/services/affection_manager.py +++ b/services/state/affection_manager.py @@ -11,13 +11,13 @@ from astrbot.api import logger -from ..config import PluginConfig +from ...config import PluginConfig -from ..core.patterns import AsyncServiceBase +from ...core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage +from ...core.interfaces import IDataStorage -from ..core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 +from ...core.framework_llm_adapter import FrameworkLLMAdapter # 导入框架适配器 class MoodType(Enum): @@ -436,62 +436,70 @@ async def _initialize_random_moods_for_active_groups(self): async def _get_active_groups(self) -> List[str]: """获取活跃群组列表(从数据库中获取最近有消息的群组)""" try: - # 从数据库获取最近24小时内有消息的群组 - async with self.db_manager.get_db_connection() as conn: - cursor = await conn.cursor() - + from sqlalchemy import select, func, and_ + from ...models.orm.message import RawMessage + + async with self.db_manager.get_session() as session: + active_groups = [] + # 先尝试获取最近24小时内有消息的群组 cutoff_time = time.time() - 86400 # 24小时前 - await cursor.execute(''' - SELECT DISTINCT group_id, COUNT(*) as msg_count - FROM raw_messages - WHERE timestamp > ? AND group_id IS NOT NULL AND group_id != '' - GROUP BY group_id - HAVING msg_count >= 3 - ORDER BY msg_count DESC - LIMIT 20 - ''', (cutoff_time,)) - - active_groups = [] - for row in await cursor.fetchall(): + stmt = ( + select(RawMessage.group_id, func.count().label('msg_count')) + .where(and_( + RawMessage.timestamp > cutoff_time, + RawMessage.group_id.isnot(None), + RawMessage.group_id != '', + )) + .group_by(RawMessage.group_id) + .having(func.count() >= 3) + .order_by(func.count().desc()) + .limit(20) + ) + result = await session.execute(stmt) + for row in result.all(): if row[0]: # 确保group_id不为空 active_groups.append(row[0]) - + # 如果24小时内没有活跃群组,扩大时间范围到7天,降低消息数要求 if not active_groups: cutoff_time = time.time() - 604800 # 7天前 - await cursor.execute(''' - SELECT DISTINCT group_id, COUNT(*) as msg_count - FROM raw_messages - WHERE timestamp > ? AND group_id IS NOT NULL AND group_id != '' - GROUP BY group_id - HAVING msg_count >= 1 - ORDER BY msg_count DESC - LIMIT 10 - ''', (cutoff_time,)) - - for row in await cursor.fetchall(): + stmt = ( + select(RawMessage.group_id, func.count().label('msg_count')) + .where(and_( + RawMessage.timestamp > cutoff_time, + RawMessage.group_id.isnot(None), + RawMessage.group_id != '', + )) + .group_by(RawMessage.group_id) + .having(func.count() >= 1) + .order_by(func.count().desc()) + .limit(10) + ) + result = await session.execute(stmt) + for row in result.all(): if row[0]: # 确保group_id不为空 active_groups.append(row[0]) - + # 如果还是没有,获取所有有消息记录的群组 if not active_groups: - await cursor.execute(''' - SELECT DISTINCT group_id - FROM raw_messages - WHERE group_id IS NOT NULL AND group_id != '' - LIMIT 5 - ''') - - for row in await cursor.fetchall(): - if row[0]: # 确保group_id不为空 - active_groups.append(row[0]) - - await cursor.close() - + stmt = ( + select(RawMessage.group_id) + .where(and_( + RawMessage.group_id.isnot(None), + RawMessage.group_id != '', + )) + .distinct() + .limit(5) + ) + result = await session.execute(stmt) + for row in result.scalars().all(): + if row: # 确保group_id不为空 + active_groups.append(row) + self._logger.info(f"找到 {len(active_groups)} 个活跃群组用于情绪初始化") return active_groups - + except Exception as e: self._logger.error(f"获取活跃群组列表失败: {e}") # 返回空列表,让调用者决定如何处理 diff --git a/services/enhanced_interaction.py b/services/state/enhanced_interaction.py similarity index 99% rename from services/enhanced_interaction.py rename to services/state/enhanced_interaction.py index f57ed5c..616a629 100644 --- a/services/enhanced_interaction.py +++ b/services/state/enhanced_interaction.py @@ -15,13 +15,13 @@ from astrbot.api import logger -from ..config import PluginConfig +from ...config import PluginConfig -from ..core.patterns import AsyncServiceBase +from ...core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage +from ...core.interfaces import IDataStorage -from ..core.framework_llm_adapter import FrameworkLLMAdapter +from ...core.framework_llm_adapter import FrameworkLLMAdapter @dataclass diff --git a/services/enhanced_memory_graph_manager.py b/services/state/enhanced_memory_graph_manager.py similarity index 66% rename from services/enhanced_memory_graph_manager.py rename to services/state/enhanced_memory_graph_manager.py index 156416e..f9a9881 100644 --- a/services/enhanced_memory_graph_manager.py +++ b/services/state/enhanced_memory_graph_manager.py @@ -1,36 +1,212 @@ """ -增强型记忆图管理器 -使用 CacheManager、Repository 和 TaskScheduler,与现有接口兼容 +记忆图管理器 (增强版) +使用 CacheManager、Repository 和 TaskScheduler,与现有接口兼容。 +基于 NetworkX 图结构实现概念关联和智能记忆融合。 """ import time import json +import math from typing import Dict, List, Optional, Tuple, Any from datetime import datetime +from dataclasses import dataclass, asdict +from collections import Counter import networkx as nx from astrbot.api import logger -from ..core.interfaces import MessageData -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..config import PluginConfig -from ..utils.cache_manager import get_cache_manager, async_cached -from ..utils.task_scheduler import get_task_scheduler +from ...core.interfaces import MessageData +from ...core.framework_llm_adapter import FrameworkLLMAdapter +from ...config import PluginConfig +from ...utils.cache_manager import get_cache_manager, async_cached +from ...utils.task_scheduler import get_task_scheduler # 导入 Repository -from ..repositories import ( +from ...repositories import ( MemoryRepository, MemoryEmbeddingRepository, MemorySummaryRepository ) -# 导入原有的数据类和图类 -from .memory_graph_manager import ( - MemoryNode, - MemoryEdge, - MemoryGraph, - MemoryGraphManager as OriginalMemoryGraphManager -) + +# 数据类 + +@dataclass +class MemoryNode: + """记忆节点""" + concept: str + memory_items: str + weight: float + created_time: float + last_modified: float + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'MemoryNode': + return cls(**data) + + +@dataclass +class MemoryEdge: + """记忆边""" + concept1: str + concept2: str + strength: float + created_time: float + last_modified: float + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'MemoryEdge': + return cls(**data) + + +class MemoryGraph: + """ + 记忆图 - 使用 NetworkX 实现概念关联和记忆管理 + """ + + def __init__(self): + self.G = nx.Graph() + + def connect_concepts(self, concept1: str, concept2: str): + """连接两个概念""" + if concept1 == concept2: + return + + current_time = time.time() + + if self.G.has_edge(concept1, concept2): + self.G[concept1][concept2]["strength"] = self.G[concept1][concept2].get("strength", 1) + 1 + self.G[concept1][concept2]["last_modified"] = current_time + else: + self.G.add_edge( + concept1, concept2, + strength=1, + created_time=current_time, + last_modified=current_time, + ) + + async def add_memory_node(self, concept: str, memory: str, llm_adapter: Optional[FrameworkLLMAdapter] = None): + """添加记忆节点,支持 LLM 智能记忆融合""" + current_time = time.time() + + if concept in self.G: + if "memory_items" in self.G.nodes[concept]: + existing_memory = self.G.nodes[concept]["memory_items"] + + if existing_memory and llm_adapter: + try: + integrated_memory = await self._integrate_memories_with_llm( + existing_memory, str(memory), llm_adapter + ) + self.G.nodes[concept]["memory_items"] = integrated_memory + current_weight = self.G.nodes[concept].get("weight", 0.0) + self.G.nodes[concept]["weight"] = current_weight + 1.0 + except Exception as e: + logger.error(f"LLM 整合记忆失败: {e}") + self.G.nodes[concept]["memory_items"] = f"{existing_memory} | {memory}" + else: + self.G.nodes[concept]["memory_items"] = str(memory) + else: + self.G.nodes[concept]["memory_items"] = str(memory) + if "created_time" not in self.G.nodes[concept]: + self.G.nodes[concept]["created_time"] = current_time + + self.G.nodes[concept]["last_modified"] = current_time + else: + self.G.add_node( + concept, + memory_items=str(memory), + weight=1.0, + created_time=current_time, + last_modified=current_time, + ) + + async def _integrate_memories_with_llm(self, old_memory: str, new_memory: str, llm_adapter: FrameworkLLMAdapter) -> str: + """使用 LLM 智能整合记忆""" + from ...statics.prompts import MEMORY_INTEGRATION_PROMPT + + prompt = MEMORY_INTEGRATION_PROMPT.format( + old_memory=old_memory, + new_memory=new_memory + ) + + response = await llm_adapter.generate_response( + prompt, temperature=0.3, model_type="refine" + ) + return response.strip() + + def get_memory_node(self, concept: str) -> Optional[Tuple[str, Dict[str, Any]]]: + """获取记忆节点""" + return (concept, self.G.nodes[concept]) if concept in self.G else None + + def get_related_concepts(self, topic: str, depth: int = 1) -> Tuple[List[str], List[str]]: + """获取相关概念""" + if topic not in self.G: + return [], [] + + first_layer_items = [] + second_layer_items = [] + + neighbors = list(self.G.neighbors(topic)) + + node_data = self.get_memory_node(topic) + if node_data: + _, data = node_data + if "memory_items" in data: + first_layer_items.append(data["memory_items"]) + + for neighbor in neighbors: + neighbor_data = self.get_memory_node(neighbor) + if neighbor_data: + _, data = neighbor_data + if "memory_items" in data: + first_layer_items.append(data["memory_items"]) + + if depth > 1: + second_neighbors = list(self.G.neighbors(neighbor)) + for second_neighbor in second_neighbors: + if second_neighbor != topic and second_neighbor not in neighbors: + second_data = self.get_memory_node(second_neighbor) + if second_data: + _, second_node_data = second_data + if "memory_items" in second_node_data: + second_layer_items.append(second_node_data["memory_items"]) + + return first_layer_items, second_layer_items + + def calculate_information_content(self, text: str) -> float: + """计算文本信息熵""" + char_count = Counter(text) + total_chars = len(text) + if total_chars == 0: + return 0 + + entropy = 0 + for count in char_count.values(): + probability = count / total_chars + entropy -= probability * math.log2(probability) + + return entropy + + def get_graph_statistics(self) -> Dict[str, Any]: + """获取图的统计信息""" + return { + "nodes_count": self.G.number_of_nodes(), + "edges_count": self.G.number_of_edges(), + "density": nx.density(self.G), + "connected_components": nx.number_connected_components(self.G), + "average_clustering": nx.average_clustering(self.G) if self.G.number_of_nodes() > 0 else 0, + "average_shortest_path": nx.average_shortest_path_length(self.G) if nx.is_connected(self.G) else 0 + } + + +# 服务类 class EnhancedMemoryGraphManager: @@ -44,9 +220,6 @@ class EnhancedMemoryGraphManager: 4. 保持与原有接口的兼容性 用法: - # 在配置中启用 - config.use_enhanced_managers = True - # 创建管理器 memory_mgr = EnhancedMemoryGraphManager.get_instance(config, db_manager, llm_adapter) await memory_mgr.start() @@ -119,11 +292,11 @@ async def start(self) -> bool: hours=1 ) - logger.info("✅ [增强型记忆图] 启动成功") + logger.info(" [增强型记忆图] 启动成功") return True except Exception as e: - logger.error(f"❌ [增强型记忆图] 启动失败: {e}") + logger.error(f" [增强型记忆图] 启动失败: {e}") return False async def stop(self) -> bool: @@ -140,16 +313,14 @@ async def stop(self) -> bool: # 清除缓存 self.cache.clear('memory') - logger.info("✅ [增强型记忆图] 已停止") + logger.info(" [增强型记忆图] 已停止") return True except Exception as e: - logger.error(f"❌ [增强型记忆图] 停止失败: {e}") + logger.error(f" [增强型记忆图] 停止失败: {e}") return False - # ============================================================ # 核心方法(与原接口兼容) - # ============================================================ def get_memory_graph(self, group_id: str) -> MemoryGraph: """ @@ -239,7 +410,7 @@ async def save_memory_graph(self, group_id: str): # 创建或更新记忆 await memory_repo.create_memory( group_id=group_id, - user_id='', # 群组级别记忆 + user_id='', # 群组级别记忆 content=memory_items, memory_type='concept', importance=node_data.get('weight', 0.5), @@ -368,9 +539,7 @@ async def get_memory_graph_statistics(self, group_id: str) -> Dict[str, Any]: logger.error(f"[增强型记忆图] 获取统计信息失败: {e}") return {} - # ============================================================ # 辅助方法 - # ============================================================ async def _extract_concepts_from_message(self, message: MessageData) -> List[str]: """从消息提取概念""" @@ -397,9 +566,7 @@ def _invalidate_related_caches(self, group_id: str): # CacheManager 不支持模式匹配删除,所以这里只是示例 logger.debug(f"[增强型记忆图] 清除群组 {group_id} 的相关缓存") - # ============================================================ # 任务调度方法 - # ============================================================ async def _cleanup_old_memories_task(self): """清理旧记忆任务(由调度器调用)""" @@ -442,9 +609,7 @@ async def _auto_save_memory_graphs_task(self): except Exception as e: logger.error(f"[增强型记忆图] 自动保存失败: {e}") - # ============================================================ # 缓存统计方法 - # ============================================================ def get_cache_stats(self) -> dict: """获取缓存统计信息""" @@ -454,3 +619,7 @@ def clear_cache(self): """清除所有缓存""" self.cache.clear('memory') logger.info("[增强型记忆图] 已清除所有缓存") + + +# 向后兼容别名: 其他模块可以 from .enhanced_memory_graph_manager import MemoryGraphManager +MemoryGraphManager = EnhancedMemoryGraphManager diff --git a/services/enhanced_psychological_state_manager.py b/services/state/enhanced_psychological_state_manager.py similarity index 90% rename from services/enhanced_psychological_state_manager.py rename to services/state/enhanced_psychological_state_manager.py index 0a638a5..7f4ee2b 100644 --- a/services/enhanced_psychological_state_manager.py +++ b/services/state/enhanced_psychological_state_manager.py @@ -10,22 +10,22 @@ from astrbot.api import logger -from ..config import PluginConfig -from ..core.patterns import AsyncServiceBase -from ..core.interfaces import IDataStorage -from ..core.framework_llm_adapter import FrameworkLLMAdapter -from ..utils.cache_manager import get_cache_manager, async_cached -from ..utils.task_scheduler import get_task_scheduler +from ...config import PluginConfig +from ...core.patterns import AsyncServiceBase +from ...core.interfaces import IDataStorage +from ...core.framework_llm_adapter import FrameworkLLMAdapter +from ...utils.cache_manager import get_cache_manager, async_cached +from ...utils.task_scheduler import get_task_scheduler # 导入 Repository -from ..repositories import ( +from ...repositories import ( PsychologicalStateRepository, PsychologicalComponentRepository, PsychologicalHistoryRepository ) # 导入原有的模型和枚举 -from ..models.psychological_state import ( +from ...models.psychological_state import ( EmotionPositiveType, EmotionNegativeType, EmotionNeutralType, AttentionState, ThinkingState, MemoryState, WillStrengthState, ActionTendencyState, GoalOrientationState, @@ -35,9 +35,6 @@ PsychologicalStateComponent, CompositePsychologicalState ) -# 导入原有的管理器用于获取time_based_rules等 -from .psychological_state_manager import PsychologicalStateManager as OriginalPsychologicalStateManager - class EnhancedPsychologicalStateManager(AsyncServiceBase): """ @@ -50,9 +47,6 @@ class EnhancedPsychologicalStateManager(AsyncServiceBase): 4. 保持与原有接口的兼容性 用法: - # 在配置中启用 - config.use_enhanced_managers = True - # 创建管理器 state_mgr = EnhancedPsychologicalStateManager(config, db_manager, llm_adapter) await state_mgr.start() @@ -133,11 +127,11 @@ async def _do_start(self) -> bool: minute=0 ) - self._logger.info("✅ [增强型心理状态] 启动成功") + self._logger.info(" [增强型心理状态] 启动成功") return True except Exception as e: - self._logger.error(f"❌ [增强型心理状态] 启动失败: {e}", exc_info=True) + self._logger.error(f" [增强型心理状态] 启动失败: {e}", exc_info=True) return False async def _do_stop(self) -> bool: @@ -155,16 +149,14 @@ async def _do_stop(self) -> bool: # 清除缓存 self.cache.clear('state') - self._logger.info("✅ [增强型心理状态] 已停止") + self._logger.info(" [增强型心理状态] 已停止") return True except Exception as e: - self._logger.error(f"❌ [增强型心理状态] 停止失败: {e}") + self._logger.error(f" [增强型心理状态] 停止失败: {e}") return False - # ============================================================ # 使用缓存装饰器的方法 - # ============================================================ @async_cached( cache_name='state', @@ -204,7 +196,7 @@ async def get_current_state( for comp in components: state_components[comp.component_name] = PsychologicalStateComponent( dimension=comp.component_name, - state_type=comp.component_name, # TODO: 需要解析类型 + state_type=comp.component_name, # TODO: 需要解析类型 value=comp.value, threshold=comp.threshold ) @@ -333,9 +325,7 @@ async def get_state_prompt_injection( self._logger.error(f"[增强型心理状态] 生成注入内容失败: {e}") return "" - # ============================================================ # 任务调度方法 - # ============================================================ async def _auto_decay_task(self): """状态自动衰减任务(由调度器调用)""" @@ -397,16 +387,14 @@ async def _cleanup_history_task(self): # TODO: 获取所有状态ID并清理30天前的历史 # 示例实现 # for state_id in state_ids: - # deleted = await history_repo.clean_old_history(state_id, days=30) + # deleted = await history_repo.clean_old_history(state_id, days=30) self._logger.info("[增强型心理状态] 历史清理完成") except Exception as e: self._logger.error(f"[增强型心理状态] 清理历史失败: {e}") - # ============================================================ # 辅助方法(保持原有逻辑) - # ============================================================ def _init_time_based_rules(self) -> List[Dict[str, Any]]: """初始化基于时间的状态变化规则""" @@ -450,9 +438,7 @@ async def _save_all_states(self): except Exception as e: self._logger.error(f"[增强型心理状态] 保存状态失败: {e}") - # ============================================================ # 缓存统计方法 - # ============================================================ def get_cache_stats(self) -> dict: """获取缓存统计信息""" diff --git a/services/state/time_decay_manager.py b/services/state/time_decay_manager.py new file mode 100644 index 0000000..ed7c4d2 --- /dev/null +++ b/services/state/time_decay_manager.py @@ -0,0 +1,365 @@ +""" +时间衰减管理器 - 实现MaiBot的时间衰减机制(ORM 版本) +为现有学习系统添加时间衰减功能,保持学习内容的时效性 + +注意:expression_patterns 的衰减由 ExpressionPatternLearner._apply_time_decay 处理, +本模块处理其余表的衰减。 +""" +import asyncio +import time +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass + +from astrbot.api import logger + +from ...core.interfaces import ServiceLifecycle +from ...config import PluginConfig +from ...exceptions import TimeDecayError +from ..database import DatabaseManager + + +@dataclass +class DecayConfig: + """衰减配置""" + decay_days: int = 15 # MaiBot的15天衰减周期 + decay_min: float = 0.01 # 最小衰减值 + table_key: str = "" # 逻辑表标识(不再直接用于 SQL) + + +class TimeDecayManager: + """ + 时间衰减管理器 - 完全基于MaiBot的衰减机制设计(ORM 版本) + 为各种学习数据提供统一的时间衰减管理 + + 所有数据库操作通过 SQLAlchemy ORM 执行,不使用原始 SQL。 + """ + + def __init__(self, config: PluginConfig, db_manager: DatabaseManager): + self.config = config + self.db_manager = db_manager + self._status = ServiceLifecycle.CREATED + + # 预定义的衰减配置(逻辑名 → 衰减参数) + self.decay_configs = { + 'learning_batches': DecayConfig( + decay_days=7, + table_key='learning_batches', + ), + 'expression_patterns': DecayConfig( + decay_days=15, + table_key='expression_patterns', + ), + } + + async def start(self) -> bool: + """启动服务""" + self._status = ServiceLifecycle.RUNNING + logger.info("TimeDecayManager服务已启动") + return True + + async def stop(self) -> bool: + """停止服务""" + self._status = ServiceLifecycle.STOPPED + logger.info("TimeDecayManager服务已停止") + return True + + def calculate_decay_factor(self, time_diff_days: float, decay_days: int = 15) -> float: + """ + 计算衰减因子 - 完全采用MaiBot的衰减算法 + + Args: + time_diff_days: 时间差(天) + decay_days: 衰减周期天数 + + Returns: + 衰减因子 + """ + if time_diff_days <= 0: + return 0.0 # 刚激活的不衰减 + + if time_diff_days >= decay_days: + return 0.01 # 长时间未活跃的大幅衰减 + + # 使用二次函数插值:在0-decay_days天之间从0衰减到0.01 + a = 0.01 / (decay_days ** 2) + decay = a * (time_diff_days ** 2) + + return min(0.01, decay) + + async def apply_decay_to_table( + self, decay_config: DecayConfig, group_id: Optional[str] = None + ) -> Tuple[int, int]: + """ + 对指定表应用时间衰减(ORM 版本) + + Args: + decay_config: 衰减配置 + group_id: 可选的群组ID筛选 + + Returns: + (更新数量, 删除数量) + """ + table_key = decay_config.table_key + handler = self._TABLE_HANDLERS.get(table_key) + if not handler: + logger.debug(f"表 {table_key} 没有衰减处理器,跳过") + return 0, 0 + + try: + return await handler(self, decay_config, group_id) + except Exception as e: + logger.error(f"对表 {table_key} 应用时间衰减失败: {e}") + raise TimeDecayError(f"时间衰减失败: {e}") + + # ---- Per-table decay handlers ---- + + async def _decay_learning_batches( + self, decay_config: DecayConfig, group_id: Optional[str] = None + ) -> Tuple[int, int]: + """对 learning_batches 表应用衰减""" + from sqlalchemy import select, delete + from ...models.orm.learning import LearningBatch + + current_time = time.time() + updated_count = 0 + deleted_count = 0 + + async with self.db_manager.get_session() as session: + stmt = select(LearningBatch) + if group_id: + stmt = stmt.where(LearningBatch.group_id == group_id) + result = await session.execute(stmt) + batches = result.scalars().all() + + ids_to_delete = [] + for batch in batches: + if batch.start_time is None: + continue + time_diff_days = (current_time - batch.start_time) / (24 * 3600) + decay_value = self.calculate_decay_factor(time_diff_days, decay_config.decay_days) + + current_score = batch.quality_score or 1.0 + new_score = max(decay_config.decay_min, current_score - decay_value) + + if new_score <= decay_config.decay_min: + ids_to_delete.append(batch.id) + deleted_count += 1 + else: + batch.quality_score = new_score + updated_count += 1 + + if ids_to_delete: + await session.execute( + delete(LearningBatch).where(LearningBatch.id.in_(ids_to_delete)) + ) + + await session.commit() + + if updated_count > 0 or deleted_count > 0: + group_info = f" (群组: {group_id})" if group_id else "" + logger.info(f"表 learning_batches{group_info} 时间衰减完成:更新 {updated_count},删除 {deleted_count}") + + return updated_count, deleted_count + + async def _decay_expression_patterns( + self, decay_config: DecayConfig, group_id: Optional[str] = None + ) -> Tuple[int, int]: + """对 expression_patterns 表应用衰减""" + from sqlalchemy import select, delete + from ...models.orm.expression import ExpressionPattern + + current_time = time.time() + updated_count = 0 + deleted_count = 0 + + async with self.db_manager.get_session() as session: + stmt = select(ExpressionPattern) + if group_id: + stmt = stmt.where(ExpressionPattern.group_id == group_id) + result = await session.execute(stmt) + patterns = result.scalars().all() + + ids_to_delete = [] + for pattern in patterns: + time_diff_days = (current_time - pattern.last_active_time) / (24 * 3600) + decay_value = self.calculate_decay_factor(time_diff_days, decay_config.decay_days) + new_weight = max(decay_config.decay_min, pattern.weight - decay_value) + + if new_weight <= decay_config.decay_min: + ids_to_delete.append(pattern.id) + deleted_count += 1 + else: + pattern.weight = new_weight + updated_count += 1 + + if ids_to_delete: + await session.execute( + delete(ExpressionPattern).where(ExpressionPattern.id.in_(ids_to_delete)) + ) + + await session.commit() + + if updated_count > 0 or deleted_count > 0: + group_info = f" (群组: {group_id})" if group_id else "" + logger.info(f"表 expression_patterns{group_info} 时间衰减完成:更新 {updated_count},删除 {deleted_count}") + + return updated_count, deleted_count + + # Handler registry + _TABLE_HANDLERS = { + 'learning_batches': _decay_learning_batches, + 'expression_patterns': _decay_expression_patterns, + } + + async def apply_decay_to_all_tables( + self, group_id: Optional[str] = None + ) -> Dict[str, Tuple[int, int]]: + """ + 对所有配置的表应用时间衰减 + + Args: + group_id: 可选的群组ID筛选 + + Returns: + 每个表的(更新数量, 删除数量)结果 + """ + results = {} + + for table_name, decay_config in self.decay_configs.items(): + try: + updated, deleted = await self.apply_decay_to_table(decay_config, group_id) + results[table_name] = (updated, deleted) + except Exception as e: + logger.error(f"对表 {table_name} 应用衰减失败: {e}") + results[table_name] = (0, 0) + + return results + + async def add_decay_config(self, name: str, config: DecayConfig): + """添加新的衰减配置""" + self.decay_configs[name] = config + logger.info(f"添加衰减配置: {name}") + + async def get_decay_statistics( + self, group_id: Optional[str] = None + ) -> Dict[str, Dict[str, Any]]: + """ + 获取衰减统计信息(ORM 版本) + + Args: + group_id: 可选的群组ID筛选 + + Returns: + 各表的衰减统计信息 + """ + statistics = {} + current_time = time.time() + + # learning_batches 统计 + try: + stats = await self._stats_learning_batches(group_id, current_time) + if stats: + statistics['learning_batches'] = stats + except Exception as e: + logger.error(f"获取 learning_batches 衰减统计失败: {e}") + statistics['learning_batches'] = {'error': str(e)} + + # expression_patterns 统计 + try: + stats = await self._stats_expression_patterns(group_id, current_time) + if stats: + statistics['expression_patterns'] = stats + except Exception as e: + logger.error(f"获取 expression_patterns 衰减统计失败: {e}") + statistics['expression_patterns'] = {'error': str(e)} + + return statistics + + async def _stats_learning_batches( + self, group_id: Optional[str], current_time: float + ) -> Optional[Dict[str, Any]]: + from sqlalchemy import select, func + from ...models.orm.learning import LearningBatch + + async with self.db_manager.get_session() as session: + stmt = select( + func.count().label('total_count'), + func.avg(LearningBatch.quality_score).label('avg_weight'), + func.min(LearningBatch.start_time).label('oldest_time'), + func.max(LearningBatch.start_time).label('newest_time'), + ).select_from(LearningBatch) + if group_id: + stmt = stmt.where(LearningBatch.group_id == group_id) + + row = (await session.execute(stmt)).one_or_none() + if not row or not row.total_count: + return None + + cfg = self.decay_configs.get('learning_batches', DecayConfig(decay_days=7)) + oldest_days = (current_time - row.oldest_time) / (24 * 3600) if row.oldest_time else 0 + newest_days = (current_time - row.newest_time) / (24 * 3600) if row.newest_time else 0 + + return { + 'total_count': row.total_count, + 'avg_weight': round(row.avg_weight, 3) if row.avg_weight else 0, + 'oldest_days': round(oldest_days, 1), + 'newest_days': round(newest_days, 1), + 'decay_config': {'decay_days': cfg.decay_days, 'decay_min': cfg.decay_min}, + } + + async def _stats_expression_patterns( + self, group_id: Optional[str], current_time: float + ) -> Optional[Dict[str, Any]]: + from sqlalchemy import select, func + from ...models.orm.expression import ExpressionPattern + + async with self.db_manager.get_session() as session: + stmt = select( + func.count().label('total_count'), + func.avg(ExpressionPattern.weight).label('avg_weight'), + func.min(ExpressionPattern.last_active_time).label('oldest_time'), + func.max(ExpressionPattern.last_active_time).label('newest_time'), + ).select_from(ExpressionPattern) + if group_id: + stmt = stmt.where(ExpressionPattern.group_id == group_id) + + row = (await session.execute(stmt)).one_or_none() + if not row or not row.total_count: + return None + + cfg = self.decay_configs.get('expression_patterns', DecayConfig(decay_days=15)) + oldest_days = (current_time - row.oldest_time) / (24 * 3600) if row.oldest_time else 0 + newest_days = (current_time - row.newest_time) / (24 * 3600) if row.newest_time else 0 + + return { + 'total_count': row.total_count, + 'avg_weight': round(row.avg_weight, 3) if row.avg_weight else 0, + 'oldest_days': round(oldest_days, 1), + 'newest_days': round(newest_days, 1), + 'decay_config': {'decay_days': cfg.decay_days, 'decay_min': cfg.decay_min}, + } + + async def schedule_decay_maintenance(self, interval_hours: int = 24): + """ + 定期衰减维护任务 + + Args: + interval_hours: 维护间隔小时数 + """ + logger.info(f"启动定期衰减维护,间隔: {interval_hours}小时") + + while self._status == ServiceLifecycle.RUNNING: + try: + results = await self.apply_decay_to_all_tables() + + total_updated = sum(r[0] for r in results.values()) + total_deleted = sum(r[1] for r in results.values()) + + if total_updated > 0 or total_deleted > 0: + logger.info(f"定期衰减维护完成,总计更新: {total_updated},删除: {total_deleted}") + + await asyncio.sleep(interval_hours * 3600) + + except Exception as e: + logger.error(f"定期衰减维护失败: {e}") + await asyncio.sleep(3600) diff --git a/services/table_schemas.py b/services/table_schemas.py deleted file mode 100644 index 73fe274..0000000 --- a/services/table_schemas.py +++ /dev/null @@ -1,526 +0,0 @@ -""" -数据库表结构定义 - -⚠️ 已废弃:所有表结构由 SQLAlchemy ORM 统一管理 -此文件保留仅供参考,不再使用 - -新的表结构定义位置: -- models/orm/message.py - 消息相关表 -- models/orm/psychological.py - 心理状态表 -- models/orm/social_relation.py - 社交关系表 -- models/orm/affection.py - 好感度表 -- models/orm/memory.py - 记忆表 -- models/orm/learning.py - 学习记录表 -- models/orm/expression.py - 表达模式表 -- models/orm/jargon.py - 黑话表 -- models/orm/social_analysis.py - 社交分析表 -""" -from typing import Dict, Tuple -from ..core.database.backend_interface import DatabaseType - - -class TableSchemas: - """ - 数据库表结构定义 - - ⚠️ 已废弃:所有表结构由 SQLAlchemy ORM 统一管理 - 请使用 models/orm/ 目录下的 ORM 模型定义 - """ - - @staticmethod - def get_all_table_schemas() -> Dict[str, Tuple[str, str]]: - """ - 获取所有表的DDL语句 - - Returns: - Dict[table_name, (sqlite_ddl, mysql_ddl)] - """ - return { - # 原始消息表(匹配 ORM 模型 RawMessage) - 'raw_messages': ( - '''CREATE TABLE IF NOT EXISTS raw_messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - sender_id TEXT NOT NULL, - sender_name TEXT, - message TEXT NOT NULL, - group_id TEXT, - timestamp INTEGER NOT NULL, - platform TEXT, - message_id TEXT, - reply_to TEXT, - created_at INTEGER NOT NULL, - processed INTEGER DEFAULT 0 - )''', - '''CREATE TABLE IF NOT EXISTS raw_messages ( - id INT PRIMARY KEY AUTO_INCREMENT, - sender_id VARCHAR(255) NOT NULL, - sender_name VARCHAR(255), - message TEXT NOT NULL, - group_id VARCHAR(255), - timestamp BIGINT NOT NULL, - platform VARCHAR(100), - message_id VARCHAR(255), - reply_to VARCHAR(255), - created_at BIGINT NOT NULL, - processed TINYINT DEFAULT 0, - INDEX idx_raw_timestamp (timestamp), - INDEX idx_raw_sender (sender_id), - INDEX idx_raw_processed (processed), - INDEX idx_raw_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 筛选后消息表(匹配 ORM 模型 FilteredMessage) - 'filtered_messages': ( - '''CREATE TABLE IF NOT EXISTS filtered_messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - raw_message_id INTEGER, - message TEXT NOT NULL, - sender_id TEXT NOT NULL, - group_id TEXT, - timestamp INTEGER NOT NULL, - confidence REAL, - quality_scores TEXT, - filter_reason TEXT, - created_at INTEGER NOT NULL, - processed INTEGER DEFAULT 0 - )''', - '''CREATE TABLE IF NOT EXISTS filtered_messages ( - id INT PRIMARY KEY AUTO_INCREMENT, - raw_message_id INT, - message TEXT NOT NULL, - sender_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255), - timestamp BIGINT NOT NULL, - confidence DOUBLE, - quality_scores TEXT, - filter_reason TEXT, - created_at BIGINT NOT NULL, - processed TINYINT DEFAULT 0, - INDEX idx_filtered_timestamp (timestamp), - INDEX idx_filtered_sender (sender_id), - INDEX idx_filtered_processed (processed), - INDEX idx_filtered_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 社交关系表 - 'social_relations': ( - '''CREATE TABLE IF NOT EXISTS social_relations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - from_user TEXT NOT NULL, - to_user TEXT NOT NULL, - relation_type TEXT NOT NULL, - strength REAL NOT NULL, - frequency INTEGER NOT NULL, - last_interaction REAL NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - UNIQUE(from_user, to_user, relation_type) - )''', - '''CREATE TABLE IF NOT EXISTS social_relations ( - id INT PRIMARY KEY AUTO_INCREMENT, - from_user VARCHAR(255) NOT NULL, - to_user VARCHAR(255) NOT NULL, - relation_type VARCHAR(100) NOT NULL, - strength DOUBLE NOT NULL, - frequency INT NOT NULL, - last_interaction DOUBLE NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - UNIQUE KEY uk_from_to_type (from_user, to_user, relation_type), - INDEX idx_from_user (from_user), - INDEX idx_to_user (to_user), - INDEX idx_strength (strength) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 用户好感度表(匹配 ORM 模型 UserAffection) - 'user_affections': ( - '''CREATE TABLE IF NOT EXISTS user_affections ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - affection_level INTEGER DEFAULT 0 NOT NULL, - max_affection INTEGER DEFAULT 100 NOT NULL, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL, - UNIQUE(group_id, user_id) - )''', - '''CREATE TABLE IF NOT EXISTS user_affections ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - user_id VARCHAR(255) NOT NULL, - affection_level INT DEFAULT 0 NOT NULL, - max_affection INT DEFAULT 100 NOT NULL, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - UNIQUE KEY idx_group_user_affection (group_id, user_id), - INDEX idx_affection_group (group_id), - INDEX idx_affection_user (user_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 表达模式表(匹配 ORM 模型 ExpressionPattern) - 'expression_patterns': ( - '''CREATE TABLE IF NOT EXISTS expression_patterns ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - situation TEXT NOT NULL, - expression TEXT NOT NULL, - weight REAL NOT NULL DEFAULT 1.0, - last_active_time REAL NOT NULL, - create_time REAL NOT NULL - )''', - '''CREATE TABLE IF NOT EXISTS expression_patterns ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - situation TEXT NOT NULL, - expression TEXT NOT NULL, - weight DOUBLE NOT NULL DEFAULT 1.0, - last_active_time DOUBLE NOT NULL, - create_time DOUBLE NOT NULL, - INDEX idx_group_weight (group_id, weight), - INDEX idx_group_active (group_id, last_active_time) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 黑话表(匹配 ORM 模型 Jargon) - 'jargon': ( - '''CREATE TABLE IF NOT EXISTS jargon ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - content TEXT NOT NULL, - raw_content TEXT, - meaning TEXT, - is_jargon INTEGER, - count INTEGER DEFAULT 1, - last_inference_count INTEGER DEFAULT 0, - is_complete INTEGER DEFAULT 0, - is_global INTEGER DEFAULT 0, - chat_id TEXT NOT NULL, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL, - UNIQUE(content, chat_id) - )''', - '''CREATE TABLE IF NOT EXISTS jargon ( - id INT PRIMARY KEY AUTO_INCREMENT, - content TEXT NOT NULL, - raw_content TEXT, - meaning TEXT, - is_jargon TINYINT, - count INT DEFAULT 1, - last_inference_count INT DEFAULT 0, - is_complete TINYINT DEFAULT 0, - is_global TINYINT DEFAULT 0, - chat_id VARCHAR(255) NOT NULL, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - UNIQUE KEY uk_content_chat (content(255), chat_id), - INDEX idx_jargon_content (content(255)), - INDEX idx_jargon_chat_id (chat_id), - INDEX idx_jargon_is_jargon (is_jargon) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # Bot消息表(匹配 ORM 模型 BotMessage) - 'bot_messages': ( - '''CREATE TABLE IF NOT EXISTS bot_messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - message TEXT NOT NULL, - timestamp INTEGER NOT NULL, - created_at INTEGER NOT NULL - )''', - '''CREATE TABLE IF NOT EXISTS bot_messages ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - message TEXT NOT NULL, - timestamp BIGINT NOT NULL, - created_at BIGINT NOT NULL, - INDEX idx_bot_timestamp (timestamp), - INDEX idx_bot_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 人格学习审核表(匹配 ORM 模型 PersonaLearningReview) - 'persona_update_reviews': ( - '''CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp REAL NOT NULL, - group_id TEXT NOT NULL, - update_type TEXT NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, - confidence_score REAL, - reason TEXT, - status TEXT DEFAULT 'pending' NOT NULL, - reviewer_comment TEXT, - review_time REAL, - metadata TEXT - )''', - '''CREATE TABLE IF NOT EXISTS persona_update_reviews ( - id INT PRIMARY KEY AUTO_INCREMENT, - timestamp DOUBLE NOT NULL, - group_id VARCHAR(255) NOT NULL, - update_type VARCHAR(255) NOT NULL, - original_content TEXT, - new_content TEXT, - proposed_content TEXT, - confidence_score DOUBLE, - reason TEXT, - status VARCHAR(50) DEFAULT 'pending' NOT NULL, - reviewer_comment TEXT, - review_time DOUBLE, - metadata TEXT, - INDEX idx_group_persona_review (group_id, status), - INDEX idx_persona_review_timestamp (timestamp), - INDEX idx_persona_review_status (status) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 风格学习审查表(匹配 ORM 模型 StyleLearningReview) - 'style_learning_reviews': ( - '''CREATE TABLE IF NOT EXISTS style_learning_reviews ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - type TEXT NOT NULL, - group_id TEXT NOT NULL, - timestamp REAL NOT NULL, - learned_patterns TEXT, - few_shots_content TEXT, - status TEXT DEFAULT 'pending', - description TEXT, - reviewer_comment TEXT, - review_time REAL, - created_at TEXT, - updated_at TEXT - )''', - '''CREATE TABLE IF NOT EXISTS style_learning_reviews ( - id INT PRIMARY KEY AUTO_INCREMENT, - type VARCHAR(100) NOT NULL, - group_id VARCHAR(255) NOT NULL, - timestamp DOUBLE NOT NULL, - learned_patterns TEXT, - few_shots_content TEXT, - status VARCHAR(50) DEFAULT 'pending', - description TEXT, - reviewer_comment TEXT, - review_time DOUBLE, - created_at DATETIME, - updated_at DATETIME, - INDEX idx_status (status), - INDEX idx_group (group_id), - INDEX idx_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # ==================== 心理状态管理表 ==================== - - # 心理状态组件表(匹配 ORM 模型 PsychologicalStateComponent) - 'psychological_state_components': ( - '''CREATE TABLE IF NOT EXISTS psychological_state_components ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - composite_state_id INTEGER, - group_id TEXT NOT NULL, - state_id TEXT NOT NULL, - category TEXT NOT NULL, - state_type TEXT NOT NULL, - value REAL NOT NULL, - threshold REAL DEFAULT 0.3 NOT NULL, - description TEXT, - start_time INTEGER NOT NULL - )''', - '''CREATE TABLE IF NOT EXISTS psychological_state_components ( - id INT PRIMARY KEY AUTO_INCREMENT, - composite_state_id INT, - group_id VARCHAR(255) NOT NULL, - state_id VARCHAR(255) NOT NULL, - category VARCHAR(50) NOT NULL, - state_type VARCHAR(100) NOT NULL, - value DOUBLE NOT NULL, - threshold DOUBLE DEFAULT 0.3 NOT NULL, - description TEXT, - start_time BIGINT NOT NULL, - INDEX idx_psych_component_composite (composite_state_id), - INDEX idx_psych_component_state (state_id), - INDEX idx_psych_component_category (category), - INDEX idx_psych_component_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 复合心理状态表(匹配 ORM 模型 CompositePsychologicalState) - 'composite_psychological_states': ( - '''CREATE TABLE IF NOT EXISTS composite_psychological_states ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL UNIQUE, - state_id TEXT NOT NULL UNIQUE, - triggering_events TEXT, - context TEXT, - created_at INTEGER NOT NULL, - last_updated INTEGER NOT NULL - )''', - '''CREATE TABLE IF NOT EXISTS composite_psychological_states ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL UNIQUE, - state_id VARCHAR(255) NOT NULL UNIQUE, - triggering_events TEXT, - context TEXT, - created_at BIGINT NOT NULL, - last_updated BIGINT NOT NULL, - INDEX idx_psych_state_group (group_id), - INDEX idx_last_updated (last_updated) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 心理状态变化历史表(匹配 ORM 模型 PsychologicalStateHistory) - 'psychological_state_history': ( - '''CREATE TABLE IF NOT EXISTS psychological_state_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - group_id TEXT NOT NULL, - state_id TEXT NOT NULL, - category TEXT NOT NULL, - old_state_type TEXT, - new_state_type TEXT NOT NULL, - old_value REAL, - new_value REAL NOT NULL, - change_reason TEXT, - timestamp INTEGER NOT NULL - )''', - '''CREATE TABLE IF NOT EXISTS psychological_state_history ( - id INT PRIMARY KEY AUTO_INCREMENT, - group_id VARCHAR(255) NOT NULL, - state_id VARCHAR(255) NOT NULL, - category VARCHAR(50) NOT NULL, - old_state_type VARCHAR(100), - new_state_type VARCHAR(100) NOT NULL, - old_value DOUBLE, - new_value DOUBLE NOT NULL, - change_reason TEXT, - timestamp BIGINT NOT NULL, - INDEX idx_psych_history_group (group_id), - INDEX idx_psych_history_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # ==================== 增强社交关系管理表 ==================== - - # 用户社交关系组件表(匹配 ORM 模型 UserSocialRelationComponent) - 'user_social_relation_components': ( - '''CREATE TABLE IF NOT EXISTS user_social_relation_components ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - profile_id INTEGER, - from_user_id TEXT NOT NULL, - to_user_id TEXT NOT NULL, - group_id TEXT NOT NULL, - relation_type TEXT NOT NULL, - value REAL NOT NULL, - frequency INTEGER DEFAULT 0 NOT NULL, - last_interaction INTEGER NOT NULL, - description TEXT, - tags TEXT, - created_at INTEGER NOT NULL - )''', - '''CREATE TABLE IF NOT EXISTS user_social_relation_components ( - id INT PRIMARY KEY AUTO_INCREMENT, - profile_id INT, - from_user_id VARCHAR(255) NOT NULL, - to_user_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255) NOT NULL, - relation_type VARCHAR(100) NOT NULL, - value DOUBLE NOT NULL, - frequency INT DEFAULT 0 NOT NULL, - last_interaction BIGINT NOT NULL, - description TEXT, - tags TEXT, - created_at BIGINT NOT NULL, - INDEX idx_social_relation_profile (profile_id), - INDEX idx_social_relation_from_to (from_user_id, to_user_id, group_id), - INDEX idx_social_relation_type (relation_type) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 用户社交档案统计表(匹配 ORM 模型 UserSocialProfile) - 'user_social_profiles': ( - '''CREATE TABLE IF NOT EXISTS user_social_profiles ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - group_id TEXT NOT NULL, - total_relations INTEGER DEFAULT 0 NOT NULL, - significant_relations INTEGER DEFAULT 0 NOT NULL, - dominant_relation_type TEXT, - created_at INTEGER NOT NULL, - last_updated INTEGER NOT NULL, - UNIQUE(user_id, group_id) - )''', - '''CREATE TABLE IF NOT EXISTS user_social_profiles ( - id INT PRIMARY KEY AUTO_INCREMENT, - user_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255) NOT NULL, - total_relations INT DEFAULT 0 NOT NULL, - significant_relations INT DEFAULT 0 NOT NULL, - dominant_relation_type VARCHAR(100), - created_at BIGINT NOT NULL, - last_updated BIGINT NOT NULL, - UNIQUE KEY idx_social_profile_user_group (user_id, group_id), - INDEX idx_social_profile_group (group_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - - # 社交关系变化历史表(匹配 ORM 模型 SocialRelationHistory) - 'social_relation_history': ( - '''CREATE TABLE IF NOT EXISTS social_relation_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - from_user_id TEXT NOT NULL, - to_user_id TEXT NOT NULL, - group_id TEXT NOT NULL, - relation_type TEXT NOT NULL, - old_value REAL, - new_value REAL NOT NULL, - change_reason TEXT, - timestamp INTEGER NOT NULL - )''', - '''CREATE TABLE IF NOT EXISTS social_relation_history ( - id INT PRIMARY KEY AUTO_INCREMENT, - from_user_id VARCHAR(255) NOT NULL, - to_user_id VARCHAR(255) NOT NULL, - group_id VARCHAR(255) NOT NULL, - relation_type VARCHAR(100) NOT NULL, - old_value DOUBLE, - new_value DOUBLE NOT NULL, - change_reason TEXT, - timestamp BIGINT NOT NULL, - INDEX idx_social_history_from_to (from_user_id, to_user_id, group_id), - INDEX idx_social_history_timestamp (timestamp) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci''' - ), - } - - @staticmethod - def get_table_ddl(table_name: str, db_type: DatabaseType) -> str: - """ - 获取指定表的DDL语句 - - Args: - table_name: 表名 - db_type: 数据库类型 - - Returns: - DDL语句 - """ - schemas = TableSchemas.get_all_table_schemas() - if table_name not in schemas: - raise ValueError(f"Unknown table: {table_name}") - - sqlite_ddl, mysql_ddl = schemas[table_name] - - if db_type == DatabaseType.SQLITE: - return sqlite_ddl - elif db_type == DatabaseType.MYSQL: - return mysql_ddl - else: - raise ValueError(f"Unsupported database type: {db_type}") - - @staticmethod - def get_all_table_names() -> list: - """获取所有表名""" - return list(TableSchemas.get_all_table_schemas().keys()) diff --git a/services/time_decay_manager.py b/services/time_decay_manager.py deleted file mode 100644 index 29f4bbc..0000000 --- a/services/time_decay_manager.py +++ /dev/null @@ -1,371 +0,0 @@ -""" -时间衰减管理器 - 实现MaiBot的时间衰减机制 -为现有学习系统添加时间衰减功能,保持学习内容的时效性 -""" -import time -import math -from typing import Dict, List, Optional, Tuple, Any -from datetime import datetime -from dataclasses import dataclass - -from astrbot.api import logger - -from ..core.interfaces import ServiceLifecycle -from ..config import PluginConfig -from ..exceptions import TimeDecayError -from .database_manager import DatabaseManager - - -@dataclass -class DecayConfig: - """衰减配置""" - decay_days: int = 15 # MaiBot的15天衰减周期 - decay_min: float = 0.01 # 最小衰减值 - decay_table: str = "" # 衰减表名 - weight_column: str = "weight" # 权重列名 - time_column: str = "last_active_time" # 时间列名 - id_column: str = "id" # ID列名 - - -class TimeDecayManager: - """ - 时间衰减管理器 - 完全基于MaiBot的衰减机制设计 - 为各种学习数据提供统一的时间衰减管理 - """ - - def __init__(self, config: PluginConfig, db_manager: DatabaseManager): - self.config = config - self.db_manager = db_manager - self._status = ServiceLifecycle.CREATED - - # 预定义的衰减配置 - self.decay_configs = { - 'style_features': DecayConfig( - decay_days=15, - decay_table='style_features', - weight_column='confidence', - time_column='updated_at' - ), - 'persona_updates': DecayConfig( - decay_days=30, # 人格更新衰减周期更长 - decay_table='persona_updates', - weight_column='confidence', - time_column='timestamp' - ), - 'learning_batches': DecayConfig( - decay_days=7, # 学习批次衰减更快 - decay_table='learning_batches', - weight_column='quality_score', - time_column='created_at' - ), - 'affection_records': DecayConfig( - decay_days=20, - decay_table='affection_records', - weight_column='strength', - time_column='timestamp' - ) - } - - async def start(self) -> bool: - """启动服务""" - self._status = ServiceLifecycle.RUNNING - logger.info("TimeDecayManager服务已启动") - return True - - async def stop(self) -> bool: - """停止服务""" - self._status = ServiceLifecycle.STOPPED - logger.info("TimeDecayManager服务已停止") - return True - - def calculate_decay_factor(self, time_diff_days: float, decay_days: int = 15) -> float: - """ - 计算衰减因子 - 完全采用MaiBot的衰减算法 - - Args: - time_diff_days: 时间差(天) - decay_days: 衰减周期天数 - - Returns: - 衰减因子 - """ - if time_diff_days <= 0: - return 0.0 # 刚激活的不衰减 - - if time_diff_days >= decay_days: - return 0.01 # 长时间未活跃的大幅衰减 - - # 使用二次函数插值:在0-decay_days天之间从0衰减到0.01 - a = 0.01 / (decay_days ** 2) - decay = a * (time_diff_days ** 2) - - return min(0.01, decay) - - async def apply_decay_to_table(self, decay_config: DecayConfig, group_id: Optional[str] = None) -> Tuple[int, int]: - """ - 对指定表应用时间衰减 - - Args: - decay_config: 衰减配置 - group_id: 可选的群组ID筛选 - - Returns: - (更新数量, 删除数量) - """ - try: - current_time = time.time() - updated_count = 0 - deleted_count = 0 - - with self.db_manager.get_connection() as conn: - # 构建查询语句 - base_query = f'SELECT {decay_config.id_column}, {decay_config.weight_column}, {decay_config.time_column} FROM {decay_config.decay_table}' - - if group_id: - query = f'{base_query} WHERE group_id = ?' - cursor = conn.execute(query, (group_id,)) - else: - cursor = conn.execute(base_query) - - records = cursor.fetchall() - - for record_id, weight, last_active_time in records: - # 计算时间差(天) - time_diff_days = (current_time - last_active_time) / (24 * 3600) - - # 计算衰减值 - decay_value = self.calculate_decay_factor(time_diff_days, decay_config.decay_days) - new_weight = max(decay_config.decay_min, weight - decay_value) - - if new_weight <= decay_config.decay_min: - # 删除权重过低的记录 - delete_query = f'DELETE FROM {decay_config.decay_table} WHERE {decay_config.id_column} = ?' - conn.execute(delete_query, (record_id,)) - deleted_count += 1 - else: - # 更新权重 - update_query = f'UPDATE {decay_config.decay_table} SET {decay_config.weight_column} = ? WHERE {decay_config.id_column} = ?' - conn.execute(update_query, (new_weight, record_id)) - updated_count += 1 - - conn.commit() - - if updated_count > 0 or deleted_count > 0: - table_name = decay_config.decay_table - group_info = f" (群组: {group_id})" if group_id else "" - logger.info(f"表 {table_name}{group_info} 时间衰减完成:更新了 {updated_count} 个,删除了 {deleted_count} 个记录") - - return updated_count, deleted_count - - except Exception as e: - logger.error(f"对表 {decay_config.decay_table} 应用时间衰减失败: {e}") - raise TimeDecayError(f"时间衰减失败: {e}") - - async def apply_decay_to_all_tables(self, group_id: Optional[str] = None) -> Dict[str, Tuple[int, int]]: - """ - 对所有配置的表应用时间衰减 - - Args: - group_id: 可选的群组ID筛选 - - Returns: - 每个表的(更新数量, 删除数量)结果 - """ - results = {} - - for table_name, decay_config in self.decay_configs.items(): - try: - # 检查表是否存在 - if await self._table_exists(decay_config.decay_table): - updated, deleted = await self.apply_decay_to_table(decay_config, group_id) - results[table_name] = (updated, deleted) - else: - logger.debug(f"表 {decay_config.decay_table} 不存在,跳过衰减") - results[table_name] = (0, 0) - except Exception as e: - logger.error(f"对表 {table_name} 应用衰减失败: {e}") - results[table_name] = (0, 0) - - return results - - async def _table_exists(self, table_name: str) -> bool: - """检查表是否存在""" - try: - with self.db_manager.get_connection() as conn: - cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", - (table_name,) - ) - return cursor.fetchone() is not None - except Exception as e: - logger.error(f"检查表 {table_name} 是否存在失败: {e}") - return False - - async def add_decay_config(self, name: str, config: DecayConfig): - """添加新的衰减配置""" - self.decay_configs[name] = config - logger.info(f"添加衰减配置: {name}") - - async def get_decay_statistics(self, group_id: Optional[str] = None) -> Dict[str, Dict[str, Any]]: - """ - 获取衰减统计信息 - - Args: - group_id: 可选的群组ID筛选 - - Returns: - 各表的衰减统计信息 - """ - statistics = {} - current_time = time.time() - - for table_name, decay_config in self.decay_configs.items(): - try: - if not await self._table_exists(decay_config.decay_table): - continue - - with self.db_manager.get_connection() as conn: - # 构建查询语句 - base_query = f''' - SELECT - COUNT(*) as total_count, - AVG({decay_config.weight_column}) as avg_weight, - MIN({decay_config.time_column}) as oldest_time, - MAX({decay_config.time_column}) as newest_time - FROM {decay_config.decay_table} - ''' - - if group_id: - query = f'{base_query} WHERE group_id = ?' - cursor = conn.execute(query, (group_id,)) - else: - cursor = conn.execute(base_query) - - result = cursor.fetchone() - - if result and result[0] > 0: - total_count, avg_weight, oldest_time, newest_time = result - - # 计算老化程度 - oldest_days = (current_time - oldest_time) / (24 * 3600) if oldest_time else 0 - newest_days = (current_time - newest_time) / (24 * 3600) if newest_time else 0 - - statistics[table_name] = { - 'total_count': total_count, - 'avg_weight': round(avg_weight, 3) if avg_weight else 0, - 'oldest_days': round(oldest_days, 1), - 'newest_days': round(newest_days, 1), - 'decay_config': { - 'decay_days': decay_config.decay_days, - 'decay_min': decay_config.decay_min - } - } - else: - statistics[table_name] = { - 'total_count': 0, - 'avg_weight': 0, - 'oldest_days': 0, - 'newest_days': 0, - 'decay_config': { - 'decay_days': decay_config.decay_days, - 'decay_min': decay_config.decay_min - } - } - - except Exception as e: - logger.error(f"获取表 {table_name} 衰减统计失败: {e}") - statistics[table_name] = {'error': str(e)} - - return statistics - - async def schedule_decay_maintenance(self, interval_hours: int = 24): - """ - 定期衰减维护任务 - - Args: - interval_hours: 维护间隔小时数 - """ - logger.info(f"启动定期衰减维护,间隔: {interval_hours}小时") - - while self._status == ServiceLifecycle.RUNNING: - try: - # 执行全局衰减 - results = await self.apply_decay_to_all_tables() - - # 记录衰减结果 - total_updated = sum(r[0] for r in results.values()) - total_deleted = sum(r[1] for r in results.values()) - - if total_updated > 0 or total_deleted > 0: - logger.info(f"定期衰减维护完成,总计更新: {total_updated},删除: {total_deleted}") - - # 等待下次维护 - await asyncio.sleep(interval_hours * 3600) - - except Exception as e: - logger.error(f"定期衰减维护失败: {e}") - await asyncio.sleep(3600) # 错误后等待1小时再重试 - - -# 衰减工具函数 -def add_time_decay_to_existing_tables(): - """ - 为现有表添加时间衰减支持的工具函数 - 修改现有表结构,添加必要的时间和权重列 - """ - - # 表结构修改SQL - table_modifications = { - 'learning_batches': [ - 'ALTER TABLE learning_batches ADD COLUMN weight REAL DEFAULT 1.0', - 'ALTER TABLE learning_batches ADD COLUMN last_active_time REAL DEFAULT 0' - ], - 'style_features': [ - 'ALTER TABLE style_features ADD COLUMN last_active_time REAL DEFAULT 0' - ], - 'persona_updates': [ - 'ALTER TABLE persona_updates ADD COLUMN weight REAL DEFAULT 1.0', - 'ALTER TABLE persona_updates ADD COLUMN last_active_time REAL DEFAULT 0' - ] - } - - return table_modifications - - -# 使用示例函数 -async def integrate_time_decay_to_existing_services(decay_manager: TimeDecayManager): - """ - 将时间衰减机制集成到现有服务的示例 - """ - - # 1. 在学习服务中集成衰减 - async def enhanced_learning_with_decay(learning_service, group_id: str): - """带衰减的增强学习""" - # 执行正常学习 - learning_result = await learning_service.process_learning(group_id) - - # 应用时间衰减 - if learning_result: - await decay_manager.apply_decay_to_table( - decay_manager.decay_configs['learning_batches'], - group_id - ) - - return learning_result - - # 2. 在人格更新中集成衰减 - async def enhanced_persona_update_with_decay(persona_service, group_id: str): - """带衰减的人格更新""" - # 执行人格更新 - update_result = await persona_service.update_persona(group_id) - - # 应用衰减 - if update_result: - await decay_manager.apply_decay_to_table( - decay_manager.decay_configs['persona_updates'], - group_id - ) - - return update_result - - return enhanced_learning_with_decay, enhanced_persona_update_with_decay \ No newline at end of file diff --git a/statics/messages.py b/statics/messages.py index d12494f..26acb9a 100644 --- a/statics/messages.py +++ b/statics/messages.py @@ -38,95 +38,95 @@ class StatusMessages: class CommandMessages: """命令响应消息""" - LEARNING_STARTED = "✅ 自动学习已启动 for group {group_id}" - LEARNING_RUNNING = "📚 自动学习已在运行中 for group {group_id}" - LEARNING_STOPPED = "⏹️ 自动学习已停止 for group {group_id}" - FORCE_LEARNING_START = "🔄 开始强制学习周期 for group {group_id}..." - FORCE_LEARNING_COMPLETE = "✅ 强制学习周期完成 for group {group_id}" - DATA_CLEARED = "🗑️ 所有学习数据已清空" - DATA_EXPORTED = "📤 学习数据已导出到: {filepath}" + LEARNING_STARTED = " 自动学习已启动 for group {group_id}" + LEARNING_RUNNING = " 自动学习已在运行中 for group {group_id}" + LEARNING_STOPPED = " 自动学习已停止 for group {group_id}" + FORCE_LEARNING_START = " 开始强制学习周期 for group {group_id}..." + FORCE_LEARNING_COMPLETE = " 强制学习周期完成 for group {group_id}" + DATA_CLEARED = " 所有学习数据已清空" + DATA_EXPORTED = " 学习数据已导出到: {filepath}" # 状态报告模板 - STATUS_REPORT_HEADER = "📚 自学习插件状态报告 (会话ID: {group_id}):" + STATUS_REPORT_HEADER = " 自学习插件状态报告 (会话ID: {group_id}):" STATUS_BASIC_CONFIG = """ -🔧 基础配置: + 基础配置: - 消息抓取: {message_capture} - 自主学习: {auto_learning} - 实时学习: {realtime_learning} - Web界面: {web_interface}""" STATUS_CAPTURE_SETTINGS = """ -👥 抓取设置: + 抓取设置: - 目标QQ: {target_qq} - 当前人格: {current_persona}""" STATUS_MODEL_CONFIG = """ -🤖 模型配置: + 模型配置: - 筛选模型: {filter_model} - 提炼模型: {refine_model}""" STATUS_LEARNING_STATS = """ -📊 学习统计 (当前会话): + 学习统计 (当前会话): - 总收集消息: {total_messages} - 筛选消息: {filtered_messages} - 风格更新次数: {style_updates} - 最后学习时间: {last_learning_time}""" STATUS_STORAGE_STATS = """ -💾 存储统计 (当前会话): + 存储统计 (当前会话): - 原始消息: {raw_messages} 条 - 待处理消息: {unprocessed_messages} 条 - 筛选过的消息: {filtered_messages} 条""" - STATUS_SCHEDULER = "⏰ 调度状态 (当前会话): {status}" + STATUS_SCHEDULER = " 调度状态 (当前会话): {status}" # 好感度系统消息 - AFFECTION_DISABLED = "❌ 好感度系统未启用" - AFFECTION_STATUS_HEADER = "💝 好感度系统状态 (群组: {group_id}):" - AFFECTION_USER_LEVEL = "👤 您的好感度: {user_level}/{max_affection}" - AFFECTION_TOTAL_STATUS = "📊 总好感度: {total_affection}/{max_total_affection}" - AFFECTION_USER_COUNT = "👥 用户数量: {user_count}" - AFFECTION_CURRENT_MOOD = "🎭 当前情绪:" + AFFECTION_DISABLED = " 好感度系统未启用" + AFFECTION_STATUS_HEADER = " 好感度系统状态 (群组: {group_id}):" + AFFECTION_USER_LEVEL = " 您的好感度: {user_level}/{max_affection}" + AFFECTION_TOTAL_STATUS = " 总好感度: {total_affection}/{max_total_affection}" + AFFECTION_USER_COUNT = " 用户数量: {user_count}" + AFFECTION_CURRENT_MOOD = " 当前情绪:" AFFECTION_MOOD_TYPE = "- 类型: {mood_type}" AFFECTION_MOOD_INTENSITY = "- 强度: {intensity:.2f}" AFFECTION_MOOD_DESCRIPTION = "- 描述: {description}" AFFECTION_NO_MOOD = "- 无当前情绪状态" - AFFECTION_TOP_USERS = "🏆 好感度排行榜:" + AFFECTION_TOP_USERS = " 好感度排行榜:" AFFECTION_USER_RANK = "{rank}. 用户 {user_id}: {affection_level}点" # 设置情绪命令 SET_MOOD_USAGE = "请指定情绪类型,如: /set_mood happy" SET_MOOD_INVALID = "无效的情绪类型。有效选项: {valid_moods}" - SET_MOOD_SUCCESS = "🎭 已设置新的情绪状态:\n类型: {mood_type}\n强度: {intensity:.2f}\n描述: {description}" + SET_MOOD_SUCCESS = " 已设置新的情绪状态:\n类型: {mood_type}\n强度: {intensity:.2f}\n描述: {description}" # 分析报告消息 - ANALYTICS_GENERATING = "📊 正在生成数据分析报告..." - ANALYTICS_REPORT_HEADER = "📈 数据分析报告 (群组: {group_id}):" + ANALYTICS_GENERATING = " 正在生成数据分析报告..." + ANALYTICS_REPORT_HEADER = " 数据分析报告 (群组: {group_id}):" ANALYTICS_LEARNING_STATS = """ -📚 学习统计: + 学习统计: - 处理消息数: {total_messages} - 学习会话数: {learning_sessions} - 平均质量分: {avg_quality:.2f}""" ANALYTICS_USER_BEHAVIOR = """ -👥 用户行为模式: + 用户行为模式: - 活跃用户数: {active_users} - 主要话题: {main_topics} - 情感倾向: {emotion_tendency}""" - ANALYTICS_RECOMMENDATIONS = "💡 建议:\n- {recommendations}" + ANALYTICS_RECOMMENDATIONS = " 建议:\n- {recommendations}" # 人格切换消息 PERSONA_SWITCH_USAGE = "请指定人格名称,如: /persona_switch friendly" - PERSONA_SWITCH_SUCCESS = "✅ 已切换到人格: {persona_name}" - PERSONA_SWITCH_FAILED = "❌ 人格切换失败,请检查人格名称是否正确" + PERSONA_SWITCH_SUCCESS = " 已切换到人格: {persona_name}" + PERSONA_SWITCH_FAILED = " 人格切换失败,请检查人格名称是否正确" # 人格更新和显示消息 - PERSONA_UPDATE_HEADER = "🎭 人格更新报告 (群组: {group_id}):" - PERSONA_UPDATE_SUCCESS = "✅ 人格更新成功完成" - PERSONA_UPDATE_FAILED = "❌ 人格更新失败: {error}" + PERSONA_UPDATE_HEADER = " 人格更新报告 (群组: {group_id}):" + PERSONA_UPDATE_SUCCESS = " 人格更新成功完成" + PERSONA_UPDATE_FAILED = " 人格更新失败: {error}" PERSONA_BEFORE_AFTER = """ -📝 人格变化对比: + 人格变化对比: 【更新前】 {before_content} @@ -134,33 +134,33 @@ class CommandMessages: 【更新后】 {after_content} -📊 变化摘要: + 变化摘要: {change_summary}""" PERSONA_CURRENT_DISPLAY = """ -🎭 当前人格信息: + 当前人格信息: -📛 人格名称: {persona_name} -📝 人格描述: + 人格名称: {persona_name} + 人格描述: {persona_prompt} -📈 学习统计: + 学习统计: - 更新次数: {update_count} - 最后更新: {last_update} - 学习质量: {quality_score:.2f}/10""" PERSONA_BACKUP_STATUS = """ -💾 备份状态: + 备份状态: - 总备份数: {total_backups} - 最新备份: {latest_backup} - 自动备份: {auto_backup_status}""" PERSONA_STYLE_FEATURES = """ -🎨 学习到的风格特征: + 学习到的风格特征: {style_features}""" PERSONA_CHANGE_SUMMARY = """ -📊 本次更新内容: + 本次更新内容: - Prompt长度: {prompt_length_before} → {prompt_length_after} ({length_change}) - 新增特征: {new_features_count} 项 - 风格调整: {style_adjustments} @@ -184,10 +184,10 @@ class CommandMessages: STOP_FAILED = "停止失败: {error}" # 状态指示符 - STATUS_ENABLED = "✅ 启用" - STATUS_DISABLED = "❌ 禁用" - STATUS_RUNNING = "🟢 运行中" - STATUS_STOPPED = "🔴 已停止" + STATUS_ENABLED = " 启用" + STATUS_DISABLED = " 禁用" + STATUS_RUNNING = " 运行中" + STATUS_STOPPED = " 已停止" STATUS_ALL_USERS = "全部用户" STATUS_UNKNOWN = "未知" STATUS_NEVER_EXECUTED = "从未执行" @@ -360,9 +360,7 @@ class SQLQueries: ''' -# ============================================================ # 更新类型常量和辅助函数(用于人格审查服务的统一类型标准化) -# ============================================================ UPDATE_TYPE_STYLE_LEARNING = 'style_learning' UPDATE_TYPE_PERSONA_LEARNING = 'persona_learning' diff --git a/statics/prompts.py b/statics/prompts.py index fda5900..7b99791 100644 --- a/statics/prompts.py +++ b/statics/prompts.py @@ -100,13 +100,13 @@ 请返回以下格式的JSON,每个维度给出0-1的评分: {{ - "vocabulary_richness": 0.0, // 词汇丰富度 - "sentence_complexity": 0.0, // 句式复杂度 - "emotional_expression": 0.0, // 情感表达度 - "interaction_tendency": 0.0, // 互动倾向 - "topic_diversity": 0.0, // 话题多样性 - "formality_level": 0.0, // 正式程度 - "creativity_score": 0.0 // 创造性得分 + "vocabulary_richness": 0.0, // 词汇丰富度 + "sentence_complexity": 0.0, // 句式复杂度 + "emotional_expression": 0.0, // 情感表达度 + "interaction_tendency": 0.0, // 互动倾向 + "topic_diversity": 0.0, // 话题多样性 + "formality_level": 0.0, // 正式程度 + "creativity_score": 0.0 // 创造性得分 }} """ @@ -170,11 +170,11 @@ 请评估以下维度并以JSON格式返回结果: {{ - "content_quality": 0.0-1.0, // 消息的深度、信息量、原创性、表达清晰度 - "relevance": 0.0-1.0, // 与当前对话主题或人格的相关性 + "content_quality": 0.0-1.0, // 消息的深度、信息量、原创性、表达清晰度 + "relevance": 0.0-1.0, // 与当前对话主题或人格的相关性 "emotional_positivity": 0.0-1.0, // 消息的情感倾向(积极程度) - "interactivity": 0.0-1.0, // 消息是否引发或回应了互动(如提问、回应、@他人) - "learning_value": 0.0-1.0 // 消息对模型学习当前人格对话模式和知识的潜在贡献 + "interactivity": 0.0-1.0, // 消息是否引发或回应了互动(如提问、回应、@他人) + "learning_value": 0.0-1.0 // 消息对模型学习当前人格对话模式和知识的潜在贡献 }} 请确保返回有效的JSON格式,并且只包含JSON对象,不需要其他说明。 @@ -269,11 +269,11 @@ 请返回以下格式的JSON: {{ - "openness": 0.0-1.0, // 开放性 - "conscientiousness": 0.0-1.0, // 尽责性 - "extraversion": 0.0-1.0, // 外向性 - "agreeableness": 0.0-1.0, // 宜人性 - "neuroticism": 0.0-1.0 // 神经质 + "openness": 0.0-1.0, // 开放性 + "conscientiousness": 0.0-1.0, // 尽责性 + "extraversion": 0.0-1.0, // 外向性 + "agreeableness": 0.0-1.0, // 宜人性 + "neuroticism": 0.0-1.0 // 神经质 }} """ @@ -307,11 +307,11 @@ 请以JSON格式返回分析结果: {{ - "emotional_diversity": 0.0-1.0, // 情感多样性得分 - "intensity_balance": 0.0-1.0, // 强度平衡得分 - "emotional_stability": 0.0-1.0, // 情感稳定性得分 - "learning_value": 0.0-1.0, // 学习价值得分 - "overall_balance": 0.0-1.0, // 总体情感平衡得分 + "emotional_diversity": 0.0-1.0, // 情感多样性得分 + "intensity_balance": 0.0-1.0, // 强度平衡得分 + "emotional_stability": 0.0-1.0, // 情感稳定性得分 + "learning_value": 0.0-1.0, // 学习价值得分 + "overall_balance": 0.0-1.0, // 总体情感平衡得分 "analysis_summary": "分析总结" }} """ @@ -405,7 +405,7 @@ {chat_content} 请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格 -1. 只考虑文字,不要考虑表情包和图片 +1. 只考虑文字,不要考虑表情包和图片 2. 不要涉及具体的人名,但是可以涉及具体名词 3. 思考有没有特殊的梗,一并总结成语言风格 4. 例子仅供参考,请严格根据群聊内容总结!!! @@ -415,7 +415,7 @@ 例如: 当"对某件事表示十分惊叹"时,使用"我嘞个xxxx" -当"表示讽刺的赞同,不讲道理"时,使用"对对对" +当"表示讽刺的赞同,不讲道理"时,使用"对对对" 当"想说明某个具体的事实观点,但懒得明说"时,使用"懂的都懂" 当"涉及游戏相关时,夸赞,略带戏谑意味"时,使用"这么强!" @@ -564,13 +564,13 @@ **要使用"你应该xxx"、"要xxx"、"记得xxx"、"多用xxx"、"少说xxx"这类直接告诉LLM该怎么做的命令!** 示例(错误): -- "强化幽默与毒舌语言表达的灵活性与协调性" ❌ -- "优化与陌生用户交流方式,保持坦率直接但降低机械感" ❌ +- "强化幽默与毒舌语言表达的灵活性与协调性" +- "优化与陌生用户交流方式,保持坦率直接但降低机械感" 示例(正确): -- "你要多用重庆方言和网络梗,说话带点毒舌和幽默感" ✅ -- "和陌生人聊天时要坦率直接,但别太机械,要自然点" ✅ -- "讨论技术问题时记得保持专业,但也要有点趣味性" ✅ +- "你要多用重庆方言和网络梗,说话带点毒舌和幽默感" +- "和陌生人聊天时要坦率直接,但别太机械,要自然点" +- "讨论技术问题时记得保持专业,但也要有点趣味性" 请以JSON格式返回增量微调结果: {{ diff --git a/tests/conftest.py b/tests/conftest.py index 58293e4..125f434 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,9 +15,7 @@ import time -# ============================================================================ # Async Test Utilities -# ============================================================================ @pytest.fixture(scope="session") def event_loop(): @@ -27,9 +25,7 @@ def event_loop(): loop.close() -# ============================================================================ # Mock ServiceContainer -# ============================================================================ @pytest.fixture def mock_plugin_config(): @@ -229,9 +225,7 @@ def mock_container( return container -# ============================================================================ # Test Data Factories -# ============================================================================ @pytest.fixture def sample_persona_data(): @@ -319,15 +313,13 @@ def sample_chat_message(): } -# ============================================================================ # Authentication Test Helpers -# ============================================================================ @pytest.fixture def sample_password_config(): """Sample password configuration""" return { - 'password_hash': '5f4dcc3b5aa765d61d8327deb882cf99', # MD5 of 'password' + 'password_hash': '5f4dcc3b5aa765d61d8327deb882cf99', # MD5 of 'password' 'salt': 'test_salt', 'algorithm': 'md5', 'created_at': time.time(), @@ -345,9 +337,7 @@ def sample_login_attempt(): } -# ============================================================================ # Async Helper Functions -# ============================================================================ @pytest.fixture def async_return(): diff --git a/utils/cache_manager.py b/utils/cache_manager.py index 0cdf6dd..845860a 100644 --- a/utils/cache_manager.py +++ b/utils/cache_manager.py @@ -24,10 +24,10 @@ def __init__(self): """初始化缓存管理器""" # 不同用途的缓存实例 # TTL 缓存 - 用于有明确过期时间的数据 - self.affection_cache = TTLCache(maxsize=2000, ttl=300) # 5分钟 - self.memory_cache = TTLCache(maxsize=1000, ttl=600) # 10分钟 - self.state_cache = TTLCache(maxsize=500, ttl=60) # 1分钟 - self.relation_cache = TTLCache(maxsize=1000, ttl=60) # 1分钟 + self.affection_cache = TTLCache(maxsize=2000, ttl=300) # 5分钟 + self.memory_cache = TTLCache(maxsize=1000, ttl=600) # 10分钟 + self.state_cache = TTLCache(maxsize=500, ttl=60) # 1分钟 + self.relation_cache = TTLCache(maxsize=1000, ttl=60) # 1分钟 # LRU 缓存 - 用于需要保持热点数据的场景 self.conversation_cache = LRUCache(maxsize=500) @@ -136,9 +136,7 @@ def get_stats(self, cache_name: str) -> dict: return {'size': len(cache)} -# ============================================================ # 装饰器 -# ============================================================ def cached( cache_name: str = 'general', @@ -229,9 +227,7 @@ async def wrapper(*args, **kwargs): return decorator -# ============================================================ # 全局单例 -# ============================================================ _global_cache_manager: Optional[CacheManager] = None diff --git a/utils/guardrails_manager.py b/utils/guardrails_manager.py index 71bf6e0..7a1bbac 100644 --- a/utils/guardrails_manager.py +++ b/utils/guardrails_manager.py @@ -8,9 +8,7 @@ from astrbot.api import logger -# ============================================================ # Pydantic 模型定义 - 用于心理状态分析 -# ============================================================ class PsychologicalStateTransition(BaseModel): """ @@ -39,9 +37,7 @@ def validate_state_name(cls, v: str) -> str: return v.strip() -# ============================================================ # Pydantic 模型定义 - 用于对话目标分析 -# ============================================================ class GoalAnalysisResult(BaseModel): """ @@ -130,9 +126,7 @@ class ConversationIntentAnalysis(BaseModel): ) -# ============================================================ # Pydantic 模型定义 - 用于社交关系分析 -# ============================================================ class RelationChange(BaseModel): """ @@ -184,9 +178,7 @@ def validate_relations_count(cls, v: List[RelationChange]) -> List[RelationChang return v -# ============================================================ # Guardrails 管理器 -# ============================================================ class GuardrailsManager: """ @@ -288,14 +280,14 @@ async def parse_state_transition( result = guard.parse(response_text) if result.validation_passed: - logger.debug(f"✅ [Guardrails] 心理状态解析成功: {result.validated_output.new_state}") + logger.debug(f" [Guardrails] 心理状态解析成功: {result.validated_output.new_state}") return result.validated_output else: - logger.warning(f"⚠️ [Guardrails] 心理状态验证失败: {result.validation_summaries}") + logger.warning(f" [Guardrails] 心理状态验证失败: {result.validation_summaries}") return None except Exception as e: - logger.error(f"❌ [Guardrails] 心理状态解析失败: {e}", exc_info=True) + logger.error(f" [Guardrails] 心理状态解析失败: {e}", exc_info=True) return None async def parse_relation_analysis( @@ -346,14 +338,14 @@ async def parse_relation_analysis( if result.validation_passed: relation_count = len(result.validated_output.relations) - logger.debug(f"✅ [Guardrails] 社交关系解析成功: {relation_count}个关系") + logger.debug(f" [Guardrails] 社交关系解析成功: {relation_count}个关系") return result.validated_output else: - logger.warning(f"⚠️ [Guardrails] 社交关系验证失败: {result.validation_summaries}") + logger.warning(f" [Guardrails] 社交关系验证失败: {result.validation_summaries}") return None except Exception as e: - logger.error(f"❌ [Guardrails] 社交关系解析失败: {e}", exc_info=True) + logger.error(f" [Guardrails] 社交关系解析失败: {e}", exc_info=True) return None def get_goal_analysis_guard(self) -> Guard: @@ -432,24 +424,24 @@ async def parse_goal_analysis( result = guard.parse(response_text) if result.validation_passed: - # ⚠️ 修复:validated_output 可能是 dict,需要转换为 Pydantic 模型 + # 修复:validated_output 可能是 dict,需要转换为 Pydantic 模型 validated_data = result.validated_output if isinstance(validated_data, dict): goal_result = GoalAnalysisResult(**validated_data) - logger.debug(f"✅ [Guardrails] 对话目标解析成功: {goal_result.goal_type}") + logger.debug(f" [Guardrails] 对话目标解析成功: {goal_result.goal_type}") return goal_result elif isinstance(validated_data, GoalAnalysisResult): - logger.debug(f"✅ [Guardrails] 对话目标解析成功: {validated_data.goal_type}") + logger.debug(f" [Guardrails] 对话目标解析成功: {validated_data.goal_type}") return validated_data else: - logger.warning(f"⚠️ [Guardrails] 意外的输出类型: {type(validated_data)}") + logger.warning(f" [Guardrails] 意外的输出类型: {type(validated_data)}") return None else: - logger.warning(f"⚠️ [Guardrails] 对话目标验证失败: {result.validation_summaries}") + logger.warning(f" [Guardrails] 对话目标验证失败: {result.validation_summaries}") return None except Exception as e: - logger.error(f"❌ [Guardrails] 对话目标解析失败: {e}", exc_info=True) + logger.error(f" [Guardrails] 对话目标解析失败: {e}", exc_info=True) return None async def parse_intent_analysis( @@ -505,24 +497,24 @@ async def parse_intent_analysis( result = guard.parse(response_text) if result.validation_passed: - # ⚠️ 修复:validated_output 可能是 dict,需要转换为 Pydantic 模型 + # 修复:validated_output 可能是 dict,需要转换为 Pydantic 模型 validated_data = result.validated_output if isinstance(validated_data, dict): intent_result = ConversationIntentAnalysis(**validated_data) - logger.debug(f"✅ [Guardrails] 对话意图解析成功") + logger.debug(f" [Guardrails] 对话意图解析成功") return intent_result elif isinstance(validated_data, ConversationIntentAnalysis): - logger.debug(f"✅ [Guardrails] 对话意图解析成功") + logger.debug(f" [Guardrails] 对话意图解析成功") return validated_data else: - logger.warning(f"⚠️ [Guardrails] 意外的输出类型: {type(validated_data)}") + logger.warning(f" [Guardrails] 意外的输出类型: {type(validated_data)}") return None else: - logger.warning(f"⚠️ [Guardrails] 对话意图验证失败: {result.validation_summaries}") + logger.warning(f" [Guardrails] 对话意图验证失败: {result.validation_summaries}") return None except Exception as e: - logger.error(f"❌ [Guardrails] 对话意图解析失败: {e}", exc_info=True) + logger.error(f" [Guardrails] 对话意图解析失败: {e}", exc_info=True) return None def parse_json_direct( @@ -545,7 +537,7 @@ def parse_json_direct( result = guard.parse(response_text) if result.validation_passed: - # ⚠️ 修复:validated_output 可能是 dict,需要转换为 Pydantic 模型 + # 修复:validated_output 可能是 dict,需要转换为 Pydantic 模型 validated_data = result.validated_output if isinstance(validated_data, dict): # 将 dict 转换为 Pydantic 模型实例 @@ -554,14 +546,14 @@ def parse_json_direct( # 已经是模型实例,直接返回 return validated_data else: - logger.warning(f"⚠️ [Guardrails] 意外的输出类型: {type(validated_data)}") + logger.warning(f" [Guardrails] 意外的输出类型: {type(validated_data)}") return None else: - logger.warning(f"⚠️ [Guardrails] JSON 验证失败: {result.validation_summaries}") + logger.warning(f" [Guardrails] JSON 验证失败: {result.validation_summaries}") return None except Exception as e: - logger.error(f"❌ [Guardrails] JSON 解析失败: {e}", exc_info=True) + logger.error(f" [Guardrails] JSON 解析失败: {e}", exc_info=True) return None def validate_and_clean_json( @@ -585,14 +577,14 @@ def validate_and_clean_json( try: # 检查输入是否为空 if not response_text: - logger.error(f"❌ [Guardrails] 输入为空,无法解析 JSON") + logger.error(f" [Guardrails] 输入为空,无法解析 JSON") return None # 1. 移除 Markdown 代码块标记 cleaned_text = response_text.strip() # 记录原始响应长度用于调试 - logger.debug(f"🔍 [Guardrails] 原始响应长度: {len(response_text)}, 清理后长度: {len(cleaned_text)}") + logger.debug(f" [Guardrails] 原始响应长度: {len(response_text)}, 清理后长度: {len(cleaned_text)}") # 移除 ```json 和 ``` 标记 if cleaned_text.startswith("```json"): @@ -607,7 +599,7 @@ def validate_and_clean_json( # 检查清理后是否为空 if not cleaned_text: - logger.warning(f"⚠️ [Guardrails] 清理后的响应为空") + logger.warning(f" [Guardrails] 清理后的响应为空") return None # 2. 尝试提取 JSON 部分(处理 LLM 可能在 JSON 前后加说明的情况) @@ -627,20 +619,20 @@ def validate_and_clean_json( # 再次检查提取后是否为空 if not cleaned_text: - logger.warning(f"⚠️ [Guardrails] 提取JSON后内容为空") + logger.warning(f" [Guardrails] 提取JSON后内容为空") return None # 3. 尝试解析 JSON parsed = json.loads(cleaned_text) - logger.debug(f"✅ [Guardrails] JSON 验证成功,类型: {type(parsed).__name__}") + logger.debug(f" [Guardrails] JSON 验证成功,类型: {type(parsed).__name__}") return parsed except json.JSONDecodeError as e: # 显示响应预览用于调试 preview = cleaned_text[:200] if len(cleaned_text) > 200 else cleaned_text - logger.warning(f"⚠️ [Guardrails] JSON 解析失败: {e},尝试修复...") - logger.debug(f"🔍 [Guardrails] 响应预览: {preview}") + logger.warning(f" [Guardrails] JSON 解析失败: {e},尝试修复...") + logger.debug(f" [Guardrails] 响应预览: {preview}") # 尝试修复常见的 JSON 错误 try: @@ -652,15 +644,15 @@ def validate_and_clean_json( fixed_text = re.sub(r',\s*]', ']', fixed_text) parsed = json.loads(fixed_text) - logger.info(f"✅ [Guardrails] JSON 修复成功") + logger.info(f" [Guardrails] JSON 修复成功") return parsed except Exception as fix_error: - logger.error(f"❌ [Guardrails] JSON 修复失败: {fix_error}") + logger.error(f" [Guardrails] JSON 修复失败: {fix_error}") return None except Exception as e: - logger.error(f"❌ [Guardrails] JSON 验证异常: {e}") + logger.error(f" [Guardrails] JSON 验证异常: {e}") return None async def validate_llm_response( @@ -705,7 +697,7 @@ async def validate_llm_response( response_text = await llm_callable(enhanced_prompt, model=model, **kwargs) if not response_text: - logger.warning("⚠️ [Guardrails] LLM 返回为空") + logger.warning(" [Guardrails] LLM 返回为空") return None # 根据期望格式验证 @@ -720,13 +712,11 @@ async def validate_llm_response( return response_text.strip() except Exception as e: - logger.error(f"❌ [Guardrails] LLM 响应验证失败: {e}", exc_info=True) + logger.error(f" [Guardrails] LLM 响应验证失败: {e}", exc_info=True) return None -# ============================================================ # 全局单例 -# ============================================================ # 使用 max_reasks=1 保持高性能 _guardrails_manager: Optional[GuardrailsManager] = None diff --git a/utils/json_cleaner.py b/utils/json_cleaner.py deleted file mode 100644 index 40b6f4a..0000000 --- a/utils/json_cleaner.py +++ /dev/null @@ -1,421 +0,0 @@ -""" -JSON 清洗工具类 -用于清洗和验证 LLM 返回的 JSON 格式内容 -""" -import json -import re -from typing import Any, Dict, List, Optional, Union -from astrbot.api import logger - - -class JSONCleaner: - """ - JSON 清洗工具类 - - 功能: - 1. 清理 LLM 返回中的无效字符和格式 - 2. 提取 JSON 内容(即使被其他文本包围) - 3. 修复常见的 JSON 格式错误 - 4. 验证 JSON 结构 - 5. 提供安全的默认值 - """ - - @staticmethod - def clean_and_parse( - raw_text: str, - expected_type: type = dict, - default_value: Any = None, - strict: bool = False - ) -> Any: - """ - 清洗并解析 JSON 文本 - - Args: - raw_text: LLM 返回的原始文本 - expected_type: 期望的类型 (dict, list, str, int, float, bool) - default_value: 解析失败时的默认值 - strict: 是否严格模式(严格模式下类型不匹配会返回默认值) - - Returns: - 解析后的 Python 对象,失败时返回 default_value - - Examples: - >>> JSONCleaner.clean_and_parse('{"key": "value"}') - {'key': 'value'} - - >>> JSONCleaner.clean_and_parse('```json\\n{"key": "value"}\\n```') - {'key': 'value'} - - >>> JSONCleaner.clean_and_parse('invalid', default_value={}) - {} - """ - if not raw_text or not isinstance(raw_text, str): - logger.warning(f"[JSON清洗] 输入无效: {type(raw_text)}") - return default_value if default_value is not None else {} - - try: - # 1. 预处理: 移除前后空白 - text = raw_text.strip() - - # 2. 提取 JSON 内容 - json_text = JSONCleaner._extract_json(text) - - if not json_text: - logger.warning(f"[JSON清洗] 无法提取 JSON 内容: {text[:100]}...") - return default_value if default_value is not None else {} - - # 3. 清理 JSON 文本 - cleaned_text = JSONCleaner._clean_json_text(json_text) - - # 4. 解析 JSON - parsed = json.loads(cleaned_text) - - # 5. 类型验证 - if strict and not isinstance(parsed, expected_type): - logger.warning( - f"[JSON清洗] 类型不匹配: 期望 {expected_type}, 实际 {type(parsed)}" - ) - return default_value if default_value is not None else {} - - logger.debug(f"[JSON清洗] 成功解析: {type(parsed)}") - return parsed - - except json.JSONDecodeError as e: - logger.error(f"[JSON清洗] JSON 解析失败: {e}") - logger.debug(f"原始文本: {raw_text[:200]}...") - return default_value if default_value is not None else {} - - except Exception as e: - logger.error(f"[JSON清洗] 未知错误: {e}", exc_info=True) - return default_value if default_value is not None else {} - - @staticmethod - def _extract_json(text: str) -> Optional[str]: - """ - 从文本中提取 JSON 内容 - - 支持的格式: - 1. 纯 JSON: {"key": "value"} - 2. Markdown 代码块: ```json\\n{...}\\n``` - 3. 代码块: ```{...}``` - 4. 文本包围: Some text {"key": "value"} more text - """ - # 尝试 1: 检查是否是纯 JSON (以 { 或 [ 开头) - if text.startswith('{') or text.startswith('['): - # 找到对应的结束位置 - if text.startswith('{'): - end_idx = JSONCleaner._find_closing_brace(text, 0) - if end_idx != -1: - return text[:end_idx + 1] - elif text.startswith('['): - end_idx = JSONCleaner._find_closing_bracket(text, 0) - if end_idx != -1: - return text[:end_idx + 1] - - # 尝试 2: 提取 markdown 代码块中的 JSON - # ```json\n{...}\n``` - json_code_block_pattern = r'```json\s*\n(.*?)\n```' - match = re.search(json_code_block_pattern, text, re.DOTALL) - if match: - return match.group(1).strip() - - # 尝试 3: 提取普通代码块中的 JSON - # ```{...}``` - code_block_pattern = r'```\s*\n?(.*?)\n?```' - match = re.search(code_block_pattern, text, re.DOTALL) - if match: - content = match.group(1).strip() - if content.startswith('{') or content.startswith('['): - return content - - # 尝试 4: 查找第一个 { 或 [ 并提取到对应的结束符 - for start_char, finder in [('{', JSONCleaner._find_closing_brace), - ('[', JSONCleaner._find_closing_bracket)]: - start_idx = text.find(start_char) - if start_idx != -1: - end_idx = finder(text, start_idx) - if end_idx != -1: - return text[start_idx:end_idx + 1] - - # 无法提取 - return None - - @staticmethod - def _find_closing_brace(text: str, start_idx: int) -> int: - """找到与起始 { 对应的结束 }""" - depth = 0 - in_string = False - escape_next = False - - for i in range(start_idx, len(text)): - char = text[i] - - if escape_next: - escape_next = False - continue - - if char == '\\': - escape_next = True - continue - - if char == '"' and not in_string: - in_string = True - elif char == '"' and in_string: - in_string = False - elif not in_string: - if char == '{': - depth += 1 - elif char == '}': - depth -= 1 - if depth == 0: - return i - - return -1 - - @staticmethod - def _find_closing_bracket(text: str, start_idx: int) -> int: - """找到与起始 [ 对应的结束 ]""" - depth = 0 - in_string = False - escape_next = False - - for i in range(start_idx, len(text)): - char = text[i] - - if escape_next: - escape_next = False - continue - - if char == '\\': - escape_next = True - continue - - if char == '"' and not in_string: - in_string = True - elif char == '"' and in_string: - in_string = False - elif not in_string: - if char == '[': - depth += 1 - elif char == ']': - depth -= 1 - if depth == 0: - return i - - return -1 - - @staticmethod - def _clean_json_text(text: str) -> str: - """ - 清理 JSON 文本中的常见问题 - - 修复: - 1. 单引号替换为双引号 - 2. 移除尾随逗号 - 3. 修复布尔值大小写 - 4. 移除注释 - """ - # 1. 移除单行注释 (//...) - text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE) - - # 2. 移除多行注释 (/*...*/) - text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL) - - # 3. 修复布尔值 (True -> true, False -> false) - text = re.sub(r'\bTrue\b', 'true', text) - text = re.sub(r'\bFalse\b', 'false', text) - text = re.sub(r'\bNone\b', 'null', text) - - # 4. 移除尾随逗号 (在 } 或 ] 之前的逗号) - text = re.sub(r',(\s*[}\]])', r'\1', text) - - # 5. 尝试修复单引号为双引号 (谨慎处理) - # 只替换键名的单引号: 'key' -> "key" - text = re.sub(r"'([^']*)'(\s*):", r'"\1"\2:', text) - - return text - - @staticmethod - def safe_get( - data: Dict[str, Any], - key: str, - default: Any = None, - expected_type: type = None - ) -> Any: - """ - 安全地从字典获取值 - - Args: - data: 字典 - key: 键名 - default: 默认值 - expected_type: 期望的类型 - - Returns: - 值或默认值 - - Examples: - >>> data = {'key': 'value', 'num': '123'} - >>> JSONCleaner.safe_get(data, 'key') - 'value' - >>> JSONCleaner.safe_get(data, 'missing', default='default') - 'default' - >>> JSONCleaner.safe_get(data, 'num', expected_type=int, default=0) - 0 # 因为 '123' 不是 int 类型 - """ - if not isinstance(data, dict): - return default - - value = data.get(key, default) - - if expected_type is not None and not isinstance(value, expected_type): - logger.debug( - f"[JSON清洗] 类型不匹配: 键 '{key}' 期望 {expected_type}, " - f"实际 {type(value)}, 返回默认值" - ) - return default - - return value - - @staticmethod - def validate_schema( - data: Dict[str, Any], - required_keys: List[str], - optional_keys: List[str] = None - ) -> bool: - """ - 验证 JSON 数据的结构 - - Args: - data: 要验证的数据 - required_keys: 必需的键列表 - optional_keys: 可选的键列表 - - Returns: - 是否有效 - - Examples: - >>> data = {'name': 'Alice', 'age': 30} - >>> JSONCleaner.validate_schema(data, ['name', 'age']) - True - >>> JSONCleaner.validate_schema(data, ['name', 'email']) - False - """ - if not isinstance(data, dict): - logger.warning("[JSON清洗] 数据不是字典类型") - return False - - # 检查必需键 - for key in required_keys: - if key not in data: - logger.warning(f"[JSON清洗] 缺少必需键: {key}") - return False - - # 检查是否有未预期的键 (如果提供了 optional_keys) - if optional_keys is not None: - all_allowed_keys = set(required_keys) | set(optional_keys) - extra_keys = set(data.keys()) - all_allowed_keys - if extra_keys: - logger.debug(f"[JSON清洗] 存在额外的键: {extra_keys}") - - logger.debug("[JSON清洗] 结构验证通过") - return True - - -class LLMJSONParser: - """ - LLM JSON 解析器 - 针对 LLM 返回的特定格式进行优化 - """ - - @staticmethod - def parse_state_analysis(raw_text: str) -> Optional[str]: - """ - 解析心理状态分析结果 - - 期望格式: LLM 返回一个状态名称(字符串) - - Returns: - 状态名称字符串,失败返回 None - """ - # 尝试直接作为字符串 - cleaned = raw_text.strip().strip('"\'') - - # 移除可能的前缀 - cleaned = re.sub(r'^(状态[::]|新状态[::])', '', cleaned) - - if cleaned and len(cleaned) < 50: # 状态名称不应太长 - return cleaned - - # 尝试作为 JSON 解析 - result = JSONCleaner.clean_and_parse(raw_text, expected_type=str, default_value=None) - if result: - return result - - return None - - @staticmethod - def parse_relation_analysis(raw_text: str) -> Dict[str, float]: - """ - 解析社交关系分析结果 - - 期望格式: {"关系类型1": 0.03, "关系类型2": 0.05} - - Returns: - 关系类型到数值变化的映射,失败返回空字典 - """ - result = JSONCleaner.clean_and_parse( - raw_text, - expected_type=dict, - default_value={}, - strict=False - ) - - if not result: - return {} - - # 清理和验证值 - cleaned_result = {} - for key, value in result.items(): - # 确保键是字符串 - if not isinstance(key, str): - key = str(key) - - # 确保值是数字 - try: - if isinstance(value, (int, float)): - cleaned_result[key] = float(value) - elif isinstance(value, str): - # 尝试转换字符串为数字 - cleaned_result[key] = float(value) - except (ValueError, TypeError): - logger.warning(f"[JSON清洗] 无法转换关系值: {key} = {value}") - continue - - return cleaned_result - - @staticmethod - def parse_event_analysis(raw_text: str) -> Dict[str, Any]: - """ - 解析事件分析结果 - - 期望格式: {"event_type": "...", "intensity": 0.5, "description": "..."} - - Returns: - 事件分析结果字典,失败返回空字典 - """ - result = JSONCleaner.clean_and_parse( - raw_text, - expected_type=dict, - default_value={}, - strict=False - ) - - # 验证必需字段 - if not JSONCleaner.validate_schema( - result, - required_keys=[], # 没有严格必需的键 - optional_keys=['event_type', 'intensity', 'description', 'impact'] - ): - return {} - - return result diff --git a/utils/schema_validator.py b/utils/schema_validator.py index 344d288..cc1e669 100644 --- a/utils/schema_validator.py +++ b/utils/schema_validator.py @@ -33,10 +33,10 @@ class ColumnInfo: class TableDiff: """表结构差异""" table_name: str - missing_columns: List[str] # 缺失的字段 - extra_columns: List[str] # 多余的字段 - type_mismatches: List[Tuple[str, str, str]] # (字段名, 期望类型, 实际类型) - nullable_mismatches: List[Tuple[str, bool, bool]] # (字段名, 期望nullable, 实际nullable) + missing_columns: List[str] # 缺失的字段 + extra_columns: List[str] # 多余的字段 + type_mismatches: List[Tuple[str, str, str]] # (字段名, 期望类型, 实际类型) + nullable_mismatches: List[Tuple[str, bool, bool]] # (字段名, 期望nullable, 实际nullable) class SchemaValidator: @@ -87,7 +87,7 @@ async def validate_all_tables(self, auto_fix: bool = True) -> Dict[str, TableDif Dict[str, TableDiff]: {表名: 差异信息} """ logger.info("=" * 70) - logger.info("🔍 开始数据库表结构验证") + logger.info(" 开始数据库表结构验证") logger.info("=" * 70) all_diffs = {} @@ -112,7 +112,7 @@ async def validate_all_tables(self, auto_fix: bool = True) -> Dict[str, TableDif # 表存在,验证结构 validated_tables.append(table_name) - logger.info(f"\n📋 验证表: {table_name}") + logger.info(f"\n 验证表: {table_name}") # 比较表结构 diff = await self._compare_table_structure(table_name, table_obj) @@ -125,25 +125,25 @@ async def validate_all_tables(self, auto_fix: bool = True) -> Dict[str, TableDif if auto_fix: await self._fix_table_structure(table_name, table_obj, diff) else: - logger.info(f" ✅ 表结构一致") + logger.info(f" 表结构一致") logger.info("\n" + "=" * 70) # 总结报告 if created_tables: - logger.info(f"🆕 新建 {len(created_tables)} 个表: {', '.join(created_tables[:5])}" + + logger.info(f" 新建 {len(created_tables)} 个表: {', '.join(created_tables[:5])}" + (f" 等" if len(created_tables) > 5 else "")) if validated_tables: - logger.info(f"✅ 验证 {len(validated_tables)} 个已存在的表") + logger.info(f" 验证 {len(validated_tables)} 个已存在的表") if all_diffs: - logger.info(f"⚠️ 发现 {len(all_diffs)} 个表存在结构差异") + logger.info(f" 发现 {len(all_diffs)} 个表存在结构差异") if auto_fix: - logger.info("✅ 已尝试自动修复") + logger.info(" 已尝试自动修复") else: if validated_tables: - logger.info("✅ 所有表结构验证通过") + logger.info(" 所有表结构验证通过") logger.info("=" * 70) @@ -176,14 +176,14 @@ async def _create_table(self, table_name: str, table_obj): try: async with self.engine.begin() as conn: await conn.run_sync(table_obj.create, checkfirst=True) - logger.info(f" ✅ 表已创建: {table_name}") + logger.info(f" 表已创建: {table_name}") except Exception as e: # 检查是否是索引已存在的错误(这是正常情况,可以忽略) error_msg = str(e).lower() if 'index' in error_msg and 'already exists' in error_msg: - logger.info(f" ✅ 表和索引已存在,跳过创建: {table_name}") + logger.info(f" 表和索引已存在,跳过创建: {table_name}") else: - logger.error(f" ❌ 创建表失败: {e}") + logger.error(f" 创建表失败: {e}") async def _get_table_columns(self, table_name: str) -> Dict[str, ColumnInfo]: """ @@ -318,10 +318,10 @@ def _normalize_type(self, type_str: str) -> str: 'DOUBLE': 'FLOAT', 'VARCHAR': 'STRING', 'CHAR': 'STRING', - 'BIGINT': 'BIGINT', # 保持 BIGINT,因为它常用于时间戳 + 'BIGINT': 'BIGINT', # 保持 BIGINT,因为它常用于时间戳 'TINYINT': 'INT', 'SMALLINT': 'INT', - 'TIMESTAMP': 'DATETIME', # 统一时间类型 + 'TIMESTAMP': 'DATETIME', # 统一时间类型 } return type_map.get(type_str, type_str) @@ -343,7 +343,7 @@ def _types_compatible(self, type1: str, type2: str) -> bool: return True # STRING 类型族 - string_types = {'STRING', 'TEXT', 'VARCHAR', 'CHAR'} + string_types = {'STRING', 'TEXT', 'VARCHAR', 'CHAR', 'MEDIUMTEXT', 'LONGTEXT'} if type1 in string_types and type2 in string_types: return True @@ -359,18 +359,18 @@ def _types_compatible(self, type1: str, type2: str) -> bool: def _log_table_diff(self, diff: TableDiff): """记录表差异""" if diff.missing_columns: - logger.warning(f" ⚠️ 缺失字段: {', '.join(diff.missing_columns)}") + logger.warning(f" 缺失字段: {', '.join(diff.missing_columns)}") if diff.extra_columns: - logger.info(f" ℹ️ 额外字段(旧版本遗留): {', '.join(diff.extra_columns)}") + logger.info(f" 额外字段(旧版本遗留): {', '.join(diff.extra_columns)}") if diff.type_mismatches: for col, expected, actual in diff.type_mismatches: - logger.warning(f" ⚠️ 字段类型不匹配: {col} (期望: {expected}, 实际: {actual})") + logger.warning(f" 字段类型不匹配: {col} (期望: {expected}, 实际: {actual})") if diff.nullable_mismatches: for col, expected, actual in diff.nullable_mismatches: - logger.warning(f" ⚠️ Nullable属性不匹配: {col} (期望: {expected}, 实际: {actual})") + logger.warning(f" Nullable属性不匹配: {col} (期望: {expected}, 实际: {actual})") async def _fix_table_structure(self, table_name: str, table_obj, diff: TableDiff): """ @@ -381,7 +381,7 @@ async def _fix_table_structure(self, table_name: str, table_obj, diff: TableDiff table_obj: SQLAlchemy Table对象 diff: 差异信息 """ - logger.info(f" 🔧 开始修复表结构...") + logger.info(f" 开始修复表结构...") # 1. 添加缺失字段 if diff.missing_columns: @@ -389,14 +389,14 @@ async def _fix_table_structure(self, table_name: str, table_obj, diff: TableDiff # 2. 类型不匹配和nullable不匹配 - 警告用户 if diff.type_mismatches: - logger.warning(f" ⚠️ 字段类型不匹配需要手动处理,建议重建表或手动ALTER TABLE") + logger.warning(f" 字段类型不匹配需要手动处理,建议重建表或手动ALTER TABLE") if diff.nullable_mismatches: - logger.warning(f" ⚠️ Nullable属性不匹配可能影响数据完整性,请检查") + logger.warning(f" Nullable属性不匹配可能影响数据完整性,请检查") # 3. 额外字段 - 保留不删除 (向后兼容) if diff.extra_columns: - logger.info(f" ℹ️ 保留额外字段作为历史数据: {', '.join(diff.extra_columns)}") + logger.info(f" 保留额外字段作为历史数据: {', '.join(diff.extra_columns)}") async def _add_missing_columns(self, table_name: str, table_obj, missing_columns: List[str]): """添加缺失字段""" @@ -417,10 +417,10 @@ async def _add_missing_columns(self, table_name: str, table_obj, missing_columns await session.execute(text(alter_sql)) await session.commit() - logger.info(f" ✅ 已添加字段: {col_name}") + logger.info(f" 已添加字段: {col_name}") except Exception as e: - logger.error(f" ❌ 添加字段 {col_name} 失败: {e}") + logger.error(f" 添加字段 {col_name} 失败: {e}") def _get_column_type_sql(self, column) -> str: """获取字段类型的SQL表示""" @@ -487,9 +487,7 @@ async def close(self): await self.engine.dispose() -# ============================================================ # 便捷函数 -# ============================================================ async def validate_and_fix_schema( db_url: str, diff --git a/utils/task_scheduler.py b/utils/task_scheduler.py index a0b66ec..040c7c4 100644 --- a/utils/task_scheduler.py +++ b/utils/task_scheduler.py @@ -28,11 +28,11 @@ class TaskSchedulerManager: def __init__(self): """初始化任务调度器""" self.scheduler = AsyncIOScheduler( - timezone='Asia/Shanghai', # 设置时区 + timezone='Asia/Shanghai', # 设置时区 job_defaults={ - 'coalesce': False, # 不合并多个未执行的任务 - 'max_instances': 1, # 每个任务最多同时运行1个实例 - 'misfire_grace_time': 60 # 错过执行时间后60秒内仍然执行 + 'coalesce': False, # 不合并多个未执行的任务 + 'max_instances': 1, # 每个任务最多同时运行1个实例 + 'misfire_grace_time': 60 # 错过执行时间后60秒内仍然执行 } ) self._started = False @@ -43,14 +43,14 @@ async def start(self): if not self._started: self.scheduler.start() self._started = True - logger.info("✅ [任务调度器] 已启动") + logger.info(" [任务调度器] 已启动") async def stop(self): """停止调度器""" if self._started: self.scheduler.shutdown(wait=True) self._started = False - logger.info("✅ [任务调度器] 已停止") + logger.info(" [任务调度器] 已停止") def add_interval_job( self, @@ -101,10 +101,10 @@ def add_interval_job( replace_existing=True, **kwargs ) - logger.info(f"✅ [任务调度器] 已添加周期任务: {job_id}") + logger.info(f" [任务调度器] 已添加周期任务: {job_id}") return job except Exception as e: - logger.error(f"❌ [任务调度器] 添加周期任务失败 ({job_id}): {e}") + logger.error(f" [任务调度器] 添加周期任务失败 ({job_id}): {e}") return None def add_cron_job( @@ -169,10 +169,10 @@ def add_cron_job( replace_existing=True, **kwargs ) - logger.info(f"✅ [任务调度器] 已添加 cron 任务: {job_id}") + logger.info(f" [任务调度器] 已添加 cron 任务: {job_id}") return job except Exception as e: - logger.error(f"❌ [任务调度器] 添加 cron 任务失败 ({job_id}): {e}") + logger.error(f" [任务调度器] 添加 cron 任务失败 ({job_id}): {e}") return None def add_date_job( @@ -210,10 +210,10 @@ def add_date_job( replace_existing=True, **kwargs ) - logger.info(f"✅ [任务调度器] 已添加一次性任务: {job_id} (执行时间: {run_date})") + logger.info(f" [任务调度器] 已添加一次性任务: {job_id} (执行时间: {run_date})") return job except Exception as e: - logger.error(f"❌ [任务调度器] 添加一次性任务失败 ({job_id}): {e}") + logger.error(f" [任务调度器] 添加一次性任务失败 ({job_id}): {e}") return None def remove_job(self, job_id: str) -> bool: @@ -228,10 +228,10 @@ def remove_job(self, job_id: str) -> bool: """ try: self.scheduler.remove_job(job_id) - logger.info(f"✅ [任务调度器] 已删除任务: {job_id}") + logger.info(f" [任务调度器] 已删除任务: {job_id}") return True except Exception as e: - logger.error(f"❌ [任务调度器] 删除任务失败 ({job_id}): {e}") + logger.error(f" [任务调度器] 删除任务失败 ({job_id}): {e}") return False def pause_job(self, job_id: str) -> bool: @@ -246,10 +246,10 @@ def pause_job(self, job_id: str) -> bool: """ try: self.scheduler.pause_job(job_id) - logger.info(f"⏸️ [任务调度器] 已暂停任务: {job_id}") + logger.info(f" [任务调度器] 已暂停任务: {job_id}") return True except Exception as e: - logger.error(f"❌ [任务调度器] 暂停任务失败 ({job_id}): {e}") + logger.error(f" [任务调度器] 暂停任务失败 ({job_id}): {e}") return False def resume_job(self, job_id: str) -> bool: @@ -264,10 +264,10 @@ def resume_job(self, job_id: str) -> bool: """ try: self.scheduler.resume_job(job_id) - logger.info(f"▶️ [任务调度器] 已恢复任务: {job_id}") + logger.info(f" [任务调度器] 已恢复任务: {job_id}") return True except Exception as e: - logger.error(f"❌ [任务调度器] 恢复任务失败 ({job_id}): {e}") + logger.error(f" [任务调度器] 恢复任务失败 ({job_id}): {e}") return False def get_job(self, job_id: str) -> Optional[Job]: @@ -302,9 +302,7 @@ def get_job_stats(self, job_id: str) -> Optional[dict]: } -# ============================================================ # 全局单例 -# ============================================================ _global_task_scheduler: Optional[TaskSchedulerManager] = None diff --git a/web_res/static/js/macos/apps/Dashboard.js b/web_res/static/js/macos/apps/Dashboard.js index 711bcb3..072d4af 100644 --- a/web_res/static/js/macos/apps/Dashboard.js +++ b/web_res/static/js/macos/apps/Dashboard.js @@ -97,6 +97,12 @@ window.AppDashboard = {
+ +
+

Hook注入耗时分析

+
+
+
@@ -289,6 +295,7 @@ window.AppDashboard = { this.updateResponseTime(); this.updateLearningGauge(); this.updateSystemRadar(); + this.updateHookPerf(); this.updateStyleChart(); this.updateHeatmap(); }, @@ -606,6 +613,115 @@ window.AppDashboard = { ); }, + /* ---------- 5.5 Hook注入耗时分析 - 堆叠柱状图 ---------- */ + updateHookPerf() { + var chart = + this.chartInstances["hookPerfChart"] || this.initChart("hookPerfChart"); + if (!chart) return; + + var perf = this.metrics.hook_performance || {}; + var samples = perf.recent_samples || []; + + if (samples.length === 0) { + chart.setOption(this.emptyOption("暂无Hook耗时数据"), true); + return; + } + + // 取最近30条 + var recent = samples.slice(-30); + var labels = recent.map(function (s, i) { + var d = new Date(s.ts * 1000); + return ( + d.getHours() + + ":" + + ("0" + d.getMinutes()).slice(-2) + + ":" + + ("0" + d.getSeconds()).slice(-2) + ); + }); + var socialData = recent.map(function (s) { + return Math.round(s.social_ctx_ms || 0); + }); + var v2Data = recent.map(function (s) { + return Math.round(s.v2_ctx_ms || 0); + }); + var diversityData = recent.map(function (s) { + return Math.round(s.diversity_ms || 0); + }); + var jargonData = recent.map(function (s) { + return Math.round(s.jargon_ms || 0); + }); + + chart.setOption( + { + tooltip: { + trigger: "axis", + axisPointer: { type: "shadow" }, + formatter: function (params) { + var tip = params[0].axisValue + "
"; + var total = 0; + params.forEach(function (p) { + tip += + p.marker + " " + p.seriesName + ": " + p.value + "ms
"; + total += p.value; + }); + tip += "总计: " + total + "ms"; + return tip; + }, + }, + legend: { + data: ["社交上下文", "V2上下文", "多样性", "黑话"], + bottom: 0, + textStyle: { fontSize: 10 }, + }, + grid: { + left: "3%", + right: "4%", + bottom: "15%", + top: "8%", + containLabel: true, + }, + xAxis: { + type: "category", + data: labels, + axisLabel: { rotate: 45, fontSize: 9 }, + }, + yAxis: { type: "value", name: "ms" }, + series: [ + { + name: "社交上下文", + type: "bar", + stack: "hook", + data: socialData, + itemStyle: { color: "#1976d2" }, + }, + { + name: "V2上下文", + type: "bar", + stack: "hook", + data: v2Data, + itemStyle: { color: "#43a047" }, + }, + { + name: "多样性", + type: "bar", + stack: "hook", + data: diversityData, + itemStyle: { color: "#ff9800" }, + }, + { + name: "黑话", + type: "bar", + stack: "hook", + data: jargonData, + itemStyle: { color: "#7b1fa2" }, + }, + ], + }, + true, + ); + }, + /* ---------- 6. 对话风格学习进度 - 混合柱线图 ---------- */ updateStyleChart() { var echarts = window.echarts; @@ -816,6 +932,7 @@ window.AppDashboard = { self.updateResponseTime(); self.updateLearningGauge(); self.updateSystemRadar(); + self.updateHookPerf(); self.updateStyleChart(); self.updateHeatmap(); diff --git a/webui/app.py b/webui/app.py index a58e06d..4d4c9bd 100644 --- a/webui/app.py +++ b/webui/app.py @@ -46,7 +46,7 @@ def create_app(webui_config: WebUIConfig = None) -> Quart: async def root_redirect(): return redirect("/api/") - logger.info("✅ [WebUI] Quart 应用创建成功") + logger.info(" [WebUI] Quart 应用创建成功") return app @@ -64,6 +64,6 @@ def register_blueprints(app: Quart): for bp in blueprints: app.register_blueprint(bp) - logger.info(f"✅ [WebUI] 已注册蓝图: {bp.name}") + logger.info(f" [WebUI] 已注册蓝图: {bp.name}") - logger.info(f"✅ [WebUI] 共注册 {len(blueprints)} 个蓝图") + logger.info(f" [WebUI] 共注册 {len(blueprints)} 个蓝图") diff --git a/webui/blueprints/__init__.py b/webui/blueprints/__init__.py index 42d2400..2fad6a0 100644 --- a/webui/blueprints/__init__.py +++ b/webui/blueprints/__init__.py @@ -49,7 +49,7 @@ def register_blueprints(app): blueprints = get_blueprints() for bp in blueprints: app.register_blueprint(bp) - print(f"✅ [WebUI] 已注册蓝图: {bp.name}") + print(f" [WebUI] 已注册蓝图: {bp.name}") __all__ = [ diff --git a/webui/blueprints/intelligent_chat.py b/webui/blueprints/intelligent_chat.py index ba19eb4..92201d7 100644 --- a/webui/blueprints/intelligent_chat.py +++ b/webui/blueprints/intelligent_chat.py @@ -115,7 +115,7 @@ async def get_goal_statistics(): async def get_goal_templates(): """获取所有可用的目标类型""" try: - from ...services.conversation_goal_manager import ConversationGoalManager + from ...services.quality import ConversationGoalManager templates = { key: { diff --git a/webui/blueprints/learning.py b/webui/blueprints/learning.py index e4e2257..abe67cf 100644 --- a/webui/blueprints/learning.py +++ b/webui/blueprints/learning.py @@ -127,29 +127,82 @@ async def get_style_learning_content_text(): 'history': [] } - if database_manager: + if database_manager and hasattr(database_manager, 'get_session'): + from sqlalchemy import select, desc, func + from ...models.orm import ( + RawMessage, StyleLearningReview, + ExpressionPattern, LearningBatch, + ) + from datetime import datetime + import time as time_module + import json as json_module + try: - # Get recent raw messages for dialogues - if hasattr(database_manager, 'get_session'): - from sqlalchemy import select, desc - from ...models.orm import RawMessage - - async with database_manager.get_session() as session: - stmt = select(RawMessage).order_by(desc(RawMessage.timestamp)).limit(20) - result = await session.execute(stmt) - raw_messages = result.scalars().all() - - for msg in raw_messages: - message_text = msg.message if msg.message else '' - if len(message_text.strip()) < 5: - continue - from datetime import datetime - import time as time_module - content_data['dialogues'].append({ - 'timestamp': datetime.fromtimestamp(msg.timestamp if msg.timestamp else time_module.time()).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"{msg.sender_name or msg.sender_id}: {message_text}", - 'metadata': f"群组: {msg.group_id}, 平台: {msg.platform or '未知'}" - }) + async with database_manager.get_session() as session: + # 1. dialogues — 最近的原始消息 + stmt = select(RawMessage).order_by(desc(RawMessage.timestamp)).limit(20) + result = await session.execute(stmt) + for msg in result.scalars().all(): + message_text = msg.message if msg.message else '' + if len(message_text.strip()) < 5: + continue + content_data['dialogues'].append({ + 'timestamp': datetime.fromtimestamp(msg.timestamp if msg.timestamp else time_module.time()).strftime('%Y-%m-%d %H:%M:%S'), + 'text': f"{msg.sender_name or msg.sender_id}: {message_text}", + 'metadata': f"群组: {msg.group_id}, 平台: {msg.platform or '未知'}" + }) + + # 2. analysis — 已审批的风格学习分析结果 + analysis_stmt = ( + select(StyleLearningReview) + .where(StyleLearningReview.status.in_(['approved', 'pending'])) + .order_by(desc(StyleLearningReview.timestamp)) + .limit(20) + ) + analysis_result = await session.execute(analysis_stmt) + for review in analysis_result.scalars().all(): + patterns = [] + if review.learned_patterns: + try: + patterns = json_module.loads(review.learned_patterns) + except (json_module.JSONDecodeError, TypeError): + pass + content_data['analysis'].append({ + 'timestamp': datetime.fromtimestamp(review.timestamp).strftime('%Y-%m-%d %H:%M:%S') if review.timestamp else '', + 'text': review.description or review.few_shots_content or f"风格学习 ({review.type})", + 'metadata': f"群组: {review.group_id}, 状态: {review.status}, 模式数: {len(patterns) if isinstance(patterns, list) else 0}" + }) + + # 3. features — 已学习的表达模式 + features_stmt = ( + select(ExpressionPattern) + .order_by(desc(ExpressionPattern.last_active_time)) + .limit(20) + ) + features_result = await session.execute(features_stmt) + for pattern in features_result.scalars().all(): + content_data['features'].append({ + 'timestamp': datetime.fromtimestamp(pattern.last_active_time).strftime('%Y-%m-%d %H:%M:%S') if pattern.last_active_time else '', + 'text': f"场景: {pattern.situation}\n表达: {pattern.expression}", + 'metadata': f"群组: {pattern.group_id}, 权重: {pattern.weight:.2f}" + }) + + # 4. history — 学习批次历史 + history_stmt = ( + select(LearningBatch) + .order_by(desc(LearningBatch.start_time)) + .limit(20) + ) + history_result = await session.execute(history_stmt) + for batch in history_result.scalars().all(): + duration = '' + if batch.start_time and batch.end_time: + duration = f", 耗时: {batch.end_time - batch.start_time:.1f}s" + content_data['history'].append({ + 'timestamp': datetime.fromtimestamp(batch.start_time).strftime('%Y-%m-%d %H:%M:%S') if batch.start_time else '', + 'text': f"批次: {batch.batch_name or batch.batch_id}, 质量: {batch.quality_score or 0:.3f}", + 'metadata': f"群组: {batch.group_id}, 消息: {batch.processed_messages or 0}, 成功: {'是' if batch.success else '否'}{duration}" + }) except Exception as e: logger.warning(f"获取学习内容文本失败: {e}") diff --git a/webui/blueprints/metrics.py b/webui/blueprints/metrics.py index 55072a6..7e7dd65 100644 --- a/webui/blueprints/metrics.py +++ b/webui/blueprints/metrics.py @@ -133,6 +133,15 @@ async def get_metrics(): except Exception: pass + # Hook performance timing + hook_performance = {} + perf_collector = container.perf_collector + if perf_collector and hasattr(perf_collector, 'get_perf_data'): + try: + hook_performance = perf_collector.get_perf_data(recent_limit=50) + except Exception as e: + logger.warning(f"获取Hook性能数据失败: {e}") + import time metrics = { "llm_calls": llm_stats, @@ -140,6 +149,7 @@ async def get_metrics(): "filtered_messages": filtered_messages, "system_metrics": system_metrics, "learning_sessions": learning_sessions, + "hook_performance": hook_performance, "last_updated": time.time() } diff --git a/webui/dependencies.py b/webui/dependencies.py index a2b2ba7..ce6d065 100644 --- a/webui/dependencies.py +++ b/webui/dependencies.py @@ -52,6 +52,9 @@ def __init__(self): # 智能指标服务 self.intelligence_metrics_service: Optional[Any] = None + # 性能计时收集器(指向插件实例的 get_perf_data 方法) + self.perf_collector: Optional[Any] = None + self._initialized = True def initialize( @@ -88,7 +91,7 @@ def initialize( # 获取人格更新器 try: self.persona_updater = service_factory.get_persona_updater() - logger.info(f"✅ [WebUI] persona_updater 获取成功: {type(self.persona_updater)}") + logger.info(f" [WebUI] persona_updater 获取成功: {type(self.persona_updater)}") except Exception as e: logger.warning(f"获取 persona_updater 失败: {e}") self.persona_updater = None @@ -99,7 +102,7 @@ def initialize( # 初始化智能指标服务 try: - from ..services.intelligence_metrics import IntelligenceMetricsService + from ..services.analysis import IntelligenceMetricsService self.intelligence_metrics_service = IntelligenceMetricsService( plugin_config, self.database_manager, @@ -115,7 +118,7 @@ def initialize( self.persona_web_manager = PersonaWebManager(astrbot_persona_manager) # 传递 group_id_to_unified_origin 映射引用(多配置文件支持) self.persona_web_manager.group_id_to_unified_origin = self.group_id_to_unified_origin - logger.info("✅ [WebUI] PersonaWebManager 初始化成功") + logger.info(" [WebUI] PersonaWebManager 初始化成功") except Exception as e: logger.warning(f"初始化 PersonaWebManager 失败: {e}") self.persona_web_manager = None @@ -123,7 +126,7 @@ def initialize( logger.warning("astrbot_persona_manager 未提供,无法初始化 PersonaWebManager") self.persona_web_manager = None - logger.info("✅ [WebUI] 服务容器初始化完成") + logger.info(" [WebUI] 服务容器初始化完成") def get_plugin_config(self): """获取插件配置""" @@ -151,9 +154,7 @@ def get_container() -> ServiceContainer: return _container -# ============================================================ # 兼容原有的 set_plugin_services 接口 -# ============================================================ async def set_plugin_services( plugin_config, @@ -180,4 +181,4 @@ async def set_plugin_services( group_id_to_unified_origin=group_id_to_unified_origin ) - logger.info("✅ [WebUI] 插件服务设置完成") + logger.info(" [WebUI] 插件服务设置完成") diff --git a/webui/manager.py b/webui/manager.py new file mode 100644 index 0000000..22c1fc8 --- /dev/null +++ b/webui/manager.py @@ -0,0 +1,222 @@ +"""WebUI 服务器全生命周期管理 — 创建、启动、停止、服务注册""" +import asyncio +import gc +import sys +from typing import Optional, Any, Dict, TYPE_CHECKING + +from astrbot.api import logger + +from .server import Server +from .dependencies import get_container as _get_webui_container, set_plugin_services + +if TYPE_CHECKING: + from ..config import PluginConfig + from ..core.factory import FactoryManager + +# 模块级服务器实例(原 main.py 中的 global server_instance) +_server_instance: Optional[Server] = None +_server_cleanup_lock = asyncio.Lock() + + +def get_server_instance() -> Optional[Server]: + return _server_instance + + +class WebUIManager: + """WebUI 服务器全生命周期管理""" + + def __init__( + self, + plugin_config: "PluginConfig", + context: Any, + factory_manager: "FactoryManager", + perf_tracker: Any, + group_id_to_unified_origin: Dict[str, str], + ): + self._config = plugin_config + self._context = context + self._factory_manager = factory_manager + self._perf_tracker = perf_tracker + self._group_id_to_unified_origin = group_id_to_unified_origin + + # 创建 + + def create_server(self) -> bool: + """创建 Server 实例(不启动)。返回 True 表示需要立即启动。""" + global _server_instance + + if not self._config.enable_web_interface: + logger.info("WebUI 未启用") + return False + + logger.info(f"准备创建 Server 实例,端口: {self._config.web_interface_port}") + try: + if _server_instance is not None: + logger.warning("检测到已存在的 Web 服务器实例,可能是插件重载") + if ( + hasattr(_server_instance, "server_thread") + and _server_instance.server_thread + and _server_instance.server_thread.is_alive() + ): + logger.warning("旧的 Web 服务器仍在运行,将复用该实例") + logger.info( + f"Web 服务器地址: http://{_server_instance.host}:{_server_instance.port}" + ) + return False + else: + logger.info("旧的 Web 服务器已停止,创建新实例") + _server_instance = None + + if _server_instance is None: + _server_instance = Server(port=self._config.web_interface_port) + if _server_instance: + logger.info( + f"Web 服务器实例已创建 " + f"({_server_instance.host}:{_server_instance.port}),将在 on_load 中启动" + ) + return True # 需要立即启动 + else: + logger.error("Web 服务器实例创建失败") + except Exception as e: + logger.error(f"创建 Web 服务器实例失败: {e}", exc_info=True) + + return False + + # 启动 + + async def immediate_start(self, db_manager: Any) -> None: + """__init__ 阶段立即启动 WebUI(通过 asyncio.create_task 调用)""" + await asyncio.sleep(1) # 等待插件完全初始化 + + global _server_instance + if not _server_instance or not self._config.enable_web_interface: + logger.error("server_instance 为空或 web_interface 未启用") + return + + # 启动数据库 + try: + db_started = await db_manager.start() + if not db_started: + raise RuntimeError("数据库管理器启动失败") + except Exception as e: + logger.error(f"启动数据库管理器失败: {e}", exc_info=True) + raise + + # 设置 WebUI 服务 + astrbot_pm = await self._acquire_persona_manager() + try: + await self._setup_services(astrbot_pm) + except Exception as e: + logger.error(f"设置插件服务失败: {e}", exc_info=True) + return + + # 启动服务器 + try: + await _server_instance.start() + logger.info("Web 服务器已成功启动") + except Exception as e: + logger.error(f"Web 服务器启动失败: {e}", exc_info=True) + logger.error("端口可能仍被占用,WebUI 不可用") + _server_instance = None + + async def setup_and_start(self) -> None: + """on_load 阶段设置服务并启动。""" + global _server_instance + + if not self._config.enable_web_interface or not _server_instance: + if not self._config.enable_web_interface: + logger.info("WebUI 未启用,跳过启动") + if not _server_instance: + logger.error("server_instance 为空,无法启动 Web 服务器") + return + + # 设置 WebUI 服务 + astrbot_pm = await self._acquire_persona_manager() + try: + await self._setup_services(astrbot_pm) + logger.info("Web 服务器插件服务设置完成") + except Exception as e: + logger.error(f"设置 Web 服务器插件服务失败: {e}", exc_info=True) + return + + # 启动服务器 + try: + logger.info( + f"准备启动 Web 服务器: " + f"http://{_server_instance.host}:{_server_instance.port}" + ) + await _server_instance.start() + logger.info("Web 服务器启动完成") + except Exception as e: + logger.error(f"Web 服务器启动失败: {e}", exc_info=True) + + # 停止 + + async def stop(self) -> None: + """有序停止 WebUI 服务器""" + global _server_instance, _server_cleanup_lock + + async with _server_cleanup_lock: + if not _server_instance: + return + try: + logger.info(f"正在停止 Web 服务器 (端口: {_server_instance.port})...") + await _server_instance.stop() + gc.collect() + + if sys.platform == "win32": + logger.info("Windows 环境:等待端口资源释放...") + await asyncio.sleep(2.0) + + _server_instance = None + logger.info("Web 服务器实例已清理") + except Exception as e: + logger.error(f"停止 Web 服务器失败: {e}", exc_info=True) + _server_instance = None + + # 内部方法 + + async def _acquire_persona_manager(self) -> Any: + """获取 AstrBot 框架 PersonaManager(带延迟重试)""" + astrbot_persona_manager = None + try: + if hasattr(self._context, "persona_manager"): + astrbot_persona_manager = self._context.persona_manager + if astrbot_persona_manager: + logger.info( + f"成功获取 AstrBot 框架 PersonaManager: " + f"{type(astrbot_persona_manager)}" + ) + else: + logger.warning("Context 中 persona_manager 为 None") + else: + logger.warning("Context 中没有 persona_manager 属性") + + if not astrbot_persona_manager: + logger.info("尝试延迟获取 PersonaManager...") + await asyncio.sleep(3) + if ( + hasattr(self._context, "persona_manager") + and self._context.persona_manager + ): + astrbot_persona_manager = self._context.persona_manager + logger.info( + f"延迟获取成功: {type(astrbot_persona_manager)}" + ) + else: + logger.warning("延迟获取 PersonaManager 仍然失败") + except Exception as e: + logger.error(f"获取 AstrBot 框架 PersonaManager 失败: {e}", exc_info=True) + + return astrbot_persona_manager + + async def _setup_services(self, astrbot_persona_manager: Any) -> None: + """调用 set_plugin_services 注册服务到 WebUI 容器""" + await set_plugin_services( + self._config, + self._factory_manager, + None, + astrbot_persona_manager, + self._group_id_to_unified_origin, + ) + _get_webui_container().perf_collector = self._perf_tracker diff --git a/webui/services/bug_report_service.py b/webui/services/bug_report_service.py index 876c502..63360ef 100644 --- a/webui/services/bug_report_service.py +++ b/webui/services/bug_report_service.py @@ -30,7 +30,7 @@ def get_bug_report_config(self) -> Dict[str, Any]: """ # Bug报告配置常量 BUG_REPORT_ENABLED = getattr(self.webui_config, 'bug_report_enabled', True) - BUG_REPORT_ATTACHMENT_ENABLED = False # 暂时禁用附件 + BUG_REPORT_ATTACHMENT_ENABLED = False # 暂时禁用附件 BUG_CLOUD_FUNCTION_URL = os.getenv( "ASTRBOT_BUG_CLOUD_URL", "http://zentao-g-submit-rwpsiodjrb.cn-hangzhou.fcapp.run/zentao-bug-submit/submit-bug" @@ -132,7 +132,7 @@ async def submit_bug_report(self, bug_data: Dict[str, Any]) -> Tuple[bool, str, "http://zentao-g-submit-rwpsiodjrb.cn-hangzhou.fcapp.run/zentao-bug-submit/submit-bug" ) - # ✅ 构建完整的重现步骤,包含所有信息 + # 构建完整的重现步骤,包含所有信息 severity_labels = {1: "致命", 2: "严重", 3: "一般", 4: "轻微"} priority_labels = {1: "紧急", 2: "高", 3: "中", 4: "低"} type_labels = { @@ -177,7 +177,7 @@ async def submit_bug_report(self, bug_data: Dict[str, Any]) -> Tuple[bool, str, {bug_data['steps']} """ - # ✅ 构建请求数据,将完整信息放入steps字段 + # 构建请求数据,将完整信息放入steps字段 payload = { "title": bug_data["title"], "steps": formatted_steps, @@ -193,7 +193,7 @@ async def submit_bug_report(self, bug_data: Dict[str, Any]) -> Tuple[bool, str, logger.info(f"准备提交Bug报告: {payload['title']}") logger.debug(f"Bug报告完整数据: {payload}") - # ✅ 实际调用云函数API + # 实际调用云函数API async with aiohttp.ClientSession() as session: async with session.post( cloud_url, diff --git a/webui/services/config_service.py b/webui/services/config_service.py index 877ab59..3f9a03e 100644 --- a/webui/services/config_service.py +++ b/webui/services/config_service.py @@ -1,7 +1,6 @@ """ 配置服务 - 处理插件配置相关业务逻辑 """ -from dataclasses import asdict from typing import Dict, Any, Tuple from astrbot.api import logger @@ -27,7 +26,7 @@ async def get_config(self) -> Dict[str, Any]: Dict: 插件配置字典 """ if self.plugin_config: - return asdict(self.plugin_config) + return self.plugin_config.to_dict() else: raise ValueError("Plugin config not initialized") @@ -55,4 +54,4 @@ async def update_config(self, new_config: Dict[str, Any]) -> Tuple[bool, str, Di # TODO: 保存配置到文件 # 需要实现配置持久化逻辑 - return True, "Config updated successfully", asdict(self.plugin_config) + return True, "Config updated successfully", self.plugin_config.to_dict() diff --git a/webui/services/jargon_service.py b/webui/services/jargon_service.py index 7c695eb..dea6b59 100644 --- a/webui/services/jargon_service.py +++ b/webui/services/jargon_service.py @@ -110,19 +110,22 @@ async def get_jargon_list( raise ValueError('数据库管理器未初始化') try: - jargons = await self.database_manager.get_recent_jargon_list( + # 获取真实总数 + total = await self.database_manager.get_jargon_count( chat_id=group_id, - limit=page_size * page, only_confirmed=confirmed, ) - # 手动实现分页 - total = len(jargons) - start_idx = (page - 1) * page_size - end_idx = start_idx + page_size - page_jargons = jargons[start_idx:end_idx] if start_idx < total else [] + # DB 层分页 + offset = (page - 1) * page_size + jargons = await self.database_manager.get_recent_jargon_list( + chat_id=group_id, + limit=page_size, + offset=offset, + only_confirmed=confirmed, + ) - formatted = [self._format_jargon_for_frontend(j) for j in page_jargons] + formatted = [self._format_jargon_for_frontend(j) for j in jargons] return { 'jargon_list': formatted, @@ -157,12 +160,8 @@ async def search_jargon( try: results = await self.database_manager.search_jargon( - keyword, chat_id=chat_id + keyword, chat_id=chat_id, confirmed_only=confirmed_only ) - # 按 confirmed_only 过滤 - if confirmed_only: - results = [r for r in results if r.get('is_jargon')] - return [self._format_jargon_for_frontend(r) for r in results] except Exception as e: logger.error(f"搜索黑话失败: {e}", exc_info=True) diff --git a/webui/services/learning_service.py b/webui/services/learning_service.py index e052cea..ffbb697 100644 --- a/webui/services/learning_service.py +++ b/webui/services/learning_service.py @@ -17,7 +17,7 @@ def __init__(self, container): """ self.container = container self.database_manager = container.database_manager - self.db_manager = container.database_manager # 兼容别名 + self.db_manager = container.database_manager # 兼容别名 self.persona_updater = getattr(container, 'persona_updater', None) async def get_style_learning_results(self) -> Dict[str, Any]: @@ -168,10 +168,10 @@ async def approve_style_learning_review(self, review_id: int) -> Tuple[bool, str logger.info(f"update_persona_with_style返回结果: {success_apply}") if success_apply: - logger.info(f"✅ 风格学习审查 {review_id} 已成功应用到人格(使用框架API方式,包含备份)") + logger.info(f" 风格学习审查 {review_id} 已成功应用到人格(使用框架API方式,包含备份)") return True, f'风格学习审查 {review_id} 已批准并应用到人格' else: - logger.warning(f"❌ 风格学习审查 {review_id} 批准成功但应用失败") + logger.warning(f" 风格学习审查 {review_id} 批准成功但应用失败") return True, f'风格学习审查 {review_id} 已批准,但人格应用失败' except Exception as e: diff --git a/webui/services/persona_review_service.py b/webui/services/persona_review_service.py index 70d1dc2..dab8d77 100644 --- a/webui/services/persona_review_service.py +++ b/webui/services/persona_review_service.py @@ -94,7 +94,7 @@ async def get_pending_persona_updates(self, limit: int = 0, offset: int = 0) -> if self.database_manager: try: logger.info("正在获取人格学习审查...") - persona_learning_reviews = await self.database_manager.get_pending_persona_learning_reviews(limit=999999) + persona_learning_reviews = await self.database_manager.get_pending_persona_learning_reviews() logger.info(f"获取到 {len(persona_learning_reviews)} 个人格学习审查") for review in persona_learning_reviews: @@ -170,7 +170,7 @@ async def get_pending_persona_updates(self, limit: int = 0, offset: int = 0) -> if self.database_manager: try: logger.info("正在获取风格学习审查...") - style_reviews = await self.database_manager.get_pending_style_reviews(limit=999999) + style_reviews = await self.database_manager.get_pending_style_reviews() logger.info(f"获取到 {len(style_reviews)} 个风格学习审查") for review in style_reviews: @@ -347,11 +347,11 @@ async def review_persona_update( message += f";{auto_apply_msg}" else: error_msg = create_result.get('error', '未知错误') - logger.warning(f"❌ 人格学习审查 {persona_learning_review_id} 批准成功但创建新人格失败: {error_msg}") + logger.warning(f" 人格学习审查 {persona_learning_review_id} 批准成功但创建新人格失败: {error_msg}") message = f"人格学习审查 {persona_learning_review_id} 已批准,但创建新人格失败: {error_msg}" except Exception as apply_error: - logger.error(f"❌ 创建新人格失败: {apply_error}", exc_info=True) + logger.error(f" 创建新人格失败: {apply_error}", exc_info=True) message = f"人格学习审查 {persona_learning_review_id} 已批准,但创建新人格过程出错: {str(apply_error)}" elif not self.persona_web_manager: logger.warning("PersonaWebManager未初始化,无法创建新人格") @@ -448,11 +448,11 @@ async def _approve_style_learning_review(self, review_id: int) -> Tuple[bool, st return True, msg else: error_msg = create_result.get('error', '未知错误') - logger.warning(f"❌ 风格学习审查 {review_id} 批准成功但创建新人格失败: {error_msg}") + logger.warning(f" 风格学习审查 {review_id} 批准成功但创建新人格失败: {error_msg}") return True, f"风格学习审查 {review_id} 已批准,但创建新人格失败: {error_msg}" except Exception as e: - logger.error(f"❌ 创建新人格失败: {e}", exc_info=True) + logger.error(f" 创建新人格失败: {e}", exc_info=True) return True, f"风格学习审查 {review_id} 已批准,但创建新人格过程出错: {str(e)}" else: logger.warning("PersonaWebManager未初始化,无法创建新人格") diff --git a/webui/services/persona_service.py b/webui/services/persona_service.py index 8cb8abf..88b171d 100644 --- a/webui/services/persona_service.py +++ b/webui/services/persona_service.py @@ -54,24 +54,14 @@ async def get_persona_details(self, persona_id: str) -> Optional[Dict[str, Any]] Returns: Optional[Dict]: 人格详情,如果不存在返回None """ - if not self.persona_manager: - raise ValueError("PersonaManager未初始化") + if not self.persona_web_mgr: + raise ValueError("PersonaWebManager未初始化") try: - persona = await self.persona_manager.get_persona(persona_id) - - persona_dict = { - "persona_id": persona.persona_id, - "system_prompt": persona.system_prompt, - "begin_dialogs": persona.begin_dialogs, - "tools": persona.tools, - "created_at": persona.created_at.isoformat() if hasattr(persona, 'created_at') and persona.created_at else None, - "updated_at": persona.updated_at.isoformat() if hasattr(persona, 'updated_at') and persona.updated_at else None, - } - - return persona_dict - - except ValueError: + all_personas = await self.persona_web_mgr.get_all_personas_for_web() + for persona in all_personas: + if persona.get('persona_id') == persona_id: + return persona return None except Exception as e: logger.error(f"获取人格详情失败: {e}") @@ -190,17 +180,19 @@ async def export_persona(self, persona_id: str) -> Dict[str, Any]: Returns: Dict: 导出的人格配置 """ - if not self.persona_manager: - raise ValueError("PersonaManager未初始化") + if not self.persona_web_mgr: + raise ValueError("PersonaWebManager未初始化") try: - persona = await self.persona_manager.get_persona(persona_id) + persona = await self.get_persona_details(persona_id) + if not persona: + raise ValueError(f"人格 {persona_id} 不存在") persona_export = { - "persona_id": persona.persona_id, - "system_prompt": persona.system_prompt, - "begin_dialogs": persona.begin_dialogs, - "tools": persona.tools, + "persona_id": persona.get("persona_id", ""), + "system_prompt": persona.get("system_prompt", ""), + "begin_dialogs": persona.get("begin_dialogs", []), + "tools": persona.get("tools", []), "export_time": datetime.now().isoformat(), "export_version": "1.0" } @@ -221,8 +213,8 @@ async def import_persona(self, data: Dict[str, Any]) -> Tuple[bool, str, Optiona Returns: Tuple[bool, str, Optional[str]]: (是否成功, 消息, 人格ID) """ - if not self.persona_manager: - raise ValueError("PersonaManager未初始化") + if not self.persona_web_mgr: + raise ValueError("PersonaWebManager未初始化") try: # 验证导入数据格式 @@ -238,37 +230,36 @@ async def import_persona(self, data: Dict[str, Any]) -> Tuple[bool, str, Optiona # 检查是否覆盖现有人格 overwrite = data.get("overwrite", False) - try: - existing_persona = await self.persona_manager.get_persona(persona_id) - except ValueError: - existing_persona = None + existing_persona = await self.get_persona_details(persona_id) if existing_persona and not overwrite: return False, "人格已存在,如要覆盖请设置overwrite=true", None # 创建或更新人格 if existing_persona: - success = await self.persona_manager.update_persona( - persona_id=persona_id, - system_prompt=system_prompt, - begin_dialogs=begin_dialogs, - tools=tools + result = await self.persona_web_mgr.update_persona_via_web( + persona_id, + { + "system_prompt": system_prompt, + "begin_dialogs": begin_dialogs, + "tools": tools, + } ) action = "更新" else: - success = await self.persona_manager.create_persona( - persona_id=persona_id, - system_prompt=system_prompt, - begin_dialogs=begin_dialogs, - tools=tools - ) + result = await self.persona_web_mgr.create_persona_via_web({ + "persona_id": persona_id, + "system_prompt": system_prompt, + "begin_dialogs": begin_dialogs, + "tools": tools, + }) action = "创建" - if success: + if result.get('success'): logger.info(f"成功导入人格: {persona_id} ({action})") return True, f"人格{action}成功", persona_id else: - return False, f"人格{action}失败", None + return False, result.get('error', f"人格{action}失败"), None except Exception as e: logger.error(f"导入人格失败: {e}") diff --git a/webui/services/social_service.py b/webui/services/social_service.py index 340c347..3ceee73 100644 --- a/webui/services/social_service.py +++ b/webui/services/social_service.py @@ -229,7 +229,7 @@ async def trigger_analysis(self, group_id: str) -> Tuple[bool, str]: if not factory_manager: return False, "工厂管理器未初始化" - from ...services.social_relation_analyzer import SocialRelationAnalyzer + from ...services.social import SocialRelationAnalyzer service_factory = factory_manager.get_service_factory() db_manager = service_factory.create_database_manager() diff --git a/webui_legacy.py b/webui_legacy.py deleted file mode 100644 index 3eee899..0000000 --- a/webui_legacy.py +++ /dev/null @@ -1,6273 +0,0 @@ -import os -import asyncio -import json # 导入 json 模块 -import secrets -import time -import base64 -import urllib.request -import urllib.error -import threading -import subprocess -import sys -import gc -import socket -from datetime import datetime, timedelta -from astrbot.api import logger -from typing import Optional, List, Dict, Any -from dataclasses import asdict -from functools import wraps - -from quart import Quart, Blueprint, render_template, request, jsonify, current_app, redirect, url_for, session # 导入 redirect 和 url_for -from quart_cors import cors # 导入 cors -import hypercorn.asyncio -from hypercorn.config import Config as HypercornConfig -try: - from hypercorn.config import Sockets -except ImportError: - class Sockets: - def __init__(self, secure_sockets, insecure_sockets, quic_sockets): - self.secure_sockets = secure_sockets - self.insecure_sockets = insecure_sockets - self.quic_sockets = quic_sockets -import aiohttp -from werkzeug.utils import secure_filename - -from astrbot.core.utils.astrbot_path import get_astrbot_data_path - -from .config import PluginConfig -from .core.factory import FactoryManager -from .persona_web_manager import PersonaWebManager, set_persona_web_manager, get_persona_web_manager -from .services.intelligence_metrics import IntelligenceMetricsService -from .utils.security_utils import ( - PasswordHasher, - login_attempt_tracker, - migrate_password_to_hashed, - verify_password_with_migration, - SecurityValidator -) -from .constants import ( - UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING, - UPDATE_TYPE_STYLE_LEARNING, - UPDATE_TYPE_EXPRESSION_LEARNING, - UPDATE_TYPE_TRADITIONAL, - normalize_update_type, - get_review_source_from_update_type -) - -# ========== 数据库管理器适配层 ========== -class DatabaseManagerAdapter: - """ - 数据库管理器适配层 - 自动检测使用 SQLAlchemy 数据库管理器还是传统数据库管理器 - 并调用相应的方法 - """ - - def __init__(self, db_manager): - self.db_manager = db_manager - self._is_sqlalchemy = self._detect_sqlalchemy() - - def _detect_sqlalchemy(self) -> bool: - """检测是否为 SQLAlchemy 数据库管理器""" - if not self.db_manager: - return False - # 检查类名或特定方法来判断类型 - class_name = type(self.db_manager).__name__ - logger.debug(f"检测到数据库管理器类型: {class_name}") - return 'SQLAlchemy' in class_name or hasattr(self.db_manager, '_legacy_db') - - async def safe_call(self, method_name: str, *args, **kwargs): - """ - 安全调用数据库方法 - 如果 SQLAlchemy 管理器没有实现该方法,自动降级到传统管理器 - """ - try: - if not self.db_manager: - logger.warning(f"数据库管理器不可用,无法调用 {method_name}") - return None - - # 获取方法 - if hasattr(self.db_manager, method_name): - method = getattr(self.db_manager, method_name) - result = await method(*args, **kwargs) - return result - else: - logger.warning(f"方法 {method_name} 在当前数据库管理器中不存在") - return None - - except Exception as e: - logger.error(f"调用数据库方法 {method_name} 失败: {e}", exc_info=True) - return None - - async def get_db_connection(self): - """获取数据库连接""" - return await self.safe_call('get_db_connection') - - async def get_messages_statistics(self): - """获取消息统计""" - return await self.safe_call('get_messages_statistics') - - async def get_group_messages_statistics(self, group_id: str): - """获取群组消息统计""" - return await self.safe_call('get_group_messages_statistics', group_id) - - async def get_social_relations_by_group(self, group_id: str): - """获取群组社交关系""" - return await self.safe_call('get_social_relations_by_group', group_id) - - async def get_filtered_messages_for_learning(self, limit: int = None): - """获取用于学习的筛选消息""" - return await self.safe_call('get_filtered_messages_for_learning', limit) - - async def get_recent_raw_messages(self, group_id: str, limit: int = 200): - """获取最近的原始消息""" - return await self.safe_call('get_recent_raw_messages', group_id, limit) - - async def get_recent_learning_batches(self, limit: int = 5): - """获取最近的学习批次""" - return await self.safe_call('get_recent_learning_batches', limit) - - # 可以继续添加更多方法... - -# 创建全局适配器实例(稍后初始化) -db_adapter: Optional[DatabaseManagerAdapter] = None - -# 获取当前文件所在的目录,然后向上两级到达插件根目录 -PLUGIN_ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '.')) -WEB_STATIC_DIR = os.path.join(PLUGIN_ROOT_DIR, "web_res", "static") -WEB_HTML_DIR = os.path.join(WEB_STATIC_DIR, "html") - -def get_password_file_path() -> str: - """动态获取密码文件路径,优先使用config.data_dir""" - if plugin_config and hasattr(plugin_config, 'data_dir'): - # 使用配置的data_dir路径 - return os.path.join(plugin_config.data_dir, "password.json") - else: - # 后备路径:使用插件根目录下的config文件夹 - return os.path.join(PLUGIN_ROOT_DIR, "config", "password.json") - -# 初始化 Quart 应用 -app = Quart(__name__, static_folder=WEB_STATIC_DIR, static_url_path="/static", template_folder=WEB_HTML_DIR) -app.secret_key = secrets.token_hex(16) # 生成随机密钥用于会话管理 -cors(app) # 启用 CORS - -# 全局变量,用于存储插件实例和服务 -plugin_config: Optional[PluginConfig] = None -persona_manager: Optional[Any] = None -persona_updater: Optional[Any] = None -database_manager: Optional[Any] = None -db_manager: Optional[Any] = None # 添加db_manager别名 -llm_client = None -llm_adapter_instance = None # LLM适配器实例,用于社交关系分析等服务 -progressive_learning: Optional[Any] = None # 添加progressive_learning全局变量 -intelligence_metrics_service: Optional[IntelligenceMetricsService] = None # 智能指标计算服务 - -# 新增的变量 -pending_updates: List[Any] = [] -password_config: Dict[str, Any] = {} # 用于存储密码配置 -group_id_to_unified_origin: Dict[str, str] = {} # group_id到unified_msg_origin映射(多配置文件支持) - - -def _resolve_umo(group_id: str) -> str: - """将group_id解析为unified_msg_origin以支持多配置文件""" - return group_id_to_unified_origin.get(group_id, group_id) - -BUG_REPORT_ENABLED = True -# 暂时禁用附件上传功能 -BUG_REPORT_ATTACHMENT_ENABLED = False # TODO: 附件功能待修复后启用 -BUG_CLOUD_FUNCTION_URL = os.getenv( - "ASTRBOT_BUG_CLOUD_URL", - "http://zentao-g-submit-rwpsiodjrb.cn-hangzhou.fcapp.run/zentao-bug-submit/submit-bug" -) # 保持完整URL,不要rstrip -BUG_CLOUD_VERIFY_CODE = os.getenv("ASTRBOT_BUG_CLOUD_VERIFY_CODE", "zentao123") -BUG_REPORT_TIMEOUT_SECONDS = int(os.getenv("ASTRBOT_BUG_REPORT_TIMEOUT", "30")) -BUG_REPORT_DEFAULT_BUILDS = [build.strip() for build in os.getenv("ASTRBOT_BUG_DEFAULT_BUILDS", "v2.0").split(",") if build.strip()] -BUG_REPORT_DEFAULT_SEVERITY = 3 -BUG_REPORT_DEFAULT_PRIORITY = 3 -BUG_REPORT_DEFAULT_TYPE = "codeerror" -BUG_REPORT_MAX_IMAGES = 1 # 云函数只支持单个附件,如需多个文件请打包为压缩包 -BUG_REPORT_MAX_IMAGE_BYTES = 8 * 1024 * 1024 # 8MB per image -BUG_REPORT_MAX_LOG_BYTES = 20_000 -# 安全白名单:允许所有图片、压缩包和文档文件 -BUG_REPORT_ALLOWED_EXTENSIONS = { - # 所有常见图片格式 - '.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.svg', '.ico', '.tiff', '.tif', - # 日志和文本 - '.txt', '.log', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', - # 文档格式 - '.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.odt', '.ods', '.odp', - # 压缩包(用于多文件场景) - '.zip', '.7z', '.rar', '.tar', '.gz', '.tar.gz', '.tgz', '.bz2', '.xz' -} -BUG_REPORT_ALLOWED_MIMETYPES = { - # 所有图片MIME类型 - 'image/png', 'image/jpeg', 'image/gif', 'image/bmp', 'image/webp', 'image/svg+xml', - 'image/x-icon', 'image/vnd.microsoft.icon', 'image/tiff', - # 文本 - 'text/plain', 'text/markdown', 'text/csv', - 'application/json', 'application/xml', 'text/xml', - 'application/x-yaml', 'text/yaml', - # 文档 - 'application/pdf', - 'application/msword', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - 'application/vnd.ms-excel', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - 'application/vnd.ms-powerpoint', 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - 'application/vnd.oasis.opendocument.text', - 'application/vnd.oasis.opendocument.spreadsheet', - 'application/vnd.oasis.opendocument.presentation', - # 压缩包 - 'application/zip', 'application/x-zip-compressed', - 'application/x-7z-compressed', 'application/x-rar-compressed', 'application/vnd.rar', - 'application/x-tar', 'application/gzip', 'application/x-gzip', - 'application/x-bzip2', 'application/x-xz' -} -BUG_REPORT_SEVERITY_OPTIONS = [ - {"value": 1, "label": "S1 - 阻断故障"}, - {"value": 2, "label": "S2 - 重大问题"}, - {"value": 3, "label": "S3 - 普通问题"}, - {"value": 4, "label": "S4 - 建议优化"} -] -BUG_REPORT_PRIORITY_OPTIONS = [ - {"value": 1, "label": "P1 - 紧急"}, - {"value": 2, "label": "P2 - 高"}, - {"value": 3, "label": "P3 - 中"}, - {"value": 4, "label": "P4 - 低"} -] -BUG_REPORT_TYPE_OPTIONS = [ - {"value": "codeerror", "label": "代码缺陷"}, - {"value": "config", "label": "配置问题"}, - {"value": "performance", "label": "性能问题"}, - {"value": "security", "label": "安全问题"}, - {"value": "others", "label": "其他"} -] -BUG_REPORT_LOG_CANDIDATES = [ - "astrbot.log", - "astrbot_debug.log", - "astrbot_plugin.log", - "self_learning.log" -] - - -def _bug_report_available() -> bool: - return BUG_REPORT_ENABLED and bool(BUG_CLOUD_FUNCTION_URL and BUG_CLOUD_VERIFY_CODE) - - -def _is_safe_attachment(filename: str, mimetype: str) -> tuple[bool, str]: - """ - 检查附件是否安全(文件类型白名单验证) - - Args: - filename: 文件名 - mimetype: MIME类型 - - Returns: - (is_safe, error_message): 是否安全及错误信息 - """ - if not filename: - return False, "文件名为空" - - filename_lower = filename.lower() - - # 处理双扩展名(如 .tar.gz) - ext = None - if filename_lower.endswith('.tar.gz'): - ext = '.tar.gz' - else: - _, ext = os.path.splitext(filename_lower) - - # 检查扩展名 - if ext not in BUG_REPORT_ALLOWED_EXTENSIONS: - allowed_exts = ', '.join(sorted(BUG_REPORT_ALLOWED_EXTENSIONS)) - return False, f"不允许的文件类型 '{ext}'。允许的类型:{allowed_exts}" - - # 检查MIME类型(如果提供) - if mimetype and mimetype not in BUG_REPORT_ALLOWED_MIMETYPES: - # 某些MIME类型可能会有变体,只要扩展名在白名单中也可以接受 - logger.warning(f"MIME类型 '{mimetype}' 不在白名单中,但扩展名 '{ext}' 有效") - - # 检查文件名中是否包含路径遍历字符 - if '..' in filename or '/' in filename or '\\' in filename: - return False, "文件名包含非法字符(路径遍历)" - - return True, "" - - -def _load_dashboard_http_config() -> Dict[str, Any]: - try: - data_path = get_astrbot_data_path() - if not data_path: - return {} - config_path = os.path.join(data_path, "cmd_config.json") - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - config_data = json.load(f) - return config_data.get("dashboard", {}) - except Exception as exc: - logger.debug(f"读取dashboard配置失败: {exc}") - return {} - - -def _fetch_dashboard_log_snapshot() -> Optional[str]: - try: - dashboard_cfg = _load_dashboard_http_config() - if dashboard_cfg and not dashboard_cfg.get("enable", True): - return None - - host = dashboard_cfg.get("host", "127.0.0.1") - port = dashboard_cfg.get("port", 6185) - base_url = f"http://{host}:{port}" - url = f"{base_url}/api/log-history" - - req = urllib.request.Request(url, headers={"Accept": "application/json"}) - with urllib.request.urlopen(req, timeout=3) as resp: - payload = json.loads(resp.read().decode("utf-8")) - logs = payload.get("data", {}).get("logs") or payload.get("logs") - if not logs: - return None - - target_dir = None - if plugin_config and getattr(plugin_config, "data_dir", None): - target_dir = os.path.join(plugin_config.data_dir, "bug_log_snapshots") - if not target_dir: - target_dir = os.path.join(PLUGIN_ROOT_DIR, "bug_log_snapshots") - os.makedirs(target_dir, exist_ok=True) - snapshot_path = os.path.join(target_dir, "dashboard_log_history.txt") - - with open(snapshot_path, "w", encoding="utf-8") as f: - for entry in logs[-200:]: - timestamp = entry.get("time", "") - level = entry.get("level", "") - message = entry.get("data", "") - f.write(f"[{timestamp}] {level}: {message}\n") - - return snapshot_path - except urllib.error.URLError as exc: - logger.debug(f"访问dashboard日志接口失败: {exc}") - except Exception as exc: - logger.debug(f"生成dashboard日志快照失败: {exc}") - return None - - -def _find_log_files() -> List[str]: - log_paths: List[str] = [] - - dashboard_snapshot = _fetch_dashboard_log_snapshot() - if dashboard_snapshot: - log_paths.append(dashboard_snapshot) - - candidate_dirs = [] - if plugin_config and getattr(plugin_config, "data_dir", None): - candidate_dirs.append(plugin_config.data_dir) - candidate_dirs.append(os.path.join(plugin_config.data_dir, "logs")) - - astrbot_path = get_astrbot_data_path() - if astrbot_path: - candidate_dirs.append(os.path.join(astrbot_path, "logs")) - candidate_dirs.append(astrbot_path) - - candidate_dirs.append(os.path.join(PLUGIN_ROOT_DIR, "logs")) - candidate_dirs.append(PLUGIN_ROOT_DIR) - - seen = set() - for base in candidate_dirs: - if not base or not os.path.exists(base): - continue - for log_name in BUG_REPORT_LOG_CANDIDATES: - path = os.path.abspath(os.path.join(base, log_name)) - if os.path.exists(path) and path not in seen: - seen.add(path) - log_paths.append(path) - return log_paths - - -def _read_log_snippet(path: str, max_bytes: int = BUG_REPORT_MAX_LOG_BYTES) -> Dict[str, Any]: - try: - size = os.path.getsize(path) - read_bytes = min(size, max_bytes) - with open(path, "rb") as f: - if size > max_bytes: - f.seek(size - max_bytes) - data = f.read(read_bytes) - text = data.decode("utf-8", errors="ignore") - preview_len = min(len(text), 800) - return { - "path": path, - "size": size, - "preview": text[-preview_len:], - "content": text - } - except Exception as exc: - logger.debug(f"读取日志失败 {path}: {exc}") - return {"path": path, "size": 0, "preview": "", "content": ""} - - -def _collect_log_previews(limit: int = 3, include_content: bool = False) -> List[Dict[str, Any]]: - previews = [] - for path in _find_log_files(): - info = _read_log_snippet(path) - if not info["preview"]: - continue - if not include_content and "content" in info: - info.pop("content", None) - previews.append(info) - if len(previews) >= limit: - break - return previews - - -def _collect_recent_logs_text() -> Optional[str]: - cutoff = time.time() - 86400 # 24 hours - log_entries = [] - for path in _find_log_files(): - try: - if os.path.getmtime(path) < cutoff: - continue - snippet = _read_log_snippet(path, BUG_REPORT_MAX_LOG_BYTES) - preview = snippet.get("content") or snippet.get("preview") - if not preview: - continue - log_entries.append( - f"===== {path} (last {len(preview)} chars) =====\n{preview}\n" - ) - except Exception as exc: - logger.debug(f"收集日志文本失败 {path}: {exc}") - continue - - if not log_entries: - return None - return "\n".join(log_entries) - - -def _encode_attachment_from_bytes(filename: str, file_bytes: bytes, content_type: str) -> Dict[str, Any]: - """ - 从字节数据编码附件(参考测试脚本的 _encode_attachment) - - Args: - filename: 文件名 - file_bytes: 文件字节数据 - content_type: MIME类型 - - Returns: - 编码后的附件字典 - """ - # 如果无法确定MIME类型,根据扩展名手动设置(参考测试脚本) - mime_type = content_type - if not mime_type: - filename_lower = filename.lower() - # 处理 .tar.gz 双扩展名 - if filename_lower.endswith('.tar.gz'): - mime_type = 'application/gzip' - else: - ext = os.path.splitext(filename_lower)[1] - mime_type_map = { - # 图片 - '.png': 'image/png', - '.jpg': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.gif': 'image/gif', - '.bmp': 'image/bmp', - '.webp': 'image/webp', - '.svg': 'image/svg+xml', - '.ico': 'image/x-icon', - '.tiff': 'image/tiff', - '.tif': 'image/tiff', - # 文本 - '.txt': 'text/plain', - '.log': 'text/plain', - '.md': 'text/markdown', - '.json': 'application/json', - '.xml': 'application/xml', - '.yaml': 'application/x-yaml', - '.yml': 'application/x-yaml', - '.csv': 'text/csv', - # 文档 - '.pdf': 'application/pdf', - '.doc': 'application/msword', - '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - '.xls': 'application/vnd.ms-excel', - '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - '.ppt': 'application/vnd.ms-powerpoint', - '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - # 压缩包 - '.zip': 'application/zip', - '.rar': 'application/x-rar-compressed', - '.7z': 'application/x-7z-compressed', - '.tar': 'application/x-tar', - '.gz': 'application/gzip', - '.tgz': 'application/gzip', - '.bz2': 'application/x-bzip2', - '.xz': 'application/x-xz', - } - mime_type = mime_type_map.get(ext, "application/octet-stream") - - # Base64 编码 - encoded = base64.b64encode(file_bytes).decode("ascii") - - # 返回格式:与测试脚本完全一致 - return { - "name": filename, - "type": mime_type, - "data": f"data:{mime_type};base64,{encoded}", - } - - -async def _send_bug_report( - bug_fields: Dict[str, Any], - attachment_dict: Optional[Dict[str, Any]] -) -> Dict[str, Any]: - """ - 发送Bug报告到服务器(完全参考测试脚本的 send_bug 函数) - - Args: - bug_fields: Bug字段字典 - attachment_dict: 单个附件字典(可选) - - Returns: - 结果字典 {"success": bool, "message": str, "data": dict} - """ - if not BUG_CLOUD_FUNCTION_URL: - return {"success": False, "message": "服务器地址未配置"} - - # 构建payload - 与测试脚本完全一致 - payload: Dict[str, Any] = { - "verifyCode": BUG_CLOUD_VERIFY_CODE, - "bugData": bug_fields, - } - - # 单个附件 - 使用 "attachment" 字段(单数) - if attachment_dict: - payload["attachment"] = attachment_dict - logger.info(f"Payload包含附件: name={attachment_dict.get('name')}, type={attachment_dict.get('type')}") - - logger.info(f"发送Bug到服务器: {BUG_CLOUD_FUNCTION_URL}") - logger.debug(f"Payload keys: {list(payload.keys())}, bugData keys: {list(bug_fields.keys())}") - - timeout = aiohttp.ClientTimeout(total=BUG_REPORT_TIMEOUT_SECONDS) - - try: - # 参考测试脚本:显式设置 Content-Type 并手动序列化 JSON - headers = {"Content-Type": "application/json"} - payload_json = json.dumps(payload, ensure_ascii=False) - - logger.debug(f"发送的JSON长度: {len(payload_json)} 字节") - - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(BUG_CLOUD_FUNCTION_URL, data=payload_json, headers=headers) as resp: - text = await resp.text() - logger.info(f"服务器响应: status={resp.status}, text_length={len(text)}") - - if resp.status in (200, 201): - try: - data = await resp.json() - logger.info(f"Bug提交成功: {data}") - return {"success": True, "data": data} - except Exception as e: - logger.warning(f"解析响应JSON失败: {e}, 使用原始文本") - return {"success": True, "data": {"raw": text}} - else: - logger.error(f"Bug提交失败: status={resp.status}, response={text[:500]}") - return { - "success": False, - "status": resp.status, - "message": text[:2000] - } - except Exception as e: - logger.error(f"发送Bug请求异常: {e}", exc_info=True) - return {"success": False, "message": f"请求异常: {str(e)}"} - -# 学习内容缓存 -_style_learning_content_cache: Optional[Dict[str, Any]] = None -_style_learning_content_cache_time: Optional[float] = None -_style_learning_content_cache_ttl: int = 300 # 缓存有效期5分钟 - -# 设置日志 -# logger = logging.getLogger(__name__) - -# 性能指标存储 -llm_call_metrics: Dict[str, Dict[str, Any]] = {} - -def load_password_config() -> Dict[str, Any]: - """加载密码配置文件,并自动迁移旧格式""" - password_file_path = get_password_file_path() - if os.path.exists(password_file_path): - with open(password_file_path, 'r', encoding='utf-8') as f: - config = json.load(f) - - # 检查是否需要迁移到新的哈希格式 - if 'password_hash' not in config and 'password' in config: - logger.info("检测到旧格式密码配置,正在迁移到哈希格式...") - config = migrate_password_to_hashed(config) - # 保存迁移后的配置 - save_password_config(config) - logger.info("密码配置迁移完成") - - return config - - # 创建默认配置(使用新的哈希格式) - default_password = "self_learning_pwd" - password_hash, salt = PasswordHasher.hash_password(default_password) - return { - "password_hash": password_hash, - "salt": salt, - "must_change": True, - "version": 2 - } - -def save_password_config(config: Dict[str, Any]): - """保存密码配置文件""" - password_file_path = get_password_file_path() - # 确保目录存在 - os.makedirs(os.path.dirname(password_file_path), exist_ok=True) - with open(password_file_path, 'w', encoding='utf-8') as f: - json.dump(config, f, indent=2) - -def require_auth(f): - """登录验证装饰器""" - @wraps(f) - async def decorated_function(*args, **kwargs): - if not session.get('authenticated'): - if request.is_json: - return jsonify({"error": "Authentication required", "redirect": "/api/login"}), 401 - return redirect(url_for('api.login_page')) - return await f(*args, **kwargs) - return decorated_function - -# 创建别名以保持向后兼容 -login_required = require_auth - -def is_authenticated(): - """检查用户是否已认证""" - return session.get('authenticated', False) - -async def set_plugin_services( - config: PluginConfig, - factory_manager: FactoryManager, - llm_c = None, # 不再使用LLMClient - astrbot_persona_manager = None, # 添加AstrBot PersonaManager参数 - group_id_to_unified_origin_map = None # 多配置文件支持 -): - """设置插件服务实例""" - global plugin_config, persona_manager, persona_updater, database_manager, db_manager, llm_client, llm_adapter_instance, pending_updates, intelligence_metrics_service, group_id_to_unified_origin - plugin_config = config - if group_id_to_unified_origin_map is not None: - group_id_to_unified_origin = group_id_to_unified_origin_map - - # 将配置存储到app中,供API认证使用 - app.plugin_config = config - - # 使用工厂管理器获取LLM适配器 - try: - # 从ServiceFactory获取LLM适配器,而不是ComponentFactory - llm_client = factory_manager.get_service_factory().create_framework_llm_adapter() - llm_adapter_instance = llm_client # 设置llm_adapter_instance别名 - logger.info(f"从服务工厂获取LLM适配器: {type(llm_client)}") - except Exception as e: - logger.error(f"获取LLM适配器失败: {e}") - llm_client = llm_c # 回退到传入的客户端 - llm_adapter_instance = llm_client # 同步设置别名 - - # 总是创建PersonaWebManager,无论是否传入AstrBot PersonaManager - try: - if astrbot_persona_manager: - persona_manager = astrbot_persona_manager - logger.info(f"设置AstrBot PersonaManager: {type(astrbot_persona_manager)}") - else: - logger.warning("未传入AstrBot PersonaManager,将创建空的PersonaWebManager") - # 从工厂管理器获取服务实例 - try: - persona_manager = factory_manager.get_service("persona_manager") - except Exception as e: - logger.error(f"获取persona_manager服务失败: {e}") - persona_manager = None - - # 总是初始化人格Web管理器(即使PersonaManager为None) - persona_web_mgr = set_persona_web_manager(astrbot_persona_manager) - # 传递 group_id_to_unified_origin 映射引用(多配置文件支持) - if group_id_to_unified_origin_map is not None: - persona_web_mgr.group_id_to_unified_origin = group_id_to_unified_origin_map - logger.info(f"创建PersonaWebManager: {persona_web_mgr}") - await persona_web_mgr.initialize() - logger.info("PersonaWebManager初始化成功") - except Exception as e: - logger.error(f"PersonaWebManager初始化失败: {e}", exc_info=True) - # 即使初始化失败,也要创建一个空的PersonaWebManager以避免500错误 - try: - set_persona_web_manager(None) - logger.info("创建了空的PersonaWebManager作为后备方案") - except Exception as fallback_e: - logger.error(f"创建后备PersonaWebManager失败: {fallback_e}") - - # 从工厂管理器获取其他服务实例 - try: - logger.info("开始初始化WebUI服务...") - - # 使用更直接的方法获取服务 - service_factory = factory_manager.get_service_factory() - logger.info("成功获取服务工厂") - - # 获取人格更新器 - logger.info("正在获取人格更新器...") - try: - persona_updater = service_factory.get_persona_updater() - logger.info(f"✅ 成功获取persona_updater: {type(persona_updater)}") - except Exception as e: - logger.error(f"❌ 获取persona_updater失败: {e}", exc_info=True) - persona_updater = None - - # 确保数据库管理器已创建 - logger.info("正在获取数据库管理器...") - try: - # 先尝试直接从factory_manager获取 - database_manager = factory_manager.get_service("database_manager") - if not database_manager: - logger.warning("从factory_manager.get_service获取database_manager为None,尝试创建") - service_factory.create_database_manager() - database_manager = factory_manager.get_service("database_manager") - - db_manager = database_manager # 设置别名 - logger.info(f"✅ 成功获取database_manager: {type(database_manager)}") - except Exception as e: - logger.error(f"❌ 获取database_manager失败: {e}", exc_info=True) - database_manager = None - db_manager = None - - # 获取progressive_learning服务 - logger.info("正在获取progressive_learning服务...") - try: - progressive_learning = factory_manager.get_service("progressive_learning") - logger.info(f"✅ 成功获取progressive_learning: {type(progressive_learning)}") - except Exception as e: - logger.error(f"❌ 获取progressive_learning失败: {e}", exc_info=True) - progressive_learning = None - - # 关键修复:设置全局变量! - logger.info("设置全局变量...") - globals()['persona_updater'] = persona_updater - globals()['database_manager'] = database_manager - globals()['db_manager'] = database_manager - globals()['progressive_learning'] = progressive_learning - - # 初始化数据库适配器 - if database_manager: - logger.info("初始化数据库管理器适配层...") - globals()['db_adapter'] = DatabaseManagerAdapter(database_manager) - logger.info(f"✅ 数据库适配器已初始化,类型: {type(database_manager).__name__}") - else: - logger.warning("⚠️ 数据库管理器不可用,适配器未初始化") - - logger.info(f"全局变量设置完成:") - logger.info(f" - persona_updater: {globals().get('persona_updater') is not None}") - logger.info(f" - database_manager: {globals().get('database_manager') is not None}") - logger.info(f" - progressive_learning: {globals().get('progressive_learning') is not None}") - - if not database_manager: - logger.error("⚠️ 警告: database_manager为None,WebUI人格审查功能将不可用!") - - # 初始化智能指标计算服务 - logger.info("正在初始化智能指标计算服务...") - intelligence_metrics_service = IntelligenceMetricsService( - config=config, - db_manager=database_manager - ) - globals()['intelligence_metrics_service'] = intelligence_metrics_service - logger.info("智能指标计算服务初始化成功") - - except Exception as e: - logger.error(f"获取服务实例失败: {e}", exc_info=True) - globals()['persona_updater'] = None - globals()['database_manager'] = None - globals()['db_manager'] = None - globals()['progressive_learning'] = None - - # 加载待审查的人格更新 - if persona_updater: - try: - pending_updates = await persona_updater.get_pending_persona_updates() - except Exception as e: - logger.error(f"加载待审查人格更新失败: {e}") - pending_updates = [] - - # 加载密码配置 - global password_config - password_config = load_password_config() - -# API 蓝图 -api_bp = Blueprint("api", __name__, url_prefix="/api") - -@api_bp.route("/") -async def read_root(): - """根目录重定向""" - global password_config - password_config = load_password_config() # 每次访问根目录时重新加载密码配置,确保最新状态 - - # 如果用户已认证,检查是否需要强制更改密码 - if is_authenticated(): - if password_config.get("must_change"): - return redirect("/api/plugin_change_password") - return redirect(url_for("api.read_root_index")) - - # 未认证用户重定向到登录页 - return redirect(url_for("api.login_page")) - -@api_bp.route("/login", methods=["GET"]) -async def login_page(): - """显示登录页面""" - # 如果已登录,重定向到主页 - if is_authenticated(): - return redirect("/api/") - return await render_template("login.html") - -@api_bp.route("/login", methods=["POST"]) -async def login(): - """处理用户登录 - 支持MD5加密和暴力破解防护""" - # 获取客户端IP - client_ip = request.remote_addr or "unknown" - - # 检查IP是否被锁定 - is_locked, remaining_time = login_attempt_tracker.is_locked(client_ip) - if is_locked: - logger.warning(f"IP {client_ip} 被锁定,剩余 {remaining_time} 秒") - return jsonify({ - "error": f"登录尝试次数过多,请在 {remaining_time} 秒后重试", - "locked": True, - "remaining_time": remaining_time - }), 429 - - data = await request.get_json() - password = data.get("password", "") - - # 清理输入 - password = SecurityValidator.sanitize_input(password, max_length=128) - - if not password: - return jsonify({"error": "密码不能为空"}), 400 - - global password_config - password_config = load_password_config() - - # 使用支持迁移的验证函数 - is_valid, updated_config = verify_password_with_migration(password, password_config) - - if is_valid: - # 如果配置被更新(迁移),保存新配置 - if updated_config != password_config: - save_password_config(updated_config) - password_config = updated_config - - # 登录成功,清除失败记录 - login_attempt_tracker.record_attempt(client_ip, success=True) - - # 设置会话认证状态 - session['authenticated'] = True - session.permanent = True - - if password_config.get("must_change"): - return jsonify({ - "message": "Login successful, but password must be changed", - "must_change": True, - "redirect": "/api/plugin_change_password" - }), 200 - return jsonify({ - "message": "Login successful", - "must_change": False, - "redirect": "/api/index" - }), 200 - - # 登录失败,记录尝试 - login_attempt_tracker.record_attempt(client_ip, success=False) - remaining_attempts = login_attempt_tracker.get_remaining_attempts(client_ip) - - logger.warning(f"IP {client_ip} 登录失败,剩余尝试次数: {remaining_attempts}") - - error_msg = "密码错误" - if remaining_attempts <= 2: - error_msg = f"密码错误,还剩 {remaining_attempts} 次尝试机会" - - return jsonify({ - "error": error_msg, - "remaining_attempts": remaining_attempts - }), 401 - -@api_bp.route("/index") -@require_auth -async def read_root_index(): - """主页面""" - return await render_template("index.html") - -@api_bp.route("/plugin_change_password", methods=["GET"]) -async def change_password_page(): - """显示修改密码页面""" - # 检查是否已认证或者是强制更改密码状态 - if not is_authenticated(): - return redirect(url_for('api.login_page')) - - # 添加调试信息 - logger.debug(f"Template folder: {WEB_HTML_DIR}") - logger.debug(f"Looking for template: change_password.html") - template_path = os.path.join(WEB_HTML_DIR, "change_password.html") - logger.debug(f"Full template path: {template_path}") - logger.debug(f"Template exists: {os.path.exists(template_path)}") - - return await render_template("change_password.html") - -@api_bp.route("/plugin_change_password", methods=["POST"]) -async def change_password(): - """处理修改密码请求 - 支持MD5加密存储""" - # 检查是否已认证 - if not is_authenticated(): - return jsonify({"error": "Authentication required", "redirect": "/api/login"}), 401 - - data = await request.get_json() - old_password = data.get("old_password", "") - new_password = data.get("new_password", "") - - # 清理输入 - old_password = SecurityValidator.sanitize_input(old_password, max_length=128) - new_password = SecurityValidator.sanitize_input(new_password, max_length=128) - - if not old_password or not new_password: - return jsonify({"error": "旧密码和新密码不能为空"}), 400 - - global password_config - password_config = load_password_config() - - # 验证旧密码 - is_valid, _ = verify_password_with_migration(old_password, password_config) - if not is_valid: - return jsonify({"error": "当前密码错误"}), 401 - - # 检查新密码是否与旧密码相同 - if old_password == new_password: - return jsonify({"error": "新密码不能与当前密码相同"}), 400 - - # 验证新密码强度 - strength_result = SecurityValidator.validate_password_strength(new_password) - if not strength_result['valid']: - issues = "、".join(strength_result['issues']) if strength_result['issues'] else "密码强度不足" - return jsonify({"error": issues}), 400 - - # 生成新的哈希密码 - password_hash, salt = PasswordHasher.hash_password(new_password) - - # 更新配置 - password_config = { - "password_hash": password_hash, - "salt": salt, - "must_change": False, - "version": 2, - "last_changed": time.time() - } - save_password_config(password_config) - - logger.info("密码已更新为MD5哈希格式") - return jsonify({"message": "密码修改成功"}), 200 - -@api_bp.route("/logout", methods=["POST"]) -@require_auth -async def logout(): - """处理用户登出""" - session.clear() - return jsonify({"message": "Logged out successfully", "redirect": "/api/login"}), 200 - -@api_bp.route("/config") -@require_auth -async def get_plugin_config(): - """获取插件配置""" - if plugin_config: - return jsonify(asdict(plugin_config)) - return jsonify({"error": "Plugin config not initialized"}), 500 - -@api_bp.route("/config", methods=["POST"]) -@require_auth -async def update_plugin_config(): - """更新插件配置""" - if plugin_config: - new_config = await request.get_json() - for key, value in new_config.items(): - if hasattr(plugin_config, key): - setattr(plugin_config, key, value) - # TODO: 保存配置到文件 - return jsonify({"message": "Config updated successfully", "new_config": asdict(plugin_config)}) - return jsonify({"error": "Plugin config not initialized"}), 500 - - -@api_bp.route("/bug_report/config", methods=["GET"]) -@require_auth -async def get_bug_report_config(): - """获取Bug自助提交配置与日志预览""" - enabled = _bug_report_available() - log_preview = _collect_log_previews() - return jsonify({ - "enabled": enabled, - "cloudFunctionUrl": BUG_CLOUD_FUNCTION_URL, - "severityOptions": BUG_REPORT_SEVERITY_OPTIONS, - "priorityOptions": BUG_REPORT_PRIORITY_OPTIONS, - "typeOptions": BUG_REPORT_TYPE_OPTIONS, - "defaultBuild": BUG_REPORT_DEFAULT_BUILDS[0] if BUG_REPORT_DEFAULT_BUILDS else "", - "maxImages": 0 if not BUG_REPORT_ATTACHMENT_ENABLED else BUG_REPORT_MAX_IMAGES, # 禁用附件时为0 - "maxImageBytes": BUG_REPORT_MAX_IMAGE_BYTES, - "allowedExtensions": sorted(list(BUG_REPORT_ALLOWED_EXTENSIONS)) if BUG_REPORT_ATTACHMENT_ENABLED else [], - "attachmentEnabled": BUG_REPORT_ATTACHMENT_ENABLED, # 新增:告诉前端是否启用附件 - "logPreview": log_preview, - "message": "Bug自助提交通过云函数转发(暂不支持附件上传)" if enabled else "Bug自助提交功能暂不可用,请联系管理员" - }) - - -@api_bp.route("/bug_report", methods=["POST"]) -@require_auth -async def submit_bug_report(): - """提交Bug到禅道接口""" - if not _bug_report_available(): - return jsonify({"error": "Bug提交未配置或已禁用"}), 400 - - try: - form = await request.form - files = await request.files - except Exception as exc: - logger.error(f"解析Bug提交数据失败: {exc}") - return jsonify({"error": "提交内容解析失败"}), 400 - - title = (form.get("title") or "").strip() or "未命名问题" - severity = int(form.get("severity") or BUG_REPORT_DEFAULT_SEVERITY) - priority = int(form.get("priority") or BUG_REPORT_DEFAULT_PRIORITY) - bug_type = (form.get("bugType") or BUG_REPORT_DEFAULT_TYPE).strip() - build = (form.get("build") or (BUG_REPORT_DEFAULT_BUILDS[0] if BUG_REPORT_DEFAULT_BUILDS else "unknown")).strip() - steps = (form.get("steps") or "").strip() - description = (form.get("description") or "").strip() - environment = (form.get("environment") or "").strip() - include_logs = (form.get("includeLogs") or "true").lower() in ("1", "true", "yes", "on") - - request_meta = f"IP: {request.remote_addr or 'unknown'}\nUser-Agent: {request.headers.get('User-Agent', 'unknown')}" - full_description = description or "(未提供描述)" - if environment: - full_description += f"\n\n【运行环境】\n{environment}" - full_description += f"\n\n【请求元信息】\n{request_meta}" - - bug_fields = { - "title": title, - "severity": severity, - "pri": priority, - "type": bug_type, - "openedBuild": [build], - "steps": steps or "暂无明确的复现步骤", - "description": full_description, - "openedBy": "astrbot_plugin_self_learning" - } - - raw_attachments: List[Dict[str, Any]] = [] - - # 处理上传的文件 - # 检查附件功能是否启用 - if files and files.getlist("attachments") and not BUG_REPORT_ATTACHMENT_ENABLED: - return jsonify({"error": "附件上传功能暂时不可用,请稍后再试"}), 400 - - upload_list = files.getlist("attachments") if files else [] - for file_storage in upload_list: - if not file_storage: - continue - - original_filename = file_storage.filename or f"screenshot_{int(time.time())}.png" - filename = secure_filename(original_filename) - mimetype = file_storage.mimetype or "" - - # 安全检查:验证文件类型 - is_safe, error_msg = _is_safe_attachment(filename, mimetype) - if not is_safe: - logger.warning(f"拒绝不安全的附件上传: {filename}, 原因: {error_msg}") - return jsonify({"error": f"附件安全检查失败: {error_msg}"}), 400 - - file_bytes = await file_storage.read() - if not file_bytes: - continue - if len(file_bytes) > BUG_REPORT_MAX_IMAGE_BYTES: - return jsonify({"error": f"单个附件不能超过 {BUG_REPORT_MAX_IMAGE_BYTES // (1024 * 1024)}MB"}), 400 - raw_attachments.append({ - "filename": filename or "screenshot.png", - "content_type": file_storage.mimetype or "image/png", - "data": file_bytes - }) - if len(raw_attachments) >= BUG_REPORT_MAX_IMAGES: - break - - try: - # 自动附带日志摘要到描述中 - if include_logs: - log_previews = _collect_log_previews(limit=2, include_content=True) - if log_previews: - log_text_sections = ["\n\n【自动附带日志摘要】"] - for log in log_previews: - content = log.get("content", "") - if not content: - continue - tail = content[-BUG_REPORT_MAX_LOG_BYTES:] - log_text_sections.append(f"--- {log['path']} | 最近 {len(tail)} 字节 ---\n{tail}") - if len(log_text_sections) > 1: - full_description += "\n".join(log_text_sections) - - bug_fields["description"] = full_description - - # 使用新的编码函数处理附件(参考测试脚本) - attachment_dict = None - if raw_attachments: - # 只取第一个附件 - first_attachment = raw_attachments[0] - logger.info(f"准备编码附件: filename={first_attachment['filename']}, size={len(first_attachment['data'])} bytes, type={first_attachment['content_type']}") - - try: - attachment_dict = _encode_attachment_from_bytes( - filename=first_attachment["filename"], - file_bytes=first_attachment["data"], - content_type=first_attachment["content_type"] - ) - logger.info(f"附件编码成功: name={attachment_dict['name']}, type={attachment_dict['type']}, data_length={len(attachment_dict['data'])}") - except Exception as e: - logger.error(f"附件编码失败: {e}", exc_info=True) - return jsonify({"error": f"附件编码失败: {str(e)}"}), 500 - - # 如果有多个附件,添加警告 - if len(raw_attachments) > 1: - warning_msg = f"\n\n⚠️ 注意:检测到 {len(raw_attachments)} 个附件,但服务器支持单个附件。仅第一个附件 '{first_attachment['filename']}' 将被提交。如需提交多个文件,建议打包为压缩包后上传。" - bug_fields["description"] += warning_msg - logger.warning(f"Bug提交包含多个附件({len(raw_attachments)}个),只会提交第一个: {first_attachment['filename']}") - - # 调用发送函数(完全参考测试脚本) - logger.info(f"准备发送Bug报告: has_attachment={attachment_dict is not None}") - result = await _send_bug_report(bug_fields, attachment_dict) - logger.info(f"Bug提交结果: success={result.get('success')}, status={result.get('status')}, message={result.get('message', '')[:200]}") - if result.get("success"): - data = result.get("data", {}) - bug_id = data.get("id") - return jsonify({ - "success": True, - "bugId": bug_id, - "message": f"Bug提交成功 (ID: {bug_id})" if bug_id else "Bug提交成功", - "response": data - }) - return jsonify({ - "error": result.get("message", "Bug提交失败"), - "status": result.get("status") - }), 502 - except Exception as exc: - logger.error(f"Bug提交异常: {exc}", exc_info=True) - return jsonify({"error": f"Bug提交异常: {exc}"}), 500 - -@api_bp.route("/persona_updates") -@require_auth -async def get_persona_updates(): - """获取需要人工审查的人格更新内容(包括风格学习审查和人格学习审查)- 支持分页""" - # 获取分页参数 - 默认每页50条记录,实现懒加载 - limit = request.args.get('limit', default=50, type=int) - offset = request.args.get('offset', default=0, type=int) - - logger.info(f"开始获取persona_updates数据... limit={limit}, offset={offset}") - all_updates = [] - - # 1. 获取传统的人格更新审查 - if persona_updater: - try: - logger.info("正在获取传统人格更新...") - traditional_updates = await persona_updater.get_pending_persona_updates() - logger.info(f"获取到 {len(traditional_updates)} 个传统人格更新") - - # 将PersonaUpdateRecord对象转换为字典格式,确保数据完整 - for record in traditional_updates: - # 使用dataclass的asdict或手动转换 - if hasattr(record, '__dict__'): - record_dict = record.__dict__.copy() - else: - # 手动构建字典 - record_dict = { - 'id': getattr(record, 'id', None), - 'timestamp': getattr(record, 'timestamp', 0), - 'group_id': getattr(record, 'group_id', 'default'), - 'update_type': getattr(record, 'update_type', 'unknown'), - 'original_content': getattr(record, 'original_content', ''), - 'new_content': getattr(record, 'new_content', ''), - 'reason': getattr(record, 'reason', ''), - 'status': getattr(record, 'status', 'pending'), - 'reviewer_comment': getattr(record, 'reviewer_comment', None), - 'review_time': getattr(record, 'review_time', None) - } - - # 添加一些前端需要的字段 - record_dict['proposed_content'] = record_dict.get('new_content', '') - record_dict['confidence_score'] = 0.8 # 默认置信度 - record_dict['reviewed'] = record_dict.get('status', 'pending') != 'pending' - record_dict['approved'] = record_dict.get('status', 'pending') == 'approved' - record_dict['review_source'] = 'traditional' # 标记来源 - - all_updates.append(record_dict) - - except Exception as e: - logger.error(f"获取传统人格更新失败: {e}") - else: - logger.warning("persona_updater 不可用") - - # 2. 获取人格学习审查(包括渐进式学习、表达学习等) - if database_manager: - try: - logger.info("正在获取人格学习审查...") - # ✅ 懒加载优化:计算需要加载多少条记录(考虑分页) - # 保守估计:加载 offset + limit * 1.5 条记录,以应对可能的过滤 - fetch_limit = min(offset + int(limit * 1.5), 1000) # 最多加载1000条 - persona_learning_reviews = await database_manager.get_pending_persona_learning_reviews(limit=fetch_limit) - logger.info(f"获取到 {len(persona_learning_reviews)} 个人格学习审查") - - for review in persona_learning_reviews: - # ✅ 使用新的常量进行类型标准化和分类 - raw_update_type = review.get('update_type', '') - normalized_type = normalize_update_type(raw_update_type) - review_source = get_review_source_from_update_type(raw_update_type) - - # ✅ 修复:只跳过真正的风格学习(精确匹配) - # 渐进式人格学习不再被误判为风格学习 - if normalized_type == UPDATE_TYPE_STYLE_LEARNING: - # Few-shot风格学习在步骤3单独处理,这里跳过 - logger.debug(f"跳过风格学习记录 ID={review['id']},在步骤3处理") - continue - - # ✅ 获取原人格文本(如果数据库中为空,实时获取) - original_content = review['original_content'] - group_id = review['group_id'] - - if not original_content or original_content.strip() == '': - # 数据库中没有原人格,实时获取 - logger.info(f"数据库中没有原人格文本,实时获取群组 {group_id} 的原人格") - try: - if persona_manager: - current_persona = await persona_manager.get_default_persona_v3(_resolve_umo(group_id)) - if current_persona and current_persona.get('prompt'): - original_content = current_persona.get('prompt', '') - logger.info(f"成功获取群组 {group_id} 的原人格文本,长度: {len(original_content)}") - else: - original_content = "[无法获取原人格文本]" - logger.warning(f"无法获取群组 {group_id} 的原人格文本") - else: - original_content = "[PersonaManager未初始化]" - logger.warning("PersonaManager未初始化,无法获取原人格") - except Exception as e: - logger.warning(f"获取群组 {group_id} 原人格失败: {e}") - original_content = f"[获取原人格失败: {str(e)}]" - - # 转换为统一的审查格式 - review_dict = { - # ✅ 根据review_source决定ID前缀 - 'id': f"persona_learning_{review['id']}" if review_source == 'persona_learning' else str(review['id']), - 'timestamp': review['timestamp'], - 'group_id': group_id, - 'update_type': raw_update_type, # 保留原始类型用于显示 - 'normalized_type': normalized_type, # 添加标准化类型 - 'original_content': original_content, # ✅ 使用获取到的原人格文本 - 'new_content': review['new_content'], - 'proposed_content': review.get('proposed_content', review['new_content']), - 'reason': review['reason'], - 'status': review['status'], - 'reviewer_comment': review['reviewer_comment'], - 'review_time': review['review_time'], - 'confidence_score': review.get('confidence_score', 0.5), - 'reviewed': False, - 'approved': False, - 'review_source': review_source, - 'persona_learning_review_id': review['id'], # 原始ID用于审批操作 - # 添加metadata中的关键字段到顶层,方便前端访问 - 'features_content': review.get('metadata', {}).get('features_content', ''), - 'llm_response': review.get('metadata', {}).get('llm_response', ''), - 'total_raw_messages': review.get('metadata', {}).get('total_raw_messages', 0), - 'messages_analyzed': review.get('metadata', {}).get('messages_analyzed', 0), - 'metadata': review.get('metadata', {}), # 保留完整的metadata - # ✅ 新增:从metadata提取高亮位置信息 - 'incremental_content': review.get('metadata', {}).get('incremental_content', ''), - 'incremental_start_pos': review.get('metadata', {}).get('incremental_start_pos', 0) - } - - all_updates.append(review_dict) - logger.debug(f"添加审查记录: ID={review_dict['id']}, type={raw_update_type}, source={review_source}") - - except Exception as e: - logger.error(f"获取人格学习审查失败: {e}", exc_info=True) - else: - logger.warning("database_manager 不可用") - - # 3. 获取风格学习审查(Few-shot样本学习) - if database_manager: - try: - logger.info("正在获取风格学习审查...") - # ✅ 懒加载优化:计算需要加载多少条记录(考虑分页) - fetch_limit = min(offset + int(limit * 1.5), 1000) # 最多加载1000条 - style_reviews = await database_manager.get_pending_style_reviews(limit=fetch_limit) - logger.info(f"获取到 {len(style_reviews)} 个风格学习审查") - - for review in style_reviews: - # ✅ 获取当前群组的原人格文本 - group_id = review['group_id'] - original_persona_text = "" - - try: - # 通过 persona_manager 获取当前人格 - if persona_manager: - current_persona = await persona_manager.get_default_persona_v3(_resolve_umo(group_id)) - if current_persona and current_persona.get('prompt'): - original_persona_text = current_persona.get('prompt', '') - else: - original_persona_text = "[无法获取原人格文本]" - else: - original_persona_text = "[PersonaManager未初始化]" - except Exception as e: - logger.warning(f"获取群组 {group_id} 原人格失败: {e}") - original_persona_text = f"[获取原人格失败: {str(e)}]" - - # ✅ 构建完整的新内容(原人格 + Few-shot内容) - few_shots_content = review['few_shots_content'] - full_new_content = original_persona_text + "\n\n" + few_shots_content if original_persona_text else few_shots_content - - # 转换为统一的审查格式 - review_dict = { - 'id': f"style_{review['id']}", # 添加前缀避免ID冲突 - 'timestamp': review['timestamp'], - 'group_id': group_id, - 'update_type': UPDATE_TYPE_STYLE_LEARNING, # ✅ 使用常量 - 'normalized_type': UPDATE_TYPE_STYLE_LEARNING, - 'original_content': original_persona_text, # ✅ 使用实际的原人格文本 - 'new_content': full_new_content, # ✅ 原人格 + Few-shot内容 - 'proposed_content': few_shots_content, # 保持为增量部分 - 'reason': review['description'], - 'status': review['status'], - 'reviewer_comment': None, - 'review_time': None, - 'confidence_score': 0.9, # 风格学习置信度高一些 - 'reviewed': False, - 'approved': False, - 'review_source': 'style_learning', # 标记来源 - 'learned_patterns': review.get('learned_patterns', []), # 额外信息 - 'style_review_id': review['id'], # 原始ID用于审批操作 - # ✅ 新增:方便前端计算高亮位置 - 'incremental_start_pos': len(original_persona_text) + 2 if original_persona_text else 0 # +2 是因为有 \n\n - } - - all_updates.append(review_dict) - - except Exception as e: - logger.error(f"获取风格学习审查失败: {e}") - - # 按时间倒序排列 - all_updates.sort(key=lambda x: x.get('timestamp', 0), reverse=True) - - total_count = len(all_updates) - - # 应用分页 - if limit is not None: - end_index = offset + limit - paginated_updates = all_updates[offset:end_index] - logger.info(f"分页返回 {len(paginated_updates)}/{total_count} 条记录 (offset={offset}, limit={limit})") - else: - paginated_updates = all_updates - logger.info(f"返回全部 {total_count} 条记录(未分页)") - - logger.info(f"返回数据统计 - 传统: {len([u for u in paginated_updates if u['review_source'] == 'traditional'])}, 人格学习: {len([u for u in paginated_updates if u['review_source'] == 'persona_learning'])}, 风格学习: {len([u for u in paginated_updates if u['review_source'] == 'style_learning'])})") - - return jsonify({ - "success": True, - "updates": paginated_updates, - "total": total_count, - "offset": offset, - "limit": limit if limit is not None else total_count - }) - -@api_bp.route("/persona_updates//review", methods=["POST"]) -@require_auth -async def review_persona_update(update_id: str): - """审查人格更新内容 (批准/拒绝) - 包括风格学习审查和人格学习审查""" - try: - # 获取全局服务实例并进行调试检查 - global persona_updater, database_manager - - logger.info(f"=== 开始审查人格更新 {update_id} ===") - logger.info(f"全局persona_updater状态: {persona_updater is not None}") - logger.info(f"全局database_manager状态: {database_manager is not None}") - - if persona_updater: - logger.info(f"PersonaUpdater类型: {type(persona_updater)}") - logger.info(f"PersonaUpdater backup_manager状态: {hasattr(persona_updater, 'backup_manager')}") - if hasattr(persona_updater, 'backup_manager'): - logger.info(f"backup_manager类型: {type(persona_updater.backup_manager)}") - - if database_manager: - logger.info(f"DatabaseManager类型: {type(database_manager)}") - - data = await request.get_json() - action = data.get("action") - comment = data.get("comment", "") - modified_content = data.get("modified_content") # 用户修改后的内容 - - logger.info(f"审查操作: {action}, 有修改内容: {modified_content is not None}") - - # 将action转换为合适的status - if action == "approve": - status = "approved" - elif action == "reject": - status = "rejected" - else: - return jsonify({"error": "Invalid action, must be 'approve' or 'reject'"}), 400 - - # 判断审查类型 - if update_id.startswith("style_"): - # 风格学习审查 - style_review_id = int(update_id.replace("style_", "")) - - if action == "approve": - # 批准风格学习审查 - return await approve_style_learning_review(style_review_id) - else: - # 拒绝风格学习审查 - return await reject_style_learning_review(style_review_id) - - elif update_id.startswith("persona_learning_"): - # 人格学习审查(质量不达标的学习结果) - persona_learning_review_id = int(update_id.replace("persona_learning_", "")) - - if not database_manager: - return jsonify({"error": "Database manager not initialized"}), 500 - - # 更新审查状态,并保存修改后的内容和审查备注 - success = await database_manager.update_persona_learning_review_status( - persona_learning_review_id, status, comment, modified_content - ) - - if success: - if action == "approve": - # 批准后应用人格更新并备份 - try: - # 获取人格学习审查详情 - review_data = await database_manager.get_persona_learning_review_by_id(persona_learning_review_id) - if review_data: - # 使用修改后的内容(如果有)或原始proposed_content - content_to_apply = modified_content if modified_content else review_data.get('proposed_content') - group_id = review_data.get('group_id', 'default') - message = f"人格学习审查 {persona_learning_review_id} 已批准" - - # ===== 自动应用到框架默认人格(独立于persona_updater) ===== - auto_apply_enabled = plugin_config and getattr(plugin_config, 'auto_apply_approved_persona', False) - logger.info(f"[自动应用] 检查配置: auto_apply={auto_apply_enabled}, persona_manager={persona_manager is not None}, content={content_to_apply is not None and len(content_to_apply) if content_to_apply else 0}") - if content_to_apply and auto_apply_enabled and persona_manager: - try: - umo = _resolve_umo(group_id) - current_persona = await persona_manager.get_default_persona_v3(umo) - if current_persona: - p_name = current_persona.get('name', 'default') - logger.info(f"[自动应用] 准备更新默认人格 [{p_name}],内容长度: {len(content_to_apply)},群组: {group_id}") - await persona_manager.update_persona( - persona_id=p_name, - system_prompt=content_to_apply - ) - logger.info(f"[自动应用] ✅ 已将人格学习审查内容应用到默认人格 [{p_name}]") - message += f",已自动应用到默认人格 [{p_name}]" - else: - logger.warning("[自动应用] 无法获取当前默认人格") - except Exception as auto_err: - logger.error(f"[自动应用] ❌ 应用到默认人格失败: {auto_err}", exc_info=True) - message += f",但自动应用到默认人格失败: {str(auto_err)}" - - # ===== 原有的update_persona_with_style逻辑(备份+内存更新) ===== - if persona_updater and content_to_apply: - try: - logger.info(f"开始应用人格学习审查 {persona_learning_review_id},群组: {group_id}") - style_analysis = { - 'enhanced_prompt': content_to_apply, - 'style_features': [], - 'style_attributes': {}, - 'confidence': 0.8, - 'source': f'人格学习审查{persona_learning_review_id}' - } - success_apply = await persona_updater.update_persona_with_style( - group_id, style_analysis, [] - ) - if success_apply: - logger.info(f"✅ 人格学习审查 {persona_learning_review_id} 备份和内存更新完成") - else: - logger.warning(f"❌ 人格学习审查 {persona_learning_review_id} update_persona_with_style返回False") - except Exception as apply_error: - logger.error(f"❌ update_persona_with_style失败: {apply_error}", exc_info=True) - - else: - logger.error(f"无法获取人格学习审查 {persona_learning_review_id} 的详情") - message = f"人格学习审查 {persona_learning_review_id} 已批准,但无法获取详情" - except Exception as e: - logger.error(f"应用人格学习审查失败: {e}", exc_info=True) - message = f"人格学习审查 {persona_learning_review_id} 已批准,但应用过程出错: {str(e)}" - else: - message = f"人格学习审查 {persona_learning_review_id} 已拒绝" - - return jsonify({"success": True, "message": message}) - else: - return jsonify({"error": "Failed to update persona learning review status"}), 500 - - else: - # 传统人格审查 - if persona_updater: - # 传递modified_content参数 - result = await persona_updater.review_persona_update(int(update_id), status, comment, modified_content) - if result: - return jsonify({"success": True, "message": f"人格更新 {update_id} 已{action}"}) - else: - return jsonify({"error": "Failed to update persona review status"}), 500 - else: - return jsonify({"error": "Persona updater not initialized"}), 500 - - except ValueError as e: - return jsonify({"error": f"Invalid update_id format: {str(e)}"}), 400 - except Exception as e: - logger.error(f"审查人格更新失败: {e}") - return jsonify({"error": str(e)}), 500 - -@api_bp.route("/persona_updates/reviewed", methods=["GET"]) -@require_auth -async def get_reviewed_persona_updates(): - """获取已审查的人格更新列表""" - try: - limit = request.args.get('limit', 50) - offset = request.args.get('offset', 0) - status_filter = request.args.get('status') # 'approved' 或 'rejected' 或 None - - # 获取已审查的人格更新记录 - reviewed_updates = [] - - # 从传统人格更新审查获取 - if persona_updater: - traditional_updates = await persona_updater.get_reviewed_persona_updates(limit, offset, status_filter) - reviewed_updates.extend(traditional_updates) - - # 从人格学习审查获取 - if database_manager: - persona_learning_updates = await database_manager.get_reviewed_persona_learning_updates(limit, offset, status_filter) - reviewed_updates.extend(persona_learning_updates) - - # 从风格学习审查获取 - if database_manager: - style_updates = await database_manager.get_reviewed_style_learning_updates(limit, offset, status_filter) - # 将风格审查转换为统一格式 - for update in style_updates: - if 'id' in update: - update['id'] = f"style_{update['id']}" - reviewed_updates.extend(style_updates) - - # 按审查时间排序 - reviewed_updates.sort(key=lambda x: x.get('review_time', 0), reverse=True) - - return jsonify({ - "success": True, - "updates": reviewed_updates, - "total": len(reviewed_updates) - }) - - except Exception as e: - logger.error(f"获取已审查人格更新失败: {e}") - return jsonify({"error": str(e)}), 500 - -@api_bp.route("/persona_updates//revert", methods=["POST"]) -@require_auth -async def revert_persona_update(update_id: str): - """撤回人格更新审查""" - try: - data = await request.get_json() - reason = data.get("reason", "撤回审查决定") - - # 判断撤回类型 - if update_id.startswith("style_"): - # 风格学习审查撤回 - style_review_id = int(update_id.replace("style_", "")) - - if not database_manager: - return jsonify({"error": "Database manager not initialized"}), 500 - - # 将状态改回pending - success = await database_manager.update_style_review_status( - style_review_id, "pending" - ) - - if success: - message = f"风格学习审查 {style_review_id} 已撤回,重新回到待审查状态" - return jsonify({"success": True, "message": message}) - else: - return jsonify({"error": "Failed to revert style learning review"}), 500 - - elif update_id.startswith("persona_learning_"): - # 人格学习审查撤回 - persona_learning_review_id = int(update_id.replace("persona_learning_", "")) - - if not database_manager: - return jsonify({"error": "Database manager not initialized"}), 500 - - # 将状态改回pending - success = await database_manager.update_persona_learning_review_status( - persona_learning_review_id, "pending", f"撤回操作: {reason}" - ) - - if success: - message = f"人格学习审查 {persona_learning_review_id} 已撤回,重新回到待审查状态" - return jsonify({"success": True, "message": message}) - else: - return jsonify({"error": "Failed to revert persona learning review"}), 500 - else: - # 传统人格审查撤回 - if persona_updater: - result = await persona_updater.revert_persona_update_review(int(update_id), reason) - if result: - message = f"人格更新 {update_id} 审查已撤回,重新回到待审查状态" - return jsonify({"success": True, "message": message}) - else: - return jsonify({"error": "Failed to revert persona update review"}), 500 - else: - return jsonify({"error": "Persona updater not initialized"}), 500 - - except ValueError as e: - return jsonify({"error": f"Invalid update_id format: {str(e)}"}), 400 - except Exception as e: - logger.error(f"撤回人格更新审查失败: {e}") - return jsonify({"error": str(e)}), 500 - -# 删除人格更新审查记录 -@api_bp.route("/persona_updates//delete", methods=["POST"]) -@require_auth -async def delete_persona_update(update_id): - """删除人格更新审查记录""" - try: - # 使用全局变量而不是 current_app.plugin_instance - global database_manager, persona_updater - if not database_manager: - return jsonify({"error": "Database manager not available"}), 500 - - # 解析update_id,处理前缀(persona_learning_、style_) - if isinstance(update_id, str): - if update_id.startswith("persona_learning_"): - numeric_id = int(update_id.replace("persona_learning_", "")) - # 删除人格学习审查记录 - success = await database_manager.delete_persona_learning_review_by_id(numeric_id) - if success: - message = f"人格学习审查记录 {numeric_id} 已删除" - return jsonify({"success": True, "message": message}) - else: - return jsonify({"error": f"未找到人格学习审查记录: {numeric_id}"}), 404 - - elif update_id.startswith("style_"): - numeric_id = int(update_id.replace("style_", "")) - # 删除风格学习审查记录 - success = await database_manager.delete_style_review_by_id(numeric_id) - if success: - message = f"风格学习审查记录 {numeric_id} 已删除" - return jsonify({"success": True, "message": message}) - else: - return jsonify({"error": f"未找到风格学习审查记录: {numeric_id}"}), 404 - - else: - # 尝试作为纯数字ID处理 - try: - numeric_id = int(update_id) - except ValueError: - return jsonify({"error": f"无效的ID格式: {update_id}"}), 400 - else: - numeric_id = int(update_id) - - # 尝试删除人格学习审查记录 - success = await database_manager.delete_persona_learning_review_by_id(numeric_id) - - if success: - message = f"人格学习审查记录 {numeric_id} 已删除" - return jsonify({"success": True, "message": message}) - else: - # 如果人格学习审查记录不存在,尝试删除传统人格审查记录 - if persona_updater: - result = await persona_updater.delete_persona_update_review(numeric_id) - if result: - message = f"人格更新审查记录 {numeric_id} 已删除" - return jsonify({"success": True, "message": message}) - else: - return jsonify({"error": "Record not found"}), 404 - else: - return jsonify({"error": "Record not found"}), 404 - - except Exception as e: - logger.error(f"删除人格更新审查记录失败: {e}") - return jsonify({"error": str(e)}), 500 - -# 批量删除人格更新审查记录 -@api_bp.route("/persona_updates/batch_delete", methods=["POST"]) -@require_auth -async def batch_delete_persona_updates(): - """批量删除人格更新审查记录""" - try: - data = await request.get_json() - update_ids = data.get('update_ids', []) - - if not update_ids or not isinstance(update_ids, list): - return jsonify({"error": "update_ids is required and must be a list"}), 400 - - # 使用全局变量而不是 current_app.plugin_instance - global database_manager, persona_updater - if not database_manager: - return jsonify({"error": "Database manager not available"}), 500 - - success_count = 0 - failed_count = 0 - - for update_id in update_ids: - try: - # 解析update_id,处理前缀(persona_learning_、style_) - if isinstance(update_id, str): - if update_id.startswith("persona_learning_"): - numeric_id = int(update_id.replace("persona_learning_", "")) - # 删除人格学习审查记录 - success = await database_manager.delete_persona_learning_review_by_id(numeric_id) - if success: - success_count += 1 - else: - failed_count += 1 - logger.warning(f"未找到人格学习审查记录: {numeric_id}") - elif update_id.startswith("style_"): - numeric_id = int(update_id.replace("style_", "")) - # 删除风格学习审查记录 - success = await database_manager.delete_style_review_by_id(numeric_id) - if success: - success_count += 1 - else: - failed_count += 1 - logger.warning(f"未找到风格学习审查记录: {numeric_id}") - else: - # 纯数字ID,尝试删除传统人格审查记录 - numeric_id = int(update_id) - if persona_updater: - result = await persona_updater.delete_persona_update_review(numeric_id) - if result: - success_count += 1 - else: - failed_count += 1 - logger.warning(f"未找到传统人格审查记录: {numeric_id}") - else: - failed_count += 1 - logger.warning("persona_updater不可用") - else: - # 纯数字ID - numeric_id = int(update_id) - # 先尝试删除人格学习审查记录 - success = await database_manager.delete_persona_learning_review_by_id(numeric_id) - - if success: - success_count += 1 - else: - # 如果人格学习审查记录不存在,尝试删除传统人格审查记录 - if persona_updater: - result = await persona_updater.delete_persona_update_review(numeric_id) - if result: - success_count += 1 - else: - failed_count += 1 - else: - failed_count += 1 - - except Exception as e: - logger.error(f"删除人格更新审查记录 {update_id} 失败: {e}") - failed_count += 1 - - return jsonify({ - "success": True, - "message": f"批量删除完成:成功 {success_count} 条,失败 {failed_count} 条", - "details": { - "success_count": success_count, - "failed_count": failed_count, - "total_count": len(update_ids) - } - }) - - except Exception as e: - logger.error(f"批量删除人格更新审查记录失败: {e}") - return jsonify({"error": str(e)}), 500 - -@api_bp.route("/persona_updates/delete_all", methods=["POST"]) -@require_auth -async def delete_all_persona_reviews(): - """删除所有人格学习审查记录(危险操作)""" - try: - data = await request.get_json() - group_id = data.get('group_id') if data else None # 可选:只删除指定群组的记录 - - # 使用全局变量 - global database_manager - if not database_manager: - return jsonify({"error": "Database manager not available"}), 500 - - # 执行批量删除 - deleted_count = await database_manager.delete_all_persona_learning_reviews(group_id=group_id) - - if group_id: - message = f"成功删除群组 {group_id} 的所有人格学习审查记录,共 {deleted_count} 条" - else: - message = f"成功删除所有人格学习审查记录,共 {deleted_count} 条" - - logger.info(message) - - return jsonify({ - "success": True, - "message": message, - "deleted_count": deleted_count - }) - - except Exception as e: - logger.error(f"删除所有人格学习审查记录失败: {e}") - return jsonify({"error": str(e)}), 500 - -# 批量操作人格更新审查记录(批准、拒绝) -@api_bp.route("/persona_updates/batch_review", methods=["POST"]) -@require_auth -async def batch_review_persona_updates(): - """批量审查人格更新记录""" - try: - data = await request.get_json() - update_ids = data.get('update_ids', []) - action = data.get('action') # 'approve' or 'reject' - comment = data.get('comment', '') - - if not update_ids or not isinstance(update_ids, list): - return jsonify({"error": "update_ids is required and must be a list"}), 400 - - if action not in ['approve', 'reject']: - return jsonify({"error": "action must be 'approve' or 'reject'"}), 400 - - # 使用全局变量而不是 current_app.plugin_instance - global database_manager, persona_updater - if not database_manager: - return jsonify({"error": "Database manager not available"}), 500 - - success_count = 0 - failed_count = 0 - - for update_id in update_ids: - try: - # 解析update_id,处理前缀(persona_learning_、style_) - if isinstance(update_id, str): - if update_id.startswith("persona_learning_"): - # 人格学习审查记录 - numeric_id = int(update_id.replace("persona_learning_", "")) - review_data = await database_manager.get_persona_learning_review_by_id(numeric_id) - - if review_data: - # ===== 先执行自动应用(不依赖数据库状态更新) ===== - if action == 'approve': - content_to_apply = review_data.get('proposed_content') or review_data.get('new_content') - group_id = review_data.get('group_id', 'default') - - auto_apply_enabled = plugin_config and getattr(plugin_config, 'auto_apply_approved_persona', False) - logger.info(f"[自动应用-批量] 检查配置: auto_apply={auto_apply_enabled}, persona_manager={persona_manager is not None}, content={len(content_to_apply) if content_to_apply else 0}") - if content_to_apply and auto_apply_enabled and persona_manager: - try: - umo = _resolve_umo(group_id) - current_persona = await persona_manager.get_default_persona_v3(umo) - if current_persona: - p_name = current_persona.get('name', 'default') - logger.info(f"[自动应用-批量] 准备更新默认人格 [{p_name}],内容长度: {len(content_to_apply)}") - await persona_manager.update_persona( - persona_id=p_name, - system_prompt=content_to_apply - ) - logger.info(f"[自动应用-批量] ✅ 已将 {update_id} 内容应用到默认人格 [{p_name}]") - except Exception as auto_err: - logger.error(f"[自动应用-批量] ❌ 应用失败: {auto_err}", exc_info=True) - - # 更新数据库审查状态(可能因event loop问题失败) - status = 'approved' if action == 'approve' else 'rejected' - try: - success = await database_manager.update_persona_learning_review_status( - numeric_id, status, comment - ) - except Exception as db_err: - logger.error(f"更新审查状态失败(event loop问题): {db_err}") - success = False - - if success: - success_count += 1 - else: - # 即使数据库更新失败,如果自动应用成功了也算部分成功 - if action == 'approve' and auto_apply_enabled: - success_count += 1 - logger.info(f"批量审查 {update_id} 数据库状态更新失败,但自动应用已执行") - else: - failed_count += 1 - else: - failed_count += 1 - logger.warning(f"未找到人格学习审查记录: {numeric_id}") - - elif update_id.startswith("style_"): - # 风格学习审查记录 - numeric_id = int(update_id.replace("style_", "")) - status = 'approved' if action == 'approve' else 'rejected' - - if action == 'approve': - # 获取审查详情用于auto-apply - pending_reviews = await database_manager.get_pending_style_reviews() - target_review = None - for rev in pending_reviews: - if rev['id'] == numeric_id: - target_review = rev - break - - success = await database_manager.update_style_review_status(numeric_id, status) - - if success: - success_count += 1 - logger.info(f"风格学习审查 {update_id} 已{status}") - - # ===== 自动应用到框架默认人格 ===== - if action == 'approve' and target_review and target_review.get('few_shots_content'): - auto_apply_enabled = plugin_config and getattr(plugin_config, 'auto_apply_approved_persona', False) - logger.info(f"[自动应用-批量] 风格审查配置: auto_apply={auto_apply_enabled}, persona_manager={persona_manager is not None}") - if auto_apply_enabled and persona_manager: - try: - group_id = target_review.get('group_id', 'default') - umo = _resolve_umo(group_id) - current_persona = await persona_manager.get_default_persona_v3(umo) - if current_persona: - p_name = current_persona.get('name', 'default') - content = target_review['few_shots_content'] - logger.info(f"[自动应用-批量] 准备更新默认人格 [{p_name}],风格内容长度: {len(content)}") - await persona_manager.update_persona( - persona_id=p_name, - system_prompt=content - ) - logger.info(f"[自动应用-批量] ✅ 已将风格 {update_id} 内容应用到默认人格 [{p_name}]") - except Exception as auto_err: - logger.error(f"[自动应用-批量] ❌ 风格应用失败: {auto_err}", exc_info=True) - else: - failed_count += 1 - logger.warning(f"未找到风格学习审查记录: {numeric_id}") - else: - # 尝试作为纯数字ID处理(传统人格审查记录) - numeric_id = int(update_id) - if persona_updater: - status = "approved" if action == 'approve' else "rejected" - result = await persona_updater.review_persona_update(numeric_id, status, comment) - if result: - success_count += 1 - else: - failed_count += 1 - else: - failed_count += 1 - else: - # 纯数字ID - 尝试人格学习审查记录 - numeric_id = int(update_id) - review_data = await database_manager.get_persona_learning_review_by_id(numeric_id) - - if review_data: - # 人格学习审查记录 - status = 'approved' if action == 'approve' else 'rejected' - success = await database_manager.update_persona_learning_review_status( - numeric_id, status, comment - ) - - if success and action == 'approve': - # 如果批准,还需要应用人格更新 - content_to_apply = review_data.get('proposed_content') or review_data.get('new_content') - group_id = review_data.get('group_id', 'default') - - # ===== 自动应用到框架默认人格 ===== - auto_apply_enabled = plugin_config and getattr(plugin_config, 'auto_apply_approved_persona', False) - logger.info(f"[自动应用-批量-数字ID] 检查配置: auto_apply={auto_apply_enabled}, content={len(content_to_apply) if content_to_apply else 0}") - if content_to_apply and auto_apply_enabled and persona_manager: - try: - umo = _resolve_umo(group_id) - current_persona = await persona_manager.get_default_persona_v3(umo) - if current_persona: - p_name = current_persona.get('name', 'default') - logger.info(f"[自动应用-批量-数字ID] 准备更新默认人格 [{p_name}]") - await persona_manager.update_persona( - persona_id=p_name, - system_prompt=content_to_apply - ) - logger.info(f"[自动应用-批量-数字ID] ✅ 已应用到默认人格 [{p_name}]") - except Exception as auto_err: - logger.error(f"[自动应用-批量-数字ID] ❌ 失败: {auto_err}", exc_info=True) - - if persona_updater and content_to_apply: - try: - style_analysis = { - 'enhanced_prompt': content_to_apply, - 'style_features': [], - 'style_attributes': {}, - 'confidence': 0.8, - 'source': f'批量审查{update_id}' - } - - success_apply = await persona_updater.update_persona_with_style( - review_data.get('group_id', 'default'), - style_analysis, - [] - ) - - if success_apply: - logger.info(f"批量审查 {update_id} 已成功应用到人格(使用框架API方式)") - else: - logger.warning(f"批量审查 {update_id} 应用失败") - - except Exception as apply_error: - logger.error(f"批量审查 {update_id} 应用过程出错: {apply_error}") - - if success: - success_count += 1 - else: - failed_count += 1 - else: - # 传统人格审查记录 - if persona_updater: - status = "approved" if action == 'approve' else "rejected" - result = await persona_updater.review_persona_update(numeric_id, status, comment) - if result: - success_count += 1 - else: - failed_count += 1 - else: - failed_count += 1 - - except Exception as e: - logger.error(f"批量审查人格更新记录 {update_id} 失败: {e}") - failed_count += 1 - - action_text = "批准" if action == 'approve' else "拒绝" - return jsonify({ - "success": True, - "message": f"批量{action_text}完成:成功 {success_count} 条,失败 {failed_count} 条", - "details": { - "success_count": success_count, - "failed_count": failed_count, - "total_count": len(update_ids) - } - }) - - except Exception as e: - logger.error(f"批量审查人格更新记录失败: {e}") - return jsonify({"error": str(e)}), 500 - -# 添加一个测试接口,用于创建测试数据 -@api_bp.route("/test/create_persona_update", methods=["POST"]) -@require_auth -async def create_test_persona_update(): - """创建测试人格更新记录(仅用于开发调试)""" - if persona_updater: - try: - import time - from ..core.interfaces import PersonaUpdateRecord - - # 创建一个测试记录 - test_record = PersonaUpdateRecord( - timestamp=time.time(), - group_id="742376823", - update_type="prompt_update", - original_content="You are a helpful assistant.", - new_content="You are a helpful assistant with a friendly and enthusiastic personality. You enjoy helping users with their questions and respond in a warm, encouraging manner.", - reason="强化学习生成的prompt过短,采用保守融合策略" - ) - - record_id = await persona_updater.record_persona_update_for_review(test_record) - logger.info(f"创建测试人格更新记录,ID: {record_id}") - - return jsonify({ - "message": "Test persona update record created successfully", - "record_id": record_id - }) - except Exception as e: - logger.error(f"创建测试记录失败: {e}", exc_info=True) - return jsonify({"error": f"创建测试记录失败: {str(e)}"}), 500 - return jsonify({"error": "Persona updater not initialized"}), 500 - -@api_bp.route("/metrics") -@require_auth -async def get_metrics(): - """获取性能指标:API调用返回时间、LLM调用次数""" - try: - # 获取真实的LLM调用统计 - llm_stats = {} - if llm_client and hasattr(llm_client, 'get_call_statistics'): - # 从LLM适配器获取真实调用统计 - real_stats = llm_client.get_call_statistics() - for provider_type, stats in real_stats.items(): - if provider_type != 'overall': - llm_stats[f"{provider_type}_provider"] = { - "total_calls": stats.get('total_calls', 0), - "avg_response_time_ms": stats.get('avg_response_time_ms', 0), - "success_rate": stats.get('success_rate', 1.0), - "error_count": stats.get('error_count', 0) - } - else: - # 后备的模拟数据 - llm_stats = { - "filter_provider": {"total_calls": 0, "avg_response_time_ms": 0, "success_rate": 1.0, "error_count": 0}, - "refine_provider": {"total_calls": 0, "avg_response_time_ms": 0, "success_rate": 1.0, "error_count": 0} - } - - # 获取真实的消息统计 - total_messages = 0 - filtered_messages = 0 - if database_manager: - try: - # 从数据库获取真实统计 - stats = await database_manager.get_messages_statistics() - - # 验证返回的数据类型 - if not isinstance(stats, dict): - logger.warning(f"get_messages_statistics 返回了非字典类型: {type(stats)}, 值: {stats}") - stats = {} - - # 安全地获取并转换数值 - total_messages_raw = stats.get('total_messages', 0) - filtered_messages_raw = stats.get('filtered_messages', 0) - - # 类型转换带验证 - try: - total_messages = int(total_messages_raw) if total_messages_raw and str(total_messages_raw).replace('-', '').isdigit() else 0 - except (ValueError, TypeError) as e: - logger.warning(f"total_messages 转换失败,原始值: {total_messages_raw}, 类型: {type(total_messages_raw)}, 错误: {e}") - total_messages = 0 - - try: - filtered_messages = int(filtered_messages_raw) if filtered_messages_raw and str(filtered_messages_raw).replace('-', '').isdigit() else 0 - except (ValueError, TypeError) as e: - logger.warning(f"filtered_messages 转换失败,原始值: {filtered_messages_raw}, 类型: {type(filtered_messages_raw)}, 错误: {e}") - filtered_messages = 0 - - except Exception as e: - logger.warning(f"获取数据库统计失败: {e}") - # 使用配置中的统计作为后备 - total_messages = plugin_config.total_messages_collected if plugin_config else 0 - filtered_messages = getattr(plugin_config, 'filtered_messages', 0) if plugin_config else 0 - else: - # 使用配置中的统计 - total_messages = plugin_config.total_messages_collected if plugin_config else 0 - filtered_messages = getattr(plugin_config, 'filtered_messages', 0) if plugin_config else 0 - - # 获取系统性能指标 - import psutil - import time - - # CPU和内存使用率(使用非阻塞方式获取CPU使用率) - cpu_percent = psutil.cpu_percent(interval=0) # interval=0 返回上次调用后的平均值,不阻塞 - memory = psutil.virtual_memory() - - # 网络统计 - net_io = psutil.net_io_counters() - - # 磁盘使用率 - disk_usage = psutil.disk_usage('/') - - metrics = { - "llm_calls": llm_stats, - "api_response_times": { - "get_config": {"avg_time_ms": 10, "requests_count": 45}, - "get_persona_updates": {"avg_time_ms": 50, "requests_count": 12}, - "get_metrics": {"avg_time_ms": 25, "requests_count": 30}, - "post_config": {"avg_time_ms": 120, "requests_count": 8} - }, - "total_messages_collected": total_messages, - "filtered_messages": filtered_messages, - "learning_efficiency": 0, # 将被智能计算覆盖 - "system_metrics": { - "cpu_percent": cpu_percent, - "memory_percent": memory.percent, - "memory_used_gb": round(memory.used / (1024**3), 2), - "memory_total_gb": round(memory.total / (1024**3), 2), - "disk_usage_percent": round(disk_usage.used / disk_usage.total * 100, 2), - "network_bytes_sent": net_io.bytes_sent, - "network_bytes_recv": net_io.bytes_recv - }, - "database_metrics": { - "total_queries": getattr(database_manager, '_total_queries', 0) if database_manager else 0, - "avg_query_time_ms": getattr(database_manager, '_avg_query_time', 0) if database_manager else 0, - "connection_pool_size": getattr(database_manager, '_pool_size', 5) if database_manager else 5, - "active_connections": getattr(database_manager, '_active_connections', 2) if database_manager else 2 - } - } - - # 获取真实的学习会话统计 - 移到metrics字典之外 - active_sessions_count = 0 - total_sessions_today = 0 - avg_session_duration = 0 - success_rate = 0.0 - - # 从progressive_learning服务获取真实数据 - try: - # 使用当前应用的插件实例 - plugin_instance = current_app.plugin_instance if hasattr(current_app, 'plugin_instance') else None - progressive_learning = getattr(plugin_instance, 'progressive_learning', None) if plugin_instance else None - - if progressive_learning: - # 计算活跃会话数量 - active_sessions_count = sum(1 for active in progressive_learning.learning_active.values() if active) - - # 获取今天的会话统计(如果有的话) - if database_manager: - # 可以从数据库获取今天的会话记录 - import time - from datetime import datetime, timedelta - today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0).timestamp() - - # 这里可以调用数据库方法获取今天的会话数据 - # 暂时使用简单的估算 - total_sessions_today = len(progressive_learning.learning_sessions) if hasattr(progressive_learning, 'learning_sessions') else 0 - - # 计算成功率 - if hasattr(progressive_learning, 'learning_sessions') and progressive_learning.learning_sessions: - successful_sessions = sum(1 for session in progressive_learning.learning_sessions if session.success) - success_rate = successful_sessions / len(progressive_learning.learning_sessions) if progressive_learning.learning_sessions else 0.0 - - # 计算平均会话时长 - completed_sessions = [s for s in progressive_learning.learning_sessions if s.end_time] - if completed_sessions: - durations = [] - for session in completed_sessions: - try: - start = datetime.fromisoformat(session.start_time) - end = datetime.fromisoformat(session.end_time) - duration_minutes = (end - start).total_seconds() / 60 - durations.append(duration_minutes) - except: - continue - if durations: - avg_session_duration = sum(durations) / len(durations) - else: - # 后备方案:使用persona_updater状态作为基础指标 - active_sessions_count = 1 if persona_updater else 0 - - except Exception as e: - logger.warning(f"获取学习会话统计失败: {e}") - # 使用默认值 - active_sessions_count = 1 if persona_updater else 0 - - # 更新metrics字典中的learning_sessions部分 - metrics["learning_sessions"] = { - "active_sessions": active_sessions_count, - "total_sessions_today": total_sessions_today, - "avg_session_duration_minutes": round(avg_session_duration, 1), - "success_rate": round(success_rate, 2) - } - metrics["last_updated"] = time.time() - - # 使用智能指标计算服务计算学习效率 - if intelligence_metrics_service: - try: - # 统计额外的学习成果指标 - refined_content_count = 0 - style_patterns_learned = 0 - persona_updates_count = 0 - active_strategies = [] - - # ✅ 使用 ORM 方法获取统计数据(支持跨线程调用) - if database_manager: - try: - # 统计提炼内容数量 - refined_content_count = await database_manager.count_refined_messages() - - # 统计风格学习成果 - style_patterns_learned = await database_manager.count_style_learning_patterns() - - # 统计待审查的人格更新 - persona_updates_count = await database_manager.count_pending_persona_updates() - - logger.debug(f"学习统计: refined={refined_content_count}, style={style_patterns_learned}, persona={persona_updates_count}") - except Exception as db_error: - logger.warning(f"从数据库获取学习统计失败: {db_error}") - - # 统计激活的学习策略 - if plugin_config: - if plugin_config.enable_message_capture: - active_strategies.append("message_filtering") - if plugin_config.enable_auto_learning: - active_strategies.append("content_refinement") - active_strategies.append("persona_evolution") - if plugin_config.enable_expression_patterns: - active_strategies.append("style_learning") - if plugin_config.enable_knowledge_graph: - active_strategies.append("context_awareness") - - # 计算智能化学习效率 - efficiency_metrics = await intelligence_metrics_service.calculate_learning_efficiency( - total_messages=total_messages, - filtered_messages=filtered_messages, - refined_content_count=refined_content_count, - style_patterns_learned=style_patterns_learned, - persona_updates_count=persona_updates_count, - active_strategies=active_strategies - ) - - # 更新metrics中的学习效率 - metrics["learning_efficiency"] = efficiency_metrics.overall_efficiency - metrics["learning_efficiency_details"] = { - "message_filter_rate": efficiency_metrics.message_filter_rate, - "content_refine_quality": efficiency_metrics.content_refine_quality, - "style_learning_progress": efficiency_metrics.style_learning_progress, - "persona_update_quality": efficiency_metrics.persona_update_quality, - "active_strategies_count": efficiency_metrics.active_strategies_count, - "active_strategies": active_strategies - } - - logger.info(f"智能学习效率计算完成: {efficiency_metrics.overall_efficiency:.2f}%") - - except Exception as metrics_error: - logger.warning(f"智能学习效率计算失败,使用简单算法: {metrics_error}") - # 回退到简单计算 (确保类型转换,带错误处理) - try: - total_msg = int(total_messages) if total_messages and str(total_messages).isdigit() else 0 - except (ValueError, TypeError): - logger.warning(f"total_messages 类型转换失败,值为: {total_messages}") - total_msg = 0 - - try: - filtered_msg = int(filtered_messages) if filtered_messages and str(filtered_messages).isdigit() else 0 - except (ValueError, TypeError): - logger.warning(f"filtered_messages 类型转换失败,值为: {filtered_messages}") - filtered_msg = 0 - - metrics["learning_efficiency"] = (filtered_msg / total_msg * 100) if total_msg > 0 else 0 - else: - # 如果服务未初始化,使用简单算法 (确保类型转换,带错误处理) - try: - total_msg = int(total_messages) if total_messages and str(total_messages).isdigit() else 0 - except (ValueError, TypeError): - logger.warning(f"total_messages 类型转换失败,值为: {total_messages}") - total_msg = 0 - - try: - filtered_msg = int(filtered_messages) if filtered_messages and str(filtered_messages).isdigit() else 0 - except (ValueError, TypeError): - logger.warning(f"filtered_messages 类型转换失败,值为: {filtered_messages}") - filtered_msg = 0 - - metrics["learning_efficiency"] = (filtered_msg / total_msg * 100) if total_msg > 0 else 0 - - return jsonify(metrics) - - except Exception as e: - logger.error(f"获取性能指标失败: {e}", exc_info=True) - return jsonify({"error": f"获取性能指标失败: {str(e)}"}), 500 - -@api_bp.route("/metrics/realtime") -@require_auth -async def get_realtime_metrics(): - """获取实时性能指标""" - try: - import psutil - import time - - # 获取实时系统指标 - cpu_percent = psutil.cpu_percent() - memory = psutil.virtual_memory() - - # 获取最近的消息处理统计 - recent_stats = { - "messages_last_hour": 45, # 可以从数据库查询 - "llm_calls_last_hour": 12, - "avg_response_time_ms": 850, - "error_rate": 0.02 - } - - realtime_data = { - "timestamp": time.time(), - "cpu_percent": cpu_percent, - "memory_percent": memory.percent, - "recent_activity": recent_stats, - "status": { - "message_capture": plugin_config.enable_message_capture if plugin_config else False, - "auto_learning": plugin_config.enable_auto_learning if plugin_config else False, - "realtime_learning": plugin_config.enable_realtime_learning if plugin_config else False - } - } - - return jsonify(realtime_data) - - except Exception as e: - return jsonify({"error": f"获取实时指标失败: {str(e)}"}), 500 - -@api_bp.route("/learning/status") -@require_auth -async def get_learning_status(): - """获取学习状态详情""" - try: - # 获取真实的学习状态 - learning_status = { - "current_session": {"error": "无会话数据"}, - "today_summary": {"error": "无今日统计数据"}, - "recent_activities": [] - } - - if database_manager: - try: - # 获取最新的学习会话 - recent_sessions = await database_manager.get_recent_learning_sessions("default", 1) - if recent_sessions: - latest_session = recent_sessions[0] - learning_status["current_session"] = { - "session_id": latest_session.get('session_id', '未知'), - "start_time": datetime.fromtimestamp(latest_session.get('start_time', time.time())).strftime('%Y-%m-%d %H:%M:%S'), - "status": "已完成" if latest_session.get('success') else "失败", - "messages_processed": latest_session.get('messages_processed', 0), - "learning_progress": round(latest_session.get('quality_score', 0) * 100, 1), - "current_task": f"已处理{latest_session.get('filtered_messages', 0)}条筛选消息" - } - - # 获取今日统计 - message_stats = await database_manager.get_messages_statistics() - all_sessions = await database_manager.get_recent_learning_sessions("default", 10) - learning_status["today_summary"] = { - "sessions_completed": len(all_sessions) if all_sessions else 0, - "total_messages_learned": message_stats.get('filtered_messages', 0), - "persona_updates": 0, # TODO: 从数据库获取人格更新次数 - "success_rate": (sum(1 for s in all_sessions if s.get('success', False)) / len(all_sessions)) if all_sessions else 0.0 - } - - # 获取最近活动(基于学习批次) - recent_batches = await database_manager.get_recent_learning_batches(3) - for batch in recent_batches: - learning_status["recent_activities"].append({ - "timestamp": batch.get('start_time', time.time()), - "activity": f"学习批次: {batch.get('batch_name', '未命名')},处理{batch.get('message_count', 0)}条消息", - "result": "成功" if batch.get('success') else "失败" - }) - - if not learning_status["recent_activities"]: - learning_status["recent_activities"] = [{"error": "暂无最近活动数据"}] - - except Exception as e: - logger.warning(f"获取真实学习状态数据失败: {e}") - learning_status = { - "current_session": {"error": f"获取会话数据失败: {str(e)}"}, - "today_summary": {"error": f"获取统计数据失败: {str(e)}"}, - "recent_activities": [{"error": f"获取活动数据失败: {str(e)}"}] - } - - return jsonify(learning_status) - - except Exception as e: - return jsonify({"error": f"获取学习状态失败: {str(e)}"}), 500 - -@api_bp.route("/analytics/trends") -@require_auth -async def get_analytics_trends(): - """获取分析趋势数据""" - try: - import random - from datetime import datetime, timedelta - - # 生成过去24小时的趋势数据 - hours_data = [] - base_time = datetime.now() - timedelta(hours=23) - - for i in range(24): - current_time = base_time + timedelta(hours=i) - hours_data.append({ - "time": current_time.strftime("%H:%M"), - "raw_messages": random.randint(10, 60), - "filtered_messages": random.randint(5, 30), - "llm_calls": random.randint(2, 15), - "response_time": random.randint(400, 1500) - }) - - # 生成过去7天的数据 - days_data = [] - base_date = datetime.now() - timedelta(days=6) - - for i in range(7): - current_date = base_date + timedelta(days=i) - days_data.append({ - "date": current_date.strftime("%m-%d"), - "total_messages": random.randint(200, 800), - "learning_sessions": random.randint(5, 20), - "persona_updates": random.randint(0, 5), - "success_rate": round(random.uniform(0.7, 0.95), 2) - }) - - # 用户活跃度热力图数据 - heatmap_data = [] - days = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"] - for day_idx in range(7): - for hour in range(24): - activity_level = random.randint(0, 50) - # 工作时间活跃度更高 - if 9 <= hour <= 18 and day_idx < 5: - activity_level = random.randint(20, 50) - # 晚上和周末活跃度中等 - elif 19 <= hour <= 23 or day_idx >= 5: - activity_level = random.randint(10, 35) - - heatmap_data.append([hour, day_idx, activity_level]) - - trends_data = { - "hourly_trends": hours_data, - "daily_trends": days_data, - "activity_heatmap": { - "data": heatmap_data, - "days": days, - "hours": [f"{i}:00" for i in range(24)] - } - } - - return jsonify(trends_data) - - except Exception as e: - return jsonify({"error": f"获取趋势数据失败: {str(e)}"}), 500 - -# 人格管理相关API端点 - -@api_bp.route("/persona_management/list") -@require_auth -async def get_personas_list(): - """获取所有人格列表""" - try: - logger.info("开始获取人格列表...") - persona_web_mgr = get_persona_web_manager() - logger.info(f"PersonaWebManager实例: {persona_web_mgr}") - - if not persona_web_mgr: - logger.warning("PersonaWebManager未初始化,返回空列表") - return jsonify({"personas": []}) - - logger.info("调用get_all_personas_for_web...") - personas = await persona_web_mgr.get_all_personas_for_web() - logger.info(f"获取到 {len(personas)} 个人格") - - return jsonify({"personas": personas}) - - except Exception as e: - logger.error(f"获取人格列表失败: {e}", exc_info=True) - # 返回空列表而不是错误,避免前端显示错误 - return jsonify({"personas": []}) - -@api_bp.route("/persona_management/get/") -@require_auth -async def get_persona_details(persona_id: str): - """获取特定人格详情""" - if not persona_manager: - return jsonify({"error": "PersonaManager未初始化"}), 500 - - try: - persona = await persona_manager.get_persona(persona_id) - if not persona: - return jsonify({"error": "人格不存在"}), 404 - - persona_dict = { - "persona_id": persona.persona_id, - "system_prompt": persona.system_prompt, - "begin_dialogs": persona.begin_dialogs, - "tools": persona.tools, - "created_at": persona.created_at.isoformat() if persona.created_at else None, - "updated_at": persona.updated_at.isoformat() if persona.updated_at else None, - } - - return jsonify(persona_dict) - - except Exception as e: - logger.error(f"获取人格详情失败: {e}") - return jsonify({"error": f"获取人格详情失败: {str(e)}"}), 500 - -@api_bp.route("/persona_management/create", methods=["POST"]) -@require_auth -async def create_persona(): - """创建新人格""" - persona_web_mgr = get_persona_web_manager() - if not persona_web_mgr: - return jsonify({"error": "人格管理功能暂不可用,请检查AstrBot PersonaManager配置"}), 503 - - try: - data = await request.get_json() - result = await persona_web_mgr.create_persona_via_web(data) - - if result["success"]: - return jsonify({"message": "人格创建成功", "persona_id": result["persona_id"]}) - else: - return jsonify({"error": result["error"]}), 400 - - except Exception as e: - logger.error(f"创建人格失败: {e}", exc_info=True) - return jsonify({"error": f"创建人格失败: {str(e)}"}), 500 - -@api_bp.route("/persona_management/update/", methods=["POST"]) -@require_auth -async def update_persona(persona_id: str): - """更新人格""" - persona_web_mgr = get_persona_web_manager() - if not persona_web_mgr: - return jsonify({"error": "人格管理功能暂不可用,请检查AstrBot PersonaManager配置"}), 503 - - try: - data = await request.get_json() - result = await persona_web_mgr.update_persona_via_web(persona_id, data) - - if result["success"]: - return jsonify({"message": "人格更新成功"}) - else: - return jsonify({"error": result["error"]}), 400 - - except Exception as e: - logger.error(f"更新人格失败: {e}", exc_info=True) - return jsonify({"error": f"更新人格失败: {str(e)}"}), 500 - -@api_bp.route("/persona_management/delete/", methods=["POST"]) -@require_auth -async def delete_persona(persona_id: str): - """删除人格""" - persona_web_mgr = get_persona_web_manager() - if not persona_web_mgr: - return jsonify({"error": "人格管理功能暂不可用,请检查AstrBot PersonaManager配置"}), 503 - - try: - result = await persona_web_mgr.delete_persona_via_web(persona_id) - - if result["success"]: - return jsonify({"message": "人格删除成功"}) - else: - return jsonify({"error": result["error"]}), 400 - - except Exception as e: - logger.error(f"删除人格失败: {e}", exc_info=True) - return jsonify({"error": f"删除人格失败: {str(e)}"}), 500 - -@api_bp.route("/persona_management/default") -@require_auth -async def get_default_persona(): - """获取默认人格""" - persona_web_mgr = get_persona_web_manager() - if not persona_web_mgr: - # 返回一个基本的默认人格,而不是错误 - return jsonify({ - "persona_id": "default", - "system_prompt": "You are a helpful assistant.", - "begin_dialogs": [], - "tools": [] - }) - - try: - default_persona = await persona_web_mgr.get_default_persona_for_web() - return jsonify(default_persona) - - except Exception as e: - logger.error(f"获取默认人格失败: {e}", exc_info=True) - # 返回基本默认人格而不是错误 - return jsonify({ - "persona_id": "default", - "system_prompt": "You are a helpful assistant.", - "begin_dialogs": [], - "tools": [] - }) - -@api_bp.route("/persona_management/export/") -@require_auth -async def export_persona(persona_id: str): - """导出人格配置""" - if not persona_manager: - return jsonify({"error": "PersonaManager未初始化"}), 500 - - try: - persona = await persona_manager.get_persona(persona_id) - if not persona: - return jsonify({"error": "人格不存在"}), 404 - - from datetime import datetime - persona_export = { - "persona_id": persona.persona_id, - "system_prompt": persona.system_prompt, - "begin_dialogs": persona.begin_dialogs, - "tools": persona.tools, - "export_time": datetime.now().isoformat(), - "export_version": "1.0" - } - - return jsonify(persona_export) - - except Exception as e: - logger.error(f"导出人格失败: {e}") - return jsonify({"error": f"导出人格失败: {str(e)}"}), 500 - -@api_bp.route("/persona_management/import", methods=["POST"]) -@require_auth -async def import_persona(): - """导入人格配置""" - if not persona_manager: - return jsonify({"error": "PersonaManager未初始化"}), 500 - - try: - data = await request.get_json() - - # 验证导入数据格式 - required_fields = ["persona_id", "system_prompt"] - for field in required_fields: - if field not in data: - return jsonify({"error": f"缺少必需字段: {field}"}), 400 - - persona_id = data["persona_id"] - system_prompt = data["system_prompt"] - begin_dialogs = data.get("begin_dialogs", []) - tools = data.get("tools", []) - - # 检查是否覆盖现有人格 - overwrite = data.get("overwrite", False) - existing_persona = await persona_manager.get_persona(persona_id) - - if existing_persona and not overwrite: - return jsonify({ - "error": "人格已存在,如要覆盖请设置overwrite=true" - }), 400 - - # 创建或更新人格 - if existing_persona: - success = await persona_manager.update_persona( - persona_id=persona_id, - system_prompt=system_prompt, - begin_dialogs=begin_dialogs, - tools=tools - ) - action = "更新" - else: - success = await persona_manager.create_persona( - persona_id=persona_id, - system_prompt=system_prompt, - begin_dialogs=begin_dialogs, - tools=tools - ) - action = "创建" - - if success: - logger.info(f"成功导入人格: {persona_id} ({action})") - return jsonify({"message": f"人格{action}成功", "persona_id": persona_id}) - else: - return jsonify({"error": f"人格{action}失败"}), 500 - - except Exception as e: - logger.error(f"导入人格失败: {e}") - return jsonify({"error": f"导入人格失败: {str(e)}"}), 500 - -@api_bp.route("/style_learning/results", methods=["GET"]) -@require_auth -async def get_style_learning_results(): - """获取风格学习结果""" - try: - # 初始化空数据结构 - results_data = { - 'statistics': { - 'unique_styles': 0, - 'avg_confidence': 0, - 'total_samples': 0, - 'latest_update': None - }, - 'style_progress': [] - } - - if db_manager: - try: - # 尝试从数据库获取真实数据 - real_stats = await db_manager.get_style_learning_statistics() - if real_stats: - results_data['statistics'].update(real_stats) - - real_progress = await db_manager.get_style_progress_data() - if real_progress: - results_data['style_progress'] = real_progress - except Exception as e: - logger.warning(f"无法从数据库获取风格学习数据: {e}") - - return jsonify(results_data) - - except Exception as e: - logger.error(f"获取风格学习结果失败: {e}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/style_learning/reviews", methods=["GET"]) -@require_auth -async def get_style_learning_reviews(): - """获取对话风格学习审查列表""" - try: - if not database_manager: - return jsonify({'error': '数据库管理器未初始化'}), 500 - - pending_reviews = await database_manager.get_pending_style_reviews(limit=50) - - # 格式化审查数据 - formatted_reviews = [] - for review in pending_reviews: - formatted_review = { - 'id': review['id'], - 'type': '对话风格学习', - 'group_id': review['group_id'], - 'description': review['description'], - 'timestamp': review['timestamp'], - 'created_at': review['created_at'], - 'status': review['status'], - 'learned_patterns': review['learned_patterns'], - 'few_shots_content': review['few_shots_content'] - } - formatted_reviews.append(formatted_review) - - return jsonify({ - 'reviews': formatted_reviews, - 'total': len(formatted_reviews) - }) - - except Exception as e: - logger.error(f"获取风格学习审查列表失败: {e}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/style_learning/reviews//approve", methods=["POST"]) -@require_auth -async def approve_style_learning_review(review_id: int): - """批准对话风格学习审查 - 使用与人格学习审查相同的备份逻辑""" - try: - if not database_manager: - return jsonify({'error': '数据库管理器未初始化'}), 500 - - # 获取审查详情 - pending_reviews = await database_manager.get_pending_style_reviews() - target_review = None - for review in pending_reviews: - if review['id'] == review_id: - target_review = review - break - - if not target_review: - return jsonify({'error': '审查记录不存在'}), 404 - - # 更新状态为approved - success = await database_manager.update_style_review_status(review_id, 'approved', target_review['group_id']) - - if success: - # 应用到人格(使用与人格学习审查相同的逻辑:备份+应用) - if target_review['few_shots_content']: - persona_update_content = target_review['few_shots_content'] - group_id = target_review.get('group_id', 'default') - message = f'风格学习审查 {review_id} 已批准' - - # ===== 自动应用到框架默认人格(独立于persona_updater) ===== - auto_apply_enabled = plugin_config and getattr(plugin_config, 'auto_apply_approved_persona', False) - logger.info(f"[自动应用] 检查配置: auto_apply={auto_apply_enabled}, persona_manager={persona_manager is not None}, content_len={len(persona_update_content)}") - if auto_apply_enabled and persona_manager: - try: - umo = _resolve_umo(group_id) - current_persona = await persona_manager.get_default_persona_v3(umo) - if current_persona: - p_name = current_persona.get('name', 'default') - logger.info(f"[自动应用] 准备更新默认人格 [{p_name}],内容长度: {len(persona_update_content)},群组: {group_id}") - await persona_manager.update_persona( - persona_id=p_name, - system_prompt=persona_update_content - ) - logger.info(f"[自动应用] ✅ 已将风格学习审查内容应用到默认人格 [{p_name}]") - message += f',已自动应用到默认人格 [{p_name}]' - else: - logger.warning("[自动应用] 无法获取当前默认人格") - except Exception as auto_err: - logger.error(f"[自动应用] ❌ 应用到默认人格失败: {auto_err}", exc_info=True) - message += f',但自动应用到默认人格失败: {str(auto_err)}' - - # ===== 原有的update_persona_with_style逻辑(备份+内存更新) ===== - if persona_updater: - try: - style_analysis = { - 'enhanced_prompt': persona_update_content, - 'style_features': [], - 'style_attributes': {}, - 'confidence': 0.8, - 'source': f'风格学习审查{review_id}' - } - success_apply = await persona_updater.update_persona_with_style( - group_id, style_analysis, [] - ) - if success_apply: - logger.info(f"✅ 风格学习审查 {review_id} 备份和内存更新完成") - else: - logger.warning(f"❌ 风格学习审查 {review_id} update_persona_with_style返回False") - except Exception as e: - logger.error(f"update_persona_with_style失败: {e}", exc_info=True) - else: - message = f'风格学习审查 {review_id} 已批准(无内容需要应用)' - - return jsonify({ - 'success': True, - 'message': message - }) - else: - return jsonify({'error': '批准失败,请检查审查记录状态'}), 500 - - except Exception as e: - logger.error(f"批准风格学习审查失败: {e}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/style_learning/reviews//reject", methods=["POST"]) -@require_auth -async def reject_style_learning_review(review_id: int): - """拒绝对话风格学习审查""" - try: - if not database_manager: - return jsonify({'error': '数据库管理器未初始化'}), 500 - - # 更新状态为rejected - success = await database_manager.update_style_review_status(review_id, 'rejected') - - if success: - logger.info(f"风格学习审查 {review_id} 已拒绝") - return jsonify({ - 'success': True, - 'message': f'风格学习审查 {review_id} 已拒绝' - }) - else: - return jsonify({'error': '拒绝失败,请检查审查记录状态'}), 500 - - except Exception as e: - logger.error(f"拒绝风格学习审查失败: {e}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/style_learning/patterns", methods=["GET"]) -@require_auth -async def get_style_learning_patterns(): - """获取风格学习模式""" - try: - # 初始化空模式数据 - patterns_data = { - 'emotion_patterns': [], - 'language_patterns': [], - 'topic_preferences': [] - } - - if db_manager: - try: - # 尝试从数据库获取真实模式数据 - real_patterns = await db_manager.get_learning_patterns_data() - if real_patterns: - patterns_data.update(real_patterns) - except Exception as e: - logger.warning(f"无法从数据库获取学习模式数据: {e}") - - return jsonify(patterns_data) - - except Exception as e: - logger.error(f"获取风格学习模式失败: {e}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/metrics/detailed", methods=["GET"]) -@require_auth -async def get_detailed_metrics(): - """获取详细性能监控数据""" - try: - # 初始化空详细数据 - detailed_data = { - 'api_metrics': { - 'hours': [], - 'response_times': [] - }, - 'database_metrics': { - 'table_stats': {} - }, - 'system_metrics': { - 'memory_percent': 0, - 'cpu_percent': 0, - 'disk_percent': 0 - } - } - - if db_manager: - try: - # 尝试从数据库获取真实详细数据 - real_detailed = await db_manager.get_detailed_metrics() - if real_detailed: - detailed_data.update(real_detailed) - except Exception as e: - logger.warning(f"无法从数据库获取详细监控数据: {e}") - - return jsonify(detailed_data) - - except Exception as e: - logger.error(f"获取详细监控数据失败: {e}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/metrics/trends", methods=["GET"]) -@require_auth -async def get_metrics_trends(): - """获取指标趋势数据""" - try: - # 初始化空趋势数据 - trends_data = { - 'message_growth': 0, - 'filtered_growth': 0, - 'llm_growth': 0, - 'sessions_growth': 0 - } - - if db_manager: - try: - # 尝试从数据库获取真实趋势数据 - real_trends = await db_manager.get_trends_data() - if real_trends: - trends_data.update(real_trends) - except Exception as e: - logger.warning(f"无法从数据库获取趋势数据: {e}") - - return jsonify(trends_data) - - except Exception as e: - logger.error(f"获取趋势数据失败: {e}") - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/style_learning/content_text", methods=["GET"]) -@require_auth -async def get_style_learning_content_text(): - """获取对话风格学习的所有内容文本(带缓存)""" - global _style_learning_content_cache, _style_learning_content_cache_time - - # 检查是否强制刷新 - force_refresh = request.args.get('force_refresh', 'false').lower() == 'true' - - # 检查缓存是否有效 - current_time = time.time() - if not force_refresh and _style_learning_content_cache is not None and _style_learning_content_cache_time is not None: - cache_age = current_time - _style_learning_content_cache_time - if cache_age < _style_learning_content_cache_ttl: - logger.info(f"使用缓存的学习内容数据(缓存年龄: {cache_age:.1f}秒)") - return jsonify(_style_learning_content_cache) - - logger.info(f"开始执行get_style_learning_content_text API请求(强制刷新: {force_refresh})") - try: - # 从数据库获取学习相关的文本内容 - content_data = { - 'dialogues': [], - 'analysis': [], - 'features': [], - 'history': [] - } - logger.debug("初始化content_data数据结构") - - if db_manager: - logger.info("数据库管理器可用,开始获取学习内容数据") - try: - # 获取对话示例文本 - 从raw_messages表获取最近的原始消息 - logger.debug("开始获取对话示例文本...") - - # 优先使用SQLAlchemy从raw_messages获取 - try: - async with db_manager.get_session() as session: - from sqlalchemy import select, desc, func - from .models.orm import RawMessage - - # 获取最近20条消息,按时间倒序 - stmt = select(RawMessage).order_by(desc(RawMessage.timestamp)).limit(20) - result = await session.execute(stmt) - raw_messages = result.scalars().all() - - logger.info(f"从raw_messages表获取到 {len(raw_messages)} 条原始消息用于对话示例") - - if raw_messages: - for i, msg in enumerate(raw_messages): - # 过滤太短的消息 - message_text = msg.message if msg.message else '' - if len(message_text.strip()) < 5: - continue - - content_data['dialogues'].append({ - 'timestamp': datetime.fromtimestamp(msg.timestamp if msg.timestamp else time.time()).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"{msg.sender_name or msg.sender_id}: {message_text}", - 'metadata': f"群组: {msg.group_id}, 平台: {msg.platform or '未知'}" - }) - if i == 0: - logger.debug(f"第一条对话示例: 群组={msg.group_id}, 时间={msg.timestamp}, 内容长度={len(message_text)}") - logger.info(f"成功添加 {len([d for d in content_data['dialogues']])} 条对话示例") - else: - logger.warning("raw_messages表为空") - raise ValueError("raw_messages表为空") - - except Exception as e: - logger.warning(f"从raw_messages表获取失败: {e}, 尝试降级方法") - # 降级到filtered_messages表 - recent_messages = await db_manager.get_filtered_messages_for_learning(20) - logger.info(f"降级获取到 {len(recent_messages) if recent_messages else 0} 条筛选消息") - - if recent_messages: - for i, msg in enumerate(recent_messages): - content_data['dialogues'].append({ - 'timestamp': datetime.fromtimestamp(msg.get('timestamp', time.time())).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"用户: {msg.get('message', '暂无内容')}", - 'metadata': f"置信度: {msg.get('confidence', 0):.1%}, 群组: {msg.get('group_id', '未知')}" - }) - logger.info(f"成功添加 {len(recent_messages)} 条对话示例") - - # 如果仍然没有数据,显示提示 - if not content_data['dialogues']: - logger.warning("未找到任何消息,显示默认提示") - content_data['dialogues'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': '暂无对话数据,请先进行一些群聊对话,系统会自动学习和筛选有价值的内容', - 'metadata': '系统提示' - }) - - except Exception as e: - logger.error(f"获取对话示例文本失败: {e}", exc_info=True) - content_data['dialogues'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': f'获取对话数据时出错: {str(e)}', - 'metadata': '错误信息' - }) - else: - logger.error("数据库管理器不可用,无法获取学习内容数据") - - if db_manager: - try: - # 获取风格分析结果 - 使用对话风格学习记录 - logger.info("开始获取风格学习分析结果...") - - # 优先从 style_learning_reviews 表获取对话风格学习记录 - try: - async with db_manager.get_session() as session: - from sqlalchemy import select, desc - from .models.orm.learning import StyleLearningReview - - stmt = select(StyleLearningReview).order_by(desc(StyleLearningReview.timestamp)).limit(5) - result = await session.execute(stmt) - style_reviews = result.scalars().all() - - logger.info(f"从数据库获取到 {len(style_reviews)} 个对话风格学习记录") - - if style_reviews: - for i, review in enumerate(style_reviews): - # 解析 learned_patterns 获取消息数量 - try: - patterns = json.loads(review.learned_patterns) if review.learned_patterns else [] - pattern_count = len(patterns) - except: - pattern_count = 0 - - # 从描述中提取消息数量(格式: "处理 X 条消息") - import re - message_count = 0 - if review.description: - match = re.search(r'处理\s*(\d+)\s*条消息', review.description) - if match: - message_count = int(match.group(1)) - - review_time = review.timestamp if review.timestamp else time.time() - - logger.debug(f"处理对话风格学习记录 {i+1}/{len(style_reviews)}: " - f"消息数: {message_count}, 模式数: {pattern_count}") - - content_data['analysis'].append({ - 'timestamp': datetime.fromtimestamp(review_time).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"对话风格学习\n处理消息: {message_count}条\n提取模式: {pattern_count}个", - 'metadata': f"状态: {review.status or '已完成'}" - }) - logger.info(f"成功添加 {len(style_reviews)} 个对话风格学习记录到分析内容") - else: - logger.warning("未找到任何对话风格学习记录") - content_data['analysis'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': '暂无学习分析数据,系统还未开始学习过程', - 'metadata': '系统提示' - }) - except Exception as e: - logger.error(f"从 style_learning_reviews 表获取数据失败: {e}", exc_info=True) - # 降级到旧的方法 - content_data['analysis'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': f'获取学习数据时出错: {str(e)}', - 'metadata': '错误信息' - }) - - except Exception as e: - logger.error(f"获取风格分析结果失败: {e}", exc_info=True) - content_data['analysis'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': f'获取分析数据时出错: {str(e)}', - 'metadata': '错误信息' - }) - - if db_manager: - try: - # 获取提炼的风格特征 - 使用工厂模式的方法 - logger.info("开始获取风格特征数据...") - - # 1. 从表达模式数据获取 - 优先使用 SQLAlchemy 数据库管理器 - try: - logger.debug("尝试从 SQLAlchemy 数据库管理器获取表达模式...") - group_patterns = await db_manager.get_all_expression_patterns() - - logger.info(f"[WebUI DEBUG] get_all_expression_patterns返回类型: {type(group_patterns)}") - logger.info(f"[WebUI DEBUG] get_all_expression_patterns返回值: {group_patterns is not None}") - if group_patterns: - logger.info(f"[WebUI DEBUG] 群组数量: {len(group_patterns)}") - for gid, pats in list(group_patterns.items())[:3]: - logger.info(f"[WebUI DEBUG] 群组 {gid}: {len(pats)} 个模式") - - if group_patterns: - logger.info(f"[WebUI] 从 SQLAlchemy 获取到 {len(group_patterns)} 个群组的模式") - - pattern_count = 0 - for group_id, patterns in group_patterns.items(): - logger.info(f"[WebUI DEBUG] 处理群组 {group_id} 的 {len(patterns)} 个表达模式") - for i, pattern in enumerate(patterns[:5]): # 每个群组取前5个 - logger.debug(f"[WebUI DEBUG] 群组 {group_id} 模式 {i}: type={type(pattern)}, is_dict={isinstance(pattern, dict)}") - if isinstance(pattern, dict): - logger.debug(f"[WebUI DEBUG] 模式字典keys: {pattern.keys()}") - # 处理字典格式(SQLAlchemy 返回) - if isinstance(pattern, dict): - if 'situation' in pattern and 'expression' in pattern: - content_data['features'].append({ - 'timestamp': datetime.fromtimestamp(pattern.get('last_active_time', time.time())).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"场景: {pattern['situation']}\n表达: {pattern['expression']}", - 'metadata': f"权重: {pattern.get('weight', 0.5):.2f}, 群组: {group_id}" - }) - pattern_count += 1 - logger.debug(f"[WebUI DEBUG] 成功添加模式: {pattern['situation'][:20]}...") - else: - logger.warning(f"[WebUI DEBUG] 模式缺少必要字段,有situation={('situation' in pattern)},有expression={('expression' in pattern)}") - # 处理对象格式(传统方法返回) - elif hasattr(pattern, 'situation') and hasattr(pattern, 'expression'): - content_data['features'].append({ - 'timestamp': datetime.fromtimestamp(getattr(pattern, 'last_active_time', time.time())).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"场景: {pattern.situation}\n表达: {pattern.expression}", - 'metadata': f"权重: {getattr(pattern, 'weight', 0.5):.2f}, 群组: {group_id}" - }) - pattern_count += 1 - logger.debug(f"[WebUI DEBUG] 成功添加对象模式") - else: - logger.warning(f"[WebUI DEBUG] 模式既不是字典也不是对象,或缺少必要属性") - logger.info(f"成功添加 {pattern_count} 个表达模式特征") - else: - logger.warning("[WebUI] SQLAlchemy 返回空数据,降级到表达模式学习器") - raise ValueError("SQLAlchemy 返回空数据") - - except Exception as e: - # 降级到表达模式学习器方法 - logger.warning(f"[WebUI] SQLAlchemy 获取表达模式失败: {e},降级到表达模式学习器") - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - component_factory = factory_manager.get_component_factory() - expression_learner = component_factory.create_expression_pattern_learner() - - # 获取所有群组的表达模式 - logger.debug("获取表达模式数据...") - if hasattr(expression_learner, 'get_all_group_patterns'): - group_patterns = await expression_learner.get_all_group_patterns() - logger.info(f"从表达模式学习器获取到 {len(group_patterns)} 个群组的模式") - - pattern_count = 0 - for group_id, patterns in group_patterns.items(): - logger.debug(f"处理群组 {group_id} 的 {len(patterns)} 个表达模式") - for pattern in patterns[:5]: # 每个群组取前5个 - if hasattr(pattern, 'situation') and hasattr(pattern, 'expression'): - content_data['features'].append({ - 'timestamp': datetime.fromtimestamp(getattr(pattern, 'last_active_time', time.time())).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"场景: {pattern.situation}\n表达: {pattern.expression}", - 'metadata': f"权重: {getattr(pattern, 'weight', 0.5):.2f}, 群组: {group_id}" - }) - pattern_count += 1 - logger.info(f"成功添加 {pattern_count} 个表达模式特征") - else: - # 回退到 ORM 查询 - logger.debug("表达模式学习器不支持get_all_group_patterns方法,使用ORM查询") - from sqlalchemy import select - from .models.orm import ExpressionPattern as ExprPatternModel - - async with db_manager.get_session() as session: - stmt = select(ExprPatternModel).order_by( - ExprPatternModel.last_active_time.desc() - ).limit(10) - result = await session.execute(stmt) - expression_patterns = result.scalars().all() - - if expression_patterns: - logger.info(f"从数据库直接查询到 {len(expression_patterns)} 个表达模式") - for pattern in expression_patterns: - content_data['features'].append({ - 'timestamp': datetime.fromtimestamp(pattern.last_active_time).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"场景: {pattern.situation}\n表达: {pattern.expression}", - 'metadata': f"权重: {pattern.weight:.2f}, 群组: {pattern.group_id}" - }) - else: - logger.warning("数据库中未找到表达模式记录") - - except Exception as e: - logger.warning(f"获取表达模式失败,将尝试其他数据源: {e}") - - # 2. 从风格学习审查中获取特征 - 使用工厂方法 - try: - logger.debug("获取风格学习审查数据...") - # 获取待审查的风格学习内容 - pending_style_reviews = await db_manager.get_pending_style_reviews() - logger.info(f"获取到 {len(pending_style_reviews) if pending_style_reviews else 0} 个待审查的风格学习记录") - - for review in pending_style_reviews: - if review.get('few_shots_content'): - content_data['features'].append({ - 'timestamp': datetime.fromtimestamp(review['timestamp']).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"风格学习内容:\n{review['few_shots_content'][:300]}{'...' if len(review['few_shots_content']) > 300 else ''}", - 'metadata': f"状态: 待审查, 描述: {review.get('description', '无')}" - }) - - # 获取已批准的风格学习内容 - approved_style_reviews = await db_manager.get_reviewed_style_learning_updates(limit=10, status_filter='approved') - logger.info(f"获取到 {len(approved_style_reviews) if approved_style_reviews else 0} 个已批准的风格学习记录") - - for review in approved_style_reviews: - if review.get('few_shots_content'): - content_data['features'].append({ - 'timestamp': datetime.fromtimestamp(review.get('review_time', review['timestamp'])).strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"已应用风格特征:\n{review['few_shots_content'][:300]}{'...' if len(review['few_shots_content']) > 300 else ''}", - 'metadata': f"状态: 已批准应用, 描述: {review.get('description', '无')}" - }) - - except Exception as e: - logger.warning(f"从风格学习审查获取特征失败: {e}") - - # 如果所有数据源都没有数据,显示提示 - if not content_data['features']: - logger.warning("未从任何数据源获取到风格特征,显示默认提示") - content_data['features'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': '暂无学习到的表达模式,请耐心等待系统学习', - 'metadata': '系统提示' - }) - else: - logger.info(f"成功获取到 {len(content_data['features'])} 个风格特征") - - except Exception as e: - logger.error(f"获取风格特征失败: {e}", exc_info=True) - content_data['features'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': f'获取特征数据时出错: {str(e)}', - 'metadata': '错误信息' - }) - - if db_manager: - try: - # 获取学习历程记录 - 使用现有的方法 - logger.info("开始获取学习历程记录...") - message_stats = await db_manager.get_messages_statistics() - logger.debug(f"获取到消息统计: {message_stats}") - - content_data['history'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': f"系统统计:\n总消息数: {message_stats.get('total_messages', 0)}条\n已筛选: {message_stats.get('filtered_messages', 0)}条\n待学习: {message_stats.get('unused_filtered_messages', 0)}条", - 'metadata': '实时统计' - }) - logger.info(f"成功添加学习历程记录") - except Exception as e: - logger.warning(f"获取学习历程记录失败: {e}") - content_data['history'].append({ - 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'text': f'获取历程数据时出错: {str(e)}', - 'metadata': '错误信息' - }) - - # 汇总所有获取的数据并记录最终状态 - logger.info("完成所有学习内容数据获取,开始汇总统计...") - total_dialogues = len(content_data['dialogues']) - total_analysis = len(content_data['analysis']) - total_features = len(content_data['features']) - total_history = len(content_data['history']) - - logger.info(f"内容数据汇总: 对话示例={total_dialogues}条, 分析结果={total_analysis}条, " - f"特征数据={total_features}条, 历程记录={total_history}条") - - # 检查数据完整性 - if total_dialogues == 0 and total_analysis == 0 and total_features == 0: - logger.warning("所有主要数据源都为空,可能系统尚未进行学习或数据库存在问题") - else: - logger.info("成功获取学习内容数据,数据完整性良好") - - # 更新缓存 - _style_learning_content_cache = content_data - _style_learning_content_cache_time = current_time - logger.info(f"已更新学习内容缓存(TTL: {_style_learning_content_cache_ttl}秒)") - - logger.info("get_style_learning_content_text API请求处理完成") - return jsonify(content_data) - - except Exception as e: - logger.error(f"get_style_learning_content_text API处理失败: {e}", exc_info=True) - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/style_learning/clear_cache", methods=["POST"]) -@require_auth -async def clear_style_learning_cache(): - """清除学习内容缓存""" - global _style_learning_content_cache, _style_learning_content_cache_time - try: - _style_learning_content_cache = None - _style_learning_content_cache_time = None - logger.info("已清除学习内容缓存") - return jsonify({'success': True, 'message': '缓存已清除'}) - except Exception as e: - logger.error(f"清除缓存失败: {e}") - return jsonify({'success': False, 'error': str(e)}), 500 - -# 新增的高级功能API端点 - -@api_bp.route("/advanced/data_analytics") -@require_auth -async def get_data_analytics(): - """获取数据分析与可视化""" - try: - from .core.factory import FactoryManager - - # 获取工厂管理器 - factory_manager = FactoryManager() - component_factory = factory_manager.get_component_factory() - - # 创建数据分析服务 - data_analytics_service = component_factory.create_data_analytics_service() - - group_id = request.args.get('group_id', 'default') - days = int(request.args.get('days', '30')) - - # 获取真实的分析数据 - learning_trajectory = await data_analytics_service.generate_learning_trajectory_chart(group_id, days) - user_activity_heatmap = await data_analytics_service.generate_user_activity_heatmap(group_id, days) - social_network = await data_analytics_service.generate_social_network_graph(group_id) - - analytics_data = { - "learning_trajectory": learning_trajectory, - "user_activity_heatmap": user_activity_heatmap, - "social_network": social_network - } - - return jsonify(analytics_data) - - except Exception as e: - return jsonify({"error": f"获取数据分析失败: {str(e)}"}), 500 - -@api_bp.route("/advanced/learning_status") -@require_auth -async def get_advanced_learning_status(): - """获取高级学习状态""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - component_factory = factory_manager.get_component_factory() - - # 创建高级学习服务 - advanced_learning_service = component_factory.create_advanced_learning_service() - - group_id = request.args.get('group_id', 'default') - - # 获取真实的高级学习状态 - status = await advanced_learning_service.get_learning_status(group_id) - - return jsonify(status) - - except Exception as e: - return jsonify({"error": f"获取高级学习状态失败: {str(e)}"}), 500 - -@api_bp.route("/advanced/interaction_status") -@require_auth -async def get_interaction_status(): - """获取交互增强状态""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - component_factory = factory_manager.get_component_factory() - - # 创建增强交互服务 - interaction_service = component_factory.create_enhanced_interaction_service() - - group_id = request.args.get('group_id', 'default') - - # 获取真实的交互状态 - status = await interaction_service.get_interaction_status(group_id) - - return jsonify(status) - - except Exception as e: - return jsonify({"error": f"获取交互状态失败: {str(e)}"}), 500 - -@api_bp.route("/advanced/intelligence_status") -@require_auth -async def get_intelligence_status(): - """获取智能化状态""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - component_factory = factory_manager.get_component_factory() - - # 创建智能化服务 - intelligence_service = component_factory.create_intelligence_enhancement_service() - - group_id = request.args.get('group_id', 'default') - - # 获取真实的智能化状态 - status = await intelligence_service.get_intelligence_status(group_id) - - return jsonify(status) - - except Exception as e: - return jsonify({"error": f"获取智能化状态失败: {str(e)}"}), 500 - -@api_bp.route("/advanced/trigger_context_switch", methods=["POST"]) -@require_auth -async def trigger_context_switch(): - """手动触发情境切换""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - component_factory = factory_manager.get_component_factory() - - # 创建高级学习服务 - advanced_learning_service = component_factory.create_advanced_learning_service() - - data = await request.get_json() - group_id = data.get('group_id', 'default') - target_context = data.get('target_context', 'casual') - - # 调用实际的情境切换功能 - result = await advanced_learning_service.trigger_context_switch(group_id, target_context) - - return jsonify(result) - - except Exception as e: - return jsonify({"error": f"情境切换失败: {str(e)}"}), 500 - -@api_bp.route("/advanced/generate_recommendations", methods=["POST"]) -@require_auth -async def generate_recommendations(): - """生成个性化推荐""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - component_factory = factory_manager.get_component_factory() - - # 创建智能化服务 - intelligence_service = component_factory.create_intelligence_enhancement_service() - - data = await request.get_json() - group_id = data.get('group_id', 'default') - user_id = data.get('user_id', 'user_1') - - # 调用实际的个性化推荐功能 - recommendations = await intelligence_service.generate_personalized_recommendations( - group_id, user_id, data - ) - - # 转换为字典格式 - recommendations_dict = [ - { - "type": rec.recommendation_type, - "content": rec.content, - "confidence": rec.confidence, - "reasoning": rec.reasoning - } - for rec in recommendations - ] - - return jsonify({"recommendations": recommendations_dict}) - - except Exception as e: - return jsonify({"error": f"生成推荐失败: {str(e)}"}), 500 - -@api_bp.route("/style_learning/stats", methods=["GET"]) -@require_auth -async def get_style_learning_stats(): - """获取对���风格学习统计数据""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - - # 获取表达模式学习器 - component_factory = factory_manager.get_component_factory() - expression_learner = component_factory.create_expression_pattern_learner() - - # 获取数据库管理器 - db_manager = service_factory.create_database_manager() - - # 获取基本统计信息 - stats = { - 'style_types_count': 0, - 'avg_confidence': 0, - 'total_samples': 0, # 改为统计原始消息总数 - 'latest_update': '--', - 'learning_groups': [], - 'style_features': [] - } - - try: - # 先统计数据库中的原始消息总数(用于前端显示) - from sqlalchemy import select, func - from .models.orm import RawMessage as RawMsgModel - - async with db_manager.get_session() as session: - stmt = select(func.count()).select_from(RawMsgModel).where( - RawMsgModel.sender_id != 'bot' - ) - result = await session.execute(stmt) - total_samples = result.scalar() or 0 - stats['total_samples'] = total_samples - - # 获取所有群组的表达模式(用于其他统计) - # 优先使用 SQLAlchemy 数据库管理器,失败时自动降级到传统实现 - group_patterns = {} - try: - group_patterns = await db_manager.get_all_expression_patterns() - logger.debug(f"[WebUI] 使用 SQLAlchemy 获取表达模式: {len(group_patterns)} 个群组") - except Exception as e: - logger.warning(f"[WebUI] 获取表达模式失败,尝试使用表达模式学习器: {e}") - # 降级到表达模式学习器方法 - if hasattr(expression_learner, 'get_all_group_patterns'): - group_patterns = await expression_learner.get_all_group_patterns() - - if group_patterns: - total_confidence = 0 - pattern_count = 0 - style_types = set() - - for group_id, patterns in group_patterns.items(): - for pattern in patterns: - # 处理字典和对象两种格式 - if isinstance(pattern, dict): - style_types.add(pattern.get('style_type', 'general')) - total_confidence += pattern.get('weight', 0.5) - else: - style_types.add(getattr(pattern, 'style_type', 'general')) - total_confidence += getattr(pattern, 'weight', 0.5) - pattern_count += 1 - - stats['style_types_count'] = len(style_types) - stats['avg_confidence'] = round((total_confidence / pattern_count * 100) if pattern_count > 0 else 0, 1) - # 不再覆盖total_samples,保持使用原始消息总数 - - # 获取最新更新时间 - latest_time = 0 - for group_id, patterns in group_patterns.items(): - for pattern in patterns: - if hasattr(pattern, 'created_time'): - latest_time = max(latest_time, pattern.created_time) - - if latest_time > 0: - import time - from datetime import datetime - stats['latest_update'] = datetime.fromtimestamp(latest_time).strftime('%Y-%m-%d %H:%M') - - # 获取学习群组列表 - stats['learning_groups'] = list(group_patterns.keys()) if group_patterns else [] - - # 提取风格特征 - if group_patterns: - style_features = [] - for group_id, patterns in group_patterns.items(): - for pattern in patterns[:5]: # 只取前5个作为展示 - if hasattr(pattern, 'situation') and hasattr(pattern, 'expression'): - style_features.append({ - 'situation': pattern.situation, - 'expression': pattern.expression, - 'weight': getattr(pattern, 'weight', 0.5), - 'group_id': group_id - }) - - stats['style_features'] = style_features[:10] # 最多返回10个特征 - - except Exception as e: - logger.warning(f"获取表达模式统计失败: {e}") - - return jsonify(stats) - - except Exception as e: - logger.error(f"获取风格学习统计失败: {e}") - return jsonify({"error": f"获取统计数据失败: {str(e)}"}), 500 - -@api_bp.route("/style_learning/content", methods=["GET"]) -@require_auth -async def get_style_learning_content(): - """获取对话风格学习内容文本""" - try: - from .core.factory import FactoryManager - import os - - factory_manager = FactoryManager() - - # 获取数据库管理器 - service_factory = factory_manager.get_service_factory() - db_manager = service_factory.create_database_manager() - - # 获取消息关系分析器 - relationship_analyzer = service_factory.create_message_relationship_analyzer() - - content = { - 'dialogue_content': '', - 'analysis_content': '', - 'features_content': '', - 'history_content': '' - } - - group_id = request.args.get('group_id', 'default') - - try: - # 1. 获取对话示例文本 - recent_messages = await db_manager.get_recent_filtered_messages(group_id, limit=20) - if recent_messages: - relationships = await relationship_analyzer.analyze_message_relationships(recent_messages, group_id) - conversation_pairs = await relationship_analyzer.get_conversation_pairs(relationships) - - if conversation_pairs: - dialogue_lines = ["*Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:"] - for sender_content, reply_content in conversation_pairs[:5]: - dialogue_lines.append(f"A:{sender_content}") - dialogue_lines.append(f"B:{reply_content}") - content['dialogue_content'] = "\n".join(dialogue_lines) - else: - content['dialogue_content'] = "暂无对话示例数据" - else: - content['dialogue_content'] = "暂无消息数据" - - # 2. 获取风格分析结果 - component_factory = factory_manager.get_component_factory() - expression_learner = component_factory.create_expression_pattern_learner() - - try: - patterns = await expression_learner.get_expression_patterns(group_id, limit=10) - if patterns: - analysis_lines = ["*Communication patterns learned from all user interactions:"] - for i, pattern in enumerate(patterns[:4], 1): - situation = getattr(pattern, 'situation', '未知情境') - expression = getattr(pattern, 'expression', '未知表达') - analysis_lines.append(f"{i}. 在{situation}时,群组用户倾向于使用\"{expression}\"这样的表达") - content['analysis_content'] = "\n".join(analysis_lines) - else: - content['analysis_content'] = "*Communication patterns learned from all user interactions:\n1. 保持自然流畅的对话风格\n2. 根据语境调整回复的正式程度" - except Exception as e: - logger.warning(f"获取表达模式失败: {e}") - content['analysis_content'] = "*Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n1. 保持自然流畅的对话风格\n2. 根据语境调整回复的正式程度" - - # 3. 获取提炼的风格特征 - try: - patterns = await expression_learner.get_expression_patterns(group_id, limit=15) - if patterns: - features_lines = ["群组表达风格特征:"] - for i, pattern in enumerate(patterns[:8], 1): - situation = getattr(pattern, 'situation', '通用情境') - expression = getattr(pattern, 'expression', '未知表达') - weight = getattr(pattern, 'weight', 0.5) - features_lines.append(f"{i}. {situation}: \"{expression}\" (置信度: {weight:.2f})") - content['features_content'] = "\n".join(features_lines) - else: - content['features_content'] = "暂无提炼的风格特征" - except Exception as e: - logger.warning(f"获取风格特征失败: {e}") - content['features_content'] = "暂无提炼的风格特征" - - # 4. 获取学习历程记录 - try: - # 从数据库获取学习历史记录 - learning_sessions = await db_manager.get_learning_sessions(group_id, limit=5) - if learning_sessions: - history_lines = ["学习历程记录:"] - for session in learning_sessions: - timestamp = session.get('end_time', session.get('start_time', 0)) - if timestamp: - import time - from datetime import datetime - time_str = datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M') - style_updates = session.get('style_updates', 0) - total_messages = session.get('total_messages', 0) - history_lines.append(f"• {time_str}: 处理{total_messages}条消息,更新{style_updates}个风格") - content['history_content'] = "\n".join(history_lines) - else: - content['history_content'] = "暂无学习历程记录" - except Exception as e: - logger.warning(f"获取学习历史失败: {e}") - content['history_content'] = "暂无学习历程记录" - - except Exception as e: - logger.error(f"获取学习内容失败: {e}") - content = { - 'dialogue_content': f"获取对话内容失败: {str(e)}", - 'analysis_content': f"获取分析内容失败: {str(e)}", - 'features_content': f"获取特征内容失败: {str(e)}", - 'history_content': f"获取历程记录失败: {str(e)}" - } - - return jsonify(content) - - except Exception as e: - logger.error(f"获取风格学习内容失败: {e}") - return jsonify({"error": f"获取学习内容失败: {str(e)}"}), 500 - -@api_bp.route("/style_learning/trigger", methods=["POST"]) -@require_auth -async def trigger_style_learning(): - """手动触发对话风格学习""" - try: - data = await request.get_json() - group_id = data.get('group_id', 'default') - - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - component_factory = factory_manager.get_component_factory() - service_factory = factory_manager.get_service_factory() - - # 获取表达模式学习器 - expression_learner = component_factory.create_expression_pattern_learner() - db_manager = service_factory.create_database_manager() - - # 获取最近的原始消息 - recent_messages = await db_manager.get_recent_raw_messages(group_id, limit=30) - - if not recent_messages or len(recent_messages) < 3: - return jsonify({ - "success": False, - "message": f"群组 {group_id} 消息数量不足({len(recent_messages) if recent_messages else 0}条),无法进行学习", - "patterns_count": 0 - }) - - # 转换为 MessageData 格式 - from .core.interfaces import MessageData - import time - - message_data_list = [] - for msg in recent_messages: - if msg.get('sender_id') != "bot": # 不学习机器人的消息 - message_data = MessageData( - sender_id=msg.get('sender_id', ''), - sender_name=msg.get('sender_name', ''), - message=msg.get('message', ''), - group_id=group_id, - timestamp=msg.get('timestamp', time.time()), - platform=msg.get('platform', 'default'), - message_id=msg.get('message_id'), - reply_to=msg.get('reply_to') - ) - message_data_list.append(message_data) - - if len(message_data_list) < 3: - return jsonify({ - "success": False, - "message": f"有效用户消息数量不足({len(message_data_list)}条),无法进行学习", - "patterns_count": 0 - }) - - # 启动表达模式学习器 - if hasattr(expression_learner, '_status') and expression_learner._status.value != 'running': - await expression_learner.start() - - # 强制触发学习 - if hasattr(expression_learner, 'last_learning_times'): - expression_learner.last_learning_times[group_id] = 0 # 重置时间以强制学习 - - learning_success = await expression_learner.trigger_learning_for_group(group_id, message_data_list) - - if learning_success: - # 获取学习到的模式数量 - patterns = await expression_learner.get_expression_patterns(group_id, limit=20) - patterns_count = len(patterns) if patterns else 0 - - return jsonify({ - "success": True, - "message": f"群组 {group_id} 风格学习成功", - "patterns_count": patterns_count, - "processed_messages": len(message_data_list) - }) - else: - return jsonify({ - "success": False, - "message": "风格学习未产生有效结果", - "patterns_count": 0 - }) - - except Exception as e: - logger.error(f"触发风格学习失败: {e}") - return jsonify({ - "success": False, - "error": f"触发学习失败: {str(e)}", - "patterns_count": 0 - }), 500 - -@api_bp.route("/groups/info", methods=["GET"]) -@require_auth -async def get_groups_info(): - """获取所有群组的详细信息""" - logger.info("开始获取所有群组信息...") - try: - groups_info = { - 'total_groups': 0, - 'groups': [], - 'database_status': {}, - 'recommendations': [] - } - - if not database_manager: - return jsonify({'error': '数据库管理器不可用'}), 500 - - # 使用 ORM 查询(支持跨线程 event loop) - from sqlalchemy import select, func, and_ - from .models.orm import RawMessage as RawMsgORM, FilteredMessage as FilteredMsgORM - - # 1. 检查数据库总体状态 - logger.debug("检查数据库总体状态...") - stats = await database_manager.get_messages_statistics() - total_raw_messages = stats.get('total_messages', 0) - total_filtered_messages = stats.get('filtered_messages', 0) - - groups_info['database_status'] = { - 'total_raw_messages': total_raw_messages, - 'total_filtered_messages': total_filtered_messages, - 'tables_exist': True - } - - logger.info(f"数据库状态: 原始消息 {total_raw_messages} 条, 筛选消息 {total_filtered_messages} 条") - - # 2. 获取所有群组的详细信息 - if total_raw_messages > 0: - logger.debug("获取所有群组的详细统计...") - async with database_manager.get_session() as session: - # 查询各群组的统计信息 - stmt = select( - RawMsgORM.group_id, - func.count().label('message_count'), - func.min(RawMsgORM.timestamp).label('earliest_message'), - func.max(RawMsgORM.timestamp).label('latest_message'), - func.count(func.distinct(RawMsgORM.sender_id)).label('unique_senders') - ).where( - and_( - RawMsgORM.group_id.isnot(None), - RawMsgORM.group_id != '' - ) - ).group_by( - RawMsgORM.group_id - ).order_by( - func.count().desc() - ) - result = await session.execute(stmt) - group_rows = result.all() - - for row in group_rows: - group_id, message_count, earliest_ts, latest_ts, unique_senders = row - - # 获取该群组的筛选消息统计 - async with database_manager.get_session() as session: - filtered_stmt = select(func.count()).select_from(FilteredMsgORM).where( - FilteredMsgORM.group_id == group_id - ) - filtered_result = await session.execute(filtered_stmt) - filtered_count = filtered_result.scalar() or 0 - - # 计算时间范围 - import datetime - earliest_date = datetime.datetime.fromtimestamp(earliest_ts).strftime('%Y-%m-%d %H:%M:%S') if earliest_ts else 'N/A' - latest_date = datetime.datetime.fromtimestamp(latest_ts).strftime('%Y-%m-%d %H:%M:%S') if latest_ts else 'N/A' - - # 计算活跃度 - days_span = (latest_ts - earliest_ts) / 86400 if earliest_ts and latest_ts else 0 - avg_messages_per_day = message_count / max(1, days_span) if days_span > 0 else 0 - - group_info = { - 'group_id': group_id, - 'message_count': message_count, - 'filtered_count': filtered_count, - 'unique_senders': unique_senders, - 'earliest_message': earliest_date, - 'latest_message': latest_date, - 'days_span': round(days_span, 1), - 'avg_messages_per_day': round(avg_messages_per_day, 1), - 'learning_potential': 'high' if message_count > 100 and filtered_count > 10 else 'medium' if message_count > 20 else 'low' - } - - groups_info['groups'].append(group_info) - logger.debug(f"群组 {group_id}: {message_count} 条消息, {filtered_count} 条筛选, {unique_senders} 个用户") - - groups_info['total_groups'] = len(groups_info['groups']) - logger.info(f"找到 {groups_info['total_groups']} 个有消息记录的群组") - - else: - logger.warning("数据库中没有任何原始消息记录") - groups_info['recommendations'] = [ - "数据库中没有消息记录,这可能是因为:", - "1. 插件刚刚安装,还没有收集到消息", - "2. 消息收集功能未启用或配置错误", - "3. 群聊中没有足够的消息活动", - "建议: 在群聊中发送一些消息,然后重新检查" - ] - - # 3. 添加学习建议 - 修改为推荐所有群组都进行分析 - if groups_info['total_groups'] > 0: - groups_info['recommendations'] = [ - f"发现 {groups_info['total_groups']} 个群组,建议对所有群组进行完整的关系分析和风格学习:", - "• 使用 /groups/analyze_all 对所有群组进行关系分析", - "• 使用 /groups/style_learning_all 对所有群组进行表达模式和风格分析", - f"• 总计可分析原始消息: {total_raw_messages} 条" - ] - - # 为每个群组添加分析状态 - for group in groups_info['groups']: - if group['message_count'] > 50: - group['analysis_ready'] = True - group['analysis_recommendation'] = "可进行完整分析" - elif group['message_count'] > 10: - group['analysis_ready'] = True - group['analysis_recommendation'] = "可进行基础分析" - else: - group['analysis_ready'] = False - group['analysis_recommendation'] = "消息数量较少,建议积累更多消息" - - logger.info("群组信息获取完成") - return jsonify(groups_info) - - except Exception as e: - logger.error(f"获取群组信息失败: {e}", exc_info=True) - return jsonify({'error': str(e)}), 500 - -@api_bp.route("/groups/analyze_all", methods=["POST"]) -@require_auth -async def analyze_all_groups(): - """对所有群组进行关系分析和表达模式分析""" - logger.info("开始对所有群组进行关系分析...") - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - component_factory = factory_manager.get_component_factory() - - # 获取关系分析器和表达模式学习器 - relationship_analyzer = service_factory.create_message_relationship_analyzer() - expression_learner = component_factory.create_expression_pattern_learner() - db_manager = service_factory.create_database_manager() - - # 获取所有群组(ORM 查询,支持跨线程 event loop) - from sqlalchemy import select, func, and_ - from .models.orm import RawMessage as RawMsgGroupQuery - - async with db_manager.get_session() as session: - stmt = select( - RawMsgGroupQuery.group_id, - func.count().label('message_count') - ).where( - and_( - RawMsgGroupQuery.group_id.isnot(None), - RawMsgGroupQuery.group_id != '' - ) - ).group_by( - RawMsgGroupQuery.group_id - ).having( - func.count() >= 10 - ).order_by( - func.count().desc() - ) - result = await session.execute(stmt) - all_groups = result.all() - - if not all_groups: - return jsonify({ - 'success': False, - 'message': '没有找到足够消息的群组进行分析', - 'analyzed_groups': [] - }) - - analysis_results = [] - - for group_id, message_count in all_groups: - logger.info(f"开始分析群组 {group_id} (消息数: {message_count})") - - try: - # 1. 获取原始消息 - recent_messages = await db_manager.get_recent_raw_messages(group_id, limit=200) - - if not recent_messages or len(recent_messages) < 5: - logger.warning(f"群组 {group_id} 消息数量不足,跳过分析") - continue - - # 2. 过滤和格式化消息 - formatted_messages = [] - for msg in recent_messages: - message_content = msg.get('message', '') - sender_id = msg.get('sender_id', '') - - # 基础过滤 - if len(message_content.strip()) < 5 or len(message_content) > 500: - continue - if sender_id == "bot": - continue - if message_content.strip() in ['', '???', '。。。', '...', '嗯', '哦', '额']: - continue - - # @符号处理 - import re - processed_message = message_content - if '@' in message_content: - at_pattern = r'@[^\s]+\s+' - processed_message = re.sub(at_pattern, '', message_content).strip() - if len(processed_message.strip()) < 5: - continue - - formatted_msg = { - 'id': msg.get('id'), - 'sender_id': sender_id, - 'sender_name': msg.get('sender_name', ''), - 'message': processed_message, - 'group_id': msg.get('group_id'), - 'timestamp': msg.get('timestamp'), - 'platform': msg.get('platform', 'default') - } - formatted_messages.append(formatted_msg) - - logger.info(f"群组 {group_id} 过滤后可用消息数: {len(formatted_messages)}") - - if len(formatted_messages) < 3: - logger.warning(f"群组 {group_id} 过滤后消息数量不足,跳过分析") - continue - - # 3. 进行关系分析 - logger.info(f"开始分析群组 {group_id} 的消息关系...") - relationships = await relationship_analyzer.analyze_message_relationships(formatted_messages, group_id) - - # 4. 提取对话对 - conversation_pairs = await relationship_analyzer.get_conversation_pairs(relationships) - - # 5. 转换为MessageData格式进行表达模式学习 - from .core.interfaces import MessageData - message_data_list = [] - for msg in formatted_messages: - message_data = MessageData( - sender_id=msg['sender_id'], - sender_name=msg['sender_name'], - message=msg['message'], - group_id=msg['group_id'], - timestamp=msg['timestamp'], - platform=msg['platform'], - message_id=msg['id'], - reply_to=None - ) - message_data_list.append(message_data) - - # 6. 启动表达模式学习器并触发学习 - if hasattr(expression_learner, '_status') and expression_learner._status.value != 'running': - await expression_learner.start() - - # 强制学习(重置时间限制) - if hasattr(expression_learner, 'last_learning_times'): - expression_learner.last_learning_times[group_id] = 0 - - learning_success = await expression_learner.trigger_learning_for_group(group_id, message_data_list) - - # 7. 获取学习结果 - patterns = await expression_learner.get_expression_patterns(group_id, limit=10) - patterns_count = len(patterns) if patterns else 0 - - analysis_result = { - 'group_id': group_id, - 'message_count': message_count, - 'processed_messages': len(formatted_messages), - 'conversation_pairs': len(conversation_pairs) if conversation_pairs else 0, - 'expression_patterns': patterns_count, - 'learning_success': learning_success, - 'analysis_completed': True - } - - analysis_results.append(analysis_result) - logger.info(f"群组 {group_id} 分析完成: 对话对 {analysis_result['conversation_pairs']}, 表达模式 {patterns_count}") - - except Exception as e: - logger.error(f"分析群组 {group_id} 失败: {e}") - analysis_results.append({ - 'group_id': group_id, - 'message_count': message_count, - 'processed_messages': 0, - 'conversation_pairs': 0, - 'expression_patterns': 0, - 'learning_success': False, - 'analysis_completed': False, - 'error': str(e) - }) - - # 统计总结果 - successful_groups = [r for r in analysis_results if r.get('analysis_completed', False)] - total_conversation_pairs = sum(r.get('conversation_pairs', 0) for r in analysis_results) - total_expression_patterns = sum(r.get('expression_patterns', 0) for r in analysis_results) - - return jsonify({ - 'success': True, - 'message': f'所有群组分析完成', - 'summary': { - 'total_groups': len(all_groups), - 'successful_groups': len(successful_groups), - 'total_conversation_pairs': total_conversation_pairs, - 'total_expression_patterns': total_expression_patterns - }, - 'analyzed_groups': analysis_results - }) - - except Exception as e: - logger.error(f"分析所有群组失败: {e}", exc_info=True) - return jsonify({ - 'success': False, - 'error': f'分析失败: {str(e)}', - 'analyzed_groups': [] - }), 500 - -@api_bp.route("/groups/style_learning_all", methods=["POST"]) -@require_auth -async def style_learning_all_groups(): - """对所有群组进行风格学习并提交审查""" - logger.info("开始对所有群组进行风格学习...") - try: - from .core.factory import FactoryManager - import time - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - component_factory = factory_manager.get_component_factory() - - # 获取必要服务 - relationship_analyzer = service_factory.create_message_relationship_analyzer() - expression_learner = component_factory.create_expression_pattern_learner() - db_manager = service_factory.create_database_manager() - - # 获取所有群组(ORM 查询,支持跨线程 event loop) - from sqlalchemy import select, func, and_ - from .models.orm import RawMessage as RawMsgGroupQuery - - async with db_manager.get_session() as session: - stmt = select( - RawMsgGroupQuery.group_id, - func.count().label('message_count') - ).where( - and_( - RawMsgGroupQuery.group_id.isnot(None), - RawMsgGroupQuery.group_id != '' - ) - ).group_by( - RawMsgGroupQuery.group_id - ).having( - func.count() >= 10 - ).order_by( - func.count().desc() - ) - result = await session.execute(stmt) - all_groups = result.all() - - if not all_groups: - return jsonify({ - 'success': False, - 'message': '没有找到足够消息的群组进行风格学习', - 'style_learning_results': [] - }) - - style_learning_results = [] - - for group_id, message_count in all_groups: - logger.info(f"开始为群组 {group_id} 进行风格学习 (消息数: {message_count})") - - try: - # 1. 获取并处理消息(与analyze_all相同的逻辑) - recent_raw_messages = await db_manager.get_recent_raw_messages(group_id, limit=100) - - if not recent_raw_messages: - logger.warning(f"群组 {group_id} 没有原始消息,跳过风格学习") - continue - - # 2. 过滤消息 - formatted_messages = [] - for msg in recent_raw_messages: - message_content = msg.get('message', '') - sender_id = msg.get('sender_id', '') - - # 使用相同的过滤逻辑 - if len(message_content.strip()) < 5 or len(message_content) > 500: - continue - if sender_id == "bot": - continue - if message_content.strip() in ['', '???', '。。。', '...', '嗯', '哦', '额']: - continue - - # @符号处理 - import re - processed_message = message_content - if '@' in message_content: - at_pattern = r'@[^\s]+\s+' - processed_message = re.sub(at_pattern, '', message_content).strip() - if len(processed_message.strip()) < 5: - continue - - formatted_msg = { - 'id': msg.get('id'), - 'sender_id': sender_id, - 'sender_name': msg.get('sender_name', ''), - 'message': processed_message, - 'group_id': msg.get('group_id'), - 'timestamp': msg.get('timestamp'), - 'platform': msg.get('platform', 'default') - } - formatted_messages.append(formatted_msg) - - if len(formatted_messages) < 3: - logger.warning(f"群组 {group_id} 过滤后消息数量不足,跳过风格学习") - continue - - # 3. 进行关系分析获取对话对 - relationships = await relationship_analyzer.analyze_message_relationships(formatted_messages, group_id) - conversation_pairs = await relationship_analyzer.get_conversation_pairs(relationships) - - if not conversation_pairs: - logger.warning(f"群组 {group_id} 未找到有效对话关系,跳过风格学习") - continue - - # 4. 生成对话内容(few shots格式) - dialogue_lines = [f"*Here are examples of real conversations between users in group {group_id}:"] - for sender_content, reply_content in conversation_pairs[:6]: # 取前6个对话对 - dialogue_lines.append(f"A:{sender_content}") - dialogue_lines.append(f"B:{reply_content}") - - dialogue_content = "\n".join(dialogue_lines) - - # 5. 进行表达模式学习 - patterns_learned = 0 - analysis_content = "*Communication style patterns observed in group conversations:\n1. 保持自然流畅的对话风格\n2. 根据语境调整回复的正式程度" - features_content = "提炼的风格特征:\n1. 自然对话风格\n2. 适度的情感表达" - - try: - # 转换为MessageData格式 - from .core.interfaces import MessageData - message_data_list = [] - for msg in formatted_messages: - message_data = MessageData( - sender_id=msg['sender_id'], - sender_name=msg['sender_name'], - message=msg['message'], - group_id=msg['group_id'], - timestamp=msg['timestamp'], - platform=msg['platform'], - message_id=msg['id'], - reply_to=None - ) - message_data_list.append(message_data) - - # 启动并触发学习 - if hasattr(expression_learner, '_status') and expression_learner._status.value != 'running': - await expression_learner.start() - - if hasattr(expression_learner, 'last_learning_times'): - expression_learner.last_learning_times[group_id] = 0 - - learning_success = await expression_learner.trigger_learning_for_group(group_id, message_data_list) - - if learning_success: - patterns = await expression_learner.get_expression_patterns(group_id, limit=10) - if patterns: - patterns_learned = len(patterns) - - # 生成更详细的分析内容 - analysis_lines = [f"*Communication style patterns observed from all user interactions in {group_id}:"] - for i, pattern in enumerate(patterns[:4], 1): - situation = getattr(pattern, 'situation', '未知情境') - expression = getattr(pattern, 'expression', '未知表达') - analysis_lines.append(f"{i}. 当{situation}时,群组用户使用\"{expression}\"这样的表达") - analysis_content = "\n".join(analysis_lines) - - # 生成特征内容 - features_lines = [f"群组 {group_id} 对话风格特征:"] - for i, pattern in enumerate(patterns[:6], 1): - situation = getattr(pattern, 'situation', '未知情境') - expression = getattr(pattern, 'expression', '未知表达') - features_lines.append(f"{i}. {situation}: {expression}") - features_content = "\n".join(features_lines) - - except Exception as e: - logger.warning(f"群组 {group_id} 表达模式学习失败: {e}") - - # 6. 生成完整的风格学习内容 - full_style_content = f"""## 真实对话示例 - 群组 {group_id} -{dialogue_content} - -## 群组风格分析 -{analysis_content} - -## {features_content} - -## 学习来源 -全群组风格学习 - 基于{len(conversation_pairs)}个真实用户对话对的深度分析 - -## 数据说明 -- 分析了群组 {group_id} 中任意用户之间的真实对话 -- 提取了用户间的对话关系和表达模式 ({patterns_learned} 个表达模式) -- 学习内容反映群组整体的对话风格特征 -- 处理原始消息: {len(recent_raw_messages)} 条,过滤后: {len(formatted_messages)} 条""" - - # 7. 提交到人格审查系统 - review_submitted = False - try: - # 使用智能置信度计算 - confidence_score = 0.85 # 默认值 - if intelligence_metrics_service: - try: - # 获取当前人格内容 - current_persona_content = "" - try: - persona_web_mgr = get_persona_web_manager() - if persona_web_mgr: - current_persona = await persona_web_mgr.get_default_persona() - current_persona_content = current_persona.get('prompt', '') - except: - pass - - # 计算智能置信度 - confidence_metrics = await intelligence_metrics_service.calculate_persona_confidence( - proposed_content=full_style_content, - original_content=current_persona_content, - learning_source=f"全群组风格学习-{group_id}", - message_count=len(formatted_messages), - llm_adapter=llm_client if llm_client else None - ) - confidence_score = confidence_metrics.overall_confidence - logger.info(f"智能置信度计算: {confidence_score:.3f} (详情: {confidence_metrics.evaluation_basis.get('method', 'unknown')})") - except Exception as conf_error: - logger.warning(f"智能置信度计算失败,使用默认值: {conf_error}") - - # 检查是否有人格学习审查方法 - if hasattr(db_manager, 'add_persona_learning_review'): - await db_manager.add_persona_learning_review( - group_id=group_id, - proposed_content=full_style_content, - learning_source=f"全群组风格学习-{group_id}", - confidence_score=confidence_score, - raw_analysis=f"基于{len(conversation_pairs)}个对话对和{patterns_learned}个表达模式", - metadata={ - "all_groups_learning": True, - "conversation_pairs": len(conversation_pairs), - "patterns_count": patterns_learned, - "messages_analyzed": len(formatted_messages), - "original_messages": len(recent_raw_messages) - } - ) - review_submitted = True - logger.info(f"群组 {group_id} 风格学习审查已提交") - else: - # 回退方法:保存到通用审查记录 - await db_manager.save_persona_update_record({ - 'timestamp': time.time(), - 'group_id': group_id, - 'update_type': 'all_groups_style_learning', - 'original_content': '群组风格特征', - 'new_content': full_style_content, - 'reason': f'全群组风格学习-基于{len(conversation_pairs)}个对话对的关系分析', - 'status': 'pending' - }) - review_submitted = True - logger.info(f"群组 {group_id} 风格学习审查已保存") - - except Exception as e: - logger.error(f"群组 {group_id} 提交风格学习审查失败: {e}") - - learning_result = { - 'group_id': group_id, - 'message_count': message_count, - 'processed_messages': len(formatted_messages), - 'conversation_pairs': len(conversation_pairs), - 'expression_patterns': patterns_learned, - 'review_submitted': review_submitted, - 'learning_completed': True - } - - style_learning_results.append(learning_result) - logger.info(f"群组 {group_id} 风格学习完成: 对话对 {len(conversation_pairs)}, 模式 {patterns_learned}") - - except Exception as e: - logger.error(f"群组 {group_id} 风格学习失败: {e}") - style_learning_results.append({ - 'group_id': group_id, - 'message_count': message_count, - 'processed_messages': 0, - 'conversation_pairs': 0, - 'expression_patterns': 0, - 'review_submitted': False, - 'learning_completed': False, - 'error': str(e) - }) - - # 统计总结果 - successful_learning = [r for r in style_learning_results if r.get('learning_completed', False)] - total_reviews_submitted = sum(1 for r in style_learning_results if r.get('review_submitted', False)) - total_conversation_pairs = sum(r.get('conversation_pairs', 0) for r in style_learning_results) - total_expression_patterns = sum(r.get('expression_patterns', 0) for r in style_learning_results) - - return jsonify({ - 'success': True, - 'message': f'所有群组风格学习完成', - 'summary': { - 'total_groups': len(all_groups), - 'successful_learning': len(successful_learning), - 'reviews_submitted': total_reviews_submitted, - 'total_conversation_pairs': total_conversation_pairs, - 'total_expression_patterns': total_expression_patterns - }, - 'style_learning_results': style_learning_results - }) - - except Exception as e: - logger.error(f"所有群组风格学习失败: {e}", exc_info=True) - return jsonify({ - 'success': False, - 'error': f'风格学习失败: {str(e)}', - 'style_learning_results': [] - }), 500 - -@api_bp.route("/relearn", methods=["POST"]) -@require_auth -async def relearn_all(): - """重新学习按钮 - 包括风格重新学习""" - try: - # 处理空请求体的情况 - data = {} - try: - if request.is_json and await request.get_data(): - data = await request.get_json() - except Exception: - # 如果JSON解析失败,使用默认空字典 - data = {} - - # 获取实际的群组ID,如果没有指定则尝试从数据库中获取第一个有消息的群组 - group_id = data.get('group_id') - include_style_learning = data.get('include_style_learning', True) - - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - component_factory = factory_manager.get_component_factory() - db_manager = service_factory.create_database_manager() - - # 如果没有指定群组ID,自动检测有消息记录的群组 - if not group_id or group_id == 'default': - # 使用 ORM 查询(支持跨线程 event loop) - logger.info("正在检查数据库中的所有消息记录...") - stats = await db_manager.get_messages_statistics() - total_count = stats.get('total_messages', 0) - logger.info(f"raw_messages表中总共有 {total_count} 条记录") - - if total_count > 0: - # 通过 ORM session 查询各群组的消息统计 - from sqlalchemy import select, func, and_ - from .models.orm import RawMessage - - async with db_manager.get_session() as session: - stmt = select( - RawMessage.group_id, - func.count().label('message_count') - ).where( - and_( - RawMessage.group_id.isnot(None), - RawMessage.group_id != '' - ) - ).group_by( - RawMessage.group_id - ).order_by( - func.count().desc() - ) - result = await session.execute(stmt) - all_results = result.all() - - logger.info(f"数据库中发现的所有群组: {[(r[0], r[1]) for r in all_results] if all_results else '无'}") - - # 选择消息数最多的群组 - if all_results: - group_id = all_results[0][0] - message_count = all_results[0][1] - logger.info(f"自动选择群组ID: {group_id} (共有{message_count}条原始消息)") - else: - logger.warning("虽然有消息记录,但没有有效的群组ID") - group_id = 'default' # 兜底使用default - else: - # 没有任何消息,检查系统状态 - logger.warning("数据库中没有任何原始消息记录") - - filtered_count = stats.get('filtered_messages', 0) - logger.info(f"filtered_messages表中有 {filtered_count} 条记录") - - # 提供解决建议 - logger.warning("建议解决方案:") - logger.warning("1. 检查消息收集功能是否正常工作") - logger.warning("2. 确认群聊中有足够的消息") - logger.warning("3. 检查插件的消息捕获配置") - - group_id = 'default' # 兜底使用default - - results = { - 'success': True, - 'message': '', - 'group_id': group_id, # 返回实际使用的群组ID - 'progressive_learning': False, - 'style_learning': False, - 'processed_messages': 0, - 'new_patterns': 0, - 'persona_update_submitted': False, - 'errors': [], - 'total_messages': 0 - } - - try: - # 1. 重新执行渐进式学习 - progressive_learning = service_factory.create_progressive_learning() - db_manager = service_factory.create_database_manager() - - logger.info(f"开始重新学习群组 {group_id}...") - - # 检查消息数量(但不强制要求) - 添加连接重试逻辑 - logger.debug(f"开始获取群组 {group_id} 的消息统计...") - try: - stats = await db_manager.get_group_messages_statistics(group_id) - total_messages = stats.get('total_messages', 0) - results['total_messages'] = total_messages - logger.info(f"群组 {group_id} 消息统计: {total_messages} 条总消息") - except Exception as stats_error: - logger.warning(f"获取群组 {group_id} 消息统计失败: {stats_error}") - total_messages = 0 - results['total_messages'] = 0 - results['errors'].append(f"获取消息统计失败: {str(stats_error)}") - - # 执行渐进式学习批次 - try: - # ✅ 重新学习模式:传递 relearn_mode=True 以忽略"已处理"标记 - await progressive_learning._execute_learning_batch(group_id, relearn_mode=True) - results['progressive_learning'] = True - results['processed_messages'] = total_messages - logger.info(f"群组 {group_id} 渐进式学习重新执行完成(重新学习模式)") - except Exception as e: - error_msg = f"渐进式学习失败: {str(e)}" - results['errors'].append(error_msg) - logger.error(error_msg) - - # 2. 风格重新学习(遵循原有逻辑:关系分析->A,B对话提取->按格式加入人格审查) - if include_style_learning: - try: - import time - logger.info(f"开始为群组 {group_id} 进行风格重新学习...") - - # 获取消息关系分析器 - relationship_analyzer = service_factory.create_message_relationship_analyzer() - - # 获取最近的原始消息用于风格分析(不需要筛选) - logger.info(f"正在为群组 {group_id} 获取原始消息进行风格分析...") - recent_raw_messages = await db_manager.get_recent_raw_messages(group_id, limit=100) - logger.info(f"群组 {group_id} 获取到 {len(recent_raw_messages) if recent_raw_messages else 0} 条原始消息") - - if recent_raw_messages: - # 直接使用原始消息,不进行筛选过滤 - # 将原始消息转换为统一格式用于风格学习 - formatted_messages = [] - for msg in recent_raw_messages: - message_content = msg.get('message', '') - sender_id = msg.get('sender_id', '') - - # 只进行最基本的过滤: 跳过机器人消息和完全空白的消息 - if sender_id == "bot": - continue - if not message_content.strip(): - continue - - # 保持消息原样,不进行任何内容处理和筛选 - formatted_msg = { - 'id': msg.get('id'), - 'sender_id': sender_id, - 'sender_name': msg.get('sender_name', ''), - 'message': message_content, # 保持原始消息内容 - 'group_id': msg.get('group_id'), - 'timestamp': msg.get('timestamp'), - 'platform': msg.get('platform', 'default') - } - formatted_messages.append(formatted_msg) - - logger.info(f"群组 {group_id} 使用未筛选的原始消息数: {len(formatted_messages)}") - - # ========== 功能1: 表达模式学习(风格学习) - 使用所有原始消息 ========== - # 这部分独立运行,不依赖关系分析 - component_factory = factory_manager.get_component_factory() - expression_learner = component_factory.create_expression_pattern_learner() - - # 将原始消息转换为MessageData格式进行风格学习 - from .core.interfaces import MessageData - import time - - message_data_list = [] - for msg in formatted_messages: - message_data = MessageData( - sender_id=msg['sender_id'], - sender_name=msg['sender_name'], - message=msg['message'], # 原始消息内容 - group_id=msg['group_id'], - timestamp=msg['timestamp'], - platform=msg['platform'], - message_id=msg['id'], - reply_to=None - ) - message_data_list.append(message_data) - - logger.info(f"开始为群组 {group_id} 进行表达模式学习(使用未筛选消息),消息数: {len(message_data_list)}") - - # 触发表达模式学习 - learning_success = False - if message_data_list and len(message_data_list) >= 5: # 至少5条消息 - try: - # 启动表达模式学习器 - if hasattr(expression_learner, '_status') and expression_learner._status.value != 'running': - await expression_learner.start() - - # 强制重新学习(无时间限制) - if hasattr(expression_learner, 'last_learning_times'): - expression_learner.last_learning_times[group_id] = 0 # 重置时间 - - # 触发学习 - learning_success = await expression_learner.trigger_learning_for_group(group_id, message_data_list) - logger.info(f"群组 {group_id} 表达模式学习结果: {learning_success}") - results['style_learning'] = True - results['messages_analyzed'] = len(message_data_list) - - except Exception as learning_error: - logger.error(f"表达模式学习失败: {learning_error}", exc_info=True) - learning_success = False - results['errors'].append(f"表达模式学习失败: {str(learning_error)}") - else: - logger.warning(f"群组 {group_id} 消息数不足({len(message_data_list)}条),需要至少5条消息") - - - # ========== 功能2: 消息关系分析 - 用于生成人格审查数据 ========== - # 这部分用于分析A→B对话对,生成人格更新审查申请 - logger.info(f"开始分析群组 {group_id} 的消息关系(用于人格审查)...") - relationships = await relationship_analyzer.analyze_message_relationships(formatted_messages, group_id) - - # 提取A,B对话对 - conversation_pairs = await relationship_analyzer.get_conversation_pairs(relationships) - logger.info(f"群组 {group_id} 提取到 {len(conversation_pairs) if conversation_pairs else 0} 个对话对") - - # 只有当有对话对时,才生成人格审查数据 - if conversation_pairs and len(conversation_pairs) > 0: - # 步骤3: 按照严格格式生成对话内容 - # 说明:这里的A、B代表群组中任意两个用户之间的对话,用于学习真实的对话风格 - dialogue_lines = ["*Here are examples of real conversations between users in this group:"] - for sender_content, reply_content in conversation_pairs[:8]: # 取更多对话对用于重新学习 - dialogue_lines.append(f"A:{sender_content}") - dialogue_lines.append(f"B:{reply_content}") - - dialogue_content = "\n".join(dialogue_lines) - - # 步骤4: 获取已经学习的表达模式(使用之前独立运行的风格学习结果) - analysis_content = "*Communication style patterns observed in group conversations:\n1. 保持自然流畅的对话风格\n2. 根据语境调整回复的正式程度" - features_content = "提炼的风格特征:\n1. 自然对话风格\n2. 适度的情感表达" - llm_raw_response = "" # 保存LLM原始响应 - - try: - patterns = await expression_learner.get_expression_patterns(group_id, limit=10) - if patterns: - # 生成分析内容 - 基于任何人与任何人之间的对话分析 - analysis_lines = ["*Communication style patterns observed from all user interactions:"] - for i, pattern in enumerate(patterns[:4], 1): - situation = getattr(pattern, 'situation', '未知情境') - expression = getattr(pattern, 'expression', '未知表达') - analysis_lines.append(f"{i}. 当{situation}时,群组用户使用\"{expression}\"这样的表达") - analysis_content = "\n".join(analysis_lines) - - # 生成特征内容 - 反映群组整体的对话风格 - features_lines = ["群组对话风格特征:"] - for i, pattern in enumerate(patterns[:6], 1): - situation = getattr(pattern, 'situation', '未知情境') - expression = getattr(pattern, 'expression', '未知表达') - features_lines.append(f"{i}. {situation}: {expression}") - features_content = "\n".join(features_lines) - - # 构建LLM响应格式(用于前端显示) - llm_response_lines = [] - for pattern in patterns[:10]: - situation = getattr(pattern, 'situation', '') - expression = getattr(pattern, 'expression', '') - if situation and expression: - llm_response_lines.append(f'当"{situation}"时,使用"{expression}"') - llm_raw_response = "\n".join(llm_response_lines) - - results['new_patterns'] = len(patterns) - except Exception as e: - logger.warning(f"获取表达模式失败: {e}") - - # 步骤5: 生成完整的风格学习内容 - full_style_content = f"""## 真实对话示例 -{dialogue_content} - -## 群组风格分析 -{analysis_content} - -## {features_content} - -## 学习来源 -重新学习模式 - 基于{len(conversation_pairs)}个真实用户对话对的深度分析 - -## 数据说明 -- 分析了群组中任意用户之间的真实对话 -- 提取了用户间的对话关系和表达模式 -- 学习内容反映群组整体的对话风格特征""" - - # 步骤6: 提交到人格审查系统 - try: - # 获取原始消息总数(未筛选的) - total_raw_messages = len(recent_raw_messages) - - # 使用智能置信度计算 - confidence_score = 0.85 # 默认值 - if intelligence_metrics_service: - try: - # 获取当前人格内容 - current_persona_content = "" - try: - persona_web_mgr = get_persona_web_manager() - if persona_web_mgr: - current_persona = await persona_web_mgr.get_default_persona() - current_persona_content = current_persona.get('prompt', '') - except: - pass - - # 计算智能置信度 - confidence_metrics = await intelligence_metrics_service.calculate_persona_confidence( - proposed_content=full_style_content, - original_content=current_persona_content, - learning_source="重新学习-关系分析", - message_count=len(formatted_messages), - llm_adapter=llm_client if llm_client else None - ) - confidence_score = confidence_metrics.overall_confidence - logger.info(f"重新学习智能置信度: {confidence_score:.3f}") - except Exception as conf_error: - logger.warning(f"智能置信度计算失败,使用默认值: {conf_error}") - - # 检查是否有add_persona_learning_review方法 - if hasattr(db_manager, 'add_persona_learning_review'): - # ✅ 获取当前人格作为 original_content - original_persona_content = "" - try: - persona_web_mgr = get_persona_web_manager() - if persona_web_mgr: - current_persona = await persona_web_mgr.get_default_persona() - original_persona_content = current_persona.get('prompt', '') - except Exception as e: - logger.warning(f"获取原人格失败: {e}") - original_persona_content = "" - - # ✅ 构建完整的新人格内容(原人格 + 风格学习内容) - full_new_persona = original_persona_content + "\n\n" + full_style_content if original_persona_content else full_style_content - - await db_manager.add_persona_learning_review( - group_id=group_id, - proposed_content=full_style_content, # 增量内容 - learning_source=UPDATE_TYPE_STYLE_LEARNING, # ✅ 使用常量 - confidence_score=confidence_score, - raw_analysis=llm_raw_response if llm_raw_response else f"基于{len(conversation_pairs)}个对话对和{results.get('new_patterns', 0)}个表达模式", - metadata={ - "relearn_triggered": True, - "conversation_pairs": len(conversation_pairs), - "patterns_count": results.get('new_patterns', 0), - "total_raw_messages": total_raw_messages, # 原始消息总数 - "messages_analyzed": len(formatted_messages), # 实际分析的消息数 - "llm_response": llm_raw_response, # LLM原始响应 - "features_content": features_content, # 风格特征内容 - "incremental_content": full_style_content, # ✅ 增量内容 - "incremental_start_pos": len(original_persona_content) + 2 if original_persona_content else 0 # ✅ 高亮位置 - }, - original_content=original_persona_content, # ✅ 传递原人格 - new_content=full_new_persona # ✅ 传递完整新人格 - ) - else: - # 使用现有的人格更新记录方法 - await db_manager.save_persona_update_record({ - 'timestamp': time.time(), - 'group_id': group_id, - 'update_type': 'style_relearning', - 'original_content': '原有风格特征', - 'new_content': full_style_content, - 'reason': f'重新学习-基于{len(conversation_pairs)}个对话对的关系分析', - 'status': 'pending' - }) - - results['persona_update_submitted'] = True - results['style_learning'] = True - logger.info(f"群组 {group_id} 风格学习审查申请已提交") - - except Exception as e: - logger.error(f"提交风格学习审查失败: {e}", exc_info=True) - results['errors'].append(f"提交审查失败: {str(e)}") - - logger.info(f"群组 {group_id} 风格重新学习完成,分析了 {len(conversation_pairs)} 个对话对") - - else: - # 没有对话对时,使用所有过滤后的消息进行基础风格学习 - logger.warning(f"群组 {group_id} 未找到对话对,将基于所有消息进行基础风格学习(消息数: {len(formatted_messages)})") - - if len(formatted_messages) >= 5: # 至少需要5条消息才能进行学习 - # 步骤3: 进行基础风格分析学习 - 基于所有过滤后的消息 - component_factory = factory_manager.get_component_factory() - expression_learner = component_factory.create_expression_pattern_learner() - - # 将过滤后的消息转换为MessageData格式 - from .core.interfaces import MessageData - import time - - message_data_list = [] - for msg in formatted_messages: - message_data = MessageData( - sender_id=msg['sender_id'], - sender_name=msg['sender_name'], - message=msg['message'], - group_id=msg['group_id'], - timestamp=msg['timestamp'], - platform=msg['platform'], - message_id=msg['id'], - reply_to=None - ) - message_data_list.append(message_data) - - logger.info(f"开始为群组 {group_id} 进行基础表达模式学习,消息数: {len(message_data_list)}") - - # 触发表达模式学习 - if message_data_list: - try: - # 启动表达模式学习器 - if hasattr(expression_learner, '_status') and expression_learner._status.value != 'running': - await expression_learner.start() - - # 强制重新学习 - if hasattr(expression_learner, 'last_learning_times'): - expression_learner.last_learning_times[group_id] = 0 - - # 触发学习 - learning_success = await expression_learner.trigger_learning_for_group(group_id, message_data_list) - logger.info(f"群组 {group_id} 基础表达模式学习结果: {learning_success}") - - results['style_learning'] = True - results['messages_analyzed'] = len(message_data_list) - logger.info(f"群组 {group_id} 基础风格学习完成,分析了 {len(message_data_list)} 条消息") - - except Exception as learning_error: - logger.error(f"基础表达模式学习失败: {learning_error}", exc_info=True) - results['errors'].append(f"基础学习失败: {str(learning_error)}") - else: - error_msg = f"群组 {group_id} 消息数不足({len(formatted_messages)}条),需要至少5条消息才能学习" - results['errors'].append(error_msg) - logger.warning(error_msg) - else: - # 当没有找到原始消息时,提供更详细的调试信息 - total_stats = await db_manager.get_messages_statistics() - group_stats = await db_manager.get_group_messages_statistics(group_id) - - # 通过 ORM 查询所有群组的原始消息统计 - from sqlalchemy import select, func, and_ - from .models.orm import RawMessage as RawMessageModel - - async with db_manager.get_session() as session: - stmt = select( - RawMessageModel.group_id, - func.count().label('raw_count') - ).where( - and_( - RawMessageModel.group_id.isnot(None), - RawMessageModel.group_id != '' - ) - ).group_by( - RawMessageModel.group_id - ).order_by( - func.count().desc() - ) - result = await session.execute(stmt) - raw_results = result.all() - - error_msg = f"群组 {group_id} 没有找到原始消息,跳过风格学习。\n" \ - f"全局统计: {total_stats}\n" \ - f"当前群组统计: {group_stats}\n" \ - f"所有群组原始消息: {[(r[0], r[1]) for r in raw_results] if raw_results else '无'}" - results['errors'].append(error_msg) - logger.warning(error_msg) - - except Exception as e: - error_msg = f"风格重新学习失败: {str(e)}" - results['errors'].append(error_msg) - logger.error(error_msg, exc_info=True) - - # 3. 构建结果消息 - success_parts = [] - if results['progressive_learning']: - success_parts.append(f"渐进式学习已完成(处理{results['processed_messages']}条消息)") - if results['style_learning']: - success_parts.append(f"风格重新学习已完成(学到{results['new_patterns']}个新模式)") - if results['persona_update_submitted']: - success_parts.append("人格更新申请已提交,等待审查") - - if success_parts: - results['message'] = "重新学习完成:" + ",".join(success_parts) - - if results['errors']: - results['message'] += f"。注意:{len(results['errors'])}个警告" - else: - results['success'] = False - results['message'] = "重新学习失败:" + ";".join(results['errors']) if results['errors'] else "未知错误" - - except Exception as e: - results['success'] = False - results['message'] = f"重新学习过程中发生严重错误: {str(e)}" - logger.error(f"重新学习失败: {e}", exc_info=True) - - return jsonify(results) - - except Exception as e: - logger.error(f"重新学习API失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": f"重新学习请求失败: {str(e)}", - "progressive_learning": False, - "style_learning": False, - "processed_messages": 0, - "new_patterns": 0, - "persona_update_submitted": False, - "total_messages": 0 - }), 500 - -async def _generate_persona_update_from_patterns(patterns, group_id: str) -> str: - """基于风格模式生成人格更新内容""" - try: - if not patterns: - return "" - - # 构建风格学习文本 - style_lines = ["*Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:"] - - # 提取主要风格特征 - for i, pattern in enumerate(patterns[:4], 1): # 取前4个最重要的模式 - situation = getattr(pattern, 'situation', '通用情境') - expression = getattr(pattern, 'expression', '自然表达') - weight = getattr(pattern, 'weight', 0.5) - - # 生成具体的风格建议 - if weight > 0.7: - style_lines.append(f"{i}. 在{situation}时,要{expression},保持这种高置信度的表达风格") - elif weight > 0.5: - style_lines.append(f"{i}. 当遇到{situation}的情况,适当使用{expression}的方式回应") - else: - style_lines.append(f"{i}. 参考{situation}场景下的{expression}表达方式,灵活运用") - - # 构建Few Shots对话示例 - few_shots_lines = [ - "", - "*Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:" - ] - - # 基于模式生成示例对话 - for i, pattern in enumerate(patterns[:3], 1): # 前3个模式作为对话示例 - situation = getattr(pattern, 'situation', '询问问题') - expression = getattr(pattern, 'expression', '好的,我来帮你') - - # 生成符合模式的示例对话 - few_shots_lines.append(f"A:{situation}") - few_shots_lines.append(f"B:{expression}") - - # 合并所有内容 - full_content = "\n".join(style_lines + few_shots_lines) - - logger.info(f"为群组 {group_id} 生成了基于 {len(patterns)} 个模式的人格更新内容") - return full_content - - except Exception as e: - logger.error(f"生成人格更新内容失败: {e}") - return "" - -# ========== 社交关系分析API ========== - -@api_bp.route("/social_relations/", methods=["GET"]) -@require_auth -async def get_social_relations(group_id: str): - """获取指定群组的社交关系分析数据""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - - # 获取数据库管理器 - db_manager = service_factory.create_database_manager() - - # 从数据库加载已保存的社交关系 - logger.info(f"从数据库加载群组 {group_id} 的社交关系...") - saved_relations = await db_manager.get_social_relations_by_group(group_id) - logger.info(f"从数据库加载到 {len(saved_relations)} 条社交关系记录") - - # 构建用户列表和统计消息数 - 使用 ORM 方法获取用户统计 - user_message_counts = {} - user_names = {} - - # ✅ 使用 ORM 方法统计每个用户的总消息数量(支持跨线程调用) - user_stats = await db_manager.get_group_user_statistics(group_id) - - for sender_id, stats in user_stats.items(): - user_key = f"{group_id}:{sender_id}" - user_message_counts[user_key] = stats['message_count'] - user_names[user_key] = stats['sender_name'] - # 同时存储纯ID格式的映射,以兼容数据库中的社交关系数据 - user_names[sender_id] = stats['sender_name'] - - logger.info(f"群组 {group_id} 从数据库统计到 {len(user_message_counts)} 个用户") - - # 初始化 raw_messages 变量 - raw_messages = [] - - # 如果没有统计到用户,尝试从最近消息获取 - if not user_message_counts: - raw_messages = await db_manager.get_recent_raw_messages(group_id, limit=200) - if not raw_messages: - return jsonify({ - "success": False, - "error": f"群组 {group_id} 没有消息记录", - "relations": [], - "members": [] - }) - - for msg in raw_messages: - sender_id = msg.get('sender_id', '') - sender_name = msg.get('sender_name', '') - if sender_id and sender_id != 'bot': - user_key = f"{group_id}:{sender_id}" - if user_key not in user_message_counts: - user_message_counts[user_key] = 0 - user_names[user_key] = sender_name - user_names[sender_id] = sender_name - user_message_counts[user_key] += 1 - - # 构建成员列表 - group_nodes = [] - for user_key, message_count in user_message_counts.items(): - user_id = user_key.split(':')[-1] if ':' in user_key else user_key - group_nodes.append({ - 'user_id': user_id, - 'nickname': user_names.get(user_key, user_id), - 'message_count': message_count, - 'nicknames': [user_names.get(user_key, user_id)], - 'id': user_key - }) - - # 构建关系列表 - group_edges = [] - for relation in saved_relations: - from_key = relation['from_user'] - to_key = relation['to_user'] - - # 提取用户ID(from_key格式可能是 "group_id:user_id") - from_id = from_key.split(':')[-1] if ':' in from_key else from_key - to_id = to_key.split(':')[-1] if ':' in to_key else to_key - - # 获取用户名 - 现在user_names字典同时包含两种格式的key - from_name = user_names.get(from_key, user_names.get(from_id, from_id)) - to_name = user_names.get(to_key, user_names.get(to_id, to_id)) - - logger.debug(f"社交关系映射: {from_key} ({from_id}) -> {to_key} ({to_id}), " - f"名称: {from_name} -> {to_name}") - - # 关系类型映射 - relation_type_map = { - 'mention': '提及(@)', - 'reply': '回复', - 'conversation': '对话', - 'frequent_interaction': '频繁互动', - 'topic_discussion': '话题讨论' - } - relation_type_text = relation_type_map.get(relation.get('relation_type', 'interaction'), '互动') - - group_edges.append({ - 'source': from_id, - 'target': to_id, - 'source_name': from_name, - 'target_name': to_name, - 'strength': relation.get('strength', 0.5), - 'type': relation.get('relation_type', 'interaction'), - 'type_text': relation_type_text, - 'frequency': relation.get('frequency', 1), - 'last_interaction': relation.get('last_interaction', '') - }) - - logger.info(f"群组 {group_id} 构建了 {len(group_edges)} 条社交关系") - - # 计算总消息数:优先使用数据库统计,否则使用raw_messages长度 - total_message_count = sum(user_message_counts.values()) if user_message_counts else len(raw_messages) - - return jsonify({ - "success": True, - "group_id": group_id, - "members": group_nodes, - "relations": group_edges, - "message_count": total_message_count, - "member_count": len(group_nodes), - "relation_count": len(group_edges) - }) - - except Exception as e: - logger.error(f"获取社交关系失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e), - "relations": [], - "members": [] - }), 500 - -@api_bp.route("/social_relations/groups", methods=["GET"]) -@require_auth -async def get_available_groups_for_social_analysis(): - """获取可用于社交关系分析的群组列表(使用 ORM 版本)""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - db_manager = service_factory.create_database_manager() - - # ✅ 使用 ORM 方法获取群组统计(支持跨线程调用) - groups_data = await db_manager.get_groups_for_social_analysis() - - groups = [] - for group_data in groups_data: - try: - group_id = group_data['group_id'] - message_count = group_data['message_count'] - member_count = group_data['member_count'] - relation_count = group_data['relation_count'] - - groups.append({ - 'group_id': group_id, - 'message_count': message_count, - 'member_count': member_count, # 修复:使用正确的字段名 - 'user_count': member_count, # 保留旧字段以兼容 - 'relation_count': relation_count # 新增:关系数 - }) - except Exception as row_error: - logger.warning(f"处理群组数据行时出错,跳过: {row_error}, data: {group_data}") - continue - - return jsonify({ - "success": True, - "groups": groups - }) - - except Exception as e: - logger.error(f"获取群组列表失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e), - "groups": [] - }), 500 - - -@api_bp.route("/social_relations//analyze", methods=["POST"]) -@require_auth -async def trigger_social_relation_analysis(group_id: str): - """触发群组社交关系分析""" - try: - from .core.factory import FactoryManager - from .services.social_relation_analyzer import SocialRelationAnalyzer - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - db_manager = service_factory.create_database_manager() - - # 获取LLM适配器 - global llm_adapter_instance - if not llm_adapter_instance: - return jsonify({ - "success": False, - "error": "LLM适配器未初始化" - }), 500 - - # 创建社交关系分析器 - analyzer = SocialRelationAnalyzer( - config=current_app.plugin_config, - llm_adapter=llm_adapter_instance, - db_manager=db_manager - ) - - # 获取参数 - data = await request.get_json() if request.is_json else {} - message_limit = data.get('message_limit', 200) - force_refresh = data.get('force_refresh', False) - - logger.info(f"开始分析群组 {group_id} 的社交关系 (消息数: {message_limit}, 强制刷新: {force_refresh})") - - # 执行分析 - relations = await analyzer.analyze_group_social_relations( - group_id=group_id, - message_limit=message_limit, - force_refresh=force_refresh - ) - - return jsonify({ - "success": True, - "message": f"成功分析 {len(relations)} 条社交关系", - "relation_count": len(relations) - }) - - except Exception as e: - logger.error(f"触发社交关系分析失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/social_relations//clear", methods=["DELETE"]) -@require_auth -async def clear_group_social_relations(group_id: str): - """清空群组社交关系数据""" - try: - from .core.factory import FactoryManager - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - db_manager = service_factory.create_database_manager() - - logger.info(f"开始清空群组 {group_id} 的社交关系数据") - - # 统计要删除的记录数 - deleted_count = 0 - - # 使用 ORM 查询和删除(支持跨线程 event loop) - from sqlalchemy import select, func, delete - from .models.orm import UserSocialRelationComponent - - async with db_manager.get_session() as session: - # 先统计数量 - count_stmt = select(func.count()).select_from(UserSocialRelationComponent).where( - UserSocialRelationComponent.group_id == group_id - ) - count_result = await session.execute(count_stmt) - deleted_count = count_result.scalar() or 0 - - # 执行删除 - delete_stmt = delete(UserSocialRelationComponent).where( - UserSocialRelationComponent.group_id == group_id - ) - await session.execute(delete_stmt) - await session.commit() - - logger.info(f"成功清空群组 {group_id} 的 {deleted_count} 条社交关系数据") - - return jsonify({ - "success": True, - "message": f"成功清空 {deleted_count} 条社交关系数据", - "deleted_count": deleted_count - }) - - except Exception as e: - logger.error(f"清空社交关系数据失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/social_relations//user/", methods=["GET"]) -@require_auth -async def get_user_social_relations(group_id: str, user_id: str): - """获取指定用户的社交关系""" - try: - from .core.factory import FactoryManager - from .services.social_relation_analyzer import SocialRelationAnalyzer - - factory_manager = FactoryManager() - service_factory = factory_manager.get_service_factory() - db_manager = service_factory.create_database_manager() - - # 获取LLM适配器 - global llm_adapter_instance - if not llm_adapter_instance: - return jsonify({ - "success": False, - "error": "LLM适配器未初始化" - }), 500 - - # 创建社交关系分析器 - analyzer = SocialRelationAnalyzer( - config=current_app.plugin_config, - llm_adapter=llm_adapter_instance, - db_manager=db_manager - ) - - # 获取用户关系 - user_relations = await analyzer.get_user_relations(group_id, user_id) - - return jsonify({ - "success": True, - **user_relations - }) - - except Exception as e: - logger.error(f"获取用户社交关系失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -# ========== 外部API接口 (供其他程序调用) ========== - -def require_api_key(f): - """API密钥认证装饰器""" - @wraps(f) - async def decorated_function(*args, **kwargs): - # 获取配置 - config = getattr(current_app, 'plugin_config', None) - - # 如果未启用API认证,直接通过 - if not config or not config.enable_api_auth: - return await f(*args, **kwargs) - - # 检查API密钥 - api_key = request.headers.get('X-API-Key') or request.args.get('api_key') - - if not api_key: - return jsonify({ - "success": False, - "error": "缺少API密钥。请在请求头中添加 X-API-Key 或在查询参数中添加 api_key" - }), 401 - - if api_key != config.api_key: - return jsonify({ - "success": False, - "error": "API密钥无效" - }), 403 - - return await f(*args, **kwargs) - return decorated_function - - -@api_bp.route("/external/current_topic", methods=["GET"]) -@require_api_key -async def get_current_topic_api(): - """ - 获取指定群组当前的聊天话题 - - 查询参数: - group_id: 群组ID (必需) - recent_count: 分析的最近消息数量 (可选,默认20) - - 返回: - JSON格式的话题信息 - """ - try: - group_id = request.args.get('group_id') - if not group_id: - return jsonify({ - "success": False, - "error": "缺少必需参数: group_id" - }), 400 - - recent_count = request.args.get('recent_count', 20, type=int) - - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - # 获取话题总结 - topic_data = await database_manager.get_current_topic_summary(group_id, recent_count) - - return jsonify({ - "success": True, - **topic_data - }) - - except Exception as e: - logger.error(f"获取当前话题失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/external/chat_history", methods=["GET"]) -@require_api_key -async def get_chat_history_api(): - """ - 获取指定群组的聊天记录(支持时间段筛选) - - 查询参数: - group_id: 群组ID (必需) - start_time: 开始时间戳(秒) (可选) - end_time: 结束时间戳(秒) (可选) - limit: 返回消息数量限制 (可选,默认100) - - 返回: - JSON格式的聊天记录列表 - """ - try: - group_id = request.args.get('group_id') - if not group_id: - return jsonify({ - "success": False, - "error": "缺少必需参数: group_id" - }), 400 - - start_time = request.args.get('start_time', type=float) - end_time = request.args.get('end_time', type=float) - limit = request.args.get('limit', 100, type=int) - - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - # 获取聊天记录 - messages = await database_manager.get_messages_by_group_and_timerange( - group_id=group_id, - start_time=start_time, - end_time=end_time, - limit=limit - ) - - return jsonify({ - "success": True, - "group_id": group_id, - "message_count": len(messages), - "messages": messages, - "filter": { - "start_time": start_time, - "end_time": end_time, - "limit": limit - } - }) - - except Exception as e: - logger.error(f"获取聊天记录失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/external/new_messages", methods=["GET"]) -@require_api_key -async def get_new_messages_api(): - """ - 获取增量消息更新(只返回之前未获取过的新消息) - - 查询参数: - group_id: 群组ID (必需) - last_message_id: 上次获取的最后一条消息ID (可选,优先使用) - last_timestamp: 上次获取的最后一条消息时间戳 (可选) - - 注意: last_message_id 和 last_timestamp 至少需要提供一个,优先使用 last_message_id - - 返回: - JSON格式的新消息列表 - """ - try: - group_id = request.args.get('group_id') - if not group_id: - return jsonify({ - "success": False, - "error": "缺少必需参数: group_id" - }), 400 - - last_message_id = request.args.get('last_message_id', type=int) - last_timestamp = request.args.get('last_timestamp', type=float) - - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - # 获取新消息 - new_messages = await database_manager.get_new_messages_since( - group_id=group_id, - last_message_id=last_message_id, - last_timestamp=last_timestamp - ) - - # 提取新消息的最大ID和最新时间戳,供下次调用使用 - max_id = None - latest_timestamp = None - if new_messages: - max_id = max(msg['id'] for msg in new_messages) - latest_timestamp = max(msg['timestamp'] for msg in new_messages) - - return jsonify({ - "success": True, - "group_id": group_id, - "new_message_count": len(new_messages), - "messages": new_messages, - "next_query": { - "last_message_id": max_id, - "last_timestamp": latest_timestamp - } if new_messages else None - }) - - except Exception as e: - logger.error(f"获取增量消息失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -# ========== 黑话学习系统API ========== - -@api_bp.route("/jargon/stats", methods=["GET"]) -@login_required -async def get_jargon_stats(): - """ - 获取黑话学习统计信息 - - 查询参数: - group_id: 群组ID (可选,不传则返回全局统计) - - 返回: - JSON格式的统计信息 - """ - try: - group_id = request.args.get('group_id') - - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - stats = await database_manager.get_jargon_statistics(group_id) - - return jsonify({ - "success": True, - "data": stats, - "group_id": group_id - }) - - except Exception as e: - logger.error(f"获取黑话统计失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon/list", methods=["GET"]) -@login_required -async def get_jargon_list(): - """ - 获取黑话学习列表 - - 查询参数: - group_id: 群组ID (可选,不传则返回所有) - limit: 返回数量限制 (默认50) - only_confirmed: 是否只返回已确认的黑话 (默认true) - page: 页码 (默认1) - - 返回: - JSON格式的黑话列表 - """ - try: - group_id = request.args.get('group_id') - limit = request.args.get('limit', 50, type=int) - only_confirmed_str = request.args.get('only_confirmed', 'true') - only_confirmed = only_confirmed_str.lower() in ('true', '1', 'yes') - page = request.args.get('page', 1, type=int) - - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - # 获取黑话列表 - jargon_list = await database_manager.get_recent_jargon_list( - chat_id=group_id, - limit=limit, - only_confirmed=only_confirmed - ) - - return jsonify({ - "success": True, - "data": jargon_list, - "total": len(jargon_list), - "group_id": group_id, - "page": page, - "limit": limit - }) - - except Exception as e: - logger.error(f"获取黑话列表失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon/search", methods=["GET"]) -@login_required -async def search_jargon(): - """ - 搜索黑话 - - 查询参数: - keyword: 搜索关键词 (必需) - group_id: 群组ID (可选,不传则搜索全局黑话) - limit: 返回数量限制 (默认10) - - 返回: - JSON格式的搜索结果 - """ - try: - keyword = request.args.get('keyword') - if not keyword: - return jsonify({ - "success": False, - "error": "缺少必需参数: keyword" - }), 400 - - group_id = request.args.get('group_id') - limit = request.args.get('limit', 10, type=int) - - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - results = await database_manager.search_jargon( - keyword=keyword, - chat_id=group_id, - limit=limit - ) - - return jsonify({ - "success": True, - "data": results, - "keyword": keyword, - "group_id": group_id, - "count": len(results) - }) - - except Exception as e: - logger.error(f"搜索黑话失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon/", methods=["DELETE"]) -@login_required -async def delete_jargon(jargon_id: int): - """ - 删除指定黑话记录 - - 路径参数: - jargon_id: 黑话记录ID - - 返回: - JSON格式的删除结果 - """ - try: - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - # 执行删除 - success = await database_manager.delete_jargon_by_id(jargon_id) - - if success: - return jsonify({ - "success": True, - "message": f"黑话记录 {jargon_id} 已删除" - }) - else: - return jsonify({ - "success": False, - "error": f"未找到黑话记录 {jargon_id}" - }), 404 - - except Exception as e: - logger.error(f"删除黑话失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon//toggle_global", methods=["POST"]) -@login_required -async def toggle_jargon_global(jargon_id: int): - """ - 切换黑话的全局状态 - - 路径参数: - jargon_id: 黑话记录ID - - 返回: - JSON格式的操作结果 - """ - try: - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - # 使用 ORM 查询和更新(支持跨线程 event loop) - from sqlalchemy import select - from .models.orm import Jargon as JargonModel - import time as _time - - async with database_manager.get_session() as session: - stmt = select(JargonModel).where(JargonModel.id == jargon_id) - result = await session.execute(stmt) - jargon_record = result.scalar_one_or_none() - - if not jargon_record: - return jsonify({ - "success": False, - "error": f"未找到黑话记录 {jargon_id}" - }), 404 - - # 切换状态 - new_status = not bool(jargon_record.is_global) - jargon_record.is_global = new_status - jargon_record.updated_at = int(_time.time()) - await session.commit() - - return jsonify({ - "success": True, - "jargon_id": jargon_id, - "is_global": new_status, - "message": f"黑话记录 {jargon_id} 已{'设为全局' if new_status else '取消全局'}" - }) - - except Exception as e: - logger.error(f"切换黑话全局状态失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon/groups", methods=["GET"]) -@login_required -async def get_jargon_groups(): - """ - 获取所有有黑话记录的群组列表(使用 ORM 版本) - - 返回: - JSON格式的群组列表,每个群组包含黑话统计 - """ - try: - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - # ✅ 使用 ORM 方法获取黑话群组列表(支持跨线程调用) - groups_data = await database_manager.get_jargon_groups() - - groups = [] - for group_data in groups_data: - try: - groups.append({ - 'group_id': group_data['group_id'], - 'total_candidates': group_data['total_jargon'], # 总黑话数 - 'confirmed_jargon': group_data['complete_jargon'], # 已完成黑话数 - 'global_jargon': group_data['global_jargon'], # 全局黑话数 - 'last_updated': None # ORM版本暂不提供 last_updated,可后续添加 - }) - except Exception as row_error: - logger.warning(f"处理黑话群组数据行时出错,跳过: {row_error}, data: {group_data}") - continue - - return jsonify({ - "success": True, - "data": groups, - "total_groups": len(groups) - }) - - except Exception as e: - logger.error(f"获取黑话群组列表失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon/global", methods=["GET"]) -@login_required -async def get_global_jargon_list(): - """ - 获取全局共享的黑话列表 - - 参数: - limit: 返回数量限制 (默认50) - - 返回: - JSON格式的全局黑话列表 - """ - try: - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - limit = request.args.get('limit', 50, type=int) - jargon_list = await database_manager.get_global_jargon_list(limit=limit) - - return jsonify({ - "success": True, - "data": jargon_list, - "total": len(jargon_list) - }) - - except Exception as e: - logger.error(f"获取全局黑话列表失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon//set_global", methods=["POST"]) -@login_required -async def set_jargon_global_status(jargon_id: int): - """ - 设置黑话的全局共享状态 - - 参数: - jargon_id: 黑话记录ID - is_global: 是否全局共享 (JSON body) - - 返回: - 操作结果 - """ - try: - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - data = await request.get_json() - is_global = data.get('is_global', True) - - result = await database_manager.set_jargon_global(jargon_id, is_global) - - if result: - return jsonify({ - "success": True, - "message": f"黑话已{'设为全局共享' if is_global else '取消全局共享'}" - }) - else: - return jsonify({ - "success": False, - "error": "更新失败,黑话可能不存在" - }), 404 - - except Exception as e: - logger.error(f"设置黑话全局状态失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon/batch_set_global", methods=["POST"]) -@login_required -async def batch_set_jargon_global(): - """ - 批量设置黑话的全局共享状态 - - 参数 (JSON body): - jargon_ids: 黑话ID列表 - is_global: 是否全局共享 - - 返回: - 操作结果统计 - """ - try: - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - data = await request.get_json() - jargon_ids = data.get('jargon_ids', []) - is_global = data.get('is_global', True) - - if not jargon_ids: - return jsonify({ - "success": False, - "error": "未提供黑话ID列表" - }), 400 - - result = await database_manager.batch_set_jargon_global(jargon_ids, is_global) - - return jsonify({ - "success": result.get('success', False), - "data": result, - "message": f"批量{'设为全局' if is_global else '取消全局'}: 成功 {result.get('success_count', 0)} 条" - }) - - except Exception as e: - logger.error(f"批量设置黑话全局状态失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -@api_bp.route("/jargon/sync_to_group", methods=["POST"]) -@login_required -async def sync_global_jargon_to_group(): - """ - 将全局黑话同步到指定群组 - - 参数 (JSON body): - target_group_id: 目标群组ID - - 返回: - 同步结果统计 - """ - try: - if not database_manager: - return jsonify({ - "success": False, - "error": "数据库管理器未初始化" - }), 500 - - data = await request.get_json() - target_group_id = data.get('target_group_id') - - if not target_group_id: - return jsonify({ - "success": False, - "error": "未提供目标群组ID" - }), 400 - - result = await database_manager.sync_global_jargon_to_group(target_group_id) - - return jsonify({ - "success": result.get('success', False), - "data": result, - "message": f"同步完成: 新增 {result.get('synced_count', 0)} 条, 跳过 {result.get('skipped_count', 0)} 条" - }) - - except Exception as e: - logger.error(f"同步全局黑话失败: {e}", exc_info=True) - return jsonify({ - "success": False, - "error": str(e) - }), 500 - - -app.register_blueprint(api_bp) - -# 添加根路由重定向 -@app.route("/") -async def root(): - """根路由重定向到API根路径""" - return redirect("/api/") - -# ========== Quart 服务器管理类 ========== -# 自定义 Config 类,用于劫持 Socket 创建过程 -# 全局锚点 -GLOBAL_SERVER_KEY = "_astrbot_self_learning_server_v5_fix" - -# [修改1] 自定义 Config 类 -class SecureConfig(HypercornConfig): - def create_sockets(self): - insecure_sockets = [] - secure_sockets = [] - quic_sockets = [] - - for bind in self.bind: - if ":" in bind: - host, port = bind.rsplit(":", 1) - port = int(port) - else: - host = bind - port = 80 - - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if sys.platform != 'win32' and hasattr(socket, 'SO_REUSEPORT'): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - - # [核心] 禁止继承 - sock.set_inheritable(False) - - sock.bind((host, port)) - sock.listen(100) - - logger.info(f"🔒 安全Socket创建成功: {host}:{port}") - insecure_sockets.append(sock) - - except Exception as e: - logger.error(f"Socket 创建失败 {bind}: {e}") - try: sock.close() - except: pass - raise e - - # [修复] 返回对象而非列表 - return Sockets(secure_sockets, insecure_sockets, quic_sockets) - -class Server: - """Quart 服务器管理类 (最终修正版)""" - _instance = None - - def __new__(cls, *args, **kwargs): - if not cls._instance: - cls._instance = super(Server, cls).__new__(cls) - return cls._instance - - def __init__(self, host: str = "0.0.0.0", port: int = 7833, auto_find_port: bool = False): - if hasattr(self, '_initialized') and self._initialized: - return - - self._initialized = True - try: - logger.info(f"🔧 初始化Web服务器 (固定端口: {port})...") - self.host = host - self.port = port - - self.server_thread: Optional[threading.Thread] = None - self._thread_loop = None - self._shutdown_event = None - - bind_host = self.host - #if sys.platform == 'win32' and self.host == '0.0.0.0': - # bind_host = '127.0.0.1' - - # [修改2] 使用 SecureConfig - self.config = SecureConfig() - self.config.bind = [f"{bind_host}:{self.port}"] - self.config.accesslog = None - self.config.errorlog = None - self.config.loglevel = "WARNING" - self.config.workers = 1 - self.config.worker_class = "asyncio" - - except Exception as e: - logger.error(f"❌ Web服务器初始化失败: {e}") - - async def _kill_port_holder(self, port: int): - import sys - import os - try: - if sys.platform == 'win32': - cmd_find = f'netstat -ano | findstr :{port}' - process = await asyncio.create_subprocess_shell( - cmd_find, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) - stdout, _ = await process.communicate() - if stdout: - lines = stdout.decode('gbk', errors='ignore').strip().split('\n') - for line in lines: - parts = line.strip().split() - if len(parts) > 4 and 'LISTENING' in line: - pid = parts[-1] - if pid and pid != str(os.getpid()): - logger.warning(f"🔫 清理占用进程 PID={pid}") - await asyncio.create_subprocess_shell( - f'taskkill /F /PID {pid}', - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL - ) - await asyncio.sleep(1.0) - except: pass - - def _run_thread(self): - import asyncio - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self._thread_loop = loop - self._shutdown_event = asyncio.Event() - - # Hypercorn 会调用 SecureConfig.create_sockets - loop.run_until_complete( - hypercorn.asyncio.serve( - app, - self.config, - shutdown_trigger=self._shutdown_event.wait - ) - ) - loop.close() - logger.info("WebUI 线程已退出") - except Exception as e: - logger.error(f"WebUI 线程异常: {e}") - - async def start(self): - """启动服务器""" - if self.server_thread and self.server_thread.is_alive(): - return - - # 1. 暴力清理 - if not self._is_port_available(self.port): - await self._kill_port_holder(self.port) - - # 2. 启动线程 - try: - self.server_thread = threading.Thread( - target=self._run_thread, - daemon=True, - name="SelfLearning_WebUI" - ) - self.server_thread.start() - - # 3. 验证 - for _ in range(5): - await asyncio.sleep(1.0) - if await self._verify_tcp(): - logger.info(f"✅ Web服务器启动成功") - logger.info(f"🔗 本地访问: http://127.0.0.1:{self.port}") - return - - logger.warning("⚠️ WebUI 线程已启动但端口无响应") - - except Exception as e: - logger.error(f"❌ 启动失败: {e}") - raise e - - async def stop(self): - """停止服务器""" - if self._thread_loop and self._shutdown_event: - try: - self._thread_loop.call_soon_threadsafe(self._shutdown_event.set) - except: pass - - if self.server_thread: - await asyncio.sleep(1.0) - self.server_thread = None - - import gc - gc.collect() - - async def _verify_tcp(self): - import socket - loop = asyncio.get_event_loop() - def check(): - try: - check_host = "127.0.0.1" if self.host == "0.0.0.0" else self.host - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.settimeout(1) - return s.connect_ex((check_host, self.port)) == 0 - except: return False - return await loop.run_in_executor(None, check) - - def _is_port_available(self, port): - import socket - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.settimeout(0.2) - s.bind(("127.0.0.1", port)) - return True - except: return False - - def _find_available_port(self, p, auto_find_port=False): return p