From 592db8cdcad53c016c2ce418c9c27646ec095c5d Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 17 Dec 2024 17:57:25 -0800 Subject: [PATCH 1/3] BaseState.get_var_value helper to get a value from a Var When given a state Var or a LiteralVar, retrieve the actual value associated with the Var. For state Vars, the returned value is directly tied to the associated state and can be modified. Modifying LiteralVar values or ComputedVar values will have no useful effect. --- reflex/state.py | 31 +++++++++++++++++++++++++++++++ reflex/utils/exceptions.py | 4 ++++ tests/units/test_state.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/reflex/state.py b/reflex/state.py index e7e6bcf326..bd656b3882 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -107,6 +107,7 @@ StateSchemaMismatchError, StateSerializationError, StateTooLargeError, + UnretrievableVarValueError, ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer @@ -1596,6 +1597,36 @@ async def get_state(self, state_cls: Type[BaseState]) -> BaseState: # Slow case - fetch missing parent states from redis. return await self._get_state_from_redis(state_cls) + async def get_var_value(self, var: Var) -> Any: + """Get the value of an rx.Var from another state. + + Args: + var: The var to get the value for. + + Returns: + The value of the var. + + Raises: + UnretrievableVarValueError: If the var does not have a literal value + or associated state. + """ + # Fast case: this is a literal var and the value is known. + if hasattr(var, "_var_value"): + return var._var_value + var_data = var._get_all_var_data() + if var_data is None or not var_data.state: + raise UnretrievableVarValueError( + f"Unable to retrieve value for {var._js_expr}: not associated with any state." + ) + # Fastish case: this var belongs to this state + if var_data.state == self.get_full_name(): + return getattr(self, var_data.field_name) + # Slow case: this var belongs to another state + other_state = await self.get_state( + self._get_root_state().get_class_substate(var_data.state) + ) + return getattr(other_state, var_data.field_name) + def _get_event_handler( self, event: Event ) -> tuple[BaseState | StateProxy, EventHandler]: diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index ae5ec01683..bceadc977e 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -187,3 +187,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn: class InvalidLockWarningThresholdError(ReflexError): """Raised when an invalid lock warning threshold is provided.""" + + +class UnretrievableVarValueError(ReflexError): + """Raised when the value of a var is not retrievable.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 912d72f4f1..2176e828f0 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -60,6 +60,7 @@ ReflexRuntimeError, SetUndefinedStateVarError, StateSerializationError, + UnretrievableVarValueError, ) from reflex.utils.format import json_dumps from reflex.vars.base import Var, computed_var @@ -3764,3 +3765,32 @@ async def test_upcast_event_handler_arg(handler, payload): state = UpcastState() async for update in state._process_event(handler, state, payload): assert update.delta == {UpcastState.get_full_name(): {"passed": True}} + + +@pytest.mark.asyncio +async def test_get_var_value(state_manager, token): + """Test that get_var_value works correctly. + + Args: + state_manager: The state manager to use. + token: A token. + """ + state = await state_manager.get_state(_substate_key(token, TestState)) + + # State Var from same state + assert await state.get_var_value(TestState.num1) == 0 + state.num1 = 42 + assert await state.get_var_value(TestState.num1) == 42 + + # State Var from another state + child_state = await state.get_state(ChildState) + assert await state.get_var_value(ChildState.count) == 23 + child_state.count = 66 + assert await state.get_var_value(ChildState.count) == 66 + + # LiteralVar with known value + assert await state.get_var_value(rx.Var.create([1, 2, 3])) == [1, 2, 3] + + # Generic Var with no state + with pytest.raises(UnretrievableVarValueError): + await state.get_var_value(rx.Var("undefined")) From a52479290a29309fed7b91d90af302c54391f15e Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 18 Dec 2024 17:07:31 -0800 Subject: [PATCH 2/3] Use Var[VAR_TYPE] annotation to take advantage of generics This requires rx.Field to pass typing where used. --- reflex/state.py | 5 ++++- tests/units/test_state.py | 12 ++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index bd656b3882..714674e65e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -144,6 +144,9 @@ ValueError, ) +# For BaseState.get_var_value +VAR_TYPE = TypeVar("VAR_TYPE", bound=Var) + def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable @@ -1597,7 +1600,7 @@ async def get_state(self, state_cls: Type[BaseState]) -> BaseState: # Slow case - fetch missing parent states from redis. return await self._get_state_from_redis(state_cls) - async def get_var_value(self, var: Var) -> Any: + async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE: """Get the value of an rx.Var from another state. Args: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2176e828f0..c1780b4f04 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -116,7 +116,7 @@ class TestState(BaseState): # Set this class as not test one __test__ = False - num1: int + num1: rx.Field[int] num2: float = 3.14 key: str map_key: str = "a" @@ -164,7 +164,7 @@ class ChildState(TestState): """A child state fixture.""" value: str - count: int = 23 + count: rx.Field[int] = rx.field(23) def change_both(self, value: str, count: int): """Change both the value and count. @@ -1664,7 +1664,7 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]: @pytest.fixture() -def substate_token(state_manager, token): +def substate_token(state_manager, token) -> str: """A token + substate name for looking up in state manager. Args: @@ -3768,14 +3768,14 @@ async def test_upcast_event_handler_arg(handler, payload): @pytest.mark.asyncio -async def test_get_var_value(state_manager, token): +async def test_get_var_value(state_manager: StateManager, substate_token: str): """Test that get_var_value works correctly. Args: state_manager: The state manager to use. - token: A token. + substate_token: Token for the substate used by state_manager. """ - state = await state_manager.get_state(_substate_key(token, TestState)) + state = await state_manager.get_state(substate_token) # State Var from same state assert await state.get_var_value(TestState.num1) == 0 From c93119a901e9b4737b144d9f742b4164db2fc736 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 19 Dec 2024 16:45:04 -0800 Subject: [PATCH 3/3] Add case where get_var_value gets something that's not a var --- reflex/state.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/reflex/state.py b/reflex/state.py index 714674e65e..65bc71d904 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -145,7 +145,7 @@ ) # For BaseState.get_var_value -VAR_TYPE = TypeVar("VAR_TYPE", bound=Var) +VAR_TYPE = TypeVar("VAR_TYPE") def _no_chain_background_task( @@ -1613,9 +1613,14 @@ async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE: UnretrievableVarValueError: If the var does not have a literal value or associated state. """ + # Oopsie case: you didn't give me a Var... so get what you give. + if not isinstance(var, Var): + return var # type: ignore + # Fast case: this is a literal var and the value is known. if hasattr(var, "_var_value"): return var._var_value + var_data = var._get_all_var_data() if var_data is None or not var_data.state: raise UnretrievableVarValueError( @@ -1624,6 +1629,7 @@ async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE: # Fastish case: this var belongs to this state if var_data.state == self.get_full_name(): return getattr(self, var_data.field_name) + # Slow case: this var belongs to another state other_state = await self.get_state( self._get_root_state().get_class_substate(var_data.state)