Skip to content

Commit

Permalink
fix: use dg with entry function and builtin type
Browse files Browse the repository at this point in the history
  • Loading branch information
raceychan committed Nov 30, 2024
1 parent 886b4e7 commit fd43a7f
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 14 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,4 +282,12 @@ make sure each of AuthService -> Repository is configured as `reuse=False`
## version 1.0.9
- remove `import typing as ty`, `import typing_extensions as tyex` to reduce global lookup
- fix a potential bug where when resolve dependent return None or False it could be re-resolved
- fix a potential bug where when resolve dependent return None or False it could be re-resolved
TODO: add a ignore part to ignore
```py
@dg.entry(ignore=[Query])
async def create_user(q: Query(max_length=5)):
...
```
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def service_factory():
yield service

@app.get("users")
async def get_user(service: Service = Depends(dg.factory(service_factory)))
async def get_user(service: Service = Depends(service_factory))
await service.create_user(...)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def service_factory():
yield service

@app.get("users")
async def get_user(service: Service = Depends(dg.factory(service_factory)))
async def get_user(service: Service = Depends(service_factory))
await service.create_user(...)
```

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ from ididi import DependencyGraph
app = FastAPI()
dg = DependencyGraph()

def auth_service_factory(db: DataBase) -> AuthService:
def auth_service_factory() -> AuthService:
async with dg.scope() as scope
yield dg.resolve(AuthService)

Service = ty.Annotated[AuthService, Depends(dg.factory(auth_service_factory))]
Service = ty.Annotated[AuthService, Depends(auth_service_factory)]

@app.get("/")
def get_service(service: Service):
Expand Down
24 changes: 22 additions & 2 deletions ididi/_itypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,42 @@ class INodeConfig(TypedDict, total=False):
lazy: bool
---
whether the resolved instance should be a `lazy` dependent, meaning that its dependencies would not be resolved untill the attribute is accessed.
partial
---
whether to ignore bulitin types when statically resolve
ignore
---
types or names to ignore
"""

reuse: bool
lazy: bool
partial: bool
ignore: tuple[Union[str, type], ...]


class NodeConfig:
__slots__ = ("reuse", "lazy", "partial")
__slots__ = ("reuse", "lazy", "partial", "ignore")

ignore: tuple[Union[str, type], ...]

def __init__(
self, *, reuse: bool = True, lazy: bool = False, partial: bool = False
self,
*,
reuse: bool = True,
lazy: bool = False,
partial: bool = False,
ignore: Union[tuple[Union[str, type], ...], None] = None,
):
self.reuse = reuse
self.lazy = lazy
self.partial = partial
self.ignore = ignore or ()

def __repr__(self):
return f"{self.__class__.__name__}({self.reuse=}, {self.lazy=}, {self.partial=}, {self.ignore=})"


class GraphConfig:
Expand Down
7 changes: 7 additions & 0 deletions ididi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def __init__(self, generic_type: Union[type, TypeVar]):
)


class BuiltinTypeFactoryError(NodeError):
def __init__(self, factory: Callable[..., Any], return_type: Any):
super().__init__(
f"factory {factory} is returning a unresolvable type {return_type}, were you trying to use `entry`?"
)


# =============== Graph Errors ===============
class GraphError(IDIDIError):
"""
Expand Down
32 changes: 26 additions & 6 deletions ididi/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from .errors import (
AsyncResourceInSyncError,
BuiltinTypeFactoryError,
CircularDependencyDetectedError,
MergeWithScopeStartedError,
MissingImplementationError,
Expand Down Expand Up @@ -486,10 +487,7 @@ def static_resolve(

current_path: list[type] = []

def dfs(
dependent_factory: Union[type, Callable[P, T]],
node_config: Maybe[NodeConfig],
) -> DependentNode[T]:
def dfs(dependent_factory: Union[type, Callable[P, T]]) -> DependentNode[T]:
if is_function(dependent_factory):
sig = get_typed_signature(dependent_factory)
dependent: type = sig.return_annotation
Expand All @@ -506,8 +504,14 @@ def dfs(

current_path.append(dependent)

# TODO: we might let node return what params should be resolved
# param with default, or set to be ignored should not be checked here at all

for param in node.signature.actualized_params():
param_type = param.param_type
ignore_params = node.config.ignore
if param.name in ignore_params or param_type in ignore_params:
continue
if param_type in current_path:
i = current_path.index(param_type)
cycle = current_path[i:] + [param_type]
Expand All @@ -522,6 +526,7 @@ def dfs(
if is_provided(param.default):
continue
if param.unresolvable:
# NOTE: node.config and note_config might not be the same
if node.config.partial:
continue
raise UnsolvableDependencyError(
Expand All @@ -531,7 +536,7 @@ def dfs(
)

try:
dep_node = dfs(param_type, node_config)
dep_node = dfs(param_type)
except UnsolvableNodeError as une:
une.add_context(
node.dependent_type, param.name, param.param_annotation
Expand All @@ -546,7 +551,7 @@ def dfs(
current_path.pop()
return node

return dfs(dependent, node_config)
return dfs(dependent)

@overload
def use_scope(
Expand Down Expand Up @@ -765,6 +770,15 @@ def entry(self, func: IFactory[P, T]) -> IAnyFactory[T]: ...
def entry(
self, func: Union[IFactory[P, T], IAsyncFactory[P, T]]
) -> Union[IAnyFactory[T], IAnyAsyncFactory[T]]:
"""
TODO:
1. add more generic vars to func
checkout https://github.com/dbrattli/Expression/blob/main/expression/core/pipe.py
TypeVarTuple
Unpack[TypeVarTuple]
2. allow ignore params, to support hook param like fastapi.Query
"""

def func_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with self.scope() as scope:
r = self.resolve(dep, scope, *args, **kwargs)
Expand Down Expand Up @@ -842,9 +856,15 @@ def auth_service_factory() -> AuthService: ...

node_config = NodeConfig(**config)

if is_unresolved_type(factory_or_class):
raise TopLevelBulitinTypeError(factory_or_class)

sig = get_typed_signature(factory_or_class)
return_type: Union[type[T], ForwardRef] = sig.return_annotation

if is_unresolved_type(return_type):
raise BuiltinTypeFactoryError(factory_or_class, return_type)

if isinstance(return_type, ForwardRef):
return_type = resolve_forwardref(factory_or_class, return_type)

Expand Down
1 change: 0 additions & 1 deletion tests/features/test_graph_build_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

from ididi.errors import UnsolvableDependencyError
from ididi.graph import DependencyGraph


Expand Down
29 changes: 29 additions & 0 deletions tests/features/test_node_ignore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from ididi import DependencyGraph
from ididi.errors import UnsolvableDependencyError


class IgnoreNode:
def __init__(self, name: str, age: int): ...


def test_direct_resolve_fail():
dg = DependencyGraph()
dg.node(IgnoreNode)
with pytest.raises(UnsolvableDependencyError):
dg.static_resolve(IgnoreNode)


def test_resolve_fail_with_partial_ignore():
dg = DependencyGraph()

dg.node(ignore=("name",))(IgnoreNode)
with pytest.raises(UnsolvableDependencyError):
dg.static_resolve(IgnoreNode)


def test_resolve_with_ignore():
dg = DependencyGraph()
dg.node(ignore=("name", int))(IgnoreNode)
dg.static_resolve(IgnoreNode)
58 changes: 58 additions & 0 deletions tests/regression/test_static_resolve_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from dataclasses import dataclass

import pytest

from ididi import DependencyGraph
from ididi.errors import BuiltinTypeFactoryError, TopLevelBulitinTypeError


@dataclass
class UserCommand:
user_id: str


@dataclass
class CreateUser(UserCommand):
user_name: str


@dataclass
class RemoveUser(UserCommand):
user_name: str


@dataclass
class UpdateUser(UserCommand):
old_name: str
new_name: str


def update_user(cmd: UpdateUser) -> str:
return cmd.new_name


class UserService:
def __init__(self):
self.name = "name"

def create_user(self, cmd: CreateUser) -> str:
return "hello"

def remove_user(self, cmd: RemoveUser) -> str:
return "goodbye"


def test_static_resolve_all():
dg = DependencyGraph()
dg.node(UserService)
with pytest.raises(BuiltinTypeFactoryError):
dg.node(update_user)

dg.static_resolve_all()


def test_dg_node_builtin():
dg = DependencyGraph()

with pytest.raises(TopLevelBulitinTypeError):
dg.node(int)
1 change: 1 addition & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def test_graph_static_resolved():
c = dg.resolve(ComplianceChecker)
dg2.static_resolve(DatabaseConfig)
d = dg2.resolve(DatabaseConfig)
repr(dg.nodes[ComplianceChecker].config)

dg.merge(dg2)
assert ComplianceChecker in dg and DatabaseConfig in dg
Expand Down

0 comments on commit fd43a7f

Please sign in to comment.