forked from p15670423/monkey
Merge pull request #2027 from guardicore/register-di-conventions
Register di conventions
This commit is contained in:
commit
e8001d8cf7
|
@ -8,6 +8,10 @@ class UnregisteredTypeError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnregisteredConventionError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DIContainer:
|
class DIContainer:
|
||||||
"""
|
"""
|
||||||
A dependency injection (DI) container that uses type annotations to resolve and inject
|
A dependency injection (DI) container that uses type annotations to resolve and inject
|
||||||
|
@ -17,6 +21,7 @@ class DIContainer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._type_registry = {}
|
self._type_registry = {}
|
||||||
self._instance_registry = {}
|
self._instance_registry = {}
|
||||||
|
self._convention_registry = {}
|
||||||
|
|
||||||
def register(self, interface: Type[T], concrete_type: Type[T]):
|
def register(self, interface: Type[T], concrete_type: Type[T]):
|
||||||
"""
|
"""
|
||||||
|
@ -56,12 +61,45 @@ class DIContainer:
|
||||||
self._instance_registry[interface] = instance
|
self._instance_registry[interface] = instance
|
||||||
DIContainer._del_key(self._type_registry, interface)
|
DIContainer._del_key(self._type_registry, interface)
|
||||||
|
|
||||||
|
def register_convention(self, type_: Type[T], name: str, instance: T):
|
||||||
|
"""
|
||||||
|
Register an instance as a convention
|
||||||
|
|
||||||
|
At times — particularly when dealing with primitive types — it can be useful to define a
|
||||||
|
convention for how dependencies should be resolved. For example, you might want any class
|
||||||
|
that specifies `hostname: str` in its constructor to receive the hostname of the system it's
|
||||||
|
running on. Registering a convention allows you to assign an object instance to a type, name
|
||||||
|
pair.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class TestClass:
|
||||||
|
def __init__(self, hostname: str):
|
||||||
|
self.hostname = hostname
|
||||||
|
|
||||||
|
di_container = DIContainer()
|
||||||
|
di_container.register_convention(str, "hostname", "my_hostname.domain")
|
||||||
|
|
||||||
|
test = di_container.resolve(TestClass)
|
||||||
|
assert test.hostname == "my_hostname.domain"
|
||||||
|
|
||||||
|
:param **type_**: The `type` (class) of the dependency
|
||||||
|
:param name: The name of the dependency parameter
|
||||||
|
:param instance: An instance (object) of `type_` that will be injected into constructors
|
||||||
|
that specify `[name]: [type_]` as parameters
|
||||||
|
"""
|
||||||
|
self._convention_registry[(type_, name)] = instance
|
||||||
|
|
||||||
def resolve(self, type_: Type[T]) -> T:
|
def resolve(self, type_: Type[T]) -> T:
|
||||||
"""
|
"""
|
||||||
Resolves all dependencies and returns a new instance of `type_` using constructor dependency
|
Resolves all dependencies and returns a new instance of `type_` using constructor dependency
|
||||||
injection. Note that only positional arguments are resolved. Varargs, keyword-only args, and
|
injection. Note that only positional arguments are resolved. Varargs, keyword-only args, and
|
||||||
default values are ignored.
|
default values are ignored.
|
||||||
|
|
||||||
|
Dependencies are resolved with the following precedence
|
||||||
|
|
||||||
|
1. Conventions
|
||||||
|
2. Types, Instances
|
||||||
|
|
||||||
:param **type_**: A `type` (class) to construct
|
:param **type_**: A `type` (class) to construct
|
||||||
:return: An instance of **type_**
|
:return: An instance of **type_**
|
||||||
"""
|
"""
|
||||||
|
@ -79,17 +117,32 @@ class DIContainer:
|
||||||
that correspond `type_`'s dependencies. Note that only positional
|
that correspond `type_`'s dependencies. Note that only positional
|
||||||
arguments are resolved. Varargs, keyword-only args, and default values are ignored.
|
arguments are resolved. Varargs, keyword-only args, and default values are ignored.
|
||||||
|
|
||||||
|
See resolve() for information about dependency resolution precedence.
|
||||||
|
|
||||||
:param **type_**: A type (class) to resolve dependencies for
|
:param **type_**: A type (class) to resolve dependencies for
|
||||||
:return: An Sequence of dependencies to be injected into `type_`'s constructor
|
:return: An Sequence of dependencies to be injected into `type_`'s constructor
|
||||||
"""
|
"""
|
||||||
args = []
|
args = []
|
||||||
|
|
||||||
for arg_type in inspect.getfullargspec(type_).annotations.values():
|
for arg_name, arg_type in inspect.getfullargspec(type_).annotations.items():
|
||||||
instance = self._resolve_type(arg_type)
|
try:
|
||||||
|
instance = self._resolve_convention(arg_type, arg_name)
|
||||||
|
except UnregisteredConventionError:
|
||||||
|
instance = self._resolve_type(arg_type)
|
||||||
|
|
||||||
args.append(instance)
|
args.append(instance)
|
||||||
|
|
||||||
return tuple(args)
|
return tuple(args)
|
||||||
|
|
||||||
|
def _resolve_convention(self, type_: Type[T], name: str) -> T:
|
||||||
|
convention_identifier = (type_, name)
|
||||||
|
try:
|
||||||
|
return self._convention_registry[convention_identifier]
|
||||||
|
except KeyError:
|
||||||
|
raise UnregisteredConventionError(
|
||||||
|
f"Failed to resolve unregistered convention {convention_identifier}"
|
||||||
|
)
|
||||||
|
|
||||||
def _resolve_type(self, type_: Type[T]) -> T:
|
def _resolve_type(self, type_: Type[T]) -> T:
|
||||||
if type_ in self._type_registry:
|
if type_ in self._type_registry:
|
||||||
return self._construct_new_instance(type_)
|
return self._construct_new_instance(type_)
|
||||||
|
@ -111,20 +164,30 @@ class DIContainer:
|
||||||
|
|
||||||
def release(self, interface: Type[T]):
|
def release(self, interface: Type[T]):
|
||||||
"""
|
"""
|
||||||
Deregister's an interface
|
Deregister an interface
|
||||||
|
|
||||||
:param interface: The interface to release
|
:param interface: The interface to release
|
||||||
"""
|
"""
|
||||||
DIContainer._del_key(self._type_registry, interface)
|
DIContainer._del_key(self._type_registry, interface)
|
||||||
DIContainer._del_key(self._instance_registry, interface)
|
DIContainer._del_key(self._instance_registry, interface)
|
||||||
|
|
||||||
|
def release_convention(self, type_: Type[T], name: str):
|
||||||
|
"""
|
||||||
|
Deregister a convention
|
||||||
|
|
||||||
|
:param **type_**: The `type` (class) of the dependency
|
||||||
|
:param name: The name of the dependency parameter
|
||||||
|
"""
|
||||||
|
convention_identifier = (type_, name)
|
||||||
|
DIContainer._del_key(self._convention_registry, convention_identifier)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _del_key(mapping: MutableMapping[T, Any], key: T):
|
def _del_key(mapping: MutableMapping[T, Any], key: T):
|
||||||
"""
|
"""
|
||||||
Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError
|
Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError
|
||||||
if the key does not exist.
|
if the key does not exist.
|
||||||
|
|
||||||
:param MutableMapping: A mapping from which a key will be deleted
|
:param mapping: A mapping from which a key will be deleted
|
||||||
:param key: A key to delete from `mapping`
|
:param key: A key to delete from `mapping`
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -27,10 +27,7 @@ AGENT_BINARIES_PATH = Path(MONKEY_ISLAND_ABS_PATH) / "cc" / "binaries"
|
||||||
def initialize_services(data_dir: Path) -> DIContainer:
|
def initialize_services(data_dir: Path) -> DIContainer:
|
||||||
container = DIContainer()
|
container = DIContainer()
|
||||||
|
|
||||||
# TODO: everything that is build with DI and expects Path in the constructor
|
container.register_convention(Path, "data_dir", data_dir)
|
||||||
# will use the same data_dir. Come up with a better way to inject
|
|
||||||
# the data_dir in the things that needed
|
|
||||||
container.register_instance(Path, data_dir)
|
|
||||||
container.register_instance(AWSInstance, AWSInstance())
|
container.register_instance(AWSInstance, AWSInstance())
|
||||||
|
|
||||||
container.register_instance(
|
container.register_instance(
|
||||||
|
|
|
@ -112,8 +112,7 @@ def test_register_mixed_instance_and_type(container):
|
||||||
assert id(test_2.service_b) != id(test_3.service_b)
|
assert id(test_2.service_b) != id(test_3.service_b)
|
||||||
|
|
||||||
|
|
||||||
def test_unregistered_type():
|
def test_unregistered_type(container):
|
||||||
container = DIContainer()
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
container.resolve(TestClass1)
|
container.resolve(TestClass1)
|
||||||
|
|
||||||
|
@ -281,3 +280,97 @@ def test_register_instance_with_conflicting_type(container):
|
||||||
service_b_instance = ServiceB()
|
service_b_instance = ServiceB()
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
container.register_instance(IServiceA, service_b_instance)
|
container.register_instance(IServiceA, service_b_instance)
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass6:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, my_str: str):
|
||||||
|
self.my_str = my_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_convention(container):
|
||||||
|
my_str = "test_string"
|
||||||
|
container.register_convention(str, "my_str", my_str)
|
||||||
|
|
||||||
|
test_6 = container.resolve(TestClass6)
|
||||||
|
|
||||||
|
assert test_6.my_str == my_str
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass7:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, my_str1: str, my_str2: str):
|
||||||
|
self.my_str1 = my_str1
|
||||||
|
self.my_str2 = my_str2
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_convention__multiple_parameters_same_type(container):
|
||||||
|
my_str1 = "s1"
|
||||||
|
my_str2 = "s2"
|
||||||
|
container.register_convention(str, "my_str2", my_str2)
|
||||||
|
container.register_convention(str, "my_str1", my_str1)
|
||||||
|
|
||||||
|
test_7 = container.resolve(TestClass7)
|
||||||
|
assert test_7.my_str1 == my_str1
|
||||||
|
assert test_7.my_str2 == my_str2
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass8:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, my_str: str, my_int: int):
|
||||||
|
self.my_str = my_str
|
||||||
|
self.my_int = my_int
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_convention__multiple_parameters_different_types(container):
|
||||||
|
my_str = "test_string"
|
||||||
|
my_int = 42
|
||||||
|
container.register_convention(str, "my_str", my_str)
|
||||||
|
container.register_convention(int, "my_int", my_int)
|
||||||
|
|
||||||
|
test_8 = container.resolve(TestClass8)
|
||||||
|
assert test_8.my_str == my_str
|
||||||
|
assert test_8.my_int == my_int
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass9:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, service_a: IServiceA, my_str: str):
|
||||||
|
self.service_a = service_a
|
||||||
|
self.my_str = my_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_convention__type_properly_resolved(container):
|
||||||
|
my_str = "test_string"
|
||||||
|
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.register_convention(str, "my_str", my_str)
|
||||||
|
test_9 = container.resolve(TestClass9)
|
||||||
|
|
||||||
|
assert isinstance(test_9.service_a, ServiceA)
|
||||||
|
assert test_9.my_str == my_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_convention__instance_properly_resolved(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
my_str = "test_string"
|
||||||
|
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
container.register_convention(str, "my_str", my_str)
|
||||||
|
test_9 = container.resolve(TestClass9)
|
||||||
|
|
||||||
|
assert id(service_a_instance) == id(test_9.service_a)
|
||||||
|
assert test_9.my_str == my_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_release_convention(container):
|
||||||
|
my_str = "test_string"
|
||||||
|
container.register_convention(str, "my_str", my_str)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
container.release_convention(str, "my_str")
|
||||||
|
container.resolve(TestClass6)
|
||||||
|
|
|
@ -182,6 +182,7 @@ GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57)
|
||||||
architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25)
|
architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25)
|
||||||
|
|
||||||
response_code # unused variable (monkey/monkey_island/cc/services/aws/aws_command_runner.py:26)
|
response_code # unused variable (monkey/monkey_island/cc/services/aws/aws_command_runner.py:26)
|
||||||
|
release_convention # unused method (monkey/common/di_container.py:174)
|
||||||
|
|
||||||
# TODO DELETE AFTER RESOURCE REFACTORING
|
# TODO DELETE AFTER RESOURCE REFACTORING
|
||||||
NetworkMap
|
NetworkMap
|
||||||
|
|
Loading…
Reference in New Issue