From 2c7aad607fa314678844088db90c3379b7d86bf9 Mon Sep 17 00:00:00 2001 From: aicorein Date: Tue, 26 Nov 2024 20:04:21 +0800 Subject: [PATCH] [melobot] Fix session state error after suspend got timeout --- src/melobot/session/base.py | 165 ++++++++++++++++++++---------------- 1 file changed, 91 insertions(+), 74 deletions(-) diff --git a/src/melobot/session/base.py b/src/melobot/session/base.py index 50a60c3..d168731 100644 --- a/src/melobot/session/base.py +++ b/src/melobot/session/base.py @@ -3,7 +3,7 @@ import asyncio from asyncio import Condition, Lock from contextlib import _AsyncGeneratorContextManager, asynccontextmanager -from typing import Any, AsyncGenerator, cast +from typing import Any, AsyncGenerator from ..adapter.model import Event from ..ctx import FlowCtx, SessionCtx @@ -15,18 +15,16 @@ _SESSION_CTX = SessionCtx() -class SessionStateError(BotException): - def __init__( - self, - cur_state: type[SessionState], - meth: str | None = None, - text: str | None = None, - ) -> None: - if text is not None: - super().__init__(f"{text}(当前会话状态:{cur_state.__name__})") - return +class SessionError(BotException): ... - super().__init__(f"当前会话状态 {cur_state.__name__} 不支持的操作:{meth}") + +class SessionStateFailed(SessionError): + def __init__(self, cur_state: str, meth: str) -> None: + self.cur_state = cur_state + super().__init__(f"当前会话状态 {cur_state} 不支持的操作:{meth}") + + +class SessionRuleLacked(SessionError): ... class SessionState: @@ -34,19 +32,19 @@ def __init__(self, session: "Session") -> None: self.session = session async def work(self, event: Event) -> None: - raise SessionStateError(self.__class__, meth=SessionState.work.__name__) + raise SessionStateFailed(self.__class__.__name__, SessionState.work.__name__) async def rest(self) -> None: - raise SessionStateError(self.__class__, meth=SessionState.rest.__name__) + raise SessionStateFailed(self.__class__.__name__, SessionState.rest.__name__) async def suspend(self, timeout: float | None) -> bool: - raise SessionStateError(self.__class__, meth=SessionState.suspend.__name__) + raise SessionStateFailed(self.__class__.__name__, SessionState.suspend.__name__) async def wakeup(self, event: Event) -> None: - raise SessionStateError(self.__class__, meth=SessionState.wakeup.__name__) + raise SessionStateFailed(self.__class__.__name__, SessionState.wakeup.__name__) async def expire(self) -> None: - raise SessionStateError(self.__class__, meth=SessionState.expire.__name__) + raise SessionStateFailed(self.__class__.__name__, SessionState.expire.__name__) class SpareSessionState(SessionState): @@ -58,42 +56,44 @@ async def work(self, event: Event) -> None: class WorkingSessionState(SessionState): async def rest(self) -> None: if self.session.rule is None: - raise SessionStateError( - WorkingSessionState, text="缺少会话规则,会话无法从“运行态”转为“空闲态”" - ) + raise SessionRuleLacked("缺少会话规则,会话无法从“运行态”转为“空闲态”") cond = self.session.refresh_cond + self.session.__to_state__(SpareSessionState) async with cond: cond.notify() - self.session.__to_state__(SpareSessionState) async def suspend(self, timeout: float | None) -> bool: if self.session.rule is None: - raise SessionStateError( - WorkingSessionState, text="缺少会话规则,会话无法从“运行态”转为“挂起态”" - ) + raise SessionRuleLacked("缺少会话规则,会话无法从“运行态”转为“挂起态”") cond = self.session.refresh_cond + self.session.__to_state__(SuspendSessionState) async with cond: cond.notify() - self.session.__to_state__(SuspendSessionState) async with self.session.wakeup_cond: if timeout is None: await self.session.wakeup_cond.wait() return True + try: await asyncio.wait_for(self.session.wakeup_cond.wait(), timeout=timeout) return True + except asyncio.TimeoutError: - return False + try: + await self.session.wakeup(self.session.event) + return False + except SessionStateFailed: + return True async def expire(self) -> None: + self.session.__to_state__(ExpireSessionState) if self.session.rule is not None: cond = self.session.refresh_cond async with cond: cond.notify() - self.session.__to_state__(ExpireSessionState) class SuspendSessionState(SessionState): @@ -101,9 +101,9 @@ class SuspendSessionState(SessionState): async def wakeup(self, event: Event) -> None: self.session.event = event cond = self.session.wakeup_cond + self.session.__to_state__(WorkingSessionState) async with cond: cond.notify() - self.session.__to_state__(WorkingSessionState) class ExpireSessionState(SessionState): ... @@ -133,6 +133,7 @@ def __init__(self, event: Event, rule: Rule | None, keep: bool = False) -> None: self.keep = keep self._state: SessionState = WorkingSessionState(self) + self._state_lock = Lock() def __to_state__(self, state_class: type[SessionState]) -> None: self._state = state_class(self) @@ -141,19 +142,24 @@ def is_state(self, state_class: type[SessionState]) -> bool: return isinstance(self._state, state_class) async def work(self, event: Event) -> None: - await self._state.work(event) + async with self._state_lock: + await self._state.work(event) async def rest(self) -> None: - await self._state.rest() + async with self._state_lock: + await self._state.rest() async def suspend(self, timeout: float | None = None) -> bool: - return await self._state.suspend(timeout) + async with self._state_lock: + return await self._state.suspend(timeout) async def wakeup(self, event: Event) -> None: - await self._state.wakeup(event) + async with self._state_lock: + await self._state.wakeup(event) async def expire(self) -> None: - await self._state.expire() + async with self._state_lock: + await self._state.expire() @classmethod async def get( @@ -171,55 +177,66 @@ async def get( cls.__instance_locks__.setdefault(rule, Lock()) async with cls.__instance_locks__[rule]: - _set = cls.__instances__.setdefault(rule, set()) - - suspends = filter(lambda s: s.is_state(SuspendSessionState), _set) - for session in suspends: - if await rule.compare(session.event, event): - await session.wakeup(event) - return None - - spares = filter(lambda s: s.is_state(SpareSessionState), _set) - for session in spares: - if await rule.compare(session.event, event): - await session.work(event) - session.keep = keep - return session - - workings = filter(lambda s: s.is_state(WorkingSessionState), _set) - expires = list(filter(lambda s: s.is_state(ExpireSessionState), _set)) - for session in workings: - if not await rule.compare(session.event, event): - continue - - if not wait: - if nowait_cb is not None: - await nowait_cb() - return None - - cond = session.refresh_cond - async with cond: - await cond.wait() - if session.is_state(ExpireSessionState): - expires.append(session) - elif session.is_state(SuspendSessionState): - await session.wakeup(event) + try: + _set = cls.__instances__.setdefault(rule, set()) + + suspends = filter(lambda s: s.is_state(SuspendSessionState), _set) + for session in suspends: + if await rule.compare(session.event, event): + try: + await session.wakeup(event) + except SessionStateFailed: + pass return None - else: + + spares = filter(lambda s: s.is_state(SpareSessionState), _set) + for session in spares: + if await rule.compare(session.event, event): await session.work(event) session.keep = keep return session - for session in expires: - Session.__instances__[cast(Rule, session.rule)].remove(session) + workings = filter(lambda s: s.is_state(WorkingSessionState), _set) + for session in workings: + if not await rule.compare(session.event, event): + continue + + if not wait: + if nowait_cb is not None: + await nowait_cb() + return None + + cond = session.refresh_cond + async with cond: + await cond.wait() - session = Session(event, rule=rule, keep=keep) - Session.__instances__[rule].add(session) - return session + if session.is_state(ExpireSessionState): + pass + + elif session.is_state(SuspendSessionState): + try: + await session.wakeup(event) + except SessionStateFailed: + pass + return None + + else: + await session.work(event) + session.keep = keep + return session + + session = Session(event, rule=rule, keep=keep) + Session.__instances__[rule].add(session) + return session + + finally: + expires = filter(lambda s: s.is_state(ExpireSessionState), _set) + for session in expires: + Session.__instances__[rule].remove(session) @classmethod @asynccontextmanager - async def enter_ctx( + async def enter( cls, rule: Rule, wait: bool = True, @@ -272,4 +289,4 @@ def enter_session( :param keep: 会话在退出会话上下文后是否继续保持 :yield: 会话对象 """ - return Session.enter_ctx(rule, wait, nowait_cb, keep) + return Session.enter(rule, wait, nowait_cb, keep)