diff --git a/monkey/common/di_container.py b/monkey/common/di_container.py index 2212a8eb2..0a735136e 100644 --- a/monkey/common/di_container.py +++ b/monkey/common/di_container.py @@ -21,6 +21,7 @@ class DIContainer: def __init__(self): self._type_registry = {} self._instance_registry = {} + self._convention_registry = {} def register(self, interface: Type[T], concrete_type: Type[T]): """ @@ -60,6 +61,34 @@ class DIContainer: self._instance_registry[interface] = instance DIContainer._del_key(self._type_registry, interface) + def register_convention(self, type_: Type[T], name: str, instance: T): + """ + Register an instance as a convention + + At times — particularly when dealing with primative types — it can be useful to define a + convention for how dependencies should be resolved. For example, you might want any class + that specifies `hostname: str` in its constructor to receive the hostname of the system it's + running on. Registering a convention allows you to assign an object instance to a type, name + pair. + + Example: + class TestClass: + def __init__(self, hostname: str): + self.hostname = hostname + + di_container = DIContainer() + di_container.register_convention(str, "hostname", "my_hostname.domain") + + test = di_container.resolve(TestClass) + assert test.hostname == "my_hostname.domain" + + :param **type_**: The `type` (class) of the dependency + :param name: The name of the dependency parameter + :param instance: An instance (object) of `type_` that will be injected into constructors + that specify `[name]: [type_]` as parameters + """ + self._convention_registry[(type_, name)] = instance + def resolve(self, type_: Type[T]) -> T: """ Resolves all dependencies and returns a new instance of `type_` using constructor dependency @@ -88,12 +117,25 @@ class DIContainer: """ args = [] - for arg_type in inspect.getfullargspec(type_).annotations.values(): - instance = self._resolve_type(arg_type) + for arg_name, arg_type in inspect.getfullargspec(type_).annotations.items(): + try: + instance = self._resolve_convention(arg_type, arg_name) + except UnregisteredConventionError: + instance = self._resolve_type(arg_type) + args.append(instance) return tuple(args) + def _resolve_convention(self, type_: Type[T], name: str) -> T: + convention_identifier = (type_, name) + try: + return self._convention_registry[convention_identifier] + except KeyError: + raise UnregisteredConventionError( + f"Failed to resolve unregistered convention {convention_identifier}" + ) + def _resolve_type(self, type_: Type[T]) -> T: if type_ in self._type_registry: return self._construct_new_instance(type_) diff --git a/monkey/tests/unit_tests/common/test_di_container.py b/monkey/tests/unit_tests/common/test_di_container.py index f7ea5534c..81e0220b4 100644 --- a/monkey/tests/unit_tests/common/test_di_container.py +++ b/monkey/tests/unit_tests/common/test_di_container.py @@ -281,3 +281,88 @@ def test_register_instance_with_conflicting_type(container): service_b_instance = ServiceB() with pytest.raises(TypeError): container.register_instance(IServiceA, service_b_instance) + + +class TestClass6: + __test__ = False + + def __init__(self, my_str: str): + self.my_str = my_str + + +def test_register_convention(container): + my_str = "test_string" + container.register_convention(str, "my_str", my_str) + + test_6 = container.resolve(TestClass6) + + assert test_6.my_str == my_str + + +class TestClass7: + __test__ = False + + def __init__(self, my_str1: str, my_str2: str): + self.my_str1 = my_str1 + self.my_str2 = my_str2 + + +def test_register_convention__multiple_parameters_same_type(container): + my_str1 = "s1" + my_str2 = "s2" + container.register_convention(str, "my_str2", my_str2) + container.register_convention(str, "my_str1", my_str1) + + test_7 = container.resolve(TestClass7) + assert test_7.my_str1 == my_str1 + assert test_7.my_str2 == my_str2 + + +class TestClass8: + __test__ = False + + def __init__(self, my_str: str, my_int: int): + self.my_str = my_str + self.my_int = my_int + + +def test_register_convention__multiple_parameters_different_types(container): + my_str = "test_string" + my_int = 42 + container.register_convention(str, "my_str", my_str) + container.register_convention(int, "my_int", my_int) + + test_8 = container.resolve(TestClass8) + assert test_8.my_str == my_str + assert test_8.my_int == my_int + + +class TestClass9: + __test__ = False + + def __init__(self, service_a: IServiceA, my_str: str): + self.service_a = service_a + self.my_str = my_str + + +def test_register_convention__type_properly_resolved(container): + my_str = "test_string" + + container.register(IServiceA, ServiceA) + container.register_convention(str, "my_str", my_str) + test_9 = container.resolve(TestClass9) + + assert isinstance(test_9.service_a, ServiceA) + assert test_9.my_str == my_str + + +def test_register_convention__instance_properly_resolved(container): + service_a_instance = ServiceA() + my_str = "test_string" + + container.register_instance(IServiceA, service_a_instance) + container.register_convention(str, "my_str", my_str) + test_9 = container.resolve(TestClass9) + + assert id(service_a_instance) == id(test_9.service_a) + assert test_9.my_str == my_str