Merge branch '1928-aws-service-refactor' into develop

Resolves #1928
This commit is contained in:
Mike Salvatore 2022-05-12 07:23:47 -04:00
commit 77702fcfbd
54 changed files with 1449 additions and 1102 deletions

View File

@ -1 +1 @@
from .di_container import DIContainer from .di_container import DIContainer, UnregisteredTypeError

View File

@ -0,0 +1 @@
from .aws_instance import AWSInstance

View File

@ -1,117 +1,52 @@
import json import threading
import logging
import re
from dataclasses import dataclass
from typing import Optional, Tuple
import requests from .aws_metadata import fetch_aws_instance_metadata
AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254" AWS_FETCH_METADATA_TIMEOUT = 10.0 # Seconds
AWS_LATEST_METADATA_URI_PREFIX = "http://{0}/latest/".format(AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS)
ACCOUNT_ID_KEY = "accountId"
logger = logging.getLogger(__name__)
AWS_TIMEOUT = 2
@dataclass class AWSTimeoutError(Exception):
class AwsInstanceInfo: """Raised when communications with AWS timeout"""
instance_id: Optional[str] = None
region: Optional[str] = None
account_id: Optional[str] = None
class AwsInstance: class AWSInstance:
""" """
Class which gives useful information about the current instance you're on. Class which gives useful information about the current instance you're on.
""" """
def __init__(self): def __init__(self):
self._is_instance, self._instance_info = AwsInstance._fetch_instance_info() self._instance_id = None
self._region = None
self._account_id = None
self._initialization_complete = threading.Event()
fetch_thread = threading.Thread(target=self._fetch_aws_instance_metadata, daemon=True)
fetch_thread.start()
def _fetch_aws_instance_metadata(self):
(self._instance_id, self._region, self._account_id) = fetch_aws_instance_metadata()
self._initialization_complete.set()
@property @property
def is_instance(self) -> bool: def is_instance(self) -> bool:
return self._is_instance self._wait_for_initialization_to_complete()
return bool(self._instance_id)
@property @property
def instance_id(self) -> str: def instance_id(self) -> str:
return self._instance_info.instance_id self._wait_for_initialization_to_complete()
return self._instance_id
@property @property
def region(self) -> str: def region(self) -> str:
return self._instance_info.region self._wait_for_initialization_to_complete()
return self._region
@property @property
def account_id(self) -> str: def account_id(self) -> str:
return self._instance_info.account_id self._wait_for_initialization_to_complete()
return self._account_id
@staticmethod def _wait_for_initialization_to_complete(self):
def _fetch_instance_info() -> Tuple[bool, AwsInstanceInfo]: if not self._initialization_complete.wait(AWS_FETCH_METADATA_TIMEOUT):
try: raise AWSTimeoutError("Timed out while attempting to retrieve metadata from AWS")
response = requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "meta-data/instance-id",
timeout=AWS_TIMEOUT,
)
if not response:
return False, AwsInstanceInfo()
info = AwsInstanceInfo()
info.instance_id = response.text if response else False
info.region = AwsInstance._parse_region(
requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "meta-data/placement/availability-zone",
timeout=AWS_TIMEOUT,
).text
)
except (requests.RequestException, IOError) as e:
logger.debug("Failed init of AwsInstance while getting metadata: {}".format(e))
return False, AwsInstanceInfo()
try:
info.account_id = AwsInstance._extract_account_id(
requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "dynamic/instance-identity/document",
timeout=AWS_TIMEOUT,
).text
)
except (requests.RequestException, json.decoder.JSONDecodeError, IOError) as e:
logger.debug(
"Failed init of AwsInstance while getting dynamic instance data: {}".format(e)
)
return False, AwsInstanceInfo()
return True, info
@staticmethod
def _parse_region(region_url_response):
# For a list of regions, see:
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts
# .RegionsAndAvailabilityZones.html
# This regex will find any AWS region format string in the response.
re_phrase = r"((?:us|eu|ap|ca|cn|sa)-[a-z]*-[0-9])"
finding = re.findall(re_phrase, region_url_response, re.IGNORECASE)
if finding:
return finding[0]
else:
return None
@staticmethod
def _extract_account_id(instance_identity_document_response):
"""
Extracts the account id from the dynamic/instance-identity/document metadata path.
Based on https://forums.aws.amazon.com/message.jspa?messageID=409028 which has a few more
solutions,
in case Amazon break this mechanism.
:param instance_identity_document_response: json returned via the web page
../dynamic/instance-identity/document
:return: The account id
"""
return json.loads(instance_identity_document_response)[ACCOUNT_ID_KEY]
def get_account_id(self):
"""
:return: the AWS account ID which "owns" this instance.
See https://docs.aws.amazon.com/general/latest/gr/acct-identifiers.html
"""
return self.account_id

View File

@ -0,0 +1,86 @@
import json
import logging
import re
from typing import Optional, Tuple
import requests
AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254"
AWS_LATEST_METADATA_URI_PREFIX = f"http://{AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS}/latest/"
ACCOUNT_ID_KEY = "accountId"
logger = logging.getLogger(__name__)
AWS_TIMEOUT = 2
def fetch_aws_instance_metadata() -> Tuple[Optional[str], Optional[str], Optional[str]]:
instance_id = None
region = None
account_id = None
try:
instance_id = _fetch_aws_instance_id()
region = _fetch_aws_region()
account_id = _fetch_account_id()
except (
requests.RequestException,
IOError,
json.decoder.JSONDecodeError,
) as err:
logger.debug(f"Failed init of AWSInstance while getting metadata: {err}")
return (None, None, None)
return (instance_id, region, account_id)
def _fetch_aws_instance_id() -> Optional[str]:
url = AWS_LATEST_METADATA_URI_PREFIX + "meta-data/instance-id"
response = requests.get(
url,
timeout=AWS_TIMEOUT,
)
response.raise_for_status()
return response.text
def _fetch_aws_region() -> Optional[str]:
response = requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "meta-data/placement/availability-zone",
timeout=AWS_TIMEOUT,
)
response.raise_for_status()
return _parse_region(response.text)
def _parse_region(region_url_response: str) -> Optional[str]:
# For a list of regions, see:
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts
# .RegionsAndAvailabilityZones.html
# This regex will find any AWS region format string in the response.
re_phrase = r"((?:us|eu|ap|ca|cn|sa)-[a-z]*-[0-9])"
finding = re.findall(re_phrase, region_url_response, re.IGNORECASE)
if finding:
return finding[0]
else:
return None
def _fetch_account_id() -> str:
"""
Fetches and extracts the account id from the dynamic/instance-identity/document metadata path.
Based on https://forums.aws.amazon.com/message.jspa?messageID=409028 which has a few more
solutions, in case Amazon break this mechanism.
:param instance_identity_document_response: json returned via the web page
../dynamic/instance-identity/document
:return: The account id
"""
response = requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "dynamic/instance-identity/document",
timeout=AWS_TIMEOUT,
)
response.raise_for_status()
return json.loads(response.text)[ACCOUNT_ID_KEY]

View File

@ -1,8 +0,0 @@
class Cmd(object):
"""
Class representing a command
"""
def __init__(self, cmd_runner, cmd_id):
self.cmd_runner = cmd_runner
self.cmd_id = cmd_id

View File

@ -1,10 +0,0 @@
class CmdResult(object):
"""
Class representing a command result
"""
def __init__(self, is_success, status_code=None, stdout=None, stderr=None):
self.is_success = is_success
self.status_code = status_code
self.stdout = stdout
self.stderr = stderr

View File

@ -1,154 +0,0 @@
import logging
import time
from abc import abstractmethod
from common.cmd.cmd_result import CmdResult
from common.cmd.cmd_status import CmdStatus
logger = logging.getLogger(__name__)
class CmdRunner(object):
"""
Interface for running commands on a remote machine
Since these classes are a bit complex, I provide a list of common terminology and formats:
* command line - a command line. e.g. 'echo hello'
* command - represent a single command which was already run. Always of type Cmd
* command id - any unique identifier of a command which was already run
* command result - represents the result of running a command. Always of type CmdResult
* command status - represents the current status of a command. Always of type CmdStatus
* command info - Any consistent structure representing additional information of a command
which was already run
* instance - a machine that commands will be run on. Can be any dictionary with 'instance_id'
as a field
* instance_id - any unique identifier of an instance (machine). Can be of any format
"""
# Default command timeout in seconds
DEFAULT_TIMEOUT = 5
# Time to sleep when waiting on commands.
WAIT_SLEEP_TIME = 1
def __init__(self, is_linux):
self.is_linux = is_linux
@staticmethod
def run_multiple_commands(instances, inst_to_cmd, inst_n_cmd_res_to_res):
"""
Run multiple commands on various instances
:param instances: List of instances.
:param inst_to_cmd: Function which receives an instance, runs a command asynchronously
and returns Cmd
:param inst_n_cmd_res_to_res: Function which receives an instance and CmdResult
and returns a parsed result (of any format)
:return: Dictionary with 'instance_id' as key and parsed result as value
"""
command_instance_dict = {}
for instance in instances:
command = inst_to_cmd(instance)
command_instance_dict[command] = instance
instance_results = {}
command_result_pairs = CmdRunner.wait_commands(list(command_instance_dict.keys()))
for command, result in command_result_pairs:
instance = command_instance_dict[command]
instance_results[instance["instance_id"]] = inst_n_cmd_res_to_res(instance, result)
return instance_results
@abstractmethod
def run_command_async(self, command_line):
"""
Runs the given command on the remote machine asynchronously.
:param command_line: The command line to run
:return: Command ID (in any format)
"""
raise NotImplementedError()
@staticmethod
def wait_commands(commands, timeout=DEFAULT_TIMEOUT):
"""
Waits on all commands up to given timeout
:param commands: list of commands (of type Cmd)
:param timeout: Timeout in seconds for command.
:return: commands and their results (tuple of Command and CmdResult)
"""
init_time = time.time()
curr_time = init_time
results = []
# TODO: Use timer.Timer
while (curr_time - init_time < timeout) and (len(commands) != 0):
for command in list(
commands
): # list(commands) clones the list. We do so because we remove items inside
CmdRunner._process_command(command, commands, results, True)
time.sleep(CmdRunner.WAIT_SLEEP_TIME)
curr_time = time.time()
for command in list(commands):
CmdRunner._process_command(command, commands, results, False)
for command, result in results:
if not result.is_success:
logger.error(
f"The command with id: {str(command.cmd_id)} failed. "
f"Status code: {str(result.status_code)}"
)
return results
@abstractmethod
def query_command(self, command_id):
"""
Queries the already run command for more info
:param command_id: The command ID to query
:return: Command info (in any format)
"""
raise NotImplementedError()
@abstractmethod
def get_command_result(self, command_info):
"""
Gets the result of the already run command
:param command_info: The command info of the command to get the result of
:return: CmdResult
"""
raise NotImplementedError()
@abstractmethod
def get_command_status(self, command_info):
"""
Gets the status of the already run command
:param command_info: The command info of the command to get the result of
:return: CmdStatus
"""
raise NotImplementedError()
@staticmethod
def _process_command(command, commands, results, should_process_only_finished):
"""
Removes the command from the list, processes its result and appends to results
:param command: Command to process. Must be in commands.
:param commands: List of unprocessed commands.
:param results: List of command results.
:param should_process_only_finished: If True, processes only if command finished.
:return: None
"""
c_runner = command.cmd_runner
c_id = command.cmd_id
try:
command_info = c_runner.query_command(c_id)
if (not should_process_only_finished) or c_runner.get_command_status(
command_info
) != CmdStatus.IN_PROGRESS:
commands.remove(command)
results.append((command, c_runner.get_command_result(command_info)))
except Exception:
logger.exception("Exception while querying command: `%s`", str(c_id))
if not should_process_only_finished:
commands.remove(command)
results.append((command, CmdResult(False)))

View File

@ -1,7 +0,0 @@
from enum import Enum
class CmdStatus(Enum):
IN_PROGRESS = 0
SUCCESS = 1
FAILURE = 2

View File

@ -0,0 +1 @@
from .timer import Timer

View File

