Skip to content

Commit

Permalink
[HOS-333] Send a "reload" message to the frontend after state expiry (r…
Browse files Browse the repository at this point in the history
…eflex-dev#4442)

* Unit test updates

* test_client_storage: simulate backend state expiry

* [HOS-333] Send a "reload" message to the frontend after state expiry

1. a state instance expires on the backing store
2. frontend attempts to process an event against the expired token and gets a
   fresh instance of the state without router_data set
3. backend sends a "reload" message on the websocket containing the event and
   immediately stops processing
4. in response to the "reload" message, frontend sends
   [hydrate, update client storage, on_load, <previous_event>]

This allows the frontend and backend to re-syncronize on the state of the app
before continuing to process regular events.

If the event in (2) is a special hydrate event, then it is processed normally
by the middleware and the "reload" logic is skipped since this indicates an
initial load or a browser refresh.

* unit tests working with redis
  • Loading branch information
masenf authored Nov 28, 2024
1 parent 24ff29f commit 39cdce6
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 6 deletions.
4 changes: 4 additions & 0 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,10 @@ export const connect = async (
queueEvents(update.events, socket);
}
});
socket.current.on("reload", async (event) => {
event_processing = false;
queueEvents([...initialEvents(), JSON5.parse(event)], socket);
})

document.addEventListener("visibilitychange", checkVisibility);
};
Expand Down
16 changes: 16 additions & 0 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
EventSpec,
EventType,
IndividualEventType,
get_hydrate_event,
window_alert,
)
from reflex.model import Model, get_db_status
Expand Down Expand Up @@ -1259,6 +1260,21 @@ async def process(
)
# Get the state for the session exclusively.
async with app.state_manager.modify_state(event.substate_token) as state:
# When this is a brand new instance of the state, signal the
# frontend to reload before processing it.
if (
not state.router_data
and event.name != get_hydrate_event(state)
and app.event_namespace is not None
):
await asyncio.create_task(
app.event_namespace.emit(
"reload",
data=format.json_dumps(event),
to=sid,
)
)
return
# re-assign only when the value is different
if state.router_data != router_data:
# assignment will recurse into substates and force recalculation of
Expand Down
3 changes: 3 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1959,6 +1959,9 @@ def _update_was_touched(self):
if var in self.base_vars or var in self._backend_vars:
self._was_touched = True
break
if var == constants.ROUTER_DATA and self.parent_state is None:
self._was_touched = True
break

