Merge pull request #1652 from guardicore/1597-implement-propagation-scanning

1597 implement propagation scanning
This commit is contained in:
Mike Salvatore 2021-12-13 09:33:16 -05:00 committed by GitHub
commit f2e95daa56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 631 additions and 53 deletions

View File

@ -36,7 +36,6 @@
], ],
"finger_classes": [ "finger_classes": [
"SSHFinger", "SSHFinger",
"PingScanner",
"HTTPFinger", "HTTPFinger",
"SMBFinger", "SMBFinger",
"MySQLFinger", "MySQLFinger",

View File

@ -2,7 +2,7 @@ import abc
import threading import threading
from collections import namedtuple from collections import namedtuple
from enum import Enum from enum import Enum
from typing import Dict, Optional, Tuple from typing import Dict
class PortStatus(Enum): class PortStatus(Enum):
@ -11,6 +11,7 @@ class PortStatus(Enum):
ExploiterResultData = namedtuple("ExploiterResultData", ["result", "info", "attempts"]) ExploiterResultData = namedtuple("ExploiterResultData", ["result", "info", "attempts"])
PingScanData = namedtuple("PingScanData", ["response_received", "os"])
PortScanData = namedtuple("PortScanData", ["port", "status", "banner", "service"]) PortScanData = namedtuple("PortScanData", ["port", "status", "banner", "service"])
PostBreachData = namedtuple("PostBreachData", ["command", "result"]) PostBreachData = namedtuple("PostBreachData", ["command", "result"])
@ -35,22 +36,22 @@ class IPuppet(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def ping(self, host: str) -> Tuple[bool, Optional[str]]: def ping(self, host: str, timeout: float) -> PingScanData:
""" """
Sends a ping (ICMP packet) to a remote host Sends a ping (ICMP packet) to a remote host
:param str host: The domain name or IP address of a host :param str host: The domain name or IP address of a host
:return: A tuple that contains whether or not the host responded and the host's inferred :param float timeout: The maximum amount of time (in seconds) to wait for a response
operating system :return: The data collected by attempting to ping the target host
:rtype: Tuple[bool, Optional[str]] :rtype: PingScanData
""" """
@abc.abstractmethod @abc.abstractmethod
def scan_tcp_port(self, host: str, port: int, timeout: int) -> PortScanData: def scan_tcp_port(self, host: str, port: int, timeout: float) -> PortScanData:
""" """
Scans a TCP port on a remote host Scans a TCP port on a remote host
:param str host: The domain name or IP address of a host :param str host: The domain name or IP address of a host
:param int port: A TCP port number to scan :param int port: A TCP port number to scan
:param int timeout: The maximum amount of time (in seconds) to wait for a response :param float timeout: The maximum amount of time (in seconds) to wait for a response
:return: The data collected by scanning the provided host:port combination :return: The data collected by scanning the provided host:port combination
:rtype: PortScanData :rtype: PortScanData
""" """

View File

@ -0,0 +1,3 @@
from .ip_scanner import IPScanner
from .propagator import Propagator
from .automated_master import AutomatedMaster

View File

@ -11,9 +11,13 @@ from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
from infection_monkey.utils.timer import Timer from infection_monkey.utils.timer import Timer
from . import IPScanner, Propagator
from .threading_utils import create_daemon_thread
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5 CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5 CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
SHUTDOWN_TIMEOUT = 5 SHUTDOWN_TIMEOUT = 5
NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads
logger = logging.getLogger() logger = logging.getLogger()
@ -29,9 +33,12 @@ class AutomatedMaster(IMaster):
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
self._control_channel = control_channel self._control_channel = control_channel
ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
self._propagator = Propagator(self._telemetry_messenger, ip_scanner)
self._stop = threading.Event() self._stop = threading.Event()
self._master_thread = _create_daemon_thread(target=self._run_master_thread) self._master_thread = create_daemon_thread(target=self._run_master_thread)
self._simulation_thread = _create_daemon_thread(target=self._run_simulation) self._simulation_thread = create_daemon_thread(target=self._run_simulation)
def start(self): def start(self):
logger.info("Starting automated breach and attack simulation") logger.info("Starting automated breach and attack simulation")
@ -87,7 +94,7 @@ class AutomatedMaster(IMaster):
def _run_simulation(self): def _run_simulation(self):
config = self._control_channel.get_config() config = self._control_channel.get_config()
system_info_collector_thread = _create_daemon_thread( system_info_collector_thread = create_daemon_thread(
target=self._run_plugins, target=self._run_plugins,
args=( args=(
config["system_info_collector_classes"], config["system_info_collector_classes"],
@ -95,7 +102,7 @@ class AutomatedMaster(IMaster):
self._collect_system_info, self._collect_system_info,
), ),
) )
pba_thread = _create_daemon_thread( pba_thread = create_daemon_thread(
target=self._run_plugins, target=self._run_plugins,
args=(config["post_breach_actions"].items(), "post-breach action", self._run_pba), args=(config["post_breach_actions"].items(), "post-breach action", self._run_pba),
) )
@ -110,11 +117,9 @@ class AutomatedMaster(IMaster):
system_info_collector_thread.join() system_info_collector_thread.join()
if self._can_propagate(): if self._can_propagate():
propagation_thread = _create_daemon_thread(target=self._propagate, args=(config,)) self._propagator.propagate(config["propagation"], self._stop)
propagation_thread.start()
propagation_thread.join()
payload_thread = _create_daemon_thread( payload_thread = create_daemon_thread(
target=self._run_plugins, target=self._run_plugins,
args=(config["payloads"].items(), "payload", self._run_payload), args=(config["payloads"].items(), "payload", self._run_payload),
) )
@ -148,9 +153,6 @@ class AutomatedMaster(IMaster):
def _can_propagate(self): def _can_propagate(self):
return True return True
def _propagate(self, config: Dict):
pass
def _run_payload(self, payload: Tuple[str, Dict]): def _run_payload(self, payload: Tuple[str, Dict]):
name = payload[0] name = payload[0]
options = payload[1] options = payload[1]
@ -172,7 +174,3 @@ class AutomatedMaster(IMaster):
def cleanup(self): def cleanup(self):
pass pass
def _create_daemon_thread(target: Callable[[Any], None], args: Tuple[Any] = ()):
return threading.Thread(target=target, args=args, daemon=True)

View File

@ -6,7 +6,7 @@ import requests
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from infection_monkey.config import WormConfiguration from infection_monkey.config import WormConfiguration
from infection_monkey.control import ControlClient from infection_monkey.control import ControlClient
from monkey.infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()
@ -23,8 +23,12 @@ class ControlChannel(IControlChannel):
logger.error("Agent should stop because it can't connect to the C&C server.") logger.error("Agent should stop because it can't connect to the C&C server.")
return True return True
try: try:
url = (
f"https://{self._control_channel_server}/api/monkey_control"
f"/needs-to-stop/{self._agent_id}"
)
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
f"https://{self._control_channel_server}/api/monkey_control/needs-to-stop/{self._agent_id}", url,
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=ControlClient.proxies,
timeout=SHORT_REQUEST_TIMEOUT, timeout=SHORT_REQUEST_TIMEOUT,

View File

@ -0,0 +1,76 @@
import logging
import queue
import threading
from queue import Queue
from threading import Event
from typing import Callable, Dict, List
from infection_monkey.i_puppet import IPuppet, PingScanData, PortScanData
from .threading_utils import create_daemon_thread
logger = logging.getLogger()
IP = str
Port = int
Callback = Callable[[IP, PingScanData, Dict[Port, PortScanData]], None]
class IPScanner:
def __init__(self, puppet: IPuppet, num_workers: int):
self._puppet = puppet
self._num_workers = num_workers
def scan(self, ips_to_scan: List[str], options: Dict, results_callback: Callback, stop: Event):
# Pre-fill a Queue with all IPs to scan so that threads know they can safely exit when the
# queue is empty.
ips = Queue()
for ip in ips_to_scan:
ips.put(ip)
scan_ips_args = (ips, options, results_callback, stop)
scan_threads = []
for i in range(0, self._num_workers):
t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args)
t.start()
scan_threads.append(t)
for t in scan_threads:
t.join()
def _scan_ips(self, ips: Queue, options: Dict, results_callback: Callback, stop: Event):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
try:
while not stop.is_set():
ip = ips.get_nowait()
logger.info(f"Scanning {ip}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
ping_scan_data = self._puppet.ping(ip, icmp_timeout)
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
tcp_ports = options["tcp"]["ports"]
port_scan_data = self._scan_tcp_ports(ip, tcp_ports, tcp_timeout, stop)
results_callback(ip, ping_scan_data, port_scan_data)
logger.debug(
f"Detected the stop signal, scanning thread {threading.get_ident()} exiting"
)
except queue.Empty:
logger.debug(
f"ips_to_scan queue is empty, scanning thread {threading.get_ident()} exiting"
)
def _scan_tcp_ports(self, ip: str, ports: List[int], timeout: float, stop: Event):
port_scan_data = {}
for p in ports:
if stop.is_set():
break
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
return port_scan_data

View File

@ -66,10 +66,10 @@ class MockMaster(IMaster):
for ip in ips: for ip in ips:
h = self._hosts[ip] h = self._hosts[ip]
(response_received, os) = self._puppet.ping(ip) ping_scan_data = self._puppet.ping(ip, 1)
h.icmp = response_received h.icmp = ping_scan_data.response_received
if os is not None: if ping_scan_data.os is not None:
h.os["type"] = os h.os["type"] = ping_scan_data.os
for p in ports: for p in ports:
port_scan_data = self._puppet.scan_tcp_port(ip, p) port_scan_data = self._puppet.scan_tcp_port(ip, p)

View File

@ -0,0 +1,80 @@
import logging
from queue import Queue
from threading import Event, Thread
from typing import Dict
from infection_monkey.i_puppet import PingScanData, PortScanData, PortStatus
from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.scan_telem import ScanTelem
from . import IPScanner
from .threading_utils import create_daemon_thread
logger = logging.getLogger()
class Propagator:
def __init__(self, telemetry_messenger: ITelemetryMessenger, ip_scanner: IPScanner):
self._telemetry_messenger = telemetry_messenger
self._ip_scanner = ip_scanner
self._hosts_to_exploit = None
def propagate(self, propagation_config: Dict, stop: Event):
logger.info("Attempting to propagate")
self._hosts_to_exploit = Queue()
scan_thread = create_daemon_thread(
target=self._scan_network, args=(propagation_config, stop)
)
exploit_thread = create_daemon_thread(
target=self._exploit_targets, args=(scan_thread, stop)
)
scan_thread.start()
exploit_thread.start()
scan_thread.join()
exploit_thread.join()
logger.info("Finished attempting to propagate")
def _scan_network(self, propagation_config: Dict, stop: Event):
logger.info("Starting network scan")
# TODO: Generate list of IPs to scan from propagation targets config
ips_to_scan = propagation_config["targets"]["subnet_scan_list"]
scan_config = propagation_config["network_scan"]
self._ip_scanner.scan(ips_to_scan, scan_config, self._process_scan_results, stop)
logger.info("Finished network scan")
def _process_scan_results(
self, ip: str, ping_scan_data: PingScanData, port_scan_data: Dict[int, PortScanData]
):
victim_host = VictimHost(ip)
has_open_port = False
victim_host.icmp = ping_scan_data.response_received
if ping_scan_data.os is not None:
victim_host.os["type"] = ping_scan_data.os
for psd in port_scan_data.values():
if psd.status == PortStatus.OPEN:
has_open_port = True
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
if has_open_port:
self._hosts_to_exploit.put(victim_host)
self._telemetry_messenger.send_telemetry(ScanTelem(victim_host))
def _exploit_targets(self, scan_thread: Thread, stop: Event):
pass

View File

@ -0,0 +1,6 @@
from threading import Thread
from typing import Callable, Tuple
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
return Thread(target=target, args=args, daemon=True)

View File

@ -1,10 +1,11 @@
import logging import logging
import threading import threading
from typing import Dict, Optional, Tuple from typing import Dict, Tuple
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
IPuppet, IPuppet,
PingScanData,
PortScanData, PortScanData,
PortStatus, PortStatus,
PostBreachData, PostBreachData,
@ -155,21 +156,21 @@ class MockPuppet(IPuppet):
else: else:
return PostBreachData("pba command 2", ["pba result 2", False]) return PostBreachData("pba command 2", ["pba result 2", False])
def ping(self, host: str) -> Tuple[bool, Optional[str]]: def ping(self, host: str, timeout: float = 1) -> PingScanData:
logger.debug(f"run_ping({host})") logger.debug(f"run_ping({host}, {timeout})")
if host == DOT_1: if host == DOT_1:
return (True, "windows") return PingScanData(True, "windows")
if host == DOT_2: if host == DOT_2:
return (False, None) return PingScanData(False, None)
if host == DOT_3: if host == DOT_3:
return (True, "linux") return PingScanData(True, "linux")
if host == DOT_4: if host == DOT_4:
return (False, None) return PingScanData(False, None)
return (False, None) return PingScanData(False, None)
def scan_tcp_port(self, host: str, port: int, timeout: int = 3) -> PortScanData: def scan_tcp_port(self, host: str, port: int, timeout: int = 3) -> PortScanData:
logger.debug(f"run_scan_tcp_port({host}, {port}, {timeout})") logger.debug(f"run_scan_tcp_port({host}, {port}, {timeout})")
@ -278,4 +279,4 @@ class MockPuppet(IPuppet):
def _get_empty_results(port: int): def _get_empty_results(port: int):
return PortScanData(port, False, None, None) return PortScanData(port, PortStatus.CLOSED, None, None)

View File

@ -2,7 +2,7 @@ import collections
import copy import copy
import functools import functools
import logging import logging
from typing import Dict from typing import Dict, List
from jsonschema import Draft4Validator, validators from jsonschema import Draft4Validator, validators
@ -419,6 +419,7 @@ class ConfigService:
ConfigService._remove_credentials_from_flat_config(config) ConfigService._remove_credentials_from_flat_config(config)
ConfigService._format_payloads_from_flat_config(config) ConfigService._format_payloads_from_flat_config(config)
ConfigService._format_pbas_from_flat_config(config) ConfigService._format_pbas_from_flat_config(config)
ConfigService._format_propagation_from_flat_config(config)
@staticmethod @staticmethod
def _remove_credentials_from_flat_config(config: Dict): def _remove_credentials_from_flat_config(config: Dict):
@ -462,3 +463,95 @@ class ConfigService:
config.pop(flat_linux_filename_field, None) config.pop(flat_linux_filename_field, None)
config.pop(flat_windows_command_field, None) config.pop(flat_windows_command_field, None)
config.pop(flat_windows_filename_field, None) config.pop(flat_windows_filename_field, None)
@staticmethod
def _format_propagation_from_flat_config(config: Dict):
formatted_propagation_config = {"network_scan": {}, "targets": {}}
formatted_propagation_config[
"network_scan"
] = ConfigService._format_network_scan_from_flat_config(config)
formatted_propagation_config["targets"] = ConfigService._format_targets_from_flat_config(
config
)
config["propagation"] = formatted_propagation_config
@staticmethod
def _format_network_scan_from_flat_config(config: Dict):
formatted_network_scan_config = {"tcp": {}, "icmp": {}}
formatted_network_scan_config["tcp"] = ConfigService._format_tcp_scan_from_flat_config(
config
)
formatted_network_scan_config["icmp"] = ConfigService._format_icmp_scan_from_flat_config(
config
)
return formatted_network_scan_config
@staticmethod
def _format_tcp_scan_from_flat_config(config: Dict):
flat_http_ports_field = "HTTP_PORTS"
flat_tcp_timeout_field = "tcp_scan_timeout"
flat_tcp_ports_field = "tcp_target_ports"
formatted_tcp_scan_config = {}
formatted_tcp_scan_config["timeout_ms"] = config[flat_tcp_timeout_field]
ports = ConfigService._union_tcp_and_http_ports(
config[flat_tcp_ports_field], config[flat_http_ports_field]
)
formatted_tcp_scan_config["ports"] = ports
# Do not remove HTTP_PORTS field. Other components besides scanning need it.
config.pop(flat_tcp_timeout_field, None)
config.pop(flat_tcp_ports_field, None)
return formatted_tcp_scan_config
@staticmethod
def _union_tcp_and_http_ports(tcp_ports: List[int], http_ports: List[int]) -> List[int]:
combined_ports = list(set(tcp_ports) | set(http_ports))
return sorted(combined_ports)
@staticmethod
def _format_icmp_scan_from_flat_config(config: Dict):
flat_ping_timeout_field = "ping_scan_timeout"
formatted_icmp_scan_config = {}
formatted_icmp_scan_config["timeout_ms"] = config[flat_ping_timeout_field]
config.pop(flat_ping_timeout_field, None)
return formatted_icmp_scan_config
@staticmethod
def _format_targets_from_flat_config(config: Dict):
flat_blocked_ips_field = "blocked_ips"
flat_inaccessible_subnets_field = "inaccessible_subnets"
flat_local_network_scan_field = "local_network_scan"
flat_subnet_scan_list_field = "subnet_scan_list"
formatted_scan_targets_config = {}
formatted_scan_targets_config[flat_blocked_ips_field] = config[flat_blocked_ips_field]
formatted_scan_targets_config[flat_inaccessible_subnets_field] = config[
flat_inaccessible_subnets_field
]
formatted_scan_targets_config[flat_local_network_scan_field] = config[
flat_local_network_scan_field
]
formatted_scan_targets_config[flat_subnet_scan_list_field] = config[
flat_subnet_scan_list_field
]
config.pop(flat_blocked_ips_field, None)
config.pop(flat_inaccessible_subnets_field, None)
config.pop(flat_local_network_scan_field, None)
config.pop(flat_subnet_scan_list_field, None)
return formatted_scan_targets_config

View File

@ -20,13 +20,6 @@ FINGER_CLASSES = {
"info": "Figures out if SSH is running.", "info": "Figures out if SSH is running.",
"attack_techniques": ["T1210"], "attack_techniques": ["T1210"],
}, },
{
"type": "string",
"enum": ["PingScanner"],
"title": "Ping Scanner",
"safe": True,
"info": "Tries to identify if host is alive and which OS it's running by ping scan.",
},
{ {
"type": "string", "type": "string",
"enum": ["HTTPFinger"], "enum": ["HTTPFinger"],

View File

@ -165,7 +165,6 @@ INTERNAL = {
"default": [ "default": [
"SMBFinger", "SMBFinger",
"SSHFinger", "SSHFinger",
"PingScanner",
"HTTPFinger", "HTTPFinger",
"MySQLFinger", "MySQLFinger",
"MSSQLFinger", "MSSQLFinger",

View File

@ -13,7 +13,7 @@
"aws_access_key_id": "", "aws_access_key_id": "",
"aws_secret_access_key": "", "aws_secret_access_key": "",
"aws_session_token": "", "aws_session_token": "",
"blocked_ips": [], "blocked_ips": ["192.168.1.1", "192.168.1.100"],
"command_servers": [ "command_servers": [
"10.197.94.72:5000" "10.197.94.72:5000"
], ],
@ -65,13 +65,12 @@
"finger_classes": [ "finger_classes": [
"SMBFinger", "SMBFinger",
"SSHFinger", "SSHFinger",
"PingScanner",
"HTTPFinger", "HTTPFinger",
"MySQLFinger", "MySQLFinger",
"MSSQLFinger", "MSSQLFinger",
"ElasticFinger" "ElasticFinger"
], ],
"inaccessible_subnets": [], "inaccessible_subnets": ["10.0.0.0/24", "10.0.10.0/24"],
"keep_tunnel_open_time": 60, "keep_tunnel_open_time": 60,
"local_network_scan": true, "local_network_scan": true,
"max_depth": null, "max_depth": null,
@ -101,7 +100,7 @@
"skip_exploit_if_file_exist": false, "skip_exploit_if_file_exist": false,
"smb_download_timeout": 300, "smb_download_timeout": 300,
"smb_service_name": "InfectionMonkey", "smb_service_name": "InfectionMonkey",
"subnet_scan_list": [], "subnet_scan_list": ["192.168.1.50", "192.168.56.0/24", "10.0.33.0/30"],
"system_info_collector_classes": [ "system_info_collector_classes": [
"AwsCollector", "AwsCollector",
"ProcessListCollector", "ProcessListCollector",

View File

@ -100,7 +100,6 @@
"finger_classes": [ "finger_classes": [
"SMBFinger", "SMBFinger",
"SSHFinger", "SSHFinger",
"PingScanner",
"HTTPFinger", "HTTPFinger",
"MySQLFinger", "MySQLFinger",
"MSSQLFinger", "MSSQLFinger",

View File

@ -1,4 +1,5 @@
from infection_monkey.master.automated_master import AutomatedMaster from infection_monkey.master import AutomatedMaster
def test_terminate_without_start(): def test_terminate_without_start():
m = AutomatedMaster(None, None, None) m = AutomatedMaster(None, None, None)

View File

@ -0,0 +1,184 @@
from threading import Barrier, Event
from typing import Set
from unittest.mock import MagicMock
import pytest
from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.master import IPScanner
from infection_monkey.puppet.mock_puppet import MockPuppet
WINDOWS_OS = "windows"
LINUX_OS = "linux"
@pytest.fixture
def scan_config():
return {
"tcp": {
"timeout_ms": 3000,
"ports": [
22,
445,
3389,
443,
8008,
3306,
],
},
"icmp": {
"timeout_ms": 1000,
},
}
@pytest.fixture
def stop():
return Event()
@pytest.fixture
def callback():
return MagicMock()
def assert_port_status(port_scan_data, expected_open_ports: Set[int]):
for psd in port_scan_data.values():
if psd.port in expected_open_ports:
assert psd.status == PortStatus.OPEN
else:
assert psd.status == PortStatus.CLOSED
def assert_scan_results_no_1(ip, ping_scan_data, port_scan_data):
assert ip == "10.0.0.1"
assert ping_scan_data.response_received is True
assert ping_scan_data.os == WINDOWS_OS
assert len(port_scan_data.keys()) == 6
psd_445 = port_scan_data[445]
psd_3389 = port_scan_data[3389]
assert psd_445.port == 445
assert psd_445.banner == "SMB BANNER"
assert psd_445.service == "tcp-445"
assert psd_3389.port == 3389
assert psd_3389.banner == ""
assert psd_3389.service == "tcp-3389"
assert_port_status(port_scan_data, {445, 3389})
def assert_scan_results_no_3(ip, ping_scan_data, port_scan_data):
assert ip == "10.0.0.3"
assert ping_scan_data.response_received is True
assert ping_scan_data.os == LINUX_OS
assert len(port_scan_data.keys()) == 6
psd_443 = port_scan_data[443]
psd_22 = port_scan_data[22]
assert psd_443.port == 443
assert psd_443.banner == "HTTPS BANNER"
assert psd_443.service == "tcp-443"
assert psd_22.port == 22
assert psd_22.banner == "SSH BANNER"
assert psd_22.service == "tcp-22"
assert_port_status(port_scan_data, {22, 443})
def assert_scan_results_host_down(ip, ping_scan_data, port_scan_data):
assert ip not in {"10.0.0.1", "10.0.0.3"}
assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6
assert_port_status(port_scan_data, set())
def test_scan_single_ip(callback, scan_config, stop):
ips = ["10.0.0.1"]
ns = IPScanner(MockPuppet(), num_workers=1)
ns.scan(ips, scan_config, callback, stop)
callback.assert_called_once()
(ip, ping_scan_data, port_scan_data) = callback.call_args_list[0][0]
assert_scan_results_no_1(ip, ping_scan_data, port_scan_data)
def test_scan_multiple_ips(callback, scan_config, stop):
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, scan_config, callback, stop)
assert callback.call_count == 4
(ip, ping_scan_data, port_scan_data) = callback.call_args_list[0][0]
assert_scan_results_no_1(ip, ping_scan_data, port_scan_data)
(ip, ping_scan_data, port_scan_data) = callback.call_args_list[1][0]
assert_scan_results_host_down(ip, ping_scan_data, port_scan_data)
(ip, ping_scan_data, port_scan_data) = callback.call_args_list[2][0]
assert_scan_results_no_3(ip, ping_scan_data, port_scan_data)
(ip, ping_scan_data, port_scan_data) = callback.call_args_list[3][0]
assert_scan_results_host_down(ip, ping_scan_data, port_scan_data)
def test_scan_lots_of_ips(callback, scan_config, stop):
ips = [f"10.0.0.{i}" for i in range(0, 255)]
ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, scan_config, callback, stop)
assert callback.call_count == 255
def test_stop_after_callback(scan_config, stop):
def _callback(*_):
# Block all threads here until 2 threads reach this barrier, then set stop
# and test that neither thread continues to scan.
_callback.barrier.wait()
stop.set()
_callback.barrier = Barrier(2)
stopable_callback = MagicMock(side_effect=_callback)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(MockPuppet(), num_workers=2)
ns.scan(ips, scan_config, stopable_callback, stop)
assert stopable_callback.call_count == 2
def test_interrupt_port_scanning(callback, scan_config, stop):
def stopable_scan_tcp_port(port, *_):
# Block all threads here until 2 threads reach this barrier, then set stop
# and test that neither thread scans any more ports
stopable_scan_tcp_port.barrier.wait()
stop.set()
return PortScanData(port, False, None, None)
stopable_scan_tcp_port.barrier = Barrier(2)
puppet = MockPuppet()
puppet.scan_tcp_port = MagicMock(side_effect=stopable_scan_tcp_port)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(puppet, num_workers=2)
ns.scan(ips, scan_config, callback, stop)
assert puppet.scan_tcp_port.call_count == 2

View File

@ -0,0 +1,82 @@
from threading import Event
from infection_monkey.i_puppet import PingScanData, PortScanData, PortStatus
from infection_monkey.master import Propagator
dot_1_results = (
PingScanData(True, "windows"),
{
22: PortScanData(22, PortStatus.CLOSED, None, None),
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
3389: PortScanData(3389, PortStatus.OPEN, "", "tcp-3389"),
},
)
dot_3_results = (
PingScanData(True, "linux"),
{
22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"),
443: PortScanData(443, PortStatus.OPEN, "HTTPS BANNER", "tcp-443"),
3389: PortScanData(3389, PortStatus.CLOSED, "", None),
},
)
dead_host_results = (
PingScanData(False, None),
{
22: PortScanData(22, PortStatus.CLOSED, None, None),
443: PortScanData(443, PortStatus.CLOSED, None, None),
3389: PortScanData(3389, PortStatus.CLOSED, "", None),
},
)
dot_1_services = {
"tcp-445": {"display_name": "unknown(TCP)", "port": 445, "banner": "SMB BANNER"},
"tcp-3389": {"display_name": "unknown(TCP)", "port": 3389, "banner": ""},
}
dot_3_services = {
"tcp-22": {"display_name": "unknown(TCP)", "port": 22, "banner": "SSH BANNER"},
"tcp-443": {"display_name": "unknown(TCP)", "port": 443, "banner": "HTTPS BANNER"},
}
class MockIPScanner:
def scan(self, ips_to_scan, options, results_callback, stop):
for ip in ips_to_scan:
if ip.endswith(".1"):
results_callback(ip, *dot_1_results)
elif ip.endswith(".3"):
results_callback(ip, *dot_3_results)
else:
results_callback(ip, *dead_host_results)
def test_scan_result_processing(telemetry_messenger_spy):
p = Propagator(telemetry_messenger_spy, MockIPScanner())
p.propagate(
{"targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}, "network_scan": {}},
Event(),
)
assert len(telemetry_messenger_spy.telemetries) == 3
for t in telemetry_messenger_spy.telemetries:
data = t.get_data()
ip = data["machine"]["ip_addr"]
if ip.endswith(".1"):
assert data["service_count"] == 2
assert data["machine"]["os"]["type"] == "windows"
assert data["machine"]["services"] == dot_1_services
assert data["machine"]["icmp"] is True
elif ip.endswith(".3"):
assert data["service_count"] == 2
assert data["machine"]["os"]["type"] == "linux"
assert data["machine"]["services"] == dot_3_services
assert data["machine"]["icmp"] is True
else:
assert data["service_count"] == 0
assert data["machine"]["os"] == {}
assert data["machine"]["services"] == {}
assert data["machine"]["icmp"] is False

View File

@ -93,3 +93,63 @@ def test_get_config_propagation_credentials_from_flat_config(flat_monkey_config)
creds = ConfigService.get_config_propagation_credentials_from_flat_config(flat_monkey_config) creds = ConfigService.get_config_propagation_credentials_from_flat_config(flat_monkey_config)
assert creds == expected_creds assert creds == expected_creds
def test_format_config_for_agent__propagation(flat_monkey_config):
ConfigService.format_flat_config_for_agent(flat_monkey_config)
assert "propagation" in flat_monkey_config
assert "network_scan" in flat_monkey_config["propagation"]
assert "targets" in flat_monkey_config["propagation"]
def test_format_config_for_agent__propagation_targets(flat_monkey_config):
expected_targets = {
"blocked_ips": ["192.168.1.1", "192.168.1.100"],
"inaccessible_subnets": ["10.0.0.0/24", "10.0.10.0/24"],
"local_network_scan": True,
"subnet_scan_list": ["192.168.1.50", "192.168.56.0/24", "10.0.33.0/30"],
}
ConfigService.format_flat_config_for_agent(flat_monkey_config)
assert flat_monkey_config["propagation"]["targets"] == expected_targets
assert "blocked_ips" not in flat_monkey_config
assert "inaccessible_subnets" not in flat_monkey_config
assert "local_network_scan" not in flat_monkey_config
assert "subnet_scan_list" not in flat_monkey_config
def test_format_config_for_agent__network_scan(flat_monkey_config):
expected_network_scan_config = {
"tcp": {
"timeout_ms": 3000,
"ports": [
22,
80,
135,
443,
445,
2222,
3306,
3389,
7001,
8008,
8080,
8088,
9200,
],
},
"icmp": {
"timeout_ms": 1000,
},
}
ConfigService.format_flat_config_for_agent(flat_monkey_config)
assert "propagation" in flat_monkey_config
assert "network_scan" in flat_monkey_config["propagation"]
assert flat_monkey_config["propagation"]["network_scan"] == expected_network_scan_config
assert "tcp_scan_timeout" not in flat_monkey_config
assert "tcp_target_ports" not in flat_monkey_config
assert "ping_scan_timeout" not in flat_monkey_config