diff --git a/.gitignore b/.gitignore index cd52268..f01164f 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,4 @@ run.sh .env .vscode aerich_config.py +.DS_Store \ No newline at end of file diff --git a/U1/message.py b/U1/message.py deleted file mode 100644 index b081602..0000000 --- a/U1/message.py +++ /dev/null @@ -1,38 +0,0 @@ -from io import BytesIO -from pathlib import Path - -from nonebot.adapters.onebot.v11.message import Message, MessageSegment - - -class MessageBuilder(Message): - def at(self, user_id: int | str) -> "MessageBuilder": - self.append(MessageSegment.at(user_id)) - return self - - def face(self, id_: int) -> "MessageBuilder": - self.append(MessageSegment.face(id_)) - return self - - def image( - self, - file: str | bytes | BytesIO | Path, - type_: str | None = None, - cache: bool = True, - proxy: bool = True, - timeout: int | None = None, - ) -> "MessageBuilder": - self.append(MessageSegment.image(file, type_, cache, proxy, timeout)) - return self - - def reply(self, id_: int) -> "MessageBuilder": - self.append(MessageSegment.reply(id_)) - return self - - def text(self, text: str) -> "MessageBuilder": - if self[-1].type == "text": - text = "\n" + text - self.append(MessageSegment.text(text)) - return self - - def done(self) -> str: - return "".join(map(str, self)) diff --git a/U1/model.py b/U1/model.py index 72fe29d..618c0c6 100644 --- a/U1/model.py +++ b/U1/model.py @@ -16,5 +16,5 @@ class Channel(Model): permissions = fields.TextField(null=True, default=None) createdAt = fields.DatetimeField(null=True, default=None) - class Meta: + class Meta: # type: ignore unique_together = ("platform", "flag") diff --git a/U1/utils/permission.py b/U1/utils/permission.py new file mode 100644 index 0000000..6f75fd4 --- /dev/null +++ b/U1/utils/permission.py @@ -0,0 +1,35 @@ +from nonebot.adapters.milky.event import GroupMessageEvent +from nonebot.adapters.milky.model.common import Member +from nonebot.permission import Permission + + +async def _group_admin(event: GroupMessageEvent) -> bool: + if isinstance(event.data.sender, Member): + return event.data.sender.role == "admin" + raise TypeError( + f"Expected Member, got {type(event.data.sender)}: {event.data.sender}" + ) + + +async def _group_owner(event: GroupMessageEvent) -> bool: + if isinstance(event.data.sender, Member): + return event.data.sender.role == "owner" + raise TypeError( + f"Expected Member, got {type(event.data.sender)}: {event.data.sender}" + ) + + +async def _group_member(event: GroupMessageEvent) -> bool: + if isinstance(event.data.sender, Member): + return event.data.sender.role == "member" + raise TypeError( + f"Expected Member, got {type(event.data.sender)}: {event.data.sender}" + ) + + +GROUP_MEMBER: Permission = Permission(_group_member) +"""匹配任意群员群聊消息类型事件""" +GROUP_ADMIN: Permission = Permission(_group_admin) +"""匹配任意群管理员群聊消息类型事件""" +GROUP_OWNER: Permission = Permission(_group_owner) +"""匹配任意群主群聊消息类型事件""" diff --git a/U1/utils/token_bucket.py b/U1/utils/token_bucket.py new file mode 100644 index 0000000..2541799 --- /dev/null +++ b/U1/utils/token_bucket.py @@ -0,0 +1,86 @@ +from asyncio import get_running_loop +from collections import defaultdict +from enum import IntEnum, auto + +from nonebot.adapters.milky.event import MessageEvent +from nonebot.matcher import Matcher +from nonebot.params import Depends + + +class CooldownIsolateLevel(IntEnum): + """命令冷却的隔离级别""" + + GLOBAL = auto() + """全局使用同一个冷却计时""" + GROUP = auto() + """群组内使用同一个冷却计时""" + USER = auto() + """按用户使用同一个冷却计时""" + GROUP_USER = auto() + """群组内每个用户使用同一个冷却计时""" + + +def Cooldown( + cooldown: float = 5, + *, + prompt: str | None = None, + isolate_level: CooldownIsolateLevel = CooldownIsolateLevel.USER, + parallel: int = 1, +) -> None: + """依赖注入形式的事件冷却 + + 用法: + ```python + @matcher.handle(parameterless=[Cooldown(cooldown=11.4514, ...)]) + async def handle_command(matcher: Matcher, message: Message): + ... + ``` + + 参数: + cooldown: 冷却间隔 + prompt: 当触发冷却时发送给用户的提示消息 + isolate_level: 事件冷却的隔离级别, 参考 `CooldownIsolateLevel` + parallel: 并行执行的命令数量 + """ + if not isinstance(isolate_level, CooldownIsolateLevel): + raise ValueError( + f"invalid isolate level: {isolate_level!r}, " + "isolate level must use provided enumerate value." + ) + running: defaultdict[str, int] = defaultdict(lambda: parallel) + + def increase(key: str, value: int = 1): + running[key] += value + if running[key] >= parallel: + del running[key] + return + + async def dependency(matcher: Matcher, event: MessageEvent): + loop = get_running_loop() + + + if isolate_level is CooldownIsolateLevel.GROUP: + if event.data.message_scene == "group": + key = str(event.data.peer_id) + else: + raise ValueError( + "isolate_level is set to GROUP, but event is not a group message." + ) + elif isolate_level is CooldownIsolateLevel.USER: + key = event.get_user_id() + elif isolate_level is CooldownIsolateLevel.GROUP_USER: + key = event.get_session_id() + else: + key = CooldownIsolateLevel.GLOBAL.name + + if not key: + return + + if running[key] <= 0: + await matcher.finish(prompt) + else: + running[key] -= 1 + loop.call_later(cooldown, lambda: increase(key)) + return + + return Depends(dependency) diff --git a/U1/utils/utils.py b/U1/utils/utils.py new file mode 100644 index 0000000..0874625 --- /dev/null +++ b/U1/utils/utils.py @@ -0,0 +1,16 @@ +from nonebot.adapters.milky import Message + +def extract_image_urls(message: Message) -> list[str]: + """提取消息中的图片链接 + + 参数: + message: 消息对象 + + 返回: + 图片链接列表 + """ + return [ + segment.data["url"] + for segment in message + if (segment.type == "image") and ("url" in segment.data) + ] diff --git a/bot.py b/bot.py index f8dddea..8d14132 100644 --- a/bot.py +++ b/bot.py @@ -45,14 +45,14 @@ def default_filter(record: "Record"): ) -from nonebot.adapters.onebot.v11 import Adapter as ONEBOT_V11Adapter -from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent +from nonebot.adapters.milky import Adapter as MilkyAdapter +from nonebot.adapters.milky import Bot +from nonebot.adapters.milky.event import GroupMessageEvent from nonebot.exception import IgnoredException nonebot.init() -app = nonebot.get_asgi() driver = nonebot.get_driver() -driver.register_adapter(ONEBOT_V11Adapter) +driver.register_adapter(MilkyAdapter) @driver.on_startup @@ -86,10 +86,10 @@ async def _(bot: Bot, event: GroupMessageEvent): if event.to_me: return - channel = await get_channel(str(event.group_id)) + channel = await get_channel(str(event.data.peer_id)) if channel is None: for _ in range(3): - channel = await get_channel(str(event.group_id)) + channel = await get_channel(str(event.data.peer_id)) if channel is not None: break # 重试直到找到频道 await asyncio.sleep(0.5) diff --git a/migrations/versions/adc8daa3ce44_migrate_cq_codes_to_img_base64.py b/migrations/versions/adc8daa3ce44_migrate_cq_codes_to_img_base64.py new file mode 100644 index 0000000..2bd76ab --- /dev/null +++ b/migrations/versions/adc8daa3ce44_migrate_cq_codes_to_img_base64.py @@ -0,0 +1,100 @@ +"""migrate_cq_codes_to_img_base64 + +迁移 ID: adc8daa3ce44 +父迁移: 782cb0785d08 +创建时间: 2025-07-11 01:05:56.515358 + +""" + +from __future__ import annotations + +import json +import re +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +revision: str = "adc8daa3ce44" +down_revision: str | Sequence[str] | None = "782cb0785d08" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def extract_base64_from_cq_codes(text: str) -> list[str]: + """ + 从文本中提取所有 [CQ:image,file=base64://...] 中的 base64 数据。 + """ + base64_images = [] + # 匹配 [CQ:image,file=base64://xxxxx] 格式 + cq_pattern = r"\[CQ:image,file=base64://([^]]+)\]" + matches = re.findall(cq_pattern, text) + base64_images.extend(matches) + return base64_images + + +def remove_cq_codes(text: str) -> str: + """ + 移除文本中的所有 CQ 码。 + """ + # 移除所有 [CQ:...] 格式的代码 + return re.sub(r"\[CQ:[^\]]+\]", "", text).strip() + + +def upgrade(name: str = "") -> None: + if name: + return + + # 检查 img_base64 列是否存在 + connection = op.get_bind() + inspector = sa.inspect(connection) + columns = [col["name"] for col in inspector.get_columns("cave_models")] + + # 如果字段不存在,则添加它 + if "img_base64" not in columns: + with op.batch_alter_table("cave_models", schema=None) as batch_op: + batch_op.add_column(sa.Column("img_base64", sa.JSON(), nullable=False)) + + # 数据迁移:处理现有的 CQ 码数据 + # 查询所有包含 CQ 码的记录 + result = connection.execute( + sa.text("SELECT id, details FROM cave_models WHERE details LIKE '%[CQ:image%'") + ) + + for row in result: + record_id, details = row + + # 提取 base64 数据 + base64_images = extract_base64_from_cq_codes(details) + + # 移除 CQ 码 + clean_details = remove_cq_codes(details) + + # 更新记录 + connection.execute( + sa.text( + "UPDATE cave_models SET details = :details, img_base64 = :img_base64 WHERE id = :id" + ), + { + "details": clean_details, + "img_base64": json.dumps(base64_images), + "id": record_id, + }, + ) + + # 为没有图片的记录设置空数组 + connection.execute( + sa.text("UPDATE cave_models SET img_base64 = '[]' WHERE img_base64 IS NULL OR img_base64 = ''") + ) + + connection.commit() + + +def downgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("cave_models", schema=None) as batch_op: + batch_op.drop_column("img_base64") + + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index aecdc21..a0795e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,6 @@ plugins = [ "nonebot_plugin_orm", ] plugin_dirs = ["src/plugins"] -builtin_plugins = [] -[[tool.nonebot.adapters]] -name = "OneBot V11" -module_name = "nonebot.adapters.onebot.v11" [tool.ruff] line-length = 88 @@ -22,16 +18,14 @@ select = ["F", "W", "E", "UP", "ASYNC", "C4", "T10", "PYI", "PT", "Q", "RUF"] ignore = ["E402", "E501", "UP037", "RUF001", "RUF002", "RUF003"] [tool.pyright] -typeCheckingMode = "basic" +typeCheckingMode = "standard" [project] requires-python = "<4.0,>=3.10" dependencies = [ - "nonebot2[fastapi,aiohttp,websockets]>=2.4.1", "nonebot-plugin-orm[mysql]>=0.8.1", "nonebot-plugin-apscheduler>=0.5.0", "nonebot-plugin-htmlrender==0.4.0", - "nonebot-adapter-onebot>=2.4.6", "nonebot-plugin-userinfo>=0.2.6", "aiohttp>=3.10.0", "ujson>=5.10.0", @@ -50,6 +44,10 @@ dependencies = [ "openai>=1.76.0", "tomlkit>=0.13.2", "nb-cli>=1.4.2", + "nonebot-adapter-milky>=0.4.1", + "pysilk>=0.0.1", + "qrcode>=8.2", + "nonebot2[httpx,websockets]>=2.4.2", ] name = "u1bot" version = "0.1.0" diff --git a/src/plugins/Menu/__init__.py b/src/plugins/Menu/__init__.py index d8103f5..ec8361b 100644 --- a/src/plugins/Menu/__init__.py +++ b/src/plugins/Menu/__init__.py @@ -5,9 +5,9 @@ import ujson as json from nonebot import on_command -from nonebot.adapters.onebot.v11 import Bot as V11Bot -from nonebot.adapters.onebot.v11 import Message as V11Msg -from nonebot.adapters.onebot.v11 import MessageSegment as V11MsgSeg +from nonebot.adapters.milky import Bot +from nonebot.adapters.milky import Message +from nonebot.adapters.milky import MessageSegment from nonebot.log import logger from nonebot.matcher import Matcher from nonebot.params import CommandArg @@ -39,11 +39,11 @@ async def get_reply(name: str): @menu.handle() -async def _(bot: V11Bot, matcher: Matcher, arg: V11Msg = CommandArg()): +async def _(bot: Bot, matcher: Matcher, arg: Message = CommandArg()): msg = arg.extract_plain_text().strip() if not msg: # 参数为空,主菜单 - await matcher.finish(V11MsgSeg.image(base64_str)) + await matcher.finish(MessageSegment.image(base64_str)) match_result = re.match(r"^(?P.*?)$", msg) if not match_result: @@ -76,7 +76,7 @@ async def _(bot: V11Bot, matcher: Matcher, arg: V11Msg = CommandArg()): if isinstance(result, str) and not result.startswith("base64://"): await matcher.finish(result) - if isinstance(bot, V11Bot): - await matcher.finish(V11MsgSeg.image(result)) + if isinstance(bot, Bot): + await matcher.finish(MessageSegment.image(result)) else: raise NotImplementedError diff --git a/src/plugins/addfirend/__init__.py b/src/plugins/addfirend/__init__.py index 176345f..c2880fb 100644 --- a/src/plugins/addfirend/__init__.py +++ b/src/plugins/addfirend/__init__.py @@ -2,8 +2,10 @@ import time from nonebot import get_bots, get_driver, logger, on_request -from nonebot.adapters.onebot.v11 import ( +from nonebot.adapters.milky import ( Bot, +) +from nonebot.adapters.milky.event import ( FriendRequestEvent, GroupRequestEvent, RequestEvent, @@ -28,7 +30,7 @@ async def check_bot_priority_in_group( 返回: (是否应该进群, 原因) """ # 跳过免疫群组 - if group_id == 966016220 or group_id == 713478803: + if group_id in {966016220, 713478803}: return True, "免疫群组,直接进入" bots = get_bots() @@ -93,26 +95,23 @@ def format_time(_time: int) -> str: @addfriend.handle() async def _(bot: Bot, event: RequestEvent): if isinstance(event, GroupRequestEvent): - if event.sub_type != "invite": + if event.data.request_type != "invite": return # 获取好友列表对比 friend_list = await bot.get_friend_list() - if event.user_id not in [friend["user_id"] for friend in friend_list]: + if event.data.initiator_id not in [friend.user_id for friend in friend_list]: try: - await bot.send_private_msg( - user_id=event.user_id, + await bot.send_private_message( + user_id=event.data.initiator_id, message="我们还是先成为好友再带我去别的地方吧~", ) except Exception as e: - logger.warning(f"发送私信失败 (用户: {event.user_id}): {e}") + logger.warning(f"发送私信失败 (用户: {event.data.initiator_id}): {e}") return - nickname = (await bot.get_stranger_info(user_id=event.user_id, no_cache=True))[ - "nickname" - ] try: group_info = await bot.get_group_info( - group_id=event.group_id, no_cache=True + group_id=event.data.group_id, no_cache=True ) except Exception: logger.exception("获取群信息失败") @@ -121,14 +120,14 @@ async def _(bot: Bot, event: RequestEvent): # 检查机器人优先级 current_bot_id = int(bot.self_id) can_join, priority_reason = await check_bot_priority_in_group( - event.group_id, current_bot_id + event.data.group_id, current_bot_id ) # 基本条件检查:人数 > 15 或者是免疫群组 basic_approve = ( - group_info["member_count"] > 15 - or event.group_id == 966016220 - or event.group_id == 713478803 + group_info.member_count > 15 + or event.data.group_id == 966016220 + or event.data.group_id == 713478803 ) # 最终决定:基本条件 AND 优先级检查 @@ -137,13 +136,11 @@ async def _(bot: Bot, event: RequestEvent): # 构建消息 msg = ( "⚠收到一条拉群邀请:\n" - f"flag: {event.flag}\n" - f"user: {event.user_id}\n" - f"name: {nickname}\n" - f"group: {event.group_id}\n" - f"name: {group_info['group_name']}\n" + f"user: {event.data.initiator_id}\n" + f"group: {event.data.group_id}\n" + f"name: {group_info.name}\n" f"time: {format_time(event.time)}\n" - f"人数: {group_info['member_count']}\n" + f"人数: {group_info.member_count}\n" f"机器人优先级检查: {priority_reason}\n" f"自动同意/拒绝: {approve}\n" ) @@ -157,45 +154,43 @@ async def _(bot: Bot, event: RequestEvent): reasons.append(f"优先级冲突 ({priority_reason})") msg += f"拒绝原因: {', '.join(reasons)}\n" - msg += f"验证信息:\n{event.comment}" + msg += f"验证信息:\n{event.data.comment}" # 设置群邀请结果 - await bot.set_group_add_request( - flag=event.flag, sub_type="invite", approve=approve - ) + await bot.accept_request(request_id=event.data.request_id) # 如果拒绝进群,发送说明消息 if not approve: try: if not can_join: - await bot.send_private_msg( - user_id=event.user_id, + await bot.send_private_message( + user_id=event.data.initiator_id, message=f"抱歉,由于群内已有其他机器人且优先级更高,无法进入该群。原因: {priority_reason}", ) else: - await bot.send_private_msg( - user_id=event.user_id, + await bot.send_private_message( + user_id=event.data.initiator_id, message="由于机器人的群数量过多,对新的群要求人数超过15人以上,请见谅!", ) except Exception as e: - logger.warning(f"发送拒绝私信失败 (用户: {event.user_id}): {e}") + logger.warning( + f"发送拒绝私信失败 (用户: {event.data.initiator_id}): {e}" + ) elif isinstance(event, FriendRequestEvent): - nickname = (await bot.get_stranger_info(user_id=event.user_id, no_cache=True))[ - "nickname" - ] approve = True msg = ( "⚠收到一条好友请求:\n" - f"flag: {event.flag}\n" - f"user: {event.user_id}\n" - f"name: {nickname}\n" + f"user: {event.data.initiator_id}\n" f"time: {format_time(event.time)}\n" f"自动同意/拒绝: {approve}\n" f"验证信息:\n" - f"{event.comment}" + f"{event.data.comment}" ) - await bot.set_friend_add_request(flag=event.flag, approve=approve) + if approve: + await bot.accept_request(request_id=event.data.request_id) + else: + await bot.reject_request(request_id=event.data.request_id) else: return for super_id in SUPERUSER_list: - await bot.send_private_msg(user_id=int(super_id), message=msg) + await bot.send_private_message(user_id=int(super_id), message=msg) diff --git a/src/plugins/cave/__init__.py b/src/plugins/cave/__init__.py index c3f2a9e..fd4de07 100644 --- a/src/plugins/cave/__init__.py +++ b/src/plugins/cave/__init__.py @@ -1,23 +1,25 @@ import random from nonebot import get_driver, logger, on_command -from nonebot.adapters.onebot.v11 import ( +from nonebot.adapters.milky import ( Bot, - GroupMessageEvent, Message, MessageEvent, - PrivateMessageEvent, + MessageSegment, ) -from nonebot.adapters.onebot.v11.helpers import extract_image_urls +from nonebot.adapters.milky.event import FriendMessageEvent +from nonebot.adapters.milky.message import OutgoingForwardedMessage from nonebot.params import CommandArg from nonebot.permission import SUPERUSER from nonebot.plugin import PluginMetadata from nonebot_plugin_orm import get_session from sqlalchemy import delete, select +from U1.utils.utils import extract_image_urls + from ..coin.api import subtract_coin from .models import cave_models -from .tool import is_image_message +from .tool import process_image_message nickname_list = list(get_driver().config.nickname) Bot_NICKNAME = nickname_list[0] if nickname_list else "bot" @@ -123,7 +125,7 @@ async def condition(event: MessageEvent, key: str) -> tuple[bool, str | None]: urllist = extract_image_urls(event.get_message()) if len(urllist) > 1: return False, "呃,投稿只能包含一张图片诶~\n再斟酌一下你的投稿内容吧~" - if not isinstance(event, PrivateMessageEvent): + if not isinstance(event, FriendMessageEvent): return False, "还是请来私聊我投稿罢~" if not key: return ( @@ -137,76 +139,105 @@ async def condition(event: MessageEvent, key: str) -> tuple[bool, str | None]: @cave_add.handle() async def _(bot: Bot, event: MessageEvent): - is_image = await is_image_message(event) - details = is_image[1] if is_image[0] else str(event.get_message()) + has_image, clean_text, base64_images = await process_image_message(event) + details = clean_text if has_image else str(event.get_message()) details = details.replace("投稿", "", 1).strip() result = await condition(event, details) if result[0] is False: # 审核 await cave_add.finish(result[1]) # 扣除次元币 - user_id = str(event.user_id) - success, remaining_coin = await subtract_coin(user_id, 200) + user_id = event.data.sender_id + success, remaining_coin = await subtract_coin(str(user_id), 200) if not success: await cave_add.finish( f"投稿需要消耗 200 次元币,您当前只有 {remaining_coin:.1f} 次元币,余额不足!" ) async with get_session() as session: - caves = cave_models(details=details, user_id=event.user_id) + caves = cave_models(details=details, user_id=user_id, img_base64=base64_images) session.add(caves) await session.commit() await session.refresh(caves) - result = f"[投稿成功 #{caves.id}]\n" - result += f"{caves.details}\n" - result += "————————————\n" - result += f"投稿时间: {caves.time.strftime('%Y-%m-%d %H:%M:%S')}\n" - result += f"消耗次元币: 200 | 余额: {remaining_coin:.1f}" + # 构建包含图片的消息序列 + img_seq: list[MessageSegment] = [] + + # 投稿内容部分 + content_part = f"[投稿成功 #{caves.id}]\n{caves.details}" + img_seq.append(MessageSegment.text(content=content_part)) + + # 图片部分 + if caves.img_base64: + img_seq.extend( + [ + MessageSegment.image(base64=img_base64) + for img_base64 in caves.img_base64 + ] + ) + + # 分割线和信息部分 + info_part = f"\n————————————\n投稿时间: {caves.time.strftime('%Y-%m-%d %H:%M:%S')}\n消耗次元币: 200 | 余额: {remaining_coin:.1f}" + img_seq.append(MessageSegment.text(content=info_part)) + for i in SUPERUSER_list: - await bot.send_private_msg( + await bot.send_private_message( user_id=int(i), message=Message(f"来自用户{event.get_user_id()}\n{result}"), ) - await cave_add.finish(Message(f"{result}")) + await cave_add.finish(message=Message(img_seq)) @cave_am_add.handle() async def _(bot: Bot, event: MessageEvent): "匿名发布回声洞" - is_image = await is_image_message(event) - details = is_image[1] if is_image[0] else str(event.get_message()) + has_image, clean_text, base64_images = await process_image_message(event) + details = clean_text if has_image else str(event.get_message()) details = details.replace("匿名投稿", "", 1).strip() result = await condition(event, details) if result[0] is False: # 审核 await cave_am_add.finish(result[1]) # 扣除次元币 - user_id = str(event.user_id) - success, remaining_coin = await subtract_coin(user_id, 400) + user_id = event.data.sender_id + success, remaining_coin = await subtract_coin(str(user_id), 400) if not success: await cave_am_add.finish( f"匿名投稿需要消耗 400 次元币,您当前只有 {remaining_coin:.1f} 次元币,余额不足!" ) async with get_session() as session: - caves = cave_models(details=details, user_id=event.user_id, anonymous=True) + caves = cave_models( + details=details, user_id=user_id, anonymous=True, img_base64=base64_images + ) session.add(caves) await session.commit() - await session.refresh(caves) + await session.refresh(caves) # 构建包含图片的消息序列 + img_seq: list[MessageSegment] = [] + + # 投稿内容部分 + content_part = f"[匿名投稿成功 #{caves.id}]\n{caves.details}" + img_seq.append(MessageSegment.text(content=content_part)) + + # 图片部分 + if caves.img_base64: + img_seq.extend( + [ + MessageSegment.image(base64=img_base64) + for img_base64 in caves.img_base64 + ] + ) + + # 分割线和信息部分 + info_part = f"\n————————————\n投稿时间: {caves.time.strftime('%Y-%m-%d %H:%M:%S')}\n匿名投稿会保存用户信息但其他用户无法看到作者\n消耗次元币: 400 | 余额: {remaining_coin:.1f}" + img_seq.append(MessageSegment.text(content=info_part)) - result = f"[匿名投稿成功 #{caves.id}]\n" - result += f"{caves.details}\n" - result += "————————————\n" - result += f"投稿时间: {caves.time.strftime('%Y-%m-%d %H:%M:%S')}\n" - result += "匿名投稿会保存用户信息但其他用户无法看到作者\n" - result += f"消耗次元币: 400 | 余额: {remaining_coin:.1f}" for i in SUPERUSER_list: - await bot.send_private_msg( + await bot.send_private_message( user_id=int(i), message=Message(f"来自用户{event.get_user_id()}\n{result}"), ) - await cave_am_add.finish(Message(f"{result}")) + await cave_am_add.finish(message=Message(img_seq)) @cave_del.handle() @@ -230,9 +261,9 @@ async def _(bot: Bot, event: MessageEvent): await cave_del.finish("没有这个序号的投稿") # 判断是否是超级用户或者是投稿人 - if str(event.user_id) in SUPERUSER_list: + if str(event.data.sender_id) in SUPERUSER_list: try: - await bot.send_private_msg( + await bot.send_private_message( user_id=data.user_id, message=Message( f"您的投稿 #{key} 已被管理员删除\n内容: {data.details}\n删除原因: {reason}" @@ -243,24 +274,51 @@ async def _(bot: Bot, event: MessageEvent): f"回声洞删除投稿私聊通知失败,投稿人 id:{data.user_id}" ) await cave_del.send("删除失败,私聊通知失败") - elif event.user_id == data.user_id: - result_content = data.details + elif event.data.sender_id == data.user_id: await session.delete(data) await session.commit() - await cave_del.finish( - Message(f"[删除成功] 编号 {key} 的投稿已删除\n内容: {result_content}") - ) + img_seq: list[MessageSegment] = [] + + # 投稿内容部分 + content_part = f"[删除成功] 编号 {key} 的投稿已删除\n内容: {data.details}" + img_seq.append(MessageSegment.text(content=content_part)) + + # 图片部分 + if data.img_base64: + img_seq.extend( + [ + MessageSegment.image(base64=img_base64) + for img_base64 in data.img_base64 + ] + ) + + await cave_del.finish(message=Message(img_seq)) else: await cave_del.finish("您没有权限删除此投稿") - result_content = data.details await session.delete(data) await session.commit() - await cave_del.finish( - Message( - f"[删除成功] 编号 {key} 的投稿已删除\n内容: {result_content}\n删除原因: {reason}" + + img_seq: list[MessageSegment] = [] + + # 投稿内容部分 + content_part = f"[删除成功] 编号 {key} 的投稿已删除\n内容: {data.details}" + img_seq.append(MessageSegment.text(content=content_part)) + + # 图片部分 + if data.img_base64: + img_seq.extend( + [ + MessageSegment.image(base64=img_base64) + for img_base64 in data.img_base64 + ] ) - ) + + # 分割线和删除原因部分 + reason_part = f"\n————————————\n删除原因: {reason}" + img_seq.append(MessageSegment.text(content=reason_part)) + + await cave_main.finish(message=Message(img_seq)) @cave_main.handle() @@ -280,13 +338,27 @@ async def _(args: Message = CommandArg()): displayname = ( "匿名用户" if random_cave.anonymous else f"用户{random_cave.user_id}" ) - result = f"[回声洞 #{random_cave.id}]\n" - result += f"{random_cave.details}\n" - result += "————————————\n" - result += f"投稿人:{displayname}\n" - result += f"时间:{random_cave.time.strftime('%Y-%m-%d %H:%M:%S')}\n" - result += "\n私聊机器人可以投稿:\n投稿 [内容] | 匿名投稿 [内容]" - await cave_main.finish(Message(result)) + + img_seq: list[MessageSegment] = [] + + # 投稿内容部分 + content_part = f"[回声洞 #{random_cave.id}]\n{random_cave.details}" + img_seq.append(MessageSegment.text(content=content_part)) + + # 图片部分 + if random_cave.img_base64: + img_seq.extend( + [ + MessageSegment.image(base64=img_base64) + for img_base64 in random_cave.img_base64 + ] + ) + + # 分割线和信息部分 + info_part = f"\n————————————\n投稿人:{displayname}\n时间:{random_cave.time.strftime('%Y-%m-%d %H:%M:%S')}\n\n私聊机器人可以投稿:\n投稿 [内容] | 匿名投稿 [内容]" + img_seq.append(MessageSegment.text(content=info_part)) + + await cave_main.finish(message=Message(img_seq)) else: # 验证输入是否为有效的数字 try: @@ -303,68 +375,80 @@ async def _(args: Message = CommandArg()): # 判断是否是匿名 displayname = "匿名用户" if cave.anonymous else f"用户{cave.user_id}" - result = f"[回声洞 #{cave.id}]\n" - result += f"{cave.details}\n" - result += "————————————\n" - result += f"投稿人: {displayname}\n" - result += f"时间: {cave.time.strftime('%Y-%m-%d %H:%M:%S')}\n" - result += "\n私聊机器人可以投稿:\n投稿 [内容] | 匿名投稿 [内容]" - await cave_main.finish(Message(result)) + + img_seq: list[MessageSegment] = [] + + # 投稿内容部分 + content_part = f"[回声洞 #{cave.id}]\n{cave.details}" + img_seq.append(MessageSegment.text(content=content_part)) + + # 图片部分 + if cave.img_base64: + img_seq.extend( + [ + MessageSegment.image(base64=img_base64) + for img_base64 in cave.img_base64 + ] + ) + + # 分割线和信息部分 + info_part = f"\n————————————\n投稿人: {displayname}\n时间: {cave.time.strftime('%Y-%m-%d %H:%M:%S')}\n\n私聊机器人可以投稿:\n投稿 [内容] | 匿名投稿 [内容]" + img_seq.append(MessageSegment.text(content=info_part)) + + await cave_main.finish(message=Message(img_seq)) @cave_history.handle() async def _(bot: Bot, event: MessageEvent): # 查询 userid 写所有数据 async with get_session() as session: - stmt = select(cave_models).where(cave_models.user_id == event.user_id) + stmt = select(cave_models).where(cave_models.user_id == event.data.sender_id) result = await session.execute(stmt) all_caves = result.scalars().all() - msg_list = [ - "您的回声洞投稿记录:", - *[ - Message( - f"[编号 #{i.id}]\n" - f"{i.details}\n" - f"————————————\n" - f"投稿时间: {i.time.strftime('%Y-%m-%d %H:%M:%S')}" + if not all_caves: + await cave_history.finish("您还没有任何投稿记录") + + # 构造转发消息 + messages: list[list[MessageSegment]] = [] + messages.append([MessageSegment.text("您的回声洞投稿记录:")]) + + # 添加每个投稿记录 + for i in all_caves: + # 为每个投稿构建消息段列表 + msg_segments: list[MessageSegment] = [] + + # 投稿内容部分 + content_part = f"[编号 #{i.id}]\n{i.details}" + msg_segments.append(MessageSegment.text(content=content_part)) + + # 图片部分 + if i.img_base64: + msg_segments.extend( + [ + MessageSegment.image(base64=img_base64) + for img_base64 in i.img_base64 + ] ) - for i in all_caves - ], - ] - await send_forward_msg(bot, event, Bot_NICKNAME, bot.self_id, msg_list) - - -async def send_forward_msg( - bot: Bot, - event: MessageEvent, - name: str, - uin: str, - msgs: list, -) -> dict: - """ - 发送转发消息的异步函数。 - 参数: - bot (Bot): 机器人实例 - event (MessageEvent): 消息事件 - name (str): 转发消息的名称 - uin (str): 转发消息的 UIN - msgs (list): 转发的消息列表 + # 分割线和信息部分 + info_part = ( + f"\n————————————\n投稿时间: {i.time.strftime('%Y-%m-%d %H:%M:%S')}" + ) + msg_segments.append(MessageSegment.text(content=info_part)) - 返回: - dict: API 调用结果 - """ + messages.append(msg_segments) - def to_json(msg: Message): - return {"type": "node", "data": {"name": name, "uin": uin, "content": msg}} + forward_msgs = [ + OutgoingForwardedMessage( + name=Bot_NICKNAME, + user_id=int(bot.self_id), + segments=msg_segments, + ) + for msg_segments in messages + ] - messages = [to_json(msg) for msg in msgs] - if isinstance(event, GroupMessageEvent): - return await bot.send_group_forward_msg( - group_id=event.group_id, messages=messages - ) - return await bot.send_private_forward_msg(user_id=event.user_id, messages=messages) + await cave_history.finish(MessageSegment.forward(forward_msgs)) def extract_deletion_reason(text): diff --git a/src/plugins/cave/models.py b/src/plugins/cave/models.py index 8989203..41dfa86 100644 --- a/src/plugins/cave/models.py +++ b/src/plugins/cave/models.py @@ -6,26 +6,18 @@ require("nonebot_plugin_orm") from nonebot_plugin_orm import Model -from sqlalchemy import BigInteger, Boolean, DateTime, Integer +from sqlalchemy import JSON, BigInteger, Boolean, DateTime, Integer from sqlalchemy.dialects.mysql import LONGTEXT -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm.properties import MappedColumn class cave_models(Model): - """ - Model representing the cave_models table in the database. - - Attributes: - id (int): The primary key of the cave model. - details (str): The details of the cave model. - user_id (int): The user ID associated with the cave model. - time (datetime): The timestamp when the cave model was created. - """ - __tablename__ = "cave_models" - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - details: Mapped[str] = mapped_column(LONGTEXT) - user_id: Mapped[int] = mapped_column(BigInteger) - time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now) - anonymous: Mapped[bool] = mapped_column(Boolean, default=False) + id: MappedColumn[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + details: MappedColumn[str] = mapped_column(LONGTEXT) + img_base64: MappedColumn[list[str]] = mapped_column(JSON, default=[]) + user_id: MappedColumn[int] = mapped_column(BigInteger) + time: MappedColumn[datetime] = mapped_column(DateTime, default=datetime.now) + anonymous: MappedColumn[bool] = mapped_column(Boolean, default=False) diff --git a/src/plugins/cave/tool.py b/src/plugins/cave/tool.py index 09f147d..9e3d473 100644 --- a/src/plugins/cave/tool.py +++ b/src/plugins/cave/tool.py @@ -3,8 +3,8 @@ import ssl import aiohttp -from nonebot.adapters.onebot.v11 import MessageEvent -from nonebot.adapters.onebot.v11.utils import unescape +from nonebot.adapters.milky import MessageEvent +from nonebot.adapters.milky.message import Image, IncomingImageData async def url_to_base64(image_url) -> str: @@ -24,9 +24,7 @@ def extract_image_url(message: str) -> str: - message (str): 消息文本。 Returns: - - tuple: 包含两个元素: - - is_image (bool): 是否找到图片 URL。 - - image_url (str): 图片 URL。 + - str: 图片 URL。 """ url_pattern = r"url=(https?[^,]+)" if image_match := re.search(url_pattern, message): @@ -39,71 +37,59 @@ def extract_image_url(message: str) -> str: return "" -def replace_cq_with_caption(text: str, base64_image: str) -> str: +def extract_base64_from_cq_codes(text: str) -> list[str]: """ - 将文本中的 [CQ:...] 标签替换为指定的描述。 + 从文本中提取所有 [CQ:image,file=base64://...] 中的 base64 数据。 参数: - - text: 包含 [CQ:...] 标签的原始文本 - - caption: 用于替换 [CQ:...] 的描述文本 + - text: 包含 CQ 码的文本 返回值: - - 替换后的文本 + - list[str]: 所有提取到的 base64 数据列表 """ - # 反转义 - text = unescape(text) - potential_matches = re.finditer(r"\[CQ:image", text) - result = [] - last_pos = 0 - replacement_template = f"[CQ:image,file=base64://{base64_image}]" - - for match in potential_matches: - start = match.start() - result.append(text[last_pos:start]) # 添加上次匹配结束到这次匹配开始的部分 - # 从匹配位置开始逐字符解析,寻找完整的 [CQ:image,...] - i = start - depth = 0 - while i < len(text): - if text[i] == "[": - depth += 1 - elif text[i] == "]": - depth -= 1 - if depth == 0: - # 匹配到完整的 [CQ:image,...] - result.append(replacement_template) - last_pos = i + 1 # 更新最后的结束位置 - break - i += 1 - else: - # 如果没能闭合,直接保留原始文本 - last_pos = start - - # 添加剩余未处理的部分 - result.append(text[last_pos:]) - return "".join(result) - - -async def is_image_message( - data: MessageEvent, is_cq_code: bool = False -) -> tuple[bool, str]: - if is_cq_code: - image_url = extract_image_url(str(data.message)) - return ( - ( - True, - replace_cq_with_caption( - str(data.message), await url_to_base64(image_url) - ), - ) - if image_url - else (False, "") - ) + base64_images = [] + # 匹配 [CQ:image,file=base64://xxxxx] 格式 + cq_pattern = r"\[CQ:image,file=base64://([^]]+)\]" + matches = re.findall(cq_pattern, text) + base64_images.extend(matches) + + return base64_images + + +def remove_cq_codes(text: str) -> str: + """ + 移除文本中的所有 CQ 码。 + + 参数: + - text: 包含 CQ 码的文本 + + 返回值: + - str: 移除 CQ 码后的纯文本 + """ + # 移除所有 [CQ:...] 格式的代码 + return re.sub(r"\[CQ:[^\]]+\]", "", text).strip() + + +async def process_image_message(data: MessageEvent) -> tuple[bool, str, list[str]]: + """ + 处理图片消息,返回是否包含图片、处理后的文本和图片base64数据。 + + 参数: + - data: 消息事件数据 + + 返回值: + - tuple: (是否包含图片, 处理后的文本, 图片base64数据列表) + """ + has_image = False + base64_images = [] for msg in data.message: - print(msg) - if msg.type == "image" and (image_url := msg.data.get("url", "")): - return True, replace_cq_with_caption( - str(data.message), await url_to_base64(image_url) - ) + if isinstance(msg, Image) and msg.data is IncomingImageData: + has_image = True + base64_data = await url_to_base64(msg.data["temp_url"]) + base64_images.append(base64_data) + + # 移除消息中的所有 CQ 码 + clean_text = remove_cq_codes(str(data.message)) - return False, "" + return has_image, clean_text, base64_images diff --git a/src/plugins/coin/__init__.py b/src/plugins/coin/__init__.py index 9fca8d4..3fd5b11 100644 --- a/src/plugins/coin/__init__.py +++ b/src/plugins/coin/__init__.py @@ -1,5 +1,5 @@ from nonebot import on_command -from nonebot.adapters.onebot.v11 import MessageEvent +from nonebot.adapters import Event from nonebot.params import CommandArg from nonebot_plugin_orm import get_session from sqlalchemy import select @@ -10,14 +10,13 @@ @coin.handle() -async def handle_coin(event: MessageEvent, args=CommandArg()): - user_id = str(event.user_id) +async def handle_coin(event: Event, args=CommandArg()): + user_id = event.get_user_id() async with get_session() as session: result = await session.execute( select(CoinRecord).where(CoinRecord.user_id == user_id) ) - record = result.scalar_one_or_none() - if record: + if record := result.scalar_one_or_none(): await coin.finish( f"你当前拥有金币:{record.coin}\n历史总金币:{record.count_coin}" ) diff --git a/src/plugins/coin/models.py b/src/plugins/coin/models.py index 9e1050f..111cb2b 100644 --- a/src/plugins/coin/models.py +++ b/src/plugins/coin/models.py @@ -1,12 +1,13 @@ from nonebot_plugin_orm import Model from sqlalchemy import Float, Integer, String -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm.properties import MappedColumn class CoinRecord(Model): __tablename__ = "coin_coinrecord" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - user_id: Mapped[str] = mapped_column(String(32)) - coin: Mapped[float] = mapped_column(Float, default=0) - count_coin: Mapped[float] = mapped_column(Float, default=0) + id: MappedColumn[int] = mapped_column(Integer, primary_key=True) + user_id: MappedColumn[str] = mapped_column(String(32)) + coin: MappedColumn[float] = mapped_column(Float, default=0) + count_coin: MappedColumn[float] = mapped_column(Float, default=0) diff --git a/src/plugins/fishing/__init__.py b/src/plugins/fishing/__init__.py index f311dbd..fb87fbb 100644 --- a/src/plugins/fishing/__init__.py +++ b/src/plugins/fishing/__init__.py @@ -2,22 +2,25 @@ import random from nonebot import get_driver, on_command, on_fullmatch -from nonebot.adapters import Event, Message -from nonebot.adapters.onebot.v11 import ( +from nonebot.adapters import Event +from nonebot.adapters.milky import ( Bot, - GroupMessageEvent, + Message, MessageEvent, - PrivateMessageEvent, + MessageSegment, ) -from nonebot.adapters.onebot.v11.helpers import ( - Cooldown, - CooldownIsolateLevel, -) -from nonebot.adapters.onebot.v11.permission import GROUP_ADMIN, GROUP_OWNER +from nonebot.adapters.milky.event import FriendMessageEvent, GroupMessageEvent +from nonebot.adapters.milky.message import OutgoingForwardedMessage from nonebot.params import CommandArg from nonebot.permission import SUPERUSER from nonebot.plugin import PluginMetadata +from U1.utils.permission import GROUP_ADMIN, GROUP_OWNER +from U1.utils.token_bucket import ( + Cooldown, + CooldownIsolateLevel, +) + from .config import Config, config from .data_source import ( choice, @@ -75,16 +78,21 @@ async def _update(event: Event): ) ] ) -async def _fishing(event: GroupMessageEvent | PrivateMessageEvent, bot: Bot): +async def _fishing(event: GroupMessageEvent | FriendMessageEvent, bot: Bot): """钓鱼""" if isinstance(event, GroupMessageEvent) and not await get_switch_fish(event): await fishing.finish("钓鱼在本群处于关闭状态,请看菜单重新打开") user_id = event.get_user_id() fish = await choice(user_id=user_id) - await bot.set_msg_emoji_like( - message_id=event.message_id, - emoji_id="127881", + + await fishing.send( + Message( + [ + MessageSegment.reply(event.data.message_seq), + MessageSegment.text(f"* {Bot_NICKNAME} 正在钓鱼..."), + ], + ), ) fish_name = fish[0] @@ -101,14 +109,29 @@ async def _fishing(event: GroupMessageEvent | PrivateMessageEvent, bot: Bot): result = f"* 你钓到了一条 {get_quality(fish_name)} {fish_name},长度为 {fish_long}cm!" await save_fish(user_id, fish_name, fish_long) await asyncio.sleep(sleep_time) - await fishing.finish(result, reply_message=True) + await fishing.finish( + Message( + [ + MessageSegment.reply(event.data.message_seq), + MessageSegment.text(result), + ], + ), + ) @stats.handle() -async def _stats(event: Event): +async def _stats(event: MessageEvent): """统计信息""" user_id = event.get_user_id() - await stats.finish(await get_stats(user_id), reply_message=True) + + await stats.finish( + Message( + [ + MessageSegment.reply(event.data.message_seq), + MessageSegment.text(await get_stats(user_id)), + ], + ) + ) @backpack.handle() @@ -116,70 +139,79 @@ async def _backpack(bot: Bot, event: MessageEvent): """背包""" user_id = event.get_user_id() fmt = await get_backpack(user_id) - return ( - await backpack.send(fmt) - if isinstance(fmt, str) - else await send_forward_msg(bot, event, Bot_NICKNAME, bot.self_id, fmt) - ) + + if isinstance(fmt, str): + await backpack.finish(fmt) + else: + messages: list[MessageSegment] = [] + # 将每个品质的信息转换为消息段 + messages.extend(MessageSegment.text(quality_info) for quality_info in fmt) + # 创建转发消息 + forward_msg = [ + OutgoingForwardedMessage( + name=Bot_NICKNAME, + user_id=int(bot.self_id), + segments=[messages_seq], + ) + for messages_seq in messages + ] + + await backpack.finish(MessageSegment.forward(forward_msg)) @sell.handle() -async def _sell(event: Event, arg: Message = CommandArg()): +async def _sell(event: MessageEvent, arg: Message = CommandArg()): """卖鱼""" msg = arg.extract_plain_text() user_id = event.get_user_id() if msg == "": await sell.finish("请输入要卖出的鱼的名字,如:卖鱼 小鱼") if msg == "全部": - await sell.finish(await sell_all_fish(user_id), reply_message=True) + await sell.finish( + Message( + [ + MessageSegment.reply(event.data.message_seq), + MessageSegment.text(await sell_all_fish(user_id)), + ], + ) + ) if msg in fish_quality.keys(): # 判断是否是为品质 - await sell.finish(await sell_quality_fish(user_id, msg), reply_message=True) - await sell.finish(await sell_fish(user_id, msg), reply_message=True) + await sell.finish( + Message( + [ + MessageSegment.reply(event.data.message_seq), + MessageSegment.text(await sell_quality_fish(user_id, msg)), + ], + ) + ) + await sell.finish( + Message( + [ + MessageSegment.reply(event.data.message_seq), + MessageSegment.text(await sell_fish(user_id, msg)), + ], + ) + ) @balance.handle() -async def _balance(event: Event): +async def _balance(event: MessageEvent): """余额""" user_id = event.get_user_id() - await balance.finish(await get_balance(user_id), reply_message=True) + await balance.finish( + Message( + [ + MessageSegment.reply(event.data.message_seq), + MessageSegment.text(await get_balance(user_id)), + ], + ) + ) @switch.handle() -async def _switch(event: GroupMessageEvent | PrivateMessageEvent): +async def _switch(event: GroupMessageEvent | FriendMessageEvent): """钓鱼开关""" if await switch_fish(event): await switch.finish("钓鱼 已打开") else: await switch.finish("钓鱼 已关闭") - - -async def send_forward_msg( - bot: Bot, - event: MessageEvent, - name: str, - uin: str, - msgs: list, -) -> dict: - """ - 发送转发消息的异步函数。 - - 参数: - bot (Bot): 机器人实例 - event (MessageEvent): 消息事件 - name (str): 转发消息的名称 - uin (str): 转发消息的 UIN - msgs (list): 转发的消息列表 - - 返回: - dict: API 调用结果 - """ - - def to_json(msg: Message): - return {"type": "node", "data": {"name": name, "uin": uin, "content": msg}} - - messages = [to_json(msg) for msg in msgs] - if isinstance(event, GroupMessageEvent): - return await bot.send_group_forward_msg( - group_id=event.group_id, messages=messages - ) - return await bot.send_private_forward_msg(user_id=event.user_id, messages=messages) diff --git a/src/plugins/fishing/data_source.py b/src/plugins/fishing/data_source.py index d4705a6..3f601e1 100644 --- a/src/plugins/fishing/data_source.py +++ b/src/plugins/fishing/data_source.py @@ -2,7 +2,7 @@ import time import ujson as json -from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent +from nonebot.adapters.milky.event import FriendMessageEvent, GroupMessageEvent from nonebot_plugin_orm import get_session from sqlalchemy import select, update @@ -174,9 +174,7 @@ async def save_fish(user_id: str, fish_name: str, fish_long: int) -> None: async with get_session() as session: stmt = select(FishingRecord).where(FishingRecord.user_id == user_id) result = await session.execute(stmt) - record = result.scalar_one_or_none() - - if record: + if record := result.scalar_one_or_none(): loads_fishes: dict[str, list[int]] = json.loads(record.fishes) try: loads_fishes[fish_name].append(fish_long) @@ -186,7 +184,6 @@ async def save_fish(user_id: str, fish_name: str, fish_long: int) -> None: record.time = time_now + fishing_limit record.frequency += 1 record.fishes = dump_fishes - await session.commit() else: data = {fish_name: [fish_long]} dump_fishes = json.dumps(data) @@ -197,7 +194,7 @@ async def save_fish(user_id: str, fish_name: str, fish_long: int) -> None: fishes=dump_fishes, ) session.add(new_record) - await session.commit() + await session.commit() async def get_stats(user_id: str) -> str: @@ -205,9 +202,7 @@ async def get_stats(user_id: str) -> str: async with get_session() as session: stmt = select(FishingRecord).where(FishingRecord.user_id == user_id) result = await session.execute(stmt) - fishing_record = result.scalar_one_or_none() - - if fishing_record: + if fishing_record := result.scalar_one_or_none(): total_length = sum( sum(fish_long) for fish_long in json.loads(fishing_record.fishes).values() @@ -226,9 +221,7 @@ async def get_backpack(user_id: str) -> str | list: async with get_session() as session: stmt = select(FishingRecord).where(FishingRecord.user_id == user_id) result = await session.execute(stmt) - fishes_record = result.scalar_one_or_none() - - if fishes_record: + if fishes_record := result.scalar_one_or_none(): load_fishes = json.loads(fishes_record.fishes) if not load_fishes: return "你的背包里空无一物" @@ -266,9 +259,7 @@ async def sell_quality_fish(user_id: str, quality: str) -> str: async with get_session() as session: stmt = select(FishingRecord).where(FishingRecord.user_id == user_id) result = await session.execute(stmt) - fishes_record = result.scalar_one_or_none() - - if fishes_record: + if fishes_record := result.scalar_one_or_none(): load_fishes = json.loads(fishes_record.fishes) if not load_fishes: return "你的背包里空无一物" @@ -303,9 +294,7 @@ async def sell_all_fish(user_id: str) -> str: async with get_session() as session: stmt = select(FishingRecord).where(FishingRecord.user_id == user_id) result = await session.execute(stmt) - fishes_record = result.scalar_one_or_none() - - if fishes_record: + if fishes_record := result.scalar_one_or_none(): load_fishes = json.loads(fishes_record.fishes) if not load_fishes: return "你的背包里空无一物" @@ -340,9 +329,7 @@ async def sell_fish(user_id: str, fish_name: str) -> str: async with get_session() as session: stmt = select(FishingRecord).where(FishingRecord.user_id == user_id) result = await session.execute(stmt) - fishes_record = result.scalar_one_or_none() - - if fishes_record: + if fishes_record := result.scalar_one_or_none(): load_fishes = json.loads(fishes_record.fishes) if fish_name not in load_fishes: return "你的背包里没有这种鱼" @@ -370,34 +357,32 @@ async def get_balance(user_id: str) -> str: return f"你有 {coin} {fishing_coin_name}" if coin else "你什么也没有 :)" -async def switch_fish(event: GroupMessageEvent | PrivateMessageEvent) -> bool: +async def switch_fish(event: GroupMessageEvent | FriendMessageEvent) -> bool: """钓鱼开关切换,没有就创建""" - if isinstance(event, PrivateMessageEvent): + if isinstance(event, FriendMessageEvent): return True async with get_session() as session: - stmt = select(FishingSwitch).where(FishingSwitch.group_id == event.group_id) + stmt = select(FishingSwitch).where(FishingSwitch.group_id == event.data.peer_id) result = await session.execute(stmt) - switch = result.scalar_one_or_none() - - if switch: + if switch := result.scalar_one_or_none(): switch.switch = not switch.switch await session.commit() return switch.switch else: - new_switch = FishingSwitch(group_id=event.group_id, switch=False) + new_switch = FishingSwitch(group_id=event.data.peer_id, switch=False) session.add(new_switch) await session.commit() return False -async def get_switch_fish(event: GroupMessageEvent | PrivateMessageEvent) -> bool: +async def get_switch_fish(event: GroupMessageEvent | FriendMessageEvent) -> bool: """获取钓鱼开关""" - if isinstance(event, PrivateMessageEvent): + if isinstance(event, FriendMessageEvent): return True async with get_session() as session: - stmt = select(FishingSwitch).where(FishingSwitch.group_id == event.group_id) + stmt = select(FishingSwitch).where(FishingSwitch.group_id == event.data.peer_id) result = await session.execute(stmt) switch = result.scalar_one_or_none() diff --git a/src/plugins/fishing/models.py b/src/plugins/fishing/models.py index 6f6843d..8dffb4c 100644 --- a/src/plugins/fishing/models.py +++ b/src/plugins/fishing/models.py @@ -4,21 +4,22 @@ from nonebot_plugin_orm import Model from sqlalchemy import Boolean, Integer, String, Text -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm.properties import MappedColumn class FishingRecord(Model): __tablename__ = "fishing_fishingrecord" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - user_id: Mapped[str] = mapped_column(String(32)) - time: Mapped[int] = mapped_column(Integer) - frequency: Mapped[int] = mapped_column(Integer) - fishes: Mapped[str] = mapped_column(Text) + id: MappedColumn[int] = mapped_column(Integer, primary_key=True) + user_id: MappedColumn[str] = mapped_column(String(32)) + time: MappedColumn[int] = mapped_column(Integer) + frequency: MappedColumn[int] = mapped_column(Integer) + fishes: MappedColumn[str] = mapped_column(Text) class FishingSwitch(Model): __tablename__ = "fishing_fishingswitch" - group_id: Mapped[int] = mapped_column(Integer, primary_key=True) - switch: Mapped[bool] = mapped_column(Boolean, default=True) + group_id: MappedColumn[int] = mapped_column(Integer, primary_key=True) + switch: MappedColumn[bool] = mapped_column(Boolean, default=True) diff --git a/src/plugins/nonebot_plugin_heweather/render_pic.py b/src/plugins/nonebot_plugin_heweather/render_pic.py index aa93cdd..deb6757 100644 --- a/src/plugins/nonebot_plugin_heweather/render_pic.py +++ b/src/plugins/nonebot_plugin_heweather/render_pic.py @@ -1,7 +1,6 @@ import platform from datetime import datetime from pathlib import Path -from typing import List from nonebot_plugin_htmlrender import template_to_pic @@ -14,9 +13,8 @@ async def render(weather: Weather) -> bytes: template_path = str(Path(__file__).parent / "templates") air = None - if weather.air: - if weather.air.now: - air = add_tag_color(weather.air.now) + if weather.air and weather.air.now: + air = add_tag_color(weather.air.now) return await template_to_pic( template_path=template_path, @@ -36,9 +34,9 @@ async def render(weather: Weather) -> bytes: ) -def add_hour_data(hourly: List[Hourly]): - min_temp = min([int(hour.temp) for hour in hourly]) - high = max([int(hour.temp) for hour in hourly]) +def add_hour_data(hourly: list[Hourly]): + min_temp = min(int(hour.temp) for hour in hourly) + high = max(int(hour.temp) for hour in hourly) low = int(min_temp - (high - min_temp)) for hour in hourly: date_time = datetime.fromisoformat(hour.fxTime) @@ -57,7 +55,7 @@ def add_hour_data(hourly: List[Hourly]): return hourly -def add_date(daily: List[Daily]): +def add_date(daily: list[Daily]): week_map = [ "周日", "周一", @@ -90,4 +88,4 @@ def add_tag_color(air: Air): "严重污染": "#D94371", } air.tag_color = color[air.category] - return air \ No newline at end of file + return air diff --git a/src/plugins/nonebot_plugin_heweather/weather_data.py b/src/plugins/nonebot_plugin_heweather/weather_data.py index 549cdcd..12d494b 100644 --- a/src/plugins/nonebot_plugin_heweather/weather_data.py +++ b/src/plugins/nonebot_plugin_heweather/weather_data.py @@ -19,7 +19,7 @@ class CityNotFoundError(Exception): ... class Weather: def __url__(self): self.url_geoapi = "https://geoapi.qweather.com/v2/city/" - if self.api_type == 2 or self.api_type == 1: + if self.api_type in [2, 1]: self.url_weather_api = "https://api.qweather.com/v7/weather/" self.url_weather_warning = "https://api.qweather.com/v7/warning/now" self.url_air = "https://api.qweather.com/v7/air/now" @@ -43,9 +43,8 @@ def __url__(self): def _forecast_days(self): self.forecast_days = QWEATHER_FORECASE_DAYS - if self.forecast_days: - if self.api_type == 0 and not (3 <= self.forecast_days <= 7): - raise ConfigError("api_type = 0 免费订阅 预报天数必须 3<= x <=7") + if self.forecast_days and (self.api_type == 0 and not (3 <= self.forecast_days <= 7)): + raise ConfigError("api_type = 0 免费订阅 预报天数必须 3<= x <=7") def __init__(self, city_name: str, api_key: str, api_type: int = 0): self.city_name = city_name @@ -90,35 +89,30 @@ async def _get_city_id(self, api_type: str = "lookup"): if res["code"] == "404": raise CityNotFoundError() elif res["code"] != "200": - raise APIError("错误! 错误代码: {}".format(res["code"]) + self.__reference) + raise APIError(f'错误! 错误代码: {res["code"]}{self.__reference}') else: self.city_name = res["location"][0]["name"] return res["location"][0]["id"] def _data_validate(self): - if self.now.code == "200" and self.daily.code == "200": - pass - else: + if self.now.code != "200" or self.daily.code != "200": raise APIError( - "错误! 请检查配置! " - f"错误代码: now: {self.now.code} " - f"daily: {self.daily.code} " - + "air: {} ".format(self.air.code if self.air else "None") - + "warning: {}".format(self.warning.code if self.warning else "None") + f"错误! 请检查配置! 错误代码: now: {self.now.code} daily: {self.daily.code} " + + f'air: {self.air.code if self.air else "None"} ' + + f'warning: {self.warning.code if self.warning else "None"}' + self.__reference ) def _check_response(self, response: Response) -> bool: - if response.status_code == 200: - logger.debug(f"{response.json()}") - return True - else: + if response.status_code != 200: raise APIError(f"Response code:{response.status_code}") + logger.debug(f"{response.json()}") + return True @property async def _now(self) -> NowApi: res = await self._get_data( - url=self.url_weather_api + "now", + url=f"{self.url_weather_api}now", params={"location": self.city_id, "key": self.apikey}, ) self._check_response(res) diff --git a/src/plugins/nonebot_plugin_multincm/__init__.py b/src/plugins/nonebot_plugin_multincm/__init__.py index 556b1db..398348d 100644 --- a/src/plugins/nonebot_plugin_multincm/__init__.py +++ b/src/plugins/nonebot_plugin_multincm/__init__.py @@ -1,10 +1,13 @@ # ruff: noqa: E402 +import asyncio + from nonebot import get_driver from nonebot.plugin import PluginMetadata, inherit_supported_adapters, require require("nonebot_plugin_alconna") require("nonebot_plugin_waiter") +require("nonebot_plugin_localstore") require("nonebot_plugin_htmlrender") from . import interaction as interaction @@ -13,7 +16,12 @@ from .interaction import load_commands driver = get_driver() -driver.on_startup(login) + + +@driver.on_startup +async def _(): + asyncio.create_task(login()) + load_commands() @@ -29,7 +37,7 @@ "▶ Bot 会自动解析你发送的网易云链接\n" if config.ncm_auto_resolve else "" ) -__version__ = "1.1.5" +__version__ = "1.2.6" __plugin_meta__ = PluginMetadata( name="MultiNCM", description="网易云多选点歌", diff --git a/src/plugins/nonebot_plugin_multincm/config.py b/src/plugins/nonebot_plugin_multincm/config.py index 11399ac..33dea88 100644 --- a/src/plugins/nonebot_plugin_multincm/config.py +++ b/src/plugins/nonebot_plugin_multincm/config.py @@ -17,6 +17,7 @@ class ConfigModel(BaseConfigModel): ncm_email: Optional[str] = None ncm_password: Optional[str] = None ncm_password_hash: Optional[str] = None + ncm_anonymous: bool = False # ui ncm_list_limit: int = 20 @@ -41,7 +42,6 @@ class ConfigModel(BaseConfigModel): ncm_resolve_cool_down_cache_size: int = 1024 ncm_card_sign_url: Optional[Annotated[str, AnyHttpUrl]] = None ncm_card_sign_timeout: int = 5 - ncm_ob_v11_local_mode: bool = False ncm_ffmpeg_executable: str = "ffmpeg" diff --git a/src/plugins/nonebot_plugin_multincm/const.py b/src/plugins/nonebot_plugin_multincm/const.py index 4301e65..25b5566 100644 --- a/src/plugins/nonebot_plugin_multincm/const.py +++ b/src/plugins/nonebot_plugin_multincm/const.py @@ -1,13 +1,37 @@ from pathlib import Path -DATA_DIR = Path.cwd() / "data" / "multincm" -SONG_CACHE_DIR = DATA_DIR / "song_cache" -for _p in (DATA_DIR, SONG_CACHE_DIR): - _p.mkdir(parents=True, exist_ok=True) +from cookit.nonebot.localstore import ensure_localstore_path_config +from nonebot_plugin_localstore import get_plugin_data_dir + +ensure_localstore_path_config() -DEBUG_ROOT_DIR = Path.cwd() / "debug" -DEBUG_DIR = DEBUG_ROOT_DIR / "multincm" +DATA_DIR = get_plugin_data_dir() +SONG_CACHE_DIR = DATA_DIR / "song_cache" URL_REGEX = r"music\.163\.com/(.*?)(?P[a-zA-Z]+)(/?\?id=|/)(?P[0-9]+)&?" SHORT_URL_BASE = "https://163cn.tv" SHORT_URL_REGEX = r"163cn\.tv/(?P[a-zA-Z0-9]+)" + +SESSION_FILE_NAME = "session.cache" +SESSION_FILE_PATH = DATA_DIR / SESSION_FILE_NAME + + +def migrate_old_data(): + old_data_dir = Path.cwd() / "data" / "multincm" + if not old_data_dir.exists(): + return + + import shutil + + from nonebot import logger + + old_session_file_path = old_data_dir / SESSION_FILE_NAME + if old_session_file_path.exists(): + shutil.move(old_session_file_path, SESSION_FILE_PATH) + logger.info("已迁移旧登录态文件") + + shutil.rmtree(old_data_dir) + logger.info("已删除旧缓存目录") + + +migrate_old_data() diff --git a/src/plugins/nonebot_plugin_multincm/data_source/base.py b/src/plugins/nonebot_plugin_multincm/data_source/base.py index 91c624c..72c4228 100644 --- a/src/plugins/nonebot_plugin_multincm/data_source/base.py +++ b/src/plugins/nonebot_plugin_multincm/data_source/base.py @@ -6,8 +6,11 @@ from typing import Any, ClassVar, Generic, Optional, TypeVar, Union from typing_extensions import Self, TypeAlias, TypeGuard, override +from yarl import URL + from ..config import config from ..utils import ( + NCMLrcGroupLine, build_item_link, calc_max_page, calc_min_index, @@ -120,9 +123,7 @@ def display_duration(self) -> str: @property def file_suffix(self) -> Optional[str]: - with suppress(Exception): - return self.playable_url.rsplit("/", 1)[-1].rsplit(".", 1)[-1] - return None + return URL(self.playable_url).suffix.removeprefix(".") or None @property def display_filename(self) -> str: @@ -179,7 +180,7 @@ async def get_cover_url(self) -> str: ... async def get_playable_url(self) -> str: ... @abstractmethod - async def get_lyrics(self) -> Optional[list[list[str]]]: ... + async def get_lyrics(self) -> Optional[list[NCMLrcGroupLine]]: ... async def get_info(self) -> SongInfo: ( diff --git a/src/plugins/nonebot_plugin_multincm/data_source/raw/login.py b/src/plugins/nonebot_plugin_multincm/data_source/raw/login.py index f2fa2ac..51b0ee9 100644 --- a/src/plugins/nonebot_plugin_multincm/data_source/raw/login.py +++ b/src/plugins/nonebot_plugin_multincm/data_source/raw/login.py @@ -1,6 +1,11 @@ -from typing import cast +import asyncio +import time +from pathlib import Path +from typing import Any, Optional import anyio +import qrcode +from cookit.loguru import warning_suppress from nonebot import logger from nonebot.utils import run_sync from pyncm import ( @@ -11,75 +16,215 @@ ) from pyncm.apis.login import ( GetCurrentLoginStatus, + LoginFailedException, + LoginQrcodeCheck, + LoginQrcodeUnikey, LoginViaAnonymousAccount, LoginViaCellphone, LoginViaEmail, + SetSendRegisterVerifcationCodeViaCellphone, + WriteLoginInfo, ) from ...config import config -from ...const import DATA_DIR +from ...const import SESSION_FILE_PATH +from .request import NCMResponseError, ncm_request -SESSION_FILE = DATA_DIR / "session.cache" +async def sms_login(phone: str, country_code: int = 86): + timeout = 60 -async def do_login(retry: bool = True): - if SESSION_FILE.exists(): - logger.info(f"使用缓存登录态 ({SESSION_FILE})") + while True: + await ncm_request( + SetSendRegisterVerifcationCodeViaCellphone, + phone, + country_code, + ) + last_send_time = time.time() + logger.success( + f"已发送验证码到 +{country_code} {'*' * (len(phone) - 3)}{phone[-3:]}", + ) + + while True: + captcha = input("> 请输入验证码,留空直接回车代表重发: ").strip() + if not captcha: + if (time_passed := (time.time() - last_send_time)) >= timeout: + break + logger.warning(f"请等待 {timeout - time_passed:.0f} 秒后再重发") + continue + + try: + await ncm_request( + LoginViaCellphone, + phone=phone, + ctcode=country_code, + captcha=captcha, + ) + except LoginFailedException as e: + data: dict[str, Any] = e.args[0] + if data.get("code") != 503: + raise + logger.error("验证码错误,请重新输入") + else: + return + + +async def phone_login( + phone: str, + password: str = "", + password_hash: str = "", + country_code: int = 86, +): + await run_sync(LoginViaCellphone)( + ctcode=country_code, + phone=phone, + password=password, + passwordHash=password_hash, + ) + + +async def email_login( + email: str, + password: str = "", + password_hash: str = "", +): + await run_sync(LoginViaEmail)( + email=email, + password=password, + passwordHash=password_hash, + ) + + +async def qrcode_login(): + async def wait_scan(uni_key: str) -> bool: + last_status: Optional[int] = None + while True: + await asyncio.sleep(2) + try: + await ncm_request(LoginQrcodeCheck, uni_key) + except NCMResponseError as e: + code = e.code + if code != last_status: + last_status = code + extra_tip = ( + f" (用户:{e.data.get('nickname')})" if code == 802 else "" + ) + logger.info(f"当前二维码状态:[{code}] {e.message}{extra_tip}") + if code == 800: + return False # 二维码过期 + if code == 803: + return True # 授权成功 + if code and (code >= 1000): + raise + + while True: + uni_key: str = (await ncm_request(LoginQrcodeUnikey))["unikey"] + + url = f"https://music.163.com/login?codekey={uni_key}" + qr = qrcode.QRCode() + qr.add_data(url) + + logger.info("请使用网易云音乐 APP 扫描下方二维码完成登录") + qr.print_ascii() + + qr_img_filename = "multincm-qrcode.png" + qr_img_path = Path.cwd() / qr_img_filename + with warning_suppress("Failed to save qrcode image"): + qr.make_image().save( + str(qr_img_path), # type: ignore + ) + logger.info( + f"二维码图片已保存至 Bot 根目录的 {qr_img_filename} 文件" + f",如终端中二维码无法扫描可使用", + ) + + logger.info("或使用下方 URL 生成二维码扫描登录:") + logger.info(url) + + try: + scan_res = await wait_scan(uni_key) + finally: + with warning_suppress("Failed to delete qrcode image"): + qr_img_path.unlink(missing_ok=True) + if scan_res: + return + + +async def anonymous_login(): + await ncm_request(LoginViaAnonymousAccount) + + +async def validate_login(): + with warning_suppress("Failed to get login status"): + ret = await ncm_request(GetCurrentLoginStatus) + ok = bool(ret.get("account")) + if ok: + WriteLoginInfo(ret, GetCurrentSession()) + return ok + return False + + +async def do_login(anonymous: bool = False): + using_cached_session = False + + if anonymous: + logger.info("使用游客身份登录") + await anonymous_login() + + elif using_cached_session := SESSION_FILE_PATH.exists(): + logger.info(f"使用缓存登录态 ({SESSION_FILE_PATH})") SetCurrentSession( LoadSessionFromString( - (await anyio.Path(SESSION_FILE).read_text(encoding="u8")), + (await anyio.Path(SESSION_FILE_PATH).read_text(encoding="u8")), ), ) - elif (config.ncm_phone or config.ncm_email) and ( - config.ncm_password or config.ncm_password_hash - ): - retry = False - - if config.ncm_phone: - logger.info("使用手机号登录") - await run_sync(LoginViaCellphone)( - ctcode=config.ncm_ctcode, - phone=config.ncm_phone, - password=config.ncm_password or "", - passwordHash=config.ncm_password_hash or "", + elif config.ncm_phone: + if config.ncm_password or config.ncm_password_hash: + logger.info("使用手机号与密码登录") + await phone_login( + config.ncm_phone, + config.ncm_password or "", + config.ncm_password_hash or "", ) - else: - logger.info("使用邮箱登录") - await run_sync(LoginViaEmail)( - email=config.ncm_email or "", - password=config.ncm_password or "", - passwordHash=config.ncm_password_hash or "", - ) + logger.info("使用短信验证登录") + await sms_login(config.ncm_phone) - await anyio.Path(SESSION_FILE).write_text( - DumpSessionAsString(GetCurrentSession()), - encoding="u8", + elif (has_password := bool(config.ncm_password or config.ncm_password_hash)) and ( + config.ncm_email + ): + logger.info("使用邮箱与密码登录") + await email_login( + config.ncm_email, + config.ncm_password or "", + config.ncm_password_hash or "", ) else: - retry = False - logger.warning("账号或密码未填写,使用游客账号登录") - await run_sync(LoginViaAnonymousAccount)() - - try: - ret = cast(dict, await run_sync(GetCurrentLoginStatus)()) - assert ret["code"] == 200 - assert ret["account"] - except Exception as e: - if await (pth := anyio.Path(SESSION_FILE)).exists(): - await pth.unlink() - - if retry: - logger.warning("恢复缓存会话失败,尝试使用正常流程登录") - await do_login(retry=False) - return - - raise RuntimeError("登录态异常,请重新登录") from e + if config.ncm_email and (not has_password): + logger.warning("配置文件中提供了邮箱,但是通过邮箱登录需要提供密码") + logger.info("使用二维码登录") + await qrcode_login() + + if not (await validate_login()) and using_cached_session: + SESSION_FILE_PATH.unlink() + logger.warning("恢复缓存会话失败,尝试使用正常流程登录") + await do_login() + return - session = GetCurrentSession() - logger.info(f"登录成功,欢迎您,{session.nickname} [{session.uid}]") + session_exists = GetCurrentSession() + if anonymous: + logger.success("游客登录成功") + else: + if not using_cached_session: + SESSION_FILE_PATH.write_text( + DumpSessionAsString(session_exists), + "u8", + ) + logger.success( + f"登录成功,欢迎您,{session_exists.nickname} [{session_exists.uid}]", + ) async def login(): @@ -92,4 +237,10 @@ async def login(): logger.info("检测到当前全局 Session 已登录,插件将跳过登录步骤") return - await do_login() + if not config.ncm_anonymous: + with warning_suppress("登录失败,回落到游客登录"): + await do_login() + return + + with warning_suppress("登录失败"): + await do_login(anonymous=True) diff --git a/src/plugins/nonebot_plugin_multincm/data_source/raw/request.py b/src/plugins/nonebot_plugin_multincm/data_source/raw/request.py index b3d4202..509a91a 100644 --- a/src/plugins/nonebot_plugin_multincm/data_source/raw/request.py +++ b/src/plugins/nonebot_plugin_multincm/data_source/raw/request.py @@ -1,5 +1,6 @@ from functools import partial from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast, overload +from typing_extensions import ParamSpec from nonebot.utils import run_sync from pydantic import BaseModel @@ -10,7 +11,7 @@ from pyncm.apis.track import GetTrackAudio, GetTrackDetail, GetTrackLyrics from ...config import config -from ...utils import calc_min_index, is_debug_mode, write_debug_file +from ...utils import calc_min_index, debug from .models import ( AlbumInfo, AlbumSearchResult, @@ -29,14 +30,36 @@ ) TModel = TypeVar("TModel", bound=BaseModel) +P = ParamSpec("P") -async def ncm_request(api: Callable, *args, **kwargs) -> dict[str, Any]: +class NCMResponseError(Exception): + def __init__(self, name: str, data: dict[str, Any]): + self.name = name + self.data = data + + @property + def code(self) -> Optional[int]: + return self.data.get("code") + + @property + def message(self) -> Optional[str]: + return self.data.get("message") + + def __str__(self): + return f"{self.name} failed: [{self.code}] {self.message}" + + +async def ncm_request( + api: Callable[P, Any], + *args: P.args, + **kwargs: P.kwargs, +) -> dict[str, Any]: ret = await run_sync(api)(*args, **kwargs) - if is_debug_mode(): - write_debug_file(f"{api.__name__}_{{time}}.json", ret) + if debug.enabled: + debug.write(ret, f"{api.__name__}_{{time}}.json") if ret.get("code", 200) != 200: - raise RuntimeError(f"请求 {api.__name__} 失败\n{ret}") + raise NCMResponseError(api.__name__, ret) return ret @@ -143,7 +166,7 @@ async def get_track_audio( **kwargs, ) -> list[TrackAudio]: res = await ncm_request(GetTrackAudio, song_ids, bitrate=bit_rate, **kwargs) - return [TrackAudio(**x) for x in cast(list[dict], res["data"])] + return [TrackAudio(**x) for x in cast("list[dict]", res["data"])] async def get_track_info(ids: list[int], **kwargs) -> list[Song]: @@ -163,7 +186,7 @@ async def get_track_info(ids: list[int], **kwargs) -> list[Song]: async def get_track_lrc(song_id: int): - res = await ncm_request(GetTrackLyrics, song_id) + res = await ncm_request(GetTrackLyrics, str(song_id)) return LyricData(**res) @@ -204,5 +227,5 @@ async def get_playlist_info(playlist_id: int): async def get_album_info(album_id: int): - res = await ncm_request(GetAlbumInfo, album_id) + res = await ncm_request(GetAlbumInfo, str(album_id)) return AlbumInfo(**res) diff --git a/src/plugins/nonebot_plugin_multincm/interaction/message/common.py b/src/plugins/nonebot_plugin_multincm/interaction/message/common.py index 404cc69..263f9da 100644 --- a/src/plugins/nonebot_plugin_multincm/interaction/message/common.py +++ b/src/plugins/nonebot_plugin_multincm/interaction/message/common.py @@ -1,3 +1,5 @@ +from typing import Union + from cookit.loguru import warning_suppress from nonebot_plugin_alconna.uniseg import UniMessage @@ -13,7 +15,7 @@ async def construct_info_msg( - it: BaseSong | BasePlaylist, + it: Union[BaseSong, BasePlaylist], tip_command: bool = True, ) -> UniMessage: tip = ( @@ -41,15 +43,8 @@ async def send(): if config.ncm_send_media: with warning_suppress(f"Send {song} file failed"): receipt = await send_song_media(song) - reply = None - if receipt and (receipt is not ...): - r = receipt.get_reply() - if isinstance(r, list): - reply = r[0] if r else None - else: - reply = r await (await construct_info_msg(song, tip_command=(receipt is ...))).send( - reply_to=reply, + reply_to=receipt.get_reply() if receipt and (receipt is not ...) else None, ) await send() diff --git a/src/plugins/nonebot_plugin_multincm/interaction/message/song_file.py b/src/plugins/nonebot_plugin_multincm/interaction/message/song_file.py index f26478d..91f6cde 100644 --- a/src/plugins/nonebot_plugin_multincm/interaction/message/song_file.py +++ b/src/plugins/nonebot_plugin_multincm/interaction/message/song_file.py @@ -21,13 +21,15 @@ async def download_song(info: "SongInfo"): filename = info.download_filename file_path = SONG_CACHE_DIR / filename - if not file_path.exists(): - async with AsyncClient(follow_redirects=True) as cli: - async with cli.stream("GET", info.playable_url) as resp: - resp.raise_for_status() - with file_path.open("wb") as f: - async for chunk in resp.aiter_bytes(): - f.write(chunk) + if file_path.exists(): + return file_path + + async with AsyncClient(follow_redirects=True) as cli, cli.stream("GET", info.playable_url) as resp: # fmt: skip + resp.raise_for_status() + SONG_CACHE_DIR.mkdir(parents=True, exist_ok=True) + with file_path.open("wb") as f: + async for chunk in resp.aiter_bytes(): + f.write(chunk) return file_path @@ -55,7 +57,6 @@ async def get_current_ev_receipt(msg_ids: Any): context=ev, exporter=exporter, msg_ids=msg_ids if isinstance(msg_ids, list) else [msg_ids], - uni_factory=UniMessage, ) @@ -63,7 +64,7 @@ async def send_song_media_telegram(info: "SongInfo", as_file: bool = False): return await send_song_media_uni_msg(await download_song(info), info, as_file=False) -async def send_song_media_onebot_v11(info: "SongInfo", as_file: bool = False): +async def send_song_media_milky(info: "SongInfo", as_file: bool = False): async def send_voice(): if not await ffmpeg_exists(): logger.warning( @@ -76,27 +77,23 @@ async def send_voice(): ).send() async def send_file(): - from nonebot.adapters.onebot.v11 import ( - Bot as OB11Bot, + from nonebot.adapters.milky import ( + Bot as MilkyBot, ) - from nonebot.adapters.onebot.v11 import ( + from nonebot.adapters.milky import ( + FriendMessageEvent, GroupMessageEvent, - PrivateMessageEvent, ) - bot = cast(OB11Bot, current_bot.get()) + bot = cast(MilkyBot, current_bot.get()) event = current_event.get() - if not isinstance(event, GroupMessageEvent | PrivateMessageEvent): + if not isinstance(event, (GroupMessageEvent, FriendMessageEvent)): raise TypeError("Event not supported") - file = ( - (await download_song(info)) - if config.ncm_ob_v11_local_mode - else cast(str, (await bot.download_file(url=info.playable_url))["file"]) - ) + file = await download_song(info) - if isinstance(event, PrivateMessageEvent): + if isinstance(event, FriendMessageEvent): await bot.upload_private_file( user_id=event.user_id, file=file, @@ -120,7 +117,7 @@ async def send_song_media_platform_specific( adapter_name = bot.adapter.get_name() processors = { "Telegram": send_song_media_telegram, - "OneBot V11": send_song_media_onebot_v11, + "Milky": send_song_media_milky, } if adapter_name not in processors: raise TypeError("This adapter is not supported") diff --git a/src/plugins/nonebot_plugin_multincm/interaction/resolver.py b/src/plugins/nonebot_plugin_multincm/interaction/resolver.py index b0afc7f..f727573 100644 --- a/src/plugins/nonebot_plugin_multincm/interaction/resolver.py +++ b/src/plugins/nonebot_plugin_multincm/interaction/resolver.py @@ -1,13 +1,13 @@ import re from dataclasses import dataclass -from typing import Annotated, TypeAlias +from typing import Annotated, Optional, Union +from typing_extensions import TypeAlias from cachetools import TTLCache from cookit import flatten, queued from cookit.loguru import warning_suppress from httpx import AsyncClient -from nonebot.adapters import Bot as BaseBot -from nonebot.adapters import Message as BaseMessage +from nonebot.adapters import Bot as BaseBot, Message as BaseMessage from nonebot.consts import REGEX_MATCHED from nonebot.matcher import Matcher from nonebot.params import Depends @@ -28,9 +28,10 @@ from ..utils import is_song_card_supported from .cache import get_cache -ExpectedTypeType: TypeAlias = ( - type[GeneralSongOrPlaylist] | tuple[type[GeneralSongOrPlaylist], ...] -) +ExpectedTypeType: TypeAlias = Union[ + type[GeneralSongOrPlaylist], + tuple[type[GeneralSongOrPlaylist], ...], +] resolved_cache: TTLCache[int, "ResolveCache"] = TTLCache( @@ -61,7 +62,7 @@ async def resolve_from_link_params_cool_down(link_type: str, link_id: int): def check_is_expected_type( item_type: str, - expected_type: ExpectedTypeType | None = None, + expected_type: Optional[ExpectedTypeType] = None, ) -> bool: if not expected_type: return True @@ -76,8 +77,8 @@ def check_is_expected_type( def extract_song_card_hyper( msg: UniMessage, - bot: BaseBot | None = None, -) -> Hyper | None: + bot: Optional[BaseBot] = None, +) -> Optional[Hyper]: if (Hyper in msg) and is_song_card_supported(bot): return msg[Hyper, 0] return None @@ -85,7 +86,7 @@ def extract_song_card_hyper( async def resolve_short_url( suffix: str, - expected_type: ExpectedTypeType | None = None, + expected_type: Optional[ExpectedTypeType] = None, use_cool_down: bool = False, ) -> GeneralSongOrPlaylist: async with AsyncClient(base_url=SHORT_URL_BASE) as client: @@ -93,7 +94,8 @@ async def resolve_short_url( if resp.status_code // 100 != 3: raise ValueError( - f"Short url {suffix} returned invalid status code {resp.status_code}", + f"Short url {suffix} " + f"returned invalid status code {resp.status_code}", ) location = resp.headers.get("Location") @@ -113,9 +115,9 @@ async def resolve_short_url( async def resolve_from_matched( matched: re.Match[str], - expected_type: ExpectedTypeType | None = None, + expected_type: Optional[ExpectedTypeType] = None, use_cool_down: bool = False, -) -> GeneralSongOrPlaylist | None: +) -> Optional[GeneralSongOrPlaylist]: groups = matched.groupdict() if "suffix" in groups: @@ -143,9 +145,9 @@ async def resolve_from_matched( async def resolve_from_plaintext( text: str, - expected_type: ExpectedTypeType | None = None, + expected_type: Optional[ExpectedTypeType] = None, use_cool_down: bool = False, -) -> GeneralSongOrPlaylist | None: +) -> Optional[GeneralSongOrPlaylist]: for regex in (SHORT_URL_REGEX, URL_REGEX): if m := re.search(regex, text, re.IGNORECASE): return await resolve_from_matched(m, expected_type, use_cool_down) @@ -155,9 +157,9 @@ async def resolve_from_plaintext( async def resolve_from_card( card: Hyper, resolve_playable: bool = True, - expected_type: ExpectedTypeType | None = None, + expected_type: Optional[ExpectedTypeType] = None, use_cool_down: bool = False, -) -> GeneralSongOrPlaylist | None: +) -> Optional[GeneralSongOrPlaylist]: if not (raw := card.raw): return None @@ -171,10 +173,10 @@ async def resolve_from_card( async def resolve_from_msg( msg: UniMessage, resolve_playable_card: bool = True, - expected_type: ExpectedTypeType | None = None, + expected_type: Optional[ExpectedTypeType] = None, use_cool_down: bool = False, - bot: BaseBot | None = None, -) -> GeneralSongOrPlaylist | None: + bot: Optional[BaseBot] = None, +) -> Optional[GeneralSongOrPlaylist]: if (h := extract_song_card_hyper(msg, bot)) and ( it := await resolve_from_card( h, @@ -195,9 +197,9 @@ async def resolve_from_ev_msg( state: T_State, bot: BaseBot, matcher: Matcher, - expected_type: ExpectedTypeType | None = None, + expected_type: Optional[ExpectedTypeType] = None, ) -> GeneralSongOrPlaylist: - regex_matched: re.Match[str] | None = state.get(REGEX_MATCHED) + regex_matched: Optional[re.Match[str]] = state.get(REGEX_MATCHED) if regex_matched: # auto resolve if h := extract_song_card_hyper(msg, bot): if it := await resolve_from_card( @@ -227,7 +229,7 @@ async def resolve_from_ev_msg( ): return it - await matcher.finish() + await matcher.finish() # noqa: RET503: NoReturn async def dependency_resolve_from_ev( @@ -270,7 +272,7 @@ async def dependency_resolve_playlist_from_ev( async def dependency_is_auto_resolve(state: T_State) -> bool: - return not not state.get(REGEX_MATCHED) + return bool(state.get(REGEX_MATCHED)) ResolvedItem = Annotated[ diff --git a/src/plugins/nonebot_plugin_multincm/render/__init__.py b/src/plugins/nonebot_plugin_multincm/render/__init__.py index 146dc35..1cde1c5 100644 --- a/src/plugins/nonebot_plugin_multincm/render/__init__.py +++ b/src/plugins/nonebot_plugin_multincm/render/__init__.py @@ -6,6 +6,5 @@ render_track_card_html as render_track_card_html, ) from .lyrics import ( - LyricsRenderParams as LyricsRenderParams, render_lyrics as render_lyrics, ) diff --git a/src/plugins/nonebot_plugin_multincm/render/lyrics.py b/src/plugins/nonebot_plugin_multincm/render/lyrics.py index 7ffa78d..f200d4a 100644 --- a/src/plugins/nonebot_plugin_multincm/render/lyrics.py +++ b/src/plugins/nonebot_plugin_multincm/render/lyrics.py @@ -1,12 +1,19 @@ -from typing import TypedDict -from typing_extensions import Unpack +from typing import TYPE_CHECKING from .utils import render_html, render_template +if TYPE_CHECKING: + from ..utils import NCMLrcGroupLine -class LyricsRenderParams(TypedDict): - groups: list[list[str]] - -async def render_lyrics(**kwargs: Unpack[LyricsRenderParams]) -> bytes: - return await render_html(await render_template("lyrics.html.jinja", **kwargs)) +async def render_lyrics(groups: list["NCMLrcGroupLine"]) -> bytes: + group_tuples = [[(n, r) for n, r in x.lrc.items()] for x in groups] + sort_order = ("roma", "main", "trans") + for group in group_tuples: + group.sort(key=lambda x: sort_order.index(x[0]) if x[0] in sort_order else 999) + return await render_html( + await render_template( + "lyrics.html.jinja", + groups=group_tuples, + ), + ) diff --git a/src/plugins/nonebot_plugin_multincm/render/templates/lyrics.html.jinja b/src/plugins/nonebot_plugin_multincm/render/templates/lyrics.html.jinja index f46163a..c7e63ab 100644 --- a/src/plugins/nonebot_plugin_multincm/render/templates/lyrics.html.jinja +++ b/src/plugins/nonebot_plugin_multincm/render/templates/lyrics.html.jinja @@ -19,15 +19,18 @@ font-size: 20px; font-weight: bold; } + + .lyric-group .roma { + font-size: 14px; + } {%- endblock %} {% block main -%} {% for group in groups -%}
- {% for it in group -%} - {% if it %}
{{ it }}
{% endif %} - {%- endfor %} + {% for n, r in group %}
{{ r }}
+ {% endfor -%}
{%- endfor %}