def _get_was_touched(self) -> bool:
"""Check current dirty_vars and flag to determine if state instance was modified.
Expand Down
113 changes: 112 additions & 1 deletion tests/integration/test_client_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webdriver import WebDriver

from reflex.state import (
State,
StateManagerDisk,
StateManagerMemory,
StateManagerRedis,
_substate_key,
)
from reflex.testing import AppHarness

from . import utils
Expand Down Expand Up @@ -74,7 +81,7 @@ def index():
return rx.fragment(
rx.input(
value=ClientSideState.router.session.client_token,
is_read_only=True,
read_only=True,
id="token",
),
rx.input(
Expand Down Expand Up @@ -604,6 +611,110 @@ def set_sub_sub(var: str, value: str):
assert s2.text == "s2 value"
assert s3.text == "s3 value"

# Simulate state expiration
if isinstance(client_side.state_manager, StateManagerRedis):
await client_side.state_manager.redis.delete(
_substate_key(token, State.get_full_name())
)
await client_side.state_manager.redis.delete(_substate_key(token, state_name))
await client_side.state_manager.redis.delete(
_substate_key(token, sub_state_name)
)
await client_side.state_manager.redis.delete(
_substate_key(token, sub_sub_state_name)
)
elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)):
del client_side.state_manager.states[token]
if isinstance(client_side.state_manager, StateManagerDisk):
client_side.state_manager.token_expiration = 0
client_side.state_manager._purge_expired_states()

# Ensure the state is gone (not hydrated)
async def poll_for_not_hydrated():
state = await client_side.get_state(_substate_key(token or "", state_name))
return not state.is_hydrated

assert await AppHarness._poll_for_async(poll_for_not_hydrated)

# Trigger event to get a new instance of the state since the old was expired.
state_var_input = driver.find_element(By.ID, "state_var")
state_var_input.send_keys("re-triggering")

# get new references to all cookie and local storage elements (again)
c1 = driver.find_element(By.ID, "c1")
c2 = driver.find_element(By.ID, "c2")
c3 = driver.find_element(By.ID, "c3")
c4 = driver.find_element(By.ID, "c4")
c5 = driver.find_element(By.ID, "c5")
c6 = driver.find_element(By.ID, "c6")
c7 = driver.find_element(By.ID, "c7")
l1 = driver.find_element(By.ID, "l1")
l2 = driver.find_element(By.ID, "l2")
l3 = driver.find_element(By.ID, "l3")
l4 = driver.find_element(By.ID, "l4")
s1 = driver.find_element(By.ID, "s1")
s2 = driver.find_element(By.ID, "s2")
s3 = driver.find_element(By.ID, "s3")
c1s = driver.find_element(By.ID, "c1s")
l1s = driver.find_element(By.ID, "l1s")
s1s = driver.find_element(By.ID, "s1s")

assert c1.text == "c1 value"
assert c2.text == "c2 value"
assert c3.text == "" # temporary cookie expired after reset state!
assert c4.text == "c4 value"
assert c5.text == "c5 value"
assert c6.text == "c6 value"
assert c7.text == "c7 value"
assert l1.text == "l1 value"
assert l2.text == "l2 value"
assert l3.text == "l3 value"
assert l4.text == "l4 value"
assert s1.text == "s1 value"
assert s2.text == "s2 value"
assert s3.text == "s3 value"
assert c1s.text == "c1s value"
assert l1s.text == "l1s value"
assert s1s.text == "s1s value"

# Get the backend state and ensure the values are still set
async def get_sub_state():
root_state = await client_side.get_state(
_substate_key(token or "", sub_state_name)
)
state = root_state.substates[client_side.get_state_name("_client_side_state")]
sub_state = state.substates[
client_side.get_state_name("_client_side_sub_state")
]
return sub_state

async def poll_for_c1_set():
sub_state = await get_sub_state()
return sub_state.c1 == "c1 value"

assert await AppHarness._poll_for_async(poll_for_c1_set)
sub_state = await get_sub_state()
assert sub_state.c1 == "c1 value"
assert sub_state.c2 == "c2 value"
assert sub_state.c3 == ""
assert sub_state.c4 == "c4 value"
assert sub_state.c5 == "c5 value"
assert sub_state.c6 == "c6 value"
assert sub_state.c7 == "c7 value"
assert sub_state.l1 == "l1 value"
assert sub_state.l2 == "l2 value"
assert sub_state.l3 == "l3 value"
assert sub_state.l4 == "l4 value"
assert sub_state.s1 == "s1 value"
assert sub_state.s2 == "s2 value"
assert sub_state.s3 == "s3 value"
sub_sub_state = sub_state.substates[
client_side.get_state_name("_client_side_sub_sub_state")
]
assert sub_sub_state.c1s == "c1s value"
assert sub_sub_state.l1s == "l1s value"
assert sub_sub_state.s1s == "s1s value"

# clear the cookie jar and local storage, ensure state reset to default
driver.delete_all_cookies()
local_storage.clear()
Expand Down
8 changes: 6 additions & 2 deletions tests/units/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,8 +1007,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
substate_token = _substate_key(token, DynamicState)
sid = "mock_sid"
client_ip = "127.0.0.1"
state = await app.state_manager.get_state(substate_token)
assert state.dynamic == ""
async with app.state_manager.modify_state(substate_token) as state:
state.router_data = {"simulate": "hydrated"}
assert state.dynamic == ""
exp_vals = ["foo", "foobar", "baz"]

def _event(name, val, **kwargs):
Expand Down Expand Up @@ -1180,13 +1181,16 @@ async def test_process_events(mocker, token: str):
"ip": "127.0.0.1",
}
app = App(state=GenState)

mocker.patch.object(app, "_postprocess", AsyncMock())
event = Event(
token=token,
name=f"{GenState.get_name()}.go",
payload={"c": 5},
router_data=router_data,
)
async with app.state_manager.modify_state(event.substate_token) as state:
state.router_data = {"simulate": "hydrated"}

async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
pass
Expand Down
12 changes: 9 additions & 3 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,10 @@ class BackgroundTaskState(BaseState):
order: List[str] = []
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}

def __init__(self, **kwargs): # noqa: D107
super().__init__(**kwargs)
self.router_data = {"simulate": "hydrate"}

@rx.var
def computed_order(self) -> List[str]:
"""Get the order as a computed var.
Expand Down Expand Up @@ -2732,7 +2736,7 @@ class BaseFieldSetterState(BaseState):
assert "c2" in bfss.dirty_vars


def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
def exp_is_hydrated(state: BaseState, is_hydrated: bool = True) -> Dict[str, Any]:
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
Args:
Expand Down Expand Up @@ -2811,7 +2815,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
app = app_module_mock.app = App(
state=State, load_events={"index": [test_state.test_handler]}
)
state = State()
async with app.state_manager.modify_state(_substate_key(token, State)) as state:
state.router_data = {"simulate": "hydrate"}

updates = []
async for update in rx.app.process(
Expand Down Expand Up @@ -2858,7 +2863,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
state=State,
load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
)
state = State()
async with app.state_manager.modify_state(_substate_key(token, State)) as state:
state.router_data = {"simulate": "hydrate"}

updates = []
async for update in rx.app.process(
Expand Down

0 comments on commit 39cdce6

Please sign in to comment.