From d1a4018d5fc77d0feb33162a2859aa1b37e1aed2 Mon Sep 17 00:00:00 2001 From: vakarisz Date: Wed, 16 Mar 2022 10:33:46 +0200 Subject: [PATCH] Agent: Pass interrupt event to HostExploiter --- monkey/infection_monkey/exploit/HostExploiter.py | 9 +++++++++ monkey/infection_monkey/exploit/exploiter_wrapper.py | 12 ++++++++++-- monkey/infection_monkey/puppet/plugin_registry.py | 3 ++- monkey/infection_monkey/puppet/puppet.py | 2 +- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/monkey/infection_monkey/exploit/HostExploiter.py b/monkey/infection_monkey/exploit/HostExploiter.py index 3bda4b0d7..d8afcf97b 100644 --- a/monkey/infection_monkey/exploit/HostExploiter.py +++ b/monkey/infection_monkey/exploit/HostExploiter.py @@ -1,4 +1,5 @@ import logging +import threading from abc import abstractmethod from datetime import datetime from typing import Dict @@ -66,12 +67,14 @@ class HostExploiter: telemetry_messenger: ITelemetryMessenger, agent_repository: IAgentRepository, options: Dict, + interrupt: threading.Event, ): self.host = host self.current_depth = current_depth self.telemetry_messenger = telemetry_messenger self.agent_repository = agent_repository self.options = options + self.interrupt = interrupt self.pre_exploit() try: @@ -91,6 +94,12 @@ class HostExploiter: ) self.set_start_time() + def is_interrupted(self): + # This method should be refactored to raise an exception to reduce duplication in the + # "if is_interrupted: return self.exploitation_results" + # Ideally the user should only do "check_for_interrupt()" + return self.interrupt.is_set() + def post_exploit(self): self.set_finish_time() diff --git a/monkey/infection_monkey/exploit/exploiter_wrapper.py b/monkey/infection_monkey/exploit/exploiter_wrapper.py index 5e855ff22..540e0b4a4 100644 --- a/monkey/infection_monkey/exploit/exploiter_wrapper.py +++ b/monkey/infection_monkey/exploit/exploiter_wrapper.py @@ -1,3 +1,4 @@ +import threading from typing import Dict, Type from infection_monkey.model import VictimHost @@ -26,10 +27,17 @@ class ExploiterWrapper: self._telemetry_messenger = telemetry_messenger self._agent_repository = agent_repository - def exploit_host(self, host: VictimHost, current_depth: int, options: Dict): + def exploit_host( + self, host: VictimHost, current_depth: int, options: Dict, interrupt: threading.Event + ): exploiter = self._exploit_class() return exploiter.exploit_host( - host, current_depth, self._telemetry_messenger, self._agent_repository, options + host, + current_depth, + self._telemetry_messenger, + self._agent_repository, + options, + interrupt, ) def __init__( diff --git a/monkey/infection_monkey/puppet/plugin_registry.py b/monkey/infection_monkey/puppet/plugin_registry.py index 2ec1e3900..1fdca5bd5 100644 --- a/monkey/infection_monkey/puppet/plugin_registry.py +++ b/monkey/infection_monkey/puppet/plugin_registry.py @@ -1,4 +1,5 @@ import logging +from typing import Any from infection_monkey.i_puppet import PluginType, UnknownPluginError @@ -27,7 +28,7 @@ class PluginRegistry: logger.debug(f"Plugin '{plugin_name}' loaded") - def get_plugin(self, plugin_name: str, plugin_type: PluginType) -> object: + def get_plugin(self, plugin_name: str, plugin_type: PluginType) -> Any: try: plugin = self._registry[plugin_type][plugin_name] except KeyError: diff --git a/monkey/infection_monkey/puppet/puppet.py b/monkey/infection_monkey/puppet/puppet.py index 95e72533f..d8bc8e0eb 100644 --- a/monkey/infection_monkey/puppet/puppet.py +++ b/monkey/infection_monkey/puppet/puppet.py @@ -66,7 +66,7 @@ class Puppet(IPuppet): interrupt: threading.Event, ) -> ExploiterResultData: exploiter = self._plugin_registry.get_plugin(name, PluginType.EXPLOITER) - return exploiter.exploit_host(host, current_depth, options) + return exploiter.exploit_host(host, current_depth, options, interrupt) def run_payload(self, name: str, options: Dict, interrupt: threading.Event): payload = self._plugin_registry.get_plugin(name, PluginType.PAYLOAD)