Skip to content

Commit

Permalink
feat(wip): type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
raceychan committed Feb 5, 2025
1 parent 95ad087 commit 1d69329
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 69 deletions.
40 changes: 22 additions & 18 deletions ididi/_ds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from types import MappingProxyType
from typing import Any, Callable, Hashable, Union
from types import FunctionType, MappingProxyType
from typing import Any, Callable, Hashable, Union, cast

from ._node import DependentNode
from ._type_resolve import get_bases
Expand Down Expand Up @@ -69,9 +69,9 @@ def __init__(self, nodes: GraphNodes[Any]):

def _visit(
self,
start_types: Union[list[Any], type],
pre_visit: Union[Callable[[type], None], None] = None,
post_visit: Union[Callable[[type], None], None] = None,
start_types: Union[list[Callable[..., Any]], Callable[..., Any]],
pre_visit: Union[Callable[[Callable[..., Any]], None], None] = None,
post_visit: Union[Callable[[Callable[..., Any]], None], None] = None,
) -> None:
"""Generic DFS traversal with customizable visit callbacks.
Expand All @@ -80,20 +80,20 @@ def _visit(
pre_visit: Called before visiting node's dependencies
post_visit: Called after visiting node's dependencies
"""
if isinstance(start_types, type):
if isinstance(start_types, type) or isinstance(start_types, FunctionType):
start_types = [start_types]

visited = set[type]()
visited = set[Callable[..., Any]]()

def _get_deps(node_type: type) -> list[type]:
def _get_deps(node_type: Callable[..., Any]) -> list[Callable[..., Any]]:
node = self._nodes[node_type]
return [
p.param_type
for _, p in node.dependencies
if p.param_type in self._nodes
]

def dfs(node_type: type):
def dfs(node_type: Callable[..., Any]):
if node_type in visited:
return
visited.add(node_type)
Expand All @@ -107,37 +107,41 @@ def dfs(node_type: type):
if post_visit:
post_visit(node_type)

for node_type in start_types:
for node_type in cast(list[Callable[..., Any]], start_types):
dfs(node_type)

def get_dependents(self, dependency: type) -> list[type]:
dependents: list[type] = []
def get_dependents(
self, dependency: Callable[..., Any]
) -> list[Callable[..., Any]]:
dependents: list[Callable[..., Any]] = []

def collect_dependent(node_type: type):
def collect_dependent(node_type: Callable[..., Any]):
node = self._nodes[node_type]
if any(p.param_type is dependency for _, p in node.dependencies):
dependents.append(node_type)

self._visit(list(self._nodes), pre_visit=collect_dependent)
return dependents

def get_dependencies(self, dependent: type, recursive: bool = False) -> list[type]:
def get_dependencies(
self, dependent: Callable[..., Any], recursive: bool = False
) -> list[Callable[..., Any]]:
if not recursive:
return [p.param_type for _, p in self._nodes[dependent].dependencies]

def collect_dependencies(t: type):
def collect_dependencies(t: Callable[..., Any]):
if t != dependent:
dependencies.append(t)

dependencies: list[type] = []
dependencies: list[Callable[..., Any]] = []
self._visit(
dependent,
post_visit=collect_dependencies,
)
return dependencies

def top_sorted_dependencies(self) -> list[type]:
def top_sorted_dependencies(self) -> list[Callable[..., Any]]:
"Sort the whole graph, from lowest dependencies to toppest dependents"
order: list[type] = []
order: list[Callable[..., Any]] = []
self._visit(list(self._nodes), post_visit=order.append)
return order
17 changes: 14 additions & 3 deletions ididi/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def auth_service_factory() -> AuthService:
"dependent",
"factory",
"factory_type",
"function_dependent",
"dependencies",
"config",
)
Expand All @@ -381,13 +382,15 @@ def __init__(
dependent: Callable[..., T],
factory: INodeAnyFactory[T],
factory_type: FactoryType,
function_dependent: bool = False,
dependencies: Dependencies,
config: NodeConfig,
):

self.dependent = dependent
self.factory = factory
self.factory_type: FactoryType = factory_type
self.function_dependent = function_dependent
self.dependencies = dependencies
self.config = config

Expand Down Expand Up @@ -484,10 +487,11 @@ def _from_factory(

signature = get_typed_signature(f, check_return=True)
dependent: type[T] = resolve_annotation(signature.return_annotation)

if get_origin(dependent) is Annotated:
metas = flatten_annotated(dependent)
if IDIDI_IGNORE_PARAM_MARK in metas:
node = cls._from_function(f, config=config)
node = cls._from_function(f, factory_type=factory_type, config=config)
return node

node = cls.create(
Expand Down Expand Up @@ -546,15 +550,22 @@ def from_node(
return cls._from_factory(factory=factory, config=config)

@classmethod
def _from_function(cls, function: Any, *, config: NodeConfig = DefaultConfig):
def _from_function(
cls,
function: Any,
*,
factory_type: FactoryType,
config: NodeConfig = DefaultConfig,
):
deps = Dependencies.from_signature(
function, get_typed_signature(function), config
)

node = DependentNode(
dependent=function,
factory=function,
factory_type="function",
factory_type=factory_type,
function_dependent=True,
dependencies=deps,
config=config,
)
Expand Down
2 changes: 1 addition & 1 deletion ididi/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def _node(
self, dependent: INode[P, T], config: NodeConfig = DefaultConfig
) -> DependentNode[T]:
node = DependentNode[T].from_node(dependent, config=config)
if is_function(node.dependent):
if node.function_dependent:
return node

if ori_node := self._nodes.get(node.dependent):
Expand Down
42 changes: 28 additions & 14 deletions tests/features/test_resolve_function.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
# from typing import Annotated
from typing import Annotated

# from ididi import Graph, Ignore
from ididi import Graph, Ignore

from ..test_data import Config, UserService

# class User:
# def __init__(self, name: str, role: str):
# self.name = name
# self.role = role

class User:
def __init__(self, name: str, role: str):
self.name = name
self.role = role

# def get_user() -> Ignore[User]:
# return User("user", "admin")

def get_user(config: Config) -> Ignore[User]:
assert isinstance(config, Config)
return User("user", "admin")

# def validate_admin(user: Annotated[User, get_user]):
# assert user.role == "admin"
# return "ok"

def validate_admin(
user: Annotated[User, get_user], service: UserService
) -> Ignore[str]:
assert user.role == "admin"
assert isinstance(service, UserService)
return "ok"

# def test_dg_resolve_params():
# dg = Graph()
# assert dg.resolve(get_user)

async def test_resolve_function():
dg = Graph()

user = dg.resolve(get_user)
assert isinstance(user, User)


def test_dg_resolve_params():
dg = Graph()

assert dg.resolve(validate_admin) == "ok"
33 changes: 0 additions & 33 deletions tests/test_feat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,3 @@
run test with: make feat
"""

from typing import Annotated

from ididi import Graph, Ignore


class User:
def __init__(self, name: str, role: str):
self.name = name
self.role = role


def get_user() -> Ignore[User]:
return User("user", "admin")


def validate_admin(user: Annotated[User, get_user]) -> Ignore[str]:
assert user.role == "admin"
return "ok"


async def test_resolve_function():
dg = Graph()

user = dg.resolve(get_user)
assert isinstance(user, User)


def test_dg_resolve_params():
dg = Graph()

node = dg.analyze(validate_admin)

assert dg.resolve(validate_admin) == "ok"

0 comments on commit 1d69329

Please sign in to comment.