UT: Add unit tests for aws_command_runner

This commit is contained in:
Mike Salvatore 2022-05-10 09:47:46 -04:00
parent 79eb584c5d
commit 27f8195be5
2 changed files with 223 additions and 1 deletions

View File

@ -7,6 +7,8 @@ from common.utils import Timer
REMOTE_COMMAND_TIMEOUT = 5
STATUS_CHECK_SLEEP_TIME = 1
LINUX_DOCUMENT_NAME = "AWS-RunShellScript"
WINDOWS_DOCUMENT_NAME = "AWS-RunPowerShellScript"
logger = logging.getLogger(__name__)
@ -66,7 +68,7 @@ def _get_run_monkey_cmd_windows_line(island_ip):
def _run_command_async(
aws_client: botocore.client.BaseClient, target_instance_id: str, target_os: str, command: str
):
doc_name = "AWS-RunShellScript" if target_os == "linux" else "AWS-RunPowerShellScript"
doc_name = LINUX_DOCUMENT_NAME if target_os == "linux" else WINDOWS_DOCUMENT_NAME
logger.debug(f'Running command on {target_instance_id} -- {doc_name}: "{command}"')
command_response = aws_client.send_command(

View File

@ -0,0 +1,220 @@
from unittest.mock import MagicMock
import pytest
from monkey_island.cc.services.aws.aws_command_runner import (
LINUX_DOCUMENT_NAME,
WINDOWS_DOCUMENT_NAME,
start_infection_monkey_agent,
)
INSTANCE_ID = "BEEFFACE"
ISLAND_IP = "127.0.0.1"
"""
"commands": [
"wget --no-check-certificate "
"https://172.31.32.78:5000/api/agent/download/linux "
"-O monkey-linux-64; chmod +x "
"monkey-linux-64; ./monkey-linux-64 "
"m0nk3y -s 172.31.32.78:5000"
]
"""
@pytest.fixture
def send_command_response():
return {
"Command": {
"CloudWatchOutputConfig": {
"CloudWatchLogGroupName": "",
"CloudWatchOutputEnabled": False,
},
"CommandId": "fe3cf24f-71b7-42b9-93ca-e27c34dd0581",
"CompletedCount": 0,
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"InstanceIds": ["i-0b62d6f0b0d9d7e77"],
"OutputS3Region": "eu-central-1",
"Parameters": {"commands": []},
"Status": "Pending",
"StatusDetails": "Pending",
"TargetCount": 1,
"Targets": [],
"TimeoutSeconds": 3600,
},
"ResponseMetadata": {
"HTTPHeaders": {
"connection": "keep-alive",
"content-length": "973",
"content-type": "application/x-amz-json-1.1",
"date": "Tue, 10 May 2022 12:35:49 GMT",
"server": "Server",
"x-amzn-requestid": "110b1563-aaf0-4e09-bd23-2db465822be7",
},
"HTTPStatusCode": 200,
"RequestId": "110b1563-aaf0-4e09-bd23-2db465822be7",
"RetryAttempts": 0,
},
}
@pytest.fixture
def in_progress_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
"Status": "InProgress",
"StatusDetails": "InProgress",
}
@pytest.fixture
def success_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
"Status": "Success",
"StatusDetails": "Success",
}
@pytest.fixture
def error_response():
return {
"CloudWatchOutputConfig": {"CloudWatchLogGroupName": "", "CloudWatchOutputEnabled": False},
"CommandId": "a5332cc6-0f9f-48e6-826a-d4bd7cabc2ee",
"Comment": "",
"DocumentName": "AWS-RunShellScript",
"DocumentVersion": "$DEFAULT",
"ExecutionEndDateTime": "",
"InstanceId": "i-0b62d6f0b0d9d7e77",
"PluginName": "aws:runShellScript",
"ResponseCode": -1,
"StandardErrorContent": "ERROR RUNNING COMMAND",
"StandardErrorUrl": "",
"StandardOutputContent": "",
"StandardOutputUrl": "",
# NOTE: "Error" is technically not a valid value for this field, but we want to test that
# anything other than "Success" and "InProgress" is treated as an error. This is
# simpler than testing all of the different possible error cases.
"Status": "Error",
"StatusDetails": "Error",
}
@pytest.fixture(autouse=True)
def patch_timeouts(monkeypatch):
monkeypatch.setattr(
"monkey_island.cc.services.aws.aws_command_runner.REMOTE_COMMAND_TIMEOUT", 0.03
)
monkeypatch.setattr(
"monkey_island.cc.services.aws.aws_command_runner.STATUS_CHECK_SLEEP_TIME", 0.01
)
@pytest.fixture
def successful_mock_client(send_command_response, success_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=success_response)
return aws_client
def test_correct_instance_id(successful_mock_client, send_command_response, success_response):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
successful_mock_client.send_command.assert_called_once()
assert successful_mock_client.send_command.call_args.kwargs["InstanceIds"] == [INSTANCE_ID]
def test_linux_doc_name(successful_mock_client, send_command_response, success_response):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
successful_mock_client.send_command.assert_called_once()
assert (
successful_mock_client.send_command.call_args.kwargs["DocumentName"] == LINUX_DOCUMENT_NAME
)
def test_windows_doc_name(successful_mock_client, send_command_response, success_response):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP)
successful_mock_client.send_command.assert_called_once()
assert (
successful_mock_client.send_command.call_args.kwargs["DocumentName"]
== WINDOWS_DOCUMENT_NAME
)
def test_linux_command(successful_mock_client, send_command_response, success_response):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "linux", ISLAND_IP)
successful_mock_client.send_command.assert_called_once()
assert (
"wget" in successful_mock_client.send_command.call_args.kwargs["Parameters"]["commands"][0]
)
def test_windows_command(successful_mock_client, send_command_response, success_response):
start_infection_monkey_agent(successful_mock_client, INSTANCE_ID, "windows", ISLAND_IP)
successful_mock_client.send_command.assert_called_once()
assert (
"DownloadFile"
in successful_mock_client.send_command.call_args.kwargs["Parameters"]["commands"][0]
)
def test_in_progress_no_timeout(send_command_response, in_progress_response, success_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(
side_effect=[in_progress_response, in_progress_response, success_response]
)
# If this test fails, an exception will be raised
start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)
# TODO: Address this test case
"""
def test_in_progress_timeout(send_command_response, in_progress_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=in_progress_response)
with pytest.raises(Exception):
start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)
"""
def test_failed_command(send_command_response, error_response):
aws_client = MagicMock()
aws_client.send_command = MagicMock(return_value=send_command_response)
aws_client.get_command_invocation = MagicMock(return_value=error_response)
with pytest.raises(Exception):
start_infection_monkey_agent(aws_client, INSTANCE_ID, "windows", ISLAND_IP)