Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseState.get_var_value helper to get a value from a Var #4553

Merged
merged 3 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
StateSchemaMismatchError,
StateSerializationError,
StateTooLargeError,
UnretrievableVarValueError,
)
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import serializer
Expand Down Expand Up @@ -143,6 +144,9 @@
ValueError,
)

# For BaseState.get_var_value
VAR_TYPE = TypeVar("VAR_TYPE")


def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable
Expand Down Expand Up @@ -1596,6 +1600,42 @@ async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
# Slow case - fetch missing parent states from redis.
return await self._get_state_from_redis(state_cls)

async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE:
"""Get the value of an rx.Var from another state.

Args:
var: The var to get the value for.

Returns:
The value of the var.

Raises:
UnretrievableVarValueError: If the var does not have a literal value
or associated state.
"""
# Oopsie case: you didn't give me a Var... so get what you give.
if not isinstance(var, Var):
return var # type: ignore

# Fast case: this is a literal var and the value is known.
if hasattr(var, "_var_value"):
return var._var_value

var_data = var._get_all_var_data()
if var_data is None or not var_data.state:
raise UnretrievableVarValueError(
f"Unable to retrieve value for {var._js_expr}: not associated with any state."
)
# Fastish case: this var belongs to this state
if var_data.state == self.get_full_name():
return getattr(self, var_data.field_name)

# Slow case: this var belongs to another state
other_state = await self.get_state(
self._get_root_state().get_class_substate(var_data.state)
)
return getattr(other_state, var_data.field_name)

def _get_event_handler(
self, event: Event
) -> tuple[BaseState | StateProxy, EventHandler]:
Expand Down
4 changes: 4 additions & 0 deletions reflex/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:

class InvalidLockWarningThresholdError(ReflexError):
"""Raised when an invalid lock warning threshold is provided."""


class UnretrievableVarValueError(ReflexError):
"""Raised when the value of a var is not retrievable."""
36 changes: 33 additions & 3 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ReflexRuntimeError,
SetUndefinedStateVarError,
StateSerializationError,
UnretrievableVarValueError,
)
from reflex.utils.format import json_dumps
from reflex.vars.base import Var, computed_var
Expand Down Expand Up @@ -115,7 +116,7 @@ class TestState(BaseState):
# Set this class as not test one
__test__ = False

num1: int
num1: rx.Field[int]
num2: float = 3.14
key: str
map_key: str = "a"
Expand Down Expand Up @@ -163,7 +164,7 @@ class ChildState(TestState):
"""A child state fixture."""

value: str
count: int = 23
count: rx.Field[int] = rx.field(23)

def change_both(self, value: str, count: int):
"""Change both the value and count.
Expand Down Expand Up @@ -1663,7 +1664,7 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]:


@pytest.fixture()
def substate_token(state_manager, token):
def substate_token(state_manager, token) -> str:
"""A token + substate name for looking up in state manager.

Args:
Expand Down Expand Up @@ -3764,3 +3765,32 @@ async def test_upcast_event_handler_arg(handler, payload):
state = UpcastState()
async for update in state._process_event(handler, state, payload):
assert update.delta == {UpcastState.get_full_name(): {"passed": True}}


@pytest.mark.asyncio
async def test_get_var_value(state_manager: StateManager, substate_token: str):
"""Test that get_var_value works correctly.

Args:
state_manager: The state manager to use.
substate_token: Token for the substate used by state_manager.
"""
state = await state_manager.get_state(substate_token)

# State Var from same state
assert await state.get_var_value(TestState.num1) == 0
state.num1 = 42
assert await state.get_var_value(TestState.num1) == 42

# State Var from another state
child_state = await state.get_state(ChildState)
assert await state.get_var_value(ChildState.count) == 23
child_state.count = 66
assert await state.get_var_value(ChildState.count) == 66

# LiteralVar with known value
assert await state.get_var_value(rx.Var.create([1, 2, 3])) == [1, 2, 3]

# Generic Var with no state
with pytest.raises(UnretrievableVarValueError):
await state.get_var_value(rx.Var("undefined"))
Loading