Skip to content

Commit

Permalink
fix: support custom fields for classes injection
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Dec 14, 2023
1 parent 5fdb947 commit d7eb0b0
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 91 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ from fast_depends import inject
from fast_depends.library import CustomField

class Header(CustomField):
def use(self, **kwargs: AnyDict) -> AnyDict:
def use(self, /, **kwargs: AnyDict) -> AnyDict:
kwargs = super().use(**kwargs)
kwargs[self.param_name] = kwargs["headers"][self.param_name]
return kwargs
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ from fast_depends import inject
from fast_depends.library import CustomField

class Header(CustomField):
def use(self, **kwargs: AnyDict) -> AnyDict:
def use(self, /, **kwargs: AnyDict) -> AnyDict:
kwargs = super().use(**kwargs)
kwargs[self.param_name] = kwargs["headers"][self.param_name]
return kwargs
Expand Down
2 changes: 1 addition & 1 deletion docs/docs_src/advanced/custom/class_declaration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fast_depends.library import CustomField

class Header(CustomField):
def use(self, **kwargs):
def use(self, /, **kwargs):
kwargs = super().use(**kwargs)
kwargs[self.param_name] = kwargs["headers"][self.param_name]
return kwargs
2 changes: 1 addition & 1 deletion docs/docs_src/advanced/custom/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from starlette.routing import Route

class Path(CustomField):
def use(self, *, request, **kwargs):
def use(self, /, *, request, **kwargs):
return {
**super().use(request=request, **kwargs),
self.param_name: request.path_params.get(self.param_name)
Expand Down
2 changes: 1 addition & 1 deletion fast_depends/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""

__version__ = "2.2.4"
__version__ = "2.2.5"
150 changes: 77 additions & 73 deletions fast_depends/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,26 @@ class CallModel(Generic[P, T]):
)

@property
def call_name(self__) -> str:
return getattr(self__.call, "__name__", type(self__.call).__name__)
def call_name(self) -> str:
return getattr(self.call, "__name__", type(self.call).__name__)

@property
def real_params(self__) -> Dict[str, FieldInfo]:
params = self__.params.copy()
for name in self__.custom_fields.keys():
def real_params(self) -> Dict[str, FieldInfo]:
params = self.params.copy()
for name in self.custom_fields.keys():
params.pop(name, None)
return params

@property
def flat_params(self__) -> Dict[str, FieldInfo]:
params = self__.real_params
for d in self__.dependencies.values():
def flat_params(self) -> Dict[str, FieldInfo]:
params = self.real_params
for d in self.dependencies.values():
params.update(d.flat_params)
return params

