Common: Add recursive resolution to DIContainer

This commit is contained in:
Mike Salvatore 2022-04-25 15:56:59 -04:00
parent 379a71d8e2
commit 435b619a5d
2 changed files with 78 additions and 14 deletions

View File

@ -44,21 +44,32 @@ class DIContainer:
""" """
args = [] args = []
# TODO: Need to handle keyword-only arguments, defaults, varars, etc. # TODO: Need to handle keyword-only arguments, defaults, varargs, etc.
for arg_type in inspect.getfullargspec(type_).annotations.values(): for arg_type in inspect.getfullargspec(type_).annotations.values():
new_instance = self._resolve_instance(arg_type) instance = self._resolve_arg_type(arg_type)
args.append(new_instance) args.append(instance)
return type_(*args) return type_(*args)
def _resolve_instance(self, arg_type: Type): def _resolve_arg_type(self, arg_type: Type[T]) -> T:
if arg_type in self._type_registry: if arg_type in self._type_registry:
return self._type_registry[arg_type]() return self._resolve_type(arg_type)
elif arg_type in self._instance_registry: elif arg_type in self._instance_registry:
return self._instance_registry[arg_type] return self._resolve_instance(arg_type)
raise ValueError(f'Failed to resolve unknown type "{arg_type.__name__}"') raise ValueError(f'Failed to resolve unknown type "{arg_type.__name__}"')
def _resolve_type(self, arg_type: Type[T]) -> T:
try:
return self._type_registry[arg_type]()
except TypeError:
# arg_type has dependencies that must be resolved. Recursively call resolve() to
# construct an instance of arg_type with all of the requesite dependencies injected.
return self.resolve(self._type_registry[arg_type])
def _resolve_instance(self, arg_type: Type[T]) -> T:
return self._instance_registry[arg_type]
def release(self, interface: Type[T]): def release(self, interface: Type[T]):
""" """
Deregister's an interface Deregister's an interface

View File

@ -1,13 +1,15 @@
import abc
import pytest import pytest
from common import DIContainer from common import DIContainer
class IServiceA: class IServiceA(metaclass=abc.ABCMeta):
pass pass
class IServiceB: class IServiceB(metaclass=abc.ABCMeta):
pass pass
@ -157,9 +159,7 @@ def test_instance_overrides_type(container):
assert id(test_1.service_a) == id(service_a_instance) assert id(test_1.service_a) == id(service_a_instance)
def test_release_type(): def test_release_type(container):
container = DIContainer()
container.register(IServiceA, ServiceA) container.register(IServiceA, ServiceA)
container.release(IServiceA) container.release(IServiceA)
@ -167,12 +167,65 @@ def test_release_type():
container.resolve(TestClass1) container.resolve(TestClass1)
def test_release_instance(): def test_release_instance(container):
container = DIContainer()
service_a_instance = ServiceA() service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance) container.register_instance(IServiceA, service_a_instance)
container.release(IServiceA) container.release(IServiceA)
with pytest.raises(ValueError): with pytest.raises(ValueError):
container.resolve(TestClass1) container.resolve(TestClass1)
class IServiceC(metaclass=abc.ABCMeta):
pass
class ServiceC(IServiceC):
def __init__(self, service_a: IServiceA):
self.service_a = service_a
class TestClass4:
def __init__(self, service_c: IServiceC):
self.service_c = service_c
def test_recursive_resolution__depth_2(container):
service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance)
container.register(IServiceC, ServiceC)
test4 = container.resolve(TestClass4)
assert isinstance(test4.service_c, ServiceC)
assert id(test4.service_c.service_a) == id(service_a_instance)
class IServiceD(metaclass=abc.ABCMeta):
pass
class ServiceD(IServiceD):
def __init__(self, service_c: IServiceC, service_b: IServiceB):
self.service_b = service_b
self.service_c = service_c
class TestClass5:
def __init__(self, service_d: IServiceD):
self.service_d = service_d
def test_recursive_resolution__depth_3(container):
container.register(IServiceA, ServiceA)
container.register(IServiceB, ServiceB)
container.register(IServiceC, ServiceC)
container.register(IServiceD, ServiceD)
test5 = container.resolve(TestClass5)
assert isinstance(test5.service_d, ServiceD)
assert isinstance(test5.service_d.service_b, ServiceB)
assert isinstance(test5.service_d.service_c, ServiceC)
assert isinstance(test5.service_d.service_c.service_a, ServiceA)