Skip to content

Commit

Permalink
feat (#38): add pydantic config option (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik authored Dec 15, 2023
1 parent 9270d65 commit d525e8a
Show file tree
Hide file tree
Showing 14 changed files with 262 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ wtf
coverage.json
site
wtf.py
CODE_OF_CONDUCT.md
.DS_Store
2 changes: 1 addition & 1 deletion fast_depends/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""

__version__ = "2.2.5"
__version__ = "2.2.6"
40 changes: 25 additions & 15 deletions fast_depends/_compat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from typing import Any
from typing import Any, Dict, Optional, Type

from pydantic import BaseModel, create_model
from pydantic.version import VERSION as PYDANTIC_VERSION

__all__ = (
"BaseModel",
"FieldInfo",
"create_model",
"evaluate_forwardref",
"PYDANTIC_V2",
"get_config_base",
"get_model_fields",
"ConfigDict",
)


PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")

default_pydantic_config = {"arbitrary_types_allowed": True}

evaluate_forwardref: Any
# isort: off
if PYDANTIC_V2:
Expand All @@ -14,23 +28,19 @@
)
from pydantic.fields import FieldInfo

class CreateBaseModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict:
return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item]

def get_model_fields(model: Type[BaseModel]) -> Dict[str, FieldInfo]:
return model.model_fields

else:
from pydantic.fields import ModelField as FieldInfo # type: ignore
from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef]
from pydantic.config import get_config, ConfigDict, BaseConfig

class CreateBaseModel(BaseModel): # type: ignore[no-redef]
class Config:
arbitrary_types_allowed = True

def get_config_base(config_data: Optional[ConfigDict] = None) -> Type[BaseConfig]: # type: ignore[misc]
return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item]

__all__ = (
"BaseModel",
"CreateBaseModel",
"FieldInfo",
"create_model",
"evaluate_forwardref",
"PYDANTIC_V2",
)
def get_model_fields(model: Type[BaseModel]) -> Dict[str, FieldInfo]:
return model.__fields__ # type: ignore[return-value]
12 changes: 8 additions & 4 deletions fast_depends/core/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_origin,
)

from fast_depends._compat import CreateBaseModel, create_model
from fast_depends._compat import ConfigDict, create_model, get_config_base
from fast_depends.core.model import CallModel, ResponseModel
from fast_depends.dependencies import Depends
from fast_depends.library import CustomField
Expand All @@ -44,6 +44,7 @@ def build_call_model(
use_cache: bool = True,
is_sync: Optional[bool] = None,
extra_dependencies: Sequence[Depends] = (),
pydantic_config: Optional[ConfigDict] = None,
) -> CallModel[P, T]:
name = getattr(call, "__name__", type(call).__name__)

Expand Down Expand Up @@ -131,6 +132,7 @@ def build_call_model(
cast=dep.cast,
use_cache=dep.use_cache,
is_sync=is_sync,
pydantic_config=pydantic_config,
)

if dep.cast is True:
Expand Down Expand Up @@ -162,14 +164,15 @@ def build_call_model(

func_model = create_model( # type: ignore[call-overload]
name,
__base__=(CreateBaseModel,),
__config__=get_config_base(pydantic_config),
**class_fields,
)

response_model: Optional[Type[ResponseModel[T]]]
if cast and return_annotation and return_annotation is not inspect._empty:
response_model: Optional[Type[ResponseModel[T]]] = create_model(
response_model = create_model( # type: ignore[assignment]
"ResponseModel",
__base__=(CreateBaseModel,), # type: ignore[arg-type]
__config__=get_config_base(pydantic_config),
response=(return_annotation, ...),
)
else:
Expand All @@ -192,6 +195,7 @@ def build_call_model(
cast=d.cast,
use_cache=d.use_cache,
is_sync=is_sync,
pydantic_config=pydantic_config,
)
for d in extra_dependencies
],
Expand Down
116 changes: 66 additions & 50 deletions fast_depends/core/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from contextlib import AsyncExitStack, ExitStack
from inspect import _empty, unwrap
from typing import (
Any,
Awaitable,
Expand All @@ -17,7 +17,7 @@

from typing_extensions import ParamSpec, TypeVar, assert_never

from fast_depends._compat import PYDANTIC_V2, BaseModel, FieldInfo
from fast_depends._compat import BaseModel, FieldInfo, get_model_fields
from fast_depends.library import CustomField
from fast_depends.utils import (
async_map,
Expand Down Expand Up @@ -48,7 +48,7 @@ class CallModel(Generic[P, T]):
response_model: Optional[Type[ResponseModel[T]]]

params: Dict[str, FieldInfo]
alias_arguments: List[str]
alias_arguments: Tuple[str, ...]

dependencies: Dict[str, "CallModel[..., Any]"]
extra_dependencies: Iterable["CallModel[..., Any]"]
Expand Down Expand Up @@ -79,7 +79,8 @@ class CallModel(Generic[P, T]):

@property
def call_name(self) -> str:
return getattr(self.call, "__name__", type(self.call).__name__)
call = unwrap(self.call)
return getattr(call, "__name__", type(call).__name__)

@property
def real_params(self) -> Dict[str, FieldInfo]:
Expand Down Expand Up @@ -117,17 +118,13 @@ def __init__(
self.model = model
self.response_model = response_model

fields: Dict[str, FieldInfo]
if PYDANTIC_V2:
fields = self.model.model_fields
else:
fields = self.model.__fields__ # type: ignore
fields: Dict[str, FieldInfo] = get_model_fields(model)

self.dependencies = dependencies or {}
self.extra_dependencies = extra_dependencies or []
self.custom_fields = custom_fields or {}

self.alias_arguments = [f.alias or name for name, f in fields.items()]
self.alias_arguments = tuple(f.alias or name for name, f in fields.items())
self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())

Expand Down Expand Up @@ -168,85 +165,104 @@ def _solve(
]
] = None,
**kwargs: P.kwargs,
) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], Any, T]:
) -> Generator[
Tuple[
Iterable[Any],
Dict[str, Any],
Union[
Callable[P, T],
Callable[P, Awaitable[T]],
],
],
Any,
T,
]:
if dependency_overrides:
self.call = dependency_overrides.get(self.call, self.call)
call = dependency_overrides.get(self.call, self.call)
assert self.is_async or not is_coroutine_callable(
self.call
call
), f"You cannot use async dependency `{self.call_name}` at sync main"
else:
call = self.call

