From fd43a7f23622d4126de927cd3cc1034a46bcdc4f Mon Sep 17 00:00:00 2001 From: raceychan Date: Sat, 30 Nov 2024 23:19:45 +0800 Subject: [PATCH] fix: use dg with entry function and builtin type --- CHANGELOG.md | 10 +++- README.md | 2 +- docs/features.md | 2 +- docs/tutorial.md | 4 +- ididi/_itypes.py | 24 ++++++++- ididi/errors.py | 7 +++ ididi/graph.py | 32 +++++++++--- tests/features/test_graph_build_node.py | 1 - tests/features/test_node_ignore.py | 29 +++++++++++ tests/regression/test_static_resolve_all.py | 58 +++++++++++++++++++++ tests/test_graph.py | 1 + 11 files changed, 156 insertions(+), 14 deletions(-) create mode 100644 tests/features/test_node_ignore.py create mode 100644 tests/regression/test_static_resolve_all.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b00a1029..d0e7fd2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -282,4 +282,12 @@ make sure each of AuthService -> Repository is configured as `reuse=False` ## version 1.0.9 - remove `import typing as ty`, `import typing_extensions as tyex` to reduce global lookup -- fix a potential bug where when resolve dependent return None or False it could be re-resolved \ No newline at end of file +- fix a potential bug where when resolve dependent return None or False it could be re-resolved + +TODO: add a ignore part to ignore + +```py +@dg.entry(ignore=[Query]) +async def create_user(q: Query(max_length=5)): + ... +``` diff --git a/README.md b/README.md index 0df5512d..19b62ae4 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ async def service_factory(): yield service @app.get("users") -async def get_user(service: Service = Depends(dg.factory(service_factory))) +async def get_user(service: Service = Depends(service_factory)) await service.create_user(...) ``` diff --git a/docs/features.md b/docs/features.md index 6c9a7022..a65b4d30 100644 --- a/docs/features.md +++ b/docs/features.md @@ -79,7 +79,7 @@ async def service_factory(): yield service @app.get("users") -async def get_user(service: Service = Depends(dg.factory(service_factory))) +async def get_user(service: Service = Depends(service_factory)) await service.create_user(...) ``` diff --git a/docs/tutorial.md b/docs/tutorial.md index 33e674e5..3e6bcad6 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -50,11 +50,11 @@ from ididi import DependencyGraph app = FastAPI() dg = DependencyGraph() -def auth_service_factory(db: DataBase) -> AuthService: +def auth_service_factory() -> AuthService: async with dg.scope() as scope yield dg.resolve(AuthService) -Service = ty.Annotated[AuthService, Depends(dg.factory(auth_service_factory))] +Service = ty.Annotated[AuthService, Depends(auth_service_factory)] @app.get("/") def get_service(service: Service): diff --git a/ididi/_itypes.py b/ididi/_itypes.py index c5923f02..ca1efd1c 100644 --- a/ididi/_itypes.py +++ b/ididi/_itypes.py @@ -67,22 +67,42 @@ class INodeConfig(TypedDict, total=False): lazy: bool --- whether the resolved instance should be a `lazy` dependent, meaning that its dependencies would not be resolved untill the attribute is accessed. + + partial + --- + whether to ignore bulitin types when statically resolve + + ignore + --- + types or names to ignore """ reuse: bool lazy: bool partial: bool + ignore: tuple[Union[str, type], ...] class NodeConfig: - __slots__ = ("reuse", "lazy", "partial") + __slots__ = ("reuse", "lazy", "partial", "ignore") + + ignore: tuple[Union[str, type], ...] def __init__( - self, *, reuse: bool = True, lazy: bool = False, partial: bool = False + self, + *, + reuse: bool = True, + lazy: bool = False, + partial: bool = False, + ignore: Union[tuple[Union[str, type], ...], None] = None, ): self.reuse = reuse self.lazy = lazy self.partial = partial + self.ignore = ignore or () + + def __repr__(self): + return f"{self.__class__.__name__}({self.reuse=}, {self.lazy=}, {self.partial=}, {self.ignore=})" class GraphConfig: diff --git a/ididi/errors.py b/ididi/errors.py index 6754f2dc..66974a78 100644 --- a/ididi/errors.py +++ b/ididi/errors.py @@ -162,6 +162,13 @@ def __init__(self, generic_type: Union[type, TypeVar]): ) +class BuiltinTypeFactoryError(NodeError): + def __init__(self, factory: Callable[..., Any], return_type: Any): + super().__init__( + f"factory {factory} is returning a unresolvable type {return_type}, were you trying to use `entry`?" + ) + + # =============== Graph Errors =============== class GraphError(IDIDIError): """ diff --git a/ididi/graph.py b/ididi/graph.py index baa20873..ced70eb6 100644 --- a/ididi/graph.py +++ b/ididi/graph.py @@ -46,6 +46,7 @@ ) from .errors import ( AsyncResourceInSyncError, + BuiltinTypeFactoryError, CircularDependencyDetectedError, MergeWithScopeStartedError, MissingImplementationError, @@ -486,10 +487,7 @@ def static_resolve( current_path: list[type] = [] - def dfs( - dependent_factory: Union[type, Callable[P, T]], - node_config: Maybe[NodeConfig], - ) -> DependentNode[T]: + def dfs(dependent_factory: Union[type, Callable[P, T]]) -> DependentNode[T]: if is_function(dependent_factory): sig = get_typed_signature(dependent_factory) dependent: type = sig.return_annotation @@ -506,8 +504,14 @@ def dfs( current_path.append(dependent) + # TODO: we might let node return what params should be resolved + # param with default, or set to be ignored should not be checked here at all + for param in node.signature.actualized_params(): param_type = param.param_type + ignore_params = node.config.ignore + if param.name in ignore_params or param_type in ignore_params: + continue if param_type in current_path: i = current_path.index(param_type) cycle = current_path[i:] + [param_type] @@ -522,6 +526,7 @@ def dfs( if is_provided(param.default): continue if param.unresolvable: + # NOTE: node.config and note_config might not be the same if node.config.partial: continue raise UnsolvableDependencyError( @@ -531,7 +536,7 @@ def dfs( ) try: - dep_node = dfs(param_type, node_config) + dep_node = dfs(param_type) except UnsolvableNodeError as une: une.add_context( node.dependent_type, param.name, param.param_annotation @@ -546,7 +551,7 @@ def dfs( current_path.pop() return node - return dfs(dependent, node_config) + return dfs(dependent) @overload def use_scope( @@ -765,6 +770,15 @@ def entry(self, func: IFactory[P, T]) -> IAnyFactory[T]: ... def entry( self, func: Union[IFactory[P, T], IAsyncFactory[P, T]] ) -> Union[IAnyFactory[T], IAnyAsyncFactory[T]]: + """ + TODO: + 1. add more generic vars to func + checkout https://github.com/dbrattli/Expression/blob/main/expression/core/pipe.py + TypeVarTuple + Unpack[TypeVarTuple] + 2. allow ignore params, to support hook param like fastapi.Query + """ + def func_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: with self.scope() as scope: r = self.resolve(dep, scope, *args, **kwargs) @@ -842,9 +856,15 @@ def auth_service_factory() -> AuthService: ... node_config = NodeConfig(**config) + if is_unresolved_type(factory_or_class): + raise TopLevelBulitinTypeError(factory_or_class) + sig = get_typed_signature(factory_or_class) return_type: Union[type[T], ForwardRef] = sig.return_annotation + if is_unresolved_type(return_type): + raise BuiltinTypeFactoryError(factory_or_class, return_type) + if isinstance(return_type, ForwardRef): return_type = resolve_forwardref(factory_or_class, return_type) diff --git a/tests/features/test_graph_build_node.py b/tests/features/test_graph_build_node.py index ff033428..23df00a7 100644 --- a/tests/features/test_graph_build_node.py +++ b/tests/features/test_graph_build_node.py @@ -1,6 +1,5 @@ import pytest -from ididi.errors import UnsolvableDependencyError from ididi.graph import DependencyGraph diff --git a/tests/features/test_node_ignore.py b/tests/features/test_node_ignore.py new file mode 100644 index 00000000..a8d4bebc --- /dev/null +++ b/tests/features/test_node_ignore.py @@ -0,0 +1,29 @@ +import pytest + +from ididi import DependencyGraph +from ididi.errors import UnsolvableDependencyError + + +class IgnoreNode: + def __init__(self, name: str, age: int): ... + + +def test_direct_resolve_fail(): + dg = DependencyGraph() + dg.node(IgnoreNode) + with pytest.raises(UnsolvableDependencyError): + dg.static_resolve(IgnoreNode) + + +def test_resolve_fail_with_partial_ignore(): + dg = DependencyGraph() + + dg.node(ignore=("name",))(IgnoreNode) + with pytest.raises(UnsolvableDependencyError): + dg.static_resolve(IgnoreNode) + + +def test_resolve_with_ignore(): + dg = DependencyGraph() + dg.node(ignore=("name", int))(IgnoreNode) + dg.static_resolve(IgnoreNode) diff --git a/tests/regression/test_static_resolve_all.py b/tests/regression/test_static_resolve_all.py new file mode 100644 index 00000000..27aaa73b --- /dev/null +++ b/tests/regression/test_static_resolve_all.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass + +import pytest + +from ididi import DependencyGraph +from ididi.errors import BuiltinTypeFactoryError, TopLevelBulitinTypeError + + +@dataclass +class UserCommand: + user_id: str + + +@dataclass +class CreateUser(UserCommand): + user_name: str + + +@dataclass +class RemoveUser(UserCommand): + user_name: str + + +@dataclass +class UpdateUser(UserCommand): + old_name: str + new_name: str + + +def update_user(cmd: UpdateUser) -> str: + return cmd.new_name + + +class UserService: + def __init__(self): + self.name = "name" + + def create_user(self, cmd: CreateUser) -> str: + return "hello" + + def remove_user(self, cmd: RemoveUser) -> str: + return "goodbye" + + +def test_static_resolve_all(): + dg = DependencyGraph() + dg.node(UserService) + with pytest.raises(BuiltinTypeFactoryError): + dg.node(update_user) + + dg.static_resolve_all() + + +def test_dg_node_builtin(): + dg = DependencyGraph() + + with pytest.raises(TopLevelBulitinTypeError): + dg.node(int) diff --git a/tests/test_graph.py b/tests/test_graph.py index cd1ac536..93a93722 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -662,6 +662,7 @@ def test_graph_static_resolved(): c = dg.resolve(ComplianceChecker) dg2.static_resolve(DatabaseConfig) d = dg2.resolve(DatabaseConfig) + repr(dg.nodes[ComplianceChecker].config) dg.merge(dg2) assert ComplianceChecker in dg and DatabaseConfig in dg