Skip to content

Commit

Permalink
[REF-1356] Track changes applied to Base subclass via helper method. (
Browse files Browse the repository at this point in the history
  • Loading branch information
masenf authored May 31, 2024
1 parent 5995b32 commit b04e3a6
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
21 changes: 20 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2867,6 +2867,11 @@ class MutableProxy(wrapt.ObjectProxy):
]
)

# These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
__never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
pydantic.BaseModel.__dict__
)

__mutable_types__ = (list, dict, set, Base)

def __init__(self, wrapped: Any, state: BaseState, field_name: str):
Expand Down Expand Up @@ -2916,7 +2921,10 @@ def _wrap_recursive(self, value: Any) -> Any:
Returns:
The wrapped value.
"""
if isinstance(value, self.__mutable_types__):
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if isinstance(value, self.__mutable_types__) and not isinstance(
value, MutableProxy
):
return type(self)(
wrapped=value,
state=self._self_state,
Expand Down Expand Up @@ -2963,6 +2971,17 @@ def __getattr__(self, __name: str) -> Any:
self._wrap_recursive_decorator,
)

if (
isinstance(self.__wrapped__, Base)
and __name not in self.__never_wrap_base_attrs__
and hasattr(value, "__func__")
):
# Wrap methods called on Base subclasses, which might do _anything_
return wrapt.FunctionWrapper(
functools.partial(value.__func__, self),
self._wrap_recursive_decorator,
)

if isinstance(value, self.__mutable_types__) and __name not in (
"__wrapped__",
"_self_state",
Expand Down
65 changes: 65 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,13 +2392,37 @@ class Custom1(Base):

foo: str

def set_foo(self, val: str):
"""Set the attribute foo.
Args:
val: The value to set.
"""
self.foo = val

def double_foo(self) -> str:
"""Concantenate foo with foo.
Returns:
foo + foo
"""
return self.foo + self.foo


class Custom2(Base):
"""A custom class with a Custom1 field."""

c1: Optional[Custom1] = None
c1r: Custom1

def set_c1r_foo(self, val: str):
"""Set the foo attribute of the c1 field.
Args:
val: The value to set.
"""
self.c1r.set_foo(val)


class Custom3(Base):
"""A custom class with a Custom2 field."""
Expand Down Expand Up @@ -2436,6 +2460,47 @@ class UnionState(BaseState):
assert types.is_union(UnionState.int_float._var_type) # type: ignore


def test_set_base_field_via_setter():
"""When calling a setter on a Base instance, also track changes."""

class BaseFieldSetterState(BaseState):
c1: Custom1 = Custom1(foo="")
c2: Custom2 = Custom2(c1r=Custom1(foo=""))

bfss = BaseFieldSetterState()
assert "c1" not in bfss.dirty_vars

# Non-mutating function, not dirty
bfss.c1.double_foo()
assert "c1" not in bfss.dirty_vars

# Mutating function, dirty
bfss.c1.set_foo("bar")
assert "c1" in bfss.dirty_vars
bfss.dirty_vars.clear()
assert "c1" not in bfss.dirty_vars

# Mutating function from Base, dirty
bfss.c1.set(foo="bar")
assert "c1" in bfss.dirty_vars
bfss.dirty_vars.clear()
assert "c1" not in bfss.dirty_vars

# Assert identity of MutableProxy
mp = bfss.c1
assert isinstance(mp, MutableProxy)
mp2 = mp.set()
assert mp is mp2
mp3 = bfss.c1.set()
assert mp is not mp3
# Since none of these set calls had values, the state should not be dirty
assert not bfss.dirty_vars

# Chained Mutating function, dirty
bfss.c2.set_c1r_foo("baz")
assert "c2" in bfss.dirty_vars


def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
Expand Down

0 comments on commit b04e3a6

Please sign in to comment.