Merge branch '1597-integrate-automated-master' into agent-refactor

This commit is contained in:
Mike Salvatore 2021-12-20 06:55:53 -05:00
commit e392915b26
16 changed files with 456 additions and 125 deletions

View File

@ -1,4 +1,5 @@
import re import re
from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
@ -20,3 +21,11 @@ def remove_port(url):
with_port = f"{parsed.scheme}://{parsed.netloc}" with_port = f"{parsed.scheme}://{parsed.netloc}"
without_port = re.sub(":[0-9]+(?=$|/)", "", with_port) without_port = re.sub(":[0-9]+(?=$|/)", "", with_port)
return without_port return without_port
def address_to_ip_port(address: str) -> Tuple[str, Optional[str]]:
if ":" in address:
ip, port = address.split(":")
return ip, port or None
else:
return address, None

View File

@ -7,6 +7,7 @@ from infection_monkey.i_control_channel import IControlChannel, IslandCommunicat
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet from infection_monkey.i_puppet import IPuppet
from infection_monkey.model import VictimHostFactory from infection_monkey.model import VictimHostFactory
from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
@ -33,6 +34,7 @@ class AutomatedMaster(IMaster):
telemetry_messenger: ITelemetryMessenger, telemetry_messenger: ITelemetryMessenger,
victim_host_factory: VictimHostFactory, victim_host_factory: VictimHostFactory,
control_channel: IControlChannel, control_channel: IControlChannel,
local_network_interfaces: List[NetworkInterface],
): ):
self._puppet = puppet self._puppet = puppet
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
@ -41,7 +43,11 @@ class AutomatedMaster(IMaster):
ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS) ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS) exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS)
self._propagator = Propagator( self._propagator = Propagator(
self._telemetry_messenger, ip_scanner, exploiter, victim_host_factory self._telemetry_messenger,
ip_scanner,
exploiter,
victim_host_factory,
local_network_interfaces,
) )
self._stop = threading.Event() self._stop = threading.Event()

View File

