Agent: Move VictimHost construction to AutomatedMaster

This commit is contained in:
Mike Salvatore 2021-12-10 10:34:19 -05:00
parent b3c520f272
commit 037d63c9f3
3 changed files with 100 additions and 60 deletions

View File

@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple
from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet
from infection_monkey.i_puppet import IPuppet, PingScanData, PortScanData, PortStatus
from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
@ -185,13 +185,34 @@ class AutomatedMaster(IMaster):
ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)]
scan_config = propagation_config["network_scan"]
self._ip_scanner.scan(ips_to_scan, scan_config, self._handle_scanned_host, self._stop)
self._ip_scanner.scan(ips_to_scan, scan_config, self._process_scan_results, self._stop)
logger.info("Finished network scan")
def _handle_scanned_host(self, host: VictimHost):
self._hosts_to_exploit.put(host)
self._telemetry_messenger.send_telemetry(ScanTelem(host))
def _process_scan_results(
self, ip: str, ping_scan_data: PingScanData, port_scan_data: PortScanData
):
victim_host = VictimHost(ip)
has_open_port = False
victim_host.icmp = ping_scan_data.response_received
if ping_scan_data.os is not None:
victim_host.os["type"] = ping_scan_data.os
for psd in port_scan_data.values():
if psd.status == PortStatus.OPEN:
has_open_port = True
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
if has_open_port:
self._hosts_to_exploit.put(victim_host)
self._telemetry_messenger.send_telemetry(ScanTelem(victim_host))
def _exploit_targets(self, scan_thread: Thread):
pass

View File

@ -5,14 +5,13 @@ from queue import Queue
from threading import Event
from typing import Callable, Dict, List
from infection_monkey.i_puppet import IPuppet, PortStatus
from infection_monkey.model.host import VictimHost
from infection_monkey.i_puppet import IPuppet, PingScanData, PortScanData
from .threading_utils import create_daemon_thread
logger = logging.getLogger()
Callback = Callable[[VictimHost], None]
Callback = Callable[[str, PingScanData, Dict[int, PortScanData]], None]
class IPScanner:
@ -45,12 +44,10 @@ class IPScanner:
ip = ips.get_nowait()
logger.info(f"Scanning {ip}")
victim_host = VictimHost(ip)
ping_scan_data = self._puppet.ping(ip, options["icmp"])
port_scan_data = self._scan_tcp_ports(ip, options["tcp"], stop)
self._ping_ip(ip, victim_host, options["icmp"])
self._scan_tcp_ports(ip, victim_host, options["tcp"], stop)
results_callback(victim_host)
results_callback(ip, ping_scan_data, port_scan_data)
except queue.Empty:
logger.debug(
@ -60,22 +57,12 @@ class IPScanner:
logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting")
def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict):
ping_scan_data = self._puppet.ping(ip, options)
victim_host.icmp = ping_scan_data.response_received
if ping_scan_data.os is not None:
victim_host.os["type"] = ping_scan_data.os
def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict, stop: Event):
def _scan_tcp_ports(self, ip: str, options: Dict, stop: Event):
port_scan_data = {}
for p in options["ports"]:
if stop.is_set():
break
port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
if port_scan_data.status == PortStatus.OPEN:
victim_host.services[port_scan_data.service] = {}
victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
victim_host.services[port_scan_data.service]["port"] = port_scan_data.port
if port_scan_data.banner is not None:
victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
return port_scan_data

View File