@ -1,3 +1,7 @@
import queue
from typing import Any, List
class abstractstatic(staticmethod): class abstractstatic(staticmethod):
__slots__ = () __slots__ = ()
@ -15,3 +19,14 @@ class Singleton(type):
if cls not in cls._instances: if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls] return cls._instances[cls]
def queue_to_list(q: queue.Queue) -> List[Any]:
list_ = []
try:
while True:
list_.append(q.get_nowait())
except queue.Empty:
pass
return list_

View File

@ -3,6 +3,7 @@ import time
from pathlib import PurePath from pathlib import PurePath
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.utils import Timer
from infection_monkey.exploit.log4shell_utils import ( from infection_monkey.exploit.log4shell_utils import (
LINUX_EXPLOIT_TEMPLATE_PATH, LINUX_EXPLOIT_TEMPLATE_PATH,
WINDOWS_EXPLOIT_TEMPLATE_PATH, WINDOWS_EXPLOIT_TEMPLATE_PATH,
@ -21,7 +22,6 @@ from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.utils.commands import build_monkey_commandline from infection_monkey.utils.commands import build_monkey_commandline
from infection_monkey.utils.monkey_dir import get_monkey_dir_path from infection_monkey.utils.monkey_dir import get_monkey_dir_path
from infection_monkey.utils.threading import interruptible_iter from infection_monkey.utils.threading import interruptible_iter
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -5,6 +5,7 @@ from pathlib import PurePath
import paramiko import paramiko
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.utils import Timer
from common.utils.attack_utils import ScanStatus from common.utils.attack_utils import ScanStatus
from common.utils.exceptions import FailedExploitationError from common.utils.exceptions import FailedExploitationError
from infection_monkey.exploit.HostExploiter import HostExploiter from infection_monkey.exploit.HostExploiter import HostExploiter
@ -17,7 +18,6 @@ from infection_monkey.telemetry.attack.t1222_telem import T1222Telem
from infection_monkey.utils.brute_force import generate_identity_secret_pairs from infection_monkey.utils.brute_force import generate_identity_secret_pairs
from infection_monkey.utils.commands import build_monkey_commandline from infection_monkey.utils.commands import build_monkey_commandline
from infection_monkey.utils.threading import interruptible_iter from infection_monkey.utils.threading import interruptible_iter
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SSH_PORT = 22 SSH_PORT = 22

View File

@ -3,6 +3,7 @@ import threading
import time import time
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from common.utils import Timer
from infection_monkey.credential_store import ICredentialsStore from infection_monkey.credential_store import ICredentialsStore
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
@ -13,7 +14,6 @@ from infection_monkey.telemetry.credentials_telem import CredentialsTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter
from infection_monkey.utils.timer import Timer
from . import Exploiter, IPScanner, Propagator from . import Exploiter, IPScanner, Propagator
from .option_parsing import custom_pba_is_enabled from .option_parsing import custom_pba_is_enabled

View File

@ -4,9 +4,9 @@ import socket
import time import time
from typing import Iterable, Mapping, Tuple from typing import Iterable, Mapping, Tuple
from common.utils import Timer
from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,10 +2,10 @@ import queue
import threading import threading
from typing import Dict from typing import Dict
from common.utils import Timer
from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem
from infection_monkey.telemetry.i_telem import ITelem from infection_monkey.telemetry.i_telem import ITelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.utils.timer import Timer
DEFAULT_PERIOD = 5 DEFAULT_PERIOD = 5
WAKES_PER_PERIOD = 4 WAKES_PER_PERIOD = 4

View File

@ -4,11 +4,11 @@ import struct
import time import time
from threading import Event, Thread from threading import Event, Thread
from common.utils import Timer
from infection_monkey.network.firewall import app as firewall from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port, local_ips from infection_monkey.network.info import get_free_tcp_port, local_ips
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
from infection_monkey.transport.base import get_last_serve_time from infection_monkey.transport.base import get_last_serve_time
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,6 +1,6 @@
import logging import logging
from common.aws.aws_instance import AwsInstance from common.aws import AWSInstance
from infection_monkey.telemetry.aws_instance_telem import AWSInstanceTelemetry from infection_monkey.telemetry.aws_instance_telem import AWSInstanceTelemetry
from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter import ( from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter import (
LegacyTelemetryMessengerAdapter, LegacyTelemetryMessengerAdapter,
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
def _report_aws_environment(telemetry_messenger: LegacyTelemetryMessengerAdapter): def _report_aws_environment(telemetry_messenger: LegacyTelemetryMessengerAdapter):
logger.info("Collecting AWS info") logger.info("Collecting AWS info")
aws_instance = AwsInstance() aws_instance = AWSInstance()
if aws_instance.is_instance: if aws_instance.is_instance:
logger.info("Machine is an AWS instance") logger.info("Machine is an AWS instance")

View File

@ -1,7 +1,7 @@
import threading import threading
from functools import wraps from functools import wraps
from .timer import Timer from common.utils import Timer
def request_cache(ttl: float): def request_cache(ttl: float):

View File

@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
from common import DIContainer from common import DIContainer
from monkey_island.cc.database import database, mongo from monkey_island.cc.database import database, mongo
from monkey_island.cc.resources import RemoteRun
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
from monkey_island.cc.resources.attack.attack_report import AttackReport from monkey_island.cc.resources.attack.attack_report import AttackReport
from monkey_island.cc.resources.auth.auth import Authenticate, init_jwt from monkey_island.cc.resources.auth.auth import Authenticate, init_jwt
@ -38,7 +39,6 @@ from monkey_island.cc.resources.pba_file_download import PBAFileDownload
from monkey_island.cc.resources.pba_file_upload import FileUpload from monkey_island.cc.resources.pba_file_upload import FileUpload
from monkey_island.cc.resources.propagation_credentials import PropagationCredentials from monkey_island.cc.resources.propagation_credentials import PropagationCredentials
from monkey_island.cc.resources.ransomware_report import RansomwareReport from monkey_island.cc.resources.ransomware_report import RansomwareReport
from monkey_island.cc.resources.remote_run import RemoteRun
from monkey_island.cc.resources.root import Root from monkey_island.cc.resources.root import Root
from monkey_island.cc.resources.security_report import SecurityReport from monkey_island.cc.resources.security_report import SecurityReport
from monkey_island.cc.resources.telemetry import Telemetry from monkey_island.cc.resources.telemetry import Telemetry

View File

@ -0,0 +1 @@
from .remote_run import RemoteRun

View File

@ -1,12 +1,13 @@
import json import json
from typing import Sequence
import flask_restful import flask_restful
from botocore.exceptions import ClientError, NoCredentialsError from botocore.exceptions import ClientError, NoCredentialsError
from flask import jsonify, make_response, request from flask import jsonify, make_response, request
from monkey_island.cc.resources.auth.auth import jwt_required from monkey_island.cc.resources.auth.auth import jwt_required
from monkey_island.cc.services import aws_service from monkey_island.cc.services import AWSService
from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService from monkey_island.cc.services.aws import AWSCommandResults
CLIENT_ERROR_FORMAT = ( CLIENT_ERROR_FORMAT = (
"ClientError, error message: '{}'. Probably, the IAM role that has been associated with the " "ClientError, error message: '{}'. Probably, the IAM role that has been associated with the "
@ -19,20 +20,18 @@ NO_CREDS_ERROR_FORMAT = (
class RemoteRun(flask_restful.Resource): class RemoteRun(flask_restful.Resource):
def run_aws_monkeys(self, request_body): def __init__(self, aws_service: AWSService):
instances = request_body.get("instances") self._aws_service = aws_service
island_ip = request_body.get("island_ip")
return RemoteRunAwsService.run_aws_monkeys(instances, island_ip)
@jwt_required @jwt_required
def get(self): def get(self):
action = request.args.get("action") action = request.args.get("action")
if action == "list_aws": if action == "list_aws":
is_aws = aws_service.is_on_aws() is_aws = self._aws_service.island_is_running_on_aws()
resp = {"is_aws": is_aws} resp = {"is_aws": is_aws}
if is_aws: if is_aws:
try: try:
resp["instances"] = aws_service.get_instances() resp["instances"] = self._aws_service.get_managed_instances()
except NoCredentialsError as e: except NoCredentialsError as e:
resp["error"] = NO_CREDS_ERROR_FORMAT.format(e) resp["error"] = NO_CREDS_ERROR_FORMAT.format(e)
return jsonify(resp) return jsonify(resp)
@ -46,11 +45,28 @@ class RemoteRun(flask_restful.Resource):
@jwt_required @jwt_required
def post(self): def post(self):
body = json.loads(request.data) body = json.loads(request.data)
resp = {}
if body.get("type") == "aws": if body.get("type") == "aws":
result = self.run_aws_monkeys(body) results = self.run_aws_monkeys(body)
resp["result"] = result return RemoteRun._encode_results(results)
return jsonify(resp)
# default action # default action
return make_response({"error": "Invalid action"}, 500) return make_response({"error": "Invalid action"}, 500)
def run_aws_monkeys(self, request_body) -> Sequence[AWSCommandResults]:
instances = request_body.get("instances")
island_ip = request_body.get("island_ip")
return self._aws_service.run_agents_on_managed_instances(instances, island_ip)
@staticmethod
def _encode_results(results: Sequence[AWSCommandResults]):
result = list(map(RemoteRun._aws_command_results_to_encodable_dict, results))
response = {"result": result}
return jsonify(response)
@staticmethod
def _aws_command_results_to_encodable_dict(aws_command_results: AWSCommandResults):
res_dict = aws_command_results.__dict__
res_dict["status"] = res_dict["status"].name.lower()
return res_dict

View File

@ -3,7 +3,6 @@ import json
import logging import logging
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
import gevent.hub import gevent.hub
from gevent.pywsgi import WSGIServer from gevent.pywsgi import WSGIServer
@ -29,7 +28,6 @@ from monkey_island.cc.server_utils.consts import ( # noqa: E402
) )
from monkey_island.cc.server_utils.island_logger import reset_logger, setup_logging # noqa: E402 from monkey_island.cc.server_utils.island_logger import reset_logger, setup_logging # noqa: E402
from monkey_island.cc.services.initialize import initialize_services # noqa: E402 from monkey_island.cc.services.initialize import initialize_services # noqa: E402
from monkey_island.cc.services.reporting.exporter_init import populate_exporter_list # noqa: E402
from monkey_island.cc.services.utils.network_utils import local_ip_addresses # noqa: E402 from monkey_island.cc.services.utils.network_utils import local_ip_addresses # noqa: E402
from monkey_island.cc.setup import island_config_options_validator # noqa: E402 from monkey_island.cc.setup import island_config_options_validator # noqa: E402
from monkey_island.cc.setup.data_dir import IncompatibleDataDirectory, setup_data_dir # noqa: E402 from monkey_island.cc.setup.data_dir import IncompatibleDataDirectory, setup_data_dir # noqa: E402
@ -132,8 +130,6 @@ def _configure_gevent_exception_handling(data_dir):
def _start_island_server( def _start_island_server(
should_setup_only: bool, config_options: IslandConfigOptions, container: DIContainer should_setup_only: bool, config_options: IslandConfigOptions, container: DIContainer
): ):
# AWS exporter takes a long time to load
Thread(target=populate_exporter_list, name="Report exporter list", daemon=True).start()
app = init_app(mongo_setup.MONGO_URL, container) app = init_app(mongo_setup.MONGO_URL, container)
if should_setup_only: if should_setup_only:

View File

@ -1,29 +0,0 @@
from common.cmd.cmd_result import CmdResult
class AwsCmdResult(CmdResult):
"""
Class representing an AWS command result
"""
def __init__(self, command_info):
super(AwsCmdResult, self).__init__(
self.is_successful(command_info, True),
command_info["ResponseCode"],
command_info["StandardOutputContent"],
command_info["StandardErrorContent"],
)
self.command_info = command_info
@staticmethod
def is_successful(command_info, is_timeout=False):
"""
Determines whether the command was successful. If it timed out and was still in progress,
we assume it worked.
:param command_info: Command info struct (returned by ssm.get_command_invocation)
:param is_timeout: Whether the given command timed out
:return: True if successful, False otherwise.
"""
return (command_info["Status"] == "Success") or (
is_timeout and (command_info["Status"] == "InProgress")
)

View File

@ -1,45 +0,0 @@
import logging
import time
from common.cmd.cmd_runner import CmdRunner
from common.cmd.cmd_status import CmdStatus
from monkey_island.cc.server_utils.aws_cmd_result import AwsCmdResult
from monkey_island.cc.services import aws_service
logger = logging.getLogger(__name__)
class AwsCmdRunner(CmdRunner):
"""
Class for running commands on a remote AWS machine
"""
def __init__(self, is_linux, instance_id, region=None):
super(AwsCmdRunner, self).__init__(is_linux)
self.instance_id = instance_id
self.region = region
self.ssm = aws_service.get_client("ssm", region)
def query_command(self, command_id):
time.sleep(2)
return self.ssm.get_command_invocation(CommandId=command_id, InstanceId=self.instance_id)
def get_command_result(self, command_info):
return AwsCmdResult(command_info)
def get_command_status(self, command_info):
if command_info["Status"] == "InProgress":
return CmdStatus.IN_PROGRESS
elif command_info["Status"] == "Success":
return CmdStatus.SUCCESS
else:
return CmdStatus.FAILURE
def run_command_async(self, command_line):
doc_name = "AWS-RunShellScript" if self.is_linux else "AWS-RunPowerShellScript"
command_res = self.ssm.send_command(
DocumentName=doc_name,
Parameters={"commands": [command_line]},
InstanceIds=[self.instance_id],
)
return command_res["Command"]["CommandId"]

View File

@ -3,3 +3,9 @@ from .directory_file_storage_service import DirectoryFileStorageService
from .authentication.authentication_service import AuthenticationService from .authentication.authentication_service import AuthenticationService
from .authentication.json_file_user_datastore import JsonFileUserDatastore from .authentication.json_file_user_datastore import JsonFileUserDatastore
from .aws import AWSService
# TODO: This is a temporary import to keep some tests passing. Remove it before merging #1928 to
# develop.
from .aws import aws_service

View File

@ -0,0 +1,2 @@
from .aws_service import AWSService
from .aws_command_runner import AWSCommandResults, AWSCommandStatus

View File

@ -0,0 +1,149 @@
import logging
import time
from dataclasses import dataclass
from enum import Enum, auto
import botocore
from common.utils import Timer
STATUS_CHECK_SLEEP_TIME = 1
LINUX_DOCUMENT_NAME = "AWS-RunShellScript"
WINDOWS_DOCUMENT_NAME = "AWS-RunPowerShellScript"
logger = logging.getLogger(__name__)
class AWSCommandStatus(Enum):
SUCCESS = auto()
IN_PROGRESS = auto()
ERROR = auto()
@dataclass(frozen=True)
class AWSCommandResults:
instance_id: str
response_code: int
stdout: str
stderr: str
status: AWSCommandStatus
@property
def success(self):
return self.status == AWSCommandStatus.SUCCESS
def start_infection_monkey_agent(
aws_client: botocore.client.BaseClient,
target_instance_id: str,
target_os: str,
island_ip: str,
timeout: float,
) -> AWSCommandResults:
"""
Run a command on a remote AWS instance
"""
command = _get_run_agent_command(target_os, island_ip)
command_id = _run_command_async(aws_client, target_instance_id, target_os, command)
_wait_for_command_to_complete(aws_client, target_instance_id, command_id, timeout)
return _fetch_command_results(aws_client, target_instance_id, command_id)
def _get_run_agent_command(target_os: str, island_ip: str):
if target_os == "linux":
return _get_run_monkey_cmd_linux_line(island_ip)
return _get_run_monkey_cmd_windows_line(island_ip)
def _get_run_monkey_cmd_linux_line(island_ip):
binary_name = "monkey-linux-64"
download_url = f"https://{island_ip}:5000/api/agent/download/linux"
download_cmd = f"wget --no-check-certificate {download_url} -O {binary_name}"
chmod_cmd = f"chmod +x {binary_name}"
run_agent_cmd = f"./{binary_name} m0nk3y -s {island_ip}:5000"
return f"{download_cmd}; {chmod_cmd}; {run_agent_cmd}"
def _get_run_monkey_cmd_windows_line(island_ip):
agent_exe_path = r".\\monkey.exe"
ignore_ssl_errors_cmd = (
"[System.Net.ServicePointManager]::ServerCertificateValidationCallback = {$true}"
)
download_url = f"https://{island_ip}:5000/api/agent/download/windows"
download_cmd = (
f"(New-Object System.Net.WebClient).DownloadFile('{download_url}', '{agent_exe_path}')"
)
run_agent_cmd = (
f"Start-Process -FilePath '{agent_exe_path}' -ArgumentList 'm0nk3y -s {island_ip}:5000'"
)
return f"{ignore_ssl_errors_cmd}; {download_cmd}; {run_agent_cmd};"
def _run_command_async(
aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, command: str
):
doc_name = LINUX_DOCUMENT_NAME if target_os == "linux" else WINDOWS_DOCUMENT_NAME
logger.debug(f'Running command on {target_instance_id} -- {doc_name}: "{command}"')
command_response = aws_client.send_command(
DocumentName=doc_name,
Parameters={"commands": [command]},
InstanceIds=[target_instance_id],
)
command_id = command_response["Command"]["CommandId"]
logger.debug(
f"Started command on AWS instance {target_instance_id} with command ID {command_id}"
)
return command_id
def _wait_for_command_to_complete(
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str, timeout: float
):
timer = Timer()
timer.set(timeout)
while not timer.is_expired():
time.sleep(STATUS_CHECK_SLEEP_TIME)
command_results = _fetch_command_results(aws_client, target_instance_id, command_id)
logger.debug(f"Command {command_id} status: {command_results.status.name}")
if command_results.status != AWSCommandStatus.IN_PROGRESS:
return
def _fetch_command_results(
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str
) -> AWSCommandResults:
command_results = aws_client.get_command_invocation(
CommandId=command_id, InstanceId=target_instance_id
)
command_status = command_results["Status"]
logger.debug(f"Command {command_id} status: {command_status}")
if command_status == "Success":
aws_command_result_status = AWSCommandStatus.SUCCESS
elif command_status == "InProgress":
aws_command_result_status = AWSCommandStatus.IN_PROGRESS
else:
aws_command_result_status = AWSCommandStatus.ERROR
return AWSCommandResults(
target_instance_id,
command_results["ResponseCode"],
command_results["StandardOutputContent"],
command_results["StandardErrorContent"],
aws_command_result_status,
)

View File

@ -0,0 +1,135 @@
import logging
from queue import Queue
from threading import Thread
from typing import Any, Iterable, Mapping, Sequence
import boto3
import botocore
from common.aws.aws_instance import AWSInstance
from common.utils.code_utils import queue_to_list
from .aws_command_runner import AWSCommandResults, start_infection_monkey_agent
DEFAULT_REMOTE_COMMAND_TIMEOUT = 5
INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList"
INSTANCE_ID_KEY = "InstanceId"
COMPUTER_NAME_KEY = "ComputerName"
PLATFORM_TYPE_KEY = "PlatformType"
IP_ADDRESS_KEY = "IPAddress"
logger = logging.getLogger(__name__)
class AWSService:
def __init__(self, aws_instance: AWSInstance):
"""
:param aws_instance: An AWSInstance object representing the AWS instance that the Island is
running on
"""
self._aws_instance = aws_instance
def island_is_running_on_aws(self) -> bool:
"""
:return: True if the island is running on an AWS instance. False otherwise.
:rtype: bool
"""
return self._aws_instance.is_instance
@property
def island_aws_instance(self) -> AWSInstance:
"""
:return: an AWSInstance object representing the AWS instance that the Island is running on.
:rtype: AWSInstance
"""
return self._aws_instance
def get_managed_instances(self) -> Sequence[Mapping[str, str]]:
"""
:return: A sequence of mappings, where each Mapping represents a managed AWS instance that
is accessible from the Island.
:rtype: Sequence[Mapping[str, str]]
"""
raw_managed_instances_info = self._get_raw_managed_instances()
return _filter_relevant_instance_info(raw_managed_instances_info)
def _get_raw_managed_instances(self) -> Sequence[Mapping[str, Any]]:
"""
Get the information for all instances with the relevant roles.
This function will assume that it's running on an EC2 instance with the correct IAM role.
See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam
-role for details.
:raises: botocore.exceptions.ClientError if can't describe local instance information.
:return: All visible instances from this instance
"""
ssm_client = boto3.client("ssm", self.island_aws_instance.region)
try:
response = ssm_client.describe_instance_information()
return response[INSTANCE_INFORMATION_LIST_KEY]
except botocore.exceptions.ClientError as err:
logger.warning("AWS client error while trying to get manage dinstances: {err}")
raise err
def run_agents_on_managed_instances(
self,
instances: Iterable[Mapping[str, str]],
island_ip: str,
timeout: float = DEFAULT_REMOTE_COMMAND_TIMEOUT,
) -> Sequence[AWSCommandResults]:
"""
Run an agent on one or more managed AWS instances.
:param instances: An iterable of instances that the agent will be run on
:param island_ip: The IP address of the Island to pass to the new agents
:param timeout: The maximum number of seconds to wait for the agents to start
:return: A sequence of AWSCommandResults
"""
results_queue = Queue()
command_threads = []
for i in instances:
t = Thread(
target=self._run_agent_on_managed_instance,
args=(results_queue, i["instance_id"], i["os"], island_ip, timeout),
daemon=True,
)
t.start()
command_threads.append(t)
for thread in command_threads:
thread.join()
return queue_to_list(results_queue)
def _run_agent_on_managed_instance(
self, results_queue: Queue, instance_id: str, os: str, island_ip: str, timeout: float
):
ssm_client = boto3.client("ssm", self.island_aws_instance.region)
command_results = start_infection_monkey_agent(
ssm_client, instance_id, os, island_ip, timeout
)
results_queue.put(command_results)
def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]):
"""
Consume raw instance data from the AWS API and return only those fields that are relevant for
Infection Monkey.
:param raw_managed_instances_info: The output of
DescribeInstanceInformation["InstanceInformation"] from the
AWS API
:return: A sequence of mappings, where each Mapping represents a managed AWS instance that
is accessible from the Island.
:rtype: Sequence[Mapping[str, str]]
"""
return [
{
"instance_id": managed_instance[INSTANCE_ID_KEY],
"name": managed_instance[COMPUTER_NAME_KEY],
"os": managed_instance[PLATFORM_TYPE_KEY].lower(),
"ip_address": managed_instance[IP_ADDRESS_KEY],
}
for managed_instance in raw_managed_instances_info
]

View File

@ -1,99 +0,0 @@
import logging
from functools import wraps
from threading import Event
from typing import Callable, Optional
import boto3
import botocore
from common.aws.aws_instance import AwsInstance
INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList"
INSTANCE_ID_KEY = "InstanceId"
COMPUTER_NAME_KEY = "ComputerName"
PLATFORM_TYPE_KEY = "PlatformType"
IP_ADDRESS_KEY = "IPAddress"
logger = logging.getLogger(__name__)
def filter_instance_data_from_aws_response(response):
return [
{
"instance_id": x[INSTANCE_ID_KEY],
"name": x[COMPUTER_NAME_KEY],
"os": x[PLATFORM_TYPE_KEY].lower(),
"ip_address": x[IP_ADDRESS_KEY],
}
for x in response[INSTANCE_INFORMATION_LIST_KEY]
]
aws_instance: Optional[AwsInstance] = None
AWS_INFO_FETCH_TIMEOUT = 10.0 # Seconds
init_done = Event()
def initialize():
global aws_instance
aws_instance = AwsInstance()
init_done.set()
def wait_init_done(fnc: Callable):
@wraps(fnc)
def inner():
awaited = init_done.wait(AWS_INFO_FETCH_TIMEOUT)
if not awaited:
logger.error(
f"AWS service couldn't initialize in time! "
f"Current timeout is {AWS_INFO_FETCH_TIMEOUT}, "
f"but AWS info took longer to fetch from metadata server."
)
return
fnc()
return inner
@wait_init_done
def is_on_aws():
return aws_instance.is_instance
@wait_init_done
def get_region():
return aws_instance.region
@wait_init_done
def get_account_id():
return aws_instance.account_id
@wait_init_done
def get_client(client_type):
return boto3.client(client_type, region_name=aws_instance.region)
@wait_init_done
def get_instances():
"""
Get the information for all instances with the relevant roles.
This function will assume that it's running on an EC2 instance with the correct IAM role.
See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam
-role for details.
:raises: botocore.exceptions.ClientError if can't describe local instance information.
:return: All visible instances from this instance
"""
local_ssm_client = boto3.client("ssm", aws_instance.region)
try:
response = local_ssm_client.describe_instance_information()
filtered_instances_data = filter_instance_data_from_aws_response(response)
return filtered_instances_data
except botocore.exceptions.ClientError as e:
logger.warning("AWS client error while trying to get instances: " + e)
raise e

View File

@ -1,26 +1,28 @@
from pathlib import Path from pathlib import Path
from threading import Thread
from common import DIContainer from common import DIContainer
from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService, aws_service from common.aws import AWSInstance
from monkey_island.cc.services import AWSService, DirectoryFileStorageService, IFileStorageService
from monkey_island.cc.services.post_breach_files import PostBreachFilesService from monkey_island.cc.services.post_breach_files import PostBreachFilesService
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
from . import AuthenticationService, JsonFileUserDatastore from . import AuthenticationService, JsonFileUserDatastore
from .reporting.report import ReportService
def initialize_services(data_dir: Path) -> DIContainer: def initialize_services(data_dir: Path) -> DIContainer:
container = DIContainer() container = DIContainer()
container.register_instance(AWSInstance, AWSInstance())
container.register_instance( container.register_instance(
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas") IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
) )
container.register_instance(AWSService, container.resolve(AWSService))
# Takes a while so it's best to start it in the background
Thread(target=aws_service.initialize, name="AwsService initialization", daemon=True).start()
# This is temporary until we get DI all worked out. # This is temporary until we get DI all worked out.
PostBreachFilesService.initialize(container.resolve(IFileStorageService)) PostBreachFilesService.initialize(container.resolve(IFileStorageService))
LocalMonkeyRunService.initialize(data_dir) LocalMonkeyRunService.initialize(data_dir)
AuthenticationService.initialize(data_dir, JsonFileUserDatastore(data_dir)) AuthenticationService.initialize(data_dir, JsonFileUserDatastore(data_dir))
ReportService.initialize(container.resolve(AWSService))
return container return container

View File

@ -1,82 +0,0 @@
import logging
from common.cmd.cmd import Cmd
from common.cmd.cmd_runner import CmdRunner
from monkey_island.cc.server_utils.aws_cmd_runner import AwsCmdRunner
logger = logging.getLogger(__name__)
class RemoteRunAwsService:
@staticmethod
def run_aws_monkeys(instances, island_ip):
"""
Runs monkeys on the given instances
:param instances: List of instances to run on
:param island_ip: IP of island the monkey will communicate with
:return: Dictionary with instance ids as keys, and True/False as values if succeeded or not
"""
return CmdRunner.run_multiple_commands(
instances,
lambda instance: RemoteRunAwsService._run_aws_monkey_cmd_async(
instance["instance_id"],
RemoteRunAwsService._is_linux(instance["os"]),
island_ip,
),
lambda _, result: result.is_success,
)
@staticmethod
def _run_aws_monkey_cmd_async(instance_id, is_linux, island_ip):
"""
Runs a monkey remotely using AWS
:param instance_id: Instance ID of target
:param is_linux: Whether target is linux
:param island_ip: IP of the island which the instance will try to connect to
:return: Cmd
"""
cmd_text = RemoteRunAwsService._get_run_monkey_cmd_line(is_linux, island_ip)
return RemoteRunAwsService._run_aws_cmd_async(instance_id, is_linux, cmd_text)
@staticmethod
def _run_aws_cmd_async(instance_id, is_linux, cmd_line):
cmd_runner = AwsCmdRunner(is_linux, instance_id)
return Cmd(cmd_runner, cmd_runner.run_command_async(cmd_line))
@staticmethod
def _is_linux(os):
return "linux" == os
@staticmethod
def _get_run_monkey_cmd_linux_line(island_ip):
return (
r"wget --no-check-certificate https://"
+ island_ip
+ r":5000/api/agent/download/linux "
+ r"-O monkey-linux-64"
+ r"; chmod +x monkey-linux-64"
+ r"; ./monkey-linux-64"
+ r" m0nk3y -s "
+ island_ip
+ r":5000"
)
@staticmethod
def _get_run_monkey_cmd_windows_line(island_ip):
return (
r"[System.Net.ServicePointManager]::ServerCertificateValidationCallback = {"
r"$true}; (New-Object System.Net.WebClient).DownloadFile('https://"
+ island_ip
+ r":5000/api/agent/download/windows'"
+ r"'.\\monkey.exe'); "
r";Start-Process -FilePath '.\\monkey.exe' "
r"-ArgumentList 'm0nk3y -s " + island_ip + r":5000'; "
)
@staticmethod
def _get_run_monkey_cmd_line(is_linux, island_ip):
return (
RemoteRunAwsService._get_run_monkey_cmd_linux_line(island_ip)
if is_linux
else RemoteRunAwsService._get_run_monkey_cmd_windows_line(island_ip)
)

View File

@ -1,21 +1,15 @@
import logging import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Mapping
import boto3 import boto3
from botocore.exceptions import UnknownServiceError from botocore.exceptions import UnknownServiceError
from monkey_island.cc.services import aws_service from common.aws import AWSInstance
from monkey_island.cc.services.reporting.exporter import Exporter
__authors__ = ["maor.rayzin", "shay.nehmad"]
from monkey_island.cc.services.reporting.issue_processing.exploit_processing.exploiter_descriptor_enum import ( # noqa:E501 (Long import) from monkey_island.cc.services.reporting.issue_processing.exploit_processing.exploiter_descriptor_enum import ( # noqa:E501 (Long import)
ExploiterDescriptorEnum, ExploiterDescriptorEnum,
) )
# noqa:E501 (Long import)
from monkey_island.cc.services.reporting.issue_processing.exploit_processing.exploiter_report_info import ( # noqa:E501 (Long import) from monkey_island.cc.services.reporting.issue_processing.exploit_processing.exploiter_report_info import ( # noqa:E501 (Long import)
CredentialType, CredentialType,
) )
@ -25,376 +19,351 @@ logger = logging.getLogger(__name__)
INFECTION_MONKEY_ARN = "324264561773:product/guardicore/aws-infection-monkey" INFECTION_MONKEY_ARN = "324264561773:product/guardicore/aws-infection-monkey"
class AWSExporter(Exporter): def handle_report(report_json: Mapping, aws_instance: AWSInstance):
@staticmethod findings_list = []
def handle_report(report_json): issues_list = report_json["recommendations"]["issues"]
if not issues_list:
findings_list = [] logger.info("No issues were found by the monkey, no need to send anything")
issues_list = report_json["recommendations"]["issues"]
if not issues_list:
logger.info("No issues were found by the monkey, no need to send anything")
return True
current_aws_region = aws_service.get_region()
for machine in issues_list:
for issue in issues_list[machine]:
try:
if "aws_instance_id" in issue:
findings_list.append(
AWSExporter._prepare_finding(issue, current_aws_region)
)
except AWSExporter.FindingNotFoundError as e:
logger.error(e)
if not AWSExporter._send_findings(findings_list, current_aws_region):
logger.error("Exporting findings to aws failed")
return False
return True return True
@staticmethod for machine in issues_list:
def merge_two_dicts(x, y): for issue in issues_list[machine]:
z = x.copy() # start with x's keys and values try:
z.update(y) # modifies z with y's keys and values & returns None if "aws_instance_id" in issue:
return z findings_list.append(_prepare_finding(issue, aws_instance))
except FindingNotFoundError as e:
logger.error(e)
@staticmethod if not _send_findings(findings_list, aws_instance.region):
def _prepare_finding(issue, region): logger.error("Exporting findings to aws failed")
findings_dict = { return False
"island_cross_segment": AWSExporter._handle_island_cross_segment_issue,
ExploiterDescriptorEnum.SSH.value.class_name: {
CredentialType.PASSWORD.value: AWSExporter._handle_ssh_issue,
CredentialType.KEY.value: AWSExporter._handle_ssh_key_issue,
},
"tunnel": AWSExporter._handle_tunnel_issue,
ExploiterDescriptorEnum.SMB.value.class_name: {
CredentialType.PASSWORD.value: AWSExporter._handle_smb_password_issue,
CredentialType.HASH.value: AWSExporter._handle_smb_pth_issue,
},
"shared_passwords": AWSExporter._handle_shared_passwords_issue,
ExploiterDescriptorEnum.WMI.value.class_name: {
CredentialType.PASSWORD.value: AWSExporter._handle_wmi_password_issue,
CredentialType.HASH.value: AWSExporter._handle_wmi_pth_issue,
},
"shared_passwords_domain": AWSExporter._handle_shared_passwords_domain_issue,
"shared_admins_domain": AWSExporter._handle_shared_admins_domain_issue,
"strong_users_on_crit": AWSExporter._handle_strong_users_on_crit_issue,
ExploiterDescriptorEnum.HADOOP.value.class_name: AWSExporter._handle_hadoop_issue,
}
configured_product_arn = INFECTION_MONKEY_ARN return True
product_arn = "arn:aws:securityhub:{region}:{arn}".format(
region=region, arn=configured_product_arn
)
instance_arn = "arn:aws:ec2:" + str(region) + ":instance:{instance_id}"
# Not suppressing error here on purpose.
account_id = aws_service.get_account_id()
logger.debug("aws account id acquired: {}".format(account_id))
aws_finding = {
"SchemaVersion": "2018-10-08",
"Id": uuid.uuid4().hex,
"ProductArn": product_arn,
"GeneratorId": issue["type"],
"AwsAccountId": account_id,
"RecordState": "ACTIVE",
"Types": ["Software and Configuration Checks/Vulnerabilities/CVE"],
"CreatedAt": datetime.now().isoformat() + "Z",
"UpdatedAt": datetime.now().isoformat() + "Z",
}
processor = AWSExporter._get_issue_processor(findings_dict, issue) def merge_two_dicts(x, y):
z = x.copy() # start with x's keys and values
z.update(y) # modifies z with y's keys and values & returns None
return z
return AWSExporter.merge_two_dicts(aws_finding, processor(issue, instance_arn))
@staticmethod def _prepare_finding(issue, aws_instance: AWSInstance):
def _get_issue_processor(finding_dict, issue): findings_dict = {
try: "island_cross_segment": _handle_island_cross_segment_issue,
processor = finding_dict[issue["type"]] ExploiterDescriptorEnum.SSH.value.class_name: {
if type(processor) == dict: CredentialType.PASSWORD.value: _handle_ssh_issue,
processor = processor[issue["credential_type"]] CredentialType.KEY.value: _handle_ssh_key_issue,
return processor },
except KeyError: "tunnel": _handle_tunnel_issue,
raise AWSExporter.FindingNotFoundError( ExploiterDescriptorEnum.SMB.value.class_name: {
f"Finding {issue['type']} not added as AWS exportable finding" CredentialType.PASSWORD.value: _handle_smb_password_issue,
) CredentialType.HASH.value: _handle_smb_pth_issue,
},
"shared_passwords": _handle_shared_passwords_issue,
ExploiterDescriptorEnum.WMI.value.class_name: {
CredentialType.PASSWORD.value: _handle_wmi_password_issue,
CredentialType.HASH.value: _handle_wmi_pth_issue,
},
"shared_passwords_domain": _handle_shared_passwords_domain_issue,
"shared_admins_domain": _handle_shared_admins_domain_issue,
"strong_users_on_crit": _handle_strong_users_on_crit_issue,
ExploiterDescriptorEnum.HADOOP.value.class_name: _handle_hadoop_issue,
}
class FindingNotFoundError(Exception): configured_product_arn = INFECTION_MONKEY_ARN
pass product_arn = "arn:aws:securityhub:{region}:{arn}".format(
region=aws_instance.region, arn=configured_product_arn
)
instance_arn = "arn:aws:ec2:" + str(aws_instance.region) + ":instance:{instance_id}"
account_id = aws_instance.account_id
logger.debug("aws account id acquired: {}".format(account_id))
@staticmethod aws_finding = {
def _send_findings(findings_list, region): "SchemaVersion": "2018-10-08",
try: "Id": uuid.uuid4().hex,
logger.debug("Trying to acquire securityhub boto3 client in " + region) "ProductArn": product_arn,
security_hub_client = boto3.client("securityhub", region_name=region) "GeneratorId": issue["type"],
logger.debug("Client acquired: {0}".format(repr(security_hub_client))) "AwsAccountId": account_id,
"RecordState": "ACTIVE",
"Types": ["Software and Configuration Checks/Vulnerabilities/CVE"],
"CreatedAt": datetime.now().isoformat() + "Z",
"UpdatedAt": datetime.now().isoformat() + "Z",
}
# Assumes the machine has the correct IAM role to do this, @see processor = _get_issue_processor(findings_dict, issue)
# https://github.com/guardicore/monkey/wiki/Monkey-Island:-Running-the-monkey-on-AWS
# -EC2-instances
import_response = security_hub_client.batch_import_findings(Findings=findings_list)
logger.debug("Import findings response: {0}".format(repr(import_response)))
if import_response["ResponseMetadata"]["HTTPStatusCode"] == 200: return merge_two_dicts(aws_finding, processor(issue, instance_arn))
return True
else:
return False
except UnknownServiceError as e:
logger.warning(
"AWS exporter called but AWS-CLI security hub service is not installed. "
"Error: {}".format(e)
)
return False
except Exception as e:
logger.exception("AWS security hub findings failed to send. Error: {}".format(e))
return False
@staticmethod
def _get_finding_resource(instance_id, instance_arn): def _get_issue_processor(finding_dict, issue):
if instance_id: try:
return [{"Type": "AwsEc2Instance", "Id": instance_arn.format(instance_id=instance_id)}] processor = finding_dict[issue["type"]]
if type(processor) == dict:
processor = processor[issue["credential_type"]]
return processor
except KeyError:
raise FindingNotFoundError(f"Finding {issue['type']} not added as AWS exportable finding")
class FindingNotFoundError(Exception):
pass
def _send_findings(findings_list, region):
try:
logger.debug("Trying to acquire securityhub boto3 client in " + region)
security_hub_client = boto3.client("securityhub", region_name=region)
logger.debug("Client acquired: {0}".format(repr(security_hub_client)))
# Assumes the machine has the correct IAM role to do this, @see
# https://github.com/guardicore/monkey/wiki/Monkey-Island:-Running-the-monkey-on-AWS
# -EC2-instances
import_response = security_hub_client.batch_import_findings(Findings=findings_list)
logger.debug("Import findings response: {0}".format(repr(import_response)))
if import_response["ResponseMetadata"]["HTTPStatusCode"] == 200:
return True
else: else:
return [{"Type": "Other", "Id": "None"}] return False
except UnknownServiceError as e:
@staticmethod logger.warning(
def _build_generic_finding( "AWS exporter called but AWS-CLI security hub service is not installed. "
severity, title, description, recommendation, instance_arn, instance_id=None "Error: {}".format(e)
):
finding = {
"Severity": {"Product": severity, "Normalized": 100},
"Resources": AWSExporter._get_finding_resource(instance_id, instance_arn),
"Title": title,
"Description": description,
"Remediation": {"Recommendation": {"Text": recommendation}},
}
return finding
@staticmethod
def _handle_tunnel_issue(issue, instance_arn):
return AWSExporter._build_generic_finding(
severity=5,
title="Weak segmentation - Machines were able to communicate over unused ports.",
description="Use micro-segmentation policies to disable communication other than "
"the required.",
recommendation="Machines are not locked down at port level. "
"Network tunnel was set up from {0} to {1}".format(issue["machine"], issue["dest"]),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
) )
return False
except Exception as e:
logger.exception("AWS security hub findings failed to send. Error: {}".format(e))
return False
@staticmethod
def _handle_smb_pth_issue(issue, instance_arn):
return AWSExporter._build_generic_finding( def _get_finding_resource(instance_id, instance_arn):
severity=5, if instance_id:
title="Machines are accessible using passwords supplied by the user during the " return [{"Type": "AwsEc2Instance", "Id": instance_arn.format(instance_id=instance_id)}]
"Monkey's configuration.", else:
description="Change {0}'s password to a complex one-use password that is not " return [{"Type": "Other", "Id": "None"}]
"shared with other computers on the "
"network.".format(issue["username"]),
recommendation="The machine {0}({1}) is vulnerable to a SMB attack. The Monkey "
"used a pass-the-hash attack over "
"SMB protocol with user {2}.".format(
issue["machine"], issue["ip_address"], issue["username"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod
def _handle_ssh_issue(issue, instance_arn):
return AWSExporter._build_generic_finding( def _build_generic_finding(
severity=1, severity, title, description, recommendation, instance_arn, instance_id=None
title="Machines are accessible using SSH passwords supplied by the user during " ):
"the Monkey's configuration.", finding = {
description="Change {0}'s password to a complex one-use password that is not " "Severity": {"Product": severity, "Normalized": 100},
"shared with other computers on the " "Resources": _get_finding_resource(instance_id, instance_arn),
"network.".format(issue["username"]), "Title": title,
recommendation="The machine {0} ({1}) is vulnerable to a SSH attack. The Monkey " "Description": description,
"authenticated over the SSH" "Remediation": {"Recommendation": {"Text": recommendation}},
" protocol with user {2} and its " }
"password.".format(issue["machine"], issue["ip_address"], issue["username"]),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod return finding
def _handle_ssh_key_issue(issue, instance_arn):
return AWSExporter._build_generic_finding(
severity=1,
title="Machines are accessible using SSH passwords supplied by the user during "
"the Monkey's configuration.",
description="Protect {ssh_key} private key with a pass phrase.".format(
ssh_key=issue["ssh_key"]
),
recommendation="The machine {machine} ({ip_address}) is vulnerable to a SSH "
"attack. The Monkey authenticated "
"over the SSH protocol with private key {ssh_key}.".format(
machine=issue["machine"], ip_address=issue["ip_address"], ssh_key=issue["ssh_key"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_tunnel_issue(issue, instance_arn):
def _handle_island_cross_segment_issue(issue, instance_arn): return _build_generic_finding(
severity=5,
title="Weak segmentation - Machines were able to communicate over unused ports.",
description="Use micro-segmentation policies to disable communication other than "
"the required.",
recommendation="Machines are not locked down at port level. "
"Network tunnel was set up from {0} to {1}".format(issue["machine"], issue["dest"]),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=1,
title="Weak segmentation - Machines from different segments are able to "
"communicate.",
description="Segment your network and make sure there is no communication between "
"machines from different "
"segments.",
recommendation="The network can probably be segmented. A monkey instance on \
{0} in the networks {1} \
could directly access the Monkey Island server in the networks {2}.".format(
issue["machine"], issue["networks"], issue["server_networks"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_smb_pth_issue(issue, instance_arn):
def _handle_shared_passwords_issue(issue, instance_arn): return _build_generic_finding(
severity=5,
title="Machines are accessible using passwords supplied by the user during the "
"Monkey's configuration.",
description="Change {0}'s password to a complex one-use password that is not "
"shared with other computers on the "
"network.".format(issue["username"]),
recommendation="The machine {0}({1}) is vulnerable to a SMB attack. The Monkey "
"used a pass-the-hash attack over "
"SMB protocol with user {2}.".format(
issue["machine"], issue["ip_address"], issue["username"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=1,
title="Multiple users have the same password",
description="Some users are sharing passwords, this should be fixed by changing "
"passwords.",
recommendation="These users are sharing access password: {0}.".format(
issue["shared_with"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_ssh_issue(issue, instance_arn):
def _handle_smb_password_issue(issue, instance_arn): return _build_generic_finding(
severity=1,
title="Machines are accessible using SSH passwords supplied by the user during "
"the Monkey's configuration.",
description="Change {0}'s password to a complex one-use password that is not "
"shared with other computers on the "
"network.".format(issue["username"]),
recommendation="The machine {0} ({1}) is vulnerable to a SSH attack. The Monkey "
"authenticated over the SSH"
" protocol with user {2} and its "
"password.".format(issue["machine"], issue["ip_address"], issue["username"]),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=1,
title="Machines are accessible using passwords supplied by the user during the "
"Monkey's configuration.",
description="Change {0}'s password to a complex one-use password that is not "
"shared with other computers on the "
"network.".format(issue["username"]),
recommendation="The machine {0} ({1}) is vulnerable to a SMB attack. The Monkey "
"authenticated over the SMB "
"protocol with user {2} and its password.".format(
issue["machine"], issue["ip_address"], issue["username"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_ssh_key_issue(issue, instance_arn):
def _handle_wmi_password_issue(issue, instance_arn): return _build_generic_finding(
severity=1,
title="Machines are accessible using SSH passwords supplied by the user during "
"the Monkey's configuration.",
description="Protect {ssh_key} private key with a pass phrase.".format(
ssh_key=issue["ssh_key"]
),
recommendation="The machine {machine} ({ip_address}) is vulnerable to a SSH "
"attack. The Monkey authenticated "
"over the SSH protocol with private key {ssh_key}.".format(
machine=issue["machine"], ip_address=issue["ip_address"], ssh_key=issue["ssh_key"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=1,
title="Machines are accessible using passwords supplied by the user during the "
"Monkey's configuration.",
description="Change {0}'s password to a complex one-use password that is not "
"shared with other computers on the "
"network.",
recommendation="The machine {machine} ({ip_address}) is vulnerable to a WMI "
"attack. The Monkey authenticated over "
"the WMI protocol with user {username} and its password.".format(
machine=issue["machine"], ip_address=issue["ip_address"], username=issue["username"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_island_cross_segment_issue(issue, instance_arn):
def _handle_wmi_pth_issue(issue, instance_arn): return _build_generic_finding(
severity=1,
title="Weak segmentation - Machines from different segments are able to " "communicate.",
description="Segment your network and make sure there is no communication between "
"machines from different "
"segments.",
recommendation="The network can probably be segmented. A monkey instance on \
{0} in the networks {1} \
could directly access the Monkey Island server in the networks {2}.".format(
issue["machine"], issue["networks"], issue["server_networks"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=1,
title="Machines are accessible using passwords supplied by the user during the "
"Monkey's configuration.",
description="Change {0}'s password to a complex one-use password that is not "
"shared with other computers on the "
"network.".format(issue["username"]),
recommendation="The machine {machine} ({ip_address}) is vulnerable to a WMI "
"attack. The Monkey used a "
"pass-the-hash attack over WMI protocol with user {username}".format(
machine=issue["machine"], ip_address=issue["ip_address"], username=issue["username"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_shared_passwords_issue(issue, instance_arn):
def _handle_shared_passwords_domain_issue(issue, instance_arn): return _build_generic_finding(
severity=1,
title="Multiple users have the same password",
description="Some users are sharing passwords, this should be fixed by changing "
"passwords.",
recommendation="These users are sharing access password: {0}.".format(issue["shared_with"]),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=1,
title="Multiple users have the same password.",
description="Some domain users are sharing passwords, this should be fixed by "
"changing passwords.",
recommendation="These users are sharing access password: {shared_with}.".format(
shared_with=issue["shared_with"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_smb_password_issue(issue, instance_arn):
def _handle_shared_admins_domain_issue(issue, instance_arn): return _build_generic_finding(
severity=1,
title="Machines are accessible using passwords supplied by the user during the "
"Monkey's configuration.",
description="Change {0}'s password to a complex one-use password that is not "
"shared with other computers on the "
"network.".format(issue["username"]),
recommendation="The machine {0} ({1}) is vulnerable to a SMB attack. The Monkey "
"authenticated over the SMB "
"protocol with user {2} and its password.".format(
issue["machine"], issue["ip_address"], issue["username"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=1,
title="Shared local administrator account - Different machines have the same "
"account as a local administrator.",
description="Make sure the right administrator accounts are managing the right "
"machines, and that there isn't "
"an unintentional local admin sharing.",
recommendation="Here is a list of machines which the account {username} is "
"defined as an administrator: "
"{shared_machines}".format(
username=issue["username"], shared_machines=issue["shared_machines"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_wmi_password_issue(issue, instance_arn):
def _handle_strong_users_on_crit_issue(issue, instance_arn): return _build_generic_finding(
severity=1,
title="Machines are accessible using passwords supplied by the user during the "
"Monkey's configuration.",
description="Change {0}'s password to a complex one-use password that is not "
"shared with other computers on the "
"network.",
recommendation="The machine {machine} ({ip_address}) is vulnerable to a WMI "
"attack. The Monkey authenticated over "
"the WMI protocol with user {username} and its password.".format(
machine=issue["machine"], ip_address=issue["ip_address"], username=issue["username"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=1,
title="Mimikatz found login credentials of a user who has admin access to a "
"server defined as critical.",
description="This critical machine is open to attacks via strong users with "
"access to it.",
recommendation="The services: {services} have been found on the machine thus "
"classifying it as a critical "
"machine. These users has access to it:{threatening_users}.".format(
services=issue["services"], threatening_users=issue["threatening_users"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
@staticmethod def _handle_wmi_pth_issue(issue, instance_arn):
def _handle_hadoop_issue(issue, instance_arn): return _build_generic_finding(
severity=1,
title="Machines are accessible using passwords supplied by the user during the "
"Monkey's configuration.",
description="Change {0}'s password to a complex one-use password that is not "
"shared with other computers on the "
"network.".format(issue["username"]),
recommendation="The machine {machine} ({ip_address}) is vulnerable to a WMI "
"attack. The Monkey used a "
"pass-the-hash attack over WMI protocol with user {username}".format(
machine=issue["machine"], ip_address=issue["ip_address"], username=issue["username"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
return AWSExporter._build_generic_finding(
severity=10, def _handle_shared_passwords_domain_issue(issue, instance_arn):
title="Hadoop/Yarn servers are vulnerable to remote code execution.", return _build_generic_finding(
description="Run Hadoop in secure mode, add Kerberos authentication.", severity=1,
recommendation="The Hadoop server at {machine} ({ip_address}) is vulnerable to " title="Multiple users have the same password.",
"remote code execution attack." description="Some domain users are sharing passwords, this should be fixed by "
"The attack was made possible due to default Hadoop/Yarn " "changing passwords.",
"configuration being insecure.", recommendation="These users are sharing access password: {shared_with}.".format(
instance_arn=instance_arn, shared_with=issue["shared_with"]
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None, ),
) instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
def _handle_shared_admins_domain_issue(issue, instance_arn):
return _build_generic_finding(
severity=1,
title="Shared local administrator account - Different machines have the same "
"account as a local administrator.",
description="Make sure the right administrator accounts are managing the right "
"machines, and that there isn't "
"an unintentional local admin sharing.",
recommendation="Here is a list of machines which the account {username} is "
"defined as an administrator: "
"{shared_machines}".format(
username=issue["username"], shared_machines=issue["shared_machines"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
def _handle_strong_users_on_crit_issue(issue, instance_arn):
return _build_generic_finding(
severity=1,
title="Mimikatz found login credentials of a user who has admin access to a "
"server defined as critical.",
description="This critical machine is open to attacks via strong users with "
"access to it.",
recommendation="The services: {services} have been found on the machine thus "
"classifying it as a critical "
"machine. These users has access to it:{threatening_users}.".format(
services=issue["services"], threatening_users=issue["threatening_users"]
),
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)
def _handle_hadoop_issue(issue, instance_arn):
return _build_generic_finding(
severity=10,
title="Hadoop/Yarn servers are vulnerable to remote code execution.",
description="Run Hadoop in secure mode, add Kerberos authentication.",
recommendation="The Hadoop server at {machine} ({ip_address}) is vulnerable to "
"remote code execution attack."
"The attack was made possible due to default Hadoop/Yarn "
"configuration being insecure.",
instance_arn=instance_arn,
instance_id=issue["aws_instance_id"] if "aws_instance_id" in issue else None,
)

View File

@ -1,7 +0,0 @@
class Exporter(object):
def __init__(self):
pass
@staticmethod
def handle_report(report_json):
raise NotImplementedError

View File

@ -1,28 +0,0 @@
import logging
from monkey_island.cc.services import aws_service
from monkey_island.cc.services.reporting.aws_exporter import AWSExporter
from monkey_island.cc.services.reporting.report_exporter_manager import ReportExporterManager
logger = logging.getLogger(__name__)
def populate_exporter_list():
manager = ReportExporterManager()
try_add_aws_exporter_to_manager(manager)
if len(manager.get_exporters_list()) != 0:
logger.debug(
"Populated exporters list with the following exporters: {0}".format(
str(manager.get_exporters_list())
)
)
def try_add_aws_exporter_to_manager(manager):
# noinspection PyBroadException
try:
if aws_service.is_on_aws():
manager.add_exporter_to_list(AWSExporter)
except Exception:
logger.error("Failed adding aws exporter to manager. Exception info:", exc_info=True)

View File

@ -26,7 +26,6 @@ from monkey_island.cc.services.reporting.exploitations.monkey_exploitation impor
get_monkey_exploited, get_monkey_exploited,
) )
from monkey_island.cc.services.reporting.pth_report import PTHReportService from monkey_island.cc.services.reporting.pth_report import PTHReportService
from monkey_island.cc.services.reporting.report_exporter_manager import ReportExporterManager
from monkey_island.cc.services.reporting.report_generation_synchronisation import ( from monkey_island.cc.services.reporting.report_generation_synchronisation import (
safe_generate_regular_report, safe_generate_regular_report,
) )
@ -36,6 +35,8 @@ from monkey_island.cc.services.reporting.stolen_credentials import (
) )
from monkey_island.cc.services.utils.network_utils import get_subnets, local_ip_addresses from monkey_island.cc.services.utils.network_utils import get_subnets, local_ip_addresses
from .. import AWSService
from . import aws_exporter
from .issue_processing.exploit_processing.exploiter_descriptor_enum import ExploiterDescriptorEnum from .issue_processing.exploit_processing.exploiter_descriptor_enum import ExploiterDescriptorEnum
from .issue_processing.exploit_processing.processors.cred_exploit import CredentialType from .issue_processing.exploit_processing.processors.cred_exploit import CredentialType
from .issue_processing.exploit_processing.processors.exploit import ExploiterReportInfo from .issue_processing.exploit_processing.processors.exploit import ExploiterReportInfo
@ -44,11 +45,18 @@ logger = logging.getLogger(__name__)
class ReportService: class ReportService:
_aws_service = None
class DerivedIssueEnum: class DerivedIssueEnum:
WEAK_PASSWORD = "weak_password" WEAK_PASSWORD = "weak_password"
STOLEN_CREDS = "stolen_creds" STOLEN_CREDS = "stolen_creds"
ZEROLOGON_PASS_RESTORE_FAILED = "zerologon_pass_restore_failed" ZEROLOGON_PASS_RESTORE_FAILED = "zerologon_pass_restore_failed"
@classmethod
def initialize(cls, aws_service: AWSService):
cls._aws_service = aws_service
@staticmethod @staticmethod
def get_first_monkey_time(): def get_first_monkey_time():
return ( return (
@ -488,8 +496,8 @@ class ReportService:
"recommendations": {"issues": issues, "domain_issues": domain_issues}, "recommendations": {"issues": issues, "domain_issues": domain_issues},
"meta_info": {"latest_monkey_modifytime": monkey_latest_modify_time}, "meta_info": {"latest_monkey_modifytime": monkey_latest_modify_time},
} }
ReportExporterManager().export(report)
save_report(report) save_report(report)
aws_exporter.handle_report(report, ReportService._aws_service.island_aws_instance)
return report return report
@staticmethod @staticmethod

View File

@ -1,24 +0,0 @@
import logging
from common.utils.code_utils import Singleton
logger = logging.getLogger(__name__)
class ReportExporterManager(object, metaclass=Singleton):
def __init__(self):
self._exporters_set = set()
def get_exporters_list(self):
return self._exporters_set
def add_exporter_to_list(self, exporter):
self._exporters_set.add(exporter)
def export(self, report):
for exporter in self._exporters_set:
logger.debug("Trying to export using " + repr(exporter))
try:
exporter().handle_report(report)
except Exception as e:
logger.exception("Failed to export report, error: " + e)

View File

@ -70,10 +70,11 @@ function AWSInstanceTable(props) {
let color = 'inherit'; let color = 'inherit';
if (r) { if (r) {
let instId = r.original.instance_id; let instId = r.original.instance_id;
let runResult = getRunResults(instId);
if (isSelected(instId)) { if (isSelected(instId)) {
color = '#ffed9f'; color = '#ffed9f';
} else if (Object.prototype.hasOwnProperty.call(props.results, instId)) { } else if (runResult) {
color = props.results[instId] ? '#00f01b' : '#f00000' color = runResult.status === 'error' ? '#f00000' : '#00f01b'
} }
} }
@ -82,6 +83,15 @@ function AWSInstanceTable(props) {
}; };
} }
function getRunResults(instanceId) {
for(let result of props.results){
if (result.instance_id === instanceId){
return result
}
}
return false
}
return ( return (
<div className="data-table-container"> <div className="data-table-container">
<CheckboxTable <CheckboxTable

View File

@ -0,0 +1 @@
from .stub_di_container import StubDIContainer

View File

@ -0,0 +1,20 @@
from typing import Any, Sequence, Type, TypeVar
from unittest.mock import MagicMock
from common import DIContainer, UnregisteredTypeError
T = TypeVar("T")
class StubDIContainer(DIContainer):
def resolve(self, type_: Type[T]) -> T:
try:
return super().resolve(type_)
except UnregisteredTypeError:
return MagicMock()
def resolve_dependencies(self, type_: Type[T]) -> Sequence[Any]:
try:
return super().resolve_dependencies(type_)
except UnregisteredTypeError:
return MagicMock()

View File

@ -0,0 +1,61 @@
import pytest
from common.aws import AWSInstance
INSTANCE_ID = "1234"
REGION = "USA"
ACCOUNT_ID = "4321"
@pytest.fixture
def patch_fetch_metadata(monkeypatch):
def inner(instance_id: str, region: str, account_id: str):
return_value = (instance_id, region, account_id)
monkeypatch.setattr(
"common.aws.aws_instance.fetch_aws_instance_metadata", lambda: return_value
)
return inner
@pytest.fixture(autouse=True)
def patch_fetch_metadata_default_values(patch_fetch_metadata):
patch_fetch_metadata(INSTANCE_ID, REGION, ACCOUNT_ID)
def test_is_instance__true():
aws_instance = AWSInstance()
assert aws_instance.is_instance is True
def test_is_instance__false_none(patch_fetch_metadata):
patch_fetch_metadata(None, "", "")
aws_instance = AWSInstance()
assert aws_instance.is_instance is False
def test_is_instance__false_empty_str(patch_fetch_metadata):
patch_fetch_metadata("", "", "")
aws_instance = AWSInstance()
assert aws_instance.is_instance is False
def test_instance_id():
aws_instance = AWSInstance()
assert aws_instance.instance_id == INSTANCE_ID
def test_region():
aws_instance = AWSInstance()
assert aws_instance.region == REGION
def test_account_id():
aws_instance = AWSInstance()
assert aws_instance.account_id == ACCOUNT_ID

View File

@ -4,7 +4,7 @@ import pytest
import requests import requests
import requests_mock import requests_mock
from common.aws.aws_instance import AWS_LATEST_METADATA_URI_PREFIX, AwsInstance from common.aws.aws_metadata import AWS_LATEST_METADATA_URI_PREFIX, fetch_aws_instance_metadata
INSTANCE_ID_RESPONSE = "i-1234567890abcdef0" INSTANCE_ID_RESPONSE = "i-1234567890abcdef0"
@ -38,7 +38,7 @@ EXPECTED_REGION = "us-west-2"
EXPECTED_ACCOUNT_ID = "123456789012" EXPECTED_ACCOUNT_ID = "123456789012"
def get_test_aws_instance( def patch_and_call_fetch_aws_instance_metadata(
text=MappingProxyType({"instance_id": None, "region": None, "account_id": None}), text=MappingProxyType({"instance_id": None, "region": None, "account_id": None}),
exception=MappingProxyType({"instance_id": None, "region": None, "account_id": None}), exception=MappingProxyType({"instance_id": None, "region": None, "account_id": None}),
): ):
@ -59,104 +59,85 @@ def get_test_aws_instance(
url, exc=exception["account_id"] url, exc=exception["account_id"]
) )
test_aws_instance_object = AwsInstance() return fetch_aws_instance_metadata()
return test_aws_instance_object
# all good data # all good data
@pytest.fixture @pytest.fixture
def good_data_mock_instance(): def good_metadata():
instance = get_test_aws_instance( return patch_and_call_fetch_aws_instance_metadata(
text={ text={
"instance_id": INSTANCE_ID_RESPONSE, "instance_id": INSTANCE_ID_RESPONSE,
"region": AVAILABILITY_ZONE_RESPONSE, "region": AVAILABILITY_ZONE_RESPONSE,
"account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE, "account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE,
} }
) )
yield instance
del instance
def test_is_instance_good_data(good_data_mock_instance): def test_instance_id_good_data(good_metadata):
assert good_data_mock_instance.is_instance assert good_metadata[0] == EXPECTED_INSTANCE_ID
def test_instance_id_good_data(good_data_mock_instance): def test_region_good_data(good_metadata):
assert good_data_mock_instance.instance_id == EXPECTED_INSTANCE_ID assert good_metadata[1] == EXPECTED_REGION
def test_region_good_data(good_data_mock_instance): def test_account_id_good_data(good_metadata):
assert good_data_mock_instance.region == EXPECTED_REGION assert good_metadata[2] == EXPECTED_ACCOUNT_ID
def test_account_id_good_data(good_data_mock_instance):
assert good_data_mock_instance.account_id == EXPECTED_ACCOUNT_ID
# 'region' bad data # 'region' bad data
@pytest.fixture @pytest.fixture
def bad_region_data_mock_instance(): def bad_region_metadata():
instance = get_test_aws_instance( return patch_and_call_fetch_aws_instance_metadata(
text={ text={
"instance_id": INSTANCE_ID_RESPONSE, "instance_id": INSTANCE_ID_RESPONSE,
"region": "in-a-different-world", "region": "in-a-different-world",
"account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE, "account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE,
} }
) )
yield instance
del instance
def test_is_instance_bad_region_data(bad_region_data_mock_instance): def test_instance_id_bad_region_data(bad_region_metadata):
assert bad_region_data_mock_instance.is_instance assert bad_region_metadata[0] == EXPECTED_INSTANCE_ID
def test_instance_id_bad_region_data(bad_region_data_mock_instance): def test_region_bad_region_data(bad_region_metadata):
assert bad_region_data_mock_instance.instance_id == EXPECTED_INSTANCE_ID assert bad_region_metadata[1] is None
def test_region_bad_region_data(bad_region_data_mock_instance): def test_account_id_bad_region_data(bad_region_metadata):
assert bad_region_data_mock_instance.region is None assert bad_region_metadata[2] == EXPECTED_ACCOUNT_ID
def test_account_id_bad_region_data(bad_region_data_mock_instance):
assert bad_region_data_mock_instance.account_id == EXPECTED_ACCOUNT_ID
# 'account_id' bad data # 'account_id' bad data
@pytest.fixture @pytest.fixture
def bad_account_id_data_mock_instance(): def bad_account_id_metadata():
instance = get_test_aws_instance( return patch_and_call_fetch_aws_instance_metadata(
text={ text={
"instance_id": INSTANCE_ID_RESPONSE, "instance_id": INSTANCE_ID_RESPONSE,
"region": AVAILABILITY_ZONE_RESPONSE, "region": AVAILABILITY_ZONE_RESPONSE,
"account_id": "who-am-i", "account_id": "who-am-i",
} }
) )
yield instance
del instance
def test_is_instance_bad_account_id_data(bad_account_id_data_mock_instance): def test_instance_id_bad_account_id_data(bad_account_id_metadata):
assert not bad_account_id_data_mock_instance.is_instance assert bad_account_id_metadata[0] is None
def test_instance_id_bad_account_id_data(bad_account_id_data_mock_instance): def test_region_bad_account_id_data(bad_account_id_metadata):
assert bad_account_id_data_mock_instance.instance_id is None assert bad_account_id_metadata[1] is None
def test_region_bad_account_id_data(bad_account_id_data_mock_instance): def test_account_id_data_bad_account_id_data(bad_account_id_metadata):
assert bad_account_id_data_mock_instance.region is None assert bad_account_id_metadata[2] is None
def test_account_id_data_bad_account_id_data(bad_account_id_data_mock_instance):
assert bad_account_id_data_mock_instance.account_id is None
# 'region' bad requests # 'region' bad requests
@pytest.fixture @pytest.fixture
def bad_region_request_mock_instance(region_exception): def region_request_failure_metadata(region_exception):
instance = get_test_aws_instance( return patch_and_call_fetch_aws_instance_metadata(
text={ text={
"instance_id": INSTANCE_ID_RESPONSE, "instance_id": INSTANCE_ID_RESPONSE,
"region": None, "region": None,
@ -164,33 +145,26 @@ def bad_region_request_mock_instance(region_exception):
}, },
exception={"instance_id": None, "region": region_exception, "account_id": None}, exception={"instance_id": None, "region": region_exception, "account_id": None},
) )
yield instance
del instance
@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError])
def test_is_instance_bad_region_request(bad_region_request_mock_instance): def test_instance_id_bad_region_request(region_request_failure_metadata):
assert not bad_region_request_mock_instance.is_instance assert region_request_failure_metadata[0] is None
@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError])
def test_instance_id_bad_region_request(bad_region_request_mock_instance): def test_region_bad_region_request(region_request_failure_metadata):
assert bad_region_request_mock_instance.instance_id is None assert region_request_failure_metadata[1] is None
@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError])
def test_region_bad_region_request(bad_region_request_mock_instance): def test_account_id_bad_region_request(region_request_failure_metadata):
assert bad_region_request_mock_instance.region is None assert region_request_failure_metadata[2] is None
@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError])
def test_account_id_bad_region_request(bad_region_request_mock_instance):
assert bad_region_request_mock_instance.account_id is None
# not found request # not found request
@pytest.fixture @pytest.fixture
def not_found_request_mock_instance(): def not_found_metadata():
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
# request made to get instance_id # request made to get instance_id
url = f"{AWS_LATEST_METADATA_URI_PREFIX}meta-data/instance-id" url = f"{AWS_LATEST_METADATA_URI_PREFIX}meta-data/instance-id"
@ -204,22 +178,16 @@ def not_found_request_mock_instance():
url = f"{AWS_LATEST_METADATA_URI_PREFIX}dynamic/instance-identity/document" url = f"{AWS_LATEST_METADATA_URI_PREFIX}dynamic/instance-identity/document"
m.get(url) m.get(url)
not_found_aws_instance_object = AwsInstance() return fetch_aws_instance_metadata()
yield not_found_aws_instance_object
del not_found_aws_instance_object
def test_is_instance_not_found_request(not_found_request_mock_instance): def test_instance_id_not_found_request(not_found_metadata):
assert not_found_request_mock_instance.is_instance is False assert not_found_metadata[0] is None
def test_instance_id_not_found_request(not_found_request_mock_instance): def test_region_not_found_request(not_found_metadata):
assert not_found_request_mock_instance.instance_id is None assert not_found_metadata[1] is None
def test_region_not_found_request(not_found_request_mock_instance): def test_account_id_not_found_request(not_found_metadata):
assert not_found_request_mock_instance.region is None assert not_found_metadata[2] is None
def test_account_id_not_found_request(not_found_request_mock_instance):
assert not_found_request_mock_instance.account_id is None

View File

@ -0,0 +1,22 @@
from queue import Queue
from common.utils.code_utils import queue_to_list
def test_empty_queue_to_empty_list():
q = Queue()
list_ = queue_to_list(q)
assert len(list_) == 0
def test_queue_to_list():
expected_list = [8, 6, 7, 5, 3, 0, 9]
q = Queue()
for i in expected_list:
q.put(i)
list_ = queue_to_list(q)
assert list_ == expected_list

View File

@ -2,7 +2,7 @@ import time
import pytest import pytest
from infection_monkey.utils.timer import Timer from common.utils import Timer
@pytest.fixture @pytest.fixture

View File

@ -3,8 +3,8 @@ from unittest.mock import MagicMock
import pytest import pytest
from common.utils import Timer
from infection_monkey.utils.decorators import request_cache from infection_monkey.utils.decorators import request_cache
from infection_monkey.utils.timer import Timer
class MockTimer(Timer): class MockTimer(Timer):

View File

@ -2,8 +2,8 @@ import io
from typing import BinaryIO from typing import BinaryIO
import pytest import pytest
from tests.common import StubDIContainer
from common import DIContainer
from monkey_island.cc.services import FileRetrievalError, IFileStorageService from monkey_island.cc.services import FileRetrievalError, IFileStorageService
FILE_NAME = "test_file" FILE_NAME = "test_file"
@ -31,8 +31,8 @@ class MockFileStorageService(IFileStorageService):
@pytest.fixture @pytest.fixture
def flask_client(build_flask_client, tmp_path): def flask_client(build_flask_client):
container = DIContainer() container = StubDIContainer()
container.register(IFileStorageService, MockFileStorageService) container.register(IFileStorageService, MockFileStorageService)
with build_flask_client(container) as flask_client: with build_flask_client(container) as flask_client:

View File

@ -2,9 +2,9 @@ import io
from typing import BinaryIO from typing import BinaryIO
import pytest import pytest
from tests.common import StubDIContainer
from tests.utils import raise_ from tests.utils import raise_
from common import DIContainer
from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
from monkey_island.cc.services import FileRetrievalError, IFileStorageService from monkey_island.cc.services import FileRetrievalError, IFileStorageService
@ -65,7 +65,7 @@ def file_storage_service():
@pytest.fixture @pytest.fixture
def flask_client(build_flask_client, file_storage_service): def flask_client(build_flask_client, file_storage_service):
container = DIContainer() container = StubDIContainer()
container.register_instance(IFileStorageService, file_storage_service) container.register_instance(IFileStorageService, file_storage_service)
with build_flask_client(container) as flask_client: with build_flask_client(container) as flask_client:

View File

@ -0,0 +1,108 @@
import json
from unittest.mock import MagicMock
import pytest
from tests.common import StubDIContainer
from monkey_island.cc.services import AWSService
from monkey_island.cc.services.aws import AWSCommandResults, AWSCommandStatus
@pytest.fixture
def mock_aws_service():
return MagicMock(spec=AWSService)
@pytest.fixture
def flask_client(build_flask_client, mock_aws_service):
container = StubDIContainer()
container.register_instance(AWSService, mock_aws_service)
with build_flask_client(container) as flask_client:
yield flask_client
def test_get_invalid_action(flask_client):
response = flask_client.get("/api/remote-monkey?action=INVALID")
assert response.text.rstrip() == "{}"
def test_get_no_action(flask_client):
response = flask_client.get("/api/remote-monkey")
assert response.text.rstrip() == "{}"
def test_get_not_aws(flask_client, mock_aws_service):
mock_aws_service.island_is_running_on_aws = MagicMock(return_value=False)
response = flask_client.get("/api/remote-monkey?action=list_aws")
assert response.text.rstrip() == '{"is_aws":false}'
def test_get_instances(flask_client, mock_aws_service):
instances = [
{"instance_id": "1", "name": "name1", "os": "linux", "ip_address": "1.1.1.1"},
{"instance_id": "2", "name": "name2", "os": "windows", "ip_address": "2.2.2.2"},
{"instance_id": "3", "name": "name3", "os": "linux", "ip_address": "3.3.3.3"},
]
mock_aws_service.island_is_running_on_aws = MagicMock(return_value=True)
mock_aws_service.get_managed_instances = MagicMock(return_value=instances)
response = flask_client.get("/api/remote-monkey?action=list_aws")
assert json.loads(response.text)["instances"] == instances
assert json.loads(response.text)["is_aws"] is True
# TODO: Test error cases for get()
def test_post_no_type(flask_client):
response = flask_client.post("/api/remote-monkey", data="{}")
assert response.status_code == 500
def test_post_invalid_type(flask_client):
response = flask_client.post("/api/remote-monkey", data='{"type": "INVALID"}')
assert response.status_code == 500
def test_post(flask_client, mock_aws_service):
request_body = json.dumps(
{
"type": "aws",
"instances": [
{"instance_id": "1", "os": "linux"},
{"instance_id": "2", "os": "linux"},
{"instance_id": "3", "os": "windows"},
],
"island_ip": "127.0.0.1",
}
)
mock_aws_service.run_agents_on_managed_instances = MagicMock(
return_value=[
AWSCommandResults("1", 0, "", "", AWSCommandStatus.SUCCESS),
AWSCommandResults("2", 0, "some_output", "", AWSCommandStatus.IN_PROGRESS),
AWSCommandResults("3", -1, "", "some_error", AWSCommandStatus.ERROR),
]
)
expected_result = [
{"instance_id": "1", "response_code": 0, "stdout": "", "stderr": "", "status": "success"},
{
"instance_id": "2",
"response_code": 0,
"stdout": "some_output",
"stderr": "",
"status": "in_progress",
},
{
"instance_id": "3",
"response_code": -1,
"stdout": "",
"stderr": "some_error",
"status": "error",
},
]
response = flask_client.post("/api/remote-monkey", data=request_body)
assert json.loads(response.text)["result"] == expected_result

View File

@ -0,0 +1,232 @@
from itertools import chain, repeat
from unittest.mock import MagicMock
import pytest
from monkey_island.cc.services.aws.aws_command_runner import (
LINUX_DOCUMENT_NAME,
WINDOWS_DOCUMENT_NAME,
AWSCommandResults,
AWSCommandStatus,
start_infection_monkey_agent,
)
TIMEOUT = 0.03
INSTANCE_ID = "BEEFFACE"
ISLAND_IP = "127.0.0.1"
"""
"commands": [
"wget --no-check-certificate "
"https://172.31.32.78:5000/api/agent/download/linux "
"-O monkey-linux-64; chmod +x "
"monkey-linux-64; ./monkey-linux-64 "
"m0nk3y -s 172.31.32.78:5000"
]
"""
@pytest.fixture
def send_command_response():
return {
"Command": {
"CloudWatchOutputConfig": {
"CloudWatchLogGroupName": "",
"CloudWatchOutputEnabled": False,
},
"CommandId": "fe3cf24f-71b7-42b9-93ca-e27c34dd0581",
"CompletedCount": 0,
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"InstanceIds": ["i-0b62d6f0b0d9d7e77"],
"OutputS3Region": "eu-central-1",
"Parameters": {"commands": []},
"Status": "Pending",
"StatusDetails": "Pending",
"TargetCount": 1,
"Targets": [],
"TimeoutSeconds": 3600,
},
"ResponseMetadata": {
"HTTPHeaders": {
"connection": "keep-alive",
"content-length": "973",
"content-type": "application/x-amz-json-1.1",
"date": "Tue, 10 May 2022 12:35:49 GMT",
"server": "Server",
"x-amzn-requestid": "110b1563-aaf0-4e09-bd23-2db465822be7",
},
"HTTPStatusCode": 200,
"RequestId": "110b1563-aaf0-4e09-bd23-2db465822be7",
"RetryAttempts": 0,
},
}
@pytest.fixture
def in_progress_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
"Status": "InProgress",
"StatusDetails": "InProgress",
}
@pytest.fixture
def success_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
"Status": "Success",
"StatusDetails": "Success",
}
@pytest.fixture
def error_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "ERROR RUNNING COMMAND",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
# NOTE: "Error" is technically not a valid value for this field, but we want to test that
# anything other than "Success" and "InProgress" is treated as an error. This is
# simpler than testing all of the different possible error cases.
"Status": "Error",
"StatusDetails": "Error",
}
@pytest.fixture(autouse=True)
def patch_timeouts(monkeypatch):
monkeypatch.setattr(
"monkey_island.cc.services.aws.aws_command_runner.STATUS_CHECK_SLEEP_TIME", 0.01
)
@pytest.fixture
def successful_mock_client(send_command_response, success_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=success_response)
return aws_client
def test_correct_instance_id(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert call_args_kwargs["InstanceIds"] == [INSTANCE_ID]
def test_linux_doc_name(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert call_args_kwargs["DocumentName"] == LINUX_DOCUMENT_NAME
def test_windows_doc_name(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert call_args_kwargs["DocumentName"] == WINDOWS_DOCUMENT_NAME
def test_linux_command(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert "wget" in call_args_kwargs["Parameters"]["commands"][0]
def test_windows_command(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert "DownloadFile" in call_args_kwargs["Parameters"]["commands"][0]
def test_multiple_status_queries(send_command_response, in_progress_response, success_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(
side_effect=chain([in_progress_response, in_progress_response], repeat(success_response))
)
command_results = start_infection_monkey_agent(
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
)
assert command_results.status == AWSCommandStatus.SUCCESS
def test_in_progress_timeout(send_command_response, in_progress_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=in_progress_response)
command_results = start_infection_monkey_agent(
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
)
assert command_results.status == AWSCommandStatus.IN_PROGRESS
def test_failed_command(send_command_response, error_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=error_response)
command_results = start_infection_monkey_agent(
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
)
assert command_results.status == AWSCommandStatus.ERROR
@pytest.mark.parametrize(
"status, success",
[
(AWSCommandStatus.SUCCESS, True),
(AWSCommandStatus.IN_PROGRESS, False),
(AWSCommandStatus.ERROR, False),
],
)
def test_command_resuls_status(status, success):
results = AWSCommandResults(INSTANCE_ID, 0, "", "", status)
assert results.success == success

View File

@ -0,0 +1,150 @@
import threading
from typing import Any, Dict, Optional, Sequence
import pytest
from common.aws import AWSInstance
from monkey_island.cc.services import AWSService
EXPECTED_INSTANCE_1 = {
"instance_id": "1",
"name": "comp1",
"os": "linux",
"ip_address": "192.168.1.1",
}
EXPECTED_INSTANCE_2 = {
"instance_id": "2",
"name": "comp2",
"os": "linux",
"ip_address": "192.168.1.2",
}
EMPTY_INSTANCE_INFO_RESPONSE = []
FULL_INSTANCE_INFO_RESPONSE = [
{
"ActivationId": "string",
"AgentVersion": "string",
"AssociationOverview": {
"DetailedStatus": "string",
"InstanceAssociationStatusAggregatedCount": {"string": 6},
},
"AssociationStatus": "string",
"ComputerName": EXPECTED_INSTANCE_1["name"],
"IamRole": "string",
"InstanceId": EXPECTED_INSTANCE_1["instance_id"],
"IPAddress": EXPECTED_INSTANCE_1["ip_address"],
"IsLatestVersion": "True",
"LastAssociationExecutionDate": 6,
"LastPingDateTime": 6,
"LastSuccessfulAssociationExecutionDate": 6,
"Name": "string",
"PingStatus": "string",
"PlatformName": "string",
"PlatformType": EXPECTED_INSTANCE_1["os"],
"PlatformVersion": "string",
"RegistrationDate": 6,
"ResourceType": "string",
},
{
"ActivationId": "string",
"AgentVersion": "string",
"AssociationOverview": {
"DetailedStatus": "string",
"InstanceAssociationStatusAggregatedCount": {"string": 6},
},
"AssociationStatus": "string",
"ComputerName": EXPECTED_INSTANCE_2["name"],
"IamRole": "string",
"InstanceId": EXPECTED_INSTANCE_2["instance_id"],
"IPAddress": EXPECTED_INSTANCE_2["ip_address"],
"IsLatestVersion": "True",
"LastAssociationExecutionDate": 6,
"LastPingDateTime": 6,
"LastSuccessfulAssociationExecutionDate": 6,
"Name": "string",
"PingStatus": "string",
"PlatformName": "string",
"PlatformType": EXPECTED_INSTANCE_2["os"],
"PlatformVersion": "string",
"RegistrationDate": 6,
"ResourceType": "string",
},
]
class StubAWSInstance(AWSInstance):
def __init__(
self,
instance_id: Optional[str] = None,
region: Optional[str] = None,
account_id: Optional[str] = None,
):
self._instance_id = instance_id
self._region = region
self._account_id = account_id
self._initialization_complete = threading.Event()
self._initialization_complete.set()
def test_aws_is_on_aws__true():
aws_instance = StubAWSInstance("1")
aws_service = AWSService(aws_instance)
assert aws_service.island_is_running_on_aws() is True
def test_aws_is_on_aws__False():
aws_instance = StubAWSInstance()
aws_service = AWSService(aws_instance)
assert aws_service.island_is_running_on_aws() is False
INSTANCE_ID = "1"
REGION = "2"
ACCOUNT_ID = "3"
@pytest.fixture
def aws_instance():
return StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID)
@pytest.fixture
def aws_service(aws_instance):
return AWSService(aws_instance)
def test_instance_id(aws_service):
assert aws_service.island_aws_instance.instance_id == INSTANCE_ID
def test_region(aws_service):
assert aws_service.island_aws_instance.region == REGION
def test_account_id(aws_service):
assert aws_service.island_aws_instance.account_id == ACCOUNT_ID
class MockAWSService(AWSService):
def __init__(self, aws_instance: AWSInstance, instance_info_response: Sequence[Dict[str, Any]]):
super().__init__(aws_instance)
self._instance_info_response = instance_info_response
def _get_raw_managed_instances(self):
return self._instance_info_response
def test_get_managed_instances__empty(aws_instance):
aws_service = MockAWSService(aws_instance, EMPTY_INSTANCE_INFO_RESPONSE)
instances = aws_service.get_managed_instances()
assert len(instances) == 0
def test_get_managed_instances(aws_instance):
aws_service = MockAWSService(aws_instance, FULL_INSTANCE_INFO_RESPONSE)
instances = aws_service.get_managed_instances()
assert len(instances) == 2
assert instances[0] == EXPECTED_INSTANCE_1
assert instances[1] == EXPECTED_INSTANCE_2

View File

@ -1,56 +0,0 @@
import json
from unittest import TestCase
from monkey_island.cc.services.aws_service import filter_instance_data_from_aws_response
class TestAwsService(TestCase):
def test_filter_instance_data_from_aws_response(self):
json_response_full = """
{
"InstanceInformationList": [
{
"ActivationId": "string",
"AgentVersion": "string",
"AssociationOverview": {
"DetailedStatus": "string",
"InstanceAssociationStatusAggregatedCount": {
"string" : 6
}
},
"AssociationStatus": "string",
"ComputerName": "string",
"IamRole": "string",
"InstanceId": "string",
"IPAddress": "string",
"IsLatestVersion": "True",
"LastAssociationExecutionDate": 6,
"LastPingDateTime": 6,
"LastSuccessfulAssociationExecutionDate": 6,
"Name": "string",
"PingStatus": "string",
"PlatformName": "string",
"PlatformType": "string",
"PlatformVersion": "string",
"RegistrationDate": 6,
"ResourceType": "string"
}
],
"NextToken": "string"
}
"""
json_response_empty = """
{
"InstanceInformationList": [],
"NextToken": "string"
}
"""
self.assertEqual(
filter_instance_data_from_aws_response(json.loads(json_response_empty)), []
)
self.assertEqual(
filter_instance_data_from_aws_response(json.loads(json_response_full)),
[{"instance_id": "string", "ip_address": "string", "name": "string", "os": "string"}],
)

View File

@ -167,3 +167,5 @@ _.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py
_.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py:64) _.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py:64)
GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57) GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57)
architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25) architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25)
response_code # unused variable (monkey/monkey_island/cc/services/aws/aws_command_runner.py:26)