diff --git a/reflex/app.py b/reflex/app.py index 7e40a95bf7..c5f1a3fa79 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -79,6 +79,7 @@ RouterData, State, StateManager, + StateManagerMemory, StateUpdate, _substate_key, code_uses_state_contexts, @@ -1100,6 +1101,40 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: sid=state.router.session.session_id, ) + async def modify_states( + self, + substate_cls: Type[BaseState] | None = None, + from_state: BaseState | None = None, + ) -> AsyncIterator[BaseState]: + """Iterate over the states. + + Args: + substate_cls: The substate class to iterate over. + from_state: The state from which this method is called. + + Yields: + The states to modify. + + Raises: + NotImplementedError: If the state manager is not StateManagerMemory + """ + # TODO: Implement for StateManagerRedis + if not isinstance(self.state_manager, StateManagerMemory): + raise NotImplementedError + + for token in self.state_manager.states: + # avoid deadlock when calling from event handler/background task + if ( + from_state is not None + and from_state.router.session.client_token == token + ): + yield from_state + continue + async with self.modify_state(token) as state: + if substate_cls is not None: + state = state.get_substate(substate_cls.get_name()) + yield state + def _process_background( self, state: BaseState, event: Event ) -> asyncio.Task | None: diff --git a/reflex/app_module_for_backend.py b/reflex/app_module_for_backend.py index cae136354b..6a5d930a8f 100644 --- a/reflex/app_module_for_backend.py +++ b/reflex/app_module_for_backend.py @@ -7,13 +7,13 @@ from reflex import constants from reflex.utils import telemetry from reflex.utils.exec import is_prod_mode -from reflex.utils.prerequisites import get_app +from reflex.utils.prerequisites import get_app_module if constants.CompileVars.APP != "app": raise AssertionError("unexpected variable name for 'app'") telemetry.send("compile") -app_module = get_app(reload=False) +app_module = get_app_module(reload=False) app = getattr(app_module, constants.CompileVars.APP) # For py3.8 and py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). @@ -30,7 +30,7 @@ # ensure only "app" is exposed. del app_module del compile_future -del get_app +del get_app_module del is_prod_mode del telemetry del constants diff --git a/reflex/state.py b/reflex/state.py index b0c6646ce9..bf1a454a1d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1467,7 +1467,7 @@ def _as_state_update( except Exception as ex: state._clean() - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) + app_instance = prerequisites.get_app() event_specs = app_instance.backend_exception_handler(ex) @@ -1541,7 +1541,7 @@ async def _process_event( except Exception as ex: telemetry.send_error(ex, context="backend") - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) + app_instance = prerequisites.get_app() event_specs = app_instance.backend_exception_handler(ex) @@ -1862,7 +1862,7 @@ def handle_frontend_exception(self, stack: str) -> None: stack: The stack trace of the exception. """ - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) + app_instance = prerequisites.get_app() app_instance.frontend_exception_handler(Exception(stack)) @@ -1901,7 +1901,7 @@ def on_load_internal(self) -> list[Event | EventSpec] | None: The list of events to queue for on load handling. """ # Do not app._compile()! It should be already compiled by now. - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + app = prerequisites.get_app() load_events = app.get_load_events(self.router.page.path) if not load_events: self.is_hydrated = True @@ -2044,7 +2044,7 @@ def __init__( """ super().__init__(state_instance) # compile is not relevant to backend logic - self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + self._self_app = prerequisites.get_app() self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_actx = None self._self_mutable = False @@ -2857,7 +2857,7 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + app = prerequisites.get_app() return app.state_manager diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 72eb75b90c..474a41a7d7 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -22,7 +22,7 @@ from fileinput import FileInput from pathlib import Path from types import ModuleType -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional import httpx import typer @@ -39,6 +39,9 @@ from reflex.utils import console, path_ops, processes from reflex.utils.format import format_library_name +if TYPE_CHECKING: + from reflex.app import App + CURRENTLY_INSTALLING_NODE = False @@ -235,7 +238,7 @@ def windows_npm_escape_hatch() -> bool: return os.environ.get("REFLEX_USE_NPM", "").lower() in ["true", "1", "yes"] -def get_app(reload: bool = False) -> ModuleType: +def get_app_module(reload: bool = False) -> ModuleType: """Get the app module based on the default config. Args: @@ -276,6 +279,21 @@ def get_app(reload: bool = False) -> ModuleType: raise +def get_app(reload: bool = False) -> App: + """Get the app based on the default config. + + Args: + reload: Re-import the app module from disk + + Returns: + The app based on the default config. + + Raises: + RuntimeError: If the app name is not set in the config. + """ + return getattr(get_app_module(reload=reload), constants.CompileVars.APP) + + def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: """Get the app module based on the default config after first compiling it. @@ -286,7 +304,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: Returns: The compiled app based on the default config. """ - app_module = get_app(reload=reload) + app_module = get_app_module(reload=reload) app = getattr(app_module, constants.CompileVars.APP) # For py3.8 and py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). diff --git a/tests/conftest.py b/tests/conftest.py index 589d35cd71..c291acc1cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,7 +50,7 @@ def app_module_mock(monkeypatch) -> mock.Mock: """ app_module_mock = mock.Mock() get_app_mock = mock.Mock(return_value=app_module_mock) - monkeypatch.setattr(prerequisites, "get_app", get_app_mock) + monkeypatch.setattr(prerequisites, "get_app_module", get_app_mock) return app_module_mock diff --git a/tests/test_state.py b/tests/test_state.py index aa5705b09a..0e98769e0f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1817,7 +1817,7 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: def _mock_get_app(*args, **kwargs): return app_module - monkeypatch.setattr(prerequisites, "get_app", _mock_get_app) + monkeypatch.setattr(prerequisites, "get_app_module", _mock_get_app) return app