From ae8f14c6cc241212f93b85420540c245e7f64f86 Mon Sep 17 00:00:00 2001 From: Pastukhov Nikita Date: Tue, 27 Feb 2024 17:04:03 +0300 Subject: [PATCH] fix (#75): correct main empty body processing (#76) --- fast_depends/__about__.py | 2 +- fast_depends/core/model.py | 14 +++++++++----- tests/async/test_depends.py | 13 +++++++++++++ tests/sync/test_depends.py | 12 ++++++++++++ 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/fast_depends/__about__.py b/fast_depends/__about__.py index 24611d6..513e73f 100644 --- a/fast_depends/__about__.py +++ b/fast_depends/__about__.py @@ -1,3 +1,3 @@ """FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System""" -__version__ = "2.4.1" +__version__ = "2.4.2" diff --git a/fast_depends/core/model.py b/fast_depends/core/model.py index 53228ce..2e9d41f 100644 --- a/fast_depends/core/model.py +++ b/fast_depends/core/model.py @@ -258,15 +258,15 @@ def _solve( keyword_args = self.keyword_args else: - keyword_args = self.keyword_args + self.positional_args - for arg in self.keyword_args: + keyword_args = set(self.keyword_args + self.positional_args) + for arg in keyword_args - set(self.dependencies.keys()): if args: kw[arg], args = args[0], args[1:] else: break solved_kw: Dict[str, Any] - solved_kw = yield (), kw, call + solved_kw = yield args, kw, call args_: Sequence[Any] if self.cast: @@ -347,7 +347,7 @@ def solve( **kwargs, ) try: - _, kwargs, _ = next(cast_gen) + args, kwargs, _ = next(cast_gen) except StopIteration as e: cached_value: T = e.value return cached_value @@ -355,6 +355,7 @@ def solve( # Heat cache and solve extra dependencies for dep, _ in self.sorted_dependencies: dep.solve( + *args, stack=stack, cache_dependencies=cache_dependencies, dependency_overrides=dependency_overrides, @@ -365,6 +366,7 @@ def solve( # Always get from cache for dep in self.extra_dependencies: dep.solve( + *args, stack=stack, cache_dependencies=cache_dependencies, dependency_overrides=dependency_overrides, @@ -447,7 +449,7 @@ async def asolve( **kwargs, ) try: - _, kwargs, _ = next(cast_gen) + args, kwargs, _ = next(cast_gen) except StopIteration as e: cached_value: T = e.value return cached_value @@ -459,6 +461,7 @@ async def asolve( for dep, subdep in self.sorted_dependencies: solve = partial( dep.asolve, + *args, stack=stack, cache_dependencies=cache_dependencies, dependency_overrides=dependency_overrides, @@ -479,6 +482,7 @@ async def asolve( # Always get from cache for dep in self.extra_dependencies: await dep.asolve( + *args, stack=stack, cache_dependencies=cache_dependencies, dependency_overrides=dependency_overrides, diff --git a/tests/async/test_depends.py b/tests/async/test_depends.py index 67fec4e..753c0a0 100644 --- a/tests/async/test_depends.py +++ b/tests/async/test_depends.py @@ -23,6 +23,19 @@ async def some_func(b: int, c=Depends(dep_func)) -> int: assert (await some_func("2")) == 7 +@pytest.mark.anyio +async def test_empty_main_body(): + async def dep_func(a: int) -> float: + return a + + @inject + async def some_func(c=Depends(dep_func)): + assert isinstance(c, float) + assert c == 1.0 + + await some_func("1") + + @pytest.mark.anyio async def test_sync_depends(): def sync_dep_func(a: int) -> float: diff --git a/tests/sync/test_depends.py b/tests/sync/test_depends.py index 2b6ea5b..8b0aca0 100644 --- a/tests/sync/test_depends.py +++ b/tests/sync/test_depends.py @@ -22,6 +22,18 @@ def some_func(b: int, c=Depends(dep_func)) -> int: assert some_func("2") == 7 +def test_empty_main_body(): + def dep_func(a: int) -> float: + return a + + @inject + def some_func(c=Depends(dep_func)): + assert isinstance(c, float) + assert c == 1.0 + + some_func("1") + + def test_depends_error(): def dep_func(b: dict, a: int = 3) -> float: # pragma: no cover return a + b