diff --git a/reflex/__init__.py b/reflex/__init__.py index 5625244160..0be568b1a7 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -331,7 +331,7 @@ "SessionStorage", ], "middleware": ["middleware", "Middleware"], - "model": ["session", "Model"], + "model": ["asession", "session", "Model"], "state": [ "var", "ComponentState", diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index 6f61435e60..6bdfa11a02 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state from .middleware import Middleware as Middleware from .middleware import middleware as middleware from .model import Model as Model +from .model import asession as asession from .model import session as session from .page import page as page from .state import ComponentState as ComponentState diff --git a/reflex/config.py b/reflex/config.py index 1d952f4f02..c40abcb39a 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -512,6 +512,9 @@ class EnvironmentVariables: # Whether to print the SQL queries if the log level is INFO or lower. SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False) + # Whether to check db connections before using them. + SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True) + # Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration. REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False) @@ -568,6 +571,10 @@ class EnvironmentVariables: environment = EnvironmentVariables() +# These vars are not logged because they may contain sensitive information. +_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"} + + class Config(Base): """The config defines runtime settings for the app. @@ -621,6 +628,9 @@ class Config: # The database url used by rx.Model. db_url: Optional[str] = "sqlite:///reflex.db" + # The async database url used by rx.Model. + async_db_url: Optional[str] = None + # The redis url redis_url: Optional[str] = None @@ -748,18 +758,20 @@ def update_from_env(self) -> dict[str, Any]: # If the env var is set, override the config value. if env_var is not None: - if key.upper() != "DB_URL": - console.info( - f"Overriding config value {key} with env var {key.upper()}={env_var}", - dedupe=True, - ) - # Interpret the value. value = interpret_env_var_value(env_var, field.outer_type_, field.name) # Set the value. updated_values[key] = value + if key.upper() in _sensitive_env_vars: + env_var = "***" + + console.info( + f"Overriding config value {key} with env var {key.upper()}={env_var}", + dedupe=True, + ) + return updated_values def get_event_namespace(self) -> str: diff --git a/reflex/model.py b/reflex/model.py index 4b070ec678..03b1cc8280 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from collections import defaultdict from typing import Any, ClassVar, Optional, Type, Union @@ -14,6 +15,7 @@ import alembic.util import sqlalchemy import sqlalchemy.exc +import sqlalchemy.ext.asyncio import sqlalchemy.orm from reflex.base import Base @@ -21,6 +23,48 @@ from reflex.utils import console from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key +_ENGINE: dict[str, sqlalchemy.engine.Engine] = {} +_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {} +_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {} + +# Import AsyncSession _after_ reflex.utils.compat +from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402 + + +def _safe_db_url_for_logging(url: str) -> str: + """Remove username and password from the database URL for logging. + + Args: + url: The database URL. + + Returns: + The database URL with the username and password removed. + """ + return re.sub(r"://[^@]+@", "://:@", url) + + +def get_engine_args(url: str | None = None) -> dict[str, Any]: + """Get the database engine arguments. + + Args: + url: The database url. + + Returns: + The database engine arguments as a dict. + """ + kwargs: dict[str, Any] = dict( + # Print the SQL queries if the log level is INFO or lower. + echo=environment.SQLALCHEMY_ECHO.get(), + # Check connections before returning them. + pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(), + ) + conf = get_config() + url = url or conf.db_url + if url is not None and url.startswith("sqlite"): + # Needed for the admin dash on sqlite. + kwargs["connect_args"] = {"check_same_thread": False} + return kwargs + def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: """Get the database engine. @@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: url = url or conf.db_url if url is None: raise ValueError("No database url configured") + + global _ENGINE + if url in _ENGINE: + return _ENGINE[url] + if not environment.ALEMBIC_CONFIG.get().exists(): console.warn( "Database is not initialized, run [bold]reflex db init[/bold] first." ) - # Print the SQL queries if the log level is INFO or lower. - echo_db_query = environment.SQLALCHEMY_ECHO.get() - # Needed for the admin dash on sqlite. - connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {} - return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args) + _ENGINE[url] = sqlmodel.create_engine( + url, + **get_engine_args(url), + ) + return _ENGINE[url] + + +def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine: + """Get the async database engine. + + Args: + url: The database url. + + Returns: + The async database engine. + + Raises: + ValueError: If the async database url is None. + """ + if url is None: + conf = get_config() + url = conf.async_db_url + if url is not None and conf.db_url is not None: + async_db_url_tail = url.partition("://")[2] + db_url_tail = conf.db_url.partition("://")[2] + if async_db_url_tail != db_url_tail: + console.warn( + f"async_db_url `{_safe_db_url_for_logging(url)}` " + "should reference the same database as " + f"db_url `{_safe_db_url_for_logging(conf.db_url)}`." + ) + if url is None: + raise ValueError("No async database url configured") + + global _ASYNC_ENGINE + if url in _ASYNC_ENGINE: + return _ASYNC_ENGINE[url] + + if not environment.ALEMBIC_CONFIG.get().exists(): + console.warn( + "Database is not initialized, run [bold]reflex db init[/bold] first." + ) + _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine( + url, + **get_engine_args(url), + ) + return _ASYNC_ENGINE[url] async def get_db_status() -> bool: @@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session: return sqlmodel.Session(get_engine(url)) +def asession(url: str | None = None) -> AsyncSession: + """Get an async sqlmodel session to interact with the database. + + async with rx.asession() as asession: + ... + + Most operations against the `asession` must be awaited. + + Args: + url: The database url. + + Returns: + An async database session. + """ + global _AsyncSessionLocal + if url not in _AsyncSessionLocal: + _AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker( + bind=get_async_engine(url), + class_=AsyncSession, + autocommit=False, + autoflush=False, + ) + return _AsyncSessionLocal[url]() + + def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session: """Get a bare sqlalchemy session to interact with the database.