Agent: Extract network scanner into its own class

This commit is contained in:
Mike Salvatore 2021-12-09 20:59:42 -05:00
parent 3f7dbbccc2
commit 81d4afab52
4 changed files with 270 additions and 66 deletions

View File

@ -1 +1,2 @@
from .ip_scanner import IPScanner
from .automated_master import AutomatedMaster from .automated_master import AutomatedMaster

View File

@ -1,5 +1,4 @@
import logging import logging
import queue
import threading import threading
import time import time
from queue import Queue from queue import Queue
@ -8,7 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet, PortStatus from infection_monkey.i_puppet import IPuppet
from infection_monkey.model.host import VictimHost from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
@ -16,6 +15,7 @@ from infection_monkey.telemetry.scan_telem import ScanTelem
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
from infection_monkey.utils.timer import Timer from infection_monkey.utils.timer import Timer
from . import IPScanner
from .threading_utils import create_daemon_thread from .threading_utils import create_daemon_thread
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5 CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
@ -37,6 +37,9 @@ class AutomatedMaster(IMaster):
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
self._control_channel = control_channel self._control_channel = control_channel
self._ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
self._hosts_to_exploit = None
self._stop = threading.Event() self._stop = threading.Event()
self._master_thread = create_daemon_thread(target=self._run_master_thread) self._master_thread = create_daemon_thread(target=self._run_master_thread)
self._simulation_thread = create_daemon_thread(target=self._run_simulation) self._simulation_thread = create_daemon_thread(target=self._run_simulation)
@ -156,17 +159,16 @@ class AutomatedMaster(IMaster):
def _can_propagate(self): def _can_propagate(self):
return True return True
# TODO: Refactor propagation into its own class
def _propagate(self, config: Dict): def _propagate(self, config: Dict):
logger.info("Attempting to propagate") logger.info("Attempting to propagate")
hosts_to_exploit = Queue() self._hosts_to_exploit = Queue()
scan_thread = create_daemon_thread( scan_thread = create_daemon_thread(
target=self._scan_network, args=(config["network_scan"], hosts_to_exploit) target=self._scan_network, args=(config["network_scan"],)
)
exploit_thread = create_daemon_thread(
target=self._exploit_targets, args=(hosts_to_exploit, scan_thread)
) )
exploit_thread = create_daemon_thread(target=self._exploit_targets, args=(scan_thread,))
scan_thread.start() scan_thread.start()
exploit_thread.start() exploit_thread.start()
@ -176,73 +178,28 @@ class AutomatedMaster(IMaster):
logger.info("Finished attempting to propagate") logger.info("Finished attempting to propagate")
def _exploit_targets(self, hosts_to_exploit: Queue, scan_thread: Thread): def _scan_network(self, scan_config: Dict):
pass
# TODO: Refactor this into its own class
def _scan_network(self, scan_config: Dict, hosts_to_exploit: Queue):
logger.info("Starting network scan") logger.info("Starting network scan")
# TODO: Generate list of IPs to scan # TODO: Generate list of IPs to scan
ips_to_scan = Queue() ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)]
for i in range(1, 255):
ips_to_scan.put(f"10.0.0.{i}")
scan_threads = [] self._ip_scanner.scan(
for i in range(0, NUM_SCAN_THREADS): ips_to_scan,
t = create_daemon_thread( scan_config["icmp"],
target=self._scan_ips, args=(ips_to_scan, scan_config, hosts_to_exploit) scan_config["tcp"],
self._handle_scanned_host,
self._stop,
) )
t.start()
scan_threads.append(t)
for t in scan_threads:
t.join()
logger.info("Finished network scan") logger.info("Finished network scan")
def _scan_ips(self, ips_to_scan: Queue, scan_config: Dict, hosts_to_exploit: Queue): def _handle_scanned_host(self, host: VictimHost):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") self._hosts_to_exploit.put(host)
try: self._telemetry_messenger.send_telemetry(ScanTelem(host))
while not self._stop.is_set():
ip = ips_to_scan.get_nowait()
logger.info(f"Scanning {ip}")
victim_host = VictimHost(ip) def _exploit_targets(self, scan_thread: Thread):
pass
self._ping_ip(ip, victim_host, scan_config["icmp"])
self._scan_tcp_ports(ip, victim_host, scan_config["tcp"])
hosts_to_exploit.put(hosts_to_exploit)
self._telemetry_messenger.send_telemetry(ScanTelem(victim_host))
except queue.Empty:
logger.debug(
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
)
logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting")
def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict):
(response_received, os) = self._puppet.ping(ip, options)
victim_host.icmp = response_received
if os is not None:
victim_host.os["type"] = os
def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict):
for p in options["ports"]:
if self._stop.is_set():
break
# TODO: check units of timeout
port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
if port_scan_data.status == PortStatus.OPEN:
victim_host.services[port_scan_data.service] = {}
victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
victim_host.services[port_scan_data.service]["port"] = port_scan_data.port
if port_scan_data.banner is not None:
victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner
def _run_payload(self, payload: Tuple[str, Dict]): def _run_payload(self, payload: Tuple[str, Dict]):
name = payload[0] name = payload[0]

View File

