From 7a62434364e7dd9cda92aba81c64b003e281e8a2 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 26 Apr 2022 01:21:36 -0400 Subject: [PATCH] Common: Resolve registered instances and types directly --- monkey/common/di_container.py | 23 +++++++++------- .../unit_tests/common/test_di_container.py | 27 ++++++++++++++++--- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/monkey/common/di_container.py b/monkey/common/di_container.py index c36de1eba..8ee7fcb1f 100644 --- a/monkey/common/di_container.py +++ b/monkey/common/di_container.py @@ -42,24 +42,29 @@ class DIContainer: :param type_: A type (class) to construct. :return: An instance of `type_` """ + try: + return self._resolve_type(type_) + except ValueError: + pass + args = [] # TODO: Need to handle keyword-only arguments, defaults, varargs, etc. for arg_type in inspect.getfullargspec(type_).annotations.values(): - instance = self._resolve_arg_type(arg_type) + instance = self._resolve_type(arg_type) args.append(instance) return type_(*args) - def _resolve_arg_type(self, arg_type: Type[T]) -> T: - if arg_type in self._type_registry: - return self._resolve_type(arg_type) - elif arg_type in self._instance_registry: - return self._resolve_instance(arg_type) + def _resolve_type(self, type_: Type[T]) -> T: + if type_ in self._type_registry: + return self._construct_new_instance(type_) + elif type_ in self._instance_registry: + return self._retrieve_registered_instance(type_) - raise ValueError(f'Failed to resolve unknown type "{arg_type.__name__}"') + raise ValueError(f'Failed to resolve unknown type "{type_.__name__}"') - def _resolve_type(self, arg_type: Type[T]) -> T: + def _construct_new_instance(self, arg_type: Type[T]) -> T: try: return self._type_registry[arg_type]() except TypeError: @@ -67,7 +72,7 @@ class DIContainer: # construct an instance of arg_type with all of the requesite dependencies injected. return self.resolve(self._type_registry[arg_type]) - def _resolve_instance(self, arg_type: Type[T]) -> T: + def _retrieve_registered_instance(self, arg_type: Type[T]) -> T: return self._instance_registry[arg_type] def release(self, interface: Type[T]): diff --git a/monkey/tests/unit_tests/common/test_di_container.py b/monkey/tests/unit_tests/common/test_di_container.py index 63cfcfd13..dab1a8f29 100644 --- a/monkey/tests/unit_tests/common/test_di_container.py +++ b/monkey/tests/unit_tests/common/test_di_container.py @@ -6,7 +6,9 @@ from common import DIContainer class IServiceA(metaclass=abc.ABCMeta): - pass + @abc.abstractmethod + def do_something(self): + pass class IServiceB(metaclass=abc.ABCMeta): @@ -14,7 +16,8 @@ class IServiceB(metaclass=abc.ABCMeta): class ServiceA(IServiceA): - pass + def do_something(self): + pass class ServiceB(IServiceB): @@ -117,7 +120,8 @@ def test_unregistered_type(): def test_type_registration_overwritten(container): class ServiceA2(IServiceA): - pass + def do_something(self): + pass container.register(IServiceA, ServiceA) container.register(IServiceA, ServiceA2) @@ -229,3 +233,20 @@ def test_recursive_resolution__depth_3(container): assert isinstance(test5.service_d.service_b, ServiceB) assert isinstance(test5.service_d.service_c, ServiceC) assert isinstance(test5.service_d.service_c.service_a, ServiceA) + + +def test_resolve_registered_interface(container): + container.register(IServiceA, ServiceA) + + resolved_instance = container.resolve(IServiceA) + + assert isinstance(resolved_instance, ServiceA) + + +def test_resolve_registered_instance(container): + service_a_instance = ServiceA() + container.register_instance(IServiceA, service_a_instance) + + service_a_actual_instance = container.resolve(IServiceA) + + assert id(service_a_actual_instance) == id(service_a_instance)