Skip to content

Commit

Permalink
add support for lifespan tasks (reflex-dev#3312)
Browse files Browse the repository at this point in the history
* add support for lifespan tasks

* allow passing args to lifespan task

* add message to the cancel call

* allow asynccontextmanager as lifespan tasks

* Fix integration.utils.SessionStorage

Previously the SessionStorage util was just looking in localStorage, but the
tests didn't catch it because they were asserting the token was not None,
rather than asserting it was truthy.

Fixed here, because I'm using this structure in the new lifespan test.

* If the lifespan task or context takes "app" parameter, pass the FastAPI instance.

* test_lifespan: end to end test for register_lifespan_task

* In py3.8, Task.cancel takes no args

* test_lifespan: use polling to make the test more robust

Fix CI failure

* Do not allow task_args for better composability

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
  • Loading branch information
2 people authored and benedikt-bartscher committed Jun 3, 2024
1 parent 440bdef commit 85f68aa
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 13 deletions.
3 changes: 1 addition & 2 deletions integration/test_component_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ async def test_component_state_app(component_state_app: AppHarness):
driver = component_state_app.frontend()

ss = utils.SessionStorage(driver)
token = AppHarness._poll_for(lambda: ss.get("token") is not None)
assert token is not None
assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"

count_a = driver.find_element(By.ID, "count-a")
count_b = driver.find_element(By.ID, "count-b")
Expand Down
120 changes: 120 additions & 0 deletions integration/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Test cases for the FastAPI lifespan integration."""
from typing import Generator

import pytest
from selenium.webdriver.common.by import By

from reflex.testing import AppHarness

from .utils import SessionStorage


def LifespanApp():
"""App with lifespan tasks and context."""
import asyncio
from contextlib import asynccontextmanager

import reflex as rx

lifespan_task_global = 0
lifespan_context_global = 0

@asynccontextmanager
async def lifespan_context(app, inc: int = 1):
global lifespan_context_global
print(f"Lifespan context entered: {app}.")
lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
try:
yield
finally:
print("Lifespan context exited.")
lifespan_context_global += inc

async def lifespan_task(inc: int = 1):
global lifespan_task_global
print("Lifespan global started.")
try:
while True:
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable]
await asyncio.sleep(0.1)
except asyncio.CancelledError as ce:
print(f"Lifespan global cancelled: {ce}.")
lifespan_task_global = 0

class LifespanState(rx.State):
@rx.var
def task_global(self) -> int:
return lifespan_task_global

@rx.var
def context_global(self) -> int:
return lifespan_context_global

def tick(self, date):
pass

def index():
return rx.vstack(
rx.text(LifespanState.task_global, id="task_global"),
rx.text(LifespanState.context_global, id="context_global"),
rx.moment(interval=100, on_change=LifespanState.tick),
)

app = rx.App()
app.register_lifespan_task(lifespan_task)
app.register_lifespan_task(lifespan_context, inc=2)
app.add_page(index)


@pytest.fixture()
def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
"""Start LifespanApp app at tmp_path via AppHarness.
Args:
tmp_path: pytest tmp_path fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path,
app_source=LifespanApp, # type: ignore
) as harness:
yield harness


@pytest.mark.asyncio
async def test_lifespan(lifespan_app: AppHarness):
"""Test the lifespan integration.
Args:
lifespan_app: harness for LifespanApp app
"""
assert lifespan_app.app_module is not None, "app module is not found"
assert lifespan_app.app_instance is not None, "app is not running"
driver = lifespan_app.frontend()

ss = SessionStorage(driver)
assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"

context_global = driver.find_element(By.ID, "context_global")
task_global = driver.find_element(By.ID, "task_global")

assert context_global.text == "2"
assert lifespan_app.app_module.lifespan_context_global == 2 # type: ignore

original_task_global_text = task_global.text
original_task_global_value = int(original_task_global_text)
lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text)
assert lifespan_app.app_module.lifespan_task_global > original_task_global_value # type: ignore
assert int(task_global.text) > original_task_global_value

# Kill the backend
assert lifespan_app.backend is not None
lifespan_app.backend.should_exit = True
if lifespan_app.backend_thread is not None:
lifespan_app.backend_thread.join()

