Agent: Move VictimHost construction to AutomatedMaster

This commit is contained in:
Mike Salvatore 2021-12-10 10:34:19 -05:00
parent b3c520f272
commit 037d63c9f3
3 changed files with 100 additions and 60 deletions

View File

@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet from infection_monkey.i_puppet import IPuppet, PingScanData, PortScanData, PortStatus
from infection_monkey.model.host import VictimHost from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
@ -185,13 +185,34 @@ class AutomatedMaster(IMaster):
ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)] ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)]
scan_config = propagation_config["network_scan"] scan_config = propagation_config["network_scan"]
self._ip_scanner.scan(ips_to_scan, scan_config, self._handle_scanned_host, self._stop) self._ip_scanner.scan(ips_to_scan, scan_config, self._process_scan_results, self._stop)
logger.info("Finished network scan") logger.info("Finished network scan")
def _handle_scanned_host(self, host: VictimHost): def _process_scan_results(
self._hosts_to_exploit.put(host) self, ip: str, ping_scan_data: PingScanData, port_scan_data: PortScanData
self._telemetry_messenger.send_telemetry(ScanTelem(host)) ):
victim_host = VictimHost(ip)
has_open_port = False
victim_host.icmp = ping_scan_data.response_received
if ping_scan_data.os is not None:
victim_host.os["type"] = ping_scan_data.os
for psd in port_scan_data.values():
if psd.status == PortStatus.OPEN:
has_open_port = True
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
if has_open_port:
self._hosts_to_exploit.put(victim_host)
self._telemetry_messenger.send_telemetry(ScanTelem(victim_host))
def _exploit_targets(self, scan_thread: Thread): def _exploit_targets(self, scan_thread: Thread):
pass pass

View File

@ -5,14 +5,13 @@ from queue import Queue
from threading import Event from threading import Event
from typing import Callable, Dict, List from typing import Callable, Dict, List
from infection_monkey.i_puppet import IPuppet, PortStatus from infection_monkey.i_puppet import IPuppet, PingScanData, PortScanData
from infection_monkey.model.host import VictimHost
from .threading_utils import create_daemon_thread from .threading_utils import create_daemon_thread
logger = logging.getLogger() logger = logging.getLogger()
Callback = Callable[[VictimHost], None] Callback = Callable[[str, PingScanData, Dict[int, PortScanData]], None]
class IPScanner: class IPScanner:
@ -45,12 +44,10 @@ class IPScanner:
ip = ips.get_nowait() ip = ips.get_nowait()
logger.info(f"Scanning {ip}") logger.info(f"Scanning {ip}")
victim_host = VictimHost(ip) ping_scan_data = self._puppet.ping(ip, options["icmp"])
port_scan_data = self._scan_tcp_ports(ip, options["tcp"], stop)
self._ping_ip(ip, victim_host, options["icmp"]) results_callback(ip, ping_scan_data, port_scan_data)
self._scan_tcp_ports(ip, victim_host, options["tcp"], stop)
results_callback(victim_host)
except queue.Empty: except queue.Empty:
logger.debug( logger.debug(
@ -60,22 +57,12 @@ class IPScanner:
logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting") logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting")
def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict): def _scan_tcp_ports(self, ip: str, options: Dict, stop: Event):
ping_scan_data = self._puppet.ping(ip, options) port_scan_data = {}
victim_host.icmp = ping_scan_data.response_received
if ping_scan_data.os is not None:
victim_host.os["type"] = ping_scan_data.os
def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict, stop: Event):
for p in options["ports"]: for p in options["ports"]:
if stop.is_set(): if stop.is_set():
break break
port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"]) port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
if port_scan_data.status == PortStatus.OPEN:
victim_host.services[port_scan_data.service] = {} return port_scan_data
victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
victim_host.services[port_scan_data.service]["port"] = port_scan_data.port
if port_scan_data.banner is not None:
victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner

View File

