diff --git a/ididi/_ds.py b/ididi/_ds.py index cc002cd4..cad3a420 100644 --- a/ididi/_ds.py +++ b/ididi/_ds.py @@ -1,6 +1,6 @@ from collections import defaultdict -from types import MappingProxyType -from typing import Any, Callable, Hashable, Union +from types import FunctionType, MappingProxyType +from typing import Any, Callable, Hashable, Union, cast from ._node import DependentNode from ._type_resolve import get_bases @@ -69,9 +69,9 @@ def __init__(self, nodes: GraphNodes[Any]): def _visit( self, - start_types: Union[list[Any], type], - pre_visit: Union[Callable[[type], None], None] = None, - post_visit: Union[Callable[[type], None], None] = None, + start_types: Union[list[Callable[..., Any]], Callable[..., Any]], + pre_visit: Union[Callable[[Callable[..., Any]], None], None] = None, + post_visit: Union[Callable[[Callable[..., Any]], None], None] = None, ) -> None: """Generic DFS traversal with customizable visit callbacks. @@ -80,12 +80,12 @@ def _visit( pre_visit: Called before visiting node's dependencies post_visit: Called after visiting node's dependencies """ - if isinstance(start_types, type): + if isinstance(start_types, type) or isinstance(start_types, FunctionType): start_types = [start_types] - visited = set[type]() + visited = set[Callable[..., Any]]() - def _get_deps(node_type: type) -> list[type]: + def _get_deps(node_type: Callable[..., Any]) -> list[Callable[..., Any]]: node = self._nodes[node_type] return [ p.param_type @@ -93,7 +93,7 @@ def _get_deps(node_type: type) -> list[type]: if p.param_type in self._nodes ] - def dfs(node_type: type): + def dfs(node_type: Callable[..., Any]): if node_type in visited: return visited.add(node_type) @@ -107,13 +107,15 @@ def dfs(node_type: type): if post_visit: post_visit(node_type) - for node_type in start_types: + for node_type in cast(list[Callable[..., Any]], start_types): dfs(node_type) - def get_dependents(self, dependency: type) -> list[type]: - dependents: list[type] = [] + def get_dependents( + self, dependency: Callable[..., Any] + ) -> list[Callable[..., Any]]: + dependents: list[Callable[..., Any]] = [] - def collect_dependent(node_type: type): + def collect_dependent(node_type: Callable[..., Any]): node = self._nodes[node_type] if any(p.param_type is dependency for _, p in node.dependencies): dependents.append(node_type) @@ -121,23 +123,25 @@ def collect_dependent(node_type: type): self._visit(list(self._nodes), pre_visit=collect_dependent) return dependents - def get_dependencies(self, dependent: type, recursive: bool = False) -> list[type]: + def get_dependencies( + self, dependent: Callable[..., Any], recursive: bool = False + ) -> list[Callable[..., Any]]: if not recursive: return [p.param_type for _, p in self._nodes[dependent].dependencies] - def collect_dependencies(t: type): + def collect_dependencies(t: Callable[..., Any]): if t != dependent: dependencies.append(t) - dependencies: list[type] = [] + dependencies: list[Callable[..., Any]] = [] self._visit( dependent, post_visit=collect_dependencies, ) return dependencies - def top_sorted_dependencies(self) -> list[type]: + def top_sorted_dependencies(self) -> list[Callable[..., Any]]: "Sort the whole graph, from lowest dependencies to toppest dependents" - order: list[type] = [] + order: list[Callable[..., Any]] = [] self._visit(list(self._nodes), post_visit=order.append) return order diff --git a/ididi/_node.py b/ididi/_node.py index e15e2f41..053abca8 100644 --- a/ididi/_node.py +++ b/ididi/_node.py @@ -371,6 +371,7 @@ def auth_service_factory() -> AuthService: "dependent", "factory", "factory_type", + "function_dependent", "dependencies", "config", ) @@ -381,6 +382,7 @@ def __init__( dependent: Callable[..., T], factory: INodeAnyFactory[T], factory_type: FactoryType, + function_dependent: bool = False, dependencies: Dependencies, config: NodeConfig, ): @@ -388,6 +390,7 @@ def __init__( self.dependent = dependent self.factory = factory self.factory_type: FactoryType = factory_type + self.function_dependent = function_dependent self.dependencies = dependencies self.config = config @@ -484,10 +487,11 @@ def _from_factory( signature = get_typed_signature(f, check_return=True) dependent: type[T] = resolve_annotation(signature.return_annotation) + if get_origin(dependent) is Annotated: metas = flatten_annotated(dependent) if IDIDI_IGNORE_PARAM_MARK in metas: - node = cls._from_function(f, config=config) + node = cls._from_function(f, factory_type=factory_type, config=config) return node node = cls.create( @@ -546,7 +550,13 @@ def from_node( return cls._from_factory(factory=factory, config=config) @classmethod - def _from_function(cls, function: Any, *, config: NodeConfig = DefaultConfig): + def _from_function( + cls, + function: Any, + *, + factory_type: FactoryType, + config: NodeConfig = DefaultConfig, + ): deps = Dependencies.from_signature( function, get_typed_signature(function), config ) @@ -554,7 +564,8 @@ def _from_function(cls, function: Any, *, config: NodeConfig = DefaultConfig): node = DependentNode( dependent=function, factory=function, - factory_type="function", + factory_type=factory_type, + function_dependent=True, dependencies=deps, config=config, ) diff --git a/ididi/graph.py b/ididi/graph.py index 148f9432..257bae09 100644 --- a/ididi/graph.py +++ b/ididi/graph.py @@ -487,7 +487,7 @@ def _node( self, dependent: INode[P, T], config: NodeConfig = DefaultConfig ) -> DependentNode[T]: node = DependentNode[T].from_node(dependent, config=config) - if is_function(node.dependent): + if node.function_dependent: return node if ori_node := self._nodes.get(node.dependent): diff --git a/tests/features/test_resolve_function.py b/tests/features/test_resolve_function.py index 0856e3cf..66878b62 100644 --- a/tests/features/test_resolve_function.py +++ b/tests/features/test_resolve_function.py @@ -1,23 +1,37 @@ -# from typing import Annotated +from typing import Annotated -# from ididi import Graph, Ignore +from ididi import Graph, Ignore +from ..test_data import Config, UserService -# class User: -# def __init__(self, name: str, role: str): -# self.name = name -# self.role = role +class User: + def __init__(self, name: str, role: str): + self.name = name + self.role = role -# def get_user() -> Ignore[User]: -# return User("user", "admin") +def get_user(config: Config) -> Ignore[User]: + assert isinstance(config, Config) + return User("user", "admin") -# def validate_admin(user: Annotated[User, get_user]): -# assert user.role == "admin" -# return "ok" +def validate_admin( + user: Annotated[User, get_user], service: UserService +) -> Ignore[str]: + assert user.role == "admin" + assert isinstance(service, UserService) + return "ok" -# def test_dg_resolve_params(): -# dg = Graph() -# assert dg.resolve(get_user) + +async def test_resolve_function(): + dg = Graph() + + user = dg.resolve(get_user) + assert isinstance(user, User) + + +def test_dg_resolve_params(): + dg = Graph() + + assert dg.resolve(validate_admin) == "ok" diff --git a/tests/test_feat.py b/tests/test_feat.py index 95974d21..00819a14 100644 --- a/tests/test_feat.py +++ b/tests/test_feat.py @@ -4,36 +4,3 @@ run test with: make feat """ -from typing import Annotated - -from ididi import Graph, Ignore - - -class User: - def __init__(self, name: str, role: str): - self.name = name - self.role = role - - -def get_user() -> Ignore[User]: - return User("user", "admin") - - -def validate_admin(user: Annotated[User, get_user]) -> Ignore[str]: - assert user.role == "admin" - return "ok" - - -async def test_resolve_function(): - dg = Graph() - - user = dg.resolve(get_user) - assert isinstance(user, User) - - -def test_dg_resolve_params(): - dg = Graph() - - node = dg.analyze(validate_admin) - - assert dg.resolve(validate_admin) == "ok"