def __init__(
self__,
self,
/,
call: Union[
Callable[P, T],
Callable[P, Awaitable[T]],
Expand All @@ -112,39 +113,40 @@ def __init__(
positional_args: Optional[List[str]] = None,
custom_fields: Optional[Dict[str, CustomField]] = None,
):
self__.call = call
self__.model = model
self__.response_model = response_model
self.call = call
self.model = model
self.response_model = response_model

fields: Dict[str, FieldInfo]
if PYDANTIC_V2:
fields = self__.model.model_fields
fields = self.model.model_fields
else:
fields = self__.model.__fields__ # type: ignore
fields = self.model.__fields__ # type: ignore

self__.dependencies = dependencies or {}
self__.extra_dependencies = extra_dependencies or []
self__.custom_fields = custom_fields or {}
self.dependencies = dependencies or {}
self.extra_dependencies = extra_dependencies or []
self.custom_fields = custom_fields or {}

self__.alias_arguments = [f.alias or name for name, f in fields.items()]
self__.keyword_args = tuple(keyword_args or ())
self__.positional_args = tuple(positional_args or ())
self.alias_arguments = [f.alias or name for name, f in fields.items()]
self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())

self__.params = fields.copy()
for name in self__.dependencies.keys():
self__.params.pop(name, None)
self.params = fields.copy()
for name in self.dependencies.keys():
self.params.pop(name, None)

self__.use_cache = use_cache
self__.cast = cast
self__.is_async = (
is_async or is_coroutine_callable(call) or is_async_gen_callable(self__.call)
self.use_cache = use_cache
self.cast = cast
self.is_async = (
is_async or is_coroutine_callable(call) or is_async_gen_callable(self.call)
)
self__.is_generator = is_gen_callable(self__.call) or is_async_gen_callable(
self__.call
self.is_generator = is_gen_callable(self.call) or is_async_gen_callable(
self.call
)

def _solve(
self__,
self,
/,
*args: P.args,
cache_dependencies: Dict[
Union[
Expand All @@ -168,56 +170,56 @@ def _solve(
**kwargs: P.kwargs,
) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], Any, T]:
if dependency_overrides:
self__.call = dependency_overrides.get(self__.call, self__.call)
assert self__.is_async or not is_coroutine_callable(
self__.call
), f"You cannot use async dependency `{self__.call_name}` at sync main"
self.call = dependency_overrides.get(self.call, self.call)
assert self.is_async or not is_coroutine_callable(
self.call
), f"You cannot use async dependency `{self.call_name}` at sync main"

if self__.use_cache and self__.call in cache_dependencies:
return cache_dependencies[self__.call]
if self.use_cache and self.call in cache_dependencies:
return cache_dependencies[self.call]

kw = {}

for arg in self__.keyword_args:
for arg in self.keyword_args:
v = kwargs.pop(arg, inspect._empty)
if v is not inspect._empty:
kw[arg] = v

if "kwargs" in self__.alias_arguments:
if "kwargs" in self.alias_arguments:
kw["kwargs"] = kwargs

else:
kw.update(kwargs)

has_args = "args" in self__.alias_arguments
has_args = "args" in self.alias_arguments

for arg in self__.positional_args:
for arg in self.positional_args:
if args:
kw[arg], args = args[0], args[1:]

if has_args:
kw["args"] = args

else:
for arg in self__.keyword_args:
for arg in self.keyword_args:
if args:
kw[arg], args = args[0], args[1:]

solved_kw: Dict[str, Any]
solved_kw = yield (), kw

casted_model: object
if self__.cast:
casted_model = self__.model(**solved_kw)
if self.cast:
casted_model = self.model(**solved_kw)
else:
casted_model = object()

kwargs_ = {
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in (
self__.keyword_args + self__.positional_args
self.keyword_args + self.positional_args
if not has_args
else self__.keyword_args
else self.keyword_args
)
}
kwargs_.update(getattr(casted_model, "kwargs", {}))
Expand All @@ -226,7 +228,7 @@ def _solve(
if has_args:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self__.positional_args
for arg in self.positional_args
]
args_.extend(getattr(casted_model, "args", ()))
else:
Expand All @@ -235,22 +237,23 @@ def _solve(
response: T
response = yield args_, kwargs_

if self__.cast and not self__.is_generator:
response = self__._cast_response(response)
if self.cast and not self.is_generator:
response = self._cast_response(response)

if self__.use_cache: # pragma: no branch
cache_dependencies[self__.call] = response
if self.use_cache: # pragma: no branch
cache_dependencies[self.call] = response

return response

def _cast_response(self__, value: Any) -> Any:
if self__.response_model is not None and self__.cast:
return self__.response_model(response=value).response
def _cast_response(self, /, value: Any) -> Any:
if self.response_model is not None and self.cast:
return self.response_model(response=value).response
else:
return value

def solve(
self__,
self,
/,
*args: P.args,
stack: ExitStack,
cache_dependencies: Dict[
Expand All @@ -275,7 +278,7 @@ def solve(
nested: bool = False,
**kwargs: P.kwargs,
) -> T:
cast_gen = self__._solve(
cast_gen = self._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
Expand All @@ -287,7 +290,7 @@ def solve(
cached_value: T = e.value
return cached_value

for dep in self__.extra_dependencies:
for dep in self.extra_dependencies:
dep.solve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -296,7 +299,7 @@ def solve(
**kwargs,
)

for dep_arg, dep in self__.dependencies.items():
for dep_arg, dep in self.dependencies.items():
kwargs[dep_arg] = dep.solve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -305,37 +308,38 @@ def solve(
**kwargs,
)

for custom in self__.custom_fields.values():
for custom in self.custom_fields.values():
kwargs = custom.use(**kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)

if self__.is_generator and nested:
if self.is_generator and nested:
response = solve_generator_sync(
*final_args,
call=self__.call,
call=self.call,
stack=stack,
**final_kwargs,
)

else:
response = self__.call(*final_args, **final_kwargs)
response = self.call(*final_args, **final_kwargs)

try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value

if not self__.cast or nested or not self__.is_generator:
if not self.cast or nested or not self.is_generator:
return value

else:
return map(self__._cast_response, value) # type: ignore[no-any-return, call-overload]
return map(self._cast_response, value) # type: ignore[no-any-return, call-overload]

assert_never(response) # pragma: no cover

async def asolve(
self__,
self,
/,
*args: P.args,
stack: AsyncExitStack,
cache_dependencies: Dict[
Expand All @@ -360,7 +364,7 @@ async def asolve(
nested: bool = False,
**kwargs: P.kwargs,
) -> T:
cast_gen = self__._solve(
cast_gen = self._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
Expand All @@ -372,7 +376,7 @@ async def asolve(
cached_value: T = e.value
return cached_value

for dep in self__.extra_dependencies:
for dep in self.extra_dependencies:
await dep.asolve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -381,7 +385,7 @@ async def asolve(
**kwargs,
)

for dep_arg, dep in self__.dependencies.items():
for dep_arg, dep in self.dependencies.items():
kwargs[dep_arg] = await dep.asolve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -390,30 +394,30 @@ async def asolve(
**kwargs,
)

for custom in self__.custom_fields.values():
for custom in self.custom_fields.values():
kwargs = await run_async(custom.use, **kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)

if self__.is_generator and nested:
if self.is_generator and nested:
response = await solve_generator_async(
*final_args,
call=self__.call,
call=self.call,
stack=stack,
**final_kwargs,
)
else:
response = await run_async(self__.call, *final_args, **final_kwargs)
response = await run_async(self.call, *final_args, **final_kwargs)

try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value

if not self__.cast or nested or not self__.is_generator:
if not self.cast or nested or not self.is_generator:
return value

else:
return async_map(self__._cast_response, value) # type: ignore[return-value, arg-type]
return async_map(self._cast_response, value) # type: ignore[return-value, arg-type]

assert_never(response) # pragma: no cover
Loading

0 comments on commit d7eb0b0

Please sign in to comment.