Agent: Extract propagation logic into Propagator class

This commit is contained in:
Mike Salvatore 2021-12-10 12:48:45 -05:00
parent abec851ed0
commit 6147d635d6
4 changed files with 168 additions and 69 deletions

View File

@ -1,2 +1,3 @@
from .ip_scanner import IPScanner from .ip_scanner import IPScanner
from .propagator import Propagator
from .automated_master import AutomatedMaster from .automated_master import AutomatedMaster

View File

@ -1,21 +1,17 @@
import logging import logging
import threading import threading
import time import time
from queue import Queue
from threading import Thread
from typing import Any, Callable, Dict, List, Tuple from typing import Any, Callable, Dict, List, Tuple
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet, PingScanData, PortScanData, PortStatus from infection_monkey.i_puppet import IPuppet
from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.telemetry.scan_telem import ScanTelem
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
from infection_monkey.utils.timer import Timer from infection_monkey.utils.timer import Timer
from . import IPScanner from . import IPScanner, Propagator
from .threading_utils import create_daemon_thread from .threading_utils import create_daemon_thread
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5 CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
@ -37,8 +33,8 @@ class AutomatedMaster(IMaster):
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
self._control_channel = control_channel self._control_channel = control_channel
self._ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS) ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
self._hosts_to_exploit = None self._propagator = Propagator(self._telemetry_messenger, ip_scanner)
self._stop = threading.Event() self._stop = threading.Event()
self._master_thread = create_daemon_thread(target=self._run_master_thread) self._master_thread = create_daemon_thread(target=self._run_master_thread)
@ -121,11 +117,7 @@ class AutomatedMaster(IMaster):
system_info_collector_thread.join() system_info_collector_thread.join()
if self._can_propagate(): if self._can_propagate():
propagation_thread = create_daemon_thread( self._propagator.propagate(config["propagation"], self._stop)
target=self._propagate, args=(config["propagation"],)
)
propagation_thread.start()
propagation_thread.join()
payload_thread = create_daemon_thread( payload_thread = create_daemon_thread(
target=self._run_plugins, target=self._run_plugins,
@ -161,62 +153,6 @@ class AutomatedMaster(IMaster):
def _can_propagate(self): def _can_propagate(self):
return True return True
# TODO: Refactor propagation into its own class
def _propagate(self, propagation_config: Dict):
logger.info("Attempting to propagate")
self._hosts_to_exploit = Queue()
scan_thread = create_daemon_thread(target=self._scan_network, args=(propagation_config,))
exploit_thread = create_daemon_thread(target=self._exploit_targets, args=(scan_thread,))
scan_thread.start()
exploit_thread.start()
scan_thread.join()
exploit_thread.join()
logger.info("Finished attempting to propagate")
def _scan_network(self, propagation_config: Dict):
logger.info("Starting network scan")
# TODO: Generate list of IPs to scan
ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)]
scan_config = propagation_config["network_scan"]
self._ip_scanner.scan(ips_to_scan, scan_config, self._process_scan_results, self._stop)
logger.info("Finished network scan")
def _process_scan_results(
self, ip: str, ping_scan_data: PingScanData, port_scan_data: PortScanData
):
victim_host = VictimHost(ip)
has_open_port = False
victim_host.icmp = ping_scan_data.response_received
if ping_scan_data.os is not None:
victim_host.os["type"] = ping_scan_data.os
for psd in port_scan_data.values():
if psd.status == PortStatus.OPEN:
has_open_port = True
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
if has_open_port:
self._hosts_to_exploit.put(victim_host)
self._telemetry_messenger.send_telemetry(ScanTelem(victim_host))
def _exploit_targets(self, scan_thread: Thread):
pass
def _run_payload(self, payload: Tuple[str, Dict]): def _run_payload(self, payload: Tuple[str, Dict]):
name = payload[0] name = payload[0]
options = payload[1] options = payload[1]

View File

