From d525e8a6d70e7461291751f37538029bfb651fc6 Mon Sep 17 00:00:00 2001 From: Pastukhov Nikita Date: Fri, 15 Dec 2023 19:18:17 +0300 Subject: [PATCH] feat (#38): add pydantic config option (#42) --- .gitignore | 2 +- fast_depends/__about__.py | 2 +- fast_depends/_compat.py | 40 +++++---- fast_depends/core/build.py | 12 ++- fast_depends/core/model.py | 116 +++++++++++++++----------- fast_depends/dependencies/provider.py | 25 ++++-- fast_depends/use.py | 22 +++-- tests/async/test_cast.py | 16 ++++ tests/async/test_config.py | 27 ++++++ tests/async/test_depends.py | 13 +++ tests/sync/test_cast.py | 15 ++++ tests/sync/test_config.py | 26 ++++++ tests/sync/test_depends.py | 12 +++ tests/test_overrides.py | 17 ++++ 14 files changed, 262 insertions(+), 83 deletions(-) create mode 100644 tests/async/test_config.py create mode 100644 tests/sync/test_config.py diff --git a/.gitignore b/.gitignore index e3de911..9850df1 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,4 @@ wtf coverage.json site wtf.py -CODE_OF_CONDUCT.md \ No newline at end of file +.DS_Store \ No newline at end of file diff --git a/fast_depends/__about__.py b/fast_depends/__about__.py index 7dc15cd..3e5cdab 100644 --- a/fast_depends/__about__.py +++ b/fast_depends/__about__.py @@ -1,3 +1,3 @@ """FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System""" -__version__ = "2.2.5" +__version__ = "2.2.6" diff --git a/fast_depends/_compat.py b/fast_depends/_compat.py index 58b3853..9876e9a 100644 --- a/fast_depends/_compat.py +++ b/fast_depends/_compat.py @@ -1,10 +1,24 @@ -from typing import Any +from typing import Any, Dict, Optional, Type from pydantic import BaseModel, create_model from pydantic.version import VERSION as PYDANTIC_VERSION +__all__ = ( + "BaseModel", + "FieldInfo", + "create_model", + "evaluate_forwardref", + "PYDANTIC_V2", + "get_config_base", + "get_model_fields", + "ConfigDict", +) + + PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +default_pydantic_config = {"arbitrary_types_allowed": True} + evaluate_forwardref: Any # isort: off if PYDANTIC_V2: @@ -14,23 +28,19 @@ ) from pydantic.fields import FieldInfo - class CreateBaseModel(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) + def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict: + return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item] + + def get_model_fields(model: Type[BaseModel]) -> Dict[str, FieldInfo]: + return model.model_fields else: from pydantic.fields import ModelField as FieldInfo # type: ignore from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef] + from pydantic.config import get_config, ConfigDict, BaseConfig - class CreateBaseModel(BaseModel): # type: ignore[no-redef] - class Config: - arbitrary_types_allowed = True - + def get_config_base(config_data: Optional[ConfigDict] = None) -> Type[BaseConfig]: # type: ignore[misc] + return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item] -__all__ = ( - "BaseModel", - "CreateBaseModel", - "FieldInfo", - "create_model", - "evaluate_forwardref", - "PYDANTIC_V2", -) + def get_model_fields(model: Type[BaseModel]) -> Dict[str, FieldInfo]: + return model.__fields__ # type: ignore[return-value] diff --git a/fast_depends/core/build.py b/fast_depends/core/build.py index 811641b..1c31912 100644 --- a/fast_depends/core/build.py +++ b/fast_depends/core/build.py @@ -21,7 +21,7 @@ get_origin, ) -from fast_depends._compat import CreateBaseModel, create_model +from fast_depends._compat import ConfigDict, create_model, get_config_base from fast_depends.core.model import CallModel, ResponseModel from fast_depends.dependencies import Depends from fast_depends.library import CustomField @@ -44,6 +44,7 @@ def build_call_model( use_cache: bool = True, is_sync: Optional[bool] = None, extra_dependencies: Sequence[Depends] = (), + pydantic_config: Optional[ConfigDict] = None, ) -> CallModel[P, T]: name = getattr(call, "__name__", type(call).__name__) @@ -131,6 +132,7 @@ def build_call_model( cast=dep.cast, use_cache=dep.use_cache, is_sync=is_sync, + pydantic_config=pydantic_config, ) if dep.cast is True: @@ -162,14 +164,15 @@ def build_call_model( func_model = create_model( # type: ignore[call-overload] name, - __base__=(CreateBaseModel,), + __config__=get_config_base(pydantic_config), **class_fields, ) + response_model: Optional[Type[ResponseModel[T]]] if cast and return_annotation and return_annotation is not inspect._empty: - response_model: Optional[Type[ResponseModel[T]]] = create_model( + response_model = create_model( # type: ignore[assignment] "ResponseModel", - __base__=(CreateBaseModel,), # type: ignore[arg-type] + __config__=get_config_base(pydantic_config), response=(return_annotation, ...), ) else: @@ -192,6 +195,7 @@ def build_call_model( cast=d.cast, use_cache=d.use_cache, is_sync=is_sync, + pydantic_config=pydantic_config, ) for d in extra_dependencies ], diff --git a/fast_depends/core/model.py b/fast_depends/core/model.py index d9e2ec1..724e92c 100644 --- a/fast_depends/core/model.py +++ b/fast_depends/core/model.py @@ -1,5 +1,5 @@ -import inspect from contextlib import AsyncExitStack, ExitStack +from inspect import _empty, unwrap from typing import ( Any, Awaitable, @@ -17,7 +17,7 @@ from typing_extensions import ParamSpec, TypeVar, assert_never -from fast_depends._compat import PYDANTIC_V2, BaseModel, FieldInfo +from fast_depends._compat import BaseModel, FieldInfo, get_model_fields from fast_depends.library import CustomField from fast_depends.utils import ( async_map, @@ -48,7 +48,7 @@ class CallModel(Generic[P, T]): response_model: Optional[Type[ResponseModel[T]]] params: Dict[str, FieldInfo] - alias_arguments: List[str] + alias_arguments: Tuple[str, ...] dependencies: Dict[str, "CallModel[..., Any]"] extra_dependencies: Iterable["CallModel[..., Any]"] @@ -79,7 +79,8 @@ class CallModel(Generic[P, T]): @property def call_name(self) -> str: - return getattr(self.call, "__name__", type(self.call).__name__) + call = unwrap(self.call) + return getattr(call, "__name__", type(call).__name__) @property def real_params(self) -> Dict[str, FieldInfo]: @@ -117,17 +118,13 @@ def __init__( self.model = model self.response_model = response_model - fields: Dict[str, FieldInfo] - if PYDANTIC_V2: - fields = self.model.model_fields - else: - fields = self.model.__fields__ # type: ignore + fields: Dict[str, FieldInfo] = get_model_fields(model) self.dependencies = dependencies or {} self.extra_dependencies = extra_dependencies or [] self.custom_fields = custom_fields or {} - self.alias_arguments = [f.alias or name for name, f in fields.items()] + self.alias_arguments = tuple(f.alias or name for name, f in fields.items()) self.keyword_args = tuple(keyword_args or ()) self.positional_args = tuple(positional_args or ()) @@ -168,12 +165,25 @@ def _solve( ] ] = None, **kwargs: P.kwargs, - ) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], Any, T]: + ) -> Generator[ + Tuple[ + Iterable[Any], + Dict[str, Any], + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + ], + Any, + T, + ]: if dependency_overrides: - self.call = dependency_overrides.get(self.call, self.call) + call = dependency_overrides.get(self.call, self.call) assert self.is_async or not is_coroutine_callable( - self.call + call ), f"You cannot use async dependency `{self.call_name}` at sync main" + else: + call = self.call if self.use_cache and self.call in cache_dependencies: return cache_dependencies[self.call] @@ -181,72 +191,78 @@ def _solve( kw = {} for arg in self.keyword_args: - v = kwargs.pop(arg, inspect._empty) - if v is not inspect._empty: + if (v := kwargs.pop(arg, _empty)) is not _empty: kw[arg] = v if "kwargs" in self.alias_arguments: kw["kwargs"] = kwargs - else: kw.update(kwargs) - has_args = "args" in self.alias_arguments - for arg in self.positional_args: if args: kw[arg], args = args[0], args[1:] + else: + break - if has_args: + if has_args := "args" in self.alias_arguments: kw["args"] = args + keyword_args = self.keyword_args else: + keyword_args = self.keyword_args + self.positional_args for arg in self.keyword_args: if args: kw[arg], args = args[0], args[1:] + else: + break solved_kw: Dict[str, Any] - solved_kw = yield (), kw + solved_kw = yield (), kw, call + + args_: Iterable[Any] - casted_model: object if self.cast: casted_model = self.model(**solved_kw) - else: - casted_model = object() - - kwargs_ = { - arg: getattr(casted_model, arg, solved_kw.get(arg)) - for arg in ( - self.keyword_args + self.positional_args - if not has_args - else self.keyword_args - ) - } - kwargs_.update(getattr(casted_model, "kwargs", {})) - args_: Iterable[Any] - if has_args: - args_ = [ - getattr(casted_model, arg, solved_kw.get(arg)) - for arg in self.positional_args - ] - args_.extend(getattr(casted_model, "args", ())) + kwargs_ = { + arg: getattr(casted_model, arg, solved_kw.get(arg)) + for arg in keyword_args + } + kwargs_.update(getattr(casted_model, "kwargs", solved_kw.get("kwargs", {}))) + + if has_args: + args_ = [ + getattr(casted_model, arg, solved_kw.get(arg)) + for arg in self.positional_args + ] + args_.extend(getattr(casted_model, "args", solved_kw.get("args", ()))) + else: + args_ = () + else: - args_ = () + kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args} + kwargs_.update(solved_kw.get("kwargs", {})) + + if has_args: + args_ = [solved_kw.get(arg) for arg in self.positional_args] + args_.extend(solved_kw.get("args", ())) + else: + args_ = () response: T - response = yield args_, kwargs_ + response = yield args_, kwargs_, call if self.cast and not self.is_generator: response = self._cast_response(response) if self.use_cache: # pragma: no branch - cache_dependencies[self.call] = response + cache_dependencies[call] = response return response def _cast_response(self, /, value: Any) -> Any: - if self.response_model is not None and self.cast: + if self.response_model is not None: return self.response_model(response=value).response else: return value @@ -285,7 +301,7 @@ def solve( **kwargs, ) try: - _, kwargs = next(cast_gen) + _, kwargs, _ = next(cast_gen) except StopIteration as e: cached_value: T = e.value return cached_value @@ -311,7 +327,7 @@ def solve( for custom in self.custom_fields.values(): kwargs = custom.use(**kwargs) - final_args, final_kwargs = cast_gen.send(kwargs) + final_args, final_kwargs, call = cast_gen.send(kwargs) if self.is_generator and nested: response = solve_generator_sync( @@ -322,7 +338,7 @@ def solve( ) else: - response = self.call(*final_args, **final_kwargs) + response = call(*final_args, **final_kwargs) try: cast_gen.send(response) @@ -371,7 +387,7 @@ async def asolve( **kwargs, ) try: - _, kwargs = next(cast_gen) + _, kwargs, _ = next(cast_gen) except StopIteration as e: cached_value: T = e.value return cached_value @@ -397,7 +413,7 @@ async def asolve( for custom in self.custom_fields.values(): kwargs = await run_async(custom.use, **kwargs) - final_args, final_kwargs = cast_gen.send(kwargs) + final_args, final_kwargs, call = cast_gen.send(kwargs) if self.is_generator and nested: response = await solve_generator_async( @@ -407,7 +423,7 @@ async def asolve( **final_kwargs, ) else: - response = await run_async(self.call, *final_args, **final_kwargs) + response = await run_async(call, *final_args, **final_kwargs) try: cast_gen.send(response) diff --git a/fast_depends/dependencies/provider.py b/fast_depends/dependencies/provider.py index c9355d2..f3e2f0f 100644 --- a/fast_depends/dependencies/provider.py +++ b/fast_depends/dependencies/provider.py @@ -1,17 +1,32 @@ -from typing import Any, Callable, Dict +from contextlib import contextmanager +from typing import Any, Callable, Dict, Iterator class Provider: + dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] + def __init__(self) -> None: - self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {} + self.dependency_overrides = {} + + def clear(self) -> None: + self.dependency_overrides = {} def override( - self, original: Callable[..., Any], override: Callable[..., Any] + self, + original: Callable[..., Any], + override: Callable[..., Any], ) -> None: self.dependency_overrides[original] = override - def clear(self) -> None: - self.dependency_overrides = {} + @contextmanager + def scope( + self, + original: Callable[..., Any], + override: Callable[..., Any], + ) -> Iterator[None]: + self.dependency_overrides[original] = override + yield + self.dependency_overrides.pop(original, None) dependency_provider = Provider() diff --git a/fast_depends/use.py b/fast_depends/use.py index fb82dcc..4985088 100644 --- a/fast_depends/use.py +++ b/fast_depends/use.py @@ -13,6 +13,7 @@ from typing_extensions import ParamSpec, Protocol, TypeVar +from fast_depends._compat import ConfigDict from fast_depends.core import CallModel, build_call_model from fast_depends.dependencies import dependency_provider, model @@ -46,10 +47,11 @@ def __call__( def inject( # pragma: no cover func: None, *, - dependency_overrides_provider: Optional[Any] = dependency_provider, + cast: bool = True, extra_dependencies: Sequence[model.Depends] = (), + pydantic_config: Optional[ConfigDict] = None, + dependency_overrides_provider: Optional[Any] = dependency_provider, wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x, - cast: bool = True, ) -> _InjectWrapper[P, T]: ... @@ -58,10 +60,11 @@ def inject( # pragma: no cover def inject( # pragma: no cover func: Callable[P, T], *, - dependency_overrides_provider: Optional[Any] = dependency_provider, + cast: bool = True, extra_dependencies: Sequence[model.Depends] = (), + pydantic_config: Optional[ConfigDict] = None, + dependency_overrides_provider: Optional[Any] = dependency_provider, wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x, - cast: bool = True, ) -> Callable[P, T]: ... @@ -69,16 +72,18 @@ def inject( # pragma: no cover def inject( func: Optional[Callable[P, T]] = None, *, - dependency_overrides_provider: Optional[Any] = dependency_provider, + cast: bool = True, extra_dependencies: Sequence[model.Depends] = (), + pydantic_config: Optional[ConfigDict] = None, + dependency_overrides_provider: Optional[Any] = dependency_provider, wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x, - cast: bool = True, ) -> Union[Callable[P, T], _InjectWrapper[P, T],]: decorator = _wrap_inject( dependency_overrides_provider=dependency_overrides_provider, wrap_model=wrap_model, extra_dependencies=extra_dependencies, cast=cast, + pydantic_config=pydantic_config, ) if func is None: @@ -96,6 +101,7 @@ def _wrap_inject( ], extra_dependencies: Sequence[model.Depends], cast: bool, + pydantic_config: Optional[ConfigDict], ) -> _InjectWrapper[P, T]: if ( dependency_overrides_provider @@ -113,9 +119,10 @@ def func_wrapper( if model is None: real_model = wrap_model( build_call_model( - func, + call=func, extra_dependencies=extra_dependencies, cast=cast, + pydantic_config=pydantic_config, ) ) else: @@ -162,6 +169,7 @@ def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: **kwargs, ) return r + raise AssertionError("unreachable") return injected_wrapper diff --git a/tests/async/test_cast.py b/tests/async/test_cast.py index a3007c4..2fb958e 100644 --- a/tests/async/test_cast.py +++ b/tests/async/test_cast.py @@ -189,3 +189,19 @@ async def simple_func(a: str) -> int: async for i in simple_func("1"): assert i == 1 + + +@pytest.mark.anyio +async def test_args_kwargs_without_cast(): + @inject(cast=False) + async def simple_func( + a: int, + *args: Tuple[float, ...], + b: int, + **kwargs: Dict[str, int], + ): + return a, args, b, kwargs + + assert (1.0, (2.0, 3), 3.0, {"key": 1.0}) == await simple_func( + 1.0, 2.0, 3, b=3.0, key=1.0 + ) diff --git a/tests/async/test_config.py b/tests/async/test_config.py new file mode 100644 index 0000000..25bfd4b --- /dev/null +++ b/tests/async/test_config.py @@ -0,0 +1,27 @@ +import pytest +from pydantic import ValidationError + +from fast_depends import Depends, inject +from fast_depends._compat import PYDANTIC_V2 + + +async def dep(a: str): + return a + + +@inject(pydantic_config={"str_max_length" if PYDANTIC_V2 else "max_anystr_length": 1}) +async def limited_str(a=Depends(dep)): + ... + + +@inject() +async def regular(a=Depends(dep)): + return a + + +@pytest.mark.anyio +async def test_config(): + await regular("123") + + with pytest.raises(ValidationError): + await limited_str("123") diff --git a/tests/async/test_depends.py b/tests/async/test_depends.py index 72a1b75..04d557c 100644 --- a/tests/async/test_depends.py +++ b/tests/async/test_depends.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from functools import partial from unittest.mock import Mock import pytest @@ -350,3 +351,15 @@ async def simple_func(a: str, d=Depends(func)) -> int: assert i == 1 mock.end.assert_called_once() + + +@pytest.mark.anyio +async def test_partial(): + async def dep(a): + return a + + @inject + async def func(a=Depends(partial(dep, 10))): + return a + + assert await func() == 10 diff --git a/tests/sync/test_cast.py b/tests/sync/test_cast.py index 5eced31..e2d8364 100644 --- a/tests/sync/test_cast.py +++ b/tests/sync/test_cast.py @@ -184,3 +184,18 @@ def simple_func(a: str) -> int: for i in simple_func("1"): assert i == 1 + + +def test_args_kwargs_without_cast(): + @inject(cast=False) + def simple_func( + a: int, + *args: Tuple[float, ...], + b: int, + **kwargs: Dict[str, int], + ): + return a, args, b, kwargs + + assert (1.0, (2.0, 3), 3.0, {"key": 1.0}) == simple_func( + 1.0, 2.0, 3, b=3.0, key=1.0 + ) diff --git a/tests/sync/test_config.py b/tests/sync/test_config.py new file mode 100644 index 0000000..eba08a3 --- /dev/null +++ b/tests/sync/test_config.py @@ -0,0 +1,26 @@ +import pytest +from pydantic import ValidationError + +from fast_depends import Depends, inject +from fast_depends._compat import PYDANTIC_V2 + + +def dep(a: str): + return a + + +@inject(pydantic_config={"str_max_length" if PYDANTIC_V2 else "max_anystr_length": 1}) +def limited_str(a=Depends(dep)): + ... + + +@inject() +def regular(a=Depends(dep)): + return a + + +def test_config(): + regular("123") + + with pytest.raises(ValidationError): + limited_str("123") diff --git a/tests/sync/test_depends.py b/tests/sync/test_depends.py index aed57c2..b3949ff 100644 --- a/tests/sync/test_depends.py +++ b/tests/sync/test_depends.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from functools import partial from unittest.mock import Mock import pytest @@ -244,3 +245,14 @@ def simple_func(a: str, d=Depends(func)) -> int: assert i == 1 mock.end.assert_called_once() + + +def test_partial(): + def dep(a): + return a + + @inject + def func(a=Depends(partial(dep, 10))): + return a + + assert func() == 10 diff --git a/tests/test_overrides.py b/tests/test_overrides.py index 45a7ed8..da3f899 100644 --- a/tests/test_overrides.py +++ b/tests/test_overrides.py @@ -50,6 +50,23 @@ def func(d=Depends(base_dep)): assert not mock.original.called +def test_override_context(provider): + def base_dep(): + return 1 + + def override_dep(): + return 2 + + @inject + def func(d=Depends(base_dep)): + return d + + with provider.scope(base_dep, override_dep): + assert func() == 2 + + assert func() == 1 + + def test_sync_by_async_override(provider): def base_dep(): # pragma: no cover return 1