Island: Return AWSCommandResults from start_infection_monkey_agent()
This commit is contained in:
parent
e5285f2f78
commit
2804ba9b07
|
@ -1 +1,2 @@
|
||||||
from .aws_service import AWSService
|
from .aws_service import AWSService
|
||||||
|
from .aws_command_runner import AWSCommandResults, AWSCommandStatus
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
import botocore
|
import botocore
|
||||||
|
|
||||||
|
@ -13,18 +15,34 @@ WINDOWS_DOCUMENT_NAME = "AWS-RunPowerShellScript"
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Make sure the return type is compatible with what RemoteRun is expecting. Add typehint.
|
class AWSCommandStatus(Enum):
|
||||||
|
SUCCESS = auto()
|
||||||
|
IN_PROGRESS = auto()
|
||||||
|
ERROR = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AWSCommandResults:
|
||||||
|
response_code: int
|
||||||
|
stdout: str
|
||||||
|
stderr: str
|
||||||
|
status: AWSCommandStatus
|
||||||
|
|
||||||
|
@property
|
||||||
|
def success(self):
|
||||||
|
return self.status == AWSCommandStatus.SUCCESS
|
||||||
|
|
||||||
|
|
||||||
def start_infection_monkey_agent(
|
def start_infection_monkey_agent(
|
||||||
aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, island_ip: str
|
aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, island_ip: str
|
||||||
):
|
) -> AWSCommandResults:
|
||||||
"""
|
"""
|
||||||
Run a command on a remote AWS instance
|
Run a command on a remote AWS instance
|
||||||
"""
|
"""
|
||||||
command = _get_run_agent_command(target_os, island_ip)
|
command = _get_run_agent_command(target_os, island_ip)
|
||||||
command_id = _run_command_async(aws_client, target_instance_id, target_os, command)
|
command_id = _run_command_async(aws_client, target_instance_id, target_os, command)
|
||||||
_wait_for_command_to_complete(aws_client, target_instance_id, command_id)
|
|
||||||
|
|
||||||
# TODO: Return result
|
return _wait_for_command_to_complete(aws_client, target_instance_id, command_id)
|
||||||
|
|
||||||
|
|
||||||
def _get_run_agent_command(target_os: str, island_ip: str):
|
def _get_run_agent_command(target_os: str, island_ip: str):
|
||||||
|
@ -85,44 +103,44 @@ def _run_command_async(
|
||||||
return command_id
|
return command_id
|
||||||
|
|
||||||
|
|
||||||
class AWSCommandError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_command_to_complete(
|
def _wait_for_command_to_complete(
|
||||||
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str
|
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str
|
||||||
):
|
) -> AWSCommandResults:
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
timer.set(REMOTE_COMMAND_TIMEOUT)
|
timer.set(REMOTE_COMMAND_TIMEOUT)
|
||||||
|
|
||||||
while not timer.is_expired():
|
while not timer.is_expired():
|
||||||
time.sleep(STATUS_CHECK_SLEEP_TIME)
|
time.sleep(STATUS_CHECK_SLEEP_TIME)
|
||||||
|
|
||||||
command_result = aws_client.get_command_invocation(
|
command_results = _fetch_command_results(aws_client, target_instance_id, command_id)
|
||||||
CommandId=command_id, InstanceId=target_instance_id
|
logger.debug(f"Command {command_id} status: {command_results.status.name}")
|
||||||
)
|
|
||||||
command_status = command_result["Status"]
|
|
||||||
|
|
||||||
logger.debug(f"Command {command_id} status: {command_status}")
|
if command_results.status != AWSCommandStatus.IN_PROGRESS:
|
||||||
|
return command_results
|
||||||
|
|
||||||
if command_status == "Success":
|
return command_results
|
||||||
break
|
|
||||||
|
|
||||||
if command_status != "InProgress":
|
|
||||||
# TODO: Create an exception for this occasion and raise it with useful information.
|
|
||||||
raise AWSCommandError(
|
|
||||||
f"AWS command failed." f" Command invocation contents: {command_result}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_command_results(
|
def _fetch_command_results(
|
||||||
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str
|
aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str
|
||||||
):
|
) -> AWSCommandResults:
|
||||||
command_results = aws_client.ssm.get_command_invocation(
|
command_results = aws_client.get_command_invocation(
|
||||||
CommandId=command_id, InstanceId=target_instance_id
|
CommandId=command_id, InstanceId=target_instance_id
|
||||||
)
|
)
|
||||||
# TODO: put these into a dataclass and return
|
command_status = command_results["Status"]
|
||||||
# self.is_successful(command_info, True)
|
logger.debug(f"Command {command_id} status: {command_status}")
|
||||||
# command_results["ResponseCode"]
|
|
||||||
# command_results["StandardOutputContent"]
|
aws_command_result_status = None
|
||||||
# command_results["StandardErrorContent"]
|
if command_status == "Success":
|
||||||
|
aws_command_result_status = AWSCommandStatus.SUCCESS
|
||||||
|
elif command_status == "InProgress":
|
||||||
|
aws_command_result_status = AWSCommandStatus.IN_PROGRESS
|
||||||
|
else:
|
||||||
|
aws_command_result_status = AWSCommandStatus.ERROR
|
||||||
|
|
||||||
|
return AWSCommandResults(
|
||||||
|
command_results["ResponseCode"],
|
||||||
|
command_results["StandardOutputContent"],
|
||||||
|
command_results["StandardErrorContent"],
|
||||||
|
aws_command_result_status,
|
||||||
|
)
|
||||||
|
|
|
@ -6,7 +6,7 @@ import botocore
|
||||||
|
|
||||||
from common.aws.aws_instance import AWSInstance
|
from common.aws.aws_instance import AWSInstance
|
||||||
|
|
||||||
from .aws_command_runner import start_infection_monkey_agent
|
from .aws_command_runner import AWSCommandResults, start_infection_monkey_agent
|
||||||
|
|
||||||
INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList"
|
INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList"
|
||||||
INSTANCE_ID_KEY = "InstanceId"
|
INSTANCE_ID_KEY = "InstanceId"
|
||||||
|
@ -68,15 +68,14 @@ class AWSService:
|
||||||
logger.warning("AWS client error while trying to get manage dinstances: {err}")
|
logger.warning("AWS client error while trying to get manage dinstances: {err}")
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
# TODO: Determine the return type
|
|
||||||
def run_agents_on_managed_instances(
|
def run_agents_on_managed_instances(
|
||||||
self, instances: Iterable[Mapping[str, str]], island_ip: str
|
self, instances: Iterable[Mapping[str, str]], island_ip: str
|
||||||
):
|
) -> Sequence[AWSCommandResults]:
|
||||||
"""
|
"""
|
||||||
Run an agent on one or more managed AWS instances.
|
Run an agent on one or more managed AWS instances.
|
||||||
:param instances: An iterable of instances that the agent will be run on
|
:param instances: An iterable of instances that the agent will be run on
|
||||||
:param island_ip: The IP address of the Island to pass to the new agents
|
:param island_ip: The IP address of the Island to pass to the new agents
|
||||||
:return: Mapping with 'instance_id' as a key the agent's status as a value
|
:return: A sequence of AWSCommandResults
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
@ -88,7 +87,9 @@ class AWSService:
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _run_agent_on_managed_instance(self, instance_id: str, os: str, island_ip: str):
|
def _run_agent_on_managed_instance(
|
||||||
|
self, instance_id: str, os: str, island_ip: str
|
||||||
|
) -> AWSCommandResults:
|
||||||
ssm_client = boto3.client("ssm", self.island_aws_instance.region)
|
ssm_client = boto3.client("ssm", self.island_aws_instance.region)
|
||||||
return start_infection_monkey_agent(ssm_client, instance_id, os, island_ip)
|
return start_infection_monkey_agent(ssm_client, instance_id, os, island_ip)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,8 @@ import pytest
|
||||||
from monkey_island.cc.services.aws.aws_command_runner import (
|
from monkey_island.cc.services.aws.aws_command_runner import (
|
||||||
LINUX_DOCUMENT_NAME,
|
LINUX_DOCUMENT_NAME,
|
||||||
WINDOWS_DOCUMENT_NAME,
|
WINDOWS_DOCUMENT_NAME,
|
||||||
|
AWSCommandResults,
|
||||||
|
AWSCommandStatus,
|
||||||
start_infection_monkey_agent,
|
start_infection_monkey_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -143,14 +145,14 @@ def successful_mock_client(send_command_response, success_response):
|
||||||
return aws_client
|
return aws_client
|
||||||
|
|
||||||
|
|
||||||
def test_correct_instance_id(successful_mock_client, send_command_response, success_response):
|
def test_correct_instance_id(successful_mock_client):
|
||||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
|
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
|
||||||
|
|
||||||
successful_mock_client.send_command.assert_called_once()
|
successful_mock_client.send_command.assert_called_once()
|
||||||
assert successful_mock_client.send_command.call_args.kwargs["InstanceIds"] == [INSTANCE_ID]
|
assert successful_mock_client.send_command.call_args.kwargs["InstanceIds"] == [INSTANCE_ID]
|
||||||
|
|
||||||
|
|
||||||
def test_linux_doc_name(successful_mock_client, send_command_response, success_response):
|
def test_linux_doc_name(successful_mock_client):
|
||||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
|
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
|
||||||
|
|
||||||
successful_mock_client.send_command.assert_called_once()
|
successful_mock_client.send_command.assert_called_once()
|
||||||
|
@ -159,7 +161,7 @@ def test_linux_doc_name(successful_mock_client, send_command_response, success_r
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_windows_doc_name(successful_mock_client, send_command_response, success_response):
|
def test_windows_doc_name(successful_mock_client):
|
||||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP)
|
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP)
|
||||||
|
|
||||||
successful_mock_client.send_command.assert_called_once()
|
successful_mock_client.send_command.assert_called_once()
|
||||||
|
@ -169,7 +171,7 @@ def test_windows_doc_name(successful_mock_client, send_command_response, success
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_linux_command(successful_mock_client, send_command_response, success_response):
|
def test_linux_command(successful_mock_client):
|
||||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
|
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
|
||||||
|
|
||||||
successful_mock_client.send_command.assert_called_once()
|
successful_mock_client.send_command.assert_called_once()
|
||||||
|
@ -178,7 +180,7 @@ def test_linux_command(successful_mock_client, send_command_response, success_re
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_windows_command(successful_mock_client, send_command_response, success_response):
|
def test_windows_command(successful_mock_client):
|
||||||
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP)
|
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP)
|
||||||
|
|
||||||
successful_mock_client.send_command.assert_called_once()
|
successful_mock_client.send_command.assert_called_once()
|
||||||
|
@ -188,27 +190,24 @@ def test_windows_command(successful_mock_client, send_command_response, success_
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_in_progress_no_timeout(send_command_response, in_progress_response, success_response):
|
def test_multiple_status_queries(send_command_response, in_progress_response, success_response):
|
||||||
aws_client = MagicMock()
|
aws_client = MagicMock()
|
||||||
aws_client.send_command = MagicMock(return_value=send_command_response)
|
aws_client.send_command = MagicMock(return_value=send_command_response)
|
||||||
aws_client.get_command_invocation = MagicMock(
|
aws_client.get_command_invocation = MagicMock(
|
||||||
side_effect=[in_progress_response, in_progress_response, success_response]
|
side_effect=[in_progress_response, in_progress_response, success_response]
|
||||||
)
|
)
|
||||||
|
|
||||||
# If this test fails, an exception will be raised
|
command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)
|
||||||
start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)
|
assert command_results.status == AWSCommandStatus.SUCCESS
|
||||||
|
|
||||||
|
|
||||||
# TODO: Address this test case
|
|
||||||
"""
|
|
||||||
def test_in_progress_timeout(send_command_response, in_progress_response):
|
def test_in_progress_timeout(send_command_response, in_progress_response):
|
||||||
aws_client = MagicMock()
|
aws_client = MagicMock()
|
||||||
aws_client.send_command = MagicMock(return_value=send_command_response)
|
aws_client.send_command = MagicMock(return_value=send_command_response)
|
||||||
aws_client.get_command_invocation = MagicMock(return_value=in_progress_response)
|
aws_client.get_command_invocation = MagicMock(return_value=in_progress_response)
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)
|
||||||
start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)
|
assert command_results.status == AWSCommandStatus.IN_PROGRESS
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_failed_command(send_command_response, error_response):
|
def test_failed_command(send_command_response, error_response):
|
||||||
|
@ -216,5 +215,18 @@ def test_failed_command(send_command_response, error_response):
|
||||||
aws_client.send_command = MagicMock(return_value=send_command_response)
|
aws_client.send_command = MagicMock(return_value=send_command_response)
|
||||||
aws_client.get_command_invocation = MagicMock(return_value=error_response)
|
aws_client.get_command_invocation = MagicMock(return_value=error_response)
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)
|
||||||
start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)
|
assert command_results.status == AWSCommandStatus.ERROR
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status, success",
|
||||||
|
[
|
||||||
|
(AWSCommandStatus.SUCCESS, True),
|
||||||
|
(AWSCommandStatus.IN_PROGRESS, False),
|
||||||
|
(AWSCommandStatus.ERROR, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_command_resuls_status(status, success):
|
||||||
|
results = AWSCommandResults(0, "", "", status)
|
||||||
|
assert results.success == success
|
||||||
|
|
|
@ -167,3 +167,5 @@ _.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py
|
||||||
_.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py:64)
|
_.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py:64)
|
||||||
GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57)
|
GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57)
|
||||||
architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25)
|
architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25)
|
||||||
|
|
||||||
|
response_code # unused variable (monkey/monkey_island/cc/services/aws/aws_command_runner.py:26)
|
||||||
|
|
Loading…
Reference in New Issue