Island: Add get_managed_instances() to AWSService

This commit is contained in:
Mike Salvatore 2022-05-09 10:09:17 -04:00
parent acabc835d4
commit 8995eb5d2f
2 changed files with 124 additions and 109 deletions

View File

@ -1,5 +1,5 @@
import logging
from typing import Iterable, Optional
from typing import Any, Dict, Iterable, Sequence
import boto3
import botocore
@ -26,51 +26,11 @@ class AWSService:
def island_aws_instance(self) -> AWSInstance:
return self._aws_instance
def run_agent_on_managed_instances(self, instance_ids: Iterable[str]):
for id_ in instance_ids:
self._run_agent_on_managed_instance(id_)
def get_managed_instances(self) -> Sequence[Dict[str, str]]:
raw_managed_instances_info = self._get_raw_managed_instances()
return _filter_instance_info_from_aws_response(raw_managed_instances_info)
def _run_agent_on_managed_instance(self, instance_id: str):
pass
def filter_instance_data_from_aws_response(response):
return [
{
"instance_id": x[INSTANCE_ID_KEY],
"name": x[COMPUTER_NAME_KEY],
"os": x[PLATFORM_TYPE_KEY].lower(),
"ip_address": x[IP_ADDRESS_KEY],
}
for x in response[INSTANCE_INFORMATION_LIST_KEY]
]
aws_instance: Optional[AWSInstance] = None
def initialize():
global aws_instance
aws_instance = AWSInstance()
def is_on_aws():
return aws_instance.is_instance
def get_region():
return aws_instance.region
def get_account_id():
return aws_instance.account_id
def get_client(client_type):
return boto3.client(client_type, region_name=aws_instance.region)
def get_instances():
def _get_raw_managed_instances(self) -> Sequence[Dict[str, Any]]:
"""
Get the information for all instances with the relevant roles.
@ -81,12 +41,29 @@ def get_instances():
:raises: botocore.exceptions.ClientError if can't describe local instance information.
:return: All visible instances from this instance
"""
local_ssm_client = boto3.client("ssm", aws_instance.region)
local_ssm_client = boto3.client("ssm", self.island_aws_instance.region)
try:
response = local_ssm_client.describe_instance_information()
return response[INSTANCE_INFORMATION_LIST_KEY]
except botocore.exceptions.ClientError as err:
logger.warning("AWS client error while trying to get manage dinstances: {err}")
raise err
filtered_instances_data = filter_instance_data_from_aws_response(response)
return filtered_instances_data
except botocore.exceptions.ClientError as e:
logger.warning("AWS client error while trying to get instances: " + e)
raise e
def run_agent_on_managed_instances(self, instance_ids: Iterable[str]):
for id_ in instance_ids:
self._run_agent_on_managed_instance(id_)
def _run_agent_on_managed_instance(self, instance_id: str):
pass
def _filter_instance_info_from_aws_response(raw_managed_instances_info: Sequence[Dict[str, Any]]):
return [
{
"instance_id": managed_instance[INSTANCE_ID_KEY],
"name": managed_instance[COMPUTER_NAME_KEY],
"os": managed_instance[PLATFORM_TYPE_KEY].lower(),
"ip_address": managed_instance[IP_ADDRESS_KEY],
}
for managed_instance in raw_managed_instances_info
]

View File

