From 7c0616eecdaf221ea2fe2ee2030a3c94382c6efd Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Wed, 11 May 2022 08:18:40 -0400 Subject: [PATCH] Island: Add timeout parameter to start_infection_monkey_agent() --- .../cc/services/aws/aws_command_runner.py | 13 ++++++---- .../cc/services/aws/aws_service.py | 15 ++++++++--- .../services/aws/test_aws_command_runner.py | 26 +++++++++++-------- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/monkey/monkey_island/cc/services/aws/aws_command_runner.py b/monkey/monkey_island/cc/services/aws/aws_command_runner.py index 9a85ba49a..e896e2cfb 100644 --- a/monkey/monkey_island/cc/services/aws/aws_command_runner.py +++ b/monkey/monkey_island/cc/services/aws/aws_command_runner.py @@ -7,7 +7,6 @@ import botocore from common.utils import Timer -REMOTE_COMMAND_TIMEOUT = 5 STATUS_CHECK_SLEEP_TIME = 1 LINUX_DOCUMENT_NAME = "AWS-RunShellScript" WINDOWS_DOCUMENT_NAME = "AWS-RunPowerShellScript" @@ -35,7 +34,11 @@ class AWSCommandResults: def start_infection_monkey_agent( - aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, island_ip: str + aws_client: botocore.client.BaseClient, + target_instance_id: str, + target_os: str, + island_ip: str, + timeout: float, ) -> AWSCommandResults: """ Run a command on a remote AWS instance @@ -43,7 +46,7 @@ def start_infection_monkey_agent( command = _get_run_agent_command(target_os, island_ip) command_id = _run_command_async(aws_client, target_instance_id, target_os, command) - _wait_for_command_to_complete(aws_client, target_instance_id, command_id) + _wait_for_command_to_complete(aws_client, target_instance_id, command_id, timeout) return _fetch_command_results(aws_client, target_instance_id, command_id) @@ -106,10 +109,10 @@ def _run_command_async( def _wait_for_command_to_complete( - aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str + aws_client: botocore.client.BaseClient, target_instance_id: str, command_id: str, timeout: float ): timer = Timer() - timer.set(REMOTE_COMMAND_TIMEOUT) + timer.set(timeout) while not timer.is_expired(): time.sleep(STATUS_CHECK_SLEEP_TIME) diff --git a/monkey/monkey_island/cc/services/aws/aws_service.py b/monkey/monkey_island/cc/services/aws/aws_service.py index 5a6151580..92a0e8bf5 100644 --- a/monkey/monkey_island/cc/services/aws/aws_service.py +++ b/monkey/monkey_island/cc/services/aws/aws_service.py @@ -11,6 +11,7 @@ from common.utils.code_utils import queue_to_list from .aws_command_runner import AWSCommandResults, start_infection_monkey_agent +DEFAULT_REMOTE_COMMAND_TIMEOUT = 5 INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList" INSTANCE_ID_KEY = "InstanceId" COMPUTER_NAME_KEY = "ComputerName" @@ -72,12 +73,16 @@ class AWSService: raise err def run_agents_on_managed_instances( - self, instances: Iterable[Mapping[str, str]], island_ip: str + self, + instances: Iterable[Mapping[str, str]], + island_ip: str, + timeout: float = DEFAULT_REMOTE_COMMAND_TIMEOUT, ) -> Sequence[AWSCommandResults]: """ Run an agent on one or more managed AWS instances. :param instances: An iterable of instances that the agent will be run on :param island_ip: The IP address of the Island to pass to the new agents + :param timeout: The maximum number of seconds to wait for the agents to start :return: A sequence of AWSCommandResults """ @@ -86,7 +91,7 @@ class AWSService: for i in instances: t = Thread( target=self._run_agent_on_managed_instance, - args=(results_queue, i["instance_id"], i["os"], island_ip), + args=(results_queue, i["instance_id"], i["os"], island_ip, timeout), daemon=True, ) t.start() @@ -98,10 +103,12 @@ class AWSService: return queue_to_list(results_queue) def _run_agent_on_managed_instance( - self, results_queue: Queue, instance_id: str, os: str, island_ip: str + self, results_queue: Queue, instance_id: str, os: str, island_ip: str, timeout: float ): ssm_client = boto3.client("ssm", self.island_aws_instance.region) - command_results = start_infection_monkey_agent(ssm_client, instance_id, os, island_ip) + command_results = start_infection_monkey_agent( + ssm_client, instance_id, os, island_ip, timeout + ) results_queue.put(command_results) diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py b/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py index dd1877a6a..aa4cfdb4b 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py @@ -11,6 +11,7 @@ from monkey_island.cc.services.aws.aws_command_runner import ( start_infection_monkey_agent, ) +TIMEOUT = 0.03 INSTANCE_ID = "BEEFFACE" ISLAND_IP = "127.0.0.1" """ @@ -129,9 +130,6 @@ def error_response(): @pytest.fixture(autouse=True) def patch_timeouts(monkeypatch): - monkeypatch.setattr( - "monkey_island.cc.services.aws.aws_command_runner.REMOTE_COMMAND_TIMEOUT", 0.03 - ) monkeypatch.setattr( "monkey_island.cc.services.aws.aws_command_runner.STATUS_CHECK_SLEEP_TIME", 0.01 ) @@ -147,7 +145,7 @@ def successful_mock_client(send_command_response, success_response): def test_correct_instance_id(successful_mock_client): - start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT) successful_mock_client.send_command.assert_called_once() call_args_kwargs = successful_mock_client.send_command.call_args[1] @@ -155,7 +153,7 @@ def test_correct_instance_id(successful_mock_client): def test_linux_doc_name(successful_mock_client): - start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT) successful_mock_client.send_command.assert_called_once() call_args_kwargs = successful_mock_client.send_command.call_args[1] @@ -163,7 +161,7 @@ def test_linux_doc_name(successful_mock_client): def test_windows_doc_name(successful_mock_client): - start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP) + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT) successful_mock_client.send_command.assert_called_once() call_args_kwargs = successful_mock_client.send_command.call_args[1] @@ -171,7 +169,7 @@ def test_windows_doc_name(successful_mock_client): def test_linux_command(successful_mock_client): - start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP, TIMEOUT) successful_mock_client.send_command.assert_called_once() call_args_kwargs = successful_mock_client.send_command.call_args[1] @@ -179,7 +177,7 @@ def test_linux_command(successful_mock_client): def test_windows_command(successful_mock_client): - start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP) + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT) successful_mock_client.send_command.assert_called_once() call_args_kwargs = successful_mock_client.send_command.call_args[1] @@ -193,7 +191,9 @@ def test_multiple_status_queries(send_command_response, in_progress_response, su side_effect=chain([in_progress_response, in_progress_response], repeat(success_response)) ) - command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + command_results = start_infection_monkey_agent( + aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT + ) assert command_results.status == AWSCommandStatus.SUCCESS @@ -202,7 +202,9 @@ def test_in_progress_timeout(send_command_response, in_progress_response): aws_client.send_command = MagicMock(return_value=send_command_response) aws_client.get_command_invocation = MagicMock(return_value=in_progress_response) - command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + command_results = start_infection_monkey_agent( + aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT + ) assert command_results.status == AWSCommandStatus.IN_PROGRESS @@ -211,7 +213,9 @@ def test_failed_command(send_command_response, error_response): aws_client.send_command = MagicMock(return_value=send_command_response) aws_client.get_command_invocation = MagicMock(return_value=error_response) - command_results = start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + command_results = start_infection_monkey_agent( + aws_client, INSTANCE_ID, "windows", ISLAND_IP, TIMEOUT + ) assert command_results.status == AWSCommandStatus.ERROR