Merge branch '1928-run-agent-on-remote-instance' into 1928-aws-service-refactor

This commit is contained in:
Mike Salvatore 2022-05-11 10:38:51 -04:00
commit c685ce3725
40 changed files with 664 additions and 374 deletions

View File

@ -1 +1 @@
from .di_container import DIContainer from .di_container import DIContainer, UnregisteredTypeError

View File

@ -1,8 +0,0 @@
class Cmd(object):
"""
Class representing a command
"""
def __init__(self, cmd_runner, cmd_id):
self.cmd_runner = cmd_runner
self.cmd_id = cmd_id

View File

@ -1,10 +0,0 @@
class CmdResult(object):
"""
Class representing a command result
"""
def __init__(self, is_success, status_code=None, stdout=None, stderr=None):
self.is_success = is_success
self.status_code = status_code
self.stdout = stdout
self.stderr = stderr

View File

@ -1,154 +0,0 @@
import logging
import time
from abc import abstractmethod
from common.cmd.cmd_result import CmdResult
from common.cmd.cmd_status import CmdStatus
logger = logging.getLogger(__name__)
class CmdRunner(object):
"""
Interface for running commands on a remote machine
Since these classes are a bit complex, I provide a list of common terminology and formats:
* command line - a command line. e.g. 'echo hello'
* command - represent a single command which was already run. Always of type Cmd
* command id - any unique identifier of a command which was already run
* command result - represents the result of running a command. Always of type CmdResult
* command status - represents the current status of a command. Always of type CmdStatus
* command info - Any consistent structure representing additional information of a command
which was already run
* instance - a machine that commands will be run on. Can be any dictionary with 'instance_id'
as a field
* instance_id - any unique identifier of an instance (machine). Can be of any format
"""
# Default command timeout in seconds
DEFAULT_TIMEOUT = 5
# Time to sleep when waiting on commands.
WAIT_SLEEP_TIME = 1
def __init__(self, is_linux):
self.is_linux = is_linux
@staticmethod
def run_multiple_commands(instances, inst_to_cmd, inst_n_cmd_res_to_res):
"""
Run multiple commands on various instances
:param instances: List of instances.
:param inst_to_cmd: Function which receives an instance, runs a command asynchronously
and returns Cmd
:param inst_n_cmd_res_to_res: Function which receives an instance and CmdResult
and returns a parsed result (of any format)
:return: Dictionary with 'instance_id' as key and parsed result as value
"""
command_instance_dict = {}
for instance in instances:
command = inst_to_cmd(instance)
command_instance_dict[command] = instance
instance_results = {}
command_result_pairs = CmdRunner.wait_commands(list(command_instance_dict.keys()))
for command, result in command_result_pairs:
instance = command_instance_dict[command]
instance_results[instance["instance_id"]] = inst_n_cmd_res_to_res(instance, result)
return instance_results
@abstractmethod
def run_command_async(self, command_line):
"""
Runs the given command on the remote machine asynchronously.
:param command_line: The command line to run
:return: Command ID (in any format)
"""
raise NotImplementedError()
@staticmethod
def wait_commands(commands, timeout=DEFAULT_TIMEOUT):
"""
Waits on all commands up to given timeout
:param commands: list of commands (of type Cmd)
:param timeout: Timeout in seconds for command.
:return: commands and their results (tuple of Command and CmdResult)
"""
init_time = time.time()
curr_time = init_time
results = []
# TODO: Use timer.Timer
while (curr_time - init_time < timeout) and (len(commands) != 0):
for command in list(
commands
): # list(commands) clones the list. We do so because we remove items inside
CmdRunner._process_command(command, commands, results, True)
time.sleep(CmdRunner.WAIT_SLEEP_TIME)
curr_time = time.time()
for command in list(commands):
CmdRunner._process_command(command, commands, results, False)
for command, result in results:
if not result.is_success:
logger.error(
f"The command with id: {str(command.cmd_id)} failed. "
f"Status code: {str(result.status_code)}"
)
return results
@abstractmethod
def query_command(self, command_id):
"""
Queries the already run command for more info
:param command_id: The command ID to query
:return: Command info (in any format)
"""
raise NotImplementedError()
@abstractmethod
def get_command_result(self, command_info):
"""
Gets the result of the already run command
:param command_info: The command info of the command to get the result of
:return: CmdResult
"""
raise NotImplementedError()
@abstractmethod
def get_command_status(self, command_info):
"""
Gets the status of the already run command
:param command_info: The command info of the command to get the result of
:return: CmdStatus
"""
raise NotImplementedError()
@staticmethod
def _process_command(command, commands, results, should_process_only_finished):
"""
Removes the command from the list, processes its result and appends to results
:param command: Command to process. Must be in commands.
:param commands: List of unprocessed commands.
:param results: List of command results.
:param should_process_only_finished: If True, processes only if command finished.
:return: None
"""
c_runner = command.cmd_runner
c_id = command.cmd_id
try:
command_info = c_runner.query_command(c_id)
if (not should_process_only_finished) or c_runner.get_command_status(
command_info
) != CmdStatus.IN_PROGRESS:
commands.remove(command)
results.append((command, c_runner.get_command_result(command_info)))
except Exception:
logger.exception("Exception while querying command: `%s`", str(c_id))
if not should_process_only_finished:
commands.remove(command)
results.append((command, CmdResult(False)))

