diff --git a/monkey/infection_monkey/puppet/puppet.py b/monkey/infection_monkey/puppet/puppet.py index 0fac17a7e..132898536 100644 --- a/monkey/infection_monkey/puppet/puppet.py +++ b/monkey/infection_monkey/puppet/puppet.py @@ -21,7 +21,7 @@ class Puppet(IPuppet): self._plugin_registry = PluginRegistry() def load_plugin(self, plugin_name: str, plugin: object, plugin_type: PluginType) -> None: - self._plugin_registry.load_plugin(plugin, plugin_name, plugin_type) + self._plugin_registry.load_plugin(plugin_name, plugin, plugin_type) def run_sys_info_collector(self, name: str) -> Dict: pass @@ -50,7 +50,8 @@ class Puppet(IPuppet): pass def run_payload(self, name: str, options: Dict, interrupt: threading.Event): - pass + payload = self._plugin_registry.get_plugin(name, PluginType.PAYLOAD) + payload.run(options, interrupt) def cleanup(self) -> None: pass diff --git a/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py b/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py new file mode 100644 index 000000000..c0c4a1f19 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py @@ -0,0 +1,43 @@ +import threading +from unittest.mock import MagicMock + +from infection_monkey.puppet.plugin_type import PluginType +from infection_monkey.puppet.puppet import Puppet + + +def test_puppet_run_payload_success(monkeypatch): + p = Puppet() + + payload = MagicMock() + payload_name = "PayloadOne" + + p.load_plugin(payload_name, payload, PluginType.PAYLOAD) + p.run_payload(payload_name, {}, threading.Event()) + + payload.run.assert_called_once() + + +def test_puppet_run_multiple_payloads(monkeypatch): + p = Puppet() + + payload_1 = MagicMock() + payload1_name = "PayloadOne" + + payload_2 = MagicMock() + payload2_name = "PayloadTwo" + + payload_3 = MagicMock() + payload3_name = "PayloadThree" + + p.load_plugin(payload1_name, payload_1, PluginType.PAYLOAD) + p.load_plugin(payload2_name, payload_2, PluginType.PAYLOAD) + p.load_plugin(payload3_name, payload_3, PluginType.PAYLOAD) + + p.run_payload(payload1_name, {}, threading.Event()) + payload_1.run.assert_called_once() + + p.run_payload(payload2_name, {}, threading.Event()) + payload_2.run.assert_called_once() + + p.run_payload(payload3_name, {}, threading.Event()) + payload_3.run.assert_called_once()