Agent: Integrate scan_target_generator with AutomatedMaster

This commit is contained in:
Mike Salvatore 2021-12-16 11:07:35 -05:00
parent 8e0efb1993
commit 332649d5d1
8 changed files with 205 additions and 72 deletions

View File

@ -7,6 +7,7 @@ from infection_monkey.i_control_channel import IControlChannel, IslandCommunicat
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet from infection_monkey.i_puppet import IPuppet
from infection_monkey.model import VictimHostFactory from infection_monkey.model import VictimHostFactory
from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
@ -33,6 +34,7 @@ class AutomatedMaster(IMaster):
telemetry_messenger: ITelemetryMessenger, telemetry_messenger: ITelemetryMessenger,
victim_host_factory: VictimHostFactory, victim_host_factory: VictimHostFactory,
control_channel: IControlChannel, control_channel: IControlChannel,
local_network_interfaces: List[NetworkInterface],
): ):
self._puppet = puppet self._puppet = puppet
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
@ -41,7 +43,11 @@ class AutomatedMaster(IMaster):
ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS) ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS) exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS)
self._propagator = Propagator( self._propagator = Propagator(
self._telemetry_messenger, ip_scanner, exploiter, victim_host_factory self._telemetry_messenger,
ip_scanner,
exploiter,
victim_host_factory,
local_network_interfaces,
) )
self._stop = threading.Event() self._stop = threading.Event()

View File

