Agent: Implement Puppet.run_payload()

This commit is contained in:
Ilija Lazoroski 2021-12-16 16:28:46 +01:00 committed by Mike Salvatore
parent b798255249
commit 0a4ff25843
2 changed files with 46 additions and 2 deletions

View File

@ -21,7 +21,7 @@ class Puppet(IPuppet):
self._plugin_registry = PluginRegistry() self._plugin_registry = PluginRegistry()
def load_plugin(self, plugin_name: str, plugin: object, plugin_type: PluginType) -> None: def load_plugin(self, plugin_name: str, plugin: object, plugin_type: PluginType) -> None:
self._plugin_registry.load_plugin(plugin, plugin_name, plugin_type) self._plugin_registry.load_plugin(plugin_name, plugin, plugin_type)
def run_sys_info_collector(self, name: str) -> Dict: def run_sys_info_collector(self, name: str) -> Dict:
pass pass
@ -50,7 +50,8 @@ class Puppet(IPuppet):
pass pass
def run_payload(self, name: str, options: Dict, interrupt: threading.Event): def run_payload(self, name: str, options: Dict, interrupt: threading.Event):
pass payload = self._plugin_registry.get_plugin(name, PluginType.PAYLOAD)
payload.run(options, interrupt)
def cleanup(self) -> None: def cleanup(self) -> None:
pass pass

View File

@ -0,0 +1,43 @@
import threading
from unittest.mock import MagicMock
from infection_monkey.puppet.plugin_type import PluginType
from infection_monkey.puppet.puppet import Puppet
def test_puppet_run_payload_success(monkeypatch):
p = Puppet()
payload = MagicMock()
payload_name = "PayloadOne"
p.load_plugin(payload_name, payload, PluginType.PAYLOAD)
p.run_payload(payload_name, {}, threading.Event())
payload.run.assert_called_once()
def test_puppet_run_multiple_payloads(monkeypatch):
p = Puppet()
payload_1 = MagicMock()
payload1_name = "PayloadOne"
payload_2 = MagicMock()
payload2_name = "PayloadTwo"
payload_3 = MagicMock()
payload3_name = "PayloadThree"
p.load_plugin(payload1_name, payload_1, PluginType.PAYLOAD)
p.load_plugin(payload2_name, payload_2, PluginType.PAYLOAD)
p.load_plugin(payload3_name, payload_3, PluginType.PAYLOAD)
p.run_payload(payload1_name, {}, threading.Event())
payload_1.run.assert_called_once()
p.run_payload(payload2_name, {}, threading.Event())
payload_2.run.assert_called_once()
p.run_payload(payload3_name, {}, threading.Event())
payload_3.run.assert_called_once()