forked from p15670423/monkey
Merge pull request #1935 from guardicore/1928-stateful-aws-service
1928 stateful aws service
This commit is contained in:
commit
6b4e991fdc
|
@ -3,3 +3,5 @@ from .directory_file_storage_service import DirectoryFileStorageService
|
||||||
|
|
||||||
from .authentication.authentication_service import AuthenticationService
|
from .authentication.authentication_service import AuthenticationService
|
||||||
from .authentication.json_file_user_datastore import JsonFileUserDatastore
|
from .authentication.json_file_user_datastore import JsonFileUserDatastore
|
||||||
|
|
||||||
|
from .aws_service import AWSService
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Any, Iterable, Mapping, Sequence
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import botocore
|
import botocore
|
||||||
|
@ -15,43 +15,39 @@ IP_ADDRESS_KEY = "IPAddress"
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def filter_instance_data_from_aws_response(response):
|
class AWSService:
|
||||||
return [
|
def __init__(self, aws_instance: AWSInstance):
|
||||||
{
|
"""
|
||||||
"instance_id": x[INSTANCE_ID_KEY],
|
:param aws_instance: An AWSInstance object representing the AWS instance that the Island is
|
||||||
"name": x[COMPUTER_NAME_KEY],
|
running on
|
||||||
"os": x[PLATFORM_TYPE_KEY].lower(),
|
"""
|
||||||
"ip_address": x[IP_ADDRESS_KEY],
|
self._aws_instance = aws_instance
|
||||||
}
|
|
||||||
for x in response[INSTANCE_INFORMATION_LIST_KEY]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
def island_is_running_on_aws(self) -> bool:
|
||||||
|
"""
|
||||||
|
:return: True if the island is running on an AWS instance. False otherwise.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return self._aws_instance.is_instance
|
||||||
|
|
||||||
aws_instance: Optional[AWSInstance] = None
|
@property
|
||||||
|
def island_aws_instance(self) -> AWSInstance:
|
||||||
|
"""
|
||||||
|
:return: an AWSInstance object representing the AWS instance that the Island is running on.
|
||||||
|
:rtype: AWSInstance
|
||||||
|
"""
|
||||||
|
return self._aws_instance
|
||||||
|
|
||||||
|
def get_managed_instances(self) -> Sequence[Mapping[str, str]]:
|
||||||
|
"""
|
||||||
|
:return: A sequence of mappings, where each Mapping represents a managed AWS instance that
|
||||||
|
is accessible from the Island.
|
||||||
|
:rtype: Sequence[Mapping[str, str]]
|
||||||
|
"""
|
||||||
|
raw_managed_instances_info = self._get_raw_managed_instances()
|
||||||
|
return _filter_relevant_instance_info(raw_managed_instances_info)
|
||||||
|
|
||||||
def initialize():
|
def _get_raw_managed_instances(self) -> Sequence[Mapping[str, Any]]:
|
||||||
global aws_instance
|
|
||||||
aws_instance = AWSInstance()
|
|
||||||
|
|
||||||
|
|
||||||
def is_on_aws():
|
|
||||||
return aws_instance.is_instance
|
|
||||||
|
|
||||||
|
|
||||||
def get_region():
|
|
||||||
return aws_instance.region
|
|
||||||
|
|
||||||
|
|
||||||
def get_account_id():
|
|
||||||
return aws_instance.account_id
|
|
||||||
|
|
||||||
|
|
||||||
def get_client(client_type):
|
|
||||||
return boto3.client(client_type, region_name=aws_instance.region)
|
|
||||||
|
|
||||||
|
|
||||||
def get_instances():
|
|
||||||
"""
|
"""
|
||||||
Get the information for all instances with the relevant roles.
|
Get the information for all instances with the relevant roles.
|
||||||
|
|
||||||
|
@ -62,12 +58,40 @@ def get_instances():
|
||||||
:raises: botocore.exceptions.ClientError if can't describe local instance information.
|
:raises: botocore.exceptions.ClientError if can't describe local instance information.
|
||||||
:return: All visible instances from this instance
|
:return: All visible instances from this instance
|
||||||
"""
|
"""
|
||||||
local_ssm_client = boto3.client("ssm", aws_instance.region)
|
local_ssm_client = boto3.client("ssm", self.island_aws_instance.region)
|
||||||
try:
|
try:
|
||||||
response = local_ssm_client.describe_instance_information()
|
response = local_ssm_client.describe_instance_information()
|
||||||
|
return response[INSTANCE_INFORMATION_LIST_KEY]
|
||||||
|
except botocore.exceptions.ClientError as err:
|
||||||
|
logger.warning("AWS client error while trying to get manage dinstances: {err}")
|
||||||
|
raise err
|
||||||
|
|
||||||
filtered_instances_data = filter_instance_data_from_aws_response(response)
|
def run_agent_on_managed_instances(self, instance_ids: Iterable[str]):
|
||||||
return filtered_instances_data
|
for id_ in instance_ids:
|
||||||
except botocore.exceptions.ClientError as e:
|
self._run_agent_on_managed_instance(id_)
|
||||||
logger.warning("AWS client error while trying to get instances: " + e)
|
|
||||||
raise e
|
def _run_agent_on_managed_instance(self, instance_id: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]):
|
||||||
|
"""
|
||||||
|
Consume raw instance data from the AWS API and return only those fields that are relevant for
|
||||||
|
Infection Monkey.
|
||||||
|
|
||||||
|
:param raw_managed_instances_info: The output of
|
||||||
|
DescribeInstanceInformation["InstanceInformation"] from the
|
||||||
|
AWS API
|
||||||
|
:return: A sequence of mappings, where each Mapping represents a managed AWS instance that
|
||||||
|
is accessible from the Island.
|
||||||
|
:rtype: Sequence[Mapping[str, str]]
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"instance_id": managed_instance[INSTANCE_ID_KEY],
|
||||||
|
"name": managed_instance[COMPUTER_NAME_KEY],
|
||||||
|
"os": managed_instance[PLATFORM_TYPE_KEY].lower(),
|
||||||
|
"ip_address": managed_instance[IP_ADDRESS_KEY],
|
||||||
|
}
|
||||||
|
for managed_instance in raw_managed_instances_info
|
||||||
|
]
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
from common import DIContainer
|
from common import DIContainer
|
||||||
from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService, aws_service
|
from common.aws import AWSInstance
|
||||||
|
from monkey_island.cc.services import AWSService, DirectoryFileStorageService, IFileStorageService
|
||||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||||
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
||||||
|
|
||||||
|
@ -11,12 +11,12 @@ from . import AuthenticationService, JsonFileUserDatastore
|
||||||
|
|
||||||
def initialize_services(data_dir: Path) -> DIContainer:
|
def initialize_services(data_dir: Path) -> DIContainer:
|
||||||
container = DIContainer()
|
container = DIContainer()
|
||||||
|
container.register_instance(AWSInstance, AWSInstance())
|
||||||
|
|
||||||
container.register_instance(
|
container.register_instance(
|
||||||
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
|
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
|
||||||
)
|
)
|
||||||
|
container.register_instance(AWSService, container.resolve(AWSService))
|
||||||
# Takes a while so it's best to start it in the background
|
|
||||||
Thread(target=aws_service.initialize, name="AwsService initialization", daemon=True).start()
|
|
||||||
|
|
||||||
# This is temporary until we get DI all worked out.
|
# This is temporary until we get DI all worked out.
|
||||||
PostBreachFilesService.initialize(container.resolve(IFileStorageService))
|
PostBreachFilesService.initialize(container.resolve(IFileStorageService))
|
||||||
|
|
|
@ -1,28 +1,38 @@
|
||||||
import json
|
import threading
|
||||||
from unittest import TestCase
|
from typing import Any, Dict, Optional, Sequence
|
||||||
|
|
||||||
from monkey_island.cc.services.aws_service import filter_instance_data_from_aws_response
|
import pytest
|
||||||
|
|
||||||
|
from common.aws import AWSInstance
|
||||||
|
from monkey_island.cc.services import AWSService
|
||||||
|
|
||||||
class TestAwsService(TestCase):
|
EXPECTED_INSTANCE_1 = {
|
||||||
def test_filter_instance_data_from_aws_response(self):
|
"instance_id": "1",
|
||||||
json_response_full = """
|
"name": "comp1",
|
||||||
{
|
"os": "linux",
|
||||||
"InstanceInformationList": [
|
"ip_address": "192.168.1.1",
|
||||||
|
}
|
||||||
|
EXPECTED_INSTANCE_2 = {
|
||||||
|
"instance_id": "2",
|
||||||
|
"name": "comp2",
|
||||||
|
"os": "linux",
|
||||||
|
"ip_address": "192.168.1.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
EMPTY_INSTANCE_INFO_RESPONSE = []
|
||||||
|
FULL_INSTANCE_INFO_RESPONSE = [
|
||||||
{
|
{
|
||||||
"ActivationId": "string",
|
"ActivationId": "string",
|
||||||
"AgentVersion": "string",
|
"AgentVersion": "string",
|
||||||
"AssociationOverview": {
|
"AssociationOverview": {
|
||||||
"DetailedStatus": "string",
|
"DetailedStatus": "string",
|
||||||
"InstanceAssociationStatusAggregatedCount": {
|
"InstanceAssociationStatusAggregatedCount": {"string": 6},
|
||||||
"string" : 6
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"AssociationStatus": "string",
|
"AssociationStatus": "string",
|
||||||
"ComputerName": "string",
|
"ComputerName": EXPECTED_INSTANCE_1["name"],
|
||||||
"IamRole": "string",
|
"IamRole": "string",
|
||||||
"InstanceId": "string",
|
"InstanceId": EXPECTED_INSTANCE_1["instance_id"],
|
||||||
"IPAddress": "string",
|
"IPAddress": EXPECTED_INSTANCE_1["ip_address"],
|
||||||
"IsLatestVersion": "True",
|
"IsLatestVersion": "True",
|
||||||
"LastAssociationExecutionDate": 6,
|
"LastAssociationExecutionDate": 6,
|
||||||
"LastPingDateTime": 6,
|
"LastPingDateTime": 6,
|
||||||
|
@ -30,27 +40,111 @@ class TestAwsService(TestCase):
|
||||||
"Name": "string",
|
"Name": "string",
|
||||||
"PingStatus": "string",
|
"PingStatus": "string",
|
||||||
"PlatformName": "string",
|
"PlatformName": "string",
|
||||||
"PlatformType": "string",
|
"PlatformType": EXPECTED_INSTANCE_1["os"],
|
||||||
"PlatformVersion": "string",
|
"PlatformVersion": "string",
|
||||||
"RegistrationDate": 6,
|
"RegistrationDate": 6,
|
||||||
"ResourceType": "string"
|
"ResourceType": "string",
|
||||||
}
|
},
|
||||||
],
|
|
||||||
"NextToken": "string"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
json_response_empty = """
|
|
||||||
{
|
{
|
||||||
"InstanceInformationList": [],
|
"ActivationId": "string",
|
||||||
"NextToken": "string"
|
"AgentVersion": "string",
|
||||||
}
|
"AssociationOverview": {
|
||||||
"""
|
"DetailedStatus": "string",
|
||||||
|
"InstanceAssociationStatusAggregatedCount": {"string": 6},
|
||||||
|
},
|
||||||
|
"AssociationStatus": "string",
|
||||||
|
"ComputerName": EXPECTED_INSTANCE_2["name"],
|
||||||
|
"IamRole": "string",
|
||||||
|
"InstanceId": EXPECTED_INSTANCE_2["instance_id"],
|
||||||
|
"IPAddress": EXPECTED_INSTANCE_2["ip_address"],
|
||||||
|
"IsLatestVersion": "True",
|
||||||
|
"LastAssociationExecutionDate": 6,
|
||||||
|
"LastPingDateTime": 6,
|
||||||
|
"LastSuccessfulAssociationExecutionDate": 6,
|
||||||
|
"Name": "string",
|
||||||
|
"PingStatus": "string",
|
||||||
|
"PlatformName": "string",
|
||||||
|
"PlatformType": EXPECTED_INSTANCE_2["os"],
|
||||||
|
"PlatformVersion": "string",
|
||||||
|
"RegistrationDate": 6,
|
||||||
|
"ResourceType": "string",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
filter_instance_data_from_aws_response(json.loads(json_response_empty)), []
|
class StubAWSInstance(AWSInstance):
|
||||||
)
|
def __init__(
|
||||||
self.assertEqual(
|
self,
|
||||||
filter_instance_data_from_aws_response(json.loads(json_response_full)),
|
instance_id: Optional[str] = None,
|
||||||
[{"instance_id": "string", "ip_address": "string", "name": "string", "os": "string"}],
|
region: Optional[str] = None,
|
||||||
)
|
account_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._instance_id = instance_id
|
||||||
|
self._region = region
|
||||||
|
self._account_id = account_id
|
||||||
|
|
||||||
|
self._initialization_complete = threading.Event()
|
||||||
|
self._initialization_complete.set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_aws_is_on_aws__true():
|
||||||
|
aws_instance = StubAWSInstance("1")
|
||||||
|
aws_service = AWSService(aws_instance)
|
||||||
|
assert aws_service.island_is_running_on_aws() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_aws_is_on_aws__False():
|
||||||
|
aws_instance = StubAWSInstance()
|
||||||
|
aws_service = AWSService(aws_instance)
|
||||||
|
assert aws_service.island_is_running_on_aws() is False
|
||||||
|
|
||||||
|
|
||||||
|
INSTANCE_ID = "1"
|
||||||
|
REGION = "2"
|
||||||
|
ACCOUNT_ID = "3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def aws_instance():
|
||||||
|
return StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def aws_service(aws_instance):
|
||||||
|
return AWSService(aws_instance)
|
||||||
|
|
||||||
|
|
||||||
|
def test_instance_id(aws_service):
|
||||||
|
assert aws_service.island_aws_instance.instance_id == INSTANCE_ID
|
||||||
|
|
||||||
|
|
||||||
|
def test_region(aws_service):
|
||||||
|
assert aws_service.island_aws_instance.region == REGION
|
||||||
|
|
||||||
|
|
||||||
|
def test_account_id(aws_service):
|
||||||
|
assert aws_service.island_aws_instance.account_id == ACCOUNT_ID
|
||||||
|
|
||||||
|
|
||||||
|
class MockAWSService(AWSService):
|
||||||
|
def __init__(self, aws_instance: AWSInstance, instance_info_response: Sequence[Dict[str, Any]]):
|
||||||
|
super().__init__(aws_instance)
|
||||||
|
self._instance_info_response = instance_info_response
|
||||||
|
|
||||||
|
def _get_raw_managed_instances(self):
|
||||||
|
return self._instance_info_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_managed_instances__empty(aws_instance):
|
||||||
|
aws_service = MockAWSService(aws_instance, EMPTY_INSTANCE_INFO_RESPONSE)
|
||||||
|
instances = aws_service.get_managed_instances()
|
||||||
|
assert len(instances) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_managed_instances(aws_instance):
|
||||||
|
aws_service = MockAWSService(aws_instance, FULL_INSTANCE_INFO_RESPONSE)
|
||||||
|
instances = aws_service.get_managed_instances()
|
||||||
|
|
||||||
|
assert len(instances) == 2
|
||||||
|
assert instances[0] == EXPECTED_INSTANCE_1
|
||||||
|
assert instances[1] == EXPECTED_INSTANCE_2
|
||||||
|
|
Loading…
Reference in New Issue