From 3394629cb275e6cfa361b7e109a44d41dc2bf54f Mon Sep 17 00:00:00 2001
From: Mike Salvatore <mike.s.salvatore@gmail.com>
Date: Tue, 14 Dec 2021 14:22:46 -0500
Subject: [PATCH] Agent: Run exploiters from AutomatedMaster

---
 monkey/infection_monkey/master/__init__.py    |   1 +
 .../master/automated_master.py                |   8 +-
 monkey/infection_monkey/master/exploiter.py   | 107 +++++++++++++++++
 monkey/infection_monkey/master/propagator.py  |  52 ++++++++-
 monkey/infection_monkey/puppet/mock_puppet.py |  10 +-
 .../infection_monkey/master/test_exploiter.py | 102 ++++++++++++++++
 .../master/test_propagator.py                 | 109 ++++++++++++++++--
 7 files changed, 371 insertions(+), 18 deletions(-)
 create mode 100644 monkey/infection_monkey/master/exploiter.py
 create mode 100644 monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py

diff --git a/monkey/infection_monkey/master/__init__.py b/monkey/infection_monkey/master/__init__.py
index fda536194..98ed6db0b 100644
--- a/monkey/infection_monkey/master/__init__.py
+++ b/monkey/infection_monkey/master/__init__.py
@@ -1,4 +1,5 @@
 from .ip_scan_results import IPScanResults
 from .ip_scanner import IPScanner
+from .exploiter import Exploiter
 from .propagator import Propagator
 from .automated_master import AutomatedMaster
diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py
index 57b8f52b2..ff6af8b43 100644
--- a/monkey/infection_monkey/master/automated_master.py
+++ b/monkey/infection_monkey/master/automated_master.py
@@ -12,13 +12,14 @@ from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
 from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
 from infection_monkey.utils.timer import Timer
 
-from . import IPScanner, Propagator
+from . import Exploiter, IPScanner, Propagator
 from .threading_utils import create_daemon_thread
 
 CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
 CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
 SHUTDOWN_TIMEOUT = 5
 NUM_SCAN_THREADS = 16  # TODO: Adjust this to the optimal number of scan threads
+NUM_EXPLOIT_THREADS = 4  # TODO: Adjust this to the optimal number of exploit threads
 
 logger = logging.getLogger()
 
@@ -36,7 +37,10 @@ class AutomatedMaster(IMaster):
         self._control_channel = control_channel
 
         ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
-        self._propagator = Propagator(self._telemetry_messenger, ip_scanner, victim_host_factory)
+        exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS)
+        self._propagator = Propagator(
+            self._telemetry_messenger, ip_scanner, exploiter, victim_host_factory
+        )
 
         self._stop = threading.Event()
         self._master_thread = create_daemon_thread(target=self._run_master_thread)