@ -0,0 +1,80 @@
import logging
from queue import Queue
from threading import Event, Thread
from typing import Dict
from infection_monkey.i_puppet import PingScanData, PortScanData, PortStatus
from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.scan_telem import ScanTelem
from . import IPScanner
from .threading_utils import create_daemon_thread
logger = logging.getLogger()
class Propagator:
def __init__(self, telemetry_messenger: ITelemetryMessenger, ip_scanner: IPScanner):
self._telemetry_messenger = telemetry_messenger
self._ip_scanner = ip_scanner
self._hosts_to_exploit = None
def propagate(self, propagation_config: Dict, stop: Event):
logger.info("Attempting to propagate")
self._hosts_to_exploit = Queue()
scan_thread = create_daemon_thread(
target=self._scan_network, args=(propagation_config, stop)
)
exploit_thread = create_daemon_thread(
target=self._exploit_targets, args=(scan_thread, stop)
)
scan_thread.start()
exploit_thread.start()
scan_thread.join()
exploit_thread.join()
logger.info("Finished attempting to propagate")
def _scan_network(self, propagation_config: Dict, stop: Event):
logger.info("Starting network scan")
# TODO: Generate list of IPs to scan from propagation targets config
ips_to_scan = propagation_config["targets"]["subnet_scan_list"]
scan_config = propagation_config["network_scan"]
self._ip_scanner.scan(ips_to_scan, scan_config, self._process_scan_results, stop)
logger.info("Finished network scan")
def _process_scan_results(
self, ip: str, ping_scan_data: PingScanData, port_scan_data: PortScanData
):
victim_host = VictimHost(ip)
has_open_port = False
victim_host.icmp = ping_scan_data.response_received
if ping_scan_data.os is not None:
victim_host.os["type"] = ping_scan_data.os
for psd in port_scan_data.values():
if psd.status == PortStatus.OPEN:
has_open_port = True
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
if has_open_port:
self._hosts_to_exploit.put(victim_host)
self._telemetry_messenger.send_telemetry(ScanTelem(victim_host))
def _exploit_targets(self, scan_thread: Thread, stop: Event):
pass

View File

@ -0,0 +1,82 @@
from threading import Event
from infection_monkey.i_puppet import PingScanData, PortScanData, PortStatus
from infection_monkey.master import Propagator
dot_1_results = (
PingScanData(True, "windows"),
{
22: PortScanData(22, PortStatus.CLOSED, None, None),
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
3389: PortScanData(3389, PortStatus.OPEN, "", "tcp-3389"),
},
)
dot_3_results = (
PingScanData(True, "linux"),
{
22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"),
443: PortScanData(443, PortStatus.OPEN, "HTTPS BANNER", "tcp-443"),
3389: PortScanData(3389, PortStatus.CLOSED, "", None),
},
)
dead_host_results = (
PingScanData(False, None),
{
22: PortScanData(22, PortStatus.CLOSED, None, None),
443: PortScanData(443, PortStatus.CLOSED, None, None),
3389: PortScanData(3389, PortStatus.CLOSED, "", None),
},
)
dot_1_services = {
"tcp-445": {"display_name": "unknown(TCP)", "port": 445, "banner": "SMB BANNER"},
"tcp-3389": {"display_name": "unknown(TCP)", "port": 3389, "banner": ""},
}
dot_3_services = {
"tcp-22": {"display_name": "unknown(TCP)", "port": 22, "banner": "SSH BANNER"},
"tcp-443": {"display_name": "unknown(TCP)", "port": 443, "banner": "HTTPS BANNER"},
}
class MockIPScanner:
def scan(self, ips_to_scan, options, results_callback, stop):
for ip in ips_to_scan:
if ip.endswith(".1"):
results_callback(ip, *dot_1_results)
elif ip.endswith(".3"):
results_callback(ip, *dot_3_results)
else:
results_callback(ip, *dead_host_results)
def test_scan_result_processing(telemetry_messenger_spy):
p = Propagator(telemetry_messenger_spy, MockIPScanner())
p.propagate(
{"targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}, "network_scan": {}},
Event(),
)
assert len(telemetry_messenger_spy.telemetries) == 3
for t in telemetry_messenger_spy.telemetries:
data = t.get_data()
ip = data["machine"]["ip_addr"]
if ip.endswith(".1"):
assert data["service_count"] == 2
assert data["machine"]["os"]["type"] == "windows"
assert data["machine"]["services"] == dot_1_services
assert data["machine"]["icmp"] is True
elif ip.endswith(".3"):
assert data["service_count"] == 2
assert data["machine"]["os"]["type"] == "linux"
assert data["machine"]["services"] == dot_3_services
assert data["machine"]["icmp"] is True
else:
assert data["service_count"] == 0
assert data["machine"]["os"] == {}
assert data["machine"]["services"] == {}
assert data["machine"]["icmp"] is False