@ -12,14 +12,14 @@ from infection_monkey.i_puppet import (
PortScanData, PortScanData,
PortStatus, PortStatus,
) )
from infection_monkey.network import NetworkAddress
from . import IPScanResults from . import IPScanResults
from .threading_utils import run_worker_threads from .threading_utils import run_worker_threads
logger = logging.getLogger() logger = logging.getLogger()
IP = str Callback = Callable[[NetworkAddress, IPScanResults], None]
Callback = Callable[[IP, IPScanResults], None]
class IPScanner: class IPScanner:
@ -27,22 +27,33 @@ class IPScanner:
self._puppet = puppet self._puppet = puppet
self._num_workers = num_workers self._num_workers = num_workers
def scan(self, ips_to_scan: List[str], options: Dict, results_callback: Callback, stop: Event): def scan(
self,
addresses_to_scan: List[NetworkAddress],
options: Dict,
results_callback: Callback,
stop: Event,
):
# Pre-fill a Queue with all IPs to scan so that threads know they can safely exit when the # Pre-fill a Queue with all IPs to scan so that threads know they can safely exit when the
# queue is empty. # queue is empty.
ips = Queue() addresses = Queue()
for ip in ips_to_scan: for address in addresses_to_scan:
ips.put(ip) addresses.put(address)
scan_ips_args = (ips, options, results_callback, stop) scan_ips_args = (addresses, options, results_callback, stop)
run_worker_threads(target=self._scan_ips, args=scan_ips_args, num_workers=self._num_workers) run_worker_threads(
target=self._scan_addresses, args=scan_ips_args, num_workers=self._num_workers
)
def _scan_ips(self, ips: Queue, options: Dict, results_callback: Callback, stop: Event): def _scan_addresses(
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event
):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
try: try:
while not stop.is_set(): while not stop.is_set():
ip = ips.get_nowait() address = addresses.get_nowait()
ip = address.ip
logger.info(f"Scanning {ip}") logger.info(f"Scanning {ip}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000 icmp_timeout = options["icmp"]["timeout_ms"] / 1000
@ -60,7 +71,7 @@ class IPScanner:
) )
scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data) scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data)
results_callback(ip, scan_results) results_callback(address, scan_results)
logger.debug( logger.debug(
f"Detected the stop signal, scanning thread {threading.get_ident()} exiting" f"Detected the stop signal, scanning thread {threading.get_ident()} exiting"

View File

@ -1,7 +1,7 @@
import logging import logging
from queue import Queue from queue import Queue
from threading import Event from threading import Event
from typing import Dict from typing import Dict, List
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
@ -11,6 +11,8 @@ from infection_monkey.i_puppet import (
PortStatus, PortStatus,
) )
from infection_monkey.model import VictimHost, VictimHostFactory from infection_monkey.model import VictimHost, VictimHostFactory
from infection_monkey.network import NetworkAddress, NetworkInterface
from infection_monkey.network.scan_target_generator import compile_scan_target_list
from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.scan_telem import ScanTelem from infection_monkey.telemetry.scan_telem import ScanTelem
@ -28,11 +30,13 @@ class Propagator:
ip_scanner: IPScanner, ip_scanner: IPScanner,
exploiter: Exploiter, exploiter: Exploiter,
victim_host_factory: VictimHostFactory, victim_host_factory: VictimHostFactory,
local_network_interfaces: List[NetworkInterface],
): ):
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
self._ip_scanner = ip_scanner self._ip_scanner = ip_scanner
self._exploiter = exploiter self._exploiter = exploiter
self._victim_host_factory = victim_host_factory self._victim_host_factory = victim_host_factory
self._local_network_interfaces = local_network_interfaces
self._hosts_to_exploit = None self._hosts_to_exploit = None
def propagate(self, propagation_config: Dict, stop: Event): def propagate(self, propagation_config: Dict, stop: Event):
@ -62,16 +66,30 @@ class Propagator:
def _scan_network(self, propagation_config: Dict, stop: Event): def _scan_network(self, propagation_config: Dict, stop: Event):
logger.info("Starting network scan") logger.info("Starting network scan")
# TODO: Generate list of IPs to scan from propagation targets config target_config = propagation_config["targets"]
ips_to_scan = propagation_config["targets"]["subnet_scan_list"]
scan_config = propagation_config["network_scan"] scan_config = propagation_config["network_scan"]
self._ip_scanner.scan(ips_to_scan, scan_config, self._process_scan_results, stop)
addresses_to_scan = self._compile_scan_target_list(target_config)
self._ip_scanner.scan(addresses_to_scan, scan_config, self._process_scan_results, stop)
logger.info("Finished network scan") logger.info("Finished network scan")
def _process_scan_results(self, ip: str, scan_results: IPScanResults): def _compile_scan_target_list(self, target_config: Dict) -> List[NetworkAddress]:
victim_host = self._victim_host_factory.build_victim_host(ip) ranges_to_scan = target_config["subnet_scan_list"]
inaccessible_subnets = target_config["inaccessible_subnets"]
blocklisted_ips = target_config["blocked_ips"]
enable_local_network_scan = target_config["local_network_scan"]
return compile_scan_target_list(
self._local_network_interfaces,
ranges_to_scan,
inaccessible_subnets,
blocklisted_ips,
enable_local_network_scan,
)
def _process_scan_results(self, address: NetworkAddress, scan_results: IPScanResults):
victim_host = self._victim_host_factory.build_victim_host(address.ip, address.domain)
Propagator._process_ping_scan_results(victim_host, scan_results.ping_scan_data) Propagator._process_ping_scan_results(victim_host, scan_results.ping_scan_data)
Propagator._process_tcp_scan_results(victim_host, scan_results.port_scan_data) Propagator._process_tcp_scan_results(victim_host, scan_results.port_scan_data)

View File

@ -5,8 +5,8 @@ class VictimHostFactory:
def __init__(self): def __init__(self):
pass pass
def build_victim_host(self, ip: str): def build_victim_host(self, ip: str, domain: str):
victim_host = VictimHost(ip) victim_host = VictimHost(ip, domain)
# TODO: Reimplement the below logic from the old monkey.py # TODO: Reimplement the below logic from the old monkey.py
""" """

View File

@ -0,0 +1 @@
from .scan_target_generator import NetworkAddress, NetworkInterface

View File

