Merge pull request #1935 from guardicore/1928-stateful-aws-service

1928 stateful aws service
This commit is contained in:
Mike Salvatore 2022-05-10 07:57:29 -04:00 committed by GitHub
commit 6b4e991fdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 229 additions and 109 deletions

View File

@ -3,3 +3,5 @@ from .directory_file_storage_service import DirectoryFileStorageService
from .authentication.authentication_service import AuthenticationService
from .authentication.json_file_user_datastore import JsonFileUserDatastore
from .aws_service import AWSService

View File

@ -1,5 +1,5 @@
import logging
from typing import Optional
from typing import Any, Iterable, Mapping, Sequence
import boto3
import botocore
@ -15,59 +15,83 @@ IP_ADDRESS_KEY = "IPAddress"
logger = logging.getLogger(__name__)
def filter_instance_data_from_aws_response(response):
class AWSService:
def __init__(self, aws_instance: AWSInstance):
"""
:param aws_instance: An AWSInstance object representing the AWS instance that the Island is
running on
"""
self._aws_instance = aws_instance
def island_is_running_on_aws(self) -> bool:
"""
:return: True if the island is running on an AWS instance. False otherwise.
:rtype: bool
"""
return self._aws_instance.is_instance
@property
def island_aws_instance(self) -> AWSInstance:
"""
:return: an AWSInstance object representing the AWS instance that the Island is running on.
:rtype: AWSInstance
"""
return self._aws_instance
def get_managed_instances(self) -> Sequence[Mapping[str, str]]:
"""
:return: A sequence of mappings, where each Mapping represents a managed AWS instance that
is accessible from the Island.
:rtype: Sequence[Mapping[str, str]]
"""
raw_managed_instances_info = self._get_raw_managed_instances()
return _filter_relevant_instance_info(raw_managed_instances_info)
def _get_raw_managed_instances(self) -> Sequence[Mapping[str, Any]]:
"""
Get the information for all instances with the relevant roles.
This function will assume that it's running on an EC2 instance with the correct IAM role.
See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam
-role for details.
:raises: botocore.exceptions.ClientError if can't describe local instance information.
:return: All visible instances from this instance
"""
local_ssm_client = boto3.client("ssm", self.island_aws_instance.region)
try:
response = local_ssm_client.describe_instance_information()
return response[INSTANCE_INFORMATION_LIST_KEY]
except botocore.exceptions.ClientError as err:
logger.warning("AWS client error while trying to get manage dinstances: {err}")
raise err
def run_agent_on_managed_instances(self, instance_ids: Iterable[str]):
for id_ in instance_ids:
self._run_agent_on_managed_instance(id_)
def _run_agent_on_managed_instance(self, instance_id: str):
pass
def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]):
"""
Consume raw instance data from the AWS API and return only those fields that are relevant for
Infection Monkey.
:param raw_managed_instances_info: The output of
DescribeInstanceInformation["InstanceInformation"] from the
AWS API
:return: A sequence of mappings, where each Mapping represents a managed AWS instance that
is accessible from the Island.
:rtype: Sequence[Mapping[str, str]]
"""
return [
{
"instance_id": x[INSTANCE_ID_KEY],
"name": x[COMPUTER_NAME_KEY],
"os": x[PLATFORM_TYPE_KEY].lower(),
"ip_address": x[IP_ADDRESS_KEY],
"instance_id": managed_instance[INSTANCE_ID_KEY],
"name": managed_instance[COMPUTER_NAME_KEY],
"os": managed_instance[PLATFORM_TYPE_KEY].lower(),
"ip_address": managed_instance[IP_ADDRESS_KEY],
}
for x in response[INSTANCE_INFORMATION_LIST_KEY]
for managed_instance in raw_managed_instances_info
]
aws_instance: Optional[AWSInstance] = None
def initialize():
global aws_instance
aws_instance = AWSInstance()
def is_on_aws():
return aws_instance.is_instance
def get_region():
return aws_instance.region
def get_account_id():
return aws_instance.account_id
def get_client(client_type):
return boto3.client(client_type, region_name=aws_instance.region)
def get_instances():
"""
Get the information for all instances with the relevant roles.
This function will assume that it's running on an EC2 instance with the correct IAM role.
See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam
-role for details.
:raises: botocore.exceptions.ClientError if can't describe local instance information.
:return: All visible instances from this instance
"""
local_ssm_client = boto3.client("ssm", aws_instance.region)
try:
response = local_ssm_client.describe_instance_information()
filtered_instances_data = filter_instance_data_from_aws_response(response)
return filtered_instances_data
except botocore.exceptions.ClientError as e:
logger.warning("AWS client error while trying to get instances: " + e)
raise e

View File

@ -1,8 +1,8 @@
from pathlib import Path
from threading import Thread
from common import DIContainer
from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService, aws_service
from common.aws import AWSInstance
from monkey_island.cc.services import AWSService, DirectoryFileStorageService, IFileStorageService
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
@ -11,12 +11,12 @@ from . import AuthenticationService, JsonFileUserDatastore
def initialize_services(data_dir: Path) -> DIContainer:
container = DIContainer()
container.register_instance(AWSInstance, AWSInstance())
container.register_instance(
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
)
# Takes a while so it's best to start it in the background
Thread(target=aws_service.initialize, name="AwsService initialization", daemon=True).start()
container.register_instance(AWSService, container.resolve(AWSService))
# This is temporary until we get DI all worked out.
PostBreachFilesService.initialize(container.resolve(IFileStorageService))

View File

