Skip to content

Commit

Permalink
[melobot] Fix session state error after suspend got timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
aicorein committed Nov 26, 2024
1 parent fb68a9a commit 2c7aad6
Showing 1 changed file with 91 additions and 74 deletions.
165 changes: 91 additions & 74 deletions src/melobot/session/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from asyncio import Condition, Lock
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from typing import Any, AsyncGenerator, cast
from typing import Any, AsyncGenerator

from ..adapter.model import Event
from ..ctx import FlowCtx, SessionCtx
Expand All @@ -15,38 +15,36 @@
_SESSION_CTX = SessionCtx()


class SessionStateError(BotException):
def __init__(
self,
cur_state: type[SessionState],
meth: str | None = None,
text: str | None = None,
) -> None:
if text is not None:
super().__init__(f"{text}(当前会话状态:{cur_state.__name__})")
return
class SessionError(BotException): ...

super().__init__(f"当前会话状态 {cur_state.__name__} 不支持的操作:{meth}")

class SessionStateFailed(SessionError):
def __init__(self, cur_state: str, meth: str) -> None:
self.cur_state = cur_state
super().__init__(f"当前会话状态 {cur_state} 不支持的操作:{meth}")


class SessionRuleLacked(SessionError): ...


class SessionState:
def __init__(self, session: "Session") -> None:
self.session = session

async def work(self, event: Event) -> None:
raise SessionStateError(self.__class__, meth=SessionState.work.__name__)
raise SessionStateFailed(self.__class__.__name__, SessionState.work.__name__)

async def rest(self) -> None:
raise SessionStateError(self.__class__, meth=SessionState.rest.__name__)
raise SessionStateFailed(self.__class__.__name__, SessionState.rest.__name__)

async def suspend(self, timeout: float | None) -> bool:
raise SessionStateError(self.__class__, meth=SessionState.suspend.__name__)
raise SessionStateFailed(self.__class__.__name__, SessionState.suspend.__name__)

async def wakeup(self, event: Event) -> None:
raise SessionStateError(self.__class__, meth=SessionState.wakeup.__name__)
raise SessionStateFailed(self.__class__.__name__, SessionState.wakeup.__name__)

async def expire(self) -> None:
raise SessionStateError(self.__class__, meth=SessionState.expire.__name__)
raise SessionStateFailed(self.__class__.__name__, SessionState.expire.__name__)


class SpareSessionState(SessionState):
Expand All @@ -58,52 +56,54 @@ async def work(self, event: Event) -> None:
class WorkingSessionState(SessionState):
async def rest(self) -> None:
if self.session.rule is None:
raise SessionStateError(
WorkingSessionState, text="缺少会话规则,会话无法从“运行态”转为“空闲态”"
)
raise SessionRuleLacked("缺少会话规则,会话无法从“运行态”转为“空闲态”")

cond = self.session.refresh_cond
self.session.__to_state__(SpareSessionState)
async with cond:
cond.notify()
self.session.__to_state__(SpareSessionState)

async def suspend(self, timeout: float | None) -> bool:
if self.session.rule is None:
raise SessionStateError(
WorkingSessionState, text="缺少会话规则,会话无法从“运行态”转为“挂起态”"
)
raise SessionRuleLacked("缺少会话规则,会话无法从“运行态”转为“挂起态”")

cond = self.session.refresh_cond
self.session.__to_state__(SuspendSessionState)
async with cond:
cond.notify()
self.session.__to_state__(SuspendSessionState)

async with self.session.wakeup_cond:
if timeout is None:
await self.session.wakeup_cond.wait()
return True

try:
await asyncio.wait_for(self.session.wakeup_cond.wait(), timeout=timeout)
return True

except asyncio.TimeoutError:
return False
try:
await self.session.wakeup(self.session.event)
return False
except SessionStateFailed:
return True

async def expire(self) -> None:
self.session.__to_state__(ExpireSessionState)
if self.session.rule is not None:
cond = self.session.refresh_cond
async with cond:
cond.notify()
self.session.__to_state__(ExpireSessionState)


class SuspendSessionState(SessionState):

async def wakeup(self, event: Event) -> None:
self.session.event = event
cond = self.session.wakeup_cond
self.session.__to_state__(WorkingSessionState)
async with cond:
cond.notify()
self.session.__to_state__(WorkingSessionState)


