From 81d4afab5248845f7bdb0c79700b6f873636d85c Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 9 Dec 2021 20:59:42 -0500 Subject: [PATCH] Agent: Extract network scanner into its own class --- monkey/infection_monkey/master/__init__.py | 1 + .../master/automated_master.py | 89 +++-------- monkey/infection_monkey/master/ip_scanner.py | 100 ++++++++++++ .../master/test_network_scanner.py | 146 ++++++++++++++++++ 4 files changed, 270 insertions(+), 66 deletions(-) create mode 100644 monkey/infection_monkey/master/ip_scanner.py create mode 100644 monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py diff --git a/monkey/infection_monkey/master/__init__.py b/monkey/infection_monkey/master/__init__.py index 6d3942abd..bf8e1775c 100644 --- a/monkey/infection_monkey/master/__init__.py +++ b/monkey/infection_monkey/master/__init__.py @@ -1 +1,2 @@ +from .ip_scanner import IPScanner from .automated_master import AutomatedMaster diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py index 7a72faca0..bc304d2d8 100644 --- a/monkey/infection_monkey/master/automated_master.py +++ b/monkey/infection_monkey/master/automated_master.py @@ -1,5 +1,4 @@ import logging -import queue import threading import time from queue import Queue @@ -8,7 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_master import IMaster -from infection_monkey.i_puppet import IPuppet, PortStatus +from infection_monkey.i_puppet import IPuppet from infection_monkey.model.host import VictimHost from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.post_breach_telem import PostBreachTelem @@ -16,6 +15,7 @@ from infection_monkey.telemetry.scan_telem import ScanTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.utils.timer import Timer +from . import IPScanner from .threading_utils import create_daemon_thread CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5 @@ -37,6 +37,9 @@ class AutomatedMaster(IMaster): self._telemetry_messenger = telemetry_messenger self._control_channel = control_channel + self._ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS) + self._hosts_to_exploit = None + self._stop = threading.Event() self._master_thread = create_daemon_thread(target=self._run_master_thread) self._simulation_thread = create_daemon_thread(target=self._run_simulation) @@ -156,17 +159,16 @@ class AutomatedMaster(IMaster): def _can_propagate(self): return True + # TODO: Refactor propagation into its own class def _propagate(self, config: Dict): logger.info("Attempting to propagate") - hosts_to_exploit = Queue() + self._hosts_to_exploit = Queue() scan_thread = create_daemon_thread( - target=self._scan_network, args=(config["network_scan"], hosts_to_exploit) - ) - exploit_thread = create_daemon_thread( - target=self._exploit_targets, args=(hosts_to_exploit, scan_thread) + target=self._scan_network, args=(config["network_scan"],) ) + exploit_thread = create_daemon_thread(target=self._exploit_targets, args=(scan_thread,)) scan_thread.start() exploit_thread.start() @@ -176,73 +178,28 @@ class AutomatedMaster(IMaster): logger.info("Finished attempting to propagate") - def _exploit_targets(self, hosts_to_exploit: Queue, scan_thread: Thread): - pass - - # TODO: Refactor this into its own class - def _scan_network(self, scan_config: Dict, hosts_to_exploit: Queue): + def _scan_network(self, scan_config: Dict): logger.info("Starting network scan") # TODO: Generate list of IPs to scan - ips_to_scan = Queue() - for i in range(1, 255): - ips_to_scan.put(f"10.0.0.{i}") + ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)] - scan_threads = [] - for i in range(0, NUM_SCAN_THREADS): - t = create_daemon_thread( - target=self._scan_ips, args=(ips_to_scan, scan_config, hosts_to_exploit) - ) - t.start() - scan_threads.append(t) - - for t in scan_threads: - t.join() + self._ip_scanner.scan( + ips_to_scan, + scan_config["icmp"], + scan_config["tcp"], + self._handle_scanned_host, + self._stop, + ) logger.info("Finished network scan") - def _scan_ips(self, ips_to_scan: Queue, scan_config: Dict, hosts_to_exploit: Queue): - logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") - try: - while not self._stop.is_set(): - ip = ips_to_scan.get_nowait() - logger.info(f"Scanning {ip}") + def _handle_scanned_host(self, host: VictimHost): + self._hosts_to_exploit.put(host) + self._telemetry_messenger.send_telemetry(ScanTelem(host)) - victim_host = VictimHost(ip) - - self._ping_ip(ip, victim_host, scan_config["icmp"]) - self._scan_tcp_ports(ip, victim_host, scan_config["tcp"]) - - hosts_to_exploit.put(hosts_to_exploit) - self._telemetry_messenger.send_telemetry(ScanTelem(victim_host)) - - except queue.Empty: - logger.debug( - f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting" - ) - - logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting") - - def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict): - (response_received, os) = self._puppet.ping(ip, options) - - victim_host.icmp = response_received - if os is not None: - victim_host.os["type"] = os - - def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict): - for p in options["ports"]: - if self._stop.is_set(): - break - - # TODO: check units of timeout - port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"]) - if port_scan_data.status == PortStatus.OPEN: - victim_host.services[port_scan_data.service] = {} - victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)" - victim_host.services[port_scan_data.service]["port"] = port_scan_data.port - if port_scan_data.banner is not None: - victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner + def _exploit_targets(self, scan_thread: Thread): + pass def _run_payload(self, payload: Tuple[str, Dict]): name = payload[0] diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py new file mode 100644 index 000000000..61329ef5d --- /dev/null +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -0,0 +1,100 @@ +import logging +import queue +import threading +from queue import Queue +from threading import Event +from typing import Callable, Dict, List + +from infection_monkey.i_puppet import IPuppet, PortStatus +from infection_monkey.model.host import VictimHost + +from .threading_utils import create_daemon_thread + +logger = logging.getLogger() + +Callback = Callable[[VictimHost], None] + + +class IPScanner: + def __init__(self, puppet: IPuppet, num_workers: int): + self._puppet = puppet + self._num_workers = num_workers + + def scan( + self, + ips: List[str], + icmp_config: Dict, + tcp_config: Dict, + report_results_callback: Callback, + stop: Event, + ): + # Pre-fill a Queue with all IPs so that threads can safely exit when the queue is empty. + ips_to_scan = Queue() + for ip in ips: + ips_to_scan.put(ip) + + scan_ips_args = ( + ips_to_scan, + icmp_config, + tcp_config, + report_results_callback, + stop, + ) + scan_threads = [] + for i in range(0, self._num_workers): + t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args) + t.start() + scan_threads.append(t) + + for t in scan_threads: + t.join() + + def _scan_ips( + self, + ips_to_scan: Queue, + icmp_config: Dict, + tcp_config: Dict, + report_results_callback: Callback, + stop: Event, + ): + logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") + + try: + while not stop.is_set(): + ip = ips_to_scan.get_nowait() + logger.info(f"Scanning {ip}") + + victim_host = VictimHost(ip) + + self._ping_ip(ip, victim_host, icmp_config) + self._scan_tcp_ports(ip, victim_host, tcp_config, stop) + + report_results_callback(victim_host) + + except queue.Empty: + logger.debug( + f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting" + ) + return + + logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting") + + def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict): + (response_received, os) = self._puppet.ping(ip, options) + + victim_host.icmp = response_received + if os is not None: + victim_host.os["type"] = os + + def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict, stop: Event): + for p in options["ports"]: + if stop.is_set(): + break + + port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"]) + if port_scan_data.status == PortStatus.OPEN: + victim_host.services[port_scan_data.service] = {} + victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)" + victim_host.services[port_scan_data.service]["port"] = port_scan_data.port + if port_scan_data.banner is not None: + victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py b/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py new file mode 100644 index 000000000..f73b5f39a --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py @@ -0,0 +1,146 @@ +from threading import Barrier, Event +from unittest.mock import MagicMock + +import pytest + +from infection_monkey.i_puppet import PortScanData +from infection_monkey.master import IPScanner +from infection_monkey.puppet.mock_puppet import MockPuppet + +WINDOWS_OS = {"type": "windows"} +LINUX_OS = {"type": "linux"} + + +class MockPuppet(MockPuppet): + def __init__(self): + self.ping = MagicMock(side_effect=super().ping) + self.scan_tcp_port = MagicMock(side_effect=super().scan_tcp_port) + + +@pytest.fixture +def tcp_scan_config(): + return { + "timeout_ms": 3000, + "ports": [ + 22, + 445, + 3389, + 443, + 8008, + 3306, + ], + } + + +@pytest.fixture +def icmp_scan_config(): + return { + "timeout_ms": 1000, + } + + +@pytest.fixture +def stop(): + return Event() + + +@pytest.fixture +def callback(): + return MagicMock() + + +def assert_dot_1(victim_host): + assert victim_host.icmp is True + assert victim_host.os == WINDOWS_OS + + assert len(victim_host.services.keys()) == 2 + assert "tcp-445" in victim_host.services + assert victim_host.services["tcp-445"]["port"] == 445 + assert victim_host.services["tcp-445"]["banner"] == "SMB BANNER" + assert "tcp-3389" in victim_host.services + assert victim_host.services["tcp-3389"]["port"] == 3389 + + +def assert_dot_3(victim_host): + assert victim_host.icmp is True + assert victim_host.os == LINUX_OS + + assert len(victim_host.services.keys()) == 2 + assert "tcp-22" in victim_host.services + assert victim_host.services["tcp-22"]["port"] == 22 + assert victim_host.services["tcp-22"]["banner"] == "SSH BANNER" + + assert "tcp-443" in victim_host.services + assert victim_host.services["tcp-443"]["port"] == 443 + assert victim_host.services["tcp-443"]["banner"] == "HTTPS BANNER" + + +def assert_host_down(victim_host): + assert victim_host.icmp is False + assert len(victim_host.services.keys()) == 0 + + +def test_scan_single_ip(callback, icmp_scan_config, tcp_scan_config, stop): + ips = ["10.0.0.1"] + + ns = IPScanner(MockPuppet(), num_workers=1) + ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) + + callback.assert_called_once() + + assert_dot_1(callback.call_args_list[0][0][0]) + + +def test_scan_multiple_ips(callback, icmp_scan_config, tcp_scan_config, stop): + ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] + + ns = IPScanner(MockPuppet(), num_workers=4) + ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) + + assert callback.call_count == 4 + + assert_dot_1(callback.call_args_list[0][0][0]) + assert_host_down(callback.call_args_list[1][0][0]) + assert_dot_3(callback.call_args_list[2][0][0]) + assert_host_down(callback.call_args_list[3][0][0]) + + +def test_stop_after_callback(icmp_scan_config, tcp_scan_config, stop): + def _callback(_): + # Block all threads here until 2 threads reach this barrier, then set stop + # and test that niether thread continues to scan. + _callback.barrier.wait() + stop.set() + + _callback.barrier = Barrier(2) + + stopable_callback = MagicMock(side_effect=_callback) + + ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] + + ns = IPScanner(MockPuppet(), num_workers=2) + ns.scan(ips, icmp_scan_config, tcp_scan_config, stopable_callback, stop) + + assert stopable_callback.call_count == 2 + + +def test_interrupt_port_scanning(callback, icmp_scan_config, tcp_scan_config, stop): + def stopable_scan_tcp_port(port, _, __): + # Block all threads here until 2 threads reach this barrier, then set stop + # and test that niether thread scans any more ports + stopable_scan_tcp_port.barrier.wait() + stop.set() + + return PortScanData(port, False, None, None) + + stopable_scan_tcp_port.barrier = Barrier(2) + + puppet = MockPuppet() + puppet.scan_tcp_port = MagicMock(side_effect=stopable_scan_tcp_port) + + ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] + + ns = IPScanner(puppet, num_workers=2) + ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) + + assert puppet.scan_tcp_port.call_count == 2