diff --git a/monkey/monkey_island/cc/services/aws/aws_command_runner.py b/monkey/monkey_island/cc/services/aws/aws_command_runner.py index 88eb7e3b4..7983cf28a 100644 --- a/monkey/monkey_island/cc/services/aws/aws_command_runner.py +++ b/monkey/monkey_island/cc/services/aws/aws_command_runner.py @@ -7,6 +7,8 @@ from common.utils import Timer REMOTE_COMMAND_TIMEOUT = 5 STATUS_CHECK_SLEEP_TIME = 1 +LINUX_DOCUMENT_NAME = "AWS-RunShellScript" +WINDOWS_DOCUMENT_NAME = "AWS-RunPowerShellScript" logger = logging.getLogger(__name__) @@ -66,7 +68,7 @@ def _get_run_monkey_cmd_windows_line(island_ip): def _run_command_async( aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, command: str ): - doc_name = "AWS-RunShellScript" if target_os == "linux" else "AWS-RunPowerShellScript" + doc_name = LINUX_DOCUMENT_NAME if target_os == "linux" else WINDOWS_DOCUMENT_NAME logger.debug(f'Running command on {target_instance_id} -- {doc_name}: "{command}"') command_response = aws_client.send_command( diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py b/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py new file mode 100644 index 000000000..c86fe06d8 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/services/aws/test_aws_command_runner.py @@ -0,0 +1,220 @@ +from unittest.mock import MagicMock + +import pytest + +from monkey_island.cc.services.aws.aws_command_runner import ( + LINUX_DOCUMENT_NAME, + WINDOWS_DOCUMENT_NAME, + start_infection_monkey_agent, +) + +INSTANCE_ID = "BEEFFACE" +ISLAND_IP = "127.0.0.1" +""" + "commands": [ + "wget --no-check-certificate " + "https://172.31.32.78:5000/api/agent/download/linux " + "-O monkey-linux-64; chmod +x " + "monkey-linux-64; ./monkey-linux-64 " + "m0nk3y -s 172.31.32.78:5000" + ] + """ + + +@pytest.fixture +def send_command_response(): + return { + "Command": { + "CloudWatchOutputConfig": { + "CloudWatchLogGroupName": "", + "CloudWatchOutputEnabled": False, + }, + "CommandId": "fe3cf24f-71b7-42b9-93ca-e27c34dd0581", + "CompletedCount": 0, + "DocumentName": "AWS-RunShellScript", + "DocumentVersion": "$DEFAULT", + "InstanceIds": ["i-0b62d6f0b0d9d7e77"], + "OutputS3Region": "eu-central-1", + "Parameters": {"commands": []}, + "Status": "Pending", + "StatusDetails": "Pending", + "TargetCount": 1, + "Targets": [], + "TimeoutSeconds": 3600, + }, + "ResponseMetadata": { + "HTTPHeaders": { + "connection": "keep-alive", + "content-length": "973", + "content-type": "application/x-amz-json-1.1", + "date": "Tue, 10 May 2022 12:35:49 GMT", + "server": "Server", + "x-amzn-requestid": "110b1563-aaf0-4e09-bd23-2db465822be7", + }, + "HTTPStatusCode": 200, + "RequestId": "110b1563-aaf0-4e09-bd23-2db465822be7", + "RetryAttempts": 0, + }, + } + + +@pytest.fixture +def in_progress_response(): + return { + "CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False}, + "CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee", + "Comment": "", + "DocumentName": "AWS-RunShellScript", + "DocumentVersion": "$DEFAULT", + "ExecutionEndDateTime": "", + "InstanceId": "i-0b62d6f0b0d9d7e77", + "PluginName": "aws:runShellScript", + "ResponseCode": -1, + "StandardErrorContent": "", + "StandardErrorUrl": "", + "StandardOutputContent": "", + "StandardOutputUrl": "", + "Status": "InProgress", + "StatusDetails": "InProgress", + } + + +@pytest.fixture +def success_response(): + return { + "CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False}, + "CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee", + "Comment": "", + "DocumentName": "AWS-RunShellScript", + "DocumentVersion": "$DEFAULT", + "ExecutionEndDateTime": "", + "InstanceId": "i-0b62d6f0b0d9d7e77", + "PluginName": "aws:runShellScript", + "ResponseCode": -1, + "StandardErrorContent": "", + "StandardErrorUrl": "", + "StandardOutputContent": "", + "StandardOutputUrl": "", + "Status": "Success", + "StatusDetails": "Success", + } + + +@pytest.fixture +def error_response(): + return { + "CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False}, + "CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee", + "Comment": "", + "DocumentName": "AWS-RunShellScript", + "DocumentVersion": "$DEFAULT", + "ExecutionEndDateTime": "", + "InstanceId": "i-0b62d6f0b0d9d7e77", + "PluginName": "aws:runShellScript", + "ResponseCode": -1, + "StandardErrorContent": "ERROR RUNNING COMMAND", + "StandardErrorUrl": "", + "StandardOutputContent": "", + "StandardOutputUrl": "", + # NOTE: "Error" is technically not a valid value for this field, but we want to test that + # anything other than "Success" and "InProgress" is treated as an error. This is + # simpler than testing all of the different possible error cases. + "Status": "Error", + "StatusDetails": "Error", + } + + +@pytest.fixture(autouse=True) +def patch_timeouts(monkeypatch): + monkeypatch.setattr( + "monkey_island.cc.services.aws.aws_command_runner.REMOTE_COMMAND_TIMEOUT", 0.03 + ) + monkeypatch.setattr( + "monkey_island.cc.services.aws.aws_command_runner.STATUS_CHECK_SLEEP_TIME", 0.01 + ) + + +@pytest.fixture +def successful_mock_client(send_command_response, success_response): + aws_client = MagicMock() + aws_client.send_command = MagicMock(return_value=send_command_response) + aws_client.get_command_invocation = MagicMock(return_value=success_response) + + return aws_client + + +def test_correct_instance_id(successful_mock_client, send_command_response, success_response): + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) + + successful_mock_client.send_command.assert_called_once() + assert successful_mock_client.send_command.call_args.kwargs["InstanceIds"] == [INSTANCE_ID] + + +def test_linux_doc_name(successful_mock_client, send_command_response, success_response): + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) + + successful_mock_client.send_command.assert_called_once() + assert ( + successful_mock_client.send_command.call_args.kwargs["DocumentName"] == LINUX_DOCUMENT_NAME + ) + + +def test_windows_doc_name(successful_mock_client, send_command_response, success_response): + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP) + + successful_mock_client.send_command.assert_called_once() + assert ( + successful_mock_client.send_command.call_args.kwargs["DocumentName"] + == WINDOWS_DOCUMENT_NAME + ) + + +def test_linux_command(successful_mock_client, send_command_response, success_response): + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP) + + successful_mock_client.send_command.assert_called_once() + assert ( + "wget" in successful_mock_client.send_command.call_args.kwargs["Parameters"]["commands"][0] + ) + + +def test_windows_command(successful_mock_client, send_command_response, success_response): + start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP) + + successful_mock_client.send_command.assert_called_once() + assert ( + "DownloadFile" + in successful_mock_client.send_command.call_args.kwargs["Parameters"]["commands"][0] + ) + + +def test_in_progress_no_timeout(send_command_response, in_progress_response, success_response): + aws_client = MagicMock() + aws_client.send_command = MagicMock(return_value=send_command_response) + aws_client.get_command_invocation = MagicMock( + side_effect=[in_progress_response, in_progress_response, success_response] + ) + + # If this test fails, an exception will be raised + start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) + + +# TODO: Address this test case +""" +def test_in_progress_timeout(send_command_response, in_progress_response): + aws_client = MagicMock() + aws_client.send_command = MagicMock(return_value=send_command_response) + aws_client.get_command_invocation = MagicMock(return_value=in_progress_response) + + with pytest.raises(Exception): + start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP) +""" + + +def test_failed_command(send_command_response, error_response): + aws_client = MagicMock() + aws_client.send_command = MagicMock(return_value=send_command_response) + aws_client.get_command_invocation = MagicMock(return_value=error_response) + + with pytest.raises(Exception): + start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)