diff --git a/monkey/common/di_container.py b/monkey/common/di_container.py index 1f2632789..0a9748f4b 100644 --- a/monkey/common/di_container.py +++ b/monkey/common/di_container.py @@ -8,6 +8,10 @@ class UnregisteredTypeError(ValueError): pass +class UnregisteredConventionError(ValueError): + pass + + class DIContainer: """ A dependency injection (DI) container that uses type annotations to resolve and inject @@ -17,6 +21,7 @@ class DIContainer: def __init__(self): self._type_registry = {} self._instance_registry = {} + self._convention_registry = {} def register(self, interface: Type[T], concrete_type: Type[T]): """ @@ -56,12 +61,45 @@ class DIContainer: self._instance_registry[interface] = instance DIContainer._del_key(self._type_registry, interface) + def register_convention(self, type_: Type[T], name: str, instance: T): + """ + Register an instance as a convention + + At times — particularly when dealing with primitive types — it can be useful to define a + convention for how dependencies should be resolved. For example, you might want any class + that specifies `hostname: str` in its constructor to receive the hostname of the system it's + running on. Registering a convention allows you to assign an object instance to a type, name + pair. + + Example: + class TestClass: + def __init__(self, hostname: str): + self.hostname = hostname + + di_container = DIContainer() + di_container.register_convention(str, "hostname", "my_hostname.domain") + + test = di_container.resolve(TestClass) + assert test.hostname == "my_hostname.domain" + + :param **type_**: The `type` (class) of the dependency + :param name: The name of the dependency parameter + :param instance: An instance (object) of `type_` that will be injected into constructors + that specify `[name]: [type_]` as parameters + """ + self._convention_registry[(type_, name)] = instance + def resolve(self, type_: Type[T]) -> T: """ Resolves all dependencies and returns a new instance of `type_` using constructor dependency injection. Note that only positional arguments are resolved. Varargs, keyword-only args, and default values are ignored. + Dependencies are resolved with the following precedence + + 1. Conventions + 2. Types, Instances + :param **type_**: A `type` (class) to construct :return: An instance of **type_** """ @@ -79,17 +117,32 @@ class DIContainer: that correspond `type_`'s dependencies. Note that only positional arguments are resolved. Varargs, keyword-only args, and default values are ignored. + See resolve() for information about dependency resolution precedence. + :param **type_**: A type (class) to resolve dependencies for :return: An Sequence of dependencies to be injected into `type_`'s constructor """ args = [] - for arg_type in inspect.getfullargspec(type_).annotations.values(): - instance = self._resolve_type(arg_type) + for arg_name, arg_type in inspect.getfullargspec(type_).annotations.items(): + try: + instance = self._resolve_convention(arg_type, arg_name) + except UnregisteredConventionError: + instance = self._resolve_type(arg_type) + args.append(instance) return tuple(args) + def _resolve_convention(self, type_: Type[T], name: str) -> T: + convention_identifier = (type_, name) + try: + return self._convention_registry[convention_identifier] + except KeyError: + raise UnregisteredConventionError( + f"Failed to resolve unregistered convention {convention_identifier}" + ) + def _resolve_type(self, type_: Type[T]) -> T: if type_ in self._type_registry: return self._construct_new_instance(type_) @@ -111,20 +164,30 @@ class DIContainer: def release(self, interface: Type[T]): """ - Deregister's an interface + Deregister an interface :param interface: The interface to release """ DIContainer._del_key(self._type_registry, interface) DIContainer._del_key(self._instance_registry, interface) + def release_convention(self, type_: Type[T], name: str): + """ + Deregister a convention + + :param **type_**: The `type` (class) of the dependency + :param name: The name of the dependency parameter + """ + convention_identifier = (type_, name) + DIContainer._del_key(self._convention_registry, convention_identifier) + @staticmethod def _del_key(mapping: MutableMapping[T, Any], key: T): """ Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError if the key does not exist. - :param MutableMapping: A mapping from which a key will be deleted + :param mapping: A mapping from which a key will be deleted :param key: A key to delete from `mapping` """ try: diff --git a/monkey/monkey_island/cc/services/initialize.py b/monkey/monkey_island/cc/services/initialize.py index 8722db760..a5d283cb3 100644 --- a/monkey/monkey_island/cc/services/initialize.py +++ b/monkey/monkey_island/cc/services/initialize.py @@ -27,10 +27,7 @@ AGENT_BINARIES_PATH = Path(MONKEY_ISLAND_ABS_PATH) / "cc" / "binaries" def initialize_services(data_dir: Path) -> DIContainer: container = DIContainer() - # TODO: everything that is build with DI and expects Path in the constructor - # will use the same data_dir. Come up with a better way to inject - # the data_dir in the things that needed - container.register_instance(Path, data_dir) + container.register_convention(Path, "data_dir", data_dir) container.register_instance(AWSInstance, AWSInstance()) container.register_instance( diff --git a/monkey/tests/unit_tests/common/test_di_container.py b/monkey/tests/unit_tests/common/test_di_container.py index f7ea5534c..857168f82 100644 --- a/monkey/tests/unit_tests/common/test_di_container.py +++ b/monkey/tests/unit_tests/common/test_di_container.py @@ -112,8 +112,7 @@ def test_register_mixed_instance_and_type(container): assert id(test_2.service_b) != id(test_3.service_b) -def test_unregistered_type(): - container = DIContainer() +def test_unregistered_type(container): with pytest.raises(ValueError): container.resolve(TestClass1) @@ -281,3 +280,97 @@ def test_register_instance_with_conflicting_type(container): service_b_instance = ServiceB() with pytest.raises(TypeError): container.register_instance(IServiceA, service_b_instance) + + +class TestClass6: + __test__ = False + + def __init__(self, my_str: str): + self.my_str = my_str + + +def test_register_convention(container): + my_str = "test_string" + container.register_convention(str, "my_str", my_str) + + test_6 = container.resolve(TestClass6) + + assert test_6.my_str == my_str + + +class TestClass7: + __test__ = False + + def __init__(self, my_str1: str, my_str2: str): + self.my_str1 = my_str1 + self.my_str2 = my_str2 + + +def test_register_convention__multiple_parameters_same_type(container): + my_str1 = "s1" + my_str2 = "s2" + container.register_convention(str, "my_str2", my_str2) + container.register_convention(str, "my_str1", my_str1) + + test_7 = container.resolve(TestClass7) + assert test_7.my_str1 == my_str1 + assert test_7.my_str2 == my_str2 + + +class TestClass8: + __test__ = False + + def __init__(self, my_str: str, my_int: int): + self.my_str = my_str + self.my_int = my_int + + +def test_register_convention__multiple_parameters_different_types(container): + my_str = "test_string" + my_int = 42 + container.register_convention(str, "my_str", my_str) + container.register_convention(int, "my_int", my_int) + + test_8 = container.resolve(TestClass8) + assert test_8.my_str == my_str + assert test_8.my_int == my_int + + +class TestClass9: + __test__ = False + + def __init__(self, service_a: IServiceA, my_str: str): + self.service_a = service_a + self.my_str = my_str + + +def test_register_convention__type_properly_resolved(container): + my_str = "test_string" + + container.register(IServiceA, ServiceA) + container.register_convention(str, "my_str", my_str) + test_9 = container.resolve(TestClass9) + + assert isinstance(test_9.service_a, ServiceA) + assert test_9.my_str == my_str + + +def test_register_convention__instance_properly_resolved(container): + service_a_instance = ServiceA() + my_str = "test_string" + + container.register_instance(IServiceA, service_a_instance) + container.register_convention(str, "my_str", my_str) + test_9 = container.resolve(TestClass9) + + assert id(service_a_instance) == id(test_9.service_a) + assert test_9.my_str == my_str + + +def test_release_convention(container): + my_str = "test_string" + container.register_convention(str, "my_str", my_str) + + with pytest.raises(ValueError): + container.release_convention(str, "my_str") + container.resolve(TestClass6) diff --git a/vulture_allowlist.py b/vulture_allowlist.py index ab09127b3..d451f367d 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -182,6 +182,7 @@ GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57) architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25) response_code # unused variable (monkey/monkey_island/cc/services/aws/aws_command_runner.py:26) +release_convention # unused method (monkey/common/di_container.py:174) # TODO DELETE AFTER RESOURCE REFACTORING NetworkMap