Skip to content

Commit

Permalink
feat(#3): add @Inject decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
sepgh committed Oct 23, 2023
1 parent 379b391 commit 1dabc9e
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 41 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ _There is no published version yet_. Written for Django 4.2 using python 3.9. Ot

Once Rhazes `ApplicationContext` is initialized it will scan for classes marked with `@bean` decorator under packages listed in `settings.INSTALLED_APPS` or `settings.RHAZES_PACKAGES` (preferably).

Afterwards, it creates a graph of these classes and their dependencies to each other and starts to create objects for each class and register them as beans under `ApplicationContext().beans`.
Afterwards, it creates a graph of these classes and their dependencies to each other and starts to create objects for each class and register them as beans under `ApplicationContext`.

If everything works perfectly, you can access the beans using `ApplicationContext().beans.get_bean(CLASS)` for a class.
If everything works perfectly, you can access the beans using `ApplicationContext.get_bean(CLASS)` for a class.


## Example
Expand Down Expand Up @@ -68,7 +68,7 @@ from rhazes.context import ApplicationContext
from somepackage import UserStorage, DatabaseUserStorage, CacheUserStorage, ProductManager


application_context = ApplicationContext()
application_context = ApplicationContext
application_context.initialize()

product_manager: ProductManager = application_context.get_bean(ProductManager)
Expand Down
61 changes: 28 additions & 33 deletions rhazes/context.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
from typing import Optional

from django.utils.functional import SimpleLazyObject

from rhazes.dependency import DependencyResolver
from rhazes.protocol import BeanProtocol, BeanFactory, DependencyNodeMetadata
from rhazes.protocol import BeanProtocol, BeanFactory
from rhazes.scanner import ModuleScanner, class_scanner


class ApplicationContext:
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(ApplicationContext, cls).__new__(cls)
return cls.instance

def __init__(self):
self._initialized = False
self._module_scanner = ModuleScanner()
self._builder_registry = {}
_initialized = False
_builder_registry = {}

def _initialize_beans(self):
@classmethod
def _initialize_beans(cls):
beans = set()
bean_factories = set()
modules = self._module_scanner.scan()
modules = ModuleScanner().scan()
for module in modules:
scanned_classes = class_scanner(module)
for scanned_class in scanned_classes:
Expand All @@ -28,29 +24,28 @@ def _initialize_beans(self):
elif issubclass(scanned_class, (BeanFactory,)):
bean_factories.add(scanned_class)

for cls, obj in DependencyResolver(beans, bean_factories).resolve().items():
self.register_bean(cls, obj)
for clazz, obj in DependencyResolver(beans, bean_factories).resolve().items():
cls.register_bean(clazz, obj)

def initialize(self):
if self._initialized:
@classmethod
def initialize(cls):
if cls._initialized:
return
self.register_bean(
ApplicationContext,
DependencyNodeMetadata(None, None, None, is_singleton=True),
lambda context: self,
)
self._initialize_beans()
self._initialized = True

def register_bean(self, cls, builder, override=False):
if cls not in self._builder_registry or override:
self._builder_registry[cls] = builder

def get_bean(self, of: type) -> Optional:
builder = self._builder_registry.get(of)
cls._initialize_beans()
cls._initialized = True

@classmethod
def register_bean(cls, clazz, builder, override=False):
if clazz not in cls._builder_registry or override:
cls._builder_registry[clazz] = builder

@classmethod
def get_bean(cls, of: type) -> Optional:
builder = cls._builder_registry.get(of)
if builder is None:
return None
return builder(self)
return builder(cls)

def get_lazy_bean(self, of: type) -> Optional:
return SimpleLazyObject(lambda: self.get_bean(of))
@classmethod
def get_lazy_bean(cls, of: type) -> Optional:
return SimpleLazyObject(lambda: cls.get_bean(of))
12 changes: 8 additions & 4 deletions rhazes/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def inject_kwargs(injections, configuration, func, kwargs: dict):
):
lazy = configuration.get(v.annotation, {}).get("lazy", False)
if lazy:
kwargs[k] = ApplicationContext().get_lazy_bean(v.annotation)
dep = ApplicationContext.get_lazy_bean(v.annotation)
else:
kwargs[k] = ApplicationContext().get_bean(v.annotation)
dep = ApplicationContext.get_bean(v.annotation)
if dep is not None:
kwargs[k] = dep
return kwargs


Expand All @@ -70,9 +72,11 @@ def __int__(self, *args, **kwargs):

elif callable(obj_of_func):

def proxy(*args, **kwargs):
def proxy(**kwargs):
inject_kwargs(injections, configuration, obj_of_func, kwargs)
return obj_of_func(*args, **kwargs)
return obj_of_func(**kwargs)

return proxy

else:
# Input is neither class or function
Expand Down
2 changes: 1 addition & 1 deletion tests/test_application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@override_settings(RHAZES_PACKAGES=["tests.data.di.context", "tests.data.di.factory"])
class ApplicationContextTestCase(TestCase):
def setUp(self) -> None:
self.application_context = ApplicationContext()
self.application_context = ApplicationContext
self.application_context.initialize()

def test_bean_context(self):
Expand Down
45 changes: 45 additions & 0 deletions tests/test_injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from django.test import TestCase, override_settings

from rhazes.context import ApplicationContext
from rhazes.decorator import inject
from tests.data.di.context.di_context import DepD, DepE


@override_settings(RHAZES_PACKAGES=["tests.data.di.context", "tests.data.di.factory"])
class InjectionTestCase(TestCase):
def setUp(self) -> None:
self.application_context = ApplicationContext
self.application_context.initialize()
self.dep_d: DepD = self.application_context.get_bean(DepD)
self.assertIsNotNone(self.dep_d)

def test_single_bean_injection(self):
@inject()
def echo(dep_d: DepD, inp: str):
return f"{inp}+{dep_d.name()}"

result = echo(inp="test")
self.assertEqual(f"test+{self.dep_d.name()}", result)

def test_unknown_bean_injection(self):
class Unkown:
pass

@inject()
def echo(unkown: Unkown, inp: str):
return inp

with self.assertRaises(TypeError):
echo(inp="test")

def test_optional_dependency(self):
@inject(injections=[DepD])
def echo(dep_d: DepD, dep_e: DepE, inp: str):
return f"{inp}+{dep_e.dep_d.name()}"

with self.assertRaises(TypeError):
echo(inp="test")

dep_e = self.application_context.get_bean(DepE)
result = echo(dep_e=dep_e, inp="test")
self.assertEqual(f"test+{self.dep_d.name()}", result)

0 comments on commit 1dabc9e

Please sign in to comment.