forked from p15670423/monkey
Merge branch '1928-run-agent-on-remote-instance' into 1928-aws-service-refactor
This commit is contained in:
commit
c685ce3725
|
@ -1 +1 @@
|
|||
from .di_container import DIContainer
|
||||
from .di_container import DIContainer, UnregisteredTypeError
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
class Cmd(object):
|
||||
"""
|
||||
Class representing a command
|
||||
"""
|
||||
|
||||
def __init__(self, cmd_runner, cmd_id):
|
||||
self.cmd_runner = cmd_runner
|
||||
self.cmd_id = cmd_id
|
|
@ -1,10 +0,0 @@
|
|||
class CmdResult(object):
|
||||
"""
|
||||
Class representing a command result
|
||||
"""
|
||||
|
||||
def __init__(self, is_success, status_code=None, stdout=None, stderr=None):
|
||||
self.is_success = is_success
|
||||
self.status_code = status_code
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
|
@ -1,154 +0,0 @@
|
|||
import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
|
||||
from common.cmd.cmd_result import CmdResult
|
||||
from common.cmd.cmd_status import CmdStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CmdRunner(object):
|
||||
"""
|
||||
Interface for running commands on a remote machine
|
||||
|
||||
Since these classes are a bit complex, I provide a list of common terminology and formats:
|
||||
* command line - a command line. e.g. 'echo hello'
|
||||
* command - represent a single command which was already run. Always of type Cmd
|
||||
* command id - any unique identifier of a command which was already run
|
||||
* command result - represents the result of running a command. Always of type CmdResult
|
||||
* command status - represents the current status of a command. Always of type CmdStatus
|
||||
* command info - Any consistent structure representing additional information of a command
|
||||
which was already run
|
||||
* instance - a machine that commands will be run on. Can be any dictionary with 'instance_id'
|
||||
as a field
|
||||
* instance_id - any unique identifier of an instance (machine). Can be of any format
|
||||
"""
|
||||
|
||||
# Default command timeout in seconds
|
||||
DEFAULT_TIMEOUT = 5
|
||||
# Time to sleep when waiting on commands.
|
||||
WAIT_SLEEP_TIME = 1
|
||||
|
||||
def __init__(self, is_linux):
|
||||
self.is_linux = is_linux
|
||||
|
||||
@staticmethod
|
||||
def run_multiple_commands(instances, inst_to_cmd, inst_n_cmd_res_to_res):
|
||||
"""
|
||||
Run multiple commands on various instances
|
||||
:param instances: List of instances.
|
||||
:param inst_to_cmd: Function which receives an instance, runs a command asynchronously
|
||||
and returns Cmd
|
||||
:param inst_n_cmd_res_to_res: Function which receives an instance and CmdResult
|
||||
and returns a parsed result (of any format)
|
||||
:return: Dictionary with 'instance_id' as key and parsed result as value
|
||||
"""
|
||||
command_instance_dict = {}
|
||||
|
||||
for instance in instances:
|
||||
command = inst_to_cmd(instance)
|
||||
command_instance_dict[command] = instance
|
||||
|
||||
instance_results = {}
|
||||
command_result_pairs = CmdRunner.wait_commands(list(command_instance_dict.keys()))
|
||||
for command, result in command_result_pairs:
|
||||
instance = command_instance_dict[command]
|
||||
instance_results[instance["instance_id"]] = inst_n_cmd_res_to_res(instance, result)
|
||||
|
||||
return instance_results
|
||||
|
||||
@abstractmethod
|
||||
def run_command_async(self, command_line):
|
||||
"""
|
||||
Runs the given command on the remote machine asynchronously.
|
||||
:param command_line: The command line to run
|
||||
:return: Command ID (in any format)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def wait_commands(commands, timeout=DEFAULT_TIMEOUT):
|
||||
"""
|
||||
Waits on all commands up to given timeout
|
||||
:param commands: list of commands (of type Cmd)
|
||||
:param timeout: Timeout in seconds for command.
|
||||
:return: commands and their results (tuple of Command and CmdResult)
|
||||
"""
|
||||
init_time = time.time()
|
||||
curr_time = init_time
|
||||
|
||||
results = []
|
||||
# TODO: Use timer.Timer
|
||||
while (curr_time - init_time < timeout) and (len(commands) != 0):
|
||||
for command in list(
|
||||
commands
|
||||
): # list(commands) clones the list. We do so because we remove items inside
|
||||
CmdRunner._process_command(command, commands, results, True)
|
||||
|
||||
time.sleep(CmdRunner.WAIT_SLEEP_TIME)
|
||||
curr_time = time.time()
|
||||
|
||||
for command in list(commands):
|
||||
CmdRunner._process_command(command, commands, results, False)
|
||||
|
||||
for command, result in results:
|
||||
if not result.is_success:
|
||||
logger.error(
|
||||
f"The command with id: {str(command.cmd_id)} failed. "
|
||||
f"Status code: {str(result.status_code)}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def query_command(self, command_id):
|
||||
"""
|
||||
Queries the already run command for more info
|
||||
:param command_id: The command ID to query
|
||||
:return: Command info (in any format)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_command_result(self, command_info):
|
||||
"""
|
||||
Gets the result of the already run command
|
||||
:param command_info: The command info of the command to get the result of
|
||||
:return: CmdResult
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_command_status(self, command_info):
|
||||
"""
|
||||
Gets the status of the already run command
|
||||
:param command_info: The command info of the command to get the result of
|
||||
:return: CmdStatus
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _process_command(command, commands, results, should_process_only_finished):
|
||||
"""
|
||||
Removes the command from the list, processes its result and appends to results
|
||||
:param command: Command to process. Must be in commands.
|
||||
:param commands: List of unprocessed commands.
|
||||
:param results: List of command results.
|
||||
:param should_process_only_finished: If True, processes only if command finished.
|
||||
:return: None
|
||||
"""
|
||||
c_runner = command.cmd_runner
|
||||
c_id = command.cmd_id
|
||||
try:
|
||||
command_info = c_runner.query_command(c_id)
|
||||
if (not should_process_only_finished) or c_runner.get_command_status(
|
||||
command_info
|
||||
) != CmdStatus.IN_PROGRESS:
|
||||
commands.remove(command)
|
||||
results.append((command, c_runner.get_command_result(command_info)))
|
||||
except Exception:
|
||||
logger.exception("Exception while querying command: `%s`", str(c_id))
|
||||
if not should_process_only_finished:
|
||||
commands.remove(command)
|
||||
results.append((command, CmdResult(False)))
|
|
@ -1,7 +0,0 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class CmdStatus(Enum):
|
||||
IN_PROGRESS = 0
|
||||
SUCCESS = 1
|
||||
FAILURE = 2
|
|
@ -0,0 +1 @@
|
|||
from .timer import Timer
|
|
@ -1,3 +1,7 @@
|
|||
import queue
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
class abstractstatic(staticmethod):
|
||||
__slots__ = ()
|
||||
|
||||
|
@ -15,3 +19,14 @@ class Singleton(type):
|
|||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
def queue_to_list(q: queue.Queue) -> List[Any]:
|
||||
list_ = []
|
||||
try:
|
||||
while True:
|
||||
list_.append(q.get_nowait())
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
return list_
|
||||
|
|
|
@ -3,6 +3,7 @@ import time
|
|||
from pathlib import PurePath
|
||||
|
||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||
from common.utils import Timer
|
||||
from infection_monkey.exploit.log4shell_utils import (
|
||||
LINUX_EXPLOIT_TEMPLATE_PATH,
|
||||
WINDOWS_EXPLOIT_TEMPLATE_PATH,
|
||||
|
@ -21,7 +22,6 @@ from infection_monkey.network.tools import get_interface_to_target
|
|||
from infection_monkey.utils.commands import build_monkey_commandline
|
||||
from infection_monkey.utils.monkey_dir import get_monkey_dir_path
|
||||
from infection_monkey.utils.threading import interruptible_iter
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from pathlib import PurePath
|
|||
import paramiko
|
||||
|
||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||
from common.utils import Timer
|
||||
from common.utils.attack_utils import ScanStatus
|
||||
from common.utils.exceptions import FailedExploitationError
|
||||
from infection_monkey.exploit.HostExploiter import HostExploiter
|
||||
|
@ -17,7 +18,6 @@ from infection_monkey.telemetry.attack.t1222_telem import T1222Telem
|
|||
from infection_monkey.utils.brute_force import generate_identity_secret_pairs
|
||||
from infection_monkey.utils.commands import build_monkey_commandline
|
||||
from infection_monkey.utils.threading import interruptible_iter
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SSH_PORT = 22
|
||||
|
|
|
@ -3,6 +3,7 @@ import threading
|
|||
import time
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from common.utils import Timer
|
||||
from infection_monkey.credential_store import ICredentialsStore
|
||||
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||
from infection_monkey.i_master import IMaster
|
||||
|
@ -13,7 +14,6 @@ from infection_monkey.telemetry.credentials_telem import CredentialsTelem
|
|||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
|
||||
from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
from . import Exploiter, IPScanner, Propagator
|
||||
from .option_parsing import custom_pba_is_enabled
|
||||
|
|
|
@ -4,9 +4,9 @@ import socket
|
|||
import time
|
||||
from typing import Iterable, Mapping, Tuple
|
||||
|
||||
from common.utils import Timer
|
||||
from infection_monkey.i_puppet import PortScanData, PortStatus
|
||||
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@ import queue
|
|||
import threading
|
||||
from typing import Dict
|
||||
|
||||
from common.utils import Timer
|
||||
from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem
|
||||
from infection_monkey.telemetry.i_telem import ITelem
|
||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
DEFAULT_PERIOD = 5
|
||||
WAKES_PER_PERIOD = 4
|
||||
|
|
|
@ -4,11 +4,11 @@ import struct
|
|||
import time
|
||||
from threading import Event, Thread
|
||||
|
||||
from common.utils import Timer
|
||||
from infection_monkey.network.firewall import app as firewall
|
||||
from infection_monkey.network.info import get_free_tcp_port, local_ips
|
||||
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
|
||||
from infection_monkey.transport.base import get_last_serve_time
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import threading
|
||||
from functools import wraps
|
||||
|
||||
from .timer import Timer
|
||||
from common.utils import Timer
|
||||
|
||||
|
||||
def request_cache(ttl: float):
|
||||
|
|
|
@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from common import DIContainer
|
||||
from monkey_island.cc.database import database, mongo
|
||||
from monkey_island.cc.resources import RemoteRun
|
||||
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
|
||||
from monkey_island.cc.resources.attack.attack_report import AttackReport
|
||||
from monkey_island.cc.resources.auth.auth import Authenticate, init_jwt
|
||||
|
@ -38,7 +39,6 @@ from monkey_island.cc.resources.pba_file_download import PBAFileDownload
|
|||
from monkey_island.cc.resources.pba_file_upload import FileUpload
|
||||
from monkey_island.cc.resources.propagation_credentials import PropagationCredentials
|
||||
from monkey_island.cc.resources.ransomware_report import RansomwareReport
|
||||
from monkey_island.cc.resources.remote_run import RemoteRun
|
||||
from monkey_island.cc.resources.root import Root
|
||||
from monkey_island.cc.resources.security_report import SecurityReport
|
||||
from monkey_island.cc.resources.telemetry import Telemetry
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .remote_run import RemoteRun
|
|
@ -1,12 +1,13 @@
|
|||
import json
|
||||
from typing import Sequence
|
||||
|
||||
import flask_restful
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
from flask import jsonify, make_response, request
|
||||
|
||||
from monkey_island.cc.resources.auth.auth import jwt_required
|
||||
from monkey_island.cc.services import aws_service
|
||||
from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService
|
||||
from monkey_island.cc.services import AWSService
|
||||
from monkey_island.cc.services.aws import AWSCommandResults
|
||||
|
||||
CLIENT_ERROR_FORMAT = (
|
||||
"ClientError, error message: '{}'. Probably, the IAM role that has been associated with the "
|
||||
|
@ -19,20 +20,18 @@ NO_CREDS_ERROR_FORMAT = (
|
|||
|
||||
|
||||
class RemoteRun(flask_restful.Resource):
|
||||
def run_aws_monkeys(self, request_body):
|
||||
instances = request_body.get("instances")
|
||||
island_ip = request_body.get("island_ip")
|
||||
return RemoteRunAwsService.run_aws_monkeys(instances, island_ip)
|
||||
def __init__(self, aws_service: AWSService):
|
||||
self._aws_service = aws_service
|
||||
|
||||
@jwt_required
|
||||
def get(self):
|
||||
action = request.args.get("action")
|
||||
if action == "list_aws":
|
||||
is_aws = aws_service.is_on_aws()
|
||||
is_aws = self._aws_service.island_is_running_on_aws()
|
||||
resp = {"is_aws": is_aws}
|
||||
if is_aws:
|
||||
try:
|
||||
resp["instances"] = aws_service.get_instances()
|
||||
resp["instances"] = self._aws_service.get_managed_instances()
|
||||
except NoCredentialsError as e:
|
||||
resp["error"] = NO_CREDS_ERROR_FORMAT.format(e)
|
||||
return jsonify(resp)
|
||||
|
@ -46,11 +45,32 @@ class RemoteRun(flask_restful.Resource):
|
|||
@jwt_required
|
||||
def post(self):
|
||||
body = json.loads(request.data)
|
||||
resp = {}
|
||||
if body.get("type") == "aws":
|
||||
result = self.run_aws_monkeys(body)
|
||||
resp["result"] = result
|
||||
return jsonify(resp)
|
||||
results = self.run_aws_monkeys(body)
|
||||
return RemoteRun._encode_results(results)
|
||||
|
||||
# default action
|
||||
return make_response({"error": "Invalid action"}, 500)
|
||||
|
||||
def run_aws_monkeys(self, request_body) -> Sequence[AWSCommandResults]:
|
||||
instances = request_body.get("instances")
|
||||
island_ip = request_body.get("island_ip")
|
||||
|
||||
return self._aws_service.run_agents_on_managed_instances(instances, island_ip)
|
||||
|
||||
@staticmethod
|
||||
def _encode_results(results: Sequence[AWSCommandResults]):
|
||||
result = list(map(RemoteRun._aws_command_results_to_encodable_dict, results))
|
||||
response = {"result": result}
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
@staticmethod
|
||||
def _aws_command_results_to_encodable_dict(aws_command_results: AWSCommandResults):
|
||||
return {
|
||||
"instance_id": aws_command_results.instance_id,
|
||||
"response_code": aws_command_results.response_code,
|
||||
"stdout": aws_command_results.stdout,
|
||||
"stderr": aws_command_results.stderr,
|
||||
"status": aws_command_results.status.name.lower(),
|
||||
}
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
from common.cmd.cmd_result import CmdResult
|
||||
|
||||
|
||||
class AwsCmdResult(CmdResult):
|
||||
"""
|
||||
Class representing an AWS command result
|
||||
"""
|
||||
|
||||
def __init__(self, command_info):
|
||||
super(AwsCmdResult, self).__init__(
|
||||
self.is_successful(command_info, True),
|
||||
command_info["ResponseCode"],
|
||||
command_info["StandardOutputContent"],
|
||||
command_info["StandardErrorContent"],
|
||||
)
|
||||
self.command_info = command_info
|
||||
|
||||
@staticmethod
|
||||
def is_successful(command_info, is_timeout=False):
|
||||
"""
|
||||
Determines whether the command was successful. If it timed out and was still in progress,
|
||||
we assume it worked.
|
||||
:param command_info: Command info struct (returned by ssm.get_command_invocation)
|
||||
:param is_timeout: Whether the given command timed out
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
return (command_info["Status"] == "Success") or (
|
||||
is_timeout and (command_info["Status"] == "InProgress")
|
||||
)
|
|
@ -1,45 +0,0 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from common.cmd.cmd_runner import CmdRunner
|
||||
from common.cmd.cmd_status import CmdStatus
|
||||
from monkey_island.cc.server_utils.aws_cmd_result import AwsCmdResult
|
||||
from monkey_island.cc.services import aws_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AwsCmdRunner(CmdRunner):
|
||||
"""
|
||||
Class for running commands on a remote AWS machine
|
||||
"""
|
||||
|
||||
def __init__(self, is_linux, instance_id, region=None):
|
||||
super(AwsCmdRunner, self).__init__(is_linux)
|
||||
self.instance_id = instance_id
|
||||
self.region = region
|
||||
self.ssm = aws_service.get_client("ssm", region)
|
||||
|
||||
def query_command(self, command_id):
|
||||
time.sleep(2)
|
||||
return self.ssm.get_command_invocation(CommandId=command_id, InstanceId=self.instance_id)
|
||||
|
||||
def get_command_result(self, command_info):
|
||||
return AwsCmdResult(command_info)
|
||||
|
||||
def get_command_status(self, command_info):
|
||||
if command_info["Status"] == "InProgress":
|
||||
return CmdStatus.IN_PROGRESS
|
||||
elif command_info["Status"] == "Success":
|
||||
return CmdStatus.SUCCESS
|
||||
else:
|
||||
return CmdStatus.FAILURE
|
||||
|
||||
def run_command_async(self, command_line):
|
||||
doc_name = "AWS-RunShellScript" if self.is_linux else "AWS-RunPowerShellScript"
|
||||
command_res = self.ssm.send_command(
|
||||
DocumentName=doc_name,
|
||||
Parameters={"commands": [command_line]},
|
||||
InstanceIds=[self.instance_id],
|
||||
)
|
||||
return command_res["Command"]["CommandId"]
|
|
@ -4,4 +4,8 @@ from .directory_file_storage_service import DirectoryFileStorageService
|
|||
from .authentication.authentication_service import AuthenticationService
|
||||
from .authentication.json_file_user_datastore import JsonFileUserDatastore
|
||||
|
||||
from .aws_service import AWSService
|
||||
from .aws import AWSService
|
||||
|
||||
# TODO: This is a temporary import to keep some tests passing. Remove it before merging #1928 to
|
||||
# develop.
|
||||
from .aws import aws_service
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .aws_service import AWSService
|
||||
from .aws_command_runner import AWSCommandResults, AWSCommandStatus
|
|
@ -0,0 +1,149 @@
|
|||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
|
||||
import botocore
|
||||
|
||||
from common.utils import Timer
|
||||
|
||||
STATUS_CHECK_SLEEP_TIME = 1
|
||||
LINUX_DOCUMENT_NAME = "AWS-RunShellScript"
|
||||
WINDOWS_DOCUMENT_NAME = "AWS-RunPowerShellScript"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AWSCommandStatus(Enum):
|
||||
SUCCESS = auto()
|
||||
IN_PROGRESS = auto()
|
||||
ERROR = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AWSCommandResults:
|
||||
instance_id: str
|
||||
response_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
status: AWSCommandStatus
|
||||
|
||||
@property
|
||||
def success(self):
|
||||
return self.status == AWSCommandStatus.SUCCESS
|
||||
|
||||
|
||||
def start_infection_monkey_agent(
|
||||
aws_client: botocore.client.BaseClient,
|
||||
target_instance_id: str,
|
||||
target_os: str,
|
||||
island_ip: str,
|
||||
timeout: float,
|
||||
) -> AWSCommandResults:
|
||||
"""
|
||||
Run a command on a remote AWS instance
|
||||
"""
|
||||
command = _get_run_agent_command(target_os, island_ip)
|
||||
command_id = _run_command_async(aws_client, target_instance_id, target_os, command)
|
||||
|
||||
_wait_for_command_to_complete(aws_client, target_instance_id, command_id, timeout)
|
||||
return _fetch_command_results(aws_client, target_instance_id, command_id)
|
||||
|
||||
|
||||
def _get_run_agent_command(target_os: str, island_ip: str):
|
||||
if target_os == "linux":
|
||||
return _get_run_monkey_cmd_linux_line(island_ip)
|
||||
|
||||
return _get_run_monkey_cmd_windows_line(island_ip)
|
||||
|
||||
|
||||
def _get_run_monkey_cmd_linux_line(island_ip):
|
||||
binary_name = "monkey-linux-64"
|
||||
|
||||
download_url = f"https://{island_ip}:5000/api/agent/download/linux"
|
||||
download_cmd = f"wget --no-check-certificate {download_url} -O {binary_name}"
|
||||
|
||||
chmod_cmd = f"chmod +x {binary_name}"
|
||||
run_agent_cmd = f"./{binary_name} m0nk3y -s {island_ip}:5000"
|
||||
|
||||
return f"{download_cmd}; {chmod_cmd}; {run_agent_cmd}"
|
||||
|
||||
|
||||
def _get_run_monkey_cmd_windows_line(island_ip):
|
||||
agent_exe_path = r".\\monkey.exe"
|
||||
|
||||
ignore_ssl_errors_cmd = (
|
||||
"[System.Net.ServicePointManager]::ServerCertificateValidationCallback = {$true}"
|
||||
)
|
||||
|
||||
download_url = f"https://{island_ip}:5000/api/agent/download/windows"
|
||||
download_cmd = (
|
||||
f"(New-Object System.Net.WebClient).DownloadFile('{download_url}', '{agent_exe_path}')"
|
||||
)
|
||||
|
||||
run_agent_cmd = (
|
||||
f"Start-Process -FilePath '{agent_exe_path}' -ArgumentList 'm0nk3y -s {island_ip}:5000'"
|
||||
)
|
||||
|
||||
return f"{ignore_ssl_errors_cmd}; {download_cmd}; {run_agent_cmd};"
|
||||
|
||||
|
||||
def _run_command_async(
|
||||
aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, command: str
|
||||
):
|
||||
doc_name = LINUX_DOCUMENT_NAME if target_os == "linux" else WINDOWS_DOCUMENT_NAME
|
||||
|
||||
logger.debug(f'Running command on {target_instance_id} -- {doc_name}: "{command}"')
|
||||
command_response = aws_client.send_command(
|
||||
DocumentName=doc_name,
|
||||
Parameters={"commands": [command]},
|
||||
InstanceIds=[target_instance_id],
|
||||
)
|
||||
|
||||
command_id = command_response["Command"]["CommandId"]
|
||||
logger.debug(
|
||||
f"Started command on AWS instance {target_instance_id} with command ID {command_id}"
|
||||
)
|
||||
|
||||
return command_id
|
||||
|
||||
|
||||
def _wait_for_command_to_complete(
|
||||
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str, timeout: float
|
||||
):
|
||||
timer = Timer()
|
||||
timer.set(timeout)
|
||||
|
||||
while not timer.is_expired():
|
||||
time.sleep(STATUS_CHECK_SLEEP_TIME)
|
||||
|
||||
command_results = _fetch_command_results(aws_client, target_instance_id, command_id)
|
||||
logger.debug(f"Command {command_id} status: {command_results.status.name}")
|
||||
|
||||
if command_results.status != AWSCommandStatus.IN_PROGRESS:
|
||||
return
|
||||
|
||||
|
||||
def _fetch_command_results(
|
||||
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str
|
||||
) -> AWSCommandResults:
|
||||
command_results = aws_client.get_command_invocation(
|
||||
CommandId=command_id, InstanceId=target_instance_id
|
||||
)
|
||||
command_status = command_results["Status"]
|
||||
logger.debug(f"Command {command_id} status: {command_status}")
|
||||
|
||||
if command_status == "Success":
|
||||
aws_command_result_status = AWSCommandStatus.SUCCESS
|
||||
elif command_status == "InProgress":
|
||||
aws_command_result_status = AWSCommandStatus.IN_PROGRESS
|
||||
else:
|
||||
aws_command_result_status = AWSCommandStatus.ERROR
|
||||
|
||||
return AWSCommandResults(
|
||||
target_instance_id,
|
||||
command_results["ResponseCode"],
|
||||
command_results["StandardOutputContent"],
|
||||
command_results["StandardErrorContent"],
|
||||
aws_command_result_status,
|
||||
)
|
|
@ -1,11 +1,17 @@
|
|||
import logging
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Any, Iterable, Mapping, Sequence
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
|
||||
from common.aws.aws_instance import AWSInstance
|
||||
from common.utils.code_utils import queue_to_list
|
||||
|
||||
from .aws_command_runner import AWSCommandResults, start_infection_monkey_agent
|
||||
|
||||
DEFAULT_REMOTE_COMMAND_TIMEOUT = 5
|
||||
INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList"
|
||||
INSTANCE_ID_KEY = "InstanceId"
|
||||
COMPUTER_NAME_KEY = "ComputerName"
|
||||
|
@ -58,20 +64,52 @@ class AWSService:
|
|||
:raises: botocore.exceptions.ClientError if can't describe local instance information.
|
||||
:return: All visible instances from this instance
|
||||
"""
|
||||
local_ssm_client = boto3.client("ssm", self.island_aws_instance.region)
|
||||
ssm_client = boto3.client("ssm", self.island_aws_instance.region)
|
||||
try:
|
||||
response = local_ssm_client.describe_instance_information()
|
||||
response = ssm_client.describe_instance_information()
|
||||
return response[INSTANCE_INFORMATION_LIST_KEY]
|
||||
except botocore.exceptions.ClientError as err:
|
||||
logger.warning("AWS client error while trying to get manage dinstances: {err}")
|
||||
raise err
|
||||
|
||||
def run_agent_on_managed_instances(self, instance_ids: Iterable[str]):
|
||||
for id_ in instance_ids:
|
||||
self._run_agent_on_managed_instance(id_)
|
||||
def run_agents_on_managed_instances(
|
||||
self,
|
||||
instances: Iterable[Mapping[str, str]],
|
||||
island_ip: str,
|
||||
timeout: float = DEFAULT_REMOTE_COMMAND_TIMEOUT,
|
||||
) -> Sequence[AWSCommandResults]:
|
||||
"""
|
||||
Run an agent on one or more managed AWS instances.
|
||||
:param instances: An iterable of instances that the agent will be run on
|
||||
:param island_ip: The IP address of the Island to pass to the new agents
|
||||
:param timeout: The maximum number of seconds to wait for the agents to start
|
||||
:return: A sequence of AWSCommandResults
|
||||
"""
|
||||
|
||||
def _run_agent_on_managed_instance(self, instance_id: str):
|
||||
pass
|
||||
results_queue = Queue()
|
||||
command_threads = []
|
||||
for i in instances:
|
||||
t = Thread(
|
||||
target=self._run_agent_on_managed_instance,
|
||||
args=(results_queue, i["instance_id"], i["os"], island_ip, timeout),
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
command_threads.append(t)
|
||||
|
||||
for thread in command_threads:
|
||||
thread.join()
|
||||
|
||||
return queue_to_list(results_queue)
|
||||
|
||||
def _run_agent_on_managed_instance(
|
||||
self, results_queue: Queue, instance_id: str, os: str, island_ip: str, timeout: float
|
||||
):
|
||||
ssm_client = boto3.client("ssm", self.island_aws_instance.region)
|
||||
command_results = start_infection_monkey_agent(
|
||||
ssm_client, instance_id, os, island_ip, timeout
|
||||
)
|
||||
results_queue.put(command_results)
|
||||
|
||||
|
||||
def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]):
|
|
@ -1,82 +0,0 @@
|
|||
import logging
|
||||
|
||||
from common.cmd.cmd import Cmd
|
||||
from common.cmd.cmd_runner import CmdRunner
|
||||
from monkey_island.cc.server_utils.aws_cmd_runner import AwsCmdRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteRunAwsService:
|
||||
@staticmethod
|
||||
def run_aws_monkeys(instances, island_ip):
|
||||
"""
|
||||
Runs monkeys on the given instances
|
||||
:param instances: List of instances to run on
|
||||
:param island_ip: IP of island the monkey will communicate with
|
||||
:return: Dictionary with instance ids as keys, and True/False as values if succeeded or not
|
||||
"""
|
||||
return CmdRunner.run_multiple_commands(
|
||||
instances,
|
||||
lambda instance: RemoteRunAwsService._run_aws_monkey_cmd_async(
|
||||
instance["instance_id"],
|
||||
RemoteRunAwsService._is_linux(instance["os"]),
|
||||
island_ip,
|
||||
),
|
||||
lambda _, result: result.is_success,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _run_aws_monkey_cmd_async(instance_id, is_linux, island_ip):
|
||||
"""
|
||||
Runs a monkey remotely using AWS
|
||||
:param instance_id: Instance ID of target
|
||||
:param is_linux: Whether target is linux
|
||||
:param island_ip: IP of the island which the instance will try to connect to
|
||||
:return: Cmd
|
||||
"""
|
||||
cmd_text = RemoteRunAwsService._get_run_monkey_cmd_line(is_linux, island_ip)
|
||||
return RemoteRunAwsService._run_aws_cmd_async(instance_id, is_linux, cmd_text)
|
||||
|
||||
@staticmethod
|
||||
def _run_aws_cmd_async(instance_id, is_linux, cmd_line):
|
||||
cmd_runner = AwsCmdRunner(is_linux, instance_id)
|
||||
return Cmd(cmd_runner, cmd_runner.run_command_async(cmd_line))
|
||||
|
||||
@staticmethod
|
||||
def _is_linux(os):
|
||||
return "linux" == os
|
||||
|
||||
@staticmethod
|
||||
def _get_run_monkey_cmd_linux_line(island_ip):
|
||||
return (
|
||||
r"wget --no-check-certificate https://"
|
||||
+ island_ip
|
||||
+ r":5000/api/agent/download/linux "
|
||||
+ r"-O monkey-linux-64"
|
||||
+ r"; chmod +x monkey-linux-64"
|
||||
+ r"; ./monkey-linux-64"
|
||||
+ r" m0nk3y -s "
|
||||
+ island_ip
|
||||
+ r":5000"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_run_monkey_cmd_windows_line(island_ip):
|
||||
return (
|
||||
r"[System.Net.ServicePointManager]::ServerCertificateValidationCallback = {"
|
||||
r"$true}; (New-Object System.Net.WebClient).DownloadFile('https://"
|
||||
+ island_ip
|
||||
+ r":5000/api/agent/download/windows'"
|
||||
+ r"'.\\monkey.exe'); "
|
||||
r";Start-Process -FilePath '.\\monkey.exe' "
|
||||
r"-ArgumentList 'm0nk3y -s " + island_ip + r":5000'; "
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_run_monkey_cmd_line(is_linux, island_ip):
|
||||
return (
|
||||
RemoteRunAwsService._get_run_monkey_cmd_linux_line(island_ip)
|
||||
if is_linux
|
||||
else RemoteRunAwsService._get_run_monkey_cmd_windows_line(island_ip)
|
||||
)
|
|
@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def populate_exporter_list():
|
||||
manager = ReportExporterManager()
|
||||
try_add_aws_exporter_to_manager(manager)
|
||||
# try_add_aws_exporter_to_manager(manager)
|
||||
|
||||
if len(manager.get_exporters_list()) != 0:
|
||||
logger.debug(
|
||||
|
|
|
@ -70,10 +70,11 @@ function AWSInstanceTable(props) {
|
|||
let color = 'inherit';
|
||||
if (r) {
|
||||
let instId = r.original.instance_id;
|
||||
let runResult = getRunResults(instId);
|
||||
if (isSelected(instId)) {
|
||||
color = '#ffed9f';
|
||||
} else if (Object.prototype.hasOwnProperty.call(props.results, instId)) {
|
||||
color = props.results[instId] ? '#00f01b' : '#f00000'
|
||||
} else if (runResult) {
|
||||
color = runResult.status === "error" ? '#f00000' : '#00f01b'
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -82,6 +83,15 @@ function AWSInstanceTable(props) {
|
|||
};
|
||||
}
|
||||
|
||||
function getRunResults(instanceId) {
|
||||
for(let result of props.results){
|
||||
if (result.instance_id === instanceId){
|
||||
return result
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="data-table-container">
|
||||
<CheckboxTable
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .stub_di_container import StubDIContainer
|
|
@ -0,0 +1,20 @@
|
|||
from typing import Any, Sequence, Type, TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from common import DIContainer, UnregisteredTypeError
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StubDIContainer(DIContainer):
|
||||
def resolve(self, type_: Type[T]) -> T:
|
||||
try:
|
||||
return super().resolve(type_)
|
||||
except UnregisteredTypeError:
|
||||
return MagicMock()
|
||||
|
||||
def resolve_dependencies(self, type_: Type[T]) -> Sequence[Any]:
|
||||
try:
|
||||
return super().resolve_dependencies(type_)
|
||||
except UnregisteredTypeError:
|
||||
return MagicMock()
|
|
@ -0,0 +1,22 @@
|
|||
from queue import Queue
|
||||
|
||||
from common.utils.code_utils import queue_to_list
|
||||
|
||||
|
||||
def test_empty_queue_to_empty_list():
|
||||
q = Queue()
|
||||
|
||||
list_ = queue_to_list(q)
|
||||
|
||||
assert len(list_) == 0
|
||||
|
||||
|
||||
def test_queue_to_list():
|
||||
expected_list = [8, 6, 7, 5, 3, 0, 9]
|
||||
q = Queue()
|
||||
for i in expected_list:
|
||||
q.put(i)
|
||||
|
||||
list_ = queue_to_list(q)
|
||||
|
||||
assert list_ == expected_list
|
|
@ -2,7 +2,7 @@ import time
|
|||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.utils.timer import Timer
|
||||
from common.utils import Timer
|
||||
|
||||
|
||||
@pytest.fixture
|
|
@ -3,8 +3,8 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from common.utils import Timer
|
||||
from infection_monkey.utils.decorators import request_cache
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
|
||||
class MockTimer(Timer):
|
||||
|
|
|
@ -2,8 +2,8 @@ import io
|
|||
from typing import BinaryIO
|
||||
|
||||
import pytest
|
||||
from tests.common import StubDIContainer
|
||||
|
||||
from common import DIContainer
|
||||
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||
|
||||
FILE_NAME = "test_file"
|
||||
|
@ -31,8 +31,8 @@ class MockFileStorageService(IFileStorageService):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_client(build_flask_client, tmp_path):
|
||||
container = DIContainer()
|
||||
def flask_client(build_flask_client):
|
||||
container = StubDIContainer()
|
||||
container.register(IFileStorageService, MockFileStorageService)
|
||||
|
||||
with build_flask_client(container) as flask_client:
|
||||
|
|
|
@ -2,9 +2,9 @@ import io
|
|||
from typing import BinaryIO
|
||||
|
||||
import pytest
|
||||
from tests.common import StubDIContainer
|
||||
from tests.utils import raise_
|
||||
|
||||
from common import DIContainer
|
||||
from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
|
||||
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||
|
||||
|
@ -65,7 +65,7 @@ def file_storage_service():
|
|||
|
||||
@pytest.fixture
|
||||
def flask_client(build_flask_client, file_storage_service):
|
||||
container = DIContainer()
|
||||
container = StubDIContainer()
|
||||
container.register_instance(IFileStorageService, file_storage_service)
|
||||
|
||||
with build_flask_client(container) as flask_client:
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from tests.common import StubDIContainer
|
||||
|
||||
from monkey_island.cc.services import AWSService
|
||||
from monkey_island.cc.services.aws import AWSCommandResults, AWSCommandStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_aws_service():
|
||||
return MagicMock(spec=AWSService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_client(build_flask_client, mock_aws_service):
|
||||
container = StubDIContainer()
|
||||
container.register_instance(AWSService, mock_aws_service)
|
||||
|
||||
with build_flask_client(container) as flask_client:
|
||||
yield flask_client
|
||||
|
||||
|
||||
def test_get_invalid_action(flask_client):
|
||||
response = flask_client.get("/api/remote-monkey?action=INVALID")
|
||||
assert response.text.rstrip() == "{}"
|
||||
|
||||
|
||||
def test_get_no_action(flask_client):
|
||||
response = flask_client.get("/api/remote-monkey")
|
||||
assert response.text.rstrip() == "{}"
|
||||
|
||||
|
||||
def test_get_not_aws(flask_client, mock_aws_service):
|
||||
mock_aws_service.island_is_running_on_aws = MagicMock(return_value=False)
|
||||
response = flask_client.get("/api/remote-monkey?action=list_aws")
|
||||
assert response.text.rstrip() == '{"is_aws":false}'
|
||||
|
||||
|
||||
def test_get_instances(flask_client, mock_aws_service):
|
||||
instances = [
|
||||
{"instance_id": "1", "name": "name1", "os": "linux", "ip_address": "1.1.1.1"},
|
||||
{"instance_id": "2", "name": "name2", "os": "windows", "ip_address": "2.2.2.2"},
|
||||
{"instance_id": "3", "name": "name3", "os": "linux", "ip_address": "3.3.3.3"},
|
||||
]
|
||||
mock_aws_service.island_is_running_on_aws = MagicMock(return_value=True)
|
||||
mock_aws_service.get_managed_instances = MagicMock(return_value=instances)
|
||||
|
||||
response = flask_client.get("/api/remote-monkey?action=list_aws")
|
||||
|
||||
assert json.loads(response.text)["instances"] == instances
|
||||
assert json.loads(response.text)["is_aws"] is True
|
||||
|
||||
|
||||
# TODO: Test error cases for get()
|
||||
|
||||
|
||||
def test_post_no_type(flask_client):
|
||||
response = flask_client.post("/api/remote-monkey", data="{}")
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_post_invalid_type(flask_client):
|
||||
response = flask_client.post("/api/remote-monkey", data='{"type": "INVALID"}')
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_post(flask_client, mock_aws_service):
|
||||
request_body = json.dumps(
|
||||
{
|
||||
"type": "aws",
|
||||
"instances": [
|
||||
{"instance_id": "1", "os": "linux"},
|
||||
{"instance_id": "2", "os": "linux"},
|
||||
{"instance_id": "3", "os": "windows"},
|
||||
],
|
||||
"island_ip": "127.0.0.1",
|
||||
}
|
||||
)
|
||||
mock_aws_service.run_agents_on_managed_instances = MagicMock(
|
||||
return_value=[
|
||||
AWSCommandResults("1", 0, "", "", AWSCommandStatus.SUCCESS),
|
||||
AWSCommandResults("2", 0, "some_output", "", AWSCommandStatus.IN_PROGRESS),
|
||||
AWSCommandResults("3", -1, "", "some_error", AWSCommandStatus.ERROR),
|
||||
]
|
||||
)
|
||||
expected_result = [
|
||||
{"instance_id": "1", "response_code": 0, "stdout": "", "stderr": "", "status": "success"},
|
||||
{
|
||||
"instance_id": "2",
|
||||
"response_code": 0,
|
||||
"stdout": "some_output",
|
||||
"stderr": "",
|
||||
"status": "in_progress",
|
||||
},
|
||||
{
|
||||
"instance_id": "3",
|
||||
"response_code": -1,
|
||||
"stdout": "",
|
||||
"stderr": "some_error",
|
||||
"status": "error",
|
||||
},
|
||||
]
|
||||
|
||||
response = flask_client.post("/api/remote-monkey", data=request_body)
|
||||
|
||||
assert json.loads(response.text)["result"] == expected_result
|
|
@ -0,0 +1,232 @@
|
|||
from itertools import chain, repeat
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from monkey_island.cc.services.aws.aws_command_runner import (
|
||||
LINUX_DOCUMENT_NAME,
|
||||
WINDOWS_DOCUMENT_NAME,
|
||||
AWSCommandResults,
|
||||
AWSCommandStatus,
|
||||
start_infection_monkey_agent,
|
||||
)
|
||||
|
||||
TIMEOUT = 0.03
|
||||
INSTANCE_ID = "BEEFFACE"
|
||||
ISLAND_IP = "127.0.0.1"
|
||||
"""
|
||||
"commands": [
|
||||
"wget --no-check-certificate "
|
||||
"https://172.31.32.78:5000/api/agent/download/linux "
|
||||
"-O monkey-linux-64; chmod +x "
|
||||
"monkey-linux-64; ./monkey-linux-64 "
|
||||
"m0nk3y -s 172.31.32.78:5000"
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def send_command_response():
|
||||
return {
|
||||
"Command": {
|
||||
"CloudWatchOutputConfig": {
|
||||
"CloudWatchLogGroupName": "",
|
||||
"CloudWatchOutputEnabled": False,
|
||||
},
|
||||
"CommandId": "fe3cf24f-71b7-42b9-93ca-e27c34dd0581",
|
||||
"CompletedCount": 0,
|
||||
"DocumentName": "AWS-RunShellScript",
|
||||
"DocumentVersion": "$DEFAULT",
|
||||
"InstanceIds": ["i-0b62d6f0b0d9d7e77"],
|
||||
"OutputS3Region": "eu-central-1",
|
||||
"Parameters": {"commands": []},
|
||||
"Status": "Pending",
|
||||
"StatusDetails": "Pending",
|
||||
"TargetCount": 1,
|
||||
"Targets": [],
|
||||
"TimeoutSeconds": 3600,
|
||||
},
|
||||
"ResponseMetadata": {
|
||||
"HTTPHeaders": {
|
||||
"connection": "keep-alive",
|
||||
"content-length": "973",
|
||||
"content-type": "application/x-amz-json-1.1",
|
||||
"date": "Tue, 10 May 2022 12:35:49 GMT",
|
||||
"server": "Server",
|
||||
"x-amzn-requestid": "110b1563-aaf0-4e09-bd23-2db465822be7",
|
||||
},
|
||||
"HTTPStatusCode": 200,
|
||||
"RequestId": "110b1563-aaf0-4e09-bd23-2db465822be7",
|
||||
"RetryAttempts": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def in_progress_response():
|
||||
return {
|
||||
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
|
||||
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
|
||||
"Comment": "",
|
||||
"DocumentName": "AWS-RunShellScript",
|
||||
"DocumentVersion": "$DEFAULT",
|
||||
"ExecutionEndDateTime": "",
|
||||
"InstanceId": "i-0b62d6f0b0d9d7e77",
|
||||
"PluginName": "aws:runShellScript",
|
||||
"ResponseCode": -1,
|
||||
"StandardErrorContent": "",
|
||||
"StandardErrorUrl": "",
|
||||
"StandardOutputContent": "",
|
||||
"StandardOutputUrl": "",
|
||||
"Status": "InProgress",
|
||||
"StatusDetails": "InProgress",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def success_response():
|
||||
return {
|
||||
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
|
||||
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
|
||||
"Comment": "",
|
||||
"DocumentName": "AWS-RunShellScript",
|
||||
"DocumentVersion": "$DEFAULT",
|
||||
"ExecutionEndDateTime": "",
|
||||
"InstanceId": "i-0b62d6f0b0d9d7e77",
|
||||
"PluginName": "aws:runShellScript",
|
||||
"ResponseCode": -1,
|
||||
"StandardErrorContent": "",
|
||||
"StandardErrorUrl": "",
|
||||
"StandardOutputContent": "",
|
||||
"StandardOutputUrl": "",
|
||||
"Status": "Success",
|
||||
"StatusDetails": "Success",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def error_response():
|
||||
return {
|
||||
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
|
||||
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
|
||||
"Comment": "",
|
||||
"DocumentName": "AWS-RunShellScript",
|
||||
"DocumentVersion": "$DEFAULT",
|
||||
"ExecutionEndDateTime": "",
|
||||
"InstanceId": "i-0b62d6f0b0d9d7e77",
|
||||
"PluginName": "aws:runShellScript",
|
||||
"ResponseCode": -1,
|
||||
"StandardErrorContent": "ERROR RUNNING COMMAND",
|
||||
"StandardErrorUrl": "",
|
||||
"StandardOutputContent": "",
|
||||
"StandardOutputUrl": "",
|
||||
# NOTE: "Error" is technically not a valid value for this field, but we want to test that
|
||||
# anything other than "Success" and "InProgress" is treated as an error. This is
|
||||
# simpler than testing all of the different possible error cases.
|
||||
"Status": "Error",
|
||||
"StatusDetails": "Error",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_timeouts(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"monkey_island.cc.services.aws.aws_command_runner.STATUS_CHECK_SLEEP_TIME", 0.01
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def successful_mock_client(send_command_response, success_response):
|
||||
aws_client = MagicMock()
|
||||
aws_client.send_command = MagicMock(return_value=send_command_response)
|
||||
aws_client.get_command_invocation = MagicMock(return_value=success_response)
|
||||
|
||||
return aws_client
|
||||
|
||||
|
||||
def test_correct_instance_id(successful_mock_client):
|
||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
|
||||
|
||||
successful_mock_client.send_command.assert_called_once()
|
||||
call_args_kwargs = successful_mock_client.send_command.call_args[1]
|
||||
assert call_args_kwargs["InstanceIds"] == [INSTANCE_ID]
|
||||
|
||||
|
||||
def test_linux_doc_name(successful_mock_client):
|
||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
|
||||
|
||||
successful_mock_client.send_command.assert_called_once()
|
||||
call_args_kwargs = successful_mock_client.send_command.call_args[1]
|
||||
assert call_args_kwargs["DocumentName"] == LINUX_DOCUMENT_NAME
|
||||
|
||||
|
||||
def test_windows_doc_name(successful_mock_client):
|
||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT)
|
||||
|
||||
successful_mock_client.send_command.assert_called_once()
|
||||
call_args_kwargs = successful_mock_client.send_command.call_args[1]
|
||||
assert call_args_kwargs["DocumentName"] == WINDOWS_DOCUMENT_NAME
|
||||
|
||||
|
||||
def test_linux_command(successful_mock_client):
|
||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT)
|
||||
|
||||
successful_mock_client.send_command.assert_called_once()
|
||||
call_args_kwargs = successful_mock_client.send_command.call_args[1]
|
||||
assert "wget" in call_args_kwargs["Parameters"]["commands"][0]
|
||||
|
||||
|
||||
def test_windows_command(successful_mock_client):
|
||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT)
|
||||
|
||||
successful_mock_client.send_command.assert_called_once()
|
||||
call_args_kwargs = successful_mock_client.send_command.call_args[1]
|
||||
assert "DownloadFile" in call_args_kwargs["Parameters"]["commands"][0]
|
||||
|
||||
|
||||
def test_multiple_status_queries(send_command_response, in_progress_response, success_response):
|
||||
aws_client = MagicMock()
|
||||
aws_client.send_command = MagicMock(return_value=send_command_response)
|
||||
aws_client.get_command_invocation = MagicMock(
|
||||
side_effect=chain([in_progress_response, in_progress_response], repeat(success_response))
|
||||
)
|
||||
|
||||
command_results = start_infection_monkey_agent(
|
||||
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
|
||||
)
|
||||
assert command_results.status == AWSCommandStatus.SUCCESS
|
||||
|
||||
|
||||
def test_in_progress_timeout(send_command_response, in_progress_response):
|
||||
aws_client = MagicMock()
|
||||
aws_client.send_command = MagicMock(return_value=send_command_response)
|
||||
aws_client.get_command_invocation = MagicMock(return_value=in_progress_response)
|
||||
|
||||
command_results = start_infection_monkey_agent(
|
||||
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
|
||||
)
|
||||
assert command_results.status == AWSCommandStatus.IN_PROGRESS
|
||||
|
||||
|
||||
def test_failed_command(send_command_response, error_response):
|
||||
aws_client = MagicMock()
|
||||
aws_client.send_command = MagicMock(return_value=send_command_response)
|
||||
aws_client.get_command_invocation = MagicMock(return_value=error_response)
|
||||
|
||||
command_results = start_infection_monkey_agent(
|
||||
aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT
|
||||
)
|
||||
assert command_results.status == AWSCommandStatus.ERROR
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status, success",
|
||||
[
|
||||
(AWSCommandStatus.SUCCESS, True),
|
||||
(AWSCommandStatus.IN_PROGRESS, False),
|
||||
(AWSCommandStatus.ERROR, False),
|
||||
],
|
||||
)
|
||||
def test_command_resuls_status(status, success):
|
||||
results = AWSCommandResults(INSTANCE_ID, 0, "", "", status)
|
||||
assert results.success == success
|
|
@ -167,3 +167,5 @@ _.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py
|
|||
_.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py:64)
|
||||
GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57)
|
||||
architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25)
|
||||
|
||||
response_code # unused variable (monkey/monkey_island/cc/services/aws/aws_command_runner.py:26)
|
||||
|
|
Loading…
Reference in New Issue