@ -1,14 +1,15 @@
from threading import Barrier, Event
from typing import Set
from unittest.mock import MagicMock
import pytest
from infection_monkey.i_puppet import PortScanData
from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.master import IPScanner
from infection_monkey.puppet.mock_puppet import MockPuppet
WINDOWS_OS = {"type": "windows"}
LINUX_OS = {"type": "linux"}
WINDOWS_OS = "windows"
LINUX_OS = "linux"
class MockPuppet(MockPuppet):
@ -47,35 +48,65 @@ def callback():
return MagicMock()
def assert_dot_1(victim_host):
assert victim_host.icmp is True
assert victim_host.os == WINDOWS_OS
assert len(victim_host.services.keys()) == 2
assert "tcp-445" in victim_host.services
assert victim_host.services["tcp-445"]["port"] == 445
assert victim_host.services["tcp-445"]["banner"] == "SMB BANNER"
assert "tcp-3389" in victim_host.services
assert victim_host.services["tcp-3389"]["port"] == 3389
def assert_port_status(port_scan_data, expected_open_ports: Set[int]):
for psd in port_scan_data.values():
if psd.port in expected_open_ports:
assert psd.status == PortStatus.OPEN
else:
assert psd.status == PortStatus.CLOSED
def assert_dot_3(victim_host):
assert victim_host.icmp is True
assert victim_host.os == LINUX_OS
def assert_dot_1(ip, ping_scan_data, port_scan_data):
assert ip == "10.0.0.1"
assert len(victim_host.services.keys()) == 2
assert "tcp-22" in victim_host.services
assert victim_host.services["tcp-22"]["port"] == 22
assert victim_host.services["tcp-22"]["banner"] == "SSH BANNER"
assert ping_scan_data.response_received is True
assert ping_scan_data.os == WINDOWS_OS
assert "tcp-443" in victim_host.services
assert victim_host.services["tcp-443"]["port"] == 443
assert victim_host.services["tcp-443"]["banner"] == "HTTPS BANNER"
assert len(port_scan_data.keys()) == 6
psd_445 = port_scan_data[445]
psd_3389 = port_scan_data[3389]
assert psd_445.status == PortStatus.OPEN
assert psd_445.port == 445
assert psd_445.banner == "SMB BANNER"
assert psd_445.service == "tcp-445"
assert psd_3389.status == PortStatus.OPEN
assert psd_3389.port == 3389
assert psd_3389.banner == ""
assert psd_3389.service == "tcp-3389"
assert_port_status(port_scan_data, {445, 3389})
def assert_host_down(victim_host):
assert victim_host.icmp is False
assert len(victim_host.services.keys()) == 0
def assert_dot_3(ip, ping_scan_data, port_scan_data):
assert ip == "10.0.0.3"
assert ping_scan_data.response_received is True
assert ping_scan_data.os == LINUX_OS
assert len(port_scan_data.keys()) == 6
psd_443 = port_scan_data[443]
psd_22 = port_scan_data[22]
assert psd_443.port == 443
assert psd_443.banner == "HTTPS BANNER"
assert psd_443.service == "tcp-443"
assert psd_22.port == 22
assert psd_22.banner == "SSH BANNER"
assert psd_22.service == "tcp-22"
assert_port_status(port_scan_data, {22, 443})
def assert_host_down(ip, ping_scan_data, port_scan_data):
assert ip not in {"10.0.0.1", "10.0.0.3"}
assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6
assert_port_status(port_scan_data, {})
def test_scan_single_ip(callback, scan_config, stop):
@ -86,7 +117,8 @@ def test_scan_single_ip(callback, scan_config, stop):
callback.assert_called_once()
assert_dot_1(callback.call_args_list[0][0][0])
print(type(callback.call_args_list[0][0]))
assert_dot_1(*(callback.call_args_list[0][0]))
def test_scan_multiple_ips(callback, scan_config, stop):
@ -97,10 +129,10 @@ def test_scan_multiple_ips(callback, scan_config, stop):
assert callback.call_count == 4
assert_dot_1(callback.call_args_list[0][0][0])
assert_host_down(callback.call_args_list[1][0][0])
assert_dot_3(callback.call_args_list[2][0][0])
assert_host_down(callback.call_args_list[3][0][0])
assert_dot_1(*(callback.call_args_list[0][0]))
assert_host_down(*(callback.call_args_list[1][0]))
assert_dot_3(*(callback.call_args_list[2][0]))
assert_host_down(*(callback.call_args_list[3][0]))
def test_scan_lots_of_ips(callback, scan_config, stop):
@ -113,7 +145,7 @@ def test_scan_lots_of_ips(callback, scan_config, stop):
def test_stop_after_callback(scan_config, stop):
def _callback(_):
def _callback(_, __, ___):
# Block all threads here until 2 threads reach this barrier, then set stop
# and test that niether thread continues to scan.
_callback.barrier.wait()