diff --git a/monkey/infection_monkey/master/__init__.py b/monkey/infection_monkey/master/__init__.py index fda536194..98ed6db0b 100644 --- a/monkey/infection_monkey/master/__init__.py +++ b/monkey/infection_monkey/master/__init__.py @@ -1,4 +1,5 @@ from .ip_scan_results import IPScanResults from .ip_scanner import IPScanner +from .exploiter import Exploiter from .propagator import Propagator from .automated_master import AutomatedMaster diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py index 57b8f52b2..ff6af8b43 100644 --- a/monkey/infection_monkey/master/automated_master.py +++ b/monkey/infection_monkey/master/automated_master.py @@ -12,13 +12,14 @@ from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.utils.timer import Timer -from . import IPScanner, Propagator +from . import Exploiter, IPScanner, Propagator from .threading_utils import create_daemon_thread CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5 CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5 SHUTDOWN_TIMEOUT = 5 NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads +NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads logger = logging.getLogger() @@ -36,7 +37,10 @@ class AutomatedMaster(IMaster): self._control_channel = control_channel ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS) - self._propagator = Propagator(self._telemetry_messenger, ip_scanner, victim_host_factory) + exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS) + self._propagator = Propagator( + self._telemetry_messenger, ip_scanner, exploiter, victim_host_factory + ) self._stop = threading.Event() self._master_thread = create_daemon_thread(target=self._run_master_thread) diff --git a/monkey/infection_monkey/master/exploiter.py b/monkey/infection_monkey/master/exploiter.py new file mode 100644 index 000000000..3f732ffa3 --- /dev/null +++ b/monkey/infection_monkey/master/exploiter.py @@ -0,0 +1,107 @@ +import logging +import queue +import threading +from queue import Queue +from threading import Event +from typing import Callable, Dict, List + +from infection_monkey.i_puppet import ExploiterResultData, IPuppet +from infection_monkey.model import VictimHost + +from .threading_utils import create_daemon_thread + +QUEUE_TIMEOUT = 2 + +logger = logging.getLogger() + +ExploiterName = str +Callback = Callable[[VictimHost, ExploiterName, ExploiterResultData], None] + + +class Exploiter: + def __init__(self, puppet: IPuppet, num_workers: int): + self._puppet = puppet + self._num_workers = num_workers + + def exploit_hosts( + self, + exploiter_config: Dict, + hosts_to_exploit: Queue, + results_callback: Callback, + scan_completed: Event, + stop: Event, + ): + # Run vulnerability exploiters before brute force exploiters to minimize the effect of + # account lockout due to invalid credentials + exploiters_to_run = exploiter_config["vulnerability"] + exploiter_config["brute_force"] + logger.debug( + "Agent is configured to run the following exploiters in order: " + f"{','.join([e['name'] for e in exploiters_to_run])}" + ) + + exploit_args = (exploiters_to_run, hosts_to_exploit, results_callback, scan_completed, stop) + + # TODO: This functionality is also used in IPScanner and can be generalized. Extract it. + exploiter_threads = [] + for i in range(0, self._num_workers): + t = create_daemon_thread(target=self._exploit_hosts_on_queue, args=exploit_args) + t.start() + exploiter_threads.append(t) + + for t in exploiter_threads: + t.join() + + def _exploit_hosts_on_queue( + self, + exploiters_to_run: List[Dict], + hosts_to_exploit: Queue, + results_callback: Callback, + scan_completed: Event, + stop: Event, + ): + logger.debug(f"Starting exploiter thread -- Thread ID: {threading.get_ident()}") + + while not stop.is_set(): + try: + victim_host = hosts_to_exploit.get(timeout=QUEUE_TIMEOUT) + self._run_all_exploiters(exploiters_to_run, victim_host, results_callback, stop) + except queue.Empty: + if ( + _all_hosts_have_been_processed(scan_completed, hosts_to_exploit) + or stop.is_set() + ): + break + + logger.debug( + f"Exiting exploiter thread -- Thread ID: {threading.get_ident()} -- " + f"stop.is_set(): {stop.is_set()} -- network_scan_completed: " + f"{scan_completed.is_set()}" + ) + + def _run_all_exploiters( + self, + exploiters_to_run: List[Dict], + victim_host: VictimHost, + results_callback: Callback, + stop: Event, + ): + for exploiter in exploiters_to_run: + if stop.is_set(): + break + + exploiter_name = exploiter["name"] + exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop) + results_callback(exploiter_name, victim_host, exploiter_results) + + if exploiter["propagator"] and exploiter_results.success: + break + + def _run_exploiter( + self, exploiter_name: str, victim_host: VictimHost, stop: Event + ) -> ExploiterResultData: + logger.debug(f"Attempting to use {exploiter_name} on {victim_host}") + return self._puppet.exploit_host(exploiter_name, victim_host.ip_addr, {}, stop) + + +def _all_hosts_have_been_processed(scan_completed: Event, hosts_to_exploit: Queue): + return scan_completed.is_set() and hosts_to_exploit.empty() diff --git a/monkey/infection_monkey/master/propagator.py b/monkey/infection_monkey/master/propagator.py index 78e08a98d..24d5fb8f0 100644 --- a/monkey/infection_monkey/master/propagator.py +++ b/monkey/infection_monkey/master/propagator.py @@ -3,12 +3,19 @@ from queue import Queue from threading import Event, Thread from typing import Dict -from infection_monkey.i_puppet import FingerprintData, PingScanData, PortScanData, PortStatus +from infection_monkey.i_puppet import ( + ExploiterResultData, + FingerprintData, + PingScanData, + PortScanData, + PortStatus, +) from infection_monkey.model import VictimHost, VictimHostFactory +from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.scan_telem import ScanTelem -from . import IPScanner, IPScanResults +from . import Exploiter, IPScanner, IPScanResults from .threading_utils import create_daemon_thread logger = logging.getLogger() @@ -19,29 +26,35 @@ class Propagator: self, telemetry_messenger: ITelemetryMessenger, ip_scanner: IPScanner, + exploiter: Exploiter, victim_host_factory: VictimHostFactory, ): self._telemetry_messenger = telemetry_messenger self._ip_scanner = ip_scanner + self._exploiter = exploiter self._victim_host_factory = victim_host_factory self._hosts_to_exploit = None def propagate(self, propagation_config: Dict, stop: Event): logger.info("Attempting to propagate") + network_scan_completed = Event() self._hosts_to_exploit = Queue() scan_thread = create_daemon_thread( target=self._scan_network, args=(propagation_config, stop) ) exploit_thread = create_daemon_thread( - target=self._exploit_targets, args=(scan_thread, stop) + target=self._exploit_hosts, + args=(scan_thread, propagation_config, network_scan_completed, stop), ) scan_thread.start() exploit_thread.start() scan_thread.join() + network_scan_completed.set() + exploit_thread.join() logger.info("Finished attempting to propagate") @@ -101,5 +114,34 @@ class Propagator: for service, details in fd.services.items(): victim_host.services.setdefault(service, {}).update(details) - def _exploit_targets(self, scan_thread: Thread, stop: Event): - pass + def _exploit_hosts( + self, + scan_thread: Thread, + propagation_config: Dict, + network_scan_completed: Event, + stop: Event, + ): + logger.info("Exploiting victims") + + exploiter_config = propagation_config["exploiters"] + self._exploiter.exploit_hosts( + self._hosts_to_exploit, + exploiter_config, + self._process_exploit_attempts, + network_scan_completed, + stop, + ) + + logger.info("Finished exploiting victims") + + def _process_exploit_attempts( + self, exploiter_name: str, host: VictimHost, result: ExploiterResultData + ): + if result.success: + logger.info("Successfully propagated to {host} using {exploiter_name}") + else: + logger.info(result.error_message) + + self._telemetry_messenger.send_telemetry( + ExploitTelem(exploiter_name, host, result.success, result.info, result.attempts) + ) diff --git a/monkey/infection_monkey/puppet/mock_puppet.py b/monkey/infection_monkey/puppet/mock_puppet.py index fe21f4cb0..64c247170 100644 --- a/monkey/infection_monkey/puppet/mock_puppet.py +++ b/monkey/infection_monkey/puppet/mock_puppet.py @@ -281,10 +281,16 @@ class MockPuppet(IPuppet): } successful_exploiters = { DOT_1: { - "PowerShellExploiter": ExploiterResultData(True, info_powershell, attempts, None) + "PowerShellExploiter": ExploiterResultData(True, info_powershell, attempts, None), + "ZerologonExploiter": ExploiterResultData(False, {}, [], "Zerologon failed"), + "SSHExploiter": ExploiterResultData(False, info_ssh, attempts, "Failed exploiting"), }, DOT_3: { - "SSHExploiter": ExploiterResultData(False, info_ssh, attempts, "Failed exploiting") + "PowerShellExploiter": ExploiterResultData( + False, info_powershell, attempts, "PowerShell Exploiter Failed" + ), + "SSHExploiter": ExploiterResultData(False, info_ssh, attempts, "Failed exploiting"), + "ZerologonExploiter": ExploiterResultData(True, {}, [], None), }, } diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py b/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py new file mode 100644 index 000000000..5b9297fe6 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py @@ -0,0 +1,102 @@ +import logging +from queue import Queue +from threading import Barrier, Event +from unittest.mock import MagicMock + +import pytest + +from infection_monkey.master import Exploiter +from infection_monkey.model import VictimHost +from infection_monkey.puppet.mock_puppet import MockPuppet + +logger = logging.getLogger() + + +@pytest.fixture(autouse=True) +def patch_queue_timeout(monkeypatch): + monkeypatch.setattr("infection_monkey.master.exploiter.QUEUE_TIMEOUT", 0.001) + + +@pytest.fixture +def scan_completed(): + return Event() + + +@pytest.fixture +def stop(): + return Event() + + +@pytest.fixture +def callback(): + return MagicMock() + + +@pytest.fixture +def exploiter_config(): + return { + "brute_force": [ + {"name": "PowerShellExploiter", "propagator": True}, + {"name": "SSHExploiter", "propagator": True}, + ], + "vulnerability": [ + {"name": "ZerologonExploiter", "propagator": False}, + ], + } + + +@pytest.fixture +def hosts(): + return [VictimHost("10.0.0.1"), VictimHost("10.0.0.3")] + + +@pytest.fixture +def hosts_to_exploit(hosts): + q = Queue() + q.put(hosts[0]) + q.put(hosts[1]) + + return q + + +def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, hosts_to_exploit): + # Set this so that Exploiter() exits once it has processed all victims + scan_completed.set() + + e = Exploiter(MockPuppet(), 2) + e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop) + + assert callback.call_count == 5 + host_exploit_combos = set() + + for i in range(0, 5): + victim_host = callback.call_args_list[i][0][0] + exploiter_name = callback.call_args_list[i][0][1] + host_exploit_combos.add((victim_host, exploiter_name)) + + assert ("ZerologonExploiter", hosts[0]) in host_exploit_combos + assert ("PowerShellExploiter", hosts[0]) in host_exploit_combos + assert ("ZerologonExploiter", hosts[1]) in host_exploit_combos + assert ("PowerShellExploiter", hosts[1]) in host_exploit_combos + assert ("SSHExploiter", hosts[1]) in host_exploit_combos + + +def test_stop_after_callback(exploiter_config, callback, scan_completed, stop, hosts_to_exploit): + callback_barrier_count = 2 + + def _callback(*_): + # Block all threads here until 2 threads reach this barrier, then set stop + # and test that neither thread continues to scan. + _callback.barrier.wait() + stop.set() + + _callback.barrier = Barrier(callback_barrier_count) + + stoppable_callback = MagicMock(side_effect=_callback) + + # Intentionally NOT setting scan_completed.set(); _callback() will set stop + + e = Exploiter(MockPuppet(), callback_barrier_count + 2) + e.exploit_hosts(exploiter_config, hosts_to_exploit, stoppable_callback, scan_completed, stop) + + assert stoppable_callback.call_count == 2 diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py b/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py index 941f17a6c..de44f40f4 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py @@ -1,12 +1,19 @@ from threading import Event -from infection_monkey.i_puppet import FingerprintData, PingScanData, PortScanData, PortStatus +from infection_monkey.i_puppet import ( + ExploiterResultData, + FingerprintData, + PingScanData, + PortScanData, + PortStatus, +) from infection_monkey.master import IPScanResults, Propagator from infection_monkey.model import VictimHostFactory +from infection_monkey.telemetry.exploit_telem import ExploitTelem empty_fingerprint_data = FingerprintData(None, None, {}) -dot_1_results = IPScanResults( +dot_1_scan_results = IPScanResults( PingScanData(True, "windows"), { 22: PortScanData(22, PortStatus.CLOSED, None, None), @@ -20,7 +27,7 @@ dot_1_results = IPScanResults( }, ) -dot_3_results = IPScanResults( +dot_3_scan_results = IPScanResults( PingScanData(True, "linux"), { 22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"), @@ -43,7 +50,7 @@ dot_3_results = IPScanResults( }, ) -dead_host_results = IPScanResults( +dead_host_scan_results = IPScanResults( PingScanData(False, None), { 22: PortScanData(22, PortStatus.CLOSED, None, None), @@ -80,19 +87,27 @@ class MockIPScanner: def scan(self, ips_to_scan, _, results_callback, stop): for ip in ips_to_scan: if ip.endswith(".1"): - results_callback(ip, dot_1_results) + results_callback(ip, dot_1_scan_results) elif ip.endswith(".3"): - results_callback(ip, dot_3_results) + results_callback(ip, dot_3_scan_results) else: - results_callback(ip, dead_host_results) + results_callback(ip, dead_host_scan_results) + + +class StubExploiter: + def exploit_hosts( + self, hosts_to_exploit, exploiter_config, results_callback, scan_completed, stop + ): + pass def test_scan_result_processing(telemetry_messenger_spy): - p = Propagator(telemetry_messenger_spy, MockIPScanner(), VictimHostFactory()) + p = Propagator(telemetry_messenger_spy, MockIPScanner(), StubExploiter(), VictimHostFactory()) p.propagate( { "targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}, - "network_scan": {}, + "network_scan": {}, # This is empty since MockIPscanner ignores it + "exploiters": {}, # This is empty since StubExploiter ignores it }, Event(), ) @@ -120,3 +135,79 @@ def test_scan_result_processing(telemetry_messenger_spy): assert data["machine"]["os"] == {} assert data["machine"]["services"] == {} assert data["machine"]["icmp"] is False + + +class MockExploiter: + def exploit_hosts( + self, hosts_to_exploit, exploiter_config, results_callback, scan_completed, stop + ): + hte = [] + for _ in range(0, 2): + hte.append(hosts_to_exploit.get()) + + for host in hte: + if host.ip_addr.endswith(".1"): + results_callback( + "PowerShellExploiter", + host, + ExploiterResultData(True, {}, {}, None), + ) + results_callback( + "SSHExploiter", + host, + ExploiterResultData(False, {}, {}, "SSH FAILED for .1"), + ) + if host.ip_addr.endswith(".2"): + results_callback( + "PowerShellExploiter", + host, + ExploiterResultData(False, {}, {}, "POWERSHELL FAILED for .2"), + ) + results_callback( + "SSHExploiter", + host, + ExploiterResultData(False, {}, {}, "SSH FAILED for .2"), + ) + if host.ip_addr.endswith(".3"): + results_callback( + "PowerShellExploiter", + host, + ExploiterResultData(False, {}, {}, "POWERSHELL FAILED for .3"), + ) + results_callback( + "SSHExploiter", + host, + ExploiterResultData(True, {}, {}, None), + ) + + +def test_exploiter_result_processing(telemetry_messenger_spy): + p = Propagator(telemetry_messenger_spy, MockIPScanner(), MockExploiter(), VictimHostFactory()) + p.propagate( + { + "targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}, + "network_scan": {}, # This is empty since MockIPscanner ignores it + "exploiters": {}, # This is empty since MockExploiter ignores it + }, + Event(), + ) + + exploit_telems = [t for t in telemetry_messenger_spy.telemetries if isinstance(t, ExploitTelem)] + assert len(exploit_telems) == 4 + + for t in exploit_telems: + data = t.get_data() + ip = data["machine"]["ip_addr"] + + assert ip.endswith(".1") or ip.endswith(".3") + + if ip.endswith(".1"): + if data["exploiter"].startswith("PowerShell"): + assert data["result"] + else: + assert not data["result"] + elif ip.endswith(".3"): + if data["exploiter"].startswith("PowerShell"): + assert not data["result"] + else: + assert data["result"]