diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py index 1f0410d5b..28994d673 100644 --- a/monkey/infection_monkey/master/automated_master.py +++ b/monkey/infection_monkey/master/automated_master.py @@ -1,7 +1,7 @@ import logging import threading import time -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, Iterable, List, Tuple from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_master import IMaster @@ -11,10 +11,10 @@ from infection_monkey.network import NetworkInterface from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem +from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter from infection_monkey.utils.timer import Timer from . import Exploiter, IPScanner, Propagator -from .threading_utils import create_daemon_thread CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5 CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5 @@ -182,7 +182,7 @@ class AutomatedMaster(IMaster): command, result = self._puppet.run_pba(name, options) self._telemetry_messenger.send_telemetry(PostBreachTelem(name, command, result)) - def _can_propagate(self): + def _can_propagate(self) -> bool: return True def _run_payload(self, payload: Tuple[str, Dict]): @@ -191,15 +191,14 @@ class AutomatedMaster(IMaster): self._puppet.run_payload(name, options, self._stop) - def _run_plugins(self, plugin: List[Any], plugin_type: str, callback: Callable[[Any], None]): + def _run_plugins( + self, plugins: Iterable[Any], plugin_type: str, callback: Callable[[Any], None] + ): logger.info(f"Running {plugin_type}s") - logger.debug(f"Found {len(plugin)} {plugin_type}(s) to run") - - for p in plugin: - if self._stop.is_set(): - logger.debug(f"Received a stop signal, skipping remaining {plugin_type}s") - return + logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run") + interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s" + for p in interruptable_iter(plugins, self._stop, interrupted_message): callback(p) logger.info(f"Finished running {plugin_type}s") diff --git a/monkey/infection_monkey/master/exploiter.py b/monkey/infection_monkey/master/exploiter.py index f1a804ba7..09f6ebf4b 100644 --- a/monkey/infection_monkey/master/exploiter.py +++ b/monkey/infection_monkey/master/exploiter.py @@ -7,8 +7,7 @@ from typing import Callable, Dict, List from infection_monkey.i_puppet import ExploiterResultData, IPuppet from infection_monkey.model import VictimHost - -from .threading_utils import run_worker_threads +from infection_monkey.utils.threading import interruptable_iter, run_worker_threads QUEUE_TIMEOUT = 2 @@ -75,10 +74,7 @@ class Exploiter: results_callback: Callback, stop: Event, ): - for exploiter in exploiters_to_run: - if stop.is_set(): - break - + for exploiter in interruptable_iter(exploiters_to_run, stop): exploiter_name = exploiter["name"] exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop) results_callback(exploiter_name, victim_host, exploiter_results) diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py index 0cd2b021f..0f7132a27 100644 --- a/monkey/infection_monkey/master/ip_scanner.py +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -13,9 +13,9 @@ from infection_monkey.i_puppet import ( PortStatus, ) from infection_monkey.network import NetworkAddress +from infection_monkey.utils.threading import interruptable_iter, run_worker_threads from . import IPScanResults -from .threading_utils import run_worker_threads logger = logging.getLogger() @@ -49,25 +49,23 @@ class IPScanner: self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event ): logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") + icmp_timeout = options["icmp"]["timeout_ms"] / 1000 + tcp_timeout = options["tcp"]["timeout_ms"] / 1000 + tcp_ports = options["tcp"]["ports"] try: while not stop.is_set(): address = addresses.get_nowait() - ip = address.ip - logger.info(f"Scanning {ip}") + logger.info(f"Scanning {address.ip}") - icmp_timeout = options["icmp"]["timeout_ms"] / 1000 - ping_scan_data = self._puppet.ping(ip, icmp_timeout) - - tcp_timeout = options["tcp"]["timeout_ms"] / 1000 - tcp_ports = options["tcp"]["ports"] - port_scan_data = self._scan_tcp_ports(ip, tcp_ports, tcp_timeout, stop) + ping_scan_data = self._puppet.ping(address.ip, icmp_timeout) + port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop) fingerprint_data = {} if IPScanner.port_scan_found_open_port(port_scan_data): fingerprinters = options["fingerprinters"] fingerprint_data = self._run_fingerprinters( - ip, fingerprinters, ping_scan_data, port_scan_data, stop + address.ip, fingerprinters, ping_scan_data, port_scan_data, stop ) scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data) @@ -87,10 +85,7 @@ class IPScanner: ) -> Dict[int, PortScanData]: port_scan_data = {} - for p in ports: - if stop.is_set(): - break - + for p in interruptable_iter(ports, stop): port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout) return port_scan_data @@ -109,10 +104,7 @@ class IPScanner: ) -> Dict[str, FingerprintData]: fingerprint_data = {} - for f in fingerprinters: - if stop.is_set(): - break - + for f in interruptable_iter(fingerprinters, stop): fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data) return fingerprint_data diff --git a/monkey/infection_monkey/master/propagator.py b/monkey/infection_monkey/master/propagator.py index b3eb7faf9..87f9a1896 100644 --- a/monkey/infection_monkey/master/propagator.py +++ b/monkey/infection_monkey/master/propagator.py @@ -16,9 +16,9 @@ from infection_monkey.network.scan_target_generator import compile_scan_target_l from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.scan_telem import ScanTelem +from infection_monkey.utils.threading import create_daemon_thread from . import Exploiter, IPScanner, IPScanResults -from .threading_utils import create_daemon_thread logger = logging.getLogger() @@ -107,14 +107,13 @@ class Propagator: victim_host.os["type"] = ping_scan_data.os @staticmethod - def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData) -> bool: - for psd in port_scan_data.values(): - if psd.status == PortStatus.OPEN: - victim_host.services[psd.service] = {} - victim_host.services[psd.service]["display_name"] = "unknown(TCP)" - victim_host.services[psd.service]["port"] = psd.port - if psd.banner is not None: - victim_host.services[psd.service]["banner"] = psd.banner + def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData): + for psd in filter(lambda psd: psd.status == PortStatus.OPEN, port_scan_data.values()): + victim_host.services[psd.service] = {} + victim_host.services[psd.service]["display_name"] = "unknown(TCP)" + victim_host.services[psd.service]["port"] = psd.port + if psd.banner is not None: + victim_host.services[psd.service]["banner"] = psd.banner @staticmethod def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData): diff --git a/monkey/infection_monkey/master/threading_utils.py b/monkey/infection_monkey/master/threading_utils.py deleted file mode 100644 index dbcc67984..000000000 --- a/monkey/infection_monkey/master/threading_utils.py +++ /dev/null @@ -1,17 +0,0 @@ -from threading import Thread -from typing import Callable, Tuple - - -def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2): - worker_threads = [] - for i in range(0, num_workers): - t = create_daemon_thread(target=target, args=args) - t.start() - worker_threads.append(t) - - for t in worker_threads: - t.join() - - -def create_daemon_thread(target: Callable[..., None], args: Tuple = ()): - return Thread(target=target, args=args, daemon=True) diff --git a/monkey/infection_monkey/payload/ransomware/in_place_file_encryptor.py b/monkey/infection_monkey/payload/ransomware/in_place_file_encryptor.py index f4bcaf3aa..fc5523352 100644 --- a/monkey/infection_monkey/payload/ransomware/in_place_file_encryptor.py +++ b/monkey/infection_monkey/payload/ransomware/in_place_file_encryptor.py @@ -28,17 +28,12 @@ class InPlaceFileEncryptor: def _encrypt_file(self, filepath: Path): with open(filepath, "rb+") as f: - data = f.read(self._chunk_size) - while data: - num_bytes_read = len(data) - + for data in iter(lambda: f.read(self._chunk_size), b""): encrypted_data = self._encrypt_bytes(data) - f.seek(-num_bytes_read, 1) + f.seek(-len(encrypted_data), 1) f.write(encrypted_data) - data = f.read(self._chunk_size) - def _add_extension(self, filepath: Path): new_filepath = filepath.with_suffix(f"{filepath.suffix}{self._new_file_extension}") filepath.rename(new_filepath) diff --git a/monkey/infection_monkey/payload/ransomware/ransomware.py b/monkey/infection_monkey/payload/ransomware/ransomware.py index 003112cc3..c4351acaf 100644 --- a/monkey/infection_monkey/payload/ransomware/ransomware.py +++ b/monkey/infection_monkey/payload/ransomware/ransomware.py @@ -5,6 +5,7 @@ from typing import Callable, List from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger +from infection_monkey.utils.threading import interruptable_iter from .consts import README_FILE_NAME, README_SRC from .ransomware_options import RansomwareOptions @@ -53,13 +54,10 @@ class Ransomware: def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event): logger.info(f"Encrypting files in {self._target_directory}") - for filepath in file_list: - if interrupt.is_set(): - logger.debug( - "Received a stop signal, skipping remaining files for encryption of " - "ransomware payload" - ) - return + interrupted_message = ( + "Received a stop signal, skipping remaining files for encryption of ransomware payload" + ) + for filepath in interruptable_iter(file_list, interrupt, interrupted_message): try: logger.debug(f"Encrypting {filepath}") self._encrypt_file(filepath) @@ -73,11 +71,11 @@ class Ransomware: self._telemetry_messenger.send_telemetry(encryption_attempt) def _leave_readme_in_target_directory(self, interrupt: threading.Event): - try: - if interrupt.is_set(): - logger.debug("Received a stop signal, skipping leave readme") - return + if interrupt.is_set(): + logger.debug("Received a stop signal, skipping leave readme") + return + try: self._leave_readme(README_SRC, self._readme_file_path) except Exception as ex: logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}") diff --git a/monkey/infection_monkey/utils/dir_utils.py b/monkey/infection_monkey/utils/dir_utils.py index 704556335..2fd29af9e 100644 --- a/monkey/infection_monkey/utils/dir_utils.py +++ b/monkey/infection_monkey/utils/dir_utils.py @@ -6,7 +6,9 @@ def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]: return filter_files(dir_path.iterdir(), [lambda f: f.is_file()]) -def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool]]): +def filter_files( + files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]] +) -> List[Path]: filtered_files = files for file_filter in file_filters: filtered_files = [f for f in filtered_files if file_filter(f)] @@ -14,16 +16,16 @@ def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool return filtered_files -def file_extension_filter(file_extensions: Set): - def inner_filter(f: Path): +def file_extension_filter(file_extensions: Set) -> Callable[[Path], bool]: + def inner_filter(f: Path) -> bool: return f.suffix in file_extensions return inner_filter -def is_not_symlink_filter(f: Path): +def is_not_symlink_filter(f: Path) -> bool: return not f.is_symlink() -def is_not_shortcut_filter(f: Path): +def is_not_shortcut_filter(f: Path) -> bool: return f.suffix != ".lnk" diff --git a/monkey/infection_monkey/utils/threading.py b/monkey/infection_monkey/utils/threading.py new file mode 100644 index 000000000..9ceec895f --- /dev/null +++ b/monkey/infection_monkey/utils/threading.py @@ -0,0 +1,44 @@ +import logging +from threading import Event, Thread +from typing import Any, Callable, Iterable, Tuple + +logger = logging.getLogger(__name__) + + +def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2): + worker_threads = [] + for i in range(0, num_workers): + t = create_daemon_thread(target=target, args=args) + t.start() + worker_threads.append(t) + + for t in worker_threads: + t.join() + + +def create_daemon_thread(target: Callable[..., None], args: Tuple = ()): + return Thread(target=target, args=args, daemon=True) + + +def interruptable_iter( + iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG +) -> Any: + """ + Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This + is a convinient way to make loops interruptable and avoids the need to add an `if` to each and + every loop. + :param Iterable iterator: An iterator that will be made interruptable. + :param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the + iterator's items from being processed. + :param str log_message: A message to be logged if the iterator is interrupted. If `log_message` + is `None` (default), then no message is logged. + :param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`. + """ + for i in iterator: + if interrupt.is_set(): + if log_message: + logger.log(log_level, log_message) + + break + + yield i diff --git a/monkey/monkey_island/cc/setup/island_config_options.py b/monkey/monkey_island/cc/setup/island_config_options.py index 763474c83..b3408ad86 100644 --- a/monkey/monkey_island/cc/setup/island_config_options.py +++ b/monkey/monkey_island/cc/setup/island_config_options.py @@ -1,5 +1,8 @@ from __future__ import annotations +from types import MappingProxyType as ImmutableMapping +from typing import Mapping + from common.utils.file_utils import expand_path from monkey_island.cc.server_utils.consts import ( DEFAULT_CRT_PATH, @@ -19,10 +22,7 @@ _LOG_LEVEL = "log_level" class IslandConfigOptions: - def __init__(self, config_contents: dict = None): - if config_contents is None: - config_contents = {} - + def __init__(self, config_contents: Mapping[str, Mapping] = ImmutableMapping({})): self.data_dir = DEFAULT_DATA_DIR self.log_level = DEFAULT_LOG_LEVEL self.start_mongodb = DEFAULT_START_MONGO_DB @@ -33,7 +33,7 @@ class IslandConfigOptions: self.update(config_contents) - def update(self, config_contents: dict): + def update(self, config_contents: Mapping[str, Mapping]): self.data_dir = config_contents.get(_DATA_DIR, self.data_dir) self.log_level = config_contents.get(_LOG_LEVEL, self.log_level) diff --git a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py index adffe6f88..365f9fecd 100644 --- a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py +++ b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py @@ -117,7 +117,12 @@ def test_interrupt_while_encrypting( mfe.assert_any_call(ransomware_test_data / HELLO_TXT) -def test_no_readme_after_interrupt(ransomware, interrupt, mock_leave_readme): +def test_no_readme_after_interrupt( + ransomware_options, build_ransomware, interrupt, mock_leave_readme +): + ransomware_options.readme_enabled = True + ransomware = build_ransomware(ransomware_options) + interrupt.set() ransomware.run(interrupt) diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py new file mode 100644 index 000000000..659fc7205 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py @@ -0,0 +1,47 @@ +import logging +from threading import Event + +from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter + + +def test_create_daemon_thread(): + thread = create_daemon_thread(lambda: None) + assert thread.daemon + + +def test_interruptable_iter(): + interrupt = Event() + items_from_iterator = [] + test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted") + + for i in test_iterator: + items_from_iterator.append(i) + if i == 3: + interrupt.set() + + assert items_from_iterator == [0, 1, 2, 3] + + +def test_interruptable_iter_not_interrupted(): + interrupt = Event() + items_from_iterator = [] + test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted") + + for i in test_iterator: + items_from_iterator.append(i) + + assert items_from_iterator == [0, 1, 2, 3, 4] + + +def test_interruptable_iter_interrupted_before_used(): + interrupt = Event() + items_from_iterator = [] + test_iterator = interruptable_iter( + range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO + ) + + interrupt.set() + for i in test_iterator: + items_from_iterator.append(i) + + assert not items_from_iterator