Skip to content

Commit

Permalink
feat: performance boost
Browse files Browse the repository at this point in the history
  • Loading branch information
raceychan committed Feb 9, 2025
1 parent 1543fd3 commit 14582e9
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 48 deletions.
6 changes: 1 addition & 5 deletions ididi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,7 @@ class GraphError(IDIDIError):

class OutOfScopeError(GraphError):
def __init__(self, name: Hashable = ""):
if name:
msg = f"scope with {name=} not found in current context"
else:
msg = "scope not found in current context"

msg = f"scope with {name=} not found in current context"
super().__init__(msg)


Expand Down
68 changes: 31 additions & 37 deletions ididi/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@
from .utils.typing_utils import P, T

Stack = TypeVar("Stack", ExitStack, AsyncExitStack)
ScopeContext = ContextVar[Union["SyncScope", "AsyncScope"]]
AnyScope = Union["SyncScope", "AsyncScope"]
ScopeToken = Token[AnyScope]
ScopeContext = ContextVar[AnyScope]
Resolved = Maybe[dict[str, Any]]


Expand All @@ -102,21 +102,21 @@ def register_dependent(
def resolve_dfs(
graph: "Resolver",
nodes: GraphNodes[Any],
# cache: ResolvedSingletons[Any],
cache: ResolvedSingletons[Any],
ptype: IDependent[T],
overrides: dict[str, Any],
) -> T:
if resolution := graph.get_resolve_cache(ptype):
if resolution := cache.get(ptype):
return resolution

pnode = nodes.get(ptype) or graph.analyze(ptype)

params = overrides
for name, param in pnode.dependencies:
if name not in params:
params[name] = resolve_dfs(graph, nodes, param.param_type, {})
params[name] = resolve_dfs(graph, nodes, cache, param.param_type, {})

instance = pnode.factory(**(params))
instance = pnode.factory(**params)

result = graph.resolve_callback(
resolved=instance,
Expand All @@ -130,19 +130,19 @@ def resolve_dfs(
async def aresolve_dfs(
graph: "Resolver",
nodes: GraphNodes[Any],
# cache: ResolvedSingletons[Any],
cache: ResolvedSingletons[Any],
ptype: IDependent[T],
overrides: dict[str, Any],
) -> T:
if resolution := graph.get_resolve_cache(ptype):
if resolution := cache.get(ptype):
return resolution

pnode = nodes.get(ptype) or graph.analyze(ptype)

params = overrides
for name, param in pnode.dependencies:
if name not in params:
params[name] = await aresolve_dfs(graph, nodes, param.param_type, {})
params[name] = await aresolve_dfs(graph, nodes, cache, param.param_type, {})

instance = pnode.factory(**(params | overrides))
resolved = await graph.aresolve_callback(
Expand Down Expand Up @@ -478,12 +478,14 @@ def resolve(
if args:
raise PositionalOverrideError(args)

if is_provided(resolution := self.get_resolve_cache(dependent)):
if is_provided(resolution := self._resolved_singletons.get(dependent, MISSING)):
return resolution

provided_params = frozenset(overrides)
node: DependentNode[T] = self.analyze(dependent, ignore=provided_params)
return resolve_dfs(self, self._nodes, node.dependent, overrides)
return resolve_dfs(
self, self._nodes, self._resolved_singletons, node.dependent, overrides
)

async def aresolve(
self,
Expand All @@ -498,12 +500,14 @@ async def aresolve(
if args:
raise PositionalOverrideError(args)

if is_provided(resolution := self.get_resolve_cache(dependent)):
if is_provided(resolution := self._resolved_singletons.get(dependent, MISSING)):
return resolution

provided_params = frozenset(overrides)
node: DependentNode[T] = self.analyze(dependent, ignore=provided_params)
return await aresolve_dfs(self, self._nodes, node.dependent, overrides)
return await aresolve_dfs(
self, self._nodes, self._resolved_singletons, node.dependent, overrides
)

def _node(
self, dependent: INode[P, T], config: NodeConfig = DefaultConfig
Expand Down Expand Up @@ -825,7 +829,7 @@ def __init__(
if self_inject:
self.register_singleton(self)

self._default_scope = self.create_scope(DefaultScopeName)
self._default_scope = self._create_scope(DefaultScopeName)
self._token = self._scope_context.set(self._default_scope)

def __repr__(self) -> str:
Expand Down Expand Up @@ -904,13 +908,9 @@ def merge(self, other: Union["Graph", Sequence["Graph"]]):
others = other

for other in others:
try:
other_scope = other._scope_context.get()
except LookupError:
pass
else:
if other_scope.name is not DefaultScopeName:
raise MergeWithScopeStartedError()
other_scope = other._scope_context.get()
if other_scope.name is not DefaultScopeName:
raise MergeWithScopeStartedError()

self._merge_nodes(other)

Expand Down Expand Up @@ -960,7 +960,7 @@ def search_node(self, dep_name: str) -> Union[DependentNode[Any], None]:
if dep.__name__ == dep_name:
return node

def create_scope(
def _create_scope(
self, name: Hashable = "", pre: Maybe[AnyScope] = MISSING
) -> SyncScope:
shared_data = SharedData(
Expand All @@ -978,7 +978,7 @@ def create_scope(
)
return scope

def create_ascope(
def _create_ascope(
self, name: Hashable = "", pre: Maybe[AnyScope] = MISSING
) -> AsyncScope:
shared_data = SharedData(
Expand All @@ -999,7 +999,7 @@ def create_ascope(
@contextmanager
def scope(self, name: Hashable = ""):
previous_scope = self._scope_context.get()
scope = self.create_scope(name, previous_scope)
scope = self._create_scope(name, previous_scope)
token = self._scope_context.set(scope)

exc_type, exc, tb = None, None, None
Expand All @@ -1011,13 +1011,12 @@ def scope(self, name: Hashable = ""):
raise
finally:
scope.__exit__(exc_type, exc, tb)
if is_provided(self._token):
self._scope_context.reset(token)
self._scope_context.reset(token)

@asynccontextmanager
async def ascope(self, name: Hashable = ""):
previous_scope = self._scope_context.get()
ascope = self.create_ascope(name, previous_scope)
ascope = self._create_ascope(name, previous_scope)
token = self._scope_context.set(ascope)

exc_type, exc, tb = None, None, None
Expand All @@ -1029,28 +1028,27 @@ async def ascope(self, name: Hashable = ""):
raise
finally:
await ascope.__aexit__(exc_type, exc, tb)
if is_provided(self._token):
self._scope_context.reset(token)
self._scope_context.reset(token)

@overload
def use_scope(
self,
name: Maybe[Hashable] = MISSING,
name: Hashable = "",
*,
as_async: Literal[False] = False,
) -> "SyncScope": ...

@overload
def use_scope(
self,
name: Maybe[Hashable] = MISSING,
name: Hashable = "",
*,
as_async: Literal[True],
) -> "AsyncScope": ...

def use_scope(
self,
name: Maybe[Hashable] = MISSING,
name: Hashable = "",
*,
as_async: bool = False,
) -> Union["SyncScope", "AsyncScope"]:
Expand All @@ -1060,12 +1058,8 @@ def use_scope(
as_async: bool
pure typing helper, ignored at runtime
"""
try:
scope = self._scope_context.get()
except LookupError as le:
raise OutOfScopeError() from le

if is_provided(name):
scope = self._scope_context.get()
if name != scope.name:
return scope.get_scope(name)
return scope

Expand Down
4 changes: 2 additions & 2 deletions tests/regression/test_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ async def test_graph_ds_slots():
with pytest.raises(AttributeError):
dg.__dict__

sc = dg.create_scope()
sc = dg.scope().__enter__()

with pytest.raises(AttributeError):
sc.__dict__

asc = dg.create_ascope()
asc = await dg.ascope().__aenter__()

with pytest.raises(AttributeError):
asc.__dict__
Expand Down
12 changes: 8 additions & 4 deletions tests/test_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pytest

from ididi import AsyncResource, Graph
from ididi import Graph
from ididi.config import DefaultScopeName
from ididi.errors import (
AsyncResourceInSyncError,
OutOfScopeError,
Expand Down Expand Up @@ -390,7 +391,10 @@ def test_two():

test_two()

dg.use_scope()
with pytest.raises(OutOfScopeError):
dg.use_scope()

assert dg.use_scope(DefaultScopeName)


async def test_db_exec():
Expand Down Expand Up @@ -447,10 +451,10 @@ def func2():
func2()


async def test_use_scope_create_on_miss():
async def test_dg_create_default_scope():
dg = Graph()

dg.use_scope()
assert dg.use_scope(DefaultScopeName)


async def test_share_single_pattern():
Expand Down

0 comments on commit 14582e9

Please sign in to comment.