@ -12,14 +12,14 @@ from infection_monkey.i_puppet import (
PortScanData, PortScanData,
PortStatus, PortStatus,
) )
from infection_monkey.network import NetworkAddress
from . import IPScanResults from . import IPScanResults
from .threading_utils import run_worker_threads from .threading_utils import run_worker_threads
logger = logging.getLogger() logger = logging.getLogger()
IP = str Callback = Callable[[NetworkAddress, IPScanResults], None]
Callback = Callable[[IP, IPScanResults], None]
class IPScanner: class IPScanner:
@ -27,22 +27,33 @@ class IPScanner:
self._puppet = puppet self._puppet = puppet
self._num_workers = num_workers self._num_workers = num_workers
def scan(self, ips_to_scan: List[str], options: Dict, results_callback: Callback, stop: Event): def scan(
self,
addresses_to_scan: List[NetworkAddress],
options: Dict,
results_callback: Callback,
stop: Event,
):
# Pre-fill a Queue with all IPs to scan so that threads know they can safely exit when the # Pre-fill a Queue with all IPs to scan so that threads know they can safely exit when the
# queue is empty. # queue is empty.
ips = Queue() addresses = Queue()
for ip in ips_to_scan: for address in addresses_to_scan:
ips.put(ip) addresses.put(address)
scan_ips_args = (ips, options, results_callback, stop) scan_ips_args = (addresses, options, results_callback, stop)
run_worker_threads(target=self._scan_ips, args=scan_ips_args, num_workers=self._num_workers) run_worker_threads(
target=self._scan_addresses, args=scan_ips_args, num_workers=self._num_workers
)
def _scan_ips(self, ips: Queue, options: Dict, results_callback: Callback, stop: Event): def _scan_addresses(
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event
):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
try: try:
while not stop.is_set(): while not stop.is_set():
ip = ips.get_nowait() address = addresses.get_nowait()
ip = address.ip
logger.info(f"Scanning {ip}") logger.info(f"Scanning {ip}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000 icmp_timeout = options["icmp"]["timeout_ms"] / 1000
@ -60,7 +71,7 @@ class IPScanner:
) )
scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data) scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data)
results_callback(ip, scan_results) results_callback(address, scan_results)
logger.debug( logger.debug(
f"Detected the stop signal, scanning thread {threading.get_ident()} exiting" f"Detected the stop signal, scanning thread {threading.get_ident()} exiting"

View File

@ -1,7 +1,7 @@
import logging import logging
from queue import Queue from queue import Queue
from threading import Event from threading import Event
from typing import Dict from typing import Dict, List
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
@ -11,6 +11,8 @@ from infection_monkey.i_puppet import (
PortStatus, PortStatus,
) )
from infection_monkey.model import VictimHost, VictimHostFactory from infection_monkey.model import VictimHost, VictimHostFactory
from infection_monkey.network import NetworkAddress, NetworkInterface
from infection_monkey.network.scan_target_generator import compile_scan_target_list
from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.scan_telem import ScanTelem from infection_monkey.telemetry.scan_telem import ScanTelem
@ -28,11 +30,13 @@ class Propagator:
ip_scanner: IPScanner, ip_scanner: IPScanner,
exploiter: Exploiter, exploiter: Exploiter,
victim_host_factory: VictimHostFactory, victim_host_factory: VictimHostFactory,
local_network_interfaces: List[NetworkInterface],
): ):
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
self._ip_scanner = ip_scanner self._ip_scanner = ip_scanner
self._exploiter = exploiter self._exploiter = exploiter
self._victim_host_factory = victim_host_factory self._victim_host_factory = victim_host_factory
self._local_network_interfaces = local_network_interfaces
self._hosts_to_exploit = None self._hosts_to_exploit = None
def propagate(self, propagation_config: Dict, stop: Event): def propagate(self, propagation_config: Dict, stop: Event):
@ -62,16 +66,30 @@ class Propagator:
def _scan_network(self, propagation_config: Dict, stop: Event): def _scan_network(self, propagation_config: Dict, stop: Event):
logger.info("Starting network scan") logger.info("Starting network scan")
# TODO: Generate list of IPs to scan from propagation targets config target_config = propagation_config["targets"]
ips_to_scan = propagation_config["targets"]["subnet_scan_list"]
scan_config = propagation_config["network_scan"] scan_config = propagation_config["network_scan"]
self._ip_scanner.scan(ips_to_scan, scan_config, self._process_scan_results, stop)
addresses_to_scan = self._compile_scan_target_list(target_config)
self._ip_scanner.scan(addresses_to_scan, scan_config, self._process_scan_results, stop)
logger.info("Finished network scan") logger.info("Finished network scan")
def _process_scan_results(self, ip: str, scan_results: IPScanResults): def _compile_scan_target_list(self, target_config: Dict) -> List[NetworkAddress]:
victim_host = self._victim_host_factory.build_victim_host(ip) ranges_to_scan = target_config["subnet_scan_list"]
inaccessible_subnets = target_config["inaccessible_subnets"]
blocklisted_ips = target_config["blocked_ips"]
enable_local_network_scan = target_config["local_network_scan"]
return compile_scan_target_list(
self._local_network_interfaces,
ranges_to_scan,
inaccessible_subnets,
blocklisted_ips,
enable_local_network_scan,
)
def _process_scan_results(self, address: NetworkAddress, scan_results: IPScanResults):
victim_host = self._victim_host_factory.build_victim_host(address)
Propagator._process_ping_scan_results(victim_host, scan_results.ping_scan_data) Propagator._process_ping_scan_results(victim_host, scan_results.ping_scan_data)
Propagator._process_tcp_scan_results(victim_host, scan_results.port_scan_data) Propagator._process_tcp_scan_results(victim_host, scan_results.port_scan_data)

View File

@ -1,5 +1,8 @@
from typing import Optional
class VictimHost(object): class VictimHost(object):
def __init__(self, ip_addr, domain_name=""): def __init__(self, ip_addr: str, domain_name: str = ""):
self.ip_addr = ip_addr self.ip_addr = ip_addr
self.domain_name = str(domain_name) self.domain_name = str(domain_name)
self.os = {} self.os = {}
@ -42,5 +45,5 @@ class VictimHost(object):
victim += "target monkey: %s" % self.monkey_exe victim += "target monkey: %s" % self.monkey_exe
return victim return victim
def set_default_server(self, default_server): def set_island_address(self, ip: str, port: Optional[str]):
self.default_server = default_server self.default_server = f"{ip}:{port}" if port else f"{ip}"

View File

