Agent, UT: Change puppet interface to use scan_tcp_ports

Instead of using scan_tcp_port and scan each port seperately
we can use scan_tcp_ports which will recieve list of ports
for the specific host and return dictionary of port:PortScanData
items. There was no point of scanning each port seperately.
This commit is contained in:
Ilija Lazoroski 2022-02-07 15:50:39 +01:00 committed by Mike Salvatore
parent f07c876d31
commit 0dae58baaf
6 changed files with 40 additions and 41 deletions

View File

@ -2,7 +2,7 @@ import abc
import threading import threading
from collections import namedtuple from collections import namedtuple
from enum import Enum from enum import Enum
from typing import Dict from typing import Dict, List
from . import PluginType from . import PluginType
@ -64,14 +64,16 @@ class IPuppet(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def scan_tcp_port(self, host: str, port: int, timeout: float) -> PortScanData: def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = 3
) -> Dict[int, PortScanData]:
""" """
Scans a TCP port on a remote host Scans a list of TCP ports on a remote host
:param str host: The domain name or IP address of a host :param str host: The domain name or IP address of a host
:param int port: A TCP port number to scan :param int ports: List of TCP port numbers to scan
:param float timeout: The maximum amount of time (in seconds) to wait for a response :param float timeout: The maximum amount of time (in seconds) to wait for a response
:return: The data collected by scanning the provided host:port combination :return: The data collected by scanning the provided host:ports combination
:rtype: PortScanData :rtype: Dict[int, PortScanData]
""" """
@abc.abstractmethod @abc.abstractmethod

View File

@ -59,7 +59,7 @@ class IPScanner:
logger.info(f"Scanning {address.ip}") logger.info(f"Scanning {address.ip}")
ping_scan_data = self._puppet.ping(address.ip, icmp_timeout) ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop) port_scan_data = self._puppet.scan_tcp_ports(address.ip, tcp_ports, tcp_timeout)
fingerprint_data = {} fingerprint_data = {}
if IPScanner.port_scan_found_open_port(port_scan_data): if IPScanner.port_scan_found_open_port(port_scan_data):
@ -80,16 +80,6 @@ class IPScanner:
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting" f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
) )
def _scan_tcp_ports(
self, ip: str, ports: List[int], timeout: float, stop: Event
) -> Dict[int, PortScanData]:
port_scan_data = {}
for p in interruptable_iter(ports, stop):
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
return port_scan_data
@staticmethod @staticmethod
def port_scan_found_open_port(port_scan_data: Dict[int, PortScanData]): def port_scan_found_open_port(port_scan_data: Dict[int, PortScanData]):
return any(psd.status == PortStatus.OPEN for psd in port_scan_data.values()) return any(psd.status == PortStatus.OPEN for psd in port_scan_data.values())

View File

@ -70,14 +70,16 @@ class MockMaster(IMaster):
if ping_scan_data.os is not None: if ping_scan_data.os is not None:
h.os["type"] = ping_scan_data.os h.os["type"] = ping_scan_data.os
for p in ports: ports_scan_data = self._puppet.scan_tcp_ports(ip, ports)
port_scan_data = self._puppet.scan_tcp_port(ip, p)
if port_scan_data.status == PortStatus.OPEN: for psd in ports_scan_data.values():
h.services[port_scan_data.service] = {} logger.debug(f"The port {psd.port} is {psd.status}")
h.services[port_scan_data.service]["display_name"] = "unknown(TCP)" if psd.status == PortStatus.OPEN:
h.services[port_scan_data.service]["port"] = port_scan_data.port h.services[psd.service] = {}
if port_scan_data.banner is not None: h.services[psd.service]["display_name"] = "unknown(TCP)"
h.services[port_scan_data.service]["banner"] = port_scan_data.banner h.services[psd.service]["port"] = psd.port
if psd.banner is not None:
h.services[psd.service]["banner"] = psd.banner
self._telemetry_messenger.send_telemetry(ScanTelem(h)) self._telemetry_messenger.send_telemetry(ScanTelem(h))
logger.info("Finished scanning network for potential victims") logger.info("Finished scanning network for potential victims")

