Skip to content

Commit

Permalink
implement state iter with StateManagerRedis
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher committed Aug 9, 2024
1 parent ca9df0f commit c5cbbdf
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
15 changes: 3 additions & 12 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
RouterData,
State,
StateManager,
StateManagerMemory,
StateUpdate,
_substate_key,
code_uses_state_contexts,
Expand Down Expand Up @@ -1128,19 +1127,11 @@ async def modify_states(
Yields:
The states to modify.
Raises:
NotImplementedError: If the state manager is not StateManagerMemory
"""
# TODO: Implement for StateManagerRedis
if not isinstance(self.state_manager, StateManagerMemory):
raise NotImplementedError

for token in self.state_manager.states:
async for token in self.state_manager.iter_state_tokens():
# avoid deadlock when calling from event handler/background task
if (
from_state is not None
and from_state.router.session.client_token == token
if from_state is not None and token.startswith(
from_state.router.session.client_token
):
state = from_state
continue
Expand Down
49 changes: 49 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.utils.string import remove_prefix
from reflex.utils.types import override
from reflex.vars import BaseVar, ComputedVar, Var, computed_var

if TYPE_CHECKING:
Expand Down Expand Up @@ -2364,6 +2365,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""
yield self.state()

@abstractmethod
def iter_state_tokens(
self, substate_cls: Type[BaseState] | None = None
) -> AsyncIterator[str]:
"""Iterate over the state names.
Args:
substate_cls: The subclass of BaseState to filter by.
Raises:
NotImplementedError: Always, because this method must be implemented by subclasses.
"""
raise NotImplementedError


class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""
Expand Down Expand Up @@ -2430,6 +2445,21 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
yield state
await self.set_state(token, state)

@override
async def iter_state_tokens(
self, substate_cls: Type[BaseState] | None = None
) -> AsyncIterator[str]:
"""Iterate over the state names.
Args:
substate_cls: The subclass of BaseState to filter by.
Yields:
The state names.
"""
for token in self.states:
yield token


# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
if not isinstance(State.validate.__func__, FunctionType):
Expand Down Expand Up @@ -2748,6 +2778,25 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
yield state
await self.set_state(token, state, lock_id)

@override
async def iter_state_tokens(
self, substate_cls: Type[BaseState] | None = None
) -> AsyncIterator[str]:
"""Iterate over the state names.
Args:
substate_cls: The subclass of BaseState to filter by.
Yields:
The state names.
"""
if substate_cls is None:
substate_cls = self.state
async for token in self.redis.scan_iter(
match=f"*_{substate_cls.get_name()}", _type="STRING"
):
yield token.decode()

@staticmethod
def _lock_key(token: str) -> bytes:
"""Get the redis key for a token's lock.
Expand Down

0 comments on commit c5cbbdf

Please sign in to comment.