Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy fast-depends inject #232

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,7 @@
"client_secret": None,
},
}

# Use fast_depends.inject() when route function is called instead of when it is created. Speeds up Yepcord launch.
# May slow down first request to every route by ~50ms.
LAZY_INJECT = False
67 changes: 46 additions & 21 deletions yepcord/rest_api/y_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,72 +17,97 @@
"""

from functools import wraps
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Awaitable

from fast_depends import inject
from flask.sansio.scaffold import T_route, setupmethod
from quart import Blueprint, g
from quart_schema import validate_request, validate_querystring

from yepcord.yepcord.config import Config

validate_funcs = {"body": validate_request, "qs": validate_querystring}


def apply_validator(func: T_route, type_: str, cls: Optional[type], source=None) -> T_route:
applied = getattr(func, "_patches", set())
def apply_validator(src_func: T_route, type_: str, cls: Optional[type], source=None) -> T_route:
applied = getattr(src_func, "_patches", set())

if cls is None or f"validate_{type_}" in applied or type_ not in validate_funcs:
return func
return src_func

kw = {} if source is None else {"source": source}
func = validate_funcs[type_](cls, **kw)(func)
func = validate_funcs[type_](cls, **kw)(src_func)

applied.add(f"validate_{type_}")
setattr(func, "_patches", applied)
if len(applied) > 1: # pragma: no cover
delattr(src_func, "_patches")

return func


def apply_inject(func: T_route) -> T_route:
applied = getattr(func, "_patches", set())
def apply_inject(src_func: T_route) -> T_route:
applied = getattr(src_func, "_patches", set())

if "fastdepends_inject" in applied:
return func
return src_func

if Config.LAZY_INJECT: # pragma: no cover
injected_func = None

@wraps(src_func)
async def func(*args, **kwargs):
nonlocal injected_func

if injected_func is None:
injected_func = inject(src_func)

return await injected_func(*args, **kwargs)
else:
func = inject(src_func)

func = inject(func)
applied.add("fastdepends_inject")
setattr(func, "_patches", applied)
if len(applied) > 1: # pragma: no cover
delattr(src_func, "_patches")

return func


def apply_allow_bots(func: T_route) -> T_route:
applied = getattr(func, "_patches", set())
def apply_allow_bots(src_func: T_route) -> T_route:
applied = getattr(src_func, "_patches", set())
if "allow_bots" in applied:
return func
return src_func

@wraps(func)
@wraps(src_func)
async def wrapped(*args, **kwargs):
g.bots_allowed = True
return await func(*args, **kwargs)
return await src_func(*args, **kwargs)

applied.add("allow_bots")
setattr(func, "_patches", applied)
setattr(wrapped, "_patches", applied)
if len(applied) > 1: # pragma: no cover
delattr(src_func, "_patches")

return wrapped


def apply_oauth(func: T_route, scopes: list[str]) -> T_route:
applied = getattr(func, "_patches", set())
def apply_oauth(src_func: T_route, scopes: list[str]) -> T_route:
applied = getattr(src_func, "_patches", set())
if "oauth" in applied:
return func
return src_func

@wraps(func)
@wraps(src_func)
async def wrapped(*args, **kwargs):
g.oauth_allowed = True
g.oauth_scopes = set(scopes)
return await func(*args, **kwargs)
return await src_func(*args, **kwargs)

applied.add("oauth")
setattr(func, "_patches", applied)
setattr(wrapped, "_patches", applied)
if len(applied) > 1: # pragma: no cover
delattr(src_func, "_patches")

return wrapped


Expand Down
18 changes: 18 additions & 0 deletions yepcord/yepcord/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class ConfigModel(BaseModel):
BCRYPT_ROUNDS: int = 15
CAPTCHA: ConfigCaptcha = Field(default_factory=ConfigCaptcha)
CONNECTIONS: ConfigConnections = Field(default_factory=ConfigConnections)
LAZY_INJECT: bool = False

@field_validator("KEY")
def validate_key(cls, value: str) -> str:
Expand Down Expand Up @@ -167,6 +168,23 @@ class Config:


class _Config(Singleton):
DB_CONNECT_STRING: str
MAIL_CONNECT_STRING: str
MIGRATIONS_DIR: str
KEY: str
PUBLIC_HOST: str
GATEWAY_HOST: str
CDN_HOST: str
STORAGE: dict
TENOR_KEY: Optional[str]
MESSAGE_BROKER: dict
REDIS_URL: Optional[str]
GATEWAY_KEEP_ALIVE_DELAY: int
BCRYPT_ROUNDS: int
CAPTCHA: dict
CONNECTIONS: dict
LAZY_INJECT: bool

def update(self, variables: dict) -> _Config:
self.__dict__.update(variables)
return self
Expand Down
Loading