diff --git a/integration/test_component_state.py b/integration/test_component_state.py index e903a1b7425..77b8b3fa183 100644 --- a/integration/test_component_state.py +++ b/integration/test_component_state.py @@ -79,8 +79,7 @@ async def test_component_state_app(component_state_app: AppHarness): driver = component_state_app.frontend() ss = utils.SessionStorage(driver) - token = AppHarness._poll_for(lambda: ss.get("token") is not None) - assert token is not None + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" count_a = driver.find_element(By.ID, "count-a") count_b = driver.find_element(By.ID, "count-b") diff --git a/integration/test_lifespan.py b/integration/test_lifespan.py new file mode 100644 index 00000000000..cb384a51165 --- /dev/null +++ b/integration/test_lifespan.py @@ -0,0 +1,120 @@ +"""Test cases for the FastAPI lifespan integration.""" +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import AppHarness + +from .utils import SessionStorage + + +def LifespanApp(): + """App with lifespan tasks and context.""" + import asyncio + from contextlib import asynccontextmanager + + import reflex as rx + + lifespan_task_global = 0 + lifespan_context_global = 0 + + @asynccontextmanager + async def lifespan_context(app, inc: int = 1): + global lifespan_context_global + print(f"Lifespan context entered: {app}.") + lifespan_context_global += inc # pyright: ignore[reportUnboundVariable] + try: + yield + finally: + print("Lifespan context exited.") + lifespan_context_global += inc + + async def lifespan_task(inc: int = 1): + global lifespan_task_global + print("Lifespan global started.") + try: + while True: + lifespan_task_global += inc # pyright: ignore[reportUnboundVariable] + await asyncio.sleep(0.1) + except asyncio.CancelledError as ce: + print(f"Lifespan global cancelled: {ce}.") + lifespan_task_global = 0 + + class LifespanState(rx.State): + @rx.var + def task_global(self) -> int: + return lifespan_task_global + + @rx.var + def context_global(self) -> int: + return lifespan_context_global + + def tick(self, date): + pass + + def index(): + return rx.vstack( + rx.text(LifespanState.task_global, id="task_global"), + rx.text(LifespanState.context_global, id="context_global"), + rx.moment(interval=100, on_change=LifespanState.tick), + ) + + app = rx.App() + app.register_lifespan_task(lifespan_task) + app.register_lifespan_task(lifespan_context, inc=2) + app.add_page(index) + + +@pytest.fixture() +def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]: + """Start LifespanApp app at tmp_path via AppHarness. + + Args: + tmp_path: pytest tmp_path fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path, + app_source=LifespanApp, # type: ignore + ) as harness: + yield harness + + +@pytest.mark.asyncio +async def test_lifespan(lifespan_app: AppHarness): + """Test the lifespan integration. + + Args: + lifespan_app: harness for LifespanApp app + """ + assert lifespan_app.app_module is not None, "app module is not found" + assert lifespan_app.app_instance is not None, "app is not running" + driver = lifespan_app.frontend() + + ss = SessionStorage(driver) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + + context_global = driver.find_element(By.ID, "context_global") + task_global = driver.find_element(By.ID, "task_global") + + assert context_global.text == "2" + assert lifespan_app.app_module.lifespan_context_global == 2 # type: ignore + + original_task_global_text = task_global.text + original_task_global_value = int(original_task_global_text) + lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text) + assert lifespan_app.app_module.lifespan_task_global > original_task_global_value # type: ignore + assert int(task_global.text) > original_task_global_value + + # Kill the backend + assert lifespan_app.backend is not None + lifespan_app.backend.should_exit = True + if lifespan_app.backend_thread is not None: + lifespan_app.backend_thread.join() + + # Check that the lifespan tasks have been cancelled + assert lifespan_app.app_module.lifespan_task_global == 0 + assert lifespan_app.app_module.lifespan_context_global == 4 diff --git a/integration/test_navigation.py b/integration/test_navigation.py index 2c288552f61..f5785a6c480 100644 --- a/integration/test_navigation.py +++ b/integration/test_navigation.py @@ -67,8 +67,7 @@ async def test_navigation_app(navigation_app: AppHarness): driver = navigation_app.frontend() ss = SessionStorage(driver) - token = AppHarness._poll_for(lambda: ss.get("token") is not None) - assert token is not None + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" internal_link = driver.find_element(By.ID, "internal") diff --git a/integration/utils.py b/integration/utils.py index bcbd6c497a0..5a5dbae8137 100644 --- a/integration/utils.py +++ b/integration/utils.py @@ -54,7 +54,9 @@ def __len__(self) -> int: Returns: The number of items in local storage. """ - return int(self.driver.execute_script("return window.localStorage.length;")) + return int( + self.driver.execute_script(f"return window.{self.storage_key}.length;") + ) def items(self) -> dict[str, str]: """Get all items in local storage. @@ -63,7 +65,7 @@ def items(self) -> dict[str, str]: A dict mapping keys to values. """ return self.driver.execute_script( - "var ls = window.localStorage, items = {}; " + f"var ls = window.{self.storage_key}, items = {{}}; " "for (var i = 0, k; i < ls.length; ++i) " " items[k = ls.key(i)] = ls.getItem(k); " "return items; " @@ -76,7 +78,7 @@ def keys(self) -> list[str]: A list of keys. """ return self.driver.execute_script( - "var ls = window.localStorage, keys = []; " + f"var ls = window.{self.storage_key}, keys = []; " "for (var i = 0; i < ls.length; ++i) " " keys[i] = ls.key(i); " "return keys; " @@ -92,7 +94,7 @@ def get(self, key) -> str: The value of the key. """ return self.driver.execute_script( - "return window.localStorage.getItem(arguments[0]);", key + f"return window.{self.storage_key}.getItem(arguments[0]);", key ) def set(self, key, value) -> None: @@ -103,7 +105,9 @@ def set(self, key, value) -> None: value: The value to set the key to. """ self.driver.execute_script( - "window.localStorage.setItem(arguments[0], arguments[1]);", key, value + f"window.{self.storage_key}.setItem(arguments[0], arguments[1]);", + key, + value, ) def has(self, key) -> bool: @@ -123,11 +127,13 @@ def remove(self, key) -> None: Args: key: The key to remove. """ - self.driver.execute_script("window.localStorage.removeItem(arguments[0]);", key) + self.driver.execute_script( + f"window.{self.storage_key}.removeItem(arguments[0]);", key + ) def clear(self) -> None: """Clear all local storage.""" - self.driver.execute_script("window.localStorage.clear();") + self.driver.execute_script(f"window.{self.storage_key}.clear();") def __getitem__(self, key) -> str: """Get a key from local storage. diff --git a/reflex/app.py b/reflex/app.py index 72a09462a1a..9481204cb9c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -7,10 +7,12 @@ import contextlib import copy import functools +import inspect import io import multiprocessing import os import platform +import sys from typing import ( Any, AsyncIterator, @@ -100,7 +102,50 @@ class OverlayFragment(Fragment): pass -class App(Base): +class LifespanMixin(Base): + """A Mixin that allow tasks to run during the whole app lifespan.""" + + # Lifespan tasks that are planned to run. + lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set() + + @contextlib.asynccontextmanager + async def _run_lifespan_tasks(self, app: FastAPI): + running_tasks = [] + try: + async with contextlib.AsyncExitStack() as stack: + for task in self.lifespan_tasks: + if isinstance(task, asyncio.Task): + running_tasks.append(task) + else: + signature = inspect.signature(task) + if "app" in signature.parameters: + task = functools.partial(task, app=app) + _t = task() + if isinstance(_t, contextlib._AsyncGeneratorContextManager): + await stack.enter_async_context(_t) + elif isinstance(_t, Coroutine): + running_tasks.append(asyncio.create_task(_t)) + yield + finally: + cancel_kwargs = ( + {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {} + ) + for task in running_tasks: + task.cancel(**cancel_kwargs) + + def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): + """Register a task to run during the lifespan of the app. + + Args: + task: The task to register. + task_kwargs: The kwargs of the task. + """ + if task_kwargs: + task = functools.partial(task, **task_kwargs) # type: ignore + self.lifespan_tasks.add(task) # type: ignore + + +class App(LifespanMixin, Base): """The main Reflex app that encapsulates the backend and frontend. Every Reflex app needs an app defined in its main module. @@ -203,7 +248,7 @@ def __init__(self, **kwargs): self.middleware.append(HydrateMiddleware()) # Set up the API. - self.api = FastAPI() + self.api = FastAPI(lifespan=self._run_lifespan_tasks) self._add_cors() self._add_default_endpoints()