From d6340b005b897fbafa05f9c108f31b69a4234dd4 Mon Sep 17 00:00:00 2001 From: sepgh Date: Thu, 21 Dec 2023 19:04:08 +0330 Subject: [PATCH] fix: using factory production as dependency in other beans --- README.md | 5 +++++ rhazes/decorator.py | 13 +++++++++++-- rhazes/dependency.py | 21 +++++++++++++++++++-- rhazes/protocol.py | 4 +++- tests/data/di/factory/di_factory.py | 9 +++++++++ tests/test_application_context.py | 15 ++++++++++++++- 6 files changed, 61 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 170df05..687e6c5 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,11 @@ class SomeBeanFactory(BeanFactory): ``` +**Note**: Factory beans don't obey `primary` keyword. Assure that you have single factory for a class or an interface, or the behaviour may be nondeterministic. + +You can also use `_for` keyword of `@bean` instead of implementing `produces(cls)` method. +In face the default implementation of `produces` method checks for `_for` keyword. +Quickest usage however is to implement the method. ### Singleton diff --git a/rhazes/decorator.py b/rhazes/decorator.py index 628625c..668386c 100644 --- a/rhazes/decorator.py +++ b/rhazes/decorator.py @@ -3,7 +3,12 @@ from rhazes.bean_builder import DefaultBeanBuilderStrategy from rhazes.context import ApplicationContext -from rhazes.protocol import InjectionConfiguration, BeanDetails, BeanBuilderStrategy +from rhazes.protocol import ( + InjectionConfiguration, + BeanDetails, + BeanBuilderStrategy, + BeanFactory, +) def bean( @@ -13,7 +18,11 @@ def bean( lazy_dependencies: Optional[List[Union[type, str]]] = None, ): def decorator(cls): - if _for is not None and not issubclass(cls, _for): + if ( + _for is not None + and not issubclass(cls, BeanFactory) + and not issubclass(cls, _for) + ): raise Exception( f"{cls} bean is meant to be registered for interface {_for} " f"but its not a subclass of that interface" diff --git a/rhazes/dependency.py b/rhazes/dependency.py index 89f37a9..4d3dc26 100644 --- a/rhazes/dependency.py +++ b/rhazes/dependency.py @@ -1,4 +1,4 @@ -from typing import Set, Type +from typing import Set, Type, List from rhazes.collections.stack import UniqueStack from rhazes.exceptions import DependencyCycleException @@ -6,6 +6,7 @@ BeanProtocol, DependencyNodeMetadata, DependencyNode, + BeanFactory, ) @@ -15,6 +16,7 @@ def __init__(self, bean_classes: Set[Type[BeanProtocol]] = None): bean_classes if bean_classes is not None else set() ) self.bean_interface_map = {} + self.factory_beans = self.get_factory_beans() self.fill_bean_interface_map() self.builders = {} self.node_registry = {} @@ -25,6 +27,17 @@ def __add_or_init_interface_mapping(self, interface, cls): mapping.append(cls) self.bean_interface_map[interface] = mapping + def get_factory_beans(self) -> List[Type]: + """ + :return: list of production from beans that are factory + """ + return [ + bean.produces() + for bean in filter( + lambda bean: issubclass(bean, BeanFactory), self.bean_classes + ) + ] + def fill_bean_interface_map(self): # Map list of implementations (value) for each bean interface (value) for bean_class in self.bean_classes: @@ -68,9 +81,13 @@ def register_metadata(self, cls) -> DependencyNodeMetadata: :return generated DependencyNodeMetadata """ metadata = DependencyNodeMetadata.generate( - cls, self.bean_classes, self.bean_interface_map + cls, self.bean_classes, self.bean_interface_map, self.factory_beans ) self.node_metadata_registry[cls] = metadata + + if metadata.is_factory: + self.node_metadata_registry[metadata.bean_for] = metadata + return metadata def resolve(self) -> dict: diff --git a/rhazes/protocol.py b/rhazes/protocol.py index dbfa01e..e70670f 100644 --- a/rhazes/protocol.py +++ b/rhazes/protocol.py @@ -74,12 +74,14 @@ def generate( cls, bean_classes: Iterable[BeanProtocol], bean_interface_mapping: Dict[Type, Type], + beans_with_factory: Iterable[Type], ): """ Generates DependencyNodeMetadata instance for a class (cls) after validating its constructor dependencies :param cls: class to generate DependencyNodeMetadata for :param bean_classes: other bean classes, possible to depend on :param bean_interface_mapping: possible classes to depend on + :param beans_with_factory: beans that are created by factories :return: generated DependencyNodeMetadata """ args = [] @@ -107,7 +109,7 @@ def generate( else: clazz = v.annotation - if clazz in bean_classes: + if clazz in bean_classes or clazz in beans_with_factory: dependencies.append(clazz) args.append(None) dependency_position[clazz] = i diff --git a/tests/data/di/factory/di_factory.py b/tests/data/di/factory/di_factory.py index 1eccdd3..aa5e62d 100644 --- a/tests/data/di/factory/di_factory.py +++ b/tests/data/di/factory/di_factory.py @@ -33,3 +33,12 @@ def name(self): return factory.tsg.get_string() return SomeInterfaceImpl() + + +@bean() +class SomeInterfaceUsage: + def __init__(self, interface: SomeInterface): + self.interface = interface + + def get_name(self): + return self.interface.name() diff --git a/tests/test_application_context.py b/tests/test_application_context.py index 0ae2e83..37ebb6e 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -13,7 +13,11 @@ DepD, DepE, ) -from tests.data.di.factory.di_factory import SomeInterface, TestStringGeneratorBean +from tests.data.di.factory.di_factory import ( + SomeInterface, + TestStringGeneratorBean, + SomeInterfaceUsage, +) class ApplicationContextTestCase(TestCase): @@ -63,6 +67,15 @@ def test_factory_context(self): si: SomeInterface = self.application_context.get_bean(SomeInterface) self.assertEqual(si.name(), test_string_generator.get_string()) + def test_factory_product_as_dependency(self): + self.assertTrue(self.application_context._initialized) + self.assertIsNotNone(self.application_context.get_bean(SomeInterfaceUsage)) + usage: SomeInterfaceUsage = self.application_context.get_bean( + SomeInterfaceUsage + ) + self.assertIsNotNone(usage.get_name()) + # print(usage.get_name()) + class TemporaryContextTestCase(TestCase): def setUp(self) -> None: