Skip to content

Commit

Permalink
feat(WIP): untyped deps
Browse files Browse the repository at this point in the history
  • Loading branch information
raceychan committed Feb 5, 2025
1 parent 259d13c commit 882cb53
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 57 deletions.
6 changes: 4 additions & 2 deletions ididi/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ._type_resolve import (
IDIDI_IGNORE_PARAM_MARK,
IDIDI_UNTYPE_DEP_MARK,
IDIDI_USE_FACTORY_MARK,
FactoryType,
ResolveOrder,
Expand Down Expand Up @@ -56,7 +57,7 @@
from .utils.param_utils import MISSING, Maybe, is_provided
from .utils.typing_utils import P, T

# ============== Ididi special hooks ===========
# ============== Ididi marks ===========

Ignore = Annotated[T, IDIDI_IGNORE_PARAM_MARK]

Expand All @@ -75,12 +76,13 @@ def func(service: UserService = use(factory)): ...
def func(service: Annotated[UserService, use(factory)]): ...
```
"""
# TODO: support untyped, untyped deps are nodes with dependent being functions.
node = DependentNode[T].from_node(factory, config=NodeConfig(**iconfig))
annt = Annotated[node.dependent_type, node, IDIDI_USE_FACTORY_MARK]
return cast(T, annt)


# ============== Ididi special hooks ===========
# ============== Ididi marks ===========


def search_meta(meta: list[Any]) -> Union["DependentNode[Any]", None]:
Expand Down
1 change: 1 addition & 0 deletions ididi/_type_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

IDIDI_USE_FACTORY_MARK = "__ididi_use_factory__"
IDIDI_IGNORE_PARAM_MARK = "__ididi_ignore_param__"
IDIDI_UNTYPE_DEP_MARK = "__ididi_untyped_dep__"

FactoryType = Literal["default", "function", "resource", "aresource"]
# carry this information in node so that resolve does not have to do
Expand Down
119 changes: 64 additions & 55 deletions ididi/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,27 +133,6 @@ def __init__(
self._registered_singletons = registered_singletons
self._resolved_singletons = resolved_singletons

def is_registered_singleton(self, dependent_type: type) -> bool:
return dependent_type in self._registered_singletons

@lru_cache(CacheMax)
def should_be_scoped(self, dep_type: INode[P, T]) -> bool:
"Recursively check if a dependent type contains any resource dependency"
if not (resolved_node := self._resolved_nodes.get(dep_type)):
resolved_node = self.analyze(dep_type)

if self.is_registered_singleton(resolved_node.dependent_type):
return False

if resolved_node.is_resource:
return True

unsolved_params = resolved_node.unsolved_params(ignore=self._ignore)
contain_resource = any(
self.should_be_scoped(param_type) for _, param_type in unsolved_params
)
return contain_resource

def _remove_node(self, node: DependentNode[Any]) -> None:
"""
Remove a node from the graph and clean up all its references.
Expand Down Expand Up @@ -189,6 +168,29 @@ def _resolve_concrete_node(self, dependent: type[T]) -> DependentNode[Any]:
concrete_node.check_for_implementations()
return concrete_node

# ================= Public =================

def is_registered_singleton(self, dependent_type: type) -> bool:
return dependent_type in self._registered_singletons

@lru_cache(CacheMax)
def should_be_scoped(self, dep_type: INode[P, T]) -> bool:
"Recursively check if a dependent type contains any resource dependency"
if not (resolved_node := self._resolved_nodes.get(dep_type)):
resolved_node = self.analyze(dep_type)

if self.is_registered_singleton(resolved_node.dependent_type):
return False

if resolved_node.is_resource:
return True

unsolved_params = resolved_node.unsolved_params(ignore=self._ignore)
contain_resource = any(
self.should_be_scoped(param_type) for _, param_type in unsolved_params
)
return contain_resource

def check_param_conflict(self, param_type: type, current_path: list[type]):
if param_type in current_path:
i = current_path.index(param_type)
Expand Down Expand Up @@ -425,6 +427,45 @@ async def aresolve(
is_reuse=node.config.reuse,
)

def analyze_params(
self, func: Callable[P, T], config: NodeConfig = DefaultConfig
) -> tuple[bool, list[tuple[str, type]]]:
deps = Dependencies.from_signature(
signature=get_typed_signature(func), factory=func, config=config
)
depends_on_resource: bool = False
unresolved: list[tuple[str, type]] = []

for name, dep in deps.filter_ignore(self._ignore):
param_type = dep.param_type

if is_unsolvable_type(param_type):
continue

if inject_node := (resolve_use(param_type) or resolve_use(dep.default)):
self._register_node(inject_node)
self._resolved_nodes[param_type] = inject_node
param_type = inject_node.dependent_type

self.analyze(param_type, config=config)
depends_on_resource = depends_on_resource or self.should_be_scoped(
param_type
)
unresolved.append((name, param_type))

return depends_on_resource, unresolved

def resolve_untyped(self, function: Callable[P, T], **overrides: Any) -> T:
"""
recursively resolve params of function and call it
async def validate_admin(user: Annotated[User, use(get_user)]):
...
"""
...

# async def aresolve_function(self, function: Callable[P, T]) -> T: ...

def _node(
self, dependent: INode[P, T], config: NodeConfig = DefaultConfig
) -> DependentNode[T]:
Expand Down Expand Up @@ -778,37 +819,6 @@ def use_scope(
return scope.get_scope(name)
return scope

def analyze_params(
self, func: Callable[P, T], **iconfig: Unpack[INodeConfig]
) -> tuple[bool, list[tuple[str, type]]]:
config = NodeConfig(**iconfig)

deps = Dependencies.from_signature(
signature=get_typed_signature(func), factory=func, config=config
)

depends_on_resource: bool = False
unresolved: list[tuple[str, type]] = []

for name, dep in deps.filter_ignore(self._ignore):
param_type = dep.param_type

if is_unsolvable_type(param_type):
continue

if inject_node := (resolve_use(param_type) or resolve_use(dep.default)):
self._register_node(inject_node)
self._resolved_nodes[param_type] = inject_node
param_type = inject_node.dependent_type

self.analyze(param_type, config=config)
depends_on_resource = depends_on_resource or self.should_be_scoped(
param_type
)
unresolved.append((name, param_type))

return depends_on_resource, unresolved

@overload
def entry(self, **iconfig: Unpack[INodeConfig]) -> TEntryDecor: ...

Expand Down Expand Up @@ -856,7 +866,8 @@ async def func(email_sender: EmailSender, /):
configured = cast(TEntryDecor, partial(self.entry, **iconfig))
return configured

require_scope, unresolved = self.analyze_params(func, **iconfig)
config = NodeConfig(**iconfig)
require_scope, unresolved = self.analyze_params(func, config=config)

def replace(
before: Maybe[type[T]] = MISSING,
Expand Down Expand Up @@ -885,7 +896,6 @@ async def _async_scoped_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if param_name in kwargs:
continue
kwargs[param_name] = await scope.resolve(param_type)

r = await func(*args, **kwargs)
return r

Expand Down Expand Up @@ -914,7 +924,6 @@ def _sync_scoped_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if param_name in kwargs:
continue
kwargs[param_name] = scope.resolve(param_type)

r = sync_func(*args, **kwargs)
return r

Expand Down
10 changes: 10 additions & 0 deletions tests/features/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ def __init__(self, env: str = "prod"):
self.env = env


class MessageQueue:
def __init__(self, config: Config):
self.config = config


class Database:
def __init__(self, config: Config):
self.config = config
Expand All @@ -16,3 +21,8 @@ def __init__(self, db: Database):
class UserService:
def __init__(self, repo: UserRepository):
self.repo = repo


class ProductService:
def __init__(self, mq: MessageQueue):
self.mq = mq
23 changes: 23 additions & 0 deletions tests/features/test_resolve_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Annotated

from ididi import Graph


class User:
def __init__(self, name: str, role: str):
self.name = name
self.role = role


def get_user():
return User("user", "admin")


def validate_admin(user: Annotated[User, get_user]):
assert user.role == "admin"
return "ok"


# def test_dg_resolve_params():
# dg = Graph()
# assert dg.resolve_function(validate_admin) == "ok"

0 comments on commit 882cb53

Please sign in to comment.