diff --git a/monkey/common/__init__.py b/monkey/common/__init__.py index e69de29bb..5f2cba2f1 100644 --- a/monkey/common/__init__.py +++ b/monkey/common/__init__.py @@ -0,0 +1 @@ +from .di_container import DIContainer diff --git a/monkey/common/di_container.py b/monkey/common/di_container.py new file mode 100644 index 000000000..1f5a2e03c --- /dev/null +++ b/monkey/common/di_container.py @@ -0,0 +1,83 @@ +import inspect +from typing import Any, MutableMapping, Type, TypeVar + +T = TypeVar("T") + + +class DIContainer: + """ + A dependency injection (DI) container that uses type annotations to resolve and inject + dependencies. + """ + + def __init__(self): + self._type_registry = {} + self._instance_registry = {} + + def register(self, interface: Type[T], concrete_type: Type[T]): + """ + Register a concrete type that satisfies a given interface. + + :param interface: An interface or abstract base class that other classes depend upon + :param concrete_type: A type (class) that implements `interface` + """ + self._type_registry[interface] = concrete_type + DIContainer._del_key(self._instance_registry, interface) + + def register_instance(self, interface: Type[T], instance: T): + """ + Register a concrete instance that satisfies a given interface. + + :param interface: An interface or abstract base class that other classes depend upon + :param instance: An instance (object) of a type that implements `interface` + """ + self._instance_registry[interface] = instance + DIContainer._del_key(self._type_registry, interface) + + def resolve(self, type_: Type[T]) -> T: + """ + Resolves all dependencies and returns a new instance of `type_` using constructor dependency + injection. + + :param type_: A type (class) to construct. + :return: An instance of `type_` + """ + args = [] + + # TODO: Need to handle keyword-only arguments, defaults, varars, etc. + for arg_type in inspect.getfullargspec(type_).annotations.values(): + new_instance = self._resolve_instance(arg_type) + args.append(new_instance) + + return type_(*args) + + def _resolve_instance(self, arg_type: Type): + if arg_type in self._type_registry: + return self._type_registry[arg_type]() + elif arg_type in self._instance_registry: + return self._instance_registry[arg_type] + + raise ValueError(f'Failed to resolve unknown type "{arg_type.__name__}"') + + def release(self, interface: Type[T]): + """ + Deregister's an interface + + :param interface: The interface to release + """ + DIContainer._del_key(self._type_registry, interface) + DIContainer._del_key(self._instance_registry, interface) + + @staticmethod + def _del_key(mapping: MutableMapping[T, Any], key: T): + """ + Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError + if the key does not exist. + + :param MutableMapping: A mapping from which a key will be deleted + :param key: A key to delete from `mapping` + """ + try: + del mapping[key] + except KeyError: + pass diff --git a/monkey/tests/unit_tests/common/test_di_container.py b/monkey/tests/unit_tests/common/test_di_container.py new file mode 100644 index 000000000..d3b6abf57 --- /dev/null +++ b/monkey/tests/unit_tests/common/test_di_container.py @@ -0,0 +1,178 @@ +import pytest + +from common import DIContainer + + +class IServiceA: + pass + + +class IServiceB: + pass + + +class ServiceA(IServiceA): + pass + + +class ServiceB(IServiceB): + pass + + +class TestClass1: + __test__ = False + + def __init__(self, service_a: IServiceA): + self.service_a = service_a + + +class TestClass2: + __test__ = False + + def __init__(self, service_b: IServiceB): + self.service_b = service_b + + +class TestClass3: + __test__ = False + + def __init__(self, service_a: IServiceA, service_b: IServiceB): + self.service_a = service_a + self.service_b = service_b + + +@pytest.fixture +def container(): + return DIContainer() + + +def test_register_resolve(container): + container.register(IServiceA, ServiceA) + test_1 = container.resolve(TestClass1) + + assert isinstance(test_1.service_a, ServiceA) + + +def test_correct_instance_type_injected(container): + container.register(IServiceA, ServiceA) + container.register(IServiceB, ServiceB) + test_1 = container.resolve(TestClass1) + test_2 = container.resolve(TestClass2) + + assert isinstance(test_1.service_a, ServiceA) + assert isinstance(test_2.service_b, ServiceB) + + +def test_multiple_correct_instance_types_injected(container): + container.register(IServiceA, ServiceA) + container.register(IServiceB, ServiceB) + test_3 = container.resolve(TestClass3) + + assert isinstance(test_3.service_a, ServiceA) + assert isinstance(test_3.service_b, ServiceB) + + +def test_register_instance(container): + service_a_instance = ServiceA() + + container.register_instance(IServiceA, service_a_instance) + test_1 = container.resolve(TestClass1) + + assert id(service_a_instance) == id(test_1.service_a) + + +def test_register_multiple_instances(container): + service_a_instance = ServiceA() + service_b_instance = ServiceB() + + container.register_instance(IServiceA, service_a_instance) + container.register_instance(IServiceB, service_b_instance) + test_3 = container.resolve(TestClass3) + + assert id(service_a_instance) == id(test_3.service_a) + assert id(service_b_instance) == id(test_3.service_b) + + +def test_register_mixed_instance_and_type(container): + service_a_instance = ServiceA() + + container.register_instance(IServiceA, service_a_instance) + container.register(IServiceB, ServiceB) + test_2 = container.resolve(TestClass2) + test_3 = container.resolve(TestClass3) + + assert id(service_a_instance) == id(test_3.service_a) + assert isinstance(test_2.service_b, ServiceB) + assert isinstance(test_3.service_b, ServiceB) + assert id(test_2.service_b) != id(test_3.service_b) + + +def test_unregistered_type(): + container = DIContainer() + with pytest.raises(ValueError): + container.resolve(TestClass1) + + +def test_type_registration_overwritten(container): + class ServiceA2(IServiceA): + pass + + container.register(IServiceA, ServiceA) + container.register(IServiceA, ServiceA2) + test_1 = container.resolve(TestClass1) + + assert isinstance(test_1.service_a, ServiceA2) + + +def test_instance_registration_overwritten(container): + service_a_instance_1 = ServiceA() + service_a_instance_2 = ServiceA() + + container.register_instance(IServiceA, service_a_instance_1) + container.register_instance(IServiceA, service_a_instance_2) + test_1 = container.resolve(TestClass1) + + assert id(test_1.service_a) != id(service_a_instance_1) + assert id(test_1.service_a) == id(service_a_instance_2) + + +def test_type_overrides_instance(container): + service_a_instance = ServiceA() + + container.register_instance(IServiceA, service_a_instance) + container.register(IServiceA, ServiceA) + test_1 = container.resolve(TestClass1) + + assert id(test_1.service_a) != id(service_a_instance) + assert isinstance(test_1.service_a, ServiceA) + + +def test_instance_overrides_type(container): + service_a_instance = ServiceA() + + container.register(IServiceA, ServiceA) + container.register_instance(IServiceA, service_a_instance) + test_1 = container.resolve(TestClass1) + + assert id(test_1.service_a) == id(service_a_instance) + + +def test_release_type(): + container = DIContainer() + + container.register(IServiceA, ServiceA) + container.release(IServiceA) + + with pytest.raises(ValueError): + container.resolve(TestClass1) + + +def test_release_instance(): + container = DIContainer() + service_a_instance = ServiceA() + + container.register_instance(IServiceA, service_a_instance) + container.release(IServiceA) + + with pytest.raises(ValueError): + container.resolve(TestClass1)