diff --git a/lenskit/lenskit/pipeline/_impl.py b/lenskit/lenskit/pipeline/_impl.py index 154aedc79..360366308 100644 --- a/lenskit/lenskit/pipeline/_impl.py +++ b/lenskit/lenskit/pipeline/_impl.py @@ -65,7 +65,7 @@ class Pipeline: _nodes: dict[str, Node[Any]] _aliases: dict[str, Node[Any]] _defaults: dict[str, Node[Any]] - _components: dict[str, PipelineFunction[Any]] + _components: dict[str, PipelineFunction[Any] | Component[Any]] _hash: str | None = None _last: Node[Any] | None = None _anon_nodes: set[str] diff --git a/lenskit/lenskit/pipeline/components.py b/lenskit/lenskit/pipeline/components.py index 7167ef4fe..8d5d3bdb7 100644 --- a/lenskit/lenskit/pipeline/components.py +++ b/lenskit/lenskit/pipeline/components.py @@ -18,6 +18,7 @@ from typing import ( Any, Callable, + Generic, Mapping, ParamSpec, Protocol, @@ -35,7 +36,6 @@ P = ParamSpec("P") T = TypeVar("T") -Cfg = TypeVar("Cfg") # COut is only return, so Component[U] can be assigned to Component[T] if U ≼ T. COut = TypeVar("COut", covariant=True) PipelineFunction: TypeAlias = Callable[..., COut] @@ -130,7 +130,7 @@ def load_params(self, params: dict[str, object]) -> None: raise NotImplementedError() -class Component: +class Component(Generic[COut]): """ Base class for pipeline component objects. Any component that is not just a function should extend this class. @@ -260,7 +260,7 @@ def __repr__(self) -> str: def instantiate_component( - comp: str | type | FunctionType, config: dict[str, Any] | None + comp: str | type | FunctionType, config: Mapping[str, Any] | None ) -> Callable[..., object]: """ Utility function to instantiate a component given its class, function, or @@ -281,7 +281,7 @@ def instantiate_component( return comp elif issubclass(comp, Component): cfg = comp.validate_config(config) - return comp(cfg) + return comp(cfg) # type: ignore else: # pragma: nocover return comp() # type: ignore