Common: Resolve registered instances and types directly

This commit is contained in:
Mike Salvatore 2022-04-26 01:21:36 -04:00
parent 435b619a5d
commit 7a62434364
2 changed files with 38 additions and 12 deletions

View File

@ -42,24 +42,29 @@ class DIContainer:
:param type_: A type (class) to construct.
:return: An instance of `type_`
"""
try:
return self._resolve_type(type_)
except ValueError:
pass
args = []
# TODO: Need to handle keyword-only arguments, defaults, varargs, etc.
for arg_type in inspect.getfullargspec(type_).annotations.values():
instance = self._resolve_arg_type(arg_type)
instance = self._resolve_type(arg_type)
args.append(instance)
return type_(*args)
def _resolve_arg_type(self, arg_type: Type[T]) -> T:
if arg_type in self._type_registry:
return self._resolve_type(arg_type)
elif arg_type in self._instance_registry:
return self._resolve_instance(arg_type)
def _resolve_type(self, type_: Type[T]) -> T:
if type_ in self._type_registry:
return self._construct_new_instance(type_)
elif type_ in self._instance_registry:
return self._retrieve_registered_instance(type_)
raise ValueError(f'Failed to resolve unknown type "{arg_type.__name__}"')
raise ValueError(f'Failed to resolve unknown type "{type_.__name__}"')
def _resolve_type(self, arg_type: Type[T]) -> T:
def _construct_new_instance(self, arg_type: Type[T]) -> T:
try:
return self._type_registry[arg_type]()
except TypeError:
@ -67,7 +72,7 @@ class DIContainer:
# construct an instance of arg_type with all of the requesite dependencies injected.
return self.resolve(self._type_registry[arg_type])
def _resolve_instance(self, arg_type: Type[T]) -> T:
def _retrieve_registered_instance(self, arg_type: Type[T]) -> T:
return self._instance_registry[arg_type]
def release(self, interface: Type[T]):

View File

@ -6,6 +6,8 @@ from common import DIContainer
class IServiceA(metaclass=abc.ABCMeta):
@abc.abstractmethod
def do_something(self):
pass
@ -14,6 +16,7 @@ class IServiceB(metaclass=abc.ABCMeta):
class ServiceA(IServiceA):
def do_something(self):
pass
@ -117,6 +120,7 @@ def test_unregistered_type():
def test_type_registration_overwritten(container):
class ServiceA2(IServiceA):
def do_something(self):
pass
container.register(IServiceA, ServiceA)
@ -229,3 +233,20 @@ def test_recursive_resolution__depth_3(container):
assert isinstance(test5.service_d.service_b, ServiceB)
assert isinstance(test5.service_d.service_c, ServiceC)
assert isinstance(test5.service_d.service_c.service_a, ServiceA)
def test_resolve_registered_interface(container):
container.register(IServiceA, ServiceA)
resolved_instance = container.resolve(IServiceA)
assert isinstance(resolved_instance, ServiceA)
def test_resolve_registered_instance(container):
service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance)
service_a_actual_instance = container.resolve(IServiceA)
assert id(service_a_actual_instance) == id(service_a_instance)