Skip to content

Commit

Permalink
fix: Allow to inject async functions into ones wrapped with asynccont…
Browse files Browse the repository at this point in the history
…extmanager
  • Loading branch information
Zhenay committed Aug 24, 2024
1 parent f9146d3 commit ed90432
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
4 changes: 2 additions & 2 deletions fast_depends/use.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:


class solve_async_gen:
_iter: Optional[AsyncIterator[Any]]
_iter: Optional[AsyncIterator[Any]] = None

def __init__(
self,
Expand Down Expand Up @@ -231,7 +231,7 @@ async def __anext__(self) -> Any:


class solve_gen:
_iter: Optional[Iterator[Any]]
_iter: Optional[Iterator[Any]] = None

def __init__(
self,
Expand Down
15 changes: 15 additions & 0 deletions tests/async/test_depends.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial
from unittest.mock import Mock
Expand Down Expand Up @@ -469,3 +470,17 @@ async def func(a=Depends(dep)):
return a

assert await func() == "a"


@pytest.mark.anyio
async def test_asynccontextmanager():
async def dep(a: str = "a"):
return a

@asynccontextmanager
@inject
async def func(a: str, b: str = Depends(dep)):
yield a == b

async with func("a") as is_equal:
assert is_equal
14 changes: 14 additions & 0 deletions tests/sync/test_depends.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from unittest.mock import Mock
Expand Down Expand Up @@ -352,3 +353,16 @@ def func(a=Depends(dep)):
return a

assert func() == "a"


def test_contextmanager():
def dep(a: str = "a"):
return a

@contextmanager
@inject
def func(a: str, b: str = Depends(dep)):
yield a == b

with func("a") as is_equal:
assert is_equal

0 comments on commit ed90432

Please sign in to comment.