@ -1,28 +1,50 @@
import logging
from typing import Optional, Tuple
from infection_monkey.model import VictimHost from infection_monkey.model import VictimHost
from infection_monkey.network import NetworkAddress
from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.tunnel import MonkeyTunnel
logger = logging.getLogger(__name__)
class VictimHostFactory: class VictimHostFactory:
def __init__(self): def __init__(
pass self,
tunnel: Optional[MonkeyTunnel],
island_ip: Optional[str],
island_port: Optional[str],
on_island: bool,
):
self.tunnel = tunnel
self.island_ip = island_ip
self.island_port = island_port
self.on_island = on_island
def build_victim_host(self, ip: str): def build_victim_host(self, network_address: NetworkAddress) -> VictimHost:
victim_host = VictimHost(ip) domain = network_address.domain or ""
victim_host = VictimHost(network_address.ip, domain)
# TODO: Reimplement the below logic from the old monkey.py if self.tunnel:
""" victim_host.default_tunnel = self.tunnel.get_tunnel_for_ip(victim_host.ip_addr)
if self._monkey_tunnel:
self._monkey_tunnel.set_tunnel_for_host(machine) if self.island_ip:
if self._default_server: ip, port = self._choose_island_address(victim_host.ip_addr)
if self._network.on_island(self._default_server): victim_host.set_island_address(ip, port)
machine.set_default_server(
get_interface_to_target(machine.ip_addr) logger.debug(f"Default tunnel for {victim_host} set to {victim_host.default_tunnel}")
+ (":" + self._default_server_port if self._default_server_port else "") logger.debug(f"Default server for {victim_host} set to {victim_host.default_server}")
)
else:
machine.set_default_server(self._default_server)
logger.debug(
f"Default server for machine: {machine} set to {machine.default_server}"
)
"""
return victim_host return victim_host
def _choose_island_address(self, victim_ip: str) -> Tuple[str, Optional[str]]:
# Victims need to connect back to the interface they can reach
# On island, choose the right interface to pass to children monkeys
if self.on_island:
default_server_port = self.island_port if self.island_port else None
interface = get_interface_to_target(victim_ip)
return interface, default_server_port
else:
return self.island_ip, self.island_port

View File

@ -4,16 +4,20 @@ import os
import subprocess import subprocess
import sys import sys
import time import time
from typing import List
import infection_monkey.tunnel as tunnel import infection_monkey.tunnel as tunnel
from common.network.network_utils import address_to_ip_port
from common.utils.attack_utils import ScanStatus, UsageEnum from common.utils.attack_utils import ScanStatus, UsageEnum
from common.version import get_version from common.version import get_version
from infection_monkey.config import GUID, WormConfiguration from infection_monkey.config import GUID, WormConfiguration
from infection_monkey.control import ControlClient from infection_monkey.control import ControlClient
from infection_monkey.master import AutomatedMaster
from infection_monkey.master.control_channel import ControlChannel from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.master.mock_master import MockMaster from infection_monkey.model import DELAY_DELETE_CMD, VictimHostFactory
from infection_monkey.model import DELAY_DELETE_CMD from infection_monkey.network import NetworkInterface
from infection_monkey.network.firewall import app as firewall from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_local_network_interfaces
from infection_monkey.puppet.mock_puppet import MockPuppet from infection_monkey.puppet.mock_puppet import MockPuppet
from infection_monkey.system_singleton import SystemSingleton from infection_monkey.system_singleton import SystemSingleton
from infection_monkey.telemetry.attack.t1106_telem import T1106Telem from infection_monkey.telemetry.attack.t1106_telem import T1106Telem
@ -35,11 +39,10 @@ logger = logging.getLogger(__name__)
class InfectionMonkey: class InfectionMonkey:
def __init__(self, args): def __init__(self, args):
logger.info("Monkey is initializing...") logger.info("Monkey is initializing...")
self._master = MockMaster(MockPuppet(), LegacyTelemetryMessengerAdapter())
self._singleton = SystemSingleton() self._singleton = SystemSingleton()
self._opts = self._get_arguments(args) self._opts = self._get_arguments(args)
# TODO Used in propagation phase to set the default server for the victim self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.server)
self._default_server_port = None self._default_server = self._opts.server
# TODO used in propogation phase # TODO used in propogation phase
self._monkey_inbound_tunnel = None self._monkey_inbound_tunnel = None
@ -52,6 +55,7 @@ class InfectionMonkey:
arg_parser.add_argument("-d", "--depth", type=int) arg_parser.add_argument("-d", "--depth", type=int)
opts, _ = arg_parser.parse_known_args(args) opts, _ = arg_parser.parse_known_args(args)
InfectionMonkey._log_arguments(opts) InfectionMonkey._log_arguments(opts)
return opts return opts
@staticmethod @staticmethod
@ -108,25 +112,22 @@ class InfectionMonkey:
def _connect_to_island(self): def _connect_to_island(self):
# Sets island's IP and port for monkey to communicate to # Sets island's IP and port for monkey to communicate to
if not self._is_default_server_set(): if self._current_server_is_set():
self._default_server = WormConfiguration.current_server
logger.debug("Default server set to: %s" % self._default_server)
else:
raise Exception( raise Exception(
"Monkey couldn't find server with {} default tunnel.".format(self._opts.tunnel) "Monkey couldn't find server with {} default tunnel.".format(self._opts.tunnel)
) )
self._set_default_port()
ControlClient.wakeup(parent=self._opts.parent) ControlClient.wakeup(parent=self._opts.parent)
ControlClient.load_control_config() ControlClient.load_control_config()
def _is_default_server_set(self) -> bool: def _current_server_is_set(self) -> bool:
""" if ControlClient.find_server(default_tunnel=self._opts.tunnel):
Sets the default server for the Monkey to communicate back to. return True
:return
""" return False
if not ControlClient.find_server(default_tunnel=self._opts.tunnel):
return False
self._opts.server = WormConfiguration.current_server
logger.debug("default server set to: %s" % self._opts.server)
return True
@staticmethod @staticmethod
def _is_upgrade_to_64_needed(): def _is_upgrade_to_64_needed():
@ -151,17 +152,48 @@ class InfectionMonkey:
StateTelem(is_done=False, version=get_version()).send() StateTelem(is_done=False, version=get_version()).send()
TunnelTelem().send() TunnelTelem().send()
self._build_master()
register_signal_handlers(self._master) register_signal_handlers(self._master)
@staticmethod
def _get_local_network_interfaces():
local_network_interfaces = get_local_network_interfaces()
for i in local_network_interfaces:
logger.debug(f"Found local interface {i.address}{i.netmask}")
return local_network_interfaces
def _build_master(self):
local_network_interfaces = InfectionMonkey._get_local_network_interfaces()
victim_host_factory = self._build_victim_host_factory(local_network_interfaces)
self._master = AutomatedMaster(
MockPuppet(),
LegacyTelemetryMessengerAdapter(),
victim_host_factory,
ControlChannel(self._default_server, GUID),
local_network_interfaces,
)
def _build_victim_host_factory(
self, local_network_interfaces: List[NetworkInterface]
) -> VictimHostFactory:
on_island = self._running_on_island(local_network_interfaces)
logger.debug(f"This agent is running on the island: {on_island}")
return VictimHostFactory(
self._monkey_inbound_tunnel, self._cmd_island_ip, self._cmd_island_port, on_island
)
def _running_on_island(self, local_network_interfaces: List[NetworkInterface]) -> bool:
server_ip, _ = address_to_ip_port(self._default_server)
return server_ip in {interface.address for interface in local_network_interfaces}
def _is_another_monkey_running(self): def _is_another_monkey_running(self):
return not self._singleton.try_lock() return not self._singleton.try_lock()
def _set_default_port(self):
try:
self._default_server_port = self._opts.server.split(":")[1]
except KeyError:
self._default_server_port = ""
def cleanup(self): def cleanup(self):
logger.info("Monkey cleanup started") logger.info("Monkey cleanup started")
self._wait_for_exploited_machine_connection() self._wait_for_exploited_machine_connection()