if self.use_cache and self.call in cache_dependencies:
return cache_dependencies[self.call]

kw = {}

for arg in self.keyword_args:
v = kwargs.pop(arg, inspect._empty)
if v is not inspect._empty:
if (v := kwargs.pop(arg, _empty)) is not _empty:
kw[arg] = v

if "kwargs" in self.alias_arguments:
kw["kwargs"] = kwargs

else:
kw.update(kwargs)

has_args = "args" in self.alias_arguments

for arg in self.positional_args:
if args:
kw[arg], args = args[0], args[1:]
else:
break

if has_args:
if has_args := "args" in self.alias_arguments:
kw["args"] = args
keyword_args = self.keyword_args

else:
keyword_args = self.keyword_args + self.positional_args
for arg in self.keyword_args:
if args:
kw[arg], args = args[0], args[1:]
else:
break

solved_kw: Dict[str, Any]
solved_kw = yield (), kw
solved_kw = yield (), kw, call

args_: Iterable[Any]

casted_model: object
if self.cast:
casted_model = self.model(**solved_kw)
else:
casted_model = object()

kwargs_ = {
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in (
self.keyword_args + self.positional_args
if not has_args
else self.keyword_args
)
}
kwargs_.update(getattr(casted_model, "kwargs", {}))

args_: Iterable[Any]
if has_args:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self.positional_args
]
args_.extend(getattr(casted_model, "args", ()))
kwargs_ = {
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in keyword_args
}
kwargs_.update(getattr(casted_model, "kwargs", solved_kw.get("kwargs", {})))

if has_args:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self.positional_args
]
args_.extend(getattr(casted_model, "args", solved_kw.get("args", ())))
else:
args_ = ()

else:
args_ = ()
kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args}
kwargs_.update(solved_kw.get("kwargs", {}))

if has_args:
args_ = [solved_kw.get(arg) for arg in self.positional_args]
args_.extend(solved_kw.get("args", ()))
else:
args_ = ()

response: T
response = yield args_, kwargs_
response = yield args_, kwargs_, call

if self.cast and not self.is_generator:
response = self._cast_response(response)

if self.use_cache: # pragma: no branch
cache_dependencies[self.call] = response
cache_dependencies[call] = response

return response

def _cast_response(self, /, value: Any) -> Any:
if self.response_model is not None and self.cast:
if self.response_model is not None:
return self.response_model(response=value).response
else:
return value
Expand Down Expand Up @@ -285,7 +301,7 @@ def solve(
**kwargs,
)
try:
_, kwargs = next(cast_gen)
_, kwargs, _ = next(cast_gen)
except StopIteration as e:
cached_value: T = e.value
return cached_value
Expand All @@ -311,7 +327,7 @@ def solve(
for custom in self.custom_fields.values():
kwargs = custom.use(**kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)
final_args, final_kwargs, call = cast_gen.send(kwargs)

if self.is_generator and nested:
response = solve_generator_sync(
Expand All @@ -322,7 +338,7 @@ def solve(
)

else:
response = self.call(*final_args, **final_kwargs)
response = call(*final_args, **final_kwargs)

try:
cast_gen.send(response)
Expand Down Expand Up @@ -371,7 +387,7 @@ async def asolve(
**kwargs,
)
try:
_, kwargs = next(cast_gen)
_, kwargs, _ = next(cast_gen)
except StopIteration as e:
cached_value: T = e.value
return cached_value
Expand All @@ -397,7 +413,7 @@ async def asolve(
for custom in self.custom_fields.values():
kwargs = await run_async(custom.use, **kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)
final_args, final_kwargs, call = cast_gen.send(kwargs)

if self.is_generator and nested:
response = await solve_generator_async(
Expand All @@ -407,7 +423,7 @@ async def asolve(
**final_kwargs,
)
else:
response = await run_async(self.call, *final_args, **final_kwargs)
response = await run_async(call, *final_args, **final_kwargs)

try:
cast_gen.send(response)
Expand Down
Loading

0 comments on commit d525e8a

Please sign in to comment.