@ -0,0 +1,100 @@
import logging
import queue
import threading
from queue import Queue
from threading import Event
from typing import Callable, Dict, List
from infection_monkey.i_puppet import IPuppet, PortStatus
from infection_monkey.model.host import VictimHost
from .threading_utils import create_daemon_thread
logger = logging.getLogger()
Callback = Callable[[VictimHost], None]
class IPScanner:
def __init__(self, puppet: IPuppet, num_workers: int):
self._puppet = puppet
self._num_workers = num_workers
def scan(
self,
ips: List[str],
icmp_config: Dict,
tcp_config: Dict,
report_results_callback: Callback,
stop: Event,
):
# Pre-fill a Queue with all IPs so that threads can safely exit when the queue is empty.
ips_to_scan = Queue()
for ip in ips:
ips_to_scan.put(ip)
scan_ips_args = (
ips_to_scan,
icmp_config,
tcp_config,
report_results_callback,
stop,
)
scan_threads = []
for i in range(0, self._num_workers):
t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args)
t.start()
scan_threads.append(t)
for t in scan_threads:
t.join()
def _scan_ips(
self,
ips_to_scan: Queue,
icmp_config: Dict,
tcp_config: Dict,
report_results_callback: Callback,
stop: Event,
):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
try:
while not stop.is_set():
ip = ips_to_scan.get_nowait()
logger.info(f"Scanning {ip}")
victim_host = VictimHost(ip)
self._ping_ip(ip, victim_host, icmp_config)
self._scan_tcp_ports(ip, victim_host, tcp_config, stop)
report_results_callback(victim_host)
except queue.Empty:
logger.debug(
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
)
return
logger.debug(f"Detected the stop signal, scanning thread {threading.get_ident()} exiting")
def _ping_ip(self, ip: str, victim_host: VictimHost, options: Dict):
(response_received, os) = self._puppet.ping(ip, options)
victim_host.icmp = response_received
if os is not None:
victim_host.os["type"] = os
def _scan_tcp_ports(self, ip: str, victim_host: VictimHost, options: Dict, stop: Event):
for p in options["ports"]:
if stop.is_set():
break
port_scan_data = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"])
if port_scan_data.status == PortStatus.OPEN:
victim_host.services[port_scan_data.service] = {}
victim_host.services[port_scan_data.service]["display_name"] = "unknown(TCP)"
victim_host.services[port_scan_data.service]["port"] = port_scan_data.port
if port_scan_data.banner is not None:
victim_host.services[port_scan_data.service]["banner"] = port_scan_data.banner

View File

@ -0,0 +1,146 @@
from threading import Barrier, Event
from unittest.mock import MagicMock
import pytest
from infection_monkey.i_puppet import PortScanData
from infection_monkey.master import IPScanner
from infection_monkey.puppet.mock_puppet import MockPuppet
WINDOWS_OS = {"type": "windows"}
LINUX_OS = {"type": "linux"}
class MockPuppet(MockPuppet):
def __init__(self):
self.ping = MagicMock(side_effect=super().ping)
self.scan_tcp_port = MagicMock(side_effect=super().scan_tcp_port)
@pytest.fixture
def tcp_scan_config():
return {
"timeout_ms": 3000,
"ports": [
22,
445,
3389,
443,
8008,
3306,
],
}
@pytest.fixture
def icmp_scan_config():
return {
"timeout_ms": 1000,
}
@pytest.fixture
def stop():
return Event()
@pytest.fixture
def callback():
return MagicMock()
def assert_dot_1(victim_host):
assert victim_host.icmp is True
assert victim_host.os == WINDOWS_OS
assert len(victim_host.services.keys()) == 2
assert "tcp-445" in victim_host.services
assert victim_host.services["tcp-445"]["port"] == 445
assert victim_host.services["tcp-445"]["banner"] == "SMB BANNER"
assert "tcp-3389" in victim_host.services
assert victim_host.services["tcp-3389"]["port"] == 3389
def assert_dot_3(victim_host):
assert victim_host.icmp is True
assert victim_host.os == LINUX_OS
assert len(victim_host.services.keys()) == 2
assert "tcp-22" in victim_host.services
assert victim_host.services["tcp-22"]["port"] == 22
assert victim_host.services["tcp-22"]["banner"] == "SSH BANNER"
assert "tcp-443" in victim_host.services
assert victim_host.services["tcp-443"]["port"] == 443
assert victim_host.services["tcp-443"]["banner"] == "HTTPS BANNER"
def assert_host_down(victim_host):
assert victim_host.icmp is False
assert len(victim_host.services.keys()) == 0
def test_scan_single_ip(callback, icmp_scan_config, tcp_scan_config, stop):
ips = ["10.0.0.1"]
ns = IPScanner(MockPuppet(), num_workers=1)
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
callback.assert_called_once()
assert_dot_1(callback.call_args_list[0][0][0])
def test_scan_multiple_ips(callback, icmp_scan_config, tcp_scan_config, stop):
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
assert callback.call_count == 4
assert_dot_1(callback.call_args_list[0][0][0])
assert_host_down(callback.call_args_list[1][0][0])
assert_dot_3(callback.call_args_list[2][0][0])
assert_host_down(callback.call_args_list[3][0][0])
def test_stop_after_callback(icmp_scan_config, tcp_scan_config, stop):
def _callback(_):
# Block all threads here until 2 threads reach this barrier, then set stop
# and test that niether thread continues to scan.
_callback.barrier.wait()
stop.set()
_callback.barrier = Barrier(2)
stopable_callback = MagicMock(side_effect=_callback)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(MockPuppet(), num_workers=2)
ns.scan(ips, icmp_scan_config, tcp_scan_config, stopable_callback, stop)
assert stopable_callback.call_count == 2
def test_interrupt_port_scanning(callback, icmp_scan_config, tcp_scan_config, stop):
def stopable_scan_tcp_port(port, _, __):
# Block all threads here until 2 threads reach this barrier, then set stop
# and test that niether thread scans any more ports
stopable_scan_tcp_port.barrier.wait()
stop.set()
return PortScanData(port, False, None, None)
stopable_scan_tcp_port.barrier = Barrier(2)
puppet = MockPuppet()
puppet.scan_tcp_port = MagicMock(side_effect=stopable_scan_tcp_port)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(puppet, num_workers=2)
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop)
assert puppet.scan_tcp_port.call_count == 2