forked from p15670423/monkey
Agent: Extract network scanner into its own class
This commit is contained in:
parent
3f7dbbccc2
commit
81d4afab52
|
@ -1 +1,2 @@
|
||||||
|
from .ip_scanner import IPScanner
|
||||||
from .automated_master import AutomatedMaster
|
from .automated_master import AutomatedMaster
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
import queue
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
@ -8,7 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple
|
||||||
|
|
||||||
from infection_monkey.i_control_channel import IControlChannel
|
from infection_monkey.i_control_channel import IControlChannel
|
||||||
from infection_monkey.i_master import IMaster
|
from infection_monkey.i_master import IMaster
|
||||||
from infection_monkey.i_puppet import IPuppet, PortStatus
|
from infection_monkey.i_puppet import IPuppet
|
||||||
from infection_monkey.model.host import VictimHost
|
from infection_monkey.model.host import VictimHost
|
||||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||||
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
|
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
|
||||||
|
@ -16,6 +15,7 @@ from infection_monkey.telemetry.scan_telem import ScanTelem
|
||||||
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
|
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
|
||||||
from infection_monkey.utils.timer import Timer
|
from infection_monkey.utils.timer import Timer
|
||||||
|
|
||||||
|
from . import IPScanner
|
||||||
from .threading_utils import create_daemon_thread
|
from .threading_utils import create_daemon_thread
|
||||||
|
|
||||||
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
|
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
|
||||||
|
@ -37,6 +37,9 @@ class AutomatedMaster(IMaster):
|
||||||
self._telemetry_messenger = telemetry_messenger
|
self._telemetry_messenger = telemetry_messenger
|
||||||
self._control_channel = control_channel
|
self._control_channel = control_channel
|
||||||
|
|
||||||
|
self._ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
|
||||||
|
self._hosts_to_exploit = None
|
||||||
|
|
||||||
self._stop = threading.Event()
|
self._stop = threading.Event()
|
||||||
self._master_thread = create_daemon_thread(target=self._run_master_thread)
|
self._master_thread = create_daemon_thread(target=self._run_master_thread)
|
||||||
self._simulation_thread = create_daemon_thread(target=self._run_simulation)
|
self._simulation_thread = create_daemon_thread(target=self._run_simulation)
|
||||||
|
@ -156,17 +159,16 @@ class AutomatedMaster(IMaster):
|
||||||
def _can_propagate(self):
|
def _can_propagate(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# TODO: Refactor propagation into its own class
|
||||||
def _propagate(self, config: Dict):
|
def _propagate(self, config: Dict):
|
||||||
logger.info("Attempting to propagate")
|
logger.info("Attempting to propagate")
|
||||||
|
|
||||||
hosts_to_exploit = Queue()
|
self._hosts_to_exploit = Queue()
|
||||||
|
|
||||||
scan_thread = create_daemon_thread(
|
scan_thread = create_daemon_thread(
|
||||||
target=self._scan_network, args=(config["network_scan"], hosts_to_exploit)
|
target=self._scan_network, args=(config["network_scan"],)
|
||||||
)
|
|
||||||
exploit_thread = create_daemon_thread(
|
|
||||||
target=self._exploit_targets, args=(hosts_to_exploit, scan_thread)
|
|
||||||
)
|
)
|
||||||
|
exploit_thread = create_daemon_thread(target=self._exploit_targets, args=(scan_thread,))
|
||||||
|
|
||||||
scan_thread.start()
|
scan_thread.start()
|
||||||
exploit_thread.start()
|
exploit_thread.start()
|
||||||
|
@ -176,73 +178,28 @@ class AutomatedMaster(IMaster):
|
||||||
|
|
||||||
logger.info("Finished attempting to propagate")
|
logger.info("Finished attempting to propagate")
|
||||||
|
|
||||||
def _exploit_targets(self, hosts_to_exploit: Queue, scan_thread: Thread):
|
def _scan_network(self, scan_config: Dict):
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: Refactor this into its own class
|
|
||||||
def _scan_network(self, scan_config: Dict, hosts_to_exploit: Queue):
|
|
||||||
logger.info("Starting network scan")
|
logger.info("Starting network scan")
|
||||||
|
|
||||||
# TODO: Generate list of IPs to scan
|
# TODO: Generate list of IPs to scan
|
||||||
ips_to_scan = Queue()
|
ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)]
|
||||||
for i in range(1, 255):
|
|
||||||
ips_to_scan.put(f"10.0.0.{i}")
|
|
||||||
|
|
||||||
scan_threads = []
|
self._ip_scanner.scan(
|
||||||
for i in range(0, NUM_SCAN_THREADS):
|
ips_to_scan,
|
||||||
t = create_daemon_thread(
|
scan_config["icmp"],
|
||||||
target=self._scan_ips, args=(ips_to_scan, scan_config, hosts_to_exploit)
|
scan_config["tcp"],
|
||||||
)
|
self._handle_scanned_host,
|
||||||
t.start()
|
self._stop,
|
||||||
scan_threads.append(t)
|
)
|
||||||
|
|
||||||
for t in scan_threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
logger.info("Finished network scan")
|
logger.info("Finished network scan")
|
||||||
|
|
||||||
def _scan_ips(self, ips_to_scan: Queue, scan_config: Dict, hosts_to_exploit: Queue):
|
def _handle_scanned_host(self, host: VictimHost):
|
||||||
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
|
self._hosts_to_exploit.put(host)
|
||||||
try:
|
self._telemetry_messenger.send_telemetry(ScanTelem(host))
|
||||||
while not self._stop.is_set():
|
|
||||||
ip = ips_to_scan.get_nowait()
|
|
||||||
logger.info(f"Scanning {ip}")
|
|
||||||
|
|
||||||
victim_host = VictimHost(ip)
|
def _exploit_targets(self, scan_thread: Thread):
|
||||||
|
pass
|
||||||
self._ping_ip(ip, victim_host, scan_config["icmp"])
|
|
||||||
self._scan_tcp_ports(ip, victim_host, scan_config["tcp"])
|
|
||||||
|
|
||||||
hosts_to_exploit.put(hosts_to_exploit)
|
|
||||||
self._telemetry_messenger.send_telemetry(ScanTelem(victim_host))
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
logger.debug(
|
|
||||||
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting")
|
|
||||||
|
|
||||||
def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict):
|
|
||||||
(response_received, os) = self._puppet.ping(ip, options)
|
|
||||||
|
|
||||||
victim_host.icmp = response_received
|
|
||||||
if os is not None:
|
|
||||||
victim_host.os["type"] = os
|
|
||||||
|
|
||||||
def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict):
|
|
||||||
for p in options["ports"]:
|
|
||||||
if self._stop.is_set():
|
|
||||||
break
|
|
||||||
|
|
||||||
# TODO: check units of timeout
|
|
||||||
port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
|
|
||||||
if port_scan_data.status == PortStatus.OPEN:
|
|
||||||
victim_host.services[port_scan_data.service] = {}
|
|
||||||
victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
|
|
||||||
victim_host.services[port_scan_data.service]["port"] = port_scan_data.port
|
|
||||||
if port_scan_data.banner is not None:
|
|
||||||
victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner
|
|
||||||
|
|
||||||
def _run_payload(self, payload: Tuple[str, Dict]):
|
def _run_payload(self, payload: Tuple[str, Dict]):
|
||||||
name = payload[0]
|
name = payload[0]
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Event
|
||||||
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
|
from infection_monkey.i_puppet import IPuppet, PortStatus
|
||||||
|
from infection_monkey.model.host import VictimHost
|
||||||
|
|
||||||
|
from .threading_utils import create_daemon_thread
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
Callback = Callable[[VictimHost], None]
|
||||||
|
|
||||||
|
|
||||||
|
class IPScanner:
|
||||||
|
def __init__(self, puppet: IPuppet, num_workers: int):
|
||||||
|
self._puppet = puppet
|
||||||
|
self._num_workers = num_workers
|
||||||
|
|
||||||
|
def scan(
|
||||||
|
self,
|
||||||
|
ips: List[str],
|
||||||
|
icmp_config: Dict,
|
||||||
|
tcp_config: Dict,
|
||||||
|
report_results_callback: Callback,
|
||||||
|
stop: Event,
|
||||||
|
):
|
||||||
|
# Pre-fill a Queue with all IPs so that threads can safely exit when the queue is empty.
|
||||||
|
ips_to_scan = Queue()
|
||||||
|
for ip in ips:
|
||||||
|
ips_to_scan.put(ip)
|
||||||
|
|
||||||
|
scan_ips_args = (
|
||||||
|
ips_to_scan,
|
||||||
|
icmp_config,
|
||||||
|
tcp_config,
|
||||||
|
report_results_callback,
|
||||||
|
stop,
|
||||||
|
)
|
||||||
|
scan_threads = []
|
||||||
|
for i in range(0, self._num_workers):
|
||||||
|
t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args)
|
||||||
|
t.start()
|
||||||
|
scan_threads.append(t)
|
||||||
|
|
||||||
|
for t in scan_threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
def _scan_ips(
|
||||||
|
self,
|
||||||
|
ips_to_scan: Queue,
|
||||||
|
icmp_config: Dict,
|
||||||
|
tcp_config: Dict,
|
||||||
|
report_results_callback: Callback,
|
||||||
|
stop: Event,
|
||||||
|
):
|
||||||
|
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not stop.is_set():
|
||||||
|
ip = ips_to_scan.get_nowait()
|
||||||
|
logger.info(f"Scanning {ip}")
|
||||||
|
|
||||||
|
victim_host = VictimHost(ip)
|
||||||
|
|
||||||
|
self._ping_ip(ip, victim_host, icmp_config)
|
||||||
|
self._scan_tcp_ports(ip, victim_host, tcp_config, stop)
|
||||||
|
|
||||||
|
report_results_callback(victim_host)
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
logger.debug(
|
||||||
|
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting")
|
||||||
|
|
||||||
|
def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict):
|
||||||
|
(response_received, os) = self._puppet.ping(ip, options)
|
||||||
|
|
||||||
|
victim_host.icmp = response_received
|
||||||
|
if os is not None:
|
||||||
|
victim_host.os["type"] = os
|
||||||
|
|
||||||
|
def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict, stop: Event):
|
||||||
|
for p in options["ports"]:
|
||||||
|
if stop.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
|
||||||
|
if port_scan_data.status == PortStatus.OPEN:
|
||||||
|
victim_host.services[port_scan_data.service] = {}
|
||||||
|
victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
|
||||||
|
victim_host.services[port_scan_data.service]["port"] = port_scan_data.port
|
||||||
|
if port_scan_data.banner is not None:
|
||||||
|
victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner
|
|
@ -0,0 +1,146 @@
|
||||||
|
from threading import Barrier, Event
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from infection_monkey.i_puppet import PortScanData
|
||||||
|
from infection_monkey.master import IPScanner
|
||||||
|
from infection_monkey.puppet.mock_puppet import MockPuppet
|
||||||
|
|
||||||
|
WINDOWS_OS = {"type": "windows"}
|
||||||
|
LINUX_OS = {"type": "linux"}
|
||||||
|
|
||||||
|
|
||||||
|
class MockPuppet(MockPuppet):
|
||||||
|
def __init__(self):
|
||||||
|
self.ping = MagicMock(side_effect=super().ping)
|
||||||
|
self.scan_tcp_port = MagicMock(side_effect=super().scan_tcp_port)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tcp_scan_config():
|
||||||
|
return {
|
||||||
|
"timeout_ms": 3000,
|
||||||
|
"ports": [
|
||||||
|
22,
|
||||||
|
445,
|
||||||
|
3389,
|
||||||
|
443,
|
||||||
|
8008,
|
||||||
|
3306,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def icmp_scan_config():
|
||||||
|
return {
|
||||||
|
"timeout_ms": 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def stop():
|
||||||
|
return Event()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def callback():
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_dot_1(victim_host):
|
||||||
|
assert victim_host.icmp is True
|
||||||
|
assert victim_host.os == WINDOWS_OS
|
||||||
|
|
||||||
|
assert len(victim_host.services.keys()) == 2
|
||||||
|
assert "tcp-445" in victim_host.services
|
||||||
|
assert victim_host.services["tcp-445"]["port"] == 445
|
||||||
|
assert victim_host.services["tcp-445"]["banner"] == "SMB BANNER"
|
||||||
|
assert "tcp-3389" in victim_host.services
|
||||||
|
assert victim_host.services["tcp-3389"]["port"] == 3389
|
||||||
|
|
||||||
|
|
||||||
|
def assert_dot_3(victim_host):
|
||||||
|
assert victim_host.icmp is True
|
||||||
|
assert victim_host.os == LINUX_OS
|
||||||
|
|
||||||
|
assert len(victim_host.services.keys()) == 2
|
||||||
|
assert "tcp-22" in victim_host.services
|
||||||
|
assert victim_host.services["tcp-22"]["port"] == 22
|
||||||
|
assert victim_host.services["tcp-22"]["banner"] == "SSH BANNER"
|
||||||
|
|
||||||
|
assert "tcp-443" in victim_host.services
|
||||||
|
assert victim_host.services["tcp-443"]["port"] == 443
|
||||||
|
assert victim_host.services["tcp-443"]["banner"] == "HTTPS BANNER"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_host_down(victim_host):
|
||||||
|
assert victim_host.icmp is False
|
||||||
|
assert len(victim_host.services.keys()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_single_ip(callback, icmp_scan_config, tcp_scan_config, stop):
|
||||||
|
ips = ["10.0.0.1"]
|
||||||
|
|
||||||
|
ns = IPScanner(MockPuppet(), num_workers=1)
|
||||||
|
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
|
||||||
|
|
||||||
|
callback.assert_called_once()
|
||||||
|
|
||||||
|
assert_dot_1(callback.call_args_list[0][0][0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_multiple_ips(callback, icmp_scan_config, tcp_scan_config, stop):
|
||||||
|
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
|
||||||
|
|
||||||
|
ns = IPScanner(MockPuppet(), num_workers=4)
|
||||||
|
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
|
||||||
|
|
||||||
|
assert callback.call_count == 4
|
||||||
|
|
||||||
|
assert_dot_1(callback.call_args_list[0][0][0])
|
||||||
|
assert_host_down(callback.call_args_list[1][0][0])
|
||||||
|
assert_dot_3(callback.call_args_list[2][0][0])
|
||||||
|
assert_host_down(callback.call_args_list[3][0][0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_after_callback(icmp_scan_config, tcp_scan_config, stop):
|
||||||
|
def _callback(_):
|
||||||
|
# Block all threads here until 2 threads reach this barrier, then set stop
|
||||||
|
# and test that niether thread continues to scan.
|
||||||
|
_callback.barrier.wait()
|
||||||
|
stop.set()
|
||||||
|
|
||||||
|
_callback.barrier = Barrier(2)
|
||||||
|
|
||||||
|
stopable_callback = MagicMock(side_effect=_callback)
|
||||||
|
|
||||||
|
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
|
||||||
|
|
||||||
|
ns = IPScanner(MockPuppet(), num_workers=2)
|
||||||
|
ns.scan(ips, icmp_scan_config, tcp_scan_config, stopable_callback, stop)
|
||||||
|
|
||||||
|
assert stopable_callback.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_interrupt_port_scanning(callback, icmp_scan_config, tcp_scan_config, stop):
|
||||||
|
def stopable_scan_tcp_port(port, _, __):
|
||||||
|
# Block all threads here until 2 threads reach this barrier, then set stop
|
||||||
|
# and test that niether thread scans any more ports
|
||||||
|
stopable_scan_tcp_port.barrier.wait()
|
||||||
|
stop.set()
|
||||||
|
|
||||||
|
return PortScanData(port, False, None, None)
|
||||||
|
|
||||||
|
stopable_scan_tcp_port.barrier = Barrier(2)
|
||||||
|
|
||||||
|
puppet = MockPuppet()
|
||||||
|
puppet.scan_tcp_port = MagicMock(side_effect=stopable_scan_tcp_port)
|
||||||
|
|
||||||
|
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
|
||||||
|
|
||||||
|
ns = IPScanner(puppet, num_workers=2)
|
||||||
|
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
|
||||||
|
|
||||||
|
assert puppet.scan_tcp_port.call_count == 2
|
Loading…
Reference in New Issue