Merge pull request #2043 from guardicore/1960-deserialize-config

1960 deserialize config
This commit is contained in:
Mike Salvatore 2022-06-27 08:35:11 -04:00 committed by GitHub
commit 44a6197422
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 514 additions and 436 deletions

View File

@ -1,8 +1,16 @@
from .agent_configuration import ( from .agent_configuration import AgentConfiguration, InvalidConfigurationError
AgentConfiguration, from .agent_sub_configurations import (
AgentConfigurationSchema, CustomPBAConfiguration,
PluginConfiguration,
ScanTargetConfiguration,
ICMPScanConfiguration,
TCPScanConfiguration,
NetworkScanConfiguration,
ExploitationOptionsConfiguration,
ExploiterConfiguration,
ExploitationConfiguration,
PropagationConfiguration,
) )
from .default_agent_configuration import ( from .default_agent_configuration import (
DEFAULT_AGENT_CONFIGURATION_JSON, DEFAULT_AGENT_CONFIGURATION,
build_default_agent_configuration,
) )

View File

@ -1,7 +1,10 @@
from dataclasses import dataclass from __future__ import annotations
from typing import List
from marshmallow import Schema, fields, post_load from dataclasses import dataclass
from typing import Any, List, Mapping
from marshmallow import Schema, fields
from marshmallow.exceptions import MarshmallowError
from .agent_sub_configuration_schemas import ( from .agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema, CustomPBAConfigurationSchema,
@ -15,6 +18,15 @@ from .agent_sub_configurations import (
) )
class InvalidConfigurationError(Exception):
pass
INVALID_CONFIGURATION_ERROR_MESSAGE = (
"Cannot construct an AgentConfiguration object with the supplied, invalid data:"
)
@dataclass(frozen=True) @dataclass(frozen=True)
class AgentConfiguration: class AgentConfiguration:
keep_tunnel_open_time: float keep_tunnel_open_time: float
@ -24,6 +36,57 @@ class AgentConfiguration:
payloads: List[PluginConfiguration] payloads: List[PluginConfiguration]
propagation: PropagationConfiguration propagation: PropagationConfiguration
def __post_init__(self):
# This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object
try:
AgentConfigurationSchema().dump(self)
except Exception as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration:
"""
Construct an AgentConfiguration from a Mapping
:param config_mapping: A Mapping that represents an AgentConfiguration
:return: An AgentConfiguration
:raises: InvalidConfigurationError if the provided Mapping does not represent a valid
AgentConfiguration
"""
try:
config_dict = AgentConfigurationSchema().load(config_mapping)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def from_json(config_json: str) -> AgentConfiguration:
"""
Construct an AgentConfiguration from a JSON string
:param config_json: A JSON string that represents an AgentConfiguration
:return: An AgentConfiguration
:raises: InvalidConfigurationError if the provided JSON does not represent a valid
AgentConfiguration
"""
try:
config_dict = AgentConfigurationSchema().loads(config_json)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def to_json(config: AgentConfiguration) -> str:
"""
Serialize an AgentConfiguration to JSON
:param config: An AgentConfiguration
:return: A JSON string representing the AgentConfiguration
"""
return AgentConfigurationSchema().dumps(config)
class AgentConfigurationSchema(Schema): class AgentConfigurationSchema(Schema):
keep_tunnel_open_time = fields.Float() keep_tunnel_open_time = fields.Float()
@ -32,7 +95,3 @@ class AgentConfigurationSchema(Schema):
credential_collectors = fields.List(fields.Nested(PluginConfigurationSchema)) credential_collectors = fields.List(fields.Nested(PluginConfigurationSchema))
payloads = fields.List(fields.Nested(PluginConfigurationSchema)) payloads = fields.List(fields.Nested(PluginConfigurationSchema))
propagation = fields.Nested(PropagationConfigurationSchema) propagation = fields.Nested(PropagationConfigurationSchema)
@post_load
def _make_agent_configuration(self, data, **kwargs):
return AgentConfiguration(**data)

View File

