From 706b5485149ff166063db383d7cecc70c513a0cc Mon Sep 17 00:00:00 2001 From: aicorein Date: Thu, 6 Feb 2025 19:58:56 +0800 Subject: [PATCH] [All] Bumped to 3.1 (#35) Updated 3.1 docs & bumped to 3.1 --- docs/source/api/index.rst | 6 +- docs/source/api/melobot.adapter.rst | 4 + docs/source/api/melobot.handle.rst | 31 + docs/source/api/melobot.io.rst | 31 + docs/source/api/melobot.mixin.rst | 14 + docs/source/api/melobot.session.rst | 9 + docs/source/api/melobot.typ.rst | 21 +- docs/source/api/melobot.utils.rst | 70 ++ docs/source/index.md | 1 + docs/source/intro/action-echo.md | 16 +- docs/source/intro/async-callable.md | 5 + docs/source/intro/create-bot.md | 18 +- docs/source/intro/event-preprocess.md | 144 ++-- docs/source/intro/event-process.md | 74 +- docs/source/intro/msg-action.md | 26 +- docs/source/ob_api/index.rst | 10 +- docs/source/ob_api/v11.const.rst | 4 + docs/source/ob_api/v11.handle.rst | 1 - docs/source/ob_api/v11.io.rst | 20 +- docs/source/ob_api/v11.rst | 5 + docs/source/ob_api/v11.utils.rst | 54 +- docs/source/ob_refer/preprocess.md | 10 +- docs/source/update-log.md | 258 +++++++ src/melobot/__init__.py | 21 +- src/melobot/_hook.py | 46 +- src/melobot/_imp.py | 24 +- src/melobot/_meta.py | 4 +- src/melobot/adapter/__init__.py | 14 +- src/melobot/adapter/base.py | 53 +- src/melobot/adapter/generic.py | 28 +- src/melobot/adapter/model.py | 38 +- src/melobot/bot/base.py | 79 +- src/melobot/bot/dispatch.py | 249 ++++--- src/melobot/ctx.py | 105 ++- src/melobot/di.py | 127 ++-- src/melobot/exceptions.py | 15 +- src/melobot/handle/__init__.py | 17 +- src/melobot/handle/base.py | 472 ++++++++++-- src/melobot/handle/process.py | 399 ---------- src/melobot/handle/register.py | 323 ++++++++ src/melobot/io/__init__.py | 7 + src/melobot/io/base.py | 56 +- src/melobot/log/base.py | 7 +- src/melobot/mixin.py | 244 ++++++ src/melobot/plugin/base.py | 108 +-- src/melobot/plugin/ipc.py | 9 +- src/melobot/plugin/load.py | 4 +- src/melobot/protocols/base.py | 8 +- src/melobot/protocols/onebot/v11/__init__.py | 59 +- .../protocols/onebot/v11/adapter/__init__.py | 138 +++- .../protocols/onebot/v11/adapter/action.py | 6 +- .../protocols/onebot/v11/adapter/base.py | 103 ++- .../protocols/onebot/v11/adapter/echo.py | 17 +- .../protocols/onebot/v11/adapter/event.py | 68 +- src/melobot/protocols/onebot/v11/const.py | 6 +- src/melobot/protocols/onebot/v11/handle.py | 337 ++------- .../protocols/onebot/v11/io/__init__.py | 2 +- src/melobot/protocols/onebot/v11/io/base.py | 84 ++- .../protocols/onebot/v11/io/duplex_http.py | 9 +- .../protocols/onebot/v11/io/forward.py | 9 +- .../protocols/onebot/v11/io/reverse.py | 9 +- .../protocols/onebot/v11/utils/__init__.py | 32 +- src/melobot/protocols/onebot/v11/utils/abc.py | 209 ------ .../protocols/onebot/v11/utils/check.py | 22 +- src/melobot/session/__init__.py | 18 +- src/melobot/session/base.py | 186 +++-- src/melobot/session/option.py | 14 +- src/melobot/typ.py | 482 ------------ src/melobot/typ/__init__.py | 9 + src/melobot/typ/_enum.py | 173 +++++ src/melobot/typ/base.py | 49 ++ src/melobot/typ/cls.py | 115 +++ src/melobot/utils.py | 700 ------------------ src/melobot/utils/__init__.py | 12 + src/melobot/utils/atool.py | 104 +++ src/melobot/utils/base.py | 92 +++ src/melobot/utils/check/__init__.py | 1 + src/melobot/utils/check/base.py | 129 ++++ src/melobot/utils/common.py | 251 +++++++ src/melobot/utils/deco.py | 360 +++++++++ src/melobot/utils/match/__init__.py | 9 + .../utils/match.py => utils/match/base.py} | 82 +- src/melobot/utils/parse/__init__.py | 2 + src/melobot/utils/parse/base.py | 40 + .../v11/utils/parse.py => utils/parse/cmd.py} | 75 +- tests/onebot/v11/test_adapter_base.py | 4 +- tests/onebot/v11/test_adapter_echo.py | 64 +- tests/onebot/v11/test_handle.py | 22 +- tests/test_handle_process.py | 2 +- tests/test_utils.py | 5 +- tests/{onebot/v11 => }/test_utils_match.py | 4 +- tests/{onebot/v11 => }/test_utils_parse.py | 4 +- 92 files changed, 4390 insertions(+), 3017 deletions(-) create mode 100644 docs/source/api/melobot.mixin.rst create mode 100644 docs/source/ob_api/v11.rst create mode 100644 docs/source/update-log.md delete mode 100644 src/melobot/handle/process.py create mode 100644 src/melobot/handle/register.py create mode 100644 src/melobot/mixin.py delete mode 100644 src/melobot/protocols/onebot/v11/utils/abc.py delete mode 100644 src/melobot/typ.py create mode 100644 src/melobot/typ/__init__.py create mode 100644 src/melobot/typ/_enum.py create mode 100644 src/melobot/typ/base.py create mode 100644 src/melobot/typ/cls.py delete mode 100644 src/melobot/utils.py create mode 100644 src/melobot/utils/__init__.py create mode 100644 src/melobot/utils/atool.py create mode 100644 src/melobot/utils/base.py create mode 100644 src/melobot/utils/check/__init__.py create mode 100644 src/melobot/utils/check/base.py create mode 100644 src/melobot/utils/common.py create mode 100644 src/melobot/utils/deco.py create mode 100644 src/melobot/utils/match/__init__.py rename src/melobot/{protocols/onebot/v11/utils/match.py => utils/match/base.py} (61%) create mode 100644 src/melobot/utils/parse/__init__.py create mode 100644 src/melobot/utils/parse/base.py rename src/melobot/{protocols/onebot/v11/utils/parse.py => utils/parse/cmd.py} (83%) rename tests/{onebot/v11 => }/test_utils_match.py (95%) rename tests/{onebot/v11 => }/test_utils_parse.py (94%) diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index ad1ee287..76c824be 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -15,10 +15,11 @@ melobot API - :class:`~melobot.adapter.base.Adapter`, :class:`~melobot.adapter.model.Event`, :class:`~melobot.adapter.model.Action`, :class:`~melobot.adapter.model.Echo` - :func:`.send_text`, :func:`.send_image` - :class:`.Flow`, :class:`.FlowStore`, :func:`.node`, :func:`.rewind`, :func:`.stop` +- :class:`.FlowDecorator`, :func:`~melobot.handle.on_event`, :func:`.on_text`, :func:`.on_start_match`, :func:`.on_contain_match`, :func:`.on_full_match`, :func:`.on_end_match`, :func:`.on_regex_match`, :func:`.on_command` - :class:`.Depends` -- :class:`.Rule`, :func:`.enter_session`, :class:`.SessionStore`, :func:`.suspend` +- :class:`.Session`, :class:`.Rule`, :class:`.DefaultRule`, :func:`.enter_session`, :class:`.SessionStore`, :func:`.suspend` - :class:`.GenericLogger`, :class:`.Logger`, :class:`.LogLevel`, :func:`.get_logger` -- :class:`.HandleLevel`, :class:`.LogicMode` +- :class:`.LogicMode` - :class:`.Context` 各模块 API 文档索引: @@ -40,3 +41,4 @@ melobot API melobot.typ melobot.exceptions melobot.ctx + melobot.mixin diff --git a/docs/source/api/melobot.adapter.rst b/docs/source/api/melobot.adapter.rst index 30bb2c18..865d3a1d 100644 --- a/docs/source/api/melobot.adapter.rst +++ b/docs/source/api/melobot.adapter.rst @@ -28,6 +28,10 @@ melobot.adapter :members: :exclude-members: __init__ +.. autoclass:: melobot.adapter.TextEvent + :members: + :exclude-members: __init__ + .. autoclass:: melobot.adapter.Action :members: :exclude-members: __init__ diff --git a/docs/source/api/melobot.handle.rst b/docs/source/api/melobot.handle.rst index c26541d9..7414e2ae 100644 --- a/docs/source/api/melobot.handle.rst +++ b/docs/source/api/melobot.handle.rst @@ -20,6 +20,32 @@ melobot.handle .. autofunction:: melobot.handle.no_deps_node +流装饰器 +---------- + +流装饰相关的组件,可以将一个普通函数装饰为一个处理流。 + +这些组件,在 melobot 的教程中,早期我们称它们为“事件绑定方法”或“事件绑定函数”。 + +.. autoclass:: melobot.handle.FlowDecorator + :exclude-members: auto_flow_wrapped + +.. autofunction:: melobot.handle.on_event + +.. autofunction:: melobot.handle.on_text + +.. autofunction:: melobot.handle.on_start_match + +.. autofunction:: melobot.handle.on_contain_match + +.. autofunction:: melobot.handle.on_end_match + +.. autofunction:: melobot.handle.on_full_match + +.. autofunction:: melobot.handle.on_regex_match + +.. autofunction:: melobot.handle.on_command + 处理流控制 ------------- @@ -55,3 +81,8 @@ melobot.handle .. autofunction:: melobot.handle.get_event .. autofunction:: melobot.handle.try_get_event + +弃用项,临时存在 +---------------- + +.. autofunction:: melobot.handle.GetParseArgs diff --git a/docs/source/api/melobot.io.rst b/docs/source/api/melobot.io.rst index 191b2d9b..50ccb458 100644 --- a/docs/source/api/melobot.io.rst +++ b/docs/source/api/melobot.io.rst @@ -37,3 +37,34 @@ melobot.io .. autoclass:: melobot.io.EchoPacket :members: :exclude-members: __init__, ok, status, prompt, noecho + +泛型 +------ + +.. data:: melobot.io.InPacketT + + 输入包泛型 + +.. data:: melobot.io.OutPacketT + + 输出包泛型 + +.. data:: melobot.io.EchoPacketT + + 回应包泛型 + +.. data:: melobot.io.InSourceT + + 输入源泛型 + +.. data:: melobot.io.OutSourceT + + 输出源泛型 + +.. data:: melobot.io.InOrOutSourceT + + 输入或输出源泛型 + +.. data:: melobot.io.IOSourceT + + 输入输出源泛型 diff --git a/docs/source/api/melobot.mixin.rst b/docs/source/api/melobot.mixin.rst new file mode 100644 index 00000000..4b602899 --- /dev/null +++ b/docs/source/api/melobot.mixin.rst @@ -0,0 +1,14 @@ +melobot.mixin +============= + +.. autoclass:: melobot.mixin.LogMixin + :members: + +.. autoclass:: melobot.mixin.FlagMixin + :members: + +.. autoclass:: melobot.mixin.AttrReprMixin + :members: + +.. autoclass:: melobot.mixin.HookMixin + :members: diff --git a/docs/source/api/melobot.session.rst b/docs/source/api/melobot.session.rst index ee781981..34c03a9c 100644 --- a/docs/source/api/melobot.session.rst +++ b/docs/source/api/melobot.session.rst @@ -14,6 +14,13 @@ melobot.session .. autoclass:: melobot.session.Rule :members: +.. autoclass:: melobot.session.CompareInfo + :members: + :exclude-members: __init__ + +.. autoclass:: melobot.session.DefaultRule + :members: + .. autofunction:: melobot.session.enter_session 会话状态 @@ -21,6 +28,8 @@ melobot.session .. autofunction:: melobot.session.suspend +.. autofunction:: melobot.session.get_session + .. autofunction:: melobot.session.get_rule .. autofunction:: melobot.session.get_session_store diff --git a/docs/source/api/melobot.typ.rst b/docs/source/api/melobot.typ.rst index ef263518..f2395cad 100644 --- a/docs/source/api/melobot.typ.rst +++ b/docs/source/api/melobot.typ.rst @@ -3,17 +3,11 @@ melobot.typ =========== -.. autoclass:: melobot.typ.HandleLevel - :members: - :exclude-members: __new__ - .. autoclass:: melobot.typ.LogicMode :members: .. autofunction:: melobot.typ.is_type -.. autofunction:: melobot.typ.abstractattr - .. autoclass:: melobot.typ.BetterABCMeta :members: :exclude-members: __call__, DummyAttribute @@ -21,19 +15,23 @@ melobot.typ .. autoclass:: melobot.typ.BetterABC :members: +.. autofunction:: melobot.typ.abstractattr + .. autoclass:: melobot.typ.SingletonMeta :exclude-members: __call__ .. autoclass:: melobot.typ.SingletonBetterABCMeta :exclude-members: __call__ -.. autoclass:: melobot.typ.Markable - :members: - :exclude-members: __init__ - .. autoclass:: melobot.typ.VoidType :members: +.. autoclass:: melobot.typ.AsyncCallable + :exclude-members: __call__, __init__ + +.. autoclass:: melobot.typ.SyncOrAsyncCallable + :exclude-members: __call__, __init__ + .. data:: melobot.typ.T 泛型 T,无约束 @@ -45,6 +43,3 @@ melobot.typ .. data:: melobot.typ.P :obj:`~typing.ParamSpec` 泛型 P,无约束 - -.. autoclass:: melobot.typ.AsyncCallable - :exclude-members: __call__, __init__ diff --git a/docs/source/api/melobot.utils.rst b/docs/source/api/melobot.utils.rst index f54aa325..70075d3f 100644 --- a/docs/source/api/melobot.utils.rst +++ b/docs/source/api/melobot.utils.rst @@ -1,6 +1,9 @@ melobot.utils ============= +基础工具 +---------- + .. autofunction:: melobot.utils.get_obj_name .. autofunction:: melobot.utils.singleton @@ -14,6 +17,8 @@ melobot.utils .. autofunction:: melobot.utils.to_coro +.. autofunction:: melobot.utils.to_sync + .. autofunction:: melobot.utils.if_not .. autofunction:: melobot.utils.unfold_ctx @@ -37,3 +42,68 @@ melobot.utils .. autofunction:: melobot.utils.async_at .. autofunction:: melobot.utils.async_interval + +检查/验证 +----------- + +.. autoclass:: melobot.utils.check.Checker + :exclude-members: __init__ + +.. autoclass:: melobot.utils.check.WrappedChecker + :exclude-members: __init__, check + +基础检查/验证工具 +------------------ + +.. autofunction:: melobot.utils.check.checker_join + +.. _melobot_match: + +匹配 +------ + +.. autoclass:: melobot.utils.match.Matcher + :exclude-members: __init__ + +.. autoclass:: melobot.utils.match.WrappedMatcher + :exclude-members: __init__, match + +基础匹配工具 +------------- + +.. autoclass:: melobot.utils.match.StartMatcher + +.. autoclass:: melobot.utils.match.ContainMatcher + +.. autoclass:: melobot.utils.match.EndMatcher + +.. autoclass:: melobot.utils.match.FullMatcher + +.. autoclass:: melobot.utils.match.RegexMatcher + +.. _melobot_parse: + +解析 +------- + +.. autoclass:: melobot.utils.parse.Parser + :exclude-members: __init__ + +.. autoclass:: melobot.utils.parse.AbstractParseArgs + :exclude-members: __init__ + +基础解析工具 +------------- + +.. autoclass:: melobot.utils.parse.CmdParser + :exclude-members: format + +.. autoclass:: melobot.utils.parse.CmdArgs + :exclude-members: __init__, vals + +.. autoclass:: melobot.utils.parse.CmdParserFactory + +.. autoclass:: melobot.utils.parse.CmdArgFormatter + +.. autoclass:: melobot.utils.parse.CmdArgFormatInfo + :exclude-members: __init__ diff --git a/docs/source/index.md b/docs/source/index.md index a7286957..ea953111 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -31,6 +31,7 @@ ob_api/index :caption: 更多 :hidden: +update-log melobot-prev ``` diff --git a/docs/source/intro/action-echo.md b/docs/source/intro/action-echo.md index 825f74e3..dfbc1a0c 100644 --- a/docs/source/intro/action-echo.md +++ b/docs/source/intro/action-echo.md @@ -11,7 +11,7 @@ ## 行为句柄 -当直接使用行为方法时,默认是尽快完成的。即不等待也不关心 OneBot 实现端是否成功完成了行为: +当直接使用 OneBot 的行为方法时,默认是尽快完成的。即不等待也不关心 OneBot 实现端是否成功完成了行为: ```python # 发送消息而不等待,也不关心是否成功 @@ -64,7 +64,7 @@ async def _(adapter: Adapter): await adapter.send("我一定是第二条消息") ``` -一整个函数都需要等待时,可以使用 {func}`.unfold_ctx` 装饰器: +一整个事件处理函数都需要等待时,可以使用 {func}`.unfold_ctx` 装饰器: ```python from melobot.protocols.onebot.v11 import Adapter, on_message, EchoRequireCtx @@ -78,6 +78,12 @@ async def _(adapter: Adapter): await adapter.send("我一定是第二条消息") ``` +注意:melobot 的所有行为操作,包括刚才提及的协议特定操作,或通用的操作(例如 {func}`.send_text`),都会返回行为句柄。 + +但这里提及的 {meth}`~.v11.Adapter.with_echo` 以及 {class}`.EchoRequireCtx` 等操作,只对 OneBot 协议特有行为操作有效。 + +**其他协议的行为句柄,是否需要类似的操作,取决于它们的实现。但不变的是:使用行为句柄,可以获知行为的执行情况。** + ```{admonition} 提示 :class: tip **不建议频繁等待行为操作**。等待总是需要更多时间,大量使用会降低运行效率。 @@ -97,7 +103,7 @@ async def _(adapter: Adapter): 和自定义消息段类似,有时候我们总是会需要自定义的 OneBot 行为类型的。一般这样构造: ```python -from melobot.protocols.onebot.v11.adapter.action import Action +from melobot.protocols.onebot.v11 import Action # 临时构造一个自定义行为 action = Action(type="action_type", params={"param1": 123456}) @@ -117,9 +123,9 @@ action.set_echo(True) handles = await adapter.call_output(action) ``` -实际上,适配器所有行为操作,都是先在内部构建 {class}`~.v11.adapter.action.Action` 对象,再通过 {meth}`~.v11.Adapter.call_output` 输出。 +实际上,适配器所有行为操作,都是先在内部构建 {class}`~melobot.adapter.model.Action` 对象,再通过 {meth}`~melobot.adapter.base.Adapter.call_output` 输出。 -而所有内置行为对象,也可以在文档 [OneBot v11 行为类型](onebot_v11_action) 中找到。你完全可以手动构造,再使用 {meth}`~.v11.Adapter.call_output` 输出,这适用于更精细的控制需求。 +而所有 OneBot v11 的行为对象,也可以在文档 [OneBot v11 行为类型](onebot_v11_action) 中找到。你完全可以手动构造,再使用 {meth}`~.v11.Adapter.call_output` 输出,这适用于更精细的控制需求。 ## 总结 diff --git a/docs/source/intro/async-callable.md b/docs/source/intro/async-callable.md index 683b28c8..c98e8d7e 100644 --- a/docs/source/intro/async-callable.md +++ b/docs/source/intro/async-callable.md @@ -79,3 +79,8 @@ aprint = to_async(print) {func}`.to_async` 只是将原对象包裹在一个异步函数中,从而满足异步可调用的接口。 **即:{func}`.to_async` 不做接口兼容外的处理,因此也就不会提供并发/并行的能力。** + +另外,有的接口支持 {class}`.SyncOrAsyncCallable` 类型的参数,这表明它是**同时支持同步可调用、异步可调用**的。{class}`.SyncOrAsyncCallable` 有以下特性: + +{class}`.SyncOrAsyncCallable`\[{data}`.P`, {data}`.T`\] {math}`\iff` +{external:class}`~collections.abc.Callable`\[{data}`.P`, {external:class}`~collections.abc.Awaitable`\[{data}`.T` | {external:class}`~typing.Awaitable`\[{data}`.T`\]\]\] diff --git a/docs/source/intro/create-bot.md b/docs/source/intro/create-bot.md index 325af418..8d02436a 100644 --- a/docs/source/intro/create-bot.md +++ b/docs/source/intro/create-bot.md @@ -4,6 +4,8 @@ 我们先从最简单的 OneBot v11 协议开始学习如何搭建一个机器人,并以此学习 melobot 的基本特性。 +虽然你可能不需要使用 OneBot v11 协议,但也可以简单浏览这部分教程,从而理解一些 melobot 中的基本概念。在后续部分,我们会拓展教程到协议通用领域。 + 首先需要一个“OneBot 实现程序”作为“前端”,完成与 qq 服务器的通信过程。请自行配置 OneBot 协议实现。 ```{admonition} 相关知识 @@ -23,8 +25,8 @@ ```{code} python :number-lines: -from melobot import Bot, PluginPlanner, send_text -from melobot.protocols.onebot.v11 import Adapter, ForwardWebSocketIO, on_start_match +from melobot import Bot, PluginPlanner, on_start_match, send_text +from melobot.protocols.onebot.v11 import ForwardWebSocketIO, OneBotV11Protocol @on_start_match(".sayhi") async def echo_hi() -> None: @@ -35,11 +37,11 @@ test_plugin = PluginPlanner(version="1.0.0", flows=[echo_hi]) if __name__ == "__main__": ( Bot(__name__) - .add_io(ForwardWebSocketIO("ws://127.0.0.1:8080")) - .add_adapter(Adapter()) + .add_protocol(OneBotV11Protocol(ForwardWebSocketIO("ws://127.0.0.1:8080"))) .load_plugin(test_plugin) .run() ) + ``` 运行后,在机器人加入的任何一个群聊中,或与机器人的私聊中,输入以 `.sayhi` 起始的消息,即可回复:`Hello, melobot!`。 @@ -67,16 +69,14 @@ test_plugin = PluginPlanner(version="1.0.0", flows=[echo_hi]) 接下来开始按以下步骤创建、初始化和启动一个 bot: 1. 通过 {class}`.Bot` 创建一个 bot; -2. 通过 {class}`.ForwardWebSocketIO` 添加一个 OneBot v11 协议的输入输出源; -3. 通过 {class}`.v11.Adapter` 添加 OneBot v11 协议的适配器; -4. 通过 {meth}`.load_plugin` 创建并加载插件 +2. 通过 {meth}`~.Bot.add_protocol` 添加 OneBot v11 相关的协议栈支持,此处需要提供一个输入输出源; +4. 通过 {meth}`~.Bot.load_plugin` 创建并加载插件 5. 启动 bot ```python ( Bot(__name__) - .add_io(ForwardWebSocketIO("ws://127.0.0.1:8080")) - .add_adapter(Adapter()) + .add_protocol(OneBotV11Protocol(ForwardWebSocketIO("ws://127.0.0.1:8080"))) .load_plugin(test_plugin) .run() ) diff --git a/docs/source/intro/event-preprocess.md b/docs/source/intro/event-preprocess.md index 7f691390..c0c28ac6 100644 --- a/docs/source/intro/event-preprocess.md +++ b/docs/source/intro/event-preprocess.md @@ -15,13 +15,12 @@ 有些时候,我们需要事件满足某些条件,才决定处理它。这就是检查器需要做的事。 -内置支持基于两种权限等级的检查:{class}`.LevelRole` 和 {class}`.GroupRole`。 +OneBot 协议组件,内置支持基于两种权限等级的检查:{class}`.LevelRole` 和 {class}`.GroupRole`。 {class}`.LevelRole` 总共分为五级权限(OWNER > SUPER > WHITE > NORMAL > BLACK)。使用例子如下所示: ```python -from melobot.protocols.onebot.v11 import on_message -from melobot.protocols.onebot.v11.utils import MsgChecker, LevelRole +from melobot.protocols.onebot.v11 import on_message, MsgChecker, LevelRole # 这些整型值都代表 qq 号 OWNER = 10001 @@ -71,7 +70,7 @@ async def _(): 频繁地传入各个等级包含的 id 很不方便,因此可以使用工厂类 {class}`.MsgCheckerFactory`: ```python -from melobot.protocols.onebot.v11.utils import MsgCheckerFactory +from melobot.protocols.onebot.v11 import MsgCheckerFactory checker_ft = MsgCheckerFactory( role=LevelRole.OWNER, @@ -93,8 +92,7 @@ priv_checker: PrivateMsgChecker = checker_ft.get_private(role=LevelRole.WHITE) {class}`.GroupRole` 分为三种:(OWNER、ADMIN、MEMBER)。使用例子如下: ```python -from melobot.protocols.onebot.v11 import on_message -from melobot.protocols.onebot.v11.utils import MsgChecker, GroupRole +from melobot.protocols.onebot.v11 import on_message, MsgChecker, GroupRole # 与刚才的 LevelRole 类似,但此时其他参数传递无效 @on_message(checker=MsgChecker(role=GroupRole.OWNER)) @@ -118,7 +116,7 @@ async def _(): 此外,检查器之间也支持逻辑或与非,及逻辑异或运算,利用这一特性可以构建精巧的检查逻辑: ```python -from melobot.protocols.onebot.v11.utils import MsgCheckerFactory, LevelRole, GroupRole +from melobot.protocols.onebot.v11 import MsgCheckerFactory, LevelRole, GroupRole # 构建一个常用的检查逻辑: # 私聊只有 SUPER 级别可以使用;在群聊白名单的群中,成员白名单中的成员或任何群管可以使用 @@ -139,11 +137,11 @@ final_checker = priv_c | grp_c1 | grp_c2 其他高级特性:自定义检查失败回调等,请参考 [内置检查器与检查器工厂](onebot_v11_check) 中各种对象的参数。 -除了这些接口,melobot 内部其实也有一种隐式检查,这就是**基于依赖注入的区分调用**: +除了这些接口,先前教程中提到的“基于依赖注入的类型收窄”,实际上就是一种内部隐式检查: ```python -from melobot.protocols.onebot.v11 import on_message, on_event -from melobot.protocols.onebot.v11.adapter.event import GroupMessageEvent, PrivateMessageEvent +from melobot.protocols.onebot.v11 import on_message, on_event, GroupMessageEvent, \ + PrivateMessageEvent @on_message(...) async def msg_handle1(ev: GroupMessageEvent): @@ -155,8 +153,7 @@ async def msg_handle2(ev: PrivateMessageEvent): # 只有触发事件属于 私聊消息事件 时,才会进入这个处理方法 ... -from melobot.protocols.onebot.v11 import on_event -from melobot.protocols.onebot.v11.adapter.event import MessageEvent +from melobot.protocols.onebot.v11 import on_event, MessageEvent from melobot.log import Logger as MeloLogger @on_event(...) @@ -181,14 +178,13 @@ async def owner_only_echo(): ... ``` -或者使用更高级的方法(实现子类),这适用于更复杂的需求,例如检查/验证时需要保存某些状态信息: +或者使用更高级的方法(实现 melobot core 的抽象类),这适用于更复杂的需求,例如检查/验证时需要保存某些状态信息: ```python -from melobot.protocols.onebot.v11 import on_message -from melobot.protocols.onebot.v11.adapter.event import MessageEvent -from melobot.protocols.onebot.v11.utils import Checker +from melobot.utils.check import Checker +from melobot.protocols.onebot.v11 import on_message, MessageEvent -class FreqGuard(Checker): +class FreqGuard(Checker[MessageEvent]): def __init__(self) -> None: super().__init__() self.freq = 0 @@ -210,15 +206,17 @@ async def _(): ## 匹配 -匹配只对消息事件的文本内容生效。只有在匹配通过后,才能运行后续操作。其他事件绑定方法无法指定匹配。 +匹配只对文本事件 {class}`.TextEvent` 生效,所以在 OneBot 协议中就只对 {class}`.MessageEvent` 有效。只有在匹配通过后,才能运行后续操作。其他事件绑定方法无法指定匹配。 -常用的几个事件绑定接口,就是内置了匹配的流程:{func}`~.v11.handle.on_command`、{func}`~.v11.handle.on_start_match`、{func}`~.v11.handle.on_contain_match`、{func}`~.v11.handle.on_full_match`、{func}`~.v11.handle.on_end_match`、{func}`~.v11.handle.on_regex_match`。 +常用的几个通用事件绑定方法,就是内置了匹配的流程:{func}`~melobot.handle.on_command`、{func}`~melobot.handle.on_start_match`、{func}`~melobot.handle.on_contain_match`、{func}`~melobot.handle.on_full_match`、{func}`~melobot.handle.on_end_match`、{func}`~melobot.handle.on_regex_match`。 -对应的匹配器可查看文档:[内置匹配器](onebot_v11_match)。你也可以自定义匹配器: +melobot core 预置的匹配器可查看文档:[内置匹配器](melobot_match),OneBot 协议支持没有实现更多匹配器类型。 + +你也可以自定义匹配器: ```python +from melobot.utils.match import Matcher from melobot.protocols.onebot.v11 import on_message -from melobot.protocols.onebot.v11.utils import Matcher class StartEndMatch(Matcher): def __init__(self, start: str, end: str) -> None: @@ -233,11 +231,11 @@ async def _(): ... ``` -其他高级特性:自定义匹配成功回调,自定义匹配失败回调等,请参考 [内置匹配器](onebot_v11_match) 中各种对象的参数。 +其他高级特性:自定义匹配失败回调等,请参考 [内置匹配器](melobot_match) 中各种对象的参数。 ## 解析 -解析只对消息事件的文本内容生效。解析完成后将会生成一个 {class}`.ParseArgs` 对象。其他事件绑定方法无法指定解析。 +解析只对文本事件 {class}`.TextEvent` 生效,所以在 OneBot 协议中就只对 {class}`.MessageEvent` 有效。解析完成后将会生成一个 {class}`.AbstractParseArgs` 对象。其他事件绑定方法无法指定解析。 想象一个典型的使用案例,你需要: @@ -245,16 +243,15 @@ async def _(): - 匹配到“天气”指令的处理方法 - 传递参数列表 `["杭州", "7"]` 给处理方法,实现具体的逻辑。 -显然,自己编写指令解析是比较费劲的。可以使用 {class}`.CmdParser`,并利用 {func}`~.v11.handle.GetParseArgs` 获取解析参数: +显然,自己编写指令解析是比较费劲的。可以使用 melobot core 内置的命令解析器 {class}`.CmdParser`,并利用依赖注入获取解析参数: ```python -from melobot.protocols.onebot.v11 import on_message, ParseArgs -from melobot.protocols.onebot.v11.utils import CmdParser -from melobot.protocols.onebot.v11.handle import GetParseArgs +from melobot.utils.parse import CmdParser, CmdArgs +from melobot.protocols.onebot.v11 import on_message @on_message(parser=CmdParser(cmd_start='.', cmd_sep=' ', targets='天气')) -# 使用 GetParseArgs 进行依赖注入 -async def _(args: ParseArgs = GetParseArgs()): +# 使用 CmdArgs 进行依赖注入,它实际上是 AbstractParseArgs 的子类型 +async def _(args: CmdArgs): assert args.name == "天气" assert args.vals == ["杭州", "7"] ``` @@ -262,16 +259,15 @@ async def _(args: ParseArgs = GetParseArgs()): 需要多个指令起始符,多个指令间隔符,多个匹配的目标?这些也同样支持: ```python -from melobot.protocols.onebot.v11 import on_message, ParseArgs -from melobot.protocols.onebot.v11.utils import CmdParser -from melobot.protocols.onebot.v11.handle import GetParseArgs +from melobot.utils.parse import CmdParser, CmdArgs +from melobot.protocols.onebot.v11 import on_message @on_message(parser=CmdParser( cmd_start=[".", "~"], cmd_sep=[" ", "#"], targets=["天气", "weather"] )) -async def _(args: ParseArgs = GetParseArgs()): +async def _(args: CmdArgs): ... ``` @@ -290,7 +286,7 @@ async def _(args: ParseArgs = GetParseArgs()): cmd_sep=[" ", "#"], targets=["功能1", "功能2", "功能3"] )) -async def _(args: ParseArgs = GetParseArgs()): +async def _(args: CmdArgs): match args.name: case "功能1": func1(args.vals) @@ -304,14 +300,40 @@ async def _(args: ParseArgs = GetParseArgs()): 同理也可以实现子命令支持,这里不再演示。 +另外,为了方便识别出一组有不同 `name` 的解析参数,实际上都是同一个解析器解析出的结果,可以使用 `tag` 参数: + +```python +@on_message(parser=CmdParser( + cmd_start=".", + cmd_sep=" ", + targets=["echo", "回显"], + tag="bar" +)) +async def _(args: CmdArgs): + # 如果文本内容为:".回显 hi" + assert args.name == "回显" + assert args.tag == "bar" + +# 不指定 tag 时,自动设置为 targets 第一元素,或 targets 本身(如果为字符串) +@on_message(parser=CmdParser( + cmd_start=".", + cmd_sep=" ", + targets=["echo", "回显"] +)) +async def _(args: CmdArgs): + # 如果文本内容为:".回显 你好呀" + assert args.name == "回显" + assert args.tag == "echo" +``` + 使用 {func}`.on_message` 手动给定 {class}`.CmdParser` 还是略显麻烦。一般的情景,更建议使用 {func}`.on_command`: ```python -from melobot.protocols.onebot.v11 import on_command, ParseArgs -from melobot.protocols.onebot.v11.handle import GetParseArgs +from melobot.handle import on_command +from melobot.utils.parse import CmdArgs @on_command(cmd_start=[".", "~"], cmd_sep=[" ", "#"], targets=["天气", "weather"]) -async def _(args: ParseArgs = GetParseArgs()): +async def _(args: CmdArgs): ... ``` @@ -324,9 +346,9 @@ async def _(args: ParseArgs = GetParseArgs()): 下面是一个例子。这个 `add` 指令,接受两个浮点数,且第二参数可以有默认值: ```python -from melobot.protocols.onebot.v11 import on_command, ParseArgs -from melobot.protocols.onebot.v11.handle import GetParseArgs -from melobot.protocols.onebot.v11.utils import CmdArgFormatter as Fmtter +from melobot.handle import on_command +from melobot.utils.parse import CmdArgs +from melobot.utils.parse import CmdArgFormatter as Fmtter @on_command( cmd_start=".", @@ -353,7 +375,7 @@ from melobot.protocols.onebot.v11.utils import CmdArgFormatter as Fmtter ), ], ) -async def _(args: ParseArgs = GetParseArgs()): +async def _(args: CmdArgs): pass ``` @@ -376,9 +398,7 @@ fmtters = [Fmtter(...), None, Fmtter(...)] 此外,你还可以自定义“参数转换失败”、“参数验证失败”、“参数缺少”时的回调。比如直接静默,而不是在日志提示: ```python -from melobot.utils import to_async - -do_nothing = to_async(lambda *_: None) +do_nothing = lambda *_: None fmtters = [ Fmtter( @@ -391,11 +411,11 @@ fmtters = [ ] ``` -或者利用回调函数 {class}`.FormatInfo` 参数提供的信息,给用户回复提示: +或者利用回调函数 {class}`.CmdArgFormatInfo` 参数提供的信息,给用户回复提示: ```python from melobot import send_text -from melobot.protocols.onebot.v11.utils import FormatInfo +from melobot.utils.parse import CmdArgFormatInfo async def convert_fail(self, info: FormatInfo) -> None: e_class = f"{info.exc.__class__.__module__}.{info.exc.__class__.__qualname__}" @@ -451,19 +471,41 @@ fmtters = [ ## 自定义解析器 -使用内置的抽象类来自定义解析器: +实现 melobot core 内置的抽象类来自定义解析器: ```python +from melobot.utils.parse import Parser from melobot.protocols.onebot.v11 import on_message -from melobot.protocols.onebot.v11.utils import Parser class MyParser(Parser): - async def parse(text: str) -> ParseArgs | None: + async def parse(text: str) -> AbstractParseArgs | None: # 返回 None 代表没有有效的解析结果 ... @on_message(parser=MyParser()) -async def _(): +async def _(args: AbstractParseArgs): + ... +``` + +```python +from melobot.utils.parse import Parser, AbstractParseArgs +from dataclasses import dataclass + +# 还可以进一步子类化 AbstractParseArgs 提供信息更丰富的解析参数: +@dataclass +class MyCmdArgs(AbstractParseArgs): + name: str + tag: str | None + vals: list[Any] + +class MyCmdParser(Parser): + async def parse(text: str) -> MyCmdArgs | None: + ... + +# 是不是感觉很熟悉? +# 实际上,内置的 CmdParser,就是像上面这样实现的 :) +@on_message(parser=MyCmdParser()) +async def _(args: MyCmdArgs): ... ``` @@ -471,7 +513,7 @@ async def _(): 本篇主要说明了预处理机制中的检查、匹配和解析。 -消息事件绑定方法,检查、匹配和解析可以同时指定。顺序是:先检查,再匹配,最后解析。其他事件绑定方法,只能指定检查。 +对于文本事件的绑定方法,检查、匹配和解析可以同时指定。顺序是:先检查,再匹配,最后解析。而其他事件绑定方法,只能指定检查。 再次提醒,所有内置预处理机制,**均不是异步安全的**。若需要异步安全,请实现自定义类。 diff --git a/docs/source/intro/event-process.md b/docs/source/intro/event-process.md index 23d4eae5..b72422ef 100644 --- a/docs/source/intro/event-process.md +++ b/docs/source/intro/event-process.md @@ -6,8 +6,15 @@ 本文档将这些方法称为“绑定方法”,同时将绑定方法绑定的函数称为:“处理方法”或“处理函数”。 -- 绑定一个任意事件的处理方法:{func}`~.v11.handle.on_event` -- 绑定一个消息事件的处理方法:{func}`~.v11.handle.on_message`、{func}`~.v11.handle.on_at_qq`、{func}`~.v11.handle.on_command`、{func}`~.v11.handle.on_start_match`、{func}`~.v11.handle.on_contain_match`、{func}`~.v11.handle.on_full_match`、{func}`~.v11.handle.on_end_match`、{func}`~.v11.handle.on_regex_match` +协议独立的绑定方法有: + +- 绑定一个来自任意协议的,任意事件的处理方法:{func}`~melobot.handle.on_event` +- 绑定一个来自任意协议的,任意文本事件的处理方法:{func}`~melobot.handle.on_text`, {func}`~melobot.handle.on_command`、{func}`~melobot.handle.on_start_match`、{func}`~melobot.handle.on_contain_match`、{func}`~melobot.handle.on_full_match`、{func}`~melobot.handle.on_end_match`、{func}`~melobot.handle.on_regex_match` + +OneBot v11 协议特有的绑定方法有: + +- 绑定一个任意 OneBot v11 事件的处理方法:{func}`~.v11.handle.on_event` +- 绑定一个消息事件的处理方法:{func}`~.v11.handle.on_message`、{func}`~.v11.handle.on_at_qq`、 - 绑定一个请求事件的处理方法:{func}`~.v11.handle.on_request` - 绑定一个通知事件的处理方法:{func}`~.v11.handle.on_notice` - 绑定一个元事件的处理方法:{func}`~.v11.handle.on_meta` @@ -15,7 +22,7 @@ 这些绑定方法的参数很多,你可以先简单浏览。关于这些方法的使用,后续会详细讲解。现在让我们先学习一些基础知识。绑定方法的使用都一样,直接用作装饰器即可: ```python -from melobot.protocols.onebot.v11 import on_start_match +from melobot.handle import on_start_match @on_start_match(...) async def func() -> None: @@ -35,8 +42,7 @@ async def func() -> None: 通过类型注解驱动的依赖注入,即可方便地在处理方法中获得触发的事件。例如使用 {class}`.MessageEvent` 注解参数,melobot 将知道你需要一个消息事件作为 event 参数的值: ```python -from melobot.protocols.onebot.v11.adapter.event import MessageEvent -from melobot.protocols.onebot.v11 import on_message +from melobot.protocols.onebot.v11 import MessageEvent, on_message @on_message(...) async def func1(event: MessageEvent): @@ -55,7 +61,7 @@ async def func1(): ... ``` -需要注意的是,通用接口返回值将标注为事件基类型 {class}`~melobot.adapter.model.Event`,这可能不是你想要的,因此可以自行添加标注: +需要注意的是,通用接口返回值将标注为 melobot 的事件基类型 {class}`~melobot.adapter.model.Event`,这可能不是你想要的,因此可以自行添加标注: ```python e: MessageEvent = get_event() @@ -65,13 +71,58 @@ e: MessageEvent = get_event() 这些事件也有着各自的属性和方法,API 文档中也已说明。 +## 通用绑定函数与依赖注入 + +另外,通用的绑定方法,依然可以使用协议特定的事件进行注入: + +```python +from melobot.handle import on_event +from melobot.protocols.onebot.v11 import MessageEvent + +# 此通用接口支持任意事件类型,因此可以接收到 MessageEvent 这种子类型 +@on_event(...) +async def func(event: MessageEvent) -> None: + # 依赖注入会有类型担保,由于标注了 MessageEvent 类型, + # 因此 event 如果不是 MessageEvent 子类型,则不会进入处理方法 + # 由此实现了智能的类型收窄 + ... +``` + +同理,对于上面提到的,通用的文本事件的绑定接口,由于 {class}`.MessageEvent` 是文本事件基类 {class}`.TextEvent` 的子类,因此这样也是可以的: + +```python +from melobot.handle import on_start_match +from melobot.protocols.onebot.v11 import MessageEvent + +# 此接口首先限制必须为 TextEvent +@on_start_match(...) +async def func(event: MessageEvent) -> None: + # 随后注解将其收窄到 MessageEvent 类型 + ... +``` + +但这样显然就不太可以了: + +```python +from melobot.handle import on_start_match +from melobot.protocols.onebot.v11 import NoticeEvent + +# 此接口首先限制必须为 TextEvent +@on_start_match(...) +async def func(event: NoticeEvent) -> None: + # NoticeEvent 不是 TextEvent 子类, + # 还没到依赖注入类型收窄,NoticeEvent 就过不了 on_start_match 这一关 + ... +``` + +注意:OneBot v11 协议中,只有 {class}`.MessageEvent` 是 {class}`.TextEvent` 的子类。 + ## 基于事件信息的处理 通过事件对象提供的信息,可以实现更有趣的处理逻辑: ```{code} python -from melobot.protocols.onebot.v11 import on_start_match -from melobot.protocols.onebot.v11.adapter.event import MessageEvent +from melobot.protocols.onebot.v11 import on_start_match, MessageEvent from melobot import send_text OWNER_QID = 10001 @@ -111,9 +162,7 @@ async def say_hi(e: MessageEvent) -> None: 需要哪种类型的消息段,就传递哪种消息段的 `type` 作为参数: ```python -from melobot.protocols.onebot.v11.adapter.event import MessageEvent -from melobot.protocols.onebot.v11.adapter.segment import ImageSegment -from melobot.protocols.onebot.v11 import on_message +from melobot.protocols.onebot.v11 import MessageEvent, ImageSegment, on_message @on_message(...) async def _(e: MessageEvent): @@ -143,8 +192,7 @@ for img in e.get_segments("image"): 如果只需要 data 字段的某一参数,使用 {meth}`~.MessageEvent.get_datas` 即可: ```python -from melobot.protocols.onebot.v11.adapter.event import MessageEvent -from melobot.protocols.onebot.v11 import on_message +from melobot.protocols.onebot.v11 import MessageEvent, on_message @on_message(...) async def _(e: MessageEvent): diff --git a/docs/source/intro/msg-action.md b/docs/source/intro/msg-action.md index b4428226..d9fc1642 100644 --- a/docs/source/intro/msg-action.md +++ b/docs/source/intro/msg-action.md @@ -33,8 +33,7 @@ async def _(adapter: Adapter): 如果要发送多媒体内容,则只能使用适配器的 {meth}`~.v11.Adapter.send` 接口。首先构造**消息段对象**,然后传入 {meth}`~.v11.Adapter.send` 作为参数。例如: ```python -from melobot.protocols.onebot.v11 import Adapter, on_message -from melobot.protocols.onebot.v11.adapter.segment import ImageSegment +from melobot.protocols.onebot.v11 import Adapter, on_message, ImageSegment @on_message(...) async def _(adapter: Adapter): @@ -65,8 +64,7 @@ async def _(adapter: Adapter): 单条消息中,自然可能有多种类型的消息段同时存在。此时这样处理: ```python -from melobot.protocols.onebot.v11 import Adapter, on_message -from melobot.protocols.onebot.v11.adapter.segment import ImageSegment, TextSegment +from melobot.protocols.onebot.v11 import Adapter, on_message, ImageSegment, TextSegment @on_message(...) async def _(): @@ -140,7 +138,7 @@ async def _(adapter: Adapter): 除使用消息段对象外,也可以使用**CQ 字符串**直接表示单条消息的所有消息内容。但只能从消息段对象生成 cq 字符串: ```python -from melobot.protocols.onebot.v11.adapter.segment import ImageSegment +from melobot.protocols.onebot.v11 import ImageSegment img_cq: str = ImageSegment(file="https://example.com/test.jpg").to_cq() ``` @@ -162,7 +160,7 @@ CQ 字符串存在注入攻击的安全隐患。因此 melobot 不提供将 cq 构造转发消息段: ```python -from melobot.protocols.onebot.v11.adapter.segment import ForwardSegment +from melobot.protocols.onebot.v11 import ForwardSegment # forward_id 是转发 id,可通过消息事件的 get_datas("forward", "id") 获得 seg = ForwardSegment(forward_id) @@ -175,16 +173,20 @@ seg = ForwardSegment(forward_id) 构造合并转发结点: ```python -from melobot.protocols.onebot.v11.adapter.segment import NodeSegment +from melobot.protocols.onebot.v11 import NodeSegment, NodeReferSegment # 这里的 msg_id 是已存在的消息的 id,可通过消息事件的 id 获得 refer_node = NodeSegment(id=msg_id) + +# 等价的从子类构造形式,拥有更好的语义: +refer_node = NodeReferSegment(id=msg_id) ``` 构造合并转发自定义结点: ```python -from melobot.protocols.onebot.v11.adapter.segment import NodeSegment +from melobot.protocols.onebot.v11 import NodeSegment, NodeGocqCustomSegment, \ + NodeStdCustomSegment # content 是消息内容,与上述消息段发送方法(例如 send, send_custom)的第一参数相同 # 后续参数是在转发消息中显示的,发送人昵称 和 发送人的qq号(int 类型) @@ -201,6 +203,14 @@ node3 = NodeSegment( name="melobot instance", uin=10001 ) + +# 以上方法是按照 go-cq 风格构造的,如果需要使用 onebot v11 标准规定的格式: +node4 = NodeSegment(content=..., name=..., uin=..., use_std=True) + +# 等价的从子类构造 go-cq 风格的:(参数名稍有不同,可自行查阅 API 文档) +node5 = NodeGocqCustomSegment(...) +# 等价的从子类构造 标准 风格的:(参数名稍有不同,可自行查阅 API 文档) +node6 = NodeStdCustomSegment(...) ``` 将消息结点组成列表,就是一条转发消息的等价表达了,使用 {meth}`~.v11.Adapter.send_forward` 来发送它: diff --git a/docs/source/ob_api/index.rst b/docs/source/ob_api/index.rst index 404d849c..47dc1d94 100644 --- a/docs/source/ob_api/index.rst +++ b/docs/source/ob_api/index.rst @@ -12,20 +12,14 @@ protocols.onebot API - 目前只支持 OneBot v11,以下所有 API 均为 v11 的 API -以下组件可从 `melobot.protocols.onebot.v11` 命名空间直接导入: - -- :data:`.PROTOCOL_IDENTIFIER` -- :class:`~.v11.adapter.base.Adapter`, :class:`.EchoRequireCtx` -- :class:`.ForwardWebSocketIO`, :class:`.ReverseWebSocketIO`, :class:`.HttpIO` -- :class:`~.adapter.event.Event`, :class:`~.adapter.segment.Segment`, :class:`~.adapter.action.Action`, :class:`~.adapter.echo.Echo` -- :func:`.on_event`, :func:`.on_message`, :func:`.on_start_match`, :func:`.on_contain_match`, :func:`.on_full_match`, :func:`.on_end_match`, :func:`.on_regex_match`, :func:`.on_command`, :func:`.on_at_qq`, :func:`.on_notice`, :func:`.on_request`, :func:`.on_meta`, :func:`.msg_session` -- :class:`.LevelRole`, :class:`.GroupRole`, :class:`.ParseArgs` +所有公开的组件都可以从 `melobot.protocols.onebot.v11` 直接导入。 各模块 API 文档索引: .. toctree:: :maxdepth: 1 + v11 v11.const v11.adapter v11.io diff --git a/docs/source/ob_api/v11.const.rst b/docs/source/ob_api/v11.const.rst index aa017424..69d8441b 100644 --- a/docs/source/ob_api/v11.const.rst +++ b/docs/source/ob_api/v11.const.rst @@ -12,6 +12,10 @@ v11 协议常量 OneBot v11 协议版本 +.. data:: melobot.protocols.onebot.v11.PROTOCOL_SUPPORT_AUTHOR + + OneBot v11 协议支持模块的作者 + .. data:: melobot.protocols.onebot.v11.PROTOCOL_IDENTIFIER OneBot v11 协议 melobot 侧实现的唯一标识 diff --git a/docs/source/ob_api/v11.handle.rst b/docs/source/ob_api/v11.handle.rst index 3f364750..1f4f2892 100644 --- a/docs/source/ob_api/v11.handle.rst +++ b/docs/source/ob_api/v11.handle.rst @@ -5,4 +5,3 @@ v11 处理流相关接口 --------------------- .. automodule:: melobot.protocols.onebot.v11.handle - :exclude-members: ParseArgsCtx diff --git a/docs/source/ob_api/v11.io.rst b/docs/source/ob_api/v11.io.rst index b148c70b..f966627e 100644 --- a/docs/source/ob_api/v11.io.rst +++ b/docs/source/ob_api/v11.io.rst @@ -1,11 +1,23 @@ v11.io ====== -v11 输出输出层 ----------------- +v11 输出输出层抽象类 +-------------------- -.. autoclass:: melobot.protocols.onebot.v11.io.BaseIO - :exclude-members: __init__, open, opened, close, input, output, logger +.. autoclass:: melobot.protocols.onebot.v11.io.BaseSource + :exclude-members: __init__, open, opened, close + +.. autoclass:: melobot.protocols.onebot.v11.io.BaseInSource + :exclude-members: __init__, open, opened, close, input + +.. autoclass:: melobot.protocols.onebot.v11.io.BaseOutSource + :exclude-members: __init__, open, opened, close, output + +.. autoclass:: melobot.protocols.onebot.v11.io.BaseIOSource + :exclude-members: __init__, open, opened, close, input, output + +v11 输出输出层实现类 +-------------------- .. autoclass:: melobot.protocols.onebot.v11.io.ForwardWebSocketIO :exclude-members: open, close, input, output diff --git a/docs/source/ob_api/v11.rst b/docs/source/ob_api/v11.rst new file mode 100644 index 00000000..b99bba8c --- /dev/null +++ b/docs/source/ob_api/v11.rst @@ -0,0 +1,5 @@ +v11 +=== + +.. autoclass:: melobot.protocols.onebot.v11.OneBotV11Protocol + :members: diff --git a/docs/source/ob_api/v11.utils.rst b/docs/source/ob_api/v11.utils.rst index 88abd070..ef245b1e 100644 --- a/docs/source/ob_api/v11.utils.rst +++ b/docs/source/ob_api/v11.utils.rst @@ -3,21 +3,19 @@ v11.utils .. _onebot_v11_check: -v11 检查(验证)器 +检查/验证 -------------------- -.. autoclass:: melobot.protocols.onebot.v11.utils.abc.Checker - :exclude-members: __init__ - -.. autoclass:: melobot.protocols.onebot.v11.utils.abc.WrappedChecker - :exclude-members: __init__, check - .. autoclass:: melobot.protocols.onebot.v11.utils.LevelRole :exclude-members: __new__ .. autoclass:: melobot.protocols.onebot.v11.utils.GroupRole :exclude-members: __new__ +.. autofunction:: melobot.protocols.onebot.v11.utils.get_level_role + +.. autofunction:: melobot.protocols.onebot.v11.utils.get_group_role + .. autoclass:: melobot.protocols.onebot.v11.utils.MsgChecker .. autoclass:: melobot.protocols.onebot.v11.utils.GroupMsgChecker @@ -27,45 +25,3 @@ v11 检查(验证)器 .. autoclass:: melobot.protocols.onebot.v11.utils.MsgCheckerFactory .. autoclass:: melobot.protocols.onebot.v11.utils.AtMsgChecker - -.. _onebot_v11_match: - -v11 匹配器 --------------- - -.. autoclass:: melobot.protocols.onebot.v11.utils.abc.Matcher - :exclude-members: __init__ - -.. autoclass:: melobot.protocols.onebot.v11.utils.abc.WrappedMatcher - :exclude-members: __init__, match - -.. autoclass:: melobot.protocols.onebot.v11.utils.StartMatcher - -.. autoclass:: melobot.protocols.onebot.v11.utils.ContainMatcher - -.. autoclass:: melobot.protocols.onebot.v11.utils.EndMatcher - -.. autoclass:: melobot.protocols.onebot.v11.utils.FullMatcher - -.. autoclass:: melobot.protocols.onebot.v11.utils.RegexMatcher - -.. _onebot_v11_parse: - -v11 解析器 --------------- - -.. autoclass:: melobot.protocols.onebot.v11.utils.abc.Parser - :exclude-members: __init__ - -.. autoclass:: melobot.protocols.onebot.v11.utils.abc.ParseArgs - :exclude-members: __init__ - -.. autoclass:: melobot.protocols.onebot.v11.utils.CmdParser - :exclude-members: format - -.. autoclass:: melobot.protocols.onebot.v11.utils.CmdArgFormatter - -.. autoclass:: melobot.protocols.onebot.v11.utils.FormatInfo - :exclude-members: __init__ - -.. autoclass:: melobot.protocols.onebot.v11.utils.CmdParserFactory diff --git a/docs/source/ob_refer/preprocess.md b/docs/source/ob_refer/preprocess.md index 2cebad01..880ffdbe 100644 --- a/docs/source/ob_refer/preprocess.md +++ b/docs/source/ob_refer/preprocess.md @@ -2,22 +2,26 @@ ## 预处理简介 -melobot 的 OneBot 支持中提供了一些常用的工具,用来在事件处理前进行一些特定的操作。 - -这些操作一般称为预处理。以下是内置的预处理支持。 +melobot core 以及 melobot 的 OneBot 支持中提供了一些常用的工具,用来在事件处理前进行一些特定的操作。这些操作一般称为预处理。以下是内置的预处理流程。 ## 检查(校验) 检查当前事件是否满足某些条件,如果满足则通过检查。 +OneBot 协议支持部分提供了一些额外的检查组件,melobot core 组件只提供了抽象接口。 + ## 匹配 检查当前消息事件的消息内容,是否满足某些匹配条件。例如:以 `xxx` 起始,包含 `xxx` 或可以通过指定的正则表达式匹配等。 +OneBot 协议支持部分没有提供额外的匹配组件,主要依靠来自 melobot core 的匹配组件。 + ## 解析 对当前消息事件的消息内容进行解析,并返回一个包含解析结果的 {class}`.ParseArgs` 对象。 +OneBot 协议支持部分没有提供额外的解析组件,主要依靠来自 melobot core 的解析组件。 + ```{admonition} 重要提示 :class: attention 内置的所有预处理支持,**都不是异步安全的**。 diff --git a/docs/source/update-log.md b/docs/source/update-log.md new file mode 100644 index 00000000..bc3f2b3b --- /dev/null +++ b/docs/source/update-log.md @@ -0,0 +1,258 @@ + + +# 更新日志 + +## v3.1.0 + +### ⏩变更 + +- [core] 改进了内部事件分发机制,现在所有情况下的事件处理都不再阻塞分发。原始的处理流优先级枚举 `HandleLevel` 已移除,现在通过 int 值定义优先级,默认处理流优先级为 0 ([e2aaa72](https://github.com/Meloland/melobot/commit/e2aaa72)) + +- [core] {func}`.async_later` 和 {func}`.async_at` 现在返回 {external:class}`~asyncio.Task` 而不是 {external:class}`~asyncio.Future` ([3b7bea2](https://github.com/Meloland/melobot/commit/3b7bea2)) + +- [core] 现在插件的 {attr}`.PluginLifeSpan.INITED` 生命周期钩子结束前,该插件所有处理流不会生效。避免通过此钩子运行异步初始化,初始化未完成处理流就先启动的不合理现象。如果需要避免额外的等待,请在钩子函数内使用 {external:func}`~asyncio.create_task` ([ca06a05](https://github.com/Meloland/melobot/commit/ca06a05)) + +- [All] 绝大多数只支持 {class}`.AsyncCallable` 参数的接口,现在变更为支持 {class}`.SyncOrAsyncCallable`,参数可接受同步或异步的可调用对象 ([b6c7f24](https://github.com/Meloland/melobot/commit/b6c7f24)) + +- [All] 为避免依赖注入出现问题,现在不能在 `on_xxx` 函数下方使用装饰器 ([9ba2265](https://github.com/Meloland/melobot/commit/9ba2265)),必须通过 `decos` 参数: + +```python +# 现在不能再使用以下形式: +@on_xxx(...) +@aaa(...) +@bbb(...) +async def _(): ... + +# 需要换为: +@on_xxx(..., decos=[aaa(...), bbb(...)]) +async def _(): ... +``` + +- [OneBot] 移除了有严重问题而无法修复的 `msg_session` 函数 ([1a372de](https://github.com/Meloland/melobot/commit/1a372de)),推荐使用 {class}`.DefaultRule` 或 `legacy_session` 参数或 `rule` 参数替代: + +```python +# 原始用法 +with msg_session(): ... + +# 现在的替代方法: + +_RULE = DefaultRule() +# 注意不要直接在 enter_session 中初始化 +# 这样会导致每次生成一个新的 rule +with enter_session(_RULE): ... + +# 或者 + +# 对于 on_xxx 接口,如有 legacy_session 参数, +# 置为 True 实现类似 msg_session 效果 +@on_xxx(..., legacy_session=True) +async def _(): + # 注意进入会话在所有 decos 装饰器之前 + # 如果这个顺序不符合你的需求,还是建议在 decos 中使用 unfold_ctx(enter_session(...)) + ... + +# 或者 + +class MyRule(Rule): ... +# 对于 on_xxx 接口,如有 rule 参数 +# 可以直接在这里初始化规则,并提供 +@on_xxx(..., rule=MyRule()) +async def _(): + # 注意进入会话在所有 decos 装饰器之前 + ... +``` + +- [OneBot] 部分 api 已并入 melobot core。尝试按原样导入并使用这些 api 依然可以工作,但会发出弃用警告。兼容原样导入的行为将在 `3.1.1` 移除 ([841eddd](https://github.com/Meloland/melobot/commit/841eddd)),请及时迁移。我们强烈建议您重新阅读一遍 [相关使用方法](./intro/event-preprocess) 来了解**新 api 的使用技巧**。以下是变动的 api: + +```shell +# 原始位置 (onebot 模块是 melobot.protocols.onebot) -> 新位置 +onebot.v11.utils.Checker -> melobot.utils.check.Checker + +onebot.v11.utils.Matcher -> melobot.utils.match.Matcher +onebot.v11.utils.StartMatcher -> melobot.utils.match.StartMatcher +onebot.v11.utils.ContainMatcher -> melobot.utils.match.ContainMatcher +onebot.v11.utils.EndMatcher -> melobot.utils.match.EndMatcher +onebot.v11.utils.FullMatcher -> melobot.utils.match.FullMatcher +onebot.v11.utils.RegexMatcher -> melobot.utils.match.RegexMatcher + +onebot.v11.utils.Parser -> melobot.utils.parse.Parser +onebot.v11.utils.ParseArgs -> melobot.utils.parse.CmdArgs +onebot.v11.utils.CmdParser -> melobot.utils.parse.CmdParser +onebot.v11.utils.CmdParserFactory -> melobot.utils.parse.CmdParserFactory +onebot.v11.utils.CmdArgFormatter -> melobot.utils.parse.CmdArgFormatter +onebot.v11.utils.FormatInfo -> melobot.utils.parse.CmdArgFormatInfo + +onebot.v11.handle.on_start_match -> melobot.handle.on_start_match +onebot.v11.handle.on_contain_match -> melobot.handle.on_contain_match +onebot.v11.handle.on_end_match -> melobot.handle.on_end_match +onebot.v11.handle.on_full_match -> melobot.handle.on_full_match +onebot.v11.handle.on_regex_match -> melobot.handle.on_regex_match +onebot.v11.handle.on_command -> melobot.handle.on_command + +# 特别注意,此 api 原本用作默认值,表示需要一个解析参数。但现在只需要注解类型即可 +# 但此 api 依然可以使用,但下一版本直接删除,在整个项目中将完全不存在 +onebot.v11.handle.GetParseArgs -> melobot.handle.GetParseArgs +``` + +- [OneBot] 除以上变更的 api 外,其余 onebot 协议支持部分的公开接口可以从 `melobot.protocols.onebot.v11` 直接导入 ([e2aaa72](https://github.com/Meloland/melobot/commit/e2aaa72)) + +### ✨ 新增 + +- [core] {class}`.Rule` 现在支持两种抽象方法 {meth}`~.Rule.compare` 与 {meth}`~.Rule.compare_with` (提供更有用的对比信息),二选一实现即可 ([ef173c6](https://github.com/Meloland/melobot/commit/ef173c6)) + +- [core] {class}`.SessionStore` 现在可以使用 set 方法设置值,方便链式调用 ([36b555e](https://github.com/Meloland/melobot/commit/36b555e)) + +```python +# 等价于 store[key] = value +store.set(key, value) +``` + +- [core] 现在会话可以被直接依赖注入,或在当前上下文通过 {func}`.get_session` 获取 ([e2aaa72](https://github.com/Meloland/melobot/commit/e2aaa72))。例如: + +```python +from melobot.session import Session, get_session + +@on_xxx(...) +async def _(session: Session, ...): ... + +# 或者 + +@on_xxx(...) +async def _(...): + # 获取当前上下文中的会话对象 + session = get_session() +``` + +- [core] 新增接口兼容装饰器函数 {func}`.to_sync`,非常不常用。极少数需要兼容同步接口时使用 ([ca06a05](https://github.com/Meloland/melobot/commit/ca06a05)) + +- [core] {func}`.if_not` 现在支持新参数 `accept`,作为条件为真时执行的回调 ([d89d62e](https://github.com/Meloland/melobot/commit/d89d62e)) + +- [core] 新增了跨协议的 `on_xxx` 方法:{func}`~.melobot.handle.on_event`(用于绑定任意协议的任意事件处理方法) 和 {func}`~.melobot.handle.on_text`(用于绑定任意文本事件处理方法)([841eddd](https://github.com/Meloland/melobot/commit/841eddd)) + +- [core] 现在 {class}`.CmdParser` 支持初始化参数 `tag` ([a7a183e](https://github.com/Meloland/melobot/commit/a7a183e), [9fdde3b](https://github.com/Meloland/melobot/commit/9fdde3b)),该值会传递给解析参数,用于标识: + +```python +parser = CmdParser(cmd_start=".", cmd_sep=" ", targets=["echo", "回显"], tag="bar") +args = await parser.parse(".回显 hi") +if args is not None: + assert args.name == "回显" + assert args.tag == "bar" + +# 不指定 tag 时,自动设置为 targets 第一元素,或 targets 本身(如果为字符串) +parser = CmdParser(cmd_start=".", cmd_sep=" ", targets=["echo", "回显"]) +args = await parser.parse(".回显 你好呀") +if args is not None: + assert args.name == "回显" + assert args.tag == "echo" +``` + +- [core] 新增了用于合并检查器序列的函数 {func}`.checker_join`。相比于使用 | & ^ ~ 运算符,此函数可以接受检查器序列,并返回一个合并检查器。检查器序列可以为检查器对象,检查函数或空值 ([841eddd](https://github.com/Meloland/melobot/commit/841eddd)) + +- [core] 现在支持动态增加、删除处理流,以及变更处理流的优先级 ([e2aaa72](https://github.com/Meloland/melobot/commit/e2aaa72))。例如: + +```python +# 在 BotLifeSpan.STARTED 生命周期之后,可以动态的增加处理流: +bot.add_flows(...) + +# 在任何时候拿到处理流对象后,可以移除该处理流 +# 如果在某处理流内部移除此处理流,依然不影响本次处理过程 +flow.dismiss() + +# 在任何时候拿到处理流对象后,可以更新优先级 +# 如果在某处理流内部更新此处理流优先级,依然不影响本次处理过程 +flow.update_priority(priority=3) +``` + +- [core] 新增了一些 mixin 类,主要提供给协议支持的开发者,参考文档中的 [melobot.mixin](./api/melobot.mixin) 部分。插件与 bot 开发者无需关心 ([e2aaa72](https://github.com/Meloland/melobot/commit/e2aaa72), [6f8253e](https://github.com/Meloland/melobot/commit/6f8253e)) + +- [All] 多数 `on_xxx` 接口提供了新参数 `rule`,用于在内部自动展开会话 ([52a1e7b](https://github.com/Meloland/melobot/commit/52a1e7b))。先前已有示例,此处不再演示。 + +- [OneBot] 新增 {func}`.get_group_role` 和 {func}`.get_level_role` 用于获取权限等级 ([65d447e](https://github.com/Meloland/melobot/commit/65d447e)) + +- [OneBot] 新增 {class}`.OneBotV11Protocol` 协议栈对象,启动代码现在更为简洁 ([6f8253e](https://github.com/Meloland/melobot/commit/6f8253e)),例如: + +```python +from melobot import Bot +from melobot.protocols.onebot.v11 import ForwardWebSocketIO, OneBotV11Protocol + +bot = Bot() +# 无需再手动添加适配器 +bot.add_protocol(OneBotV11Protocol(ForwardWebSocketIO(...))) +bot.load_plugin(...) +... +bot.run() +``` + +### 👍修复 + +- [core] 通过依赖注入获取适配器时,返回值可能为空的错误 ([1c7d170](https://github.com/Meloland/melobot/commit/1c7d170)) + +- [core] 更改部分内置数据结构为集合,避免重复添加元素导致未定义行为 ([7ec9709](https://github.com/Meloland/melobot/commit/7ec9709)) + +- [All] 小幅度提升异步任务创建的性能,修复一些异步任务序列为空可能导致的错误,以及更好的异常提示 ([33c1c68](https://github.com/Meloland/melobot/commit/33c1c68)) + +- [OneBot] 优化了 event, action 和 echo 对象的 repr 显示。在调试时或错误日志中,repr 不再显示为超长字符串 ([33c1c68](https://github.com/Meloland/melobot/commit/33c1c68)) + +- [OneBot] 现在使用更安全的校验。意外传递反射依赖项到 checker 不再会导致校验默认通过 ([dcf782f](https://github.com/Meloland/melobot/commit/dcf782f)) + +- [OneBot] 小幅度提升了 event 与 echo 验证错误时的回调的执行性能 ([4af5422](https://github.com/Meloland/melobot/commit/4af5422)) + + +## v3.0.0 + +### ⚠️特别: + +- **melobot v3 是跨平台、跨协议、支持多路 IO 及其他高级特性的 bot 开发框架,与 v2 完全不兼容。** + +- v3 文档教程:[melobot docs](https://docs.melobot.org) + +| 特色 | 描述 | +| -------------- | ------------------------------------------------------------ | +| 实用接口 | 封装高频使用的异步逻辑,使业务开发更简洁 | +| 插件管理 | 低耦合度、无序的插件加载与通信 | +| 处理流设计 | 可自由组合“处理中间件”为处理流,提升了各组件的复用率 | +| 热插拔/重启 | 处理流支持动态热插拔,支持 bot 级别的重启 | +| 会话支持 | 可在处理流中自动传递的、可自定义的会话上下文 | +| 协议支持 | 所有协议被描述为 IO 过程,因此支持各类协议 | +| 跨平台 | 更简洁的跨平台接口,便捷实现跨平台插件开发 | +| 跨协议 IO | 支持多个协议实现端同时输入,自由输出到指定协议实现端 | +| 日志支持 | 日志记录兼容标准库和绝大多数日志框架,可自行选择 | + +对比上一预发布版本 `3.0.0rc21`,主要有: + +### ⏩变更 + +- [core] 移除计划移除的 api 和组件(移除了方法 `Args`, `Context.in_ctx` 与传统插件类 `Plugin`)([ec518f5](https://github.com/Meloland/melobot/commit/ec518f5)) + +- [core] 改进了 io 层的 packet 限制,现在所有 packet 不再是 `frozen` 的 ([88eeb85](https://github.com/Meloland/melobot/commit/88eeb85)) + +- [core] 改进了 adapter 层的组件,现在钩子 `BEFORE_EVENT` 重命名为 `BEFORE_EVENT_CREATE`,钩子 `BEFORE_ACTION` 重命名为 `BEFORE_ACTION_EXEC` ([d50d3a3](https://github.com/Meloland/melobot/commit/d50d3a3)) + +### ✨ 新增 + +- [core] 内置日志器添加 `yellow_warn` 参数,可在智能着色模式下强制警告消息为醒目的黄色 ([0dae81d](https://github.com/Meloland/melobot/commit/0dae81d)) + +- [core] 现在使用 {class}`.PluginPlanner` 声明插件及各种插件功能 ([4508081](https://github.com/Meloland/melobot/commit/4508081)) + +- [core] {class}`.PluginPlanner` 现在支持使用 {meth}`~.PluginPlanner.use` 装饰器来收集各种插件组件(处理流、共享对象与导出函数)([ecec685](https://github.com/Meloland/melobot/commit/ecec685)) + +- [OneBot] 添加了用于处理 OneBot v11 实体(事件、动作与回应)数据模型验证异常的 OneBot v11 适配器接口 {meth}`~.protocols.onebot.v11.adapter.base.Adapter.when_validate_error` ([4bddb6a](https://github.com/Meloland/melobot/commit/4bddb6a), [0589f3a](https://github.com/Meloland/melobot/commit/0589f3a), [a4d35b3](https://github.com/Meloland/melobot/commit/a4d35b3)) + +### 👍修复 + +- [OneBot] 自定义消息段类型创建和解析 ([3026543](https://github.com/Meloland/melobot/commit/3026543), [51f7cbe](https://github.com/Meloland/melobot/commit/51f7cbe), [f006ee0](https://github.com/Meloland/melobot/commit/f006ee0), [819489f](https://github.com/Meloland/melobot/commit/819489f)) + +- [OneBot] 正向 websocket IO 源忽略 bot 停止信号 ([da0e3df](https://github.com/Meloland/melobot/commit/da0e3df)) + +- [All] 项目各处类型注解的改进 ([1bd8760](https://github.com/Meloland/melobot/commit/1bd8760)) + +- [All] 文档与内置异常提示更正 + +### ♥️新贡献者 + +* [@Asankilp](https://github.com/Asankilp) 首次提交 [#14](https://github.com/Meloland/melobot/pull/14) + +* [@NingmengLemon](https://github.com/NingmengLemon) 首次提交 [#15](https://github.com/Meloland/melobot/pull/15) diff --git a/src/melobot/__init__.py b/src/melobot/__init__.py index 679c9dbc..a86ff152 100644 --- a/src/melobot/__init__.py +++ b/src/melobot/__init__.py @@ -4,8 +4,23 @@ from .bot import Bot, get_bot from .ctx import Context from .di import Depends -from .handle import Flow, FlowStore, node, rewind, stop +from .handle import ( + Flow, + FlowDecorator, + FlowStore, + node, + on_command, + on_contain_match, + on_end_match, + on_event, + on_full_match, + on_regex_match, + on_start_match, + on_text, + rewind, + stop, +) from .log import GenericLogger, Logger, LogLevel, get_logger from .plugin import AsyncShare, PluginInfo, PluginLifeSpan, PluginPlanner, SyncShare -from .session import Rule, SessionStore, enter_session, suspend -from .typ import HandleLevel, LogicMode +from .session import DefaultRule, Rule, Session, SessionStore, enter_session, suspend +from .typ._enum import LogicMode diff --git a/src/melobot/_hook.py b/src/melobot/_hook.py index 5be0569c..de747aaa 100644 --- a/src/melobot/_hook.py +++ b/src/melobot/_hook.py @@ -1,13 +1,15 @@ import asyncio import time +from asyncio import Task from enum import Enum -from typing_extensions import Any, Callable, Generic, TypeVar +from typing_extensions import Any, Generic, TypeVar from .ctx import LoggerCtx from .di import inject_deps from .log.base import LogLevel -from .typ import AsyncCallable, P +from .typ.base import AsyncCallable, SyncOrAsyncCallable +from .utils import to_async, to_sync HookEnumT = TypeVar("HookEnumT", bound=Enum) @@ -55,10 +57,10 @@ def set_tag(self, tag: str | None) -> None: def register( self, hook_type: HookEnumT, - hook_func: AsyncCallable[..., None], + hook_func: SyncOrAsyncCallable[..., None], once: bool = True, ) -> None: - runner = HookRunner(hook_type, hook_func, once) + runner = HookRunner(hook_type, to_async(hook_func), once) self._hooks[hook_type].append(runner) def get_evoke_time(self, hook_type: HookEnumT) -> float: @@ -72,6 +74,7 @@ async def emit( *, args: tuple | None = None, kwargs: dict[str, Any] | None = None, + callback: SyncOrAsyncCallable[[Task[None] | None], None] | None = None, ) -> None: self._stamps[hook_type] = time.time_ns() / 1e9 args = args if args is not None else () @@ -85,31 +88,18 @@ async def emit( msg = f"开始 hook: {msg}" logger.debug(msg) - tasks = [ + tasks = tuple( asyncio.create_task(runner.run(*args, **kwargs)) for runner in self._hooks[hook_type] - ] + ) + + if callback is not None: + if len(tasks): + for t in tasks: + t.add_done_callback(to_sync(callback)) + else: + to_sync(callback)(None) + return + if wait and len(tasks): await asyncio.wait(tasks) - - -class Hookable(Generic[HookEnumT]): - def __init__(self, hook_type: type[HookEnumT], tag: str | None = None): - super().__init__() - self._hook_bus = HookBus[HookEnumT](hook_type, tag) - - def on( - self, *periods: HookEnumT - ) -> Callable[[AsyncCallable[P, None]], AsyncCallable[P, None]]: - """注册一个 hook - - :param periods: 要绑定的 hook 类型 - :return: 装饰器 - """ - - def wrapped(func: AsyncCallable[P, None]) -> AsyncCallable[P, None]: - for type in periods: - self._hook_bus.register(type, func) - return func - - return wrapped diff --git a/src/melobot/_imp.py b/src/melobot/_imp.py index 7a055350..f68487b3 100644 --- a/src/melobot/_imp.py +++ b/src/melobot/_imp.py @@ -24,7 +24,7 @@ from typing_extensions import Any, Sequence, cast from .exceptions import DynamicImpError -from .utils import singleton +from .utils.common import singleton ALL_EXTS = tuple(all_suffixes()) EMPTY_PKG_TAG = "__melobot_namespace_pkg__" @@ -86,13 +86,17 @@ def find_spec( # 再次是 zip 文件导入 if entry_path.suffix == ".zip" and entry_path.exists(): - zip_importer = zipimport.zipimporter( # pylint: disable=no-member - str(entry_path) - ) + zip_importer = zipimport.zipimporter(str(entry_path)) spec = zip_importer.find_spec(fullname, target) if spec is not None: - assert spec.origin is not None and spec.origin != "" - assert spec.loader is not None + assert spec.origin is not None and spec.origin != "", ( + f"zip file from {entry_path}, module named {fullname} from {target}, " + "failed to get spec origin" + ) + assert spec.loader is not None, ( + f"zip file from {entry_path}, module named {fullname} from {target}, " + "spec has no loader" + ) spec.loader = ModuleLoader( fullname, Path(spec.origin).resolve(), @@ -130,7 +134,9 @@ def find_spec( ), submodule_search_locations=loader._path, # type: ignore[attr-defined] ) - assert spec is not None + assert ( + spec is not None + ), f"package from {dir_path} without __init__.py create spec failed" spec.has_location = False spec.origin = None setattr(spec, EMPTY_PKG_TAG, True) @@ -341,7 +347,9 @@ def import_mod( ) mod = module_from_spec(spec) - assert spec.loader is not None + assert ( + spec.loader is not None + ), f"module named {name} and path from {path} has no loader" spec.loader.exec_module(mod) return mod diff --git a/src/melobot/_meta.py b/src/melobot/_meta.py index 0882ee5b..880f8f38 100644 --- a/src/melobot/_meta.py +++ b/src/melobot/_meta.py @@ -2,9 +2,9 @@ from typing_extensions import Any, ClassVar, Generic, Literal, NamedTuple, NoReturn -from .typ import T +from .typ.base import T -__version__ = "3.0.0" +__version__ = "3.1.0" def _version_str_to_info(s: str) -> VersionInfo: diff --git a/src/melobot/adapter/__init__.py b/src/melobot/adapter/__init__.py index cd242fb3..b06ded84 100644 --- a/src/melobot/adapter/__init__.py +++ b/src/melobot/adapter/__init__.py @@ -6,4 +6,16 @@ AdapterLifeSpan, ) from .content import Content -from .model import Action, ActionChain, ActionHandle, Echo, Event, open_chain +from .model import ( + Action, + ActionChain, + ActionHandle, + ActionRetT, + ActionT, + Echo, + EchoT, + Event, + EventT, + TextEvent, + open_chain, +) diff --git a/src/melobot/adapter/base.py b/src/melobot/adapter/base.py index 3cf7f8ab..f2f4b9c0 100644 --- a/src/melobot/adapter/base.py +++ b/src/melobot/adapter/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from abc import abstractmethod from asyncio import create_task from contextlib import AsyncExitStack, _GeneratorContextManager, asynccontextmanager from enum import Enum @@ -21,8 +22,7 @@ final, ) -from .._hook import Hookable -from ..ctx import EventBuildInfo, EventBuildInfoCtx, LoggerCtx, OutSrcFilterCtx +from ..ctx import EventOrigin, FlowCtx, LoggerCtx, OutSrcFilterCtx from ..exceptions import AdapterError from ..io.base import ( AbstractInSource, @@ -34,7 +34,8 @@ OutSourceT, ) from ..log.base import LogLevel -from ..typ import AsyncCallable, BetterABC, P, abstractmethod +from ..mixin import HookMixin +from ..typ.cls import BetterABC from .content import Content from .model import ActionHandle, ActionT, EchoT, Event, EventT @@ -42,8 +43,8 @@ from ..bot.dispatch import Dispatcher -_EVENT_BUILD_INFO_CTX = EventBuildInfoCtx() _OUT_SRC_FILTER_CTX = OutSrcFilterCtx() +_FLOW_CTX = FlowCtx() class AbstractEventFactory(BetterABC, Generic[InPacketT, EventT]): @@ -104,7 +105,7 @@ class AdapterLifeSpan(Enum): class Adapter( - BetterABC, + HookMixin[AdapterLifeSpan], Generic[ EventFactoryT, OutputFactoryT, @@ -113,7 +114,7 @@ class Adapter( InSourceT, OutSourceT, ], - Hookable[AdapterLifeSpan], + BetterABC, ): """适配器基类 @@ -130,29 +131,20 @@ def __init__( output_factory: OutputFactoryT, echo_factory: EchoFactoryT, ) -> None: - Hookable.__init__(self, AdapterLifeSpan, tag=protocol) + super().__init__(hook_type=AdapterLifeSpan, hook_tag=protocol) self.protocol = protocol - self.in_srcs: list[InSourceT] = [] - self.out_srcs: list[OutSourceT] = [] + self.in_srcs: set[InSourceT] = set() + self.out_srcs: set[OutSourceT] = set() self.dispatcher: "Dispatcher" self._event_factory = event_factory self._output_factory = output_factory self._echo_factory = echo_factory self._inited = False - - def on( - self, *periods: AdapterLifeSpan - ) -> Callable[[AsyncCallable[P, None]], AsyncCallable[P, None]]: - groups = (AdapterLifeSpan.BEFORE_EVENT_HANDLE, AdapterLifeSpan.BEFORE_ACTION_EXEC) - - def wrapped(func: AsyncCallable[P, None]) -> AsyncCallable[P, None]: - for type in periods: - self._hook_bus.register(type, func, once=type not in groups) - return func - - return wrapped + self.__mark_repeatable_hooks__( + AdapterLifeSpan.BEFORE_EVENT_HANDLE, AdapterLifeSpan.BEFORE_ACTION_EXEC + ) @final def get_isrcs(self, filter: Callable[[InSourceT], bool]) -> set[InSourceT]: @@ -178,12 +170,12 @@ async def __adapter_input_loop__(self, src: InSourceT) -> NoReturn: while True: try: packet = await src.input() - event = await self._event_factory.create(packet) - with _EVENT_BUILD_INFO_CTX.unfold(EventBuildInfo(self, src)): - await self._hook_bus.emit( - AdapterLifeSpan.BEFORE_EVENT_HANDLE, True, args=(event,) - ) - asyncio.create_task(self.dispatcher.broadcast(event)) + event: Event = await self._event_factory.create(packet) + EventOrigin.set_origin(event, EventOrigin(self, src)) + await self._hook_bus.emit( + AdapterLifeSpan.BEFORE_EVENT_HANDLE, True, args=(event,) + ) + self.dispatcher.broadcast(event) except Exception: logger.exception(f"适配器 {self} 处理输入与分发事件时发生异常") logger.generic_obj("异常点局部变量:", locals(), level=LogLevel.ERROR) @@ -244,8 +236,9 @@ async def call_output(self, action: ActionT) -> tuple[ActionHandle, ...]: osrcs: Iterable[OutSourceT] filter = _OUT_SRC_FILTER_CTX.try_get() cur_isrc: AbstractInSource | None - if info := _EVENT_BUILD_INFO_CTX.try_get(): - cur_isrc = info.in_src + + if event := _FLOW_CTX.try_get_event(): + cur_isrc = EventOrigin.get_origin(event).in_src else: cur_isrc = None @@ -254,7 +247,7 @@ async def call_output(self, action: ActionT) -> tuple[ActionHandle, ...]: elif isinstance(cur_isrc, AbstractOutSource): osrcs = (cast(OutSourceT, cur_isrc),) else: - osrcs = (self.out_srcs[0],) if len(self.out_srcs) else () + osrcs = self.out_srcs if len(self.out_srcs) else () await self._hook_bus.emit( AdapterLifeSpan.BEFORE_ACTION_EXEC, True, args=(action,) diff --git a/src/melobot/adapter/generic.py b/src/melobot/adapter/generic.py index d62f568e..5cdb003d 100644 --- a/src/melobot/adapter/generic.py +++ b/src/melobot/adapter/generic.py @@ -2,16 +2,17 @@ from typing_extensions import Sequence -from ..ctx import EventBuildInfoCtx +from ..ctx import EventOrigin, FlowCtx +from .base import Adapter from .content import Content from .model import ActionHandle, Event -_CTX = EventBuildInfoCtx() +_CTX = FlowCtx() async def send_text(text: str) -> tuple[ActionHandle, ...]: """通用文本输出方法""" - return await _CTX.get().adapter.__send_text__(text) + return await _get_ctx_adapter().__send_text__(text) async def send_media( @@ -21,7 +22,7 @@ async def send_media( mimetype: str | None = None, ) -> tuple[ActionHandle, ...]: """通用媒体内容输出方法""" - return await _CTX.get().adapter.__send_media__(name, raw, url, mimetype) + return await _get_ctx_adapter().__send_media__(name, raw, url, mimetype) async def send_image( @@ -31,7 +32,7 @@ async def send_image( mimetype: str | None = None, ) -> tuple[ActionHandle, ...]: """通用图像内容输出方法""" - return await _CTX.get().adapter.__send_image__(name, raw, url, mimetype) + return await _get_ctx_adapter().__send_image__(name, raw, url, mimetype) async def send_audio( @@ -41,7 +42,7 @@ async def send_audio( mimetype: str | None = None, ) -> tuple[ActionHandle, ...]: """通用音频内容输出方法""" - return await _CTX.get().adapter.__send_audio__(name, raw, url, mimetype) + return await _get_ctx_adapter().__send_audio__(name, raw, url, mimetype) async def send_voice( @@ -51,7 +52,7 @@ async def send_voice( mimetype: str | None = None, ) -> tuple[ActionHandle, ...]: """通用语音内容输出方法""" - return await _CTX.get().adapter.__send_voice__(name, raw, url, mimetype) + return await _get_ctx_adapter().__send_voice__(name, raw, url, mimetype) async def send_video( @@ -61,21 +62,26 @@ async def send_video( mimetype: str | None = None, ) -> tuple[ActionHandle, ...]: """通用视频内容输出方法""" - return await _CTX.get().adapter.__send_video__(name, raw, url, mimetype) + return await _get_ctx_adapter().__send_video__(name, raw, url, mimetype) async def send_file(name: str, path: str | PathLike[str]) -> tuple[ActionHandle, ...]: """通用文件输出方法""" - return await _CTX.get().adapter.__send_file__(name, path) + return await _get_ctx_adapter().__send_file__(name, path) async def send_refer( event: Event, contents: Sequence[Content] | None = None ) -> tuple[ActionHandle, ...]: """通用过往事件引用输出方法""" - return await _CTX.get().adapter.__send_refer__(event, contents) + return await _get_ctx_adapter().__send_refer__(event, contents) async def send_resource(name: str, url: str) -> tuple[ActionHandle, ...]: """通用其他资源输出方法""" - return await _CTX.get().adapter.__send_resource__(name, url) + return await _get_ctx_adapter().__send_resource__(name, url) + + +def _get_ctx_adapter() -> Adapter: + event = _CTX.get_event() + return EventOrigin.get_origin(event).adapter diff --git a/src/melobot/adapter/model.py b/src/melobot/adapter/model.py index e1b0e85d..a045c6d8 100644 --- a/src/melobot/adapter/model.py +++ b/src/melobot/adapter/model.py @@ -26,32 +26,37 @@ from ..exceptions import AdapterError from ..io.base import AbstractOutSource from ..log.base import LogLevel -from ..typ import AttrsReprable, T -from ..utils import get_id, to_coro +from ..mixin import AttrReprMixin, FlagMixin +from ..typ.base import T +from ..typ.cls import abstractattr +from ..utils.base import to_coro +from ..utils.common import get_id from .content import Content if TYPE_CHECKING: - from .base import AbstractEchoFactory, AbstractOutputFactory + from .base import AbstractEchoFactory, AbstractOutputFactory, Adapter -class Event(AttrsReprable): +class Event(AttrReprMixin, FlagMixin): """事件基类 + :ivar typing.LiteralString protocol: 遵循的协议,为空则协议无关 :ivar float time: 时间戳 :ivar str id: id 标识 - :ivar typing.LiteralString | None protocol: 遵循的协议,为空则协议无关 - :ivar typing.Sequence[Content] contents: 附加的通用内容序列 :ivar typing.Hashable | None scope: 所在的域,可空 + :ivar typing.Sequence[Content] contents: 附加的通用内容序列 """ def __init__( self, + protocol: LiteralString, time: float = -1, id: str = "", - protocol: LiteralString | None = None, scope: Hashable | None = None, contents: Sequence[Content] | None = None, ) -> None: + super().__init__() + self.time = time_ns() / 1e9 if time == -1 else time self.id = get_id() if id == "" else id self.protocol = protocol @@ -64,7 +69,17 @@ def __init__( EventT = TypeVar("EventT", bound=Event) -class Action(AttrsReprable): +class TextEvent(Event): + """文本事件类 + + :ivar str text: 文本内容 + """ + + text: str = abstractattr() + textlines: list[str] = abstractattr() + + +class Action(AttrReprMixin, FlagMixin): """行为基类 :ivar float time: 时间戳 @@ -84,6 +99,7 @@ def __init__( contents: Sequence[Content] | None = None, trigger: Event | None = None, ) -> None: + super().__init__() self.time = time_ns() / 1e9 if time == -1 else time self.id = get_id() if id == "" else id self.protocol = protocol @@ -92,7 +108,7 @@ def __init__( self.trigger = trigger -class Echo(AttrsReprable): +class Echo(AttrReprMixin, FlagMixin): """回应基类 :ivar float time: 时间戳 @@ -116,6 +132,7 @@ def __init__( prompt: str = "", data: Any = None, ) -> None: + super().__init__() self.time = time_ns() / 1e9 if time == -1 else time self.id = get_id() if id == "" else id self.protocol = protocol @@ -231,7 +248,8 @@ async def _exec_handle(self, handles: Awaitable[tuple[ActionHandle, ...]]) -> No _handles = await handles for handle in _handles: handle.execute() - await asyncio.wait(map(lambda h: create_task(to_coro(h)), _handles)) + if len(_handles): + await asyncio.wait(map(create_task, map(to_coro, _handles))) def add( self, diff --git a/src/melobot/bot/base.py b/src/melobot/bot/base.py index 3d50ef3c..85e21150 100644 --- a/src/melobot/bot/base.py +++ b/src/melobot/bot/base.py @@ -22,18 +22,19 @@ NoReturn, ) -from .._hook import Hookable from .._meta import MetaInfo from ..adapter.base import Adapter from ..ctx import BotCtx, LoggerCtx from ..exceptions import BotError +from ..handle.base import Flow from ..io.base import AbstractInSource, AbstractIOSource, AbstractOutSource from ..log.base import GenericLogger, Logger, NullLogger +from ..mixin import HookMixin from ..plugin.base import Plugin, PluginLifeSpan, PluginPlanner from ..plugin.ipc import AsyncShare, IPCManager, SyncShare from ..plugin.load import PluginLoader from ..protocols.base import ProtocolStack -from ..typ import AsyncCallable, P +from ..typ.base import AsyncCallable, P, SyncOrAsyncCallable from .dispatch import Dispatcher @@ -71,7 +72,7 @@ def _start_log(logger: GenericLogger) -> None: logger.info("=" * 40) -class Bot(Hookable[BotLifeSpan]): +class Bot(HookMixin[BotLifeSpan]): """bot 类 :ivar str name: bot 对象的名称 @@ -103,7 +104,7 @@ def __init__( 可使用 melobot 内置的 :class:`.Logger`,或经过 :func:`.logger_patch` 修补的日志器 :param enable_log: 是否启用日志功能 """ - Hookable.__init__(self, BotLifeSpan, tag=name) + super().__init__(hook_type=BotLifeSpan, hook_tag=name) self.name = name self.logger: GenericLogger @@ -117,12 +118,11 @@ def __init__( self.adapters: dict[str, Adapter] = {} self.ipc_manager = IPCManager() - self._in_srcs: dict[str, list[AbstractInSource]] = {} - self._out_srcs: dict[str, list[AbstractOutSource]] = {} + self._in_srcs: dict[str, set[AbstractInSource]] = {} + self._out_srcs: dict[str, set[AbstractOutSource]] = {} self._loader = PluginLoader() self._plugins: dict[str, Plugin] = {} self._dispatcher = Dispatcher() - self._tasks: list[asyncio.Task] = [] self._inited = False self._running = False @@ -159,7 +159,7 @@ def add_input(self, src: AbstractInSource) -> Bot: if self._inited: raise BotError(f"{self} 已不在初始化期,无法再绑定输入源") - self._in_srcs.setdefault(src.protocol, []).append(src) + self._in_srcs.setdefault(src.protocol, set()).add(src) return self def add_output(self, src: AbstractOutSource) -> Bot: @@ -171,7 +171,7 @@ def add_output(self, src: AbstractOutSource) -> Bot: if self._inited: raise BotError(f"{self} 已不在初始化期,无法再绑定输出源") - self._out_srcs.setdefault(src.protocol, []).append(src) + self._out_srcs.setdefault(src.protocol, set()).add(src) return self def add_io(self, src: AbstractIOSource) -> Bot: @@ -221,7 +221,7 @@ def _core_init(self) -> None: for isrc in srcs: adapter = self.adapters.get(protocol) if adapter is not None: - adapter.in_srcs.append(isrc) + adapter.in_srcs.add(isrc) else: self.logger.warning( f"输入源 {isrc.__class__.__name__} 没有对应的适配器" @@ -231,7 +231,7 @@ def _core_init(self) -> None: for osrc in outsrcs: adapter = self.adapters.get(protocol) if adapter is not None: - adapter.out_srcs.append(osrc) + adapter.out_srcs.add(osrc) else: self.logger.warning( f"输出源 {osrc.__class__.__name__} 没有对应的适配器" @@ -273,12 +273,6 @@ def load_plugin( return self self._plugins[p.name] = p - - if self._hook_bus.get_evoke_time(BotLifeSpan.STARTED) != -1: - asyncio.create_task(self._dispatcher.add(*p.handlers)) - else: - self._dispatcher.add_nowait(*p.handlers) - for share in p.shares: self.ipc_manager.add(p.name, share) for func in p.funcs: @@ -286,7 +280,12 @@ def load_plugin( self.logger.info(f"成功加载插件:{p.name}") if self._hook_bus.get_evoke_time(BotLifeSpan.STARTED) != -1: - asyncio.create_task(p.hook_bus.emit(PluginLifeSpan.INITED)) + asyncio.create_task( + p.hook_bus.emit( + PluginLifeSpan.INITED, + callback=lambda _, p=p: self._dispatcher.add(*p.init_flows), + ) + ) return self def load_plugins( @@ -352,10 +351,13 @@ async def core_run(self) -> None: await self._hook_bus.emit(BotLifeSpan.RELOADED) for p in self._plugins.values(): - await p.hook_bus.emit(PluginLifeSpan.INITED) + await p.hook_bus.emit( + PluginLifeSpan.INITED, + callback=lambda _, p=p: self._dispatcher.add(*p.init_flows), + ) + + self._dispatcher.start() - timed_task = asyncio.create_task(self._dispatcher.timed_gc()) - self._tasks.append(timed_task) ts = tuple( asyncio.create_task( stack.enter_async_context(adapter.__adapter_launch__()) @@ -370,9 +372,6 @@ async def core_run(self) -> None: finally: async with self._common_async_ctx() as stack: - for t in self._tasks: - t.cancel() - await self._hook_bus.emit(BotLifeSpan.STOPPED, True) self.logger.info(f"{self} 已安全停止运行") self._running = False @@ -397,7 +396,8 @@ async def bots_run() -> None: for bot in bots: tasks.append(asyncio.create_task(bot.core_run())) try: - await asyncio.wait(tasks) + if len(tasks): + await asyncio.wait(tasks) except asyncio.CancelledError: pass @@ -481,28 +481,47 @@ def get_share(self, plugin: str, share: str) -> SyncShare | AsyncShare: """ return self.ipc_manager.get(plugin, share) + def add_flows(self, *flows: Flow) -> None: + """添加处理流 + + :param flows: 流对象 + """ + if self._hook_bus.get_evoke_time(BotLifeSpan.STARTED) == -1: + raise BotError(f"只有在 {BotLifeSpan.STARTED} 生命周期后才能动态添加处理流") + self._dispatcher.add(*flows) + @property - def on_loaded(self) -> Callable[[AsyncCallable[P, None]], AsyncCallable[P, None]]: + def on_loaded( + self, + ) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]: """给 bot 注册 :obj:`.BotLifeSpan.LOADED` 阶段 hook 的装饰器""" return self.on(BotLifeSpan.LOADED) @property - def on_reloaded(self) -> Callable[[AsyncCallable[P, None]], AsyncCallable[P, None]]: + def on_reloaded( + self, + ) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]: """给 bot 注册 :obj:`.BotLifeSpan.RELOADED` 阶段 hook 的装饰器""" return self.on(BotLifeSpan.RELOADED) @property - def on_started(self) -> Callable[[AsyncCallable[P, None]], AsyncCallable[P, None]]: + def on_started( + self, + ) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]: """给 bot 注册 :obj:`.BotLifeSpan.STARTED` 阶段 hook 的装饰器""" return self.on(BotLifeSpan.STARTED) @property - def on_close(self) -> Callable[[AsyncCallable[P, None]], AsyncCallable[P, None]]: + def on_close( + self, + ) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]: """给 bot 注册 :obj:`.BotLifeSpan.CLOSE` 阶段 hook 的装饰器""" return self.on(BotLifeSpan.CLOSE) @property - def on_stopped(self) -> Callable[[AsyncCallable[P, None]], AsyncCallable[P, None]]: + def on_stopped( + self, + ) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]: """给 bot 注册 :obj:`.BotLifeSpan.STOPPED` 阶段 hook 的装饰器""" return self.on(BotLifeSpan.STOPPED) diff --git a/src/melobot/bot/dispatch.py b/src/melobot/bot/dispatch.py index 209744ae..2d14e765 100644 --- a/src/melobot/bot/dispatch.py +++ b/src/melobot/bot/dispatch.py @@ -1,117 +1,160 @@ -import asyncio - -from typing_extensions import Any, TypeVar - -from ..adapter.model import Event -from ..handle.base import EventHandler -from ..typ import AsyncCallable, HandleLevel -from ..utils import RWContext +from __future__ import annotations -KeyT = TypeVar("KeyT", bound=float, default=float) -ValT = TypeVar("ValT", default=Any) +import asyncio +from asyncio import Queue, Task, get_running_loop +from ..adapter.base import Event +from ..handle.base import Flow +from ..mixin import LogMixin -class _KeyOrderDict(dict[KeyT, ValT]): - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.update(*args, **kwargs) - self.__buf: list[tuple[KeyT, ValT]] = [] - def __setitem__(self, key: KeyT, value: ValT) -> None: - if len(self) == 0: - return super().__setitem__(key, value) +class Dispatcher(LogMixin): + def __init__(self) -> None: + self.first_chan: EventChannel | None = None - if key <= next(reversed(self.items()))[0]: - return super().__setitem__(key, value) + self._pending_chans: list[EventChannel] = [] - cnt = 0 - for k, _ in reversed(self.items()): - if key > k: - cnt += 1 - else: - break + def _arrange_chan(self, chan: EventChannel) -> None: + try: + get_running_loop() + asyncio.create_task(chan.run()) + except RuntimeError: + self._pending_chans.append(chan) - for _ in range(cnt): - self.__buf.append(self.popitem()) - super().__setitem__(key, value) - while len(self.__buf): - super().__setitem__(*self.__buf.pop()) + def add(self, *flows: Flow) -> None: + for f in flows: + lvl = f.priority - return None + if self.first_chan is None: + self.first_chan = EventChannel(self, priority=lvl) + self.first_chan.flow_que.put_nowait(f) - def update(self, *args: Any, **kwargs: Any) -> None: - for k, v in dict(*args, **kwargs).items(): - self[k] = v + elif lvl == self.first_chan.priority: + self.first_chan.flow_que.put_nowait(f) - def setdefault(self, key: KeyT, default: ValT) -> ValT: - if key not in self: - self[key] = default - return self[key] + elif lvl > self.first_chan.priority: + chan = EventChannel(self, priority=lvl) + chan.set_next(self.first_chan) + self.first_chan = chan + chan.flow_que.put_nowait(f) + else: + chan = self.first_chan + while chan.next is not None and lvl <= chan.next.priority: + chan = chan.next + + if lvl == chan.priority: + chan.flow_que.put_nowait(f) + else: + new_chan = EventChannel(self, priority=lvl) + chan_next = chan.next + new_chan.set_pre(chan) + new_chan.set_next(chan_next) + new_chan.flow_que.put_nowait(f) + + f._active = True + + def remove(self, *flows: Flow) -> None: + for f in flows: + f._active = False + + def update(self, priority: int, *flows: Flow) -> None: + self.remove(*flows) + for f in flows: + f.priority = priority + self.add(*flows) + + def broadcast(self, event: Event) -> None: + if self.first_chan is not None: + self.first_chan.event_que.put_nowait(event) + else: + self.logger.warning(f"没有任何可用的事件处理流,事件 {event.id} 将被丢弃") + + def start(self) -> None: + for chan in self._pending_chans: + asyncio.create_task(chan.run()) + self._pending_chans.clear() + + +class EventChannel: + def __init__(self, owner: Dispatcher, priority: int) -> None: + self.owner = owner + self.event_que: Queue[Event] = Queue() + self.flow_que: Queue[Flow] = Queue() + self.priority = priority + + self.pre: EventChannel | None = None + self.next: EventChannel | None = None + + self.owner._arrange_chan(self) + + def set_pre(self, pre: EventChannel | None) -> None: + self.pre = pre + if self.pre is not None: + self.pre.next = self + + def set_next(self, next: EventChannel | None) -> None: + self.next = next + if self.next is not None: + self.next.pre = self + + async def run(self) -> None: + handle_tasks: list[Task] = [] + events: list[Event] = [] + valid_flows: list[Flow] = [] -class Dispatcher: - def __init__(self) -> None: - self.handlers: _KeyOrderDict[HandleLevel, set[EventHandler]] = _KeyOrderDict() - self.dispatch_ctrl = RWContext() - self.gc_interval = 5 - - def add_nowait(self, *handlers: EventHandler) -> None: - for h in handlers: - self.handlers.setdefault(h.flow.priority, set()).add(h) - h.flow.on_priority_reset( - lambda new_prior, h=h: self._reset_hook(h, new_prior) - ) - - async def add( - self, *handlers: EventHandler, callback: AsyncCallable[[], None] | None = None - ) -> None: - async with self.dispatch_ctrl.write(): - self.add_nowait(*handlers) - if callback is not None: - await callback() - - async def _remove(self, *handlers: EventHandler) -> None: - for h in handlers: - await h.expire() - h_set = self.handlers[h.flow.priority] - h_set.remove(h) - if len(h_set) == 0: - self.handlers.pop(h.flow.priority) - - async def remove( - self, *handlers: EventHandler, callback: AsyncCallable[[], None] | None = None - ) -> None: - async with self.dispatch_ctrl.write(): - await self._remove(*handlers) - if callback is not None: - await callback() - - async def _reset_hook(self, handler: EventHandler, new_prior: HandleLevel) -> None: - if handler.flow.priority == new_prior: + while True: + events.clear() + if self.event_que.qsize() == 0: + ev = await self.event_que.get() + events.append(ev) + for _ in range(self.event_que.qsize()): + events.append(self.event_que.get_nowait()) + + for ev in events: + ev.flag_set_default(self.owner, self.owner, set()) + handle_tasks.clear() + valid_flows.clear() + + if self.flow_que.qsize() == 0: + self._dispose(*events) + return + + for _ in range(self.flow_que.qsize()): + handled_fs: set[Flow] = ev.flag_get(self.owner, self.owner) + f = self.flow_que.get_nowait() + if f._active and f.priority == self.priority: + if f not in handled_fs: + handle_tasks.append(asyncio.create_task(f._handle(ev))) + handled_fs.add(f) + valid_flows.append(f) + + for f in valid_flows: + self.flow_que.put_nowait(f) + if len(valid_flows): + coro = self._determine_spread(ev, handle_tasks) + asyncio.create_task(coro) + else: + self._dispose(*events) + return + + def _dispose(self, *events: Event) -> None: + if self.pre is not None: + self.pre.set_next(self.next) + + if self.next is not None: + for ev in events: + self.next.event_que.put_nowait(ev) + + if self is self.owner.first_chan: + self.owner.first_chan = self.next + + async def _determine_spread(self, ev: Event, handle_tasks: list[Task]) -> None: + if not len(handle_tasks): + if self.next is not None: + self.next.event_que.put_nowait(ev) return - async with self.dispatch_ctrl.write(): - old_prior = handler.flow.priority - if old_prior == new_prior: - return - h_set = self.handlers[old_prior] - h_set.remove(handler) - if len(h_set) == 0: - self.handlers.pop(old_prior) - self.handlers.setdefault(new_prior, set()).add(handler) - - async def broadcast(self, event: Event) -> None: - async with self.dispatch_ctrl.read(): - for h_set in self.handlers.values(): - tasks = tuple(asyncio.create_task(h.handle(event)) for h in h_set) - await asyncio.wait(tasks) - if not event.spread: - break - - async def timed_gc(self) -> None: - while True: - await asyncio.sleep(self.gc_interval) - async with self.dispatch_ctrl.write(): - hs = tuple( - h for h_set in self.handlers.values() for h in h_set if h.invalid - ) - await self._remove(*hs) + await asyncio.wait(handle_tasks) + if self.next is not None and ev.spread: + self.next.event_que.put_nowait(ev) diff --git a/src/melobot/ctx.py b/src/melobot/ctx.py index 4d5c2bb9..ad47b865 100644 --- a/src/melobot/ctx.py +++ b/src/melobot/ctx.py @@ -1,22 +1,33 @@ +from asyncio import Future from contextlib import contextmanager from contextvars import ContextVar, Token from dataclasses import dataclass, field from enum import Enum -from typing_extensions import TYPE_CHECKING, Any, Callable, Generator, Generic, Union +from typing_extensions import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Generic, + Self, + Union, + cast, +) from .exceptions import AdapterError, BotError, FlowError, LogError, SessionError -from .typ import SingletonMeta, T +from .typ.base import T +from .typ.cls import SingletonMeta if TYPE_CHECKING: - from .adapter import model - from .adapter.base import Adapter + from .adapter import Adapter, model from .bot.base import Bot - from .handle.process import Flow, FlowNode + from .handle.base import Flow, FlowNode from .io.base import AbstractInSource, OutSourceT from .log.base import GenericLogger from .session.base import Session, SessionStore from .session.option import Rule + from .utils.parse import AbstractParseArgs class Context(Generic[T], metaclass=SingletonMeta): @@ -102,26 +113,6 @@ def __init__(self) -> None: super().__init__("MELOBOT_OUT_SRC_FILTER", AdapterError) -@dataclass -class EventBuildInfo: - adapter: "Adapter" - in_src: "AbstractInSource" - - -class EventBuildInfoCtx(Context[EventBuildInfo]): - def __init__(self) -> None: - super().__init__( - "MELOBOT_EVENT_BUILD_INFO", - AdapterError, - "此时不在活动的事件处理流中,无法获取适配器与输入源的上下文信息", - ) - - def get_adapter_type(self) -> type["Adapter"]: - from .adapter.base import Adapter - - return Adapter - - class FlowRecordStage(Enum): """流记录阶段的枚举""" @@ -161,10 +152,10 @@ class FlowStore(dict[str, Any]): @dataclass class FlowStatus: - event: "model.Event" flow: "Flow" node: "FlowNode" next_valid: bool + completion: "EventCompletion" records: FlowRecords = field(default_factory=FlowRecords) store: FlowStore = field(default_factory=FlowStore) @@ -181,7 +172,7 @@ def get_event(self) -> "model.Event": session = SessionCtx().try_get() if session is not None: return session.event - return self.get().event + return self.get().completion.event def try_get_event(self) -> Union["model.Event", None]: session = SessionCtx().try_get() @@ -190,7 +181,7 @@ def try_get_event(self) -> Union["model.Event", None]: status = self.try_get() if status is not None: - return status.event + return status.completion.event return None def get_event_type(self) -> type["model.Event"]: @@ -198,6 +189,12 @@ def get_event_type(self) -> type["model.Event"]: return Event + def get_records(self) -> tuple[FlowRecord, ...]: + return tuple(self.get().records) + + def get_completion(self) -> "EventCompletion": + return self.get().completion + def get_store(self) -> FlowStore: return self.get().store @@ -226,6 +223,16 @@ def __init__(self) -> None: def get_store(self) -> "SessionStore": return self.get().store + def get_rule(self) -> "Rule": + rule = self.get().rule + assert rule is not None, "预期之外的会话规则为空" + return rule + + def get_session_type(self) -> type["Session"]: + from .session.base import Session + + return Session + def get_store_type(self) -> type["SessionStore"]: from .session.base import SessionStore @@ -247,6 +254,48 @@ def get_type(self) -> type["GenericLogger"]: return GenericLogger +class EventOrigin: + _FLAG_KEYS = (object(), object()) + + def __init__(self, adapter: "Adapter", in_src: "AbstractInSource") -> None: + self.adapter = adapter + self.in_src = in_src + + @classmethod + def set_origin(cls, event: "model.Event", origin: "EventOrigin") -> None: + event.flag_set(cls._FLAG_KEYS[0], cls._FLAG_KEYS[1], origin) + + @classmethod + def get_origin(cls, event: "model.Event") -> Self: + origin = event.flag_get(cls._FLAG_KEYS[0], cls._FLAG_KEYS[1]) + return cast(Self, origin) + + +# 不使用 dataclass,不用重写任何方法就可哈希 +class EventCompletion: + def __init__( + self, + event: "model.Event", + completed: Future[None], + owner_flow: "Flow", + under_session: bool = False, + ) -> None: + self.event = event + self.completed = completed + self.owner_flow = owner_flow + self.under_session = under_session + + class ActionManualSignalCtx(Context[bool]): def __init__(self) -> None: super().__init__("MELOBOT_ACTION_MANUAL_SIGNAL", AdapterError) + + +class ParseArgsCtx(Context["AbstractParseArgs"]): + def __init__(self) -> None: + super().__init__("MELOBOT_PARSE_ARGS", LookupError, "当前上下文中不存在解析参数") + + def get_args_type(self) -> type["AbstractParseArgs"]: + from .utils.parse import AbstractParseArgs + + return AbstractParseArgs diff --git a/src/melobot/di.py b/src/melobot/di.py index 30f574b3..083bf911 100644 --- a/src/melobot/di.py +++ b/src/melobot/di.py @@ -1,9 +1,10 @@ from __future__ import annotations +from abc import abstractmethod from asyncio import Lock from collections import deque from dataclasses import dataclass -from functools import wraps +from functools import partial, wraps from inspect import Parameter, isawaitable, signature, unwrap from sys import version_info from types import BuiltinFunctionType, FunctionType, LambdaType @@ -20,19 +21,13 @@ get_origin, ) -from .ctx import BotCtx, EventBuildInfoCtx, FlowCtx, LoggerCtx, SessionCtx +from .ctx import BotCtx, EventOrigin, FlowCtx, LoggerCtx, ParseArgsCtx, SessionCtx from .exceptions import DependBindError, DependInitError -from .typ import ( - AsyncCallable, - BetterABC, - P, - T, - VoidType, - abstractmethod, - is_subhint, - is_type, -) -from .utils import get_obj_name, to_async +from .typ._enum import VoidType +from .typ.base import AsyncCallable, P, SyncOrAsyncCallable, T, is_subhint, is_type +from .typ.cls import BetterABC +from .utils.base import to_async +from .utils.common import get_obj_name if TYPE_CHECKING: from .adapter.base import Adapter @@ -49,11 +44,11 @@ def __init__( self.hint = hint -class Depends: +class Depends(Generic[T]): def __init__( self, - dep: Callable[[], Any] | AsyncCallable[[], Any] | Depends, - sub_getter: Callable[[Any], Any] | AsyncCallable[[Any], Any] | None = None, + dep: SyncOrAsyncCallable[[], T] | Depends[T], + sub_getter: SyncOrAsyncCallable[[T], T] | None = None, cache: bool = False, recursive: bool = True, ) -> None: @@ -65,8 +60,8 @@ def __init__( :param recursive: 是否启用递归满足(默认启用,如果当前依赖来源中存在依赖项,会被递归满足;关闭可节约性能) """ super().__init__() - self.ref: Depends | None - self.getter: AsyncCallable[[], Any] | None + self.ref: Depends[T] | None + self.getter: AsyncCallable[[], T] | None if isinstance(dep, Depends): self.ref = dep @@ -93,22 +88,22 @@ def __repr__(self) -> str: ref_str = f"ref={self.ref}" if self.ref is not None else "" return f"{self.__class__.__name__}({ref_str if ref_str != '' else getter_str})" - async def _get(self, dep_scope: dict[Depends, Any]) -> Any: + async def _get(self, dep_scope: dict[Depends, Any]) -> T: + val: T | VoidType + if self.getter is not None: val = await self.getter() else: - ref = cast(Depends, self.ref) + ref = cast(Depends[T], self.ref) val = dep_scope.get(ref, VoidType.VOID) if val is VoidType.VOID: val = await ref.fulfill(dep_scope) if self.sub_getter is not None: - val = self.sub_getter(val) - if isawaitable(val): - val = await val + val = await self.sub_getter(val) return val - async def fulfill(self, dep_scope: dict[Depends, Any]) -> Any: + async def fulfill(self, dep_scope: dict[Depends, Any]) -> T: if self._lock is None: val = await self._get(dep_scope) elif self._cached is not VoidType.VOID: @@ -123,24 +118,6 @@ async def fulfill(self, dep_scope: dict[Depends, Any]) -> Any: return val -def _adapter_get(hint: Any) -> "Adapter": - ctx = EventBuildInfoCtx() - try: - return ctx.get().adapter - except ctx.lookup_exc_cls as e: - adapter = BotCtx().get().get_adapter(hint) - if adapter is None: - raise e - return adapter - - -def _custom_logger_get(hint: Any, data: CustomLogger) -> Any: - val = LoggerCtx().get() - if not is_type(val, hint): - val = data.getter() - return val - - class AutoDepends(Depends): def __init__(self, func: Callable, name: str, hint: Any) -> None: self.hint = hint @@ -158,7 +135,7 @@ def __init__(self, func: Callable, name: str, hint: Any) -> None: else: self.metadatas = () - self.orig_getter: Callable[[], Any] | AsyncCallable[[], Any] | None = None + self.orig_getter: SyncOrAsyncCallable[[], Any] | None = None if is_subhint(hint, FlowCtx().get_event_type()): self.orig_getter = FlowCtx().get_event @@ -166,8 +143,8 @@ def __init__(self, func: Callable, name: str, hint: Any) -> None: elif is_subhint(hint, BotCtx().get_type()): self.orig_getter = BotCtx().get - elif is_subhint(hint, EventBuildInfoCtx().get_adapter_type()): - self.orig_getter = cast(Callable[[], Any], lambda h=hint: _adapter_get(h)) + elif is_subhint(hint, _get_adapter_type()): + self.orig_getter = cast(Callable[[], Any], partial(_adapter_get, self, hint)) elif is_subhint(hint, LoggerCtx().get_type()): self.orig_getter = LoggerCtx().get @@ -175,16 +152,22 @@ def __init__(self, func: Callable, name: str, hint: Any) -> None: elif is_subhint(hint, FlowCtx().get_store_type()): self.orig_getter = FlowCtx().get_store + elif is_subhint(hint, SessionCtx().get_session_type()): + self.orig_getter = SessionCtx().get + elif is_subhint(hint, SessionCtx().get_store_type()): self.orig_getter = SessionCtx().get_store - elif is_subhint(hint, SessionCtx().get_rule_type() | None): - self.orig_getter = lambda: SessionCtx().get().rule + elif is_subhint(hint, SessionCtx().get_rule_type()): + self.orig_getter = SessionCtx().get_rule + + elif is_subhint(hint, ParseArgsCtx().get_args_type()): + self.orig_getter = ParseArgsCtx().get for data in self.metadatas: if isinstance(data, CustomLogger): self.orig_getter = cast( - Callable[[], Any], lambda h=hint, d=data: _custom_logger_get(h, d) + Callable[[], Any], partial(_custom_logger_get, hint, data) ) break @@ -197,8 +180,7 @@ def __init__(self, func: Callable, name: str, hint: Any) -> None: for data in self.metadatas: if isinstance(data, Reflect): self.orig_getter = cast( - Callable[[], Any], - lambda g=self.orig_getter: Reflection(cast(Callable[[], Any], g)), + Callable[[], Any], partial(Reflection, self.orig_getter) ) break @@ -230,7 +212,7 @@ async def fulfill(self, dep_scope: dict[Depends, Any]) -> Any: val = await super().fulfill(dep_scope) if isinstance(val, Reflection): - inner_val = val.__obj_getter__() + inner_val = val.__origin__ if isawaitable(inner_val): raise AttributeError( f"异步依赖项不能通过 {Reflect.__name__} 创建反射依赖" @@ -243,6 +225,31 @@ async def fulfill(self, dep_scope: dict[Depends, Any]) -> Any: return val +def _get_adapter_type() -> type["Adapter"]: + from .adapter.base import Adapter + + return Adapter + + +def _adapter_get(deps: AutoDepends, hint: Any) -> "Adapter": + flow_ctx = FlowCtx() + try: + event = flow_ctx.get_event() + return EventOrigin.get_origin(event).adapter + except flow_ctx.lookup_exc_cls: + adapter = BotCtx().get().get_adapter(hint) + if adapter is None: + raise deps._unmatch_exc(VoidType) from None + return adapter + + +def _custom_logger_get(hint: Any, data: CustomLogger) -> Any: + val = LoggerCtx().get() + if not is_type(val, hint): + val = data.getter() + return val + + @dataclass class Exclude: """数据类。`types` 指定的类别会在依赖注入时被排除 @@ -265,7 +272,7 @@ class CustomLogger: # 如果 bot 设置的 logger 是 MyLogger 类型,则成功依赖注入 # 否则使用 getter 获取一个日志器 - NewLoggerHint = Annotated[MyLogger, CustomLogger(getter=lambda: MyLogger())] + NewLoggerHint = Annotated[MyLogger, CustomLogger(getter=MyLogger)] """ getter: Callable[[], Any] @@ -389,7 +396,7 @@ def _get_bound_args( return list(bind.args), bind.kwargs -class DependsHook(Depends, BetterABC, Generic[T]): +class DependsHook(Depends[T], BetterABC): """依赖钩子 包装一个依赖项,依赖满足后内部的 hook 将会执行 @@ -397,11 +404,11 @@ class DependsHook(Depends, BetterABC, Generic[T]): def __init__( self, - func: Callable[P, T] | AsyncCallable[P, T], + dep: SyncOrAsyncCallable[[], T], cache: bool = False, recursive: bool = False, ) -> None: - super().__init__(func, cache=cache, recursive=recursive) + super().__init__(dep, cache=cache, recursive=recursive) @abstractmethod async def deps_callback(self, val: T) -> None: @@ -411,14 +418,14 @@ async def deps_callback(self, val: T) -> None: """ raise NotImplementedError - async def fulfill(self, dep_scope: dict[Depends, Any]) -> Any: + async def fulfill(self, dep_scope: dict[Depends, Any]) -> T: val = await super().fulfill(dep_scope) await self.deps_callback(val) return val def inject_deps( - injectee: Callable[..., T] | AsyncCallable[..., T], manual_arg: bool = False + injectee: SyncOrAsyncCallable[..., T], manual_arg: bool = False ) -> AsyncCallable[..., T]: """依赖注入标记装饰器,标记当前对象需要被依赖注入 @@ -431,7 +438,7 @@ def inject_deps( """ @wraps(injectee) - async def di_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + async def inject_deps_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: _args, _kwargs = _get_bound_args(injectee, *args, **kwargs) dep_scope: dict[Depends, Any] = {} @@ -452,9 +459,9 @@ async def di_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: if isinstance(injectee, (FunctionType, BuiltinFunctionType)): _init_auto_deps(injectee, manual_arg) - return di_wrapped + return inject_deps_wrapped if isinstance(injectee, LambdaType): - return di_wrapped + return inject_deps_wrapped raise DependInitError( f"{injectee} 对象不属于以下类别中的任何一种:" diff --git a/src/melobot/exceptions.py b/src/melobot/exceptions.py index dca59893..18c95cfa 100644 --- a/src/melobot/exceptions.py +++ b/src/melobot/exceptions.py @@ -37,7 +37,11 @@ def __str__(self) -> str: return self.err -class ValidateError(BotException): +class UtilError(BotException): + """melobot.utils 异常""" + + +class UtilValidateError(UtilError): """:py:mod:`melobot.utils` 函数参数验证异常""" @@ -69,6 +73,15 @@ class SessionError(BotException): """melobot 会话异常""" +class SessionStateFailed(SessionError): + def __init__(self, cur_state: str, meth: str) -> None: + self.cur_state = cur_state + super().__init__(f"当前会话状态 {cur_state} 不支持的操作:{meth}") + + +class SessionRuleLacked(SessionError): ... + + class FlowError(BotException): """melobot 处理流异常""" diff --git a/src/melobot/handle/__init__.py b/src/melobot/handle/__init__.py index ba080b3e..367b0cc1 100644 --- a/src/melobot/handle/__init__.py +++ b/src/melobot/handle/__init__.py @@ -1,7 +1,8 @@ from ..adapter.model import Event as _Event +from ..ctx import BotCtx as _BotCtx from ..ctx import FlowCtx as _FlowCtx from ..ctx import FlowRecord, FlowRecordStage, FlowStore -from .process import ( +from .base import ( Flow, FlowNode, block, @@ -13,6 +14,18 @@ rewind, stop, ) +from .register import ( + FlowDecorator, + GetParseArgs, + on_command, + on_contain_match, + on_end_match, + on_event, + on_full_match, + on_regex_match, + on_start_match, + on_text, +) def get_flow_records() -> tuple[FlowRecord, ...]: @@ -20,7 +33,7 @@ def get_flow_records() -> tuple[FlowRecord, ...]: :return: 流记录 """ - return tuple(_FlowCtx().get().records) + return _FlowCtx().get_records() def get_flow_store() -> FlowStore: diff --git a/src/melobot/handle/base.py b/src/melobot/handle/base.py index 215439ee..0eb8956c 100644 --- a/src/melobot/handle/base.py +++ b/src/melobot/handle/base.py @@ -1,53 +1,441 @@ -from typing_extensions import TYPE_CHECKING +from __future__ import annotations -from ..adapter.model import Event -from ..ctx import LoggerCtx -from ..log.base import LogLevel -from ..utils import RWContext -from .process import Flow +from asyncio import create_task, get_running_loop +from dataclasses import dataclass +from itertools import tee -if TYPE_CHECKING: - from ..plugin.base import Plugin +from typing_extensions import Iterable, NoReturn, Sequence +from ..adapter.base import Event +from ..ctx import BotCtx, EventCompletion, FlowCtx, FlowRecord, FlowRecords +from ..ctx import FlowRecordStage as RecordStage +from ..ctx import FlowStatus, FlowStore +from ..di import DependNotMatched, inject_deps +from ..exceptions import FlowError +from ..log import LogLevel +from ..mixin import LogMixin +from ..typ.base import AsyncCallable, SyncOrAsyncCallable +from ..utils.base import to_async +from ..utils.common import get_obj_name -class EventHandler: - def __init__(self, plugin: "Plugin", flow: Flow) -> None: - self.flow = flow - self.logger = LoggerCtx().get() - self.name = flow.name +_FLOW_CTX = FlowCtx() - self._plugin = plugin - self._handle_ctrl = RWContext() - self._temp = flow.temp - self.invalid = False - async def _handle(self, event: Event) -> None: +def node(func: SyncOrAsyncCallable[..., bool | None]) -> FlowNode: + """处理结点装饰器,将当前异步可调用对象装饰为一个处理结点""" + return FlowNode(func) + + +def no_deps_node(func: SyncOrAsyncCallable[..., bool | None]) -> FlowNode: + """与 :func:`node` 类似,但是不自动为结点标记依赖注入。 + + 需要后续使用 :func:`.inject_deps` 手动标记依赖注入, + 这适用于某些对处理结点进行再装饰的情况 + """ + return FlowNode(func, no_deps=True) + + +class FlowNode: + """处理流结点""" + + def __init__( + self, func: SyncOrAsyncCallable[..., bool | None], no_deps: bool = False + ) -> None: + self.name = get_obj_name(func, otype="callable") + self.processor: AsyncCallable[..., bool | None] = ( + to_async(func) if no_deps else inject_deps(func) + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.name})" + + async def process(self, flow: Flow, completion: EventCompletion) -> None: + event = completion.event try: - await self.flow.run(event) - except Exception: - self.logger.exception(f"事件处理流 {self.name} 发生异常") - self.logger.generic_obj( - f"异常点 event {event.id}", event.__dict__, level=LogLevel.ERROR - ) - self.logger.generic_obj("异常点局部变量:", locals(), level=LogLevel.ERROR) - - async def handle(self, event: Event) -> None: - if self.invalid: + status = _FLOW_CTX.get() + records, store = status.records, status.store + except _FLOW_CTX.lookup_exc_cls: + records, store = FlowRecords(), FlowStore() + + with _FLOW_CTX.unfold(FlowStatus(flow, self, True, completion, records, store)): + try: + records.append( + FlowRecord(RecordStage.NODE_START, flow.name, self.name, event) + ) + + try: + ret = await self.processor() + records.append( + FlowRecord(RecordStage.NODE_FINISH, flow.name, self.name, event) + ) + except DependNotMatched as e: + ret = False + records.append( + FlowRecord( + RecordStage.DEPENDS_NOT_MATCH, + flow.name, + self.name, + event, + f"Real({e.real_type}) <=> Annotation({e.hint})", + ) + ) + + if ret in (None, True) and _FLOW_CTX.get().next_valid: + await nextn() + + except FlowContinued: + await nextn() + + +@dataclass +class NodeInfo: + nexts: list[FlowNode] + in_deg: int + out_deg: int + + def copy(self) -> NodeInfo: + return NodeInfo(self.nexts, self.in_deg, self.out_deg) + + +class Flow(LogMixin): + """处理流 + + :ivar str name: 处理流的标识 + """ + + def __init__( + self, + name: str, + *edge_maps: Iterable[Iterable[FlowNode] | FlowNode], + priority: int = 0, + ) -> None: + """初始化处理流 + + :param name: 处理流的标识 + :param edge_maps: 边映射,遵循 melobot 的 graph edges 表示方法 + :param priority: 处理流的优先级 + """ + self.name = name + self.graph: dict[FlowNode, NodeInfo] = {} + self.priority = priority + + self._active = True + + _edge_maps = tuple( + tuple((elem,) if isinstance(elem, FlowNode) else elem for elem in emap) + for emap in edge_maps + ) + edges = self._get_edges(_edge_maps) + + for n1, n2 in edges: + self._add(n1, n2) + + if not self._valid_check(): + raise FlowError(f"定义的处理流 {self.name} 中存在环路") + + def __repr__(self) -> str: + output = f"{self.__class__.__name__}(name={self.name}, nums={len(self.graph)}" + + if len(self.graph): + output += f", starts=[{', '.join(repr(n) for n in self.starts)}])" + else: + output += ")" + return output + + @property + def starts(self) -> tuple[FlowNode, ...]: + return tuple(n for n, info in self.graph.items() if info.in_deg == 0) + + @property + def ends(self) -> tuple[FlowNode, ...]: + return tuple(n for n, info in self.graph.items() if info.out_deg == 0) + + def update_priority(self, priority: int) -> None: + """更新处理流优先级 + + :param priority: 新优先级 + """ + BotCtx().get()._dispatcher.update(priority, self) + + def dismiss(self) -> None: + """停用处理流 + + 停用后将无法处理任何新事件,也无法再次恢复使用 + """ + self._active = False + + def is_active(self) -> bool: + """判断处理流是否处于可用状态 + + :return: 是否可用 + """ + return self._active + + def link(self, flow: Flow, priority: int | None = None) -> Flow: + """连接另一处理流返回新处理流,并设置新优先级 + + :param flow: 连接的新流 + :param priority: 新优先级,若为空,则使用两者中较小的优先级 + :return: 新的处理流 + """ + _froms = self.ends + _tos = flow.starts + new_edges = tuple((n1, n2) for n1 in _froms for n2 in _tos) + + new_flow = Flow( + f"{self.name} ~ {flow.name}", + *new_edges, + priority=priority if priority else min(self.priority, flow.priority), + ) + + for n1, info in (self.graph | flow.graph).items(): + if not len(info.nexts): + new_flow._add(n1, None) + continue + for n2 in info.nexts: + new_flow._add(n1, n2) + + if not self._valid_check(): + raise FlowError(f"定义的处理流 {self.name} 中存在环路") + + return new_flow + + def _get_edges( + self, edge_maps: Sequence[Sequence[Iterable[FlowNode]]] + ) -> list[tuple[FlowNode, FlowNode]]: + edges: list[tuple[FlowNode, FlowNode]] = [] + + for emap in edge_maps: + iter1, iter2 = tee(emap, 2) + try: + next(iter2) + except StopIteration: + continue + + if len(emap) == 1: + for n in emap[0]: + self._add(n, None) + continue + + for from_seq, to_seq in zip(iter1, iter2): + for n1 in from_seq: + for n2 in to_seq: + if (n1, n2) not in edges: + edges.append((n1, n2)) + + return edges + + def _add(self, _from: FlowNode, to: FlowNode | None) -> None: + from_info = self.graph.setdefault(_from, NodeInfo([], 0, 0)) + + if to is not None: + to_info = self.graph.setdefault(to, NodeInfo([], 0, 0)) + to_info.in_deg += 1 + from_info.out_deg += 1 + from_info.nexts.append(to) + + def _valid_check(self) -> bool: + graph = {n: info.copy() for n, info in self.graph.items()} + + while len(graph): + for n, info in graph.items(): + nexts, in_deg = info.nexts, info.in_deg + + if in_deg == 0: + graph.pop(n) + for next_n in nexts: + graph[next_n].in_deg -= 1 + break + + else: + return False + + return True + + async def _handle(self, event: Event) -> None: + fut = get_running_loop().create_future() + create_task(self._run(EventCompletion(event, fut, self))) + await fut + + async def _run(self, completion: EventCompletion) -> None: + if not len(self.starts): + if ( + completion.owner_flow is self + and not completion.under_session + and not completion.completed.done() + ): + completion.completed.set_result(None) return - if not self._temp: - async with self._handle_ctrl.read(): - if self.invalid: - return - return await self._handle(event) - - async with self._handle_ctrl.write(): - if self.invalid: - return - await self._handle(event) - self.invalid = True + event = completion.event + try: + status = _FLOW_CTX.get() + records, store = status.records, status.store + except _FLOW_CTX.lookup_exc_cls: + records, store = FlowRecords(), FlowStore() + + with _FLOW_CTX.unfold( + FlowStatus(self, self.starts[0], True, completion, records, store) + ): + try: + records.append( + FlowRecord( + RecordStage.FLOW_START, self.name, self.starts[0].name, event + ) + ) + + idx = 0 + while idx < len(self.starts): + try: + await self.starts[idx].process(self, completion) + idx += 1 + except FlowRewound: + pass + + records.append( + FlowRecord( + RecordStage.FLOW_FINISH, self.name, self.starts[0].name, event + ) + ) + + except FlowBroke: + pass + + except Exception: + self.logger.exception(f"事件处理流 {self.name} 发生异常") + self.logger.generic_obj( + f"异常点 event {event.id}", event.__dict__, level=LogLevel.ERROR + ) + self.logger.generic_obj( + "异常点局部变量:", locals(), level=LogLevel.ERROR + ) + + finally: + if ( + completion.owner_flow is self + and not completion.under_session + and not completion.completed.done() + ): + completion.completed.set_result(None) + + +class _FlowSignal(BaseException): ... + + +class FlowBroke(_FlowSignal): ... + + +class FlowContinued(_FlowSignal): ... + + +class FlowRewound(_FlowSignal): ... + + +async def nextn() -> None: + """运行下一处理结点(在处理流中使用)""" + try: + status = _FLOW_CTX.get() + nexts = status.flow.graph[status.node].nexts + if not status.next_valid: return - async def expire(self) -> None: - async with self._handle_ctrl.write(): - self.invalid = True + idx = 0 + while idx < len(nexts): + try: + await nexts[idx].process(status.flow, status.completion) + idx += 1 + except FlowRewound: + pass + + except _FLOW_CTX.lookup_exc_cls: + raise FlowError("此时不在活动的事件处理流中,无法调用下一处理结点") from None + finally: + status.next_valid = False + + +async def block() -> None: + """阻止当前事件向更低优先级的处理流传播(在处理流中使用)""" + status = _FLOW_CTX.get() + status.records.append( + FlowRecord( + RecordStage.BLOCK, status.flow.name, status.node.name, status.completion.event + ) + ) + status.completion.event.spread = False + + +async def stop() -> NoReturn: + """立即停止当前处理流(在处理流中使用)""" + status = _FLOW_CTX.get() + status.records.append( + FlowRecord( + RecordStage.STOP, status.flow.name, status.node.name, status.completion.event + ) + ) + status.records.append( + FlowRecord( + RecordStage.NODE_EARLY_FINISH, + status.flow.name, + status.node.name, + status.completion.event, + ) + ) + status.records.append( + FlowRecord( + RecordStage.FLOW_EARLY_FINISH, + status.flow.name, + status.node.name, + status.completion.event, + ) + ) + raise FlowBroke("事件处理流被安全地提早结束,请无视这个内部工作信号") + + +async def bypass() -> NoReturn: + """立即跳过当前处理结点剩下的步骤,运行下一处理结点(在处理流中使用)""" + status = _FLOW_CTX.get() + status.records.append( + FlowRecord( + RecordStage.BYPASS, + status.flow.name, + status.node.name, + status.completion.event, + ) + ) + status.records.append( + FlowRecord( + RecordStage.NODE_EARLY_FINISH, + status.flow.name, + status.node.name, + status.completion.event, + ) + ) + raise FlowContinued("事件处理流安全地跳过结点执行,请无视这个内部工作信号") + + +async def rewind() -> NoReturn: + """立即重新运行当前处理结点(在处理流中使用)""" + status = _FLOW_CTX.get() + status.records.append( + FlowRecord( + RecordStage.REWIND, + status.flow.name, + status.node.name, + status.completion.event, + ) + ) + status.records.append( + FlowRecord( + RecordStage.NODE_EARLY_FINISH, + status.flow.name, + status.node.name, + status.completion.event, + ) + ) + raise FlowRewound("事件处理流安全地重复执行处理结点,请无视这个内部工作信号") + + +async def flow_to(flow: Flow) -> None: + """立即进入一个其他处理流(在处理流中使用)""" + status = _FLOW_CTX.get() + await flow._run(status.completion) diff --git a/src/melobot/handle/process.py b/src/melobot/handle/process.py deleted file mode 100644 index c1f7ca57..00000000 --- a/src/melobot/handle/process.py +++ /dev/null @@ -1,399 +0,0 @@ -from __future__ import annotations - -from asyncio import create_task -from dataclasses import dataclass -from itertools import tee - -from typing_extensions import Iterable, NoReturn, Sequence - -from ..adapter.model import Event -from ..ctx import FlowCtx, FlowRecord, FlowRecords -from ..ctx import FlowRecordStage as RecordStage -from ..ctx import FlowStatus, FlowStore -from ..di import DependNotMatched, inject_deps -from ..exceptions import FlowError -from ..typ import AsyncCallable, HandleLevel -from ..utils import get_obj_name - -_FLOW_CTX = FlowCtx() - - -def node(func: AsyncCallable[..., bool | None]) -> FlowNode: - """处理结点装饰器,将当前异步可调用对象装饰为一个处理结点""" - return FlowNode(func) - - -def no_deps_node(func: AsyncCallable[..., bool | None]) -> FlowNode: - """与 :func:`node` 类似,但是不自动为结点标记依赖注入。 - - 需要后续使用 :func:`.inject_deps` 手动标记依赖注入, - 这适用于某些对处理结点进行再装饰的情况 - """ - return FlowNode(func, no_deps=True) - - -class FlowNode: - """处理流结点""" - - def __init__( - self, func: AsyncCallable[..., bool | None], no_deps: bool = False - ) -> None: - self.name = get_obj_name(func, otype="callable") - self.processor: AsyncCallable[..., bool | None] = ( - func if no_deps else inject_deps(func) - ) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(name={self.name})" - - async def process(self, event: Event, flow: Flow) -> None: - try: - status = _FLOW_CTX.get() - records, store = status.records, status.store - except _FLOW_CTX.lookup_exc_cls: - records, store = FlowRecords(), FlowStore() - - with _FLOW_CTX.unfold(FlowStatus(event, flow, self, True, records, store)): - try: - records.append( - FlowRecord(RecordStage.NODE_START, flow.name, self.name, event) - ) - - try: - ret = await self.processor() - records.append( - FlowRecord(RecordStage.NODE_FINISH, flow.name, self.name, event) - ) - except DependNotMatched as e: - ret = False - records.append( - FlowRecord( - RecordStage.DEPENDS_NOT_MATCH, - flow.name, - self.name, - event, - f"Real({e.real_type}) <=> Annotation({e.hint})", - ) - ) - - if ret in (None, True) and _FLOW_CTX.get().next_valid: - await nextn() - - except FlowContinued: - await nextn() - - -@dataclass -class _NodeInfo: - nexts: list[FlowNode] - in_deg: int - out_deg: int - - def copy(self) -> _NodeInfo: - return _NodeInfo(self.nexts, self.in_deg, self.out_deg) - - -class Flow: - """处理流 - - :ivar str name: 处理流的标识 - """ - - def __init__( - self, - name: str, - *edge_maps: Iterable[Iterable[FlowNode] | FlowNode], - priority: HandleLevel = HandleLevel.NORMAL, - temp: bool = False, - ) -> None: - """初始化处理流 - - :param name: 处理流的标识 - :param edge_maps: 边映射,遵循 melobot 的 graph edges 表示方法 - :param priority: 处理流的优先级 - :param temp: 处理流是否只运行一次 - """ - self.name = name - self.temp = temp - self.graph: dict[FlowNode, _NodeInfo] = {} - - self._priority = priority - self._priority_cb: AsyncCallable[[HandleLevel], None] - - _edge_maps = tuple( - tuple((elem,) if isinstance(elem, FlowNode) else elem for elem in emap) - for emap in edge_maps - ) - edges = self._get_edges(_edge_maps) - - for n1, n2 in edges: - self._add(n1, n2) - - if not self._valid_check(): - raise FlowError(f"定义的处理流 {self.name} 中存在环路") - - def _get_edges( - self, edge_maps: Sequence[Sequence[Iterable[FlowNode]]] - ) -> list[tuple[FlowNode, FlowNode]]: - edges: list[tuple[FlowNode, FlowNode]] = [] - - for emap in edge_maps: - iter1, iter2 = tee(emap, 2) - try: - next(iter2) - except StopIteration: - continue - - if len(emap) == 1: - for n in emap[0]: - self._add(n, None) - continue - - for from_seq, to_seq in zip(iter1, iter2): - for n1 in from_seq: - for n2 in to_seq: - if (n1, n2) not in edges: - edges.append((n1, n2)) - - return edges - - @property - def priority(self) -> HandleLevel: - """处理流的优先级""" - return self._priority - - def on_priority_reset(self, callback: AsyncCallable[[HandleLevel], None]) -> None: - self._priority_cb = callback - - async def reset_priority(self, new_prior: HandleLevel) -> None: - """重设处理流的优先级 - - 不会立即生效,通常会在下一次进入处理流前生效。 - 因此返回时不代表已经更改优先级,只是添加了计划重设优先级的任务 - - :param new_prior: 新优先级 - """ - - async def _reset_wrapper() -> None: - await self._priority_cb(new_prior) - self._priority = new_prior - - create_task(_reset_wrapper()) - - @property - def starts(self) -> tuple[FlowNode, ...]: - return tuple(n for n, info in self.graph.items() if info.in_deg == 0) - - @property - def ends(self) -> tuple[FlowNode, ...]: - return tuple(n for n, info in self.graph.items() if info.out_deg == 0) - - def __repr__(self) -> str: - output = f"{self.__class__.__name__}(name={self.name}, nums={len(self.graph)}" - - if len(self.graph): - output += f", starts=[{', '.join(repr(n) for n in self.starts)}])" - else: - output += ")" - return output - - def _add(self, _from: FlowNode, to: FlowNode | None) -> None: - from_info = self.graph.setdefault(_from, _NodeInfo([], 0, 0)) - - if to is not None: - to_info = self.graph.setdefault(to, _NodeInfo([], 0, 0)) - to_info.in_deg += 1 - from_info.out_deg += 1 - from_info.nexts.append(to) - - def _valid_check(self) -> bool: - graph = {n: info.copy() for n, info in self.graph.items()} - - while len(graph): - for n, info in graph.items(): - nexts, in_deg = info.nexts, info.in_deg - - if in_deg == 0: - graph.pop(n) - for next_n in nexts: - graph[next_n].in_deg -= 1 - break - - else: - return False - - return True - - def link(self, flow: Flow, priority: HandleLevel | None = None) -> Flow: - """连接另一处理流返回新处理流,并设置新优先级 - - :param flow: 连接的新流 - :param priority: 新优先级,若为空,则使用两者中较小的优先级 - :return: 新的处理流 - """ - _froms = self.ends - _tos = flow.starts - new_edges = tuple((n1, n2) for n1 in _froms for n2 in _tos) - - new_flow = Flow( - f"{self.name} ~ {flow.name}", - *new_edges, - priority=priority if priority else min(self.priority, flow.priority), - temp=self.temp or flow.temp, - ) - - for n1, info in (self.graph | flow.graph).items(): - if not len(info.nexts): - new_flow._add(n1, None) - continue - for n2 in info.nexts: - new_flow._add(n1, n2) - - if not self._valid_check(): - raise FlowError(f"定义的处理流 {self.name} 中存在环路") - - return new_flow - - async def run(self, event: Event) -> None: - if not len(self.starts): - return - - try: - status = _FLOW_CTX.get() - records, store = status.records, status.store - except _FLOW_CTX.lookup_exc_cls: - records, store = FlowRecords(), FlowStore() - - with _FLOW_CTX.unfold( - FlowStatus(event, self, self.starts[0], True, records, store) - ): - try: - records.append( - FlowRecord( - RecordStage.FLOW_START, self.name, self.starts[0].name, event - ) - ) - - idx = 0 - while idx < len(self.starts): - try: - await self.starts[idx].process(event, self) - idx += 1 - except FlowRewound: - pass - - records.append( - FlowRecord( - RecordStage.FLOW_FINISH, self.name, self.starts[0].name, event - ) - ) - except FlowBroke: - pass - - -class _FlowSignal(BaseException): ... - - -class FlowBroke(_FlowSignal): ... - - -class FlowContinued(_FlowSignal): ... - - -class FlowRewound(_FlowSignal): ... - - -async def nextn() -> None: - """运行下一处理结点(在处理流中使用)""" - try: - status = _FLOW_CTX.get() - nexts = status.flow.graph[status.node].nexts - if not status.next_valid: - return - - idx = 0 - while idx < len(nexts): - try: - await nexts[idx].process(status.event, status.flow) - idx += 1 - except FlowRewound: - pass - - except _FLOW_CTX.lookup_exc_cls: - raise FlowError("此时不在活动的事件处理流中,无法调用下一处理结点") from None - finally: - status.next_valid = False - - -async def block() -> None: - """阻止当前事件向更低优先级的处理流传播(在处理流中使用)""" - status = _FLOW_CTX.get() - status.records.append( - FlowRecord(RecordStage.BLOCK, status.flow.name, status.node.name, status.event) - ) - status.event.spread = False - - -async def stop() -> NoReturn: - """立即停止当前处理流(在处理流中使用)""" - status = _FLOW_CTX.get() - status.records.append( - FlowRecord(RecordStage.STOP, status.flow.name, status.node.name, status.event) - ) - status.records.append( - FlowRecord( - RecordStage.NODE_EARLY_FINISH, - status.flow.name, - status.node.name, - status.event, - ) - ) - status.records.append( - FlowRecord( - RecordStage.FLOW_EARLY_FINISH, - status.flow.name, - status.node.name, - status.event, - ) - ) - raise FlowBroke("事件处理流被安全地提早结束,请无视这个内部工作信号") - - -async def bypass() -> NoReturn: - """立即跳过当前处理结点剩下的步骤,运行下一处理结点(在处理流中使用)""" - status = _FLOW_CTX.get() - status.records.append( - FlowRecord(RecordStage.BYPASS, status.flow.name, status.node.name, status.event) - ) - status.records.append( - FlowRecord( - RecordStage.NODE_EARLY_FINISH, - status.flow.name, - status.node.name, - status.event, - ) - ) - raise FlowContinued("事件处理流安全地跳过结点执行,请无视这个内部工作信号") - - -async def rewind() -> NoReturn: - """立即重新运行当前处理结点(在处理流中使用)""" - status = _FLOW_CTX.get() - status.records.append( - FlowRecord(RecordStage.REWIND, status.flow.name, status.node.name, status.event) - ) - status.records.append( - FlowRecord( - RecordStage.NODE_EARLY_FINISH, - status.flow.name, - status.node.name, - status.event, - ) - ) - raise FlowRewound("事件处理流安全地重复执行处理结点,请无视这个内部工作信号") - - -async def flow_to(flow: Flow) -> None: - """立即进入一个其他处理流(在处理流中使用)""" - status = _FLOW_CTX.get() - await flow.run(status.event) diff --git a/src/melobot/handle/register.py b/src/melobot/handle/register.py new file mode 100644 index 00000000..ac2f67e5 --- /dev/null +++ b/src/melobot/handle/register.py @@ -0,0 +1,323 @@ +from asyncio import Lock +from functools import partial, wraps + +from typing_extensions import Callable, Sequence, cast + +from ..adapter.model import Event, TextEvent +from ..ctx import FlowCtx, ParseArgsCtx +from ..di import Depends, inject_deps +from ..session.base import enter_session +from ..session.option import DefaultRule, Rule +from ..typ._enum import LogicMode +from ..typ.base import AsyncCallable, SyncOrAsyncCallable +from ..utils.check import Checker, checker_join +from ..utils.common import get_obj_name +from ..utils.match import ( + ContainMatcher, + EndMatcher, + FullMatcher, + Matcher, + RegexMatcher, + StartMatcher, +) +from ..utils.parse import AbstractParseArgs, CmdArgFormatter, CmdParser, Parser +from .base import Flow, no_deps_node + + +def GetParseArgs() -> AbstractParseArgs: # pylint: disable=invalid-name + """获取解析参数 + + :return: 解析参数 + """ + return cast(AbstractParseArgs, Depends(ParseArgsCtx().get, recursive=False)) + + +class FlowDecorator: + def __init__( + self, + checker: Checker | None | Callable[[Event], bool] = None, + matcher: Matcher | None = None, + parser: Parser | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + rule: Rule[Event] | None = None, + ) -> None: + self.checker: Checker | None + if callable(checker): + self.checker = Checker.new(checker) + else: + self.checker = checker + + self.matcher = matcher + self.parser = parser + self.priority = priority + self.block = block + self.decos = decos + self.rule = rule + + self._temp = temp + self._invalid = False + self._lock = Lock() + self._flow: Flow + + async def _pre_process(self, event: Event) -> tuple[bool, AbstractParseArgs | None]: + if self.checker: + status = await self.checker.check(event) + if not status: + return (False, None) + + args: AbstractParseArgs | None = None + if isinstance(event, TextEvent): + if self.matcher: + status = await self.matcher.match(event.text) + if not status: + return (False, None) + + if self.parser: + args = await self.parser.parse(event.text) + if args: + return (True, args) + return (False, None) + + return (True, None) + + async def _process( + self, + func: AsyncCallable[..., bool | None], + event: Event, + args: AbstractParseArgs | None, + ) -> bool | None: + event.spread = not self.block + parse_args_ctx = ParseArgsCtx() + if args is not None: + args_token = parse_args_ctx.add(args) + else: + args_token = None + + try: + if self.rule is not None: + async with enter_session(self.rule): + return await func() + return await func() + finally: + if args_token: + parse_args_ctx.remove(args_token) + + async def auto_flow_wrapped( + self, func: AsyncCallable[..., bool | None] + ) -> bool | None: + if self._invalid: + self._flow.dismiss() + return None + + event = FlowCtx().get().completion.event + if not self._temp: + passed, args = await self._pre_process(event) + if not passed: + return None + return await self._process(func, event, args) + + async with self._lock: + if self._invalid: + self._flow.dismiss() + return None + + passed, args = await self._pre_process(event) + if not passed: + return None + self._invalid = True + + return await self._process(func, event, args) + + def __call__(self, func: SyncOrAsyncCallable[..., bool | None]) -> Flow: + func = inject_deps(func) + if self.decos is not None: + for deco in reversed(self.decos): + func = deco(func) + + n = no_deps_node(wraps(func)(partial(self.auto_flow_wrapped, func))) + n.name = get_obj_name(func, otype="callable") + self._flow = Flow(f"OneBotV11Flow[{n.name}]", (n,), priority=self.priority) + return self._flow + + +def on_event( + checker: Checker | None | Callable[[Event], bool] = None, + matcher: Matcher | None = None, + parser: Parser | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + rule: Rule[Event] | None = None, +) -> FlowDecorator: + return FlowDecorator(checker, matcher, parser, priority, block, temp, decos, rule) + + +def on_text( + checker: Checker | None | Callable[[TextEvent], bool] = None, + matcher: Matcher | None = None, + parser: Parser | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + legacy_session: bool = False, +) -> FlowDecorator: + if legacy_session: + rule = DefaultRule() + else: + rule = None + + return on_event( + checker_join(lambda e: isinstance(e, TextEvent), checker), # type: ignore[arg-type] + matcher, + parser, + priority, + block, + temp, + decos, + cast(Rule[Event] | None, rule), + ) + + +def on_command( + cmd_start: str | list[str], + cmd_sep: str | list[str], + targets: str | list[str], + fmtters: list[CmdArgFormatter | None] | None = None, + checker: Checker | None | Callable[[TextEvent], bool] = None, + matcher: Matcher | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + legacy_session: bool = False, +) -> FlowDecorator: + return on_text( + checker, + matcher, + CmdParser(cmd_start=cmd_start, cmd_sep=cmd_sep, targets=targets, fmtters=fmtters), + priority, + block, + temp, + decos, + legacy_session, + ) + + +def on_start_match( + target: str | list[str], + logic_mode: LogicMode = LogicMode.OR, + checker: Checker | Callable[[TextEvent], bool] | None = None, + parser: Parser | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + legacy_session: bool = False, +) -> FlowDecorator: + return on_text( + checker, + StartMatcher(target, logic_mode), + parser, + priority, + block, + temp, + decos, + legacy_session, + ) + + +def on_contain_match( + target: str | list[str], + logic_mode: LogicMode = LogicMode.OR, + checker: Checker | Callable[[TextEvent], bool] | None = None, + parser: Parser | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + legacy_session: bool = False, +) -> FlowDecorator: + return on_text( + checker, + ContainMatcher(target, logic_mode), + parser, + priority, + block, + temp, + decos, + legacy_session, + ) + + +def on_full_match( + target: str | list[str], + logic_mode: LogicMode = LogicMode.OR, + checker: Checker | Callable[[TextEvent], bool] | None = None, + parser: Parser | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + legacy_session: bool = False, +) -> FlowDecorator: + return on_text( + checker, + FullMatcher(target, logic_mode), + parser, + priority, + block, + temp, + decos, + legacy_session, + ) + + +def on_end_match( + target: str | list[str], + logic_mode: LogicMode = LogicMode.OR, + checker: Checker | Callable[[TextEvent], bool] | None = None, + parser: Parser | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + legacy_session: bool = False, +) -> FlowDecorator: + return on_text( + checker, + EndMatcher(target, logic_mode), + parser, + priority, + block, + temp, + decos, + legacy_session, + ) + + +def on_regex_match( + target: str, + logic_mode: LogicMode = LogicMode.OR, + checker: Checker | Callable[[TextEvent], bool] | None = None, + parser: Parser | None = None, + priority: int = 0, + block: bool = False, + temp: bool = False, + decos: Sequence[Callable[[Callable], Callable]] | None = None, + legacy_session: bool = False, +) -> FlowDecorator: + return on_text( + checker, + RegexMatcher(target, logic_mode), + parser, + priority, + block, + temp, + decos, + legacy_session, + ) diff --git a/src/melobot/io/__init__.py b/src/melobot/io/__init__.py index 35472279..a9ad934e 100644 --- a/src/melobot/io/__init__.py +++ b/src/melobot/io/__init__.py @@ -4,7 +4,14 @@ AbstractOutSource, AbstractSource, EchoPacket, + EchoPacketT, + InOrOutSourceT, InPacket, + InPacketT, + InSourceT, + IOSourceT, OutPacket, + OutPacketT, + OutSourceT, SourceLifeSpan, ) diff --git a/src/melobot/io/base.py b/src/melobot/io/base.py index 6b8e02d3..35926ea1 100644 --- a/src/melobot/io/base.py +++ b/src/melobot/io/base.py @@ -1,25 +1,20 @@ +from __future__ import annotations + import time +from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum from types import TracebackType from typing_extensions import Any, Generic, LiteralString, Self, TypeVar -from .._hook import Hookable -from ..typ import BetterABC, abstractmethod -from ..utils import get_id - - -@dataclass -class _Packet: - time: float = field(default_factory=lambda: time.time_ns() / 1e9) - id: str = field(default_factory=get_id) - protocol: LiteralString | None = None - data: Any = None +from ..mixin import HookMixin, LogMixin +from ..typ.cls import BetterABC, abstractattr +from ..utils.common import get_id @dataclass -class InPacket(_Packet): +class InPacket: """输入包基类(数据类) :ivar float time: 时间戳 @@ -28,9 +23,14 @@ class InPacket(_Packet): :ivar Any data: 附加的数据 """ + time: float = field(default_factory=lambda: time.time_ns() / 1e9) + id: str = field(default_factory=get_id) + protocol: LiteralString | None = None + data: Any = None + @dataclass -class OutPacket(_Packet): +class OutPacket: """输出包基类(数据类) :ivar float time: 时间戳 @@ -39,9 +39,14 @@ class OutPacket(_Packet): :ivar Any data: 附加的数据 """ + time: float = field(default_factory=lambda: time.time_ns() / 1e9) + id: str = field(default_factory=get_id) + protocol: LiteralString | None = None + data: Any = None + @dataclass -class EchoPacket(_Packet): +class EchoPacket: """回应包基类(数据类) :ivar float time: 时间戳 @@ -54,6 +59,10 @@ class EchoPacket(_Packet): :ivar bool noecho: 是否并无回应产生 """ + time: float = field(default_factory=lambda: time.time_ns() / 1e9) + id: str = field(default_factory=get_id) + protocol: LiteralString | None = None + data: Any = None ok: bool = True status: int = 0 prompt: str = "" @@ -73,15 +82,16 @@ class SourceLifeSpan(Enum): STOPPED = "sto" -class AbstractSource(BetterABC, Hookable[SourceLifeSpan]): +class AbstractSource(HookMixin[SourceLifeSpan], LogMixin, BetterABC): """抽象源基类""" - def __init__(self, protocol: LiteralString) -> None: - Hookable.__init__( - self, SourceLifeSpan, tag=f"{protocol}/{self.__class__.__name__}" - ) + protocol: LiteralString = abstractattr() - self.protocol = protocol + def __init__(self) -> None: + super().__init__( + hook_type=SourceLifeSpan, + hook_tag=f"{self.__class__.__module__}.{self.__class__.__name__}", + ) @abstractmethod async def open(self) -> None: @@ -122,7 +132,7 @@ async def __aexit__( return None -class AbstractInSource(AbstractSource, BetterABC, Generic[InPacketT]): +class AbstractInSource(AbstractSource, Generic[InPacketT]): """抽象输入源基类""" @abstractmethod @@ -149,7 +159,7 @@ async def input(self) -> InPacketT: InSourceT = TypeVar("InSourceT", bound=AbstractInSource) -class AbstractOutSource(AbstractSource, BetterABC, Generic[OutPacketT, EchoPacketT]): +class AbstractOutSource(AbstractSource, Generic[OutPacketT, EchoPacketT]): """抽象输出源基类""" @abstractmethod @@ -179,7 +189,7 @@ async def output(self, packet: OutPacketT) -> EchoPacketT: class AbstractIOSource( - AbstractInSource[InPacketT], AbstractOutSource[OutPacketT, EchoPacketT], BetterABC + AbstractInSource[InPacketT], AbstractOutSource[OutPacketT, EchoPacketT] ): """抽象输入输出源基类""" diff --git a/src/melobot/log/base.py b/src/melobot/log/base.py index a4780bb1..502ac401 100644 --- a/src/melobot/log/base.py +++ b/src/melobot/log/base.py @@ -8,6 +8,7 @@ import sys import traceback import types +from abc import abstractmethod from contextlib import contextmanager from enum import Enum from inspect import currentframe @@ -24,8 +25,10 @@ from rich.text import Text from typing_extensions import Any, Callable, Generator, Literal, Optional -from ..typ import BetterABC, T, VoidType, abstractmethod -from ..utils import singleton +from ..typ._enum import VoidType +from ..typ.base import T +from ..typ.cls import BetterABC +from ..utils.common import singleton _CONSOLE_IO = io.StringIO() _CONSOLE = rich.console.Console(file=_CONSOLE_IO, record=True, color_system="256") diff --git a/src/melobot/mixin.py b/src/melobot/mixin.py new file mode 100644 index 00000000..dd5837a4 --- /dev/null +++ b/src/melobot/mixin.py @@ -0,0 +1,244 @@ +import inspect +from asyncio import Future, get_running_loop + +from typing_extensions import Any, Callable, Generic, Self, cast + +from ._hook import HookBus, HookEnumT +from .ctx import LoggerCtx +from .log.base import GenericLogger +from .typ.base import AsyncCallable, P, SyncOrAsyncCallable +from .utils.base import to_async + + +class LogMixin: + @property + def logger(self) -> GenericLogger: + return LoggerCtx().get() + + +class FlagMixin: + def __init__(self) -> None: + self.__flag_mixin_flags__: dict[Any, dict[Any, Any]] = {} + self.__flag_mixin_waitings__: dict[ + tuple[Any, Any], list[tuple[Any, Future, bool, bool]] + ] = {} + + def __flag_waitings_fulfill__(self, namespace: Any, flag: Any, val: Any) -> None: + waitings = self.__flag_mixin_waitings__.get((namespace, flag)) + if waitings is None: + return + + for waiting in waitings: + expect_val, signal, use_id, wait_val = waiting + if not wait_val: + signal.set_result(None) + continue + + if use_id and val is expect_val: + signal.set_result(None) + continue + + if not use_id and val == expect_val: + signal.set_result(None) + continue + + def flag_set( + self, + namespace: Any, + flag: Any, + val: Any = None, + strict: bool = True, + ) -> None: + """设置标记 + + 注:不同的对象并不共享 `namespace`,`namespace` 只适用于单个对象 + + :param namespace: 命名空间 + :param flag: 标记 + :param val: 标记值 + :param strict: 严格模式,启用严格模式,则不允许 `flag` 标记已经存在 + """ + self.__flag_mixin_flags__.setdefault(namespace, {}) + + if strict and flag in self.__flag_mixin_flags__[namespace].keys(): + raise ValueError( + f"标记失败。对象 {self} 的命名空间 {namespace} 中已存在名为 {flag} 的标记" + ) + + self.__flag_mixin_flags__[namespace][flag] = val + self.__flag_waitings_fulfill__(namespace, flag, val) + + def flag_set_default(self, namespace: Any, flag: Any, default: Any) -> None: + """设置标记,并在标记不存在时使用 `default` 初始化 + + 注:不同的对象并不共享 `namespace`,`namespace` 只适用于单个对象 + + :param namespace: 命名空间 + :param flag: 标记 + :param default: 标记不存在时的默认值 + """ + self.__flag_mixin_flags__.setdefault(namespace, {}) + self.__flag_mixin_flags__[namespace].setdefault(flag, default) + val = self.__flag_mixin_flags__[namespace][flag] + self.__flag_waitings_fulfill__(namespace, flag, val) + + def flag_get( + self, namespace: Any, flag: Any, raise_exc: bool = True, default: Any = None + ) -> Any: + """获取标记值 + + 注:不同的对象并不共享 `namespace`,`namespace` 只适用于单个对象 + + :param namespace: 命名空间 + :param flag: 标记 + :param raise_exc: 为 `True`,则在标记不存在时引发 `KeyError` + :param default: 标记不存在时的默认值,只在 `raise_exc` 为 `False` 时有效 + :return: 标记值 + """ + try: + return self.__flag_mixin_flags__[namespace][flag] + except KeyError: + if raise_exc: + raise KeyError( + f"对象 {self} 的命名空间 {namespace} 中不存在名为 {flag} 的标记" + ) from None + return default + + def flag_check( + self, + namespace: Any, + flag: Any, + val: Any = None, + check_val: bool = True, + use_id: bool = False, + ) -> bool: + """检查标记 + + 注:不同的对象并不共享 `namespace`,`namespace` 只适用于单个对象 + + :param namespace: 命名空间 + :param flag: 标记 + :param val: 标记值 + :param check_val: 为 `True` 则需要值也一致 + :param use_id: 为 `True` 则使用 `is` 判断 `val`,否则调用 `==` 判断 `val` + :return: 是否通过检查 + """ + # pylint: disable=consider-iterating-dictionary + if namespace not in self.__flag_mixin_flags__.keys(): + return False + if flag not in self.__flag_mixin_flags__[namespace].keys(): + return False + flag = self.__flag_mixin_flags__[namespace][flag] + + if not check_val: + return True + if use_id: + return flag is val + return cast(bool, flag == val) + + async def flag_wait( + self, + namespace: Any, + flag: Any, + val: Any = None, + wait_val: bool = True, + use_id: bool = False, + ) -> None: + """等待标记 + + 注:不同的对象并不共享 `namespace`,`namespace` 只适用于单个对象 + + :param namespace: 命名空间 + :param flag: 标记 + :param val: 标记值 + :param wait_val: 为 `True` 则需要值也一致 + :param use_id: 为 `True` 则使用 `is` 判断 `val`,否则调用 `==` 判断 `val` + :return: Future 对象 + """ + if self.flag_check(namespace, flag, val, wait_val, use_id): + return None + + signal: Future[None] = get_running_loop().create_future() + waitings = self.__flag_mixin_waitings__.setdefault((namespace, flag), []) + waitings.append((val, signal, use_id, wait_val)) + await signal + waitings = list(filter(lambda x: not x[1].done(), waitings)) + if not len(waitings): + self.__flag_mixin_waitings__.pop((namespace, flag)) + + +class AttrReprMixin: + def __repr__(self) -> str: + attrs = ", ".join( + f"{k}={repr(v)}" for k, v in self.__dict__.items() if not k.startswith("_") + ) + if len(attrs) >= 100: + attrs = attrs[:100] + "..." + return f"{self.__class__.__name__}({attrs})" + + +class LocateMixin: + def __new__(cls, *_args: Any, **_kwargs: Any) -> Self: + obj = super().__new__(cls) + obj.__obj_location__ = obj.__location_init__() # type: ignore[attr-defined] + return obj + + def __init__(self) -> None: + self.__obj_location__: tuple[str, str, int] + + @staticmethod + def __location_init__() -> tuple[str, str, int]: + frame = inspect.currentframe() + while frame: + if frame.f_code.co_name == "": + return ( + frame.f_globals["__name__"], + frame.f_globals["__file__"], + frame.f_lineno, + ) + frame = frame.f_back + + return ("", "", -1) + + @property + def __obj_module__(self) -> str: + return self.__obj_location__[0] + + @property + def __obj_file__(self) -> str: + return self.__obj_location__[1] + + @property + def __obj_line__(self) -> int: + return self.__obj_location__[2] + + +class HookMixin(Generic[HookEnumT]): + def __init__(self, hook_type: type[HookEnumT], hook_tag: str | None = None): + super().__init__() + self._hook_bus = HookBus[HookEnumT](hook_type, hook_tag) + self.__repeatable_hook_types__: set[HookEnumT] = set() + + def __mark_repeatable_hooks__(self, *types: HookEnumT) -> None: + for t in types: + self.__repeatable_hook_types__.add(t) + + def on( + self, *periods: HookEnumT + ) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]: + """注册一个 hook + + :param periods: 要绑定的 hook 类型 + :return: 装饰器 + """ + + def hook_register_wrapped( + func: SyncOrAsyncCallable[P, None] + ) -> AsyncCallable[P, None]: + f = to_async(func) + for type in periods: + once = type not in self.__repeatable_hook_types__ + self._hook_bus.register(type, func, once) + return f + + return hook_register_wrapped diff --git a/src/melobot/plugin/base.py b/src/melobot/plugin/base.py index 048884a2..a8568304 100644 --- a/src/melobot/plugin/base.py +++ b/src/melobot/plugin/base.py @@ -1,19 +1,15 @@ from __future__ import annotations -import asyncio from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing_extensions import final, overload +from typing_extensions import Callable, final, overload -from .._hook import HookBus -from ..ctx import BotCtx from ..exceptions import PluginLoadError -from ..handle.base import EventHandler -from ..handle.process import Flow -from ..typ import AsyncCallable, Callable, P, T -from ..utils import to_async +from ..handle.base import Flow +from ..mixin import HookMixin +from ..typ.base import P, T from .ipc import AsyncShare, SyncShare @@ -35,7 +31,7 @@ class PluginInfo: author: str = "" -class PluginPlanner: +class PluginPlanner(HookMixin[PluginLifeSpan]): """插件管理器类 用于声明一个插件,并为插件添加功能 @@ -58,33 +54,25 @@ def __init__( :param funcs: 导出函数列表。可以先指定为空,后续使用 :meth:`use` 绑定 :param info: 插件信息 """ + super().__init__(hook_type=PluginLifeSpan) self.version = version - self.flows = [] if flows is None else flows + self.init_flows = [] if flows is None else flows self.shares = [] if shares is None else shares self.funcs = [] if funcs is None else funcs self.info = PluginInfo() if info is None else info self._pname: str = "" - self._hook_bus = HookBus[PluginLifeSpan](PluginLifeSpan) self._built: bool = False self._plugin: Plugin @final - def on( - self, *periods: PluginLifeSpan - ) -> Callable[[AsyncCallable[P, None]], AsyncCallable[P, None]]: - """注册一个 hook - - :param periods: 要绑定的 hook 类型 - :return: 装饰器 - """ - - def wrapped(func: AsyncCallable[P, None]) -> AsyncCallable[P, None]: - for type in periods: - self._hook_bus.register(type, func) - return func - - return wrapped + def __p_build__(self, name: str) -> Plugin: + if not self._built: + self._pname = name + self._hook_bus.set_tag(name) + self._plugin = Plugin(self) + self._built = True + return self._plugin @overload def use(self, obj: Flow) -> Flow: ... @@ -105,7 +93,7 @@ def use(self, obj: T) -> T: :return: 被绑定的组件本身 """ if isinstance(obj, Flow): - self.flows.append(obj) + self.init_flows.append(obj) elif isinstance(obj, (SyncShare, AsyncShare)): self.shares.append(obj) elif callable(obj): @@ -114,70 +102,6 @@ def use(self, obj: T) -> T: raise PluginLoadError(f"插件无法使用 {type(obj)} 类型的对象") return obj - @final - def __p_build__(self, name: str) -> Plugin: - if not self._built: - self._pname = name - self._hook_bus.set_tag(name) - self._plugin = Plugin(self) - self._built = True - return self._plugin - - @final - def add_flow(self, *flows: Flow) -> None: - """在运行期为指定的插件添加一批处理流 - - 在 :obj:`.PluginLifeSpan.INITED` 及其之后的阶段可以使用 - - 注意:不会立即生效,通常会在下一次事件处理前生效。 - 因此返回时不代表已经添加了处理流,只是增添了添加处理流的任务 - - :param flows: 处理流 - """ - try: - self._plugin - except AttributeError as e: - raise PluginLoadError("插件尚未创建,此时无法运行此方法") from e - - hs = tuple(EventHandler(self._plugin, f) for f in flows) - - async def _add() -> None: - await BotCtx().get()._dispatcher.add( - *hs, callback=to_async(lambda: self._plugin.handlers.extend(hs)) - ) - - asyncio.create_task(_add()) - - @final - def remove_flow(self, *flows: Flow) -> None: - """在运行期为指定的插件移除一批处理流 - - 如果插件没有启用对应的处理流,不会发出异常,而是忽略 - - 在 :obj:`.PluginLifeSpan.INITED` 及其之后的阶段可以使用 - - 注意:不会立即生效,通常会在下一次事件处理前生效。 - 因此返回时不代表已经移除了处理流,只是增添了移除处理流的任务 - - :param flows: 处理流 - """ - try: - self._plugin - except AttributeError as e: - raise PluginLoadError("插件尚未创建,此时无法运行此方法") from e - - hs = tuple(filter(lambda x: x.flow in flows, self._plugin.handlers)) - - async def _del() -> None: - await BotCtx().get()._dispatcher.remove(*hs, callback=_after_del) - - async def _after_del() -> None: - self._plugin.handlers = list( - filter(lambda x: x not in hs, self._plugin.handlers) - ) - - asyncio.create_task(_del()) - class Plugin: def __init__(self, planner: PluginPlanner) -> None: @@ -187,4 +111,4 @@ def __init__(self, planner: PluginPlanner) -> None: self.shares = planner.shares self.funcs = planner.funcs - self.handlers = list(EventHandler(self, f) for f in self.planner.flows) + self.init_flows = planner.init_flows diff --git a/src/melobot/plugin/ipc.py b/src/melobot/plugin/ipc.py index e0b589c5..1309872b 100644 --- a/src/melobot/plugin/ipc.py +++ b/src/melobot/plugin/ipc.py @@ -2,11 +2,12 @@ from ..di import inject_deps from ..exceptions import PluginIpcError -from ..typ import AsyncCallable, AttrsReprable, Locatable, T -from ..utils import RWContext +from ..mixin import AttrReprMixin, LocateMixin +from ..typ.base import AsyncCallable, T +from ..utils.common import RWContext -class AsyncShare(Generic[T], Locatable, AttrsReprable): +class AsyncShare(Generic[T], LocateMixin, AttrReprMixin): """异步共享对象""" def __init__( @@ -88,7 +89,7 @@ async def set(self, val: T) -> None: return await self.__callback(val) -class SyncShare(Generic[T], Locatable, AttrsReprable): +class SyncShare(Generic[T], LocateMixin, AttrReprMixin): """同步共享对象""" def __init__( diff --git a/src/melobot/plugin/load.py b/src/melobot/plugin/load.py index 4012e2df..d6c392e3 100644 --- a/src/melobot/plugin/load.py +++ b/src/melobot/plugin/load.py @@ -13,7 +13,7 @@ from .._imp import Importer from ..ctx import BotCtx, LoggerCtx from ..exceptions import PluginAutoGenError, PluginLoadError -from ..utils import singleton +from ..utils.common import singleton from .base import Plugin, PluginPlanner from .ipc import AsyncShare, SyncShare @@ -41,7 +41,7 @@ class PluginInitHelper: def _get_init_py_str() -> str: return re.sub( r"_VAR(\d+)", - lambda match: f"_{int(time()):#x}{match.group(1)}", + lambda matched: f"_{int(time()):#x}{matched.group(1)}", PluginInitHelper._BASE_INIT_PY_STR, ) diff --git a/src/melobot/protocols/base.py b/src/melobot/protocols/base.py index a42b0a15..44eabdb6 100644 --- a/src/melobot/protocols/base.py +++ b/src/melobot/protocols/base.py @@ -1,8 +1,6 @@ -from typing_extensions import Sequence - from ..adapter.base import Adapter from ..io.base import AbstractInSource, AbstractOutSource -from ..typ import BetterABC, abstractattr +from ..typ.cls import BetterABC, abstractattr class ProtocolStack(BetterABC): @@ -11,13 +9,13 @@ class ProtocolStack(BetterABC): 子类需要把以下属性按 :func:`.abstractattr` 的要求实现 """ - inputs: Sequence[AbstractInSource] = abstractattr() + inputs: set[AbstractInSource] = abstractattr() """该协议栈兼容的输入源序列 :meta hide-value: """ - outputs: Sequence[AbstractOutSource] = abstractattr() + outputs: set[AbstractOutSource] = abstractattr() """该协议栈兼容的输出源序列 :meta hide-value: diff --git a/src/melobot/protocols/onebot/v11/__init__.py b/src/melobot/protocols/onebot/v11/__init__.py index 07809f23..606de096 100644 --- a/src/melobot/protocols/onebot/v11/__init__.py +++ b/src/melobot/protocols/onebot/v11/__init__.py @@ -1,24 +1,39 @@ +from melobot.protocols import ProtocolStack +from melobot.utils.common import DeprecatedLoader as _DeprecatedLoader + from .. import __version__ -from .adapter import Adapter, EchoRequireCtx -from .adapter.action import Action -from .adapter.echo import Echo -from .adapter.event import Event -from .adapter.segment import Segment -from .const import PROTOCOL_IDENTIFIER -from .handle import ( - DefaultRule, - on_at_qq, - on_command, - on_contain_match, - on_end_match, - on_event, - on_full_match, - on_message, - on_meta, - on_notice, - on_regex_match, - on_request, - on_start_match, +from .adapter import * +from .const import ( + PROTOCOL_IDENTIFIER, + PROTOCOL_NAME, + PROTOCOL_SUPPORT_AUTHOR, + PROTOCOL_VERSION, ) -from .io import ForwardWebSocketIO, HttpIO, ReverseWebSocketIO -from .utils import GroupRole, LevelRole, ParseArgs +from .handle import _LOADER as _HANDLE_LOADER +from .handle import on_at_qq, on_event, on_message, on_meta, on_notice, on_request +from .io import * +from .utils import _LOADER as _UTILS_LOADER +from .utils import * + + +class OneBotV11Protocol(ProtocolStack): + def __init__(self, *srcs: BaseSource) -> None: + super().__init__() + self.adapter = Adapter() + self.inputs = set() + self.outputs = set() + + for src in srcs: + if not isinstance(src, BaseSource): + raise TypeError(f"不支持的 OneBot v11 源类型: {type(src)}") + if isinstance(src, BaseInSource): + self.inputs.add(src) + if isinstance(src, BaseOutSource): + self.outputs.add(src) + + +_LOADER = _DeprecatedLoader.merge(__name__, _HANDLE_LOADER, _UTILS_LOADER) + + +def __getattr__(name: str) -> Any: + return _LOADER.get(name) diff --git a/src/melobot/protocols/onebot/v11/adapter/__init__.py b/src/melobot/protocols/onebot/v11/adapter/__init__.py index 846e8921..71f1e881 100644 --- a/src/melobot/protocols/onebot/v11/adapter/__init__.py +++ b/src/melobot/protocols/onebot/v11/adapter/__init__.py @@ -1 +1,137 @@ -from .base import Adapter, EchoRequireCtx +from .action import ( + Action, + CanSendImageAction, + CanSendRecordAction, + CleanCacheAction, + DeleteMsgAction, + GetCookiesAction, + GetCredentialsAction, + GetCsrfTokenAction, + GetForwardMsgAction, + GetFriendlistAction, + GetGroupHonorInfoAction, + GetGroupInfoAction, + GetGrouplistAction, + GetGroupMemberInfoAction, + GetGroupMemberlistAction, + GetImageAction, + GetLoginInfoAction, + GetMsgAction, + GetRecordAction, + GetStatusAction, + GetStrangerInfoAction, + GetVersionInfoAction, + SendForwardMsgAction, + SendLikeAction, + SendMsgAction, + SetFriendAddRequestAction, + SetGroupAddRequestAction, + SetGroupAdminAction, + SetGroupAnonymousAction, + SetGroupAnonymousBanAction, + SetGroupBanAction, + SetGroupCardAction, + SetGroupKickAction, + SetGroupLeaveAction, + SetGroupNameAction, + SetGroupSpecialTitleAction, + SetGroupWholeBanAction, + SetRestartAction, + msgs_to_dicts, +) +from .base import Adapter, EchoFactory, EchoRequireCtx, EventFactory, OutputFactory +from .echo import ( + CanSendImageEcho, + CanSendRecordEcho, + Echo, + EmptyEcho, + GetCookiesEcho, + GetCredentialsEcho, + GetCsrfTokenEcho, + GetForwardMsgEcho, + GetFriendListEcho, + GetGroupHonorInfoEcho, + GetGroupInfoEcho, + GetGroupListEcho, + GetGroupMemberInfoEcho, + GetGroupMemberListEcho, + GetImageEcho, + GetLoginInfoEcho, + GetMsgEcho, + GetRecordEcho, + GetStatusEcho, + GetStrangerInfoEcho, + GetVersionInfoEcho, + SendForwardMsgEcho, + SendMsgEcho, +) +from .event import ( + Event, + FriendAddNoticeEvent, + FriendRecallNoticeEvent, + FriendRequestEvent, + GroupAdminNoticeEvent, + GroupBanNoticeEvent, + GroupDecreaseNoticeEvent, + GroupIncreaseNoticeEvent, + GroupMessageEvent, + GroupRecallNoticeEvent, + GroupRequestEvent, + GroupUploadNoticeEvent, + HeartBeatMetaEvent, + HonorNotifyEvent, + LifeCycleMetaEvent, + LuckyKingNotifyEvent, + MessageEvent, + MetaEvent, + NoticeEvent, + NotifyNoticeEvent, + PokeNotifyEvent, + PrivateMessageEvent, + RequestEvent, +) +from .segment import ( + AnonymousSegment, + AtSegment, + ContactFriendSegment, + ContactGroupSegment, + ContactSegment, + DiceSegment, + FaceSegment, + ForwardSegment, + ImageRecvSegment, + ImageSegment, + ImageSendSegment, + JsonSegment, + LocationSegment, + MediaUrl, + MusicCustomSegment, + MusicPlatformSegment, + MusicSegment, + NodeGocqCustomSegment, + NodeReferSegment, + NodeSegment, + NodeStdCustomSegment, + PokeRecvSegment, + PokeSegment, + PokeSendSegment, + RecordRecvSegment, + RecordSegment, + RecordSendSegment, + ReplySegment, + RpsSegment, + Segment, + ShakeSegment, + ShareSegment, + TextSegment, + VideoRecvSegment, + VideoSegment, + VideoSendSegment, + XmlSegment, + base64_encode, + contents_to_segs, + cq_anti_escape, + cq_escape, + cq_filter_text, + segs_to_contents, +) diff --git a/src/melobot/protocols/onebot/v11/adapter/action.py b/src/melobot/protocols/onebot/v11/adapter/action.py index c08f113c..4338168b 100644 --- a/src/melobot/protocols/onebot/v11/adapter/action.py +++ b/src/melobot/protocols/onebot/v11/adapter/action.py @@ -2,7 +2,7 @@ from typing_extensions import Any, Iterable, Literal, Optional, TypedDict -from melobot.adapter.model import Action as RootAction +from melobot.adapter import Action as RootAction from melobot.handle import try_get_event from ..const import PROTOCOL_IDENTIFIER @@ -19,6 +19,9 @@ def __init__(self, type: str, params: dict[str, Any]) -> None: self.params = params self.need_echo = False + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type}, echo={self.need_echo})" + def set_echo(self, status: bool) -> None: self.need_echo = status @@ -60,6 +63,7 @@ def __init__( type = "send_msg" _msgs = msgs_to_dicts(msgs) if group_id is None: + assert user_id is not None, "group_id 为空时,user_id 必须为非空" params = { "message_type": "private", "user_id": user_id, diff --git a/src/melobot/protocols/onebot/v11/adapter/base.py b/src/melobot/protocols/onebot/v11/adapter/base.py index 3b5450fc..209c68a8 100644 --- a/src/melobot/protocols/onebot/v11/adapter/base.py +++ b/src/melobot/protocols/onebot/v11/adapter/base.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import asyncio -from functools import wraps from os import PathLike from typing_extensions import Any, Callable, Iterable, Literal, Optional, Sequence, cast @@ -8,65 +9,56 @@ AbstractEchoFactory, AbstractEventFactory, AbstractOutputFactory, + ActionHandle, ) from melobot.adapter import Adapter as RootAdapter +from melobot.adapter import Content, EchoT from melobot.adapter import content as mc -from melobot.adapter.content import Content -from melobot.adapter.model import ActionHandle, EchoT from melobot.ctx import Context from melobot.exceptions import AdapterError from melobot.handle import try_get_event -from melobot.typ import AsyncCallable -from melobot.utils import to_coro +from melobot.typ import AsyncCallable, SyncOrAsyncCallable +from melobot.utils import to_async, to_coro -from ..const import PROTOCOL_IDENTIFIER, P, T -from ..io.base import BaseIO +from ..const import ACTION_TYPE_KEY_NAME, PROTOCOL_IDENTIFIER, P, T +from ..io.base import BaseIOSource from ..io.packet import EchoPacket, InPacket, OutPacket from . import action as ac from . import echo as ec from . import event as ev from . import segment as se -_ValidateErrHandler = AsyncCallable[[dict[str, Any], Exception], None] +_ValidateHandler = AsyncCallable[[dict[str, Any], Exception], None] -class ValidateErrHandleable: +class ValidateHandleMixin: def __init__(self) -> None: - self.err_handlers: list[_ValidateErrHandler] = [] - - def add_validate_handler(self, callback: _ValidateErrHandler) -> None: - self.err_handlers.append(callback) - - def validate_handle( - self, data: dict[str, Any] - ) -> Callable[[Callable[P, T]], AsyncCallable[P, T]]: - - def wrapper(func: Callable[P, T]) -> AsyncCallable[P, T]: - - @wraps(func) - async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: - try: - return func(*args, **kwargs) - - except Exception as e: - tasks = tuple( - asyncio.create_task(to_coro(cb(data, e))) - for cb in self.err_handlers - ) - if len(tasks): - await asyncio.wait(tasks) - - return func(*args, **kwargs) - - return wrapped - - return wrapper + self.validate_handlers: list[_ValidateHandler] = [] + + def add_validate_handler(self, callback: _ValidateHandler) -> None: + self.validate_handlers.append(callback) + + async def validate_handle( + self, data: dict[str, Any], func: Callable[[dict[str, Any]], T] + ) -> T: + try: + return func(data) + except Exception as e: + tasks = map( + asyncio.create_task, + map( + to_coro, + (cb(data, e) for cb in self.validate_handlers), + ), + ) + if len(self.validate_handlers): + await asyncio.wait(tasks) + return func(data) -class EventFactory(AbstractEventFactory[InPacket, ev.Event], ValidateErrHandleable): +class EventFactory(AbstractEventFactory[InPacket, ev.Event], ValidateHandleMixin): async def create(self, packet: InPacket) -> ev.Event: - data = packet.data - return await self.validate_handle(data)(ev.Event.resolve)(data) + return await self.validate_handle(packet.data, ev.Event.resolve) class OutputFactory(AbstractOutputFactory[OutPacket, ac.Action]): @@ -79,14 +71,14 @@ async def create(self, action: ac.Action) -> OutPacket: ) -class EchoFactory(AbstractEchoFactory[EchoPacket, ec.Echo], ValidateErrHandleable): +class EchoFactory(AbstractEchoFactory[EchoPacket, ec.Echo], ValidateHandleMixin): async def create(self, packet: EchoPacket) -> ec.Echo | None: if packet.noecho: return None data = packet.data - data["action_type"] = packet.action_type - return await self.validate_handle(data)(ec.Echo.resolve)(data) + data[ACTION_TYPE_KEY_NAME] = packet.action_type + return await self.validate_handle(data, ec.Echo.resolve) class EchoRequireCtx(Context[bool]): @@ -95,7 +87,9 @@ def __init__(self) -> None: class Adapter( - RootAdapter[EventFactory, OutputFactory, EchoFactory, ac.Action, BaseIO, BaseIO] + RootAdapter[ + EventFactory, OutputFactory, EchoFactory, ac.Action, BaseIOSource, BaseIOSource + ] ): def __init__(self) -> None: super().__init__( @@ -103,21 +97,22 @@ def __init__(self) -> None: ) def when_validate_error(self, validate_type: Literal["event", "echo"]) -> Callable[ - [AsyncCallable[[dict[str, Any], Exception], None]], + [SyncOrAsyncCallable[[dict[str, Any], Exception], None]], AsyncCallable[[dict[str, Any], Exception], None], ]: - def wrapper( - func: AsyncCallable[[dict[str, Any], Exception], None] + def when_validate_error_wrapper( + func: SyncOrAsyncCallable[[dict[str, Any], Exception], None] ) -> AsyncCallable[[dict[str, Any], Exception], None]: + f = to_async(func) if validate_type == "event": - self._event_factory.add_validate_handler(func) + self._event_factory.add_validate_handler(f) elif validate_type == "echo": - self._echo_factory.add_validate_handler(func) + self._echo_factory.add_validate_handler(f) else: raise AdapterError("无效的验证类型,合法值是 'event', 'echo' 之一") - return func + return f - return wrapper + return when_validate_error_wrapper async def call_output(self, action: ac.Action) -> tuple[ActionHandle, ...]: """输出行为的底层方法 @@ -132,14 +127,14 @@ async def call_output(self, action: ac.Action) -> tuple[ActionHandle, ...]: def with_echo( self, func: AsyncCallable[P, tuple[ActionHandle[EchoT | None], ...]] ) -> AsyncCallable[P, tuple[ActionHandle[EchoT], ...]]: - async def wrapped_api( + async def with_echo_wrapped( *args: P.args, **kwargs: P.kwargs ) -> tuple[ActionHandle[EchoT], ...]: with EchoRequireCtx().unfold(True): handles = await func(*args, **kwargs) return cast(tuple[ActionHandle[EchoT], ...], handles) - return wrapped_api + return with_echo_wrapped async def __send_text__( self, text: str diff --git a/src/melobot/protocols/onebot/v11/adapter/echo.py b/src/melobot/protocols/onebot/v11/adapter/echo.py index 1378de00..df1a7838 100644 --- a/src/melobot/protocols/onebot/v11/adapter/echo.py +++ b/src/melobot/protocols/onebot/v11/adapter/echo.py @@ -3,9 +3,9 @@ from pydantic import BaseModel from typing_extensions import Any, Literal, Mapping, TypedDict, cast -from melobot.adapter.model import Echo as RootEcho +from melobot.adapter import Echo as RootEcho -from ..const import PROTOCOL_IDENTIFIER +from ..const import ACTION_TYPE_KEY_NAME, PROTOCOL_IDENTIFIER from .event import _GroupMessageSender, _MessageSender from .segment import NodeSegment, Segment @@ -19,7 +19,10 @@ class Model(BaseModel): def __init__(self, **kv_pairs: Any) -> None: self._model = self.Model(**kv_pairs) - self.raw = kv_pairs + + _dic = kv_pairs.copy() + self.action_type: str = _dic.pop(ACTION_TYPE_KEY_NAME) + self.raw = _dic super().__init__( protocol=PROTOCOL_IDENTIFIER, @@ -28,6 +31,12 @@ def __init__(self, **kv_pairs: Any) -> None: data=self._model.data, ) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(status={self._model.status}," + f" retcode={self._model.retcode}, action_type={self.action_type})" + ) + def is_ok(self) -> bool: return self._model.status == "ok" @@ -39,7 +48,7 @@ def is_failed(self) -> bool: @classmethod def resolve(cls, raw: dict[str, Any]) -> Echo: - match raw["action_type"]: + match raw[ACTION_TYPE_KEY_NAME]: case "send_private_msg" | "send_group_msg" | "send_msg": return SendMsgEcho(**raw) case "send_private_forward_msg" | "send_group_forward_msg": diff --git a/src/melobot/protocols/onebot/v11/adapter/event.py b/src/melobot/protocols/onebot/v11/adapter/event.py index 0ff22e56..78c17f91 100644 --- a/src/melobot/protocols/onebot/v11/adapter/event.py +++ b/src/melobot/protocols/onebot/v11/adapter/event.py @@ -3,8 +3,9 @@ from pydantic import BaseModel from typing_extensions import Any, Literal, Sequence, cast +from melobot.adapter import Event as RootEvent +from melobot.adapter import TextEvent as RootTextEvent from melobot.adapter import content -from melobot.adapter.model import Event as RootEvent from ..const import PROTOCOL_IDENTIFIER from .segment import Segment, TextSegment, segs_to_contents @@ -21,7 +22,7 @@ def __init__(self, **event_data: Any) -> None: #: 时间戳 self.time: int - super().__init__(self._model.time, protocol=PROTOCOL_IDENTIFIER) + super().__init__(PROTOCOL_IDENTIFIER, self._model.time) #: 机器人自己的 qq 号 self.self_id: int = self._model.self_id #: 事件类型 @@ -31,6 +32,9 @@ def __init__(self, **event_data: Any) -> None: #: 事件原始数据 self.raw: dict[str, Any] = event_data + def __repr__(self) -> str: + return f"{self.__class__.__name__}(post_type={self.post_type})" + @classmethod def resolve(cls, event_data: dict[str, Any]) -> Event: cls_map: dict[str, type[Event]] = { @@ -56,7 +60,7 @@ def is_meta(self) -> bool: return self.post_type == "meta_event" -class MessageEvent(Event): +class MessageEvent(RootTextEvent, Event): class Model(Event.Model): post_type: Literal["message"] message_type: Literal["private", "group"] | str @@ -112,6 +116,21 @@ def __init__(self, **event_data: Any) -> None: #: 消息字体 self.font: int = self._model.font + #: 消息内容 + self.text = "".join( + seg.data["text"] for seg in self.message if isinstance(seg, TextSegment) + ) + #: 消息内容行 + self.textlines = "\n".join( + seg.data["text"] for seg in self.message if isinstance(seg, TextSegment) + ).split("\n") + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(text={self.text!r}," + f" user_id={self.user_id}), sub_type={self.sub_type})" + ) + @classmethod def resolve(cls, event_data: dict[str, Any]) -> MessageEvent: cls_map: dict[str, type[MessageEvent]] = { @@ -122,18 +141,6 @@ def resolve(cls, event_data: dict[str, Any]) -> MessageEvent: return cls_map[mtype](**event_data) return cls(**event_data) - @property - def text(self) -> str: - return "".join( - seg.data["text"] for seg in self.message if isinstance(seg, TextSegment) - ) - - @property - def textlines(self) -> str: - return "\n".join( - seg.data["text"] for seg in self.message if isinstance(seg, TextSegment) - ) - def get_segments(self, type: type[Segment] | str) -> list[Segment]: if isinstance(type, str): return [seg for seg in self.message if seg.type == type] @@ -329,6 +336,13 @@ def __init__(self, **event_data: Any) -> None: #: 消息子类型 self.sub_type: Literal["normal", "anonymous", "notice", "group_self"] + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(text={self.text!r}," + f" user_id={self.user_id}, group_id={self.group_id}," + f" sub_type={self.sub_type})" + ) + class MetaEvent(Event): @@ -345,6 +359,14 @@ def __init__(self, **event_data: Any) -> None: self._model.meta_event_type ) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(meta_type={self.meta_event_type}" + + f", sub_type={self.sub_type})" + if hasattr(self, "sub_type") + else ")" + ) + @classmethod def resolve(cls, event_data: dict[str, Any]) -> MetaEvent: cls_map: dict[str, type[MetaEvent]] = { @@ -464,6 +486,14 @@ def __init__(self, **event_data: Any) -> None: | str ) = self._model.notice_type + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(notice_type={self.notice_type}" + + f", sub_type={self.sub_type})" + if hasattr(self, "sub_type") + else ")" + ) + @classmethod def resolve(cls, event_data: dict[str, Any]) -> NoticeEvent: cls_map: dict[str, type[NoticeEvent]] = { @@ -886,6 +916,14 @@ def __init__(self, **event_data: Any) -> None: #: 请求事件类型 self.request_type: Literal["friend", "group"] | str = self._model.request_type + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(request_type={self.request_type}" + + f", sub_type={self.sub_type})" + if hasattr(self, "sub_type") + else ")" + ) + @classmethod def resolve(cls, event_data: dict[str, Any]) -> RequestEvent: cls_map: dict[str, type[RequestEvent]] = { diff --git a/src/melobot/protocols/onebot/v11/const.py b/src/melobot/protocols/onebot/v11/const.py index 0781180d..b2f78c9f 100644 --- a/src/melobot/protocols/onebot/v11/const.py +++ b/src/melobot/protocols/onebot/v11/const.py @@ -7,5 +7,7 @@ PROTOCOL_NAME = "OneBot" PROTOCOL_VERSION = "11" -MODULE_AUTHOR = "Meloland" -PROTOCOL_IDENTIFIER = f"{PROTOCOL_NAME}-v{PROTOCOL_VERSION}@{MODULE_AUTHOR}" +PROTOCOL_SUPPORT_AUTHOR = "Meloland" +PROTOCOL_IDENTIFIER = f"{PROTOCOL_NAME}-v{PROTOCOL_VERSION}@{PROTOCOL_SUPPORT_AUTHOR}" + +ACTION_TYPE_KEY_NAME = "action_type" diff --git a/src/melobot/protocols/onebot/v11/handle.py b/src/melobot/protocols/onebot/v11/handle.py index e8e0e3a7..454255dc 100644 --- a/src/melobot/protocols/onebot/v11/handle.py +++ b/src/melobot/protocols/onebot/v11/handle.py @@ -1,169 +1,60 @@ -from functools import wraps -from typing import Sequence - -from typing_extensions import Callable, cast - -from melobot.ctx import Context -from melobot.di import Depends, inject_deps -from melobot.handle import Flow, get_event, no_deps_node -from melobot.session import Rule, enter_session -from melobot.typ import AsyncCallable, HandleLevel, LogicMode -from melobot.utils import get_obj_name +from typing_extensions import Any, Callable, Sequence, cast + +from melobot.adapter.model import Event as RootEvent +from melobot.handle import FlowDecorator +from melobot.handle import on_event as on_root_event +from melobot.handle import on_text +from melobot.session import Rule +from melobot.utils.check import Checker, checker_join +from melobot.utils.common import DeprecatedLoader as _DeprecatedLoader +from melobot.utils.match import Matcher +from melobot.utils.parse import Parser from .adapter.event import Event, MessageEvent, MetaEvent, NoticeEvent, RequestEvent -from .utils import check, match -from .utils.abc import Checker, Matcher, ParseArgs, Parser -from .utils.parse import CmdArgFormatter, CmdParser - - -class ParseArgsCtx(Context[ParseArgs | None]): - def __init__(self) -> None: - super().__init__( - "ONEBOT_V11_PARSE_ARGS", LookupError, "当前上下文中不存在解析参数" - ) - - -def GetParseArgs() -> ParseArgs: # pylint: disable=invalid-name - """获取解析参数 - - :return: 解析参数 - """ - return cast(ParseArgs, Depends(ParseArgsCtx().get, recursive=False)) - - -class DefaultRule(Rule[MessageEvent]): - """传统的会话规则(只针对消息事件) - - 两消息事件如果在同一发送渠道,且由同一人发送,则在同一会话中 - """ - - async def compare(self, e1: MessageEvent, e2: MessageEvent) -> bool: - return e1.scope == e2.scope - - -class FlowDecorator: - def __init__( - self, - checker: Checker | None | Callable[[Event], bool] = None, - matcher: Matcher | None = None, - parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, - block: bool = False, - temp: bool = False, - decos: Sequence[Callable[[Callable], Callable]] | None = None, - rule: Rule[Event] | None = None, - ) -> None: - self.checker = checker - self.matcher = matcher - self.parser = parser - self.priority = priority - self.block = block - self.temp = temp - self.decos = decos - self.rule = rule - - def __call__(self, func: AsyncCallable[..., bool | None]) -> Flow: - if not isinstance(self.checker, Checker) and callable(self.checker): - _checker = Checker.new(self.checker) - else: - _checker = cast(Checker, self.checker) - - func = inject_deps(func) - if self.decos is not None: - for deco in reversed(self.decos): - func = deco(func) - - @wraps(func) - async def wrapped() -> bool | None: - event = cast(Event, get_event()) - status = await _checker.check(event) - if not status: - return None - - p_args: ParseArgs | None = None - if isinstance(event, MessageEvent): - if self.matcher: - status = await self.matcher.match(event.text) - if not status: - return None - - if self.parser: - parse_args = await self.parser.parse(event.text) - if parse_args is None: - return None - p_args = parse_args - - event.spread = not self.block - afunc = cast(AsyncCallable[..., bool | None], func) - - with ParseArgsCtx().unfold(p_args): - if self.rule is not None: - async with enter_session(self.rule): - return await afunc() - return await afunc() - - n = no_deps_node(wrapped) - n.name = get_obj_name(func, otype="callable") - return Flow( - f"OneBotV11Flow[{n.name}]", - (n,), - priority=self.priority, - temp=self.temp, - ) +from .utils import check def on_event( checker: Checker | None | Callable[[Event], bool] = None, matcher: Matcher | None = None, parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, + priority: int = 0, block: bool = False, temp: bool = False, decos: Sequence[Callable[[Callable], Callable]] | None = None, rule: Rule[Event] | None = None, ) -> FlowDecorator: - return FlowDecorator(checker, matcher, parser, priority, block, temp, decos, rule) - - -def _checker_join(*checkers: Checker | None | Callable[[Event], bool]) -> Checker: - checker: Checker | None = None - for c in checkers: - if c is None: - continue - if isinstance(c, Checker): - checker = checker & c if checker else c - else: - checker = checker & Checker.new(c) if checker else Checker.new(c) - - if checker is None: - raise ValueError("检查器序列不能全为空") - return checker + return on_root_event( + checker=checker_join(lambda e: isinstance(e, Event), checker), + matcher=matcher, + parser=parser, + priority=priority, + block=block, + temp=temp, + decos=decos, + rule=cast(Rule[RootEvent] | None, rule), + ) def on_message( checker: Checker | None | Callable[[MessageEvent], bool] = None, matcher: Matcher | None = None, parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, + priority: int = 0, block: bool = False, temp: bool = False, decos: Sequence[Callable[[Callable], Callable]] | None = None, legacy_session: bool = False, ) -> FlowDecorator: - if legacy_session: - rule = DefaultRule() - else: - rule = None - - return on_event( - _checker_join(lambda e: e.is_message(), checker), # type: ignore[arg-type] + return on_text( + checker_join(lambda e: isinstance(e, MessageEvent), checker), matcher, parser, priority, block, temp, decos, - cast(Rule[Event], rule), + legacy_session, ) @@ -172,14 +63,14 @@ def on_at_qq( checker: Checker | None | Callable[[MessageEvent], bool] = None, matcher: Matcher | None = None, parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, + priority: int = 0, block: bool = False, temp: bool = False, decos: Sequence[Callable[[Callable], Callable]] | None = None, legacy_session: bool = False, ) -> FlowDecorator: return on_message( - _checker_join(check.AtMsgChecker(qid if qid else "all"), checker), # type: ignore[arg-type] + checker_join(check.AtMsgChecker(qid if qid else "all"), checker), # type: ignore[arg-type] matcher, parser, priority, @@ -190,158 +81,18 @@ def on_at_qq( ) -def on_command( - cmd_start: str | list[str], - cmd_sep: str | list[str], - targets: str | list[str], - fmtters: list[CmdArgFormatter | None] | None = None, - checker: Checker | None | Callable[[MessageEvent], bool] = None, - matcher: Matcher | None = None, - priority: HandleLevel = HandleLevel.NORMAL, - block: bool = False, - temp: bool = False, - decos: Sequence[Callable[[Callable], Callable]] | None = None, - legacy_session: bool = False, -) -> FlowDecorator: - return on_message( - checker, - matcher, - CmdParser(cmd_start=cmd_start, cmd_sep=cmd_sep, targets=targets, fmtters=fmtters), - priority, - block, - temp, - decos, - legacy_session, - ) - - -def on_start_match( - target: str | list[str], - logic_mode: LogicMode = LogicMode.OR, - checker: Checker | Callable[[MessageEvent], bool] | None = None, - parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, - block: bool = False, - temp: bool = False, - decos: Sequence[Callable[[Callable], Callable]] | None = None, - legacy_session: bool = False, -) -> FlowDecorator: - return on_message( - checker, - match.StartMatcher(target, logic_mode), - parser, - priority, - block, - temp, - decos, - legacy_session, - ) - - -def on_contain_match( - target: str | list[str], - logic_mode: LogicMode = LogicMode.OR, - checker: Checker | Callable[[MessageEvent], bool] | None = None, - parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, - block: bool = False, - temp: bool = False, - decos: Sequence[Callable[[Callable], Callable]] | None = None, - legacy_session: bool = False, -) -> FlowDecorator: - return on_message( - checker, - match.ContainMatcher(target, logic_mode), - parser, - priority, - block, - temp, - decos, - legacy_session, - ) - - -def on_full_match( - target: str | list[str], - logic_mode: LogicMode = LogicMode.OR, - checker: Checker | Callable[[MessageEvent], bool] | None = None, - parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, - block: bool = False, - temp: bool = False, - decos: Sequence[Callable[[Callable], Callable]] | None = None, - legacy_session: bool = False, -) -> FlowDecorator: - return on_message( - checker, - match.FullMatcher(target, logic_mode), - parser, - priority, - block, - temp, - decos, - legacy_session, - ) - - -def on_end_match( - target: str | list[str], - logic_mode: LogicMode = LogicMode.OR, - checker: Checker | Callable[[MessageEvent], bool] | None = None, - parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, - block: bool = False, - temp: bool = False, - decos: Sequence[Callable[[Callable], Callable]] | None = None, - legacy_session: bool = False, -) -> FlowDecorator: - return on_message( - checker, - match.EndMatcher(target, logic_mode), - parser, - priority, - block, - temp, - decos, - legacy_session, - ) - - -def on_regex_match( - target: str, - logic_mode: LogicMode = LogicMode.OR, - checker: Checker | Callable[[MessageEvent], bool] | None = None, - parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, - block: bool = False, - temp: bool = False, - decos: Sequence[Callable[[Callable], Callable]] | None = None, - legacy_session: bool = False, -) -> FlowDecorator: - return on_message( - checker, - match.RegexMatcher(target, logic_mode), - parser, - priority, - block, - temp, - decos, - legacy_session, - ) - - def on_request( checker: Checker | None | Callable[[RequestEvent], bool] = None, matcher: Matcher | None = None, parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, + priority: int = 0, block: bool = False, temp: bool = False, decos: Sequence[Callable[[Callable], Callable]] | None = None, rule: Rule[Event] | None = None, ) -> FlowDecorator: return on_event( - _checker_join(lambda e: e.is_request(), checker), # type: ignore[arg-type] + checker_join(lambda e: isinstance(e, RequestEvent), checker), # type: ignore[arg-type] matcher, parser, priority, @@ -356,14 +107,14 @@ def on_notice( checker: Checker | None | Callable[[NoticeEvent], bool] = None, matcher: Matcher | None = None, parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, + priority: int = 0, block: bool = False, temp: bool = False, decos: Sequence[Callable[[Callable], Callable]] | None = None, rule: Rule[Event] | None = None, ) -> FlowDecorator: return on_event( - _checker_join(lambda e: e.is_notice(), checker), # type: ignore[arg-type] + checker_join(lambda e: isinstance(e, NoticeEvent), checker), # type: ignore[arg-type] matcher, parser, priority, @@ -378,14 +129,14 @@ def on_meta( checker: Checker | None | Callable[[MetaEvent], bool] = None, matcher: Matcher | None = None, parser: Parser | None = None, - priority: HandleLevel = HandleLevel.NORMAL, + priority: int = 0, block: bool = False, temp: bool = False, decos: Sequence[Callable[[Callable], Callable]] | None = None, rule: Rule[Event] | None = None, ) -> FlowDecorator: return on_event( - _checker_join(lambda e: e.is_meta(), checker), # type: ignore[arg-type] + checker_join(lambda e: isinstance(e, MetaEvent), checker), # type: ignore[arg-type] matcher, parser, priority, @@ -394,3 +145,21 @@ def on_meta( decos, rule, ) + + +_LOADER = _DeprecatedLoader( + __name__, + { + "GetParseArgs": ("melobot.handle", "GetParseArgs", "3.1.1"), + "on_command": ("melobot.handle", "on_command", "3.1.1"), + "on_start_match": ("melobot.handle", "on_start_match", "3.1.1"), + "on_contain_match": ("melobot.handle", "on_contain_match", "3.1.1"), + "on_full_match": ("melobot.handle", "on_full_match", "3.1.1"), + "on_end_match": ("melobot.handle", "on_end_match", "3.1.1"), + "on_regex_match": ("melobot.handle", "on_regex_match", "3.1.1"), + }, +) + + +def __getattr__(name: str) -> Any: + return _LOADER.get(name) diff --git a/src/melobot/protocols/onebot/v11/io/__init__.py b/src/melobot/protocols/onebot/v11/io/__init__.py index d41dff36..b31d9512 100644 --- a/src/melobot/protocols/onebot/v11/io/__init__.py +++ b/src/melobot/protocols/onebot/v11/io/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseIO +from .base import BaseInSource, BaseIOSource, BaseOutSource, BaseSource from .duplex_http import HttpIO from .forward import ForwardWebSocketIO from .packet import InPacket diff --git a/src/melobot/protocols/onebot/v11/io/base.py b/src/melobot/protocols/onebot/v11/io/base.py index c8a44d5c..ff146692 100644 --- a/src/melobot/protocols/onebot/v11/io/base.py +++ b/src/melobot/protocols/onebot/v11/io/base.py @@ -1,12 +1,80 @@ -from melobot.io import AbstractIOSource -from melobot.log import GenericLogger, get_logger -from melobot.typ import abstractmethod +from abc import abstractmethod + +from melobot.io import ( + AbstractInSource, + AbstractIOSource, + AbstractOutSource, + AbstractSource, +) from ..const import PROTOCOL_IDENTIFIER from .packet import EchoPacket, InPacket, OutPacket -class BaseIO(AbstractIOSource[InPacket, OutPacket, EchoPacket]): +class BaseSource(AbstractSource): + def __init__(self) -> None: + super().__init__() + self.protocol = PROTOCOL_IDENTIFIER + self._hook_bus.set_tag(f"{self.protocol}/{self.__class__.__name__}") + + @abstractmethod + async def open(self) -> None: + raise NotImplementedError + + @abstractmethod + async def close(self) -> None: + raise NotImplementedError + + @abstractmethod + def opened(self) -> bool: + raise NotImplementedError + + +class BaseInSource(AbstractInSource[InPacket], BaseSource): + def __init__(self) -> None: + super().__init__() + + @abstractmethod + async def open(self) -> None: + raise NotImplementedError + + @abstractmethod + async def close(self) -> None: + raise NotImplementedError + + @abstractmethod + def opened(self) -> bool: + raise NotImplementedError + + @abstractmethod + async def input(self) -> InPacket: + raise NotImplementedError + + +class BaseOutSource(AbstractOutSource[OutPacket, EchoPacket], BaseSource): + def __init__(self) -> None: + super().__init__() + + @abstractmethod + async def open(self) -> None: + raise NotImplementedError + + @abstractmethod + async def close(self) -> None: + raise NotImplementedError + + @abstractmethod + def opened(self) -> bool: + raise NotImplementedError + + @abstractmethod + async def output(self, packet: OutPacket) -> EchoPacket: + raise NotImplementedError + + +class BaseIOSource( + AbstractIOSource[InPacket, OutPacket, EchoPacket], BaseInSource, BaseOutSource +): """ :ivar float cd_time: 发送行为操作的冷却时间(防风控) @@ -14,12 +82,8 @@ class BaseIO(AbstractIOSource[InPacket, OutPacket, EchoPacket]): # pylint: disable=duplicate-code def __init__(self, cd_time: float) -> None: - super().__init__(PROTOCOL_IDENTIFIER) - self.cd_time = cd_time if cd_time > 0 else 0.01 - - @property - def logger(self) -> GenericLogger: - return get_logger() + super().__init__() + self.cd_time = cd_time if cd_time >= 0 else 0 @abstractmethod async def open(self) -> None: diff --git a/src/melobot/protocols/onebot/v11/io/duplex_http.py b/src/melobot/protocols/onebot/v11/io/duplex_http.py index 2ffe9266..716b07b1 100644 --- a/src/melobot/protocols/onebot/v11/io/duplex_http.py +++ b/src/melobot/protocols/onebot/v11/io/duplex_http.py @@ -12,11 +12,11 @@ from melobot.io import SourceLifeSpan from melobot.log import LogLevel -from .base import BaseIO +from .base import BaseIOSource from .packet import EchoPacket, InPacket, OutPacket -class HttpIO(BaseIO): +class HttpIO(BaseIOSource): def __init__( self, onebot_host: str, @@ -25,7 +25,7 @@ def __init__( serve_port: int, secret: str | None = None, access_token: str | None = None, - cd_time: float = 0.2, + cd_time: float = 0, ) -> None: super().__init__(cd_time) self.onebot_url = f"http://{onebot_host}:{onebot_port}" @@ -182,7 +182,8 @@ async def close(self) -> None: await self.client_session.close() for t in self._tasks: t.cancel() - await asyncio.wait(self._tasks) + if len(self._tasks): + await asyncio.wait(self._tasks) self._tasks.clear() self._opened.clear() diff --git a/src/melobot/protocols/onebot/v11/io/forward.py b/src/melobot/protocols/onebot/v11/io/forward.py index 093c50e6..9badd6e9 100644 --- a/src/melobot/protocols/onebot/v11/io/forward.py +++ b/src/melobot/protocols/onebot/v11/io/forward.py @@ -13,17 +13,17 @@ from melobot.io import SourceLifeSpan from melobot.log import LogLevel -from .base import BaseIO +from .base import BaseIOSource from .packet import EchoPacket, InPacket, OutPacket -class ForwardWebSocketIO(BaseIO): +class ForwardWebSocketIO(BaseIOSource): def __init__( self, url: str, max_retry: int = -1, retry_delay: float = 4.0, - cd_time: float = 0.2, + cd_time: float = 0, access_token: str | None = None, ) -> None: super().__init__(cd_time) @@ -175,7 +175,8 @@ async def close(self) -> None: for t in self._tasks: t.cancel() - await asyncio.wait(self._tasks) + if len(self._tasks): + await asyncio.wait(self._tasks) self._tasks.clear() self.logger.info("OneBot v11 正向 WebSocket IO 源已断开连接") diff --git a/src/melobot/protocols/onebot/v11/io/reverse.py b/src/melobot/protocols/onebot/v11/io/reverse.py index c6f684b3..224c7cef 100644 --- a/src/melobot/protocols/onebot/v11/io/reverse.py +++ b/src/melobot/protocols/onebot/v11/io/reverse.py @@ -14,13 +14,13 @@ from melobot.io import SourceLifeSpan from melobot.log import LogLevel -from .base import BaseIO +from .base import BaseIOSource from .packet import EchoPacket, InPacket, OutPacket -class ReverseWebSocketIO(BaseIO): +class ReverseWebSocketIO(BaseIOSource): def __init__( - self, host: str, port: int, cd_time: float = 0.2, access_token: str | None = None + self, host: str, port: int, cd_time: float = 0, access_token: str | None = None ) -> None: super().__init__(cd_time) self.host = host @@ -173,7 +173,8 @@ async def close(self) -> None: for t in self._tasks: t.cancel() - await asyncio.wait(self._tasks) + if len(self._tasks): + await asyncio.wait(self._tasks) self._tasks.clear() self.logger.info("OneBot v11 反向 WebSocket IO 源的服务已停止运行") diff --git a/src/melobot/protocols/onebot/v11/utils/__init__.py b/src/melobot/protocols/onebot/v11/utils/__init__.py index 8418edf4..98f42e6d 100644 --- a/src/melobot/protocols/onebot/v11/utils/__init__.py +++ b/src/melobot/protocols/onebot/v11/utils/__init__.py @@ -1,4 +1,9 @@ -from .abc import Checker, Matcher, ParseArgs, Parser +from types import ModuleType + +from typing_extensions import Any + +from melobot.utils.common import DeprecatedLoader as _DeprecatedLoader + from .check import ( AtMsgChecker, GroupMsgChecker, @@ -10,5 +15,26 @@ get_group_role, get_level_role, ) -from .match import ContainMatcher, EndMatcher, FullMatcher, RegexMatcher, StartMatcher -from .parse import CmdArgFormatter, CmdParser, CmdParserFactory, FormatInfo + +_LOADER = _DeprecatedLoader( + __name__, + { + "Checker": ("melobot.utils.check", "Checker", "3.1.1"), + "Matcher": ("melobot.utils.match", "Matcher", "3.1.1"), + "ParseArgs": ("melobot.utils.parse", "CmdArgs", "3.1.1"), + "Parser": ("melobot.utils.parse", "Parser", "3.1.1"), + "ContainMatcher": ("melobot.utils.match", "ContainMatcher", "3.1.1"), + "EndMatcher": ("melobot.utils.match", "EndMatcher", "3.1.1"), + "FullMatcher": ("melobot.utils.match", "FullMatcher", "3.1.1"), + "RegexMatcher": ("melobot.utils.match", "RegexMatcher", "3.1.1"), + "StartMatcher": ("melobot.utils.match", "StartMatcher", "3.1.1"), + "CmdArgFormatter": ("melobot.utils.parse", "CmdArgFormatter", "3.1.1"), + "CmdParser": ("melobot.utils.parse", "CmdParser", "3.1.1"), + "CmdParserFactory": ("melobot.utils.parse", "CmdParserFactory", "3.1.1"), + "FormatInfo": ("melobot.utils.parse", "CmdArgFormatInfo", "3.1.1"), + }, +) + + +def __getattr__(name: str) -> Any: + return _LOADER.get(name) diff --git a/src/melobot/protocols/onebot/v11/utils/abc.py b/src/melobot/protocols/onebot/v11/utils/abc.py deleted file mode 100644 index e65cc992..00000000 --- a/src/melobot/protocols/onebot/v11/utils/abc.py +++ /dev/null @@ -1,209 +0,0 @@ -from __future__ import annotations - -from copy import deepcopy -from dataclasses import dataclass - -from typing_extensions import Any, Callable, Coroutine, Self - -from melobot.exceptions import BotException -from melobot.typ import AsyncCallable, BetterABC, LogicMode, abstractmethod -from melobot.utils import to_async - -from ..adapter.event import Event - - -class UtilsError(BotException): ... - - -class Cloneable: - def copy(self) -> Self: - return deepcopy(self) - - -class Checker(BetterABC, Cloneable): - """检查器基类""" - - def __init__(self, fail_cb: AsyncCallable[[], None] | None = None) -> None: - super().__init__() - self.fail_cb = fail_cb - - def __and__(self, other: Checker) -> WrappedChecker: - if not isinstance(other, Checker): - raise UtilsError(f"联合检查器定义时出现了非检查器对象,其值为:{other}") - return WrappedChecker(LogicMode.AND, self, other) - - def __or__(self, other: Checker) -> WrappedChecker: - if not isinstance(other, Checker): - raise UtilsError(f"联合检查器定义时出现了非检查器对象,其值为:{other}") - return WrappedChecker(LogicMode.OR, self, other) - - def __invert__(self) -> WrappedChecker: - return WrappedChecker(LogicMode.NOT, self) - - def __xor__(self, other: Checker) -> WrappedChecker: - if not isinstance(other, Checker): - raise UtilsError(f"联合检查器定义时出现了非检查器对象,其值为:{other}") - return WrappedChecker(LogicMode.XOR, self, other) - - @abstractmethod - async def check(self, event: Event) -> bool: - """检查器检查方法 - - 任何检查器应该实现此抽象方法。 - - :param event: 给定的事件 - :return: 检查是否通过 - """ - raise NotImplementedError - - @staticmethod - def new(func: Callable[[Event], bool]) -> Checker: - return CustomChecker(func) - - -class CustomChecker(Checker): - def __init__(self, func: Callable[[Event], bool]) -> None: - super().__init__() - self.func = func - - async def check(self, event: Event) -> bool: - return self.func(event) - - -class WrappedChecker(Checker): - """合并检查器 - - 在两个 :class:`Checker` 对象间使用 | & ^ ~ 运算符即可返回合并检查器。 - """ - - def __init__( - self, - mode: LogicMode, - checker1: Checker, - checker2: Checker | None = None, - ) -> None: - """初始化一个合并检查器 - - :param mode: 合并检查的逻辑模式 - :param checker1: 检查器1 - :param checker2: 检查器2 - """ - super().__init__() - self.mode = mode - self.c1, self.c2 = checker1, checker2 - - def set_fail_cb(self, fail_cb: AsyncCallable[[], None] | None) -> None: - self.fail_cb = fail_cb - - async def check(self, event: Event) -> bool: - c2_check: Callable[[], Coroutine[Any, Any, bool | None]] = ( - (lambda: self.c2.check(event)) # type: ignore[union-attr] - if self.c2 is not None - else to_async(lambda: None) # type: ignore[misc,arg-type] - ) - status = await LogicMode.async_short_calc( - self.mode, lambda: self.c1.check(event), c2_check - ) - - if not status and self.fail_cb is not None: - await self.fail_cb() - return status - - -class Matcher(BetterABC, Cloneable): - """匹配器基类""" - - def __init__(self) -> None: - super().__init__() - - def __and__(self, other: Matcher) -> WrappedMatcher: - if not isinstance(other, Matcher): - raise UtilsError(f"联合匹配器定义时出现了非匹配器对象,其值为:{other}") - return WrappedMatcher(LogicMode.AND, self, other) - - def __or__(self, other: Matcher) -> WrappedMatcher: - if not isinstance(other, Matcher): - raise UtilsError(f"联合匹配器定义时出现了非匹配器对象,其值为:{other}") - return WrappedMatcher(LogicMode.OR, self, other) - - def __invert__(self) -> WrappedMatcher: - return WrappedMatcher(LogicMode.NOT, self) - - def __xor__(self, other: Matcher) -> WrappedMatcher: - if not isinstance(other, Matcher): - raise UtilsError(f"联合匹配器定义时出现了非匹配器对象,其值为:{other}") - return WrappedMatcher(LogicMode.XOR, self, other) - - @abstractmethod - async def match(self, text: str) -> bool: - """匹配器匹配方法 - - 任何匹配器应该实现此抽象方法。 - - :param text: 消息事件的文本内容 - :return: 是否匹配 - """ - raise NotImplementedError - - -class WrappedMatcher(Matcher): - """合并匹配器 - - 在两个 :class:`Matcher` 对象间使用 | & ^ ~ 运算符即可返回合并匹配器 - """ - - def __init__( - self, - mode: LogicMode, - matcher1: Matcher, - matcher2: Matcher | None = None, - ) -> None: - """初始化一个合并匹配器 - - :param mode: 合并匹配的逻辑模式 - :param matcher1: 匹配器1 - :param matcher2: 匹配器2 - """ - super().__init__() - self.mode = mode - self.m1, self.m2 = matcher1, matcher2 - - async def match(self, text: str) -> bool: - m2_match: Callable[[], Coroutine[Any, Any, bool | None]] = ( - (lambda: self.m2.match(text)) # type: ignore[union-attr] - if self.m2 is not None - else to_async(lambda: None) # type: ignore[misc,arg-type] - ) - return await LogicMode.async_short_calc( - self.mode, lambda: self.m1.match(text), m2_match - ) - - -@dataclass -class ParseArgs: - """解析参数""" - - name: str - vals: list[Any] - - -class Parser(BetterABC): - """解析器基类 - - 解析器一般用作从消息文本中按规则批量提取参数 - """ - - def __init__(self) -> None: - super().__init__() - - @abstractmethod - async def parse(self, text: str) -> ParseArgs | None: - """解析方法 - - 任何解析器应该实现此抽象方法 - - :param text: 消息文本内容 - :return: 解析结果 - - """ - raise NotImplementedError diff --git a/src/melobot/protocols/onebot/v11/utils/check.py b/src/melobot/protocols/onebot/v11/utils/check.py index 634ec54c..af1ac5f8 100644 --- a/src/melobot/protocols/onebot/v11/utils/check.py +++ b/src/melobot/protocols/onebot/v11/utils/check.py @@ -4,11 +4,11 @@ from typing_extensions import Literal, Optional, cast -from melobot.typ import AsyncCallable +from melobot.typ import SyncOrAsyncCallable +from melobot.utils.check import Checker from ..adapter.event import Event, GroupMessageEvent, MessageEvent from ..adapter.segment import AtSegment -from .abc import Checker class LevelRole(int, Enum): @@ -64,7 +64,7 @@ def get_group_role(event: MessageEvent) -> GroupRole: return GroupRole.MEMBER -class MsgChecker(Checker): +class MsgChecker(Checker[Event]): """消息事件分级权限检查器 主要分 主人、超级用户、白名单用户、普通用户、黑名单用户 五级 @@ -77,7 +77,7 @@ def __init__( super_users: Optional[list[int]] = None, white_users: Optional[list[int]] = None, black_users: Optional[list[int]] = None, - fail_cb: Optional[AsyncCallable[[], None]] = None, + fail_cb: Optional[SyncOrAsyncCallable[[], None]] = None, ) -> None: """初始化一个消息事件分级权限检查器 @@ -140,7 +140,7 @@ def __init__( white_users: Optional[list[int]] = None, black_users: Optional[list[int]] = None, white_groups: Optional[list[int]] = None, - fail_cb: Optional[AsyncCallable[[], None]] = None, + fail_cb: Optional[SyncOrAsyncCallable[[], None]] = None, ) -> None: """初始化一个群聊消息事件分级权限检查器 @@ -182,7 +182,7 @@ def __init__( super_users: Optional[list[int]] = None, white_users: Optional[list[int]] = None, black_users: Optional[list[int]] = None, - fail_cb: Optional[AsyncCallable[[], None]] = None, + fail_cb: Optional[SyncOrAsyncCallable[[], None]] = None, ) -> None: """初始化一个私聊消息事件分级权限检查器 @@ -215,7 +215,7 @@ def __init__( white_users: Optional[list[int]] = None, black_users: Optional[list[int]] = None, white_groups: Optional[list[int]] = None, - fail_cb: Optional[AsyncCallable[[], None]] = None, + fail_cb: Optional[SyncOrAsyncCallable[[], None]] = None, ) -> None: """初始化一个消息事件分级权限检查器的工厂 @@ -237,7 +237,7 @@ def __init__( def get_base( self, role: LevelRole | GroupRole, - fail_cb: Optional[AsyncCallable[[], None]] = None, + fail_cb: Optional[SyncOrAsyncCallable[[], None]] = None, ) -> MsgChecker: """根据内部依据和给定等级,生成一个 :class:`MsgChecker` 对象 @@ -257,7 +257,7 @@ def get_base( def get_group( self, role: LevelRole | GroupRole, - fail_cb: Optional[AsyncCallable[[], None]] = None, + fail_cb: Optional[SyncOrAsyncCallable[[], None]] = None, ) -> GroupMsgChecker: """根据内部依据和给定等级,生成一个 :class:`GroupMsgChecker` 对象 @@ -278,7 +278,7 @@ def get_group( def get_private( self, role: LevelRole, - fail_cb: Optional[AsyncCallable[[], None]] = None, + fail_cb: Optional[SyncOrAsyncCallable[[], None]] = None, ) -> PrivateMsgChecker: """根据内部依据和给定等级,生成一个 :class:`PrivateMsgChecker` 对象 @@ -302,7 +302,7 @@ class AtMsgChecker(Checker): def __init__( self, qid: int | Literal["all"] | None = None, - fail_cb: Optional[AsyncCallable[[], None]] = None, + fail_cb: Optional[SyncOrAsyncCallable[[], None]] = None, ) -> None: """初始化一个艾特消息事件检查器 diff --git a/src/melobot/session/__init__.py b/src/melobot/session/__init__.py index 99fd20f9..5974fbb8 100644 --- a/src/melobot/session/__init__.py +++ b/src/melobot/session/__init__.py @@ -1,6 +1,14 @@ from ..ctx import SessionCtx as _SessionCtx from .base import Session, SessionStore, enter_session, suspend -from .option import Rule +from .option import CompareInfo, DefaultRule, Rule + + +def get_session() -> Session: + """获取当前上下文中的会话 + + :return: 会话 + """ + return _SessionCtx().get() def get_session_store() -> SessionStore: @@ -11,9 +19,11 @@ def get_session_store() -> SessionStore: return _SessionCtx().get().store -def get_rule() -> Rule | None: +def get_rule() -> Rule: """获取当前上下文中的会话规则 - :return: 会话规则或空 + :return: 会话规则 """ - return _SessionCtx().get().rule + rule = _SessionCtx().get().rule + assert rule is not None, "预期之外的会话规则为空" + return rule diff --git a/src/melobot/session/base.py b/src/melobot/session/base.py index ca85e79a..7ea8d799 100644 --- a/src/melobot/session/base.py +++ b/src/melobot/session/base.py @@ -1,38 +1,27 @@ from __future__ import annotations import asyncio -from asyncio import Condition, Lock +import inspect +from asyncio import Condition, Future, Lock from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from typing_extensions import Any, AsyncGenerator from ..adapter.model import Event from ..ctx import FlowCtx, SessionCtx -from ..exceptions import BotException -from ..handle.process import stop -from ..typ import AsyncCallable +from ..exceptions import SessionRuleLacked, SessionStateFailed +from ..handle.base import EventCompletion, stop +from ..typ.base import SyncOrAsyncCallable from .option import CompareInfo, Rule _SESSION_CTX = SessionCtx() -class SessionError(BotException): ... - - -class SessionStateFailed(SessionError): - def __init__(self, cur_state: str, meth: str) -> None: - self.cur_state = cur_state - super().__init__(f"当前会话状态 {cur_state} 不支持的操作:{meth}") - - -class SessionRuleLacked(SessionError): ... - - class SessionState: def __init__(self, session: "Session") -> None: self.session = session - async def work(self, event: Event) -> None: + async def work(self, completion: EventCompletion) -> None: raise SessionStateFailed(self.__class__.__name__, SessionState.work.__name__) async def rest(self) -> None: @@ -41,7 +30,7 @@ async def rest(self) -> None: async def suspend(self, timeout: float | None) -> bool: raise SessionStateFailed(self.__class__.__name__, SessionState.suspend.__name__) - async def wakeup(self, event: Event) -> None: + async def wakeup(self, completion: EventCompletion | None) -> None: raise SessionStateFailed(self.__class__.__name__, SessionState.wakeup.__name__) async def expire(self) -> None: @@ -49,8 +38,9 @@ async def expire(self) -> None: class SpareSessionState(SessionState): - async def work(self, event: Event) -> None: - self.session.event = event + async def work(self, completion: EventCompletion) -> None: + self.session._completions.add(completion) + self.session.event = completion.event self.session.__to_state__(WorkingSessionState) @@ -59,31 +49,33 @@ async def rest(self) -> None: if self.session.rule is None: raise SessionRuleLacked("缺少会话规则,会话无法从“运行态”转为“空闲态”") - cond = self.session.refresh_cond + cond = self.session._refresh_cond self.session.__to_state__(SpareSessionState) async with cond: cond.notify() async def suspend(self, timeout: float | None) -> bool: + self.session.set_completed() + if self.session.rule is None: raise SessionRuleLacked("缺少会话规则,会话无法从“运行态”转为“挂起态”") - cond = self.session.refresh_cond + cond = self.session._refresh_cond self.session.__to_state__(SuspendSessionState) async with cond: cond.notify() - async with self.session.wakeup_cond: + async with self.session._wakeup_cond: if timeout is None: - await self.session.wakeup_cond.wait() + await self.session._wakeup_cond.wait() return True try: - await asyncio.wait_for(self.session.wakeup_cond.wait(), timeout=timeout) + await asyncio.wait_for(self.session._wakeup_cond.wait(), timeout=timeout) return True except asyncio.TimeoutError: - if self.session.is_state(WorkingSessionState): + if self.session.__is_state__(WorkingSessionState): return True self.session.__to_state__(WorkingSessionState) return False @@ -91,18 +83,23 @@ async def suspend(self, timeout: float | None) -> bool: async def expire(self) -> None: self.session.__to_state__(ExpireSessionState) if self.session.rule is not None: - cond = self.session.refresh_cond + cond = self.session._refresh_cond async with cond: cond.notify() + self.session.set_completed() + class SuspendSessionState(SessionState): - async def wakeup(self, event: Event) -> None: - if self.session.is_state(WorkingSessionState): + async def wakeup(self, completion: EventCompletion | None) -> None: + if self.session.__is_state__(WorkingSessionState): return - self.session.event = event - cond = self.session.wakeup_cond + + if completion is not None: + self.session._completions.add(completion) + self.session.event = completion.event + cond = self.session._wakeup_cond self.session.__to_state__(WorkingSessionState) async with cond: cond.notify() @@ -129,27 +126,53 @@ class Session: __instance_locks__: dict[Rule, Lock] = {} __cls_lock__ = Lock() - def __init__(self, event: Event, rule: Rule | None, keep: bool = False) -> None: + def __init__( + self, + rule: Rule | None, + first_completion: EventCompletion, + keep: bool = False, + auto_complete: bool = True, + ) -> None: self.store: SessionStore = SessionStore() - self.event = event + self.event = first_completion.event self.rule = rule - self.refresh_cond = Condition() - self.wakeup_cond = Condition() - self.keep = keep + self.auto_complete = auto_complete + self._completions: set[EventCompletion] = set() + self._completions.add(first_completion) + self._refresh_cond = Condition() + self._wakeup_cond = Condition() + self._keep = keep self._state: SessionState = WorkingSessionState(self) + def stop_keep(self) -> None: + self._keep = False + + def set_completed(self, event: Event | None = None) -> None: + if event is None: + for c in self._completions: + c.completed.set_result(None) + self._completions.clear() + return + + comps = filter(lambda c: c.event is event, self._completions) + for c in comps: + c.completed.set_result(None) + self._completions.remove(c) + + def get_incompletions(self) -> list[tuple[Event, Future]]: + return [ + (c.event, c.completed) for c in self._completions if not c.completed.done() + ] + def __to_state__(self, state_class: type[SessionState]) -> None: self._state = state_class(self) - def is_state(self, state_class: type[SessionState]) -> bool: + def __is_state__(self, state_class: type[SessionState]) -> bool: return isinstance(self._state, state_class) - def mark_expire(self) -> None: - self.keep = False - - async def __work__(self, event: Event) -> None: - await self._state.work(event) + async def __work__(self, completion: EventCompletion) -> None: + await self._state.work(completion) async def __rest__(self) -> None: await self._state.rest() @@ -157,8 +180,8 @@ async def __rest__(self) -> None: async def __suspend__(self, timeout: float | None = None) -> bool: return await self._state.suspend(timeout) - async def __wakeup__(self, event: Event) -> None: - await self._state.wakeup(event) + async def __wakeup__(self, completion: EventCompletion | None) -> None: + await self._state.wakeup(completion) async def __expire__(self) -> None: await self._state.expire() @@ -166,14 +189,21 @@ async def __expire__(self) -> None: @classmethod async def get( cls, - event: Event, + completion: EventCompletion, rule: Rule | None = None, wait: bool = True, - nowait_cb: AsyncCallable[[], None] | None = None, + nowait_cb: SyncOrAsyncCallable[[], None] | None = None, keep: bool = False, + auto_complete: bool = True, ) -> Session | None: + event = completion.event if rule is None: - return Session(event, rule=None, keep=keep) + return Session( + rule=None, + first_completion=completion, + keep=keep, + auto_complete=auto_complete, + ) async with cls.__cls_lock__: cls.__instance_locks__.setdefault(rule, Lock()) @@ -182,24 +212,24 @@ async def get( try: _set = cls.__instances__.setdefault(rule, set()) - suspends = filter(lambda s: s.is_state(SuspendSessionState), _set) + suspends = filter(lambda s: s.__is_state__(SuspendSessionState), _set) for session in suspends: if await rule.compare_with( CompareInfo(session, session.event, event) ): - await session.__wakeup__(event) + await session.__wakeup__(completion) return None - spares = filter(lambda s: s.is_state(SpareSessionState), _set) + spares = filter(lambda s: s.__is_state__(SpareSessionState), _set) for session in spares: if await rule.compare_with( CompareInfo(session, session.event, event) ): - await session.__work__(event) - session.keep = keep + await session.__work__(completion) + session._keep = keep return session - workings = filter(lambda s: s.is_state(WorkingSessionState), _set) + workings = filter(lambda s: s.__is_state__(WorkingSessionState), _set) for session in workings: if not await rule.compare_with( CompareInfo(session, session.event, event) @@ -208,31 +238,41 @@ async def get( if not wait: if nowait_cb is not None: - await nowait_cb() + ret = nowait_cb() + if inspect.isawaitable(ret): + await ret + completion.completed.set_result(None) return None - cond = session.refresh_cond + cond = session._refresh_cond async with cond: await cond.wait() - if session.is_state(ExpireSessionState): + if session.__is_state__(ExpireSessionState): pass - elif session.is_state(SuspendSessionState): - await session.__wakeup__(event) + elif session.__is_state__(SuspendSessionState): + await session.__wakeup__(completion) return None else: - await session.__work__(event) - session.keep = keep + await session.__work__(completion) + session._keep = keep return session - session = Session(event, rule=rule, keep=keep) + session = Session( + rule=rule, + first_completion=completion, + keep=keep, + auto_complete=auto_complete, + ) Session.__instances__[rule].add(session) return session finally: - expires = tuple(filter(lambda s: s.is_state(ExpireSessionState), _set)) + expires = tuple( + filter(lambda s: s.__is_state__(ExpireSessionState), _set) + ) for session in expires: Session.__instances__[rule].remove(session) @@ -242,15 +282,21 @@ async def enter( cls, rule: Rule, wait: bool = True, - nowait_cb: AsyncCallable[[], None] | None = None, + nowait_cb: SyncOrAsyncCallable[[], None] | None = None, keep: bool = False, + auto_complete: bool = True, ) -> AsyncGenerator[Session, None]: + flow_ctx = FlowCtx() + completion = flow_ctx.get_completion() + completion.under_session = True + session = await cls.get( - FlowCtx().get_event(), + completion, rule=rule, wait=wait, nowait_cb=nowait_cb, keep=keep, + auto_complete=auto_complete, ) if session is None: await stop() @@ -259,10 +305,10 @@ async def enter( try: yield session except asyncio.CancelledError: - if session.is_state(SuspendSessionState): - await session.__wakeup__(session.event) + if session.__is_state__(SuspendSessionState): + await session.__wakeup__(completion=None) finally: - if session.keep: + if session._keep: await session.__rest__() else: await session.__expire__() @@ -280,8 +326,9 @@ async def suspend(timeout: float | None = None) -> bool: def enter_session( rule: Rule, wait: bool = True, - nowait_cb: AsyncCallable[[], None] | None = None, + nowait_cb: SyncOrAsyncCallable[[], None] | None = None, keep: bool = False, + auto_complete: bool = True, ) -> _AsyncGeneratorContextManager[Session]: """上下文管理器,提供一个会话上下文,在此上下文中可使用会话的高级特性 @@ -289,6 +336,7 @@ def enter_session( :param wait: 当出现会话冲突时,是否需要等待 :param nowait_cb: 指定了 `wait=False` 后,会话冲突时执行的回调 :param keep: 会话在退出会话上下文后是否继续保持 + :param auto_complete: 当前会话挂起后,事件是否自动向更低优先级传播 :yield: 会话对象 """ - return Session.enter(rule, wait, nowait_cb, keep) + return Session.enter(rule, wait, nowait_cb, keep, auto_complete) diff --git a/src/melobot/session/option.py b/src/melobot/session/option.py index 99bc7d41..5cd2b33e 100644 --- a/src/melobot/session/option.py +++ b/src/melobot/session/option.py @@ -4,8 +4,8 @@ from typing_extensions import TYPE_CHECKING, Any, Callable, Generic, final -from ..adapter.model import EventT -from ..typ import BetterABC +from ..adapter.model import Event, EventT +from ..typ.cls import BetterABC if TYPE_CHECKING: from .base import Session @@ -85,3 +85,13 @@ def __init__(self, meth: Callable[[EventT, EventT], bool]) -> None: async def compare(self, e1: EventT, e2: EventT) -> bool: return self.meth(e1, e2) + + +class DefaultRule(Rule[Event]): + """传统的会话规则 + + 判断事件的 `scope` 是否相同 + """ + + async def compare(self, e1: Event, e2: Event) -> bool: + return e1.scope == e2.scope diff --git a/src/melobot/typ.py b/src/melobot/typ.py deleted file mode 100644 index 9d0084c8..00000000 --- a/src/melobot/typ.py +++ /dev/null @@ -1,482 +0,0 @@ -import inspect -import warnings -from abc import ABCMeta, abstractmethod -from enum import Enum -from functools import wraps - -from beartype import BeartypeConf as _BeartypeConf -from beartype.door import is_bearable as _is_type -from beartype.door import is_subhint -from typing_extensions import ( - Any, - Awaitable, - Callable, - ParamSpec, - Protocol, - Self, - Sequence, - TypeGuard, - TypeVar, - cast, -) - -__all__ = ( - "T", - "T_co", - "P", - "AsyncCallable", - "is_type", - "is_subhint", - "HandleLevel", - "LogicMode", - "BetterABCMeta", - "BetterABCMeta", - "BetterABC", - "SingletonMeta", - "SingletonBetterABCMeta", - "abstractattr", - "abstractmethod", - "Markable", - "AttrsReprable", - "Locatable", - "VoidType", -) - -#: 泛型 T,无约束 -T = TypeVar("T", default=Any) -#: 泛型 T_co,协变无约束 -T_co = TypeVar("T_co", covariant=True, default=Any) -#: :obj:`~typing.ParamSpec` 泛型 P,无约束 -P = ParamSpec("P", default=Any) - - -class AsyncCallable(Protocol[P, T_co]): - """用法:AsyncCallable[P, T] - - 是该类型的等价形式:Callable[P, Awaitable[T]] - """ - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T_co]: ... - - -_DEFAULT_BEARTYPE_CONF = _BeartypeConf(is_pep484_tower=True) - - -def is_type(obj: T, hint: type[Any]) -> TypeGuard[T]: - """检查 `obj` 是否是类型注解 `hint` 所表示的类型 - - :param obj: 任意对象 - :param hint: 任意类型注解 - :return: 布尔值 - """ - ret = _is_type(obj, hint, conf=_DEFAULT_BEARTYPE_CONF) - return ret # type: ignore[no-any-return] - - -class HandleLevel(float, Enum): - """事件处理流优先级枚举类型""" - - MAX = 1 << 6 - ULTRA_HIGH = 1 << 5 - HIGH = 1 << 4 - NORMAL = 1 << 3 - LOW = 1 << 2 - ULTRA_LOW = 1 << 1 - MIN = 1 - - -class LogicMode(Enum): - """逻辑模式枚举类型""" - - AND = 1 - OR = 2 - NOT = 3 - XOR = 4 - - @classmethod - def calc(cls, logic: "LogicMode", v1: Any, v2: Any = None) -> bool: - """将两个值使用指定逻辑模式运算 - - :param logic: 逻辑模式 - :param v1: 值 1 - :param v2: 值 2 - :return: 布尔值 - """ - if logic == LogicMode.AND: - return (v1 and v2) if v2 is not None else bool(v1) # type: ignore[no-any-return] - if logic == LogicMode.OR: - return (v1 or v2) if v2 is not None else bool(v1) # type: ignore[no-any-return] - if logic == LogicMode.NOT: - return not v1 - return (v1 ^ v2) if v2 is not None else bool(v1) # type: ignore[no-any-return] - - @classmethod - def short_calc( - cls, logic: "LogicMode", v1: Callable[[], Any], v2: Callable[[], Any] - ) -> bool: - """与 :func:`calc` 功能类似,但运算支持短路 - - :param logic: 逻辑模式 - :param v1: 生成值 1 的可调用对象 - :param v2: 生成值 2 的可调用对象 - :return: 布尔值 - """ - if logic == LogicMode.AND: - return (v1() and v2()) if v2 is not None else bool(v1()) # type: ignore[no-any-return] - if logic == LogicMode.OR: - return (v1() or v2()) if v2 is not None else bool(v1()) # type: ignore[no-any-return] - if logic == LogicMode.NOT: - return not v1() - return (v1() ^ v2()) if v2 is not None else bool(v1()) # type: ignore[no-any-return] - - @classmethod - async def async_short_calc( - cls, logic: "LogicMode", v1: AsyncCallable[[], Any], v2: AsyncCallable[[], Any] - ) -> bool: - """与 :func:`short_calc` 功能类似,但运算支持异步 - - :param logic: 逻辑模式 - :param v1: 生成值 1 的异步可调用对象 - :param v2: 生成值 2 的异步可调用对象 - :return: 布尔值 - """ - if logic == LogicMode.AND: - res = (await v1() and await v2()) if v2 is not None else bool(await v1()) - return res # type: ignore[no-any-return] - if logic == LogicMode.OR: - res = (await v1() or await v2()) if v2 is not None else bool(await v1()) - return res # type: ignore[no-any-return] - if logic == LogicMode.NOT: - return not await v1() - res = (await v1() ^ await v2()) if v2 is not None else bool(await v1()) - return res # type: ignore[no-any-return] - - @classmethod - def seq_calc(cls, logic: "LogicMode", values: list[Any]) -> bool: - """使用指定的逻辑模式,对值序列进行运算 - - .. code:: python - - # 操作等价与:True and False and True - LogicMode.seq_calc(LogicMode.AND, [True, False, True]) - - :param logic: 逻辑模式 - :param values: 值序列 - :return: 布尔值 - """ - if len(values) <= 0: - return False - if len(values) <= 1: - return bool(values[0]) - - idx = 0 - res: bool - while idx < len(values): - if idx == 0: - res = cls.calc(logic, values[idx], values[idx + 1]) - idx += 1 - else: - res = cls.calc(logic, res, values[idx]) - idx += 1 - return res - - @classmethod - def short_seq_calc( - cls, logic: "LogicMode", getters: Sequence[Callable[[], Any]] - ) -> bool: - """与 :func:`seq_calc` 功能类似,但运算支持短路 - - :param logic: 逻辑模式 - :param getters: 一组获取值的可调用对象 - :return: 布尔值 - """ - if len(getters) <= 0: - return False - if len(getters) <= 1: - return bool(getters[0]()) - - idx = 0 - res: bool - while idx < len(getters): - if idx == 0: - res = cls.short_calc(logic, getters[idx], getters[idx + 1]) - idx += 1 - else: - res = cls.short_calc(logic, lambda: res, getters[idx]) - idx += 1 - return res - - @classmethod - async def async_short_seq_calc( - cls, logic: "LogicMode", getters: Sequence[AsyncCallable[[], Any]] - ) -> bool: - """与 :func:`short_seq_calc` 功能类似,但运算支持异步 - - :param logic: 逻辑模式 - :param getters: 一组获取值的异步可调用对象 - :return: 布尔值 - """ - if len(getters) <= 0: - return False - if len(getters) <= 1: - return bool(await getters[0]()) - - idx = 0 - res: bool - while idx < len(getters): - if idx == 0: - res = await cls.async_short_calc(logic, getters[idx], getters[idx + 1]) - idx += 1 - else: - - async def res_getter() -> bool: - return res - - res = await cls.async_short_calc(logic, res_getter, getters[idx]) - idx += 1 - return res - - -def abstractattr(obj: Callable[[Any], T] | None = None) -> T: - """抽象属性 - - 与 `abstractproperty` 相比更灵活,`abstractattr` 不关心你以何种形式定义属性。只要子类在实例化后,该属性存在,即认为合法。 - - 但注意它必须与 :class:`BetterABC` 或 :class:`BetterABCMeta` 或 :class:`.SingletonBetterABCMeta` 配合使用 - - 这意味着可以在类层级、实例层级定义属性,或使用 `property` 定义属性: - - .. code:: python - - class A(BetterABC): - foo: int = abstractattr() # 声明为抽象属性 - - # 或者使用装饰器的形式声明,这与上面是等价的 - @abstractattr - def bar(self) -> int: ... - - # 以下实现都是合法的: - - class B(A): - foo = 2 - bar = 4 - - class C(A): - foo = 3 - def __init__(self) -> None: - self.bar = 5 - - class D(A): - def __init__(self) -> None: - self.foo = 8 - - @property - def bar(self) -> int: - return self.foo + 10 - """ - _obj = cast(Any, obj) - if obj is None: - _obj = BetterABCMeta.DummyAttribute() - setattr(_obj, "__is_abstract_attribute__", True) - return cast(T, _obj) - - -class BetterABCMeta(ABCMeta): - """更好的抽象元类,兼容 `ABCMeta` 的所有功能,但是额外支持 :func:`abstractattr`""" - - class DummyAttribute: ... - - def __call__(cls: type[T], *args: Any, **kwargs: Any) -> T: - instance = ABCMeta.__call__(cls, *args, **kwargs) - lack_attrs = set() - for name in dir(instance): - try: - attr = getattr(instance, name) - except Exception: - if not isinstance(getattr(instance.__class__, name), property): - raise - - if getattr(attr, "__is_abstract_attribute__", False): - lack_attrs.add(name) - if inspect.iscoroutine(attr): - attr.close() - - if lack_attrs: - raise NotImplementedError( - "Can't instantiate abstract class {} with" - " abstract attributes: {}".format(cls.__name__, ", ".join(lack_attrs)) - ) - return cast(T, instance) - - -class BetterABC(metaclass=BetterABCMeta): - """更好的抽象类,兼容 `ABC` 的所有功能,但是额外支持 :func:`abstractattr`""" - - __slots__ = () - - -class SingletonMeta(type): - """单例元类 - - 相比单例装饰器,可以自动保证所有子类都为单例 - """ - - __instances__: dict[type, Any] = {} - - def __call__(cls: type[T], *args: Any, **kwargs: Any) -> T: - if cls not in SingletonMeta.__instances__: - SingletonMeta.__instances__[cls] = cast( - T, super(SingletonMeta, cls).__call__(*args, **kwargs) # type: ignore[misc] - ) - return cast(T, SingletonMeta.__instances__[cls]) - - -class SingletonBetterABCMeta(BetterABCMeta): - """单例抽象元类 - - 相比普通的抽象元类,还可以自动保证所有子类都为单例 - """ - - __instances__: dict[type, Any] = {} - - def __call__(cls: type[T], *args: Any, **kwargs: Any) -> T: - mcls = SingletonBetterABCMeta - if cls not in mcls.__instances__: - mcls.__instances__[cls] = BetterABCMeta.__call__(cls, *args, **kwargs) - return cast(T, mcls.__instances__[cls]) - - -class Markable: - """可标记对象 - - 无需直接实例化,而是用作接口在其他类中继承 - """ - - def __init__(self) -> None: - self._flags: dict[str, dict[str, Any]] = {} - - def flag_mark(self, namespace: str, flag_name: str, val: Any = None) -> None: - """在 `namespace` 命名空间中设置 `flag_name` 标记,值为 `val` - - 注:不同的对象并不共享 `namespace`,`namespace` 只适用于单个对象 - - :param namespace: 命名空间 - :param flag_name: 标记名 - :param val: 标记值 - """ - self._flags.setdefault(namespace, {}) - - if flag_name in self._flags[namespace].keys(): - raise ValueError( - f"标记失败。对象的命名空间 {namespace} 中已存在名为 {flag_name} 的标记" - ) - - self._flags[namespace][flag_name] = val - - def flag_check(self, namespace: str, flag_name: str, val: Any = None) -> bool: - """检查 `namespace` 命名空间中 `flag_name` 标记值是否为 `val` - - 注:不同的对象并不共享 `namespace`,`namespace` 只适用于单个对象 - - :param namespace: 命名空间 - :param flag_name: 标记名 - :param val: 标记值 - :return: 是否通过检查 - """ - if self._flags.get(namespace) is None: - return False - if flag_name not in self._flags[namespace].keys(): - return False - flag = self._flags[namespace][flag_name] - - if val is None: - return flag is None - return cast(bool, flag == val) - - -class AttrsReprable: - def __repr__(self) -> str: - attrs = ", ".join( - f"{k}={repr(v)}" for k, v in self.__dict__.items() if not k.startswith("_") - ) - return f"{self.__class__.__name__}({attrs})" - - -class Locatable: - def __new__(cls, *_args: Any, **_kwargs: Any) -> Self: - obj = super().__new__(cls) - obj.__obj_location__ = obj._init_location() # type: ignore[attr-defined] - return obj - - def __init__(self) -> None: - self.__obj_location__: tuple[str, str, int] - - @staticmethod - def _init_location() -> tuple[str, str, int]: - frame = inspect.currentframe() - while frame: - if frame.f_code.co_name == "": - return ( - frame.f_globals["__name__"], - frame.f_globals["__file__"], - frame.f_lineno, - ) - frame = frame.f_back - - return ( - "", - "", - -1, - ) - - @property - def __obj_module__(self) -> str: - return self.__obj_location__[0] - - @property - def __obj_file__(self) -> str: - return self.__obj_location__[1] - - @property - def __obj_line__(self) -> int: - return self.__obj_location__[2] - - -class VoidType(Enum): - """空类型,需要区别于 `None` 时使用 - - .. code:: python - - # 有些时候 `None` 也是合法值,因此需要一个额外的哨兵值: - def foo(val: Any | VoidType = VoidType.VOID) -> None: - ... - """ - - VOID = type("_VOID", (), {}) - - -def deprecate_warn(msg: str) -> None: - # pylint: disable=cyclic-import - from .ctx import LoggerCtx - - if logger := LoggerCtx().try_get(): - logger.warning(msg) - warnings.simplefilter("always", DeprecationWarning) - warnings.warn(msg, category=DeprecationWarning, stacklevel=1) - warnings.simplefilter("default", DeprecationWarning) - - -def deprecated(msg: str) -> Callable[[Callable[P, T]], Callable[P, T]]: - - def decorator(func: Callable[P, T]) -> Callable[P, T]: - - @wraps(func) - def deprecate_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: - deprecate_warn( - f"使用了弃用函数/方法 {func.__module__}.{func.__qualname__}: {msg}" - ) - return func(*args, **kwargs) - - return deprecate_wrapped - - return decorator diff --git a/src/melobot/typ/__init__.py b/src/melobot/typ/__init__.py new file mode 100644 index 00000000..665a0d2a --- /dev/null +++ b/src/melobot/typ/__init__.py @@ -0,0 +1,9 @@ +from ._enum import LogicMode, VoidType +from .base import AsyncCallable, P, SyncOrAsyncCallable, T, T_co, is_type +from .cls import ( + BetterABC, + BetterABCMeta, + SingletonBetterABCMeta, + SingletonMeta, + abstractattr, +) diff --git a/src/melobot/typ/_enum.py b/src/melobot/typ/_enum.py new file mode 100644 index 00000000..7cac01e4 --- /dev/null +++ b/src/melobot/typ/_enum.py @@ -0,0 +1,173 @@ +from enum import Enum + +from typing_extensions import Any, Callable, Sequence + +from .base import AsyncCallable + + +class VoidType(Enum): + """空类型,需要区别于 `None` 时使用 + + .. code:: python + + # 有些时候 `None` 也是合法值,因此需要一个额外的哨兵值: + def foo(val: Any | VoidType = VoidType.VOID) -> None: + ... + """ + + VOID = type("_VOID", (), {}) + + +class LogicMode(Enum): + """逻辑模式枚举类型""" + + AND = 1 + OR = 2 + NOT = 3 + XOR = 4 + + @classmethod + def calc(cls, logic: "LogicMode", v1: Any, v2: Any = None) -> bool: + """将两个值使用指定逻辑模式运算 + + :param logic: 逻辑模式 + :param v1: 值 1 + :param v2: 值 2 + :return: 布尔值 + """ + if logic == LogicMode.AND: + return (v1 and v2) if v2 is not None else bool(v1) # type: ignore[no-any-return] + if logic == LogicMode.OR: + return (v1 or v2) if v2 is not None else bool(v1) # type: ignore[no-any-return] + if logic == LogicMode.NOT: + return not v1 + return (v1 ^ v2) if v2 is not None else bool(v1) # type: ignore[no-any-return] + + @classmethod + def short_calc( + cls, logic: "LogicMode", v1: Callable[[], Any], v2: Callable[[], Any] | None + ) -> bool: + """与 :func:`calc` 功能类似,但运算支持短路 + + :param logic: 逻辑模式 + :param v1: 生成值 1 的可调用对象 + :param v2: 生成值 2 的可调用对象 + :return: 布尔值 + """ + if logic == LogicMode.AND: + return (v1() and v2()) if v2 is not None else bool(v1()) # type: ignore[no-any-return] + if logic == LogicMode.OR: + return (v1() or v2()) if v2 is not None else bool(v1()) # type: ignore[no-any-return] + if logic == LogicMode.NOT: + return not v1() + return (v1() ^ v2()) if v2 is not None else bool(v1()) # type: ignore[no-any-return] + + @classmethod + async def async_short_calc( + cls, + logic: "LogicMode", + v1: AsyncCallable[[], Any], + v2: AsyncCallable[[], Any] | None, + ) -> bool: + """与 :func:`short_calc` 功能类似,但运算支持异步 + + :param logic: 逻辑模式 + :param v1: 生成值 1 的异步可调用对象 + :param v2: 生成值 2 的异步可调用对象 + :return: 布尔值 + """ + if logic == LogicMode.AND: + res = (await v1() and await v2()) if v2 is not None else bool(await v1()) + return res # type: ignore[no-any-return] + if logic == LogicMode.OR: + res = (await v1() or await v2()) if v2 is not None else bool(await v1()) + return res # type: ignore[no-any-return] + if logic == LogicMode.NOT: + return not await v1() + res = (await v1() ^ await v2()) if v2 is not None else bool(await v1()) + return res # type: ignore[no-any-return] + + @classmethod + def seq_calc(cls, logic: "LogicMode", values: list[Any]) -> bool: + """使用指定的逻辑模式,对值序列进行运算 + + .. code:: python + + # 操作等价与:True and False and True + LogicMode.seq_calc(LogicMode.AND, [True, False, True]) + + :param logic: 逻辑模式 + :param values: 值序列 + :return: 布尔值 + """ + if len(values) <= 0: + return False + if len(values) <= 1: + return bool(values[0]) + + idx = 0 + res: bool + while idx < len(values): + if idx == 0: + res = cls.calc(logic, values[idx], values[idx + 1]) + idx += 1 + else: + res = cls.calc(logic, res, values[idx]) + idx += 1 + return res + + @classmethod + def short_seq_calc( + cls, logic: "LogicMode", getters: Sequence[Callable[[], Any]] + ) -> bool: + """与 :func:`seq_calc` 功能类似,但运算支持短路 + + :param logic: 逻辑模式 + :param getters: 一组获取值的可调用对象 + :return: 布尔值 + """ + if len(getters) <= 0: + return False + if len(getters) <= 1: + return bool(getters[0]()) + + idx = 0 + res: bool + while idx < len(getters): + if idx == 0: + res = cls.short_calc(logic, getters[idx], getters[idx + 1]) + idx += 1 + else: + res = cls.short_calc(logic, lambda: res, getters[idx]) + idx += 1 + return res + + @classmethod + async def async_short_seq_calc( + cls, logic: "LogicMode", getters: Sequence[AsyncCallable[[], Any]] + ) -> bool: + """与 :func:`short_seq_calc` 功能类似,但运算支持异步 + + :param logic: 逻辑模式 + :param getters: 一组获取值的异步可调用对象 + :return: 布尔值 + """ + if len(getters) <= 0: + return False + if len(getters) <= 1: + return bool(await getters[0]()) + + idx = 0 + res: bool + while idx < len(getters): + if idx == 0: + res = await cls.async_short_calc(logic, getters[idx], getters[idx + 1]) + idx += 1 + else: + + async def res_getter() -> bool: + return res + + res = await cls.async_short_calc(logic, res_getter, getters[idx]) + idx += 1 + return res diff --git a/src/melobot/typ/base.py b/src/melobot/typ/base.py new file mode 100644 index 00000000..8c3b4335 --- /dev/null +++ b/src/melobot/typ/base.py @@ -0,0 +1,49 @@ +from beartype import BeartypeConf as _BeartypeConf +from beartype.door import is_bearable as _is_type +from beartype.door import is_subhint +from typing_extensions import Any, Awaitable, ParamSpec, Protocol, TypeIs, TypeVar + +__all__ = ("AsyncCallable", "P", "T", "T_co", "is_type", "is_subhint") + +#: 泛型 T,无约束 +T = TypeVar("T", default=Any) +#: 泛型 T,无约束 +U = TypeVar("U", default=Any) +#: 泛型 T,无约束 +V = TypeVar("V", default=Any) +#: 泛型 T_co,协变无约束 +T_co = TypeVar("T_co", covariant=True, default=Any) +#: :obj:`~typing.ParamSpec` 泛型 P,无约束 +P = ParamSpec("P", default=Any) + + +class AsyncCallable(Protocol[P, T_co]): + """用法:AsyncCallable[P, T] + + 是该类型的等价形式:Callable[P, Awaitable[T]] + """ + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T_co]: ... + + +class SyncOrAsyncCallable(Protocol[P, T_co]): + """用法:SyncOrAsyncCallable[P, T] + + 是该类型的等价形式:Callable[P, T | Awaitable[T]] + """ + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_co | Awaitable[T_co]: ... + + +_DEFAULT_BEARTYPE_CONF = _BeartypeConf(is_pep484_tower=True) + + +def is_type(obj: T, hint: type[Any]) -> TypeIs[T]: + """检查 `obj` 是否是类型注解 `hint` 所表示的类型 + + :param obj: 任意对象 + :param hint: 任意类型注解 + :return: 布尔值 + """ + ret = _is_type(obj, hint, conf=_DEFAULT_BEARTYPE_CONF) + return ret # type: ignore[no-any-return] diff --git a/src/melobot/typ/cls.py b/src/melobot/typ/cls.py new file mode 100644 index 00000000..7faaae5b --- /dev/null +++ b/src/melobot/typ/cls.py @@ -0,0 +1,115 @@ +import inspect +from abc import ABCMeta + +from typing_extensions import Any, Callable, cast + +from .base import T + + +class BetterABCMeta(ABCMeta): + """更好的抽象元类,兼容 `ABCMeta` 的所有功能,但是额外支持 :func:`abstractattr`""" + + class DummyAttribute: ... + + def __call__(cls: type[T], *args: Any, **kwargs: Any) -> T: + instance = ABCMeta.__call__(cls, *args, **kwargs) + lack_attrs = set() + for name in dir(instance): + try: + attr = getattr(instance, name) + except Exception: + if not isinstance(getattr(instance.__class__, name), property): + raise + + if getattr(attr, "__is_abstract_attribute__", False): + lack_attrs.add(name) + if inspect.iscoroutine(attr): + attr.close() + + if lack_attrs: + raise NotImplementedError( + "Can't instantiate abstract class {} with" + " abstract attributes: {}".format(cls.__name__, ", ".join(lack_attrs)) + ) + return cast(T, instance) + + +class BetterABC(metaclass=BetterABCMeta): + """更好的抽象类,兼容 `ABC` 的所有功能,但是额外支持 :func:`abstractattr`""" + + __slots__ = () + + +class SingletonMeta(type): + """单例元类 + + 相比单例装饰器,可以自动保证所有子类都为单例 + """ + + __instances__: dict[type, Any] = {} + + def __call__(cls: type[T], *args: Any, **kwargs: Any) -> T: + if cls not in SingletonMeta.__instances__: + SingletonMeta.__instances__[cls] = cast( + T, super(SingletonMeta, cls).__call__(*args, **kwargs) # type: ignore[misc] + ) + return cast(T, SingletonMeta.__instances__[cls]) + + +class SingletonBetterABCMeta(BetterABCMeta): + """单例抽象元类 + + 相比普通的抽象元类,还可以自动保证所有子类都为单例 + """ + + __instances__: dict[type, Any] = {} + + def __call__(cls: type[T], *args: Any, **kwargs: Any) -> T: + mcls = SingletonBetterABCMeta + if cls not in mcls.__instances__: + mcls.__instances__[cls] = BetterABCMeta.__call__(cls, *args, **kwargs) + return cast(T, mcls.__instances__[cls]) + + +def abstractattr(obj: Callable[[Any], T] | None = None) -> T: + """抽象属性 + + 与 `abstractproperty` 相比更灵活,`abstractattr` 不关心你以何种形式定义属性。只要子类在实例化后,该属性存在,即认为合法。 + + 但注意它必须与 :class:`BetterABC` 或 :class:`BetterABCMeta` 或 :class:`.SingletonBetterABCMeta` 配合使用 + + 这意味着可以在类层级、实例层级定义属性,或使用 `property` 定义属性: + + .. code:: python + + class A(BetterABC): + foo: int = abstractattr() # 声明为抽象属性 + + # 或者使用装饰器的形式声明,这与上面是等价的 + @abstractattr + def bar(self) -> int: ... + + # 以下实现都是合法的: + + class B(A): + foo = 2 + bar = 4 + + class C(A): + foo = 3 + def __init__(self) -> None: + self.bar = 5 + + class D(A): + def __init__(self) -> None: + self.foo = 8 + + @property + def bar(self) -> int: + return self.foo + 10 + """ + _obj = cast(Any, obj) + if obj is None: + _obj = BetterABCMeta.DummyAttribute() + setattr(_obj, "__is_abstract_attribute__", True) + return cast(T, _obj) diff --git a/src/melobot/utils.py b/src/melobot/utils.py deleted file mode 100644 index b0e67c84..00000000 --- a/src/melobot/utils.py +++ /dev/null @@ -1,700 +0,0 @@ -import asyncio -import base64 -import inspect -import time -from contextlib import asynccontextmanager -from datetime import datetime -from functools import wraps - -from typing_extensions import ( - Any, - AsyncContextManager, - AsyncGenerator, - Awaitable, - Callable, - ContextManager, - Coroutine, - Literal, - TypeVar, - cast, -) - -from .exceptions import ValidateError -from .typ import AsyncCallable, P, T - - -def get_obj_name( - obj: Any, - otype: Literal["callable", "class", "object"] | str = "object", - default: str = "", -) -> str: - """获取一个对象的限定名称或名称,这适用于一些类型较宽的参数。 - - 无法获取有效名称时,产生一个 `default % otype` 字符串 - - 例如某处接受一个 `Callable` 类型的参数,对于一般函数来说,使用 - `__qualname__` 或 `__name__` 可获得名称,但某些可调用对象这些值可能为 `None` - 或不存在。使用此方法可保证一定返回字符串 - - .. code:: python - - def _(a: Callable) -> None: - valid_str: str = get_obj_name(a, otype="callable") - - def _(a: type) -> None: - valid_str: str = get_obj_name(a, otype="class") - - def _(a: Any) -> None: - valid_str: str = get_obj_name(a, otype="type of a, only for str concat") - - - :param obj: 对象 - :param otype: 预期的对象类型 - :param default: 无法获取任何有效名称时的默认字符串 - :return: 对象名称或默认字符串 - """ - if hasattr(obj, "__qualname__"): - return cast(str, obj.__qualname__) - - if hasattr(obj, "__name__"): - return cast(str, obj.__name__) - - return default % otype - - -def singleton(cls: Callable[P, T]) -> Callable[P, T]: - """单例装饰器 - - :param cls: 需要被单例化的可调用对象 - :return: 需要被单例化的可调用对象 - """ - obj_map = {} - - @wraps(cls) - def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: - if cls not in obj_map: - obj_map[cls] = cls(*args, **kwargs) - return obj_map[cls] - - return wrapped - - -class RWContext: - """异步读写上下文 - - 提供异步安全的读写上下文。在读取时可以多读,同时读写互斥。 - - 使用方法: - - .. code:: python - - rwc = RWContext() - # 读时使用此控制器的安全读上下文: - async with rwc.read(): - ... - # 写时使用此控制器的安全写上下文: - async with rwc.write(): - ... - """ - - def __init__(self, read_limit: int | None = None) -> None: - """初始化异步读写上下文 - - :param read_limit: 读取的数量限制,为空则不限制 - """ - self.write_semaphore = asyncio.Semaphore(1) - self.read_semaphore = asyncio.Semaphore(read_limit) if read_limit else None - self.read_num = 0 - self.read_num_lock = asyncio.Lock() - - @asynccontextmanager - async def read(self) -> AsyncGenerator[None, None]: - """上下文管理器,展开一个关于该对象的安全异步读上下文""" - if self.read_semaphore: - await self.read_semaphore.acquire() - - async with self.read_num_lock: - if self.read_num == 0: - await self.write_semaphore.acquire() - self.read_num += 1 - - try: - yield - finally: - async with self.read_num_lock: - self.read_num -= 1 - if self.read_num == 0: - self.write_semaphore.release() - if self.read_semaphore: - self.read_semaphore.release() - - @asynccontextmanager - async def write(self) -> AsyncGenerator[None, None]: - """上下文管理器,展开一个关于该对象的安全异步写上下文""" - await self.write_semaphore.acquire() - try: - yield - finally: - self.write_semaphore.release() - - -class SnowFlakeIdWorker: - def __init__(self, datacenter_id: int, worker_id: int, sequence: int = 0) -> None: - self.max_worker_id = -1 ^ (-1 << 3) - self.max_datacenter_id = -1 ^ (-1 << 5) - self.worker_id_shift = 12 - self.datacenter_id_shift = 12 + 3 - self.timestamp_left_shift = 12 + 3 + 5 - self.sequence_mask = -1 ^ (-1 << 12) - self.startepoch = int(datetime(2022, 12, 11, 12, 8, 45).timestamp() * 1000) - - if worker_id > self.max_worker_id or worker_id < 0: - raise ValueError("worker_id 值越界") - if datacenter_id > self.max_datacenter_id or datacenter_id < 0: - raise ValueError("datacenter_id 值越界") - self.worker_id = worker_id - self.datacenter_id = datacenter_id - self.sequence = sequence - self.last_timestamp = -1 - - def _gen_timestamp(self) -> int: - return int(time.time() * 1000) - - def get_id(self) -> int: - timestamp = self._gen_timestamp() - - if timestamp < self.last_timestamp: - raise ValueError(f"时钟回拨,{self.last_timestamp} 前拒绝 id 生成请求") - if timestamp == self.last_timestamp: - self.sequence = (self.sequence + 1) & self.sequence_mask - if self.sequence == 0: - timestamp = self._until_next_millis(self.last_timestamp) - else: - self.sequence = 0 - self.last_timestamp = timestamp - new_id = ( - ((timestamp - self.startepoch) << self.timestamp_left_shift) - | (self.datacenter_id << self.datacenter_id_shift) - | (self.worker_id << self.worker_id_shift) - | self.sequence - ) - return new_id - - def get_b64_id(self, trim_pad: bool = True) -> str: - id = base64.urlsafe_b64encode( - self.get_id().to_bytes(8, byteorder="little") - ).decode() - if trim_pad: - id = id.rstrip("=") - return id - - def _until_next_millis(self, last_time: int) -> int: - timestamp = self._gen_timestamp() - while timestamp <= last_time: - timestamp = self._gen_timestamp() - return timestamp - - -_DEFAULT_ID_WORKER = SnowFlakeIdWorker(1, 1, 0) - - -def get_id() -> str: - """从 melobot 内部 id 获取器获得一个 id 值,不保证线程安全。算法使用雪花算法 - - :return: id 值 - """ - return _DEFAULT_ID_WORKER.get_b64_id() - - -def to_async( - obj: Callable[P, T] | AsyncCallable[P, T] | Awaitable[T] -) -> Callable[P, Coroutine[Any, Any, T]]: - """异步包装函数 - - 将一个可调用对象或可等待对象装饰为异步函数 - - :param obj: 需要转换的可调用对象或可等待对象 - :return: 异步函数 - """ - if inspect.iscoroutinefunction(obj): - return obj - - async def async_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: - if not inspect.isawaitable(obj): - ret = obj(*args, **kwargs) - else: - ret = obj - if inspect.isawaitable(ret): - return await ret - return ret - - if not inspect.isawaitable(obj): - async_wrapped = wraps(obj)(async_wrapped) - return async_wrapped - - -def to_coro( - obj: Callable[P, T] | AsyncCallable[P, T] | Awaitable[T], - *args: Any, - **kwargs: Any, -) -> Coroutine[Any, Any, T]: - """协程包装函数 - - 将一个可调用对象或可等待对象装饰为异步函数,并返回对应的协程 - - :param obj: 需要包装的可调用对象或可等待对象 - :param args: 需要使用的位置参数 - :param kwargs: 需要使用的关键字参数 - :return: 协程 - """ - if inspect.iscoroutine(obj): - return obj - return to_async(obj)(*args, **kwargs) # type: ignore[arg-type] - - -CbRetT = TypeVar("CbRetT", default=Any) -FirstCbRetT = TypeVar("FirstCbRetT", default=Any) -SecondCbRetT = TypeVar("SecondCbRetT", default=Any) -OriginRetT = TypeVar("OriginRetT", default=Any) -CondT = TypeVar("CondT", default=Any) - - -def if_not( - condition: Callable[[], CondT] | AsyncCallable[[], CondT] | CondT, - reject: AsyncCallable[[], None], - give_up: bool = False, - accept: AsyncCallable[[CondT], None] | None = None, -) -> Callable[[AsyncCallable[P, T]], AsyncCallable[P, T | None]]: - """条件判断装饰器 - - :param condition: 用于判断的条件(如果是可调用对象,则先求值再转为 bool 值) - :param reject: 当条件为 `False` 时,执行的回调 - :param give_up: 在条件为 `False` 时,是否放弃执行被装饰函数 - :param accept: 当条件为 `True` 时,执行的回调 - """ - - def deco_func(func: AsyncCallable[P, T]) -> AsyncCallable[P, T | None]: - - @wraps(func) - async def wrapped_func(*args: P.args, **kwargs: P.kwargs) -> T | None: - if not callable(condition): - cond = condition - else: - obj = condition() - cond = await obj if inspect.isawaitable(obj) else obj - - if not cond: - await async_guard(reject) - - if cond or not give_up: - if accept is not None: - await async_guard(accept, cond) - - return await async_guard(func, *args, **kwargs) - return None - - return wrapped_func - - return deco_func - - -def unfold_ctx( - getter: Callable[[], ContextManager | AsyncContextManager], -) -> Callable[[AsyncCallable[P, T]], AsyncCallable[P, T]]: - """上下文装饰器 - - 展开一个上下文,供被装饰函数使用。 - 但注意此装饰器不支持获取上下文管理器 `yield` 的值 - - :param getter: 上下文管理器获取方法 - """ - - def deco_func(func: AsyncCallable[P, T]) -> AsyncCallable[P, T]: - - @wraps(func) - async def wrapped_func(*args: P.args, **kwargs: P.kwargs) -> T: - manager = getter() - if isinstance(manager, ContextManager): - with manager: - return await async_guard(func, *args, **kwargs) - else: - async with manager: - return await async_guard(func, *args, **kwargs) - - return wrapped_func - - return deco_func - - -def lock( - callback: AsyncCallable[[], CbRetT] | None = None -) -> Callable[[AsyncCallable[P, OriginRetT]], AsyncCallable[P, CbRetT | OriginRetT]]: - """锁装饰器 - - 本方法作为异步函数的装饰器使用,可以为被装饰函数加锁。 - - 在获取锁冲突时,调用 `callback` 获得一个回调并执行。回调执行完毕后直接返回。 - - `callback` 参数为空,只应用 :class:`asyncio.Lock` 的锁功能。 - - 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值 - - :param callback: 获取锁冲突时的回调 - """ - alock = asyncio.Lock() - - def deco_func( - func: AsyncCallable[P, OriginRetT] - ) -> AsyncCallable[P, CbRetT | OriginRetT]: - - @wraps(func) - async def wrapped_func(*args: P.args, **kwargs: P.kwargs) -> CbRetT | OriginRetT: - if callback is not None and alock.locked(): - return await async_guard(callback) - async with alock: - return await async_guard(func, *args, **kwargs) - - return wrapped_func - - return deco_func - - -def cooldown( - busy_callback: AsyncCallable[[], FirstCbRetT] | None = None, - cd_callback: AsyncCallable[[float], SecondCbRetT] | None = None, - interval: float = 5, -) -> Callable[ - [AsyncCallable[P, OriginRetT]], - AsyncCallable[P, OriginRetT | FirstCbRetT | SecondCbRetT], -]: - """冷却装饰器 - - 本方法作为异步函数的装饰器使用,可以为被装饰函数添加 cd 时间。 - - 如果被装饰函数已有一个在运行,此时调用 `busy_callback` 生成回调并执行。回调执行完毕后直接返回。 - - `busy_callback` 参数为空,则等待已运行的运行完成。随后执行下面的“冷却”处理逻辑。 - - 当被装饰函数没有在运行的,但冷却时间未结束: - - `cd_callback` 不为空:使用 `cd_callback` 生成回调并执行。 - - `cd_callback` 为空,被装饰函数持续等待,直至冷却结束再执行。 - - 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值 - - :param busy_callback: 已运行时的回调 - :param cd_callback: 冷却时间未结束的回调 - :param interval: 冷却时间 - """ - alock = asyncio.Lock() - pre_finish_t = time.perf_counter() - interval - 1 - - def deco_func( - func: AsyncCallable[P, OriginRetT] - ) -> AsyncCallable[P, OriginRetT | FirstCbRetT | SecondCbRetT]: - - @wraps(func) - async def wrapped_func( - *args: P.args, **kwargs: P.kwargs - ) -> OriginRetT | FirstCbRetT | SecondCbRetT: - nonlocal pre_finish_t - if busy_callback is not None and alock.locked(): - return await async_guard(busy_callback) - - async with alock: - duration = time.perf_counter() - pre_finish_t - if duration > interval: - ret = await async_guard(func, *args, **kwargs) - pre_finish_t = time.perf_counter() - return ret - - remain_t = interval - duration - if cd_callback is not None: - return await async_guard(cd_callback, remain_t) - - await asyncio.sleep(remain_t) - ret = await async_guard(func, *args, **kwargs) - pre_finish_t = time.perf_counter() - return ret - - return wrapped_func - - return deco_func - - -def semaphore( - callback: AsyncCallable[[], CbRetT] | None = None, value: int = -1 -) -> Callable[[AsyncCallable[P, OriginRetT]], AsyncCallable[P, OriginRetT | CbRetT]]: - """信号量装饰器 - - 本方法作为异步函数的装饰器使用,可以为被装饰函数添加信号量控制。 - - 在信号量无法立刻获取时,将调用 `callback` 获得回调并执行。回调执行完毕后直接返回。 - - `callback` 参数为空,只应用 :class:`asyncio.Semaphore` 的信号量功能。 - - 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值 - - :param callback: 信号量无法立即获取的回调 - :param value: 信号量阈值 - """ - a_semaphore = asyncio.Semaphore(value) - - def deco_func( - func: AsyncCallable[P, OriginRetT] - ) -> AsyncCallable[P, OriginRetT | CbRetT]: - - @wraps(func) - async def wrapped_func(*args: P.args, **kwargs: P.kwargs) -> OriginRetT | CbRetT: - if callback is not None and a_semaphore.locked(): - return await async_guard(callback) - async with a_semaphore: - return await async_guard(func, *args, **kwargs) - - return wrapped_func - - return deco_func - - -def timelimit( - callback: AsyncCallable[[], CbRetT] | None = None, timeout: float = 5 -) -> Callable[[AsyncCallable[P, OriginRetT]], AsyncCallable[P, OriginRetT | CbRetT]]: - """时间限制装饰器 - - 本方法作为异步函数的装饰器使用,可以为被装饰函数添加超时控制。 - - 超时之后,调用 `callback` 获得回调并执行,同时取消原任务。 - - `callback` 参数为空,如果超时,则抛出 :class:`asyncio.TimeoutError` 异常。 - - 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值 - - :param callback: 超时时的回调 - :param timeout: 超时时间 - """ - - def deco_func( - func: AsyncCallable[P, OriginRetT] - ) -> AsyncCallable[P, OriginRetT | CbRetT]: - - @wraps(func) - async def wrapped_func(*args: P.args, **kwargs: P.kwargs) -> OriginRetT | CbRetT: - try: - return await asyncio.wait_for(async_guard(func, *args, **kwargs), timeout) - except asyncio.TimeoutError: - if callback is None: - raise TimeoutError("timelimit 所装饰的任务已超时") from None - return await async_guard(callback) - - return wrapped_func - - return deco_func - - -def speedlimit( - callback: AsyncCallable[[], CbRetT] | None = None, - limit: int = 60, - duration: int = 60, -) -> Callable[[AsyncCallable[P, OriginRetT]], AsyncCallable[P, OriginRetT | CbRetT]]: - """流量/速率限制装饰器(使用固定窗口算法) - - 本方法作为异步函数的装饰器使用,可以为被装饰函数添加流量控制:`duration` 秒内只允许 `limit` 次调用。 - - 超出调用速率限制后,调用 `callback` 获得回调并执行,同时取消原任务。 - - `callback` 参数为空,等待直至满足速率控制要求再调用。 - - 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值。 - - :param callback: 超出速率限制时的回调 - :param limit: `duration` 秒内允许调用多少次 - :param duration: 时长区间 - """ - called_num = 0 - min_start = time.perf_counter() - if limit <= 0: - raise ValidateError("speedlimit 装饰器的 limit 参数必须 > 0") - if duration <= 0: - raise ValidateError("speedlimit 装饰器的 duration 参数必须 > 0") - - def deco_func( - func: AsyncCallable[P, OriginRetT] - ) -> AsyncCallable[P, OriginRetT | CbRetT]: - - @wraps(func) - async def wrapped_func(*args: P.args, **kwargs: P.kwargs) -> OriginRetT | CbRetT: - fut = _wrapped_func(func, *args, **kwargs) - fut = cast(asyncio.Future[CbRetT | OriginRetT | Exception], fut) - fut_ret = await fut - if isinstance(fut_ret, Exception): - raise fut_ret - return fut_ret - - return wrapped_func - - def _wrapped_func( - func: AsyncCallable[P, T], - *args: P.args, - **kwargs: P.kwargs, - ) -> asyncio.Future: - # 分离出来定义,方便 result_set 调用形成递归。主要逻辑通过 Future 实现,有利于避免竞争问题。 - nonlocal called_num, min_start - passed_time = time.perf_counter() - min_start - res_fut: Any = asyncio.get_running_loop().create_future() - - if passed_time <= duration: - if called_num < limit: - called_num += 1 - asyncio.create_task(result_set(func, res_fut, -1, *args, **kwargs)) - - elif callback is not None: - asyncio.create_task(result_set(callback, res_fut, -1)) - - else: - asyncio.create_task( - result_set(func, res_fut, duration - passed_time, *args, **kwargs) - ) - else: - called_num, min_start = 0, time.perf_counter() - called_num += 1 - asyncio.create_task(result_set(func, res_fut, -1, *args, **kwargs)) - - return cast(asyncio.Future, res_fut) - - async def result_set( - func: AsyncCallable[P, T], - fut: asyncio.Future, - delay: float, - *args: P.args, - **kwargs: P.kwargs, - ) -> None: - """ - 只有依然在当前 duration 区间内,但超出调用次数限制的,需要等待。 - 随后就是递归调用。delay > 0 为需要递归的分支。 - """ - nonlocal called_num - try: - if delay > 0: - await asyncio.sleep(delay) - res = await _wrapped_func(func, *args, **kwargs) - fut.set_result(res) - return - - res = await async_guard(func, *args, **kwargs) - fut.set_result(res) - - except Exception as e: - fut.set_result(e) - - return deco_func - - -def call_later(callback: Callable[[], None], delay: float) -> asyncio.TimerHandle: - """同步函数延迟调度 - - 在指定的 `delay` 后调度一个 `callback` 执行。`callback` 应该是同步方法。 - - :param callback: 同步函数 - :param delay: 多长时间后调度 - :return: :class:`asyncio.TimerHandle` 对象 - """ - return asyncio.get_running_loop().call_later(delay, callback) - - -def call_at(callback: Callable[[], None], timestamp: float) -> asyncio.TimerHandle: - """同步函数指定时间调度 - - 在指定的时间戳调度一个 `callback` 执行。`callback` 应该是同步方法。`timestamp` <= 当前时刻回调立即执行 - - :param callback: 同步函数 - :param timestamp: 在什么时刻调度 - :return: :class:`asyncio.TimerHandle` 对象 - """ - loop = asyncio.get_running_loop() - if timestamp <= time.time_ns() / 1e9: - return loop.call_later(0, callback) - - return loop.call_later(timestamp - time.time_ns() / 1e9, callback) - - -def async_later(callback: Coroutine[Any, Any, T], delay: float) -> asyncio.Task[T]: - """异步函数延迟调度(可自主选择是否等待) - - 在指定的 `delay` 后调度一个 `callback` 执行。`callback` 是协程。 - - 返回一个 :class:`asyncio.Task` 对象,等待 :class:`asyncio.Task` 即是等待 `callback` 的返回值。 - - :param callback: 协程(可有返回值) - :param delay: 多长时间后调度 - :return: :class:`asyncio.Task` 对象 - """ - - async def _later_task() -> T: - try: - await asyncio.sleep(delay) - res = await callback - return res - except asyncio.CancelledError: - callback.close() - raise - - return asyncio.create_task(_later_task()) - - -def async_at(callback: Coroutine[Any, Any, T], timestamp: float) -> asyncio.Task[T]: - """异步函数指定时间调度(可自主选择是否等待) - - 在指定的时间戳调度一个 `callback` 执行。`callback` 是协程。 - - 返回一个 :class:`asyncio.Task` 对象,等待 :class:`asyncio.Task` 即是等待 `callback` 的返回值。 - - 注意:如果 `callback` 未完成就被取消,需要捕获 :class:`asyncio.CancelledError`。 - - :param callback: 协程(可有返回值) - :param timestamp: 在什么时刻调度 - :return: :class:`asyncio.Task` 对象 - """ - if timestamp <= time.time_ns() / 1e9: - return async_later(callback, 0) - - return async_later(callback, timestamp - time.time_ns() / 1e9) - - -def async_interval( - callback: Callable[[], Coroutine[Any, Any, None]], interval: float -) -> asyncio.Task[None]: - """异步函数间隔调度(类似 JavaScript 的 setInterval) - - 每过时间间隔执行 `callback` 一次。`callback` 是返回协程的可调用对象(异步函数或 lambda 函数等)。 - - 返回一个 :class:`asyncio.Task` 对象,可使用该 task 取消调度过程。 - - :param callback: 异步函数 - :param interval: 调度的间隔 - :return: :class:`asyncio.Task` 对象 - """ - - async def _interval_task() -> None: - try: - while True: - coro = callback() - await asyncio.sleep(interval) - await coro - except asyncio.CancelledError: - coro.close() - raise - - t = asyncio.create_task(_interval_task()) - return t - - -async def async_guard(func: AsyncCallable[..., T], *args: Any, **kwargs: Any) -> T: - """在使用异步可调用对象时,提供用户友好的验证""" - if not callable(func): - raise ValidateError(f"{func} 不是异步可调用对象(返回 Awaitable 的可调用对象)") - - await_obj = func(*args, **kwargs) - if inspect.isawaitable(await_obj): - return await await_obj - raise ValidateError( - f"{func} 应该是异步函数,或其他异步可调用对象(返回 Awaitable 的可调用对象)。但它返回了:{await_obj},因此可能是同步函数" - ) diff --git a/src/melobot/utils/__init__.py b/src/melobot/utils/__init__.py new file mode 100644 index 00000000..75bb47bb --- /dev/null +++ b/src/melobot/utils/__init__.py @@ -0,0 +1,12 @@ +from .atool import async_at, async_interval, async_later, call_at, call_later +from .base import async_guard, to_async, to_coro, to_sync +from .common import ( + DeprecatedLoader, + RWContext, + deprecate_warn, + deprecated, + get_id, + get_obj_name, + singleton, +) +from .deco import cooldown, if_not, lock, semaphore, speedlimit, timelimit, unfold_ctx diff --git a/src/melobot/utils/atool.py b/src/melobot/utils/atool.py new file mode 100644 index 00000000..6f9681fb --- /dev/null +++ b/src/melobot/utils/atool.py @@ -0,0 +1,104 @@ +import asyncio +import time + +from typing_extensions import Any, Callable, Coroutine + +from ..typ.base import T + + +def call_later(callback: Callable[[], None], delay: float) -> asyncio.TimerHandle: + """同步函数延迟调度 + + 在指定的 `delay` 后调度一个 `callback` 执行。`callback` 应该是同步方法。 + + :param callback: 同步函数 + :param delay: 多长时间后调度 + :return: :class:`asyncio.TimerHandle` 对象 + """ + return asyncio.get_running_loop().call_later(delay, callback) + + +def call_at(callback: Callable[[], None], timestamp: float) -> asyncio.TimerHandle: + """同步函数指定时间调度 + + 在指定的时间戳调度一个 `callback` 执行。`callback` 应该是同步方法。`timestamp` <= 当前时刻回调立即执行 + + :param callback: 同步函数 + :param timestamp: 在什么时刻调度 + :return: :class:`asyncio.TimerHandle` 对象 + """ + loop = asyncio.get_running_loop() + if timestamp <= time.time_ns() / 1e9: + return loop.call_later(0, callback) + + return loop.call_later(timestamp - time.time_ns() / 1e9, callback) + + +def async_later(callback: Coroutine[Any, Any, T], delay: float) -> asyncio.Task[T]: + """异步函数延迟调度(可自主选择是否等待) + + 在指定的 `delay` 后调度一个 `callback` 执行。`callback` 是协程。 + + 返回一个 :class:`asyncio.Task` 对象,等待 :class:`asyncio.Task` 即是等待 `callback` 的返回值。 + + :param callback: 协程(可有返回值) + :param delay: 多长时间后调度 + :return: :class:`asyncio.Task` 对象 + """ + + async def _later_task() -> T: + try: + await asyncio.sleep(delay) + res = await callback + return res + except asyncio.CancelledError: + callback.close() + raise + + return asyncio.create_task(_later_task()) + + +def async_at(callback: Coroutine[Any, Any, T], timestamp: float) -> asyncio.Task[T]: + """异步函数指定时间调度(可自主选择是否等待) + + 在指定的时间戳调度一个 `callback` 执行。`callback` 是协程。 + + 返回一个 :class:`asyncio.Task` 对象,等待 :class:`asyncio.Task` 即是等待 `callback` 的返回值。 + + 注意:如果 `callback` 未完成就被取消,需要捕获 :class:`asyncio.CancelledError`。 + + :param callback: 协程(可有返回值) + :param timestamp: 在什么时刻调度 + :return: :class:`asyncio.Task` 对象 + """ + if timestamp <= time.time_ns() / 1e9: + return async_later(callback, 0) + + return async_later(callback, timestamp - time.time_ns() / 1e9) + + +def async_interval( + callback: Callable[[], Coroutine[Any, Any, None]], interval: float +) -> asyncio.Task[None]: + """异步函数间隔调度(类似 JavaScript 的 setInterval) + + 每过时间间隔执行 `callback` 一次。`callback` 是返回协程的可调用对象(异步函数或 lambda 函数等)。 + + 返回一个 :class:`asyncio.Task` 对象,可使用该 task 取消调度过程。 + + :param callback: 异步函数 + :param interval: 调度的间隔 + :return: :class:`asyncio.Task` 对象 + """ + + async def _interval_task() -> None: + try: + while True: + coro = callback() + await asyncio.sleep(interval) + await coro + except asyncio.CancelledError: + coro.close() + raise + + return asyncio.create_task(_interval_task()) diff --git a/src/melobot/utils/base.py b/src/melobot/utils/base.py new file mode 100644 index 00000000..2b1be388 --- /dev/null +++ b/src/melobot/utils/base.py @@ -0,0 +1,92 @@ +import asyncio +import inspect +from functools import wraps + +from typing_extensions import Any, Awaitable, Callable, Coroutine + +from ..exceptions import UtilValidateError +from ..typ.base import AsyncCallable, P, SyncOrAsyncCallable, T + + +def to_async( + obj: SyncOrAsyncCallable[P, T] | Awaitable[T] +) -> Callable[P, Coroutine[Any, Any, T]]: + """异步包装函数 + + 将一个可调用对象或可等待对象装饰为异步函数 + + :param obj: 需要转换的可调用对象或可等待对象 + :return: 异步函数 + """ + if inspect.iscoroutinefunction(obj): + return obj + + async def to_async_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + if not inspect.isawaitable(obj): + ret = obj(*args, **kwargs) + else: + ret = obj + if inspect.isawaitable(ret): + return await ret + return ret + + if not inspect.isawaitable(obj): + to_async_wrapped = wraps(obj)(to_async_wrapped) + return to_async_wrapped + + +def to_coro( + obj: SyncOrAsyncCallable[P, T] | Awaitable[T], *args: Any, **kwargs: Any +) -> Coroutine[Any, Any, T]: + """协程包装函数 + + 将一个可调用对象或可等待对象装饰为异步函数,并返回对应的协程 + + :param obj: 需要包装的可调用对象或可等待对象 + :param args: 需要使用的位置参数 + :param kwargs: 需要使用的关键字参数 + :return: 协程 + """ + if inspect.iscoroutine(obj): + return obj + return to_async(obj)(*args, **kwargs) # type: ignore[arg-type] + + +async def async_guard(func: AsyncCallable[..., T], *args: Any, **kwargs: Any) -> T: + """在使用异步可调用对象时,提供用户友好的验证""" + if not callable(func): + raise UtilValidateError( + f"{func} 不是异步可调用对象(返回 Awaitable 的可调用对象)" + ) + + await_obj = func(*args, **kwargs) + if inspect.isawaitable(await_obj): + return await await_obj + raise UtilValidateError( + f"{func} 应该是异步函数,或其他异步可调用对象(返回 Awaitable 的可调用对象)。但它返回了:{await_obj},因此可能是同步函数" + ) + + +def to_sync(obj: SyncOrAsyncCallable[P, Any] | Awaitable[Any]) -> Callable[P, None]: + """同步包装函数 + + 将一个可调用对象或可等待对象装饰为同步函数,但同步函数无法异步等待,包装后无法获取返回值 + + 因此仅用于接口兼容,如果提供了异步可调用对象,需要自行捕获内部可能的异常 + + :param obj: 需要转换的可调用对象或可等待对象 + :return: 同步函数 + """ + + def to_sync_wrapped(*args: P.args, **kwargs: P.kwargs) -> None: + if inspect.isawaitable(obj): + asyncio.create_task(to_coro(obj)) + return + + res = obj(*args, **kwargs) + if inspect.isawaitable(res): + asyncio.create_task(to_coro(res)) + + if not inspect.isawaitable(obj): + to_sync_wrapped = wraps(obj)(to_sync_wrapped) + return to_sync_wrapped diff --git a/src/melobot/utils/check/__init__.py b/src/melobot/utils/check/__init__.py new file mode 100644 index 00000000..579789f7 --- /dev/null +++ b/src/melobot/utils/check/__init__.py @@ -0,0 +1 @@ +from .base import Checker, WrappedChecker, checker_join diff --git a/src/melobot/utils/check/base.py b/src/melobot/utils/check/base.py new file mode 100644 index 00000000..77805e7b --- /dev/null +++ b/src/melobot/utils/check/base.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from abc import abstractmethod +from functools import partial + +from typing_extensions import Any, Callable, Coroutine, Generic + +from ...adapter.model import EventT +from ...exceptions import UtilValidateError +from ...typ._enum import LogicMode +from ...typ.base import SyncOrAsyncCallable +from ...typ.cls import BetterABC +from ..base import to_async + + +class Checker(Generic[EventT], BetterABC): + """检查器基类""" + + def __init__(self, fail_cb: SyncOrAsyncCallable[[], None] | None = None) -> None: + super().__init__() + self.fail_cb = to_async(fail_cb) if fail_cb is not None else None + + def __and__(self, other: Checker) -> WrappedChecker: + if not isinstance(other, Checker): + raise UtilValidateError( + f"联合检查器定义时出现了非检查器对象,其值为:{other}" + ) + return WrappedChecker(LogicMode.AND, self, other) + + def __or__(self, other: Checker) -> WrappedChecker: + if not isinstance(other, Checker): + raise UtilValidateError( + f"联合检查器定义时出现了非检查器对象,其值为:{other}" + ) + return WrappedChecker(LogicMode.OR, self, other) + + def __invert__(self) -> WrappedChecker: + return WrappedChecker(LogicMode.NOT, self) + + def __xor__(self, other: Checker) -> WrappedChecker: + if not isinstance(other, Checker): + raise UtilValidateError( + f"联合检查器定义时出现了非检查器对象,其值为:{other}" + ) + return WrappedChecker(LogicMode.XOR, self, other) + + @abstractmethod + async def check(self, event: EventT) -> bool: + """检查器检查方法 + + 任何检查器应该实现此抽象方法。 + + :param event: 给定的事件 + :return: 检查是否通过 + """ + raise NotImplementedError + + @staticmethod + def new(func: Callable[[EventT], bool]) -> Checker[EventT]: + return _CustomChecker[EventT](func) + + +class _CustomChecker(Checker[EventT]): + def __init__(self, func: Callable[[EventT], bool]) -> None: + super().__init__() + self.func = func + + async def check(self, event: EventT) -> bool: + return self.func(event) + + +class WrappedChecker(Checker[EventT]): + """合并检查器 + + 在两个 :class:`Checker` 对象间使用 | & ^ ~ 运算符即可返回合并检查器。 + """ + + def __init__( + self, + mode: LogicMode, + checker1: Checker, + checker2: Checker | None = None, + ) -> None: + """初始化一个合并检查器 + + :param mode: 合并检查的逻辑模式 + :param checker1: 检查器1 + :param checker2: 检查器2 + """ + super().__init__() + self.mode = mode + self.c1, self.c2 = checker1, checker2 + + def set_fail_cb(self, fail_cb: SyncOrAsyncCallable[[], None] | None) -> None: + self.fail_cb = to_async(fail_cb) if fail_cb is not None else None + + async def check(self, event: EventT) -> bool: + c2_check: Callable[[], Coroutine[Any, Any, bool]] | None = ( + partial(self.c2.check, event) if self.c2 is not None else None + ) + status = await LogicMode.async_short_calc( + self.mode, partial(self.c1.check, event), c2_check + ) + + if not status and self.fail_cb is not None: + await self.fail_cb() + return status + + +def checker_join(*checkers: Checker | None | Callable[[Any], bool]) -> Checker: + """合并检查器 + + 相比于使用 | & ^ ~ 运算符,此函数可以接受一个检查器序列,并返回一个合并检查器。 + 检查器序列可以为检查器对象,检查函数或空值 + + :return: 合并后的检查器对象 + """ + checker: Checker | None = None + for c in checkers: + if c is None: + continue + if isinstance(c, Checker): + checker = checker & c if checker else c + else: + checker = checker & Checker.new(c) if checker else Checker.new(c) + + if checker is None: + raise ValueError("检查器序列不能全为空") + return checker diff --git a/src/melobot/utils/common.py b/src/melobot/utils/common.py new file mode 100644 index 00000000..4a3826a8 --- /dev/null +++ b/src/melobot/utils/common.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import asyncio +import base64 +import importlib +import time +import warnings +from contextlib import asynccontextmanager +from datetime import datetime +from functools import wraps + +from typing_extensions import Any, AsyncGenerator, Callable, Literal, cast + +from ..typ.base import P, T + + +def get_obj_name( + obj: Any, + otype: Literal["callable", "class", "object"] | str = "object", + default: str = "", +) -> str: + """获取一个对象的限定名称或名称,这适用于一些类型较宽的参数。 + + 无法获取有效名称时,产生一个 `default % otype` 字符串 + + 例如某处接受一个 `Callable` 类型的参数,对于一般函数来说,使用 + `__qualname__` 或 `__name__` 可获得名称,但某些可调用对象这些值可能为 `None` + 或不存在。使用此方法可保证一定返回字符串 + + .. code:: python + + def _(a: Callable) -> None: + valid_str: str = get_obj_name(a, otype="callable") + + def _(a: type) -> None: + valid_str: str = get_obj_name(a, otype="class") + + def _(a: Any) -> None: + valid_str: str = get_obj_name(a, otype="type of a, only for str concat") + + + :param obj: 对象 + :param otype: 预期的对象类型 + :param default: 无法获取任何有效名称时的默认字符串 + :return: 对象名称或默认字符串 + """ + if hasattr(obj, "__qualname__"): + return cast(str, obj.__qualname__) + + if hasattr(obj, "__name__"): + return cast(str, obj.__name__) + + return default % otype + + +def singleton(cls: Callable[P, T]) -> Callable[P, T]: + """单例装饰器 + + :param cls: 需要被单例化的可调用对象 + :return: 需要被单例化的可调用对象 + """ + obj_map = {} + + @wraps(cls) + def singleton_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + if cls not in obj_map: + obj_map[cls] = cls(*args, **kwargs) + return obj_map[cls] + + return singleton_wrapped + + +def deprecate_warn(msg: str, stacklevel: int = 2) -> None: + # pylint: disable=cyclic-import + from ..ctx import LoggerCtx + + if logger := LoggerCtx().try_get(): + logger.warning(msg) + warnings.simplefilter("always", DeprecationWarning) + warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) + warnings.simplefilter("default", DeprecationWarning) + + +def deprecated(msg: str) -> Callable[[Callable[P, T]], Callable[P, T]]: + + def deprecated_wrapper(func: Callable[P, T]) -> Callable[P, T]: + + @wraps(func) + def deprecated_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + deprecate_warn( + f"使用了弃用函数/方法 {func.__module__}.{func.__qualname__}: {msg}", + stacklevel=3, + ) + return func(*args, **kwargs) + + return deprecated_wrapped + + return deprecated_wrapper + + +class DeprecatedLoader: + def __init__(self, mod_name: str, obj_pairs: dict[str, tuple[str, str, str]]) -> None: + self.__depre_mod_name__ = mod_name + self.__deprecations__ = obj_pairs + + def get(self, name: str) -> Any: + if name not in self.__deprecations__: + raise AttributeError( + f"module {self.__depre_mod_name__!r} has no attribute {name!r}" + ) + location, varname, ver = self.__deprecations__[name] + deprecate_warn( + f"{self.__depre_mod_name__}.{name} 现以弃用," + f"将于 {ver} 版本移除,使用 {location}.{varname} 代替", + stacklevel=4, + ) + return getattr(importlib.import_module(location), varname) + + @staticmethod + def merge(name: str, *loaders: DeprecatedLoader) -> DeprecatedLoader: + dic: dict[str, tuple[str, str, str]] = {} + for loader in loaders: + dic |= loader.__deprecations__ + return DeprecatedLoader(name, dic) + + +class RWContext: + """异步读写上下文 + + 提供异步安全的读写上下文。在读取时可以多读,同时读写互斥。 + + 使用方法: + + .. code:: python + + rwc = RWContext() + # 读时使用此控制器的安全读上下文: + async with rwc.read(): + ... + # 写时使用此控制器的安全写上下文: + async with rwc.write(): + ... + """ + + def __init__(self, read_limit: int | None = None) -> None: + """初始化异步读写上下文 + + :param read_limit: 读取的数量限制,为空则不限制 + """ + self.write_semaphore = asyncio.Semaphore(1) + self.read_semaphore = asyncio.Semaphore(read_limit) if read_limit else None + self.read_num = 0 + self.read_num_lock = asyncio.Lock() + + @asynccontextmanager + async def read(self) -> AsyncGenerator[None, None]: + """上下文管理器,展开一个关于该对象的安全异步读上下文""" + if self.read_semaphore: + await self.read_semaphore.acquire() + + async with self.read_num_lock: + if self.read_num == 0: + await self.write_semaphore.acquire() + self.read_num += 1 + + try: + yield + finally: + async with self.read_num_lock: + self.read_num -= 1 + if self.read_num == 0: + self.write_semaphore.release() + if self.read_semaphore: + self.read_semaphore.release() + + @asynccontextmanager + async def write(self) -> AsyncGenerator[None, None]: + """上下文管理器,展开一个关于该对象的安全异步写上下文""" + await self.write_semaphore.acquire() + try: + yield + finally: + self.write_semaphore.release() + + +class SnowFlakeIdWorker: + def __init__(self, datacenter_id: int, worker_id: int, sequence: int = 0) -> None: + self.max_worker_id = -1 ^ (-1 << 3) + self.max_datacenter_id = -1 ^ (-1 << 5) + self.worker_id_shift = 12 + self.datacenter_id_shift = 12 + 3 + self.timestamp_left_shift = 12 + 3 + 5 + self.sequence_mask = -1 ^ (-1 << 12) + self.startepoch = int(datetime(2022, 12, 11, 12, 8, 45).timestamp() * 1000) + + if worker_id > self.max_worker_id or worker_id < 0: + raise ValueError("worker_id 值越界") + if datacenter_id > self.max_datacenter_id or datacenter_id < 0: + raise ValueError("datacenter_id 值越界") + self.worker_id = worker_id + self.datacenter_id = datacenter_id + self.sequence = sequence + self.last_timestamp = -1 + + def _gen_timestamp(self) -> int: + return int(time.time() * 1000) + + def get_id(self) -> int: + timestamp = self._gen_timestamp() + + if timestamp < self.last_timestamp: + raise ValueError(f"时钟回拨,{self.last_timestamp} 前拒绝 id 生成请求") + if timestamp == self.last_timestamp: + self.sequence = (self.sequence + 1) & self.sequence_mask + if self.sequence == 0: + timestamp = self._until_next_millis(self.last_timestamp) + else: + self.sequence = 0 + self.last_timestamp = timestamp + new_id = ( + ((timestamp - self.startepoch) << self.timestamp_left_shift) + | (self.datacenter_id << self.datacenter_id_shift) + | (self.worker_id << self.worker_id_shift) + | self.sequence + ) + return new_id + + def get_b64_id(self, trim_pad: bool = True) -> str: + id = base64.urlsafe_b64encode( + self.get_id().to_bytes(8, byteorder="little") + ).decode() + if trim_pad: + id = id.rstrip("=") + return id + + def _until_next_millis(self, last_time: int) -> int: + timestamp = self._gen_timestamp() + while timestamp <= last_time: + timestamp = self._gen_timestamp() + return timestamp + + +_DEFAULT_ID_WORKER = SnowFlakeIdWorker(1, 1, 0) + + +def get_id() -> str: + """从 melobot 内部 id 获取器获得一个 id 值,不保证线程安全。算法使用雪花算法 + + :return: id 值 + """ + return _DEFAULT_ID_WORKER.get_b64_id() diff --git a/src/melobot/utils/deco.py b/src/melobot/utils/deco.py new file mode 100644 index 00000000..6ec729ae --- /dev/null +++ b/src/melobot/utils/deco.py @@ -0,0 +1,360 @@ +import asyncio +import inspect +import time +from functools import wraps + +from typing_extensions import Any, AsyncContextManager, Callable, ContextManager, cast + +from ..exceptions import UtilValidateError +from ..typ.base import AsyncCallable, P, SyncOrAsyncCallable, T, U, V +from .base import to_async + + +def if_not( + condition: SyncOrAsyncCallable[[], U] | U, + reject: SyncOrAsyncCallable[[], None], + give_up: bool = False, + accept: SyncOrAsyncCallable[[U], None] | None = None, +) -> Callable[[SyncOrAsyncCallable[P, T]], AsyncCallable[P, T | None]]: + """条件判断装饰器 + + :param condition: 用于判断的条件(如果是可调用对象,则先求值再转为 bool 值) + :param reject: 当条件为 `False` 时,执行的回调 + :param give_up: 在条件为 `False` 时,是否放弃执行被装饰函数 + :param accept: 当条件为 `True` 时,执行的回调 + """ + _condition = to_async(condition) if callable(condition) else condition + _reject = to_async(reject) + _accept = to_async(accept) if accept is not None else accept + + def if_not_wrapper(func: SyncOrAsyncCallable[P, T]) -> AsyncCallable[P, T | None]: + _func = to_async(func) + + @wraps(func) + async def if_not_wrapped(*args: P.args, **kwargs: P.kwargs) -> T | None: + if not callable(_condition): + cond = _condition + else: + obj = _condition() + cond = await obj if inspect.isawaitable(obj) else obj + + if not cond: + await _reject() + + if cond or not give_up: + if _accept is not None: + await _accept(cond) + + return await _func(*args, **kwargs) + return None + + return if_not_wrapped + + return if_not_wrapper + + +def unfold_ctx( + getter: SyncOrAsyncCallable[[], ContextManager | AsyncContextManager], +) -> Callable[[SyncOrAsyncCallable[P, T]], AsyncCallable[P, T]]: + """上下文装饰器 + + 展开一个上下文,供被装饰函数使用。 + 但注意此装饰器不支持获取上下文管理器 `yield` 的值 + + :param getter: 上下文管理器或上下文管理器获取方法 + """ + + _getter = to_async(getter) + + def unfold_ctx_wrapper(func: SyncOrAsyncCallable[P, T]) -> AsyncCallable[P, T]: + _func = to_async(func) + + @wraps(func) + async def unfold_ctx_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + try: + manager = await _getter() + except Exception as e: + raise UtilValidateError( + f"{unfold_ctx.__name__} 的 getter 参数为:{getter},调用它获取上下文管理器失败:{e}" + ) from e + + if isinstance(manager, ContextManager): + with manager: + return await _func(*args, **kwargs) + elif isinstance(manager, AsyncContextManager): + async with manager: + return await _func(*args, **kwargs) + else: + raise UtilValidateError( + f"{unfold_ctx.__name__} 的 getter 参数为:{getter},调用它返回了无效的上下文管理器" + ) + + return unfold_ctx_wrapped + + return unfold_ctx_wrapper + + +def lock( + callback: SyncOrAsyncCallable[[], U] | None = None +) -> Callable[[SyncOrAsyncCallable[P, T]], AsyncCallable[P, T | U]]: + """锁装饰器 + + 本方法作为异步函数的装饰器使用,可以为被装饰函数加锁。 + + 在获取锁冲突时,调用 `callback` 获得一个回调并执行。回调执行完毕后直接返回。 + + `callback` 参数为空,只应用 :class:`asyncio.Lock` 的锁功能。 + + 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值 + + :param callback: 获取锁冲突时的回调 + """ + alock = asyncio.Lock() + _callback = to_async(callback) if callback is not None else None + + def lock_wrapper(func: SyncOrAsyncCallable[P, T]) -> AsyncCallable[P, T | U]: + _func = to_async(func) + + @wraps(func) + async def lock_wrapped(*args: P.args, **kwargs: P.kwargs) -> T | U: + if _callback is not None and alock.locked(): + return await _callback() + async with alock: + return await _func(*args, **kwargs) + + return lock_wrapped + + return lock_wrapper + + +def cooldown( + busy_callback: SyncOrAsyncCallable[[], U] | None = None, + cd_callback: SyncOrAsyncCallable[[float], V] | None = None, + interval: float = 5, +) -> Callable[[SyncOrAsyncCallable[P, T]], AsyncCallable[P, T | U | V]]: + """冷却装饰器 + + 本方法作为异步函数的装饰器使用,可以为被装饰函数添加 cd 时间。 + + 如果被装饰函数已有一个在运行,此时调用 `busy_callback` 生成回调并执行。回调执行完毕后直接返回。 + + `busy_callback` 参数为空,则等待已运行的运行完成。随后执行下面的“冷却”处理逻辑。 + + 当被装饰函数没有在运行的,但冷却时间未结束: + - `cd_callback` 不为空:使用 `cd_callback` 生成回调并执行。 + - `cd_callback` 为空,被装饰函数持续等待,直至冷却结束再执行。 + + 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值 + + :param busy_callback: 已运行时的回调 + :param cd_callback: 冷却时间未结束的回调 + :param interval: 冷却时间 + """ + alock = asyncio.Lock() + pre_finish_t = time.perf_counter() - interval - 1 + + _busy_callback = to_async(busy_callback) if busy_callback is not None else None + _cd_callback = to_async(cd_callback) if cd_callback is not None else None + + def cooldown_wrapper(func: SyncOrAsyncCallable[P, T]) -> AsyncCallable[P, T | U | V]: + _func = to_async(func) + + @wraps(func) + async def cooldown_wrapped(*args: P.args, **kwargs: P.kwargs) -> T | U | V: + nonlocal pre_finish_t + if _busy_callback is not None and alock.locked(): + return await _busy_callback() + + async with alock: + duration = time.perf_counter() - pre_finish_t + if duration > interval: + ret = await _func(*args, **kwargs) + pre_finish_t = time.perf_counter() + return ret + + remain_t = interval - duration + if _cd_callback is not None: + return await _cd_callback(remain_t) + + await asyncio.sleep(remain_t) + ret = await _func(*args, **kwargs) + pre_finish_t = time.perf_counter() + return ret + + return cooldown_wrapped + + return cooldown_wrapper + + +def semaphore( + callback: SyncOrAsyncCallable[[], U] | None = None, value: int = -1 +) -> Callable[[SyncOrAsyncCallable[P, T]], AsyncCallable[P, T | U]]: + """信号量装饰器 + + 本方法作为异步函数的装饰器使用,可以为被装饰函数添加信号量控制。 + + 在信号量无法立刻获取时,将调用 `callback` 获得回调并执行。回调执行完毕后直接返回。 + + `callback` 参数为空,只应用 :class:`asyncio.Semaphore` 的信号量功能。 + + 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值 + + :param callback: 信号量无法立即获取的回调 + :param value: 信号量阈值 + """ + a_semaphore = asyncio.Semaphore(value) + _callback = to_async(callback) if callback is not None else None + + def semaphore_wrapper(func: SyncOrAsyncCallable[P, T]) -> AsyncCallable[P, T | U]: + _func = to_async(func) + + @wraps(func) + async def semaphore_wrapped(*args: P.args, **kwargs: P.kwargs) -> T | U: + if _callback is not None and a_semaphore.locked(): + return await _callback() + async with a_semaphore: + return await _func(*args, **kwargs) + + return semaphore_wrapped + + return semaphore_wrapper + + +def timelimit( + callback: SyncOrAsyncCallable[[], U] | None = None, timeout: float = 5 +) -> Callable[[SyncOrAsyncCallable[P, T]], AsyncCallable[P, T | U]]: + """时间限制装饰器 + + 本方法作为异步函数的装饰器使用,可以为被装饰函数添加超时控制。 + + 超时之后,调用 `callback` 获得回调并执行,同时取消原任务。 + + `callback` 参数为空,如果超时,则抛出 :class:`asyncio.TimeoutError` 异常。 + + 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值 + + :param callback: 超时时的回调 + :param timeout: 超时时间 + """ + _callback = to_async(callback) if callback is not None else None + + def timelimit_wrapper(func: SyncOrAsyncCallable[P, T]) -> AsyncCallable[P, T | U]: + _func = to_async(func) + + @wraps(func) + async def timelimit_wrapped(*args: P.args, **kwargs: P.kwargs) -> T | U: + try: + return await asyncio.wait_for(_func(*args, **kwargs), timeout) + except asyncio.TimeoutError: + if _callback is None: + raise TimeoutError("timelimit 所装饰的任务已超时") from None + return await _callback() + + return timelimit_wrapped + + return timelimit_wrapper + + +def speedlimit( + callback: SyncOrAsyncCallable[[], U] | None = None, + limit: int = 60, + duration: int = 60, +) -> Callable[[SyncOrAsyncCallable[P, T]], AsyncCallable[P, T | U]]: + """流量/速率限制装饰器(使用固定窗口算法) + + 本方法作为异步函数的装饰器使用,可以为被装饰函数添加流量控制:`duration` 秒内只允许 `limit` 次调用。 + + 超出调用速率限制后,调用 `callback` 获得回调并执行,同时取消原任务。 + + `callback` 参数为空,等待直至满足速率控制要求再调用。 + + 被装饰函数的返回值:被装饰函数被执行 -> 被装饰函数返回值;执行任何回调 -> 那个回调的返回值。 + + :param callback: 超出速率限制时的回调 + :param limit: `duration` 秒内允许调用多少次 + :param duration: 时长区间 + """ + called_num = 0 + min_start = time.perf_counter() + if limit <= 0: + raise UtilValidateError("speedlimit 装饰器的 limit 参数必须 > 0") + if duration <= 0: + raise UtilValidateError("speedlimit 装饰器的 duration 参数必须 > 0") + + _callback = to_async(callback) if callback is not None else None + + def speedlimit_wrapper(func: SyncOrAsyncCallable[P, T]) -> AsyncCallable[P, T | U]: + _func = to_async(func) + + @wraps(func) + async def speedlimit_wrapped(*args: P.args, **kwargs: P.kwargs) -> T | U: + fut = _speedlimit_wrapped(_func, *args, **kwargs) + fut = cast(asyncio.Future[T | U | Exception], fut) + fut_ret = await fut + if isinstance(fut_ret, Exception): + raise fut_ret + return fut_ret + + return speedlimit_wrapped + + def _speedlimit_wrapped( + func: AsyncCallable[P, T], + *args: P.args, + **kwargs: P.kwargs, + ) -> asyncio.Future: + # 分离出来定义,方便 result_set 调用形成递归。主要逻辑通过 Future 实现,有利于避免竞争问题。 + nonlocal called_num, min_start + passed_time = time.perf_counter() - min_start + res_fut: Any = asyncio.get_running_loop().create_future() + + if passed_time <= duration: + if called_num < limit: + called_num += 1 + asyncio.create_task( + _speedlimit_set_result(func, res_fut, -1, *args, **kwargs) + ) + + elif _callback is not None: + asyncio.create_task(_speedlimit_set_result(_callback, res_fut, -1)) + + else: + asyncio.create_task( + _speedlimit_set_result( + func, res_fut, duration - passed_time, *args, **kwargs + ) + ) + else: + called_num, min_start = 0, time.perf_counter() + called_num += 1 + asyncio.create_task( + _speedlimit_set_result(func, res_fut, -1, *args, **kwargs) + ) + + return cast(asyncio.Future, res_fut) + + async def _speedlimit_set_result( + func: AsyncCallable[P, T], + fut: asyncio.Future, + delay: float, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """ + 只有依然在当前 duration 区间内,但超出调用次数限制的,需要等待。 + 随后就是递归调用。delay > 0 为需要递归的分支。 + """ + nonlocal called_num + try: + if delay > 0: + await asyncio.sleep(delay) + res = await _speedlimit_wrapped(func, *args, **kwargs) + fut.set_result(res) + return + + res = await func(*args, **kwargs) + fut.set_result(res) + + except Exception as e: + fut.set_result(e) + + return speedlimit_wrapper diff --git a/src/melobot/utils/match/__init__.py b/src/melobot/utils/match/__init__.py new file mode 100644 index 00000000..51fd3a33 --- /dev/null +++ b/src/melobot/utils/match/__init__.py @@ -0,0 +1,9 @@ +from .base import ( + ContainMatcher, + EndMatcher, + FullMatcher, + Matcher, + RegexMatcher, + StartMatcher, + WrappedMatcher, +) diff --git a/src/melobot/protocols/onebot/v11/utils/match.py b/src/melobot/utils/match/base.py similarity index 61% rename from src/melobot/protocols/onebot/v11/utils/match.py rename to src/melobot/utils/match/base.py index 15f6d3c8..5bc75c21 100644 --- a/src/melobot/protocols/onebot/v11/utils/match.py +++ b/src/melobot/utils/match/base.py @@ -1,10 +1,86 @@ +from __future__ import annotations + import re +from abc import abstractmethod +from functools import partial + +from typing_extensions import Any, Callable, Coroutine + +from melobot.exceptions import UtilValidateError +from melobot.typ import BetterABC, LogicMode + + +class Matcher(BetterABC): + """匹配器基类""" -from typing_extensions import Any + def __init__(self) -> None: + super().__init__() + + def __and__(self, other: Matcher) -> WrappedMatcher: + if not isinstance(other, Matcher): + raise UtilValidateError( + f"联合匹配器定义时出现了非匹配器对象,其值为:{other}" + ) + return WrappedMatcher(LogicMode.AND, self, other) + + def __or__(self, other: Matcher) -> WrappedMatcher: + if not isinstance(other, Matcher): + raise UtilValidateError( + f"联合匹配器定义时出现了非匹配器对象,其值为:{other}" + ) + return WrappedMatcher(LogicMode.OR, self, other) + + def __invert__(self) -> WrappedMatcher: + return WrappedMatcher(LogicMode.NOT, self) + + def __xor__(self, other: Matcher) -> WrappedMatcher: + if not isinstance(other, Matcher): + raise UtilValidateError( + f"联合匹配器定义时出现了非匹配器对象,其值为:{other}" + ) + return WrappedMatcher(LogicMode.XOR, self, other) + + @abstractmethod + async def match(self, text: str) -> bool: + """匹配器匹配方法 -from melobot.typ import LogicMode + 任何匹配器应该实现此抽象方法。 + + :param text: 消息事件的文本内容 + :return: 是否匹配 + """ + raise NotImplementedError -from .abc import Matcher + +class WrappedMatcher(Matcher): + """合并匹配器 + + 在两个 :class:`Matcher` 对象间使用 | & ^ ~ 运算符即可返回合并匹配器 + """ + + def __init__( + self, + mode: LogicMode, + matcher1: Matcher, + matcher2: Matcher | None = None, + ) -> None: + """初始化一个合并匹配器 + + :param mode: 合并匹配的逻辑模式 + :param matcher1: 匹配器1 + :param matcher2: 匹配器2 + """ + super().__init__() + self.mode = mode + self.m1, self.m2 = matcher1, matcher2 + + async def match(self, text: str) -> bool: + m2_match: Callable[[], Coroutine[Any, Any, bool]] | None = ( + partial(self.m2.match, text) if self.m2 is not None else None + ) + return await LogicMode.async_short_calc( + self.mode, partial(self.m1.match, text), m2_match + ) class StartMatcher(Matcher): diff --git a/src/melobot/utils/parse/__init__.py b/src/melobot/utils/parse/__init__.py new file mode 100644 index 00000000..a7d0a3ab --- /dev/null +++ b/src/melobot/utils/parse/__init__.py @@ -0,0 +1,2 @@ +from .base import AbstractParseArgs, Parser +from .cmd import CmdArgFormatInfo, CmdArgFormatter, CmdArgs, CmdParser, CmdParserFactory diff --git a/src/melobot/utils/parse/base.py b/src/melobot/utils/parse/base.py new file mode 100644 index 00000000..fe6292be --- /dev/null +++ b/src/melobot/utils/parse/base.py @@ -0,0 +1,40 @@ +from abc import abstractmethod + +from typing_extensions import Any + +from ...typ.cls import BetterABC, abstractattr + + +class AbstractParseArgs: + """解析参数抽象类 + + 子类需要把以下属性按 :func:`.abstractattr` 的要求实现 + """ + + vals: Any = abstractattr() + """解析值 + + :meta hide-value: + """ + + +class Parser(BetterABC): + """解析器基类 + + 解析器一般用作从消息文本中按规则批量提取参数 + """ + + def __init__(self) -> None: + super().__init__() + + @abstractmethod + async def parse(self, text: str) -> AbstractParseArgs | None: + """解析方法 + + 任何解析器应该实现此抽象方法 + + :param text: 消息文本内容 + :return: 解析结果,为空代表没有有效的解析参数 + + """ + raise NotImplementedError diff --git a/src/melobot/protocols/onebot/v11/utils/parse.py b/src/melobot/utils/parse/cmd.py similarity index 83% rename from src/melobot/protocols/onebot/v11/utils/parse.py rename to src/melobot/utils/parse/cmd.py index 86f84075..c66b34c2 100644 --- a/src/melobot/protocols/onebot/v11/utils/parse.py +++ b/src/melobot/utils/parse/cmd.py @@ -1,20 +1,22 @@ import re +from dataclasses import dataclass from functools import lru_cache from types import TracebackType from typing_extensions import Any, Callable, Iterator, Optional -from melobot.exceptions import BotException +from melobot.exceptions import UtilError from melobot.log import get_logger -from melobot.typ import AsyncCallable, VoidType +from melobot.typ import SyncOrAsyncCallable, VoidType +from melobot.utils import to_async -from .abc import ParseArgs, Parser +from .base import AbstractParseArgs, Parser -class ParseError(BotException): ... +class CmdParseError(UtilError): ... -class FormatError(BotException): ... +class FormatError(CmdParseError): ... class ArgValidateFailed(FormatError): ... @@ -23,7 +25,14 @@ class ArgValidateFailed(FormatError): ... class ArgLackError(FormatError): ... -class FormatInfo: +@dataclass +class CmdArgs(AbstractParseArgs): + name: str + tag: str | None + vals: list[Any] + + +class CmdArgFormatInfo: """命令参数格式化信息对象 用于在命令参数格式化异常时传递信息。 @@ -69,9 +78,9 @@ def __init__( src_expect: Optional[str] = None, default: Any = VoidType.VOID, default_replace_flag: Optional[str] = None, - convert_fail: Optional[AsyncCallable[[FormatInfo], None]] = None, - validate_fail: Optional[AsyncCallable[[FormatInfo], None]] = None, - arg_lack: Optional[AsyncCallable[[FormatInfo], None]] = None, + convert_fail: Optional[SyncOrAsyncCallable[[CmdArgFormatInfo], None]] = None, + validate_fail: Optional[SyncOrAsyncCallable[[CmdArgFormatInfo], None]] = None, + arg_lack: Optional[SyncOrAsyncCallable[[CmdArgFormatInfo], None]] = None, ) -> None: """初始化一个命令参数格式化器 @@ -93,15 +102,17 @@ def __init__( self.default = default self.default_replace_flag = default_replace_flag if self.default is VoidType.VOID and self.default_replace_flag is not None: - raise ParseError( + raise CmdParseError( "初始化参数格式化器时,使用“默认值替换标记”必须同时设置默认值" ) - self.convert_fail = convert_fail - self.validate_fail = validate_fail - self.arg_lack = arg_lack + self.convert_fail = to_async(convert_fail) if convert_fail is not None else None + self.validate_fail = ( + to_async(validate_fail) if validate_fail is not None else None + ) + self.arg_lack = to_async(arg_lack) if arg_lack is not None else None - def _get_val(self, args: ParseArgs, idx: int) -> Any: + def _get_val(self, args: CmdArgs, idx: int) -> Any: if self.default is VoidType.VOID: if len(args.vals) < idx + 1: raise ArgLackError @@ -112,12 +123,7 @@ def _get_val(self, args: ParseArgs, idx: int) -> Any: return args.vals[idx] - async def format( - self, - group_id: str, - args: ParseArgs, - idx: int, - ) -> bool: + async def format(self, group_id: str, args: CmdArgs, idx: int) -> bool: # 格式化参数为对应类型的变量 try: src = self._get_val(args, idx) @@ -133,7 +139,7 @@ async def format( return True except ArgValidateFailed as e: - info = FormatInfo( + info = CmdArgFormatInfo( src, self.src_desc, self.src_expect, idx, e, e.__traceback__, group_id ) if self.validate_fail: @@ -143,7 +149,7 @@ async def format( return False except ArgLackError as e: - info = FormatInfo( + info = CmdArgFormatInfo( VoidType.VOID, self.src_desc, self.src_expect, @@ -159,7 +165,7 @@ async def format( return False except Exception as e: - info = FormatInfo( + info = CmdArgFormatInfo( src, self.src_desc, self.src_expect, idx, e, e.__traceback__, group_id ) if self.convert_fail: @@ -168,7 +174,7 @@ async def format( await self._convert_fail_default(info) return False - async def _convert_fail_default(self, info: FormatInfo) -> None: + async def _convert_fail_default(self, info: CmdArgFormatInfo) -> None: e_class = f"{info.exc.__class__.__module__}.{info.exc.__class__.__qualname__}" src = repr(info.src) if isinstance(info.src, str) else info.src @@ -184,7 +190,7 @@ async def _convert_fail_default(self, info: FormatInfo) -> None: tip = f"命令 {info.name} 参数格式化失败:\n{tip}" get_logger().warning(tip) - async def _validate_fail_default(self, info: FormatInfo) -> None: + async def _validate_fail_default(self, info: CmdArgFormatInfo) -> None: src = repr(info.src) if isinstance(info.src, str) else info.src tip = f"第 {info.idx + 1} 个参数" @@ -198,7 +204,7 @@ async def _validate_fail_default(self, info: FormatInfo) -> None: tip = f"命令 {info.name} 参数格式化失败:\n{tip}" get_logger().warning(tip) - async def _arglack_default(self, info: FormatInfo) -> None: + async def _arglack_default(self, info: CmdArgFormatInfo) -> None: tip = f"第 {info.idx + 1} 个参数" tip += f"({info.src_desc})缺失。" if info.src_desc else "缺失。" tip += f"参数要求:{info.src_expect}。" if info.src_expect else "" @@ -248,6 +254,7 @@ def __init__( cmd_sep: str | list[str], targets: str | list[str], fmtters: Optional[list[Optional[CmdArgFormatter]]] = None, + tag: str | None = None, ) -> None: """初始化一个命令解析器 @@ -261,9 +268,11 @@ def __init__( :param cmd_sep: 命令间隔符(可以是字符串或字符串列表) :param targets: 匹配的命令名 :param formatters: 格式化器列表(列表可以包含空值,即此位置的参数无格式化) + :param tag: 标签,此标签将被填充给本解析器产生的 :class:`.CmdArgs` 对象的 `tag` 属性 """ super().__init__() self.targets = targets if isinstance(targets, list) else [targets] + assert len(self.targets) >= 1, "命令解析器至少需要一个目标命令名" self.fmtters = fmtters self.start_tokens = cmd_start if isinstance(cmd_start, list) else [cmd_start] @@ -274,9 +283,10 @@ def __init__( self.cmd_sep: list[str] self.start_regex: re.Pattern[str] self.sep_regex: re.Pattern[str] + self.arg_tag = tag if tag is not None else self.targets[0] if self.ban_regex.findall(f"{''.join(cmd_start)}{''.join(cmd_sep)}"): - raise ParseError("存在命令解析器不支持的命令起始符,或命令间隔符") + raise CmdParseError("存在命令解析器不支持的命令起始符,或命令间隔符") _regex = re.compile( r"([\`\-\=\~\!\@\#\$\%\^\&\*\(\)\_\+\[\]\{\}\|\:\,\.\/\<\>\?])" @@ -288,12 +298,13 @@ def __init__( self.sep_regex = re.compile(rf"{'|'.join(self.cmd_sep)}") self.start_regex = re.compile(rf"{'|'.join(self.cmd_start)}") else: - raise ParseError("命令解析器起始符不能和间隔符重合") + raise CmdParseError("命令解析器起始符不能和间隔符重合") - async def parse(self, text: str) -> Optional[ParseArgs]: + async def parse(self, text: str) -> Optional[CmdArgs]: cmd_dict = _cmd_parse(text, self.start_regex, self.sep_regex) args_dict = { - cmd_name: ParseArgs(cmd_name, vals) for cmd_name, vals in cmd_dict.items() + cmd_name: CmdArgs(cmd_name, self.arg_tag, vals) + for cmd_name, vals in cmd_dict.items() } for group_id in self.targets: @@ -342,10 +353,12 @@ def get( self, targets: str | list[str], formatters: Optional[list[Optional[CmdArgFormatter]]] = None, + tag: str | None = None, ) -> CmdParser: """生成匹配指定命令名的命令解析器 :param targets: 匹配的命令名 :param formatters: 格式化器列表(列表可以包含空值,即此位置的参数无格式化选项) + :param tag: 标签,此标签将被填充给解析器产生的 :class:`.CmdArgs` 对象的 `tag` 属性 """ - return CmdParser(self.cmd_start, self.cmd_sep, targets, formatters) + return CmdParser(self.cmd_start, self.cmd_sep, targets, formatters, tag) diff --git a/tests/onebot/v11/test_adapter_base.py b/tests/onebot/v11/test_adapter_base.py index 46cb251e..7d44ec71 100644 --- a/tests/onebot/v11/test_adapter_base.py +++ b/tests/onebot/v11/test_adapter_base.py @@ -8,7 +8,7 @@ from melobot.plugin import PluginPlanner from melobot.protocols.onebot.v11.adapter.base import Adapter from melobot.protocols.onebot.v11.adapter.event import MessageEvent -from melobot.protocols.onebot.v11.io.base import BaseIO +from melobot.protocols.onebot.v11.io.base import BaseIOSource from melobot.protocols.onebot.v11.io.packet import EchoPacket, InPacket, OutPacket from tests.base import * @@ -41,7 +41,7 @@ _SUCCESS_SIGNAL = asyncio.Event() -class TempIO(BaseIO): +class TempIO(BaseIOSource): def __init__(self) -> None: super().__init__(1) self.queue = Queue() diff --git a/tests/onebot/v11/test_adapter_echo.py b/tests/onebot/v11/test_adapter_echo.py index 29d66912..992f0966 100644 --- a/tests/onebot/v11/test_adapter_echo.py +++ b/tests/onebot/v11/test_adapter_echo.py @@ -1,4 +1,5 @@ from melobot.protocols.onebot.v11.adapter import echo +from melobot.protocols.onebot.v11.const import ACTION_TYPE_KEY_NAME from tests.base import * @@ -10,24 +11,32 @@ def ec(**kv_pairs): return head() | {"data": kv_pairs if len(kv_pairs) else None} +def fake_ec(**kv_pairs): + return ec(**kv_pairs) | {"action_type": ""} + + def li_ec(lis): return head() | {"data": lis} +def fake_li_ec(lis): + return li_ec(lis) | {"action_type": ""} + + async def test_empty(): echo.EmptyEcho(**ec(), action_type="") async def test_other(): assert isinstance( - echo.Echo.resolve({**ec(message_id=123), "action_type": "send_msg"}), + echo.Echo.resolve({**ec(message_id=123), ACTION_TYPE_KEY_NAME: "send_msg"}), echo.SendMsgEcho, ) assert isinstance( echo.Echo.resolve( { **ec(message_id=123, forward_id="abc"), - "action_type": "send_private_forward_msg", + ACTION_TYPE_KEY_NAME: "send_private_forward_msg", } ), echo.SendForwardMsgEcho, @@ -43,18 +52,23 @@ async def test_other(): sender={"user_id": 789, "nickname": "melody", "sex": "male", "age": 18}, message="123[45[CQ:node,user_id=10001000,nickname=某人,content=[CQ:face,id=123]哈喽~]12345", ), - "action_type": "get_msg", + ACTION_TYPE_KEY_NAME: "get_msg", } ).data["message"][0].data == {"text": "123[45"} assert not ( echo.GetMsgEcho( - **ec( + **fake_ec( time=123, message_type="private", message_id=123, real_id=456, - sender={"user_id": 789, "nickname": "melody", "sex": "male", "age": 18}, + sender={ + "user_id": 789, + "nickname": "melody", + "sex": "male", + "age": 18, + }, message=[ {"type": "text", "data": {"text": "12345"}}, { @@ -69,7 +83,7 @@ async def test_other(): }, }, ], - ) + ), ) .data["sender"] .is_group_admin() @@ -77,7 +91,7 @@ async def test_other(): assert ( echo.GetForwardMsgEcho( - **ec( + **fake_ec( message=[ { "type": "node", @@ -99,10 +113,12 @@ async def test_other(): == 123 ) - echo.GetLoginInfoEcho(**ec(user_id=123, nickname="melody")) - echo.GetStrangerInfoEcho(**ec(user_id=123, nickname="melody", sex="male", age=18)) + echo.GetLoginInfoEcho(**fake_ec(user_id=123, nickname="melody")) + echo.GetStrangerInfoEcho( + **fake_ec(user_id=123, nickname="melody", sex="male", age=18) + ) echo.GetFriendListEcho( - **li_ec( + **fake_li_ec( [ {"user_id": 123, "nickname": "melody", "remark": "123"}, {"user_id": 456, "nickname": "jack", "remark": "123456"}, @@ -110,10 +126,10 @@ async def test_other(): ) ) echo.GetGroupInfoEcho( - **ec(group_id=123, group_name="test", member_count=12, max_member_count=100) + **fake_ec(group_id=123, group_name="test", member_count=12, max_member_count=100) ) echo.GetGroupListEcho( - **li_ec( + **fake_li_ec( [ { "group_id": 123, @@ -148,8 +164,8 @@ async def test_other(): "title_expire_time": 1234567890, "card_changeable": True, } - echo.GetGroupMemberInfoEcho(**ec(**base_info)) - echo.GetGroupMemberListEcho(**li_ec([base_info, base_info])) + echo.GetGroupMemberInfoEcho(**fake_ec(**base_info)) + echo.GetGroupMemberListEcho(**fake_li_ec([base_info, base_info])) base_info = { "user_id": 123, @@ -158,7 +174,7 @@ async def test_other(): "description": "123", } echo.GetGroupHonorInfoEcho( - **ec( + **fake_ec( group_id=123, current_talkative={ "user_id": 123, @@ -174,17 +190,17 @@ async def test_other(): ) ) - echo.GetCookiesEcho(**ec(cookies="abc123")) - echo.GetCsrfTokenEcho(**ec(token=123)) - echo.GetCredentialsEcho(**ec(csrf_token=123, cookies="abc123")) - echo.GetRecordEcho(**ec(file="123")) - echo.GetImageEcho(**ec(file="123")) - echo.CanSendRecordEcho(**ec(yes=True)) - echo.CanSendImageEcho(**ec(yes=False)) - assert echo.GetStatusEcho(**ec(online=True, good=True, nice=True)).data["nice"] + echo.GetCookiesEcho(**fake_ec(cookies="abc123")) + echo.GetCsrfTokenEcho(**fake_ec(token=123)) + echo.GetCredentialsEcho(**fake_ec(csrf_token=123, cookies="abc123")) + echo.GetRecordEcho(**fake_ec(file="123")) + echo.GetImageEcho(**fake_ec(file="123")) + echo.CanSendRecordEcho(**fake_ec(yes=True)) + echo.CanSendImageEcho(**fake_ec(yes=False)) + assert echo.GetStatusEcho(**fake_ec(online=True, good=True, nice=True)).data["nice"] assert ( echo.GetVersionInfoEcho( - **ec( + **fake_ec( app_name="lgr", app_version="1.0.0", platform="linux", diff --git a/tests/onebot/v11/test_handle.py b/tests/onebot/v11/test_handle.py index bea62a12..a1e7922b 100644 --- a/tests/onebot/v11/test_handle.py +++ b/tests/onebot/v11/test_handle.py @@ -2,19 +2,15 @@ from asyncio import Queue, create_task from melobot.bot import Bot +from melobot.handle import GetParseArgs, on_start_match from melobot.log import GenericLogger from melobot.plugin import PluginPlanner -from melobot.protocols.onebot.v11 import handle from melobot.protocols.onebot.v11.adapter.base import Adapter from melobot.protocols.onebot.v11.adapter.event import MessageEvent -from melobot.protocols.onebot.v11.io.base import BaseIO +from melobot.protocols.onebot.v11.io.base import BaseIOSource from melobot.protocols.onebot.v11.io.packet import EchoPacket, InPacket, OutPacket -from melobot.protocols.onebot.v11.utils import ( - CmdParser, - GroupMsgChecker, - LevelRole, - ParseArgs, -) +from melobot.protocols.onebot.v11.utils import GroupMsgChecker, LevelRole +from melobot.utils.parse import CmdArgs, CmdParser from tests.base import * _GRUOP_EVENT_DICT = { @@ -44,7 +40,7 @@ } -h = handle.on_start_match( +h = on_start_match( ["123", "456"], checker=GroupMsgChecker( role=LevelRole.WHITE, @@ -59,18 +55,18 @@ @h -async def test_this( +async def _flow( bot: Bot, event: MessageEvent, logger: GenericLogger, - args: ParseArgs = handle.GetParseArgs(), + args: CmdArgs = GetParseArgs(), ) -> None: logger.info(args) await bot.close() _SUCCESS_SIGNAL.set() -class TempIO(BaseIO): +class TempIO(BaseIOSource): def __init__(self) -> None: super().__init__(1) self.queue = Queue() @@ -100,7 +96,7 @@ async def test_adapter_base(): mbot = Bot("test_handle") mbot.add_io(TempIO()) mbot.add_adapter(Adapter()) - mbot.load_plugin(PluginPlanner("1.0.0", flows=[test_this])) + mbot.load_plugin(PluginPlanner("1.0.0", flows=[_flow])) create_task(mbot.core_run()) await mbot._rip_signal.wait() await _SUCCESS_SIGNAL.wait() diff --git a/tests/test_handle_process.py b/tests/test_handle_process.py index dde665d4..a2d517d0 100644 --- a/tests/test_handle_process.py +++ b/tests/test_handle_process.py @@ -1,4 +1,4 @@ -from melobot.handle.process import Flow, FlowNode +from melobot.handle.base import Flow, FlowNode from tests.base import * diff --git a/tests/test_utils.py b/tests/test_utils.py index 980ccbf2..e12aaba0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,9 +3,12 @@ # @Time : 2024/08/26 20:53:04 # @Author : Kariko Lin +import asyncio from enum import Enum from random import choice, randint +from typing_extensions import Any, Coroutine + from melobot.utils import * from tests.base import * @@ -82,7 +85,7 @@ async def test_rwcontrol(cls) -> None: idx = choice(range(0, len(r_seq) + len(rw_seq) - 1)) seq = rw_seq[:idx] + r_seq + rw_seq[idx:] - await aio.wait(map(lambda c: aio.create_task(c), seq)) + await aio.wait(map(aio.create_task, seq)) assert cls.ASYNC_READED diff --git a/tests/onebot/v11/test_utils_match.py b/tests/test_utils_match.py similarity index 95% rename from tests/onebot/v11/test_utils_match.py rename to tests/test_utils_match.py index 037a9703..b6a0766a 100644 --- a/tests/onebot/v11/test_utils_match.py +++ b/tests/test_utils_match.py @@ -1,11 +1,11 @@ -from melobot.protocols.onebot.v11.utils import ( +from melobot.typ import LogicMode +from melobot.utils.match import ( ContainMatcher, EndMatcher, FullMatcher, RegexMatcher, StartMatcher, ) -from melobot.typ import LogicMode from tests.base import * diff --git a/tests/onebot/v11/test_utils_parse.py b/tests/test_utils_parse.py similarity index 94% rename from tests/onebot/v11/test_utils_parse.py rename to tests/test_utils_parse.py index 06c8ebf6..231830cd 100644 --- a/tests/onebot/v11/test_utils_parse.py +++ b/tests/test_utils_parse.py @@ -2,8 +2,8 @@ from melobot.ctx import LoggerCtx from melobot.log import Logger -from melobot.protocols.onebot.v11.utils import CmdArgFormatter as Fmtter -from melobot.protocols.onebot.v11.utils import CmdParserFactory +from melobot.utils.parse import CmdArgFormatter as Fmtter +from melobot.utils.parse import CmdParserFactory from tests.base import * _OUT_BUF = Queue()