diff --git a/monkey/infection_monkey/master/exploiter.py b/monkey/infection_monkey/master/exploiter.py
new file mode 100644
index 000000000..3f732ffa3
--- /dev/null
+++ b/monkey/infection_monkey/master/exploiter.py
@@ -0,0 +1,107 @@
+import logging
+import queue
+import threading
+from queue import Queue
+from threading import Event
+from typing import Callable, Dict, List
+
+from infection_monkey.i_puppet import ExploiterResultData, IPuppet
+from infection_monkey.model import VictimHost
+
+from .threading_utils import create_daemon_thread
+
+QUEUE_TIMEOUT = 2
+
+logger = logging.getLogger()
+
+ExploiterName = str
+Callback = Callable[[VictimHost, ExploiterName, ExploiterResultData], None]
+
+
+class Exploiter:
+    def __init__(self, puppet: IPuppet, num_workers: int):
+        self._puppet = puppet
+        self._num_workers = num_workers
+
+    def exploit_hosts(
+        self,
+        exploiter_config: Dict,
+        hosts_to_exploit: Queue,
+        results_callback: Callback,
+        scan_completed: Event,
+        stop: Event,
+    ):
+        # Run vulnerability exploiters before brute force exploiters to minimize the effect of
+        # account lockout due to invalid credentials
+        exploiters_to_run = exploiter_config["vulnerability"] + exploiter_config["brute_force"]
+        logger.debug(
+            "Agent is configured to run the following exploiters in order: "
+            f"{','.join([e['name'] for e in exploiters_to_run])}"
+        )
+
+        exploit_args = (exploiters_to_run, hosts_to_exploit, results_callback, scan_completed, stop)
+
+        # TODO: This functionality is also used in IPScanner and can be generalized. Extract it.
+        exploiter_threads = []
+        for i in range(0, self._num_workers):
+            t = create_daemon_thread(target=self._exploit_hosts_on_queue, args=exploit_args)
+            t.start()
+            exploiter_threads.append(t)
+
+        for t in exploiter_threads:
+            t.join()
+
+    def _exploit_hosts_on_queue(
+        self,
+        exploiters_to_run: List[Dict],
+        hosts_to_exploit: Queue,
+        results_callback: Callback,
+        scan_completed: Event,
+        stop: Event,
+    ):
+        logger.debug(f"Starting exploiter thread -- Thread ID: {threading.get_ident()}")
+
+        while not stop.is_set():
+            try:
+                victim_host = hosts_to_exploit.get(timeout=QUEUE_TIMEOUT)
+                self._run_all_exploiters(exploiters_to_run, victim_host, results_callback, stop)
+            except queue.Empty:
+                if (
+                    _all_hosts_have_been_processed(scan_completed, hosts_to_exploit)
+                    or stop.is_set()
+                ):
+                    break
+
+        logger.debug(
+            f"Exiting exploiter thread -- Thread ID: {threading.get_ident()} -- "
+            f"stop.is_set(): {stop.is_set()} -- network_scan_completed: "
+            f"{scan_completed.is_set()}"
+        )
+
+    def _run_all_exploiters(
+        self,
+        exploiters_to_run: List[Dict],
+        victim_host: VictimHost,
+        results_callback: Callback,
+        stop: Event,
+    ):
+        for exploiter in exploiters_to_run:
+            if stop.is_set():
+                break
+
+            exploiter_name = exploiter["name"]
+            exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
+            results_callback(exploiter_name, victim_host, exploiter_results)
+
+            if exploiter["propagator"] and exploiter_results.success:
+                break
+
+    def _run_exploiter(
+        self, exploiter_name: str, victim_host: VictimHost, stop: Event
+    ) -> ExploiterResultData:
+        logger.debug(f"Attempting to use {exploiter_name} on {victim_host}")
+        return self._puppet.exploit_host(exploiter_name, victim_host.ip_addr, {}, stop)
+
+
+def _all_hosts_have_been_processed(scan_completed: Event, hosts_to_exploit: Queue):
+    return scan_completed.is_set() and hosts_to_exploit.empty()
diff --git a/monkey/infection_monkey/master/propagator.py b/monkey/infection_monkey/master/propagator.py
index 78e08a98d..24d5fb8f0 100644
--- a/monkey/infection_monkey/master/propagator.py
+++ b/monkey/infection_monkey/master/propagator.py
@@ -3,12 +3,19 @@ from queue import Queue
 from threading import Event, Thread
 from typing import Dict
 
-from infection_monkey.i_puppet import FingerprintData, PingScanData, PortScanData, PortStatus
+from infection_monkey.i_puppet import (
+    ExploiterResultData,
+    FingerprintData,
+    PingScanData,
+    PortScanData,
+    PortStatus,
+)
 from infection_monkey.model import VictimHost, VictimHostFactory
+from infection_monkey.telemetry.exploit_telem import ExploitTelem
 from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
 from infection_monkey.telemetry.scan_telem import ScanTelem
 
-from . import IPScanner, IPScanResults
+from . import Exploiter, IPScanner, IPScanResults
 from .threading_utils import create_daemon_thread
 
 logger = logging.getLogger()
@@ -19,29 +26,35 @@ class Propagator:
         self,
         telemetry_messenger: ITelemetryMessenger,
         ip_scanner: IPScanner,
+        exploiter: Exploiter,
         victim_host_factory: VictimHostFactory,
     ):
         self._telemetry_messenger = telemetry_messenger
         self._ip_scanner = ip_scanner
