diff --git a/monkey/common/aws/__init__.py b/monkey/common/aws/__init__.py index e69de29bb..1bf4dafaa 100644 --- a/monkey/common/aws/__init__.py +++ b/monkey/common/aws/__init__.py @@ -0,0 +1 @@ +from .aws_instance import AWSInstance diff --git a/monkey/common/aws/aws_instance.py b/monkey/common/aws/aws_instance.py index d99c87117..76c3cac9a 100644 --- a/monkey/common/aws/aws_instance.py +++ b/monkey/common/aws/aws_instance.py @@ -1,117 +1,52 @@ -import json -import logging -import re -from dataclasses import dataclass -from typing import Optional, Tuple +import threading -import requests +from .aws_metadata import fetch_aws_instance_metadata -AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254" -AWS_LATEST_METADATA_URI_PREFIX = "http://{0}/latest/".format(AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS) -ACCOUNT_ID_KEY = "accountId" - -logger = logging.getLogger(__name__) - -AWS_TIMEOUT = 2 +AWS_FETCH_METADATA_TIMEOUT = 10.0 # Seconds -@dataclass -class AwsInstanceInfo: - instance_id: Optional[str] = None - region: Optional[str] = None - account_id: Optional[str] = None +class AWSTimeoutError(Exception): + """Raised when communications with AWS timeout""" -class AwsInstance: +class AWSInstance: """ Class which gives useful information about the current instance you're on. """ def __init__(self): - self._is_instance, self._instance_info = AwsInstance._fetch_instance_info() + self._instance_id = None + self._region = None + self._account_id = None + self._initialization_complete = threading.Event() + + fetch_thread = threading.Thread(target=self._fetch_aws_instance_metadata, daemon=True) + fetch_thread.start() + + def _fetch_aws_instance_metadata(self): + (self._instance_id, self._region, self._account_id) = fetch_aws_instance_metadata() + self._initialization_complete.set() @property def is_instance(self) -> bool: - return self._is_instance + self._wait_for_initialization_to_complete() + return bool(self._instance_id) @property def instance_id(self) -> str: - return self._instance_info.instance_id + self._wait_for_initialization_to_complete() + return self._instance_id @property def region(self) -> str: - return self._instance_info.region + self._wait_for_initialization_to_complete() + return self._region @property def account_id(self) -> str: - return self._instance_info.account_id + self._wait_for_initialization_to_complete() + return self._account_id - @staticmethod - def _fetch_instance_info() -> Tuple[bool, AwsInstanceInfo]: - try: - response = requests.get( - AWS_LATEST_METADATA_URI_PREFIX + "meta-data/instance-id", - timeout=AWS_TIMEOUT, - ) - if not response: - return False, AwsInstanceInfo() - - info = AwsInstanceInfo() - info.instance_id = response.text if response else False - info.region = AwsInstance._parse_region( - requests.get( - AWS_LATEST_METADATA_URI_PREFIX + "meta-data/placement/availability-zone", - timeout=AWS_TIMEOUT, - ).text - ) - except (requests.RequestException, IOError) as e: - logger.debug("Failed init of AwsInstance while getting metadata: {}".format(e)) - return False, AwsInstanceInfo() - - try: - info.account_id = AwsInstance._extract_account_id( - requests.get( - AWS_LATEST_METADATA_URI_PREFIX + "dynamic/instance-identity/document", - timeout=AWS_TIMEOUT, - ).text - ) - except (requests.RequestException, json.decoder.JSONDecodeError, IOError) as e: - logger.debug( - "Failed init of AwsInstance while getting dynamic instance data: {}".format(e) - ) - return False, AwsInstanceInfo() - - return True, info - - @staticmethod - def _parse_region(region_url_response): - # For a list of regions, see: - # https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts - # .RegionsAndAvailabilityZones.html - # This regex will find any AWS region format string in the response. - re_phrase = r"((?:us|eu|ap|ca|cn|sa)-[a-z]*-[0-9])" - finding = re.findall(re_phrase, region_url_response, re.IGNORECASE) - if finding: - return finding[0] - else: - return None - - @staticmethod - def _extract_account_id(instance_identity_document_response): - """ - Extracts the account id from the dynamic/instance-identity/document metadata path. - Based on https://forums.aws.amazon.com/message.jspa?messageID=409028 which has a few more - solutions, - in case Amazon break this mechanism. - :param instance_identity_document_response: json returned via the web page - ../dynamic/instance-identity/document - :return: The account id - """ - return json.loads(instance_identity_document_response)[ACCOUNT_ID_KEY] - - def get_account_id(self): - """ - :return: the AWS account ID which "owns" this instance. - See https://docs.aws.amazon.com/general/latest/gr/acct-identifiers.html - """ - return self.account_id + def _wait_for_initialization_to_complete(self): + if not self._initialization_complete.wait(AWS_FETCH_METADATA_TIMEOUT): + raise AWSTimeoutError("Timed out while attempting to retrieve metadata from AWS") diff --git a/monkey/common/aws/aws_metadata.py b/monkey/common/aws/aws_metadata.py new file mode 100644 index 000000000..634d41c49 --- /dev/null +++ b/monkey/common/aws/aws_metadata.py @@ -0,0 +1,86 @@ +import json +import logging +import re +from typing import Optional, Tuple + +import requests + +AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254" +AWS_LATEST_METADATA_URI_PREFIX = f"http://{AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS}/latest/" +ACCOUNT_ID_KEY = "accountId" + +logger = logging.getLogger(__name__) + +AWS_TIMEOUT = 2 + + +def fetch_aws_instance_metadata() -> Tuple[Optional[str], Optional[str], Optional[str]]: + instance_id = None + region = None + account_id = None + + try: + instance_id = _fetch_aws_instance_id() + region = _fetch_aws_region() + account_id = _fetch_account_id() + except ( + requests.RequestException, + IOError, + json.decoder.JSONDecodeError, + ) as err: + logger.debug(f"Failed init of AWSInstance while getting metadata: {err}") + return (None, None, None) + + return (instance_id, region, account_id) + + +def _fetch_aws_instance_id() -> Optional[str]: + url = AWS_LATEST_METADATA_URI_PREFIX + "meta-data/instance-id" + response = requests.get( + url, + timeout=AWS_TIMEOUT, + ) + response.raise_for_status() + + return response.text + + +def _fetch_aws_region() -> Optional[str]: + response = requests.get( + AWS_LATEST_METADATA_URI_PREFIX + "meta-data/placement/availability-zone", + timeout=AWS_TIMEOUT, + ) + response.raise_for_status() + + return _parse_region(response.text) + + +def _parse_region(region_url_response: str) -> Optional[str]: + # For a list of regions, see: + # https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts + # .RegionsAndAvailabilityZones.html + # This regex will find any AWS region format string in the response. + re_phrase = r"((?:us|eu|ap|ca|cn|sa)-[a-z]*-[0-9])" + finding = re.findall(re_phrase, region_url_response, re.IGNORECASE) + if finding: + return finding[0] + else: + return None + + +def _fetch_account_id() -> str: + """ + Fetches and extracts the account id from the dynamic/instance-identity/document metadata path. + Based on https://forums.aws.amazon.com/message.jspa?messageID=409028 which has a few more + solutions, in case Amazon break this mechanism. + :param instance_identity_document_response: json returned via the web page + ../dynamic/instance-identity/document + :return: The account id + """ + response = requests.get( + AWS_LATEST_METADATA_URI_PREFIX + "dynamic/instance-identity/document", + timeout=AWS_TIMEOUT, + ) + response.raise_for_status() + + return json.loads(response.text)[ACCOUNT_ID_KEY] diff --git a/monkey/infection_monkey/utils/aws_environment_check.py b/monkey/infection_monkey/utils/aws_environment_check.py index 203882425..dc074cd6a 100644 --- a/monkey/infection_monkey/utils/aws_environment_check.py +++ b/monkey/infection_monkey/utils/aws_environment_check.py @@ -1,6 +1,6 @@ import logging -from common.aws.aws_instance import AwsInstance +from common.aws import AWSInstance from infection_monkey.telemetry.aws_instance_telem import AWSInstanceTelemetry from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter import ( LegacyTelemetryMessengerAdapter, @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) def _report_aws_environment(telemetry_messenger: LegacyTelemetryMessengerAdapter): logger.info("Collecting AWS info") - aws_instance = AwsInstance() + aws_instance = AWSInstance() if aws_instance.is_instance: logger.info("Machine is an AWS instance") diff --git a/monkey/monkey_island/cc/services/aws_service.py b/monkey/monkey_island/cc/services/aws_service.py index b0f252608..12432b8b7 100644 --- a/monkey/monkey_island/cc/services/aws_service.py +++ b/monkey/monkey_island/cc/services/aws_service.py @@ -1,12 +1,10 @@ import logging -from functools import wraps -from threading import Event -from typing import Callable, Optional +from typing import Optional import boto3 import botocore -from common.aws.aws_instance import AwsInstance +from common.aws.aws_instance import AWSInstance INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList" INSTANCE_ID_KEY = "InstanceId" @@ -29,54 +27,30 @@ def filter_instance_data_from_aws_response(response): ] -aws_instance: Optional[AwsInstance] = None -AWS_INFO_FETCH_TIMEOUT = 10.0 # Seconds -init_done = Event() +aws_instance: Optional[AWSInstance] = None def initialize(): global aws_instance - aws_instance = AwsInstance() - init_done.set() + aws_instance = AWSInstance() -def wait_init_done(fnc: Callable): - @wraps(fnc) - def inner(): - awaited = init_done.wait(AWS_INFO_FETCH_TIMEOUT) - if not awaited: - logger.error( - f"AWS service couldn't initialize in time! " - f"Current timeout is {AWS_INFO_FETCH_TIMEOUT}, " - f"but AWS info took longer to fetch from metadata server." - ) - return - fnc() - - return inner - - -@wait_init_done def is_on_aws(): return aws_instance.is_instance -@wait_init_done def get_region(): return aws_instance.region -@wait_init_done def get_account_id(): return aws_instance.account_id -@wait_init_done def get_client(client_type): return boto3.client(client_type, region_name=aws_instance.region) -@wait_init_done def get_instances(): """ Get the information for all instances with the relevant roles. diff --git a/monkey/tests/unit_tests/common/aws/test_aws_instance.py b/monkey/tests/unit_tests/common/aws/test_aws_instance.py new file mode 100644 index 000000000..3ba3329f4 --- /dev/null +++ b/monkey/tests/unit_tests/common/aws/test_aws_instance.py @@ -0,0 +1,61 @@ +import pytest + +from common.aws import AWSInstance + +INSTANCE_ID = "1234" +REGION = "USA" +ACCOUNT_ID = "4321" + + +@pytest.fixture +def patch_fetch_metadata(monkeypatch): + def inner(instance_id: str, region: str, account_id: str): + return_value = (instance_id, region, account_id) + monkeypatch.setattr( + "common.aws.aws_instance.fetch_aws_instance_metadata", lambda: return_value + ) + + return inner + + +@pytest.fixture(autouse=True) +def patch_fetch_metadata_default_values(patch_fetch_metadata): + patch_fetch_metadata(INSTANCE_ID, REGION, ACCOUNT_ID) + + +def test_is_instance__true(): + aws_instance = AWSInstance() + + assert aws_instance.is_instance is True + + +def test_is_instance__false_none(patch_fetch_metadata): + patch_fetch_metadata(None, "", "") + aws_instance = AWSInstance() + + assert aws_instance.is_instance is False + + +def test_is_instance__false_empty_str(patch_fetch_metadata): + patch_fetch_metadata("", "", "") + aws_instance = AWSInstance() + + assert aws_instance.is_instance is False + + +def test_instance_id(): + aws_instance = AWSInstance() + + assert aws_instance.instance_id == INSTANCE_ID + + +def test_region(): + aws_instance = AWSInstance() + + assert aws_instance.region == REGION + + +def test_account_id(): + aws_instance = AWSInstance() + + assert aws_instance.account_id == ACCOUNT_ID diff --git a/monkey/tests/unit_tests/common/cloud/aws/test_aws_instance.py b/monkey/tests/unit_tests/common/aws/test_aws_metadata.py similarity index 50% rename from monkey/tests/unit_tests/common/cloud/aws/test_aws_instance.py rename to monkey/tests/unit_tests/common/aws/test_aws_metadata.py index 070b14092..470e41e01 100644 --- a/monkey/tests/unit_tests/common/cloud/aws/test_aws_instance.py +++ b/monkey/tests/unit_tests/common/aws/test_aws_metadata.py @@ -4,7 +4,7 @@ import pytest import requests import requests_mock -from common.aws.aws_instance import AWS_LATEST_METADATA_URI_PREFIX, AwsInstance +from common.aws.aws_metadata import AWS_LATEST_METADATA_URI_PREFIX, fetch_aws_instance_metadata INSTANCE_ID_RESPONSE = "i-1234567890abcdef0" @@ -38,7 +38,7 @@ EXPECTED_REGION = "us-west-2" EXPECTED_ACCOUNT_ID = "123456789012" -def get_test_aws_instance( +def patch_and_call_fetch_aws_instance_metadata( text=MappingProxyType({"instance_id": None, "region": None, "account_id": None}), exception=MappingProxyType({"instance_id": None, "region": None, "account_id": None}), ): @@ -59,104 +59,85 @@ def get_test_aws_instance( url, exc=exception["account_id"] ) - test_aws_instance_object = AwsInstance() - return test_aws_instance_object + return fetch_aws_instance_metadata() # all good data @pytest.fixture -def good_data_mock_instance(): - instance = get_test_aws_instance( +def good_metadata(): + return patch_and_call_fetch_aws_instance_metadata( text={ "instance_id": INSTANCE_ID_RESPONSE, "region": AVAILABILITY_ZONE_RESPONSE, "account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE, } ) - yield instance - del instance -def test_is_instance_good_data(good_data_mock_instance): - assert good_data_mock_instance.is_instance +def test_instance_id_good_data(good_metadata): + assert good_metadata[0] == EXPECTED_INSTANCE_ID -def test_instance_id_good_data(good_data_mock_instance): - assert good_data_mock_instance.instance_id == EXPECTED_INSTANCE_ID +def test_region_good_data(good_metadata): + assert good_metadata[1] == EXPECTED_REGION -def test_region_good_data(good_data_mock_instance): - assert good_data_mock_instance.region == EXPECTED_REGION - - -def test_account_id_good_data(good_data_mock_instance): - assert good_data_mock_instance.account_id == EXPECTED_ACCOUNT_ID +def test_account_id_good_data(good_metadata): + assert good_metadata[2] == EXPECTED_ACCOUNT_ID # 'region' bad data @pytest.fixture -def bad_region_data_mock_instance(): - instance = get_test_aws_instance( +def bad_region_metadata(): + return patch_and_call_fetch_aws_instance_metadata( text={ "instance_id": INSTANCE_ID_RESPONSE, "region": "in-a-different-world", "account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE, } ) - yield instance - del instance -def test_is_instance_bad_region_data(bad_region_data_mock_instance): - assert bad_region_data_mock_instance.is_instance +def test_instance_id_bad_region_data(bad_region_metadata): + assert bad_region_metadata[0] == EXPECTED_INSTANCE_ID -def test_instance_id_bad_region_data(bad_region_data_mock_instance): - assert bad_region_data_mock_instance.instance_id == EXPECTED_INSTANCE_ID +def test_region_bad_region_data(bad_region_metadata): + assert bad_region_metadata[1] is None -def test_region_bad_region_data(bad_region_data_mock_instance): - assert bad_region_data_mock_instance.region is None - - -def test_account_id_bad_region_data(bad_region_data_mock_instance): - assert bad_region_data_mock_instance.account_id == EXPECTED_ACCOUNT_ID +def test_account_id_bad_region_data(bad_region_metadata): + assert bad_region_metadata[2] == EXPECTED_ACCOUNT_ID # 'account_id' bad data @pytest.fixture -def bad_account_id_data_mock_instance(): - instance = get_test_aws_instance( +def bad_account_id_metadata(): + return patch_and_call_fetch_aws_instance_metadata( text={ "instance_id": INSTANCE_ID_RESPONSE, "region": AVAILABILITY_ZONE_RESPONSE, "account_id": "who-am-i", } ) - yield instance - del instance -def test_is_instance_bad_account_id_data(bad_account_id_data_mock_instance): - assert not bad_account_id_data_mock_instance.is_instance +def test_instance_id_bad_account_id_data(bad_account_id_metadata): + assert bad_account_id_metadata[0] is None -def test_instance_id_bad_account_id_data(bad_account_id_data_mock_instance): - assert bad_account_id_data_mock_instance.instance_id is None +def test_region_bad_account_id_data(bad_account_id_metadata): + assert bad_account_id_metadata[1] is None -def test_region_bad_account_id_data(bad_account_id_data_mock_instance): - assert bad_account_id_data_mock_instance.region is None - - -def test_account_id_data_bad_account_id_data(bad_account_id_data_mock_instance): - assert bad_account_id_data_mock_instance.account_id is None +def test_account_id_data_bad_account_id_data(bad_account_id_metadata): + assert bad_account_id_metadata[2] is None # 'region' bad requests @pytest.fixture -def bad_region_request_mock_instance(region_exception): - instance = get_test_aws_instance( +def region_request_failure_metadata(region_exception): + return patch_and_call_fetch_aws_instance_metadata( text={ "instance_id": INSTANCE_ID_RESPONSE, "region": None, @@ -164,33 +145,26 @@ def bad_region_request_mock_instance(region_exception): }, exception={"instance_id": None, "region": region_exception, "account_id": None}, ) - yield instance - del instance @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) -def test_is_instance_bad_region_request(bad_region_request_mock_instance): - assert not bad_region_request_mock_instance.is_instance +def test_instance_id_bad_region_request(region_request_failure_metadata): + assert region_request_failure_metadata[0] is None @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) -def test_instance_id_bad_region_request(bad_region_request_mock_instance): - assert bad_region_request_mock_instance.instance_id is None +def test_region_bad_region_request(region_request_failure_metadata): + assert region_request_failure_metadata[1] is None @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) -def test_region_bad_region_request(bad_region_request_mock_instance): - assert bad_region_request_mock_instance.region is None - - -@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) -def test_account_id_bad_region_request(bad_region_request_mock_instance): - assert bad_region_request_mock_instance.account_id is None +def test_account_id_bad_region_request(region_request_failure_metadata): + assert region_request_failure_metadata[2] is None # not found request @pytest.fixture -def not_found_request_mock_instance(): +def not_found_metadata(): with requests_mock.Mocker() as m: # request made to get instance_id url = f"{AWS_LATEST_METADATA_URI_PREFIX}meta-data/instance-id" @@ -204,22 +178,16 @@ def not_found_request_mock_instance(): url = f"{AWS_LATEST_METADATA_URI_PREFIX}dynamic/instance-identity/document" m.get(url) - not_found_aws_instance_object = AwsInstance() - yield not_found_aws_instance_object - del not_found_aws_instance_object + return fetch_aws_instance_metadata() -def test_is_instance_not_found_request(not_found_request_mock_instance): - assert not_found_request_mock_instance.is_instance is False +def test_instance_id_not_found_request(not_found_metadata): + assert not_found_metadata[0] is None -def test_instance_id_not_found_request(not_found_request_mock_instance): - assert not_found_request_mock_instance.instance_id is None +def test_region_not_found_request(not_found_metadata): + assert not_found_metadata[1] is None -def test_region_not_found_request(not_found_request_mock_instance): - assert not_found_request_mock_instance.region is None - - -def test_account_id_not_found_request(not_found_request_mock_instance): - assert not_found_request_mock_instance.account_id is None +def test_account_id_not_found_request(not_found_metadata): + assert not_found_metadata[2] is None