Agent: Pass interrupt event to HostExploiter

This commit is contained in:
vakarisz 2022-03-16 10:33:46 +02:00 committed by Mike Salvatore
parent ed5e686b04
commit d1a4018d5f
4 changed files with 22 additions and 4 deletions

View File

@ -1,4 +1,5 @@
import logging
import threading
from abc import abstractmethod
from datetime import datetime
from typing import Dict
@ -66,12 +67,14 @@ class HostExploiter:
telemetry_messenger: ITelemetryMessenger,
agent_repository: IAgentRepository,
options: Dict,
interrupt: threading.Event,
):
self.host = host
self.current_depth = current_depth
self.telemetry_messenger = telemetry_messenger
self.agent_repository = agent_repository
self.options = options
self.interrupt = interrupt
self.pre_exploit()
try:
@ -91,6 +94,12 @@ class HostExploiter:
)
self.set_start_time()
def is_interrupted(self):
# This method should be refactored to raise an exception to reduce duplication in the
# "if is_interrupted: return self.exploitation_results"
# Ideally the user should only do "check_for_interrupt()"
return self.interrupt.is_set()
def post_exploit(self):
self.set_finish_time()

View File

@ -1,3 +1,4 @@
import threading
from typing import Dict, Type
from infection_monkey.model import VictimHost
@ -26,10 +27,17 @@ class ExploiterWrapper:
self._telemetry_messenger = telemetry_messenger
self._agent_repository = agent_repository
def exploit_host(self, host: VictimHost, current_depth: int, options: Dict):
def exploit_host(
self, host: VictimHost, current_depth: int, options: Dict, interrupt: threading.Event
):
exploiter = self._exploit_class()
return exploiter.exploit_host(
host, current_depth, self._telemetry_messenger, self._agent_repository, options
host,
current_depth,
self._telemetry_messenger,
self._agent_repository,
options,
interrupt,
)
def __init__(

View File

@ -1,4 +1,5 @@
import logging
from typing import Any
from infection_monkey.i_puppet import PluginType, UnknownPluginError
@ -27,7 +28,7 @@ class PluginRegistry:
logger.debug(f"Plugin '{plugin_name}' loaded")
def get_plugin(self, plugin_name: str, plugin_type: PluginType) -> object:
def get_plugin(self, plugin_name: str, plugin_type: PluginType) -> Any:
try:
plugin = self._registry[plugin_type][plugin_name]
except KeyError:

View File

@ -66,7 +66,7 @@ class Puppet(IPuppet):
interrupt: threading.Event,
) -> ExploiterResultData:
exploiter = self._plugin_registry.get_plugin(name, PluginType.EXPLOITER)
return exploiter.exploit_host(host, current_depth, options)
return exploiter.exploit_host(host, current_depth, options, interrupt)
def run_payload(self, name: str, options: Dict, interrupt: threading.Event):
payload = self._plugin_registry.get_plugin(name, PluginType.PAYLOAD)