class ExpireSessionState(SessionState): ...
Expand Down Expand Up @@ -133,6 +133,7 @@ def __init__(self, event: Event, rule: Rule | None, keep: bool = False) -> None:
self.keep = keep

self._state: SessionState = WorkingSessionState(self)
self._state_lock = Lock()

def __to_state__(self, state_class: type[SessionState]) -> None:
self._state = state_class(self)
Expand All @@ -141,19 +142,24 @@ def is_state(self, state_class: type[SessionState]) -> bool:
return isinstance(self._state, state_class)

async def work(self, event: Event) -> None:
await self._state.work(event)
async with self._state_lock:
await self._state.work(event)

async def rest(self) -> None:
await self._state.rest()
async with self._state_lock:
await self._state.rest()

async def suspend(self, timeout: float | None = None) -> bool:
return await self._state.suspend(timeout)
async with self._state_lock:
return await self._state.suspend(timeout)

async def wakeup(self, event: Event) -> None:
await self._state.wakeup(event)
async with self._state_lock:
await self._state.wakeup(event)

async def expire(self) -> None:
await self._state.expire()
async with self._state_lock:
await self._state.expire()

@classmethod
async def get(
Expand All @@ -171,55 +177,66 @@ async def get(
cls.__instance_locks__.setdefault(rule, Lock())

async with cls.__instance_locks__[rule]:
_set = cls.__instances__.setdefault(rule, set())

suspends = filter(lambda s: s.is_state(SuspendSessionState), _set)
for session in suspends:
if await rule.compare(session.event, event):
await session.wakeup(event)
return None

spares = filter(lambda s: s.is_state(SpareSessionState), _set)
for session in spares:
if await rule.compare(session.event, event):
await session.work(event)
session.keep = keep
return session

workings = filter(lambda s: s.is_state(WorkingSessionState), _set)
expires = list(filter(lambda s: s.is_state(ExpireSessionState), _set))
for session in workings:
if not await rule.compare(session.event, event):
continue

if not wait:
if nowait_cb is not None:
await nowait_cb()
return None

cond = session.refresh_cond
async with cond:
await cond.wait()
if session.is_state(ExpireSessionState):
expires.append(session)
elif session.is_state(SuspendSessionState):
await session.wakeup(event)
try:
_set = cls.__instances__.setdefault(rule, set())

suspends = filter(lambda s: s.is_state(SuspendSessionState), _set)
for session in suspends:
if await rule.compare(session.event, event):
try:
await session.wakeup(event)
except SessionStateFailed:
pass
return None
else:

spares = filter(lambda s: s.is_state(SpareSessionState), _set)
for session in spares:
if await rule.compare(session.event, event):
await session.work(event)
session.keep = keep
return session

for session in expires:
Session.__instances__[cast(Rule, session.rule)].remove(session)
workings = filter(lambda s: s.is_state(WorkingSessionState), _set)
for session in workings:
if not await rule.compare(session.event, event):
continue

if not wait:
if nowait_cb is not None:
await nowait_cb()
return None

cond = session.refresh_cond
async with cond:
await cond.wait()

session = Session(event, rule=rule, keep=keep)
Session.__instances__[rule].add(session)
return session
if session.is_state(ExpireSessionState):
pass

elif session.is_state(SuspendSessionState):
try:
await session.wakeup(event)
except SessionStateFailed:
pass
return None

else:
await session.work(event)
session.keep = keep
return session

session = Session(event, rule=rule, keep=keep)
Session.__instances__[rule].add(session)
return session

finally:
expires = filter(lambda s: s.is_state(ExpireSessionState), _set)
for session in expires:
Session.__instances__[rule].remove(session)

@classmethod
@asynccontextmanager
async def enter_ctx(
async def enter(
cls,
rule: Rule,
wait: bool = True,
Expand Down Expand Up @@ -272,4 +289,4 @@ def enter_session(
:param keep: 会话在退出会话上下文后是否继续保持
:yield: 会话对象
"""
return Session.enter_ctx(rule, wait, nowait_cb, keep)
return Session.enter(rule, wait, nowait_cb, keep)

0 comments on commit 2c7aad6

Please sign in to comment.