diff --git a/config.example.py b/config.example.py index eb9db8e..f56fb4e 100644 --- a/config.example.py +++ b/config.example.py @@ -113,3 +113,7 @@ "client_secret": None, }, } + +# Use fast_depends.inject() when route function is called instead of when it is created. Speeds up Yepcord launch. +# May slow down first request to every route by ~50ms. +LAZY_INJECT = False diff --git a/yepcord/rest_api/y_blueprint.py b/yepcord/rest_api/y_blueprint.py index fb08752..b44221d 100644 --- a/yepcord/rest_api/y_blueprint.py +++ b/yepcord/rest_api/y_blueprint.py @@ -17,72 +17,97 @@ """ from functools import wraps -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Awaitable from fast_depends import inject from flask.sansio.scaffold import T_route, setupmethod from quart import Blueprint, g from quart_schema import validate_request, validate_querystring +from yepcord.yepcord.config import Config + validate_funcs = {"body": validate_request, "qs": validate_querystring} -def apply_validator(func: T_route, type_: str, cls: Optional[type], source=None) -> T_route: - applied = getattr(func, "_patches", set()) +def apply_validator(src_func: T_route, type_: str, cls: Optional[type], source=None) -> T_route: + applied = getattr(src_func, "_patches", set()) if cls is None or f"validate_{type_}" in applied or type_ not in validate_funcs: - return func + return src_func kw = {} if source is None else {"source": source} - func = validate_funcs[type_](cls, **kw)(func) + func = validate_funcs[type_](cls, **kw)(src_func) applied.add(f"validate_{type_}") setattr(func, "_patches", applied) + if len(applied) > 1: # pragma: no cover + delattr(src_func, "_patches") return func -def apply_inject(func: T_route) -> T_route: - applied = getattr(func, "_patches", set()) +def apply_inject(src_func: T_route) -> T_route: + applied = getattr(src_func, "_patches", set()) if "fastdepends_inject" in applied: - return func + return src_func + + if Config.LAZY_INJECT: # pragma: no cover + injected_func = None + + @wraps(src_func) + async def func(*args, **kwargs): + nonlocal injected_func + + if injected_func is None: + injected_func = inject(src_func) + + return await injected_func(*args, **kwargs) + else: + func = inject(src_func) - func = inject(func) applied.add("fastdepends_inject") setattr(func, "_patches", applied) + if len(applied) > 1: # pragma: no cover + delattr(src_func, "_patches") return func -def apply_allow_bots(func: T_route) -> T_route: - applied = getattr(func, "_patches", set()) +def apply_allow_bots(src_func: T_route) -> T_route: + applied = getattr(src_func, "_patches", set()) if "allow_bots" in applied: - return func + return src_func - @wraps(func) + @wraps(src_func) async def wrapped(*args, **kwargs): g.bots_allowed = True - return await func(*args, **kwargs) + return await src_func(*args, **kwargs) applied.add("allow_bots") - setattr(func, "_patches", applied) + setattr(wrapped, "_patches", applied) + if len(applied) > 1: # pragma: no cover + delattr(src_func, "_patches") + return wrapped -def apply_oauth(func: T_route, scopes: list[str]) -> T_route: - applied = getattr(func, "_patches", set()) +def apply_oauth(src_func: T_route, scopes: list[str]) -> T_route: + applied = getattr(src_func, "_patches", set()) if "oauth" in applied: - return func + return src_func - @wraps(func) + @wraps(src_func) async def wrapped(*args, **kwargs): g.oauth_allowed = True g.oauth_scopes = set(scopes) - return await func(*args, **kwargs) + return await src_func(*args, **kwargs) applied.add("oauth") - setattr(func, "_patches", applied) + setattr(wrapped, "_patches", applied) + if len(applied) > 1: # pragma: no cover + delattr(src_func, "_patches") + return wrapped diff --git a/yepcord/yepcord/config.py b/yepcord/yepcord/config.py index 0e45bf0..c0350f0 100644 --- a/yepcord/yepcord/config.py +++ b/yepcord/yepcord/config.py @@ -133,6 +133,7 @@ class ConfigModel(BaseModel): BCRYPT_ROUNDS: int = 15 CAPTCHA: ConfigCaptcha = Field(default_factory=ConfigCaptcha) CONNECTIONS: ConfigConnections = Field(default_factory=ConfigConnections) + LAZY_INJECT: bool = False @field_validator("KEY") def validate_key(cls, value: str) -> str: @@ -167,6 +168,23 @@ class Config: class _Config(Singleton): + DB_CONNECT_STRING: str + MAIL_CONNECT_STRING: str + MIGRATIONS_DIR: str + KEY: str + PUBLIC_HOST: str + GATEWAY_HOST: str + CDN_HOST: str + STORAGE: dict + TENOR_KEY: Optional[str] + MESSAGE_BROKER: dict + REDIS_URL: Optional[str] + GATEWAY_KEEP_ALIVE_DELAY: int + BCRYPT_ROUNDS: int + CAPTCHA: dict + CONNECTIONS: dict + LAZY_INJECT: bool + def update(self, variables: dict) -> _Config: self.__dict__.update(variables) return self