Skip to content

Commit

Permalink
use covariant type variables for providers (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
stkrizh authored Oct 1, 2024
1 parent 4704255 commit 78a6202
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 82 deletions.
8 changes: 4 additions & 4 deletions that_depends/providers/attr_getter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from that_depends.providers.base import AbstractProvider


T = typing.TypeVar("T")
T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


Expand All @@ -14,16 +14,16 @@ def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.


class AttrGetter(
AbstractProvider[T],
AbstractProvider[T_co],
):
__slots__ = "_provider", "_attrs"

def __init__(self, provider: AbstractProvider[T], attr_name: str) -> None:
def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None:
super().__init__()
self._provider = provider
self._attrs = [attr_name]

def __getattr__(self, attr: str) -> "AttrGetter[T]":
def __getattr__(self, attr: str) -> "AttrGetter[T_co]":
if attr.startswith("_"):
msg = f"'{type(self)}' object has no attribute '{attr}'"
raise AttributeError(msg)
Expand Down
37 changes: 18 additions & 19 deletions that_depends/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from contextlib import contextmanager


T = typing.TypeVar("T")
T_co = typing.TypeVar("T_co", covariant=True)
R = typing.TypeVar("R")
P = typing.ParamSpec("P")
T_co = typing.TypeVar("T_co", covariant=True)


class AbstractProvider(typing.Generic[T_co], abc.ABC):
Expand Down Expand Up @@ -119,10 +118,10 @@ def sync_tear_down(self) -> None:
raise RuntimeError(msg)


class AbstractResource(AbstractProvider[T], abc.ABC):
class AbstractResource(AbstractProvider[T_co], abc.ABC):
def __init__(
self,
creator: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]],
creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand All @@ -141,21 +140,21 @@ def __init__(
self._override = None

def _is_creator_async(
self, _: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]]
) -> typing.TypeGuard[typing.Callable[P, typing.AsyncIterator[T]]]:
self, _: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]]
) -> typing.TypeGuard[typing.Callable[P, typing.AsyncIterator[T_co]]]:
return self._is_async

def _is_creator_sync(
self, _: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]]
) -> typing.TypeGuard[typing.Callable[P, typing.Iterator[T]]]:
self, _: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]]
) -> typing.TypeGuard[typing.Callable[P, typing.Iterator[T_co]]]:
return not self._is_async

@abc.abstractmethod
def _fetch_context(self) -> ResourceContext[T]: ...
def _fetch_context(self) -> ResourceContext[T_co]: ...

async def async_resolve(self) -> T:
async def async_resolve(self) -> T_co:
if self._override:
return typing.cast(T, self._override)
return typing.cast(T_co, self._override)

context = self._fetch_context()

Expand All @@ -172,7 +171,7 @@ async def async_resolve(self) -> T:
if self._is_creator_async(self._creator):
context.context_stack = contextlib.AsyncExitStack()
context.instance = typing.cast(
T,
T_co,
await context.context_stack.enter_async_context(
contextlib.asynccontextmanager(self._creator)(
*[await x() if isinstance(x, AbstractProvider) else x for x in self._args],
Expand All @@ -194,11 +193,11 @@ async def async_resolve(self) -> T:
},
),
)
return typing.cast(T, context.instance)
return typing.cast(T_co, context.instance)

def sync_resolve(self) -> T:
def sync_resolve(self) -> T_co:
if self._override:
return typing.cast(T, self._override)
return typing.cast(T_co, self._override)

context = self._fetch_context()
if context.instance is not None:
Expand All @@ -216,16 +215,16 @@ def sync_resolve(self) -> T:
**{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
),
)
return typing.cast(T, context.instance)
return typing.cast(T_co, context.instance)


class AbstractFactory(AbstractProvider[T], abc.ABC):
class AbstractFactory(AbstractProvider[T_co], abc.ABC):
"""Abstract Factory Class."""

@property
def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, T]]:
def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, T_co]]:
return self.async_resolve

@property
def sync_provider(self) -> typing.Callable[[], T]:
def sync_provider(self) -> typing.Callable[[], T_co]:
return self.sync_resolve
20 changes: 10 additions & 10 deletions that_depends/providers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,35 @@
from that_depends.providers.base import AbstractProvider


T = typing.TypeVar("T")
T_co = typing.TypeVar("T_co", covariant=True)


class List(AbstractProvider[list[T]]):
class List(AbstractProvider[list[T_co]]):
__slots__ = ("_providers",)

def __init__(self, *providers: AbstractProvider[T]) -> None:
def __init__(self, *providers: AbstractProvider[T_co]) -> None:
super().__init__()
self._providers: typing.Final = providers

async def async_resolve(self) -> list[T]:
async def async_resolve(self) -> list[T_co]:
return [await x.async_resolve() for x in self._providers]

def sync_resolve(self) -> list[T]:
def sync_resolve(self) -> list[T_co]:
return [x.sync_resolve() for x in self._providers]

async def __call__(self) -> list[T]:
async def __call__(self) -> list[T_co]:
return await self.async_resolve()


class Dict(AbstractProvider[dict[str, T]]):
class Dict(AbstractProvider[dict[str, T_co]]):
__slots__ = ("_providers",)

def __init__(self, **providers: AbstractProvider[T]) -> None:
def __init__(self, **providers: AbstractProvider[T_co]) -> None:
super().__init__()
self._providers: typing.Final = providers

async def async_resolve(self) -> dict[str, T]:
async def async_resolve(self) -> dict[str, T_co]:
return {key: await provider.async_resolve() for key, provider in self._providers.items()}

def sync_resolve(self) -> dict[str, T]:
def sync_resolve(self) -> dict[str, T_co]:
return {key: provider.sync_resolve() for key, provider in self._providers.items()}
22 changes: 11 additions & 11 deletions that_depends/providers/context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


