Agent, UT: Change puppet interface to use scan_tcp_ports

Instead of using scan_tcp_port and scan each port seperately
we can use scan_tcp_ports which will recieve list of ports
for the specific host and return dictionary of port:PortScanData
items. There was no point of scanning each port seperately.
This commit is contained in:
Ilija Lazoroski 2022-02-07 15:50:39 +01:00 committed by Mike Salvatore
parent f07c876d31
commit 0dae58baaf
6 changed files with 40 additions and 41 deletions

View File

@ -2,7 +2,7 @@ import abc
import threading
from collections import namedtuple
from enum import Enum
from typing import Dict
from typing import Dict, List
from . import PluginType
@ -64,14 +64,16 @@ class IPuppet(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def scan_tcp_port(self, host: str, port: int, timeout: float) -> PortScanData:
def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = 3
) -> Dict[int, PortScanData]:
"""
Scans a TCP port on a remote host
Scans a list of TCP ports on a remote host
:param str host: The domain name or IP address of a host
:param int port: A TCP port number to scan
:param int ports: List of TCP port numbers to scan
:param float timeout: The maximum amount of time (in seconds) to wait for a response
:return: The data collected by scanning the provided host:port combination
:rtype: PortScanData
:return: The data collected by scanning the provided host:ports combination
:rtype: Dict[int, PortScanData]
"""
@abc.abstractmethod

View File

@ -59,7 +59,7 @@ class IPScanner:
logger.info(f"Scanning {address.ip}")
ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop)
port_scan_data = self._puppet.scan_tcp_ports(address.ip, tcp_ports, tcp_timeout)
fingerprint_data = {}
if IPScanner.port_scan_found_open_port(port_scan_data):
@ -80,16 +80,6 @@ class IPScanner:
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
)
def _scan_tcp_ports(
self, ip: str, ports: List[int], timeout: float, stop: Event
) -> Dict[int, PortScanData]:
port_scan_data = {}
for p in interruptable_iter(ports, stop):
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
return port_scan_data
@staticmethod
def port_scan_found_open_port(port_scan_data: Dict[int, PortScanData]):
return any(psd.status == PortStatus.OPEN for psd in port_scan_data.values())

View File

@ -70,14 +70,16 @@ class MockMaster(IMaster):
if ping_scan_data.os is not None:
h.os["type"] = ping_scan_data.os
for p in ports:
port_scan_data = self._puppet.scan_tcp_port(ip, p)
if port_scan_data.status == PortStatus.OPEN:
h.services[port_scan_data.service] = {}
h.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
h.services[port_scan_data.service]["port"] = port_scan_data.port
if port_scan_data.banner is not None:
h.services[port_scan_data.service]["banner"] = port_scan_data.banner
ports_scan_data = self._puppet.scan_tcp_ports(ip, ports)
for psd in ports_scan_data.values():
logger.debug(f"The port {psd.port} is {psd.status}")
if psd.status == PortStatus.OPEN:
h.services[psd.service] = {}
h.services[psd.service]["display_name"] = "unknown(TCP)"
h.services[psd.service]["port"] = psd.port
if psd.banner is not None:
h.services[psd.service]["banner"] = psd.banner
self._telemetry_messenger.send_telemetry(ScanTelem(h))
logger.info("Finished scanning network for potential victims")

View File

@ -1,6 +1,6 @@
import logging
import threading
from typing import Dict
from typing import Dict, List
from infection_monkey.i_puppet import (
ExploiterResultData,
@ -177,8 +177,10 @@ class MockPuppet(IPuppet):
return PingScanData(False, None)
def scan_tcp_port(self, host: str, port: int, timeout: int = 3) -> PortScanData:
logger.debug(f"run_scan_tcp_port({host}, {port}, {timeout})")
def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = 3
) -> Dict[int, PortScanData]:
logger.debug(f"run_scan_tcp_port({host}, {ports}, {timeout})")
dot_1_results = {
22: PortScanData(22, PortStatus.CLOSED, None, None),
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
@ -191,12 +193,12 @@ class MockPuppet(IPuppet):
}
if host == DOT_1:
return dot_1_results.get(port, _get_empty_results(port))
return {port: dot_1_results.get(port, _get_empty_results(port)) for port in ports}
if host == DOT_3:
return dot_3_results.get(port, _get_empty_results(port))
return {port: dot_3_results.get(port, _get_empty_results(port)) for port in ports}
return _get_empty_results(port)
return {port: _get_empty_results(port) for port in ports}
def fingerprint(
self,

View File

@ -1,6 +1,6 @@
import logging
import threading
from typing import Dict
from typing import Dict, List
from infection_monkey import network
from infection_monkey.i_puppet import (
@ -36,8 +36,10 @@ class Puppet(IPuppet):
def ping(self, host: str, timeout: float = 1) -> PingScanData:
return network.ping(host, timeout)
def scan_tcp_port(self, host: str, port: int, timeout: float = 3) -> PortScanData:
return self._mock_puppet.scan_tcp_port(host, port, timeout)
def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = 3
) -> Dict[int, PortScanData]:
return self._mock_puppet.scan_tcp_ports(host, ports, timeout)
def fingerprint(
self,

View File

@ -222,19 +222,20 @@ def test_stop_after_callback(scan_config, stop):
assert stoppable_callback.call_count == 2
def test_interrupt_port_scanning(callback, scan_config, stop):
def stoppable_scan_tcp_port(port, *_):
def test_interrupt_before_fingerprinting(callback, scan_config, stop):
def stoppable_scan_tcp_ports(port, *_):
# Block all threads here until 2 threads reach this barrier, then set stop
# and test that neither thread scans any more ports
stoppable_scan_tcp_port.barrier.wait()
stoppable_scan_tcp_ports.barrier.wait()
stop.set()
return PortScanData(port, False, None, None)
return {port: PortScanData(port, False, None, None)}
stoppable_scan_tcp_port.barrier = Barrier(2)
stoppable_scan_tcp_ports.barrier = Barrier(2)
puppet = MockPuppet()
puppet.scan_tcp_port = MagicMock(side_effect=stoppable_scan_tcp_port)
puppet.scan_tcp_ports = MagicMock(side_effect=stoppable_scan_tcp_ports)
puppet.fingerprint = MagicMock()
addresses = [
NetworkAddress("10.0.0.1", None),
@ -246,7 +247,7 @@ def test_interrupt_port_scanning(callback, scan_config, stop):
ns = IPScanner(puppet, num_workers=2)
ns.scan(addresses, scan_config, callback, stop)
assert puppet.scan_tcp_port.call_count == 2
puppet.fingerprint.assert_not_called()
def test_interrupt_fingerprinting(callback, scan_config, stop):