Agent: Extract network scanner into its own class
This commit is contained in:
parent
3f7dbbccc2
commit
81d4afab52
|
@ -1 +1,2 @@
|
|||
from .ip_scanner import IPScanner
|
||||
from .automated_master import AutomatedMaster
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
|
@ -8,7 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple
|
|||
|
||||
from infection_monkey.i_control_channel import IControlChannel
|
||||
from infection_monkey.i_master import IMaster
|
||||
from infection_monkey.i_puppet import IPuppet, PortStatus
|
||||
from infection_monkey.i_puppet import IPuppet
|
||||
from infection_monkey.model.host import VictimHost
|
||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
|
||||
|
@ -16,6 +15,7 @@ from infection_monkey.telemetry.scan_telem import ScanTelem
|
|||
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
from . import IPScanner
|
||||
from .threading_utils import create_daemon_thread
|
||||
|
||||
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
|
||||
|
@ -37,6 +37,9 @@ class AutomatedMaster(IMaster):
|
|||
self._telemetry_messenger = telemetry_messenger
|
||||
self._control_channel = control_channel
|
||||
|
||||
self._ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
|
||||
self._hosts_to_exploit = None
|
||||
|
||||
self._stop = threading.Event()
|
||||
self._master_thread = create_daemon_thread(target=self._run_master_thread)
|
||||
self._simulation_thread = create_daemon_thread(target=self._run_simulation)
|
||||
|
@ -156,17 +159,16 @@ class AutomatedMaster(IMaster):
|
|||
def _can_propagate(self):
|
||||
return True
|
||||
|
||||
# TODO: Refactor propagation into its own class
|
||||
def _propagate(self, config: Dict):
|
||||
logger.info("Attempting to propagate")
|
||||
|
||||
hosts_to_exploit = Queue()
|
||||
self._hosts_to_exploit = Queue()
|
||||
|
||||
scan_thread = create_daemon_thread(
|
||||
target=self._scan_network, args=(config["network_scan"], hosts_to_exploit)
|
||||
)
|
||||
exploit_thread = create_daemon_thread(
|
||||
target=self._exploit_targets, args=(hosts_to_exploit, scan_thread)
|
||||
target=self._scan_network, args=(config["network_scan"],)
|
||||
)
|
||||
exploit_thread = create_daemon_thread(target=self._exploit_targets, args=(scan_thread,))
|
||||
|
||||
scan_thread.start()
|
||||
exploit_thread.start()
|
||||
|
@ -176,73 +178,28 @@ class AutomatedMaster(IMaster):
|
|||
|
||||
logger.info("Finished attempting to propagate")
|
||||
|
||||
def _exploit_targets(self, hosts_to_exploit: Queue, scan_thread: Thread):
|
||||
pass
|
||||
|
||||
# TODO: Refactor this into its own class
|
||||
def _scan_network(self, scan_config: Dict, hosts_to_exploit: Queue):
|
||||
def _scan_network(self, scan_config: Dict):
|
||||
logger.info("Starting network scan")
|
||||
|
||||
# TODO: Generate list of IPs to scan
|
||||
ips_to_scan = Queue()
|
||||
for i in range(1, 255):
|
||||
ips_to_scan.put(f"10.0.0.{i}")
|
||||
ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)]
|
||||
|
||||
scan_threads = []
|
||||
for i in range(0, NUM_SCAN_THREADS):
|
||||
t = create_daemon_thread(
|
||||
target=self._scan_ips, args=(ips_to_scan, scan_config, hosts_to_exploit)
|
||||
)
|
||||
t.start()
|
||||
scan_threads.append(t)
|
||||
|
||||
for t in scan_threads:
|
||||
t.join()
|
||||
self._ip_scanner.scan(
|
||||
ips_to_scan,
|
||||
scan_config["icmp"],
|
||||
scan_config["tcp"],
|
||||
self._handle_scanned_host,
|
||||
self._stop,
|
||||
)
|
||||
|
||||
logger.info("Finished network scan")
|
||||
|
||||
def _scan_ips(self, ips_to_scan: Queue, scan_config: Dict, hosts_to_exploit: Queue):
|
||||
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
|
||||
try:
|
||||
while not self._stop.is_set():
|
||||
ip = ips_to_scan.get_nowait()
|
||||
logger.info(f"Scanning {ip}")
|
||||
def _handle_scanned_host(self, host: VictimHost):
|
||||
self._hosts_to_exploit.put(host)
|
||||
self._telemetry_messenger.send_telemetry(ScanTelem(host))
|
||||
|
||||
victim_host = VictimHost(ip)
|
||||
|
||||
self._ping_ip(ip, victim_host, scan_config["icmp"])
|
||||
self._scan_tcp_ports(ip, victim_host, scan_config["tcp"])
|
||||
|
||||
hosts_to_exploit.put(hosts_to_exploit)
|
||||
self._telemetry_messenger.send_telemetry(ScanTelem(victim_host))
|
||||
|
||||
except queue.Empty:
|
||||
logger.debug(
|
||||
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
|
||||
)
|
||||
|
||||
logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting")
|
||||
|
||||
def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict):
|
||||
(response_received, os) = self._puppet.ping(ip, options)
|
||||
|
||||
victim_host.icmp = response_received
|
||||
if os is not None:
|
||||
victim_host.os["type"] = os
|
||||
|
||||
def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict):
|
||||
for p in options["ports"]:
|
||||
if self._stop.is_set():
|
||||
break
|
||||
|
||||
# TODO: check units of timeout
|
||||
port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
|
||||
if port_scan_data.status == PortStatus.OPEN:
|
||||
victim_host.services[port_scan_data.service] = {}
|
||||
victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
|
||||
victim_host.services[port_scan_data.service]["port"] = port_scan_data.port
|
||||
if port_scan_data.banner is not None:
|
||||
victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner
|
||||
def _exploit_targets(self, scan_thread: Thread):
|
||||
pass
|
||||
|
||||
def _run_payload(self, payload: Tuple[str, Dict]):
|
||||
name = payload[0]
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from queue import Queue
|
||||
from threading import Event
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from infection_monkey.i_puppet import IPuppet, PortStatus
|
||||
from infection_monkey.model.host import VictimHost
|
||||
|
||||
from .threading_utils import create_daemon_thread
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
Callback = Callable[[VictimHost], None]
|
||||
|
||||
|
||||
class IPScanner:
|
||||
def __init__(self, puppet: IPuppet, num_workers: int):
|
||||
self._puppet = puppet
|
||||
self._num_workers = num_workers
|
||||
|
||||
def scan(
|
||||
self,
|
||||
ips: List[str],
|
||||
icmp_config: Dict,
|
||||
tcp_config: Dict,
|
||||
report_results_callback: Callback,
|
||||
stop: Event,
|
||||
):
|
||||
# Pre-fill a Queue with all IPs so that threads can safely exit when the queue is empty.
|
||||
ips_to_scan = Queue()
|
||||
for ip in ips:
|
||||
ips_to_scan.put(ip)
|
||||
|
||||
scan_ips_args = (
|
||||
ips_to_scan,
|
||||
icmp_config,
|
||||
tcp_config,
|
||||
report_results_callback,
|
||||
stop,
|
||||
)
|
||||
scan_threads = []
|
||||
for i in range(0, self._num_workers):
|
||||
t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args)
|
||||
t.start()
|
||||
scan_threads.append(t)
|
||||
|
||||
for t in scan_threads:
|
||||
t.join()
|
||||
|
||||
def _scan_ips(
|
||||
self,
|
||||
ips_to_scan: Queue,
|
||||
icmp_config: Dict,
|
||||
tcp_config: Dict,
|
||||
report_results_callback: Callback,
|
||||
stop: Event,
|
||||
):
|
||||
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
|
||||
|
||||
try:
|
||||
while not stop.is_set():
|
||||
ip = ips_to_scan.get_nowait()
|
||||
logger.info(f"Scanning {ip}")
|
||||
|
||||
victim_host = VictimHost(ip)
|
||||
|
||||
self._ping_ip(ip, victim_host, icmp_config)
|
||||
self._scan_tcp_ports(ip, victim_host, tcp_config, stop)
|
||||
|
||||
report_results_callback(victim_host)
|
||||
|
||||
except queue.Empty:
|
||||
logger.debug(
|
||||
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting")
|
||||
|
||||
def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict):
|
||||
(response_received, os) = self._puppet.ping(ip, options)
|
||||
|
||||
victim_host.icmp = response_received
|
||||
if os is not None:
|
||||
victim_host.os["type"] = os
|
||||
|
||||
def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict, stop: Event):
|
||||
for p in options["ports"]:
|
||||
if stop.is_set():
|
||||
break
|
||||
|
||||
port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
|
||||
if port_scan_data.status == PortStatus.OPEN:
|
||||
victim_host.services[port_scan_data.service] = {}
|
||||
victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
|
||||
victim_host.services[port_scan_data.service]["port"] = port_scan_data.port
|
||||
if port_scan_data.banner is not None:
|
||||
victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner
|
|
@ -0,0 +1,146 @@
|
|||
from threading import Barrier, Event
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.i_puppet import PortScanData
|
||||
from infection_monkey.master import IPScanner
|
||||
from infection_monkey.puppet.mock_puppet import MockPuppet
|
||||
|
||||
WINDOWS_OS = {"type": "windows"}
|
||||
LINUX_OS = {"type": "linux"}
|
||||
|
||||
|
||||
class MockPuppet(MockPuppet):
|
||||
def __init__(self):
|
||||
self.ping = MagicMock(side_effect=super().ping)
|
||||
self.scan_tcp_port = MagicMock(side_effect=super().scan_tcp_port)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tcp_scan_config():
|
||||
return {
|
||||
"timeout_ms": 3000,
|
||||
"ports": [
|
||||
22,
|
||||
445,
|
||||
3389,
|
||||
443,
|
||||
8008,
|
||||
3306,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def icmp_scan_config():
|
||||
return {
|
||||
"timeout_ms": 1000,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop():
|
||||
return Event()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def callback():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def assert_dot_1(victim_host):
|
||||
assert victim_host.icmp is True
|
||||
assert victim_host.os == WINDOWS_OS
|
||||
|
||||
assert len(victim_host.services.keys()) == 2
|
||||
assert "tcp-445" in victim_host.services
|
||||
assert victim_host.services["tcp-445"]["port"] == 445
|
||||
assert victim_host.services["tcp-445"]["banner"] == "SMB BANNER"
|
||||
assert "tcp-3389" in victim_host.services
|
||||
assert victim_host.services["tcp-3389"]["port"] == 3389
|
||||
|
||||
|
||||
def assert_dot_3(victim_host):
|
||||
assert victim_host.icmp is True
|
||||
assert victim_host.os == LINUX_OS
|
||||
|
||||
assert len(victim_host.services.keys()) == 2
|
||||
assert "tcp-22" in victim_host.services
|
||||
assert victim_host.services["tcp-22"]["port"] == 22
|
||||
assert victim_host.services["tcp-22"]["banner"] == "SSH BANNER"
|
||||
|
||||
assert "tcp-443" in victim_host.services
|
||||
assert victim_host.services["tcp-443"]["port"] == 443
|
||||
assert victim_host.services["tcp-443"]["banner"] == "HTTPS BANNER"
|
||||
|
||||
|
||||
def assert_host_down(victim_host):
|
||||
assert victim_host.icmp is False
|
||||
assert len(victim_host.services.keys()) == 0
|
||||
|
||||
|
||||
def test_scan_single_ip(callback, icmp_scan_config, tcp_scan_config, stop):
|
||||
ips = ["10.0.0.1"]
|
||||
|
||||
ns = IPScanner(MockPuppet(), num_workers=1)
|
||||
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
|
||||
|
||||
callback.assert_called_once()
|
||||
|
||||
assert_dot_1(callback.call_args_list[0][0][0])
|
||||
|
||||
|
||||
def test_scan_multiple_ips(callback, icmp_scan_config, tcp_scan_config, stop):
|
||||
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
|
||||
|
||||
ns = IPScanner(MockPuppet(), num_workers=4)
|
||||
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
|
||||
|
||||
assert callback.call_count == 4
|
||||
|
||||
assert_dot_1(callback.call_args_list[0][0][0])
|
||||
assert_host_down(callback.call_args_list[1][0][0])
|
||||
assert_dot_3(callback.call_args_list[2][0][0])
|
||||
assert_host_down(callback.call_args_list[3][0][0])
|
||||
|
||||
|
||||
def test_stop_after_callback(icmp_scan_config, tcp_scan_config, stop):
|
||||
def _callback(_):
|
||||
# Block all threads here until 2 threads reach this barrier, then set stop
|
||||
# and test that niether thread continues to scan.
|
||||
_callback.barrier.wait()
|
||||
stop.set()
|
||||
|
||||
_callback.barrier = Barrier(2)
|
||||
|
||||
stopable_callback = MagicMock(side_effect=_callback)
|
||||
|
||||
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
|
||||
|
||||
ns = IPScanner(MockPuppet(), num_workers=2)
|
||||
ns.scan(ips, icmp_scan_config, tcp_scan_config, stopable_callback, stop)
|
||||
|
||||
assert stopable_callback.call_count == 2
|
||||
|
||||
|
||||
def test_interrupt_port_scanning(callback, icmp_scan_config, tcp_scan_config, stop):
|
||||
def stopable_scan_tcp_port(port, _, __):
|
||||
# Block all threads here until 2 threads reach this barrier, then set stop
|
||||
# and test that niether thread scans any more ports
|
||||
stopable_scan_tcp_port.barrier.wait()
|
||||
stop.set()
|
||||
|
||||
return PortScanData(port, False, None, None)
|
||||
|
||||
stopable_scan_tcp_port.barrier = Barrier(2)
|
||||
|
||||
puppet = MockPuppet()
|
||||
puppet.scan_tcp_port = MagicMock(side_effect=stopable_scan_tcp_port)
|
||||
|
||||
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
|
||||
|
||||
ns = IPScanner(puppet, num_workers=2)
|
||||
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
|
||||
|
||||
assert puppet.scan_tcp_port.call_count == 2
|
Loading…
Reference in New Issue