View File

@ -0,0 +1 @@
from .scan_target_generator import NetworkAddress, NetworkInterface

View File

@ -1,7 +1,9 @@
import itertools import itertools
import socket import socket
import struct import struct
from ipaddress import IPv4Network
from random import randint # noqa: DUO102 from random import randint # noqa: DUO102
from typing import List
import netifaces import netifaces
import psutil import psutil
@ -9,6 +11,8 @@ import psutil
from common.network.network_range import CidrRange from common.network.network_range import CidrRange
from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.environment import is_windows_os
from . import NetworkInterface
# Timeout for monkey connections # Timeout for monkey connections
TIMEOUT = 15 TIMEOUT = 15
LOOPBACK_NAME = b"lo" LOOPBACK_NAME = b"lo"
@ -18,6 +22,16 @@ RTF_UP = 0x0001 # Route usable
RTF_REJECT = 0x0200 RTF_REJECT = 0x0200
def get_local_network_interfaces() -> List[NetworkInterface]:
network_interfaces = []
for i in get_host_subnets():
netmask_bits = IPv4Network(f"{i['addr']}/{i['netmask']}", strict=False).prefixlen
cidr_netmask = f"/{netmask_bits}"
network_interfaces.append(NetworkInterface(i["addr"], cidr_netmask))
return network_interfaces
def get_host_subnets(): def get_host_subnets():
""" """
Returns a list of subnets visible to host (omitting loopback and auto conf networks) Returns a list of subnets visible to host (omitting loopback and auto conf networks)

View File

