From 80995912e81ecb4a501bb393e92767b1e704ed73 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Fri, 12 Jul 2024 20:06:31 +0200 Subject: [PATCH] dynamic state wip --- .../jinja/web/utils/context.js.jinja2 | 31 +- reflex/.templates/web/utils/state.js | 126 +++++--- reflex/app.py | 22 +- reflex/compiler/compiler.py | 3 + reflex/compiler/templates.py | 2 +- reflex/compiler/utils.py | 11 +- reflex/event.py | 13 +- reflex/state.py | 305 +++++++++++++++--- reflex/utils/format.py | 12 +- reflex/vars.py | 4 +- 10 files changed, 406 insertions(+), 123 deletions(-) diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index 5f734a0c3b8..51b495c2b46 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -1,19 +1,27 @@ +{% set all_state_names = [] %} import { createContext, useContext, useMemo, useReducer, useState } from "react" -import { applyDelta, Event, hydrateClientStorage, useEventLoop, refs } from "/utils/state.js" +import { applyDelta, Event, hydrateClientStorage, useEventLoop, refs, createDefaultDict } from "/utils/state.js" {% if initial_state %} +{% set all_state_names = initial_state.keys() | list %} export const initialState = {{ initial_state|json_dumps }} {% else %} export const initialState = {} {% endif %} +{% if initial_state_parametrized %} +{% set all_state_names = all_state_names + initial_state_parametrized.keys() | list %} +export const initialStateParametrized = {% raw %}{{% endraw %}{% for state_name, state in initial_state_parametrized.items() %}"{{state_name}}": createDefaultDict(() => ({{state|json_dumps}})){% if not loop.last %},{% endif %}{% endfor %}} +{% else %} +export const initialStateParametrized = {} +{% endif %} export const defaultColorMode = "{{ default_color_mode }}" export const ColorModeContext = createContext(null); export const UploadFilesContext = createContext(null); export const DispatchContext = createContext(null); export const StateContexts = { - {% for state_name in initial_state %} - {{state_name|var_name}}: createContext(null), + {% for state_name in all_state_names %} + {{state_name|format_state_name}}: createContext(null), {% endfor %} } export const EventLoopContext = createContext(null); @@ -98,25 +106,28 @@ export function EventLoopProvider({ children }) { export function StateProvider({ children }) { {% for state_name in initial_state %} - const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"]) + const [{{state_name|format_state_name}}, dispatch_{{state_name|format_state_name}}] = useReducer(applyDelta, initialState["{{state_name}}"]) + {% endfor %} + {% for state_name in initial_state_parametrized %} + const [{{state_name|format_state_name}}, dispatch_{{state_name|format_state_name}}] = useReducer(applyDelta, initialStateParametrized["{{state_name}}"]) {% endfor %} const dispatchers = useMemo(() => { return { - {% for state_name in initial_state %} - "{{state_name}}": dispatch_{{state_name|var_name}}, + {% for state_name in all_state_names %} + "{{state_name}}": dispatch_{{state_name|format_state_name}}, {% endfor %} } }, []) return ( - {% for state_name in initial_state %} - + {% for state_name in all_state_names %} + {% endfor %} {children} - {% for state_name in initial_state|reverse %} - + {% for state_name in all_state_names | reverse %} + {% endfor %} ) } diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index f67ce68581b..ae5757b7d61 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -117,8 +117,8 @@ export const isStateful = () => { if (event_queue.length === 0) { return false; } - return event_queue.some(event => event.name.startsWith("reflex___state")); -} + return event_queue.some((event) => event.name.startsWith("reflex___state")); +}; /** * Apply a delta to the state. @@ -141,7 +141,7 @@ export const queueEventIfSocketExists = async (events, socket) => { return; } await queueEvents(events, socket); -} +}; /** * Handle frontend event or send the event to the backend via Websocket. @@ -208,7 +208,10 @@ export const applyEvent = async (event, socket) => { const a = document.createElement("a"); a.hidden = true; // Special case when linking to uploaded files - a.href = event.payload.url.replace("${getBackendURL(env.UPLOAD)}", getBackendURL(env.UPLOAD)) + a.href = event.payload.url.replace( + "${getBackendURL(env.UPLOAD)}", + getBackendURL(env.UPLOAD), + ); a.download = event.payload.filename; a.click(); a.remove(); @@ -249,7 +252,7 @@ export const applyEvent = async (event, socket) => { } catch (e) { console.log("_call_script", e); if (window && window?.onerror) { - window.onerror(e.message, null, null, null, e) + window.onerror(e.message, null, null, null, e); } } return false; @@ -272,7 +275,7 @@ export const applyEvent = async (event, socket) => { if (socket) { socket.emit( "event", - JSON.stringify(event, (k, v) => (v === undefined ? null : v)) + JSON.stringify(event, (k, v) => (v === undefined ? null : v)), ); return true; } @@ -290,10 +293,9 @@ export const applyEvent = async (event, socket) => { export const applyRestEvent = async (event, socket) => { let eventSent = false; if (event.handler === "uploadFiles") { - if (event.payload.files === undefined || event.payload.files.length === 0) { // Submit the event over the websocket to trigger the event handler. - return await applyEvent(Event(event.name), socket) + return await applyEvent(Event(event.name), socket); } // Start upload, but do not wait for it, which would block other events. @@ -302,7 +304,7 @@ export const applyRestEvent = async (event, socket) => { event.payload.files, event.payload.upload_id, event.payload.on_upload_progress, - socket + socket, ); return false; } @@ -369,7 +371,7 @@ export const connect = async ( dispatch, transports, setConnectErrors, - client_storage = {} + client_storage = {}, ) => { // Get backend URL object from the endpoint. const endpoint = getBackendURL(EVENTURL); @@ -397,7 +399,7 @@ export const connect = async ( console.log("Disconnect backend before bfcache on navigation"); socket.current.disconnect(); } - } + }; // Once the socket is open, hydrate the page. socket.current.on("connect", () => { @@ -447,7 +449,7 @@ export const uploadFiles = async ( files, upload_id, on_upload_progress, - socket + socket, ) => { // return if there's no file to upload if (files === undefined || files.length === 0) { @@ -530,8 +532,8 @@ export const uploadFiles = async ( * @param handler The client handler to process event. * @returns The event object. */ -export const Event = (name, payload = {}, handler = null) => { - return { name, payload, handler }; +export const Event = (name, payload = {}, handler = null, state_key = null) => { + return { name, payload, handler, state_key }; }; /** @@ -556,7 +558,7 @@ export const hydrateClientStorage = (client_storage) => { for (const state_key in client_storage.local_storage) { const options = client_storage.local_storage[state_key]; const local_storage_value = localStorage.getItem( - options.name || state_key + options.name || state_key, ); if (local_storage_value !== null) { client_storage_values[state_key] = local_storage_value; @@ -567,14 +569,18 @@ export const hydrateClientStorage = (client_storage) => { for (const state_key in client_storage.session_storage) { const session_options = client_storage.session_storage[state_key]; const session_storage_value = sessionStorage.getItem( - session_options.name || state_key + session_options.name || state_key, ); if (session_storage_value != null) { client_storage_values[state_key] = session_storage_value; } } } - if (client_storage.cookies || client_storage.local_storage || client_storage.session_storage) { + if ( + client_storage.cookies || + client_storage.local_storage || + client_storage.session_storage + ) { return client_storage_values; } return {}; @@ -588,7 +594,7 @@ export const hydrateClientStorage = (client_storage) => { const applyClientStorageDelta = (client_storage, delta) => { // find the main state and check for is_hydrated const unqualified_states = Object.keys(delta).filter( - (key) => key.split(".").length === 1 + (key) => key.split(".").length === 1, ); if (unqualified_states.length === 1) { const main_state = delta[unqualified_states[0]]; @@ -614,15 +620,17 @@ const applyClientStorageDelta = (client_storage, delta) => { ) { const options = client_storage.local_storage[state_key]; localStorage.setItem(options.name || state_key, delta[substate][key]); - } else if( + } else if ( client_storage.session_storage && state_key in client_storage.session_storage && typeof window !== "undefined" ) { const session_options = client_storage.session_storage[state_key]; - sessionStorage.setItem(session_options.name || state_key, delta[substate][key]); + sessionStorage.setItem( + session_options.name || state_key, + delta[substate][key], + ); } - } } }; @@ -640,7 +648,7 @@ const applyClientStorageDelta = (client_storage, delta) => { export const useEventLoop = ( dispatch, initial_events = () => [], - client_storage = {} + client_storage = {}, ) => { const socket = useRef(null); const router = useRouter(); @@ -685,36 +693,38 @@ export const useEventLoop = ( query, asPath, }))(router), - })) + })), ); sentHydrate.current = true; } }, [router.isReady]); - // Handle frontend errors and send them to the backend via websocket. - useEffect(() => { - - if (typeof window === 'undefined') { - return; - } - - window.onerror = function (msg, url, lineNo, columnNo, error) { - addEvents([Event(`${exception_state_name}.handle_frontend_exception`, { + // Handle frontend errors and send them to the backend via websocket. + useEffect(() => { + if (typeof window === "undefined") { + return; + } + + window.onerror = function (msg, url, lineNo, columnNo, error) { + addEvents([ + Event(`${exception_state_name}.handle_frontend_exception`, { stack: error.stack, - })]) - return false; - } + }), + ]); + return false; + }; - //NOTE: Only works in Chrome v49+ - //https://github.com/mknichel/javascript-errors?tab=readme-ov-file#promise-rejection-events - window.onunhandledrejection = function (event) { - addEvents([Event(`${exception_state_name}.handle_frontend_exception`, { - stack: event.reason.stack, - })]) - return false; - } - - },[]) + //NOTE: Only works in Chrome v49+ + //https://github.com/mknichel/javascript-errors?tab=readme-ov-file#promise-rejection-events + window.onunhandledrejection = function (event) { + addEvents([ + Event(`${exception_state_name}.handle_frontend_exception`, { + stack: event.reason.stack, + }), + ]); + return false; + }; + }, []); // Main event loop. useEffect(() => { @@ -731,7 +741,7 @@ export const useEventLoop = ( dispatch, ["websocket", "polling"], setConnectErrors, - client_storage + client_storage, ); } (async () => { @@ -764,7 +774,7 @@ export const useEventLoop = ( vars[storage_to_state_map[e.key]] = e.newValue; const event = Event( `${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`, - { vars: vars } + { vars: vars }, ); addEvents([event], e); } @@ -777,11 +787,11 @@ export const useEventLoop = ( // Route after the initial page hydration. useEffect(() => { const change_start = () => { - const main_state_dispatch = dispatch["state"] + const main_state_dispatch = dispatch["state"]; if (main_state_dispatch !== undefined) { - main_state_dispatch({ is_hydrated: false }) + main_state_dispatch({ is_hydrated: false }); } - } + }; const change_complete = () => addEvents(onLoadInternalEvent()); router.events.on("routeChangeStart", change_start); router.events.on("routeChangeComplete", change_complete); @@ -846,7 +856,7 @@ export const getRefValues = (refs) => { return refs.map((ref) => ref.current ? ref.current.value || ref.current.getAttribute("aria-valuenow") - : null + : null, ); }; @@ -865,3 +875,17 @@ export const spreadArraysOrObjects = (first, second) => { throw new Error("Both parameters must be either arrays or objects."); } }; + +export function createDefaultDict(defaultValueFactory) { + return new Proxy( + {}, + { + get: (target, name) => { + if (!(name in target)) { + target[name] = defaultValueFactory(); + } + return target[name]; + }, + }, + ); +} diff --git a/reflex/app.py b/reflex/app.py index 658ba1a1f5c..50561c046a2 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1054,11 +1054,14 @@ def _submit_work(fn, *args, **kwargs): compiler_utils.write_page(output_path, code) @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: + async def modify_state( + self, token: str, key: str | None = None + ) -> AsyncIterator[BaseState]: """Modify the state out of band. Args: token: The token to modify the state for. + key: The key for parameterized states. Yields: The state to modify. @@ -1070,7 +1073,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: raise RuntimeError("App has not been initialized yet.") # Get exclusive access to the state. - async with self.state_manager.modify_state(token) as state: + async with self.state_manager.modify_state(token, key=key) as state: # No other event handler can modify the state while in this context. yield state delta = state.get_delta() @@ -1241,6 +1244,7 @@ async def process( """ from reflex.utils import telemetry + print(f"process {event=} {sid=}") try: # Add request data to the state. router_data = event.router_data @@ -1254,7 +1258,9 @@ async def process( } ) # Get the state for the session exclusively. - async with app.state_manager.modify_state(event.substate_token) as state: + async with app.state_manager.modify_state( + event.substate_token, key=event.state_key + ) as state: # re-assign only when the value is different if state.router_data != router_data: # assignment will recurse into substates and force recalculation of @@ -1279,6 +1285,9 @@ async def process( # Process the event synchronously. async for update in state._process(event): # Postprocess the event. + print( + f"creating state update {type(state).__name__=} {update.events=}" + ) update = await app._postprocess(state, event, update) # Yield the update. @@ -1472,7 +1481,7 @@ async def emit_update(self, update: StateUpdate, sid: str) -> None: self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) ) - async def on_event(self, sid, data): + async def on_event(self, sid: str, data: str): """Event for receiving front-end websocket events. Args: @@ -1481,6 +1490,7 @@ async def on_event(self, sid, data): """ # Get the event. event = Event.parse_raw(data) + print(f"on_event: {sid=} {event=}") self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token @@ -1491,7 +1501,7 @@ async def on_event(self, sid, data): assert environ is not None # Get the client headers. - headers = { + headers: dict[str, str] = { k.decode("utf-8"): v.decode("utf-8") for (k, v) in environ["asgi.scope"]["headers"] } @@ -1507,7 +1517,7 @@ async def on_event(self, sid, data): # Emit the update from processing the event. await self.emit_update(update=update, sid=sid) - async def on_ping(self, sid): + async def on_ping(self, sid: str): """Event for testing the API endpoint. Args: diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 4345e244ff4..d6bf32d4f41 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -87,6 +87,9 @@ def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None) return ( templates.CONTEXT.render( initial_state=utils.compile_state(state), + initial_state_parametrized=utils.compile_state( + state, only_parametrized=True + ), state_name=state.get_name(), client_storage=utils.compile_client_storage(state), is_dev_mode=not is_prod_mode(), diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index c868a0cbb74..626b56885c2 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -19,7 +19,7 @@ def __init__(self) -> None: ) self.filters["json_dumps"] = json_dumps self.filters["react_setter"] = lambda state: f"set{state.capitalize()}" - self.filters["var_name"] = format_state_name + self.filters["format_state_name"] = format_state_name self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE) self.globals["const"] = { "socket": constants.CompileVars.SOCKET, diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 1b69539ac72..e8436a3590a 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -137,22 +137,27 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) } -def compile_state(state: Type[BaseState]) -> dict: +def compile_state(state: Type[BaseState], only_parametrized: bool = False) -> dict: """Compile the state of the app. Args: state: The app state object. + only_parametrized: Whether to include only parametrized states. Returns: A dictionary of the compiled state. """ try: - initial_state = state(_reflex_internal_init=True).dict(initial=True) + initial_state = state(_reflex_internal_init=True).dict( + initial=True, only_parametrized=only_parametrized + ) except Exception as e: console.warn( f"Failed to compile initial state with computed vars, excluding them: {e}" ) - initial_state = state(_reflex_internal_init=True).dict(include_computed=False) + initial_state = state(_reflex_internal_init=True).dict( + initial=True, only_parametrized=only_parametrized, include_computed=False + ) return format.format_state(initial_state) diff --git a/reflex/event.py b/reflex/event.py index 7e98ade38d6..69d244d4260 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -43,6 +43,9 @@ class Event(Base): # The event payload. payload: Dict[str, Any] = {} + # State key + state_key: str = "" + @property def substate_token(self) -> str: """Get the substate token for the event. @@ -135,12 +138,15 @@ class EventHandler(EventActionsMixin): """An event handler responds to an event to update the state.""" # The function to call in response to the event. - fn: Any + fn: Callable[..., Any] # The full name of the state class this event handler is attached to. # Empty string means this event handler is a server side event. state_full_name: str = "" + # The key for parametrized states. + state_key: Optional[Union[Var[str], str]] = None + class Config: """The Pydantic config.""" @@ -959,6 +965,7 @@ def fix_events( events: list[EventHandler | EventSpec] | None, token: str, router_data: dict[str, Any] | None = None, + state_key: str = "", ) -> list[Event]: """Fix a list of events returned by an event handler. @@ -966,6 +973,7 @@ def fix_events( events: The events to fix. token: The user token. router_data: The optional router data to set in the event. + state_key: The key for parametrized states. Returns: The fixed events. @@ -986,7 +994,7 @@ def fix_events( out.append(e) continue if not isinstance(e, (EventHandler, EventSpec)): - e = EventHandler(fn=e) + e = EventHandler(fn=e, state_key=state_key) # Otherwise, create an event from the event spec. if isinstance(e, EventHandler): e = e() @@ -1007,6 +1015,7 @@ def fix_events( name=name, payload=payload, router_data=event_router_data, + state_key=state_key, ) ) diff --git a/reflex/state.py b/reflex/state.py index 49b5bd4a413..6ac315f8679 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -54,9 +54,11 @@ from reflex.utils.exceptions import ImmutableStateError, LockExpiredError from reflex.utils.exec import is_testing_env from reflex.utils.serializers import SerializedType, serialize, serializer +from reflex.utils.types import override from reflex.vars import BaseVar, ComputedVar, Var, computed_var if TYPE_CHECKING: + from reflex.app import App from reflex.components.component import Component @@ -333,7 +335,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): parent_state: Optional[BaseState] = None # The substates of the state. - substates: Dict[str, BaseState] = {} + substates: Dict[str, BaseState | Dict[str, BaseState]] = {} # The set of dirty vars. dirty_vars: Set[str] = set() @@ -359,6 +361,31 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # A special event handler for setting base vars. setvar: ClassVar[EventHandler] + # dynamic state, allows to initialize the same state multiple times for one token + _key: str = "" + + # Whether the state is parametrized. + parametrized: ClassVar[bool] = False + + @classmethod + def __class_getitem__(cls, item: Var) -> ParametrizedState: + """Parametrize the state. + + Args: + item: The Var to parametrize the state with. + + Returns: + The parametrized state. + """ + cls.parametrized = True + return ParametrizedState(state_cls=cls, key=item) + + # alternative to __class_getitem__ + # @classmethod + # def parametrize(cls, key: str) -> ParametrizedState: + # cls.parametrized = True + # return ParametrizedState(state_cls=cls, key=key) + def __init__( self, *args, @@ -383,6 +410,11 @@ def __init__( """ from reflex.utils.exceptions import ReflexRuntimeError + _key = kwargs.get("_key", "") + print(f"init state {self.__class__.__name__} {init_substates=} {_key=}") + if self.parametrized and not _key and not init_substates: + raise ValueError("Parametrized state must be initialized with a key") + if not _reflex_internal_init and not is_testing_env(): raise ReflexRuntimeError( "State classes should not be instantiated directly in a Reflex app. " @@ -393,11 +425,15 @@ def __init__( # Setup the substates (for memory state manager only). if init_substates: - for substate in self.get_substates(): - self.substates[substate.get_name()] = substate( + for substate_cls in self.get_substates(): + substate = substate_cls( parent_state=self, _reflex_internal_init=True, ) + if substate_cls.parametrized: + # hacky way to include parametrized state in initial state dict + substate = {"": substate} + self.substates[substate_cls.get_name()] = substate # Create a fresh copy of the backend variables for this instance self._backend_vars = copy.deepcopy( @@ -917,7 +953,7 @@ def _set_var(cls, prop: BaseVar): setattr(cls, prop._var_name, prop) @classmethod - def _create_event_handler(cls, fn): + def _create_event_handler(cls, fn: Callable[..., Any]) -> EventHandler: """Create an event handler for the given function. Args: @@ -1130,7 +1166,7 @@ def reset(self): setattr(self, prop_name, default) # Recursively reset the substates. - for substate in self.substates.values(): + for substate in self.all_substates: substate.reset() def _reset_client_storage(self): @@ -1147,14 +1183,15 @@ def _reset_client_storage(self): setattr(self, prop_name, copy.deepcopy(field.default)) # Recursively reset the substate client storage. - for substate in self.substates.values(): + for substate in self.all_substates: substate._reset_client_storage() - def get_substate(self, path: Sequence[str]) -> BaseState: + def get_substate(self, path: Sequence[str], key: str | None = None) -> BaseState: """Get the substate. Args: path: The path to the substate. + key: The key for parametrized states. Returns: The substate. @@ -1162,6 +1199,7 @@ def get_substate(self, path: Sequence[str]) -> BaseState: Raises: ValueError: If the substate is not found. """ + print(f"get_substate {self.__class__.__name__} {path=} {key=}") if len(path) == 0: return self if path[0] == self.get_name(): @@ -1170,7 +1208,12 @@ def get_substate(self, path: Sequence[str]) -> BaseState: path = path[1:] if path[0] not in self.substates: raise ValueError(f"Invalid path: {path}") - return self.substates[path[0]].get_substate(path[1:]) + first_state = self.substates[path[0]] + if not isinstance(first_state, BaseState): + if key is None: + raise ValueError("Substate is a dict, but no key was provided.") + first_state = first_state[key] + return first_state.get_substate(path[1:], first_state._key) @classmethod def _get_common_ancestor(cls, other: Type[BaseState]) -> str: @@ -1271,7 +1314,8 @@ async def _populate_parent_states(self, target_state_cls: Type[BaseState]): pass parent_state = await state_manager.get_state( token=_substate_key( - self.router.session.client_token, parent_state_name + self._token, + parent_state_name, ), top_level=False, get_substates=False, @@ -1281,6 +1325,15 @@ async def _populate_parent_states(self, target_state_cls: Type[BaseState]): # Return the direct parent of target_state_cls for subsequent linking. return parent_state + @property + def _token(self) -> str: + """Get the token of the state. + + Returns: + The token of the state. + """ + return self.router.session.client_token + def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState: """Get a state instance from the cache. @@ -1319,10 +1372,11 @@ async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState: "(All states should already be available -- this is likely a bug).", ) return await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), + token=_substate_key(self._token, state_cls), top_level=False, get_substates=True, parent_state=parent_state_of_state_cls, + key=self._key, ) async def get_state(self, state_cls: Type[BaseState]) -> BaseState: @@ -1363,7 +1417,7 @@ def _get_event_handler( # Get the event handler. path = event.name.split(".") path, name = path[:-1], path[-1] - substate = self.get_substate(path) + substate = self.get_substate(path, event.state_key) if not substate: raise ValueError( "The value of state cannot be None when processing an event." @@ -1385,6 +1439,7 @@ async def _process(self, event: Event) -> AsyncIterator[StateUpdate]: Yields: The state update after processing the event. """ + print(f"_process {event=} {type(self).__name__=}") # Get the event handler. substate, handler = self._get_event_handler(event) @@ -1445,19 +1500,24 @@ def _as_state_update( """ # get the delta from the root of the state tree state = self + print(f"get the delta from the root of the state tree: {type(state).__name__=}") while state.parent_state is not None: state = state.parent_state + print(f"get parent state: {type(state).__name__=}") - token = self.router.session.client_token + token = self._token # Convert valid EventHandler and EventSpec into Event - fixed_events = fix_events(self._check_valid(handler, events), token) + fixed_events = fix_events( + self._check_valid(handler, events), token, state_key=self._key + ) try: # Get the delta after processing the event. delta = state.get_delta() state._clean() + # TODO: add StateDelta class? or make StateUpdate.delta a dict? return StateUpdate( delta=delta, events=fixed_events, @@ -1481,6 +1541,7 @@ def _as_state_update( event_specs_correct_type, token, router_data=state.router_data, + state_key=self._key, ) return StateUpdate( events=fixed_events, @@ -1502,6 +1563,9 @@ async def _process_event( """ from reflex.utils import telemetry + print( + f"process event: {handler.fn.__qualname__=} {handler.state_full_name=} {handler.state_key=} {type(state).__name__=} {payload=}" + ) # Get the function to process the event. fn = functools.partial(handler.fn, state) @@ -1645,12 +1709,18 @@ def get_delta(self) -> Delta: if not types.is_backend_base_variable(prop, type(self)) } if len(subdelta) > 0: + if self.parametrized: + subdelta = {self._key: subdelta} delta[self.get_full_name()] = subdelta # Recursively find the substate deltas. - substates = self.substates - for substate in self.dirty_substates.union(self._always_dirty_substates): - delta.update(substates[substate].get_delta()) + for substate_name in self.dirty_substates.union(self._always_dirty_substates): + substate = self.substates[substate_name] + if isinstance(substate, BaseState): + delta.update(substate.get_delta()) + else: + for substate_instance in substate.values(): + delta.update(substate_instance.get_delta()) # Format the delta. delta = format.format_state(delta) @@ -1678,13 +1748,17 @@ def _mark_dirty(self): def _mark_dirty_substates(self): """Propagate dirty var / computed var status into substates.""" - substates = self.substates for var in self.dirty_vars: for substate_name in self._substate_var_dependencies[var]: self.dirty_substates.add(substate_name) - substate = substates[substate_name] - substate.dirty_vars.add(var) - substate._mark_dirty() + substate = self.substates[substate_name] + if isinstance(substate, BaseState): + substate.dirty_vars.add(var) + substate._mark_dirty() + else: + for substate_instance in substate.values(): + substate_instance.dirty_vars.add(var) + substate_instance._mark_dirty() def _update_was_touched(self): """Update the _was_touched flag based on dirty_vars.""" @@ -1714,10 +1788,15 @@ def _clean(self): self._update_was_touched() # Recursively clean the substates. - for substate in self.dirty_substates: - if substate not in self.substates: + for substate_name in self.dirty_substates: + if substate_name not in self.substates: continue - self.substates[substate]._clean() + substate = self.substates[substate_name] + if isinstance(substate, BaseState): + substate._clean() + else: + for substate_instance in substate.values(): + substate_instance._clean() # Clean this state. self.dirty_vars = set() @@ -1739,13 +1818,18 @@ def get_value(self, key: str) -> Any: return super().get_value(key) def dict( - self, include_computed: bool = True, initial: bool = False, **kwargs + self, + include_computed: bool = True, + initial: bool = False, + only_parametrized: bool | None = None, + **kwargs, ) -> dict[str, Any]: """Convert the object to a dictionary. Args: include_computed: Whether to include computed vars. initial: Whether to get the initial value of computed vars. + only_parametrized: Whether to include only parametrized states. **kwargs: Kwargs to pass to the pydantic dict method. Returns: @@ -1761,7 +1845,7 @@ def dict( prop_name: self.get_value(getattr(self, prop_name)) for prop_name in self.base_vars } - if initial: + if initial and include_computed: computed_vars = { # Include initial computed vars. prop_name: ( @@ -1783,17 +1867,47 @@ def dict( else: computed_vars = {} variables = {**base_vars, **computed_vars} - d = { - self.get_full_name(): {k: variables[k] for k in sorted(variables)}, - } + d_variables = {k: variables[k] for k in sorted(variables)} + if initial: + if (only_parametrized and self.parametrized) or ( + not only_parametrized and not self.parametrized + ): + d = {self.get_full_name(): d_variables} + else: + d = {} + else: + if self.parametrized: + d_variables = {self._key: d_variables} + d = {self.get_full_name(): d_variables} for substate_d in [ - v.dict(include_computed=include_computed, initial=initial, **kwargs) - for v in self.substates.values() + v.dict( + include_computed=include_computed, + initial=initial, + only_parametrized=only_parametrized, + **kwargs, + ) + for v in self.all_substates ]: d.update(substate_d) return d + @property + def all_substates(self) -> list[BaseState]: + """Get all substates. + + Returns: + A list of all substates. + """ + substates = [] + for substate in self.substates.values(): + if isinstance(substate, BaseState): + substates.append(substate) + else: + for substate_instance in substate.values(): + substates.append(substate_instance) + return substates + async def __aenter__(self) -> BaseState: """Enter the async context manager protocol. @@ -1838,6 +1952,42 @@ def __getstate__(self): return state +class ParametrizedState: + """A proxy class for parametrized state compile time access.""" + + state_cls: Type[BaseState] + + key: Var[str] | str = "" + + def __init__(self, state_cls: Type[BaseState], key: Var[str]) -> None: + """Initialize the ParametrizedState. + + Args: + state_cls: The class of the state. + key: The key of the parametrized state. + """ + self.state_cls = state_cls + self.key = key + + def __getattribute__(self, name: str) -> Any: + """Get the attribute.""" + if name in ["state_cls", "key"]: + return super().__getattribute__(name) + obj = getattr(self.state_cls, name) + if isinstance(obj, Var) and obj._var_data is not None: + new_var_data = obj._var_data.copy() + new_var_data.state = ( + f"{self.state_cls.get_full_name()}[{self.key._var_full_name}]" + ) + obj = obj._replace(_var_data=new_var_data) + print(f"var access trough parametrized state for: {obj._var_name=}") + if isinstance(obj, EventHandler): + obj = EventHandler( + fn=obj.fn, state_full_name=obj.state_full_name, state_key=self.key + ) + return obj + + EventHandlerSetVar.update_forward_refs() @@ -1909,8 +2059,9 @@ def on_load_internal(self) -> list[Event | EventSpec] | None: return [ *fix_events( load_events, - self.router.session.client_token, + self._token, router_data=self.router_data, + state_key=self._key, ), State.set_is_hydrated(True), # type: ignore ] @@ -2028,7 +2179,15 @@ async def bg_increment(self): self.counter += 1 """ - def __init__(self, state_instance): + if TYPE_CHECKING: + _self_app: App + _self_substate_path: list[str] + _self_substate_key: str | None + _self_actx: contextlib._AsyncGeneratorContextManager[BaseState] | None + _self_mutable: bool + _self_actx_lock: asyncio.Lock + + def __init__(self, state_instance: BaseState): """Create a proxy for a state instance. Args: @@ -2038,6 +2197,7 @@ def __init__(self, state_instance): # compile is not relevant to backend logic self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) self._self_substate_path = state_instance.get_full_name().split(".") + self._self_substate_key = state_instance._key self._self_actx = None self._self_mutable = False self._self_actx_lock = asyncio.Lock() @@ -2059,11 +2219,15 @@ async def __aenter__(self) -> StateProxy: token=_substate_key( self.__wrapped__.router.session.client_token, self._self_substate_path, - ) + ), + key=self._self_substate_key, ) mutable_state = await self._self_actx.__aenter__() super().__setattr__( - "__wrapped__", mutable_state.get_substate(self._self_substate_path) + "__wrapped__", + mutable_state.get_substate( + path=self._self_substate_path, key=self._self_substate_key + ), ) self._self_mutable = True return self @@ -2265,11 +2429,12 @@ def create(cls, state: Type[BaseState]): return StateManagerMemory(state=state) @abstractmethod - async def get_state(self, token: str) -> BaseState: + async def get_state(self, token: str, key: str | None = None) -> BaseState: """Get the state for a token. Args: token: The token to get the state for. + key: The key for parameterized states. Returns: The state for the token. @@ -2288,11 +2453,14 @@ async def set_state(self, token: str, state: BaseState): @abstractmethod @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: + async def modify_state( + self, token: str, key: str | None = None + ) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: token: The token to modify the state for. + key: The key for parameterized states. Yields: The state for the token. @@ -2319,11 +2487,13 @@ class Config: "_states_locks": {"exclude": True}, } - async def get_state(self, token: str) -> BaseState: + @override + async def get_state(self, token: str, key: str | None = None) -> BaseState: """Get the state for a token. Args: token: The token to get the state for. + key: The key for parameterized states. Returns: The state for the token. @@ -2334,6 +2504,7 @@ async def get_state(self, token: str) -> BaseState: self.states[token] = self.state(_reflex_internal_init=True) return self.states[token] + @override async def set_state(self, token: str, state: BaseState): """Set the state for a token. @@ -2343,12 +2514,16 @@ async def set_state(self, token: str, state: BaseState): """ pass + @override @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: + async def modify_state( + self, token: str, key: str | None = None + ) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: token: The token to modify the state for. + key: The key for parameterized states. Yields: The state for the token. @@ -2484,11 +2659,15 @@ async def _populate_substates( state: The state instance to populate substates for. all_substates: Whether to fetch all substates or just required substates. """ + print( + f"populate_substates: {token=}, {type(state).__name__=}, {all_substates=}" + ) client_token, _ = _split_substate_key(token) if all_substates: # All substates are requested. - fetch_substates = state.get_substates() + # fetch_substates = state.get_substates() + fetch_substates = state.substates else: # Only _potentially_dirty_substates need to be fetched to recalc computed vars. fetch_substates = state._potentially_dirty_substates() @@ -2507,11 +2686,19 @@ async def _populate_substates( ) for substate_name, substate_task in tasks.items(): - state.substates[substate_name] = await substate_task + substate = await substate_task + if substate.parametrized: + d = state.substates.get(substate_name, {}) + d[substate._key] = substate + state.substates[substate_name] = d + else: + state.substates[substate_name] = substate + @override async def get_state( self, token: str, + key: str | None = None, top_level: bool = True, get_substates: bool = True, parent_state: BaseState | None = None, @@ -2520,6 +2707,7 @@ async def get_state( Args: token: The token to get the state for. + key: The key for parameterized states. top_level: If true, return an instance of the top-level state (self.state). get_substates: If true, also retrieve substates. parent_state: If provided, use this parent_state instead of getting it from redis. @@ -2540,7 +2728,11 @@ async def get_state( "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}" ) + if key: + token = f"{token}_{key}" + # Fetch the serialized substate from redis. + print(f"redis get: {token=}") redis_state = await self.redis.get(token) if redis_state is not None: @@ -2552,7 +2744,12 @@ async def get_state( parent_state = await self._get_parent_state(token) # Set up Bidirectional linkage between this state and its parent. if parent_state is not None: - parent_state.substates[state.get_name()] = state + if state.parametrized: + d = parent_state.substates.get(state.get_name(), {}) + d[state._key] = state + parent_state.substates[state.get_name()] = d + else: + parent_state.substates[state.get_name()] = state state.parent_state = parent_state # Populate substates if requested. await self._populate_substates(token, state, all_substates=get_substates) @@ -2572,10 +2769,16 @@ async def get_state( parent_state=parent_state, init_substates=False, _reflex_internal_init=True, + _key=key, ) # Set up Bidirectional linkage between this state and its parent. if parent_state is not None: - parent_state.substates[state.get_name()] = state + if state.parametrized: + d = parent_state.substates.get(state.get_name(), {}) + d[state._key] = state + parent_state.substates[state.get_name()] = d + else: + parent_state.substates[state.get_name()] = state state.parent_state = parent_state # Populate substates for the newly created state. await self._populate_substates(token, state, all_substates=get_substates) @@ -2608,6 +2811,7 @@ def _warn_if_too_large( ) self._warned_about_state_size.add(state_full_name) + @override async def set_state( self, token: str, @@ -2644,7 +2848,7 @@ async def set_state( # Recursively set_state on all known substates. tasks = [] - for substate in state.substates.values(): + for substate in state.all_substates: tasks.append( asyncio.create_task( self.set_state( @@ -2658,8 +2862,12 @@ async def set_state( if state._get_was_touched(): pickle_state = dill.dumps(state, byref=True) self._warn_if_too_large(state, len(pickle_state)) + redis_token = _substate_key(client_token, state) + if state.parametrized: + redis_token = f"{redis_token}_{state._key}" + print(f"redis set: {redis_token=} {type(state).__name__}") await self.redis.set( - _substate_key(client_token, state), + redis_token, pickle_state, ex=self.token_expiration, ) @@ -2668,18 +2876,23 @@ async def set_state( for t in tasks: await t + @override @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: + async def modify_state( + self, token: str, key: str | None = None + ) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: token: The token to modify the state for. + key: The key for parameterized states. Yields: The state for the token. """ + print(f"modify_state: {token=} {key=}") async with self._lock(token) as lock_id: - state = await self.get_state(token) + state = await self.get_state(token=token, key=key) yield state await self.set_state(token, state, lock_id) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index e163ebaac6b..ecfe67d4a9f 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -567,7 +567,15 @@ def format_event(event_spec: EventSpec) -> str: event_args.append(wrap(args, "{")) if event_spec.client_handler_name: - event_args.append(wrap(event_spec.client_handler_name, '"')) + client_handler_name = wrap(event_spec.client_handler_name, '"') + event_args.append(client_handler_name) + + if event_spec.handler.state_key is not None: + if not event_spec.client_handler_name: + event_args.append("null") + state_key = event_spec.handler.state_key._var_full_name + event_args.append(state_key) + return f"Event({', '.join(event_args)})" @@ -721,7 +729,7 @@ def format_state(value: Any, key: Optional[str] = None) -> Any: if isinstance(value, dict): return {k: format_state(v, k) for k, v in value.items()} - # Handle lists, sets, typles. + # Handle lists, sets, tuples. if isinstance(value, types.StateIterBases): return [format_state(v) for v in value] diff --git a/reflex/vars.py b/reflex/vars.py index c6ad4eed58e..f229c21a574 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1948,11 +1948,11 @@ def _var_full_name(self) -> str: ) ) - def _var_set_state(self, state: Type[BaseState] | str) -> Any: + def _var_set_state(self, state: Type[BaseState]) -> Any: """Set the state of the var. Args: - state: The state to set or the full name of the state. + state: The state to set. Returns: The var with the set state.