Merge pull request #1935 from guardicore/1928-stateful-aws-service
1928 stateful aws service
This commit is contained in:
commit
6b4e991fdc
|
@ -3,3 +3,5 @@ from .directory_file_storage_service import DirectoryFileStorageService
|
|||
|
||||
from .authentication.authentication_service import AuthenticationService
|
||||
from .authentication.json_file_user_datastore import JsonFileUserDatastore
|
||||
|
||||
from .aws_service import AWSService
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Iterable, Mapping, Sequence
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
|
@ -15,59 +15,83 @@ IP_ADDRESS_KEY = "IPAddress"
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_instance_data_from_aws_response(response):
|
||||
class AWSService:
|
||||
def __init__(self, aws_instance: AWSInstance):
|
||||
"""
|
||||
:param aws_instance: An AWSInstance object representing the AWS instance that the Island is
|
||||
running on
|
||||
"""
|
||||
self._aws_instance = aws_instance
|
||||
|
||||
def island_is_running_on_aws(self) -> bool:
|
||||
"""
|
||||
:return: True if the island is running on an AWS instance. False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return self._aws_instance.is_instance
|
||||
|
||||
@property
|
||||
def island_aws_instance(self) -> AWSInstance:
|
||||
"""
|
||||
:return: an AWSInstance object representing the AWS instance that the Island is running on.
|
||||
:rtype: AWSInstance
|
||||
"""
|
||||
return self._aws_instance
|
||||
|
||||
def get_managed_instances(self) -> Sequence[Mapping[str, str]]:
|
||||
"""
|
||||
:return: A sequence of mappings, where each Mapping represents a managed AWS instance that
|
||||
is accessible from the Island.
|
||||
:rtype: Sequence[Mapping[str, str]]
|
||||
"""
|
||||
raw_managed_instances_info = self._get_raw_managed_instances()
|
||||
return _filter_relevant_instance_info(raw_managed_instances_info)
|
||||
|
||||
def _get_raw_managed_instances(self) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Get the information for all instances with the relevant roles.
|
||||
|
||||
This function will assume that it's running on an EC2 instance with the correct IAM role.
|
||||
See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam
|
||||
-role for details.
|
||||
|
||||
:raises: botocore.exceptions.ClientError if can't describe local instance information.
|
||||
:return: All visible instances from this instance
|
||||
"""
|
||||
local_ssm_client = boto3.client("ssm", self.island_aws_instance.region)
|
||||
try:
|
||||
response = local_ssm_client.describe_instance_information()
|
||||
return response[INSTANCE_INFORMATION_LIST_KEY]
|
||||
except botocore.exceptions.ClientError as err:
|
||||
logger.warning("AWS client error while trying to get manage dinstances: {err}")
|
||||
raise err
|
||||
|
||||
def run_agent_on_managed_instances(self, instance_ids: Iterable[str]):
|
||||
for id_ in instance_ids:
|
||||
self._run_agent_on_managed_instance(id_)
|
||||
|
||||
def _run_agent_on_managed_instance(self, instance_id: str):
|
||||
pass
|
||||
|
||||
|
||||
def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]):
|
||||
"""
|
||||
Consume raw instance data from the AWS API and return only those fields that are relevant for
|
||||
Infection Monkey.
|
||||
|
||||
:param raw_managed_instances_info: The output of
|
||||
DescribeInstanceInformation["InstanceInformation"] from the
|
||||
AWS API
|
||||
:return: A sequence of mappings, where each Mapping represents a managed AWS instance that
|
||||
is accessible from the Island.
|
||||
:rtype: Sequence[Mapping[str, str]]
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"instance_id": x[INSTANCE_ID_KEY],
|
||||
"name": x[COMPUTER_NAME_KEY],
|
||||
"os": x[PLATFORM_TYPE_KEY].lower(),
|
||||
"ip_address": x[IP_ADDRESS_KEY],
|
||||
"instance_id": managed_instance[INSTANCE_ID_KEY],
|
||||
"name": managed_instance[COMPUTER_NAME_KEY],
|
||||
"os": managed_instance[PLATFORM_TYPE_KEY].lower(),
|
||||
"ip_address": managed_instance[IP_ADDRESS_KEY],
|
||||
}
|
||||
for x in response[INSTANCE_INFORMATION_LIST_KEY]
|
||||
for managed_instance in raw_managed_instances_info
|
||||
]
|
||||
|
||||
|
||||
aws_instance: Optional[AWSInstance] = None
|
||||
|
||||
|
||||
def initialize():
|
||||
global aws_instance
|
||||
aws_instance = AWSInstance()
|
||||
|
||||
|
||||
def is_on_aws():
|
||||
return aws_instance.is_instance
|
||||
|
||||
|
||||
def get_region():
|
||||
return aws_instance.region
|
||||
|
||||
|
||||
def get_account_id():
|
||||
return aws_instance.account_id
|
||||
|
||||
|
||||
def get_client(client_type):
|
||||
return boto3.client(client_type, region_name=aws_instance.region)
|
||||
|
||||
|
||||
def get_instances():
|
||||
"""
|
||||
Get the information for all instances with the relevant roles.
|
||||
|
||||
This function will assume that it's running on an EC2 instance with the correct IAM role.
|
||||
See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam
|
||||
-role for details.
|
||||
|
||||
:raises: botocore.exceptions.ClientError if can't describe local instance information.
|
||||
:return: All visible instances from this instance
|
||||
"""
|
||||
local_ssm_client = boto3.client("ssm", aws_instance.region)
|
||||
try:
|
||||
response = local_ssm_client.describe_instance_information()
|
||||
|
||||
filtered_instances_data = filter_instance_data_from_aws_response(response)
|
||||
return filtered_instances_data
|
||||
except botocore.exceptions.ClientError as e:
|
||||
logger.warning("AWS client error while trying to get instances: " + e)
|
||||
raise e
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
from common import DIContainer
|
||||
from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService, aws_service
|
||||
from common.aws import AWSInstance
|
||||
from monkey_island.cc.services import AWSService, DirectoryFileStorageService, IFileStorageService
|
||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
||||
|
||||
|
@ -11,12 +11,12 @@ from . import AuthenticationService, JsonFileUserDatastore
|
|||
|
||||
def initialize_services(data_dir: Path) -> DIContainer:
|
||||
container = DIContainer()
|
||||
container.register_instance(AWSInstance, AWSInstance())
|
||||
|
||||
container.register_instance(
|
||||
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
|
||||
)
|
||||
|
||||
# Takes a while so it's best to start it in the background
|
||||
Thread(target=aws_service.initialize, name="AwsService initialization", daemon=True).start()
|
||||
container.register_instance(AWSService, container.resolve(AWSService))
|
||||
|
||||
# This is temporary until we get DI all worked out.
|
||||
PostBreachFilesService.initialize(container.resolve(IFileStorageService))
|
||||
|
|
|
@ -1,56 +1,150 @@
|
|||
import json
|
||||
from unittest import TestCase
|
||||
import threading
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
from monkey_island.cc.services.aws_service import filter_instance_data_from_aws_response
|
||||
import pytest
|
||||
|
||||
from common.aws import AWSInstance
|
||||
from monkey_island.cc.services import AWSService
|
||||
|
||||
EXPECTED_INSTANCE_1 = {
|
||||
"instance_id": "1",
|
||||
"name": "comp1",
|
||||
"os": "linux",
|
||||
"ip_address": "192.168.1.1",
|
||||
}
|
||||
EXPECTED_INSTANCE_2 = {
|
||||
"instance_id": "2",
|
||||
"name": "comp2",
|
||||
"os": "linux",
|
||||
"ip_address": "192.168.1.2",
|
||||
}
|
||||
|
||||
EMPTY_INSTANCE_INFO_RESPONSE = []
|
||||
FULL_INSTANCE_INFO_RESPONSE = [
|
||||
{
|
||||
"ActivationId": "string",
|
||||
"AgentVersion": "string",
|
||||
"AssociationOverview": {
|
||||
"DetailedStatus": "string",
|
||||
"InstanceAssociationStatusAggregatedCount": {"string": 6},
|
||||
},
|
||||
"AssociationStatus": "string",
|
||||
"ComputerName": EXPECTED_INSTANCE_1["name"],
|
||||
"IamRole": "string",
|
||||
"InstanceId": EXPECTED_INSTANCE_1["instance_id"],
|
||||
"IPAddress": EXPECTED_INSTANCE_1["ip_address"],
|
||||
"IsLatestVersion": "True",
|
||||
"LastAssociationExecutionDate": 6,
|
||||
"LastPingDateTime": 6,
|
||||
"LastSuccessfulAssociationExecutionDate": 6,
|
||||
"Name": "string",
|
||||
"PingStatus": "string",
|
||||
"PlatformName": "string",
|
||||
"PlatformType": EXPECTED_INSTANCE_1["os"],
|
||||
"PlatformVersion": "string",
|
||||
"RegistrationDate": 6,
|
||||
"ResourceType": "string",
|
||||
},
|
||||
{
|
||||
"ActivationId": "string",
|
||||
"AgentVersion": "string",
|
||||
"AssociationOverview": {
|
||||
"DetailedStatus": "string",
|
||||
"InstanceAssociationStatusAggregatedCount": {"string": 6},
|
||||
},
|
||||
"AssociationStatus": "string",
|
||||
"ComputerName": EXPECTED_INSTANCE_2["name"],
|
||||
"IamRole": "string",
|
||||
"InstanceId": EXPECTED_INSTANCE_2["instance_id"],
|
||||
"IPAddress": EXPECTED_INSTANCE_2["ip_address"],
|
||||
"IsLatestVersion": "True",
|
||||
"LastAssociationExecutionDate": 6,
|
||||
"LastPingDateTime": 6,
|
||||
"LastSuccessfulAssociationExecutionDate": 6,
|
||||
"Name": "string",
|
||||
"PingStatus": "string",
|
||||
"PlatformName": "string",
|
||||
"PlatformType": EXPECTED_INSTANCE_2["os"],
|
||||
"PlatformVersion": "string",
|
||||
"RegistrationDate": 6,
|
||||
"ResourceType": "string",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestAwsService(TestCase):
|
||||
def test_filter_instance_data_from_aws_response(self):
|
||||
json_response_full = """
|
||||
{
|
||||
"InstanceInformationList": [
|
||||
{
|
||||
"ActivationId": "string",
|
||||
"AgentVersion": "string",
|
||||
"AssociationOverview": {
|
||||
"DetailedStatus": "string",
|
||||
"InstanceAssociationStatusAggregatedCount": {
|
||||
"string" : 6
|
||||
}
|
||||
},
|
||||
"AssociationStatus": "string",
|
||||
"ComputerName": "string",
|
||||
"IamRole": "string",
|
||||
"InstanceId": "string",
|
||||
"IPAddress": "string",
|
||||
"IsLatestVersion": "True",
|
||||
"LastAssociationExecutionDate": 6,
|
||||
"LastPingDateTime": 6,
|
||||
"LastSuccessfulAssociationExecutionDate": 6,
|
||||
"Name": "string",
|
||||
"PingStatus": "string",
|
||||
"PlatformName": "string",
|
||||
"PlatformType": "string",
|
||||
"PlatformVersion": "string",
|
||||
"RegistrationDate": 6,
|
||||
"ResourceType": "string"
|
||||
}
|
||||
],
|
||||
"NextToken": "string"
|
||||
}
|
||||
"""
|
||||
class StubAWSInstance(AWSInstance):
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
):
|
||||
self._instance_id = instance_id
|
||||
self._region = region
|
||||
self._account_id = account_id
|
||||
|
||||
json_response_empty = """
|
||||
{
|
||||
"InstanceInformationList": [],
|
||||
"NextToken": "string"
|
||||
}
|
||||
"""
|
||||
self._initialization_complete = threading.Event()
|
||||
self._initialization_complete.set()
|
||||
|
||||
self.assertEqual(
|
||||
filter_instance_data_from_aws_response(json.loads(json_response_empty)), []
|
||||
)
|
||||
self.assertEqual(
|
||||
filter_instance_data_from_aws_response(json.loads(json_response_full)),
|
||||
[{"instance_id": "string", "ip_address": "string", "name": "string", "os": "string"}],
|
||||
)
|
||||
|
||||
def test_aws_is_on_aws__true():
|
||||
aws_instance = StubAWSInstance("1")
|
||||
aws_service = AWSService(aws_instance)
|
||||
assert aws_service.island_is_running_on_aws() is True
|
||||
|
||||
|
||||
def test_aws_is_on_aws__False():
|
||||
aws_instance = StubAWSInstance()
|
||||
aws_service = AWSService(aws_instance)
|
||||
assert aws_service.island_is_running_on_aws() is False
|
||||
|
||||
|
||||
INSTANCE_ID = "1"
|
||||
REGION = "2"
|
||||
ACCOUNT_ID = "3"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aws_instance():
|
||||
return StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aws_service(aws_instance):
|
||||
return AWSService(aws_instance)
|
||||
|
||||
|
||||
def test_instance_id(aws_service):
|
||||
assert aws_service.island_aws_instance.instance_id == INSTANCE_ID
|
||||
|
||||
|
||||
def test_region(aws_service):
|
||||
assert aws_service.island_aws_instance.region == REGION
|
||||
|
||||
|
||||
def test_account_id(aws_service):
|
||||
assert aws_service.island_aws_instance.account_id == ACCOUNT_ID
|
||||
|
||||
|
||||
class MockAWSService(AWSService):
|
||||
def __init__(self, aws_instance: AWSInstance, instance_info_response: Sequence[Dict[str, Any]]):
|
||||
super().__init__(aws_instance)
|
||||
self._instance_info_response = instance_info_response
|
||||
|
||||
def _get_raw_managed_instances(self):
|
||||
return self._instance_info_response
|
||||
|
||||
|
||||
def test_get_managed_instances__empty(aws_instance):
|
||||
aws_service = MockAWSService(aws_instance, EMPTY_INSTANCE_INFO_RESPONSE)
|
||||
instances = aws_service.get_managed_instances()
|
||||
assert len(instances) == 0
|
||||
|
||||
|
||||
def test_get_managed_instances(aws_instance):
|
||||
aws_service = MockAWSService(aws_instance, FULL_INSTANCE_INFO_RESPONSE)
|
||||
instances = aws_service.get_managed_instances()
|
||||
|
||||
assert len(instances) == 2
|
||||
assert instances[0] == EXPECTED_INSTANCE_1
|
||||
assert instances[1] == EXPECTED_INSTANCE_2
|
||||
|
|
Loading…
Reference in New Issue