Agent: Replace if checks with interruptable_iter() in for loops

This commit is contained in:
Mike Salvatore 2022-01-24 08:56:04 -05:00
parent 0c877833c5
commit fae0c8ded2
4 changed files with 13 additions and 27 deletions

View File

@ -11,7 +11,7 @@ from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
from infection_monkey.utils.threading import create_daemon_thread from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
from infection_monkey.utils.timer import Timer from infection_monkey.utils.timer import Timer
from . import Exploiter, IPScanner, Propagator from . import Exploiter, IPScanner, Propagator
@ -195,11 +195,8 @@ class AutomatedMaster(IMaster):
logger.info(f"Running {plugin_type}s") logger.info(f"Running {plugin_type}s")
logger.debug(f"Found {len(plugin)} {plugin_type}(s) to run") logger.debug(f"Found {len(plugin)} {plugin_type}(s) to run")
for p in plugin: interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s"
if self._stop.is_set(): for p in interruptable_iter(plugin, self._stop, interrupted_message):
logger.debug(f"Received a stop signal, skipping remaining {plugin_type}s")
return
callback(p) callback(p)
logger.info(f"Finished running {plugin_type}s") logger.info(f"Finished running {plugin_type}s")

View File

@ -7,7 +7,7 @@ from typing import Callable, Dict, List
from infection_monkey.i_puppet import ExploiterResultData, IPuppet from infection_monkey.i_puppet import ExploiterResultData, IPuppet
from infection_monkey.model import VictimHost from infection_monkey.model import VictimHost
from infection_monkey.utils.threading import run_worker_threads from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
QUEUE_TIMEOUT = 2 QUEUE_TIMEOUT = 2
@ -74,10 +74,7 @@ class Exploiter:
results_callback: Callback, results_callback: Callback,
stop: Event, stop: Event,
): ):
for exploiter in exploiters_to_run: for exploiter in interruptable_iter(exploiters_to_run, stop):
if stop.is_set():
break
exploiter_name = exploiter["name"] exploiter_name = exploiter["name"]
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop) exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
results_callback(exploiter_name, victim_host, exploiter_results) results_callback(exploiter_name, victim_host, exploiter_results)

View File

@ -13,7 +13,7 @@ from infection_monkey.i_puppet import (
PortStatus, PortStatus,
) )
from infection_monkey.network import NetworkAddress from infection_monkey.network import NetworkAddress
from infection_monkey.utils.threading import run_worker_threads from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
from . import IPScanResults from . import IPScanResults
@ -85,10 +85,7 @@ class IPScanner:
) -> Dict[int, PortScanData]: ) -> Dict[int, PortScanData]:
port_scan_data = {} port_scan_data = {}
for p in ports: for p in interruptable_iter(ports, stop):
if stop.is_set():
break
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout) port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
return port_scan_data return port_scan_data
@ -107,10 +104,7 @@ class IPScanner:
) -> Dict[str, FingerprintData]: ) -> Dict[str, FingerprintData]:
fingerprint_data = {} fingerprint_data = {}
for f in fingerprinters: for f in interruptable_iter(fingerprinters, stop):
if stop.is_set():
break
fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data) fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data)
return fingerprint_data return fingerprint_data

View File

@ -5,6 +5,7 @@ from typing import Callable, List
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.utils.threading import interruptable_iter
from .consts import README_FILE_NAME, README_SRC from .consts import README_FILE_NAME, README_SRC
from .ransomware_options import RansomwareOptions from .ransomware_options import RansomwareOptions
@ -53,13 +54,10 @@ class Ransomware:
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event): def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
logger.info(f"Encrypting files in {self._target_directory}") logger.info(f"Encrypting files in {self._target_directory}")
for filepath in file_list: interrupted_message = (
if interrupt.is_set(): "Received a stop signal, skipping remaining files for encryption of ransomware payload"
logger.debug(
"Received a stop signal, skipping remaining files for encryption of "
"ransomware payload"
) )
return for filepath in interruptable_iter(file_list, interrupt, interrupted_message):
try: try:
logger.debug(f"Encrypting {filepath}") logger.debug(f"Encrypting {filepath}")
self._encrypt_file(filepath) self._encrypt_file(filepath)