Skip to content

Commit

Permalink
feat(WIP): resolve function, a working version
Browse files Browse the repository at this point in the history
  • Loading branch information
raceychan committed Feb 5, 2025
1 parent 882cb53 commit 1407e92
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 126 deletions.
39 changes: 32 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -989,15 +989,29 @@ def test_ignore_dependences():
## version 1.4.2
resolve function
Function as Dependency
```graph
Graph.resolve_function
```python
def get_user(session: Session, token: Token) -> Ignore[Any]:
...
```
where we treat function as a dependent.
graph._nodes[get_current_user] = DependentNode(dependent_type=get_current_user)
```python
def validate_admin(user: Annotated[User, get_user]):
...
```
<!--
# TODO: are FnDep nodes?
if so, we need to define new type of node
```python
class FnDepNode:
factory: Callable[P, ...]
factory_type: Literal["function", "resource"]
dependencies: Dependencies
config: NodeConfig
```
```
Expand Down Expand Up @@ -1029,8 +1043,19 @@ or
we introduce `FuncReturn` Mark
```python
def get_user(session: Session, token: Token) -> FuncReturn[User]:
def get_user(session: Session, token: Token) -> Fndep[User]:
...
dg.resolve(get_user)
```
```
FuncReturn can be Any
The biggest difference between Funcdep and factory is that
factory is associated with a type where Fndep is not.
in short,
dg.resolve(User) won't call get_user -->
12 changes: 6 additions & 6 deletions ididi/_ds.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections import defaultdict
from types import MappingProxyType
from typing import Any, Callable, Union
from typing import Any, Callable, Union, Hashable

from ._node import DependentNode
from ._type_resolve import get_bases
from .utils.typing_utils import T

GraphNodes = dict[type, DependentNode[Any]]
GraphNodes = dict[Callable[..., T], DependentNode[T]]
"""
### mapping a type to its corresponding node
"""
Expand All @@ -33,13 +33,13 @@ class TypeRegistry:
def __init__(self):
self._mappings: TypeMappings[Any] = defaultdict(list)

def __getitem__(self, dependent_type: type[T]) -> list[type[T]]:
def __getitem__(self, dependent_type: Callable[..., T]) -> list[type[T]]:
return self._mappings[dependent_type].copy()

def __len__(self) -> int:
return len(self._mappings)

def __contains__(self, dependent_type: type) -> bool:
def __contains__(self, dependent_type: Hashable) -> bool:
return dependent_type in self._mappings

def update(self, other: "TypeRegistry"):
Expand All @@ -64,12 +64,12 @@ def clear(self) -> None:
class Visitor:
__slots__ = ("_nodes",)

def __init__(self, nodes: GraphNodes):
def __init__(self, nodes: GraphNodes[Any]):
self._nodes = nodes

def _visit(
self,
start_types: Union[list[type], type],
start_types: Union[list[Any], type],
pre_visit: Union[Callable[[type], None], None] = None,
post_visit: Union[Callable[[type], None], None] = None,
) -> None:
Expand Down
29 changes: 26 additions & 3 deletions ididi/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from ._type_resolve import (
IDIDI_IGNORE_PARAM_MARK,
IDIDI_UNTYPE_DEP_MARK,
IDIDI_USE_FACTORY_MARK,
FactoryType,
ResolveOrder,
Expand All @@ -33,6 +32,7 @@
is_class_or_method,
is_class_with_empty_init,
is_ctxmgr_cls,
is_function,
is_unsolvable_type,
isasyncgenfunction,
isgeneratorfunction,
Expand Down Expand Up @@ -76,7 +76,7 @@ def func(service: UserService = use(factory)): ...
def func(service: Annotated[UserService, use(factory)]): ...
```
"""
# TODO: support untyped, untyped deps are nodes with dependent being functions.

node = DependentNode[T].from_node(factory, config=NodeConfig(**iconfig))
annt = Annotated[node.dependent_type, node, IDIDI_USE_FACTORY_MARK]
return cast(T, annt)
Expand Down Expand Up @@ -298,7 +298,10 @@ def from_signature(
continue

if IDIDI_USE_FACTORY_MARK not in annotate_meta:
param_type, *_ = get_args(param_type)
param_type, *rest = get_args(param_type)
for r in rest:
if is_function(r):
param_type = r

if param_type is Unpack:
dependencies.update(unpack_to_deps(param_annotation))
Expand Down Expand Up @@ -480,6 +483,11 @@ def _from_factory(

signature = get_typed_signature(f, check_return=True)
dependent: type[T] = resolve_annotation(signature.return_annotation)
# if get_origin(dependent) is Annotated:
# metas = flatten_annotated(dependent)
# if IDIDI_IGNORE_PARAM_MARK in metas:
# return cls._from_function(f, config=config)

node = cls.create(
dependent_type=dependent,
factory=cast(INodeFactory[P, T], f),
Expand Down Expand Up @@ -534,3 +542,18 @@ def from_node(
else:
factory = cast(INodeFactory[P, T], factory_or_class)
return cls._from_factory(factory=factory, config=config)

@classmethod
def _from_function(cls, function: Any, *, config: NodeConfig = DefaultConfig):
deps = Dependencies.from_signature(
function, get_typed_signature(function), config
)

node = DependentNode(
dependent_type=function,
factory=function,
factory_type="function",
dependencies=deps,
config=config,
)
return node
11 changes: 8 additions & 3 deletions ididi/_type_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from inspect import Parameter, Signature
from inspect import isasyncgenfunction as isasyncgenfunction
from inspect import isgeneratorfunction as isgeneratorfunction
from types import GenericAlias, MethodType
from types import FunctionType, GenericAlias, MethodType
from typing import (
Annotated,
Any,
Expand Down Expand Up @@ -50,7 +50,7 @@

IDIDI_USE_FACTORY_MARK = "__ididi_use_factory__"
IDIDI_IGNORE_PARAM_MARK = "__ididi_ignore_param__"
IDIDI_UNTYPE_DEP_MARK = "__ididi_untyped_dep__"
# IDIDI_UNTYPE_DEP_MARK = "__ididi_untyped_dep__"

FactoryType = Literal["default", "function", "resource", "aresource"]
# carry this information in node so that resolve does not have to do
Expand Down Expand Up @@ -118,6 +118,7 @@ def get_typed_signature(

if isinstance(typed_return, ForwardRef):
raise ForwardReferenceNotFoundError(typed_return)

if check_return and is_unsolvable_type(typed_return):
raise UnsolvableReturnTypeError(call, typed_return)

Expand Down Expand Up @@ -196,6 +197,10 @@ def is_class_or_method(obj: Any) -> bool:
return isinstance(obj, (type, MethodType, classmethod))


def is_function(obj: Any):
return isinstance(obj, FunctionType)


def is_class(
obj: Union[type[T], Callable[..., Union[T, Awaitable[T]]]]
) -> TypeGuard[type[T]]:
Expand All @@ -205,7 +210,7 @@ def is_class(
origin = get_origin(obj) or obj
is_type = isinstance(origin, type)
is_generic_alias = isinstance(obj, GenericAlias)
return is_type or is_generic_alias or origin is Annotated
return is_type or is_generic_alias


def is_class_with_empty_init(cls: type) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions ididi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def __init__(
*,
dep_name: str,
factory: Union[Callable[..., Any], type],
dependent_type: type,
dependency_type: type,
dependent_type: Callable[..., Any],
dependency_type: Callable[..., Any],
):
type_repr = getattr(dependency_type, "__name__", str(dependency_type))
param_repr = f" * {dependent_type.__name__}({dep_name}: {type_repr}) \n value of `{dep_name}` must be provided"
Expand Down Expand Up @@ -206,18 +206,18 @@ class GraphResolveError(GraphError):
class CircularDependencyDetectedError(GraphResolveError):
"""Raised when a circular dependency is detected in the dependency graph."""

def __init__(self, cycle_path: list[type]):
def __init__(self, cycle_path: list[Callable[..., Any]]):
cycle_str = " -> ".join(t.__name__ for t in cycle_path)
self._cycle_path = cycle_path
super().__init__(f"Circular dependency detected: {cycle_str}")

@property
def cycle_path(self) -> list[type]:
def cycle_path(self) -> list[Callable[..., Any]]:
return self._cycle_path


class ReusabilityConflictError(GraphResolveError):
def __init__(self, path: list[type], nonreuse: type):
def __init__(self, path: list[Callable[..., Any]], nonreuse: type):
conflict_str = " -> ".join(t.__name__ for t in path)
msg = f"""Transient dependency `{nonreuse.__name__}` with reuse dependents \
\n make sure each of {conflict_str} is configured as `reuse=False` \
Expand Down
Loading

0 comments on commit 1407e92

Please sign in to comment.