diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py index b31c21550..721c1a243 100644 --- a/monkey/infection_monkey/master/automated_master.py +++ b/monkey/infection_monkey/master/automated_master.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_master import IMaster -from infection_monkey.i_puppet import IPuppet +from infection_monkey.i_puppet import IPuppet, PingScanData, PortScanData, PortStatus from infection_monkey.model.host import VictimHost from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.post_breach_telem import PostBreachTelem @@ -185,13 +185,34 @@ class AutomatedMaster(IMaster): ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)] scan_config = propagation_config["network_scan"] - self._ip_scanner.scan(ips_to_scan, scan_config, self._handle_scanned_host, self._stop) + self._ip_scanner.scan(ips_to_scan, scan_config, self._process_scan_results, self._stop) logger.info("Finished network scan") - def _handle_scanned_host(self, host: VictimHost): - self._hosts_to_exploit.put(host) - self._telemetry_messenger.send_telemetry(ScanTelem(host)) + def _process_scan_results( + self, ip: str, ping_scan_data: PingScanData, port_scan_data: PortScanData + ): + victim_host = VictimHost(ip) + has_open_port = False + + victim_host.icmp = ping_scan_data.response_received + if ping_scan_data.os is not None: + victim_host.os["type"] = ping_scan_data.os + + for psd in port_scan_data.values(): + if psd.status == PortStatus.OPEN: + has_open_port = True + + victim_host.services[psd.service] = {} + victim_host.services[psd.service]["display_name"] = "unknown(TCP)" + victim_host.services[psd.service]["port"] = psd.port + if psd.banner is not None: + victim_host.services[psd.service]["banner"] = psd.banner + + if has_open_port: + self._hosts_to_exploit.put(victim_host) + + self._telemetry_messenger.send_telemetry(ScanTelem(victim_host)) def _exploit_targets(self, scan_thread: Thread): pass diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py index 419931064..8073abad3 100644 --- a/monkey/infection_monkey/master/ip_scanner.py +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -5,14 +5,13 @@ from queue import Queue from threading import Event from typing import Callable, Dict, List -from infection_monkey.i_puppet import IPuppet, PortStatus -from infection_monkey.model.host import VictimHost +from infection_monkey.i_puppet import IPuppet, PingScanData, PortScanData from .threading_utils import create_daemon_thread logger = logging.getLogger() -Callback = Callable[[VictimHost], None] +Callback = Callable[[str, PingScanData, Dict[int, PortScanData]], None] class IPScanner: @@ -45,12 +44,10 @@ class IPScanner: ip = ips.get_nowait() logger.info(f"Scanning {ip}") - victim_host = VictimHost(ip) + ping_scan_data = self._puppet.ping(ip, options["icmp"]) + port_scan_data = self._scan_tcp_ports(ip, options["tcp"], stop) - self._ping_ip(ip, victim_host, options["icmp"]) - self._scan_tcp_ports(ip, victim_host, options["tcp"], stop) - - results_callback(victim_host) + results_callback(ip, ping_scan_data, port_scan_data) except queue.Empty: logger.debug( @@ -60,22 +57,12 @@ class IPScanner: logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting") - def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict): - ping_scan_data = self._puppet.ping(ip, options) - - victim_host.icmp = ping_scan_data.response_received - if ping_scan_data.os is not None: - victim_host.os["type"] = ping_scan_data.os - - def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict, stop: Event): + def _scan_tcp_ports(self, ip: str, options: Dict, stop: Event): + port_scan_data = {} for p in options["ports"]: if stop.is_set(): break - port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"]) - if port_scan_data.status == PortStatus.OPEN: - victim_host.services[port_scan_data.service] = {} - victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)" - victim_host.services[port_scan_data.service]["port"] = port_scan_data.port - if port_scan_data.banner is not None: - victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner + port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"]) + + return port_scan_data diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py b/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py index 186d85be1..078a47593 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py @@ -1,14 +1,15 @@ from threading import Barrier, Event +from typing import Set from unittest.mock import MagicMock import pytest -from infection_monkey.i_puppet import PortScanData +from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.master import IPScanner from infection_monkey.puppet.mock_puppet import MockPuppet -WINDOWS_OS = {"type": "windows"} -LINUX_OS = {"type": "linux"} +WINDOWS_OS = "windows" +LINUX_OS = "linux" class MockPuppet(MockPuppet): @@ -47,35 +48,65 @@ def callback(): return MagicMock() -def assert_dot_1(victim_host): - assert victim_host.icmp is True - assert victim_host.os == WINDOWS_OS - - assert len(victim_host.services.keys()) == 2 - assert "tcp-445" in victim_host.services - assert victim_host.services["tcp-445"]["port"] == 445 - assert victim_host.services["tcp-445"]["banner"] == "SMB BANNER" - assert "tcp-3389" in victim_host.services - assert victim_host.services["tcp-3389"]["port"] == 3389 +def assert_port_status(port_scan_data, expected_open_ports: Set[int]): + for psd in port_scan_data.values(): + if psd.port in expected_open_ports: + assert psd.status == PortStatus.OPEN + else: + assert psd.status == PortStatus.CLOSED -def assert_dot_3(victim_host): - assert victim_host.icmp is True - assert victim_host.os == LINUX_OS +def assert_dot_1(ip, ping_scan_data, port_scan_data): + assert ip == "10.0.0.1" - assert len(victim_host.services.keys()) == 2 - assert "tcp-22" in victim_host.services - assert victim_host.services["tcp-22"]["port"] == 22 - assert victim_host.services["tcp-22"]["banner"] == "SSH BANNER" + assert ping_scan_data.response_received is True + assert ping_scan_data.os == WINDOWS_OS - assert "tcp-443" in victim_host.services - assert victim_host.services["tcp-443"]["port"] == 443 - assert victim_host.services["tcp-443"]["banner"] == "HTTPS BANNER" + assert len(port_scan_data.keys()) == 6 + + psd_445 = port_scan_data[445] + psd_3389 = port_scan_data[3389] + + assert psd_445.status == PortStatus.OPEN + assert psd_445.port == 445 + assert psd_445.banner == "SMB BANNER" + assert psd_445.service == "tcp-445" + + assert psd_3389.status == PortStatus.OPEN + assert psd_3389.port == 3389 + assert psd_3389.banner == "" + assert psd_3389.service == "tcp-3389" + + assert_port_status(port_scan_data, {445, 3389}) -def assert_host_down(victim_host): - assert victim_host.icmp is False - assert len(victim_host.services.keys()) == 0 +def assert_dot_3(ip, ping_scan_data, port_scan_data): + assert ip == "10.0.0.3" + + assert ping_scan_data.response_received is True + assert ping_scan_data.os == LINUX_OS + assert len(port_scan_data.keys()) == 6 + + psd_443 = port_scan_data[443] + psd_22 = port_scan_data[22] + + assert psd_443.port == 443 + assert psd_443.banner == "HTTPS BANNER" + assert psd_443.service == "tcp-443" + + assert psd_22.port == 22 + assert psd_22.banner == "SSH BANNER" + assert psd_22.service == "tcp-22" + + assert_port_status(port_scan_data, {22, 443}) + + +def assert_host_down(ip, ping_scan_data, port_scan_data): + assert ip not in {"10.0.0.1", "10.0.0.3"} + + assert ping_scan_data.response_received is False + assert len(port_scan_data.keys()) == 6 + assert_port_status(port_scan_data, {}) def test_scan_single_ip(callback, scan_config, stop): @@ -86,7 +117,8 @@ def test_scan_single_ip(callback, scan_config, stop): callback.assert_called_once() - assert_dot_1(callback.call_args_list[0][0][0]) + print(type(callback.call_args_list[0][0])) + assert_dot_1(*(callback.call_args_list[0][0])) def test_scan_multiple_ips(callback, scan_config, stop): @@ -97,10 +129,10 @@ def test_scan_multiple_ips(callback, scan_config, stop): assert callback.call_count == 4 - assert_dot_1(callback.call_args_list[0][0][0]) - assert_host_down(callback.call_args_list[1][0][0]) - assert_dot_3(callback.call_args_list[2][0][0]) - assert_host_down(callback.call_args_list[3][0][0]) + assert_dot_1(*(callback.call_args_list[0][0])) + assert_host_down(*(callback.call_args_list[1][0])) + assert_dot_3(*(callback.call_args_list[2][0])) + assert_host_down(*(callback.call_args_list[3][0])) def test_scan_lots_of_ips(callback, scan_config, stop): @@ -113,7 +145,7 @@ def test_scan_lots_of_ips(callback, scan_config, stop): def test_stop_after_callback(scan_config, stop): - def _callback(_): + def _callback(_, __, ___): # Block all threads here until 2 threads reach this barrier, then set stop # and test that niether thread continues to scan. _callback.barrier.wait()