Skip to content

Commit

Permalink
feat (#34): inject to classes (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik authored Dec 13, 2023
1 parent 133955e commit 13ba7b7
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 74 deletions.
2 changes: 1 addition & 1 deletion fast_depends/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""

__version__ = "2.2.3"
__version__ = "2.2.4"
146 changes: 73 additions & 73 deletions fast_depends/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,25 @@ class CallModel(Generic[P, T]):
)

@property
def call_name(self) -> str:
return getattr(self.call, "__name__", type(self.call).__name__)
def call_name(self__) -> str:
return getattr(self__.call, "__name__", type(self__.call).__name__)

@property
def real_params(self) -> Dict[str, FieldInfo]:
params = self.params.copy()
for name in self.custom_fields.keys():
def real_params(self__) -> Dict[str, FieldInfo]:
params = self__.params.copy()
for name in self__.custom_fields.keys():
params.pop(name, None)
return params

@property
def flat_params(self) -> Dict[str, FieldInfo]:
params = self.real_params
for d in self.dependencies.values():
def flat_params(self__) -> Dict[str, FieldInfo]:
params = self__.real_params
for d in self__.dependencies.values():
params.update(d.flat_params)
return params

def __init__(
self,
self__,
call: Union[
Callable[P, T],
Callable[P, Awaitable[T]],
Expand All @@ -112,39 +112,39 @@ def __init__(
positional_args: Optional[List[str]] = None,
custom_fields: Optional[Dict[str, CustomField]] = None,
):
self.call = call
self.model = model
self.response_model = response_model
self__.call = call
self__.model = model
self__.response_model = response_model

fields: Dict[str, FieldInfo]
if PYDANTIC_V2:
fields = self.model.model_fields
fields = self__.model.model_fields
else:
fields = self.model.__fields__ # type: ignore
fields = self__.model.__fields__ # type: ignore

self.dependencies = dependencies or {}
self.extra_dependencies = extra_dependencies or []
self.custom_fields = custom_fields or {}
self__.dependencies = dependencies or {}
self__.extra_dependencies = extra_dependencies or []
self__.custom_fields = custom_fields or {}

self.alias_arguments = [f.alias or name for name, f in fields.items()]
self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())
self__.alias_arguments = [f.alias or name for name, f in fields.items()]
self__.keyword_args = tuple(keyword_args or ())
self__.positional_args = tuple(positional_args or ())

self.params = fields.copy()
for name in self.dependencies.keys():
self.params.pop(name, None)
self__.params = fields.copy()
for name in self__.dependencies.keys():
self__.params.pop(name, None)

self.use_cache = use_cache
self.cast = cast
self.is_async = (
is_async or is_coroutine_callable(call) or is_async_gen_callable(self.call)
self__.use_cache = use_cache
self__.cast = cast
self__.is_async = (
is_async or is_coroutine_callable(call) or is_async_gen_callable(self__.call)
)
self.is_generator = is_gen_callable(self.call) or is_async_gen_callable(
self.call
self__.is_generator = is_gen_callable(self__.call) or is_async_gen_callable(
self__.call
)

def _solve(
self,
self__,
*args: P.args,
cache_dependencies: Dict[
Union[
Expand All @@ -168,56 +168,56 @@ def _solve(
**kwargs: P.kwargs,
) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], Any, T]:
if dependency_overrides:
self.call = dependency_overrides.get(self.call, self.call)
assert self.is_async or not is_coroutine_callable(
self.call
), f"You cannot use async dependency `{self.call_name}` at sync main"
self__.call = dependency_overrides.get(self__.call, self__.call)
assert self__.is_async or not is_coroutine_callable(
self__.call
), f"You cannot use async dependency `{self__.call_name}` at sync main"

if self.use_cache and self.call in cache_dependencies:
return cache_dependencies[self.call]
if self__.use_cache and self__.call in cache_dependencies:
return cache_dependencies[self__.call]

kw = {}

for arg in self.keyword_args:
for arg in self__.keyword_args:
v = kwargs.pop(arg, inspect._empty)
if v is not inspect._empty:
kw[arg] = v

if "kwargs" in self.alias_arguments:
if "kwargs" in self__.alias_arguments:
kw["kwargs"] = kwargs

else:
kw.update(kwargs)

has_args = "args" in self.alias_arguments
has_args = "args" in self__.alias_arguments

for arg in self.positional_args:
for arg in self__.positional_args:
if args:
kw[arg], args = args[0], args[1:]

if has_args:
kw["args"] = args

else:
for arg in self.keyword_args:
for arg in self__.keyword_args:
if args:
kw[arg], args = args[0], args[1:]

solved_kw: Dict[str, Any]
solved_kw = yield (), kw

casted_model: object
if self.cast:
casted_model = self.model(**solved_kw)
if self__.cast:
casted_model = self__.model(**solved_kw)
else:
casted_model = object()

kwargs_ = {
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in (
self.keyword_args + self.positional_args
self__.keyword_args + self__.positional_args
if not has_args
else self.keyword_args
else self__.keyword_args
)
}
kwargs_.update(getattr(casted_model, "kwargs", {}))
Expand All @@ -226,7 +226,7 @@ def _solve(
if has_args:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self.positional_args
for arg in self__.positional_args
]
args_.extend(getattr(casted_model, "args", ()))
else:
Expand All @@ -235,22 +235,22 @@ def _solve(
response: T
response = yield args_, kwargs_

if self.cast and not self.is_generator:
response = self._cast_response(response)
if self__.cast and not self__.is_generator:
response = self__._cast_response(response)

if self.use_cache: # pragma: no branch
cache_dependencies[self.call] = response
if self__.use_cache: # pragma: no branch
cache_dependencies[self__.call] = response

return response

def _cast_response(self, value: Any) -> Any:
if self.response_model is not None and self.cast:
return self.response_model(response=value).response
def _cast_response(self__, value: Any) -> Any:
if self__.response_model is not None and self__.cast:
return self__.response_model(response=value).response
else:
return value

def solve(
self,
self__,
*args: P.args,
stack: ExitStack,
cache_dependencies: Dict[
Expand All @@ -275,7 +275,7 @@ def solve(
nested: bool = False,
**kwargs: P.kwargs,
) -> T:
cast_gen = self._solve(
cast_gen = self__._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
Expand All @@ -287,7 +287,7 @@ def solve(
cached_value: T = e.value
return cached_value

for dep in self.extra_dependencies:
for dep in self__.extra_dependencies:
dep.solve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -296,7 +296,7 @@ def solve(
**kwargs,
)

for dep_arg, dep in self.dependencies.items():
for dep_arg, dep in self__.dependencies.items():
kwargs[dep_arg] = dep.solve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -305,37 +305,37 @@ def solve(
**kwargs,
)

for custom in self.custom_fields.values():
for custom in self__.custom_fields.values():
kwargs = custom.use(**kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)

if self.is_generator and nested:
if self__.is_generator and nested:
response = solve_generator_sync(
*final_args,
call=self.call,
call=self__.call,
stack=stack,
**final_kwargs,
)

else:
response = self.call(*final_args, **final_kwargs)
response = self__.call(*final_args, **final_kwargs)

try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value

if not self.cast or nested or not self.is_generator:
if not self__.cast or nested or not self__.is_generator:
return value

else:
return map(self._cast_response, value) # type: ignore[no-any-return, call-overload]
return map(self__._cast_response, value) # type: ignore[no-any-return, call-overload]

assert_never(response) # pragma: no cover

async def asolve(
self,
self__,
*args: P.args,
stack: AsyncExitStack,
cache_dependencies: Dict[
Expand All @@ -360,7 +360,7 @@ async def asolve(
nested: bool = False,
**kwargs: P.kwargs,
) -> T:
cast_gen = self._solve(
cast_gen = self__._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
Expand All @@ -372,7 +372,7 @@ async def asolve(
cached_value: T = e.value
return cached_value

for dep in self.extra_dependencies:
for dep in self__.extra_dependencies:
await dep.asolve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -381,7 +381,7 @@ async def asolve(
**kwargs,
)

for dep_arg, dep in self.dependencies.items():
for dep_arg, dep in self__.dependencies.items():
kwargs[dep_arg] = await dep.asolve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -390,30 +390,30 @@ async def asolve(
**kwargs,
)

for custom in self.custom_fields.values():
for custom in self__.custom_fields.values():
kwargs = await run_async(custom.use, **kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)

if self.is_generator and nested:
if self__.is_generator and nested:
response = await solve_generator_async(
*final_args,
call=self.call,
call=self__.call,
stack=stack,
**final_kwargs,
)
else:
response = await run_async(self.call, *final_args, **final_kwargs)
response = await run_async(self__.call, *final_args, **final_kwargs)

try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value

if not self.cast or nested or not self.is_generator:
if not self__.cast or nested or not self__.is_generator:
return value

else:
return async_map(self._cast_response, value) # type: ignore[return-value, arg-type]
return async_map(self__._cast_response, value) # type: ignore[return-value, arg-type]

assert_never(response) # pragma: no cover
22 changes: 22 additions & 0 deletions tests/async/test_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from fast_depends import Depends, inject


def _get_var():
return 1


class Class:
@inject
def __init__(self, a = Depends(_get_var)) -> None:
self.a = a

@inject
async def calc(self, a = Depends(_get_var)) -> int:
return a + self.a


@pytest.mark.anyio
async def test_class():
assert await Class().calc() == 2
19 changes: 19 additions & 0 deletions tests/sync/test_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from fast_depends import Depends, inject


def _get_var():
return 1


class Class:
@inject
def __init__(self, a = Depends(_get_var)) -> None:
self.a = a

@inject
def calc(self, a = Depends(_get_var)) -> int:
return a + self.a


def test_class():
assert Class().calc() == 2

0 comments on commit 13ba7b7

Please sign in to comment.