From 0dae58baaf5400a309eda47f5268682947a0090a Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Mon, 7 Feb 2022 15:50:39 +0100 Subject: [PATCH] Agent, UT: Change puppet interface to use scan_tcp_ports Instead of using scan_tcp_port and scan each port seperately we can use scan_tcp_ports which will recieve list of ports for the specific host and return dictionary of port:PortScanData items. There was no point of scanning each port seperately. --- monkey/infection_monkey/i_puppet/i_puppet.py | 14 ++++++++------ monkey/infection_monkey/master/ip_scanner.py | 12 +----------- monkey/infection_monkey/master/mock_master.py | 18 ++++++++++-------- monkey/infection_monkey/puppet/mock_puppet.py | 14 ++++++++------ monkey/infection_monkey/puppet/puppet.py | 8 +++++--- .../infection_monkey/master/test_ip_scanner.py | 15 ++++++++------- 6 files changed, 40 insertions(+), 41 deletions(-) diff --git a/monkey/infection_monkey/i_puppet/i_puppet.py b/monkey/infection_monkey/i_puppet/i_puppet.py index 3fa2aabd9..c0a42d95c 100644 --- a/monkey/infection_monkey/i_puppet/i_puppet.py +++ b/monkey/infection_monkey/i_puppet/i_puppet.py @@ -2,7 +2,7 @@ import abc import threading from collections import namedtuple from enum import Enum -from typing import Dict +from typing import Dict, List from . import PluginType @@ -64,14 +64,16 @@ class IPuppet(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def scan_tcp_port(self, host: str, port: int, timeout: float) -> PortScanData: + def scan_tcp_ports( + self, host: str, ports: List[int], timeout: float = 3 + ) -> Dict[int, PortScanData]: """ - Scans a TCP port on a remote host + Scans a list of TCP ports on a remote host :param str host: The domain name or IP address of a host - :param int port: A TCP port number to scan + :param int ports: List of TCP port numbers to scan :param float timeout: The maximum amount of time (in seconds) to wait for a response - :return: The data collected by scanning the provided host:port combination - :rtype: PortScanData + :return: The data collected by scanning the provided host:ports combination + :rtype: Dict[int, PortScanData] """ @abc.abstractmethod diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py index 0f7132a27..135f79c94 100644 --- a/monkey/infection_monkey/master/ip_scanner.py +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -59,7 +59,7 @@ class IPScanner: logger.info(f"Scanning {address.ip}") ping_scan_data = self._puppet.ping(address.ip, icmp_timeout) - port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop) + port_scan_data = self._puppet.scan_tcp_ports(address.ip, tcp_ports, tcp_timeout) fingerprint_data = {} if IPScanner.port_scan_found_open_port(port_scan_data): @@ -80,16 +80,6 @@ class IPScanner: f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting" ) - def _scan_tcp_ports( - self, ip: str, ports: List[int], timeout: float, stop: Event - ) -> Dict[int, PortScanData]: - port_scan_data = {} - - for p in interruptable_iter(ports, stop): - port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout) - - return port_scan_data - @staticmethod def port_scan_found_open_port(port_scan_data: Dict[int, PortScanData]): return any(psd.status == PortStatus.OPEN for psd in port_scan_data.values()) diff --git a/monkey/infection_monkey/master/mock_master.py b/monkey/infection_monkey/master/mock_master.py index 31d4d83a7..274f960f8 100644 --- a/monkey/infection_monkey/master/mock_master.py +++ b/monkey/infection_monkey/master/mock_master.py @@ -70,14 +70,16 @@ class MockMaster(IMaster): if ping_scan_data.os is not None: h.os["type"] = ping_scan_data.os - for p in ports: - port_scan_data = self._puppet.scan_tcp_port(ip, p) - if port_scan_data.status == PortStatus.OPEN: - h.services[port_scan_data.service] = {} - h.services[port_scan_data.service]["display_name"] = "unknown(TCP)" - h.services[port_scan_data.service]["port"] = port_scan_data.port - if port_scan_data.banner is not None: - h.services[port_scan_data.service]["banner"] = port_scan_data.banner + ports_scan_data = self._puppet.scan_tcp_ports(ip, ports) + + for psd in ports_scan_data.values(): + logger.debug(f"The port {psd.port} is {psd.status}") + if psd.status == PortStatus.OPEN: + h.services[psd.service] = {} + h.services[psd.service]["display_name"] = "unknown(TCP)" + h.services[psd.service]["port"] = psd.port + if psd.banner is not None: + h.services[psd.service]["banner"] = psd.banner self._telemetry_messenger.send_telemetry(ScanTelem(h)) logger.info("Finished scanning network for potential victims") diff --git a/monkey/infection_monkey/puppet/mock_puppet.py b/monkey/infection_monkey/puppet/mock_puppet.py index 182ebe55e..d35ec2cbb 100644 --- a/monkey/infection_monkey/puppet/mock_puppet.py +++ b/monkey/infection_monkey/puppet/mock_puppet.py @@ -1,6 +1,6 @@ import logging import threading -from typing import Dict +from typing import Dict, List from infection_monkey.i_puppet import ( ExploiterResultData, @@ -177,8 +177,10 @@ class MockPuppet(IPuppet): return PingScanData(False, None) - def scan_tcp_port(self, host: str, port: int, timeout: int = 3) -> PortScanData: - logger.debug(f"run_scan_tcp_port({host}, {port}, {timeout})") + def scan_tcp_ports( + self, host: str, ports: List[int], timeout: float = 3 + ) -> Dict[int, PortScanData]: + logger.debug(f"run_scan_tcp_port({host}, {ports}, {timeout})") dot_1_results = { 22: PortScanData(22, PortStatus.CLOSED, None, None), 445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"), @@ -191,12 +193,12 @@ class MockPuppet(IPuppet): } if host == DOT_1: - return dot_1_results.get(port, _get_empty_results(port)) + return {port: dot_1_results.get(port, _get_empty_results(port)) for port in ports} if host == DOT_3: - return dot_3_results.get(port, _get_empty_results(port)) + return {port: dot_3_results.get(port, _get_empty_results(port)) for port in ports} - return _get_empty_results(port) + return {port: _get_empty_results(port) for port in ports} def fingerprint( self, diff --git a/monkey/infection_monkey/puppet/puppet.py b/monkey/infection_monkey/puppet/puppet.py index af550e4cb..175a3f0eb 100644 --- a/monkey/infection_monkey/puppet/puppet.py +++ b/monkey/infection_monkey/puppet/puppet.py @@ -1,6 +1,6 @@ import logging import threading -from typing import Dict +from typing import Dict, List from infection_monkey import network from infection_monkey.i_puppet import ( @@ -36,8 +36,10 @@ class Puppet(IPuppet): def ping(self, host: str, timeout: float = 1) -> PingScanData: return network.ping(host, timeout) - def scan_tcp_port(self, host: str, port: int, timeout: float = 3) -> PortScanData: - return self._mock_puppet.scan_tcp_port(host, port, timeout) + def scan_tcp_ports( + self, host: str, ports: List[int], timeout: float = 3 + ) -> Dict[int, PortScanData]: + return self._mock_puppet.scan_tcp_ports(host, ports, timeout) def fingerprint( self, diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_ip_scanner.py b/monkey/tests/unit_tests/infection_monkey/master/test_ip_scanner.py index 93762e44e..59bb6bf77 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_ip_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_ip_scanner.py @@ -222,19 +222,20 @@ def test_stop_after_callback(scan_config, stop): assert stoppable_callback.call_count == 2 -def test_interrupt_port_scanning(callback, scan_config, stop): - def stoppable_scan_tcp_port(port, *_): +def test_interrupt_before_fingerprinting(callback, scan_config, stop): + def stoppable_scan_tcp_ports(port, *_): # Block all threads here until 2 threads reach this barrier, then set stop # and test that neither thread scans any more ports - stoppable_scan_tcp_port.barrier.wait() + stoppable_scan_tcp_ports.barrier.wait() stop.set() - return PortScanData(port, False, None, None) + return {port: PortScanData(port, False, None, None)} - stoppable_scan_tcp_port.barrier = Barrier(2) + stoppable_scan_tcp_ports.barrier = Barrier(2) puppet = MockPuppet() - puppet.scan_tcp_port = MagicMock(side_effect=stoppable_scan_tcp_port) + puppet.scan_tcp_ports = MagicMock(side_effect=stoppable_scan_tcp_ports) + puppet.fingerprint = MagicMock() addresses = [ NetworkAddress("10.0.0.1", None), @@ -246,7 +247,7 @@ def test_interrupt_port_scanning(callback, scan_config, stop): ns = IPScanner(puppet, num_workers=2) ns.scan(addresses, scan_config, callback, stop) - assert puppet.scan_tcp_port.call_count == 2 + puppet.fingerprint.assert_not_called() def test_interrupt_fingerprinting(callback, scan_config, stop):