diff --git a/monkey/infection_monkey/exploit/__init__.py b/monkey/infection_monkey/exploit/__init__.py index e69de29bb..42d8d18bf 100644 --- a/monkey/infection_monkey/exploit/__init__.py +++ b/monkey/infection_monkey/exploit/__init__.py @@ -0,0 +1 @@ +from .exploiter_wrapper import ExploiterWrapper diff --git a/monkey/infection_monkey/exploit/exploiter_wrapper.py b/monkey/infection_monkey/exploit/exploiter_wrapper.py new file mode 100644 index 000000000..444c89b31 --- /dev/null +++ b/monkey/infection_monkey/exploit/exploiter_wrapper.py @@ -0,0 +1,32 @@ +from typing import Dict, Type + +from infection_monkey.model import VictimHost +from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger + +from .HostExploiter import HostExploiter + + +class ExploiterWrapper: + """ + This class is a temporary measure to allow existing exploiters to play nicely within the + confines of the IPuppet interface. It keeps a reference to an ITelemetryMessenger that is passed + to all exploiters. Additionally, it constructs a new instance of the exploiter for each call to + exploit_host(). When exploiters are refactored into plugins, this class will likely go away. + """ + + class Inner: + def __init__( + self, exploit_class: Type[HostExploiter], telemetry_messenger: ITelemetryMessenger + ): + self._exploit_class = exploit_class + self._telemetry_messenger = telemetry_messenger + + def exploit_host(self, host: VictimHost, options: Dict): + exploiter = self._exploit_class() + return exploiter.exploit_host(host, self._telemetry_messenger, options) + + def __init__(self, telemetry_messenger: ITelemetryMessenger): + self._telemetry_messenger = telemetry_messenger + + def wrap(self, exploit_class: Type[HostExploiter]): + return ExploiterWrapper.Inner(exploit_class, self._telemetry_messenger) diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 087fa9959..fc52290bb 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -16,6 +16,7 @@ from infection_monkey.credential_collectors import ( MimikatzCredentialCollector, SSHCredentialCollector, ) +from infection_monkey.exploit import ExploiterWrapper from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.i_puppet import IPuppet, PluginType from infection_monkey.master import AutomatedMaster @@ -195,7 +196,7 @@ class InfectionMonkey: return local_network_interfaces def _build_puppet(self) -> IPuppet: - puppet = Puppet(self.telemetry_messenger) + puppet = Puppet() puppet.load_plugin( "MimikatzCollector", @@ -214,7 +215,13 @@ class InfectionMonkey: puppet.load_plugin("smb", SMBFingerprinter(), PluginType.FINGERPRINTER) puppet.load_plugin("ssh", SSHFingerprinter(), PluginType.FINGERPRINTER) - puppet.load_plugin("SSHExploiter", SSHExploiter(), PluginType.EXPLOITER) + exploit_wrapper = ExploiterWrapper(self.telemetry_messenger) + + puppet.load_plugin( + "SSHExploiter", + exploit_wrapper.wrap(SSHExploiter), + PluginType.EXPLOITER, + ) puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD) diff --git a/monkey/infection_monkey/puppet/puppet.py b/monkey/infection_monkey/puppet/puppet.py index 1df3df885..e10695993 100644 --- a/monkey/infection_monkey/puppet/puppet.py +++ b/monkey/infection_monkey/puppet/puppet.py @@ -15,7 +15,6 @@ from infection_monkey.i_puppet import ( ) from infection_monkey.model import VictimHost -from ..telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from .mock_puppet import MockPuppet from .plugin_registry import PluginRegistry @@ -23,10 +22,9 @@ logger = logging.getLogger() class Puppet(IPuppet): - def __init__(self, telemetry_messenger: ITelemetryMessenger) -> None: + def __init__(self) -> None: self._mock_puppet = MockPuppet() self._plugin_registry = PluginRegistry() - self._telemetry_messenger = telemetry_messenger def load_plugin(self, plugin_name: str, plugin: object, plugin_type: PluginType) -> None: self._plugin_registry.load_plugin(plugin_name, plugin, plugin_type) @@ -63,7 +61,7 @@ class Puppet(IPuppet): self, name: str, host: VictimHost, options: Dict, interrupt: threading.Event ) -> ExploiterResultData: exploiter = self._plugin_registry.get_plugin(name, PluginType.EXPLOITER) - return exploiter.exploit_host(host, self._telemetry_messenger, options) + return exploiter.exploit_host(host, options) def run_payload(self, name: str, options: Dict, interrupt: threading.Event): payload = self._plugin_registry.get_plugin(name, PluginType.PAYLOAD) diff --git a/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py b/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py index 54b9275ae..70c98d252 100644 --- a/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py +++ b/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py @@ -1,19 +1,12 @@ import threading from unittest.mock import MagicMock -import pytest - from infection_monkey.i_puppet import PluginType from infection_monkey.puppet.puppet import Puppet -@pytest.fixture -def mock_telemetry_messenger(): - return MagicMock() - - -def test_puppet_run_payload_success(monkeypatch, mock_telemetry_messenger): - p = Puppet(mock_telemetry_messenger) +def test_puppet_run_payload_success(): + p = Puppet() payload = MagicMock() payload_name = "PayloadOne" @@ -24,8 +17,8 @@ def test_puppet_run_payload_success(monkeypatch, mock_telemetry_messenger): payload.run.assert_called_once() -def test_puppet_run_multiple_payloads(monkeypatch, mock_telemetry_messenger): - p = Puppet(mock_telemetry_messenger) +def test_puppet_run_multiple_payloads(): + p = Puppet() payload_1 = MagicMock() payload1_name = "PayloadOne"