View File

@ -1,6 +1,6 @@
import logging import logging
import threading import threading
from typing import Dict from typing import Dict, List
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
@ -177,8 +177,10 @@ class MockPuppet(IPuppet):
return PingScanData(False, None) return PingScanData(False, None)
def scan_tcp_port(self, host: str, port: int, timeout: int = 3) -> PortScanData: def scan_tcp_ports(
logger.debug(f"run_scan_tcp_port({host}, {port}, {timeout})") self, host: str, ports: List[int], timeout: float = 3
) -> Dict[int, PortScanData]:
logger.debug(f"run_scan_tcp_port({host}, {ports}, {timeout})")
dot_1_results = { dot_1_results = {
22: PortScanData(22, PortStatus.CLOSED, None, None), 22: PortScanData(22, PortStatus.CLOSED, None, None),
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"), 445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
@ -191,12 +193,12 @@ class MockPuppet(IPuppet):
} }
if host == DOT_1: if host == DOT_1:
return dot_1_results.get(port, _get_empty_results(port)) return {port: dot_1_results.get(port, _get_empty_results(port)) for port in ports}
if host == DOT_3: if host == DOT_3:
return dot_3_results.get(port, _get_empty_results(port)) return {port: dot_3_results.get(port, _get_empty_results(port)) for port in ports}
return _get_empty_results(port) return {port: _get_empty_results(port) for port in ports}
def fingerprint( def fingerprint(
self, self,

View File

@ -1,6 +1,6 @@
import logging import logging
import threading import threading
from typing import Dict from typing import Dict, List
from infection_monkey import network from infection_monkey import network
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
@ -36,8 +36,10 @@ class Puppet(IPuppet):
def ping(self, host: str, timeout: float = 1) -> PingScanData: def ping(self, host: str, timeout: float = 1) -> PingScanData:
return network.ping(host, timeout) return network.ping(host, timeout)
def scan_tcp_port(self, host: str, port: int, timeout: float = 3) -> PortScanData: def scan_tcp_ports(
return self._mock_puppet.scan_tcp_port(host, port, timeout) self, host: str, ports: List[int], timeout: float = 3
) -> Dict[int, PortScanData]:
return self._mock_puppet.scan_tcp_ports(host, ports, timeout)
def fingerprint( def fingerprint(
self, self,

View File

@ -222,19 +222,20 @@ def test_stop_after_callback(scan_config, stop):
assert stoppable_callback.call_count == 2 assert stoppable_callback.call_count == 2
def test_interrupt_port_scanning(callback, scan_config, stop): def test_interrupt_before_fingerprinting(callback, scan_config, stop):
def stoppable_scan_tcp_port(port, *_): def stoppable_scan_tcp_ports(port, *_):
# Block all threads here until 2 threads reach this barrier, then set stop # Block all threads here until 2 threads reach this barrier, then set stop
# and test that neither thread scans any more ports # and test that neither thread scans any more ports
stoppable_scan_tcp_port.barrier.wait() stoppable_scan_tcp_ports.barrier.wait()
stop.set() stop.set()
return PortScanData(port, False, None, None) return {port: PortScanData(port, False, None, None)}
stoppable_scan_tcp_port.barrier = Barrier(2) stoppable_scan_tcp_ports.barrier = Barrier(2)
puppet = MockPuppet() puppet = MockPuppet()
puppet.scan_tcp_port = MagicMock(side_effect=stoppable_scan_tcp_port) puppet.scan_tcp_ports = MagicMock(side_effect=stoppable_scan_tcp_ports)
puppet.fingerprint = MagicMock()
addresses = [ addresses = [
NetworkAddress("10.0.0.1", None), NetworkAddress("10.0.0.1", None),
@ -246,7 +247,7 @@ def test_interrupt_port_scanning(callback, scan_config, stop):
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)
ns.scan(addresses, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert puppet.scan_tcp_port.call_count == 2 puppet.fingerprint.assert_not_called()
def test_interrupt_fingerprinting(callback, scan_config, stop): def test_interrupt_fingerprinting(callback, scan_config, stop):