+        self._exploiter = exploiter
         self._victim_host_factory = victim_host_factory
         self._hosts_to_exploit = None
 
     def propagate(self, propagation_config: Dict, stop: Event):
         logger.info("Attempting to propagate")
 
+        network_scan_completed = Event()
         self._hosts_to_exploit = Queue()
 
         scan_thread = create_daemon_thread(
             target=self._scan_network, args=(propagation_config, stop)
         )
         exploit_thread = create_daemon_thread(
-            target=self._exploit_targets, args=(scan_thread, stop)
+            target=self._exploit_hosts,
+            args=(scan_thread, propagation_config, network_scan_completed, stop),
         )
 
         scan_thread.start()
         exploit_thread.start()
 
         scan_thread.join()
+        network_scan_completed.set()
+
         exploit_thread.join()
 
         logger.info("Finished attempting to propagate")
@@ -101,5 +114,34 @@ class Propagator:
             for service, details in fd.services.items():
                 victim_host.services.setdefault(service, {}).update(details)
 
-    def _exploit_targets(self, scan_thread: Thread, stop: Event):
-        pass
+    def _exploit_hosts(
+        self,
+        scan_thread: Thread,
+        propagation_config: Dict,
+        network_scan_completed: Event,
+        stop: Event,
+    ):
+        logger.info("Exploiting victims")
+
+        exploiter_config = propagation_config["exploiters"]
+        self._exploiter.exploit_hosts(
+            self._hosts_to_exploit,
+            exploiter_config,
+            self._process_exploit_attempts,
+            network_scan_completed,
+            stop,
+        )
+
+        logger.info("Finished exploiting victims")
+
+    def _process_exploit_attempts(
+        self, exploiter_name: str, host: VictimHost, result: ExploiterResultData
+    ):
+        if result.success:
+            logger.info("Successfully propagated to {host} using {exploiter_name}")
+        else:
+            logger.info(result.error_message)
+
+        self._telemetry_messenger.send_telemetry(
+            ExploitTelem(exploiter_name, host, result.success, result.info, result.attempts)
+        )
diff --git a/monkey/infection_monkey/puppet/mock_puppet.py b/monkey/infection_monkey/puppet/mock_puppet.py
index fe21f4cb0..64c247170 100644
--- a/monkey/infection_monkey/puppet/mock_puppet.py
+++ b/monkey/infection_monkey/puppet/mock_puppet.py
@@ -281,10 +281,16 @@ class MockPuppet(IPuppet):
         }
         successful_exploiters = {
             DOT_1: {
-                "PowerShellExploiter": ExploiterResultData(True, info_powershell, attempts, None)
+                "PowerShellExploiter": ExploiterResultData(True, info_powershell, attempts, None),
+                "ZerologonExploiter": ExploiterResultData(False, {}, [], "Zerologon failed"),
+                "SSHExploiter": ExploiterResultData(False, info_ssh, attempts, "Failed exploiting"),
             },
             DOT_3: {
-                "SSHExploiter": ExploiterResultData(False, info_ssh, attempts, "Failed exploiting")
+                "PowerShellExploiter": ExploiterResultData(
+                    False, info_powershell, attempts, "PowerShell Exploiter Failed"
+                ),
+                "SSHExploiter": ExploiterResultData(False, info_ssh, attempts, "Failed exploiting"),
+                "ZerologonExploiter": ExploiterResultData(True, {}, [], None),
             },
         }
 
diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py b/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py
new file mode 100644
index 000000000..5b9297fe6
--- /dev/null
+++ b/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py
@@ -0,0 +1,102 @@
+import logging
+from queue import Queue
+from threading import Barrier, Event
+from unittest.mock import MagicMock
+
+import pytest
+
+from infection_monkey.master import Exploiter
+from infection_monkey.model import VictimHost
+from infection_monkey.puppet.mock_puppet import MockPuppet
+
+logger = logging.getLogger()
+
+
+@pytest.fixture(autouse=True)
+def patch_queue_timeout(monkeypatch):
+    monkeypatch.setattr("infection_monkey.master.exploiter.QUEUE_TIMEOUT", 0.001)
+
+
+@pytest.fixture
+def scan_completed():
+    return Event()
+
+
+@pytest.fixture
+def stop():
+    return Event()
+
+
+@pytest.fixture
+def callback():
+    return MagicMock()
+
+
+@pytest.fixture
+def exploiter_config():
+    return {
+        "brute_force": [
+            {"name": "PowerShellExploiter", "propagator": True},
+            {"name": "SSHExploiter", "propagator": True},
+        ],
+        "vulnerability": [
+            {"name": "ZerologonExploiter", "propagator": False},
+        ],
+    }
+
+
+@pytest.fixture
+def hosts():
+    return [VictimHost("10.0.0.1"), VictimHost("10.0.0.3")]
+
+
+@pytest.fixture
+def hosts_to_exploit(hosts):
+    q = Queue()
+    q.put(hosts[0])
+    q.put(hosts[1])
+
+    return q
+
+
+def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, hosts_to_exploit):
+    # Set this so that Exploiter() exits once it has processed all victims
+    scan_completed.set()
+
+    e = Exploiter(MockPuppet(), 2)
+    e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop)
+
+    assert callback.call_count == 5
+    host_exploit_combos = set()
+
+    for i in range(0, 5):
+        victim_host = callback.call_args_list[i][0][0]
+        exploiter_name = callback.call_args_list[i][0][1]
+        host_exploit_combos.add((victim_host, exploiter_name))
+
+    assert ("ZerologonExploiter", hosts[0]) in host_exploit_combos
+    assert ("PowerShellExploiter", hosts[0]) in host_exploit_combos
+    assert ("ZerologonExploiter", hosts[1]) in host_exploit_combos
+    assert ("PowerShellExploiter", hosts[1]) in host_exploit_combos
+    assert ("SSHExploiter", hosts[1]) in host_exploit_combos
+
+
+def test_stop_after_callback(exploiter_config, callback, scan_completed, stop, hosts_to_exploit):
+    callback_barrier_count = 2
+
+    def _callback(*_):
+        # Block all threads here until 2 threads reach this barrier, then set stop
+        # and test that neither thread continues to scan.
+        _callback.barrier.wait()
+        stop.set()
+
+    _callback.barrier = Barrier(callback_barrier_count)
+
+    stoppable_callback = MagicMock(side_effect=_callback)
+
+    # Intentionally NOT setting scan_completed.set(); _callback() will set stop
+
+    e = Exploiter(MockPuppet(), callback_barrier_count + 2)
+    e.exploit_hosts(exploiter_config, hosts_to_exploit, stoppable_callback, scan_completed, stop)
+
+    assert stoppable_callback.call_count == 2
diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py b/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py
index 941f17a6c..de44f40f4 100644
--- a/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py
+++ b/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py
@@ -1,12 +1,19 @@
 from threading import Event
 
-from infection_monkey.i_puppet import FingerprintData, PingScanData, PortScanData, PortStatus
+from infection_monkey.i_puppet import (
+    ExploiterResultData,
+    FingerprintData,
+    PingScanData,
+    PortScanData,
+    PortStatus,
+)
 from infection_monkey.master import IPScanResults, Propagator
 from infection_monkey.model import VictimHostFactory
+from infection_monkey.telemetry.exploit_telem import ExploitTelem
 
 empty_fingerprint_data = FingerprintData(None, None, {})
 