logger: typing.Final = logging.getLogger(__name__)
T = typing.TypeVar("T")
T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")
_CONTAINER_CONTEXT: typing.Final[ContextVar[dict[str, typing.Any]]] = ContextVar("CONTAINER_CONTEXT")
AppType = typing.TypeVar("AppType")
Expand Down Expand Up @@ -90,18 +90,18 @@ async def __aexit__(
finally:
_CONTAINER_CONTEXT.reset(self._context_token)

def __call__(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]:
def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]:
if inspect.iscoroutinefunction(func):

@wraps(func)
async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T:
async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T_co:
async with container_context(self._initial_context):
return await func(*args, **kwargs) # type: ignore[no-any-return]

return typing.cast(typing.Callable[P, T], _async_inner)
return typing.cast(typing.Callable[P, T_co], _async_inner)

@wraps(func)
def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T:
def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co:
with container_context(self._initial_context):
return func(*args, **kwargs)

Expand Down Expand Up @@ -138,7 +138,7 @@ def fetch_context_item(key: str, default: typing.Any = None) -> typing.Any: # n
return _get_container_context().get(key, default)


class ContextResource(AbstractResource[T]):
class ContextResource(AbstractResource[T_co]):
__slots__ = (
"_is_async",
"_creator",
Expand All @@ -150,27 +150,27 @@ class ContextResource(AbstractResource[T]):

def __init__(
self,
creator: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]],
creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
super().__init__(creator, *args, **kwargs)
self._internal_name: typing.Final = f"{creator.__name__}-{uuid.uuid4()}"

def _fetch_context(self) -> ResourceContext[T]:
def _fetch_context(self) -> ResourceContext[T_co]:
container_context = _get_container_context()
if resource_context := container_context.get(self._internal_name):
return typing.cast(ResourceContext[T], resource_context)
return typing.cast(ResourceContext[T_co], resource_context)

resource_context = ResourceContext(is_async=_is_container_context_async())
container_context[self._internal_name] = resource_context
return resource_context


class AsyncContextResource(ContextResource[T]):
class AsyncContextResource(ContextResource[T_co]):
def __init__(
self,
creator: typing.Callable[P, typing.AsyncIterator[T]],
creator: typing.Callable[P, typing.AsyncIterator[T_co]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand Down
22 changes: 11 additions & 11 deletions that_depends/providers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,51 @@
from that_depends.providers.base import AbstractFactory, AbstractProvider


T = typing.TypeVar("T")
T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class Factory(AbstractFactory[T]):
class Factory(AbstractFactory[T_co]):
__slots__ = "_factory", "_args", "_kwargs", "_override"

def __init__(self, factory: type[T] | typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
def __init__(self, factory: type[T_co] | typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None:
super().__init__()
self._factory: typing.Final = factory
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
self._override = None

async def async_resolve(self) -> T:
async def async_resolve(self) -> T_co:
if self._override:
return typing.cast(T, self._override)
return typing.cast(T_co, self._override)

return self._factory(
*[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)

def sync_resolve(self) -> T:
def sync_resolve(self) -> T_co:
if self._override:
return typing.cast(T, self._override)
return typing.cast(T_co, self._override)

return self._factory(
*[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)


class AsyncFactory(AbstractFactory[T]):
class AsyncFactory(AbstractFactory[T_co]):
__slots__ = "_factory", "_args", "_kwargs", "_override"

def __init__(self, factory: typing.Callable[P, typing.Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> None:
def __init__(self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs) -> None:
self._factory: typing.Final = factory
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
self._override = None

async def async_resolve(self) -> T:
async def async_resolve(self) -> T_co:
if self._override:
return typing.cast(T, self._override)
return typing.cast(T_co, self._override)

return await self._factory(
*[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
Expand Down
10 changes: 5 additions & 5 deletions that_depends/providers/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
from that_depends.providers.base import AbstractProvider


T = typing.TypeVar("T")
T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class Object(AbstractProvider[T]):
class Object(AbstractProvider[T_co]):
__slots__ = ("_obj",)

def __init__(self, obj: T) -> None:
def __init__(self, obj: T_co) -> None:
super().__init__()
self._obj: typing.Final = obj

async def async_resolve(self) -> T:
async def async_resolve(self) -> T_co:
return self._obj

def sync_resolve(self) -> T:
def sync_resolve(self) -> T_co:
return self._obj
14 changes: 7 additions & 7 deletions that_depends/providers/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from that_depends.providers.base import AbstractResource, ResourceContext


T = typing.TypeVar("T")
T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class Resource(AbstractResource[T]):
class Resource(AbstractResource[T_co]):
__slots__ = (
"_is_async",
"_creator",
Expand All @@ -20,24 +20,24 @@ class Resource(AbstractResource[T]):

def __init__(
self,
creator: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]],
creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
super().__init__(creator, *args, **kwargs)
self._context: typing.Final[ResourceContext[T]] = ResourceContext(is_async=self._is_async)
self._context: typing.Final[ResourceContext[T_co]] = ResourceContext(is_async=self._is_async)

def _fetch_context(self) -> ResourceContext[T]:
def _fetch_context(self) -> ResourceContext[T_co]:
return self._context

async def tear_down(self) -> None:
await self._fetch_context().tear_down()


class AsyncResource(Resource[T]):
class AsyncResource(Resource[T_co]):
def __init__(
self,
creator: typing.Callable[P, typing.AsyncIterator[T]],
creator: typing.Callable[P, typing.AsyncIterator[T_co]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand Down
Loading

0 comments on commit 78a6202

Please sign in to comment.