diff --git a/monkey/monkey_island/cc/services/aws/__init__.py b/monkey/monkey_island/cc/services/aws/__init__.py index ff20c89a2..78ed07bc5 100644 --- a/monkey/monkey_island/cc/services/aws/__init__.py +++ b/monkey/monkey_island/cc/services/aws/__init__.py @@ -1 +1,2 @@ from .aws_service import AWSService +from .aws_command_runner import AWSCommandResults, AWSCommandStatus diff --git a/monkey/monkey_island/cc/services/aws/aws_command_runner.py b/monkey/monkey_island/cc/services/aws/aws_command_runner.py index b6e9aa16e..69ada244c 100644 --- a/monkey/monkey_island/cc/services/aws/aws_command_runner.py +++ b/monkey/monkey_island/cc/services/aws/aws_command_runner.py @@ -1,5 +1,7 @@ import logging import time +from dataclasses import dataclass +from enum import Enum, auto import botocore @@ -13,18 +15,34 @@ WINDOWS_DOCUMENT_NAME = "AWS-RunPowerShellScript" logger = logging.getLogger(__name__) -# TODO: Make sure the return type is compatible with what RemoteRun is expecting. Add typehint. +class AWSCommandStatus(Enum): + SUCCESS = auto() + IN_PROGRESS = auto() + ERROR = auto() + + +@dataclass(frozen=True) +class AWSCommandResults: + response_code: int + stdout: str + stderr: str + status: AWSCommandStatus + + @property + def success(self): + return self.status == AWSCommandStatus.SUCCESS + + def start_infection_monkey_agent( aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, island_ip: str -): +) -> AWSCommandResults: """ Run a command on a remote AWS instance """ command = _get_run_agent_command(target_os, island_ip) command_id = _run_command_async(aws_client, target_instance_id, target_os, command) - _wait_for_command_to_complete(aws_client, target_instance_id, command_id) - # TODO: Return result + return _wait_for_command_to_complete(aws_client, target_instance_id, command_id) def _get_run_agent_command(target_os: str, island_ip: str): @@ -85,44 +103,44 @@ def _run_command_async( return command_id -class AWSCommandError(Exception): - pass - - def _wait_for_command_to_complete( aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str -): +) -> AWSCommandResults: timer = Timer() timer.set(REMOTE_COMMAND_TIMEOUT) while not timer.is_expired(): time.sleep(STATUS_CHECK_SLEEP_TIME) - command_result = aws_client.get_command_invocation( - CommandId=command_id, InstanceId=target_instance_id - ) - command_status = command_result["Status"] + command_results = _fetch_command_results(aws_client, target_instance_id, command_id) + logger.debug(f"Command {command_id} status: {command_results.status.name}") - logger.debug(f"Command {command_id} status: {command_status}") + if command_results.status != AWSCommandStatus.IN_PROGRESS: + return command_results - if command_status == "Success": - break - - if command_status != "InProgress": - # TODO: Create an exception for this occasion and raise it with useful information. - raise AWSCommandError( - f"AWS command failed." f" Command invocation contents: {command_result}" - ) + return command_results def _fetch_command_results( aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str -): - command_results = aws_client.ssm.get_command_invocation( +) -> AWSCommandResults: + command_results = aws_client.get_command_invocation( CommandId=command_id, InstanceId=target_instance_id ) - # TODO: put these into a dataclass and return - # self.is_successful(command_info, True) - # command_results["ResponseCode"] - # command_results["StandardOutputContent"] - # command_results["StandardErrorContent"] + command_status = command_results["Status"] + logger.debug(f"Command {command_id} status: {command_status}") + + aws_command_result_status = None + if command_status == "Success": + aws_command_result_status = AWSCommandStatus.SUCCESS + elif command_status == "InProgress": + aws_command_result_status = AWSCommandStatus.IN_PROGRESS + else: + aws_command_result_status = AWSCommandStatus.ERROR + + return AWSCommandResults( + command_results["ResponseCode"], + command_results["StandardOutputContent"], + command_results["StandardErrorContent"], + aws_command_result_status, + ) diff --git a/monkey/monkey_island/cc/services/aws/aws_service.py b/monkey/monkey_island/cc/services/aws/aws_service.py index db0ac5468..cf4acd89d 100644 --- a/monkey/monkey_island/cc/services/aws/aws_service.py +++ b/monkey/monkey_island/cc/services/aws/aws_service.py @@ -6,7 +6,7 @@ import botocore from common.aws.aws_instance import AWSInstance -from .aws_command_runner import start_infection_monkey_agent +from .aws_command_runner import AWSCommandResults, start_infection_monkey_agent INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList" INSTANCE_ID_KEY = "InstanceId" @@ -68,15 +68,14 @@ class AWSService: logger.warning("AWS client error while trying to get manage dinstances: {err}") raise err - # TODO: Determine the return type def run_agents_on_managed_instances( self, instances: Iterable[Mapping[str, str]], island_ip: str - ): + ) -> Sequence[AWSCommandResults]: """ Run an agent on one or more managed AWS instances. :param instances: An iterable of instances that the agent will be run on :param island_ip: The IP address of the Island to pass to the new agents - :return: Mapping with 'instance_id' as a key the agent's status as a value + :return: A sequence of AWSCommandResults """ results = [] @@ -88,7 +87,9 @@ class AWSService: return results - def _run_agent_on_managed_instance(self, instance_id: str, os: str, island_ip: str): + def _run_agent_on_managed_instance( + self, instance_id: str, os: str, island_ip: str + ) -> AWSCommandResults: ssm_client = boto3.client("ssm", self.island_aws_instance.region) return start_infection_monkey_agent(ssm_client, instance_id, os, island_ip) diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py b/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py index c86fe06d8..3386772ed 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py @@ -5,6 +5,8 @@ import pytest from monkey_island.cc.services.aws.aws_command_runner import ( LINUX_DOCUMENT_NAME, WINDOWS_DOCUMENT_NAME, + AWSCommandResults, + AWSCommandStatus, start_infection_monkey_agent, ) @@ -143,14 +145,14 @@ def successful_mock_client(send_command_response, success_response): return aws_client -def test_correct_instance_id(successful_mock_client, send_command_response, success_response): +def test_correct_instance_id(successful_mock_client): start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) successful_mock_client.send_command.assert_called_once() assert successful_mock_client.send_command.call_args.kwargs["InstanceIds"] == [INSTANCE_ID] -def test_linux_doc_name(successful_mock_client, send_command_response, success_response): +def test_linux_doc_name(successful_mock_client): start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) successful_mock_client.send_command.assert_called_once() @@ -159,7 +161,7 @@ def test_linux_doc_name(successful_mock_client, send_command_response, success_r ) -def test_windows_doc_name(successful_mock_client, send_command_response, success_response): +def test_windows_doc_name(successful_mock_client): start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP) successful_mock_client.send_command.assert_called_once() @@ -169,7 +171,7 @@ def test_windows_doc_name(successful_mock_client, send_command_response, success ) -def test_linux_command(successful_mock_client, send_command_response, success_response): +def test_linux_command(successful_mock_client): start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) successful_mock_client.send_command.assert_called_once() @@ -178,7 +180,7 @@ def test_linux_command(successful_mock_client, send_command_response, success_re ) -def test_windows_command(successful_mock_client, send_command_response, success_response): +def test_windows_command(successful_mock_client): start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP) successful_mock_client.send_command.assert_called_once() @@ -188,27 +190,24 @@ def test_windows_command(successful_mock_client, send_command_response, success_ ) -def test_in_progress_no_timeout(send_command_response, in_progress_response, success_response): +def test_multiple_status_queries(send_command_response, in_progress_response, success_response): aws_client = MagicMock() aws_client.send_command = MagicMock(return_value=send_command_response) aws_client.get_command_invocation = MagicMock( side_effect=[in_progress_response, in_progress_response, success_response] ) - # If this test fails, an exception will be raised - start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + assert command_results.status == AWSCommandStatus.SUCCESS -# TODO: Address this test case -""" def test_in_progress_timeout(send_command_response, in_progress_response): aws_client = MagicMock() aws_client.send_command = MagicMock(return_value=send_command_response) aws_client.get_command_invocation = MagicMock(return_value=in_progress_response) - with pytest.raises(Exception): - start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) -""" + command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + assert command_results.status == AWSCommandStatus.IN_PROGRESS def test_failed_command(send_command_response, error_response): @@ -216,5 +215,18 @@ def test_failed_command(send_command_response, error_response): aws_client.send_command = MagicMock(return_value=send_command_response) aws_client.get_command_invocation = MagicMock(return_value=error_response) - with pytest.raises(Exception): - start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + assert command_results.status == AWSCommandStatus.ERROR + + +@pytest.mark.parametrize( + "status, success", + [ + (AWSCommandStatus.SUCCESS, True), + (AWSCommandStatus.IN_PROGRESS, False), + (AWSCommandStatus.ERROR, False), + ], +) +def test_command_resuls_status(status, success): + results = AWSCommandResults(0, "", "", status) + assert results.success == success diff --git a/vulture_allowlist.py b/vulture_allowlist.py index c8c1378d4..cec60a0c9 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -167,3 +167,5 @@ _.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py _.instance_name # unused attribute (monkey/common/cloud/azure/azure_instance.py:64) GCPHandler # unused function (envs/monkey_zoo/blackbox/test_blackbox.py:57) architecture # unused variable (monkey/infection_monkey/exploit/caching_agent_repository.py:25) + +response_code # unused variable (monkey/monkey_island/cc/services/aws/aws_command_runner.py:26)