@ -1,56 +1,150 @@
import json
from unittest import TestCase
import threading
from typing import Any, Dict, Optional, Sequence
from monkey_island.cc.services.aws_service import filter_instance_data_from_aws_response
import pytest
from common.aws import AWSInstance
from monkey_island.cc.services import AWSService
EXPECTED_INSTANCE_1 = {
"instance_id": "1",
"name": "comp1",
"os": "linux",
"ip_address": "192.168.1.1",
}
EXPECTED_INSTANCE_2 = {
"instance_id": "2",
"name": "comp2",
"os": "linux",
"ip_address": "192.168.1.2",
}
EMPTY_INSTANCE_INFO_RESPONSE = []
FULL_INSTANCE_INFO_RESPONSE = [
{
"ActivationId": "string",
"AgentVersion": "string",
"AssociationOverview": {
"DetailedStatus": "string",
"InstanceAssociationStatusAggregatedCount": {"string": 6},
},
"AssociationStatus": "string",
"ComputerName": EXPECTED_INSTANCE_1["name"],
"IamRole": "string",
"InstanceId": EXPECTED_INSTANCE_1["instance_id"],
"IPAddress": EXPECTED_INSTANCE_1["ip_address"],
"IsLatestVersion": "True",
"LastAssociationExecutionDate": 6,
"LastPingDateTime": 6,
"LastSuccessfulAssociationExecutionDate": 6,
"Name": "string",
"PingStatus": "string",
"PlatformName": "string",
"PlatformType": EXPECTED_INSTANCE_1["os"],
"PlatformVersion": "string",
"RegistrationDate": 6,
"ResourceType": "string",
},
{
"ActivationId": "string",
"AgentVersion": "string",
"AssociationOverview": {
"DetailedStatus": "string",
"InstanceAssociationStatusAggregatedCount": {"string": 6},
},
"AssociationStatus": "string",
"ComputerName": EXPECTED_INSTANCE_2["name"],
"IamRole": "string",
"InstanceId": EXPECTED_INSTANCE_2["instance_id"],
"IPAddress": EXPECTED_INSTANCE_2["ip_address"],
"IsLatestVersion": "True",
"LastAssociationExecutionDate": 6,
"LastPingDateTime": 6,
"LastSuccessfulAssociationExecutionDate": 6,
"Name": "string",
"PingStatus": "string",
"PlatformName": "string",
"PlatformType": EXPECTED_INSTANCE_2["os"],
"PlatformVersion": "string",
"RegistrationDate": 6,
"ResourceType": "string",
},
]
class TestAwsService(TestCase):
def test_filter_instance_data_from_aws_response(self):
json_response_full = """
{
"InstanceInformationList": [
{
"ActivationId": "string",
"AgentVersion": "string",
"AssociationOverview": {
"DetailedStatus": "string",
"InstanceAssociationStatusAggregatedCount": {
"string" : 6
}
},
"AssociationStatus": "string",
"ComputerName": "string",
"IamRole": "string",
"InstanceId": "string",
"IPAddress": "string",
"IsLatestVersion": "True",
"LastAssociationExecutionDate": 6,
"LastPingDateTime": 6,
"LastSuccessfulAssociationExecutionDate": 6,
"Name": "string",
"PingStatus": "string",
"PlatformName": "string",
"PlatformType": "string",
"PlatformVersion": "string",
"RegistrationDate": 6,
"ResourceType": "string"
}
],
"NextToken": "string"
}
"""
class StubAWSInstance(AWSInstance):
def __init__(
self,
instance_id: Optional[str] = None,
region: Optional[str] = None,
account_id: Optional[str] = None,
):
self._instance_id = instance_id
self._region = region
self._account_id = account_id
json_response_empty = """
{
"InstanceInformationList": [],
"NextToken": "string"
}
"""
self._initialization_complete = threading.Event()
self._initialization_complete.set()
self.assertEqual(
filter_instance_data_from_aws_response(json.loads(json_response_empty)), []
)
self.assertEqual(
filter_instance_data_from_aws_response(json.loads(json_response_full)),
[{"instance_id": "string", "ip_address": "string", "name": "string", "os": "string"}],
)
def test_aws_is_on_aws__true():
aws_instance = StubAWSInstance("1")
aws_service = AWSService(aws_instance)
assert aws_service.island_is_running_on_aws() is True
def test_aws_is_on_aws__False():
aws_instance = StubAWSInstance()
aws_service = AWSService(aws_instance)
assert aws_service.island_is_running_on_aws() is False
INSTANCE_ID = "1"
REGION = "2"
ACCOUNT_ID = "3"
@pytest.fixture
def aws_instance():
return StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID)
@pytest.fixture
def aws_service(aws_instance):
return AWSService(aws_instance)
def test_instance_id(aws_service):
assert aws_service.island_aws_instance.instance_id == INSTANCE_ID
def test_region(aws_service):
assert aws_service.island_aws_instance.region == REGION
def test_account_id(aws_service):
assert aws_service.island_aws_instance.account_id == ACCOUNT_ID
class MockAWSService(AWSService):
def __init__(self, aws_instance: AWSInstance, instance_info_response: Sequence[Dict[str, Any]]):
super().__init__(aws_instance)
self._instance_info_response = instance_info_response
def _get_raw_managed_instances(self):
return self._instance_info_response
def test_get_managed_instances__empty(aws_instance):
aws_service = MockAWSService(aws_instance, EMPTY_INSTANCE_INFO_RESPONSE)
instances = aws_service.get_managed_instances()
assert len(instances) == 0
def test_get_managed_instances(aws_instance):
aws_service = MockAWSService(aws_instance, FULL_INSTANCE_INFO_RESPONSE)
instances = aws_service.get_managed_instances()
assert len(instances) == 2
assert instances[0] == EXPECTED_INSTANCE_1
assert instances[1] == EXPECTED_INSTANCE_2