Skip to content
Draft
22 changes: 19 additions & 3 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# A special event handler for setting base vars.
setvar: ClassVar[EventHandler]

# Track if computed vars have changed since last serialization
_changed_computed_vars: Set[str] = set()

# Track which computed vars have already been computed
_ready_computed_vars: Set[str] = set()

def __init__(
self,
parent_state: BaseState | None = None,
Expand Down Expand Up @@ -1850,11 +1856,12 @@ def _mark_dirty_computed_vars(self) -> None:
while dirty_vars:
calc_vars, dirty_vars = dirty_vars, set()
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
self.dirty_vars.add(cvar)
dirty_vars.add(cvar)
actual_var = self.computed_vars.get(cvar)
if actual_var is not None:
assert actual_var is not None
if actual_var.has_changed(instance=self):
actual_var.mark_dirty(instance=self)
self.dirty_vars.add(cvar)
dirty_vars.add(cvar)

def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.
Expand Down Expand Up @@ -2134,6 +2141,10 @@ def __getstate__(self):
state["__dict__"].pop("parent_state", None)
state["__dict__"].pop("substates", None)
state["__dict__"].pop("_was_touched", None)
state["__dict__"].pop("_changed_computed_vars", None)
state["__dict__"].pop("_ready_computed_vars", None)
state["__fields_set__"].discard("_changed_computed_vars")
state["__fields_set__"].discard("_ready_computed_vars")
# Remove all inherited vars.
for inherited_var_name in self.inherited_vars:
state["__dict__"].pop(inherited_var_name, None)
Expand All @@ -2150,6 +2161,9 @@ def __setstate__(self, state: dict[str, Any]):
state["__dict__"]["parent_state"] = None
state["__dict__"]["substates"] = {}
super().__setstate__(state)
self._was_touched = False
self._changed_computed_vars = set()
self._ready_computed_vars = set()

def _check_state_size(
self,
Expand Down Expand Up @@ -3131,6 +3145,8 @@ async def get_state(
root_state = self.states.get(client_token)
if root_state is not None:
# Retrieved state from memory.
root_state._changed_computed_vars = set()
root_state._ready_computed_vars = set()
return root_state

# Deserialize root state from disk.
Expand Down
2 changes: 2 additions & 0 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def override(func: Callable) -> Callable:
"_abc_impl",
"_backend_vars",
"_was_touched",
"_changed_computed_vars",
"_ready_computed_vars",
}

if sys.version_info >= (3, 11):
Expand Down
76 changes: 61 additions & 15 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,18 +2022,7 @@ def __get__(self, instance: BaseState | None, owner):
existing_var=self,
)

if not self._cache:
value = self.fget(instance)
else:
# handle caching
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, self.fget(instance))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
value = getattr(instance, self._cache_attr)
value = self.get_value(instance)

if not _isinstance(value, self._var_type):
console.deprecate(
Expand Down Expand Up @@ -2158,14 +2147,71 @@ def _deps(
self_is_top_of_stack = False
return d

def mark_dirty(self, instance) -> None:
def mark_dirty(self, instance: BaseState) -> None:
"""Mark this ComputedVar as dirty.

Args:
instance: the state instance that needs to recompute the value.
"""
with contextlib.suppress(AttributeError):
delattr(instance, self._cache_attr)
instance._ready_computed_vars.discard(self._js_expr)

def already_computed(self, instance: BaseState) -> bool:
"""Check if the ComputedVar has already been computed.

Args:
instance: the state instance that needs to recompute the value.

Returns:
True if the ComputedVar has already been computed, False otherwise.
"""
if self.needs_update(instance):
return False
return self._js_expr in instance._ready_computed_vars

def get_value(self, instance: BaseState) -> RETURN_TYPE:
"""Get the value of the ComputedVar.

Args:
instance: the state instance that needs to recompute the value.

Returns:
The value of the ComputedVar.
"""
if not self._cache:
instance._was_touched = True
new = self.fget(instance)
return new

has_cache = hasattr(instance, self._cache_attr)

if self.already_computed(instance) and has_cache:
return getattr(instance, self._cache_attr)

cache_value = getattr(instance, self._cache_attr, None)
instance._ready_computed_vars.add(self._js_expr)
setattr(instance, self._last_updated_attr, datetime.datetime.now())
new_value = self.fget(instance)
if cache_value != new_value:
instance._changed_computed_vars.add(self._js_expr)
instance._was_touched = True
setattr(instance, self._cache_attr, new_value)
return new_value

def has_changed(self, instance: BaseState) -> bool:
"""Check if the ComputedVar value has changed.

Args:
instance: the state instance that needs to recompute the value.

Returns:
True if the value has changed, False otherwise.
"""
if not self._cache:
return True
if self._js_expr in instance._changed_computed_vars:
return True
# TODO: prime the cache if it's not already? creates side effects and breaks order of computed var execution
return self._js_expr in instance._changed_computed_vars

def _determine_var_type(self) -> Type:
"""Get the type of the var.
Expand Down
27 changes: 27 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3563,6 +3563,33 @@ class DillState(BaseState):
_ = state3._serialize()


def test_pickle():
class PickleState(BaseState):
pass

state = PickleState(_reflex_internal_init=True) # type: ignore

# test computed var cache is persisted
setattr(state, "__cvcached", 1)
state = PickleState._deserialize(state._serialize())
assert getattr(state, "__cvcached", None) == 1

# test ready computed vars set is not persisted
state._ready_computed_vars = {"foo"}
state = PickleState._deserialize(state._serialize())
assert not state._ready_computed_vars

# test that changed computed vars set is not persisted
state._changed_computed_vars = {"foo"}
state = PickleState._deserialize(state._serialize())
assert not state._changed_computed_vars

# test was_touched is not persisted
state._was_touched = True
state = PickleState._deserialize(state._serialize())
assert not state._was_touched


def test_typed_state() -> None:
class TypedState(rx.State):
field: rx.Field[str] = rx.field("")
Expand Down
Loading