diff --git a/monkey/monkey_island/cc/services/__init__.py b/monkey/monkey_island/cc/services/__init__.py index 43aa39382..7d96ee5c4 100644 --- a/monkey/monkey_island/cc/services/__init__.py +++ b/monkey/monkey_island/cc/services/__init__.py @@ -3,3 +3,5 @@ from .directory_file_storage_service import DirectoryFileStorageService from .authentication.authentication_service import AuthenticationService from .authentication.json_file_user_datastore import JsonFileUserDatastore + +from .aws_service import AWSService diff --git a/monkey/monkey_island/cc/services/aws_service.py b/monkey/monkey_island/cc/services/aws_service.py index 12432b8b7..1a8dec455 100644 --- a/monkey/monkey_island/cc/services/aws_service.py +++ b/monkey/monkey_island/cc/services/aws_service.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Any, Iterable, Mapping, Sequence import boto3 import botocore @@ -15,59 +15,83 @@ IP_ADDRESS_KEY = "IPAddress" logger = logging.getLogger(__name__) -def filter_instance_data_from_aws_response(response): +class AWSService: + def __init__(self, aws_instance: AWSInstance): + """ + :param aws_instance: An AWSInstance object representing the AWS instance that the Island is + running on + """ + self._aws_instance = aws_instance + + def island_is_running_on_aws(self) -> bool: + """ + :return: True if the island is running on an AWS instance. False otherwise. + :rtype: bool + """ + return self._aws_instance.is_instance + + @property + def island_aws_instance(self) -> AWSInstance: + """ + :return: an AWSInstance object representing the AWS instance that the Island is running on. + :rtype: AWSInstance + """ + return self._aws_instance + + def get_managed_instances(self) -> Sequence[Mapping[str, str]]: + """ + :return: A sequence of mappings, where each Mapping represents a managed AWS instance that + is accessible from the Island. + :rtype: Sequence[Mapping[str, str]] + """ + raw_managed_instances_info = self._get_raw_managed_instances() + return _filter_relevant_instance_info(raw_managed_instances_info) + + def _get_raw_managed_instances(self) -> Sequence[Mapping[str, Any]]: + """ + Get the information for all instances with the relevant roles. + + This function will assume that it's running on an EC2 instance with the correct IAM role. + See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam + -role for details. + + :raises: botocore.exceptions.ClientError if can't describe local instance information. + :return: All visible instances from this instance + """ + local_ssm_client = boto3.client("ssm", self.island_aws_instance.region) + try: + response = local_ssm_client.describe_instance_information() + return response[INSTANCE_INFORMATION_LIST_KEY] + except botocore.exceptions.ClientError as err: + logger.warning("AWS client error while trying to get manage dinstances: {err}") + raise err + + def run_agent_on_managed_instances(self, instance_ids: Iterable[str]): + for id_ in instance_ids: + self._run_agent_on_managed_instance(id_) + + def _run_agent_on_managed_instance(self, instance_id: str): + pass + + +def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]): + """ + Consume raw instance data from the AWS API and return only those fields that are relevant for + Infection Monkey. + + :param raw_managed_instances_info: The output of + DescribeInstanceInformation["InstanceInformation"] from the + AWS API + :return: A sequence of mappings, where each Mapping represents a managed AWS instance that + is accessible from the Island. + :rtype: Sequence[Mapping[str, str]] + """ return [ { - "instance_id": x[INSTANCE_ID_KEY], - "name": x[COMPUTER_NAME_KEY], - "os": x[PLATFORM_TYPE_KEY].lower(), - "ip_address": x[IP_ADDRESS_KEY], + "instance_id": managed_instance[INSTANCE_ID_KEY], + "name": managed_instance[COMPUTER_NAME_KEY], + "os": managed_instance[PLATFORM_TYPE_KEY].lower(), + "ip_address": managed_instance[IP_ADDRESS_KEY], } - for x in response[INSTANCE_INFORMATION_LIST_KEY] + for managed_instance in raw_managed_instances_info ] - - -aws_instance: Optional[AWSInstance] = None - - -def initialize(): - global aws_instance - aws_instance = AWSInstance() - - -def is_on_aws(): - return aws_instance.is_instance - - -def get_region(): - return aws_instance.region - - -def get_account_id(): - return aws_instance.account_id - - -def get_client(client_type): - return boto3.client(client_type, region_name=aws_instance.region) - - -def get_instances(): - """ - Get the information for all instances with the relevant roles. - - This function will assume that it's running on an EC2 instance with the correct IAM role. - See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam - -role for details. - - :raises: botocore.exceptions.ClientError if can't describe local instance information. - :return: All visible instances from this instance - """ - local_ssm_client = boto3.client("ssm", aws_instance.region) - try: - response = local_ssm_client.describe_instance_information() - - filtered_instances_data = filter_instance_data_from_aws_response(response) - return filtered_instances_data - except botocore.exceptions.ClientError as e: - logger.warning("AWS client error while trying to get instances: " + e) - raise e diff --git a/monkey/monkey_island/cc/services/initialize.py b/monkey/monkey_island/cc/services/initialize.py index faa3bcef9..8e0540329 100644 --- a/monkey/monkey_island/cc/services/initialize.py +++ b/monkey/monkey_island/cc/services/initialize.py @@ -1,8 +1,8 @@ from pathlib import Path -from threading import Thread from common import DIContainer -from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService, aws_service +from common.aws import AWSInstance +from monkey_island.cc.services import AWSService, DirectoryFileStorageService, IFileStorageService from monkey_island.cc.services.post_breach_files import PostBreachFilesService from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService @@ -11,12 +11,12 @@ from . import AuthenticationService, JsonFileUserDatastore def initialize_services(data_dir: Path) -> DIContainer: container = DIContainer() + container.register_instance(AWSInstance, AWSInstance()) + container.register_instance( IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas") ) - - # Takes a while so it's best to start it in the background - Thread(target=aws_service.initialize, name="AwsService initialization", daemon=True).start() + container.register_instance(AWSService, container.resolve(AWSService)) # This is temporary until we get DI all worked out. PostBreachFilesService.initialize(container.resolve(IFileStorageService)) diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_aws_service.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_aws_service.py index 0d8a71f36..8be306310 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/test_aws_service.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_aws_service.py @@ -1,56 +1,150 @@ -import json -from unittest import TestCase +import threading +from typing import Any, Dict, Optional, Sequence -from monkey_island.cc.services.aws_service import filter_instance_data_from_aws_response +import pytest + +from common.aws import AWSInstance +from monkey_island.cc.services import AWSService + +EXPECTED_INSTANCE_1 = { + "instance_id": "1", + "name": "comp1", + "os": "linux", + "ip_address": "192.168.1.1", +} +EXPECTED_INSTANCE_2 = { + "instance_id": "2", + "name": "comp2", + "os": "linux", + "ip_address": "192.168.1.2", +} + +EMPTY_INSTANCE_INFO_RESPONSE = [] +FULL_INSTANCE_INFO_RESPONSE = [ + { + "ActivationId": "string", + "AgentVersion": "string", + "AssociationOverview": { + "DetailedStatus": "string", + "InstanceAssociationStatusAggregatedCount": {"string": 6}, + }, + "AssociationStatus": "string", + "ComputerName": EXPECTED_INSTANCE_1["name"], + "IamRole": "string", + "InstanceId": EXPECTED_INSTANCE_1["instance_id"], + "IPAddress": EXPECTED_INSTANCE_1["ip_address"], + "IsLatestVersion": "True", + "LastAssociationExecutionDate": 6, + "LastPingDateTime": 6, + "LastSuccessfulAssociationExecutionDate": 6, + "Name": "string", + "PingStatus": "string", + "PlatformName": "string", + "PlatformType": EXPECTED_INSTANCE_1["os"], + "PlatformVersion": "string", + "RegistrationDate": 6, + "ResourceType": "string", + }, + { + "ActivationId": "string", + "AgentVersion": "string", + "AssociationOverview": { + "DetailedStatus": "string", + "InstanceAssociationStatusAggregatedCount": {"string": 6}, + }, + "AssociationStatus": "string", + "ComputerName": EXPECTED_INSTANCE_2["name"], + "IamRole": "string", + "InstanceId": EXPECTED_INSTANCE_2["instance_id"], + "IPAddress": EXPECTED_INSTANCE_2["ip_address"], + "IsLatestVersion": "True", + "LastAssociationExecutionDate": 6, + "LastPingDateTime": 6, + "LastSuccessfulAssociationExecutionDate": 6, + "Name": "string", + "PingStatus": "string", + "PlatformName": "string", + "PlatformType": EXPECTED_INSTANCE_2["os"], + "PlatformVersion": "string", + "RegistrationDate": 6, + "ResourceType": "string", + }, +] -class TestAwsService(TestCase): - def test_filter_instance_data_from_aws_response(self): - json_response_full = """ - { - "InstanceInformationList": [ - { - "ActivationId": "string", - "AgentVersion": "string", - "AssociationOverview": { - "DetailedStatus": "string", - "InstanceAssociationStatusAggregatedCount": { - "string" : 6 - } - }, - "AssociationStatus": "string", - "ComputerName": "string", - "IamRole": "string", - "InstanceId": "string", - "IPAddress": "string", - "IsLatestVersion": "True", - "LastAssociationExecutionDate": 6, - "LastPingDateTime": 6, - "LastSuccessfulAssociationExecutionDate": 6, - "Name": "string", - "PingStatus": "string", - "PlatformName": "string", - "PlatformType": "string", - "PlatformVersion": "string", - "RegistrationDate": 6, - "ResourceType": "string" - } - ], - "NextToken": "string" - } - """ +class StubAWSInstance(AWSInstance): + def __init__( + self, + instance_id: Optional[str] = None, + region: Optional[str] = None, + account_id: Optional[str] = None, + ): + self._instance_id = instance_id + self._region = region + self._account_id = account_id - json_response_empty = """ - { - "InstanceInformationList": [], - "NextToken": "string" - } - """ + self._initialization_complete = threading.Event() + self._initialization_complete.set() - self.assertEqual( - filter_instance_data_from_aws_response(json.loads(json_response_empty)), [] - ) - self.assertEqual( - filter_instance_data_from_aws_response(json.loads(json_response_full)), - [{"instance_id": "string", "ip_address": "string", "name": "string", "os": "string"}], - ) + +def test_aws_is_on_aws__true(): + aws_instance = StubAWSInstance("1") + aws_service = AWSService(aws_instance) + assert aws_service.island_is_running_on_aws() is True + + +def test_aws_is_on_aws__False(): + aws_instance = StubAWSInstance() + aws_service = AWSService(aws_instance) + assert aws_service.island_is_running_on_aws() is False + + +INSTANCE_ID = "1" +REGION = "2" +ACCOUNT_ID = "3" + + +@pytest.fixture +def aws_instance(): + return StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID) + + +@pytest.fixture +def aws_service(aws_instance): + return AWSService(aws_instance) + + +def test_instance_id(aws_service): + assert aws_service.island_aws_instance.instance_id == INSTANCE_ID + + +def test_region(aws_service): + assert aws_service.island_aws_instance.region == REGION + + +def test_account_id(aws_service): + assert aws_service.island_aws_instance.account_id == ACCOUNT_ID + + +class MockAWSService(AWSService): + def __init__(self, aws_instance: AWSInstance, instance_info_response: Sequence[Dict[str, Any]]): + super().__init__(aws_instance) + self._instance_info_response = instance_info_response + + def _get_raw_managed_instances(self): + return self._instance_info_response + + +def test_get_managed_instances__empty(aws_instance): + aws_service = MockAWSService(aws_instance, EMPTY_INSTANCE_INFO_RESPONSE) + instances = aws_service.get_managed_instances() + assert len(instances) == 0 + + +def test_get_managed_instances(aws_instance): + aws_service = MockAWSService(aws_instance, FULL_INSTANCE_INFO_RESPONSE) + instances = aws_service.get_managed_instances() + + assert len(instances) == 2 + assert instances[0] == EXPECTED_INSTANCE_1 + assert instances[1] == EXPECTED_INSTANCE_2