Agent, UT: Change puppet interface to use scan_tcp_ports
Instead of using scan_tcp_port and scan each port seperately we can use scan_tcp_ports which will recieve list of ports for the specific host and return dictionary of port:PortScanData items. There was no point of scanning each port seperately.
This commit is contained in:
parent
f07c876d31
commit
0dae58baaf
|
@ -2,7 +2,7 @@ import abc
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
from . import PluginType
|
from . import PluginType
|
||||||
|
|
||||||
|
@ -64,14 +64,16 @@ class IPuppet(metaclass=abc.ABCMeta):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def scan_tcp_port(self, host: str, port: int, timeout: float) -> PortScanData:
|
def scan_tcp_ports(
|
||||||
|
self, host: str, ports: List[int], timeout: float = 3
|
||||||
|
) -> Dict[int, PortScanData]:
|
||||||
"""
|
"""
|
||||||
Scans a TCP port on a remote host
|
Scans a list of TCP ports on a remote host
|
||||||
:param str host: The domain name or IP address of a host
|
:param str host: The domain name or IP address of a host
|
||||||
:param int port: A TCP port number to scan
|
:param int ports: List of TCP port numbers to scan
|
||||||
:param float timeout: The maximum amount of time (in seconds) to wait for a response
|
:param float timeout: The maximum amount of time (in seconds) to wait for a response
|
||||||
:return: The data collected by scanning the provided host:port combination
|
:return: The data collected by scanning the provided host:ports combination
|
||||||
:rtype: PortScanData
|
:rtype: Dict[int, PortScanData]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
|
|
@ -59,7 +59,7 @@ class IPScanner:
|
||||||
logger.info(f"Scanning {address.ip}")
|
logger.info(f"Scanning {address.ip}")
|
||||||
|
|
||||||
ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
|
ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
|
||||||
port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop)
|
port_scan_data = self._puppet.scan_tcp_ports(address.ip, tcp_ports, tcp_timeout)
|
||||||
|
|
||||||
fingerprint_data = {}
|
fingerprint_data = {}
|
||||||
if IPScanner.port_scan_found_open_port(port_scan_data):
|
if IPScanner.port_scan_found_open_port(port_scan_data):
|
||||||
|
@ -80,16 +80,6 @@ class IPScanner:
|
||||||
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
|
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _scan_tcp_ports(
|
|
||||||
self, ip: str, ports: List[int], timeout: float, stop: Event
|
|
||||||
) -> Dict[int, PortScanData]:
|
|
||||||
port_scan_data = {}
|
|
||||||
|
|
||||||
for p in interruptable_iter(ports, stop):
|
|
||||||
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
|
|
||||||
|
|
||||||
return port_scan_data
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def port_scan_found_open_port(port_scan_data: Dict[int, PortScanData]):
|
def port_scan_found_open_port(port_scan_data: Dict[int, PortScanData]):
|
||||||
return any(psd.status == PortStatus.OPEN for psd in port_scan_data.values())
|
return any(psd.status == PortStatus.OPEN for psd in port_scan_data.values())
|
||||||
|
|
|
@ -70,14 +70,16 @@ class MockMaster(IMaster):
|
||||||
if ping_scan_data.os is not None:
|
if ping_scan_data.os is not None:
|
||||||
h.os["type"] = ping_scan_data.os
|
h.os["type"] = ping_scan_data.os
|
||||||
|
|
||||||
for p in ports:
|
ports_scan_data = self._puppet.scan_tcp_ports(ip, ports)
|
||||||
port_scan_data = self._puppet.scan_tcp_port(ip, p)
|
|
||||||
if port_scan_data.status == PortStatus.OPEN:
|
for psd in ports_scan_data.values():
|
||||||
h.services[port_scan_data.service] = {}
|
logger.debug(f"The port {psd.port} is {psd.status}")
|
||||||
h.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
|
if psd.status == PortStatus.OPEN:
|
||||||
h.services[port_scan_data.service]["port"] = port_scan_data.port
|
h.services[psd.service] = {}
|
||||||
if port_scan_data.banner is not None:
|
h.services[psd.service]["display_name"] = "unknown(TCP)"
|
||||||
h.services[port_scan_data.service]["banner"] = port_scan_data.banner
|
h.services[psd.service]["port"] = psd.port
|
||||||
|
if psd.banner is not None:
|
||||||
|
h.services[psd.service]["banner"] = psd.banner
|
||||||
|
|
||||||
self._telemetry_messenger.send_telemetry(ScanTelem(h))
|
self._telemetry_messenger.send_telemetry(ScanTelem(h))
|
||||||
logger.info("Finished scanning network for potential victims")
|
logger.info("Finished scanning network for potential victims")
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
from infection_monkey.i_puppet import (
|
from infection_monkey.i_puppet import (
|
||||||
ExploiterResultData,
|
ExploiterResultData,
|
||||||
|
@ -177,8 +177,10 @@ class MockPuppet(IPuppet):
|
||||||
|
|
||||||
return PingScanData(False, None)
|
return PingScanData(False, None)
|
||||||
|
|
||||||
def scan_tcp_port(self, host: str, port: int, timeout: int = 3) -> PortScanData:
|
def scan_tcp_ports(
|
||||||
logger.debug(f"run_scan_tcp_port({host}, {port}, {timeout})")
|
self, host: str, ports: List[int], timeout: float = 3
|
||||||
|
) -> Dict[int, PortScanData]:
|
||||||
|
logger.debug(f"run_scan_tcp_port({host}, {ports}, {timeout})")
|
||||||
dot_1_results = {
|
dot_1_results = {
|
||||||
22: PortScanData(22, PortStatus.CLOSED, None, None),
|
22: PortScanData(22, PortStatus.CLOSED, None, None),
|
||||||
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
|
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
|
||||||
|
@ -191,12 +193,12 @@ class MockPuppet(IPuppet):
|
||||||
}
|
}
|
||||||
|
|
||||||
if host == DOT_1:
|
if host == DOT_1:
|
||||||
return dot_1_results.get(port, _get_empty_results(port))
|
return {port: dot_1_results.get(port, _get_empty_results(port)) for port in ports}
|
||||||
|
|
||||||
if host == DOT_3:
|
if host == DOT_3:
|
||||||
return dot_3_results.get(port, _get_empty_results(port))
|
return {port: dot_3_results.get(port, _get_empty_results(port)) for port in ports}
|
||||||
|
|
||||||
return _get_empty_results(port)
|
return {port: _get_empty_results(port) for port in ports}
|
||||||
|
|
||||||
def fingerprint(
|
def fingerprint(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
from infection_monkey import network
|
from infection_monkey import network
|
||||||
from infection_monkey.i_puppet import (
|
from infection_monkey.i_puppet import (
|
||||||
|
@ -36,8 +36,10 @@ class Puppet(IPuppet):
|
||||||
def ping(self, host: str, timeout: float = 1) -> PingScanData:
|
def ping(self, host: str, timeout: float = 1) -> PingScanData:
|
||||||
return network.ping(host, timeout)
|
return network.ping(host, timeout)
|
||||||
|
|
||||||
def scan_tcp_port(self, host: str, port: int, timeout: float = 3) -> PortScanData:
|
def scan_tcp_ports(
|
||||||
return self._mock_puppet.scan_tcp_port(host, port, timeout)
|
self, host: str, ports: List[int], timeout: float = 3
|
||||||
|
) -> Dict[int, PortScanData]:
|
||||||
|
return self._mock_puppet.scan_tcp_ports(host, ports, timeout)
|
||||||
|
|
||||||
def fingerprint(
|
def fingerprint(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -222,19 +222,20 @@ def test_stop_after_callback(scan_config, stop):
|
||||||
assert stoppable_callback.call_count == 2
|
assert stoppable_callback.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
def test_interrupt_port_scanning(callback, scan_config, stop):
|
def test_interrupt_before_fingerprinting(callback, scan_config, stop):
|
||||||
def stoppable_scan_tcp_port(port, *_):
|
def stoppable_scan_tcp_ports(port, *_):
|
||||||
# Block all threads here until 2 threads reach this barrier, then set stop
|
# Block all threads here until 2 threads reach this barrier, then set stop
|
||||||
# and test that neither thread scans any more ports
|
# and test that neither thread scans any more ports
|
||||||
stoppable_scan_tcp_port.barrier.wait()
|
stoppable_scan_tcp_ports.barrier.wait()
|
||||||
stop.set()
|
stop.set()
|
||||||
|
|
||||||
return PortScanData(port, False, None, None)
|
return {port: PortScanData(port, False, None, None)}
|
||||||
|
|
||||||
stoppable_scan_tcp_port.barrier = Barrier(2)
|
stoppable_scan_tcp_ports.barrier = Barrier(2)
|
||||||
|
|
||||||
puppet = MockPuppet()
|
puppet = MockPuppet()
|
||||||
puppet.scan_tcp_port = MagicMock(side_effect=stoppable_scan_tcp_port)
|
puppet.scan_tcp_ports = MagicMock(side_effect=stoppable_scan_tcp_ports)
|
||||||
|
puppet.fingerprint = MagicMock()
|
||||||
|
|
||||||
addresses = [
|
addresses = [
|
||||||
NetworkAddress("10.0.0.1", None),
|
NetworkAddress("10.0.0.1", None),
|
||||||
|
@ -246,7 +247,7 @@ def test_interrupt_port_scanning(callback, scan_config, stop):
|
||||||
ns = IPScanner(puppet, num_workers=2)
|
ns = IPScanner(puppet, num_workers=2)
|
||||||
ns.scan(addresses, scan_config, callback, stop)
|
ns.scan(addresses, scan_config, callback, stop)
|
||||||
|
|
||||||
assert puppet.scan_tcp_port.call_count == 2
|
puppet.fingerprint.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_interrupt_fingerprinting(callback, scan_config, stop):
|
def test_interrupt_fingerprinting(callback, scan_config, stop):
|
||||||
|
|
Loading…
Reference in New Issue