Skip to content

Commit

Permalink
Reuse the sqlalchemy engine once it's created (#4493)
Browse files Browse the repository at this point in the history
* Reuse the sqlalchemy engine once it's created

* Implement `rx.asession` for async database support

Requires setting `async_db_url` to the same database as `db_url`, except using
an async driver; this will vary by database.

* resolve the url first, so the key into _ENGINE is correct

* Ping db connections before returning them from pool

Move connect engine kwargs to a separate function

* the param is `echo`

* sanity check that config db_url and async_db_url are the same

throw a warning if the part following the `://` differs between these two

* create_async_engine: use sqlalchemy async API

update types

* redact ASYNC_DB_URL similarly to DB_URL when overridden in config

* update rx.asession docstring

* use async_sessionmaker

* Redact sensitive env vars instead of hiding them
  • Loading branch information
masenf authored Dec 10, 2024
1 parent 3ef7106 commit 4922f7b
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 12 deletions.
2 changes: 1 addition & 1 deletion reflex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@
"SessionStorage",
],
"middleware": ["middleware", "Middleware"],
"model": ["session", "Model"],
"model": ["asession", "session", "Model"],
"state": [
"var",
"ComponentState",
Expand Down
1 change: 1 addition & 0 deletions reflex/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state
from .middleware import Middleware as Middleware
from .middleware import middleware as middleware
from .model import Model as Model
from .model import asession as asession
from .model import session as session
from .page import page as page
from .state import ComponentState as ComponentState
Expand Down
24 changes: 18 additions & 6 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,9 @@ class EnvironmentVariables:
# Whether to print the SQL queries if the log level is INFO or lower.
SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)

# Whether to check db connections before using them.
SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True)

# Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)

Expand Down Expand Up @@ -568,6 +571,10 @@ class EnvironmentVariables:
environment = EnvironmentVariables()


# These vars are not logged because they may contain sensitive information.
_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"}


class Config(Base):
"""The config defines runtime settings for the app.
Expand Down Expand Up @@ -621,6 +628,9 @@ class Config:
# The database url used by rx.Model.
db_url: Optional[str] = "sqlite:///reflex.db"

# The async database url used by rx.Model.
async_db_url: Optional[str] = None

# The redis url
redis_url: Optional[str] = None

Expand Down Expand Up @@ -748,18 +758,20 @@ def update_from_env(self) -> dict[str, Any]:

# If the env var is set, override the config value.
if env_var is not None:
if key.upper() != "DB_URL":
console.info(
f"Overriding config value {key} with env var {key.upper()}={env_var}",
dedupe=True,
)

# Interpret the value.
value = interpret_env_var_value(env_var, field.outer_type_, field.name)

# Set the value.
updated_values[key] = value

if key.upper() in _sensitive_env_vars:
env_var = "***"

console.info(
f"Overriding config value {key} with env var {key.upper()}={env_var}",
dedupe=True,
)

return updated_values

def get_event_namespace(self) -> str:
Expand Down
126 changes: 121 additions & 5 deletions reflex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import re
from collections import defaultdict
from typing import Any, ClassVar, Optional, Type, Union

Expand All @@ -14,13 +15,56 @@
import alembic.util
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.ext.asyncio
import sqlalchemy.orm

from reflex.base import Base
from reflex.config import environment, get_config
from reflex.utils import console
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key

_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}

# Import AsyncSession _after_ reflex.utils.compat
from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402


def _safe_db_url_for_logging(url: str) -> str:
"""Remove username and password from the database URL for logging.
Args:
url: The database URL.
Returns:
The database URL with the username and password removed.
"""
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)


def get_engine_args(url: str | None = None) -> dict[str, Any]:
"""Get the database engine arguments.
Args:
url: The database url.
Returns:
The database engine arguments as a dict.
"""
kwargs: dict[str, Any] = dict(
# Print the SQL queries if the log level is INFO or lower.
echo=environment.SQLALCHEMY_ECHO.get(),
# Check connections before returning them.
pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
)
conf = get_config()
url = url or conf.db_url
if url is not None and url.startswith("sqlite"):
# Needed for the admin dash on sqlite.
kwargs["connect_args"] = {"check_same_thread": False}
return kwargs


def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
"""Get the database engine.
Expand All @@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
url = url or conf.db_url
if url is None:
raise ValueError("No database url configured")

global _ENGINE
if url in _ENGINE:
return _ENGINE[url]

if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
# Print the SQL queries if the log level is INFO or lower.
echo_db_query = environment.SQLALCHEMY_ECHO.get()
# Needed for the admin dash on sqlite.
connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
_ENGINE[url] = sqlmodel.create_engine(
url,
**get_engine_args(url),
)
return _ENGINE[url]


def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
"""Get the async database engine.
Args:
url: The database url.
Returns:
The async database engine.
Raises:
ValueError: If the async database url is None.
"""
if url is None:
conf = get_config()
url = conf.async_db_url
if url is not None and conf.db_url is not None:
async_db_url_tail = url.partition("://")[2]
db_url_tail = conf.db_url.partition("://")[2]
if async_db_url_tail != db_url_tail:
console.warn(
f"async_db_url `{_safe_db_url_for_logging(url)}` "
"should reference the same database as "
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
)
if url is None:
raise ValueError("No async database url configured")

global _ASYNC_ENGINE
if url in _ASYNC_ENGINE:
return _ASYNC_ENGINE[url]

if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
url,
**get_engine_args(url),
)
return _ASYNC_ENGINE[url]


async def get_db_status() -> bool:
Expand Down Expand Up @@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session:
return sqlmodel.Session(get_engine(url))


def asession(url: str | None = None) -> AsyncSession:
"""Get an async sqlmodel session to interact with the database.
async with rx.asession() as asession:
...
Most operations against the `asession` must be awaited.
Args:
url: The database url.
Returns:
An async database session.
"""
global _AsyncSessionLocal
if url not in _AsyncSessionLocal:
_AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
bind=get_async_engine(url),
class_=AsyncSession,
autocommit=False,
autoflush=False,
)
return _AsyncSessionLocal[url]()


def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
"""Get a bare sqlalchemy session to interact with the database.
Expand Down

0 comments on commit 4922f7b

Please sign in to comment.