diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index b73a361..f672672 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -2,7 +2,8 @@ name: Docker Image CI on: push: - branches: [ "main" ] + branches: [ "main", "dev" ] + workflow_dispatch: # 允许手动触发工作流 jobs: @@ -32,7 +33,11 @@ jobs: - name: Determine Image Tags id: tags run: | - echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:latest,${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:main-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT + if [ "${{ github.ref_name }}" == "main" ]; then + echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:latest,${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:main-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT + elif [ "${{ github.ref_name }}" == "dev" ]; then + echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:dev,${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:dev-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT + fi - name: Build and Push Docker Image uses: docker/build-push-action@v5 @@ -42,8 +47,8 @@ jobs: platforms: linux/amd64,linux/arm64 tags: ${{ steps.tags.outputs.tags }} push: true - cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:buildcache - cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:buildcache,mode=max + cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:buildcache-${{ github.ref_name }} + cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:buildcache-${{ github.ref_name }},mode=max labels: | org.opencontainers.image.created=${{ steps.tags.outputs.date_tag }} org.opencontainers.image.revision=${{ github.sha }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6d6652c..f6eb732 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,7 @@ elua.confirmed # C extensions *.so /results - +config_backup/ # Distribution / packaging .Python build/ @@ -39,6 +39,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +dev/ # PyInstaller # Usually these files are written by a python script from a template @@ -64,6 +65,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +dev/ # Translations *.mo @@ -148,6 +150,7 @@ venv/ ENV/ env.bak/ venv.bak/ +uv.lock # Spyder project settings .spyderproject @@ -270,4 +273,10 @@ $RECYCLE.BIN/ *.lnk config.toml -test \ No newline at end of file +config.toml.back +test +data/NapcatAdapter.db +data/NapcatAdapter.db-shm +data/NapcatAdapter.db-wal + +uv.lock \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 7dee666..d50a5f0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.13.2-slim +FROM python:3.13.5-slim LABEL authors="infinitycat233" # Copy uv and maim_message diff --git a/README.md b/README.md index 4615f49..266ebac 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ sequenceDiagram - [x] 读取戳一戳的自定义内容 - [ ] 语音解析(?) - [ ] 所有的notice类 - - [ ] 撤回 + - [x] 撤回(已添加相关指令) - [x] 发送消息 - [x] 发送文本 - [x] 发送图片 @@ -78,4 +78,6 @@ sequenceDiagram - [x] 群踢人功能 # 特别鸣谢 - 特别感谢[@Maple127667](https://github.com/Maple127667)对本项目代码思路的支持 \ No newline at end of file + 特别感谢[@Maple127667](https://github.com/Maple127667)对本项目代码思路的支持 + + 以及[@墨梓柒](https://github.com/DrSmoothl)对部分代码想法的支持 \ No newline at end of file diff --git a/command_args.md b/command_args.md index 8a73651..3f21a5b 100644 --- a/command_args.md +++ b/command_args.md @@ -1,11 +1,31 @@ # Command Arguments + ```python Seg.type = "command" ``` -## 群聊禁言 + +所有命令执行后都会通过自定义消息类型 `command_response` 返回响应,格式如下: + +```python +{ + "command_name": "命令名称", + "success": True/False, # 是否执行成功 + "timestamp": 1234567890.123, # 时间戳 + "data": {...}, # 返回数据(成功时) + "error": "错误信息" # 错误信息(失败时) +} +``` + +插件需要注册 `command_response` 自定义消息处理器来接收命令响应。 + +--- + +## 操作类命令 + +### 群聊禁言 ```python Seg.data: Dict[str, Any] = { - "name": "GROUP_BAN" + "name": "GROUP_BAN", "args": { "qq_id": "用户QQ号", "duration": "禁言时长(秒)" @@ -13,10 +33,13 @@ Seg.data: Dict[str, Any] = { } ``` 其中,群聊ID将会通过Group_Info.group_id自动获取。 -## 群聊全体禁言 + +**当`duration`为 0 时相当于解除禁言。** + +### 群聊全体禁言 ```python Seg.data: Dict[str, Any] = { - "name": "GROUP_WHOLE_BAN" + "name": "GROUP_WHOLE_BAN", "args": { "enable": "是否开启全体禁言(True/False)" }, @@ -25,13 +48,429 @@ Seg.data: Dict[str, Any] = { 其中,群聊ID将会通过Group_Info.group_id自动获取。 `enable`的参数需要为boolean类型,True表示开启全体禁言,False表示关闭全体禁言。 -## 群聊踢人 + +### 群聊踢人 +将指定成员从群聊中踢出,可选拉黑。 + ```python Seg.data: Dict[str, Any] = { - "name": "GROUP_KICK" + "name": "GROUP_KICK", "args": { - "qq_id": "用户QQ号", + "group_id": 123456789, # 可选,如果在群聊上下文中可从 group_info 自动获取 + "user_id": 12345678, # 必需,用户QQ号 + "reject_add_request": False # 可选,是否群拉黑,默认 False + }, +} +``` + +### 批量踢出群成员 +批量将多个成员从群聊中踢出,可选拉黑。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GROUP_KICK_MEMBERS", + "args": { + "group_id": 123456789, # 可选,如果在群聊上下文中可从 group_info 自动获取 + "user_id": [12345678, 87654321], # 必需,用户QQ号数组 + "reject_add_request": False # 可选,是否群拉黑,默认 False + }, +} +``` + +### 戳一戳 +```python +Seg.data: Dict[str, Any] = { + "name": "SEND_POKE", + "args": { + "qq_id": "目标QQ号" + } +} +``` + +### 撤回消息 +```python +Seg.data: Dict[str, Any] = { + "name": "DELETE_MSG", + "args": { + "message_id": "消息所对应的message_id" + } +} +``` +其中message_id是消息的实际qq_id,于新版的mmc中可以从数据库获取(如果工作正常的话) + +### 给消息贴表情 +```python +Seg.data: Dict[str, Any] = { + "name": "SET_MSG_EMOJI_LIKE", + "args": { + "message_id": "消息ID", + "emoji_id": "表情ID" + } +} +``` + +### 设置群名 +设置指定群的群名称。 + +```python +Seg.data: Dict[str, Any] = { + "name": "SET_GROUP_NAME", + "args": { + "group_id": 123456789, # 可选,如果在群聊上下文中可从 group_info 自动获取 + "group_name": "新群名" # 必需,新的群名称 + } +} +``` + +### 设置账号信息 +设置Bot自己的QQ账号资料。 + +```python +Seg.data: Dict[str, Any] = { + "name": "SET_QQ_PROFILE", + "args": { + "nickname": "新昵称", # 必需,昵称 + "personal_note": "个性签名", # 可选,个性签名 + "sex": "male" # 可选,性别:"male" | "female" | "unknown" + } +} +``` + +**返回数据示例:** +```python +{ + "result": 0, # 结果码,0为成功 + "errMsg": "" # 错误信息 +} +``` + +--- + +## 查询类命令 + +### 获取登录号信息 +获取Bot自身的账号信息。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_LOGIN_INFO", + "args": {} +} +``` + +**返回数据示例:** +```python +{ + "user_id": 12345678, + "nickname": "Bot昵称" +} +``` + +### 获取陌生人信息 +```python +Seg.data: Dict[str, Any] = { + "name": "GET_STRANGER_INFO", + "args": { + "user_id": "用户QQ号" + } +} +``` + +**返回数据示例:** +```python +{ + "user_id": 12345678, + "nickname": "用户昵称", + "sex": "male/female/unknown", + "age": 0 +} +``` + +### 获取好友列表 +获取Bot的好友列表。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_FRIEND_LIST", + "args": { + "no_cache": False # 可选,是否不使用缓存,默认 False + } +} +``` + +**返回数据示例:** +```python +[ + { + "user_id": 12345678, + "nickname": "好友昵称", + "remark": "备注名", + "sex": "male", # "male" | "female" | "unknown" + "age": 18, + "qid": "QID字符串", + "level": 64, + "login_days": 365, + "birthday_year": 2000, + "birthday_month": 1, + "birthday_day": 1, + "phone_num": "电话号码", + "email": "邮箱", + "category_id": 0, # 分组ID + "categoryName": "我的好友", # 分组名称 + "categoryId": 0 + }, + ... +] +``` + +### 获取群信息 +获取指定群的详细信息。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_GROUP_INFO", + "args": { + "group_id": 123456789 # 可选,如果在群聊上下文中可从 group_info 自动获取 + } +} +``` + +**返回数据示例:** +```python +{ + "group_id": "123456789", # 群号(字符串) + "group_name": "群名称", + "group_remark": "群备注", + "group_all_shut": 0, # 群全员禁言状态(0=未禁言) + "member_count": 100, # 当前成员数量 + "max_member_count": 500 # 最大成员数量 +} +``` + +### 获取群详细信息 +获取指定群的详细信息(与 GET_GROUP_INFO 类似,可能提供更实时的数据)。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_GROUP_DETAIL_INFO", + "args": { + "group_id": 123456789 # 可选,如果在群聊上下文中可从 group_info 自动获取 + } +} +``` + +**返回数据示例:** +```python +{ + "group_id": 123456789, # 群号(数字) + "group_name": "群名称", + "group_remark": "群备注", + "group_all_shut": 0, # 群全员禁言状态(0=未禁言) + "member_count": 100, # 当前成员数量 + "max_member_count": 500 # 最大成员数量 +} +``` + +### 获取群列表 +获取Bot加入的所有群列表。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_GROUP_LIST", + "args": { + "no_cache": False # 可选,是否不使用缓存,默认 False + } +} +``` + +**返回数据示例:** +```python +[ + { + "group_id": "123456789", # 群号(字符串) + "group_name": "群名称", + "group_remark": "群备注", + "group_all_shut": 0, # 群全员禁言状态 + "member_count": 100, # 当前成员数量 + "max_member_count": 500 # 最大成员数量 + }, + ... +] +``` + +### 获取群@全体成员剩余次数 +查询指定群的@全体成员剩余使用次数。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_GROUP_AT_ALL_REMAIN", + "args": { + "group_id": 123456789 # 可选,如果在群聊上下文中可从 group_info 自动获取 + } +} +``` + +**返回数据示例:** +```python +{ + "can_at_all": True, # 是否可以@全体成员 + "remain_at_all_count_for_group": 10, # 群剩余@全体成员次数 + "remain_at_all_count_for_uin": 5 # Bot剩余@全体成员次数 +} +``` + +### 获取群成员信息 +获取指定群成员的详细信息。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_GROUP_MEMBER_INFO", + "args": { + "group_id": 123456789, # 可选,如果在群聊上下文中可从 group_info 自动获取 + "user_id": 12345678, # 必需,用户QQ号 + "no_cache": False # 可选,是否不使用缓存,默认 False + } +} +``` + +**返回数据示例:** +```python +{ + "group_id": 123456789, + "user_id": 12345678, + "nickname": "昵称", + "card": "群名片", + "sex": "male", # "male" | "female" | "unknown" + "age": 18, + "join_time": 1234567890, # 加群时间戳 + "last_sent_time": 1234567890, # 最后发言时间戳 + "level": 1, # 群等级 + "qq_level": 64, # QQ等级 + "role": "member", # "owner" | "admin" | "member" + "title": "专属头衔", + "area": "地区", + "unfriendly": False, # 是否不友好 + "title_expire_time": 1234567890, # 头衔过期时间 + "card_changeable": True, # 名片是否可修改 + "shut_up_timestamp": 0, # 禁言时间戳 + "is_robot": False, # 是否机器人 + "qage": "10年" # Q龄 +} +``` + +### 获取群成员列表 +获取指定群的所有成员列表。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_GROUP_MEMBER_LIST", + "args": { + "group_id": 123456789, # 可选,如果在群聊上下文中可从 group_info 自动获取 + "no_cache": False # 可选,是否不使用缓存,默认 False + } +} +``` + +**返回数据示例:** +```python +[ + { + "group_id": 123456789, + "user_id": 12345678, + "nickname": "昵称", + "card": "群名片", + "sex": "male", # "male" | "female" | "unknown" + "age": 18, + "join_time": 1234567890, + "last_sent_time": 1234567890, + "level": 1, + "qq_level": 64, + "role": "member", # "owner" | "admin" | "member" + "title": "专属头衔", + "area": "地区", + "unfriendly": False, + "title_expire_time": 1234567890, + "card_changeable": True, + "shut_up_timestamp": 0, + "is_robot": False, + "qage": "10年" + }, + ... +] +``` + +### 获取消息详情 +获取指定消息的完整详情信息。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_MSG", + "args": { + "message_id": 123456 # 必需,消息ID + } +} +``` + +**返回数据示例:** +```python +{ + "self_id": 12345678, # Bot自身ID + "user_id": 87654321, # 发送者ID + "time": 1234567890, # 时间戳 + "message_id": 123456, # 消息ID + "message_seq": 123456, # 消息序列号 + "real_id": 123456, # 真实消息ID + "real_seq": "123456", # 真实序列号(字符串) + "message_type": "group", # "private" | "group" + "sub_type": "normal", # 子类型 + "message_format": "array", # 消息格式 + "post_type": "message", # 事件类型 + "group_id": 123456789, # 群号(群消息时存在) + "sender": { + "user_id": 87654321, + "nickname": "昵称", + "sex": "male", # "male" | "female" | "unknown" + "age": 18, + "card": "群名片", # 群消息时存在 + "level": "1", # 群等级(字符串) + "role": "member" # "owner" | "admin" | "member" }, + "message": [...], # 消息段数组 + "raw_message": "消息文本内容", # 原始消息文本 + "font": 0 # 字体 } ``` -其中,群聊ID将会通过Group_Info.group_id自动获取。 \ No newline at end of file + +### 获取合并转发消息 +获取合并转发消息的所有子消息内容。 + +```python +Seg.data: Dict[str, Any] = { + "name": "GET_FORWARD_MSG", + "args": { + "message_id": "7123456789012345678" # 必需,合并转发消息ID(字符串) + } +} +``` + +**返回数据示例:** +```python +{ + "messages": [ + { + "sender": { + "user_id": 87654321, + "nickname": "昵称", + "sex": "male", + "age": 18, + "card": "群名片", + "level": "1", + "role": "member" + }, + "time": 1234567890, + "message": [...] # 消息段数组 + }, + ... + ] +} +``` \ No newline at end of file diff --git a/docs/36bd6c6d15c0fa7ece5b856d0e51ebe5.jpg b/docs/36bd6c6d15c0fa7ece5b856d0e51ebe5.jpg deleted file mode 100644 index a597338..0000000 Binary files a/docs/36bd6c6d15c0fa7ece5b856d0e51ebe5.jpg and /dev/null differ diff --git a/main.py b/main.py index 33af0c5..10c48e6 100644 --- a/main.py +++ b/main.py @@ -1,26 +1,39 @@ import asyncio import sys import json +import http import websockets as Server from src.logger import logger -from src.recv_handler import recv_handler -from src.send_handler import send_handler +from src.recv_handler.message_handler import message_handler +from src.recv_handler.meta_event_handler import meta_event_handler +from src.recv_handler.notice_handler import notice_handler +from src.recv_handler.message_sending import message_send_instance +from src.send_handler.nc_sending import nc_message_sender from src.config import global_config from src.mmc_com_layer import mmc_start_com, mmc_stop_com, router -from src.message_queue import message_queue, put_response, check_timeout_response +from src.response_pool import put_response, check_timeout_response + +message_queue = asyncio.Queue() +websocket_server = None # 保存WebSocket服务器实例以便关闭 async def message_recv(server_connection: Server.ServerConnection): - recv_handler.server_connection = server_connection - send_handler.server_connection = server_connection - async for raw_message in server_connection: - logger.debug(f"{raw_message[:100]}..." if len(raw_message) > 100 else raw_message) - decoded_raw_message: dict = json.loads(raw_message) - post_type = decoded_raw_message.get("post_type") - if post_type in ["meta_event", "message", "notice"]: - await message_queue.put(decoded_raw_message) - elif post_type is None: - await put_response(decoded_raw_message) + try: + await message_handler.set_server_connection(server_connection) + asyncio.create_task(notice_handler.set_server_connection(server_connection)) + await nc_message_sender.set_server_connection(server_connection) + async for raw_message in server_connection: + logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message) + decoded_raw_message: dict = json.loads(raw_message) + post_type = decoded_raw_message.get("post_type") + if post_type in ["meta_event", "message", "notice"]: + await message_queue.put(decoded_raw_message) + elif post_type is None: + await put_response(decoded_raw_message) + except asyncio.CancelledError: + logger.debug("message_recv 收到取消信号,正在关闭连接") + await server_connection.close() + raise async def message_process(): @@ -28,11 +41,11 @@ async def message_process(): message = await message_queue.get() post_type = message.get("post_type") if post_type == "message": - await recv_handler.handle_raw_message(message) + await message_handler.handle_raw_message(message) elif post_type == "meta_event": - await recv_handler.handle_meta_event(message) + await meta_event_handler.handle_meta_event(message) elif post_type == "notice": - await recv_handler.handle_notice(message) + await notice_handler.handle_notice(message) else: logger.warning(f"未知的post_type: {post_type}") message_queue.task_done() @@ -40,27 +53,170 @@ async def message_process(): async def main(): - recv_handler.maibot_router = router - _ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process(), check_timeout_response()) + # 启动配置文件监控并注册napcat_server配置变更回调 + from src.config import config_manager + + # 保存napcat_server任务的引用,用于重启 + napcat_task = None + restart_event = asyncio.Event() + + async def on_napcat_config_change(old_value, new_value): + """当napcat_server配置变更时,重启WebSocket服务器""" + nonlocal napcat_task + + logger.warning( + f"NapCat配置已变更:\n" + f" 旧配置: {old_value.host}:{old_value.port}\n" + f" 新配置: {new_value.host}:{new_value.port}" + ) + + # 关闭当前WebSocket服务器 + global websocket_server + if websocket_server: + try: + logger.info("正在关闭旧的WebSocket服务器...") + websocket_server.close() + await websocket_server.wait_closed() + logger.info("旧的WebSocket服务器已关闭") + except Exception as e: + logger.error(f"关闭旧WebSocket服务器失败: {e}") + + # 取消旧任务 + if napcat_task and not napcat_task.done(): + napcat_task.cancel() + try: + await napcat_task + except asyncio.CancelledError: + pass + + # 触发重启 + restart_event.set() + + config_manager.on_config_change("napcat_server", on_napcat_config_change) + + # 启动文件监控 + asyncio.create_task(config_manager.start_watch()) + + # WebSocket服务器重启循环 + async def napcat_with_restart(): + nonlocal napcat_task + while True: + restart_event.clear() + try: + await napcat_server() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"NapCat服务器异常: {e}") + break + + # 等待重启信号 + if not restart_event.is_set(): + break + + logger.info("正在重启WebSocket服务器...") + await asyncio.sleep(1) # 等待1秒后重启 + _ = await asyncio.gather(napcat_with_restart(), mmc_start_com(), message_process(), check_timeout_response()) + +def check_napcat_server_token(conn, request): + token = global_config.napcat_server.token + if not token or token.strip() == "": + return None + auth_header = request.headers.get("Authorization") + if auth_header != f"Bearer {token}": + return Server.Response( + status=http.HTTPStatus.UNAUTHORIZED, + headers=Server.Headers([("Content-Type", "text/plain")]), + body=b"Unauthorized\n" + ) + return None async def napcat_server(): - logger.info("正在启动adapter...") - async with Server.serve(message_recv, global_config.server_host, global_config.server_port) as server: - logger.info(f"Adapter已启动,监听地址: ws://{global_config.server_host}:{global_config.server_port}") - await server.serve_forever() + global websocket_server + logger.info("正在启动 MaiBot-Napcat-Adapter...") + logger.debug(f"日志等级: {global_config.debug.level}") + logger.debug("日志文件: logs/adapter_*.log") + try: + async with Server.serve( + message_recv, + global_config.napcat_server.host, + global_config.napcat_server.port, + max_size=2**26, + process_request=check_napcat_server_token + ) as server: + websocket_server = server + logger.success( + f"✅ Adapter 启动成功! 监听: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}" + ) + try: + await server.serve_forever() + except asyncio.CancelledError: + logger.debug("napcat_server 收到取消信号") + raise + except OSError: + # 端口绑定失败时抛出异常让外层处理 + raise -async def graceful_shutdown(): +async def graceful_shutdown(silent: bool = False): + """ + 优雅关闭adapter + Args: + silent: 静默模式,控制台不输出日志,但仍记录到文件 + """ + global websocket_server try: - logger.info("正在关闭adapter...") - await mmc_stop_com() + if not silent: + logger.info("正在关闭adapter...") + else: + logger.debug("正在清理资源...") + + # 先关闭WebSocket服务器 + if websocket_server: + try: + logger.debug("正在关闭WebSocket服务器") + websocket_server.close() + await websocket_server.wait_closed() + logger.debug("WebSocket服务器已关闭") + except Exception as e: + logger.debug(f"关闭WebSocket服务器时出现错误: {e}") + + # 关闭MMC连接 + try: + await asyncio.wait_for(mmc_stop_com(), timeout=3) + except asyncio.TimeoutError: + logger.debug("关闭MMC连接超时") + except Exception as e: + logger.debug(f"关闭MMC连接时出现错误: {e}") + + # 取消所有任务 tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + if tasks: + logger.debug(f"正在取消 {len(tasks)} 个任务") for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + if not task.done(): + task.cancel() + + # 等待任务完成,记录异常到日志文件 + if tasks: + try: + results = await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=3) + # 记录任务取消的详细信息到日志文件 + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.debug(f"任务 {i+1} 清理时产生异常: {type(result).__name__}: {result}") + except asyncio.TimeoutError: + logger.debug("任务清理超时") + except Exception as e: + logger.debug(f"任务清理时出现错误: {e}") + + if not silent: + logger.info("Adapter已成功关闭") + else: + logger.debug("资源清理完成") except Exception as e: - logger.error(f"Adapter关闭中出现错误: {e}") + logger.debug(f"graceful_shutdown异常: {e}", exc_info=True) if __name__ == "__main__": @@ -70,11 +226,58 @@ async def graceful_shutdown(): loop.run_until_complete(main()) except KeyboardInterrupt: logger.warning("收到中断信号,正在优雅关闭...") - loop.run_until_complete(graceful_shutdown()) + try: + loop.run_until_complete(graceful_shutdown(silent=False)) + except Exception: + pass + except OSError as e: + # 处理端口占用等网络错误 + if e.errno == 10048 or "address already in use" in str(e).lower(): + logger.error(f"❌ 端口 {global_config.napcat_server.port} 已被占用,请检查:") + logger.error(" 1. 是否有其他 MaiBot-Napcat-Adapter 实例正在运行") + logger.error(" 2. 修改 config.toml 中的 port 配置") + logger.error(f" 3. 使用命令查看占用进程: netstat -ano | findstr {global_config.napcat_server.port}") + else: + logger.error(f"❌ 网络错误: {str(e)}") + + logger.debug("完整错误信息:", exc_info=True) + + # 端口占用时静默清理(控制台不输出,但记录到日志文件) + try: + loop.run_until_complete(graceful_shutdown(silent=True)) + except Exception as e: + logger.debug(f"清理资源时出现错误: {e}", exc_info=True) + sys.exit(1) except Exception as e: - logger.exception(f"主程序异常: {str(e)}") + logger.error(f"❌ 主程序异常: {str(e)}") + logger.debug("详细错误信息:", exc_info=True) + try: + loop.run_until_complete(graceful_shutdown(silent=True)) + except Exception as e: + logger.debug(f"清理资源时出现错误: {e}", exc_info=True) sys.exit(1) finally: - if loop and not loop.is_closed(): - loop.close() + # 清理事件循环 + try: + # 取消所有剩余任务 + pending = asyncio.all_tasks(loop) + if pending: + logger.debug(f"finally块清理 {len(pending)} 个剩余任务") + for task in pending: + task.cancel() + # 给任务一点时间完成取消 + try: + results = loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + # 记录清理结果到日志文件 + for i, result in enumerate(results): + if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError): + logger.debug(f"剩余任务 {i+1} 清理异常: {type(result).__name__}: {result}") + except Exception as e: + logger.debug(f"清理剩余任务时出现错误: {e}") + except Exception as e: + logger.debug(f"finally块清理出现错误: {e}") + finally: + if loop and not loop.is_closed(): + logger.debug("关闭事件循环") + loop.close() sys.exit(0) diff --git a/notify_args.md b/notify_args.md new file mode 100644 index 0000000..8a94fef --- /dev/null +++ b/notify_args.md @@ -0,0 +1,44 @@ +# Notify Args +```python +Seg.type = "notify" +``` +## 群聊成员被禁言 +```python +Seg.data: Dict[str, Any] = { + "sub_type": "ban", + "duration": "对应的禁言时间,单位为秒", + "banned_user_info": "被禁言的用户的信息,为标准UserInfo转换成的字典" +} +``` +此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息 + +**注意: `banned_user_info`需要自行调用`UserInfo.from_dict()`函数转换为标准UserInfo对象** +## 群聊开启全体禁言 +```python +Seg.data: Dict[str, Any] = { + "sub_type": "whole_ban", + "duration": -1, + "banned_user_info": None +} +``` +此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息 +## 群聊成员被解除禁言 +```python +Seg.data: Dict[str, Any] = { + "sub_type": "whole_lift_ban", + "lifted_user_info": "被解除禁言的用户的信息,为标准UserInfo对象" +} +``` +**对于自然禁言解除的情况,此时`MessageBase.UserInfo`为`None`** + +对于手动解除禁言的情况,此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息 + +**注意: `lifted_user_info`需要自行调用`UserInfo.from_dict()`函数转换为标准UserInfo对象** +## 群聊关闭全体禁言 +```python +Seg.data: Dict[str, Any] = { + "sub_type": "whole_lift_ban", + "lifted_user_info": None, +} +``` +此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3cf0553..3ecd8b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,21 @@ [project] name = "MaiBotNapcatAdapter" -version = "0.2.5" +version = "0.7.0" description = "A MaiBot adapter for Napcat" +requires-python = ">=3.10" +dependencies = [ + "aiohttp>=3.13.2", + "asyncio>=4.0.0", + "loguru>=0.7.3", + "maim-message>=0.6.2", + "pillow>=12.0.0", + "requests>=2.32.5", + "rich>=14.2.0", + "sqlmodel>=0.0.27", + "tomlkit>=0.13.3", + "websockets>=15.0.1", + "watchdog>=3.0.0", +] [tool.ruff] @@ -21,7 +35,7 @@ select = [ "B", # flake8-bugbear ] -ignore = ["E711","E501"] +ignore = ["E711", "E501"] [tool.ruff.format] docstring-code-format = true diff --git a/requirements.txt b/requirements.txt index 41e8eb1..817dc53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,7 @@ requests maim_message loguru pillow -tomli \ No newline at end of file +tomlkit +rich +sqlmodel +watchdog \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index a35ff0e..81b187b 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,77 +1,42 @@ -from enum import Enum - - -class MetaEventType: - lifecycle = "lifecycle" # 生命周期 - - class Lifecycle: - connect = "connect" # 生命周期 - WebSocket 连接成功 - - heartbeat = "heartbeat" # 心跳 - - -class MessageType: # 接受消息大类 - private = "private" # 私聊消息 - - class Private: - friend = "friend" # 私聊消息 - 好友 - group = "group" # 私聊消息 - 群临时 - group_self = "group_self" # 私聊消息 - 群中自身发送 - other = "other" # 私聊消息 - 其他 - - group = "group" # 群聊消息 - - class Group: - normal = "normal" # 群聊消息 - 普通 - anonymous = "anonymous" # 群聊消息 - 匿名消息 - notice = "notice" # 群聊消息 - 系统提示 - - -class NoticeType: # 通知事件 - friend_recall = "friend_recall" # 私聊消息撤回 - group_recall = "group_recall" # 群聊消息撤回 - notify = "notify" - - class Notify: - poke = "poke" # 戳一戳 - - -class RealMessageType: # 实际消息分类 - text = "text" # 纯文本 - face = "face" # qq表情 - image = "image" # 图片 - record = "record" # 语音 - video = "video" # 视频 - at = "at" # @某人 - rps = "rps" # 猜拳魔法表情 - dice = "dice" # 骰子 - shake = "shake" # 私聊窗口抖动(只收) - poke = "poke" # 群聊戳一戳 - share = "share" # 链接分享(json形式) - reply = "reply" # 回复消息 - forward = "forward" # 转发消息 - node = "node" # 转发消息节点 - - -class MessageSentType: - private = "private" - - class Private: - friend = "friend" - group = "group" - - group = "group" - - class Group: - normal = "normal" - - -class CommandType(Enum): - """命令类型""" - - GROUP_BAN = "set_group_ban" # 禁言用户 - GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 - GROUP_KICK = "set_group_kick" # 踢出群聊 - - def __str__(self) -> str: - return self.value +from enum import Enum +import tomlkit +import os +from .logger import logger + + +class CommandType(Enum): + """命令类型""" + + # 操作类命令 + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + GROUP_KICK_MEMBERS = "set_group_kick_members" # 批量踢出群成员 + SET_GROUP_NAME = "set_group_name" # 设置群名 + SEND_POKE = "send_poke" # 戳一戳 + DELETE_MSG = "delete_msg" # 撤回消息 + AI_VOICE_SEND = "send_group_ai_record" # 发送群AI语音 + SET_MSG_EMOJI_LIKE = "set_msg_emoji_like" # 给消息贴表情 + SET_QQ_PROFILE = "set_qq_profile" # 设置账号信息 + + # 查询类命令 + GET_LOGIN_INFO = "get_login_info" # 获取登录号信息 + GET_STRANGER_INFO = "get_stranger_info" # 获取陌生人信息 + GET_FRIEND_LIST = "get_friend_list" # 获取好友列表 + GET_GROUP_INFO = "get_group_info" # 获取群信息 + GET_GROUP_DETAIL_INFO = "get_group_detail_info" # 获取群详细信息 + GET_GROUP_LIST = "get_group_list" # 获取群列表 + GET_GROUP_AT_ALL_REMAIN = "get_group_at_all_remain" # 获取群@全体成员剩余次数 + GET_GROUP_MEMBER_INFO = "get_group_member_info" # 获取群成员信息 + GET_GROUP_MEMBER_LIST = "get_group_member_list" # 获取群成员列表 + GET_MSG = "get_msg" # 获取消息 + GET_FORWARD_MSG = "get_forward_msg" # 获取合并转发消息 + + def __str__(self) -> str: + return self.value + + +pyproject_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "pyproject.toml") +toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read()) +version = toml_data["project"]["version"] +logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n喜欢的话点个star喵~\n") diff --git a/src/config.py b/src/config.py deleted file mode 100644 index ee13c98..0000000 --- a/src/config.py +++ /dev/null @@ -1,93 +0,0 @@ -import os -import sys -import tomli -import shutil -from .logger import logger -from typing import Optional - - -class Config: - platform: str = "qq" - nickname: Optional[str] = None - server_host: str = "localhost" - server_port: int = 8095 - napcat_heartbeat_interval: int = 30 - - def __init__(self): - self._get_config_path() - - def _get_config_path(self): - current_file_path = os.path.abspath(__file__) - src_path = os.path.dirname(current_file_path) - self.root_path = os.path.join(src_path, "..") - self.config_path = os.path.join(self.root_path, "config.toml") - - def load_config(self): # sourcery skip: extract-method, move-assign - include_configs = ["Napcat_Server", "MaiBot_Server", "Chat", "Voice", "Debug"] - if not os.path.exists(self.config_path): - logger.error("配置文件不存在!") - logger.info("正在创建配置文件...") - shutil.copy( - os.path.join(self.root_path, "template", "template_config.toml"), - os.path.join(self.root_path, "config.toml"), - ) - logger.info("配置文件创建成功,请修改配置文件后重启程序。") - sys.exit(1) - with open(self.config_path, "rb") as f: - try: - raw_config = tomli.load(f) - except tomli.TOMLDecodeError as e: - logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}") - sys.exit(1) - for key in include_configs: - if key not in raw_config: - logger.error(f"配置文件中缺少必需的字段: '{key}'") - logger.error("你的配置文件可能过时,请尝试手动更新配置文件。") - sys.exit(1) - - self.server_host = raw_config["Napcat_Server"].get("host", "localhost") - self.server_port = raw_config["Napcat_Server"].get("port", 8095) - self.napcat_heartbeat_interval = raw_config["Napcat_Server"].get("heartbeat", 30) - - self.mai_host = raw_config["MaiBot_Server"].get("host", "localhost") - self.mai_port = raw_config["MaiBot_Server"].get("port", 8000) - self.platform = raw_config["MaiBot_Server"].get("platform_name") - if not self.platform: - logger.critical("请在配置文件中指定平台") - sys.exit(1) - - self.group_list_type: str = raw_config["Chat"].get("group_list_type") - self.group_list: list = raw_config["Chat"].get("group_list", []) - self.private_list_type: str = raw_config["Chat"].get("private_list_type") - self.private_list: list = raw_config["Chat"].get("private_list", []) - self.ban_user_id: list = raw_config["Chat"].get("ban_user_id", []) - self.enable_poke: bool = raw_config["Chat"].get("enable_poke", True) - if self.group_list_type not in ["whitelist", "blacklist"]: - logger.critical("请在配置文件中指定group_list_type或group_list_type填写错误") - sys.exit(1) - if self.private_list_type not in ["whitelist", "blacklist"]: - logger.critical("请在配置文件中指定private_list_type或private_list_type填写错误") - sys.exit(1) - - self.use_tts = raw_config["Voice"].get("use_tts", False) - - self.debug_level = raw_config["Debug"].get("level", "INFO") - if self.debug_level == "DEBUG": - logger.debug("原始配置文件内容:") - logger.debug(raw_config) - logger.debug("读取到的配置内容:") - logger.debug(f"平台: {self.platform}") - logger.debug(f"MaiBot服务器地址: {self.mai_host}:{self.mai_port}") - logger.debug(f"Napcat服务器地址: {self.server_host}:{self.server_port}") - logger.debug(f"心跳间隔: {self.napcat_heartbeat_interval}秒") - logger.debug(f"群聊列表类型: {self.group_list_type}") - logger.debug(f"群聊列表: {self.group_list}") - logger.debug(f"私聊列表类型: {self.private_list_type}") - logger.debug(f"私聊列表: {self.private_list}") - logger.debug(f"禁用用户ID列表: {self.ban_user_id}") - logger.debug(f"是否启用TTS: {self.use_tts}") - logger.debug(f"调试级别: {self.debug_level}") - - -global_config = Config() -global_config.load_config() diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000..e6d30db --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,6 @@ +from .config import global_config, _config_manager as config_manager + +__all__ = [ + "global_config", + "config_manager", +] diff --git a/src/config/config.py b/src/config/config.py new file mode 100644 index 0000000..1bf531d --- /dev/null +++ b/src/config/config.py @@ -0,0 +1,158 @@ +import os +from dataclasses import dataclass +from datetime import datetime + +import tomlkit +import shutil + +from tomlkit import TOMLDocument +from tomlkit.items import Table +from ..logger import logger +from rich.traceback import install + +from src.config.config_base import ConfigBase +from src.config.official_configs import ( + ChatConfig, + DebugConfig, + ForwardConfig, + MaiBotServerConfig, + NapcatServerConfig, + NicknameConfig, + VoiceConfig, +) + +install(extra_lines=3) + +TEMPLATE_DIR = "template" + + +def update_config(): + # 定义文件路径 + template_path = f"{TEMPLATE_DIR}/template_config.toml" + old_config_path = "config.toml" + new_config_path = "config.toml" + + # 检查配置文件是否存在 + if not os.path.exists(old_config_path): + logger.info("配置文件不存在,从模板创建新配置") + shutil.copy2(template_path, old_config_path) # 复制模板文件 + logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") + # 如果是新创建的配置文件,直接返回 + quit() + + # 读取旧配置文件和模板文件 + with open(old_config_path, "r", encoding="utf-8") as f: + old_config = tomlkit.load(f) + with open(template_path, "r", encoding="utf-8") as f: + new_config = tomlkit.load(f) + + # 检查version是否相同 + if old_config and "inner" in old_config and "inner" in new_config: + old_version = old_config["inner"].get("version") + new_version = new_config["inner"].get("version") + if old_version and new_version and old_version == new_version: + logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") + return + else: + logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") + else: + logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") + + # 创建备份文件夹 + backup_dir = "config_backup" + os.makedirs(backup_dir, exist_ok=True) + + # 备份文件名 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + old_backup_path = os.path.join(backup_dir, f"config.toml.bak.{timestamp}") + + # 备份旧配置文件 + shutil.copy2(old_config_path, old_backup_path) + logger.info(f"已备份旧配置文件到: {old_backup_path}") + + # 复制模板文件到配置目录 + shutil.copy2(template_path, new_config_path) + logger.info(f"已创建新配置文件: {new_config_path}") + + def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): + """ + 将source字典的值更新到target字典中(如果target中存在相同的键) + """ + for key, value in source.items(): + # 跳过version字段的更新 + if key == "version": + continue + if key in target: + if isinstance(value, dict) and isinstance(target[key], (dict, Table)): + update_dict(target[key], value) + else: + try: + # 对数组类型进行特殊处理 + if isinstance(value, list): + # 如果是空数组,确保它保持为空数组 + target[key] = tomlkit.array(str(value)) if value else tomlkit.array() + else: + # 其他类型使用item方法创建新值 + target[key] = tomlkit.item(value) + except (TypeError, ValueError): + # 如果转换失败,直接赋值 + target[key] = value + + # 将旧配置的值更新到新配置中 + logger.info("开始合并新旧配置...") + update_dict(new_config, old_config) + + # 保存更新后的配置(保留注释和格式) + with open(new_config_path, "w", encoding="utf-8") as f: + f.write(tomlkit.dumps(new_config)) + logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + quit() + + +@dataclass +class Config(ConfigBase): + """总配置类""" + + nickname: NicknameConfig + napcat_server: NapcatServerConfig + maibot_server: MaiBotServerConfig + chat: ChatConfig + voice: VoiceConfig + forward: ForwardConfig + debug: DebugConfig + + +def load_config(config_path: str) -> Config: + """ + 加载配置文件 + :param config_path: 配置文件路径 + :return: Config对象 + """ + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + config_data = tomlkit.load(f) + + # 创建Config对象 + try: + return Config.from_dict(config_data) + except Exception as e: + logger.critical("配置文件解析失败") + raise e + + +# 更新配置 +update_config() + +logger.info("正在品鉴配置文件...") + +# 创建配置管理器 +from .config_manager import ConfigManager + +_config_manager = ConfigManager() +_config_manager.load(config_path="config.toml") + +# 向后兼容:global_config 指向配置管理器 +# 所有现有代码可以继续使用 global_config.chat.xxx 访问配置 +global_config = _config_manager + +logger.info("非常的新鲜,非常的美味!") diff --git a/src/config/config_base.py b/src/config/config_base.py new file mode 100644 index 0000000..87cb079 --- /dev/null +++ b/src/config/config_base.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass, fields, MISSING +from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union + +T = TypeVar("T", bound="ConfigBase") + +TOML_DICT_TYPE = { + int, + float, + str, + bool, + list, + dict, +} + + +@dataclass +class ConfigBase: + """配置类的基类""" + + @classmethod + def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: + """从字典加载配置字段""" + if not isinstance(data, dict): + raise TypeError(f"Expected a dictionary, got {type(data).__name__}") + + init_args: Dict[str, Any] = {} + + for f in fields(cls): + field_name = f.name + field_type = f.type + if field_name.startswith("_"): + # 跳过以 _ 开头的字段 + continue + + if field_name not in data: + if f.default is not MISSING or f.default_factory is not MISSING: + # 跳过未提供且有默认值/默认构造方法的字段 + continue + else: + raise ValueError(f"Missing required field: '{field_name}'") + + value = data[field_name] + try: + init_args[field_name] = cls._convert_field(value, field_type) + except TypeError as e: + raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e + except Exception as e: + raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e + + return cls(**init_args) + + @classmethod + def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + """ + 转换字段值为指定类型 + + 1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法 + 2. 对于泛型集合类型(list, set, tuple),递归转换每个元素 + 3. 对于基础类型(int, str, float, bool),直接转换 + 4. 对于其他类型,尝试直接转换,如果失败则抛出异常 + """ + # 如果是嵌套的 dataclass,递归调用 from_dict 方法 + if isinstance(field_type, type) and issubclass(field_type, ConfigBase): + return field_type.from_dict(value) + + field_origin_type = get_origin(field_type) + field_args_type = get_args(field_type) + + # 处理泛型集合类型(list, set, tuple) + if field_origin_type in {list, set, tuple}: + # 检查提供的value是否为list + if not isinstance(value, list): + raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}") + + if field_origin_type is list: + return [cls._convert_field(item, field_args_type[0]) for item in value] + if field_origin_type is set: + return {cls._convert_field(item, field_args_type[0]) for item in value} + if field_origin_type is tuple: + # 检查提供的value长度是否与类型参数一致 + if len(value) != len(field_args_type): + raise TypeError( + f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}" + ) + return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type)) + + if field_origin_type is dict: + # 检查提供的value是否为dict + if not isinstance(value, dict): + raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") + + # 检查字典的键值类型 + if len(field_args_type) != 2: + raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}") + key_type, value_type = field_args_type + + return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} + + # 处理Optional类型 + if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union + if value is None: + return None + # 如果有数据,检查实际类型 + if type(value) not in field_args_type: + raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}") + return cls._convert_field(value, field_args_type[0]) + + # 处理int, str, float, bool等基础类型 + if field_origin_type is None: + if isinstance(value, field_type): + return field_type(value) + else: + raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}") + + # 处理Literal类型 + if field_origin_type is Literal: + # 获取Literal的允许值 + allowed_values = get_args(field_type) + if value in allowed_values: + return value + else: + raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type") + + # 处理其他类型 + if field_type is Any: + return value + + # 其他类型直接转换 + try: + return field_type(value) + except (ValueError, TypeError) as e: + raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e + + def __str__(self): + """返回配置类的字符串表示""" + return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})" diff --git a/src/config/config_manager.py b/src/config/config_manager.py new file mode 100644 index 0000000..b888ab7 --- /dev/null +++ b/src/config/config_manager.py @@ -0,0 +1,281 @@ +"""配置管理器 - 支持热重载""" +import asyncio +import os +from typing import Callable, Dict, List, Any, Optional +from datetime import datetime +from watchdog.observers import Observer +from watchdog.events import FileSystemEventHandler, FileModifiedEvent + +from ..logger import logger +from .config import Config, load_config + + +class ConfigManager: + """配置管理器 - 混合模式(属性代理 + 选择性回调) + + 支持热重载配置文件,使用watchdog实时监控文件变化。 + 需要特殊处理的配置项可以注册回调函数。 + """ + + def __init__(self) -> None: + self._config: Optional[Config] = None + self._config_path: str = "config.toml" + self._lock: asyncio.Lock = asyncio.Lock() + self._callbacks: Dict[str, List[Callable]] = {} + + # Watchdog相关 + self._observer: Optional[Observer] = None + self._event_handler: Optional[FileSystemEventHandler] = None + self._reload_debounce_task: Optional[asyncio.Task] = None + self._debounce_delay: float = 0.5 # 防抖延迟(秒) + self._loop: Optional[asyncio.AbstractEventLoop] = None # 事件循环引用 + self._is_reloading: bool = False # 标记是否正在重载 + self._last_reload_trigger: float = 0.0 # 最后一次触发重载的时间 + + def load(self, config_path: str = "config.toml") -> None: + """加载配置文件 + + Args: + config_path: 配置文件路径 + """ + self._config_path = os.path.abspath(config_path) + self._config = load_config(config_path) + + logger.info(f"配置已加载: {config_path}") + + async def reload(self, config_path: Optional[str] = None) -> bool: + """重载配置文件(热重载) + + Args: + config_path: 配置文件路径,如果为None则使用初始路径 + + Returns: + bool: 是否重载成功 + """ + if config_path is None: + config_path = self._config_path + + async with self._lock: + old_config = self._config + try: + new_config = load_config(config_path) + + if old_config is not None: + await self._notify_changes(old_config, new_config) + + self._config = new_config + logger.info(f"配置重载成功: {config_path}") + return True + + except Exception as e: + logger.error(f"配置重载失败: {e}", exc_info=True) + return False + + def on_config_change( + self, + config_path: str, + callback: Callable[[Any, Any], Any] + ) -> None: + """为特定配置路径注册回调函数 + + Args: + config_path: 配置路径,如 'napcat_server', 'chat.ban_user_id', 'debug.level' + callback: 回调函数,签名为 async def callback(old_value, new_value) + """ + if config_path not in self._callbacks: + self._callbacks[config_path] = [] + self._callbacks[config_path].append(callback) + logger.debug(f"已注册配置变更回调: {config_path}") + + async def _notify_changes(self, old_config: Config, new_config: Config) -> None: + """通知配置变更 + + Args: + old_config: 旧配置对象 + new_config: 新配置对象 + """ + for config_path, callbacks in self._callbacks.items(): + try: + old_value = self._get_value(old_config, config_path) + new_value = self._get_value(new_config, config_path) + + if old_value != new_value: + logger.info(f"检测到配置变更: {config_path}") + for callback in callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback(old_value, new_value) + else: + callback(old_value, new_value) + except Exception as e: + logger.error( + f"配置变更回调执行失败 [{config_path}]: {e}", + exc_info=True + ) + except Exception as e: + logger.error(f"获取配置值失败 [{config_path}]: {e}") + + def _get_value(self, config: Config, path: str) -> Any: + """获取嵌套配置值 + + Args: + config: 配置对象 + path: 配置路径,支持点分隔的嵌套路径 + + Returns: + Any: 配置值 + + Raises: + AttributeError: 配置路径不存在 + """ + parts = path.split('.') + value = config + for part in parts: + value = getattr(value, part) + return value + + def __getattr__(self, name: str) -> Any: + """动态代理配置属性访问 + + 支持直接访问配置对象的属性,如: + - config_manager.napcat_server + - config_manager.chat + - config_manager.debug + + Args: + name: 属性名 + + Returns: + Any: 配置对象的对应属性值 + + Raises: + RuntimeError: 配置尚未加载 + AttributeError: 属性不存在 + """ + # 私有属性不代理 + if name.startswith('_'): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + # 检查配置是否已加载 + if self._config is None: + raise RuntimeError("配置尚未加载,请先调用 load() 方法") + + # 尝试从 _config 获取属性 + try: + return getattr(self._config, name) + except AttributeError as e: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) from e + + async def start_watch(self) -> None: + """启动配置文件监控(需要在事件循环中调用)""" + if self._observer is not None: + logger.warning("配置文件监控已在运行") + return + + # 保存当前事件循环引用 + self._loop = asyncio.get_running_loop() + + # 创建文件监控事件处理器 + config_file_path = self._config_path + + class ConfigFileHandler(FileSystemEventHandler): + def __init__(handler_self, manager: "ConfigManager"): + handler_self.manager = manager + handler_self.config_path = config_file_path + + def on_modified(handler_self, event): + # 检查是否是目标配置文件修改事件 + if isinstance(event, FileModifiedEvent) and os.path.abspath(event.src_path) == handler_self.config_path: + logger.debug(f"检测到配置文件变更: {event.src_path}") + # 使用防抖机制避免重复重载 + # watchdog运行在独立线程,需要使用run_coroutine_threadsafe + if handler_self.manager._loop: + asyncio.run_coroutine_threadsafe( + handler_self.manager._debounced_reload(), + handler_self.manager._loop + ) + + self._event_handler = ConfigFileHandler(self) + + # 创建Observer并监控配置文件所在目录 + self._observer = Observer() + watch_dir = os.path.dirname(self._config_path) or "." + + self._observer.schedule(self._event_handler, watch_dir, recursive=False) + self._observer.start() + + logger.info(f"已启动配置文件实时监控: {self._config_path}") + + async def stop_watch(self) -> None: + """停止配置文件监控""" + if self._observer is None: + return + + logger.debug("正在停止配置文件监控") + + # 取消防抖任务 + if self._reload_debounce_task: + self._reload_debounce_task.cancel() + try: + await self._reload_debounce_task + except asyncio.CancelledError: + pass + + # 停止observer + self._observer.stop() + self._observer.join(timeout=2) + self._observer = None + self._event_handler = None + + logger.info("配置文件监控已停止") + + async def _debounced_reload(self) -> None: + """防抖重载:避免短时间内多次文件修改事件导致重复重载""" + import time + + # 记录当前触发时间 + trigger_time = time.time() + self._last_reload_trigger = trigger_time + + # 等待防抖延迟 + await asyncio.sleep(self._debounce_delay) + + # 检查是否有更新的触发 + if self._last_reload_trigger > trigger_time: + # 有更新的触发,放弃本次重载 + logger.debug("放弃过时的重载请求") + return + + # 检查是否已有重载在进行 + if self._is_reloading: + logger.debug("重载已在进行中,跳过") + return + + # 执行重载 + self._is_reloading = True + try: + modified_time = datetime.fromtimestamp( + os.path.getmtime(self._config_path) + ).strftime("%Y-%m-%d %H:%M:%S") + + logger.info( + f"配置文件已更新 (修改时间: {modified_time}),正在重载..." + ) + + success = await self.reload() + + if not success: + logger.error( + "配置文件重载失败!请检查配置文件格式是否正确。\n" + "当前仍使用旧配置运行,修复配置文件后将自动重试。" + ) + finally: + self._is_reloading = False + + def __repr__(self) -> str: + watching = self._observer is not None and self._observer.is_alive() + return f"" diff --git a/src/config/official_configs.py b/src/config/official_configs.py new file mode 100644 index 0000000..86b8d38 --- /dev/null +++ b/src/config/official_configs.py @@ -0,0 +1,100 @@ +from dataclasses import dataclass, field +from typing import Literal + +from src.config.config_base import ConfigBase + +""" +须知: +1. 本文件中记录了所有的配置项 +2. 所有新增的class都需要继承自ConfigBase +3. 所有新增的class都应在config.py中的Config类中添加字段 +4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default +""" + +ADAPTER_PLATFORM = "qq" + + +@dataclass +class NicknameConfig(ConfigBase): + nickname: str + """机器人昵称""" + + +@dataclass +class NapcatServerConfig(ConfigBase): + host: str = "localhost" + """Napcat服务端的主机地址""" + + port: int = 8095 + """Napcat服务端的端口号""" + + token: str = "" + """Napcat服务端的访问令牌,若无则留空""" + + heartbeat_interval: int = 30 + """Napcat心跳间隔时间,单位为秒""" + + +@dataclass +class MaiBotServerConfig(ConfigBase): + platform_name: str = field(default=ADAPTER_PLATFORM, init=False) + """平台名称,“qq”""" + + host: str = "localhost" + """MaiMCore的主机地址""" + + port: int = 8000 + """MaiMCore的端口号""" + + enable_api_server: bool = False + """是否启用API-Server模式连接""" + + base_url: str = "" + """API-Server连接地址 (ws://ipp:port/path)""" + + api_key: str = "" + """API Key (仅在enable_api_server为True时使用)""" + + +@dataclass +class ChatConfig(ConfigBase): + group_list_type: Literal["whitelist", "blacklist"] = "whitelist" + """群聊列表类型 白名单/黑名单""" + + group_list: list[int] = field(default_factory=[]) + """群聊列表""" + + private_list_type: Literal["whitelist", "blacklist"] = "whitelist" + """私聊列表类型 白名单/黑名单""" + + private_list: list[int] = field(default_factory=[]) + """私聊列表""" + + ban_user_id: list[int] = field(default_factory=[]) + """被封禁的用户ID列表,封禁后将无法与其进行交互""" + + ban_qq_bot: bool = False + """是否屏蔽QQ官方机器人,若为True,则所有QQ官方机器人将无法与MaiMCore进行交互""" + + enable_poke: bool = True + """是否启用戳一戳功能""" + + +@dataclass +class VoiceConfig(ConfigBase): + use_tts: bool = False + """是否启用TTS功能""" + + +@dataclass +class ForwardConfig(ConfigBase): + """转发消息相关配置""" + + image_threshold: int = 3 + """图片数量阈值:转发消息中图片数量超过此值时,使用占位符代替base64发送,避免麦麦VLM处理卡死""" + + +@dataclass +class DebugConfig(ConfigBase): + level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + """日志级别,默认为INFO""" diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000..af193da --- /dev/null +++ b/src/database.py @@ -0,0 +1,162 @@ +import os +from typing import Optional, List +from dataclasses import dataclass +from sqlmodel import Field, Session, SQLModel, create_engine, select + +from src.logger import logger + +""" +表记录的方式: +| group_id | user_id | lift_time | +|----------|---------|-----------| + +其中使用 user_id == 0 表示群全体禁言 +""" + + +@dataclass +class BanUser: + """ + 程序处理使用的实例 + """ + + user_id: int + group_id: int + lift_time: Optional[int] = Field(default=-1) + + +class DB_BanUser(SQLModel, table=True): + """ + 表示数据库中的用户禁言记录。 + 使用双重主键 + """ + + user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID + group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID + lift_time: Optional[int] # 禁言解除的时间(时间戳) + + +def is_identical(obj1: BanUser, obj2: BanUser) -> bool: + """ + 检查两个 BanUser 对象是否相同。 + """ + return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id + + +class DatabaseManager: + """ + 数据库管理类,负责与数据库交互。 + """ + + def __init__(self): + os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在 + DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db") + self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL + self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎 + self._ensure_database() # 确保数据库和表已创建 + + def _ensure_database(self) -> None: + """ + 确保数据库和表已创建。 + """ + logger.info("确保数据库文件和表已创建...") + SQLModel.metadata.create_all(self.engine) + logger.success("数据库和表已创建或已存在") + + def update_ban_record(self, ban_list: List[BanUser]) -> None: + # sourcery skip: class-extract-method + """ + 更新禁言列表到数据库。 + 支持在不存在时创建新记录,对于多余的项目自动删除。 + """ + with Session(self.engine) as session: + all_records = session.exec(select(DB_BanUser)).all() + for ban_user in ban_list: + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id + ) + if existing_record := session.exec(statement).first(): + if existing_record.lift_time == ban_user.lift_time: + logger.debug(f"禁言记录未变更: {existing_record}") + continue + # 更新现有记录的 lift_time + existing_record.lift_time = ban_user.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {existing_record}") + else: + # 创建新记录 + db_record = DB_BanUser( + user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time + ) + session.add(db_record) + logger.debug(f"创建新禁言记录: {ban_user}") + # 删除不在 ban_list 中的记录 + for db_record in all_records: + record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) + if not any(is_identical(record, ban_user) for ban_user in ban_list): + statement = select(DB_BanUser).where( + DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id + ) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + session.commit() + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: {ban_record}") + + session.commit() + logger.info("禁言记录已更新") + + def get_ban_records(self) -> List[BanUser]: + """ + 读取所有禁言记录。 + """ + with Session(self.engine) as session: + statement = select(DB_BanUser) + records = session.exec(statement).all() + return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] + + def create_ban_record(self, ban_record: BanUser) -> None: + """ + 为特定群组中的用户创建禁言记录。 + 一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。 + 其同时还是简化版的更新方式。 + """ + with Session(self.engine) as session: + # 检查记录是否已存在 + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id + ) + existing_record = session.exec(statement).first() + if existing_record: + # 如果记录已存在,更新 lift_time + existing_record.lift_time = ban_record.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {ban_record}") + else: + # 如果记录不存在,创建新记录 + db_record = DB_BanUser( + user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + ) + session.add(db_record) + logger.debug(f"创建新禁言记录: {ban_record}") + session.commit() + + def delete_ban_record(self, ban_record: BanUser): + """ + 删除特定用户在特定群组中的禁言记录。 + 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 + """ + user_id = ban_record.user_id + group_id = ban_record.group_id + with Session(self.engine) as session: + statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + session.commit() + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}") + + +db_manager = DatabaseManager() diff --git a/src/logger.py b/src/logger.py index 3acba4f..ab509e9 100644 --- a/src/logger.py +++ b/src/logger.py @@ -1,10 +1,106 @@ from loguru import logger from .config import global_config import sys +from pathlib import Path +from datetime import datetime, timedelta +# 日志目录配置 +LOG_DIR = Path(__file__).parent.parent / "logs" +LOG_DIR.mkdir(exist_ok=True) + +# 日志等级映射(用于显示单字母) +LEVEL_ABBR = { + "TRACE": "T", + "DEBUG": "D", + "INFO": "I", + "SUCCESS": "S", + "WARNING": "W", + "ERROR": "E", + "CRITICAL": "C" +} + +def get_level_abbr(record): + """获取日志等级的缩写""" + return LEVEL_ABBR.get(record["level"].name, record["level"].name[0]) + +def clean_old_logs(days: int = 30): + """清理超过指定天数的日志文件""" + try: + cutoff_date = datetime.now() - timedelta(days=days) + for log_file in LOG_DIR.glob("*.log"): + try: + file_time = datetime.fromtimestamp(log_file.stat().st_mtime) + if file_time < cutoff_date: + log_file.unlink() + print(f"已清理过期日志: {log_file.name}") + except Exception as e: + print(f"清理日志文件 {log_file.name} 失败: {e}") + except Exception as e: + print(f"清理日志目录失败: {e}") + +# 清理过期日志 +clean_old_logs(30) + +# 移除默认处理器 logger.remove() + +# 自定义格式化函数 +def format_log(record): + """格式化日志记录""" + record["extra"]["level_abbr"] = get_level_abbr(record) + if "module_name" not in record["extra"]: + record["extra"]["module_name"] = "Adapter" + return True + +# 控制台输出处理器 - 简洁格式 logger.add( sys.stderr, - level=global_config.debug_level, - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", + level=global_config.debug.level, + format="{time:MM-DD HH:mm:ss} | [{extra[level_abbr]}] | {extra[module_name]} | {message}", + filter=lambda record: format_log(record) and record["extra"].get("module_name") != "maim_message", ) + +# maim_message 单独处理 +logger.add( + sys.stderr, + level="INFO", + format="{time:MM-DD HH:mm:ss} | [{extra[level_abbr]}] | {extra[module_name]} | {message}", + filter=lambda record: format_log(record) and record["extra"].get("module_name") == "maim_message", +) + +# 文件输出处理器 - 详细格式,记录所有TRACE级别 +log_file = LOG_DIR / f"adapter_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" +logger.add( + log_file, + level="TRACE", + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | [{level}] | {extra[module_name]} | {name}:{function}:{line} - {message}", + rotation="100 MB", # 单个日志文件最大100MB + retention="30 days", # 保留30天 + encoding="utf-8", + enqueue=True, # 异步写入,避免阻塞 + filter=format_log, # 确保extra字段存在 +) + +def get_logger(module_name: str = "Adapter"): + """ + 获取自定义模块名的logger + + Args: + module_name: 模块名称,用于日志输出中标识来源 + + Returns: + 配置好的logger实例 + + Example: + >>> from src.logger import get_logger + >>> logger = get_logger("MyModule") + >>> logger.info("这是一条日志") + MM-DD HH:mm:ss | [I] | MyModule | 这是一条日志 + """ + return logger.bind(module_name=module_name) + +# 默认logger实例(用于向后兼容) +logger = logger.bind(module_name="Adapter") + +# maim_message的logger +custom_logger = logger.bind(module_name="maim_message") diff --git a/src/mmc_com_layer.py b/src/mmc_com_layer.py index 174ef1f..659b934 100644 --- a/src/mmc_com_layer.py +++ b/src/mmc_com_layer.py @@ -1,24 +1,164 @@ -from maim_message import Router, RouteConfig, TargetConfig +from maim_message import Router, RouteConfig, TargetConfig, MessageBase from .config import global_config -from .logger import logger -from .send_handler import send_handler - -route_config = RouteConfig( - route_config={ - global_config.platform: TargetConfig( - url=f"ws://{global_config.mai_host}:{global_config.mai_port}/ws", - token=None, +from .logger import logger, custom_logger +from .send_handler.main_send_handler import send_handler +from .recv_handler.message_sending import message_send_instance +from maim_message.client import create_client_config, WebSocketClient +from maim_message.message import APIMessageBase +from typing import Dict, Any +import importlib.metadata + +# 检查 maim_message 版本是否支持 MessageConverter (>= 0.6.2) +try: + maim_message_version = importlib.metadata.version("maim_message") + version_int = [int(x) for x in maim_message_version.split(".")] + HAS_MESSAGE_CONVERTER = version_int >= [0, 6, 2] +except (importlib.metadata.PackageNotFoundError, ValueError): + HAS_MESSAGE_CONVERTER = False + +# router = Router(route_config, custom_logger) +# router will be initialized in mmc_start_com +router = None + + +class APIServerWrapper: + """ + Wrapper to make WebSocketClient compatible with legacy Router interface + """ + def __init__(self, client: WebSocketClient): + self.client = client + self.platform = global_config.maibot_server.platform_name + + def register_class_handler(self, handler): + # In API Server mode, we register the on_message callback in config, + # but here we might need to bridge it if the handler structure is different. + # However, WebSocketClient config handles on_message. + # The legacy Router.register_class_handler registers a handler for received messages. + # We need to adapt the callback style. + pass + + async def send_message(self, message: MessageBase) -> bool: + # 使用 MessageConverter 转换 Legacy MessageBase 到 APIMessageBase + # 接收场景:Adapter 收到来自 Napcat 的消息,发送给 MaiMBot + # group_info/user_info 是消息发送者信息,放入 sender_info + from maim_message import MessageConverter + + api_message = MessageConverter.to_api_receive( + message=message, + api_key=global_config.maibot_server.api_key, + platform=message.message_info.platform or self.platform, ) - } -) -router = Router(route_config) + return await self.client.send_message(api_message) + + async def send_custom_message(self, platform: str, message_type_name: str, message: Dict) -> bool: + return await self.client.send_custom_message(message_type_name, message) + + async def run(self): + await self.client.start() + await self.client.connect() + + async def stop(self): + await self.client.stop() + +# Global variable to hold the communication object (Router or Wrapper) +router = None +async def _legacy_message_handler_adapter(message: APIMessageBase, metadata: dict): + # Adapter to call the legacy handler with dict as expected by main_send_handler + # send_handler.handle_message expects a dict. + # We need to convert APIMessageBase back to dict legacy format if possible. + # Or check what handle_message expects. + # main_send_handler.py: handle_message takes raw_message_base_dict: dict + # and does MessageBase.from_dict(raw_message_base_dict). + + # So we need to serialize APIMessageBase to a dict that looks like legacy MessageBase dict. + # This might be tricky if structures diverged. + # Let's try `to_dict()` if available, otherwise construct it. + + # Inspecting APIMessageBase structure from docs: + # APIMessageBase has message_info, message_segment, message_dim. + # Legacy MessageBase has message_info, message_segment. + + # We can try to construct the dict. + data = { + "message_info": { + "id": message.message_info.message_id, + "timestamp": message.message_info.time, + "group_info": {}, # Fill if available + "user_info": {}, # Fill if available + }, + "message_segment": { + "type": message.message_segment.type, + "data": message.message_segment.data + } + } + # Note: This is an approximation. Ideally we should check strict compatibility. + # However, for the adapter -> bot direction (sending to napcat), + # the bot sends messages to adapter? No, Adapter sends to Bot? + # mmc_com_layer seems to be for Adapter talking to MaiBot Core. + # recv_handler/message_sending.py uses this router to send TO MaiBot. + # The `register_class_handler` in `mmc_start_com` suggests MaiBot sends messages TO Adapter? + # Wait, `send_handler.handle_message` seems to be handling messages RECEIVED FROM MaiBot. + # So `router` is bidirectional. + + # If explicit to_dict is needed: + await send_handler.handle_message(data) async def mmc_start_com(): - logger.info("正在连接MaiBot") - router.register_class_handler(send_handler.handle_message) - await router.run() + global router + config = global_config.maibot_server + + if config.enable_api_server and HAS_MESSAGE_CONVERTER: + logger.info("使用 API-Server 模式连接 MaiBot") + + # Create legacy adapter handler + # We need to define the on_message callback here to bridge to send_handler + async def on_message_bridge(message: APIMessageBase, metadata: Dict[str, Any]): + # 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase + # 发送场景:收到来自 MaiMBot 的回复消息,需要发送给 Napcat + # receiver_info 包含消息接收者信息,需要提取到 group_info/user_info + try: + from maim_message import MessageConverter + + legacy_message = MessageConverter.from_api_send(message) + msg_dict = legacy_message.to_dict() + + await send_handler.handle_message(msg_dict) + + except Exception as e: + logger.error(f"消息桥接转换失败: {e}") + import traceback + logger.error(traceback.format_exc()) + + client_config = create_client_config( + url=config.base_url, + api_key=config.api_key, + platform=config.platform_name, + on_message=on_message_bridge, + custom_logger=custom_logger # 传入自定义logger + ) + + client = WebSocketClient(client_config) + router = APIServerWrapper(client) + message_send_instance.maibot_router = router + await router.run() + + else: + logger.info("使用 Legacy WebSocket 模式连接 MaiBot") + route_config = RouteConfig( + route_config={ + config.platform_name: TargetConfig( + url=f"ws://{config.host}:{config.port}/ws", + token=None, + ) + } + ) + router = Router(route_config, custom_logger) + router.register_class_handler(send_handler.handle_message) + message_send_instance.maibot_router = router + await router.run() async def mmc_stop_com(): - await router.stop() + if router: + await router.stop() diff --git a/src/recv_handler.py b/src/recv_handler.py deleted file mode 100644 index 4d5c397..0000000 --- a/src/recv_handler.py +++ /dev/null @@ -1,774 +0,0 @@ -from .logger import logger -from .config import global_config -from .qq_emoji_list import qq_face -import time -import asyncio -import json -import websockets as Server -from typing import List, Tuple, Optional, Dict, Any -import uuid - -from . import MetaEventType, RealMessageType, MessageType, NoticeType -from maim_message import ( - UserInfo, - GroupInfo, - Seg, - BaseMessageInfo, - MessageBase, - TemplateInfo, - FormatInfo, - Router, -) - -from .utils import ( - get_group_info, - get_member_info, - get_image_base64, - get_self_info, - get_stranger_info, - get_message_detail, -) -from .message_queue import get_response - - -class RecvHandler: - maibot_router: Router = None - - def __init__(self): - self.server_connection: Server.ServerConnection = None - self.interval = global_config.napcat_heartbeat_interval - - async def handle_meta_event(self, message: dict) -> None: - event_type = message.get("meta_event_type") - if event_type == MetaEventType.lifecycle: - sub_type = message.get("sub_type") - if sub_type == MetaEventType.Lifecycle.connect: - self_id = message.get("self_id") - self.last_heart_beat = time.time() - logger.info(f"Bot {self_id} 连接成功") - asyncio.create_task(self.check_heartbeat(self_id)) - elif event_type == MetaEventType.heartbeat: - if message["status"].get("online") and message["status"].get("good"): - self.last_heart_beat = time.time() - self.interval = message.get("interval") / 1000 - else: - self_id = message.get("self_id") - logger.warning(f"Bot {self_id} Napcat 端异常!") - - async def check_heartbeat(self, id: int) -> None: - while True: - now_time = time.time() - if now_time - self.last_heart_beat > self.interval + 3: - logger.warning(f"Bot {id} 连接已断开") - break - else: - logger.debug("心跳正常") - await asyncio.sleep(self.interval) - - def check_allow_to_chat(self, user_id: int, group_id: Optional[int]) -> bool: - # sourcery skip: hoist-statement-from-if, merge-else-if-into-elif - """ - 检查是否允许聊天 - Parameters: - user_id: int: 用户ID - group_id: int: 群ID - Returns: - bool: 是否允许聊天 - """ - logger.debug(f"群聊id: {group_id}, 用户id: {user_id}") - if group_id: - if global_config.group_list_type == "whitelist" and group_id not in global_config.group_list: - logger.warning("群聊不在聊天白名单中,消息被丢弃") - return False - elif global_config.group_list_type == "blacklist" and group_id in global_config.group_list: - logger.warning("群聊在聊天黑名单中,消息被丢弃") - return False - else: - if global_config.private_list_type == "whitelist" and user_id not in global_config.private_list: - logger.warning("私聊不在聊天白名单中,消息被丢弃") - return False - elif global_config.private_list_type == "blacklist" and user_id in global_config.private_list: - logger.warning("私聊在聊天黑名单中,消息被丢弃") - return False - if user_id in global_config.ban_user_id: - logger.warning("用户在全局黑名单中,消息被丢弃") - return False - return True - - async def handle_raw_message(self, raw_message: dict) -> None: - # sourcery skip: low-code-quality, remove-unreachable-code - """ - 从Napcat接受的原始消息处理 - - Parameters: - raw_message: dict: 原始消息 - """ - message_type: str = raw_message.get("message_type") - message_id: int = raw_message.get("message_id") - # message_time: int = raw_message.get("time") - message_time: float = time.time() # 应可乐要求,现在是float了 - - template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用 - format_info: FormatInfo = FormatInfo( - content_format=["text", "image", "emoji"], - accept_format=["text", "image", "emoji", "reply", "voice", "command"], - ) # 格式化信息 - if message_type == MessageType.private: - sub_type = raw_message.get("sub_type") - if sub_type == MessageType.Private.friend: - sender_info: dict = raw_message.get("sender") - - if not self.check_allow_to_chat(sender_info.get("user_id"), None): - return None - - # 发送者用户信息 - user_info: UserInfo = UserInfo( - platform=global_config.platform, - user_id=sender_info.get("user_id"), - user_nickname=sender_info.get("nickname"), - user_cardname=sender_info.get("card"), - ) - - # 不存在群信息 - group_info: GroupInfo = None - elif sub_type == MessageType.Private.group: - """ - 本部分暂时不做支持,先放着 - """ - logger.warning("群临时消息类型不支持") - return None - - sender_info: dict = raw_message.get("sender") - - # 由于临时会话中,Napcat默认不发送成员昵称,所以需要单独获取 - fetched_member_info: dict = await get_member_info( - self.server_connection, - raw_message.get("group_id"), - sender_info.get("user_id"), - ) - nickname = fetched_member_info.get("nickname") if fetched_member_info else None - # 发送者用户信息 - user_info: UserInfo = UserInfo( - platform=global_config.platform, - user_id=sender_info.get("user_id"), - user_nickname=nickname, - user_cardname=None, - ) - - # -------------------这里需要群信息吗?------------------- - - # 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有 - fetched_group_info: dict = await get_group_info(self.server_connection, raw_message.get("group_id")) - group_name = "" - if fetched_group_info.get("group_name"): - group_name = fetched_group_info.get("group_name") - - group_info: GroupInfo = GroupInfo( - platform=global_config.platform, - group_id=raw_message.get("group_id"), - group_name=group_name, - ) - - else: - logger.warning(f"私聊消息类型 {sub_type} 不支持") - return None - elif message_type == MessageType.group: - sub_type = raw_message.get("sub_type") - if sub_type == MessageType.Group.normal: - sender_info: dict = raw_message.get("sender") - - if not self.check_allow_to_chat(sender_info.get("user_id"), raw_message.get("group_id")): - return None - - # 发送者用户信息 - user_info: UserInfo = UserInfo( - platform=global_config.platform, - user_id=sender_info.get("user_id"), - user_nickname=sender_info.get("nickname"), - user_cardname=sender_info.get("card"), - ) - - # 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有 - fetched_group_info = await get_group_info(self.server_connection, raw_message.get("group_id")) - group_name: str = None - if fetched_group_info: - group_name = fetched_group_info.get("group_name") - - group_info: GroupInfo = GroupInfo( - platform=global_config.platform, - group_id=raw_message.get("group_id"), - group_name=group_name, - ) - - else: - logger.warning(f"群聊消息类型 {sub_type} 不支持") - return None - - additional_config: dict = {} - if global_config.use_tts: - additional_config["allow_tts"] = True - - # 消息信息 - message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.platform, - message_id=message_id, - time=message_time, - user_info=user_info, - group_info=group_info, - template_info=template_info, - format_info=format_info, - additional_config=additional_config, - ) - - # 处理实际信息 - if not raw_message.get("message"): - logger.warning("原始消息内容为空") - return None - - # 获取Seg列表 - seg_message: List[Seg] = await self.handle_real_message(raw_message) - if not seg_message: - logger.warning("处理后消息内容为空") - return None - submit_seg: Seg = Seg( - type="seglist", - data=seg_message, - ) - # MessageBase创建 - message_base: MessageBase = MessageBase( - message_info=message_info, - message_segment=submit_seg, - raw_message=raw_message.get("raw_message"), - ) - - logger.info("发送到Maibot处理信息") - await self.message_process(message_base) - - async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None: - # sourcery skip: low-code-quality - """ - 处理实际消息 - Parameters: - real_message: dict: 实际消息 - Returns: - seg_message: list[Seg]: 处理后的消息段列表 - """ - real_message: list = raw_message.get("message") - if not real_message: - return None - seg_message: List[Seg] = [] - for sub_message in real_message: - sub_message: dict - sub_message_type = sub_message.get("type") - match sub_message_type: - case RealMessageType.text: - ret_seg = await self.handle_text_message(sub_message) - if ret_seg: - seg_message.append(ret_seg) - else: - logger.warning("text处理失败") - case RealMessageType.face: - ret_seg = await self.handle_face_message(sub_message) - if ret_seg: - seg_message.append(ret_seg) - else: - logger.warning("face处理失败或不支持") - case RealMessageType.reply: - if not in_reply: - ret_seg = await self.handle_reply_message(sub_message) - if ret_seg: - seg_message += ret_seg - else: - logger.warning("reply处理失败") - case RealMessageType.image: - ret_seg = await self.handle_image_message(sub_message) - if ret_seg: - seg_message.append(ret_seg) - else: - logger.warning("image处理失败") - case RealMessageType.record: - logger.warning("不支持语音解析") - case RealMessageType.video: - logger.warning("不支持视频解析") - case RealMessageType.at: - ret_seg = await self.handle_at_message( - sub_message, - raw_message.get("self_id"), - raw_message.get("group_id"), - ) - if ret_seg: - seg_message.append(ret_seg) - else: - logger.warning("at处理失败") - case RealMessageType.rps: - logger.warning("暂时不支持猜拳魔法表情解析") - case RealMessageType.dice: - logger.warning("暂时不支持骰子表情解析") - case RealMessageType.shake: - # 预计等价于戳一戳 - logger.warning("暂时不支持窗口抖动解析") - case RealMessageType.share: - logger.warning("暂时不支持链接解析") - case RealMessageType.forward: - messages = await self.get_forward_message(sub_message) - if not messages: - logger.warning("转发消息内容为空或获取失败") - return None - ret_seg = await self.handle_forward_message(messages) - if ret_seg: - seg_message.append(ret_seg) - else: - logger.warning("转发消息处理失败") - case RealMessageType.node: - logger.warning("不支持转发消息节点解析") - case _: - logger.warning(f"未知消息类型: {sub_message_type}") - return seg_message - - async def handle_text_message(self, raw_message: dict) -> Seg: - """ - 处理纯文本信息 - Parameters: - raw_message: dict: 原始消息 - Returns: - seg_data: Seg: 处理后的消息段 - """ - message_data: dict = raw_message.get("data") - plain_text: str = message_data.get("text") - return Seg(type=RealMessageType.text, data=plain_text) - - async def handle_face_message(self, raw_message: dict) -> Seg | None: - """ - 处理表情消息 - Parameters: - raw_message: dict: 原始消息 - Returns: - seg_data: Seg: 处理后的消息段 - """ - message_data: dict = raw_message.get("data") - face_raw_id: str = str(message_data.get("id")) - if face_raw_id in qq_face: - face_content: str = qq_face.get(face_raw_id) - return Seg(type="text", data=face_content) - else: - logger.warning(f"不支持的表情:{face_raw_id}") - return None - - async def handle_image_message(self, raw_message: dict) -> Seg | None: - """ - 处理图片消息与表情包消息 - Parameters: - raw_message: dict: 原始消息 - Returns: - seg_data: Seg: 处理后的消息段 - """ - message_data: dict = raw_message.get("data") - image_sub_type = message_data.get("sub_type") - try: - image_base64 = await get_image_base64(message_data.get("url")) - except Exception as e: - logger.error(f"图片消息处理失败: {str(e)}") - return None - if image_sub_type == 0: - """这部分认为是图片""" - return Seg(type="image", data=image_base64) - elif image_sub_type == 1: - """这部分认为是表情包""" - return Seg(type="emoji", data=image_base64) - else: - logger.warning(f"不支持的图片子类型:{image_sub_type}") - return None - - async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg | None: - # sourcery skip: use-named-expression - """ - 处理at消息 - Parameters: - raw_message: dict: 原始消息 - self_id: int: 机器人QQ号 - group_id: int: 群号 - Returns: - seg_data: Seg: 处理后的消息段 - """ - message_data: dict = raw_message.get("data") - if message_data: - qq_id = message_data.get("qq") - if str(self_id) == str(qq_id): - logger.debug("机器人被at") - self_info: dict = await get_self_info(self.server_connection) - if self_info: - return Seg(type="text", data=f"@<{self_info.get('nickname')}:{self_info.get('user_id')}>") - else: - return None - else: - member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=qq_id) - if member_info: - return Seg(type="text", data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>") - else: - return None - - async def get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None: - forward_message_data: Dict = raw_message.get("data") - if not forward_message_data: - logger.warning("转发消息内容为空") - return None - forward_message_id = forward_message_data.get("id") - request_uuid = str(uuid.uuid4()) - payload = json.dumps( - { - "action": "get_forward_msg", - "params": {"message_id": forward_message_id}, - "echo": request_uuid, - } - ) - try: - await self.server_connection.send(payload) - response: dict = await get_response(request_uuid) - except TimeoutError: - logger.error("获取转发消息超时") - return None - except Exception as e: - logger.error(f"获取转发消息失败: {str(e)}") - return None - logger.debug( - f"转发消息原始格式:{json.dumps(response)[:80]}..." - if len(json.dumps(response)) > 80 - else json.dumps(response) - ) - response_data: Dict = response.get("data") - if not response_data: - logger.warning("转发消息内容为空或获取失败") - return None - return response_data.get("messages") - - async def handle_reply_message(self, raw_message: dict) -> Seg | None: - # sourcery skip: move-assign-in-block, use-named-expression - """ - 处理回复消息 - - """ - raw_message_data: dict = raw_message.get("data") - message_id: int = None - if raw_message_data: - message_id = raw_message_data.get("id") - else: - return None - message_detail: dict = await get_message_detail(self.server_connection, message_id) - if not message_detail: - logger.warning("获取被引用的消息详情失败") - return None - reply_message = await self.handle_real_message(message_detail, in_reply=True) - if reply_message is None: - reply_message = "(获取发言内容失败)" - sender_info: dict = message_detail.get("sender") - sender_nickname: str = sender_info.get("nickname") - sender_id: str = sender_info.get("user_id") - seg_message: List[Seg] = [] - if not sender_nickname: - logger.warning("无法获取被引用的人的昵称,返回默认值") - seg_message.append(Seg(type="text", data="[回复 未知用户:")) - else: - seg_message.append(Seg(type="text", data=f"[回复<{sender_nickname}:{sender_id}>:")) - seg_message += reply_message - seg_message.append(Seg(type="text", data="],说:")) - return seg_message - - async def handle_notice(self, raw_message: dict) -> None: - notice_type = raw_message.get("notice_type") - # message_time: int = raw_message.get("time") - message_time: float = time.time() # 应可乐要求,现在是float了 - - group_id = raw_message.get("group_id") - user_id = raw_message.get("user_id") - handled_message: Seg = None - - match notice_type: - case NoticeType.friend_recall: - logger.info("好友撤回一条消息") - logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") - logger.warning("暂时不支持撤回消息处理") - case NoticeType.group_recall: - logger.info("群内用户撤回一条消息") - logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}") - logger.warning("暂时不支持撤回消息处理") - case NoticeType.notify: - sub_type = raw_message.get("sub_type") - match sub_type: - case NoticeType.Notify.poke: - if global_config.enable_poke: - handled_message: Seg = await self.handle_poke_notify(raw_message) - else: - logger.warning("戳一戳消息被禁用,取消戳一戳处理") - case _: - logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") - case _: - logger.warning(f"不支持的notice类型: {notice_type}") - return None - if not handled_message: - logger.warning("notice处理失败或不支持") - return None - - source_name: str = None - source_cardname: str = None - if group_id: - member_info: dict = await get_member_info(self.server_connection, group_id, user_id) - if member_info: - source_name = member_info.get("nickname") - source_cardname = member_info.get("card") - else: - logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") - source_name = "QQ用户" - else: - stranger_info = await get_stranger_info(self.server_connection, user_id) - if stranger_info: - source_name = stranger_info.get("nickname") - else: - logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效") - source_name = "QQ用户" - - user_info: UserInfo = UserInfo( - platform=global_config.platform, - user_id=user_id, - user_nickname=source_name, - user_cardname=source_cardname, - ) - - group_info: GroupInfo = None - if group_id: - fetched_group_info = await get_group_info(self.server_connection, group_id) - group_name: str = None - if fetched_group_info: - group_name = fetched_group_info.get("group_name") - else: - logger.warning("无法获取戳一戳消息所在群的名称") - group_info = GroupInfo( - platform=global_config.platform, - group_id=group_id, - group_name=group_name, - ) - - message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.platform, - message_id="notice", - time=message_time, - user_info=user_info, - group_info=group_info, - template_info=None, - format_info=None, - ) - - message_base: MessageBase = MessageBase( - message_info=message_info, - message_segment=handled_message, - raw_message=json.dumps(raw_message), - ) - - logger.info("发送到Maibot处理通知信息") - await self.message_process(message_base) - - async def handle_poke_notify(self, raw_message: dict) -> Seg | None: - self_info: dict = await get_self_info(self.server_connection) - if not self_info: - logger.error("自身信息获取失败") - return None - self_id = raw_message.get("self_id") - target_id = raw_message.get("target_id") - target_name: str = None - raw_info: list = raw_message.get("raw_info") - # 计算Seg - if self_id == target_id: - target_name = self_info.get("nickname") - else: - return None - try: - first_txt = raw_info[2].get("txt", "戳了戳") - second_txt = raw_info[4].get("txt", "") - except Exception as e: - logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") - first_txt = "戳了戳" - second_txt = "" - """ - # 不启用戳其他人的处理 - else: - # 由于Napcat不支持获取昵称,所以需要单独获取 - group_id = raw_message.get("group_id") - fetched_member_info: dict = await get_member_info( - self.server_connection, group_id, target_id - ) - if fetched_member_info: - target_name = fetched_member_info.get("nickname") - """ - seg_data: Seg = Seg( - type="text", - data=f"{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)", - ) - return seg_data - - async def handle_forward_message(self, message_list: list) -> Seg | None: - """ - 递归处理转发消息,并按照动态方式确定图片处理方式 - Parameters: - message_list: list: 转发消息列表 - """ - handled_message, image_count = await self._handle_forward_message(message_list, 0) - handled_message: Seg - image_count: int - if not handled_message: - return None - if image_count < 5 and image_count > 0: - # 处理图片数量小于5的情况,此时解析图片为base64 - logger.trace("图片数量小于5,开始解析图片为base64") - return await self._recursive_parse_image_seg(handled_message, True) - elif image_count > 0: - logger.trace("图片数量大于等于5,开始解析图片为占位符") - # 处理图片数量大于等于5的情况,此时解析图片为占位符 - return await self._recursive_parse_image_seg(handled_message, False) - else: - # 处理没有图片的情况,此时直接返回 - logger.trace("没有图片,直接返回") - return handled_message - - async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg: - # sourcery skip: merge-else-if-into-elif - if to_image: - if seg_data.type == "seglist": - new_seg_list = [] - for i_seg in seg_data.data: - parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image) - new_seg_list.append(parsed_seg) - return Seg(type="seglist", data=new_seg_list) - elif seg_data.type == "image": - image_url = seg_data.data - try: - encoded_image = await get_image_base64(image_url) - except Exception as e: - logger.error(f"图片处理失败: {str(e)}") - return Seg(type="text", data="[图片]") - return Seg(type="image", data=encoded_image) - elif seg_data.type == "emoji": - image_url = seg_data.data - try: - encoded_image = await get_image_base64(image_url) - except Exception as e: - logger.error(f"图片处理失败: {str(e)}") - return Seg(type="text", data="[表情包]") - return Seg(type="emoji", data=encoded_image) - else: - logger.trace(f"不处理类型: {seg_data.type}") - return seg_data - else: - if seg_data.type == "seglist": - new_seg_list = [] - for i_seg in seg_data.data: - parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image) - new_seg_list.append(parsed_seg) - return Seg(type="seglist", data=new_seg_list) - elif seg_data.type == "image": - return Seg(type="text", data="[图片]") - elif seg_data.type == "emoji": - return Seg(type="text", data="[动画表情]") - else: - logger.trace(f"不处理类型: {seg_data.type}") - return seg_data - - async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]: - # sourcery skip: low-code-quality - """ - 递归处理实际转发消息 - Parameters: - message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段 - layer: int: 当前层级 - Returns: - seg_data: Seg: 处理后的消息段 - image_count: int: 图片数量 - """ - seg_list: List[Seg] = [] - image_count = 0 - if message_list is None: - return None, 0 - for sub_message in message_list: - sub_message: dict - sender_info: dict = sub_message.get("sender") - user_nickname: str = sender_info.get("nickname", "QQ用户") - user_nickname_str = f"【{user_nickname}】:" - break_seg = Seg(type="text", data="\n") - message_of_sub_message_list: dict = sub_message.get("message") - if not message_of_sub_message_list: - logger.warning("转发消息内容为空") - continue - message_of_sub_message = message_of_sub_message_list[0] - if message_of_sub_message.get("type") == RealMessageType.forward: - if layer >= 3: - full_seg_data = Seg( - type="text", - data=("--" * layer) + f"【{user_nickname}】:【转发消息】\n", - ) - else: - sub_message_data = message_of_sub_message.get("data") - if not sub_message_data: - continue - contents = sub_message_data.get("content") - seg_data, count = await self._handle_forward_message(contents, layer + 1) - image_count += count - head_tip = Seg( - type="text", - data=("--" * layer) + f"【{user_nickname}】: 合并转发消息内容:\n", - ) - full_seg_data = Seg(type="seglist", data=[head_tip, seg_data]) - seg_list.append(full_seg_data) - elif message_of_sub_message.get("type") == RealMessageType.text: - sub_message_data = message_of_sub_message.get("data") - if not sub_message_data: - continue - text_message = sub_message_data.get("text") - seg_data = Seg(type="text", data=text_message) - data_list: List[Any] = [] - if layer > 0: - data_list = [ - Seg(type="text", data=("--" * layer) + user_nickname_str), - seg_data, - break_seg, - ] - else: - data_list = [ - Seg(type="text", data=user_nickname_str), - seg_data, - break_seg, - ] - seg_list.append(Seg(type="seglist", data=data_list)) - elif message_of_sub_message.get("type") == RealMessageType.image: - image_count += 1 - image_data = message_of_sub_message.get("data") - sub_type = image_data.get("sub_type") - image_url = image_data.get("url") - data_list: List[Any] = [] - if sub_type == 0: - seg_data = Seg(type="image", data=image_url) - else: - seg_data = Seg(type="emoji", data=image_url) - if layer > 0: - data_list = [ - Seg(type="text", data=("--" * layer) + user_nickname_str), - seg_data, - break_seg, - ] - else: - data_list = [ - Seg(type="text", data=user_nickname_str), - seg_data, - break_seg, - ] - full_seg_data = Seg(type="seglist", data=data_list) - seg_list.append(full_seg_data) - return Seg(type="seglist", data=seg_list), image_count - - async def message_process(self, message_base: MessageBase) -> None: - try: - await self.maibot_router.send_message(message_base) - except Exception as e: - logger.error(f"发送消息失败: {str(e)}") - logger.error("请检查与MaiBot之间的连接") - return None - - -recv_handler = RecvHandler() diff --git a/src/recv_handler/__init__.py b/src/recv_handler/__init__.py new file mode 100644 index 0000000..e4c9744 --- /dev/null +++ b/src/recv_handler/__init__.py @@ -0,0 +1,127 @@ +from enum import Enum + + +class MetaEventType: + lifecycle = "lifecycle" # 生命周期 + + class Lifecycle: + connect = "connect" # 生命周期 - WebSocket 连接成功 + + heartbeat = "heartbeat" # 心跳 + + +class MessageType: # 接受消息大类 + private = "private" # 私聊消息 + + class Private: + friend = "friend" # 私聊消息 - 好友 + group = "group" # 私聊消息 - 群临时 + group_self = "group_self" # 私聊消息 - 群中自身发送 + other = "other" # 私聊消息 - 其他 + + group = "group" # 群聊消息 + + class Group: + normal = "normal" # 群聊消息 - 普通 + anonymous = "anonymous" # 群聊消息 - 匿名消息 + notice = "notice" # 群聊消息 - 系统提示 + + +class NoticeType: # 通知事件 + friend_recall = "friend_recall" # 私聊消息撤回 + group_recall = "group_recall" # 群聊消息撤回 + notify = "notify" + group_ban = "group_ban" # 群禁言 + group_msg_emoji_like = "group_msg_emoji_like" # 群消息表情回应 + group_upload = "group_upload" # 群文件上传 + group_increase = "group_increase" # 群成员增加 + group_decrease = "group_decrease" # 群成员减少 + group_admin = "group_admin" # 群管理员变动 + essence = "essence" # 精华消息 + + class Notify: + poke = "poke" # 戳一戳 + group_name = "group_name" # 群名称变更 + + class GroupBan: + ban = "ban" # 禁言 + lift_ban = "lift_ban" # 解除禁言 + + class GroupIncrease: + approve = "approve" # 管理员同意入群 + invite = "invite" # 被邀请入群 + + class GroupDecrease: + leave = "leave" # 主动退群 + kick = "kick" # 被踢出群 + kick_me = "kick_me" # 机器人被踢 + + class GroupAdmin: + set = "set" # 设置管理员 + unset = "unset" # 取消管理员 + + class Essence: + add = "add" # 添加精华消息 + delete = "delete" # 移除精华消息 + + +class RealMessageType: # 实际消息分类 + text = "text" # 纯文本 + face = "face" # qq表情 + image = "image" # 图片 + record = "record" # 语音 + video = "video" # 视频 + at = "at" # @某人 + rps = "rps" # 猜拳魔法表情 + dice = "dice" # 骰子 + shake = "shake" # 私聊窗口抖动(只收) + poke = "poke" # 群聊戳一戳 + share = "share" # 链接分享(json形式) + reply = "reply" # 回复消息 + forward = "forward" # 转发消息 + node = "node" # 转发消息节点 + json = "json" # JSON卡片消息 + file = "file" # 文件消息 + + +class MessageSentType: + private = "private" + + class Private: + friend = "friend" + group = "group" + + group = "group" + + class Group: + normal = "normal" + + +class CommandType(Enum): + """命令类型""" + + GROUP_BAN = "set_group_ban" # 禁言用户 + GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言 + GROUP_KICK = "set_group_kick" # 踢出群聊 + SEND_POKE = "send_poke" # 戳一戳 + DELETE_MSG = "delete_msg" # 撤回消息 + + def __str__(self) -> str: + return self.value + + +ACCEPT_FORMAT = [ + "text", + "image", + "emoji", + "reply", + "voice", + "command", + "voiceurl", + "music", + "videourl", + "file", + "imageurl", + "forward", + "video", +] diff --git a/src/recv_handler/message_handler.py b/src/recv_handler/message_handler.py new file mode 100644 index 0000000..54a5b4b --- /dev/null +++ b/src/recv_handler/message_handler.py @@ -0,0 +1,1001 @@ +from src.logger import logger +from src.config import global_config +from src.utils import ( + get_group_info, + get_member_info, + get_image_base64, + get_record_detail, + get_self_info, + get_message_detail, +) +import base64 +from .qq_emoji_list import qq_face +from .message_sending import message_send_instance +from . import RealMessageType, MessageType, ACCEPT_FORMAT + +import time +import json +import websockets as Server +from typing import List, Tuple, Optional, Dict, Any +import uuid + +from maim_message import ( + UserInfo, + GroupInfo, + Seg, + BaseMessageInfo, + MessageBase, + TemplateInfo, + FormatInfo, +) + + +from src.response_pool import get_response + + +class MessageHandler: + def __init__(self): + self.server_connection: Server.ServerConnection = None + self.bot_id_list: Dict[int, bool] = {} + + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + + async def check_allow_to_chat( + self, + user_id: int, + group_id: Optional[int] = None, + ignore_bot: Optional[bool] = False, + ignore_global_list: Optional[bool] = False, + ) -> bool: + # sourcery skip: hoist-statement-from-if, merge-else-if-into-elif + """ + 检查是否允许聊天 + Parameters: + user_id: int: 用户ID + group_id: int: 群ID + ignore_bot: bool: 是否忽略机器人检查 + ignore_global_list: bool: 是否忽略全局黑名单检查 + Returns: + bool: 是否允许聊天 + """ + logger.debug(f"群聊id: {group_id}, 用户id: {user_id}") + logger.debug("开始检查聊天白名单/黑名单") + if group_id: + if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list: + logger.warning("群聊不在聊天白名单中,消息被丢弃") + return False + elif global_config.chat.group_list_type == "blacklist" and group_id in global_config.chat.group_list: + logger.warning("群聊在聊天黑名单中,消息被丢弃") + return False + else: + if global_config.chat.private_list_type == "whitelist" and user_id not in global_config.chat.private_list: + logger.warning("私聊不在聊天白名单中,消息被丢弃") + return False + elif global_config.chat.private_list_type == "blacklist" and user_id in global_config.chat.private_list: + logger.warning("私聊在聊天黑名单中,消息被丢弃") + return False + if user_id in global_config.chat.ban_user_id and not ignore_global_list: + logger.warning("用户在全局黑名单中,消息被丢弃") + return False + + if global_config.chat.ban_qq_bot and group_id and not ignore_bot: + logger.debug("开始判断是否为机器人") + member_info = await get_member_info(self.server_connection, group_id, user_id) + if member_info: + is_bot = member_info.get("is_robot") + if is_bot is None: + logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新") + else: + if is_bot: + logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单") + self.bot_id_list[user_id] = True + return False + else: + self.bot_id_list[user_id] = False + + return True + + async def handle_raw_message(self, raw_message: dict) -> None: + # sourcery skip: low-code-quality, remove-unreachable-code + """ + 从Napcat接受的原始消息处理 + + Parameters: + raw_message: dict: 原始消息 + """ + message_type: str = raw_message.get("message_type") + message_id: int = raw_message.get("message_id") + # message_time: int = raw_message.get("time") + message_time: float = time.time() # 应可乐要求,现在是float了 + + template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用 + format_info: FormatInfo = FormatInfo( + content_format=["text", "image", "emoji", "voice"], + accept_format=ACCEPT_FORMAT, + ) # 格式化信息 + if message_type == MessageType.private: + sub_type = raw_message.get("sub_type") + if sub_type == MessageType.Private.friend: + sender_info: dict = raw_message.get("sender") + + if not await self.check_allow_to_chat(sender_info.get("user_id"), None): + return None + + # 发送者用户信息 + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=sender_info.get("user_id"), + user_nickname=sender_info.get("nickname"), + user_cardname=sender_info.get("card"), + ) + + # 不存在群信息 + group_info: GroupInfo = None + elif sub_type == MessageType.Private.group: + """ + 本部分暂时不做支持,先放着 + """ + logger.warning("群临时消息类型不支持") + return None + + sender_info: dict = raw_message.get("sender") + + # 由于临时会话中,Napcat默认不发送成员昵称,所以需要单独获取 + fetched_member_info: dict = await get_member_info( + self.server_connection, + raw_message.get("group_id"), + sender_info.get("user_id"), + ) + nickname = fetched_member_info.get("nickname") if fetched_member_info else None + # 发送者用户信息 + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=sender_info.get("user_id"), + user_nickname=nickname, + user_cardname=None, + ) + + # -------------------这里需要群信息吗?------------------- + + # 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有 + fetched_group_info: dict = await get_group_info(self.server_connection, raw_message.get("group_id")) + group_name = "" + if fetched_group_info.get("group_name"): + group_name = fetched_group_info.get("group_name") + + group_info: GroupInfo = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=raw_message.get("group_id"), + group_name=group_name, + ) + + else: + logger.warning(f"私聊消息类型 {sub_type} 不支持") + return None + elif message_type == MessageType.group: + sub_type = raw_message.get("sub_type") + if sub_type == MessageType.Group.normal: + sender_info: dict = raw_message.get("sender") + + if not await self.check_allow_to_chat(sender_info.get("user_id"), raw_message.get("group_id")): + return None + + # 发送者用户信息 + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=sender_info.get("user_id"), + user_nickname=sender_info.get("nickname"), + user_cardname=sender_info.get("card"), + ) + + # 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有 + fetched_group_info = await get_group_info(self.server_connection, raw_message.get("group_id")) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + + group_info: GroupInfo = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=raw_message.get("group_id"), + group_name=group_name, + ) + + else: + logger.warning(f"群聊消息类型 {sub_type} 不支持") + return None + + # 处理实际信息 + if not raw_message.get("message"): + logger.warning("原始消息内容为空") + return None + + # 获取Seg列表 + seg_message, additional_config = await self.handle_real_message(raw_message) + if global_config.voice.use_tts: + additional_config["allow_tts"] = True + + if not seg_message: + logger.warning("处理后消息内容为空") + return None + submit_seg: Seg = Seg( + type="seglist", + data=seg_message, + ) + + # 消息信息 + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id=message_id, + time=message_time, + user_info=user_info, + group_info=group_info, + template_info=template_info, + format_info=format_info, + additional_config=additional_config, + ) + + # MessageBase创建 + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=submit_seg, + raw_message=raw_message.get("raw_message"), + ) + + logger.info("发送到Maibot处理信息") + await message_send_instance.message_send(message_base) + + async def handle_real_message( + self, raw_message: dict, in_reply: bool = False + ) -> Tuple[List[Seg] | None, Dict[str, Any]]: + # sourcery skip: low-code-quality + """ + 处理实际消息 + Parameters: + real_message: dict: 实际消息 + Returns: + seg_message: list[Seg]: 处理后的消息段列表 + """ + additional_config: dict = {} + real_message: list = raw_message.get("message") + if not real_message: + logger.warning("实际消息内容为空") + return None, {} + seg_message: List[Seg] = [] + for sub_message in real_message: + sub_message: dict + sub_message_type = sub_message.get("type") + match sub_message_type: + case RealMessageType.text: + ret_seg = await self.handle_text_message(sub_message) + if ret_seg: + seg_message.append(ret_seg) + else: + logger.warning("text处理失败") + case RealMessageType.face: + ret_seg = await self.handle_face_message(sub_message) + if ret_seg: + seg_message.append(ret_seg) + else: + logger.warning("face处理失败或不支持") + case RealMessageType.reply: + if not in_reply: + ret_seg, additional_config = await self.handle_reply_message(sub_message, additional_config) + if ret_seg: + seg_message += ret_seg + else: + logger.warning("reply处理失败") + case RealMessageType.image: + ret_seg = await self.handle_image_message(sub_message) + if ret_seg: + seg_message.append(ret_seg) + else: + logger.warning("image处理失败") + case RealMessageType.record: + ret_seg = await self.handle_record_message(sub_message) + if ret_seg: + seg_message.clear() + seg_message.append(ret_seg) + break # 使得消息只有record消息 + else: + logger.warning("record处理失败或不支持") + case RealMessageType.video: + ret_seg = await self.handle_video_message(sub_message) + if ret_seg: + seg_message.append(ret_seg) + else: + logger.warning("video处理失败") + case RealMessageType.json: + ret_segs = await self.handle_json_message(sub_message) + if ret_segs: + seg_message.extend(ret_segs) + else: + logger.warning("json处理失败") + case RealMessageType.file: + ret_seg = await self.handle_file_message(sub_message) + if ret_seg: + seg_message.append(ret_seg) + else: + logger.warning("file处理失败") + case RealMessageType.at: + ret_seg = await self.handle_at_message( + sub_message, + raw_message.get("self_id"), + raw_message.get("group_id"), + ) + if ret_seg: + seg_message.append(ret_seg) + else: + logger.warning("at处理失败") + case RealMessageType.rps: + logger.warning("暂时不支持猜拳魔法表情解析") + case RealMessageType.dice: + logger.warning("暂时不支持骰子表情解析") + case RealMessageType.shake: + # 预计等价于戳一戳 + logger.warning("暂时不支持窗口抖动解析") + case RealMessageType.share: + logger.warning("暂时不支持链接解析") + case RealMessageType.forward: + messages = await self._get_forward_message(sub_message) + if not messages: + logger.warning("转发消息内容为空或获取失败") + return None, {} + ret_seg = await self.handle_forward_message(messages) + if ret_seg: + seg_message.append(ret_seg) + else: + logger.warning("转发消息处理失败") + case RealMessageType.node: + logger.warning("不支持转发消息节点解析") + case _: + logger.warning(f"未知消息类型: {sub_message_type}") + return seg_message, additional_config + + async def handle_text_message(self, raw_message: dict) -> Seg: + """ + 处理纯文本信息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + plain_text: str = message_data.get("text") + return Seg(type="text", data=plain_text) + + async def handle_face_message(self, raw_message: dict) -> Seg | None: + """ + 处理表情消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + face_raw_id: str = str(message_data.get("id")) + if face_raw_id in qq_face: + face_content: str = qq_face.get(face_raw_id) + return Seg(type="text", data=face_content) + else: + logger.warning(f"不支持的表情:{face_raw_id}") + return None + + async def handle_image_message(self, raw_message: dict) -> Seg | None: + """ + 处理图片消息与表情包消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + image_sub_type = message_data.get("sub_type") + try: + image_base64 = await get_image_base64(message_data.get("url")) + except Exception as e: + logger.error(f"图片消息处理失败: {str(e)}") + return None + if image_sub_type == 0: + """这部分认为是图片""" + return Seg(type="image", data=image_base64) + elif image_sub_type not in [4, 9]: + """这部分认为是表情包""" + return Seg(type="emoji", data=image_base64) + else: + logger.warning(f"不支持的图片子类型:{image_sub_type}") + return None + + async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg | None: + # sourcery skip: use-named-expression + """ + 处理at消息 + Parameters: + raw_message: dict: 原始消息 + self_id: int: 机器人QQ号 + group_id: int: 群号 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + if message_data: + qq_id = message_data.get("qq") + if str(self_id) == str(qq_id): + logger.debug("机器人被at") + self_info: dict = await get_self_info(self.server_connection) + if self_info: + return Seg(type="text", data=f"@<{self_info.get('nickname')}:{self_info.get('user_id')}>") + else: + return None + else: + member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=qq_id) + if member_info: + return Seg(type="text", data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>") + else: + return None + + async def handle_record_message(self, raw_message: dict) -> Seg | None: + """ + 处理语音消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + file: str = message_data.get("file") + if not file: + logger.warning("语音消息缺少文件信息") + return None + try: + record_detail = await get_record_detail(self.server_connection, file) + if not record_detail: + logger.warning("获取语音消息详情失败") + return None + audio_base64: str = record_detail.get("base64") + except Exception as e: + logger.error(f"语音消息处理失败: {str(e)}") + return None + if not audio_base64: + logger.error("语音消息处理失败,未获取到音频数据") + return None + return Seg(type="voice", data=audio_base64) + + async def handle_video_message(self, raw_message: dict) -> Seg | None: + """ + 处理视频消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段(video_card类型) + """ + message_data: dict = raw_message.get("data") + file: str = message_data.get("file", "") + url: str = message_data.get("url", "") + file_size: str = message_data.get("file_size", "") + + if not file: + logger.warning("视频消息缺少文件信息") + return None + + # 返回结构化的视频卡片数据 + return Seg(type="video_card", data={ + "file": file, + "file_size": file_size, + "url": url + }) + + async def handle_json_message(self, raw_message: dict) -> List[Seg] | None: + """ + 处理JSON卡片消息(小程序、分享、群公告等) + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: List[Seg]: 处理后的消息段列表(可能包含文本和图片) + """ + message_data: dict = raw_message.get("data") + json_data: str = message_data.get("data") + + if not json_data: + logger.warning("JSON消息缺少数据") + return None + + try: + # 尝试解析JSON获取详细信息 + parsed_json = json.loads(json_data) + app = parsed_json.get("app", "") + meta = parsed_json.get("meta", {}) + + # 群公告(由于图片URL是加密的,因此无法读取) + if app == "com.tencent.mannounce": + mannounce = meta.get("mannounce", {}) + title = mannounce.get("title", "") + text = mannounce.get("text", "") + encode_flag = mannounce.get("encode", 0) + if encode_flag == 1: + try: + if title: + title = base64.b64decode(title).decode("utf-8", errors="ignore") + if text: + text = base64.b64decode(text).decode("utf-8", errors="ignore") + except Exception as e: + logger.warning(f"群公告Base64解码失败: {e}") + if title and text: + content = f"[{title}]:{text}" + elif title: + content = f"[{title}]" + elif text: + content = f"{text}" + else: + content = "[群公告]" + return [Seg(type="text", data=content)] + + # 音乐卡片 + if app in ("com.tencent.music.lua", "com.tencent.structmsg"): + music = meta.get("music", {}) + if music: + title = music.get("title", "") + singer = music.get("desc", "") or music.get("singer", "") + jump_url = music.get("jumpUrl", "") or music.get("jump_url", "") + music_url = music.get("musicUrl", "") or music.get("music_url", "") + tag = music.get("tag", "") + preview = music.get("preview", "") + + return [Seg(type="music_card", data={ + "title": title, + "singer": singer, + "jump_url": jump_url, + "music_url": music_url, + "tag": tag, + "preview": preview + })] + + # QQ小程序分享(含预览图) + if app == "com.tencent.miniapp_01": + detail = meta.get("detail_1", {}) + if detail: + title = detail.get("title", "") + desc = detail.get("desc", "") + url = detail.get("url", "") + qqdocurl = detail.get("qqdocurl", "") + preview_url = detail.get("preview", "") + icon = detail.get("icon", "") + + seg_list = [Seg(type="miniapp_card", data={ + "title": title, + "desc": desc, + "url": url, + "source_url": qqdocurl, + "preview": preview_url, + "icon": icon + })] + + # 下载预览图 + if preview_url: + try: + image_base64 = await get_image_base64(preview_url) + seg_list.append(Seg(type="image", data=image_base64)) + except Exception as e: + logger.error(f"QQ小程序预览图下载失败: {e}") + + return seg_list + + # 礼物消息 + if app == "com.tencent.giftmall.giftark": + giftark = meta.get("giftark", {}) + if giftark: + gift_name = giftark.get("title", "礼物") + desc = giftark.get("desc", "") + gift_text = f"[赠送礼物: {gift_name}]" + if desc: + gift_text += f"\n{desc}" + return [Seg(type="text", data=gift_text)] + + # 推荐联系人 + if app == "com.tencent.contact.lua": + contact_info = meta.get("contact", {}) + name = contact_info.get("nickname", "未知联系人") + tag = contact_info.get("tag", "推荐联系人") + return [Seg(type="text", data=f"[{tag}] {name}")] + + # 推荐群聊 + if app == "com.tencent.troopsharecard": + contact_info = meta.get("contact", {}) + name = contact_info.get("nickname", "未知群聊") + tag = contact_info.get("tag", "推荐群聊") + return [Seg(type="text", data=f"[{tag}] {name}")] + + # 图文分享(如 哔哩哔哩HD、网页、群精华等) + if app == "com.tencent.tuwen.lua": + news = meta.get("news", {}) + title = news.get("title", "未知标题") + desc = (news.get("desc", "") or "").replace("[图片]", "").strip() + tag = news.get("tag", "图文分享") + preview_url = news.get("preview", "") + if tag and title and tag in title: + title = title.replace(tag, "", 1).strip(":: -— ") + text_content = f"[{tag}] {title}:{desc}" + seg_list = [Seg(type="text", data=text_content)] + + # 下载预览图 + if preview_url: + try: + image_base64 = await get_image_base64(preview_url) + seg_list.append(Seg(type="image", data=image_base64)) + except Exception as e: + logger.error(f"图文预览图下载失败: {e}") + + return seg_list + + # 群相册(含预览图) + if app == "com.tencent.feed.lua": + feed = meta.get("feed", {}) + title = feed.get("title", "群相册") + tag = feed.get("tagName", "群相册") + desc = feed.get("forwardMessage", "") + cover_url = feed.get("cover", "") + if tag and title and tag in title: + title = title.replace(tag, "", 1).strip(":: -— ") + text_content = f"[{tag}] {title}:{desc}" + seg_list = [Seg(type="text", data=text_content)] + + # 下载封面图 + if cover_url: + try: + image_base64 = await get_image_base64(cover_url) + seg_list.append(Seg(type="image", data=image_base64)) + except Exception as e: + logger.error(f"群相册封面下载失败: {e}") + + return seg_list + + # QQ收藏分享(含预览图) + if app == "com.tencent.template.qqfavorite.share": + news = meta.get("news", {}) + desc = news.get("desc", "").replace("[图片]", "").strip() + tag = news.get("tag", "QQ收藏") + preview_url = news.get("preview", "") + seg_list = [Seg(type="text", data=f"[{tag}] {desc}")] + + # 下载预览图 + if preview_url: + try: + image_base64 = await get_image_base64(preview_url) + seg_list.append(Seg(type="image", data=image_base64)) + except Exception as e: + logger.error(f"QQ收藏预览图下载失败: {e}") + + return seg_list + + # QQ空间分享(含预览图) + if app == "com.tencent.miniapp.lua": + miniapp = meta.get("miniapp", {}) + title = miniapp.get("title", "未知标题") + tag = miniapp.get("tag", "QQ空间") + preview_url = miniapp.get("preview", "") + seg_list = [Seg(type="text", data=f"[{tag}] {title}")] + + # 下载预览图 + if preview_url: + try: + image_base64 = await get_image_base64(preview_url) + seg_list.append(Seg(type="image", data=image_base64)) + except Exception as e: + logger.error(f"QQ空间预览图下载失败: {e}") + + return seg_list + + # QQ频道分享(含预览图) + if app == "com.tencent.forum": + detail = meta.get("detail") if isinstance(meta, dict) else None + if detail: + feed = detail.get("feed", {}) + poster = detail.get("poster", {}) + channel_info = detail.get("channel_info", {}) + guild_name = channel_info.get("guild_name", "") + nick = poster.get("nick", "QQ用户") + title = feed.get("title", {}).get("contents", [{}])[0].get("text_content", {}).get("text", "帖子") + face_content = "" + for item in feed.get("contents", {}).get("contents", []): + emoji = item.get("emoji_content") + if emoji: + eid = emoji.get("id") + if eid in qq_face: + face_content += qq_face.get(eid, "") + + seg_list = [Seg(type="text", data=f"[频道帖子] [{guild_name}]{nick}:{title}{face_content}")] + + # 下载帖子中的图片 + pic_urls = [img.get("pic_url") for img in feed.get("images", []) if img.get("pic_url")] + for pic_url in pic_urls: + try: + image_base64 = await get_image_base64(pic_url) + seg_list.append(Seg(type="image", data=image_base64)) + except Exception as e: + logger.error(f"QQ频道图片下载失败: {e}") + + return seg_list + + # QQ地图位置分享 + if app == "com.tencent.map": + location = meta.get("Location.Search", {}) + name = location.get("name", "未知地点") + address = location.get("address", "") + return [Seg(type="text", data=f"[位置] {address} · {name}")] + + # QQ一起听歌 + if app == "com.tencent.together": + invite = (meta or {}).get("invite", {}) + title = invite.get("title") or "一起听歌" + summary = invite.get("summary") or "" + return [Seg(type="text", data=f"[{title}] {summary}")] + + # 其他卡片消息使用prompt字段 + prompt = parsed_json.get("prompt", "[卡片消息]") + return [Seg(type="text", data=prompt)] + except json.JSONDecodeError: + logger.warning("JSON消息解析失败") + return [Seg(type="text", data="[卡片消息]")] + except Exception as e: + logger.error(f"JSON消息处理异常: {e}") + return [Seg(type="text", data="[卡片消息]")] + + async def handle_file_message(self, raw_message: dict) -> Seg | None: + """ + 处理文件消息 + Parameters: + raw_message: dict: 原始消息 + Returns: + seg_data: Seg: 处理后的消息段 + """ + message_data: dict = raw_message.get("data") + file_name: str = message_data.get("file") + file_size: str = message_data.get("file_size", "未知大小") + file_url: str = message_data.get("url") + + if not file_name: + logger.warning("文件消息缺少文件名") + return None + + file_text = f"[文件: {file_name}, 大小: {file_size}字节]" + if file_url: + file_text += f"\n文件链接: {file_url}" + + return Seg(type="text", data=file_text) + + async def handle_reply_message(self, raw_message: dict, additional_config: dict) -> Tuple[List[Seg] | None, dict]: + # sourcery skip: move-assign-in-block, use-named-expression + """ + 处理回复消息 + + """ + raw_message_data: dict = raw_message.get("data") + message_id: int = None + if raw_message_data: + message_id = raw_message_data.get("id") + else: + return None, {} + additional_config["reply_message_id"] = message_id + message_detail: dict = await get_message_detail(self.server_connection, message_id) + if not message_detail: + logger.warning("获取被引用的消息详情失败") + return None, {} + reply_message, _ = await self.handle_real_message(message_detail, in_reply=True) + if reply_message is None: + reply_message = [Seg(type="text", data="(获取发言内容失败)")] + sender_info: dict = message_detail.get("sender") + sender_nickname: str = sender_info.get("nickname") + sender_id: str = sender_info.get("user_id") + seg_message: List[Seg] = [] + if not sender_nickname: + logger.warning("无法获取被引用的人的昵称,返回默认值") + seg_message.append(Seg(type="text", data="[回复 未知用户:")) + else: + seg_message.append(Seg(type="text", data=f"[回复<{sender_nickname}:{sender_id}>:")) + seg_message += reply_message + seg_message.append(Seg(type="text", data="],说:")) + return seg_message, additional_config + + async def handle_forward_message(self, message_list: list) -> Seg | None: + """ + 递归处理转发消息,并按照动态方式确定图片处理方式 + Parameters: + message_list: list: 转发消息列表 + """ + handled_message, image_count = await self._handle_forward_message(message_list, 0) + handled_message: Seg + image_count: int + if not handled_message: + return None + + # 添加转发消息的标题和结束标识 + forward_header = Seg(type="text", data="========== 转发消息开始 ==========\n") + forward_footer = Seg(type="text", data="========== 转发消息结束 ==========") + + # 图片阈值:超过此数量使用占位符避免麦麦VLM处理卡死 + image_threshold = global_config.forward.image_threshold + + if image_count < image_threshold and image_count > 0: + # 处理图片数量小于阈值的情况,此时解析图片为base64 + logger.trace(f"图片数量({image_count})小于{image_threshold},开始解析图片为base64") + parsed_message = await self._recursive_parse_image_seg(handled_message, True) + return Seg(type="seglist", data=[forward_header, parsed_message, forward_footer]) + elif image_count > 0: + logger.trace(f"图片数量({image_count})大于等于{image_threshold},开始解析图片为占位符") + # 处理图片数量大于等于阈值的情况,此时解析图片为占位符 + parsed_message = await self._recursive_parse_image_seg(handled_message, False) + return Seg(type="seglist", data=[forward_header, parsed_message, forward_footer]) + else: + # 处理没有图片的情况,此时直接返回 + logger.trace("没有图片,直接返回") + return Seg(type="seglist", data=[forward_header, handled_message, forward_footer]) + + async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg: + # sourcery skip: merge-else-if-into-elif + if to_image: + if seg_data.type == "seglist": + new_seg_list = [] + for i_seg in seg_data.data: + parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image) + new_seg_list.append(parsed_seg) + return Seg(type="seglist", data=new_seg_list) + elif seg_data.type == "image": + image_url = seg_data.data + try: + encoded_image = await get_image_base64(image_url) + except Exception as e: + logger.error(f"图片处理失败: {str(e)}") + return Seg(type="text", data="[图片]") + return Seg(type="image", data=encoded_image) + elif seg_data.type == "emoji": + image_url = seg_data.data + try: + encoded_image = await get_image_base64(image_url) + except Exception as e: + logger.error(f"图片处理失败: {str(e)}") + return Seg(type="text", data="[表情包]") + return Seg(type="emoji", data=encoded_image) + else: + logger.trace(f"不处理类型: {seg_data.type}") + return seg_data + else: + if seg_data.type == "seglist": + new_seg_list = [] + for i_seg in seg_data.data: + parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image) + new_seg_list.append(parsed_seg) + return Seg(type="seglist", data=new_seg_list) + elif seg_data.type == "image": + return Seg(type="text", data="[图片]") + elif seg_data.type == "emoji": + return Seg(type="text", data="[动画表情]") + else: + logger.trace(f"不处理类型: {seg_data.type}") + return seg_data + + async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]: + # sourcery skip: low-code-quality + """ + 递归处理实际转发消息 + Parameters: + message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段 + layer: int: 当前层级 + Returns: + seg_data: Seg: 处理后的消息段 + image_count: int: 图片数量 + """ + seg_list: List[Seg] = [] + image_count = 0 + if message_list is None: + return None, 0 + # 统一在最前加入【转发消息】标识(带层级缩进) + seg_list.append(Seg(type="text", data=("--" * layer) + "\n【转发消息】\n")) + for sub_message in message_list: + sub_message: dict + sender_info: dict = sub_message.get("sender") + user_nickname: str = sender_info.get("nickname", "QQ用户") + user_nickname_str = f"【{user_nickname}】:" + break_seg = Seg(type="text", data="\n") + message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message") + if not message_of_sub_message_list: + logger.warning("转发消息内容为空") + continue + message_of_sub_message = message_of_sub_message_list[0] + if message_of_sub_message.get("type") == RealMessageType.forward: + sub_message_data = message_of_sub_message.get("data") + if not sub_message_data: + continue + contents = sub_message_data.get("content") + seg_data, count = await self._handle_forward_message(contents, layer + 1) + image_count += count + head_tip = Seg( + type="text", + data=("--" * layer) + f"【{user_nickname}】: 合并转发消息内容:\n", + ) + full_seg_data = Seg(type="seglist", data=[head_tip, seg_data]) + seg_list.append(full_seg_data) + elif message_of_sub_message.get("type") == RealMessageType.text: + sub_message_data = message_of_sub_message.get("data") + if not sub_message_data: + continue + text_message = sub_message_data.get("text") + seg_data = Seg(type="text", data=text_message) + data_list: List[Any] = [] + if layer > 0: + data_list = [ + Seg(type="text", data=("--" * layer) + user_nickname_str), + seg_data, + break_seg, + ] + else: + data_list = [ + Seg(type="text", data=user_nickname_str), + seg_data, + break_seg, + ] + seg_list.append(Seg(type="seglist", data=data_list)) + elif message_of_sub_message.get("type") == RealMessageType.image: + image_count += 1 + image_data = message_of_sub_message.get("data") + sub_type = image_data.get("sub_type") + image_url = image_data.get("url") + data_list: List[Any] = [] + if sub_type == 0: + seg_data = Seg(type="image", data=image_url) + else: + seg_data = Seg(type="emoji", data=image_url) + if layer > 0: + data_list = [ + Seg(type="text", data=("--" * layer) + user_nickname_str), + seg_data, + break_seg, + ] + else: + data_list = [ + Seg(type="text", data=user_nickname_str), + seg_data, + break_seg, + ] + full_seg_data = Seg(type="seglist", data=data_list) + seg_list.append(full_seg_data) + # 在结尾追加标识 + seg_list.append(Seg(type="text", data=("--" * layer) + "【转发消息结束】")) + return Seg(type="seglist", data=seg_list), image_count + + async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None: + forward_message_data: Dict = raw_message.get("data") + if not forward_message_data: + logger.warning("转发消息内容为空") + return None + forward_message_id = forward_message_data.get("id") + request_uuid = str(uuid.uuid4()) + payload = json.dumps( + { + "action": "get_forward_msg", + "params": {"message_id": forward_message_id}, + "echo": request_uuid, + } + ) + try: + await self.server_connection.send(payload) + response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error("获取转发消息超时") + return None + except Exception as e: + logger.error(f"获取转发消息失败: {str(e)}") + return None + logger.debug( + f"转发消息原始格式:{json.dumps(response)[:80]}..." + if len(json.dumps(response)) > 80 + else json.dumps(response) + ) + response_data: Dict = response.get("data") + if not response_data: + logger.warning("转发消息内容为空或获取失败") + return None + return response_data.get("messages") + + +message_handler = MessageHandler() diff --git a/src/recv_handler/message_sending.py b/src/recv_handler/message_sending.py new file mode 100644 index 0000000..2d92f02 --- /dev/null +++ b/src/recv_handler/message_sending.py @@ -0,0 +1,79 @@ +from typing import Dict +import json +from src.logger import logger +from maim_message import MessageBase, Router + + +# 消息大小限制 (字节) +# WebSocket 服务端限制为 100MB,这里设置 95MB 留一点余量 +MAX_MESSAGE_SIZE_BYTES = 95 * 1024 * 1024 # 95MB +MAX_MESSAGE_SIZE_KB = MAX_MESSAGE_SIZE_BYTES / 1024 +MAX_MESSAGE_SIZE_MB = MAX_MESSAGE_SIZE_KB / 1024 + + +class MessageSending: + """ + 负责把消息发送到麦麦 + """ + + maibot_router: Router = None + + def __init__(self): + pass + + async def message_send(self, message_base: MessageBase) -> bool: + """ + 发送消息 + Parameters: + message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息 + """ + try: + # 计算消息大小用于调试 + msg_dict = message_base.to_dict() + msg_json = json.dumps(msg_dict, ensure_ascii=False) + msg_size_bytes = len(msg_json.encode('utf-8')) + msg_size_kb = msg_size_bytes / 1024 + msg_size_mb = msg_size_kb / 1024 + + logger.debug(f"发送消息大小: {msg_size_kb:.2f} KB") + + # 检查消息是否超过大小限制 + if msg_size_bytes > MAX_MESSAGE_SIZE_BYTES: + logger.error( + f"消息大小 ({msg_size_mb:.2f} MB) 超过限制 ({MAX_MESSAGE_SIZE_MB:.0f} MB)," + f"消息已被丢弃以避免连接断开" + ) + logger.warning( + f"被丢弃的消息来源: platform={message_base.message_info.platform}, " + f"group_id={message_base.message_info.group_info.group_id if message_base.message_info.group_info else 'N/A'}, " + f"user_id={message_base.message_info.user_info.user_id if message_base.message_info.user_info else 'N/A'}" + ) + return False + + if msg_size_kb > 1024: # 超过 1MB 时警告 + logger.warning(f"发送的消息较大 ({msg_size_mb:.2f} MB),可能导致传输延迟") + + send_status = await self.maibot_router.send_message(message_base) + if not send_status: + raise RuntimeError("可能是路由未正确配置或连接异常") + logger.debug("消息发送成功") + return send_status + except Exception as e: + logger.error(f"发送消息失败: {str(e)}") + logger.error("请检查与MaiBot之间的连接") + return False + + async def send_custom_message(self, custom_message: Dict, platform: str, message_type: str) -> bool: + """ + 发送自定义消息 + """ + try: + await self.maibot_router.send_custom_message(platform=platform, message_type_name=message_type, message=custom_message) + return True + except Exception as e: + logger.error(f"发送自定义消息失败: {str(e)}") + logger.error("请检查与MaiBot之间的连接") + return False + + +message_send_instance = MessageSending() diff --git a/src/recv_handler/meta_event_handler.py b/src/recv_handler/meta_event_handler.py new file mode 100644 index 0000000..40f5a1a --- /dev/null +++ b/src/recv_handler/meta_event_handler.py @@ -0,0 +1,61 @@ +from src.logger import logger +from src.config import global_config +import time +import asyncio + +from . import MetaEventType + + +class MetaEventHandler: + """ + 处理Meta事件 + """ + + def __init__(self): + self.interval = global_config.napcat_server.heartbeat_interval + self._interval_checking = False + + async def handle_meta_event(self, message: dict) -> None: + event_type = message.get("meta_event_type") + if event_type == MetaEventType.lifecycle: + sub_type = message.get("sub_type") + if sub_type == MetaEventType.Lifecycle.connect: + self_id = message.get("self_id") + self.last_heart_beat = time.time() + logger.success(f"Bot {self_id} 连接成功") + asyncio.create_task(self.check_heartbeat(self_id)) + elif event_type == MetaEventType.heartbeat: + self_id = message.get("self_id") + status = message.get("status", {}) + is_online = status.get("online", False) + is_good = status.get("good", False) + + if is_online and is_good: + # 正常心跳 + if not self._interval_checking: + asyncio.create_task(self.check_heartbeat(self_id)) + self.last_heart_beat = time.time() + self.interval = message.get("interval", 30000) / 1000 + else: + # Bot 离线或状态异常 + if not is_online: + logger.error(f"🔴 Bot {self_id} 已下线 (online=false)") + logger.warning("Bot 可能被踢下线、网络断开或主动退出登录") + elif not is_good: + logger.warning(f"⚠️ Bot {self_id} 状态异常 (good=false)") + else: + logger.warning(f"Bot {self_id} Napcat 端异常!") + + async def check_heartbeat(self, id: int) -> None: + self._interval_checking = True + while True: + now_time = time.time() + if now_time - self.last_heart_beat > self.interval * 2: + logger.error(f"Bot {id} 可能发生了连接断开,被下线,或者Napcat卡死!") + break + else: + logger.debug("心跳正常") + await asyncio.sleep(self.interval) + + +meta_event_handler = MetaEventHandler() diff --git a/src/recv_handler/notice_handler.py b/src/recv_handler/notice_handler.py new file mode 100644 index 0000000..add8913 --- /dev/null +++ b/src/recv_handler/notice_handler.py @@ -0,0 +1,1000 @@ +import time +import json +import asyncio +import websockets as Server +from typing import Tuple, Optional + +from src.logger import logger +from src.config import global_config +from src.database import BanUser, db_manager, is_identical +from . import NoticeType, ACCEPT_FORMAT +from .message_sending import message_send_instance +from .message_handler import message_handler +from .qq_emoji_list import qq_face +from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase + +from src.utils import ( + get_group_info, + get_member_info, + get_self_info, + get_stranger_info, + read_ban_list, +) + +notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100) +unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3) + + +class NoticeHandler: + banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表 + lifted_list: list[BanUser] = [] # 已经自然解除禁言 + + def __init__(self): + self.server_connection: Server.ServerConnection = None + + async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: + """设置Napcat连接""" + self.server_connection = server_connection + + while self.server_connection.state != Server.State.OPEN: + await asyncio.sleep(0.5) + self.banned_list, self.lifted_list = await read_ban_list(self.server_connection) + + asyncio.create_task(self.auto_lift_detect()) + asyncio.create_task(self.send_notice()) + asyncio.create_task(self.handle_natural_lift()) + + def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: + """ + 将用户禁言记录添加到self.banned_list中 + 如果是全体禁言,则user_id为0 + """ + if user_id is None: + user_id = 0 # 使用0表示全体禁言 + lift_time = -1 + ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time) + for record in self.banned_list: + if is_identical(record, ban_record): + self.banned_list.remove(record) + self.banned_list.append(ban_record) + db_manager.create_ban_record(ban_record) # 作为更新 + return + self.banned_list.append(ban_record) + db_manager.create_ban_record(ban_record) # 添加到数据库 + + def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: + """ + 从self.lifted_group_list中移除已经解除全体禁言的群 + """ + if user_id is None: + user_id = 0 # 使用0表示全体禁言 + ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1) + self.lifted_list.append(ban_record) + db_manager.delete_ban_record(ban_record) # 删除数据库中的记录 + + async def handle_notice(self, raw_message: dict) -> None: + notice_type = raw_message.get("notice_type") + # message_time: int = raw_message.get("time") + message_time: float = time.time() # 应可乐要求,现在是float了 + + group_id = raw_message.get("group_id") + user_id = raw_message.get("user_id") + target_id = raw_message.get("target_id") + + handled_message: Seg = None + user_info: UserInfo = None + system_notice: bool = False + + match notice_type: + case NoticeType.friend_recall: + logger.info("好友撤回一条消息") + handled_message, user_info = await self.handle_friend_recall_notify(raw_message) + case NoticeType.group_recall: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("群内用户撤回一条消息") + handled_message, user_info = await self.handle_group_recall_notify(raw_message, group_id, user_id) + system_notice = True + case NoticeType.notify: + sub_type = raw_message.get("sub_type") + match sub_type: + case NoticeType.Notify.poke: + if global_config.chat.enable_poke and await message_handler.check_allow_to_chat( + user_id, group_id, False, False + ): + logger.info("处理戳一戳消息") + handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) + else: + logger.warning("戳一戳消息被禁用,取消戳一戳处理") + case NoticeType.Notify.group_name: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理群名称变更") + handled_message, user_info = await self.handle_group_name_notify(raw_message, group_id, user_id) + system_notice = True + case _: + logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") + case NoticeType.group_ban: + sub_type = raw_message.get("sub_type") + match sub_type: + case NoticeType.GroupBan.ban: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理群禁言") + handled_message, user_info = await self.handle_ban_notify(raw_message, group_id) + system_notice = True + case NoticeType.GroupBan.lift_ban: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理解除群禁言") + handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id) + system_notice = True + case _: + logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}") + case NoticeType.group_msg_emoji_like: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理群消息表情回应") + handled_message, user_info = await self.handle_emoji_like_notify(raw_message, group_id, user_id) + case NoticeType.group_upload: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + logger.info("处理群文件上传") + handled_message, user_info = await self.handle_group_upload_notify(raw_message, group_id, user_id) + system_notice = True + case NoticeType.group_increase: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + sub_type = raw_message.get("sub_type") + logger.info(f"处理群成员增加: {sub_type}") + handled_message, user_info = await self.handle_group_increase_notify(raw_message, group_id, user_id) + system_notice = True + case NoticeType.group_decrease: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + sub_type = raw_message.get("sub_type") + logger.info(f"处理群成员减少: {sub_type}") + handled_message, user_info = await self.handle_group_decrease_notify(raw_message, group_id, user_id) + system_notice = True + case NoticeType.group_admin: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + sub_type = raw_message.get("sub_type") + logger.info(f"处理群管理员变动: {sub_type}") + handled_message, user_info = await self.handle_group_admin_notify(raw_message, group_id, user_id) + system_notice = True + case NoticeType.essence: + if not await message_handler.check_allow_to_chat(user_id, group_id, True, False): + return None + sub_type = raw_message.get("sub_type") + logger.info(f"处理精华消息: {sub_type}") + handled_message, user_info = await self.handle_essence_notify(raw_message, group_id) + system_notice = True + case _: + logger.warning(f"不支持的notice类型: {notice_type}") + return None + if not handled_message or not user_info: + logger.warning("notice处理失败或不支持") + return None + + group_info: GroupInfo = None + if group_id: + fetched_group_info = await get_group_info(self.server_connection, group_id) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + else: + logger.warning("无法获取notice消息所在群的名称") + group_info = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=group_id, + group_name=group_name, + ) + + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id="notice", + time=message_time, + user_info=user_info, + group_info=group_info, + template_info=None, + format_info=FormatInfo( + content_format=["text", "notify"], + accept_format=ACCEPT_FORMAT, + ), + additional_config={"target_id": target_id}, # 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁 + ) + + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=handled_message, + raw_message=json.dumps(raw_message), + ) + + if system_notice: + await self.put_notice(message_base) + else: + logger.info("发送到Maibot处理通知信息") + await message_send_instance.message_send(message_base) + + async def handle_poke_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + # sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches + self_info: dict = await get_self_info(self.server_connection) + + if not self_info: + logger.error("自身信息获取失败") + return None, None + + self_id = raw_message.get("self_id") + target_id = raw_message.get("target_id") + target_name: str = None + raw_info: list = raw_message.get("raw_info") + + if group_id: + user_qq_info: dict = await get_member_info(self.server_connection, group_id, user_id) + else: + user_qq_info: dict = await get_stranger_info(self.server_connection, user_id) + if user_qq_info: + user_name = user_qq_info.get("nickname") + user_cardname = user_qq_info.get("card") + else: + user_name = "QQ用户" + user_cardname = "QQ用户" + logger.info("无法获取戳一戳对方的用户昵称") + + # 计算Seg + if self_id == target_id: + display_name = "" + target_name = self_info.get("nickname") + + elif self_id == user_id: + # 让ada不发送麦麦戳别人的消息 + return None, None + + else: + # 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境 + if group_id: + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, target_id) + if fetched_member_info: + target_name = fetched_member_info.get("nickname") + else: + target_name = "QQ用户" + logger.info("无法获取被戳一戳方的用户昵称") + display_name = user_name + else: + return None, None + + first_txt: str = "戳了戳" + second_txt: str = "" + try: + first_txt = raw_info[2].get("txt", "戳了戳") + second_txt = raw_info[4].get("txt", "") + except Exception as e: + logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") + + user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=user_cardname, + ) + + seg_data: Seg = Seg( + type="text", + data=f"{display_name}{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)", + ) + return seg_data, user_info + + async def handle_friend_recall_notify(self, raw_message: dict) -> Tuple[Seg | None, UserInfo | None]: + """处理好友消息撤回""" + user_id = raw_message.get("user_id") + message_id = raw_message.get("message_id") + + if not user_id: + logger.error("用户ID不能为空,无法处理好友撤回通知") + return None, None + + # 获取好友信息 + user_qq_info: dict = await get_stranger_info(self.server_connection, user_id) + if user_qq_info: + user_name = user_qq_info.get("nickname") + else: + user_name = "QQ用户" + logger.warning("无法获取撤回消息好友的昵称") + + user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=None, + ) + + seg_data = Seg( + type="notify", + data={ + "sub_type": "friend_recall", + "message_id": message_id, + }, + ) + + return seg_data, user_info + + async def handle_group_recall_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + """处理群消息撤回""" + if not group_id: + logger.error("群ID不能为空,无法处理群撤回通知") + return None, None + + message_id = raw_message.get("message_id") + operator_id = raw_message.get("operator_id") + + # 获取撤回操作者信息 + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取撤回操作者的昵称") + operator_nickname = "QQ用户" + + operator_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 获取被撤回消息发送者信息(如果不是自己撤回的话) + recalled_user_info: UserInfo | None = None + if user_id != operator_id: + user_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if user_member_info: + user_nickname = user_member_info.get("nickname") + user_cardname = user_member_info.get("card") + else: + user_nickname = "QQ用户" + user_cardname = None + logger.warning("无法获取被撤回消息发送者的昵称") + + recalled_user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + + seg_data = Seg( + type="notify", + data={ + "sub_type": "group_recall", + "message_id": message_id, + "recalled_user_info": recalled_user_info.to_dict() if recalled_user_info else None, + }, + ) + + return seg_data, operator_info + + async def handle_emoji_like_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + """处理群消息表情回应""" + if not group_id: + logger.error("群ID不能为空,无法处理表情回应通知") + return None, None + + # 获取用户信息 + user_qq_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if user_qq_info: + user_name = user_qq_info.get("nickname") + user_cardname = user_qq_info.get("card") + else: + user_name = "QQ用户" + user_cardname = "QQ用户" + logger.warning("无法获取表情回应用户的昵称") + + # 解析表情列表 + likes = raw_message.get("likes", []) + message_id = raw_message.get("message_id") + + # 构建表情文本,直接使用 qq_face 映射 + emoji_texts = [] + for like in likes: + emoji_id = str(like.get("emoji_id", "")) + count = like.get("count", 1) + # 使用 qq_face 字典获取表情描述 + emoji = qq_face.get(emoji_id, f"[表情:未知{emoji_id}]") + if count > 1: + emoji_texts.append(f"{emoji}x{count}") + else: + emoji_texts.append(emoji) + + emoji_str = "、".join(emoji_texts) if emoji_texts else "未知表情" + display_name = user_cardname if user_cardname and user_cardname != "QQ用户" else user_name + + # 构建消息文本 + message_text = f"{display_name} 对消息(ID:{message_id})表达了 {emoji_str}" + + user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=user_cardname, + ) + + seg_data = Seg(type="text", data=message_text) + return seg_data, user_info + + async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: + if not group_id: + logger.error("群ID不能为空,无法处理禁言通知") + return None, None + + # 计算user_info + operator_id = raw_message.get("operator_id") + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取禁言执行者的昵称,消息可能会无效") + operator_nickname = "QQ用户" + + operator_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 计算Seg + user_id = raw_message.get("user_id") + banned_user_info: UserInfo = None + user_nickname: str = "QQ用户" + user_cardname: str = None + sub_type: str = None + + duration = raw_message.get("duration") + if duration is None: + logger.error("禁言时长不能为空,无法处理禁言通知") + return None, None + + if user_id == 0: # 为全体禁言 + sub_type: str = "whole_ban" + self._ban_operation(group_id) + else: # 为单人禁言 + # 获取被禁言人的信息 + sub_type: str = "ban" + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + banned_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + self._ban_operation(group_id, user_id, int(time.time() + duration)) + + seg_data: Seg = Seg( + type="notify", + data={ + "sub_type": sub_type, + "duration": duration, + "banned_user_info": banned_user_info.to_dict() if banned_user_info else None, + }, + ) + + return seg_data, operator_info + + async def handle_lift_ban_notify( + self, raw_message: dict, group_id: int + ) -> Tuple[Seg, UserInfo] | Tuple[None, None]: + if not group_id: + logger.error("群ID不能为空,无法处理解除禁言通知") + return None, None + + # 计算user_info + operator_id = raw_message.get("operator_id") + operator_nickname: str = None + operator_cardname: str = None + + member_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if member_info: + operator_nickname = member_info.get("nickname") + operator_cardname = member_info.get("card") + else: + logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效") + operator_nickname = "QQ用户" + + operator_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_nickname, + user_cardname=operator_cardname, + ) + + # 计算Seg + sub_type: str = None + user_nickname: str = "QQ用户" + user_cardname: str = None + lifted_user_info: UserInfo = None + + user_id = raw_message.get("user_id") + if user_id == 0: # 全体禁言解除 + sub_type = "whole_lift_ban" + self._lift_operation(group_id) + else: # 单人禁言解除 + sub_type = "lift_ban" + # 获取被解除禁言人的信息 + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + else: + logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效") + lifted_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + self._lift_operation(group_id, user_id) + + seg_data: Seg = Seg( + type="notify", + data={ + "sub_type": sub_type, + "lifted_user_info": lifted_user_info.to_dict() if lifted_user_info else None, + }, + ) + return seg_data, operator_info + + async def put_notice(self, message_base: MessageBase) -> None: + """ + 将处理后的通知消息放入通知队列 + """ + if notice_queue.full() or unsuccessful_notice_queue.full(): + logger.warning("通知队列已满,可能是多次发送失败,消息丢弃") + else: + await notice_queue.put(message_base) + + async def handle_natural_lift(self) -> None: + while True: + if len(self.lifted_list) != 0: + lift_record = self.lifted_list.pop() + group_id = lift_record.group_id + user_id = lift_record.user_id + + db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录 + + seg_message: Seg = await self.natural_lift(group_id, user_id) + + fetched_group_info = await get_group_info(self.server_connection, group_id) + group_name: str = None + if fetched_group_info: + group_name = fetched_group_info.get("group_name") + else: + logger.warning("无法获取notice消息所在群的名称") + group_info = GroupInfo( + platform=global_config.maibot_server.platform_name, + group_id=group_id, + group_name=group_name, + ) + + message_info: BaseMessageInfo = BaseMessageInfo( + platform=global_config.maibot_server.platform_name, + message_id="notice", + time=time.time(), + user_info=None, # 自然解除禁言没有操作者 + group_info=group_info, + template_info=None, + format_info=None, + ) + + message_base: MessageBase = MessageBase( + message_info=message_info, + message_segment=seg_message, + raw_message=json.dumps( + { + "post_type": "notice", + "notice_type": "group_ban", + "sub_type": "lift_ban", + "group_id": group_id, + "user_id": user_id, + "operator_id": None, # 自然解除禁言没有操作者 + } + ), + ) + + await self.put_notice(message_base) + await asyncio.sleep(0.5) # 确保队列处理间隔 + else: + await asyncio.sleep(5) # 每5秒检查一次 + + async def natural_lift(self, group_id: int, user_id: int) -> Seg | None: + if not group_id: + logger.error("群ID不能为空,无法处理解除禁言通知") + return None + + if user_id == 0: # 理论上永远不会触发 + return Seg( + type="notify", + data={ + "sub_type": "whole_lift_ban", + "lifted_user_info": None, + }, + ) + + user_nickname: str = "QQ用户" + user_cardname: str = None + fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if fetched_member_info: + user_nickname = fetched_member_info.get("nickname") + user_cardname = fetched_member_info.get("card") + + lifted_user_info: UserInfo = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + ) + + return Seg( + type="notify", + data={ + "sub_type": "lift_ban", + "lifted_user_info": lifted_user_info.to_dict(), + }, + ) + + async def auto_lift_detect(self) -> None: + while True: + if len(self.banned_list) == 0: + await asyncio.sleep(5) + continue + for ban_record in self.banned_list: + if ban_record.user_id == 0 or ban_record.lift_time == -1: + continue + if ban_record.lift_time <= int(time.time()): + # 触发自然解除禁言 + logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除") + self.lifted_list.append(ban_record) + self.banned_list.remove(ban_record) + await asyncio.sleep(5) + + async def send_notice(self) -> None: + """ + 发送通知消息到Napcat + """ + while True: + if not unsuccessful_notice_queue.empty(): + to_be_send: MessageBase = await unsuccessful_notice_queue.get() + try: + send_status = await message_send_instance.message_send(to_be_send) + if send_status: + unsuccessful_notice_queue.task_done() + else: + await unsuccessful_notice_queue.put(to_be_send) + except Exception as e: + logger.error(f"发送通知消息失败: {str(e)}") + await unsuccessful_notice_queue.put(to_be_send) + await asyncio.sleep(1) + continue + to_be_send: MessageBase = await notice_queue.get() + try: + send_status = await message_send_instance.message_send(to_be_send) + if send_status: + notice_queue.task_done() + else: + await unsuccessful_notice_queue.put(to_be_send) + except Exception as e: + logger.error(f"发送通知消息失败: {str(e)}") + await unsuccessful_notice_queue.put(to_be_send) + await asyncio.sleep(1) + + async def handle_group_upload_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + """ + 处理群文件上传通知 + """ + file_info: dict = raw_message.get("file", {}) + file_name = file_info.get("name", "未知文件") + file_size = file_info.get("size", 0) + file_id = file_info.get("id", "") + + user_qq_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if user_qq_info: + user_name = user_qq_info.get("nickname") + user_cardname = user_qq_info.get("card") + else: + logger.warning("无法获取上传者信息") + user_name = "QQ用户" + user_cardname = None + + user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=user_cardname, + ) + + # 格式化文件大小 + if file_size < 1024: + size_str = f"{file_size}B" + elif file_size < 1024 * 1024: + size_str = f"{file_size / 1024:.2f}KB" + else: + size_str = f"{file_size / (1024 * 1024):.2f}MB" + + notify_seg = Seg( + type="notify", + data={ + "sub_type": "group_upload", + "file_name": file_name, + "file_size": size_str, + "file_id": file_id, + }, + ) + + return notify_seg, user_info + + async def handle_group_increase_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + """ + 处理群成员增加通知 + """ + sub_type = raw_message.get("sub_type") + operator_id = raw_message.get("operator_id") + + # 获取新成员信息 + user_qq_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if user_qq_info: + user_name = user_qq_info.get("nickname") + user_cardname = user_qq_info.get("card") + else: + logger.warning("无法获取新成员信息") + user_name = "QQ用户" + user_cardname = None + + user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=user_cardname, + ) + + # 获取操作者信息 + operator_name = "未知" + if operator_id: + operator_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if operator_info: + operator_name = operator_info.get("card") or operator_info.get("nickname", "未知") + + if sub_type == NoticeType.GroupIncrease.invite: + action_text = f"被 {operator_name} 邀请" + elif sub_type == NoticeType.GroupIncrease.approve: + action_text = f"经 {operator_name} 同意" + else: + action_text = "加入" + + notify_seg = Seg( + type="notify", + data={ + "sub_type": "group_increase", + "action": action_text, + "increase_type": sub_type, + "operator_id": operator_id, + }, + ) + + return notify_seg, user_info + + async def handle_group_decrease_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + """ + 处理群成员减少通知 + """ + sub_type = raw_message.get("sub_type") + operator_id = raw_message.get("operator_id") + + # 获取离开成员信息 + user_qq_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if user_qq_info: + user_name = user_qq_info.get("nickname") + user_cardname = user_qq_info.get("card") + else: + logger.warning("无法获取离开成员信息") + user_name = "QQ用户" + user_cardname = None + + user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=user_cardname, + ) + + # 获取操作者信息 + operator_name = "未知" + if operator_id and operator_id != 0: + operator_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if operator_info: + operator_name = operator_info.get("card") or operator_info.get("nickname", "未知") + + if sub_type == NoticeType.GroupDecrease.leave: + action_text = "主动退群" + elif sub_type == NoticeType.GroupDecrease.kick: + action_text = f"被 {operator_name} 踢出" + elif sub_type == NoticeType.GroupDecrease.kick_me: + action_text = "机器人被踢出" + else: + action_text = "离开群聊" + + notify_seg = Seg( + type="notify", + data={ + "sub_type": "group_decrease", + "action": action_text, + "decrease_type": sub_type, + "operator_id": operator_id, + }, + ) + + return notify_seg, user_info + + async def handle_group_admin_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + """ + 处理群管理员变动通知 + """ + sub_type = raw_message.get("sub_type") + + # 获取目标用户信息 + user_qq_info: dict = await get_member_info(self.server_connection, group_id, user_id) + if user_qq_info: + user_name = user_qq_info.get("nickname") + user_cardname = user_qq_info.get("card") + else: + logger.warning("无法获取目标用户信息") + user_name = "QQ用户" + user_cardname = None + + user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=user_cardname, + ) + + if sub_type == NoticeType.GroupAdmin.set: + action_text = "被设置为管理员" + elif sub_type == NoticeType.GroupAdmin.unset: + action_text = "被取消管理员" + else: + action_text = "管理员变动" + + notify_seg = Seg( + type="notify", + data={ + "sub_type": "group_admin", + "action": action_text, + "admin_type": sub_type, + }, + ) + + return notify_seg, user_info + + async def handle_essence_notify( + self, raw_message: dict, group_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + """ + 处理精华消息通知 + """ + sub_type = raw_message.get("sub_type") + sender_id = raw_message.get("sender_id") + operator_id = raw_message.get("operator_id") + message_id = raw_message.get("message_id") + + # 获取操作者信息(设置精华的人) + operator_info: dict = await get_member_info(self.server_connection, group_id, operator_id) + if operator_info: + operator_name = operator_info.get("nickname") + operator_cardname = operator_info.get("card") + else: + logger.warning("无法获取操作者信息") + operator_name = "QQ用户" + operator_cardname = None + + user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=operator_id, + user_nickname=operator_name, + user_cardname=operator_cardname, + ) + + # 获取消息发送者信息 + sender_name = "未知用户" + if sender_id: + sender_info: dict = await get_member_info(self.server_connection, group_id, sender_id) + if sender_info: + sender_name = sender_info.get("card") or sender_info.get("nickname", "未知用户") + + if sub_type == NoticeType.Essence.add: + action_text = f"将 {sender_name} 的消息设为精华" + elif sub_type == NoticeType.Essence.delete: + action_text = f"移除了 {sender_name} 的精华消息" + else: + action_text = "精华消息变动" + + notify_seg = Seg( + type="notify", + data={ + "sub_type": "essence", + "action": action_text, + "essence_type": sub_type, + "sender_id": sender_id, + "message_id": message_id, + }, + ) + + return notify_seg, user_info + + async def handle_group_name_notify( + self, raw_message: dict, group_id: int, user_id: int + ) -> Tuple[Seg | None, UserInfo | None]: + """ + 处理群名称变更通知 + """ + new_name = raw_message.get("name_new") + + if not new_name: + logger.warning("群名称变更通知缺少新名称") + return None, None + + # 获取操作者信息 + user_info_dict: dict = await get_member_info(self.server_connection, group_id, user_id) + if user_info_dict: + user_name = user_info_dict.get("nickname") + user_cardname = user_info_dict.get("card") + else: + logger.warning("无法获取修改群名称的用户信息") + user_name = "QQ用户" + user_cardname = None + + user_info = UserInfo( + platform=global_config.maibot_server.platform_name, + user_id=user_id, + user_nickname=user_name, + user_cardname=user_cardname, + ) + + action_text = f"修改群名称为: {new_name}" + + notify_seg = Seg( + type="notify", + data={ + "sub_type": "group_name", + "action": action_text, + "new_name": new_name, + }, + ) + + return notify_seg, user_info + + +notice_handler = NoticeHandler() diff --git a/src/qq_emoji_list.py b/src/recv_handler/qq_emoji_list.py similarity index 79% rename from src/qq_emoji_list.py rename to src/recv_handler/qq_emoji_list.py index 51c3232..3b3c8bb 100644 --- a/src/qq_emoji_list.py +++ b/src/recv_handler/qq_emoji_list.py @@ -31,7 +31,7 @@ "30": "[表情:奋斗]", "31": "[表情:咒骂]", "32": "[表情:疑问]", - "33": "[表情: 嘘]", + "33": "[表情:嘘]", "34": "[表情:晕]", "35": "[表情:折磨]", "36": "[表情:衰]", @@ -117,7 +117,7 @@ "268": "[表情:问号脸]", "269": "[表情:暗中观察]", "270": "[表情:emm]", - "271": "[表情:吃 瓜]", + "271": "[表情:吃瓜]", "272": "[表情:呵呵哒]", "273": "[表情:我酸了]", "277": "[表情:汪汪]", @@ -146,7 +146,7 @@ "314": "[表情:仔细分析]", "317": "[表情:菜汪]", "318": "[表情:崇拜]", - "319": "[表情: 比心]", + "319": "[表情:比心]", "320": "[表情:庆祝]", "323": "[表情:嫌弃]", "324": "[表情:吃糖]", @@ -175,13 +175,65 @@ "355": "[表情:耶]", "356": "[表情:666]", "357": "[表情:裂开]", - "392": "[表情:龙年 快乐]", + "392": "[表情:龙年快乐]", "393": "[表情:新年中龙]", "394": "[表情:新年大龙]", "395": "[表情:略略略]", + "128522": "[表情:嘿嘿]", + "128524": "[表情:羞涩]", + "128538": "[表情:亲亲]", + "128531": "[表情:汗]", + "128560": "[表情:紧张]", + "128541": "[表情:吐舌]", + "128513": "[表情:呲牙]", + "128540": "[表情:淘气]", + "9786": "[表情:可爱]", + "128532": "[表情:失落]", + "128516": "[表情:高兴]", + "128527": "[表情:哼哼]", + "128530": "[表情:不屑]", + "128563": "[表情:瞪眼]", + "128536": "[表情:飞吻]", + "128557": "[表情:大哭]", + "128514": "[表情:激动]", + "128170": "[表情:肌肉]", + "128074": "[表情:拳头]", + "128077": "[表情:厉害]", + "128079": "[表情:鼓掌]", + "128076": "[表情:好的]", + "127836": "[表情:拉面]", + "127847": "[表情:刨冰]", + "127838": "[表情:面包]", + "127866": "[表情:啤酒]", + "127867": "[表情:干杯]", + "9749": "[表情:咖啡]", + "127822": "[表情:苹果]", + "127827": "[表情:草莓]", + "127817": "[表情:西瓜]", + "127801": "[表情:玫瑰]", + "127881": "[表情:庆祝]", + "128157": "[表情:礼物]", + "10024": "[表情:闪光]", + "128168": "[表情:吹气]", + "128166": "[表情:水]", + "128293": "[表情:火]", + "128164": "[表情:睡觉]", + "128235": "[表情:邮箱]", + "128103": "[表情:女孩]", + "128102": "[表情:男孩]", + "128053": "[表情:猴]", + "128046": "[表情:牛]", + "128027": "[表情:虫]", + "128051": "[表情:鲸鱼]", + "9728": "[表情:晴天]", + "10068": "[表情:问号]", + "128147": "[表情:爱心]", + "10060": "[表情:错误]", + "128089": "[表情:内衣]", + "128104": "[表情:爸爸]", "😊": "[表情:嘿嘿]", "😌": "[表情:羞涩]", - "😚": "[ 表情:亲亲]", + "😚": "[表情:亲亲]", "😓": "[表情:汗]", "😰": "[表情:紧张]", "😝": "[表情:吐舌]", @@ -200,7 +252,7 @@ "😂": "[表情:激动]", "💪": "[表情:肌肉]", "👊": "[表情:拳头]", - "👍": "[表情 :厉害]", + "👍": "[表情:厉害]", "👏": "[表情:鼓掌]", "👎": "[表情:鄙视]", "🙏": "[表情:合十]", @@ -245,6 +297,6 @@ "☀": "[表情:晴天]", "❔": "[表情:问号]", "🔫": "[表情:手枪]", - "💓": "[表情:爱 心]", + "💓": "[表情:爱心]", "🏪": "[表情:便利店]", -} +} \ No newline at end of file diff --git a/src/message_queue.py b/src/response_pool.py similarity index 69% rename from src/message_queue.py rename to src/response_pool.py index 3720590..41feb9e 100644 --- a/src/message_queue.py +++ b/src/response_pool.py @@ -6,22 +6,21 @@ response_dict: Dict = {} response_time_dict: Dict = {} -message_queue = asyncio.Queue() -async def get_response(request_id: str) -> dict: - retry_count = 0 - max_retries = 50 # 10秒超时 - while request_id not in response_dict: - retry_count += 1 - if retry_count >= max_retries: - raise TimeoutError(f"请求超时,未收到响应,request_id: {request_id}") - await asyncio.sleep(0.2) - response = response_dict.pop(request_id) +async def get_response(request_id: str, timeout: int = 10) -> dict: + response = await asyncio.wait_for(_get_response(request_id), timeout) _ = response_time_dict.pop(request_id) logger.trace(f"响应信息id: {request_id} 已从响应字典中取出") return response +async def _get_response(request_id: str) -> dict: + """ + 内部使用的获取响应函数,主要用于在需要时获取响应 + """ + while request_id not in response_dict: + await asyncio.sleep(0.2) + return response_dict.pop(request_id) async def put_response(response: dict): echo_id = response.get("echo") @@ -36,10 +35,10 @@ async def check_timeout_response() -> None: cleaned_message_count: int = 0 now_time = time.time() for echo_id, response_time in list(response_time_dict.items()): - if now_time - response_time > global_config.napcat_heartbeat_interval: + if now_time - response_time > global_config.napcat_server.heartbeat_interval: cleaned_message_count += 1 response_dict.pop(echo_id) response_time_dict.pop(echo_id) logger.warning(f"响应消息 {echo_id} 超时,已删除") logger.info(f"已删除 {cleaned_message_count} 条超时响应消息") - await asyncio.sleep(global_config.napcat_heartbeat_interval) + await asyncio.sleep(global_config.napcat_server.heartbeat_interval) diff --git a/src/send_handler.py b/src/send_handler.py deleted file mode 100644 index 88b33cf..0000000 --- a/src/send_handler.py +++ /dev/null @@ -1,313 +0,0 @@ -import json -import websockets as Server -import uuid -from maim_message import ( - UserInfo, - GroupInfo, - Seg, - BaseMessageInfo, - MessageBase, -) -from typing import Dict, Any, Tuple - -from . import CommandType -from .config import global_config -from .message_queue import get_response -from .logger import logger -from .utils import get_image_format, convert_image_to_gif - - -class SendHandler: - def __init__(self): - self.server_connection: Server.ServerConnection = None - - async def handle_message(self, raw_message_base_dict: dict) -> None: - raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict) - message_segment: Seg = raw_message_base.message_segment - logger.info("接收到来自MaiBot的消息,处理中") - if message_segment.type == "command": - return await self.send_command(raw_message_base) - else: - return await self.send_normal_message(raw_message_base) - - async def send_normal_message(self, raw_message_base: MessageBase) -> None: - """ - 处理普通消息发送 - """ - logger.info("处理普通信息中") - message_info: BaseMessageInfo = raw_message_base.message_info - message_segment: Seg = raw_message_base.message_segment - group_info: GroupInfo = message_info.group_info - user_info: UserInfo = message_info.user_info - target_id: int = None - action: str = None - id_name: str = None - processed_message: list = [] - try: - processed_message = await self.handle_seg_recursive(message_segment) - except Exception as e: - logger.error(f"处理消息时发生错误: {e}") - return - - if not processed_message: - logger.critical("现在暂时不支持解析此回复!") - return None - - if group_info and user_info: - logger.debug("发送群聊消息") - target_id = group_info.group_id - action = "send_group_msg" - id_name = "group_id" - elif user_info: - logger.debug("发送私聊消息") - target_id = user_info.user_id - action = "send_private_msg" - id_name = "user_id" - else: - logger.error("无法识别的消息类型") - return - logger.info("尝试发送到napcat") - response = await self.send_message_to_napcat( - action, - { - id_name: target_id, - "message": processed_message, - }, - ) - if response.get("status") == "ok": - logger.info("消息发送成功") - else: - logger.warning(f"消息发送失败,napcat返回:{str(response)}") - - async def send_command(self, raw_message_base: MessageBase) -> None: - """ - 处理命令类 - """ - logger.info("处理命令中") - message_info: BaseMessageInfo = raw_message_base.message_info - message_segment: Seg = raw_message_base.message_segment - group_info: GroupInfo = message_info.group_info - seg_data: Dict[str, Any] = message_segment.data - command_name: str = seg_data.get("name") - try: - match command_name: - case CommandType.GROUP_BAN.name: - command, args_dict = self.handle_ban_command(seg_data.get("args"), group_info) - case CommandType.GROUP_WHOLE_BAN.name: - command, args_dict = self.handle_whole_ban_command(seg_data.get("args"), group_info) - case CommandType.GROUP_KICK.name: - command, args_dict = self.handle_kick_command(seg_data.get("args"), group_info) - case _: - logger.error(f"未知命令: {command_name}") - return - except Exception as e: - logger.error(f"处理命令时发生错误: {e}") - return None - - if not command or not args_dict: - logger.error("命令或参数缺失") - return None - - response = await self.send_message_to_napcat(command, args_dict) - if response.get("status") == "ok": - logger.info(f"命令 {command_name} 执行成功") - else: - logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}") - - def get_level(self, seg_data: Seg) -> int: - if seg_data.type == "seglist": - return 1 + max(self.get_level(seg) for seg in seg_data.data) - else: - return 1 - - async def handle_seg_recursive(self, seg_data: Seg) -> list: - payload: list = [] - if seg_data.type == "seglist": - # level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用 - if not seg_data.data: - return [] - for seg in seg_data.data: - payload = self.process_message_by_type(seg, payload) - else: - payload = self.process_message_by_type(seg_data, payload) - return payload - - def process_message_by_type(self, seg: Seg, payload: list) -> list: - # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression - new_payload = payload - if seg.type == "reply": - target_id = seg.data - if target_id == "notice": - return payload - new_payload = self.build_payload(payload, self.handle_reply_message(target_id), True) - elif seg.type == "text": - text = seg.data - if not text: - return payload - new_payload = self.build_payload(payload, self.handle_text_message(text), False) - elif seg.type == "face": - logger.warning("MaiBot 发送了qq原生表情,暂时不支持") - elif seg.type == "image": - image = seg.data - new_payload = self.build_payload(payload, self.handle_image_message(image), False) - elif seg.type == "emoji": - emoji = seg.data - new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False) - elif seg.type == "voice": - voice = seg.data - new_payload = self.build_payload(payload, self.handle_voice_message(voice), False) - return new_payload - - def build_payload(self, payload: list, addon: dict, is_reply: bool = False) -> list: - # sourcery skip: for-append-to-extend, merge-list-append, simplify-generator - """构建发送的消息体""" - if is_reply: - temp_list = [] - temp_list.append(addon) - for i in payload: - if i.get("type") == "reply": - logger.debug("检测到多个回复,使用最新的回复") - continue - temp_list.append(i) - return temp_list - else: - payload.append(addon) - return payload - - def handle_reply_message(self, id: str) -> dict: - """处理回复消息""" - return {"type": "reply", "data": {"id": id}} - - def handle_text_message(self, message: str) -> dict: - """处理文本消息""" - return {"type": "text", "data": {"text": message}} - - def handle_image_message(self, encoded_image: str) -> dict: - """处理图片消息""" - return { - "type": "image", - "data": { - "file": f"base64://{encoded_image}", - "subtype": 0, - }, - } # base64 编码的图片 - - def handle_emoji_message(self, encoded_emoji: str) -> dict: - """处理表情消息""" - encoded_image = encoded_emoji - image_format = get_image_format(encoded_emoji) - if image_format != "gif": - encoded_image = convert_image_to_gif(encoded_emoji) - return { - "type": "image", - "data": { - "file": f"base64://{encoded_image}", - "subtype": 1, - "summary": "[动画表情]", - }, - } - - def handle_voice_message(self, encoded_voice: str) -> dict: - """处理语音消息""" - if not global_config.use_tts: - logger.warning("未启用语音消息处理") - return {} - if not encoded_voice: - return {} - return { - "type": "record", - "data": {"file": f"base64://{encoded_voice}"}, - } - - def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: - """处理封禁命令 - - Args: - args (Dict[str, Any]): 参数字典 - group_info (GroupInfo): 群聊信息(对应目标群聊) - - Returns: - Tuple[CommandType, Dict[str, Any]] - """ - duration: int = int(args["duration"]) - user_id: int = int(args["qq_id"]) - group_id: int = int(group_info.group_id) - if duration <= 0: - raise ValueError("封禁时间必须大于0") - if not user_id or not group_id: - raise ValueError("封禁命令缺少必要参数") - if duration > 2592000: - raise ValueError("封禁时间不能超过30天") - return ( - CommandType.GROUP_BAN.value, - { - "group_id": group_id, - "user_id": user_id, - "duration": duration, - }, - ) - - def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: - """处理全体禁言命令 - - Args: - args (Dict[str, Any]): 参数字典 - group_info (GroupInfo): 群聊信息(对应目标群聊) - - Returns: - Tuple[CommandType, Dict[str, Any]] - """ - enable = args["enable"] - assert isinstance(enable, bool), "enable参数必须是布尔值" - group_id: int = int(group_info.group_id) - if group_id <= 0: - raise ValueError("群组ID无效") - return ( - CommandType.GROUP_WHOLE_BAN.value, - { - "group_id": group_id, - "enable": enable, - }, - ) - - def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: - """处理群成员踢出命令 - - Args: - args (Dict[str, Any]): 参数字典 - group_info (GroupInfo): 群聊信息(对应目标群聊) - - Returns: - Tuple[CommandType, Dict[str, Any]] - """ - user_id: int = int(args["qq_id"]) - group_id: int = int(group_info.group_id) - if group_id <= 0: - raise ValueError("群组ID无效") - if user_id <= 0: - raise ValueError("用户ID无效") - return ( - CommandType.GROUP_KICK.value, - { - "group_id": group_id, - "user_id": user_id, - "reject_add_request": False, # 不拒绝加群请求 - }, - ) - - async def send_message_to_napcat(self, action: str, params: dict) -> dict: - request_uuid = str(uuid.uuid4()) - payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) - await self.server_connection.send(payload) - try: - response = await get_response(request_uuid) - except TimeoutError: - logger.error("发送消息超时,未收到响应") - return {"status": "error", "message": "timeout"} - except Exception as e: - logger.error(f"发送消息失败: {e}") - return {"status": "error", "message": str(e)} - return response - - -send_handler = SendHandler() diff --git a/src/send_handler/__init__.py b/src/send_handler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/send_handler/main_send_handler.py b/src/send_handler/main_send_handler.py new file mode 100644 index 0000000..cc2c945 --- /dev/null +++ b/src/send_handler/main_send_handler.py @@ -0,0 +1,174 @@ +from typing import Any, Dict, Optional +import time +from maim_message import ( + UserInfo, + GroupInfo, + Seg, + BaseMessageInfo, + MessageBase, +) +from src.logger import logger +from .send_command_handler import SendCommandHandleClass +from .send_message_handler import SendMessageHandleClass +from .nc_sending import nc_message_sender +from src.recv_handler.message_sending import message_send_instance + + +class SendHandler: + def __init__(self): + pass + + async def handle_message(self, raw_message_base_dict: dict) -> None: + raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict) + message_segment: Seg = raw_message_base.message_segment + logger.info("接收到来自MaiBot的消息,处理中") + if message_segment.type == "command": + return await self.send_command(raw_message_base) + else: + return await self.send_normal_message(raw_message_base) + + async def send_command(self, raw_message_base: MessageBase) -> None: + """ + 处理命令类 + """ + logger.info("处理命令中") + message_info: BaseMessageInfo = raw_message_base.message_info + message_segment: Seg = raw_message_base.message_segment + group_info: GroupInfo = message_info.group_info + seg_data: Dict[str, Any] = message_segment.data + command_name = seg_data.get('name', 'UNKNOWN') + + try: + command, args_dict = SendCommandHandleClass.handle_command(seg_data, group_info) + except Exception as e: + logger.error(f"处理命令时出错: {str(e)}") + # 发送错误响应给麦麦 + await self._send_command_response( + platform=message_info.platform, + command_name=command_name, + success=False, + error=str(e) + ) + return + + if not command or not args_dict: + logger.error("命令或参数缺失") + await self._send_command_response( + platform=message_info.platform, + command_name=command_name, + success=False, + error="命令或参数缺失" + ) + return None + + response = await nc_message_sender.send_message_to_napcat(command, args_dict) + + # 根据响应状态发送结果给麦麦 + if response.get("status") == "ok": + logger.info(f"命令 {command_name} 执行成功") + await self._send_command_response( + platform=message_info.platform, + command_name=command_name, + success=True, + data=response.get("data") + ) + else: + logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}") + await self._send_command_response( + platform=message_info.platform, + command_name=command_name, + success=False, + error=str(response), + data=response.get("data") # 有些错误响应也可能包含部分数据 + ) + + async def _send_command_response( + self, + platform: str, + command_name: str, + success: bool, + data: Optional[Dict] = None, + error: Optional[str] = None + ) -> None: + """发送命令响应回麦麦 + + Args: + platform: 平台标识 + command_name: 命令名称 + success: 是否执行成功 + data: 返回数据(成功时) + error: 错误信息(失败时) + """ + response_data = { + "command_name": command_name, + "success": success, + "timestamp": time.time() + } + + if data is not None: + response_data["data"] = data + if error: + response_data["error"] = error + + try: + await message_send_instance.send_custom_message( + custom_message=response_data, + platform=platform, + message_type="command_response" + ) + logger.debug(f"已发送命令响应: {command_name}, success={success}") + except Exception as e: + logger.error(f"发送命令响应失败: {e}") + + async def send_normal_message(self, raw_message_base: MessageBase) -> None: + """ + 处理普通消息发送 + """ + logger.info("处理普通信息中") + message_info: BaseMessageInfo = raw_message_base.message_info + message_segment: Seg = raw_message_base.message_segment + group_info: GroupInfo = message_info.group_info + user_info: UserInfo = message_info.user_info + target_id: int = None + action: str = None + id_name: str = None + processed_message: list = [] + try: + processed_message = SendMessageHandleClass.process_seg_recursive(message_segment) + except Exception as e: + logger.error(f"处理消息时发生错误: {e}") + return + + if not processed_message: + logger.critical("现在暂时不支持解析此回复!") + return None + + if group_info and user_info: + logger.debug("发送群聊消息") + target_id = group_info.group_id + action = "send_group_msg" + id_name = "group_id" + elif user_info: + logger.debug("发送私聊消息") + target_id = user_info.user_id + action = "send_private_msg" + id_name = "user_id" + else: + logger.error("无法识别的消息类型") + return + logger.info("尝试发送到napcat") + response = await nc_message_sender.send_message_to_napcat( + action, + { + id_name: target_id, + "message": processed_message, + }, + ) + if response.get("status") == "ok": + logger.info("消息发送成功") + qq_message_id = response.get("data", {}).get("message_id") + await nc_message_sender.message_sent_back(raw_message_base, qq_message_id) + else: + logger.warning(f"消息发送失败,napcat返回:{str(response)}") + +send_handler = SendHandler() \ No newline at end of file diff --git a/src/send_handler/nc_sending.py b/src/send_handler/nc_sending.py new file mode 100644 index 0000000..bb3b65e --- /dev/null +++ b/src/send_handler/nc_sending.py @@ -0,0 +1,61 @@ +import json +import uuid +import websockets as Server +from maim_message import MessageBase + +from src.response_pool import get_response +from src.logger import logger +from src.recv_handler.message_sending import message_send_instance + +class NCMessageSender: + def __init__(self): + self.server_connection: Server.ServerConnection = None + + async def set_server_connection(self, connection: Server.ServerConnection): + self.server_connection = connection + + async def send_message_to_napcat(self, action: str, params: dict) -> dict: + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) + await self.server_connection.send(payload) + try: + response = await get_response(request_uuid) + except TimeoutError: + logger.error("发送消息超时,未收到响应") + return {"status": "error", "message": "timeout"} + except Exception as e: + logger.error(f"发送消息失败: {e}") + return {"status": "error", "message": str(e)} + return response + + async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None: + # # 修改 additional_config,添加 echo 字段 + # if message_base.message_info.additional_config is None: + # message_base.message_info.additional_config = {} + + # message_base.message_info.additional_config["echo"] = True + + # # 获取原始的 mmc_message_id + # mmc_message_id = message_base.message_info.message_id + + # # 修改 message_segment 为 notify 类型 + # message_base.message_segment = Seg( + # type="notify", data={"sub_type": "echo", "echo": mmc_message_id, "actual_id": qq_message_id} + # ) + # await message_send_instance.message_send(message_base) + # logger.debug("已回送消息ID") + # return + platform = message_base.message_info.platform + mmc_message_id = message_base.message_info.message_id + echo_data = { + "type": "echo", + "echo": mmc_message_id, + "actual_id": qq_message_id, + } + success = await message_send_instance.send_custom_message(echo_data, platform, "message_id_echo") + if success: + logger.debug("已回送消息ID") + else: + logger.error("回送消息ID失败") + +nc_message_sender = NCMessageSender() \ No newline at end of file diff --git a/src/send_handler/send_command_handler.py b/src/send_handler/send_command_handler.py new file mode 100644 index 0000000..bdacb62 --- /dev/null +++ b/src/send_handler/send_command_handler.py @@ -0,0 +1,719 @@ +from maim_message import GroupInfo +from typing import Any, Dict, Tuple, Callable, Optional + +from src import CommandType + + +# 全局命令处理器注册表(在类外部定义以避免循环引用) +_command_handlers: Dict[str, Dict[str, Any]] = {} + + +def register_command(command_type: CommandType, require_group: bool = True): + """装饰器:注册命令处理器 + + Args: + command_type: 命令类型 + require_group: 是否需要群聊信息,默认为True + + Returns: + 装饰器函数 + """ + + def decorator(func: Callable) -> Callable: + _command_handlers[command_type.name] = { + "handler": func, + "require_group": require_group, + } + return func + + return decorator + + +class SendCommandHandleClass: + @classmethod + def handle_command(cls, raw_command_data: Dict[str, Any], group_info: Optional[GroupInfo]): + """统一命令处理入口 + + Args: + raw_command_data: 原始命令数据 + group_info: 群聊信息(可选) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) 用于发送给NapCat + + Raises: + RuntimeError: 命令类型未知或处理失败 + """ + command_name: str = raw_command_data.get("name") + + if command_name not in _command_handlers: + raise RuntimeError(f"未知的命令类型: {command_name}") + + try: + handler_info = _command_handlers[command_name] + handler = handler_info["handler"] + require_group = handler_info["require_group"] + + # 检查群聊信息要求 + if require_group and not group_info: + raise ValueError(f"命令 {command_name} 需要在群聊上下文中使用") + + # 调用处理器 + args = raw_command_data.get("args", {}) + return handler(args, group_info) + + except Exception as e: + raise RuntimeError(f"处理命令 {command_name} 时出错: {str(e)}") from e + + # ============ 命令处理器(使用装饰器注册)============ + + @staticmethod + @register_command(CommandType.GROUP_BAN, require_group=True) + def handle_ban_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """处理封禁命令 + + Args: + args: 参数字典 {"qq_id": int, "duration": int} + group_info: 群聊信息(对应目标群聊) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + duration: int = int(args["duration"]) + user_id: int = int(args["qq_id"]) + group_id: int = int(group_info.group_id) + if duration < 0: + raise ValueError("封禁时间必须大于等于0") + if not user_id or not group_id: + raise ValueError("封禁命令缺少必要参数") + if duration > 2592000: + raise ValueError("封禁时间不能超过30天") + return ( + CommandType.GROUP_BAN.value, + { + "group_id": group_id, + "user_id": user_id, + "duration": duration, + }, + ) + + @staticmethod + @register_command(CommandType.GROUP_WHOLE_BAN, require_group=True) + def handle_whole_ban_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """处理全体禁言命令 + + Args: + args: 参数字典 {"enable": bool} + group_info: 群聊信息(对应目标群聊) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + enable = args["enable"] + assert isinstance(enable, bool), "enable参数必须是布尔值" + group_id: int = int(group_info.group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + return ( + CommandType.GROUP_WHOLE_BAN.value, + { + "group_id": group_id, + "enable": enable, + }, + ) + + @staticmethod + @register_command(CommandType.GROUP_KICK, require_group=False) + def handle_kick_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """处理群成员踢出命令 + + Args: + args: 参数字典 {"group_id": int, "user_id": int, "reject_add_request": bool (可选)} + group_info: 群聊信息(可选,可自动获取 group_id) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("群踢人命令缺少参数") + + # 优先从 args 获取 group_id,否则从 group_info 获取 + group_id = args.get("group_id") + if not group_id and group_info: + group_id = int(group_info.group_id) + + user_id = args.get("user_id") + + if not group_id: + raise ValueError("群踢人命令缺少必要参数: group_id") + if not user_id: + raise ValueError("群踢人命令缺少必要参数: user_id") + + group_id = int(group_id) + user_id = int(user_id) + if group_id <= 0: + raise ValueError("群组ID无效") + if user_id <= 0: + raise ValueError("用户ID无效") + + # reject_add_request 是可选参数,默认 False + reject_add_request = args.get("reject_add_request", False) + + return ( + CommandType.GROUP_KICK.value, + { + "group_id": group_id, + "user_id": user_id, + "reject_add_request": bool(reject_add_request), + }, + ) + + @staticmethod + @register_command(CommandType.GROUP_KICK_MEMBERS, require_group=False) + def handle_kick_members_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """处理批量踢出群成员命令 + + Args: + args: 参数字典 {"group_id": int, "user_id": List[int], "reject_add_request": bool (可选)} + group_info: 群聊信息(可选,可自动获取 group_id) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("批量踢人命令缺少参数") + + # 优先从 args 获取 group_id,否则从 group_info 获取 + group_id = args.get("group_id") + if not group_id and group_info: + group_id = int(group_info.group_id) + + user_id = args.get("user_id") + + if not group_id: + raise ValueError("批量踢人命令缺少必要参数: group_id") + if not user_id: + raise ValueError("批量踢人命令缺少必要参数: user_id") + + # 验证 user_id 是数组 + if not isinstance(user_id, list): + raise ValueError("user_id 必须是数组类型") + if len(user_id) == 0: + raise ValueError("user_id 数组不能为空") + + # 转换并验证每个 user_id + user_id_list = [] + for uid in user_id: + try: + uid_int = int(uid) + if uid_int <= 0: + raise ValueError(f"用户ID无效: {uid}") + user_id_list.append(uid_int) + except (ValueError, TypeError) as e: + raise ValueError(f"用户ID格式错误: {uid} - {str(e)}") from None + + group_id = int(group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + + # reject_add_request 是可选参数,默认 False + reject_add_request = args.get("reject_add_request", False) + + return ( + CommandType.GROUP_KICK_MEMBERS.value, + { + "group_id": group_id, + "user_id": user_id_list, + "reject_add_request": bool(reject_add_request), + }, + ) + + @staticmethod + @register_command(CommandType.SEND_POKE, require_group=False) + def handle_poke_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """处理戳一戳命令 + + Args: + args: 参数字典 {"qq_id": int} + group_info: 群聊信息(可选,私聊时为None) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + user_id: int = int(args["qq_id"]) + if group_info is None: + group_id = None + else: + group_id: int = int(group_info.group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + if user_id <= 0: + raise ValueError("用户ID无效") + return ( + CommandType.SEND_POKE.value, + { + "group_id": group_id, + "user_id": user_id, + }, + ) + + @staticmethod + @register_command(CommandType.SET_GROUP_NAME, require_group=False) + def handle_set_group_name_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """设置群名 + + Args: + args: 参数字典 {"group_id": int, "group_name": str} + group_info: 群聊信息(可选,可自动获取 group_id) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("设置群名命令缺少参数") + + # 优先从 args 获取 group_id,否则从 group_info 获取 + group_id = args.get("group_id") + if not group_id and group_info: + group_id = int(group_info.group_id) + + group_name = args.get("group_name") + + if not group_id: + raise ValueError("设置群名命令缺少必要参数: group_id") + if not group_name: + raise ValueError("设置群名命令缺少必要参数: group_name") + + group_id = int(group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + + return ( + CommandType.SET_GROUP_NAME.value, + { + "group_id": group_id, + "group_name": str(group_name), + }, + ) + + @staticmethod + @register_command(CommandType.DELETE_MSG, require_group=False) + def delete_msg_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """处理撤回消息命令 + + Args: + args: 参数字典 {"message_id": int} + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + try: + message_id = int(args["message_id"]) + if message_id <= 0: + raise ValueError("消息ID无效") + except KeyError: + raise ValueError("缺少必需参数: message_id") from None + except (ValueError, TypeError) as e: + raise ValueError(f"消息ID无效: {args['message_id']} - {str(e)}") from None + + return (CommandType.DELETE_MSG.value, {"message_id": message_id}) + + @staticmethod + @register_command(CommandType.SET_QQ_PROFILE, require_group=False) + def handle_set_qq_profile_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """设置账号信息 + + Args: + args: 参数字典 {"nickname": str, "personal_note": str (可选), "sex": str (可选)} + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("设置账号信息命令缺少参数") + + nickname = args.get("nickname") + if not nickname: + raise ValueError("设置账号信息命令缺少必要参数: nickname") + + params = {"nickname": str(nickname)} + + # 可选参数 + if "personal_note" in args: + params["personal_note"] = str(args["personal_note"]) + + if "sex" in args: + sex = str(args["sex"]).lower() + if sex not in ["male", "female", "unknown"]: + raise ValueError(f"性别参数无效: {sex},必须为 male/female/unknown 之一") + params["sex"] = sex + + return (CommandType.SET_QQ_PROFILE.value, params) + + @staticmethod + @register_command(CommandType.AI_VOICE_SEND, require_group=True) + def handle_ai_voice_send_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """处理AI语音发送命令 + + Args: + args: 参数字典 {"character": str, "text": str} + group_info: 群聊信息 + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not group_info or not group_info.group_id: + raise ValueError("AI语音发送命令必须在群聊上下文中使用") + if not args: + raise ValueError("AI语音发送命令缺少参数") + + group_id: int = int(group_info.group_id) + character_id = args.get("character") + text_content = args.get("text") + + if not character_id or not text_content: + raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'") + + return ( + CommandType.AI_VOICE_SEND.value, + { + "group_id": group_id, + "text": text_content, + "character": character_id, + }, + ) + + @staticmethod + @register_command(CommandType.SET_MSG_EMOJI_LIKE, require_group=False) + def handle_set_msg_emoji_like_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """处理给消息贴表情命令 + + Args: + args: 参数字典 {"message_id": int, "emoji_id": int} + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("消息贴表情命令缺少参数") + + message_id = args.get("message_id") + emoji_id = args.get("emoji_id") + if not message_id: + raise ValueError("消息贴表情命令缺少必要参数: message_id") + if not emoji_id: + raise ValueError("消息贴表情命令缺少必要参数: emoji_id") + + message_id = int(message_id) + emoji_id = int(emoji_id) + if message_id <= 0: + raise ValueError("消息ID无效") + if emoji_id <= 0: + raise ValueError("表情ID无效") + + return ( + CommandType.SET_MSG_EMOJI_LIKE.value, + { + "message_id": message_id, + "emoji_id": emoji_id, + "set": True, + }, + ) + + # ============ 查询类命令处理器 ============ + + @staticmethod + @register_command(CommandType.GET_LOGIN_INFO, require_group=False) + def handle_get_login_info_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取登录号信息(Bot自身信息) + + Args: + args: 参数字典(无需参数) + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + return (CommandType.GET_LOGIN_INFO.value, {}) + + @staticmethod + @register_command(CommandType.GET_STRANGER_INFO, require_group=False) + def handle_get_stranger_info_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取陌生人信息 + + Args: + args: 参数字典 {"user_id": int} + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("获取陌生人信息命令缺少参数") + + user_id = args.get("user_id") + if not user_id: + raise ValueError("获取陌生人信息命令缺少必要参数: user_id") + + user_id = int(user_id) + if user_id <= 0: + raise ValueError("用户ID无效") + + return ( + CommandType.GET_STRANGER_INFO.value, + {"user_id": user_id}, + ) + + @staticmethod + @register_command(CommandType.GET_FRIEND_LIST, require_group=False) + def handle_get_friend_list_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取好友列表 + + Args: + args: 参数字典 {"no_cache": bool} (可选,默认 false) + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + # no_cache 参数是可选的,默认为 false + no_cache = args.get("no_cache", False) if args else False + + return (CommandType.GET_FRIEND_LIST.value, {"no_cache": bool(no_cache)}) + + @staticmethod + @register_command(CommandType.GET_GROUP_INFO, require_group=False) + def handle_get_group_info_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取群信息 + + Args: + args: 参数字典 {"group_id": int} 或从 group_info 自动获取 + group_info: 群聊信息(可选) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + # 优先从 args 获取,否则从 group_info 获取 + group_id = args.get("group_id") if args else None + if not group_id and group_info: + group_id = int(group_info.group_id) + + if not group_id: + raise ValueError("获取群信息命令缺少必要参数: group_id") + + group_id = int(group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + + return ( + CommandType.GET_GROUP_INFO.value, + {"group_id": group_id}, + ) + + @staticmethod + @register_command(CommandType.GET_GROUP_DETAIL_INFO, require_group=False) + def handle_get_group_detail_info_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取群详细信息 + + Args: + args: 参数字典 {"group_id": int} 或从 group_info 自动获取 + group_info: 群聊信息(可选) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + # 优先从 args 获取,否则从 group_info 获取 + group_id = args.get("group_id") if args else None + if not group_id and group_info: + group_id = int(group_info.group_id) + + if not group_id: + raise ValueError("获取群详细信息命令缺少必要参数: group_id") + + group_id = int(group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + + return ( + CommandType.GET_GROUP_DETAIL_INFO.value, + {"group_id": group_id}, + ) + + @staticmethod + @register_command(CommandType.GET_GROUP_LIST, require_group=False) + def handle_get_group_list_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取群列表 + + Args: + args: 参数字典 {"no_cache": bool} (可选,默认 false) + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + # no_cache 参数是可选的,默认为 false + no_cache = args.get("no_cache", False) if args else False + + return (CommandType.GET_GROUP_LIST.value, {"no_cache": bool(no_cache)}) + + @staticmethod + @register_command(CommandType.GET_GROUP_AT_ALL_REMAIN, require_group=False) + def handle_get_group_at_all_remain_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取群@全体成员剩余次数 + + Args: + args: 参数字典 {"group_id": int} 或从 group_info 自动获取 + group_info: 群聊信息(可选) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + # 优先从 args 获取,否则从 group_info 获取 + group_id = args.get("group_id") if args else None + if not group_id and group_info: + group_id = int(group_info.group_id) + + if not group_id: + raise ValueError("获取群@全体成员剩余次数命令缺少必要参数: group_id") + + group_id = int(group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + + return ( + CommandType.GET_GROUP_AT_ALL_REMAIN.value, + {"group_id": group_id}, + ) + + @staticmethod + @register_command(CommandType.GET_GROUP_MEMBER_INFO, require_group=False) + def handle_get_group_member_info_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取群成员信息 + + Args: + args: 参数字典 {"group_id": int, "user_id": int, "no_cache": bool} 或 group_id 从 group_info 自动获取 + group_info: 群聊信息(可选) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("获取群成员信息命令缺少参数") + + # 优先从 args 获取,否则从 group_info 获取 + group_id = args.get("group_id") + if not group_id and group_info: + group_id = int(group_info.group_id) + + user_id = args.get("user_id") + no_cache = args.get("no_cache", False) + + if not group_id: + raise ValueError("获取群成员信息命令缺少必要参数: group_id") + if not user_id: + raise ValueError("获取群成员信息命令缺少必要参数: user_id") + + group_id = int(group_id) + user_id = int(user_id) + if group_id <= 0: + raise ValueError("群组ID无效") + if user_id <= 0: + raise ValueError("用户ID无效") + + return ( + CommandType.GET_GROUP_MEMBER_INFO.value, + { + "group_id": group_id, + "user_id": user_id, + "no_cache": bool(no_cache), + }, + ) + + @staticmethod + @register_command(CommandType.GET_GROUP_MEMBER_LIST, require_group=False) + def handle_get_group_member_list_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取群成员列表 + + Args: + args: 参数字典 {"group_id": int, "no_cache": bool} 或 group_id 从 group_info 自动获取 + group_info: 群聊信息(可选) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + # 优先从 args 获取,否则从 group_info 获取 + group_id = args.get("group_id") if args else None + if not group_id and group_info: + group_id = int(group_info.group_id) + + no_cache = args.get("no_cache", False) if args else False + + if not group_id: + raise ValueError("获取群成员列表命令缺少必要参数: group_id") + + group_id = int(group_id) + if group_id <= 0: + raise ValueError("群组ID无效") + + return ( + CommandType.GET_GROUP_MEMBER_LIST.value, + { + "group_id": group_id, + "no_cache": bool(no_cache), + }, + ) + + @staticmethod + @register_command(CommandType.GET_MSG, require_group=False) + def handle_get_msg_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取消息详情 + + Args: + args: 参数字典 {"message_id": int} + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("获取消息命令缺少参数") + + message_id = args.get("message_id") + if not message_id: + raise ValueError("获取消息命令缺少必要参数: message_id") + + message_id = int(message_id) + if message_id <= 0: + raise ValueError("消息ID无效") + + return ( + CommandType.GET_MSG.value, + {"message_id": message_id}, + ) + + @staticmethod + @register_command(CommandType.GET_FORWARD_MSG, require_group=False) + def handle_get_forward_msg_command(args: Dict[str, Any], group_info: Optional[GroupInfo]) -> Tuple[str, Dict[str, Any]]: + """获取合并转发消息 + + Args: + args: 参数字典 {"message_id": str} + group_info: 群聊信息(不使用) + + Returns: + Tuple[str, Dict[str, Any]]: (action, params) + """ + if not args: + raise ValueError("获取合并转发消息命令缺少参数") + + message_id = args.get("message_id") + if not message_id: + raise ValueError("获取合并转发消息命令缺少必要参数: message_id") + + return ( + CommandType.GET_FORWARD_MSG.value, + {"message_id": str(message_id)}, + ) diff --git a/src/send_handler/send_message_handler.py b/src/send_handler/send_message_handler.py new file mode 100644 index 0000000..101ef8d --- /dev/null +++ b/src/send_handler/send_message_handler.py @@ -0,0 +1,295 @@ +from maim_message import Seg, MessageBase +from typing import List, Dict + +from src.logger import logger +from src.config import global_config +from src.utils import get_image_format, convert_image_to_gif + + +class SendMessageHandleClass: + @classmethod + def parse_seg_to_nc_format(cls, message_segment: Seg): + parsed_payload: List = cls.process_seg_recursive(message_segment) + return parsed_payload + + @classmethod + def process_seg_recursive(cls, seg_data: Seg, in_forward: bool = False) -> List: + payload: List = [] + if seg_data.type == "seglist": + if not seg_data.data: + return [] + for seg in seg_data.data: + payload = cls.process_message_by_type(seg, payload, in_forward) + else: + payload = cls.process_message_by_type(seg_data, payload, in_forward) + return payload + + @classmethod + def process_message_by_type(cls, seg: Seg, payload: List, in_forward: bool = False) -> List: + # sourcery skip: for-append-to-extend, reintroduce-else, swap-if-else-branches, use-named-expression + new_payload = payload + if seg.type == "reply": + target_id = seg.data + if target_id == "notice": + return payload + new_payload = cls.build_payload(payload, cls.handle_reply_message(target_id), True) + elif seg.type == "text": + text = seg.data + if not text: + return payload + new_payload = cls.build_payload(payload, cls.handle_text_message(text), False) + elif seg.type == "face": + face_id = seg.data + new_payload = cls.build_payload(payload, cls.handle_native_face_message(face_id), False) + elif seg.type == "image": + image = seg.data + new_payload = cls.build_payload(payload, cls.handle_image_message(image), False) + elif seg.type == "emoji": + emoji = seg.data + new_payload = cls.build_payload(payload, cls.handle_emoji_message(emoji), False) + elif seg.type == "voice": + voice = seg.data + new_payload = cls.build_payload(payload, cls.handle_voice_message(voice), False) + elif seg.type == "voiceurl": + voice_url = seg.data + new_payload = cls.build_payload(payload, cls.handle_voiceurl_message(voice_url), False) + elif seg.type == "music": + music_data = seg.data + new_payload = cls.build_payload(payload, cls.handle_music_message(music_data), False) + elif seg.type == "videourl": + video_url = seg.data + new_payload = cls.build_payload(payload, cls.handle_videourl_message(video_url), False) + elif seg.type == "file": + file_path = seg.data + new_payload = cls.build_payload(payload, cls.handle_file_message(file_path), False) + elif seg.type == "imageurl": + image_url = seg.data + new_payload = cls.build_payload(payload, cls.handle_imageurl_message(image_url), False) + elif seg.type == "video": + video_path = seg.data + new_payload = cls.build_payload(payload, cls.handle_video_message(video_path), False) + elif seg.type == "forward" and not in_forward: + forward_message_content: List[Dict] = seg.data + new_payload: List[Dict] = [ + cls.handle_forward_message(MessageBase.from_dict(item)) for item in forward_message_content + ] # 转发消息不能和其他消息一起发送 + return new_payload + + @classmethod + def handle_forward_message(cls, item: MessageBase) -> Dict: + # sourcery skip: remove-unnecessary-else + message_segment: Seg = item.message_segment + if message_segment.type == "id": + return {"type": "node", "data": {"id": message_segment.data}} + else: + user_info = item.message_info.user_info + content = cls.process_seg_recursive(message_segment, True) + return { + "type": "node", + "data": {"name": user_info.user_nickname or "QQ用户", "uin": user_info.user_id, "content": content}, + } + + @staticmethod + def build_payload(payload: List, addon: dict, is_reply: bool = False) -> List: + # sourcery skip: for-append-to-extend, merge-list-append, simplify-generator + if is_reply: + temp_list = [] + temp_list.append(addon) + for i in payload: + if i.get("type") == "reply": + logger.debug("检测到多个回复,使用最新的回复") + continue + temp_list.append(i) + return temp_list + else: + payload.append(addon) + return payload + + @staticmethod + def handle_reply_message(id: str) -> dict: + """处理回复消息""" + return {"type": "reply", "data": {"id": id}} + + @staticmethod + def handle_text_message(message: str) -> dict: + """处理文本消息""" + return {"type": "text", "data": {"text": message}} + + @staticmethod + def handle_native_face_message(face_id: int) -> dict: + # sourcery skip: remove-unnecessary-cast + """处理原生表情消息""" + return {"type": "face", "data": {"id": int(face_id)}} + + @staticmethod + def handle_image_message(encoded_image: str) -> dict: + """处理图片消息""" + return { + "type": "image", + "data": { + "file": f"base64://{encoded_image}", + "subtype": 0, + }, + } # base64 编码的图片 + + @staticmethod + def handle_emoji_message(encoded_emoji: str) -> dict: + """处理表情消息""" + encoded_image = encoded_emoji + image_format = get_image_format(encoded_emoji) + if image_format != "gif": + encoded_image = convert_image_to_gif(encoded_emoji) + return { + "type": "image", + "data": { + "file": f"base64://{encoded_image}", + "subtype": 1, + "summary": "[动画表情]", + }, + } + + @staticmethod + def handle_voice_message(encoded_voice: str) -> dict: + """处理语音消息""" + if not global_config.voice.use_tts: + logger.warning("未启用语音消息处理") + return {} + if not encoded_voice: + return {} + return { + "type": "record", + "data": {"file": f"base64://{encoded_voice}"}, + } + + @staticmethod + def handle_voiceurl_message(voice_url: str) -> dict: + """处理语音链接消息""" + return { + "type": "record", + "data": {"file": voice_url}, + } + + @staticmethod + def handle_music_message(music_data) -> dict: + """ + 处理音乐消息 + music_data 可以是: + 1. 字符串:默认为网易云音乐ID + 2. 字典:{"type": "163"/"qq", "id": "歌曲ID"} + """ + # 兼容旧格式:直接传入歌曲ID字符串 + if isinstance(music_data, str): + return { + "type": "music", + "data": {"type": "163", "id": music_data}, + } + + # 新格式:字典包含平台和ID + if isinstance(music_data, dict): + platform = music_data.get("type", "163") # 默认网易云 + song_id = music_data.get("id", "") + + # 验证平台类型 + if platform not in ["163", "qq"]: + logger.warning(f"不支持的音乐平台: {platform},使用默认平台163") + platform = "163" + + # 确保ID是字符串 + if not isinstance(song_id, str): + song_id = str(song_id) + + return { + "type": "music", + "data": {"type": platform, "id": song_id}, + } + + # 其他情况返回空 + logger.error(f"不支持的音乐数据格式: {type(music_data)}") + return {} + + @staticmethod + def handle_videourl_message(video_url: str) -> dict: + """处理视频链接消息""" + return { + "type": "video", + "data": {"file": video_url}, + } + + @staticmethod + def handle_file_message(file_data) -> dict: + """处理文件消息 + + Args: + file_data: 可以是字符串(文件路径)或字典(完整文件信息) + - 字符串:简单的文件路径 + - 字典:包含 file, name, path, thumb, url 等字段 + + Returns: + NapCat 格式的文件消息段 + """ + # 如果是简单的字符串路径(兼容旧版本) + if isinstance(file_data, str): + return { + "type": "file", + "data": {"file": f"file://{file_data}"}, + } + + # 如果是完整的字典数据 + if isinstance(file_data, dict): + data = {} + + # file 字段是必需的 + if "file" in file_data: + file_value = file_data["file"] + # 如果是本地路径且没有协议前缀,添加 file:// 前缀 + if not any(file_value.startswith(prefix) for prefix in ["file://", "http://", "https://", "base64://"]): + data["file"] = f"file://{file_value}" + else: + data["file"] = file_value + else: + # 没有 file 字段,尝试使用 path 或 url + if "path" in file_data: + data["file"] = f"file://{file_data['path']}" + elif "url" in file_data: + data["file"] = file_data["url"] + else: + logger.warning("文件消息缺少必要的 file/path/url 字段") + return None + + # 添加可选字段 + if "name" in file_data: + data["name"] = file_data["name"] + if "thumb" in file_data: + data["thumb"] = file_data["thumb"] + if "url" in file_data and "file" not in file_data: + data["file"] = file_data["url"] + + return { + "type": "file", + "data": data, + } + + logger.warning(f"不支持的文件数据类型: {type(file_data)}") + return None + + @staticmethod + def handle_imageurl_message(image_url: str) -> dict: + """处理图片链接消息""" + return { + "type": "image", + "data": {"file": image_url}, + } + + @staticmethod + def handle_video_message(encoded_video: str) -> dict: + """处理视频消息(base64格式)""" + if not encoded_video: + logger.error("视频数据为空") + return {} + + logger.info(f"处理视频消息,数据长度: {len(encoded_video)} 字符") + + return { + "type": "video", + "data": {"file": f"base64://{encoded_video}"}, + } diff --git a/src/utils.py b/src/utils.py index f85ad9a..78b0d0c 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,14 +2,16 @@ import json import base64 import uuid -from .logger import logger -from .message_queue import get_response - import urllib3 import ssl +import io + +from src.database import BanUser, db_manager +from .logger import logger +from .response_pool import get_response from PIL import Image -import io +from typing import Union, List, Tuple, Optional class SSLAdapter(urllib3.PoolManager): @@ -21,7 +23,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict: +async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: """ 获取群相关信息 @@ -43,7 +45,29 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d return socket_response.get("data") -async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict: +async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: + """ + 获取群详细信息 + + 返回值需要处理可能为空的情况 + """ + logger.debug("获取群详细信息中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}) + try: + await websocket.send(payload) + socket_response: dict = await get_response(request_uuid) + except TimeoutError: + logger.error(f"获取群详细信息超时,群号: {group_id}") + return None + except Exception as e: + logger.error(f"获取群详细信息失败: {e}") + return None + logger.debug(socket_response) + return socket_response.get("data") + + +async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None: """ 获取群成员信息 @@ -88,6 +112,7 @@ async def get_image_base64(url: str) -> str: def convert_image_to_gif(image_base64: str) -> str: + # sourcery skip: extract-method """ 将Base64编码的图片转换为GIF格式 Parameters: @@ -108,7 +133,7 @@ def convert_image_to_gif(image_base64: str) -> str: return image_base64 -async def get_self_info(websocket: Server.ServerConnection) -> dict: +async def get_self_info(websocket: Server.ServerConnection) -> dict | None: """ 获取自身信息 Parameters: @@ -144,7 +169,7 @@ def get_image_format(raw_data: str) -> str: return Image.open(io.BytesIO(image_bytes)).format.lower() -async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict: +async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None: """ 获取陌生人信息 Parameters: @@ -169,7 +194,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> return response.get("data") -async def get_message_detail(websocket: Server.ServerConnection, message_id: str) -> dict: +async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None: """ 获取消息详情,可能为空 Parameters: @@ -183,7 +208,7 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: str payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}) try: await websocket.send(payload) - response: dict = await get_response(request_uuid) + response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 except TimeoutError: logger.error(f"获取消息详情超时,消息ID: {message_id}") return None @@ -192,3 +217,94 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: str return None logger.debug(response) return response.get("data") + + +async def get_record_detail( + websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None +) -> dict | None: + """ + 获取语音消息内容 + Parameters: + websocket: WebSocket连接对象 + file: 文件名 + file_id: 文件ID + Returns: + dict: 返回的语音消息详情 + """ + logger.debug("获取语音消息详情中") + request_uuid = str(uuid.uuid4()) + payload = json.dumps( + { + "action": "get_record", + "params": {"file": file, "file_id": file_id, "out_format": "wav"}, + "echo": request_uuid, + } + ) + try: + await websocket.send(payload) + response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 + except TimeoutError: + logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}") + return None + except Exception as e: + logger.error(f"获取语音消息详情失败: {e}") + return None + logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长 + return response.get("data") + + +async def read_ban_list( + websocket: Server.ServerConnection, +) -> Tuple[List[BanUser], List[BanUser]]: + """ + 从根目录下的data文件夹中的文件读取禁言列表。 + 同时自动更新已经失效禁言 + Returns: + Tuple[ + 一个仍在禁言中的用户的BanUser列表, + 一个已经自然解除禁言的用户的BanUser列表, + 一个仍在全体禁言中的群的BanUser列表, + 一个已经自然解除全体禁言的群的BanUser列表, + ] + """ + try: + ban_list = db_manager.get_ban_records() + lifted_list: List[BanUser] = [] + logger.info("已经读取禁言列表") + for ban_record in ban_list: + if ban_record.user_id == 0: + fetched_group_info = await get_group_info(websocket, ban_record.group_id) + if fetched_group_info is None: + logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除") + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + group_all_shut: int = fetched_group_info.get("group_all_shut") + if group_all_shut == 0: + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + else: + fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id) + if fetched_member_info is None: + logger.warning( + f"无法获取群成员信息,用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除" + ) + lifted_list.append(ban_record) + ban_list.remove(ban_record) + continue + lift_ban_time: int = fetched_member_info.get("shut_up_timestamp") + if lift_ban_time == 0: + lifted_list.append(ban_record) + ban_list.remove(ban_record) + else: + ban_record.lift_time = lift_ban_time + db_manager.update_ban_record(ban_list) + return ban_list, lifted_list + except Exception as e: + logger.error(f"读取禁言列表失败: {e}") + return [], [] + + +def save_ban_record(list: List[BanUser]): + return db_manager.update_ban_record(list) diff --git a/template/template_config.toml b/template/template_config.toml index 1d0d830..2f786a4 100644 --- a/template/template_config.toml +++ b/template/template_config.toml @@ -1,30 +1,41 @@ -[Nickname] # 现在没用 +[inner] +version = "0.1.3" # 版本号 +# 请勿修改版本号,除非你知道自己在做什么 + +[nickname] # 现在没用 nickname = "" -[Napcat_Server] # Napcat连接的ws服务设置 -host = "localhost" # Napcat设定的主机地址 -port = 8095 # Napcat设定的端口 -heartbeat = 30 # 与Napcat设置的心跳相同(按秒计) +[napcat_server] # Napcat连接的ws服务设置 +host = "localhost" # Napcat设定的主机地址 +port = 8095 # Napcat设定的端口 +token = "" # Napcat设定的访问令牌,若无则留空 +heartbeat_interval = 30 # 与Napcat设置的心跳相同(按秒计) -[MaiBot_Server] # 连接麦麦的ws服务设置 -platform_name = "qq" # 标识adapter的名称(必填) -host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段 -port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段 +[maibot_server] # 连接麦麦的ws服务设置 +host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段 +port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段 +enable_api_server = false # 是否启用API-Server模式连接 +base_url = "ws://127.0.0.1:18095/ws" # API-Server连接地址 (ws://ip:port/path),仅在enable_api_server为true时使用 +api_key = "maibot" # API Key (仅在enable_api_server为true时使用) -[Chat] # 黑白名单功能 +[chat] # 黑白名单功能 group_list_type = "whitelist" # 群组名单类型,可选为:whitelist, blacklist -group_list = [] # 群组名单 +group_list = [] # 群组名单 # 当group_list_type为whitelist时,只有群组名单中的群组可以聊天 # 当group_list_type为blacklist时,群组名单中的任何群组无法聊天 private_list_type = "whitelist" # 私聊名单类型,可选为:whitelist, blacklist -private_list = [] # 私聊名单 +private_list = [] # 私聊名单 # 当private_list_type为whitelist时,只有私聊名单中的用户可以聊天 # 当private_list_type为blacklist时,私聊名单中的任何用户无法聊天 -ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天) +ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天) +ban_qq_bot = false # 是否屏蔽QQ官方机器人 enable_poke = true # 是否启用戳一戳功能 -[Voice] # 发送语音设置 +[voice] # 发送语音设置 use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter) -[Debug] -level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR) +[forward] # 转发消息处理设置 +image_threshold = 3 # 图片数量阈值:转发消息中图片数量超过此值时使用占位符(避免麦麦VLM处理卡死) + +[debug] +level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)