Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (#38): add pydantic config option #42

Merged
merged 4 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ wtf
coverage.json
site
wtf.py
CODE_OF_CONDUCT.md
.DS_Store
2 changes: 1 addition & 1 deletion fast_depends/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""

__version__ = "2.2.5"
__version__ = "2.2.6"
40 changes: 25 additions & 15 deletions fast_depends/_compat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from typing import Any
from typing import Any, Dict, Optional, Type

from pydantic import BaseModel, create_model
from pydantic.version import VERSION as PYDANTIC_VERSION

__all__ = (
"BaseModel",
"FieldInfo",
"create_model",
"evaluate_forwardref",
"PYDANTIC_V2",
"get_config_base",
"get_model_fields",
"ConfigDict",
)


PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")

default_pydantic_config = {"arbitrary_types_allowed": True}

evaluate_forwardref: Any
# isort: off
if PYDANTIC_V2:
Expand All @@ -14,23 +28,19 @@
)
from pydantic.fields import FieldInfo

class CreateBaseModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict:
return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item]

def get_model_fields(model: Type[BaseModel]) -> Dict[str, FieldInfo]:
return model.model_fields

else:
from pydantic.fields import ModelField as FieldInfo # type: ignore
from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef]
from pydantic.config import get_config, ConfigDict, BaseConfig

class CreateBaseModel(BaseModel): # type: ignore[no-redef]
class Config:
arbitrary_types_allowed = True

def get_config_base(config_data: Optional[ConfigDict] = None) -> Type[BaseConfig]: # type: ignore[misc]
return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item]

__all__ = (
"BaseModel",
"CreateBaseModel",
"FieldInfo",
"create_model",
"evaluate_forwardref",
"PYDANTIC_V2",
)
def get_model_fields(model: Type[BaseModel]) -> Dict[str, FieldInfo]:
return model.__fields__ # type: ignore[return-value]
12 changes: 8 additions & 4 deletions fast_depends/core/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_origin,
)

from fast_depends._compat import CreateBaseModel, create_model
from fast_depends._compat import ConfigDict, create_model, get_config_base
from fast_depends.core.model import CallModel, ResponseModel
from fast_depends.dependencies import Depends
from fast_depends.library import CustomField
Expand All @@ -44,6 +44,7 @@ def build_call_model(
use_cache: bool = True,
is_sync: Optional[bool] = None,
extra_dependencies: Sequence[Depends] = (),
pydantic_config: Optional[ConfigDict] = None,
) -> CallModel[P, T]:
name = getattr(call, "__name__", type(call).__name__)

Expand Down Expand Up @@ -131,6 +132,7 @@ def build_call_model(
cast=dep.cast,
use_cache=dep.use_cache,
is_sync=is_sync,
pydantic_config=pydantic_config,
)

if dep.cast is True:
Expand Down Expand Up @@ -162,14 +164,15 @@ def build_call_model(

func_model = create_model( # type: ignore[call-overload]
name,
__base__=(CreateBaseModel,),
__config__=get_config_base(pydantic_config),
**class_fields,
)

response_model: Optional[Type[ResponseModel[T]]]
if cast and return_annotation and return_annotation is not inspect._empty:
response_model: Optional[Type[ResponseModel[T]]] = create_model(
response_model = create_model( # type: ignore[assignment]
"ResponseModel",
__base__=(CreateBaseModel,), # type: ignore[arg-type]
__config__=get_config_base(pydantic_config),
response=(return_annotation, ...),
)
else:
Expand All @@ -192,6 +195,7 @@ def build_call_model(
cast=d.cast,
use_cache=d.use_cache,
is_sync=is_sync,
pydantic_config=pydantic_config,
)
for d in extra_dependencies
],
Expand Down
116 changes: 66 additions & 50 deletions fast_depends/core/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from contextlib import AsyncExitStack, ExitStack
from inspect import _empty, unwrap
from typing import (
Any,
Awaitable,
Expand All @@ -17,7 +17,7 @@

from typing_extensions import ParamSpec, TypeVar, assert_never

from fast_depends._compat import PYDANTIC_V2, BaseModel, FieldInfo
from fast_depends._compat import BaseModel, FieldInfo, get_model_fields
from fast_depends.library import CustomField
from fast_depends.utils import (
async_map,
Expand Down Expand Up @@ -48,7 +48,7 @@ class CallModel(Generic[P, T]):
response_model: Optional[Type[ResponseModel[T]]]

params: Dict[str, FieldInfo]
alias_arguments: List[str]
alias_arguments: Tuple[str, ...]

dependencies: Dict[str, "CallModel[..., Any]"]
extra_dependencies: Iterable["CallModel[..., Any]"]
Expand Down Expand Up @@ -79,7 +79,8 @@ class CallModel(Generic[P, T]):

@property
def call_name(self) -> str:
return getattr(self.call, "__name__", type(self.call).__name__)
call = unwrap(self.call)
return getattr(call, "__name__", type(call).__name__)

@property
def real_params(self) -> Dict[str, FieldInfo]:
Expand Down Expand Up @@ -117,17 +118,13 @@ def __init__(
self.model = model
self.response_model = response_model

fields: Dict[str, FieldInfo]
if PYDANTIC_V2:
fields = self.model.model_fields
else:
fields = self.model.__fields__ # type: ignore
fields: Dict[str, FieldInfo] = get_model_fields(model)

self.dependencies = dependencies or {}
self.extra_dependencies = extra_dependencies or []
self.custom_fields = custom_fields or {}

self.alias_arguments = [f.alias or name for name, f in fields.items()]
self.alias_arguments = tuple(f.alias or name for name, f in fields.items())
self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())

Expand Down Expand Up @@ -168,85 +165,104 @@ def _solve(
]
] = None,
**kwargs: P.kwargs,
) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], Any, T]:
) -> Generator[
Tuple[
Iterable[Any],
Dict[str, Any],
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
],
Any,
T,
]:
if dependency_overrides:
self.call = dependency_overrides.get(self.call, self.call)
call = dependency_overrides.get(self.call, self.call)
assert self.is_async or not is_coroutine_callable(
self.call
call
), f"You cannot use async dependency `{self.call_name}` at sync main"
else:
call = self.call

