Skip to content

Commit

Permalink
add typed get_substate variant, improve modify_states typing
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher committed Aug 9, 2024
1 parent c93441b commit f4295df
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 21 deletions.
22 changes: 18 additions & 4 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
Optional,
Set,
Type,
TypeVar,
Union,
get_args,
get_type_hints,
overload,
)

from fastapi import FastAPI, HTTPException, Request, UploadFile
Expand Down Expand Up @@ -1101,11 +1103,23 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
sid=state.router.session.session_id,
)

S = TypeVar("S", bound=BaseState)

@overload
async def modify_states(
self, substate_cls: Type[S], from_state: None
) -> AsyncIterator[S]: ...

@overload
async def modify_states(
self, substate_cls: None, from_state: BaseState
) -> AsyncIterator[BaseState]: ...

async def modify_states(
self,
substate_cls: Type[BaseState] | None = None,
substate_cls: Type[S] | Type[BaseState] | None = None,
from_state: BaseState | None = None,
) -> AsyncIterator[BaseState]:
) -> AsyncIterator[S] | AsyncIterator[BaseState]:
"""Iterate over the states.
Args:
Expand All @@ -1128,11 +1142,11 @@ async def modify_states(
from_state is not None
and from_state.router.session.client_token == token
):
yield from_state
state = from_state
continue
async with self.modify_state(token) as state:
if substate_cls is not None:
state = state.get_substate(substate_cls.get_name())
state = state.get_substate(substate_cls)
yield state

def _process_background(
Expand Down
21 changes: 18 additions & 3 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
Sequence,
Set,
Type,
TypeVar,
Union,
cast,
overload,
)

import dill
Expand Down Expand Up @@ -291,6 +293,9 @@ def __call__(self, *args: Any) -> EventSpec:
return super().__call__(*args)


S = TypeVar("S", bound="BaseState")


class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""

Expand Down Expand Up @@ -1151,18 +1156,28 @@ def _reset_client_storage(self):
for substate in self.substates.values():
substate._reset_client_storage()

def get_substate(self, path: Sequence[str]) -> BaseState:
@overload
def get_substate(self, path: Sequence[str]) -> BaseState: ...

@overload
def get_substate(self, path: Type[S]) -> S: ...

def get_substate(self, path: Sequence[str] | Type[S]) -> BaseState | S:
"""Get the substate.
Args:
path: The path to the substate.
path: The path to the substate or the class of the substate.
Returns:
The substate.
Raises:
ValueError: If the substate is not found.
"""
if isinstance(path, type):
path = (
path.get_full_name().removeprefix(f"{self.get_full_name()}.").split(".")
)
if len(path) == 0:
return self
if path[0] == self.get_name():
Expand Down Expand Up @@ -1295,7 +1310,7 @@ def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
root_state = self
else:
root_state = self._get_parent_states()[-1][1]
return root_state.get_substate(state_cls.get_full_name().split("."))
return root_state.get_substate(state_cls)

async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
"""Get a state instance from redis.
Expand Down
3 changes: 0 additions & 3 deletions reflex/utils/prerequisites.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,6 @@ def get_app(reload: bool = False) -> App:
Returns:
The app based on the default config.
Raises:
RuntimeError: If the app name is not set in the config.
"""
return getattr(get_app_module(reload=reload), constants.CompileVars.APP)

Expand Down
20 changes: 9 additions & 11 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def child_state(test_state) -> ChildState:
Returns:
A test child state.
"""
child_state = test_state.get_substate([ChildState.get_name()])
child_state = test_state.get_substate(ChildState)
assert child_state is not None
return child_state

Expand All @@ -233,7 +233,7 @@ def child_state2(test_state) -> ChildState2:
Returns:
A second test child state.
"""
child_state2 = test_state.get_substate([ChildState2.get_name()])
child_state2 = test_state.get_substate(ChildState2)
assert child_state2 is not None
return child_state2

Expand All @@ -248,7 +248,7 @@ def grandchild_state(child_state) -> GrandchildState:
Returns:
A test state.
"""
grandchild_state = child_state.get_substate([GrandchildState.get_name()])
grandchild_state = child_state.get_substate(GrandchildState)
assert grandchild_state is not None
return grandchild_state

Expand Down Expand Up @@ -1183,7 +1183,7 @@ def set_v4(self, v: int):
assert ms.v == 2

# ensure handler can be called from substate (referencing grandparent handler)
ms.get_substate(tuple(SubSubState.get_full_name().split("."))).set_v4(3)
ms.get_substate(SubSubState).set_v4(3)
assert ms.v == 3


Expand Down Expand Up @@ -2854,7 +2854,7 @@ async def test_get_state(mock_app: rx.App, token: str):
)

# Get the child_state2 directly.
child_state2_direct = test_state.get_substate([ChildState2.get_name()])
child_state2_direct = test_state.get_substate(ChildState2)
child_state2_get_state = await test_state.get_state(ChildState2)
# These should be the same object.
assert child_state2_direct is child_state2_get_state
Expand All @@ -2871,15 +2871,13 @@ async def test_get_state(mock_app: rx.App, token: str):
)

# ChildState should be retrievable
child_state_direct = test_state.get_substate([ChildState.get_name()])
child_state_direct = test_state.get_substate(ChildState)
child_state_get_state = await test_state.get_state(ChildState)
# These should be the same object.
assert child_state_direct is child_state_get_state

# GrandchildState instance should be the same as the one retrieved from the child_state2.
assert grandchild_state is child_state_direct.get_substate(
[GrandchildState.get_name()]
)
assert grandchild_state is child_state_direct.get_substate(GrandchildState)
grandchild_state.value2 = "set_value"

assert test_state.get_delta() == {
Expand Down Expand Up @@ -2920,7 +2918,7 @@ async def test_get_state(mock_app: rx.App, token: str):
)

# Set a value on child_state2, should update cached var in grandchild_state2
child_state2 = new_test_state.get_substate((ChildState2.get_name(),))
child_state2 = new_test_state.get_substate(ChildState2)
child_state2.value = "set_c2_value"

assert new_test_state.get_delta() == {
Expand Down Expand Up @@ -3015,7 +3013,7 @@ class GreatGrandchild3(Grandchild3):
assert Child3.get_name() in root.substates # (due to @rx.var)

# Get the unconnected sibling state, which will be used to `get_state` other instances.
child = root.get_substate(Child.get_full_name().split("."))
child = root.get_substate(Child)

# Get an uncached child state.
child2 = await child.get_state(Child2)
Expand Down

0 comments on commit f4295df

Please sign in to comment.