@ -4,7 +4,6 @@ import struct
import time import time
from threading import Thread from threading import Thread
from infection_monkey.model import VictimHost
from infection_monkey.network.firewall import app as firewall from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port, local_ips from infection_monkey.network.info import get_free_tcp_port, local_ips
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
@ -188,14 +187,13 @@ class MonkeyTunnel(Thread):
proxy.stop() proxy.stop()
proxy.join() proxy.join()
def set_tunnel_for_host(self, host): def get_tunnel_for_ip(self, ip: str):
assert isinstance(host, VictimHost)
if not self.local_port: if not self.local_port:
return return
ip_match = get_interface_to_target(host.ip_addr) ip_match = get_interface_to_target(ip)
host.default_tunnel = "%s:%d" % (ip_match, self.local_port) return "%s:%d" % (ip_match, self.local_port)
def stop(self): def stop(self):
self._stopped = True self._stopped = True

View File

@ -1,6 +1,10 @@
from unittest import TestCase from unittest import TestCase
from common.network.network_utils import get_host_from_network_location, remove_port from common.network.network_utils import (
address_to_ip_port,
get_host_from_network_location,
remove_port,
)
class TestNetworkUtils(TestCase): class TestNetworkUtils(TestCase):
@ -15,3 +19,17 @@ class TestNetworkUtils(TestCase):
assert remove_port("https://google.com:80") == "https://google.com" assert remove_port("https://google.com:80") == "https://google.com"
assert remove_port("https://8.8.8.8:65336") == "https://8.8.8.8" assert remove_port("https://8.8.8.8:65336") == "https://8.8.8.8"
assert remove_port("ftp://ftpserver.com:21/hello/world") == "ftp://ftpserver.com" assert remove_port("ftp://ftpserver.com:21/hello/world") == "ftp://ftpserver.com"
def test_address_to_ip_port():
ip, port = address_to_ip_port("192.168.65.1:5000")
assert ip == "192.168.65.1"
assert port == "5000"
def test_address_to_ip_port_no_port():
ip, port = address_to_ip_port("192.168.65.1")
assert port is None
ip, port = address_to_ip_port("192.168.65.1:")
assert port is None

View File

@ -14,7 +14,7 @@ INTERVAL = 0.001
def test_terminate_without_start(): def test_terminate_without_start():
m = AutomatedMaster(None, None, None, None) m = AutomatedMaster(None, None, None, None, [])
# Test that call to terminate does not raise exception # Test that call to terminate does not raise exception
m.terminate() m.terminate()
@ -34,7 +34,7 @@ def test_stop_if_cant_get_config_from_island(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL "infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL
) )
m = AutomatedMaster(None, None, None, cc) m = AutomatedMaster(None, None, None, cc, [])
m.start() m.start()
assert cc.get_config.call_count == CHECK_FOR_CONFIG_COUNT assert cc.get_config.call_count == CHECK_FOR_CONFIG_COUNT
@ -73,7 +73,7 @@ def test_stop_if_cant_get_stop_signal_from_island(monkeypatch, sleep_and_return_
"infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL "infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL
) )
m = AutomatedMaster(None, None, None, cc) m = AutomatedMaster(None, None, None, cc, [])
m.start() m.start()
assert cc.should_agent_stop.call_count == CHECK_FOR_STOP_AGENT_COUNT assert cc.should_agent_stop.call_count == CHECK_FOR_STOP_AGENT_COUNT

View File