View File

@ -1,7 +0,0 @@
from enum import Enum
class CmdStatus(Enum):
IN_PROGRESS = 0
SUCCESS = 1
FAILURE = 2

View File

@ -0,0 +1 @@
from .timer import Timer

View File

@ -1,3 +1,7 @@
import queue
from typing import Any, List
class abstractstatic(staticmethod): class abstractstatic(staticmethod):
__slots__ = () __slots__ = ()
@ -15,3 +19,14 @@ class Singleton(type):
if cls not in cls._instances: if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls] return cls._instances[cls]
def queue_to_list(q: queue.Queue) -> List[Any]:
list_ = []
try:
while True:
list_.append(q.get_nowait())
except queue.Empty:
pass
return list_

View File

@ -3,6 +3,7 @@ import time
from pathlib import PurePath from pathlib import PurePath
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.utils import Timer
from infection_monkey.exploit.log4shell_utils import ( from infection_monkey.exploit.log4shell_utils import (
LINUX_EXPLOIT_TEMPLATE_PATH, LINUX_EXPLOIT_TEMPLATE_PATH,
WINDOWS_EXPLOIT_TEMPLATE_PATH, WINDOWS_EXPLOIT_TEMPLATE_PATH,
@ -21,7 +22,6 @@ from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.utils.commands import build_monkey_commandline from infection_monkey.utils.commands import build_monkey_commandline
from infection_monkey.utils.monkey_dir import get_monkey_dir_path from infection_monkey.utils.monkey_dir import get_monkey_dir_path
from infection_monkey.utils.threading import interruptible_iter from infection_monkey.utils.threading import interruptible_iter
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -5,6 +5,7 @@ from pathlib import PurePath
import paramiko import paramiko
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.utils import Timer
from common.utils.attack_utils import ScanStatus from common.utils.attack_utils import ScanStatus
from common.utils.exceptions import FailedExploitationError from common.utils.exceptions import FailedExploitationError
from infection_monkey.exploit.HostExploiter import HostExploiter from infection_monkey.exploit.HostExploiter import HostExploiter
@ -17,7 +18,6 @@ from infection_monkey.telemetry.attack.t1222_telem import T1222Telem
from infection_monkey.utils.brute_force import generate_identity_secret_pairs from infection_monkey.utils.brute_force import generate_identity_secret_pairs
from infection_monkey.utils.commands import build_monkey_commandline from infection_monkey.utils.commands import build_monkey_commandline
from infection_monkey.utils.threading import interruptible_iter from infection_monkey.utils.threading import interruptible_iter
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SSH_PORT = 22 SSH_PORT = 22

View File

@ -3,6 +3,7 @@ import threading
import time import time
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from common.utils import Timer
from infection_monkey.credential_store import ICredentialsStore from infection_monkey.credential_store import ICredentialsStore
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
@ -13,7 +14,6 @@ from infection_monkey.telemetry.credentials_telem import CredentialsTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter
from infection_monkey.utils.timer import Timer
from . import Exploiter, IPScanner, Propagator from . import Exploiter, IPScanner, Propagator
from .option_parsing import custom_pba_is_enabled from .option_parsing import custom_pba_is_enabled

View File

@ -4,9 +4,9 @@ import socket
import time import time
from typing import Iterable, Mapping, Tuple from typing import Iterable, Mapping, Tuple
from common.utils import Timer
from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,10 +2,10 @@ import queue
import threading import threading
from typing import Dict from typing import Dict
from common.utils import Timer
from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem
from infection_monkey.telemetry.i_telem import ITelem from infection_monkey.telemetry.i_telem import ITelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.utils.timer import Timer
DEFAULT_PERIOD = 5 DEFAULT_PERIOD = 5
WAKES_PER_PERIOD = 4 WAKES_PER_PERIOD = 4

View File

@ -4,11 +4,11 @@ import struct
import time import time
from threading import Event, Thread from threading import Event, Thread
from common.utils import Timer
from infection_monkey.network.firewall import app as firewall from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port, local_ips from infection_monkey.network.info import get_free_tcp_port, local_ips
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
from infection_monkey.transport.base import get_last_serve_time from infection_monkey.transport.base import get_last_serve_time
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,7 +1,7 @@
import threading import threading
from functools import wraps from functools import wraps
from .timer import Timer from common.utils import Timer
def request_cache(ttl: float): def request_cache(ttl: float):

View File

@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
from common import DIContainer from common import DIContainer
from monkey_island.cc.database import database, mongo from monkey_island.cc.database import database, mongo
from monkey_island.cc.resources import RemoteRun
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
from monkey_island.cc.resources.attack.attack_report import AttackReport from monkey_island.cc.resources.attack.attack_report import AttackReport
from monkey_island.cc.resources.auth.auth import Authenticate, init_jwt from monkey_island.cc.resources.auth.auth import Authenticate, init_jwt
@ -38,7 +39,6 @@ from monkey_island.cc.resources.pba_file_download import PBAFileDownload
from monkey_island.cc.resources.pba_file_upload import FileUpload from monkey_island.cc.resources.pba_file_upload import FileUpload
from monkey_island.cc.resources.propagation_credentials import PropagationCredentials from monkey_island.cc.resources.propagation_credentials import PropagationCredentials
from monkey_island.cc.resources.ransomware_report import RansomwareReport from monkey_island.cc.resources.ransomware_report import RansomwareReport
from monkey_island.cc.resources.remote_run import RemoteRun
from monkey_island.cc.resources.root import Root from monkey_island.cc.resources.root import Root
from monkey_island.cc.resources.security_report import SecurityReport from monkey_island.cc.resources.security_report import SecurityReport
from monkey_island.cc.resources.telemetry import Telemetry from monkey_island.cc.resources.telemetry import Telemetry

View File

@ -0,0 +1 @@
from .remote_run import RemoteRun

View File

@ -1,12 +1,13 @@
import json import json
from typing import Sequence
import flask_restful import flask_restful
from botocore.exceptions import ClientError, NoCredentialsError from botocore.exceptions import ClientError, NoCredentialsError
from flask import jsonify, make_response, request from flask import jsonify, make_response, request
from monkey_island.cc.resources.auth.auth import jwt_required from monkey_island.cc.resources.auth.auth import jwt_required
from monkey_island.cc.services import aws_service from monkey_island.cc.services import AWSService
from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService from monkey_island.cc.services.aws import AWSCommandResults
CLIENT_ERROR_FORMAT = ( CLIENT_ERROR_FORMAT = (
"ClientError, error message: '{}'. Probably, the IAM role that has been associated with the " "ClientError, error message: '{}'. Probably, the IAM role that has been associated with the "
@ -19,20 +20,18 @@ NO_CREDS_ERROR_FORMAT = (
class RemoteRun(flask_restful.Resource): class RemoteRun(flask_restful.Resource):
def run_aws_monkeys(self, request_body): def __init__(self, aws_service: AWSService):
instances = request_body.get("instances") self._aws_service = aws_service
island_ip = request_body.get("island_ip")
return RemoteRunAwsService.run_aws_monkeys(instances, island_ip)
@jwt_required @jwt_required
def get(self): def get(self):
action = request.args.get("action") action = request.args.get("action")
if action == "list_aws": if action == "list_aws":
is_aws = aws_service.is_on_aws() is_aws = self._aws_service.island_is_running_on_aws()
resp = {"is_aws": is_aws} resp = {"is_aws": is_aws}
if is_aws: if is_aws:
try: try:
resp["instances"] = aws_service.get_instances() resp["instances"] = self._aws_service.get_managed_instances()
except NoCredentialsError as e: except NoCredentialsError as e:
resp["error"] = NO_CREDS_ERROR_FORMAT.format(e) resp["error"] = NO_CREDS_ERROR_FORMAT.format(e)
return jsonify(resp) return jsonify(resp)
@ -46,11 +45,32 @@ class RemoteRun(flask_restful.Resource):
@jwt_required @jwt_required
def post(self): def post(self):
body = json.loads(request.data) body = json.loads(request.data)
resp = {}
if body.get("type") == "aws": if body.get("type") == "aws":
result = self.run_aws_monkeys(body) results = self.run_aws_monkeys(body)
resp["result"] = result return RemoteRun._encode_results(results)
return jsonify(resp)
# default action # default action
return make_response({"error": "Invalid action"}, 500) return make_response({"error": "Invalid action"}, 500)
def run_aws_monkeys(self, request_body) -> Sequence[AWSCommandResults]:
instances = request_body.get("instances")
island_ip = request_body.get("island_ip")
return self._aws_service.run_agents_on_managed_instances(instances, island_ip)
@staticmethod
def _encode_results(results: Sequence[AWSCommandResults]):
result = list(map(RemoteRun._aws_command_results_to_encodable_dict, results))
response = {"result": result}
return jsonify(response)
@staticmethod
def _aws_command_results_to_encodable_dict(aws_command_results: AWSCommandResults):
return {
"instance_id": aws_command_results.instance_id,
"response_code": aws_command_results.response_code,
"stdout": aws_command_results.stdout,
"stderr": aws_command_results.stderr,
"status": aws_command_results.status.name.lower(),
}

View File

@ -1,29 +0,0 @@
from common.cmd.cmd_result import CmdResult
class AwsCmdResult(CmdResult):
"""
Class representing an AWS command result
"""
def __init__(self, command_info):
super(AwsCmdResult, self).__init__(
self.is_successful(command_info, True),
command_info["ResponseCode"],
command_info["StandardOutputContent"],
command_info["StandardErrorContent"],
)
self.command_info = command_info
@staticmethod
def is_successful(command_info, is_timeout=False):
"""
Determines whether the command was successful. If it timed out and was still in progress,
we assume it worked.
:param command_info: Command info struct (returned by ssm.get_command_invocation)
:param is_timeout: Whether the given command timed out
:return: True if successful, False otherwise.
"""
return (command_info["Status"] == "Success") or (
is_timeout and (command_info["Status"] == "InProgress")
)

View File

@ -1,45 +0,0 @@
import logging
import time
from common.cmd.cmd_runner import CmdRunner
from common.cmd.cmd_status import CmdStatus
from monkey_island.cc.server_utils.aws_cmd_result import AwsCmdResult
from monkey_island.cc.services import aws_service
logger = logging.getLogger(__name__)
class AwsCmdRunner(CmdRunner):
"""
Class for running commands on a remote AWS machine
"""
def __init__(self, is_linux, instance_id, region=None):
super(AwsCmdRunner, self).__init__(is_linux)
self.instance_id = instance_id
self.region = region
self.ssm = aws_service.get_client("ssm", region)
def query_command(self, command_id):
time.sleep(2)
return self.ssm.get_command_invocation(CommandId=command_id, InstanceId=self.instance_id)
def get_command_result(self, command_info):
return AwsCmdResult(command_info)
def get_command_status(self, command_info):
if command_info["Status"] == "InProgress":
return CmdStatus.IN_PROGRESS
elif command_info["Status"] == "Success":
return CmdStatus.SUCCESS
else:
return CmdStatus.FAILURE
def run_command_async(self, command_line):
doc_name = "AWS-RunShellScript" if self.is_linux else "AWS-RunPowerShellScript"
command_res = self.ssm.send_command(
DocumentName=doc_name,
Parameters={"commands": [command_line]},
InstanceIds=[self.instance_id],
)
return command_res["Command"]["CommandId"]

View File

@ -4,4 +4,8 @@ from .directory_file_storage_service import DirectoryFileStorageService
from .authentication.authentication_service import AuthenticationService from .authentication.authentication_service import AuthenticationService
from .authentication.json_file_user_datastore import JsonFileUserDatastore from .authentication.json_file_user_datastore import JsonFileUserDatastore
from .aws_service import AWSService from .aws import AWSService
# TODO: This is a temporary import to keep some tests passing. Remove it before merging #1928 to
# develop.
from .aws import aws_service

View File

@ -0,0 +1,2 @@
from .aws_service import AWSService
from .aws_command_runner import AWSCommandResults, AWSCommandStatus

View File

@ -0,0 +1,149 @@
import logging
import time
from dataclasses import dataclass
from enum import Enum, auto
import botocore
from common.utils import Timer
STATUS_CHECK_SLEEP_TIME = 1
LINUX_DOCUMENT_NAME = "AWS-RunShellScript"
WINDOWS_DOCUMENT_NAME = "AWS-RunPowerShellScript"
logger = logging.getLogger(__name__)
class AWSCommandStatus(Enum):
SUCCESS = auto()
IN_PROGRESS = auto()
ERROR = auto()
@dataclass(frozen=True)
class AWSCommandResults:
instance_id: str
response_code: int
stdout: str
stderr: str
status: AWSCommandStatus
@property
def success(self):
return self.status == AWSCommandStatus.SUCCESS
def start_infection_monkey_agent(
aws_client: botocore.client.BaseClient,
target_instance_id: str,
target_os: str,
island_ip: str,
timeout: float,
) -> AWSCommandResults:
"""
Run a command on a remote AWS instance
"""
command = _get_run_agent_command(target_os, island_ip)
command_id = _run_command_async(aws_client, target_instance_id, target_os, command)
_wait_for_command_to_complete(aws_client, target_instance_id, command_id, timeout)
return _fetch_command_results(aws_client, target_instance_id, command_id)
def _get_run_agent_command(target_os: str, island_ip: str):
if target_os == "linux":
return _get_run_monkey_cmd_linux_line(island_ip)
return _get_run_monkey_cmd_windows_line(island_ip)
def _get_run_monkey_cmd_linux_line(island_ip):
binary_name = "monkey-linux-64"
download_url = f"https://{island_ip}:5000/api/agent/download/linux"
download_cmd = f"wget --no-check-certificate {download_url} -O {binary_name}"
chmod_cmd = f"chmod +x {binary_name}"
run_agent_cmd = f"./{binary_name} m0nk3y -s {island_ip}:5000"
return f"{download_cmd}; {chmod_cmd}; {run_agent_cmd}"
def _get_run_monkey_cmd_windows_line(island_ip):
agent_exe_path = r".\\monkey.exe"
ignore_ssl_errors_cmd = (
"[System.Net.ServicePointManager]::ServerCertificateValidationCallback = {$true}"
)
download_url = f"https://{island_ip}:5000/api/agent/download/windows"
download_cmd = (
f"(New-Object System.Net.WebClient).DownloadFile('{download_url}', '{agent_exe_path}')"
)
run_agent_cmd = (
f"Start-Process -FilePath '{agent_exe_path}' -ArgumentList 'm0nk3y -s {island_ip}:5000'"
)
return f"{ignore_ssl_errors_cmd}; {download_cmd}; {run_agent_cmd};"
def _run_command_async(
aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, command: str
):
doc_name = LINUX_DOCUMENT_NAME if target_os == "linux" else WINDOWS_DOCUMENT_NAME
logger.debug(f'Running command on {target_instance_id} -- {doc_name}: "{command}"')
command_response = aws_client.send_command(
DocumentName=doc_name,
Parameters={"commands": [command]},
InstanceIds=[target_instance_id],
)
command_id = command_response["Command"]["CommandId"]
logger.debug(
f"Started command on AWS instance {target_instance_id} with command ID {command_id}"
)
return command_id
def _wait_for_command_to_complete(
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str, timeout: float
):
timer = Timer()
timer.set(timeout)
while not timer.is_expired():
time.sleep(STATUS_CHECK_SLEEP_TIME)
command_results = _fetch_command_results(aws_client, target_instance_id, command_id)
logger.debug(f"Command {command_id} status: {command_results.status.name}")
if command_results.status != AWSCommandStatus.IN_PROGRESS:
return
def _fetch_command_results(
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str
) -> AWSCommandResults:
command_results = aws_client.get_command_invocation(
CommandId=command_id, InstanceId=target_instance_id
)
command_status = command_results["Status"]
logger.debug(f"Command {command_id} status: {command_status}")
if command_status == "Success":
aws_command_result_status = AWSCommandStatus.SUCCESS
elif command_status == "InProgress":
aws_command_result_status = AWSCommandStatus.IN_PROGRESS
else:
aws_command_result_status = AWSCommandStatus.ERROR
return AWSCommandResults(
target_instance_id,
command_results["ResponseCode"],
command_results["StandardOutputContent"],
command_results["StandardErrorContent"],
aws_command_result_status,
)

View File

@ -1,11 +1,17 @@
import logging import logging
from queue import Queue
from threading import Thread
from typing import Any, Iterable, Mapping, Sequence from typing import Any, Iterable, Mapping, Sequence
import boto3 import boto3
import botocore import botocore
from common.aws.aws_instance import AWSInstance from common.aws.aws_instance import AWSInstance
from common.utils.code_utils import queue_to_list
from .aws_command_runner import AWSCommandResults, start_infection_monkey_agent
DEFAULT_REMOTE_COMMAND_TIMEOUT = 5
INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList" INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList"
INSTANCE_ID_KEY = "InstanceId" INSTANCE_ID_KEY = "InstanceId"
COMPUTER_NAME_KEY = "ComputerName" COMPUTER_NAME_KEY = "ComputerName"
@ -58,20 +64,52 @@ class AWSService:
:raises: botocore.exceptions.ClientError if can't describe local instance information. :raises: botocore.exceptions.ClientError if can't describe local instance information.
:return: All visible instances from this instance :return: All visible instances from this instance
""" """
local_ssm_client = boto3.client("ssm", self.island_aws_instance.region) ssm_client = boto3.client("ssm", self.island_aws_instance.region)
try: try:
response = local_ssm_client.describe_instance_information() response = ssm_client.describe_instance_information()
return response[INSTANCE_INFORMATION_LIST_KEY] return response[INSTANCE_INFORMATION_LIST_KEY]
except botocore.exceptions.ClientError as err: except botocore.exceptions.ClientError as err:
logger.warning("AWS client error while trying to get manage dinstances: {err}") logger.warning("AWS client error while trying to get manage dinstances: {err}")
raise err raise err
def run_agent_on_managed_instances(self, instance_ids: Iterable[str]): def run_agents_on_managed_instances(
for id_ in instance_ids: self,
self._run_agent_on_managed_instance(id_) instances: Iterable[Mapping[str, str]],
island_ip: str,
timeout: float = DEFAULT_REMOTE_COMMAND_TIMEOUT,
) -> Sequence[AWSCommandResults]:
"""
Run an agent on one or more managed AWS instances.
:param instances: An iterable of instances that the agent will be run on
:param island_ip: The IP address of the Island to pass to the new agents
:param timeout: The maximum number of seconds to wait for the agents to start
:return: A sequence of AWSCommandResults
"""
def _run_agent_on_managed_instance(self, instance_id: str): results_queue = Queue()
pass command_threads = []
for i in instances:
t = Thread(
target=self._run_agent_on_managed_instance,
args=(results_queue, i["instance_id"], i["os"], island_ip, timeout),
daemon=True,
)
t.start()
command_threads.append(t)
for thread in command_threads:
thread.join()
return queue_to_list(results_queue)
def _run_agent_on_managed_instance(
self, results_queue: Queue, instance_id: str, os: str, island_ip: str, timeout: float
):
ssm_client = boto3.client("ssm", self.island_aws_instance.region)
command_results = start_infection_monkey_agent(
ssm_client, instance_id, os, island_ip, timeout
)
results_queue.put(command_results)
def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]): def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]):

View File

@ -1,82 +0,0 @@
import logging
from common.cmd.cmd import Cmd
from common.cmd.cmd_runner import CmdRunner
from monkey_island.cc.server_utils.aws_cmd_runner import AwsCmdRunner
logger = logging.getLogger(__name__)
class RemoteRunAwsService:
@staticmethod
def run_aws_monkeys(instances, island_ip):
"""
Runs monkeys on the given instances
:param instances: List of instances to run on
:param island_ip: IP of island the monkey will communicate with
:return: Dictionary with instance ids as keys, and True/False as values if succeeded or not
"""
return CmdRunner.run_multiple_commands(
instances,
lambda instance: RemoteRunAwsService._run_aws_monkey_cmd_async(
instance["instance_id"],
RemoteRunAwsService._is_linux(instance["os"]),
island_ip,
),
lambda _, result: result.is_success,
)
@staticmethod
def _run_aws_monkey_cmd_async(instance_id, is_linux, island_ip):
"""
Runs a monkey remotely using AWS
:param instance_id: Instance ID of target
:param is_linux: Whether target is linux
:param island_ip: IP of the island which the instance will try to connect to
:return: Cmd
"""
cmd_text = RemoteRunAwsService._get_run_monkey_cmd_line(is_linux, island_ip)
return RemoteRunAwsService._run_aws_cmd_async(instance_id, is_linux, cmd_text)
@staticmethod
def _run_aws_cmd_async(instance_id, is_linux, cmd_line):
cmd_runner = AwsCmdRunner(is_linux, instance_id)
return Cmd(cmd_runner, cmd_runner.run_command_async(cmd_line))
@staticmethod
def _is_linux(os):
return "linux" == os
@staticmethod
def _get_run_monkey_cmd_linux_line(island_ip):
return (
r"wget --no-check-certificate https://"
+ island_ip
+ r":5000/api/agent/download/linux "
+ r"-O monkey-linux-64"
+ r"; chmod +x monkey-linux-64"
+ r"; ./monkey-linux-64"
+ r" m0nk3y -s "
+ island_ip
+ r":5000"
)
@staticmethod
def _get_run_monkey_cmd_windows_line(island_ip):
return (
r"[System.Net.ServicePointManager]::ServerCertificateValidationCallback = {"
r"$true}; (New-Object System.Net.WebClient).DownloadFile('https://"
+ island_ip
+ r":5000/api/agent/download/windows'"
+ r"'.\\monkey.exe'); "
r";Start-Process -FilePath '.\\monkey.exe' "
r"-ArgumentList 'm0nk3y -s " + island_ip + r":5000'; "
)
@staticmethod
def _get_run_monkey_cmd_line(is_linux, island_ip):
return (
RemoteRunAwsService._get_run_monkey_cmd_linux_line(island_ip)
if is_linux
else RemoteRunAwsService._get_run_monkey_cmd_windows_line(island_ip)
)

View File

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
def populate_exporter_list(): def populate_exporter_list():
manager = ReportExporterManager() manager = ReportExporterManager()
try_add_aws_exporter_to_manager(manager) # try_add_aws_exporter_to_manager(manager)
if len(manager.get_exporters_list()) != 0: if len(manager.get_exporters_list()) != 0:
logger.debug( logger.debug(

View File

@ -70,10 +70,11 @@ function AWSInstanceTable(props) {
let color = 'inherit'; let color = 'inherit';
if (r) { if (r) {
let instId = r.original.instance_id; let instId = r.original.instance_id;
let runResult = getRunResults(instId);
if (isSelected(instId)) { if (isSelected(instId)) {
color = '#ffed9f'; color = '#ffed9f';
} else if (Object.prototype.hasOwnProperty.call(props.results, instId)) { } else if (runResult) {
color = props.results[instId] ? '#00f01b' : '#f00000' color = runResult.status === "error" ? '#f00000' : '#00f01b'
} }
} }
@ -82,6 +83,15 @@ function AWSInstanceTable(props) {
}; };
} }
function getRunResults(instanceId) {
for(let result of props.results){
if (result.instance_id === instanceId){
return result
}
}
return false
}
return ( return (
<div className="data-table-container"> <div className="data-table-container">
<CheckboxTable <CheckboxTable

View File

@ -0,0 +1 @@
from .stub_di_container import StubDIContainer

View File

@ -0,0 +1,20 @@
from typing import Any, Sequence, Type, TypeVar
from unittest.mock import MagicMock
from common import DIContainer, UnregisteredTypeError
T = TypeVar("T")
class StubDIContainer(DIContainer):
def resolve(self, type_: Type[T]) -> T:
try:
return super().resolve(type_)
except UnregisteredTypeError:
return MagicMock()
def resolve_dependencies(self, type_: Type[T]) -> Sequence[Any]:
try:
return super().resolve_dependencies(type_)
except UnregisteredTypeError:
return MagicMock()

View File

@ -0,0 +1,22 @@
from queue import Queue
from common.utils.code_utils import queue_to_list
def test_empty_queue_to_empty_list():
q = Queue()
list_ = queue_to_list(q)
assert len(list_) == 0
def test_queue_to_list():
expected_list = [8, 6, 7, 5, 3, 0, 9]
q = Queue()
for i in expected_list:
q.put(i)
list_ = queue_to_list(q)
assert list_ == expected_list

View File

@ -2,7 +2,7 @@ import time
import pytest import pytest
from infection_monkey.utils.timer import Timer from common.utils import Timer
@pytest.fixture @pytest.fixture

View File

@ -3,8 +3,8 @@ from unittest.mock import MagicMock
import pytest import pytest
from common.utils import Timer
from infection_monkey.utils.decorators import request_cache from infection_monkey.utils.decorators import request_cache
from infection_monkey.utils.timer import Timer
class MockTimer(Timer): class MockTimer(Timer):

View File

@ -2,8 +2,8 @@ import io
from typing import BinaryIO from typing import BinaryIO
import pytest import pytest
from tests.common import StubDIContainer
from common import DIContainer
from monkey_island.cc.services import FileRetrievalError, IFileStorageService from monkey_island.cc.services import FileRetrievalError, IFileStorageService
FILE_NAME = "test_file" FILE_NAME = "test_file"
@ -31,8 +31,8 @@ class MockFileStorageService(IFileStorageService):
@pytest.fixture @pytest.fixture
def flask_client(build_flask_client, tmp_path): def flask_client(build_flask_client):
container = DIContainer() container = StubDIContainer()
container.register(IFileStorageService, MockFileStorageService) container.register(IFileStorageService, MockFileStorageService)
with build_flask_client(container) as flask_client: with build_flask_client(container) as flask_client:

View File

@ -2,9 +2,9 @@ import io
from typing import BinaryIO from typing import BinaryIO
import pytest import pytest
from tests.common import StubDIContainer
from tests.utils import raise_ from tests.utils import raise_
from common import DIContainer
from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
from monkey_island.cc.services import FileRetrievalError, IFileStorageService from monkey_island.cc.services import FileRetrievalError, IFileStorageService
@ -65,7 +65,7 @@ def file_storage_service():
@pytest.fixture @pytest.fixture
def flask_client(build_flask_client, file_storage_service): def flask_client(build_flask_client, file_storage_service):
container = DIContainer() container = StubDIContainer()
container.register_instance(IFileStorageService, file_storage_service) container.register_instance(IFileStorageService, file_storage_service)
with build_flask_client(container) as flask_client: with build_flask_client(container) as flask_client:

View File

@ -0,0 +1,108 @@
import json
from unittest.mock import MagicMock
import pytest
from tests.common import StubDIContainer
from monkey_island.cc.services import AWSService
from monkey_island.cc.services.aws import AWSCommandResults, AWSCommandStatus
@pytest.fixture
def mock_aws_service():
return MagicMock(spec=AWSService)
@pytest.fixture
def flask_client(build_flask_client, mock_aws_service):
container = StubDIContainer()
container.register_instance(AWSService, mock_aws_service)
with build_flask_client(container) as flask_client:
yield flask_client
def test_get_invalid_action(flask_client):
response = flask_client.get("/api/remote-monkey?action=INVALID")
assert response.text.rstrip() == "{}"
def test_get_no_action(flask_client):
response = flask_client.get("/api/remote-monkey")
assert response.text.rstrip() == "{}"
def test_get_not_aws(flask_client, mock_aws_service):
mock_aws_service.island_is_running_on_aws = MagicMock(return_value=False)
response = flask_client.get("/api/remote-monkey?action=list_aws")
assert response.text.rstrip() == '{"is_aws":false}'
def test_get_instances(flask_client, mock_aws_service):
instances = [
{"instance_id": "1", "name": "name1", "os": "linux", "ip_address": "1.1.1.1"},
{"instance_id": "2", "name": "name2", "os": "windows", "ip_address": "2.2.2.2"},
{"instance_id": "3", "name": "name3", "os": "linux", "ip_address": "3.3.3.3"},
]
mock_aws_service.island_is_running_on_aws = MagicMock(return_value=True)
mock_aws_service.get_managed_instances = MagicMock(return_value=instances)
response = flask_client.get("/api/remote-monkey?action=list_aws")
assert json.loads(response.text)["instances"] == instances
assert json.loads(response.text)["is_aws"] is True
# TODO: Test error cases for get()
def test_post_no_type(flask_client):
response = flask_client.post("/api/remote-monkey", data="{}")
assert response.status_code == 500
def test_post_invalid_type(flask_client):
response = flask_client.post("/api/remote-monkey", data='{"type": "INVALID"}')
assert response.status_code == 500
def test_post(flask_client, mock_aws_service):
request_body = json.dumps(
{
"type": "aws",
"instances": [
{"instance_id": "1", "os": "linux"},
{"instance_id": "2", "os": "linux"},
{"instance_id": "3", "os": "windows"},
],
"island_ip": "127.0.0.1",
}
)
mock_aws_service.run_agents_on_managed_instances = MagicMock(
return_value=[
AWSCommandResults("1", 0, "", "", AWSCommandStatus.SUCCESS),
AWSCommandResults("2", 0, "some_output", "", AWSCommandStatus.IN_PROGRESS),
AWSCommandResults("3", -1, "", "some_error", AWSCommandStatus.ERROR),
]
)
expected_result = [
{"instance_id": "1", "response_code": 0, "stdout": "", "stderr": "", "status": "success"},
{
"instance_id": "2",
"response_code": 0,
"stdout": "some_output",
"stderr": "",
"status": "in_progress",
},
{
"instance_id": "3",
"response_code": -1,
"stdout": "",
"stderr": "some_error",
"status": "error",
},
]
response = flask_client.post("/api/remote-monkey", data=request_body)
assert json.loads(response.text)["result"] == expected_result

View File

@ -0,0 +1,232 @@
from itertools import chain, repeat
from unittest.mock import MagicMock
import pytest
from monkey_island.cc.services.aws.aws_command_runner import (
LINUX_DOCUMENT_NAME,
WINDOWS_DOCUMENT_NAME,
AWSCommandResults,
AWSCommandStatus,
start_infection_monkey_agent,
)
TIMEOUT = 0.03
INSTANCE_ID = "BEEFFACE"
ISLAND_IP = "127.0.0.1"
"""
"commands": [
"wget --no-check-certificate "
"https://172.31.32.78:5000/api/agent/download/linux "
"-O monkey-linux-64; chmod +x "
"monkey-linux-64; ./monkey-linux-64 "
"m0nk3y -s 172.31.32.78:5000"
]
"""
@pytest.fixture
def send_command_response():
return {
"Command": {
"CloudWatchOutputConfig": {
"CloudWatchLogGroupName": "",
"CloudWatchOutputEnabled": False,
},
"CommandId": "fe3cf24f-71b7-42b9-93ca-e27c34dd0581",
"CompletedCount": 0,
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"InstanceIds": ["i-0b62d6f0b0d9d7e77"],
"OutputS3Region": "eu-central-1",
"Parameters": {"commands": []},
"Status": "Pending",
"StatusDetails": "Pending",
"TargetCount": 1,
"Targets": [],
"TimeoutSeconds": 3600,
},
"ResponseMetadata": {
"HTTPHeaders": {
"connection": "keep-alive",
"content-length": "973",
"content-type": "application/x-amz-json-1.1",
"date": "Tue, 10 May 2022 12:35:49 GMT",
"server": "Server",
"x-amzn-requestid": "110b1563-aaf0-4e09-bd23-2db465822be7",
},
"HTTPStatusCode": 200,
"RequestId": "110b1563-aaf0-4e09-bd23-2db465822be7",
"RetryAttempts": 0,
},
}
@pytest.fixture
def in_progress_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
"Status": "InProgress",
"StatusDetails": "InProgress",
}
@pytest.fixture
def success_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
"Status": "Success",
"StatusDetails": "Success",
}
@pytest.fixture
def error_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "ERROR RUNNING COMMAND",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
# NOTE: "Error" is technically not a valid value for this field, but we want to test that
# anything other than "Success" and "InProgress" is treated as an error. This is
# simpler than testing all of the different possible error cases.
"Status": "Error",
"StatusDetails": "Error",
}
@pytest.fixture(autouse=True)
def patch_timeouts(monkeypatch):
monkeypatch.setattr(
"monkey_island.cc.services.aws.aws_command_runner.STATUS_CHECK_SLEEP_TIME", 0.01
)
@pytest.fixture
def successful_mock_client(send_command_response, success_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=success_response)
return aws_client
def test_correct_instance_id(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert call_args_kwargs["InstanceIds"] == [INSTANCE_ID]
def test_linux_doc_name(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert call_args_kwargs["DocumentName"] == LINUX_DOCUMENT_NAME
def test_windows_doc_name(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert call_args_kwargs["DocumentName"] == WINDOWS_DOCUMENT_NAME
def test_linux_command(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert "wget" in call_args_kwargs["Parameters"]["commands"][0]
def test_windows_command(successful_mock_client):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT)
successful_mock_client.send_command.assert_called_once()
call_args_kwargs = successful_mock_client.send_command.call_args[1]
assert "DownloadFile" in call_args_kwargs["Parameters"]["commands"][0]
def test_multiple_status_queries(send_command_response, in_progress_response, success_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(
side_effect=chain([in_progress_response, in_progress_response], repeat(success_response))
)
command_results = start_infection_monkey_agent(
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
)
assert command_results.status == AWSCommandStatus.SUCCESS
def test_in_progress_timeout(send_command_response, in_progress_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=in_progress_response)
command_results = start_infection_monkey_agent(
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
)
assert command_results.status == AWSCommandStatus.IN_PROGRESS
def test_failed_command(send_command_response, error_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=error_response)
command_results = start_infection_monkey_agent(
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
)
assert command_results.status == AWSCommandStatus.ERROR
@pytest.mark.parametrize(
"status, success",
[
(AWSCommandStatus.SUCCESS, True),
(AWSCommandStatus.IN_PROGRESS, False),
(AWSCommandStatus.ERROR, False),
],
)
def test_command_resuls_status(status, success):
results = AWSCommandResults(INSTANCE_ID, 0, "", "", status)
assert results.success == success

View File

@ -167,3 +167,5 @@ _.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py
_.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py:64) _.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py:64)
GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57) GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57)
architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25) architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25)
response_code # unused variable (monkey/monkey_island/cc/services/aws/aws_command_runner.py:26)