diff --git a/monkey/common/di_container.py b/monkey/common/di_container.py index 1f5a2e03c..c36de1eba 100644 --- a/monkey/common/di_container.py +++ b/monkey/common/di_container.py @@ -44,21 +44,32 @@ class DIContainer: """ args = [] - # TODO: Need to handle keyword-only arguments, defaults, varars, etc. + # TODO: Need to handle keyword-only arguments, defaults, varargs, etc. for arg_type in inspect.getfullargspec(type_).annotations.values(): - new_instance = self._resolve_instance(arg_type) - args.append(new_instance) + instance = self._resolve_arg_type(arg_type) + args.append(instance) return type_(*args) - def _resolve_instance(self, arg_type: Type): + def _resolve_arg_type(self, arg_type: Type[T]) -> T: if arg_type in self._type_registry: - return self._type_registry[arg_type]() + return self._resolve_type(arg_type) elif arg_type in self._instance_registry: - return self._instance_registry[arg_type] + return self._resolve_instance(arg_type) raise ValueError(f'Failed to resolve unknown type "{arg_type.__name__}"') + def _resolve_type(self, arg_type: Type[T]) -> T: + try: + return self._type_registry[arg_type]() + except TypeError: + # arg_type has dependencies that must be resolved. Recursively call resolve() to + # construct an instance of arg_type with all of the requesite dependencies injected. + return self.resolve(self._type_registry[arg_type]) + + def _resolve_instance(self, arg_type: Type[T]) -> T: + return self._instance_registry[arg_type] + def release(self, interface: Type[T]): """ Deregister's an interface diff --git a/monkey/tests/unit_tests/common/test_di_container.py b/monkey/tests/unit_tests/common/test_di_container.py index d3b6abf57..63cfcfd13 100644 --- a/monkey/tests/unit_tests/common/test_di_container.py +++ b/monkey/tests/unit_tests/common/test_di_container.py @@ -1,13 +1,15 @@ +import abc + import pytest from common import DIContainer -class IServiceA: +class IServiceA(metaclass=abc.ABCMeta): pass -class IServiceB: +class IServiceB(metaclass=abc.ABCMeta): pass @@ -157,9 +159,7 @@ def test_instance_overrides_type(container): assert id(test_1.service_a) == id(service_a_instance) -def test_release_type(): - container = DIContainer() - +def test_release_type(container): container.register(IServiceA, ServiceA) container.release(IServiceA) @@ -167,12 +167,65 @@ def test_release_type(): container.resolve(TestClass1) -def test_release_instance(): - container = DIContainer() +def test_release_instance(container): service_a_instance = ServiceA() - container.register_instance(IServiceA, service_a_instance) + container.release(IServiceA) with pytest.raises(ValueError): container.resolve(TestClass1) + + +class IServiceC(metaclass=abc.ABCMeta): + pass + + +class ServiceC(IServiceC): + def __init__(self, service_a: IServiceA): + self.service_a = service_a + + +class TestClass4: + def __init__(self, service_c: IServiceC): + self.service_c = service_c + + +def test_recursive_resolution__depth_2(container): + service_a_instance = ServiceA() + container.register_instance(IServiceA, service_a_instance) + container.register(IServiceC, ServiceC) + + test4 = container.resolve(TestClass4) + + assert isinstance(test4.service_c, ServiceC) + assert id(test4.service_c.service_a) == id(service_a_instance) + + +class IServiceD(metaclass=abc.ABCMeta): + pass + + +class ServiceD(IServiceD): + def __init__(self, service_c: IServiceC, service_b: IServiceB): + self.service_b = service_b + self.service_c = service_c + + +class TestClass5: + def __init__(self, service_d: IServiceD): + self.service_d = service_d + + +def test_recursive_resolution__depth_3(container): + container.register(IServiceA, ServiceA) + container.register(IServiceB, ServiceB) + container.register(IServiceC, ServiceC) + container.register(IServiceD, ServiceD) + + test5 = container.resolve(TestClass5) + + assert isinstance(test5.service_d, ServiceD) + assert isinstance(test5.service_d.service_b, ServiceB) + assert isinstance(test5.service_d.service_c, ServiceC) + assert isinstance(test5.service_d.service_c.service_a, ServiceA)