Skip to content

Commit

Permalink
fix: using factory production as dependency in other beans
Browse files Browse the repository at this point in the history
  • Loading branch information
sepgh committed Dec 21, 2023
1 parent e2de320 commit d6340b0
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 6 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class SomeBeanFactory(BeanFactory):

```

**Note**: Factory beans don't obey `primary` keyword. Assure that you have single factory for a class or an interface, or the behaviour may be nondeterministic.

You can also use `_for` keyword of `@bean` instead of implementing `produces(cls)` method.
In face the default implementation of `produces` method checks for `_for` keyword.
Quickest usage however is to implement the method.

### Singleton

Expand Down
13 changes: 11 additions & 2 deletions rhazes/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

from rhazes.bean_builder import DefaultBeanBuilderStrategy
from rhazes.context import ApplicationContext
from rhazes.protocol import InjectionConfiguration, BeanDetails, BeanBuilderStrategy
from rhazes.protocol import (
InjectionConfiguration,
BeanDetails,
BeanBuilderStrategy,
BeanFactory,
)


def bean(
Expand All @@ -13,7 +18,11 @@ def bean(
lazy_dependencies: Optional[List[Union[type, str]]] = None,
):
def decorator(cls):
if _for is not None and not issubclass(cls, _for):
if (
_for is not None
and not issubclass(cls, BeanFactory)
and not issubclass(cls, _for)
):
raise Exception(
f"{cls} bean is meant to be registered for interface {_for} "
f"but its not a subclass of that interface"
Expand Down
21 changes: 19 additions & 2 deletions rhazes/dependency.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Set, Type
from typing import Set, Type, List

from rhazes.collections.stack import UniqueStack
from rhazes.exceptions import DependencyCycleException
from rhazes.protocol import (
BeanProtocol,
DependencyNodeMetadata,
DependencyNode,
BeanFactory,
)


Expand All @@ -15,6 +16,7 @@ def __init__(self, bean_classes: Set[Type[BeanProtocol]] = None):
bean_classes if bean_classes is not None else set()
)
self.bean_interface_map = {}
self.factory_beans = self.get_factory_beans()
self.fill_bean_interface_map()
self.builders = {}
self.node_registry = {}
Expand All @@ -25,6 +27,17 @@ def __add_or_init_interface_mapping(self, interface, cls):
mapping.append(cls)
self.bean_interface_map[interface] = mapping

def get_factory_beans(self) -> List[Type]:
"""
:return: list of production from beans that are factory
"""
return [
bean.produces()
for bean in filter(
lambda bean: issubclass(bean, BeanFactory), self.bean_classes
)
]

def fill_bean_interface_map(self):
# Map list of implementations (value) for each bean interface (value)
for bean_class in self.bean_classes:
Expand Down Expand Up @@ -68,9 +81,13 @@ def register_metadata(self, cls) -> DependencyNodeMetadata:
:return generated DependencyNodeMetadata
"""
metadata = DependencyNodeMetadata.generate(
cls, self.bean_classes, self.bean_interface_map
cls, self.bean_classes, self.bean_interface_map, self.factory_beans
)
self.node_metadata_registry[cls] = metadata

if metadata.is_factory:
self.node_metadata_registry[metadata.bean_for] = metadata

return metadata

def resolve(self) -> dict:
Expand Down
4 changes: 3 additions & 1 deletion rhazes/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@ def generate(
cls,
bean_classes: Iterable[BeanProtocol],
bean_interface_mapping: Dict[Type, Type],
beans_with_factory: Iterable[Type],
):
"""
Generates DependencyNodeMetadata instance for a class (cls) after validating its constructor dependencies
:param cls: class to generate DependencyNodeMetadata for
:param bean_classes: other bean classes, possible to depend on
:param bean_interface_mapping: possible classes to depend on
:param beans_with_factory: beans that are created by factories
:return: generated DependencyNodeMetadata
"""
args = []
Expand Down Expand Up @@ -107,7 +109,7 @@ def generate(
else:
clazz = v.annotation

if clazz in bean_classes:
if clazz in bean_classes or clazz in beans_with_factory:
dependencies.append(clazz)
args.append(None)
dependency_position[clazz] = i
Expand Down
9 changes: 9 additions & 0 deletions tests/data/di/factory/di_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ def name(self):
return factory.tsg.get_string()

return SomeInterfaceImpl()


@bean()
class SomeInterfaceUsage:
def __init__(self, interface: SomeInterface):
self.interface = interface

def get_name(self):
return self.interface.name()
15 changes: 14 additions & 1 deletion tests/test_application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
DepD,
DepE,
)
from tests.data.di.factory.di_factory import SomeInterface, TestStringGeneratorBean
from tests.data.di.factory.di_factory import (
SomeInterface,
TestStringGeneratorBean,
SomeInterfaceUsage,
)


class ApplicationContextTestCase(TestCase):
Expand Down Expand Up @@ -63,6 +67,15 @@ def test_factory_context(self):
si: SomeInterface = self.application_context.get_bean(SomeInterface)
self.assertEqual(si.name(), test_string_generator.get_string())

def test_factory_product_as_dependency(self):
self.assertTrue(self.application_context._initialized)
self.assertIsNotNone(self.application_context.get_bean(SomeInterfaceUsage))
usage: SomeInterfaceUsage = self.application_context.get_bean(
SomeInterfaceUsage
)
self.assertIsNotNone(usage.get_name())
# print(usage.get_name())


class TemporaryContextTestCase(TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit d6340b0

Please sign in to comment.