@ -1,34 +1,38 @@
import json
import threading
from typing import Optional
from unittest import TestCase
from typing import Any, Dict, Optional, Sequence
import pytest
from common.aws import AWSInstance
from monkey_island.cc.services import AWSService
from monkey_island.cc.services.aws_service import filter_instance_data_from_aws_response
EXPECTED_INSTANCE_1 = {
"instance_id": "1",
"name": "comp1",
"os": "linux",
"ip_address": "192.168.1.1",
}
EXPECTED_INSTANCE_2 = {
"instance_id": "2",
"name": "comp2",
"os": "linux",
"ip_address": "192.168.1.2",
}
class TestAwsService(TestCase):
def test_filter_instance_data_from_aws_response(self):
json_response_full = """
{
"InstanceInformationList": [
EMPTY_INSTANCE_INFO_RESPONSE = []
FULL_INSTANCE_INFO_RESPONSE = [
{
"ActivationId": "string",
"AgentVersion": "string",
"AssociationOverview": {
"DetailedStatus": "string",
"InstanceAssociationStatusAggregatedCount": {
"string" : 6
}
"InstanceAssociationStatusAggregatedCount": {"string": 6},
},
"AssociationStatus": "string",
"ComputerName": "string",
"ComputerName": EXPECTED_INSTANCE_1["name"],
"IamRole": "string",
"InstanceId": "string",
"IPAddress": "string",
"InstanceId": EXPECTED_INSTANCE_1["instance_id"],
"IPAddress": EXPECTED_INSTANCE_1["ip_address"],
"IsLatestVersion": "True",
"LastAssociationExecutionDate": 6,
"LastPingDateTime": 6,
@ -36,30 +40,36 @@ class TestAwsService(TestCase):
"Name": "string",
"PingStatus": "string",
"PlatformName": "string",
"PlatformType": "string",
"PlatformType": EXPECTED_INSTANCE_1["os"],
"PlatformVersion": "string",
"RegistrationDate": 6,
"ResourceType": "string"
}
],
"NextToken": "string"
}
"""
json_response_empty = """
"ResourceType": "string",
},
{
"InstanceInformationList": [],
"NextToken": "string"
}
"""
self.assertEqual(
filter_instance_data_from_aws_response(json.loads(json_response_empty)), []
)
self.assertEqual(
filter_instance_data_from_aws_response(json.loads(json_response_full)),
[{"instance_id": "string", "ip_address": "string", "name": "string", "os": "string"}],
)
"ActivationId": "string",
"AgentVersion": "string",
"AssociationOverview": {
"DetailedStatus": "string",
"InstanceAssociationStatusAggregatedCount": {"string": 6},
},
"AssociationStatus": "string",
"ComputerName": EXPECTED_INSTANCE_2["name"],
"IamRole": "string",
"InstanceId": EXPECTED_INSTANCE_2["instance_id"],
"IPAddress": EXPECTED_INSTANCE_2["ip_address"],
"IsLatestVersion": "True",
"LastAssociationExecutionDate": 6,
"LastPingDateTime": 6,
"LastSuccessfulAssociationExecutionDate": 6,
"Name": "string",
"PingStatus": "string",
"PlatformName": "string",
"PlatformType": EXPECTED_INSTANCE_2["os"],
"PlatformVersion": "string",
"RegistrationDate": 6,
"ResourceType": "string",
},
]
class StubAWSInstance(AWSInstance):
@ -95,8 +105,12 @@ ACCOUNT_ID = "3"
@pytest.fixture
def aws_service():
aws_instance = StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID)
def aws_instance():
return StubAWSInstance(INSTANCE_ID, REGION, ACCOUNT_ID)
@pytest.fixture
def aws_service(aws_instance):
return AWSService(aws_instance)
@ -110,3 +124,27 @@ def test_region(aws_service):
def test_account_id(aws_service):
assert aws_service.island_aws_instance.account_id == ACCOUNT_ID
class MockAWSService(AWSService):
def __init__(self, aws_instance: AWSInstance, instance_info_response: Sequence[Dict[str, Any]]):
super().__init__(aws_instance)
self._instance_info_response = instance_info_response
def _get_raw_managed_instances(self):
return self._instance_info_response
def test_get_managed_instances__empty(aws_instance):
aws_service = MockAWSService(aws_instance, EMPTY_INSTANCE_INFO_RESPONSE)
instances = aws_service.get_managed_instances()
assert len(instances) == 0
def test_get_managed_instances(aws_instance):
aws_service = MockAWSService(aws_instance, FULL_INSTANCE_INFO_RESPONSE)
instances = aws_service.get_managed_instances()
assert len(instances) == 2
assert instances[0] == EXPECTED_INSTANCE_1
assert instances[1] == EXPECTED_INSTANCE_2