# Check that the lifespan tasks have been cancelled
assert lifespan_app.app_module.lifespan_task_global == 0
assert lifespan_app.app_module.lifespan_context_global == 4
3 changes: 1 addition & 2 deletions integration/test_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ async def test_navigation_app(navigation_app: AppHarness):
driver = navigation_app.frontend()

ss = SessionStorage(driver)
token = AppHarness._poll_for(lambda: ss.get("token") is not None)
assert token is not None
assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"

internal_link = driver.find_element(By.ID, "internal")

Expand Down
20 changes: 13 additions & 7 deletions integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __len__(self) -> int:
Returns:
The number of items in local storage.
"""
return int(self.driver.execute_script("return window.localStorage.length;"))
return int(
self.driver.execute_script(f"return window.{self.storage_key}.length;")
)

def items(self) -> dict[str, str]:
"""Get all items in local storage.
Expand All @@ -63,7 +65,7 @@ def items(self) -> dict[str, str]:
A dict mapping keys to values.
"""
return self.driver.execute_script(
"var ls = window.localStorage, items = {}; "
f"var ls = window.{self.storage_key}, items = {{}}; "
"for (var i = 0, k; i < ls.length; ++i) "
" items[k = ls.key(i)] = ls.getItem(k); "
"return items; "
Expand All @@ -76,7 +78,7 @@ def keys(self) -> list[str]:
A list of keys.
"""
return self.driver.execute_script(
"var ls = window.localStorage, keys = []; "
f"var ls = window.{self.storage_key}, keys = []; "
"for (var i = 0; i < ls.length; ++i) "
" keys[i] = ls.key(i); "
"return keys; "
Expand All @@ -92,7 +94,7 @@ def get(self, key) -> str:
The value of the key.
"""
return self.driver.execute_script(
"return window.localStorage.getItem(arguments[0]);", key
f"return window.{self.storage_key}.getItem(arguments[0]);", key
)

def set(self, key, value) -> None:
Expand All @@ -103,7 +105,9 @@ def set(self, key, value) -> None:
value: The value to set the key to.
"""
self.driver.execute_script(
"window.localStorage.setItem(arguments[0], arguments[1]);", key, value
f"window.{self.storage_key}.setItem(arguments[0], arguments[1]);",
key,
value,
)

def has(self, key) -> bool:
Expand All @@ -123,11 +127,13 @@ def remove(self, key) -> None:
Args:
key: The key to remove.
"""
self.driver.execute_script("window.localStorage.removeItem(arguments[0]);", key)
self.driver.execute_script(
f"window.{self.storage_key}.removeItem(arguments[0]);", key
)

def clear(self) -> None:
"""Clear all local storage."""
self.driver.execute_script("window.localStorage.clear();")
self.driver.execute_script(f"window.{self.storage_key}.clear();")

def __getitem__(self, key) -> str:
"""Get a key from local storage.
Expand Down
49 changes: 47 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import contextlib
import copy
import functools
import inspect
import io
import multiprocessing
import os
import platform
import sys
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -100,7 +102,50 @@ class OverlayFragment(Fragment):
pass


class App(Base):
class LifespanMixin(Base):
"""A Mixin that allow tasks to run during the whole app lifespan."""

# Lifespan tasks that are planned to run.
lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()

@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: FastAPI):
running_tasks = []
try:
async with contextlib.AsyncExitStack() as stack:
for task in self.lifespan_tasks:
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
_t = task()
if isinstance(_t, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(_t)
elif isinstance(_t, Coroutine):
running_tasks.append(asyncio.create_task(_t))
yield
finally:
cancel_kwargs = (
{"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
)
for task in running_tasks:
task.cancel(**cancel_kwargs)

def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
Args:
task: The task to register.
task_kwargs: The kwargs of the task.
"""
if task_kwargs:
task = functools.partial(task, **task_kwargs) # type: ignore
self.lifespan_tasks.add(task) # type: ignore


class App(LifespanMixin, Base):
"""The main Reflex app that encapsulates the backend and frontend.
Every Reflex app needs an app defined in its main module.
Expand Down Expand Up @@ -203,7 +248,7 @@ def __init__(self, **kwargs):
self.middleware.append(HydrateMiddleware())

# Set up the API.
self.api = FastAPI()
self.api = FastAPI(lifespan=self._run_lifespan_tasks)
self._add_cors()
self._add_default_endpoints()

Expand Down

0 comments on commit 85f68aa

Please sign in to comment.