From 492e6b434fc3da85b1a0f10583022b1d0c855687 Mon Sep 17 00:00:00 2001 From: Evgeny Seregin Date: Sat, 24 Aug 2024 12:21:36 +0400 Subject: [PATCH] fix: Allow to inject async functions into functions wrapped with asynccontextmanager (#125) --- fast_depends/use.py | 4 ++-- tests/async/test_depends.py | 15 +++++++++++++++ tests/sync/test_depends.py | 14 ++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/fast_depends/use.py b/fast_depends/use.py index 972a9c6..0e88ede 100644 --- a/fast_depends/use.py +++ b/fast_depends/use.py @@ -184,7 +184,7 @@ def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: class solve_async_gen: - _iter: Optional[AsyncIterator[Any]] + _iter: Optional[AsyncIterator[Any]] = None def __init__( self, @@ -231,7 +231,7 @@ async def __anext__(self) -> Any: class solve_gen: - _iter: Optional[Iterator[Any]] + _iter: Optional[Iterator[Any]] = None def __init__( self, diff --git a/tests/async/test_depends.py b/tests/async/test_depends.py index 592436b..ee12e9e 100644 --- a/tests/async/test_depends.py +++ b/tests/async/test_depends.py @@ -1,4 +1,5 @@ import logging +from contextlib import asynccontextmanager from dataclasses import dataclass from functools import partial from unittest.mock import Mock @@ -469,3 +470,17 @@ async def func(a=Depends(dep)): return a assert await func() == "a" + + +@pytest.mark.anyio +async def test_asynccontextmanager(): + async def dep(a: str): + return a + + @asynccontextmanager + @inject + async def func(a: str, b: str = Depends(dep)): + yield a == b + + async with func("a") as is_equal: + assert is_equal diff --git a/tests/sync/test_depends.py b/tests/sync/test_depends.py index ef56dfc..c1a2fe2 100644 --- a/tests/sync/test_depends.py +++ b/tests/sync/test_depends.py @@ -1,4 +1,5 @@ import logging +from contextlib import contextmanager from dataclasses import dataclass from functools import partial from unittest.mock import Mock @@ -352,3 +353,16 @@ def func(a=Depends(dep)): return a assert func() == "a" + + +def test_contextmanager(): + def dep(a: str): + return a + + @contextmanager + @inject + def func(a: str, b: str = Depends(dep)): + yield a == b + + with func("a") as is_equal: + assert is_equal