-dot_1_results = IPScanResults(
+dot_1_scan_results = IPScanResults(
     PingScanData(True, "windows"),
     {
         22: PortScanData(22, PortStatus.CLOSED, None, None),
@@ -20,7 +27,7 @@ dot_1_results = IPScanResults(
     },
 )
 
-dot_3_results = IPScanResults(
+dot_3_scan_results = IPScanResults(
     PingScanData(True, "linux"),
     {
         22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"),
@@ -43,7 +50,7 @@ dot_3_results = IPScanResults(
     },
 )
 
-dead_host_results = IPScanResults(
+dead_host_scan_results = IPScanResults(
     PingScanData(False, None),
     {
         22: PortScanData(22, PortStatus.CLOSED, None, None),
@@ -80,19 +87,27 @@ class MockIPScanner:
     def scan(self, ips_to_scan, _, results_callback, stop):
         for ip in ips_to_scan:
             if ip.endswith(".1"):
-                results_callback(ip, dot_1_results)
+                results_callback(ip, dot_1_scan_results)
             elif ip.endswith(".3"):
-                results_callback(ip, dot_3_results)
+                results_callback(ip, dot_3_scan_results)
             else:
-                results_callback(ip, dead_host_results)
+                results_callback(ip, dead_host_scan_results)
+
+
+class StubExploiter:
+    def exploit_hosts(
+        self, hosts_to_exploit, exploiter_config, results_callback, scan_completed, stop
+    ):
+        pass
 
 
 def test_scan_result_processing(telemetry_messenger_spy):
-    p = Propagator(telemetry_messenger_spy, MockIPScanner(), VictimHostFactory())
+    p = Propagator(telemetry_messenger_spy, MockIPScanner(), StubExploiter(), VictimHostFactory())
     p.propagate(
         {
             "targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]},
-            "network_scan": {},
+            "network_scan": {},  # This is empty since MockIPscanner ignores it
+            "exploiters": {},  # This is empty since StubExploiter ignores it
         },
         Event(),
     )
@@ -120,3 +135,79 @@ def test_scan_result_processing(telemetry_messenger_spy):
             assert data["machine"]["os"] == {}
             assert data["machine"]["services"] == {}
             assert data["machine"]["icmp"] is False
+
+
+class MockExploiter:
+    def exploit_hosts(
+        self, hosts_to_exploit, exploiter_config, results_callback, scan_completed, stop
+    ):
+        hte = []
+        for _ in range(0, 2):
+            hte.append(hosts_to_exploit.get())
+
+        for host in hte:
+            if host.ip_addr.endswith(".1"):
+                results_callback(
+                    "PowerShellExploiter",
+                    host,
+                    ExploiterResultData(True, {}, {}, None),
+                )
+                results_callback(
+                    "SSHExploiter",
+                    host,
+                    ExploiterResultData(False, {}, {}, "SSH FAILED for .1"),
+                )
+            if host.ip_addr.endswith(".2"):
+                results_callback(
+                    "PowerShellExploiter",
+                    host,
+                    ExploiterResultData(False, {}, {}, "POWERSHELL FAILED for .2"),
+                )
+                results_callback(
+                    "SSHExploiter",
+                    host,
+                    ExploiterResultData(False, {}, {}, "SSH FAILED for .2"),
+                )
+            if host.ip_addr.endswith(".3"):
+                results_callback(
+                    "PowerShellExploiter",
+                    host,
+                    ExploiterResultData(False, {}, {}, "POWERSHELL FAILED for .3"),
+                )
+                results_callback(
+                    "SSHExploiter",
+                    host,
+                    ExploiterResultData(True, {}, {}, None),
+                )
+
+
+def test_exploiter_result_processing(telemetry_messenger_spy):
+    p = Propagator(telemetry_messenger_spy, MockIPScanner(), MockExploiter(), VictimHostFactory())
+    p.propagate(
+        {
+            "targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]},
+            "network_scan": {},  # This is empty since MockIPscanner ignores it
+            "exploiters": {},  # This is empty since MockExploiter ignores it
+        },
+        Event(),
+    )
+
+    exploit_telems = [t for t in telemetry_messenger_spy.telemetries if isinstance(t, ExploitTelem)]
+    assert len(exploit_telems) == 4
+
+    for t in exploit_telems:
+        data = t.get_data()
+        ip = data["machine"]["ip_addr"]
+
+        assert ip.endswith(".1") or ip.endswith(".3")
+
+        if ip.endswith(".1"):
+            if data["exploiter"].startswith("PowerShell"):
+                assert data["result"]
+            else:
+                assert not data["result"]
+        elif ip.endswith(".3"):
+            if data["exploiter"].startswith("PowerShell"):
+                assert not data["result"]
+            else:
+                assert data["result"]