Skip to content

Commit

Permalink
rename get_app to get_app_module, add get_app helper
Browse files Browse the repository at this point in the history
prepare state manager iter
  • Loading branch information
benedikt-bartscher committed Aug 9, 2024
1 parent 911c2af commit c93441b
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 14 deletions.
35 changes: 35 additions & 0 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
RouterData,
State,
StateManager,
StateManagerMemory,
StateUpdate,
_substate_key,
code_uses_state_contexts,
Expand Down Expand Up @@ -1100,6 +1101,40 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
sid=state.router.session.session_id,
)

async def modify_states(
self,
substate_cls: Type[BaseState] | None = None,
from_state: BaseState | None = None,
) -> AsyncIterator[BaseState]:
"""Iterate over the states.
Args:
substate_cls: The substate class to iterate over.
from_state: The state from which this method is called.
Yields:
The states to modify.
Raises:
NotImplementedError: If the state manager is not StateManagerMemory
"""
# TODO: Implement for StateManagerRedis
if not isinstance(self.state_manager, StateManagerMemory):
raise NotImplementedError

for token in self.state_manager.states:
# avoid deadlock when calling from event handler/background task
if (
from_state is not None
and from_state.router.session.client_token == token
):
yield from_state
continue
async with self.modify_state(token) as state:
if substate_cls is not None:
state = state.get_substate(substate_cls.get_name())
yield state

def _process_background(
self, state: BaseState, event: Event
) -> asyncio.Task | None:
Expand Down
6 changes: 3 additions & 3 deletions reflex/app_module_for_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from reflex import constants
from reflex.utils import telemetry
from reflex.utils.exec import is_prod_mode
from reflex.utils.prerequisites import get_app
from reflex.utils.prerequisites import get_app_module

if constants.CompileVars.APP != "app":
raise AssertionError("unexpected variable name for 'app'")

telemetry.send("compile")
app_module = get_app(reload=False)
app_module = get_app_module(reload=False)
app = getattr(app_module, constants.CompileVars.APP)
# For py3.8 and py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
Expand All @@ -30,7 +30,7 @@
# ensure only "app" is exposed.
del app_module
del compile_future
del get_app
del get_app_module
del is_prod_mode
del telemetry
del constants
Expand Down
12 changes: 6 additions & 6 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,7 @@ def _as_state_update(
except Exception as ex:
state._clean()

app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance = prerequisites.get_app()

event_specs = app_instance.backend_exception_handler(ex)

Expand Down Expand Up @@ -1541,7 +1541,7 @@ async def _process_event(
except Exception as ex:
telemetry.send_error(ex, context="backend")

app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance = prerequisites.get_app()

event_specs = app_instance.backend_exception_handler(ex)

Expand Down Expand Up @@ -1862,7 +1862,7 @@ def handle_frontend_exception(self, stack: str) -> None:
stack: The stack trace of the exception.
"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance = prerequisites.get_app()
app_instance.frontend_exception_handler(Exception(stack))


Expand Down Expand Up @@ -1901,7 +1901,7 @@ def on_load_internal(self) -> list[Event | EventSpec] | None:
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app = prerequisites.get_app()
load_events = app.get_load_events(self.router.page.path)
if not load_events:
self.is_hydrated = True
Expand Down Expand Up @@ -2044,7 +2044,7 @@ def __init__(
"""
super().__init__(state_instance)
# compile is not relevant to backend logic
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
self._self_app = prerequisites.get_app()
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_actx = None
self._self_mutable = False
Expand Down Expand Up @@ -2857,7 +2857,7 @@ def get_state_manager() -> StateManager:
Returns:
The state manager.
"""
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app = prerequisites.get_app()
return app.state_manager


Expand Down
24 changes: 21 additions & 3 deletions reflex/utils/prerequisites.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from fileinput import FileInput
from pathlib import Path
from types import ModuleType
from typing import Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional

import httpx
import typer
Expand All @@ -39,6 +39,9 @@
from reflex.utils import console, path_ops, processes
from reflex.utils.format import format_library_name

if TYPE_CHECKING:
from reflex.app import App

CURRENTLY_INSTALLING_NODE = False


Expand Down Expand Up @@ -235,7 +238,7 @@ def windows_npm_escape_hatch() -> bool:
return os.environ.get("REFLEX_USE_NPM", "").lower() in ["true", "1", "yes"]


def get_app(reload: bool = False) -> ModuleType:
def get_app_module(reload: bool = False) -> ModuleType:
"""Get the app module based on the default config.
Args:
Expand Down Expand Up @@ -276,6 +279,21 @@ def get_app(reload: bool = False) -> ModuleType:
raise


def get_app(reload: bool = False) -> App:
"""Get the app based on the default config.
Args:
reload: Re-import the app module from disk
Returns:
The app based on the default config.
Raises:
RuntimeError: If the app name is not set in the config.
"""
return getattr(get_app_module(reload=reload), constants.CompileVars.APP)


def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
"""Get the app module based on the default config after first compiling it.
Expand All @@ -286,7 +304,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
Returns:
The compiled app based on the default config.
"""
app_module = get_app(reload=reload)
app_module = get_app_module(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
# For py3.8 and py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def app_module_mock(monkeypatch) -> mock.Mock:
"""
app_module_mock = mock.Mock()
get_app_mock = mock.Mock(return_value=app_module_mock)
monkeypatch.setattr(prerequisites, "get_app", get_app_mock)
monkeypatch.setattr(prerequisites, "get_app_module", get_app_mock)
return app_module_mock


Expand Down
2 changes: 1 addition & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,7 +1817,7 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
def _mock_get_app(*args, **kwargs):
return app_module

monkeypatch.setattr(prerequisites, "get_app", _mock_get_app)
monkeypatch.setattr(prerequisites, "get_app_module", _mock_get_app)
return app


Expand Down

0 comments on commit c93441b

Please sign in to comment.