Common: Add recursive resolution to DIContainer
This commit is contained in:
parent
379a71d8e2
commit
435b619a5d
|
@ -44,21 +44,32 @@ class DIContainer:
|
||||||
"""
|
"""
|
||||||
args = []
|
args = []
|
||||||
|
|
||||||
# TODO: Need to handle keyword-only arguments, defaults, varars, etc.
|
# TODO: Need to handle keyword-only arguments, defaults, varargs, etc.
|
||||||
for arg_type in inspect.getfullargspec(type_).annotations.values():
|
for arg_type in inspect.getfullargspec(type_).annotations.values():
|
||||||
new_instance = self._resolve_instance(arg_type)
|
instance = self._resolve_arg_type(arg_type)
|
||||||
args.append(new_instance)
|
args.append(instance)
|
||||||
|
|
||||||
return type_(*args)
|
return type_(*args)
|
||||||
|
|
||||||
def _resolve_instance(self, arg_type: Type):
|
def _resolve_arg_type(self, arg_type: Type[T]) -> T:
|
||||||
if arg_type in self._type_registry:
|
if arg_type in self._type_registry:
|
||||||
return self._type_registry[arg_type]()
|
return self._resolve_type(arg_type)
|
||||||
elif arg_type in self._instance_registry:
|
elif arg_type in self._instance_registry:
|
||||||
return self._instance_registry[arg_type]
|
return self._resolve_instance(arg_type)
|
||||||
|
|
||||||
raise ValueError(f'Failed to resolve unknown type "{arg_type.__name__}"')
|
raise ValueError(f'Failed to resolve unknown type "{arg_type.__name__}"')
|
||||||
|
|
||||||
|
def _resolve_type(self, arg_type: Type[T]) -> T:
|
||||||
|
try:
|
||||||
|
return self._type_registry[arg_type]()
|
||||||
|
except TypeError:
|
||||||
|
# arg_type has dependencies that must be resolved. Recursively call resolve() to
|
||||||
|
# construct an instance of arg_type with all of the requesite dependencies injected.
|
||||||
|
return self.resolve(self._type_registry[arg_type])
|
||||||
|
|
||||||
|
def _resolve_instance(self, arg_type: Type[T]) -> T:
|
||||||
|
return self._instance_registry[arg_type]
|
||||||
|
|
||||||
def release(self, interface: Type[T]):
|
def release(self, interface: Type[T]):
|
||||||
"""
|
"""
|
||||||
Deregister's an interface
|
Deregister's an interface
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
|
import abc
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from common import DIContainer
|
from common import DIContainer
|
||||||
|
|
||||||
|
|
||||||
class IServiceA:
|
class IServiceA(metaclass=abc.ABCMeta):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class IServiceB:
|
class IServiceB(metaclass=abc.ABCMeta):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,9 +159,7 @@ def test_instance_overrides_type(container):
|
||||||
assert id(test_1.service_a) == id(service_a_instance)
|
assert id(test_1.service_a) == id(service_a_instance)
|
||||||
|
|
||||||
|
|
||||||
def test_release_type():
|
def test_release_type(container):
|
||||||
container = DIContainer()
|
|
||||||
|
|
||||||
container.register(IServiceA, ServiceA)
|
container.register(IServiceA, ServiceA)
|
||||||
container.release(IServiceA)
|
container.release(IServiceA)
|
||||||
|
|
||||||
|
@ -167,12 +167,65 @@ def test_release_type():
|
||||||
container.resolve(TestClass1)
|
container.resolve(TestClass1)
|
||||||
|
|
||||||
|
|
||||||
def test_release_instance():
|
def test_release_instance(container):
|
||||||
container = DIContainer()
|
|
||||||
service_a_instance = ServiceA()
|
service_a_instance = ServiceA()
|
||||||
|
|
||||||
container.register_instance(IServiceA, service_a_instance)
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
|
||||||
container.release(IServiceA)
|
container.release(IServiceA)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
container.resolve(TestClass1)
|
container.resolve(TestClass1)
|
||||||
|
|
||||||
|
|
||||||
|
class IServiceC(metaclass=abc.ABCMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceC(IServiceC):
|
||||||
|
def __init__(self, service_a: IServiceA):
|
||||||
|
self.service_a = service_a
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass4:
|
||||||
|
def __init__(self, service_c: IServiceC):
|
||||||
|
self.service_c = service_c
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_resolution__depth_2(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
container.register(IServiceC, ServiceC)
|
||||||
|
|
||||||
|
test4 = container.resolve(TestClass4)
|
||||||
|
|
||||||
|
assert isinstance(test4.service_c, ServiceC)
|
||||||
|
assert id(test4.service_c.service_a) == id(service_a_instance)
|
||||||
|
|
||||||
|
|
||||||
|
class IServiceD(metaclass=abc.ABCMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceD(IServiceD):
|
||||||
|
def __init__(self, service_c: IServiceC, service_b: IServiceB):
|
||||||
|
self.service_b = service_b
|
||||||
|
self.service_c = service_c
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass5:
|
||||||
|
def __init__(self, service_d: IServiceD):
|
||||||
|
self.service_d = service_d
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_resolution__depth_3(container):
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.register(IServiceB, ServiceB)
|
||||||
|
container.register(IServiceC, ServiceC)
|
||||||
|
container.register(IServiceD, ServiceD)
|
||||||
|
|
||||||
|
test5 = container.resolve(TestClass5)
|
||||||
|
|
||||||
|
assert isinstance(test5.service_d, ServiceD)
|
||||||
|
assert isinstance(test5.service_d.service_b, ServiceB)
|
||||||
|
assert isinstance(test5.service_d.service_c, ServiceC)
|
||||||
|
assert isinstance(test5.service_d.service_c.service_a, ServiceA)
|
||||||
|
|
Loading…
Reference in New Issue