@ -14,7 +14,7 @@ INTERVAL = 0.001
def test_terminate_without_start(): def test_terminate_without_start():
m = AutomatedMaster(None, None, None, None) m = AutomatedMaster(None, None, None, None, [])
# Test that call to terminate does not raise exception # Test that call to terminate does not raise exception
m.terminate() m.terminate()
@ -34,7 +34,7 @@ def test_stop_if_cant_get_config_from_island(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL "infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL
) )
m = AutomatedMaster(None, None, None, cc) m = AutomatedMaster(None, None, None, cc, [])
m.start() m.start()
assert cc.get_config.call_count == CHECK_FOR_CONFIG_COUNT assert cc.get_config.call_count == CHECK_FOR_CONFIG_COUNT
@ -73,7 +73,7 @@ def test_stop_if_cant_get_stop_signal_from_island(monkeypatch, sleep_and_return_
"infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL "infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL
) )
m = AutomatedMaster(None, None, None, cc) m = AutomatedMaster(None, None, None, cc, [])
m.start() m.start()
assert cc.should_agent_stop.call_count == CHECK_FOR_STOP_AGENT_COUNT assert cc.should_agent_stop.call_count == CHECK_FOR_STOP_AGENT_COUNT

View File

@ -6,6 +6,7 @@ import pytest
from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus
from infection_monkey.master import IPScanner from infection_monkey.master import IPScanner
from infection_monkey.network import NetworkAddress
from infection_monkey.puppet.mock_puppet import MockPuppet from infection_monkey.puppet.mock_puppet import MockPuppet
WINDOWS_OS = "windows" WINDOWS_OS = "windows"
@ -51,20 +52,21 @@ def assert_port_status(port_scan_data, expected_open_ports: Set[int]):
assert psd.status == PortStatus.CLOSED assert psd.status == PortStatus.CLOSED
def assert_scan_results(ip, scan_results): def assert_scan_results(address, scan_results):
ping_scan_data = scan_results.ping_scan_data ping_scan_data = scan_results.ping_scan_data
port_scan_data = scan_results.port_scan_data port_scan_data = scan_results.port_scan_data
fingerprint_data = scan_results.fingerprint_data fingerprint_data = scan_results.fingerprint_data
if ip == "10.0.0.1": if address.ip == "10.0.0.1":
assert_scan_results_no_1(ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results_no_1(address.domain, ping_scan_data, port_scan_data, fingerprint_data)
elif ip == "10.0.0.3": elif address.ip == "10.0.0.3":
assert_scan_results_no_3(ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results_no_3(address.domain, ping_scan_data, port_scan_data, fingerprint_data)
else: else:
assert_scan_results_host_down(ip, ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data)
def assert_scan_results_no_1(ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results_no_1(domain, ping_scan_data, port_scan_data, fingerprint_data):
assert domain == "d1"
assert ping_scan_data.response_received is True assert ping_scan_data.response_received is True
assert ping_scan_data.os == WINDOWS_OS assert ping_scan_data.os == WINDOWS_OS
@ -97,7 +99,9 @@ def assert_fingerprint_results_no_1(fingerprint_data):
assert fingerprint_data["SMBFinger"].services["tcp-445"]["name"] == "smb_service_name" assert fingerprint_data["SMBFinger"].services["tcp-445"]["name"] == "smb_service_name"
def assert_scan_results_no_3(ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results_no_3(domain, ping_scan_data, port_scan_data, fingerprint_data):
assert domain == "d3"
assert ping_scan_data.response_received is True assert ping_scan_data.response_received is True
assert ping_scan_data.os == LINUX_OS assert ping_scan_data.os == LINUX_OS
assert len(port_scan_data.keys()) == 6 assert len(port_scan_data.keys()) == 6
@ -135,8 +139,9 @@ def assert_fingerprint_results_no_3(fingerprint_data):
assert fingerprint_data["HTTPFinger"].services["tcp-443"]["data"] == ("SERVER_HEADERS_2", True) assert fingerprint_data["HTTPFinger"].services["tcp-443"]["data"] == ("SERVER_HEADERS_2", True)
def assert_scan_results_host_down(ip, ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data):
assert ip not in {"10.0.0.1", "10.0.0.3"} assert address.ip not in {"10.0.0.1", "10.0.0.3"}
assert address.domain is None
assert ping_scan_data.response_received is False assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6 assert len(port_scan_data.keys()) == 6
@ -146,44 +151,49 @@ def assert_scan_results_host_down(ip, ping_scan_data, port_scan_data, fingerprin
def test_scan_single_ip(callback, scan_config, stop): def test_scan_single_ip(callback, scan_config, stop):
ips = ["10.0.0.1"] addresses = [NetworkAddress("10.0.0.1", "d1")]
ns = IPScanner(MockPuppet(), num_workers=1) ns = IPScanner(MockPuppet(), num_workers=1)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
callback.assert_called_once() callback.assert_called_once()
(ip, scan_results) = callback.call_args_list[0][0] (address, scan_results) = callback.call_args_list[0][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
def test_scan_multiple_ips(callback, scan_config, stop): def test_scan_multiple_ips(callback, scan_config, stop):
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] addresses = [
NetworkAddress("10.0.0.1", "d1"),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", "d3"),
NetworkAddress("10.0.0.4", None),
]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert callback.call_count == 4 assert callback.call_count == 4
(ip, scan_results) = callback.call_args_list[0][0] (address, scan_results) = callback.call_args_list[0][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
(ip, scan_results) = callback.call_args_list[1][0] (address, scan_results) = callback.call_args_list[1][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
(ip, scan_results) = callback.call_args_list[2][0] (address, scan_results) = callback.call_args_list[2][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
(ip, scan_results) = callback.call_args_list[3][0] (address, scan_results) = callback.call_args_list[3][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
@pytest.mark.slow @pytest.mark.slow
def test_scan_lots_of_ips(callback, scan_config, stop): def test_scan_lots_of_ips(callback, scan_config, stop):
ips = [f"10.0.0.{i}" for i in range(0, 255)] addresses = [NetworkAddress(f"10.0.0.{i}", None) for i in range(0, 255)]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert callback.call_count == 255 assert callback.call_count == 255
@ -199,10 +209,15 @@ def test_stop_after_callback(scan_config, stop):
stoppable_callback = MagicMock(side_effect=_callback) stoppable_callback = MagicMock(side_effect=_callback)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
]
ns = IPScanner(MockPuppet(), num_workers=2) ns = IPScanner(MockPuppet(), num_workers=2)
ns.scan(ips, scan_config, stoppable_callback, stop) ns.scan(addresses, scan_config, stoppable_callback, stop)
assert stoppable_callback.call_count == 2 assert stoppable_callback.call_count == 2
@ -221,10 +236,15 @@ def test_interrupt_port_scanning(callback, scan_config, stop):
puppet = MockPuppet() puppet = MockPuppet()
puppet.scan_tcp_port = MagicMock(side_effect=stoppable_scan_tcp_port) puppet.scan_tcp_port = MagicMock(side_effect=stoppable_scan_tcp_port)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert puppet.scan_tcp_port.call_count == 2 assert puppet.scan_tcp_port.call_count == 2
@ -243,9 +263,14 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
puppet = MockPuppet() puppet = MockPuppet()
puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint) puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert puppet.fingerprint.call_count == 2 assert puppet.fingerprint.call_count == 2

View File

@ -1,4 +1,7 @@
from threading import Event from threading import Event
from unittest.mock import MagicMock
import pytest
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
@ -9,6 +12,7 @@ from infection_monkey.i_puppet import (
) )
from infection_monkey.master import IPScanResults, Propagator from infection_monkey.master import IPScanResults, Propagator
from infection_monkey.model import VictimHostFactory from infection_monkey.model import VictimHostFactory
from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem
empty_fingerprint_data = FingerprintData(None, None, {}) empty_fingerprint_data = FingerprintData(None, None, {})
@ -83,15 +87,21 @@ dot_3_services = {
} }
class MockIPScanner: @pytest.fixture
def scan(self, ips_to_scan, _, results_callback, stop): def mock_ip_scanner():
for ip in ips_to_scan: def scan(adresses_to_scan, _, results_callback, stop):
if ip.endswith(".1"): for address in adresses_to_scan:
results_callback(ip, dot_1_scan_results) if address.ip.endswith(".1"):
elif ip.endswith(".3"): results_callback(address, dot_1_scan_results)
results_callback(ip, dot_3_scan_results) elif address.ip.endswith(".3"):
results_callback(address, dot_3_scan_results)
else: else:
results_callback(ip, dead_host_scan_results) results_callback(address, dead_host_scan_results)
ip_scanner = MagicMock()
ip_scanner.scan = MagicMock(side_effect=scan)
return ip_scanner
class StubExploiter: class StubExploiter:
@ -101,11 +111,18 @@ class StubExploiter:
pass pass
def test_scan_result_processing(telemetry_messenger_spy): def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner):
p = Propagator(telemetry_messenger_spy, MockIPScanner(), StubExploiter(), VictimHostFactory()) p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), VictimHostFactory(), []
)
p.propagate( p.propagate(
{ {
"targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}, "targets": {
"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"],
"local_network_scan": False,
"inaccessible_subnets": [],
"blocked_ips": [],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it "network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since StubExploiter ignores it "exploiters": {}, # This is empty since StubExploiter ignores it
}, },
@ -141,10 +158,13 @@ class MockExploiter:
def exploit_hosts( def exploit_hosts(
self, exploiter_config, hosts_to_exploit, results_callback, scan_completed, stop self, exploiter_config, hosts_to_exploit, results_callback, scan_completed, stop
): ):
scan_completed.wait()
hte = [] hte = []
for _ in range(0, 2): for _ in range(0, 2):
hte.append(hosts_to_exploit.get()) hte.append(hosts_to_exploit.get())
assert hosts_to_exploit.empty()
for host in hte: for host in hte:
if host.ip_addr.endswith(".1"): if host.ip_addr.endswith(".1"):
results_callback( results_callback(
@ -157,7 +177,7 @@ class MockExploiter:
host, host,
ExploiterResultData(False, {}, {}, "SSH FAILED for .1"), ExploiterResultData(False, {}, {}, "SSH FAILED for .1"),
) )
if host.ip_addr.endswith(".2"): elif host.ip_addr.endswith(".2"):
results_callback( results_callback(
"PowerShellExploiter", "PowerShellExploiter",
host, host,
@ -168,7 +188,7 @@ class MockExploiter:
host, host,
ExploiterResultData(False, {}, {}, "SSH FAILED for .2"), ExploiterResultData(False, {}, {}, "SSH FAILED for .2"),
) )
if host.ip_addr.endswith(".3"): elif host.ip_addr.endswith(".3"):
results_callback( results_callback(
"PowerShellExploiter", "PowerShellExploiter",
host, host,
@ -181,11 +201,18 @@ class MockExploiter:
) )
def test_exploiter_result_processing(telemetry_messenger_spy): def test_exploiter_result_processing(telemetry_messenger_spy, mock_ip_scanner):
p = Propagator(telemetry_messenger_spy, MockIPScanner(), MockExploiter(), VictimHostFactory()) p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), VictimHostFactory(), []
)
p.propagate( p.propagate(
{ {
"targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}, "targets": {
"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"],
"local_network_scan": False,
"inaccessible_subnets": [],
"blocked_ips": [],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it "network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since MockExploiter ignores it "exploiters": {}, # This is empty since MockExploiter ignores it
}, },
@ -211,3 +238,48 @@ def test_exploiter_result_processing(telemetry_messenger_spy):
assert not data["result"] assert not data["result"]
else: else:
assert data["result"] assert data["result"]
def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner):
local_network_interfaces = [NetworkInterface("10.0.0.9", "/29")]
p = Propagator(
telemetry_messenger_spy,
mock_ip_scanner,
StubExploiter(),
VictimHostFactory(),
local_network_interfaces,
)
p.propagate(
{
"targets": {
"subnet_scan_list": ["10.0.0.0/29", "172.10.20.30"],
"local_network_scan": True,
"blocked_ips": ["10.0.0.3"],
"inaccessible_subnets": ["10.0.0.128/30", "10.0.0.8/29"],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since MockExploiter ignores it
},
Event(),
)
expected_ip_scan_list = [
"10.0.0.0",
"10.0.0.1",
"10.0.0.2",
"10.0.0.4",
"10.0.0.5",
"10.0.0.6",
"10.0.0.8",
"10.0.0.10",
"10.0.0.11",
"10.0.0.12",
"10.0.0.13",
"10.0.0.14",
"10.0.0.128",
"10.0.0.129",
"10.0.0.130",
"172.10.20.30",
]
actual_ip_scan_list = [address.ip for address in mock_ip_scanner.scan.call_args_list[0][0][0]]
assert actual_ip_scan_list == expected_ip_scan_list