@ -1,84 +1,51 @@
from . import AgentConfiguration, AgentConfigurationSchema from . import AgentConfiguration
from .agent_sub_configurations import (
CustomPBAConfiguration,
ExploitationConfiguration,
ExploitationOptionsConfiguration,
ExploiterConfiguration,
ICMPScanConfiguration,
NetworkScanConfiguration,
PluginConfiguration,
PropagationConfiguration,
ScanTargetConfiguration,
TCPScanConfiguration,
)
DEFAULT_AGENT_CONFIGURATION_JSON = """{ PBAS = [
"keep_tunnel_open_time": 30, "CommunicateAsBackdoorUser",
"post_breach_actions": [ "ModifyShellStartupFiles",
{ "HiddenFiles",
"name": "CommunicateAsBackdoorUser", "TrapCommand",
"options": {} "ChangeSetuidSetgid",
}, "ScheduleJobs",
{ "Timestomping",
"name": "ModifyShellStartupFiles", "AccountDiscovery",
"options": {} "ProcessListCollection",
}, ]
{
"name": "HiddenFiles", CREDENTIAL_COLLECTORS = ["MimikatzCollector", "SSHCollector"]
"options": {}
}, PBA_CONFIGURATION = [PluginConfiguration(pba, {}) for pba in PBAS]
{ CREDENTIAL_COLLECTOR_CONFIGURATION = [
"name": "TrapCommand", PluginConfiguration(collector, {}) for collector in CREDENTIAL_COLLECTORS
"options": {} ]
},
{ RANSOMWARE_OPTIONS = {
"name": "ChangeSetuidSetgid",
"options": {}
},
{
"name": "ScheduleJobs",
"options": {}
},
{
"name": "Timestomping",
"options": {}
},
{
"name": "AccountDiscovery",
"options": {}
},
{
"name": "ProcessListCollection",
"options": {}
}
],
"credential_collectors": [
{
"name": "MimikatzCollector",
"options": {}
},
{
"name": "SSHCollector",
"options": {}
}
],
"payloads": [
{
"name": "ransomware",
"options": {
"encryption": { "encryption": {
"enabled": true, "enabled": True,
"directories": { "directories": {"linux_target_dir": "", "windows_target_dir": ""},
"linux_target_dir": "",
"windows_target_dir": ""
}
}, },
"other_behaviors": { "other_behaviors": {"readme": True},
"readme": true
} }
}
} PAYLOAD_CONFIGURATION = [PluginConfiguration("ransomware", RANSOMWARE_OPTIONS)]
],
"custom_pbas": { CUSTOM_PBA_CONFIGURATION = CustomPBAConfiguration(
"linux_command": "", linux_command="", linux_filename="", windows_command="", windows_filename=""
"linux_filename": "", )
"windows_command": "",
"windows_filename": "" TCP_PORTS = [
},
"propagation": {
"maximum_depth": 2,
"network_scan": {
"tcp": {
"timeout": 3000,
"ports": [
22, 22,
80, 80,
135, 135,
@ -95,114 +62,54 @@ DEFAULT_AGENT_CONFIGURATION_JSON = """{
8088, 8088,
8983, 8983,
9200, 9200,
9600 9600,
] ]
},
"icmp": { TCP_SCAN_CONFIGURATION = TCPScanConfiguration(timeout=3.0, ports=TCP_PORTS)
"timeout": 1000 ICMP_CONFIGURATION = ICMPScanConfiguration(timeout=1.0)
}, HTTP_PORTS = [80, 443, 7001, 8008, 8080, 8983, 9200, 9600]
"fingerprinters": [ FINGERPRINTERS = [
{ PluginConfiguration("elastic", {}),
"name": "elastic", PluginConfiguration("http", {"http_ports": HTTP_PORTS}),
"options": {} PluginConfiguration("mssql", {}),
}, PluginConfiguration("smb", {}),
{ PluginConfiguration("ssh", {}),
"name": "http",
"options": {
"http_ports": [
80,
443,
7001,
8008,
8080,
8983,
9200,
9600
] ]
}
}, SCAN_TARGET_CONFIGURATION = ScanTargetConfiguration([], [], True, [])
{ NETWORK_SCAN_CONFIGURATION = NetworkScanConfiguration(
"name": "mssql", TCP_SCAN_CONFIGURATION, ICMP_CONFIGURATION, FINGERPRINTERS, SCAN_TARGET_CONFIGURATION
"options": {} )
},
{ EXPLOITATION_OPTIONS_CONFIGURATION = ExploitationOptionsConfiguration(HTTP_PORTS)
"name": "smb", BRUTE_FORCE_EXPLOITERS = [
"options": {} ExploiterConfiguration("MSSQLExploiter", {}),
}, ExploiterConfiguration("PowerShellExploiter", {}),
{ ExploiterConfiguration("SSHExploiter", {}),
"name": "ssh", ExploiterConfiguration("SmbExploiter", {"smb_download_timeout": 30}),
"options": {} ExploiterConfiguration("WmiExploiter", {"smb_download_timeout": 30}),
}
],
"targets": {
"blocked_ips": [],
"inaccessible_subnets": [],
"local_network_scan": true,
"subnets": []
}
},
"exploitation": {
"options": {
"http_ports": [
80,
443,
7001,
8008,
8080,
8983,
9200,
9600
] ]
},
"brute_force": [
{
"name": "MSSQLExploiter",
"options": {}
}, VULNERABILITY_EXPLOITERS = [
{ ExploiterConfiguration("Log4ShellExploiter", {}),
"name": "PowerShellExploiter", ExploiterConfiguration("HadoopExploiter", {}),
"options": {}
},
{
"name": "SSHExploiter",
"options": {}
},
{
"name": "SmbExploiter",
"options": {
"smb_download_timeout": 30
}
},
{
"name": "WmiExploiter",
"options": {
"smb_download_timeout": 30
}
}
],
"vulnerability": [
{
"name": "HadoopExploiter",
"options": {}
},
{
"name": "Log4ShellExploiter",
"options": {}
}
] ]
}
}
}
"""
EXPLOITATION_CONFIGURATION = ExploitationConfiguration(
EXPLOITATION_OPTIONS_CONFIGURATION, BRUTE_FORCE_EXPLOITERS, VULNERABILITY_EXPLOITERS
)
def build_default_agent_configuration() -> AgentConfiguration: PROPAGATION_CONFIGURATION = PropagationConfiguration(
schema = AgentConfigurationSchema() maximum_depth=2,
return schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON) network_scan=NETWORK_SCAN_CONFIGURATION,
exploitation=EXPLOITATION_CONFIGURATION,
)
DEFAULT_AGENT_CONFIGURATION = AgentConfiguration(
keep_tunnel_open_time=30,
custom_pbas=CUSTOM_PBA_CONFIGURATION,
post_breach_actions=PBA_CONFIGURATION,
credential_collectors=CREDENTIAL_COLLECTOR_CONFIGURATION,
payloads=PAYLOAD_CONFIGURATION,
propagation=PROPAGATION_CONFIGURATION,
)

View File

@ -2,5 +2,12 @@ from enum import Enum
class OperatingSystems(Enum): class OperatingSystems(Enum):
"""
An Enum representing all supported operating systems
This Enum represents all operating systems that Infection Monkey supports. The value of each
member is the member's name in all lower-case characters.
"""
LINUX = "linux" LINUX = "linux"
WINDOWS = "windows" WINDOWS = "windows"

View File

@ -38,5 +38,7 @@ class DomainControllerNameFetchError(FailedExploitationError):
"""Raise on failed attempt to extract domain controller's name""" """Raise on failed attempt to extract domain controller's name"""
# TODO: This has been replaced by common.configuration.InvalidConfigurationError. Use that error
# instead and remove this one.
class InvalidConfigurationError(Exception): class InvalidConfigurationError(Exception):
"""Raise when configuration is invalid""" """Raise when configuration is invalid"""

View File

@ -5,6 +5,7 @@ from typing import Mapping
import requests import requests
from common import OperatingSystems
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from . import IAgentRepository from . import IAgentRepository
@ -22,18 +23,22 @@ class CachingAgentRepository(IAgentRepository):
self._proxies = proxies self._proxies = proxies
self._lock = threading.Lock() self._lock = threading.Lock()
def get_agent_binary(self, os: str, architecture: str = None) -> io.BytesIO: def get_agent_binary(
self, operating_system: OperatingSystems, architecture: str = None
) -> io.BytesIO:
# If multiple calls to get_agent_binary() are made simultaneously before the result of # If multiple calls to get_agent_binary() are made simultaneously before the result of
# _download_binary_from_island() is cached, then multiple requests will be sent to the # _download_binary_from_island() is cached, then multiple requests will be sent to the
# island. Add a mutex in front of the call to _download_agent_binary_from_island() so # island. Add a mutex in front of the call to _download_agent_binary_from_island() so
# that only one request per OS will be sent to the island. # that only one request per OS will be sent to the island.
with self._lock: with self._lock:
return io.BytesIO(self._download_binary_from_island(os)) return io.BytesIO(self._download_binary_from_island(operating_system))
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _download_binary_from_island(self, os: str) -> bytes: def _download_binary_from_island(self, operating_system: OperatingSystems) -> bytes:
os_name = operating_system.value
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
f"{self._island_url}/api/agent-binaries/{os}", f"{self._island_url}/api/agent-binaries/{os_name}",
verify=False, verify=False,
proxies=self._proxies, proxies=self._proxies,
timeout=MEDIUM_REQUEST_TIMEOUT, timeout=MEDIUM_REQUEST_TIMEOUT,

View File

@ -1,6 +1,8 @@
import abc import abc
import io import io
from common import OperatingSystems
# TODO: The Island also has an IAgentRepository with a totally different interface. At the moment, # TODO: The Island also has an IAgentRepository with a totally different interface. At the moment,
# the Island and Agent have different needs, but at some point we should unify these. # the Island and Agent have different needs, but at some point we should unify these.
@ -13,12 +15,13 @@ class IAgentRepository(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def get_agent_binary(self, os: str, architecture: str = None) -> io.BytesIO: def get_agent_binary(
self, operating_system: OperatingSystems, architecture: str = None
) -> io.BytesIO:
""" """
Retrieve the appropriate agent binary from the repository. Retrieve the appropriate agent binary from the repository.
:param str os: The name of the operating system on which the agent binary will run :param operating_system: The name of the operating system on which the agent binary will run
:param str architecture: Reserved :param architecture: Reserved
:return: A file-like object for the requested agent binary :return: A file-like object for the requested agent binary
:rtype: io.BytesIO
""" """
pass pass

View File

@ -129,7 +129,7 @@ class Log4ShellExploiter(WebRCE):
} }
def _build_java_class(self, exploit_command: str) -> bytes: def _build_java_class(self, exploit_command: str) -> bytes:
if OperatingSystems.LINUX in self.host.os["type"]: if OperatingSystems.LINUX == self.host.os["type"]:
return build_exploit_bytecode(exploit_command, LINUX_EXPLOIT_TEMPLATE_PATH) return build_exploit_bytecode(exploit_command, LINUX_EXPLOIT_TEMPLATE_PATH)
else: else:
return build_exploit_bytecode(exploit_command, WINDOWS_EXPLOIT_TEMPLATE_PATH) return build_exploit_bytecode(exploit_command, WINDOWS_EXPLOIT_TEMPLATE_PATH)

View File

@ -2,6 +2,7 @@ import logging
from pathlib import Path, PurePath from pathlib import Path, PurePath
from typing import List, Optional from typing import List, Optional
from common import OperatingSystems
from infection_monkey.exploit.HostExploiter import HostExploiter from infection_monkey.exploit.HostExploiter import HostExploiter
from infection_monkey.exploit.powershell_utils.auth_options import AuthOptions, get_auth_options from infection_monkey.exploit.powershell_utils.auth_options import AuthOptions, get_auth_options
from infection_monkey.exploit.powershell_utils.credentials import ( from infection_monkey.exploit.powershell_utils.credentials import (
@ -162,7 +163,7 @@ class PowerShellExploiter(HostExploiter):
temp_monkey_binary_filepath.unlink() temp_monkey_binary_filepath.unlink()
def _create_local_agent_file(self, binary_path): def _create_local_agent_file(self, binary_path):
agent_binary_bytes = self.agent_repository.get_agent_binary("windows") agent_binary_bytes = self.agent_repository.get_agent_binary(OperatingSystems.WINDOWS)
with open(binary_path, "wb") as f: with open(binary_path, "wb") as f:
f.write(agent_binary_bytes.getvalue()) f.write(agent_binary_bytes.getvalue())

View File

@ -57,6 +57,6 @@ class HTTPTools(object):
httpd.start() httpd.start()
lock.acquire() lock.acquire()
return ( return (
"http://%s:%s/%s" % (local_ip, local_port, urllib.parse.quote(host.os["type"])), f"http://{local_ip}:{local_port}/{urllib.parse.quote(host.os['type'].value)}",
httpd, httpd,
) )

View File

@ -1,5 +1,7 @@
import abc import abc
from common.configuration import AgentConfiguration
class IControlChannel(metaclass=abc.ABCMeta): class IControlChannel(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
@ -11,10 +13,10 @@ class IControlChannel(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def get_config(self) -> dict: def get_config(self) -> AgentConfiguration:
""" """
:return: A dictionary containing Agent Configuration :return: An AgentConfiguration object
:rtype: dict :rtype: AgentConfiguration
""" """
pass pass

View File

@ -1,8 +1,9 @@
import logging import logging
import threading import threading
import time import time
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from typing import Any, Callable, Iterable, List, Optional
from common.configuration import CustomPBAConfiguration, PluginConfiguration
from common.utils import Timer from common.utils import Timer
from infection_monkey.credential_store import ICredentialsStore from infection_monkey.credential_store import ICredentialsStore
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
@ -13,7 +14,7 @@ from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.credentials_telem import CredentialsTelem from infection_monkey.telemetry.credentials_telem import CredentialsTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.utils.propagation import should_propagate from infection_monkey.utils.propagation import maximum_depth_reached
from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter
from . import Exploiter, IPScanner, Propagator from . import Exploiter, IPScanner, Propagator
@ -111,7 +112,7 @@ class AutomatedMaster(IMaster):
time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC) time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC)
@staticmethod @staticmethod
def _try_communicate_with_island(fn: Callable[[], Any], max_tries: int): def _try_communicate_with_island(fn: Callable[[], Any], max_tries: int) -> Any:
tries = 0 tries = 0
while tries < max_tries: while tries < max_tries:
try: try:
@ -141,7 +142,7 @@ class AutomatedMaster(IMaster):
try: try:
config = AutomatedMaster._try_communicate_with_island( config = AutomatedMaster._try_communicate_with_island(
self._control_channel.get_config, CHECK_FOR_CONFIG_COUNT self._control_channel.get_config, CHECK_FOR_CONFIG_COUNT
)["config"] )
except IslandCommunicationError as e: except IslandCommunicationError as e:
logger.error(f"An error occurred while fetching configuration: {e}") logger.error(f"An error occurred while fetching configuration: {e}")
return return
@ -150,7 +151,7 @@ class AutomatedMaster(IMaster):
target=self._run_plugins, target=self._run_plugins,
name="CredentialCollectorThread", name="CredentialCollectorThread",
args=( args=(
config["credential_collectors"], config.credential_collectors,
"credential collector", "credential collector",
self._collect_credentials, self._collect_credentials,
), ),
@ -158,7 +159,7 @@ class AutomatedMaster(IMaster):
pba_thread = create_daemon_thread( pba_thread = create_daemon_thread(
target=self._run_pbas, target=self._run_pbas,
name="PBAThread", name="PBAThread",
args=(config["post_breach_actions"].items(), self._run_pba, config["custom_pbas"]), args=(config.post_breach_actions, self._run_pba, config.custom_pbas),
) )
credential_collector_thread.start() credential_collector_thread.start()
@ -173,52 +174,56 @@ class AutomatedMaster(IMaster):
current_depth = self._current_depth if self._current_depth is not None else 0 current_depth = self._current_depth if self._current_depth is not None else 0
logger.info(f"Current depth is {current_depth}") logger.info(f"Current depth is {current_depth}")
if should_propagate(self._control_channel.get_config(), self._current_depth): if maximum_depth_reached(config.propagation.maximum_depth, self._current_depth):
self._propagator.propagate(config["propagation"], current_depth, self._stop) self._propagator.propagate(config.propagation, current_depth, self._stop)
else: else:
logger.info("Skipping propagation: maximum depth reached") logger.info("Skipping propagation: maximum depth reached")
payload_thread = create_daemon_thread( payload_thread = create_daemon_thread(
target=self._run_plugins, target=self._run_plugins,
name="PayloadThread", name="PayloadThread",
args=(config["payloads"].items(), "payload", self._run_payload), args=(config.payloads, "payload", self._run_payload),
) )
payload_thread.start() payload_thread.start()
payload_thread.join() payload_thread.join()
pba_thread.join() pba_thread.join()
def _collect_credentials(self, collector: str): def _collect_credentials(self, collector: PluginConfiguration):
credentials = self._puppet.run_credential_collector(collector, {}) credentials = self._puppet.run_credential_collector(collector.name, collector.options)
if credentials: if credentials:
self._telemetry_messenger.send_telemetry(CredentialsTelem(credentials)) self._telemetry_messenger.send_telemetry(CredentialsTelem(credentials))
else: else:
logger.debug(f"No credentials were collected by {collector}") logger.debug(f"No credentials were collected by {collector}")
def _run_pba(self, pba: Tuple[str, Dict]): def _run_pba(self, pba: PluginConfiguration):
name = pba[0] for pba_data in self._puppet.run_pba(pba.name, pba.options):
options = pba[1]
for pba_data in self._puppet.run_pba(name, options):
self._telemetry_messenger.send_telemetry(PostBreachTelem(pba_data)) self._telemetry_messenger.send_telemetry(PostBreachTelem(pba_data))
def _run_payload(self, payload: Tuple[str, Dict]): def _run_payload(self, payload: PluginConfiguration):
name = payload[0] self._puppet.run_payload(payload.name, payload.options, self._stop)
options = payload[1]
self._puppet.run_payload(name, options, self._stop)
def _run_pbas( def _run_pbas(
self, plugins: Iterable[Any], callback: Callable[[Any], None], custom_pba_options: Mapping self,
plugins: Iterable[PluginConfiguration],
callback: Callable[[Any], None],
custom_pba_options: CustomPBAConfiguration,
): ):
self._run_plugins(plugins, "post-breach action", callback) self._run_plugins(plugins, "post-breach action", callback)
if custom_pba_is_enabled(custom_pba_options): if custom_pba_is_enabled(custom_pba_options):
self._run_plugins([("CustomPBA", custom_pba_options)], "post-breach action", callback) self._run_plugins(
[PluginConfiguration(name="CustomPBA", options=custom_pba_options.__dict__)],
"post-breach action",
callback,
)
def _run_plugins( def _run_plugins(
self, plugins: Iterable[Any], plugin_type: str, callback: Callable[[Any], None] self,
plugins: Iterable[PluginConfiguration],
plugin_type: str,
callback: Callable[[Any], None],
): ):
logger.info(f"Running {plugin_type}s") logger.info(f"Running {plugin_type}s")
logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run") logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run")

View File

@ -1,10 +1,12 @@
import json import json
import logging import logging
from pprint import pformat
from typing import Mapping from typing import Mapping
import requests import requests
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from common.configuration import AgentConfiguration
from infection_monkey.custom_types import PropagationCredentials from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
@ -47,7 +49,7 @@ class ControlChannel(IControlChannel):
) as e: ) as e:
raise IslandCommunicationError(e) raise IslandCommunicationError(e)
def get_config(self) -> dict: def get_config(self) -> AgentConfiguration:
try: try:
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
f"https://{self._control_channel_server}/api/agent", f"https://{self._control_channel_server}/api/agent",
@ -57,7 +59,10 @@ class ControlChannel(IControlChannel):
) )
response.raise_for_status() response.raise_for_status()
return json.loads(response.content.decode()) config_dict = json.loads(response.text)["config"]
logger.debug(f"Received configuration:\n{pformat(json.loads(response.text))}")
return AgentConfiguration.from_mapping(config_dict)
except ( except (
json.JSONDecodeError, json.JSONDecodeError,
requests.exceptions.ConnectionError, requests.exceptions.ConnectionError,

View File

@ -5,9 +5,13 @@ from copy import deepcopy
from itertools import chain from itertools import chain
from queue import Queue from queue import Queue
from threading import Event from threading import Event
from typing import Callable, Dict, List, Mapping from typing import Callable, Dict, Sequence
from common import OperatingSystems from common import OperatingSystems
from common.configuration.agent_sub_configurations import (
ExploitationConfiguration,
ExploiterConfiguration,
)
from infection_monkey.custom_types import PropagationCredentials from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_puppet import ExploiterResultData, IPuppet from infection_monkey.i_puppet import ExploiterResultData, IPuppet
from infection_monkey.model import VictimHost from infection_monkey.model import VictimHost
@ -46,7 +50,7 @@ class Exploiter:
def exploit_hosts( def exploit_hosts(
self, self,
exploiter_config: Dict, exploiter_config: ExploitationConfiguration,
hosts_to_exploit: Queue, hosts_to_exploit: Queue,
current_depth: int, current_depth: int,
results_callback: Callback, results_callback: Callback,
@ -56,7 +60,7 @@ class Exploiter:
exploiters_to_run = self._process_exploiter_config(exploiter_config) exploiters_to_run = self._process_exploiter_config(exploiter_config)
logger.debug( logger.debug(
"Agent is configured to run the following exploiters in order: " "Agent is configured to run the following exploiters in order: "
f"{', '.join([e['name'] for e in exploiters_to_run])}" f"{', '.join([e.name for e in exploiters_to_run])}"
) )
exploit_args = ( exploit_args = (
@ -75,24 +79,26 @@ class Exploiter:
) )
@staticmethod @staticmethod
def _process_exploiter_config(exploiter_config: Mapping) -> List[Mapping]: def _process_exploiter_config(
exploiter_config: ExploitationConfiguration,
) -> Sequence[ExploiterConfiguration]:
# Run vulnerability exploiters before brute force exploiters to minimize the effect of # Run vulnerability exploiters before brute force exploiters to minimize the effect of
# account lockout due to invalid credentials # account lockout due to invalid credentials
ordered_exploiters = chain( ordered_exploiters = chain(exploiter_config.vulnerability, exploiter_config.brute_force)
exploiter_config["vulnerability"], exploiter_config["brute_force"]
)
exploiters_to_run = list(deepcopy(ordered_exploiters)) exploiters_to_run = list(deepcopy(ordered_exploiters))
extended_exploiters = []
for exploiter in exploiters_to_run: for exploiter in exploiters_to_run:
# This order allows exploiter-specific options to # This order allows exploiter-specific options to
# override general options for all exploiters. # override general options for all exploiters.
exploiter["options"] = {**exploiter_config["options"], **exploiter["options"]} options = {**exploiter_config.options.__dict__, **exploiter.options}
extended_exploiters.append(ExploiterConfiguration(exploiter.name, options))
return exploiters_to_run return extended_exploiters
def _exploit_hosts_on_queue( def _exploit_hosts_on_queue(
self, self,
exploiters_to_run: List[Dict], exploiters_to_run: Sequence[ExploiterConfiguration],
hosts_to_exploit: Queue, hosts_to_exploit: Queue,
current_depth: int, current_depth: int,
results_callback: Callback, results_callback: Callback,
@ -119,7 +125,7 @@ class Exploiter:
def _run_all_exploiters( def _run_all_exploiters(
self, self,
exploiters_to_run: List[Dict], exploiters_to_run: Sequence[ExploiterConfiguration],
victim_host: VictimHost, victim_host: VictimHost,
current_depth: int, current_depth: int,
results_callback: Callback, results_callback: Callback,
@ -127,11 +133,10 @@ class Exploiter:
): ):
for exploiter in interruptible_iter(exploiters_to_run, stop): for exploiter in interruptible_iter(exploiters_to_run, stop):
exploiter_name = exploiter["name"] exploiter_name = exploiter.name
victim_os = victim_host.os.get("type") victim_os = victim_host.os.get("type")
# We want to try all exploiters if the victim's OS is unknown # We want to try all exploiters if the victim's OS is unknown
print(victim_os)
if victim_os is not None and victim_os not in SUPPORTED_OS[exploiter_name]: if victim_os is not None and victim_os not in SUPPORTED_OS[exploiter_name]:
logger.debug( logger.debug(
f"Skipping {exploiter_name} because it does not support " f"Skipping {exploiter_name} because it does not support "
@ -140,7 +145,7 @@ class Exploiter:
continue continue
exploiter_results = self._run_exploiter( exploiter_results = self._run_exploiter(
exploiter_name, exploiter["options"], victim_host, current_depth, stop exploiter_name, exploiter.options, victim_host, current_depth, stop
) )
results_callback(exploiter_name, victim_host, exploiter_results) results_callback(exploiter_name, victim_host, exploiter_results)

View File

@ -3,8 +3,13 @@ import queue
import threading import threading
from queue import Queue from queue import Queue
from threading import Event from threading import Event
from typing import Any, Callable, Dict, List from typing import Callable, Dict, Sequence
from common.configuration.agent_sub_configurations import (
NetworkScanConfiguration,
PluginConfiguration,
ScanTargetConfiguration,
)
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
FingerprintData, FingerprintData,
IPuppet, IPuppet,
@ -29,8 +34,8 @@ class IPScanner:
def scan( def scan(
self, self,
addresses_to_scan: List[NetworkAddress], addresses_to_scan: Sequence[NetworkAddress],
options: Dict, options: ScanTargetConfiguration,
results_callback: Callback, results_callback: Callback,
stop: Event, stop: Event,
): ):
@ -49,12 +54,16 @@ class IPScanner:
) )
def _scan_addresses( def _scan_addresses(
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event self,
addresses: Queue,
options: NetworkScanConfiguration,
results_callback: Callback,
stop: Event,
): ):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") logger.debug(f"Starting scan .read -- Thread ID: {threading.get_ident()}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000 icmp_timeout = options.icmp.timeout
tcp_timeout = options["tcp"]["timeout_ms"] / 1000 tcp_timeout = options.tcp.timeout
tcp_ports = options["tcp"]["ports"] tcp_ports = options.tcp.ports
try: try:
while not stop.is_set(): while not stop.is_set():
@ -66,7 +75,7 @@ class IPScanner:
fingerprint_data = {} fingerprint_data = {}
if IPScanner.port_scan_found_open_port(port_scan_data): if IPScanner.port_scan_found_open_port(port_scan_data):
fingerprinters = options["fingerprinters"] fingerprinters = options.fingerprinters
fingerprint_data = self._run_fingerprinters( fingerprint_data = self._run_fingerprinters(
address.ip, fingerprinters, ping_scan_data, port_scan_data, stop address.ip, fingerprinters, ping_scan_data, port_scan_data, stop
) )
@ -90,7 +99,7 @@ class IPScanner:
def _run_fingerprinters( def _run_fingerprinters(
self, self,
ip: str, ip: str,
fingerprinters: List[Dict[str, Any]], fingerprinters: Sequence[PluginConfiguration],
ping_scan_data: PingScanData, ping_scan_data: PingScanData,
port_scan_data: Dict[int, PortScanData], port_scan_data: Dict[int, PortScanData],
stop: Event, stop: Event,
@ -98,8 +107,8 @@ class IPScanner:
fingerprint_data = {} fingerprint_data = {}
for f in interruptible_iter(fingerprinters, stop): for f in interruptible_iter(fingerprinters, stop):
fingerprint_data[f["name"]] = self._puppet.fingerprint( fingerprint_data[f.name] = self._puppet.fingerprint(
f["name"], ip, ping_scan_data, port_scan_data, f["options"] f.name, ip, ping_scan_data, port_scan_data, f.options
) )
return fingerprint_data return fingerprint_data

View File

@ -1,13 +1,12 @@
from typing import Dict from common.configuration import CustomPBAConfiguration
from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.environment import is_windows_os
def custom_pba_is_enabled(pba_options: Dict) -> bool: def custom_pba_is_enabled(pba_options: CustomPBAConfiguration) -> bool:
if not is_windows_os(): if not is_windows_os():
if pba_options["linux_command"]: if pba_options.linux_command:
return True return True
else: else:
if pba_options["windows_command"]: if pba_options.windows_command:
return True return True
return False return False

View File

@ -1,8 +1,13 @@
import logging import logging
from queue import Queue from queue import Queue
from threading import Event from threading import Event
from typing import Dict, List from typing import List
from common.configuration import (
NetworkScanConfiguration,
PropagationConfiguration,
ScanTargetConfiguration,
)
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
FingerprintData, FingerprintData,
@ -39,14 +44,18 @@ class Propagator:
self._local_network_interfaces = local_network_interfaces self._local_network_interfaces = local_network_interfaces
self._hosts_to_exploit = None self._hosts_to_exploit = None
def propagate(self, propagation_config: Dict, current_depth: int, stop: Event): def propagate(
self, propagation_config: PropagationConfiguration, current_depth: int, stop: Event
):
logger.info("Attempting to propagate") logger.info("Attempting to propagate")
network_scan_completed = Event() network_scan_completed = Event()
self._hosts_to_exploit = Queue() self._hosts_to_exploit = Queue()
scan_thread = create_daemon_thread( scan_thread = create_daemon_thread(
target=self._scan_network, name="PropagatorScanThread", args=(propagation_config, stop) target=self._scan_network,
name="PropagatorScanThread",
args=(propagation_config.network_scan, stop),
) )
exploit_thread = create_daemon_thread( exploit_thread = create_daemon_thread(
target=self._exploit_hosts, target=self._exploit_hosts,
@ -64,22 +73,21 @@ class Propagator:
logger.info("Finished attempting to propagate") logger.info("Finished attempting to propagate")
def _scan_network(self, propagation_config: Dict, stop: Event): def _scan_network(self, scan_config: NetworkScanConfiguration, stop: Event):
logger.info("Starting network scan") logger.info("Starting network scan")
target_config = propagation_config["targets"] addresses_to_scan = self._compile_scan_target_list(scan_config.targets)
scan_config = propagation_config["network_scan"]
addresses_to_scan = self._compile_scan_target_list(target_config)
self._ip_scanner.scan(addresses_to_scan, scan_config, self._process_scan_results, stop) self._ip_scanner.scan(addresses_to_scan, scan_config, self._process_scan_results, stop)
logger.info("Finished network scan") logger.info("Finished network scan")
def _compile_scan_target_list(self, target_config: Dict) -> List[NetworkAddress]: def _compile_scan_target_list(
ranges_to_scan = target_config["subnet_scan_list"] self, target_config: ScanTargetConfiguration
inaccessible_subnets = target_config["inaccessible_subnets"] ) -> List[NetworkAddress]:
blocklisted_ips = target_config["blocked_ips"] ranges_to_scan = target_config.subnets
enable_local_network_scan = target_config["local_network_scan"] inaccessible_subnets = target_config.inaccessible_subnets
blocklisted_ips = target_config.blocked_ips
enable_local_network_scan = target_config.local_network_scan
return compile_scan_target_list( return compile_scan_target_list(
self._local_network_interfaces, self._local_network_interfaces,
@ -134,14 +142,14 @@ class Propagator:
def _exploit_hosts( def _exploit_hosts(
self, self,
propagation_config: Dict, propagation_config: PropagationConfiguration,
current_depth: int, current_depth: int,
network_scan_completed: Event, network_scan_completed: Event,
stop: Event, stop: Event,
): ):
logger.info("Exploiting victims") logger.info("Exploiting victims")
exploiter_config = propagation_config["exploiters"] exploiter_config = propagation_config.exploitation
self._exploiter.exploit_hosts( self._exploiter.exploit_hosts(
exploiter_config, exploiter_config,
self._hosts_to_exploit, self._hosts_to_exploit,

View File

@ -17,7 +17,7 @@ class VictimHost(object):
return self.__dict__ return self.__dict__
def is_windows(self) -> bool: def is_windows(self) -> bool:
return OperatingSystems.WINDOWS in self.os["type"] return OperatingSystems.WINDOWS == self.os["type"]
def __hash__(self): def __hash__(self):
return hash(self.ip_addr) return hash(self.ip_addr)

View File

@ -78,7 +78,7 @@ from infection_monkey.utils.monkey_dir import (
remove_monkey_dir, remove_monkey_dir,
) )
from infection_monkey.utils.monkey_log_path import get_agent_log_path from infection_monkey.utils.monkey_log_path import get_agent_log_path
from infection_monkey.utils.propagation import should_propagate from infection_monkey.utils.propagation import maximum_depth_reached
from infection_monkey.utils.signal_handler import register_signal_handlers, reset_signal_handlers from infection_monkey.utils.signal_handler import register_signal_handlers, reset_signal_handlers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -173,9 +173,11 @@ class InfectionMonkey:
config = control_channel.get_config() config = control_channel.get_config()
self._monkey_inbound_tunnel = self._control_client.create_control_tunnel( self._monkey_inbound_tunnel = self._control_client.create_control_tunnel(
config["config"]["keep_tunnel_open_time"] config.keep_tunnel_open_time
) )
if self._monkey_inbound_tunnel and should_propagate(config, self._current_depth): if self._monkey_inbound_tunnel and maximum_depth_reached(
config.propagation.maximum_depth, self._current_depth
):
self._inbound_tunnel_opened = True self._inbound_tunnel_opened = True
self._monkey_inbound_tunnel.start() self._monkey_inbound_tunnel.start()

View File

@ -16,7 +16,7 @@ class ClearCommandHistory(PBA):
super().__init__(telemetry_messenger, name=POST_BREACH_CLEAR_CMD_HISTORY) super().__init__(telemetry_messenger, name=POST_BREACH_CLEAR_CMD_HISTORY)
def run(self, options: Dict) -> Iterable[PostBreachData]: def run(self, options: Dict) -> Iterable[PostBreachData]:
results = [pba.run() for pba in self.clear_command_history_pba_list()] results = [pba.run(options) for pba in self.clear_command_history_pba_list()]
if results: if results:
# `self.command` is empty here # `self.command` is empty here
self.pba_data.append(PostBreachData(self.name, self.command, results)) self.pba_data.append(PostBreachData(self.name, self.command, results))
@ -53,7 +53,7 @@ class ClearCommandHistory(PBA):
linux_cmd=linux_cmds, linux_cmd=linux_cmds,
) )
def run(self) -> Tuple[str, bool]: def run(self, options: Dict) -> Tuple[str, bool]:
if self.command: if self.command:
try: try:
output = subprocess.check_output( # noqa: DUO116 output = subprocess.check_output( # noqa: DUO116

View File

@ -62,7 +62,7 @@ class FileServHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
f.close() f.close()
def send_head(self): def send_head(self):
if self.path != "/" + urllib.parse.quote(self.victim_os): if self.path != "/" + urllib.parse.quote(self.victim_os.value):
self.send_error(500, "") self.send_error(500, "")
return None, 0, 0 return None, 0, 0
try: try:

View File

@ -1,2 +1,2 @@
def should_propagate(config: dict, current_depth: int) -> bool: def maximum_depth_reached(maximum_depth: int, current_depth: int) -> bool:
return config["config"]["depth"] > current_depth return maximum_depth > current_depth

View File

@ -1,6 +1,6 @@
import io import io
from common.configuration import AgentConfiguration, AgentConfigurationSchema from common.configuration import AgentConfiguration
from monkey_island.cc import repository from monkey_island.cc import repository
from monkey_island.cc.repository import ( from monkey_island.cc.repository import (
IAgentConfigurationRepository, IAgentConfigurationRepository,
@ -17,21 +17,20 @@ class FileAgentConfigurationRepository(IAgentConfigurationRepository):
): ):
self._default_agent_configuration = default_agent_configuration self._default_agent_configuration = default_agent_configuration
self._file_repository = file_repository self._file_repository = file_repository
self._schema = AgentConfigurationSchema()
def get_configuration(self) -> AgentConfiguration: def get_configuration(self) -> AgentConfiguration:
try: try:
with self._file_repository.open_file(AGENT_CONFIGURATION_FILE_NAME) as f: with self._file_repository.open_file(AGENT_CONFIGURATION_FILE_NAME) as f:
configuration_json = f.read().decode() configuration_json = f.read().decode()
return self._schema.loads(configuration_json) return AgentConfiguration.from_json(configuration_json)
except repository.FileNotFoundError: except repository.FileNotFoundError:
return self._default_agent_configuration return self._default_agent_configuration
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving the agent configuration: {err}") raise RetrievalError(f"Error retrieving the agent configuration: {err}")
def store_configuration(self, agent_configuration: AgentConfiguration): def store_configuration(self, agent_configuration: AgentConfiguration):
configuration_json = self._schema.dumps(agent_configuration) configuration_json = AgentConfiguration.to_json(agent_configuration)
self._file_repository.save_file( self._file_repository.save_file(
AGENT_CONFIGURATION_FILE_NAME, io.BytesIO(configuration_json.encode()) AGENT_CONFIGURATION_FILE_NAME, io.BytesIO(configuration_json.encode())

View File

@ -1,9 +1,9 @@
import json import json
import marshmallow
from flask import make_response, request from flask import make_response, request
from common.configuration.agent_configuration import AgentConfigurationSchema from common.configuration.agent_configuration import AgentConfiguration as AgentConfigurationObject
from common.configuration.agent_configuration import InvalidConfigurationError
from monkey_island.cc.repository import IAgentConfigurationRepository from monkey_island.cc.repository import IAgentConfigurationRepository
from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required from monkey_island.cc.resources.request_authentication import jwt_required
@ -14,22 +14,21 @@ class AgentConfiguration(AbstractResource):
def __init__(self, agent_configuration_repository: IAgentConfigurationRepository): def __init__(self, agent_configuration_repository: IAgentConfigurationRepository):
self._agent_configuration_repository = agent_configuration_repository self._agent_configuration_repository = agent_configuration_repository
self._schema = AgentConfigurationSchema()
@jwt_required @jwt_required
def get(self): def get(self):
configuration = self._agent_configuration_repository.get_configuration() configuration = self._agent_configuration_repository.get_configuration()
configuration_json = self._schema.dumps(configuration) configuration_json = AgentConfigurationObject.to_json(configuration)
return make_response(configuration_json, 200) return make_response(configuration_json, 200)
@jwt_required @jwt_required
def post(self): def post(self):
try: try:
configuration_object = self._schema.loads(request.data) configuration_object = AgentConfigurationObject.from_json(request.data)
self._agent_configuration_repository.store_configuration(configuration_object) self._agent_configuration_repository.store_configuration(configuration_object)
return make_response({}, 200) return make_response({}, 200)
except (marshmallow.exceptions.ValidationError, json.JSONDecodeError) as err: except (InvalidConfigurationError, json.JSONDecodeError) as err:
return make_response( return make_response(
{"message": f"Invalid configuration supplied: {err}"}, {"message": f"Invalid configuration supplied: {err}"},
400, 400,

View File

@ -179,6 +179,7 @@ class ConfigService:
should_encrypt=True, should_encrypt=True,
) )
@staticmethod
def _filter_none_values(data): def _filter_none_values(data):
if isinstance(data, dict): if isinstance(data, dict):
return { return {
@ -460,7 +461,7 @@ class ConfigService:
formatted_tcp_scan_config = {} formatted_tcp_scan_config = {}
formatted_tcp_scan_config["timeout"] = config[flat_tcp_timeout_field] formatted_tcp_scan_config["timeout"] = config[flat_tcp_timeout_field] / 1000
ports = ConfigService._union_tcp_and_http_ports( ports = ConfigService._union_tcp_and_http_ports(
config[flat_tcp_ports_field], config[flat_http_ports_field] config[flat_tcp_ports_field], config[flat_http_ports_field]
@ -484,7 +485,7 @@ class ConfigService:
flat_ping_timeout_field = "ping_scan_timeout" flat_ping_timeout_field = "ping_scan_timeout"
formatted_icmp_scan_config = {} formatted_icmp_scan_config = {}
formatted_icmp_scan_config["timeout"] = config[flat_ping_timeout_field] formatted_icmp_scan_config["timeout"] = config[flat_ping_timeout_field] / 1000
config.pop(flat_ping_timeout_field, None) config.pop(flat_ping_timeout_field, None)

View File

@ -3,7 +3,7 @@ from pathlib import Path
from common import DIContainer from common import DIContainer
from common.aws import AWSInstance from common.aws import AWSInstance
from common.configuration import AgentConfiguration, build_default_agent_configuration from common.configuration import DEFAULT_AGENT_CONFIGURATION, AgentConfiguration
from common.utils.file_utils import get_binary_io_sha256_hash from common.utils.file_utils import get_binary_io_sha256_hash
from monkey_island.cc.repository import ( from monkey_island.cc.repository import (
AgentBinaryRepository, AgentBinaryRepository,
@ -32,7 +32,7 @@ def initialize_services(data_dir: Path) -> DIContainer:
container.register_convention(Path, "data_dir", data_dir) container.register_convention(Path, "data_dir", data_dir)
container.register_convention( container.register_convention(
AgentConfiguration, "default_agent_configuration", build_default_agent_configuration() AgentConfiguration, "default_agent_configuration", DEFAULT_AGENT_CONFIGURATION
) )
container.register_instance(AWSInstance, AWSInstance()) container.register_instance(AWSInstance, AWSInstance())

View File

@ -1,12 +1,12 @@
from tests.common.example_agent_configuration import AGENT_CONFIGURATION from tests.common.example_agent_configuration import AGENT_CONFIGURATION
from common.configuration.agent_configuration import AgentConfigurationSchema from common.configuration.agent_configuration import AgentConfiguration
from monkey_island.cc.repository import IAgentConfigurationRepository from monkey_island.cc.repository import IAgentConfigurationRepository
class InMemoryAgentConfigurationRepository(IAgentConfigurationRepository): class InMemoryAgentConfigurationRepository(IAgentConfigurationRepository):
def __init__(self): def __init__(self):
self._configuration = AgentConfigurationSchema().load(AGENT_CONFIGURATION) self._configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
def get_configuration(self): def get_configuration(self):
return self._configuration return self._configuration

View File

@ -1,3 +1,7 @@
import json
from copy import deepcopy
import pytest
from tests.common.example_agent_configuration import ( from tests.common.example_agent_configuration import (
AGENT_CONFIGURATION, AGENT_CONFIGURATION,
BLOCKED_IPS, BLOCKED_IPS,
@ -23,11 +27,8 @@ from tests.common.example_agent_configuration import (
WINDOWS_FILENAME, WINDOWS_FILENAME,
) )
from common.configuration import ( from common.configuration import AgentConfiguration, InvalidConfigurationError
DEFAULT_AGENT_CONFIGURATION_JSON, from common.configuration.agent_configuration import AgentConfigurationSchema
AgentConfiguration,
AgentConfigurationSchema,
)
from common.configuration.agent_sub_configuration_schemas import ( from common.configuration.agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema, CustomPBAConfigurationSchema,
ExploitationConfigurationSchema, ExploitationConfigurationSchema,
@ -157,10 +158,8 @@ def test_propagation_configuration():
def test_agent_configuration(): def test_agent_configuration():
schema = AgentConfigurationSchema() config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
config_json = AgentConfiguration.to_json(config)
config = schema.load(AGENT_CONFIGURATION)
config_dict = schema.dump(config)
assert isinstance(config, AgentConfiguration) assert isinstance(config, AgentConfiguration)
assert config.keep_tunnel_open_time == 30 assert config.keep_tunnel_open_time == 30
@ -169,12 +168,53 @@ def test_agent_configuration():
assert isinstance(config.credential_collectors[0], PluginConfiguration) assert isinstance(config.credential_collectors[0], PluginConfiguration)
assert isinstance(config.payloads[0], PluginConfiguration) assert isinstance(config.payloads[0], PluginConfiguration)
assert isinstance(config.propagation, PropagationConfiguration) assert isinstance(config.propagation, PropagationConfiguration)
assert config_dict == AGENT_CONFIGURATION assert json.loads(config_json) == AGENT_CONFIGURATION
def test_default_agent_configuration(): def test_incorrect_type():
valid_config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
with pytest.raises(InvalidConfigurationError):
valid_config_dict = valid_config.__dict__
valid_config_dict["keep_tunnel_open_time"] = "not_a_float"
AgentConfiguration(**valid_config_dict)
def test_from_dict():
schema = AgentConfigurationSchema() schema = AgentConfigurationSchema()
dict_ = deepcopy(AGENT_CONFIGURATION)
config = schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON) config = AgentConfiguration.from_mapping(dict_)
assert schema.dump(config) == dict_
def test_from_dict__invalid_data():
dict_ = deepcopy(AGENT_CONFIGURATION)
dict_["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_mapping(dict_)
def test_from_json():
schema = AgentConfigurationSchema()
dict_ = deepcopy(AGENT_CONFIGURATION)
config = AgentConfiguration.from_json(json.dumps(dict_))
assert isinstance(config, AgentConfiguration) assert isinstance(config, AgentConfiguration)
assert schema.dump(config) == dict_
def test_from_json__invalid_data():
invalid_dict = deepcopy(AGENT_CONFIGURATION)
invalid_dict["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_json(json.dumps(invalid_dict))
def test_to_json():
config = deepcopy(AGENT_CONFIGURATION)
assert json.loads(AgentConfiguration.to_json(config)) == AGENT_CONFIGURATION

View File

@ -9,7 +9,7 @@ from _pytest.monkeypatch import MonkeyPatch
MONKEY_BASE_PATH = str(Path(__file__).parent.parent.parent) MONKEY_BASE_PATH = str(Path(__file__).parent.parent.parent)
sys.path.insert(0, MONKEY_BASE_PATH) sys.path.insert(0, MONKEY_BASE_PATH)
from common.configuration import AgentConfiguration, build_default_agent_configuration # noqa: E402 from common.configuration import DEFAULT_AGENT_CONFIGURATION, AgentConfiguration # noqa: E402
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -60,4 +60,4 @@ def load_monkey_config(data_for_tests_dir) -> Callable[[str], Dict]:
@pytest.fixture @pytest.fixture
def default_agent_configuration() -> AgentConfiguration: def default_agent_configuration() -> AgentConfiguration:
return build_default_agent_configuration() return DEFAULT_AGENT_CONFIGURATION

View File

@ -8,6 +8,10 @@ import pytest
from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet
from common import OperatingSystems from common import OperatingSystems
from common.configuration.agent_sub_configurations import (
ExploitationConfiguration,
ExploiterConfiguration,
)
from infection_monkey.master import Exploiter from infection_monkey.master import Exploiter
from infection_monkey.model import VictimHost from infection_monkey.model import VictimHost
@ -35,18 +39,18 @@ def callback():
@pytest.fixture @pytest.fixture
def exploiter_config(): def exploiter_config(default_agent_configuration):
return { brute_force = [
"options": {"dropper_path_linux": "/tmp/monkey"}, ExploiterConfiguration(name="MSSQLExploiter", options={"timeout": 10}),
"brute_force": [ ExploiterConfiguration(name="SSHExploiter", options={}),
{"name": "MSSQLExploiter", "options": {"timeout": 10}}, ExploiterConfiguration(name="WmiExploiter", options={"timeout": 10}),
{"name": "SSHExploiter", "options": {}}, ]
{"name": "WmiExploiter", "options": {"timeout": 10}}, vulnerability = [ExploiterConfiguration(name="ZerologonExploiter", options={})]
], return ExploitationConfiguration(
"vulnerability": [ options=default_agent_configuration.propagation.exploitation.options,
{"name": "ZerologonExploiter", "options": {}}, brute_force=brute_force,
], vulnerability=vulnerability,
} )
@pytest.fixture @pytest.fixture

View File

@ -5,6 +5,12 @@ from unittest.mock import MagicMock
import pytest import pytest
from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet
from common.configuration.agent_sub_configurations import (
ICMPScanConfiguration,
NetworkScanConfiguration,
PluginConfiguration,
TCPScanConfiguration,
)
from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus
from infection_monkey.master import IPScanner from infection_monkey.master import IPScanner
from infection_monkey.network import NetworkAddress from infection_monkey.network import NetworkAddress
@ -14,11 +20,10 @@ LINUX_OS = "linux"
@pytest.fixture @pytest.fixture
def scan_config(): def scan_config(default_agent_configuration):
return { tcp_config = TCPScanConfiguration(
"tcp": { timeout=3,
"timeout_ms": 3000, ports=[
"ports": [
22, 22,
445, 445,
3389, 3389,
@ -26,16 +31,20 @@ def scan_config():
8008, 8008,
3306, 3306,
], ],
}, )
"icmp": { icmp_config = ICMPScanConfiguration(timeout=1)
"timeout_ms": 1000, fingerprinter_config = [
}, PluginConfiguration(name="HTTPFinger", options={}),
"fingerprinters": [ PluginConfiguration(name="SMBFinger", options={}),
{"name": "HTTPFinger", "options": {}}, PluginConfiguration(name="SSHFinger", options={}),
{"name": "SMBFinger", "options": {}}, ]
{"name": "SSHFinger", "options": {}}, scan_config = NetworkScanConfiguration(
], tcp_config,
} icmp_config,
fingerprinter_config,
default_agent_configuration.propagation.network_scan.targets,
)
return scan_config
@pytest.fixture @pytest.fixture

View File

@ -3,6 +3,11 @@ from unittest.mock import MagicMock
import pytest import pytest
from common.configuration.agent_sub_configurations import (
NetworkScanConfiguration,
PropagationConfiguration,
ScanTargetConfiguration,
)
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
FingerprintData, FingerprintData,
@ -135,24 +140,37 @@ class StubExploiter:
pass pass
def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory): def get_propagation_config(
default_agent_configuration, scan_target_config: ScanTargetConfiguration
):
network_scan = NetworkScanConfiguration(
default_agent_configuration.propagation.network_scan.tcp,
default_agent_configuration.propagation.network_scan.icmp,
default_agent_configuration.propagation.network_scan.fingerprinters,
scan_target_config,
)
propagation_config = PropagationConfiguration(
default_agent_configuration.propagation.maximum_depth,
network_scan,
default_agent_configuration.propagation.exploitation,
)
return propagation_config
def test_scan_result_processing(
telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration
):
p = Propagator( p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), mock_victim_host_factory, [] telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), mock_victim_host_factory, []
) )
p.propagate( targets = ScanTargetConfiguration(
{ blocked_ips=[],
"targets": { inaccessible_subnets=[],
"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"], local_network_scan=False,
"local_network_scan": False, subnets=["10.0.0.1", "10.0.0.2", "10.0.0.3"],
"inaccessible_subnets": [],
"blocked_ips": [],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since StubExploiter ignores it
},
1,
Event(),
) )
propagation_config = get_propagation_config(default_agent_configuration, targets)
p.propagate(propagation_config, 1, Event())
assert len(telemetry_messenger_spy.telemetries) == 3 assert len(telemetry_messenger_spy.telemetries) == 3
@ -237,25 +255,20 @@ class MockExploiter:
def test_exploiter_result_processing( def test_exploiter_result_processing(
telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration
): ):
p = Propagator( p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), mock_victim_host_factory, [] telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), mock_victim_host_factory, []
) )
p.propagate(
{ targets = ScanTargetConfiguration(
"targets": { blocked_ips=[],
"subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"], inaccessible_subnets=[],
"local_network_scan": False, local_network_scan=False,
"inaccessible_subnets": [], subnets=["10.0.0.1", "10.0.0.2", "10.0.0.3"],
"blocked_ips": [],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since MockExploiter ignores it
},
1,
Event(),
) )
propagation_config = get_propagation_config(default_agent_configuration, targets)
p.propagate(propagation_config, 1, Event())
exploit_telems = [t for t in telemetry_messenger_spy.telemetries if isinstance(t, ExploitTelem)] exploit_telems = [t for t in telemetry_messenger_spy.telemetries if isinstance(t, ExploitTelem)]
assert len(exploit_telems) == 4 assert len(exploit_telems) == 4
@ -278,7 +291,9 @@ def test_exploiter_result_processing(
assert data["propagation_result"] assert data["propagation_result"]
def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory): def test_scan_target_generation(
telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration
):
local_network_interfaces = [NetworkInterface("10.0.0.9", "/29")] local_network_interfaces = [NetworkInterface("10.0.0.9", "/29")]
p = Propagator( p = Propagator(
telemetry_messenger_spy, telemetry_messenger_spy,
@ -287,20 +302,15 @@ def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_v
mock_victim_host_factory, mock_victim_host_factory,
local_network_interfaces, local_network_interfaces,
) )
p.propagate( targets = ScanTargetConfiguration(
{ blocked_ips=["10.0.0.3"],
"targets": { inaccessible_subnets=["10.0.0.128/30", "10.0.0.8/29"],
"subnet_scan_list": ["10.0.0.0/29", "172.10.20.30"], local_network_scan=True,
"local_network_scan": True, subnets=["10.0.0.0/29", "172.10.20.30"],
"blocked_ips": ["10.0.0.3"],
"inaccessible_subnets": ["10.0.0.128/30", "10.0.0.8/29"],
},
"network_scan": {}, # This is empty since MockIPscanner ignores it
"exploiters": {}, # This is empty since MockExploiter ignores it
},
1,
Event(),
) )
propagation_config = get_propagation_config(default_agent_configuration, targets)
p.propagate(propagation_config, 1, Event())
expected_ip_scan_list = [ expected_ip_scan_list = [
"10.0.0.0", "10.0.0.0",
"10.0.0.1", "10.0.0.1",

View File

@ -1,32 +1,22 @@
from infection_monkey.utils.propagation import should_propagate from infection_monkey.utils.propagation import maximum_depth_reached
def get_config(max_depth): def test_maximum_depth_reached__current_less_than_max():
return {"config": {"depth": max_depth}} maximum_depth = 2
def test_should_propagate_current_less_than_max():
max_depth = 2
current_depth = 1 current_depth = 1
config = get_config(max_depth) assert maximum_depth_reached(maximum_depth, current_depth) is True
assert should_propagate(config, current_depth) is True
def test_should_propagate_current_greater_than_max(): def test_maximum_depth_reached__current_greater_than_max():
max_depth = 2 maximum_depth = 2
current_depth = 3 current_depth = 3
config = get_config(max_depth) assert maximum_depth_reached(maximum_depth, current_depth) is False
assert should_propagate(config, current_depth) is False
def test_should_propagate_current_equal_to_max(): def test_maximum_depth_reached__current_equal_to_max():
max_depth = 2 maximum_depth = 2
current_depth = max_depth current_depth = maximum_depth
config = get_config(max_depth) assert maximum_depth_reached(maximum_depth, current_depth) is False
assert should_propagate(config, current_depth) is False

View File

@ -2,7 +2,7 @@ import pytest
from tests.common.example_agent_configuration import AGENT_CONFIGURATION from tests.common.example_agent_configuration import AGENT_CONFIGURATION
from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository
from common.configuration import AgentConfigurationSchema from common.configuration import AgentConfiguration
from monkey_island.cc.repository import FileAgentConfigurationRepository, RetrievalError from monkey_island.cc.repository import FileAgentConfigurationRepository, RetrievalError
@ -12,8 +12,7 @@ def repository(default_agent_configuration):
def test_store_agent_config(repository): def test_store_agent_config(repository):
schema = AgentConfigurationSchema() agent_configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
agent_configuration = schema.load(AGENT_CONFIGURATION)
repository.store_configuration(agent_configuration) repository.store_configuration(agent_configuration)
retrieved_agent_configuration = repository.get_configuration() retrieved_agent_configuration = repository.get_configuration()

View File

@ -99,7 +99,7 @@ def test_format_config_for_agent__propagation():
def test_format_config_for_agent__network_scan(): def test_format_config_for_agent__network_scan():
expected_network_scan_config = { expected_network_scan_config = {
"tcp": { "tcp": {
"timeout": 3000, "timeout": 3.0,
"ports": [ "ports": [
22, 22,
80, 80,
@ -117,7 +117,7 @@ def test_format_config_for_agent__network_scan():
], ],
}, },
"icmp": { "icmp": {
"timeout": 1000, "timeout": 1.0,
}, },
"targets": { "targets": {
"blocked_ips": ["192.168.1.1", "192.168.1.100"], "blocked_ips": ["192.168.1.1", "192.168.1.100"],