Agent: Add IPScanResults dataclass

This commit is contained in:
Mike Salvatore 2021-12-13 13:28:40 -05:00
parent 8067dc9ff8
commit d51af8a583
6 changed files with 52 additions and 43 deletions

View File

@ -1,3 +1,4 @@
from .ip_scan_results import IPScanResults
from .ip_scanner import IPScanner from .ip_scanner import IPScanner
from .propagator import Propagator from .propagator import Propagator
from .automated_master import AutomatedMaster from .automated_master import AutomatedMaster

View File

@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Dict
from infection_monkey.i_puppet import FingerprintData, PingScanData, PortScanData
Port = int
FingerprinterName = str
@dataclass
class IPScanResults:
ping_scan_data: PingScanData
port_scan_data: Dict[Port, PortScanData]
fingerprint_data: Dict[FingerprinterName, FingerprintData]

View File

@ -5,24 +5,15 @@ from queue import Queue
from threading import Event from threading import Event
from typing import Callable, Dict, List from typing import Callable, Dict, List
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import FingerprintData, IPuppet, PortScanData, PortStatus
FingerprintData,
IPuppet,
PingScanData,
PortScanData,
PortStatus,
)
from . import IPScanResults
from .threading_utils import create_daemon_thread from .threading_utils import create_daemon_thread
logger = logging.getLogger() logger = logging.getLogger()
IP = str IP = str
Port = int Callback = Callable[[IP, IPScanResults], None]
FingerprinterName = str
Callback = Callable[
[IP, PingScanData, Dict[Port, PortScanData], Dict[FingerprinterName, FingerprintData]], None
]
class IPScanner: class IPScanner:
@ -67,7 +58,8 @@ class IPScanner:
fingerprinters = options["fingerprinters"] fingerprinters = options["fingerprinters"]
fingerprint_data = self._run_fingerprinters(ip, fingerprinters, stop) fingerprint_data = self._run_fingerprinters(ip, fingerprinters, stop)
results_callback(ip, ping_scan_data, port_scan_data, fingerprint_data) scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data)
results_callback(ip, scan_results)
logger.debug( logger.debug(
f"Detected the stop signal, scanning thread {threading.get_ident()} exiting" f"Detected the stop signal, scanning thread {threading.get_ident()} exiting"
@ -99,7 +91,9 @@ class IPScanner:
return False return False
def _run_fingerprinters(self, ip: str, fingerprinters: List[str], stop: Event): def _run_fingerprinters(
self, ip: str, fingerprinters: List[str], stop: Event
) -> Dict[str, FingerprintData]:
fingerprint_data = {} fingerprint_data = {}
for f in fingerprinters: for f in fingerprinters:

View File

@ -8,7 +8,7 @@ from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.scan_telem import ScanTelem from infection_monkey.telemetry.scan_telem import ScanTelem
from . import IPScanner from . import IPScanner, IPScanResults
from .threading_utils import create_daemon_thread from .threading_utils import create_daemon_thread
logger = logging.getLogger() logger = logging.getLogger()
@ -51,18 +51,14 @@ class Propagator:
logger.info("Finished network scan") logger.info("Finished network scan")
def _process_scan_results( def _process_scan_results(self, ip: str, scan_results: IPScanResults):
self,
ip: str,
ping_scan_data: PingScanData,
port_scan_data: Dict[int, PortScanData],
fingerprint_data: Dict[str, FingerprintData],
):
victim_host = VictimHost(ip) victim_host = VictimHost(ip)
Propagator._process_ping_scan_results(victim_host, ping_scan_data) Propagator._process_ping_scan_results(victim_host, scan_results.ping_scan_data)
has_open_port = Propagator._process_tcp_scan_results(victim_host, port_scan_data) has_open_port = Propagator._process_tcp_scan_results(
Propagator._process_fingerprinter_results(victim_host, fingerprint_data) victim_host, scan_results.port_scan_data
)
Propagator._process_fingerprinter_results(victim_host, scan_results.fingerprint_data)
if has_open_port: if has_open_port:
self._hosts_to_exploit.put(victim_host) self._hosts_to_exploit.put(victim_host)

View File

@ -51,7 +51,11 @@ def assert_port_status(port_scan_data, expected_open_ports: Set[int]):
assert psd.status == PortStatus.CLOSED assert psd.status == PortStatus.CLOSED
def assert_scan_results(ip, ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results(ip, scan_results):
ping_scan_data = scan_results.ping_scan_data
port_scan_data = scan_results.port_scan_data
fingerprint_data = scan_results.fingerprint_data
if ip == "10.0.0.1": if ip == "10.0.0.1":
assert_scan_results_no_1(ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results_no_1(ping_scan_data, port_scan_data, fingerprint_data)
elif ip == "10.0.0.3": elif ip == "10.0.0.3":
@ -149,8 +153,8 @@ def test_scan_single_ip(callback, scan_config, stop):
callback.assert_called_once() callback.assert_called_once()
(ip, ping_scan_data, port_scan_data, fingerprint_data) = callback.call_args_list[0][0] (ip, scan_results) = callback.call_args_list[0][0]
assert_scan_results(ip, ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results(ip, scan_results)
def test_scan_multiple_ips(callback, scan_config, stop): def test_scan_multiple_ips(callback, scan_config, stop):
@ -161,17 +165,17 @@ def test_scan_multiple_ips(callback, scan_config, stop):
assert callback.call_count == 4 assert callback.call_count == 4
(ip, ping_scan_data, port_scan_data, fingerprint_data) = callback.call_args_list[0][0] (ip, scan_results) = callback.call_args_list[0][0]
assert_scan_results(ip, ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results(ip, scan_results)
(ip, ping_scan_data, port_scan_data, fingerprint_data) = callback.call_args_list[1][0] (ip, scan_results) = callback.call_args_list[1][0]
assert_scan_results(ip, ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results(ip, scan_results)
(ip, ping_scan_data, port_scan_data, fingerprint_data) = callback.call_args_list[2][0] (ip, scan_results) = callback.call_args_list[2][0]
assert_scan_results(ip, ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results(ip, scan_results)
(ip, ping_scan_data, port_scan_data, fingerprint_data) = callback.call_args_list[3][0] (ip, scan_results) = callback.call_args_list[3][0]
assert_scan_results(ip, ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results(ip, scan_results)
def test_scan_lots_of_ips(callback, scan_config, stop): def test_scan_lots_of_ips(callback, scan_config, stop):

View File

@ -1,11 +1,11 @@
from threading import Event from threading import Event
from infection_monkey.i_puppet import FingerprintData, PingScanData, PortScanData, PortStatus from infection_monkey.i_puppet import FingerprintData, PingScanData, PortScanData, PortStatus
from infection_monkey.master import Propagator from infection_monkey.master import IPScanResults, Propagator
empty_fingerprint_data = FingerprintData(None, None, {}) empty_fingerprint_data = FingerprintData(None, None, {})
dot_1_results = ( dot_1_results = IPScanResults(
PingScanData(True, "windows"), PingScanData(True, "windows"),
{ {
22: PortScanData(22, PortStatus.CLOSED, None, None), 22: PortScanData(22, PortStatus.CLOSED, None, None),
@ -19,7 +19,7 @@ dot_1_results = (
}, },
) )
dot_3_results = ( dot_3_results = IPScanResults(
PingScanData(True, "linux"), PingScanData(True, "linux"),
{ {
22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"), 22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"),
@ -42,7 +42,7 @@ dot_3_results = (
}, },
) )
dead_host_results = ( dead_host_results = IPScanResults(
PingScanData(False, None), PingScanData(False, None),
{ {
22: PortScanData(22, PortStatus.CLOSED, None, None), 22: PortScanData(22, PortStatus.CLOSED, None, None),
@ -79,11 +79,11 @@ class MockIPScanner:
def scan(self, ips_to_scan, _, results_callback, stop): def scan(self, ips_to_scan, _, results_callback, stop):
for ip in ips_to_scan: for ip in ips_to_scan:
if ip.endswith(".1"): if ip.endswith(".1"):
results_callback(ip, *dot_1_results) results_callback(ip, dot_1_results)
elif ip.endswith(".3"): elif ip.endswith(".3"):
results_callback(ip, *dot_3_results) results_callback(ip, dot_3_results)
else: else:
results_callback(ip, *dead_host_results) results_callback(ip, dead_host_results)
def test_scan_result_processing(telemetry_messenger_spy): def test_scan_result_processing(telemetry_messenger_spy):