Merge pull request #2043 from guardicore/1960-deserialize-config

1960 deserialize config
This commit is contained in:
Mike Salvatore 2022-06-27 08:35:11 -04:00 committed by GitHub
commit 44a6197422
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 514 additions and 436 deletions

View File

@ -1,8 +1,16 @@
from .agent_configuration import (
AgentConfiguration,
AgentConfigurationSchema,
from .agent_configuration import AgentConfiguration, InvalidConfigurationError
from .agent_sub_configurations import (
CustomPBAConfiguration,
PluginConfiguration,
ScanTargetConfiguration,
ICMPScanConfiguration,
TCPScanConfiguration,
NetworkScanConfiguration,
ExploitationOptionsConfiguration,
ExploiterConfiguration,
ExploitationConfiguration,
PropagationConfiguration,
)
from .default_agent_configuration import (
DEFAULT_AGENT_CONFIGURATION_JSON,
build_default_agent_configuration,
DEFAULT_AGENT_CONFIGURATION,
)

View File

@ -1,7 +1,10 @@
from dataclasses import dataclass
from typing import List
from __future__ import annotations
from marshmallow import Schema, fields, post_load
from dataclasses import dataclass
from typing import Any, List, Mapping
from marshmallow import Schema, fields
from marshmallow.exceptions import MarshmallowError
from .agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema,
@ -15,6 +18,15 @@ from .agent_sub_configurations import (
)
class InvalidConfigurationError(Exception):
pass
INVALID_CONFIGURATION_ERROR_MESSAGE = (
"Cannot construct an AgentConfiguration object with the supplied, invalid data:"
)
@dataclass(frozen=True)
class AgentConfiguration:
keep_tunnel_open_time: float
@ -24,6 +36,57 @@ class AgentConfiguration:
payloads: List[PluginConfiguration]
propagation: PropagationConfiguration
def __post_init__(self):
# This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object
try:
AgentConfigurationSchema().dump(self)
except Exception as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration:
"""
Construct an AgentConfiguration from a Mapping
:param config_mapping: A Mapping that represents an AgentConfiguration
:return: An AgentConfiguration
:raises: InvalidConfigurationError if the provided Mapping does not represent a valid
AgentConfiguration
"""
try:
config_dict = AgentConfigurationSchema().load(config_mapping)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def from_json(config_json: str) -> AgentConfiguration:
"""
Construct an AgentConfiguration from a JSON string
:param config_json: A JSON string that represents an AgentConfiguration
:return: An AgentConfiguration
:raises: InvalidConfigurationError if the provided JSON does not represent a valid
AgentConfiguration
"""
try:
config_dict = AgentConfigurationSchema().loads(config_json)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def to_json(config: AgentConfiguration) -> str:
"""
Serialize an AgentConfiguration to JSON
:param config: An AgentConfiguration
:return: A JSON string representing the AgentConfiguration
"""
return AgentConfigurationSchema().dumps(config)
class AgentConfigurationSchema(Schema):
keep_tunnel_open_time = fields.Float()
@ -32,7 +95,3 @@ class AgentConfigurationSchema(Schema):
credential_collectors = fields.List(fields.Nested(PluginConfigurationSchema))
payloads = fields.List(fields.Nested(PluginConfigurationSchema))
propagation = fields.Nested(PropagationConfigurationSchema)
@post_load
def _make_agent_configuration(self, data, **kwargs):
return AgentConfiguration(**data)

View File

@ -1,208 +1,115 @@
from . import AgentConfiguration, AgentConfigurationSchema
from . import AgentConfiguration
from .agent_sub_configurations import (
CustomPBAConfiguration,
ExploitationConfiguration,
ExploitationOptionsConfiguration,
ExploiterConfiguration,
ICMPScanConfiguration,
NetworkScanConfiguration,
PluginConfiguration,
PropagationConfiguration,
ScanTargetConfiguration,
TCPScanConfiguration,
)
DEFAULT_AGENT_CONFIGURATION_JSON = """{
"keep_tunnel_open_time": 30,
"post_breach_actions": [
{
"name": "CommunicateAsBackdoorUser",
"options": {}
},
{
"name": "ModifyShellStartupFiles",
"options": {}
},
{
"name": "HiddenFiles",
"options": {}
},
{
"name": "TrapCommand",
"options": {}
},
{
"name": "ChangeSetuidSetgid",
"options": {}
},
{
"name": "ScheduleJobs",
"options": {}
},
{
"name": "Timestomping",
"options": {}
},
{
"name": "AccountDiscovery",
"options": {}
},
{
"name": "ProcessListCollection",
"options": {}
}
],
"credential_collectors": [
{
"name": "MimikatzCollector",
"options": {}
},
{
"name": "SSHCollector",
"options": {}
}
],
"payloads": [
{
"name": "ransomware",
"options": {
"encryption": {
"enabled": true,
"directories": {
"linux_target_dir": "",
"windows_target_dir": ""
}
},
"other_behaviors": {
"readme": true
}
}
}
],
"custom_pbas": {
"linux_command": "",
"linux_filename": "",
"windows_command": "",
"windows_filename": ""
},
"propagation": {
"maximum_depth": 2,
"network_scan": {
"tcp": {
"timeout": 3000,
"ports": [
22,
80,
135,
443,
445,
2222,
3306,
3389,
5985,
5986,
7001,
8008,
8080,
8088,
8983,
9200,
9600
]
},
"icmp": {
"timeout": 1000
},
"fingerprinters": [
{
"name": "elastic",
"options": {}
},
{
"name": "http",
"options": {
"http_ports": [
80,
443,
7001,
8008,
8080,
8983,
9200,
9600
]
}
},
{
"name": "mssql",
"options": {}
},
{
"name": "smb",
"options": {}
},
{
"name": "ssh",
"options": {}
}
],
"targets": {
"blocked_ips": [],
"inaccessible_subnets": [],
"local_network_scan": true,
"subnets": []
}
},
"exploitation": {
"options": {
"http_ports": [
80,
443,
7001,
8008,
8080,
8983,
9200,
9600
]
},
"brute_force": [
{
"name": "MSSQLExploiter",
"options": {}
PBAS = [
"CommunicateAsBackdoorUser",
"ModifyShellStartupFiles",
"HiddenFiles",
"TrapCommand",
"ChangeSetuidSetgid",
"ScheduleJobs",
"Timestomping",
"AccountDiscovery",
"ProcessListCollection",
]
},
{
"name": "PowerShellExploiter",
"options": {}
CREDENTIAL_COLLECTORS = ["MimikatzCollector", "SSHCollector"]
},
{
"name": "SSHExploiter",
"options": {}
PBA_CONFIGURATION = [PluginConfiguration(pba, {}) for pba in PBAS]
CREDENTIAL_COLLECTOR_CONFIGURATION = [
PluginConfiguration(collector, {}) for collector in CREDENTIAL_COLLECTORS
]
},
{
"name": "SmbExploiter",
"options": {
"smb_download_timeout": 30
}
RANSOMWARE_OPTIONS = {
"encryption": {
"enabled": True,
"directories": {"linux_target_dir": "", "windows_target_dir": ""},
},
"other_behaviors": {"readme": True},
}
},
{
"name": "WmiExploiter",
"options": {
"smb_download_timeout": 30
}
PAYLOAD_CONFIGURATION = [PluginConfiguration("ransomware", RANSOMWARE_OPTIONS)]
}
],
"vulnerability": [
{
"name": "HadoopExploiter",
"options": {}
CUSTOM_PBA_CONFIGURATION = CustomPBAConfiguration(
linux_command="", linux_filename="", windows_command="", windows_filename=""
)
},
{
"name": "Log4ShellExploiter",
"options": {}
TCP_PORTS = [
22,
80,
135,
443,
445,
2222,
3306,
3389,
5985,
5986,
7001,
8008,
8080,
8088,
8983,
9200,
9600,
]
}
]
}
}
}
"""
TCP_SCAN_CONFIGURATION = TCPScanConfiguration(timeout=3.0, ports=TCP_PORTS)
ICMP_CONFIGURATION = ICMPScanConfiguration(timeout=1.0)
HTTP_PORTS = [80, 443, 7001, 8008, 8080, 8983, 9200, 9600]
FINGERPRINTERS = [
PluginConfiguration("elastic", {}),
PluginConfiguration("http", {"http_ports": HTTP_PORTS}),
PluginConfiguration("mssql", {}),
PluginConfiguration("smb", {}),
PluginConfiguration("ssh", {}),
]
SCAN_TARGET_CONFIGURATION = ScanTargetConfiguration([], [], True, [])
NETWORK_SCAN_CONFIGURATION = NetworkScanConfiguration(
TCP_SCAN_CONFIGURATION, ICMP_CONFIGURATION, FINGERPRINTERS, SCAN_TARGET_CONFIGURATION
)
def build_default_agent_configuration() -> AgentConfiguration:
schema = AgentConfigurationSchema()
return schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
EXPLOITATION_OPTIONS_CONFIGURATION = ExploitationOptionsConfiguration(HTTP_PORTS)
BRUTE_FORCE_EXPLOITERS = [
ExploiterConfiguration("MSSQLExploiter", {}),
ExploiterConfiguration("PowerShellExploiter", {}),
ExploiterConfiguration("SSHExploiter", {}),
ExploiterConfiguration("SmbExploiter", {"smb_download_timeout": 30}),
ExploiterConfiguration("WmiExploiter", {"smb_download_timeout": 30}),
]
VULNERABILITY_EXPLOITERS = [
ExploiterConfiguration("Log4ShellExploiter", {}),
ExploiterConfiguration("HadoopExploiter", {}),
]
EXPLOITATION_CONFIGURATION = ExploitationConfiguration(
EXPLOITATION_OPTIONS_CONFIGURATION, BRUTE_FORCE_EXPLOITERS, VULNERABILITY_EXPLOITERS
)
PROPAGATION_CONFIGURATION = PropagationConfiguration(
maximum_depth=2,
network_scan=NETWORK_SCAN_CONFIGURATION,
exploitation=EXPLOITATION_CONFIGURATION,
)
DEFAULT_AGENT_CONFIGURATION = AgentConfiguration(
keep_tunnel_open_time=30,
custom_pbas=CUSTOM_PBA_CONFIGURATION,
post_breach_actions=PBA_CONFIGURATION,
credential_collectors=CREDENTIAL_COLLECTOR_CONFIGURATION,
payloads=PAYLOAD_CONFIGURATION,
propagation=PROPAGATION_CONFIGURATION,
)

View File

@ -2,5 +2,12 @@ from enum import Enum
class OperatingSystems(Enum):
"""
An Enum representing all supported operating systems
This Enum represents all operating systems that Infection Monkey supports. The value of each
member is the member's name in all lower-case characters.
"""
LINUX = "linux"
WINDOWS = "windows"

View File

@ -38,5 +38,7 @@ class DomainControllerNameFetchError(FailedExploitationError):
"""Raise on failed attempt to extract domain controller's name"""
# TODO: This has been replaced by common.configuration.InvalidConfigurationError. Use that error
# instead and remove this one.
class InvalidConfigurationError(Exception):
"""Raise when configuration is invalid"""

View File

@ -5,6 +5,7 @@ from typing import Mapping
import requests
from common import OperatingSystems
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from . import IAgentRepository
@ -22,18 +23,22 @@ class CachingAgentRepository(IAgentRepository):
self._proxies = proxies
self._lock = threading.Lock()
def get_agent_binary(self, os: str, architecture: str = None) -> io.BytesIO:
def get_agent_binary(
self, operating_system: OperatingSystems, architecture: str = None
) -> io.BytesIO:
# If multiple calls to get_agent_binary() are made simultaneously before the result of
# _download_binary_from_island() is cached, then multiple requests will be sent to the
# island. Add a mutex in front of the call to _download_agent_binary_from_island() so
# that only one request per OS will be sent to the island.
with self._lock:
return io.BytesIO(self._download_binary_from_island(os))
return io.BytesIO(self._download_binary_from_island(operating_system))
@lru_cache(maxsize=None)
def _download_binary_from_island(self, os: str) -> bytes:
def _download_binary_from_island(self, operating_system: OperatingSystems) -> bytes:
os_name = operating_system.value
response = requests.get( # noqa: DUO123
f"{self._island_url}/api/agent-binaries/{os}",
f"{self._island_url}/api/agent-binaries/{os_name}",
verify=False,
proxies=self._proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,

View File

@ -1,6 +1,8 @@
import abc
import io
from common import OperatingSystems
# TODO: The Island also has an IAgentRepository with a totally different interface. At the moment,
# the Island and Agent have different needs, but at some point we should unify these.
@ -13,12 +15,13 @@ class IAgentRepository(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def get_agent_binary(self, os: str, architecture: str = None) -> io.BytesIO:
def get_agent_binary(
self, operating_system: OperatingSystems, architecture: str = None
) -> io.BytesIO:
"""
Retrieve the appropriate agent binary from the repository.
:param str os: The name of the operating system on which the agent binary will run
:param str architecture: Reserved
:param operating_system: The name of the operating system on which the agent binary will run
:param architecture: Reserved
:return: A file-like object for the requested agent binary
:rtype: io.BytesIO
"""
pass

View File

@ -129,7 +129,7 @@ class Log4ShellExploiter(WebRCE):
}
def _build_java_class(self, exploit_command: str) -> bytes:
if OperatingSystems.LINUX in self.host.os["type"]:
if OperatingSystems.LINUX == self.host.os["type"]:
return build_exploit_bytecode(exploit_command, LINUX_EXPLOIT_TEMPLATE_PATH)
else:
return build_exploit_bytecode(exploit_command, WINDOWS_EXPLOIT_TEMPLATE_PATH)

View File

@ -2,6 +2,7 @@ import logging
from pathlib import Path, PurePath
from typing import List, Optional
from common import OperatingSystems
from infection_monkey.exploit.HostExploiter import HostExploiter
from infection_monkey.exploit.powershell_utils.auth_options import AuthOptions, get_auth_options
from infection_monkey.exploit.powershell_utils.credentials import (
@ -162,7 +163,7 @@ class PowerShellExploiter(HostExploiter):
temp_monkey_binary_filepath.unlink()
def _create_local_agent_file(self, binary_path):
agent_binary_bytes = self.agent_repository.get_agent_binary("windows")
agent_binary_bytes = self.agent_repository.get_agent_binary(OperatingSystems.WINDOWS)
with open(binary_path, "wb") as f:
f.write(agent_binary_bytes.getvalue())

View File

@ -57,6 +57,6 @@ class HTTPTools(object):
httpd.start()
lock.acquire()
return (
"http://%s:%s/%s" % (local_ip, local_port, urllib.parse.quote(host.os["type"])),
f"http://{local_ip}:{local_port}/{urllib.parse.quote(host.os['type'].value)}",
httpd,
)

View File

@ -1,5 +1,7 @@
import abc
from common.configuration import AgentConfiguration
class IControlChannel(metaclass=abc.ABCMeta):
@abc.abstractmethod
@ -11,10 +13,10 @@ class IControlChannel(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def get_config(self) -> dict:
def get_config(self) -> AgentConfiguration:
"""
:return: A dictionary containing Agent Configuration
:rtype: dict
:return: An AgentConfiguration object
:rtype: AgentConfiguration
"""
pass

View File

@ -1,8 +1,9 @@
import logging
import threading
import time
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Any, Callable, Iterable, List, Optional
from common.configuration import CustomPBAConfiguration, PluginConfiguration
from common.utils import Timer
from infection_monkey.credential_store import ICredentialsStore
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
@ -13,7 +14,7 @@ from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.credentials_telem import CredentialsTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.utils.propagation import should_propagate
from infection_monkey.utils.propagation import maximum_depth_reached
from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter
from . import Exploiter, IPScanner, Propagator
@ -111,7 +112,7 @@ class AutomatedMaster(IMaster):
time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC)
@staticmethod
def _try_communicate_with_island(fn: Callable[[], Any], max_tries: int):
def _try_communicate_with_island(fn: Callable[[], Any], max_tries: int) -> Any:
tries = 0
while tries < max_tries:
try:
@ -141,7 +142,7 @@ class AutomatedMaster(IMaster):
try:
config = AutomatedMaster._try_communicate_with_island(
self._control_channel.get_config, CHECK_FOR_CONFIG_COUNT
)["config"]
)
except IslandCommunicationError as e:
logger.error(f"An error occurred while fetching configuration: {e}")
return
@ -150,7 +151,7 @@ class AutomatedMaster(IMaster):
target=self._run_plugins,
name="CredentialCollectorThread",
args=(
config["credential_collectors"],
config.credential_collectors,
"credential collector",
self._collect_credentials,
),
@ -158,7 +159,7 @@ class AutomatedMaster(IMaster):
pba_thread = create_daemon_thread(
target=self._run_pbas,
name="PBAThread",
args=(config["post_breach_actions"].items(), self._run_pba, config["custom_pbas"]),
args=(config.post_breach_actions, self._run_pba, config.custom_pbas),
)
credential_collector_thread.start()
@ -173,52 +174,56 @@ class AutomatedMaster(IMaster):
current_depth = self._current_depth if self._current_depth is not None else 0
logger.info(f"Current depth is {current_depth}")
if should_propagate(self._control_channel.get_config(), self._current_depth):
self._propagator.propagate(config["propagation"], current_depth, self._stop)
if maximum_depth_reached(config.propagation.maximum_depth, self._current_depth):
self._propagator.propagate(config.propagation, current_depth, self._stop)
else:
logger.info("Skipping propagation: maximum depth reached")
payload_thread = create_daemon_thread(
target=self._run_plugins,
name="PayloadThread",
args=(config["payloads"].items(), "payload", self._run_payload),
args=(config.payloads, "payload", self._run_payload),
)
payload_thread.start()
payload_thread.join()
pba_thread.join()
def _collect_credentials(self, collector: str):
credentials = self._puppet.run_credential_collector(collector, {})
def _collect_credentials(self, collector: PluginConfiguration):
credentials = self._puppet.run_credential_collector(collector.name, collector.options)
if credentials:
self._telemetry_messenger.send_telemetry(CredentialsTelem(credentials))
else:
logger.debug(f"No credentials were collected by {collector}")
def _run_pba(self, pba: Tuple[str, Dict]):
name = pba[0]
options = pba[1]
for pba_data in self._puppet.run_pba(name, options):
def _run_pba(self, pba: PluginConfiguration):
for pba_data in self._puppet.run_pba(pba.name, pba.options):
self._telemetry_messenger.send_telemetry(PostBreachTelem(pba_data))
def _run_payload(self, payload: Tuple[str, Dict]):
name = payload[0]
options = payload[1]
self._puppet.run_payload(name, options, self._stop)
def _run_payload(self, payload: PluginConfiguration):
self._puppet.run_payload(payload.name, payload.options, self._stop)
def _run_pbas(
self, plugins: Iterable[Any], callback: Callable[[Any], None], custom_pba_options: Mapping
self,
plugins: Iterable[PluginConfiguration],
callback: Callable[[Any], None],
custom_pba_options: CustomPBAConfiguration,
):
self._run_plugins(plugins, "post-breach action", callback)
if custom_pba_is_enabled(custom_pba_options):
self._run_plugins([("CustomPBA", custom_pba_options)], "post-breach action", callback)
self._run_plugins(
[PluginConfiguration(name="CustomPBA", options=custom_pba_options.__dict__)],
"post-breach action",
callback,
)
def _run_plugins(
self, plugins: Iterable[Any], plugin_type: str, callback: Callable[[Any], None]
self,
plugins: Iterable[PluginConfiguration],
plugin_type: str,
callback: Callable[[Any], None],
):
logger.info(f"Running {plugin_type}s")
logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run")

View File

@ -1,10 +1,12 @@
import json
import logging
from pprint import pformat
from typing import Mapping
import requests
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from common.configuration import AgentConfiguration
from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
@ -47,7 +49,7 @@ class ControlChannel(IControlChannel):
) as e:
raise IslandCommunicationError(e)
def get_config(self) -> dict:
def get_config(self) -> AgentConfiguration:
try:
response = requests.get( # noqa: DUO123
f"https://{self._control_channel_server}/api/agent",
@ -57,7 +59,10 @@ class ControlChannel(IControlChannel):
)
response.raise_for_status()
return json.loads(response.content.decode())
config_dict = json.loads(response.text)["config"]
logger.debug(f"Received configuration:\n{pformat(json.loads(response.text))}")
return AgentConfiguration.from_mapping(config_dict)
except (
json.JSONDecodeError,
requests.exceptions.ConnectionError,

View File

@ -5,9 +5,13 @@ from copy import deepcopy
from itertools import chain
from queue import Queue
from threading import Event
from typing import Callable, Dict, List, Mapping
from typing import Callable, Dict, Sequence
from common import OperatingSystems
from common.configuration.agent_sub_configurations import (
ExploitationConfiguration,
ExploiterConfiguration,
)
from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_puppet import ExploiterResultData, IPuppet
from infection_monkey.model import VictimHost
@ -46,7 +50,7 @@ class Exploiter:
def exploit_hosts(
self,
exploiter_config: Dict,
exploiter_config: ExploitationConfiguration,
hosts_to_exploit: Queue,
current_depth: int,
results_callback: Callback,
@ -56,7 +60,7 @@ class Exploiter:
exploiters_to_run = self._process_exploiter_config(exploiter_config)
logger.debug(
"Agent is configured to run the following exploiters in order: "
f"{', '.join([e['name'] for e in exploiters_to_run])}"
f"{', '.join([e.name for e in exploiters_to_run])}"
)
exploit_args = (
@ -75,24 +79,26 @@ class Exploiter:
)
@staticmethod
def _process_exploiter_config(exploiter_config: Mapping) -> List[Mapping]:
def _process_exploiter_config(
exploiter_config: ExploitationConfiguration,
) -> Sequence[ExploiterConfiguration]:
# Run vulnerability exploiters before brute force exploiters to minimize the effect of
# account lockout due to invalid credentials
ordered_exploiters = chain(
exploiter_config["vulnerability"], exploiter_config["brute_force"]
)
ordered_exploiters = chain(exploiter_config.vulnerability, exploiter_config.brute_force)
exploiters_to_run = list(deepcopy(ordered_exploiters))
extended_exploiters = []
for exploiter in exploiters_to_run:
# This order allows exploiter-specific options to
# override general options for all exploiters.
exploiter["options"] = {**exploiter_config["options"], **exploiter["options"]}
options = {**exploiter_config.options.__dict__, **exploiter.options}
extended_exploiters.append(ExploiterConfiguration(exploiter.name, options))
return exploiters_to_run
return extended_exploiters
def _exploit_hosts_on_queue(
self,
exploiters_to_run: List[Dict],
exploiters_to_run: Sequence[ExploiterConfiguration],
hosts_to_exploit: Queue,
current_depth: int,
results_callback: Callback,
@ -119,7 +125,7 @@ class Exploiter:
def _run_all_exploiters(
self,
exploiters_to_run: List[Dict],
exploiters_to_run: Sequence[ExploiterConfiguration],
victim_host: VictimHost,
current_depth: int,
results_callback: Callback,
@ -127,11 +133,10 @@ class Exploiter:
):
for exploiter in interruptible_iter(exploiters_to_run, stop):
exploiter_name = exploiter["name"]
exploiter_name = exploiter.name
victim_os = victim_host.os.get("type")
# We want to try all exploiters if the victim's OS is unknown
print(victim_os)
if victim_os is not None and victim_os not in SUPPORTED_OS[exploiter_name]:
logger.debug(
f"Skipping {exploiter_name} because it does not support "
@ -140,7 +145,7 @@ class Exploiter:
continue
exploiter_results = self._run_exploiter(
exploiter_name, exploiter["options"], victim_host, current_depth, stop
exploiter_name, exploiter.options, victim_host, current_depth, stop
)
results_callback(exploiter_name, victim_host, exploiter_results)

View File

@ -3,8 +3,13 @@ import queue
import threading
from queue import Queue
from threading import Event
from typing import Any, Callable, Dict, List
from typing import Callable, Dict, Sequence
from common.configuration.agent_sub_configurations import (
NetworkScanConfiguration,
PluginConfiguration,
ScanTargetConfiguration,
)
from infection_monkey.i_puppet import (
FingerprintData,
IPuppet,
@ -29,8 +34,8 @@ class IPScanner:
def scan(
self,
addresses_to_scan: List[NetworkAddress],
options: Dict,
addresses_to_scan: Sequence[NetworkAddress],
options: ScanTargetConfiguration,
results_callback: Callback,
stop: Event,
):
@ -49,12 +54,16 @@ class IPScanner:
)
def _scan_addresses(
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event
self,
addresses: Queue,
options: NetworkScanConfiguration,
results_callback: Callback,
stop: Event,
):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
tcp_ports = options["tcp"]["ports"]
logger.debug(f"Starting scan .read -- Thread ID: {threading.get_ident()}")
icmp_timeout = options.icmp.timeout
tcp_timeout = options.tcp.timeout
tcp_ports = options.tcp.ports
try:
while not stop.is_set():
@ -66,7 +75,7 @@ class IPScanner:
fingerprint_data = {}
if IPScanner.port_scan_found_open_port(port_scan_data):
fingerprinters = options["fingerprinters"]
fingerprinters = options.fingerprinters
fingerprint_data = self._run_fingerprinters(
address.ip, fingerprinters, ping_scan_data, port_scan_data, stop
)
@ -90,7 +99,7 @@ class IPScanner:
def _run_fingerprinters(
self,
ip: str,
fingerprinters: List[Dict[str, Any]],
fingerprinters: Sequence[PluginConfiguration],
ping_scan_data: PingScanData,
port_scan_data: Dict[int, PortScanData],
stop: Event,
@ -98,8 +107,8 @@ class IPScanner:
fingerprint_data = {}
for f in interruptible_iter(fingerprinters, stop):
fingerprint_data[f["name"]] = self._puppet.fingerprint(
f["name"], ip, ping_scan_data, port_scan_data, f["options"]
fingerprint_data[f.name] = self._puppet.fingerprint(
f.name, ip, ping_scan_data, port_scan_data, f.options
)
return fingerprint_data

View File

@ -1,13 +1,12 @@
from typing import Dict
from common.configuration import CustomPBAConfiguration
from infection_monkey.utils.environment import is_windows_os
def custom_pba_is_enabled(pba_options: Dict) -> bool:
def custom_pba_is_enabled(pba_options: CustomPBAConfiguration) -> bool:
if not is_windows_os():
if pba_options["linux_command"]:
if pba_options.linux_command:
return True
else:
if pba_options["windows_command"]:
if pba_options.windows_command:
return True
return False

View File

@ -1,8 +1,13 @@
import logging
from queue import Queue
from threading import Event
from typing import Dict, List
from typing import List
from common.configuration import (
NetworkScanConfiguration,
PropagationConfiguration,
ScanTargetConfiguration,
)
from infection_monkey.i_puppet import (
ExploiterResultData,
FingerprintData,
@ -39,14 +44,18 @@ class Propagator:
self._local_network_interfaces = local_network_interfaces
self._hosts_to_exploit = None
def propagate(self, propagation_config: Dict, current_depth: int, stop: Event):
def propagate(
self, propagation_config: PropagationConfiguration, current_depth: int, stop: Event
):
logger.info("Attempting to propagate")
network_scan_completed = Event()
self._hosts_to_exploit = Queue()
scan_thread = create_daemon_thread(
target=self._scan_network, name="PropagatorScanThread", args=(propagation_config, stop)
target=self._scan_network,
name="PropagatorScanThread",
args=(propagation_config.network_scan, stop),
)
exploit_thread = create_daemon_thread(
target=self._exploit_hosts,
@ -64,22 +73,21 @@ class Propagator:
logger.info("Finished attempting to propagate")
def _scan_network(self, propagation_config: Dict, stop: Event):
def _scan_network(self, scan_config: NetworkScanConfiguration, stop: Event):
logger.info("Starting network scan")
target_config = propagation_config["targets"]
scan_config = propagation_config["network_scan"]
addresses_to_scan = self._compile_scan_target_list(target_config)
addresses_to_scan = self._compile_scan_target_list(scan_config.targets)
self._ip_scanner.scan(addresses_to_scan, scan_config, self._process_scan_results, stop)
logger.info("Finished network scan")
def _compile_scan_target_list(self, target_config: Dict) -> List[NetworkAddress]:
ranges_to_scan = target_config["subnet_scan_list"]
inaccessible_subnets = target_config["inaccessible_subnets"]
blocklisted_ips = target_config["blocked_ips"]
enable_local_network_scan = target_config["local_network_scan"]
def _compile_scan_target_list(
self, target_config: ScanTargetConfiguration
) -> List[NetworkAddress]:
ranges_to_scan = target_config.subnets
inaccessible_subnets = target_config.inaccessible_subnets
blocklisted_ips = target_config.blocked_ips
enable_local_network_scan = target_config.local_network_scan
return compile_scan_target_list(
self._local_network_interfaces,
@ -134,14 +142,14 @@ class Propagator:
def _exploit_hosts(
self,
propagation_config: Dict,
propagation_config: PropagationConfiguration,
current_depth: int,
network_scan_completed: Event,
stop: Event,
):
logger.info("Exploiting victims")
exploiter_config = propagation_config["exploiters"]
exploiter_config = propagation_config.exploitation
self._exploiter.exploit_hosts(
exploiter_config,
self._hosts_to_exploit,

View File

@ -17,7 +17,7 @@ class VictimHost(object):
return self.__dict__
def is_windows(self) -> bool:
return OperatingSystems.WINDOWS in self.os["type"]
return OperatingSystems.WINDOWS == self.os["type"]
def __hash__(self):
return hash(self.ip_addr)

View File

@ -78,7 +78,7 @@ from infection_monkey.utils.monkey_dir import (
remove_monkey_dir,
)
from infection_monkey.utils.monkey_log_path import get_agent_log_path
from infection_monkey.utils.propagation import should_propagate
from infection_monkey.utils.propagation import maximum_depth_reached
from infection_monkey.utils.signal_handler import register_signal_handlers, reset_signal_handlers
logger = logging.getLogger(__name__)
@ -173,9 +173,11 @@ class InfectionMonkey:
config = control_channel.get_config()
self._monkey_inbound_tunnel = self._control_client.create_control_tunnel(
config["config"]["keep_tunnel_open_time"]
config.keep_tunnel_open_time
)
if self._monkey_inbound_tunnel and should_propagate(config, self._current_depth):
if self._monkey_inbound_tunnel and maximum_depth_reached(
config.propagation.maximum_depth, self._current_depth
):
self._inbound_tunnel_opened = True
self._monkey_inbound_tunnel.start()

View File

@ -16,7 +16,7 @@ class ClearCommandHistory(PBA):
super().__init__(telemetry_messenger, name=POST_BREACH_CLEAR_CMD_HISTORY)
def run(self, options: Dict) -> Iterable[PostBreachData]:
results = [pba.run() for pba in self.clear_command_history_pba_list()]
results = [pba.run(options) for pba in self.clear_command_history_pba_list()]
if results:
# `self.command` is empty here
self.pba_data.append(PostBreachData(self.name, self.command, results))
@ -53,7 +53,7 @@ class ClearCommandHistory(PBA):
linux_cmd=linux_cmds,
)
def run(self) -> Tuple[str, bool]:
def run(self, options: Dict) -> Tuple[str, bool]:
if self.command:
try:
output = subprocess.check_output( # noqa: DUO116

View File

@ -62,7 +62,7 @@ class FileServHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
f.close()
def send_head(self):
if self.path != "/" + urllib.parse.quote(self.victim_os):
if self.path != "/" + urllib.parse.quote(self.victim_os.value):
self.send_error(500, "")
return None, 0, 0
try:

View File

@ -1,2 +1,2 @@
def should_propagate(config: dict, current_depth: int) -> bool:
return config["config"]["depth"] > current_depth
def maximum_depth_reached(maximum_depth: int, current_depth: int) -> bool:
return maximum_depth > current_depth

View File

@ -1,6 +1,6 @@
import io
from common.configuration import AgentConfiguration, AgentConfigurationSchema
from common.configuration import AgentConfiguration
from monkey_island.cc import repository
from monkey_island.cc.repository import (
IAgentConfigurationRepository,
@ -17,21 +17,20 @@ class FileAgentConfigurationRepository(IAgentConfigurationRepository):
):
self._default_agent_configuration = default_agent_configuration
self._file_repository = file_repository
self._schema = AgentConfigurationSchema()
def get_configuration(self) -> AgentConfiguration:
try:
with self._file_repository.open_file(AGENT_CONFIGURATION_FILE_NAME) as f:
configuration_json = f.read().decode()
return self._schema.loads(configuration_json)
return AgentConfiguration.from_json(configuration_json)
except repository.FileNotFoundError:
return self._default_agent_configuration
except Exception as err:
raise RetrievalError(f"Error retrieving the agent configuration: {err}")
def store_configuration(self, agent_configuration: AgentConfiguration):
configuration_json = self._schema.dumps(agent_configuration)
configuration_json = AgentConfiguration.to_json(agent_configuration)
self._file_repository.save_file(
AGENT_CONFIGURATION_FILE_NAME, io.BytesIO(configuration_json.encode())

View File

@ -1,9 +1,9 @@
import json
import marshmallow
from flask import make_response, request
from common.configuration.agent_configuration import AgentConfigurationSchema
from common.configuration.agent_configuration import AgentConfiguration as AgentConfigurationObject
from common.configuration.agent_configuration import InvalidConfigurationError
from monkey_island.cc.repository import IAgentConfigurationRepository
from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required
@ -14,22 +14,21 @@ class AgentConfiguration(AbstractResource):
def __init__(self, agent_configuration_repository: IAgentConfigurationRepository):
self._agent_configuration_repository = agent_configuration_repository
self._schema = AgentConfigurationSchema()
@jwt_required
def get(self):
configuration = self._agent_configuration_repository.get_configuration()
configuration_json = self._schema.dumps(configuration)
configuration_json = AgentConfigurationObject.to_json(configuration)
return make_response(configuration_json, 200)
@jwt_required
def post(self):
try:
configuration_object = self._schema.loads(request.data)
configuration_object = AgentConfigurationObject.from_json(request.data)
self._agent_configuration_repository.store_configuration(configuration_object)
return make_response({}, 200)
except (marshmallow.exceptions.ValidationError, json.JSONDecodeError) as err:
except (InvalidConfigurationError, json.JSONDecodeError) as err:
return make_response(
{"message": f"Invalid configuration supplied: {err}"},
400,

View File

@ -179,6 +179,7 @@ class ConfigService:
should_encrypt=True,
)
@staticmethod
def _filter_none_values(data):
if isinstance(data, dict):
return {
@ -460,7 +461,7 @@ class ConfigService:
formatted_tcp_scan_config = {}
formatted_tcp_scan_config["timeout"] = config[flat_tcp_timeout_field]
formatted_tcp_scan_config["timeout"] = config[flat_tcp_timeout_field] / 1000
ports = ConfigService._union_tcp_and_http_ports(
config[flat_tcp_ports_field], config[flat_http_ports_field]
@ -484,7 +485,7 @@ class ConfigService:
flat_ping_timeout_field = "ping_scan_timeout"
formatted_icmp_scan_config = {}
formatted_icmp_scan_config["timeout"] = config[flat_ping_timeout_field]
formatted_icmp_scan_config["timeout"] = config[flat_ping_timeout_field] / 1000
config.pop(flat_ping_timeout_field, None)

View File

@ -3,7 +3,7 @@ from pathlib import Path
from common import DIContainer
from common.aws import AWSInstance
from common.configuration import AgentConfiguration, build_default_agent_configuration
from common.configuration import DEFAULT_AGENT_CONFIGURATION, AgentConfiguration
from common.utils.file_utils import get_binary_io_sha256_hash
from monkey_island.cc.repository import (
AgentBinaryRepository,
@ -32,7 +32,7 @@ def initialize_services(data_dir: Path) -> DIContainer:
container.register_convention(Path, "data_dir", data_dir)
container.register_convention(
AgentConfiguration, "default_agent_configuration", build_default_agent_configuration()
AgentConfiguration, "default_agent_configuration", DEFAULT_AGENT_CONFIGURATION
)
container.register_instance(AWSInstance, AWSInstance())

View File

@ -1,12 +1,12 @@
from tests.common.example_agent_configuration import AGENT_CONFIGURATION
from common.configuration.agent_configuration import AgentConfigurationSchema
from common.configuration.agent_configuration import AgentConfiguration
from monkey_island.cc.repository import IAgentConfigurationRepository
class InMemoryAgentConfigurationRepository(IAgentConfigurationRepository):
def __init__(self):
self._configuration = AgentConfigurationSchema().load(AGENT_CONFIGURATION)
self._configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
def get_configuration(self):
return self._configuration

View File

@ -1,3 +1,7 @@
import json
from copy import deepcopy
import pytest
from tests.common.example_agent_configuration import (
AGENT_CONFIGURATION,
BLOCKED_IPS,
@ -23,11 +27,8 @@ from tests.common.example_agent_configuration import (
WINDOWS_FILENAME,
)
from common.configuration import (
DEFAULT_AGENT_CONFIGURATION_JSON,
AgentConfiguration,
AgentConfigurationSchema,
)
from common.configuration import AgentConfiguration, InvalidConfigurationError
from common.configuration.agent_configuration import AgentConfigurationSchema
from common.configuration.agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema,
ExploitationConfigurationSchema,
@ -157,10 +158,8 @@ def test_propagation_configuration():
def test_agent_configuration():
schema = AgentConfigurationSchema()
config = schema.load(AGENT_CONFIGURATION)
config_dict = schema.dump(config)
config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
config_json = AgentConfiguration.to_json(config)
assert isinstance(config, AgentConfiguration)
assert config.keep_tunnel_open_time == 30
@ -169,12 +168,53 @@ def test_agent_configuration():
assert isinstance(config.credential_collectors[0], PluginConfiguration)
assert isinstance(config.payloads[0], PluginConfiguration)
assert isinstance(config.propagation, PropagationConfiguration)
assert config_dict == AGENT_CONFIGURATION
assert json.loads(config_json) == AGENT_CONFIGURATION
def test_default_agent_configuration():
def test_incorrect_type():
valid_config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
with pytest.raises(InvalidConfigurationError):
valid_config_dict = valid_config.__dict__
valid_config_dict["keep_tunnel_open_time"] = "not_a_float"
AgentConfiguration(**valid_config_dict)
def test_from_dict():
schema = AgentConfigurationSchema()
dict_ = deepcopy(AGENT_CONFIGURATION)
config = schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
config = AgentConfiguration.from_mapping(dict_)
assert schema.dump(config) == dict_
def test_from_dict__invalid_data():
dict_ = deepcopy(AGENT_CONFIGURATION)
dict_["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_mapping(dict_)
def test_from_json():
schema = AgentConfigurationSchema()
dict_ = deepcopy(AGENT_CONFIGURATION)
config = AgentConfiguration.from_json(json.dumps(dict_))
assert isinstance(config, AgentConfiguration)
assert schema.dump(config) == dict_
def test_from_json__invalid_data():
invalid_dict = deepcopy(AGENT_CONFIGURATION)
invalid_dict["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_json(json.dumps(invalid_dict))
def test_to_json():
config = deepcopy(AGENT_CONFIGURATION)
assert json.loads(AgentConfiguration.to_json(config)) == AGENT_CONFIGURATION

View File

@ -9,7 +9,7 @@ from _pytest.monkeypatch import MonkeyPatch
MONKEY_BASE_PATH = str(Path(__file__).parent.parent.parent)
sys.path.insert(0, MONKEY_BASE_PATH)
from common.configuration import AgentConfiguration, build_default_agent_configuration # noqa: E402
from common.configuration import DEFAULT_AGENT_CONFIGURATION, AgentConfiguration # noqa: E402
@pytest.fixture(scope="session")
@ -60,4 +60,4 @@ def load_monkey_config(data_for_tests_dir) -> Callable[[str], Dict]:
@pytest.fixture
def default_agent_configuration() -> AgentConfiguration:
return build_default_agent_configuration()
return DEFAULT_AGENT_CONFIGURATION

View File

@ -8,6 +8,10 @@ import pytest
from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet
from common import OperatingSystems
from common.configuration.agent_sub_configurations import (
ExploitationConfiguration,
ExploiterConfiguration,
)
from infection_monkey.master import Exploiter
from infection_monkey.model import VictimHost
@ -35,18 +39,18 @@ def callback():
@pytest.fixture
def exploiter_config():
return {
"options": {"dropper_path_linux": "/tmp/monkey"},
"brute_force": [
{"name": "MSSQLExploiter", "options": {"timeout": 10}},
{"name": "SSHExploiter", "options": {}},
{"name": "WmiExploiter", "options": {"timeout": 10}},
],
"vulnerability": [
{"name": "ZerologonExploiter", "options": {}},
],
}
def exploiter_config(default_agent_configuration):
brute_force = [
ExploiterConfiguration(name="MSSQLExploiter", options={"timeout": 10}),
ExploiterConfiguration(name="SSHExploiter", options={}),
ExploiterConfiguration(name="WmiExploiter", options={"timeout": 10}),
]
vulnerability = [ExploiterConfiguration(name="ZerologonExploiter", options={})]
return ExploitationConfiguration(
options=default_agent_configuration.propagation.exploitation.options,
brute_force=brute_force,
vulnerability=vulnerability,
)
@pytest.fixture

View File

@ -5,6 +5,12 @@ from unittest.mock import MagicMock
import pytest
from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet
from common.configuration.agent_sub_configurations import (
ICMPScanConfiguration,
NetworkScanConfiguration,
PluginConfiguration,
TCPScanConfiguration,
)
from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus
from infection_monkey.master import IPScanner
from infection_monkey.network import NetworkAddress
@ -14,28 +20,31 @@ LINUX_OS = "linux"
@pytest.fixture
def scan_config():
return {
"tcp": {
"timeout_ms": 3000,
"ports": [
22,
445,
3389,
443,
8008,
3306,
],
},
"icmp": {
"timeout_ms": 1000,
},
"fingerprinters": [
{"name": "HTTPFinger", "options": {}},
{"name": "SMBFinger", "options": {}},
{"name": "SSHFinger", "options": {}},
def scan_config(default_agent_configuration):
tcp_config = TCPScanConfiguration(
timeout=3,
ports=[
22,
445,
3389,
443,
8008,
3306,
],
}
)
icmp_config = ICMPScanConfiguration(timeout=1)
fingerprinter_config = [
PluginConfiguration(name="HTTPFinger", options={}),
PluginConfiguration(name="SMBFinger", options={}),
PluginConfiguration(name="SSHFinger", options={}),
]
scan_config = NetworkScanConfiguration(
tcp_config,
icmp_config,
fingerprinter_config,
default_agent_configuration.propagation.network_scan.targets,
)
return scan_config
@pytest.fixture

View File

@ -3,6 +3,11 @@ from unittest.mock import MagicMock
import pytest
from common.configuration.agent_sub_configurations import (
NetworkScanConfiguration,
PropagationConfiguration,
ScanTargetConfiguration,
)
from infection_monkey.i_puppet import (
ExploiterResultData,
FingerprintData,
@ -135,24 +140,37 @@ class StubExploiter:
pass
def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory):
def get_propagation_config(
default_agent_configuration, scan_target_config: ScanTargetConfiguration
):
network_scan = NetworkScanConfiguration(
default_agent_configuration.propagation.network_scan.tcp,
default_agent_configuration.propagation.network_scan.icmp,
default_agent_configuration.propagation.network_scan.fingerprinters,
scan_target_config,
)
propagation_config = PropagationConfiguration(
default_agent_configuration.propagation.maximum_depth,
network_scan,
default_agent_configuration.propagation.exploitation,
)
return propagation_config
def test_scan_result_processing(
telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration
):
p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), mock_victim_host_factory, []
)
p.propagate(
{
"targets": {
"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"],
"local_network_scan": False,
"inaccessible_subnets": [],
"blocked_ips": [],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since StubExploiter ignores it
},
1,
Event(),
targets = ScanTargetConfiguration(
blocked_ips=[],
inaccessible_subnets=[],
local_network_scan=False,
subnets=["10.0.0.1", "10.0.0.2", "10.0.0.3"],
)
propagation_config = get_propagation_config(default_agent_configuration, targets)
p.propagate(propagation_config, 1, Event())
assert len(telemetry_messenger_spy.telemetries) == 3
@ -237,25 +255,20 @@ class MockExploiter:
def test_exploiter_result_processing(
telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory
telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration
):
p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), mock_victim_host_factory, []
)
p.propagate(
{
"targets": {
"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"],
"local_network_scan": False,
"inaccessible_subnets": [],
"blocked_ips": [],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since MockExploiter ignores it
},
1,
Event(),
targets = ScanTargetConfiguration(
blocked_ips=[],
inaccessible_subnets=[],
local_network_scan=False,
subnets=["10.0.0.1", "10.0.0.2", "10.0.0.3"],
)
propagation_config = get_propagation_config(default_agent_configuration, targets)
p.propagate(propagation_config, 1, Event())
exploit_telems = [t for t in telemetry_messenger_spy.telemetries if isinstance(t, ExploitTelem)]
assert len(exploit_telems) == 4
@ -278,7 +291,9 @@ def test_exploiter_result_processing(
assert data["propagation_result"]
def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory):
def test_scan_target_generation(
telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration
):
local_network_interfaces = [NetworkInterface("10.0.0.9", "/29")]
p = Propagator(
telemetry_messenger_spy,
@ -287,20 +302,15 @@ def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_v
mock_victim_host_factory,
local_network_interfaces,
)
p.propagate(
{
"targets": {
"subnet_scan_list": ["10.0.0.0/29", "172.10.20.30"],
"local_network_scan": True,
"blocked_ips": ["10.0.0.3"],
"inaccessible_subnets": ["10.0.0.128/30", "10.0.0.8/29"],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since MockExploiter ignores it
},
1,
Event(),
targets = ScanTargetConfiguration(
blocked_ips=["10.0.0.3"],
inaccessible_subnets=["10.0.0.128/30", "10.0.0.8/29"],
local_network_scan=True,
subnets=["10.0.0.0/29", "172.10.20.30"],
)
propagation_config = get_propagation_config(default_agent_configuration, targets)
p.propagate(propagation_config, 1, Event())
expected_ip_scan_list = [
"10.0.0.0",
"10.0.0.1",

View File

@ -1,32 +1,22 @@
from infection_monkey.utils.propagation import should_propagate
from infection_monkey.utils.propagation import maximum_depth_reached
def get_config(max_depth):
return {"config": {"depth": max_depth}}
def test_should_propagate_current_less_than_max():
max_depth = 2
def test_maximum_depth_reached__current_less_than_max():
maximum_depth = 2
current_depth = 1
config = get_config(max_depth)
assert should_propagate(config, current_depth) is True
assert maximum_depth_reached(maximum_depth, current_depth) is True
def test_should_propagate_current_greater_than_max():
max_depth = 2
def test_maximum_depth_reached__current_greater_than_max():
maximum_depth = 2
current_depth = 3
config = get_config(max_depth)
assert should_propagate(config, current_depth) is False
assert maximum_depth_reached(maximum_depth, current_depth) is False
def test_should_propagate_current_equal_to_max():
max_depth = 2
current_depth = max_depth
def test_maximum_depth_reached__current_equal_to_max():
maximum_depth = 2
current_depth = maximum_depth
config = get_config(max_depth)
assert should_propagate(config, current_depth) is False
assert maximum_depth_reached(maximum_depth, current_depth) is False

View File

@ -2,7 +2,7 @@ import pytest
from tests.common.example_agent_configuration import AGENT_CONFIGURATION
from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository
from common.configuration import AgentConfigurationSchema
from common.configuration import AgentConfiguration
from monkey_island.cc.repository import FileAgentConfigurationRepository, RetrievalError
@ -12,8 +12,7 @@ def repository(default_agent_configuration):
def test_store_agent_config(repository):
schema = AgentConfigurationSchema()
agent_configuration = schema.load(AGENT_CONFIGURATION)
agent_configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
repository.store_configuration(agent_configuration)
retrieved_agent_configuration = repository.get_configuration()

View File

@ -99,7 +99,7 @@ def test_format_config_for_agent__propagation():
def test_format_config_for_agent__network_scan():
expected_network_scan_config = {
"tcp": {
"timeout": 3000,
"timeout": 3.0,
"ports": [
22,
80,
@ -117,7 +117,7 @@ def test_format_config_for_agent__network_scan():
],
},
"icmp": {
"timeout": 1000,
"timeout": 1.0,
},
"targets": {
"blocked_ips": ["192.168.1.1", "192.168.1.100"],