From 35a0a333acfb417e032700f8b9c6b75841fac96a Mon Sep 17 00:00:00 2001 From: Lendemor Date: Thu, 16 May 2024 15:36:56 +0200 Subject: [PATCH 01/10] add support for lifespan tasks --- reflex/app.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 4d99d6949ec..0754c2c61ee 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -100,7 +100,35 @@ class OverlayFragment(Fragment): pass -class App(Base): +class LifespanMixin: + """A Mixin that allow tasks to run during the whole app lifespan.""" + + # Lifespan tasks that are planned to run. + lifespan_tasks: Set[asyncio.Task] = set() + + @contextlib.asynccontextmanager + async def _run_lifespan_tasks(self, app: FastAPI): + running_tasks = [] + try: + running_tasks = [ + task if isinstance(task, asyncio.Task) else asyncio.create_task(task()) + for task in self.lifespan_tasks + ] + yield + finally: + for task in running_tasks: + task.cancel() + + def register_lifespan_task(self, task: Callable | asyncio.Task): + """Register a task to run during the lifespan of the app. + + Args: + task: The task to register. + """ + self.lifespan_tasks.add(task) # type: ignore + + +class App(LifespanMixin, Base): """The main Reflex app that encapsulates the backend and frontend. Every Reflex app needs an app defined in its main module. @@ -203,7 +231,7 @@ def __init__(self, **kwargs): self.middleware.append(HydrateMiddleware()) # Set up the API. - self.api = FastAPI() + self.api = FastAPI(lifespan=self._run_lifespan_tasks) self._add_cors() self._add_default_endpoints() From d7d4fc33aeceddf0017d80a1b14e9831c5162479 Mon Sep 17 00:00:00 2001 From: Lendemor Date: Thu, 16 May 2024 15:55:19 +0200 Subject: [PATCH 02/10] allow passing args to lifespan task --- reflex/app.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/reflex/app.py b/reflex/app.py index 0754c2c61ee..5fe48b2e28d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -119,12 +119,18 @@ async def _run_lifespan_tasks(self, app: FastAPI): for task in running_tasks: task.cancel() - def register_lifespan_task(self, task: Callable | asyncio.Task): + def register_lifespan_task( + self, task: Callable | asyncio.Task, *task_args, **task_kwargs + ): """Register a task to run during the lifespan of the app. Args: task: The task to register. + task_args: The args of the task. + task_kwargs: The kwargs of the task. """ + if task_args or task_kwargs: + task = functools.partial(task, *task_args, **task_kwargs) # type: ignore self.lifespan_tasks.add(task) # type: ignore From bfc43e45d32e1f80adf350fbd430147d7d389407 Mon Sep 17 00:00:00 2001 From: Lendemor Date: Thu, 16 May 2024 18:38:21 +0200 Subject: [PATCH 03/10] add message to the cancel call --- reflex/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/app.py b/reflex/app.py index 5fe48b2e28d..a00323627f0 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -117,7 +117,7 @@ async def _run_lifespan_tasks(self, app: FastAPI): yield finally: for task in running_tasks: - task.cancel() + task.cancel("lifespan_cleanup") def register_lifespan_task( self, task: Callable | asyncio.Task, *task_args, **task_kwargs From cba8d0325dcf9ff34e76cf05c4a9975f15066278 Mon Sep 17 00:00:00 2001 From: Lendemor Date: Fri, 17 May 2024 03:04:32 +0200 Subject: [PATCH 04/10] allow asynccontextmanager as lifespan tasks --- reflex/app.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index a00323627f0..6d714b87e0c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -110,11 +110,17 @@ class LifespanMixin: async def _run_lifespan_tasks(self, app: FastAPI): running_tasks = [] try: - running_tasks = [ - task if isinstance(task, asyncio.Task) else asyncio.create_task(task()) - for task in self.lifespan_tasks - ] - yield + async with contextlib.AsyncExitStack() as stack: + for task in self.lifespan_tasks: + if isinstance(task, asyncio.Task): + running_tasks.append(task) + else: + _t = task() + if isinstance(_t, contextlib._AsyncGeneratorContextManager): + await stack.enter_async_context(_t) + elif isinstance(_t, Coroutine): + running_tasks.append(asyncio.create_task(_t)) + yield finally: for task in running_tasks: task.cancel("lifespan_cleanup") From 0b7bfbf87427c7a2b96b2e2e64f428d27bd9581d Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 17 May 2024 13:05:47 -0700 Subject: [PATCH 05/10] Fix integration.utils.SessionStorage Previously the SessionStorage util was just looking in localStorage, but the tests didn't catch it because they were asserting the token was not None, rather than asserting it was truthy. Fixed here, because I'm using this structure in the new lifespan test. --- integration/test_component_state.py | 3 +-- integration/test_navigation.py | 3 +-- integration/utils.py | 20 +++++++++++++------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/integration/test_component_state.py b/integration/test_component_state.py index e903a1b7425..77b8b3fa183 100644 --- a/integration/test_component_state.py +++ b/integration/test_component_state.py @@ -79,8 +79,7 @@ async def test_component_state_app(component_state_app: AppHarness): driver = component_state_app.frontend() ss = utils.SessionStorage(driver) - token = AppHarness._poll_for(lambda: ss.get("token") is not None) - assert token is not None + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" count_a = driver.find_element(By.ID, "count-a") count_b = driver.find_element(By.ID, "count-b") diff --git a/integration/test_navigation.py b/integration/test_navigation.py index 2c288552f61..f5785a6c480 100644 --- a/integration/test_navigation.py +++ b/integration/test_navigation.py @@ -67,8 +67,7 @@ async def test_navigation_app(navigation_app: AppHarness): driver = navigation_app.frontend() ss = SessionStorage(driver) - token = AppHarness._poll_for(lambda: ss.get("token") is not None) - assert token is not None + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" internal_link = driver.find_element(By.ID, "internal") diff --git a/integration/utils.py b/integration/utils.py index bcbd6c497a0..5a5dbae8137 100644 --- a/integration/utils.py +++ b/integration/utils.py @@ -54,7 +54,9 @@ def __len__(self) -> int: Returns: The number of items in local storage. """ - return int(self.driver.execute_script("return window.localStorage.length;")) + return int( + self.driver.execute_script(f"return window.{self.storage_key}.length;") + ) def items(self) -> dict[str, str]: """Get all items in local storage. @@ -63,7 +65,7 @@ def items(self) -> dict[str, str]: A dict mapping keys to values. """ return self.driver.execute_script( - "var ls = window.localStorage, items = {}; " + f"var ls = window.{self.storage_key}, items = {{}}; " "for (var i = 0, k; i < ls.length; ++i) " " items[k = ls.key(i)] = ls.getItem(k); " "return items; " @@ -76,7 +78,7 @@ def keys(self) -> list[str]: A list of keys. """ return self.driver.execute_script( - "var ls = window.localStorage, keys = []; " + f"var ls = window.{self.storage_key}, keys = []; " "for (var i = 0; i < ls.length; ++i) " " keys[i] = ls.key(i); " "return keys; " @@ -92,7 +94,7 @@ def get(self, key) -> str: The value of the key. """ return self.driver.execute_script( - "return window.localStorage.getItem(arguments[0]);", key + f"return window.{self.storage_key}.getItem(arguments[0]);", key ) def set(self, key, value) -> None: @@ -103,7 +105,9 @@ def set(self, key, value) -> None: value: The value to set the key to. """ self.driver.execute_script( - "window.localStorage.setItem(arguments[0], arguments[1]);", key, value + f"window.{self.storage_key}.setItem(arguments[0], arguments[1]);", + key, + value, ) def has(self, key) -> bool: @@ -123,11 +127,13 @@ def remove(self, key) -> None: Args: key: The key to remove. """ - self.driver.execute_script("window.localStorage.removeItem(arguments[0]);", key) + self.driver.execute_script( + f"window.{self.storage_key}.removeItem(arguments[0]);", key + ) def clear(self) -> None: """Clear all local storage.""" - self.driver.execute_script("window.localStorage.clear();") + self.driver.execute_script(f"window.{self.storage_key}.clear();") def __getitem__(self, key) -> str: """Get a key from local storage. From 409e881750dcb443916d3b316aabed7bb4d749ee Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 17 May 2024 13:07:18 -0700 Subject: [PATCH 06/10] If the lifespan task or context takes "app" parameter, pass the FastAPI instance. --- reflex/app.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index c3edb662eed..e034df5c106 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -7,6 +7,7 @@ import contextlib import copy import functools +import inspect import io import multiprocessing import os @@ -100,11 +101,11 @@ class OverlayFragment(Fragment): pass -class LifespanMixin: +class LifespanMixin(Base): """A Mixin that allow tasks to run during the whole app lifespan.""" # Lifespan tasks that are planned to run. - lifespan_tasks: Set[asyncio.Task] = set() + lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set() @contextlib.asynccontextmanager async def _run_lifespan_tasks(self, app: FastAPI): @@ -115,6 +116,9 @@ async def _run_lifespan_tasks(self, app: FastAPI): if isinstance(task, asyncio.Task): running_tasks.append(task) else: + signature = inspect.signature(task) + if "app" in signature.parameters: + task = functools.partial(task, app=app) _t = task() if isinstance(_t, contextlib._AsyncGeneratorContextManager): await stack.enter_async_context(_t) From 92bf7551b5b09ab58a07bef6f4e2332354ef1c8d Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 17 May 2024 13:10:43 -0700 Subject: [PATCH 07/10] test_lifespan: end to end test for register_lifespan_task --- integration/test_lifespan.py | 120 +++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 integration/test_lifespan.py diff --git a/integration/test_lifespan.py b/integration/test_lifespan.py new file mode 100644 index 00000000000..4dccff2aada --- /dev/null +++ b/integration/test_lifespan.py @@ -0,0 +1,120 @@ +"""Test cases for the FastAPI lifespan integration.""" +import asyncio +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import AppHarness + +from .utils import SessionStorage + + +def LifespanApp(): + """App with lifespan tasks and context.""" + import asyncio + from contextlib import asynccontextmanager + + import reflex as rx + + lifespan_task_global = 0 + lifespan_context_global = 0 + + @asynccontextmanager + async def lifespan_context(app, inc: int = 1): + global lifespan_context_global + print(f"Lifespan context entered: {app}.") + lifespan_context_global += inc # pyright: ignore[reportUnboundVariable] + try: + yield + finally: + print("Lifespan context exited.") + lifespan_context_global += inc + + async def lifespan_task(inc: int = 1): + global lifespan_task_global + print("Lifespan global started.") + try: + while True: + lifespan_task_global += inc # pyright: ignore[reportUnboundVariable] + await asyncio.sleep(0.1) + except asyncio.CancelledError: + print("Lifespan global cancelled.") + lifespan_task_global = 0 + + class LifespanState(rx.State): + @rx.var + def task_global(self) -> int: + return lifespan_task_global + + @rx.var + def context_global(self) -> int: + return lifespan_context_global + + def tick(self, date): + pass + + def index(): + return rx.vstack( + rx.text(LifespanState.task_global, id="task_global"), + rx.text(LifespanState.context_global, id="context_global"), + rx.moment(interval=100, on_change=LifespanState.tick), + ) + + app = rx.App() + app.register_lifespan_task(lifespan_task) + app.register_lifespan_task(lifespan_context, inc=2) + app.add_page(index) + + +@pytest.fixture() +def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]: + """Start LifespanApp app at tmp_path via AppHarness. + + Args: + tmp_path: pytest tmp_path fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path, + app_source=LifespanApp, # type: ignore + ) as harness: + yield harness + + +@pytest.mark.asyncio +async def test_lifespan(lifespan_app: AppHarness): + """Test the lifespan integration. + + Args: + lifespan_app: harness for LifespanApp app + """ + assert lifespan_app.app_module is not None, "app module is not found" + assert lifespan_app.app_instance is not None, "app is not running" + driver = lifespan_app.frontend() + + ss = SessionStorage(driver) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + + context_global = driver.find_element(By.ID, "context_global") + task_global = driver.find_element(By.ID, "task_global") + + assert context_global.text == "2" + assert lifespan_app.app_module.lifespan_context_global == 2 # type: ignore + + original_task_global_value = int(task_global.text) + await asyncio.sleep(0.3) + assert lifespan_app.app_module.lifespan_task_global > original_task_global_value # type: ignore + assert int(task_global.text) > original_task_global_value + + # Kill the backend + assert lifespan_app.backend is not None + lifespan_app.backend.should_exit = True + if lifespan_app.backend_thread is not None: + lifespan_app.backend_thread.join() + + # Check that the lifespan tasks have been cancelled + assert lifespan_app.app_module.lifespan_task_global == 0 + assert lifespan_app.app_module.lifespan_context_global == 4 From 000318ade4c80e0627d3c8f125887e87afd33586 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 17 May 2024 13:40:46 -0700 Subject: [PATCH 08/10] In py3.8, Task.cancel takes no args --- reflex/app.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/reflex/app.py b/reflex/app.py index e034df5c106..2de8b0aee4d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -12,6 +12,7 @@ import multiprocessing import os import platform +import sys from typing import ( Any, AsyncIterator, @@ -126,8 +127,11 @@ async def _run_lifespan_tasks(self, app: FastAPI): running_tasks.append(asyncio.create_task(_t)) yield finally: + cancel_kwargs = ( + {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {} + ) for task in running_tasks: - task.cancel("lifespan_cleanup") + task.cancel(**cancel_kwargs) def register_lifespan_task( self, task: Callable | asyncio.Task, *task_args, **task_kwargs From ee89548b7778bbed1e2e47e3cc309cd882b665ab Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 17 May 2024 13:41:41 -0700 Subject: [PATCH 09/10] test_lifespan: use polling to make the test more robust Fix CI failure --- integration/test_lifespan.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/integration/test_lifespan.py b/integration/test_lifespan.py index 4dccff2aada..cb384a51165 100644 --- a/integration/test_lifespan.py +++ b/integration/test_lifespan.py @@ -1,5 +1,4 @@ """Test cases for the FastAPI lifespan integration.""" -import asyncio from typing import Generator import pytest @@ -38,8 +37,8 @@ async def lifespan_task(inc: int = 1): while True: lifespan_task_global += inc # pyright: ignore[reportUnboundVariable] await asyncio.sleep(0.1) - except asyncio.CancelledError: - print("Lifespan global cancelled.") + except asyncio.CancelledError as ce: + print(f"Lifespan global cancelled: {ce}.") lifespan_task_global = 0 class LifespanState(rx.State): @@ -104,8 +103,9 @@ async def test_lifespan(lifespan_app: AppHarness): assert context_global.text == "2" assert lifespan_app.app_module.lifespan_context_global == 2 # type: ignore - original_task_global_value = int(task_global.text) - await asyncio.sleep(0.3) + original_task_global_text = task_global.text + original_task_global_value = int(original_task_global_text) + lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text) assert lifespan_app.app_module.lifespan_task_global > original_task_global_value # type: ignore assert int(task_global.text) > original_task_global_value From 90db4627db1930bcc2242987d1837621e57f4691 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 21 May 2024 12:12:03 -0700 Subject: [PATCH 10/10] Do not allow task_args for better composability --- reflex/app.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 518e9beebc8..9481204cb9c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -133,18 +133,15 @@ async def _run_lifespan_tasks(self, app: FastAPI): for task in running_tasks: task.cancel(**cancel_kwargs) - def register_lifespan_task( - self, task: Callable | asyncio.Task, *task_args, **task_kwargs - ): + def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): """Register a task to run during the lifespan of the app. Args: task: The task to register. - task_args: The args of the task. task_kwargs: The kwargs of the task. """ - if task_args or task_kwargs: - task = functools.partial(task, *task_args, **task_kwargs) # type: ignore + if task_kwargs: + task = functools.partial(task, **task_kwargs) # type: ignore self.lifespan_tasks.add(task) # type: ignore