if self.use_cache and self.call in cache_dependencies:
return cache_dependencies[self.call]

kw = {}

for arg in self.keyword_args:
v = kwargs.pop(arg, inspect._empty)
if v is not inspect._empty:
if (v := kwargs.pop(arg, _empty)) is not _empty:
kw[arg] = v

if "kwargs" in self.alias_arguments:
kw["kwargs"] = kwargs

else:
kw.update(kwargs)

has_args = "args" in self.alias_arguments

for arg in self.positional_args:
if args:
kw[arg], args = args[0], args[1:]
else:
break

if has_args:
if has_args := "args" in self.alias_arguments:
kw["args"] = args
keyword_args = self.keyword_args

else:
keyword_args = self.keyword_args + self.positional_args
for arg in self.keyword_args:
if args:
kw[arg], args = args[0], args[1:]
else:
break

solved_kw: Dict[str, Any]
solved_kw = yield (), kw
solved_kw = yield (), kw, call

args_: Iterable[Any]

casted_model: object
if self.cast:
casted_model = self.model(**solved_kw)
else:
casted_model = object()

kwargs_ = {
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in (
self.keyword_args + self.positional_args
if not has_args
else self.keyword_args
)
}
kwargs_.update(getattr(casted_model, "kwargs", {}))

args_: Iterable[Any]
if has_args:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self.positional_args
]
args_.extend(getattr(casted_model, "args", ()))
kwargs_ = {
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in keyword_args
}
kwargs_.update(getattr(casted_model, "kwargs", solved_kw.get("kwargs", {})))

if has_args:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self.positional_args
]
args_.extend(getattr(casted_model, "args", solved_kw.get("args", ())))
else:
args_ = ()

else:
args_ = ()
kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args}
kwargs_.update(solved_kw.get("kwargs", {}))

if has_args:
args_ = [solved_kw.get(arg) for arg in self.positional_args]
args_.extend(solved_kw.get("args", ()))
else:
args_ = ()

response: T
response = yield args_, kwargs_
response = yield args_, kwargs_, call

if self.cast and not self.is_generator:
response = self._cast_response(response)

if self.use_cache: # pragma: no branch
cache_dependencies[self.call] = response
cache_dependencies[call] = response

return response

def _cast_response(self, /, value: Any) -> Any:
if self.response_model is not None and self.cast:
if self.response_model is not None:
return self.response_model(response=value).response
else:
return value
Expand Down Expand Up @@ -285,7 +301,7 @@ def solve(
**kwargs,
)
try:
_, kwargs = next(cast_gen)
_, kwargs, _ = next(cast_gen)
except StopIteration as e:
cached_value: T = e.value
return cached_value
Expand All @@ -311,7 +327,7 @@ def solve(
for custom in self.custom_fields.values():
kwargs = custom.use(**kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)
final_args, final_kwargs, call = cast_gen.send(kwargs)

if self.is_generator and nested:
response = solve_generator_sync(
Expand All @@ -322,7 +338,7 @@ def solve(
)

else:
response = self.call(*final_args, **final_kwargs)
response = call(*final_args, **final_kwargs)

try:
cast_gen.send(response)
Expand Down Expand Up @@ -371,7 +387,7 @@ async def asolve(
**kwargs,
)
try:
_, kwargs = next(cast_gen)
_, kwargs, _ = next(cast_gen)
except StopIteration as e:
cached_value: T = e.value
return cached_value
Expand All @@ -397,7 +413,7 @@ async def asolve(
for custom in self.custom_fields.values():
kwargs = await run_async(custom.use, **kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)
final_args, final_kwargs, call = cast_gen.send(kwargs)

if self.is_generator and nested:
response = await solve_generator_async(
Expand All @@ -407,7 +423,7 @@ async def asolve(
**final_kwargs,
)
else:
response = await run_async(self.call, *final_args, **final_kwargs)
response = await run_async(call, *final_args, **final_kwargs)

try:
cast_gen.send(response)
Expand Down
Loading