From 882cb53374807678318da1fea03647b0355b15c0 Mon Sep 17 00:00:00 2001 From: raceychan Date: Wed, 5 Feb 2025 13:34:05 +0800 Subject: [PATCH] feat(WIP): untyped deps --- ididi/_node.py | 6 +- ididi/_type_resolve.py | 1 + ididi/graph.py | 119 +++++++++++++----------- tests/features/services.py | 10 ++ tests/features/test_resolve_function.py | 23 +++++ 5 files changed, 102 insertions(+), 57 deletions(-) create mode 100644 tests/features/test_resolve_function.py diff --git a/ididi/_node.py b/ididi/_node.py index 6ae1420b..108db745 100644 --- a/ididi/_node.py +++ b/ididi/_node.py @@ -20,6 +20,7 @@ from ._type_resolve import ( IDIDI_IGNORE_PARAM_MARK, + IDIDI_UNTYPE_DEP_MARK, IDIDI_USE_FACTORY_MARK, FactoryType, ResolveOrder, @@ -56,7 +57,7 @@ from .utils.param_utils import MISSING, Maybe, is_provided from .utils.typing_utils import P, T -# ============== Ididi special hooks =========== +# ============== Ididi marks =========== Ignore = Annotated[T, IDIDI_IGNORE_PARAM_MARK] @@ -75,12 +76,13 @@ def func(service: UserService = use(factory)): ... def func(service: Annotated[UserService, use(factory)]): ... ``` """ + # TODO: support untyped, untyped deps are nodes with dependent being functions. node = DependentNode[T].from_node(factory, config=NodeConfig(**iconfig)) annt = Annotated[node.dependent_type, node, IDIDI_USE_FACTORY_MARK] return cast(T, annt) -# ============== Ididi special hooks =========== +# ============== Ididi marks =========== def search_meta(meta: list[Any]) -> Union["DependentNode[Any]", None]: diff --git a/ididi/_type_resolve.py b/ididi/_type_resolve.py index 9ca3fd14..8c85075e 100644 --- a/ididi/_type_resolve.py +++ b/ididi/_type_resolve.py @@ -50,6 +50,7 @@ IDIDI_USE_FACTORY_MARK = "__ididi_use_factory__" IDIDI_IGNORE_PARAM_MARK = "__ididi_ignore_param__" +IDIDI_UNTYPE_DEP_MARK = "__ididi_untyped_dep__" FactoryType = Literal["default", "function", "resource", "aresource"] # carry this information in node so that resolve does not have to do diff --git a/ididi/graph.py b/ididi/graph.py index 670b6b5b..42de278f 100644 --- a/ididi/graph.py +++ b/ididi/graph.py @@ -133,27 +133,6 @@ def __init__( self._registered_singletons = registered_singletons self._resolved_singletons = resolved_singletons - def is_registered_singleton(self, dependent_type: type) -> bool: - return dependent_type in self._registered_singletons - - @lru_cache(CacheMax) - def should_be_scoped(self, dep_type: INode[P, T]) -> bool: - "Recursively check if a dependent type contains any resource dependency" - if not (resolved_node := self._resolved_nodes.get(dep_type)): - resolved_node = self.analyze(dep_type) - - if self.is_registered_singleton(resolved_node.dependent_type): - return False - - if resolved_node.is_resource: - return True - - unsolved_params = resolved_node.unsolved_params(ignore=self._ignore) - contain_resource = any( - self.should_be_scoped(param_type) for _, param_type in unsolved_params - ) - return contain_resource - def _remove_node(self, node: DependentNode[Any]) -> None: """ Remove a node from the graph and clean up all its references. @@ -189,6 +168,29 @@ def _resolve_concrete_node(self, dependent: type[T]) -> DependentNode[Any]: concrete_node.check_for_implementations() return concrete_node + # ================= Public ================= + + def is_registered_singleton(self, dependent_type: type) -> bool: + return dependent_type in self._registered_singletons + + @lru_cache(CacheMax) + def should_be_scoped(self, dep_type: INode[P, T]) -> bool: + "Recursively check if a dependent type contains any resource dependency" + if not (resolved_node := self._resolved_nodes.get(dep_type)): + resolved_node = self.analyze(dep_type) + + if self.is_registered_singleton(resolved_node.dependent_type): + return False + + if resolved_node.is_resource: + return True + + unsolved_params = resolved_node.unsolved_params(ignore=self._ignore) + contain_resource = any( + self.should_be_scoped(param_type) for _, param_type in unsolved_params + ) + return contain_resource + def check_param_conflict(self, param_type: type, current_path: list[type]): if param_type in current_path: i = current_path.index(param_type) @@ -425,6 +427,45 @@ async def aresolve( is_reuse=node.config.reuse, ) + def analyze_params( + self, func: Callable[P, T], config: NodeConfig = DefaultConfig + ) -> tuple[bool, list[tuple[str, type]]]: + deps = Dependencies.from_signature( + signature=get_typed_signature(func), factory=func, config=config + ) + depends_on_resource: bool = False + unresolved: list[tuple[str, type]] = [] + + for name, dep in deps.filter_ignore(self._ignore): + param_type = dep.param_type + + if is_unsolvable_type(param_type): + continue + + if inject_node := (resolve_use(param_type) or resolve_use(dep.default)): + self._register_node(inject_node) + self._resolved_nodes[param_type] = inject_node + param_type = inject_node.dependent_type + + self.analyze(param_type, config=config) + depends_on_resource = depends_on_resource or self.should_be_scoped( + param_type + ) + unresolved.append((name, param_type)) + + return depends_on_resource, unresolved + + def resolve_untyped(self, function: Callable[P, T], **overrides: Any) -> T: + """ + recursively resolve params of function and call it + + async def validate_admin(user: Annotated[User, use(get_user)]): + ... + """ + ... + + # async def aresolve_function(self, function: Callable[P, T]) -> T: ... + def _node( self, dependent: INode[P, T], config: NodeConfig = DefaultConfig ) -> DependentNode[T]: @@ -778,37 +819,6 @@ def use_scope( return scope.get_scope(name) return scope - def analyze_params( - self, func: Callable[P, T], **iconfig: Unpack[INodeConfig] - ) -> tuple[bool, list[tuple[str, type]]]: - config = NodeConfig(**iconfig) - - deps = Dependencies.from_signature( - signature=get_typed_signature(func), factory=func, config=config - ) - - depends_on_resource: bool = False - unresolved: list[tuple[str, type]] = [] - - for name, dep in deps.filter_ignore(self._ignore): - param_type = dep.param_type - - if is_unsolvable_type(param_type): - continue - - if inject_node := (resolve_use(param_type) or resolve_use(dep.default)): - self._register_node(inject_node) - self._resolved_nodes[param_type] = inject_node - param_type = inject_node.dependent_type - - self.analyze(param_type, config=config) - depends_on_resource = depends_on_resource or self.should_be_scoped( - param_type - ) - unresolved.append((name, param_type)) - - return depends_on_resource, unresolved - @overload def entry(self, **iconfig: Unpack[INodeConfig]) -> TEntryDecor: ... @@ -856,7 +866,8 @@ async def func(email_sender: EmailSender, /): configured = cast(TEntryDecor, partial(self.entry, **iconfig)) return configured - require_scope, unresolved = self.analyze_params(func, **iconfig) + config = NodeConfig(**iconfig) + require_scope, unresolved = self.analyze_params(func, config=config) def replace( before: Maybe[type[T]] = MISSING, @@ -885,7 +896,6 @@ async def _async_scoped_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if param_name in kwargs: continue kwargs[param_name] = await scope.resolve(param_type) - r = await func(*args, **kwargs) return r @@ -914,7 +924,6 @@ def _sync_scoped_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if param_name in kwargs: continue kwargs[param_name] = scope.resolve(param_type) - r = sync_func(*args, **kwargs) return r diff --git a/tests/features/services.py b/tests/features/services.py index 7af9ffc9..51d5fc2c 100644 --- a/tests/features/services.py +++ b/tests/features/services.py @@ -3,6 +3,11 @@ def __init__(self, env: str = "prod"): self.env = env +class MessageQueue: + def __init__(self, config: Config): + self.config = config + + class Database: def __init__(self, config: Config): self.config = config @@ -16,3 +21,8 @@ def __init__(self, db: Database): class UserService: def __init__(self, repo: UserRepository): self.repo = repo + + +class ProductService: + def __init__(self, mq: MessageQueue): + self.mq = mq diff --git a/tests/features/test_resolve_function.py b/tests/features/test_resolve_function.py new file mode 100644 index 00000000..d0ff4e66 --- /dev/null +++ b/tests/features/test_resolve_function.py @@ -0,0 +1,23 @@ +from typing import Annotated + +from ididi import Graph + + +class User: + def __init__(self, name: str, role: str): + self.name = name + self.role = role + + +def get_user(): + return User("user", "admin") + + +def validate_admin(user: Annotated[User, get_user]): + assert user.role == "admin" + return "ok" + + +# def test_dg_resolve_params(): +# dg = Graph() +# assert dg.resolve_function(validate_admin) == "ok"