From 78a6202972062258c1bba568782c841f32c56a6c Mon Sep 17 00:00:00 2001 From: Stanislav Kiriukhin <44553725+stkrizh@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:28:36 +0400 Subject: [PATCH] use covariant type variables for providers (#96) --- that_depends/providers/attr_getter.py | 8 ++--- that_depends/providers/base.py | 37 ++++++++++----------- that_depends/providers/collections.py | 20 +++++------ that_depends/providers/context_resources.py | 22 ++++++------ that_depends/providers/factories.py | 22 ++++++------ that_depends/providers/object.py | 10 +++--- that_depends/providers/resources.py | 14 ++++---- that_depends/providers/selector.py | 14 ++++---- that_depends/providers/singleton.py | 16 ++++----- 9 files changed, 81 insertions(+), 82 deletions(-) diff --git a/that_depends/providers/attr_getter.py b/that_depends/providers/attr_getter.py index edaf450..3529651 100644 --- a/that_depends/providers/attr_getter.py +++ b/that_depends/providers/attr_getter.py @@ -4,7 +4,7 @@ from that_depends.providers.base import AbstractProvider -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") @@ -14,16 +14,16 @@ def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing. class AttrGetter( - AbstractProvider[T], + AbstractProvider[T_co], ): __slots__ = "_provider", "_attrs" - def __init__(self, provider: AbstractProvider[T], attr_name: str) -> None: + def __init__(self, provider: AbstractProvider[T_co], attr_name: str) -> None: super().__init__() self._provider = provider self._attrs = [attr_name] - def __getattr__(self, attr: str) -> "AttrGetter[T]": + def __getattr__(self, attr: str) -> "AttrGetter[T_co]": if attr.startswith("_"): msg = f"'{type(self)}' object has no attribute '{attr}'" raise AttributeError(msg) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index bd8bfa9..04c7d32 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -6,10 +6,9 @@ from contextlib import contextmanager -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) R = typing.TypeVar("R") P = typing.ParamSpec("P") -T_co = typing.TypeVar("T_co", covariant=True) class AbstractProvider(typing.Generic[T_co], abc.ABC): @@ -119,10 +118,10 @@ def sync_tear_down(self) -> None: raise RuntimeError(msg) -class AbstractResource(AbstractProvider[T], abc.ABC): +class AbstractResource(AbstractProvider[T_co], abc.ABC): def __init__( self, - creator: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]], + creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -141,21 +140,21 @@ def __init__( self._override = None def _is_creator_async( - self, _: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]] - ) -> typing.TypeGuard[typing.Callable[P, typing.AsyncIterator[T]]]: + self, _: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]] + ) -> typing.TypeGuard[typing.Callable[P, typing.AsyncIterator[T_co]]]: return self._is_async def _is_creator_sync( - self, _: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]] - ) -> typing.TypeGuard[typing.Callable[P, typing.Iterator[T]]]: + self, _: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]] + ) -> typing.TypeGuard[typing.Callable[P, typing.Iterator[T_co]]]: return not self._is_async @abc.abstractmethod - def _fetch_context(self) -> ResourceContext[T]: ... + def _fetch_context(self) -> ResourceContext[T_co]: ... - async def async_resolve(self) -> T: + async def async_resolve(self) -> T_co: if self._override: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) context = self._fetch_context() @@ -172,7 +171,7 @@ async def async_resolve(self) -> T: if self._is_creator_async(self._creator): context.context_stack = contextlib.AsyncExitStack() context.instance = typing.cast( - T, + T_co, await context.context_stack.enter_async_context( contextlib.asynccontextmanager(self._creator)( *[await x() if isinstance(x, AbstractProvider) else x for x in self._args], @@ -194,11 +193,11 @@ async def async_resolve(self) -> T: }, ), ) - return typing.cast(T, context.instance) + return typing.cast(T_co, context.instance) - def sync_resolve(self) -> T: + def sync_resolve(self) -> T_co: if self._override: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) context = self._fetch_context() if context.instance is not None: @@ -216,16 +215,16 @@ def sync_resolve(self) -> T: **{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ), ) - return typing.cast(T, context.instance) + return typing.cast(T_co, context.instance) -class AbstractFactory(AbstractProvider[T], abc.ABC): +class AbstractFactory(AbstractProvider[T_co], abc.ABC): """Abstract Factory Class.""" @property - def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, T]]: + def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, T_co]]: return self.async_resolve @property - def sync_provider(self) -> typing.Callable[[], T]: + def sync_provider(self) -> typing.Callable[[], T_co]: return self.sync_resolve diff --git a/that_depends/providers/collections.py b/that_depends/providers/collections.py index d5511cb..0e3c7f3 100644 --- a/that_depends/providers/collections.py +++ b/that_depends/providers/collections.py @@ -3,35 +3,35 @@ from that_depends.providers.base import AbstractProvider -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) -class List(AbstractProvider[list[T]]): +class List(AbstractProvider[list[T_co]]): __slots__ = ("_providers",) - def __init__(self, *providers: AbstractProvider[T]) -> None: + def __init__(self, *providers: AbstractProvider[T_co]) -> None: super().__init__() self._providers: typing.Final = providers - async def async_resolve(self) -> list[T]: + async def async_resolve(self) -> list[T_co]: return [await x.async_resolve() for x in self._providers] - def sync_resolve(self) -> list[T]: + def sync_resolve(self) -> list[T_co]: return [x.sync_resolve() for x in self._providers] - async def __call__(self) -> list[T]: + async def __call__(self) -> list[T_co]: return await self.async_resolve() -class Dict(AbstractProvider[dict[str, T]]): +class Dict(AbstractProvider[dict[str, T_co]]): __slots__ = ("_providers",) - def __init__(self, **providers: AbstractProvider[T]) -> None: + def __init__(self, **providers: AbstractProvider[T_co]) -> None: super().__init__() self._providers: typing.Final = providers - async def async_resolve(self) -> dict[str, T]: + async def async_resolve(self) -> dict[str, T_co]: return {key: await provider.async_resolve() for key, provider in self._providers.items()} - def sync_resolve(self) -> dict[str, T]: + def sync_resolve(self) -> dict[str, T_co]: return {key: provider.sync_resolve() for key, provider in self._providers.items()} diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 3ed7687..69687a7 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -12,7 +12,7 @@ logger: typing.Final = logging.getLogger(__name__) -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") _CONTAINER_CONTEXT: typing.Final[ContextVar[dict[str, typing.Any]]] = ContextVar("CONTAINER_CONTEXT") AppType = typing.TypeVar("AppType") @@ -90,18 +90,18 @@ async def __aexit__( finally: _CONTAINER_CONTEXT.reset(self._context_token) - def __call__(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: + def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: if inspect.iscoroutinefunction(func): @wraps(func) - async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T: + async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: async with container_context(self._initial_context): return await func(*args, **kwargs) # type: ignore[no-any-return] - return typing.cast(typing.Callable[P, T], _async_inner) + return typing.cast(typing.Callable[P, T_co], _async_inner) @wraps(func) - def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T: + def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: with container_context(self._initial_context): return func(*args, **kwargs) @@ -138,7 +138,7 @@ def fetch_context_item(key: str, default: typing.Any = None) -> typing.Any: # n return _get_container_context().get(key, default) -class ContextResource(AbstractResource[T]): +class ContextResource(AbstractResource[T_co]): __slots__ = ( "_is_async", "_creator", @@ -150,27 +150,27 @@ class ContextResource(AbstractResource[T]): def __init__( self, - creator: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]], + creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], *args: P.args, **kwargs: P.kwargs, ) -> None: super().__init__(creator, *args, **kwargs) self._internal_name: typing.Final = f"{creator.__name__}-{uuid.uuid4()}" - def _fetch_context(self) -> ResourceContext[T]: + def _fetch_context(self) -> ResourceContext[T_co]: container_context = _get_container_context() if resource_context := container_context.get(self._internal_name): - return typing.cast(ResourceContext[T], resource_context) + return typing.cast(ResourceContext[T_co], resource_context) resource_context = ResourceContext(is_async=_is_container_context_async()) container_context[self._internal_name] = resource_context return resource_context -class AsyncContextResource(ContextResource[T]): +class AsyncContextResource(ContextResource[T_co]): def __init__( self, - creator: typing.Callable[P, typing.AsyncIterator[T]], + creator: typing.Callable[P, typing.AsyncIterator[T_co]], *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py index dd6840e..28774c0 100644 --- a/that_depends/providers/factories.py +++ b/that_depends/providers/factories.py @@ -3,32 +3,32 @@ from that_depends.providers.base import AbstractFactory, AbstractProvider -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") -class Factory(AbstractFactory[T]): +class Factory(AbstractFactory[T_co]): __slots__ = "_factory", "_args", "_kwargs", "_override" - def __init__(self, factory: type[T] | typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, factory: type[T_co] | typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: super().__init__() self._factory: typing.Final = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs self._override = None - async def async_resolve(self) -> T: + async def async_resolve(self) -> T_co: if self._override: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) return self._factory( *[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], **{k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) - def sync_resolve(self) -> T: + def sync_resolve(self) -> T_co: if self._override: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) return self._factory( *[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], @@ -36,18 +36,18 @@ def sync_resolve(self) -> T: ) -class AsyncFactory(AbstractFactory[T]): +class AsyncFactory(AbstractFactory[T_co]): __slots__ = "_factory", "_args", "_kwargs", "_override" - def __init__(self, factory: typing.Callable[P, typing.Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs) -> None: self._factory: typing.Final = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs self._override = None - async def async_resolve(self) -> T: + async def async_resolve(self) -> T_co: if self._override: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) return await self._factory( *[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], diff --git a/that_depends/providers/object.py b/that_depends/providers/object.py index b7bd21b..a0095c7 100644 --- a/that_depends/providers/object.py +++ b/that_depends/providers/object.py @@ -3,19 +3,19 @@ from that_depends.providers.base import AbstractProvider -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") -class Object(AbstractProvider[T]): +class Object(AbstractProvider[T_co]): __slots__ = ("_obj",) - def __init__(self, obj: T) -> None: + def __init__(self, obj: T_co) -> None: super().__init__() self._obj: typing.Final = obj - async def async_resolve(self) -> T: + async def async_resolve(self) -> T_co: return self._obj - def sync_resolve(self) -> T: + def sync_resolve(self) -> T_co: return self._obj diff --git a/that_depends/providers/resources.py b/that_depends/providers/resources.py index 5144cd3..9dfe1c5 100644 --- a/that_depends/providers/resources.py +++ b/that_depends/providers/resources.py @@ -4,11 +4,11 @@ from that_depends.providers.base import AbstractResource, ResourceContext -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") -class Resource(AbstractResource[T]): +class Resource(AbstractResource[T_co]): __slots__ = ( "_is_async", "_creator", @@ -20,24 +20,24 @@ class Resource(AbstractResource[T]): def __init__( self, - creator: typing.Callable[P, typing.Iterator[T] | typing.AsyncIterator[T]], + creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], *args: P.args, **kwargs: P.kwargs, ) -> None: super().__init__(creator, *args, **kwargs) - self._context: typing.Final[ResourceContext[T]] = ResourceContext(is_async=self._is_async) + self._context: typing.Final[ResourceContext[T_co]] = ResourceContext(is_async=self._is_async) - def _fetch_context(self) -> ResourceContext[T]: + def _fetch_context(self) -> ResourceContext[T_co]: return self._context async def tear_down(self) -> None: await self._fetch_context().tear_down() -class AsyncResource(Resource[T]): +class AsyncResource(Resource[T_co]): def __init__( self, - creator: typing.Callable[P, typing.AsyncIterator[T]], + creator: typing.Callable[P, typing.AsyncIterator[T_co]], *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/that_depends/providers/selector.py b/that_depends/providers/selector.py index 55e3e18..e2a6aeb 100644 --- a/that_depends/providers/selector.py +++ b/that_depends/providers/selector.py @@ -3,21 +3,21 @@ from that_depends.providers.base import AbstractProvider -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) -class Selector(AbstractProvider[T]): +class Selector(AbstractProvider[T_co]): __slots__ = "_selector", "_providers", "_override" - def __init__(self, selector: typing.Callable[[], str], **providers: AbstractProvider[T]) -> None: + def __init__(self, selector: typing.Callable[[], str], **providers: AbstractProvider[T_co]) -> None: super().__init__() self._selector: typing.Final = selector self._providers: typing.Final = providers self._override = None - async def async_resolve(self) -> T: + async def async_resolve(self) -> T_co: if self._override: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) selected_key: typing.Final = self._selector() if selected_key not in self._providers: @@ -25,9 +25,9 @@ async def async_resolve(self) -> T: raise RuntimeError(msg) return await self._providers[selected_key].async_resolve() - def sync_resolve(self) -> T: + def sync_resolve(self) -> T_co: if self._override: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) selected_key: typing.Final = self._selector() if selected_key not in self._providers: diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index 6f9bc2e..11c86c7 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -5,20 +5,20 @@ from that_depends.providers.base import AbstractProvider -T = typing.TypeVar("T") +T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") -class Singleton(AbstractProvider[T]): +class Singleton(AbstractProvider[T_co]): __slots__ = "_factory", "_args", "_kwargs", "_override", "_instance", "_resolving_lock" - def __init__(self, factory: type[T] | typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, factory: type[T_co] | typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: super().__init__() self._factory: typing.Final = factory self._args: typing.Final = args self._kwargs: typing.Final = kwargs self._override = None - self._instance: T | None = None + self._instance: T_co | None = None self._resolving_lock: typing.Final = asyncio.Lock() def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401 @@ -27,9 +27,9 @@ def __getattr__(self, attr_name: str) -> typing.Any: # noqa: ANN401 raise AttributeError(msg) return AttrGetter(provider=self, attr_name=attr_name) - async def async_resolve(self) -> T: + async def async_resolve(self) -> T_co: if self._override is not None: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) if self._instance is not None: return self._instance @@ -46,9 +46,9 @@ async def async_resolve(self) -> T: ) return self._instance - def sync_resolve(self) -> T: + def sync_resolve(self) -> T_co: if self._override is not None: - return typing.cast(T, self._override) + return typing.cast(T_co, self._override) if self._instance is None: self._instance = self._factory(