diff --git a/monkey/infection_monkey/puppet/plugin_registry.py b/monkey/infection_monkey/puppet/plugin_registry.py new file mode 100644 index 000000000..d86a7491e --- /dev/null +++ b/monkey/infection_monkey/puppet/plugin_registry.py @@ -0,0 +1,42 @@ +import logging +from typing import Optional + +from infection_monkey.i_puppet import UnknownPluginError +from infection_monkey.puppet.plugin_type import PluginType + +logger = logging.getLogger() + + +class PluginRegistry: + def __init__(self): + """ + `self._registry` looks like - + { + PluginType.EXPLOITER: { + "ZerologonExploiter": ZerologonExploiter, + "SMBExploiter": SMBExploiter + }, + PluginType.PBA: { + "CommunicateAsBackdoorUser": CommunicateAsBackdoorUser + } + } + """ + self._registry = {} + + def load_plugin(self, plugin: object, plugin_type: PluginType) -> None: + self._registry.setdefault(plugin_type, {}) + self._registry[plugin_type][plugin.__class__.__name__] = plugin + + logger.debug(f"Plugin '{plugin.__class__.__name__}' loaded") + + def get_plugin(self, plugin_name: str, plugin_type: PluginType) -> Optional[object]: + try: + plugin = self._registry[plugin_type][plugin_name] + except KeyError: + raise UnknownPluginError( + f"Unknown plugin '{plugin_name}' of type '{plugin_type.value}'" + ) + + logger.debug(f"Plugin '{plugin_name}' found") + + return plugin diff --git a/monkey/infection_monkey/puppet/puppet.py b/monkey/infection_monkey/puppet/puppet.py index 5563a2dfe..0c36435e0 100644 --- a/monkey/infection_monkey/puppet/puppet.py +++ b/monkey/infection_monkey/puppet/puppet.py @@ -10,14 +10,18 @@ from infection_monkey.i_puppet import ( PortScanData, PostBreachData, ) +from infection_monkey.puppet.plugin_registry import PluginRegistry from infection_monkey.puppet.plugin_type import PluginType logger = logging.getLogger() class Puppet(IPuppet): + def __init__(self) -> None: + self._plugin_registry = PluginRegistry() + def load_plugin(self, plugin: object, plugin_type: PluginType) -> None: - pass + self._plugin_registry.load_plugin(plugin, plugin_type) def run_sys_info_collector(self, name: str) -> Dict: pass