@ -6,6 +6,7 @@ import pytest
from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus
from infection_monkey.master import IPScanner from infection_monkey.master import IPScanner
from infection_monkey.network import NetworkAddress
from infection_monkey.puppet.mock_puppet import MockPuppet from infection_monkey.puppet.mock_puppet import MockPuppet
WINDOWS_OS = "windows" WINDOWS_OS = "windows"
@ -51,20 +52,21 @@ def assert_port_status(port_scan_data, expected_open_ports: Set[int]):
assert psd.status == PortStatus.CLOSED assert psd.status == PortStatus.CLOSED
def assert_scan_results(ip, scan_results): def assert_scan_results(address, scan_results):
ping_scan_data = scan_results.ping_scan_data ping_scan_data = scan_results.ping_scan_data
port_scan_data = scan_results.port_scan_data port_scan_data = scan_results.port_scan_data
fingerprint_data = scan_results.fingerprint_data fingerprint_data = scan_results.fingerprint_data
if ip == "10.0.0.1": if address.ip == "10.0.0.1":
assert_scan_results_no_1(ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results_no_1(address.domain, ping_scan_data, port_scan_data, fingerprint_data)
elif ip == "10.0.0.3": elif address.ip == "10.0.0.3":
assert_scan_results_no_3(ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results_no_3(address.domain, ping_scan_data, port_scan_data, fingerprint_data)
else: else:
assert_scan_results_host_down(ip, ping_scan_data, port_scan_data, fingerprint_data) assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data)
def assert_scan_results_no_1(ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results_no_1(domain, ping_scan_data, port_scan_data, fingerprint_data):
assert domain == "d1"
assert ping_scan_data.response_received is True assert ping_scan_data.response_received is True
assert ping_scan_data.os == WINDOWS_OS assert ping_scan_data.os == WINDOWS_OS
@ -97,7 +99,9 @@ def assert_fingerprint_results_no_1(fingerprint_data):
assert fingerprint_data["SMBFinger"].services["tcp-445"]["name"] == "smb_service_name" assert fingerprint_data["SMBFinger"].services["tcp-445"]["name"] == "smb_service_name"
def assert_scan_results_no_3(ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results_no_3(domain, ping_scan_data, port_scan_data, fingerprint_data):
assert domain == "d3"
assert ping_scan_data.response_received is True assert ping_scan_data.response_received is True
assert ping_scan_data.os == LINUX_OS assert ping_scan_data.os == LINUX_OS
assert len(port_scan_data.keys()) == 6 assert len(port_scan_data.keys()) == 6
@ -135,8 +139,9 @@ def assert_fingerprint_results_no_3(fingerprint_data):
assert fingerprint_data["HTTPFinger"].services["tcp-443"]["data"] == ("SERVER_HEADERS_2", True) assert fingerprint_data["HTTPFinger"].services["tcp-443"]["data"] == ("SERVER_HEADERS_2", True)
def assert_scan_results_host_down(ip, ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data):
assert ip not in {"10.0.0.1", "10.0.0.3"} assert address.ip not in {"10.0.0.1", "10.0.0.3"}
assert address.domain is None
assert ping_scan_data.response_received is False assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6 assert len(port_scan_data.keys()) == 6
@ -146,44 +151,49 @@ def assert_scan_results_host_down(ip, ping_scan_data, port_scan_data, fingerprin
def test_scan_single_ip(callback, scan_config, stop): def test_scan_single_ip(callback, scan_config, stop):
ips = ["10.0.0.1"] addresses = [NetworkAddress("10.0.0.1", "d1")]
ns = IPScanner(MockPuppet(), num_workers=1) ns = IPScanner(MockPuppet(), num_workers=1)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
callback.assert_called_once() callback.assert_called_once()
(ip, scan_results) = callback.call_args_list[0][0] (address, scan_results) = callback.call_args_list[0][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
def test_scan_multiple_ips(callback, scan_config, stop): def test_scan_multiple_ips(callback, scan_config, stop):
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] addresses = [
NetworkAddress("10.0.0.1", "d1"),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", "d3"),
NetworkAddress("10.0.0.4", None),
]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert callback.call_count == 4 assert callback.call_count == 4
(ip, scan_results) = callback.call_args_list[0][0] (address, scan_results) = callback.call_args_list[0][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
(ip, scan_results) = callback.call_args_list[1][0] (address, scan_results) = callback.call_args_list[1][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
(ip, scan_results) = callback.call_args_list[2][0] (address, scan_results) = callback.call_args_list[2][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
(ip, scan_results) = callback.call_args_list[3][0] (address, scan_results) = callback.call_args_list[3][0]
assert_scan_results(ip, scan_results) assert_scan_results(address, scan_results)
@pytest.mark.slow @pytest.mark.slow
def test_scan_lots_of_ips(callback, scan_config, stop): def test_scan_lots_of_ips(callback, scan_config, stop):
ips = [f"10.0.0.{i}" for i in range(0, 255)] addresses = [NetworkAddress(f"10.0.0.{i}", None) for i in range(0, 255)]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert callback.call_count == 255 assert callback.call_count == 255
@ -199,10 +209,15 @@ def test_stop_after_callback(scan_config, stop):
stoppable_callback = MagicMock(side_effect=_callback) stoppable_callback = MagicMock(side_effect=_callback)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
]
ns = IPScanner(MockPuppet(), num_workers=2) ns = IPScanner(MockPuppet(), num_workers=2)
ns.scan(ips, scan_config, stoppable_callback, stop) ns.scan(addresses, scan_config, stoppable_callback, stop)
assert stoppable_callback.call_count == 2 assert stoppable_callback.call_count == 2
@ -221,10 +236,15 @@ def test_interrupt_port_scanning(callback, scan_config, stop):
puppet = MockPuppet() puppet = MockPuppet()
puppet.scan_tcp_port = MagicMock(side_effect=stoppable_scan_tcp_port) puppet.scan_tcp_port = MagicMock(side_effect=stoppable_scan_tcp_port)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert puppet.scan_tcp_port.call_count == 2 assert puppet.scan_tcp_port.call_count == 2
@ -243,9 +263,14 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
puppet = MockPuppet() puppet = MockPuppet()
puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint) puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint)
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)
ns.scan(ips, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
assert puppet.fingerprint.call_count == 2 assert puppet.fingerprint.call_count == 2

View File

@ -1,4 +1,7 @@
from threading import Event from threading import Event
from unittest.mock import MagicMock
import pytest
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
@ -8,8 +11,26 @@ from infection_monkey.i_puppet import (
PortStatus, PortStatus,
) )
from infection_monkey.master import IPScanResults, Propagator from infection_monkey.master import IPScanResults, Propagator
from infection_monkey.model import VictimHostFactory from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.model import VictimHost, VictimHostFactory
from infection_monkey.network import NetworkAddress
@pytest.fixture
def mock_victim_host_factory():
class MockVictimHostFactory(VictimHostFactory):
def __init__(self):
pass
def build_victim_host(self, network_address: NetworkAddress) -> VictimHost:
domain = network_address.domain or ""
return VictimHost(network_address.ip, domain)
return MockVictimHostFactory()
empty_fingerprint_data = FingerprintData(None, None, {}) empty_fingerprint_data = FingerprintData(None, None, {})
@ -83,15 +104,21 @@ dot_3_services = {
} }
class MockIPScanner: @pytest.fixture
def scan(self, ips_to_scan, _, results_callback, stop): def mock_ip_scanner():
for ip in ips_to_scan: def scan(adresses_to_scan, _, results_callback, stop):
if ip.endswith(".1"): for address in adresses_to_scan:
results_callback(ip, dot_1_scan_results) if address.ip.endswith(".1"):
elif ip.endswith(".3"): results_callback(address, dot_1_scan_results)
results_callback(ip, dot_3_scan_results) elif address.ip.endswith(".3"):
results_callback(address, dot_3_scan_results)
else: else:
results_callback(ip, dead_host_scan_results) results_callback(address, dead_host_scan_results)
ip_scanner = MagicMock()
ip_scanner.scan = MagicMock(side_effect=scan)
return ip_scanner
class StubExploiter: class StubExploiter:
@ -101,11 +128,18 @@ class StubExploiter:
pass pass
def test_scan_result_processing(telemetry_messenger_spy): def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory):
p = Propagator(telemetry_messenger_spy, MockIPScanner(), StubExploiter(), VictimHostFactory()) p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), mock_victim_host_factory, []
)
p.propagate( p.propagate(
{ {
"targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}, "targets": {
"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"],
"local_network_scan": False,
"inaccessible_subnets": [],
"blocked_ips": [],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it "network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since StubExploiter ignores it "exploiters": {}, # This is empty since StubExploiter ignores it
}, },
@ -141,10 +175,13 @@ class MockExploiter:
def exploit_hosts( def exploit_hosts(
self, exploiter_config, hosts_to_exploit, results_callback, scan_completed, stop self, exploiter_config, hosts_to_exploit, results_callback, scan_completed, stop
): ):
scan_completed.wait()
hte = [] hte = []
for _ in range(0, 2): for _ in range(0, 2):
hte.append(hosts_to_exploit.get()) hte.append(hosts_to_exploit.get())
assert hosts_to_exploit.empty()
for host in hte: for host in hte:
if host.ip_addr.endswith(".1"): if host.ip_addr.endswith(".1"):
results_callback( results_callback(
@ -157,7 +194,7 @@ class MockExploiter:
host, host,
ExploiterResultData(False, {}, {}, "SSH FAILED for .1"), ExploiterResultData(False, {}, {}, "SSH FAILED for .1"),
) )
if host.ip_addr.endswith(".2"): elif host.ip_addr.endswith(".2"):
results_callback( results_callback(
"PowerShellExploiter", "PowerShellExploiter",
host, host,
@ -168,7 +205,7 @@ class MockExploiter:
host, host,
ExploiterResultData(False, {}, {}, "SSH FAILED for .2"), ExploiterResultData(False, {}, {}, "SSH FAILED for .2"),
) )
if host.ip_addr.endswith(".3"): elif host.ip_addr.endswith(".3"):
results_callback( results_callback(
"PowerShellExploiter", "PowerShellExploiter",
host, host,
@ -181,11 +218,20 @@ class MockExploiter:
) )
def test_exploiter_result_processing(telemetry_messenger_spy): def test_exploiter_result_processing(
p = Propagator(telemetry_messenger_spy, MockIPScanner(), MockExploiter(), VictimHostFactory()) telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory
):
p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), mock_victim_host_factory, []
)
p.propagate( p.propagate(
{ {
"targets": {"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]}, "targets": {
"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"],
"local_network_scan": False,
"inaccessible_subnets": [],
"blocked_ips": [],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it "network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since MockExploiter ignores it "exploiters": {}, # This is empty since MockExploiter ignores it
}, },
@ -211,3 +257,48 @@ def test_exploiter_result_processing(telemetry_messenger_spy):
assert not data["result"] assert not data["result"]
else: else:
assert data["result"] assert data["result"]
def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory):
local_network_interfaces = [NetworkInterface("10.0.0.9", "/29")]
p = Propagator(
telemetry_messenger_spy,
mock_ip_scanner,
StubExploiter(),
mock_victim_host_factory,
local_network_interfaces,
)
p.propagate(
{
"targets": {
"subnet_scan_list": ["10.0.0.0/29", "172.10.20.30"],
"local_network_scan": True,
"blocked_ips": ["10.0.0.3"],
"inaccessible_subnets": ["10.0.0.128/30", "10.0.0.8/29"],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since MockExploiter ignores it
},
Event(),
)
expected_ip_scan_list = [
"10.0.0.0",
"10.0.0.1",
"10.0.0.2",
"10.0.0.4",
"10.0.0.5",
"10.0.0.6",
"10.0.0.8",
"10.0.0.10",
"10.0.0.11",
"10.0.0.12",
"10.0.0.13",
"10.0.0.14",
"10.0.0.128",
"10.0.0.129",
"10.0.0.130",
"172.10.20.30",
]
actual_ip_scan_list = [address.ip for address in mock_ip_scanner.scan.call_args_list[0][0][0]]
assert actual_ip_scan_list == expected_ip_scan_list

View File

@ -0,0 +1,83 @@
from unittest.mock import MagicMock
import pytest
from infection_monkey.model import VictimHostFactory
from infection_monkey.network.scan_target_generator import NetworkAddress
@pytest.fixture
def mock_tunnel():
tunnel = MagicMock()
tunnel.get_tunnel_for_ip = lambda _: "1.2.3.4:1234"
return tunnel
@pytest.fixture(autouse=True)
def mock_get_interface_to_target(monkeypatch):
monkeypatch.setattr(
"infection_monkey.model.victim_host_factory.get_interface_to_target", lambda _: "1.1.1.1"
)
def test_factory_no_tunnel():
factory = VictimHostFactory(
tunnel=None, island_ip="192.168.56.1", island_port="5000", on_island=False
)
network_address = NetworkAddress("192.168.56.2", None)
victim = factory.build_victim_host(network_address)
assert victim.default_server == "192.168.56.1:5000"
assert victim.ip_addr == "192.168.56.2"
assert victim.default_tunnel is None
assert victim.domain_name == ""
def test_factory_with_tunnel(mock_tunnel):
factory = VictimHostFactory(
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port="5000", on_island=False
)
network_address = NetworkAddress("192.168.56.2", None)
victim = factory.build_victim_host(network_address)
assert victim.default_server == "192.168.56.1:5000"
assert victim.ip_addr == "192.168.56.2"
assert victim.default_tunnel == "1.2.3.4:1234"
assert victim.domain_name == ""
def test_factory_on_island(mock_tunnel):
factory = VictimHostFactory(
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port="99", on_island=True
)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)
assert victim.default_server == "1.1.1.1:99"
assert victim.domain_name == "www.bogus.monkey"
assert victim.ip_addr == "192.168.56.2"
assert victim.default_tunnel == "1.2.3.4:1234"
@pytest.mark.parametrize("default_port", ["", None])
def test_factory_no_port(mock_tunnel, default_port):
factory = VictimHostFactory(
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port=default_port, on_island=True
)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)
assert victim.default_server == "1.1.1.1"
def test_factory_no_default_server(mock_tunnel):
factory = VictimHostFactory(tunnel=mock_tunnel, island_ip=None, island_port="", on_island=True)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)
assert victim.default_server is None

View File

@ -96,9 +96,9 @@ def test_get_monkey_commandline_linux():
def test_build_monkey_commandline(): def test_build_monkey_commandline():
example_host = VictimHost(ip_addr="bla") example_host = VictimHost(ip_addr="bla")
example_host.set_default_server("101010") example_host.set_island_address("101010", "5000")
expected = f" -p {GUID} -s 101010 -d 0 -l /home/bla" expected = f" -p {GUID} -s 101010:5000 -d 0 -l /home/bla"
actual = build_monkey_commandline(target_host=example_host, depth=0, location="/home/bla") actual = build_monkey_commandline(target_host=example_host, depth=0, location="/home/bla")
assert expected == actual assert expected == actual