@ -1,14 +1,15 @@
from threading import Barrier, Event from threading import Barrier, Event
from typing import Set
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from infection_monkey.i_puppet import PortScanData from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.master import IPScanner from infection_monkey.master import IPScanner
from infection_monkey.puppet.mock_puppet import MockPuppet from infection_monkey.puppet.mock_puppet import MockPuppet
WINDOWS_OS = {"type": "windows"} WINDOWS_OS = "windows"
LINUX_OS = {"type": "linux"} LINUX_OS = "linux"
class MockPuppet(MockPuppet): class MockPuppet(MockPuppet):
@ -47,35 +48,65 @@ def callback():
return MagicMock() return MagicMock()
def assert_dot_1(victim_host): def assert_port_status(port_scan_data, expected_open_ports: Set[int]):
assert victim_host.icmp is True for psd in port_scan_data.values():
assert victim_host.os == WINDOWS_OS if psd.port in expected_open_ports:
assert psd.status == PortStatus.OPEN
assert len(victim_host.services.keys()) == 2 else:
assert "tcp-445" in victim_host.services assert psd.status == PortStatus.CLOSED
assert victim_host.services["tcp-445"]["port"] == 445
assert victim_host.services["tcp-445"]["banner"] == "SMB BANNER"
assert "tcp-3389" in victim_host.services
assert victim_host.services["tcp-3389"]["port"] == 3389
def assert_dot_3(victim_host): def assert_dot_1(ip, ping_scan_data, port_scan_data):
assert victim_host.icmp is True assert ip == "10.0.0.1"
assert victim_host.os == LINUX_OS
assert len(victim_host.services.keys()) == 2 assert ping_scan_data.response_received is True
assert "tcp-22" in victim_host.services assert ping_scan_data.os == WINDOWS_OS
assert victim_host.services["tcp-22"]["port"] == 22
assert victim_host.services["tcp-22"]["banner"] == "SSH BANNER"
assert "tcp-443" in victim_host.services assert len(port_scan_data.keys()) == 6
assert victim_host.services["tcp-443"]["port"] == 443
assert victim_host.services["tcp-443"]["banner"] == "HTTPS BANNER" psd_445 = port_scan_data[445]
psd_3389 = port_scan_data[3389]
assert psd_445.status == PortStatus.OPEN
assert psd_445.port == 445
assert psd_445.banner == "SMB BANNER"
assert psd_445.service == "tcp-445"
assert psd_3389.status == PortStatus.OPEN
assert psd_3389.port == 3389
assert psd_3389.banner == ""
assert psd_3389.service == "tcp-3389"
assert_port_status(port_scan_data, {445, 3389})
def assert_host_down(victim_host): def assert_dot_3(ip, ping_scan_data, port_scan_data):
assert victim_host.icmp is False assert ip == "10.0.0.3"
assert len(victim_host.services.keys()) == 0
assert ping_scan_data.response_received is True
assert ping_scan_data.os == LINUX_OS
assert len(port_scan_data.keys()) == 6
psd_443 = port_scan_data[443]
psd_22 = port_scan_data[22]
assert psd_443.port == 443
assert psd_443.banner == "HTTPS BANNER"
assert psd_443.service == "tcp-443"
assert psd_22.port == 22
assert psd_22.banner == "SSH BANNER"
assert psd_22.service == "tcp-22"
assert_port_status(port_scan_data, {22, 443})
def assert_host_down(ip, ping_scan_data, port_scan_data):
assert ip not in {"10.0.0.1", "10.0.0.3"}
assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6
assert_port_status(port_scan_data, {})
def test_scan_single_ip(callback, scan_config, stop): def test_scan_single_ip(callback, scan_config, stop):
@ -86,7 +117,8 @@ def test_scan_single_ip(callback, scan_config, stop):
callback.assert_called_once() callback.assert_called_once()
assert_dot_1(callback.call_args_list[0][0][0]) print(type(callback.call_args_list[0][0]))
assert_dot_1(*(callback.call_args_list[0][0]))
def test_scan_multiple_ips(callback, scan_config, stop): def test_scan_multiple_ips(callback, scan_config, stop):
@ -97,10 +129,10 @@ def test_scan_multiple_ips(callback, scan_config, stop):
assert callback.call_count == 4 assert callback.call_count == 4
assert_dot_1(callback.call_args_list[0][0][0]) assert_dot_1(*(callback.call_args_list[0][0]))
assert_host_down(callback.call_args_list[1][0][0]) assert_host_down(*(callback.call_args_list[1][0]))
assert_dot_3(callback.call_args_list[2][0][0]) assert_dot_3(*(callback.call_args_list[2][0]))
assert_host_down(callback.call_args_list[3][0][0]) assert_host_down(*(callback.call_args_list[3][0]))
def test_scan_lots_of_ips(callback, scan_config, stop): def test_scan_lots_of_ips(callback, scan_config, stop):
@ -113,7 +145,7 @@ def test_scan_lots_of_ips(callback, scan_config, stop):
def test_stop_after_callback(scan_config, stop): def test_stop_after_callback(scan_config, stop):
def _callback(_): def _callback(_, __, ___):
# Block all threads here until 2 threads reach this barrier, then set stop # Block all threads here until 2 threads reach this barrier, then set stop
# and test that niether thread continues to scan. # and test that niether thread continues to scan.
_callback.barrier.wait() _callback.barrier.wait()