Skip to content

Commit

Permalink
[melobot] Fix errors and bugs in session process
Browse files Browse the repository at this point in the history
  • Loading branch information
aicorein committed Nov 27, 2024
1 parent 2c7aad6 commit 12c0b48
Showing 1 changed file with 27 additions and 35 deletions.
62 changes: 27 additions & 35 deletions src/melobot/session/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,10 @@ async def suspend(self, timeout: float | None) -> bool:
return True

except asyncio.TimeoutError:
try:
await self.session.wakeup(self.session.event)
return False
except SessionStateFailed:
if self.session.is_state(WorkingSessionState):
return True
self.session.__to_state__(WorkingSessionState)
return False

async def expire(self) -> None:
self.session.__to_state__(ExpireSessionState)
Expand All @@ -99,6 +98,8 @@ async def expire(self) -> None:
class SuspendSessionState(SessionState):

async def wakeup(self, event: Event) -> None:
if self.session.is_state(WorkingSessionState):
return
self.session.event = event
cond = self.session.wakeup_cond
self.session.__to_state__(WorkingSessionState)
Expand Down Expand Up @@ -133,33 +134,30 @@ def __init__(self, event: Event, rule: Rule | None, keep: bool = False) -> None:
self.keep = keep

self._state: SessionState = WorkingSessionState(self)
self._state_lock = Lock()

def __to_state__(self, state_class: type[SessionState]) -> None:
self._state = state_class(self)

def is_state(self, state_class: type[SessionState]) -> bool:
return isinstance(self._state, state_class)

async def work(self, event: Event) -> None:
async with self._state_lock:
await self._state.work(event)
def mark_expire(self) -> None:
self.keep = False

async def rest(self) -> None:
async with self._state_lock:
await self._state.rest()
async def __work__(self, event: Event) -> None:
await self._state.work(event)

async def suspend(self, timeout: float | None = None) -> bool:
async with self._state_lock:
return await self._state.suspend(timeout)
async def __rest__(self) -> None:
await self._state.rest()

async def wakeup(self, event: Event) -> None:
async with self._state_lock:
await self._state.wakeup(event)
async def __suspend__(self, timeout: float | None = None) -> bool:
return await self._state.suspend(timeout)

async def expire(self) -> None:
async with self._state_lock:
await self._state.expire()
async def __wakeup__(self, event: Event) -> None:
await self._state.wakeup(event)

async def __expire__(self) -> None:
await self._state.expire()

@classmethod
async def get(
Expand All @@ -183,16 +181,13 @@ async def get(
suspends = filter(lambda s: s.is_state(SuspendSessionState), _set)
for session in suspends:
if await rule.compare(session.event, event):
try:
await session.wakeup(event)
except SessionStateFailed:
pass
await session.__wakeup__(event)
return None

spares = filter(lambda s: s.is_state(SpareSessionState), _set)
for session in spares:
if await rule.compare(session.event, event):
await session.work(event)
await session.__work__(event)
session.keep = keep
return session

Expand All @@ -214,14 +209,11 @@ async def get(
pass

elif session.is_state(SuspendSessionState):
try:
await session.wakeup(event)
except SessionStateFailed:
pass
await session.__wakeup__(event)
return None

else:
await session.work(event)
await session.__work__(event)
session.keep = keep
return session

Expand All @@ -230,7 +222,7 @@ async def get(
return session

finally:
expires = filter(lambda s: s.is_state(ExpireSessionState), _set)
expires = tuple(filter(lambda s: s.is_state(ExpireSessionState), _set))
for session in expires:
Session.__instances__[rule].remove(session)

Expand Down Expand Up @@ -258,12 +250,12 @@ async def enter(
yield session
except asyncio.CancelledError:
if session.is_state(SuspendSessionState):
await session.wakeup(session.event)
await session.__wakeup__(session.event)
finally:
if session.keep:
await session.rest()
await session.__rest__()
else:
await session.expire()
await session.__expire__()


async def suspend(timeout: float | None = None) -> bool:
Expand All @@ -272,7 +264,7 @@ async def suspend(timeout: float | None = None) -> bool:
:param timeout: 挂起后再唤醒的超时时间, 为空则永不超时
:return: 如果为 `False` 则表明唤醒超时
"""
return await SessionCtx().get().suspend(timeout)
return await SessionCtx().get().__suspend__(timeout)


def enter_session(
Expand Down

0 comments on commit 12c0b48

Please sign in to comment.