diff --git a/monkey/monkey_island/cc/services/aws_service.py b/monkey/monkey_island/cc/services/aws_service.py index 1c15bb6d5..d589fad96 100644 --- a/monkey/monkey_island/cc/services/aws_service.py +++ b/monkey/monkey_island/cc/services/aws_service.py @@ -1,5 +1,5 @@ import logging -from typing import Iterable, Optional +from typing import Any, Dict, Iterable, Sequence import boto3 import botocore @@ -26,6 +26,29 @@ class AWSService: def island_aws_instance(self) -> AWSInstance: return self._aws_instance + def get_managed_instances(self) -> Sequence[Dict[str, str]]: + raw_managed_instances_info = self._get_raw_managed_instances() + return _filter_instance_info_from_aws_response(raw_managed_instances_info) + + def _get_raw_managed_instances(self) -> Sequence[Dict[str, Any]]: + """ + Get the information for all instances with the relevant roles. + + This function will assume that it's running on an EC2 instance with the correct IAM role. + See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam + -role for details. + + :raises: botocore.exceptions.ClientError if can't describe local instance information. + :return: All visible instances from this instance + """ + local_ssm_client = boto3.client("ssm", self.island_aws_instance.region) + try: + response = local_ssm_client.describe_instance_information() + return response[INSTANCE_INFORMATION_LIST_KEY] + except botocore.exceptions.ClientError as err: + logger.warning("AWS client error while trying to get manage dinstances: {err}") + raise err + def run_agent_on_managed_instances(self, instance_ids: Iterable[str]): for id_ in instance_ids: self._run_agent_on_managed_instance(id_) @@ -34,59 +57,13 @@ class AWSService: pass -def filter_instance_data_from_aws_response(response): +def _filter_instance_info_from_aws_response(raw_managed_instances_info: Sequence[Dict[str, Any]]): return [ { - "instance_id": x[INSTANCE_ID_KEY], - "name": x[COMPUTER_NAME_KEY], - "os": x[PLATFORM_TYPE_KEY].lower(), - "ip_address": x[IP_ADDRESS_KEY], + "instance_id": managed_instance[INSTANCE_ID_KEY], + "name": managed_instance[COMPUTER_NAME_KEY], + "os": managed_instance[PLATFORM_TYPE_KEY].lower(), + "ip_address": managed_instance[IP_ADDRESS_KEY], } - for x in response[INSTANCE_INFORMATION_LIST_KEY] + for managed_instance in raw_managed_instances_info ] - - -aws_instance: Optional[AWSInstance] = None - - -def initialize(): - global aws_instance - aws_instance = AWSInstance() - - -def is_on_aws(): - return aws_instance.is_instance - - -def get_region(): - return aws_instance.region - - -def get_account_id(): - return aws_instance.account_id - - -def get_client(client_type): - return boto3.client(client_type, region_name=aws_instance.region) - - -def get_instances(): - """ - Get the information for all instances with the relevant roles. - - This function will assume that it's running on an EC2 instance with the correct IAM role. - See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam - -role for details. - - :raises: botocore.exceptions.ClientError if can't describe local instance information. - :return: All visible instances from this instance - """ - local_ssm_client = boto3.client("ssm", aws_instance.region) - try: - response = local_ssm_client.describe_instance_information() - - filtered_instances_data = filter_instance_data_from_aws_response(response) - return filtered_instances_data - except botocore.exceptions.ClientError as e: - logger.warning("AWS client error while trying to get instances: " + e) - raise e diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_aws_service.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_aws_service.py index bd85595f5..8be306310 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/test_aws_service.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_aws_service.py @@ -1,65 +1,75 @@ -import json import threading -from typing import Optional -from unittest import TestCase +from typing import Any, Dict, Optional, Sequence import pytest from common.aws import AWSInstance from monkey_island.cc.services import AWSService -from monkey_island.cc.services.aws_service import filter_instance_data_from_aws_response +EXPECTED_INSTANCE_1 = { + "instance_id": "1", + "name": "comp1", + "os": "linux", + "ip_address": "192.168.1.1", +} +EXPECTED_INSTANCE_2 = { + "instance_id": "2", + "name": "comp2", + "os": "linux", + "ip_address": "192.168.1.2", +} -class TestAwsService(TestCase): - def test_filter_instance_data_from_aws_response(self): - json_response_full = """ - { - "InstanceInformationList": [ - { - "ActivationId": "string", - "AgentVersion": "string", - "AssociationOverview": { - "DetailedStatus": "string", - "InstanceAssociationStatusAggregatedCount": { - "string" : 6 - } - }, - "AssociationStatus": "string", - "ComputerName": "string", - "IamRole": "string", - "InstanceId": "string", - "IPAddress": "string", - "IsLatestVersion": "True", - "LastAssociationExecutionDate": 6, - "LastPingDateTime": 6, - "LastSuccessfulAssociationExecutionDate": 6, - "Name": "string", - "PingStatus": "string", - "PlatformName": "string", - "PlatformType": "string", - "PlatformVersion": "string", - "RegistrationDate": 6, - "ResourceType": "string" - } - ], - "NextToken": "string" - } - """ - - json_response_empty = """ - { - "InstanceInformationList": [], - "NextToken": "string" - } - """ - - self.assertEqual( - filter_instance_data_from_aws_response(json.loads(json_response_empty)), [] - ) - self.assertEqual( - filter_instance_data_from_aws_response(json.loads(json_response_full)), - [{"instance_id": "string", "ip_address": "string", "name": "string", "os": "string"}], - ) +EMPTY_INSTANCE_INFO_RESPONSE = [] +FULL_INSTANCE_INFO_RESPONSE = [ + { + "ActivationId": "string", + "AgentVersion": "string", + "AssociationOverview": { + "DetailedStatus": "string", + "InstanceAssociationStatusAggregatedCount": {"string": 6}, + }, + "AssociationStatus": "string", + "ComputerName": EXPECTED_INSTANCE_1["name"], + "IamRole": "string", + "InstanceId": EXPECTED_INSTANCE_1["instance_id"], + "IPAddress": EXPECTED_INSTANCE_1["ip_address"], + "IsLatestVersion": "True", + "LastAssociationExecutionDate": 6, + "LastPingDateTime": 6, + "LastSuccessfulAssociationExecutionDate": 6, + "Name": "string", + "PingStatus": "string", + "PlatformName": "string", + "PlatformType": EXPECTED_INSTANCE_1["os"], + "PlatformVersion": "string", + "RegistrationDate": 6, + "ResourceType": "string", + }, + { + "ActivationId": "string", + "AgentVersion": "string", + "AssociationOverview": { + "DetailedStatus": "string", + "InstanceAssociationStatusAggregatedCount": {"string": 6}, + }, + "AssociationStatus": "string", + "ComputerName": EXPECTED_INSTANCE_2["name"], + "IamRole": "string", + "InstanceId": EXPECTED_INSTANCE_2["instance_id"], + "IPAddress": EXPECTED_INSTANCE_2["ip_address"], + "IsLatestVersion": "True", + "LastAssociationExecutionDate": 6, + "LastPingDateTime": 6, + "LastSuccessfulAssociationExecutionDate": 6, + "Name": "string", + "PingStatus": "string", + "PlatformName": "string", + "PlatformType": EXPECTED_INSTANCE_2["os"], + "PlatformVersion": "string", + "RegistrationDate": 6, + "ResourceType": "string", + }, +] class StubAWSInstance(AWSInstance): @@ -95,8 +105,12 @@ ACCOUNT_ID = "3" @pytest.fixture -def aws_service(): - aws_instance = StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID) +def aws_instance(): + return StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID) + + +@pytest.fixture +def aws_service(aws_instance): return AWSService(aws_instance) @@ -110,3 +124,27 @@ def test_region(aws_service): def test_account_id(aws_service): assert aws_service.island_aws_instance.account_id == ACCOUNT_ID + + +class MockAWSService(AWSService): + def __init__(self, aws_instance: AWSInstance, instance_info_response: Sequence[Dict[str, Any]]): + super().__init__(aws_instance) + self._instance_info_response = instance_info_response + + def _get_raw_managed_instances(self): + return self._instance_info_response + + +def test_get_managed_instances__empty(aws_instance): + aws_service = MockAWSService(aws_instance, EMPTY_INSTANCE_INFO_RESPONSE) + instances = aws_service.get_managed_instances() + assert len(instances) == 0 + + +def test_get_managed_instances(aws_instance): + aws_service = MockAWSService(aws_instance, FULL_INSTANCE_INFO_RESPONSE) + instances = aws_service.get_managed_instances() + + assert len(instances) == 2 + assert instances[0] == EXPECTED_INSTANCE_1 + assert instances[1] == EXPECTED_INSTANCE_2