Merge pull request #1934 from guardicore/1928-lazy-load-aws-instance

1928 lazy load aws instance
This commit is contained in:
Mike Salvatore 2022-05-09 08:54:05 -04:00 committed by GitHub
commit 4f023621d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 225 additions and 200 deletions

View File

@ -0,0 +1 @@
from .aws_instance import AWSInstance

View File

@ -1,117 +1,52 @@
import json import threading
import logging
import re
from dataclasses import dataclass
from typing import Optional, Tuple
import requests from .aws_metadata import fetch_aws_instance_metadata
AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254" AWS_FETCH_METADATA_TIMEOUT = 10.0 # Seconds
AWS_LATEST_METADATA_URI_PREFIX = "http://{0}/latest/".format(AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS)
ACCOUNT_ID_KEY = "accountId"
logger = logging.getLogger(__name__)
AWS_TIMEOUT = 2
@dataclass class AWSTimeoutError(Exception):
class AwsInstanceInfo: """Raised when communications with AWS timeout"""
instance_id: Optional[str] = None
region: Optional[str] = None
account_id: Optional[str] = None
class AwsInstance: class AWSInstance:
""" """
Class which gives useful information about the current instance you're on. Class which gives useful information about the current instance you're on.
""" """
def __init__(self): def __init__(self):
self._is_instance, self._instance_info = AwsInstance._fetch_instance_info() self._instance_id = None
self._region = None
self._account_id = None
self._initialization_complete = threading.Event()
fetch_thread = threading.Thread(target=self._fetch_aws_instance_metadata, daemon=True)
fetch_thread.start()
def _fetch_aws_instance_metadata(self):
(self._instance_id, self._region, self._account_id) = fetch_aws_instance_metadata()
self._initialization_complete.set()
@property @property
def is_instance(self) -> bool: def is_instance(self) -> bool:
return self._is_instance self._wait_for_initialization_to_complete()
return bool(self._instance_id)
@property @property
def instance_id(self) -> str: def instance_id(self) -> str:
return self._instance_info.instance_id self._wait_for_initialization_to_complete()
return self._instance_id
@property @property
def region(self) -> str: def region(self) -> str:
return self._instance_info.region self._wait_for_initialization_to_complete()
return self._region
@property @property
def account_id(self) -> str: def account_id(self) -> str:
return self._instance_info.account_id self._wait_for_initialization_to_complete()
return self._account_id
@staticmethod def _wait_for_initialization_to_complete(self):
def _fetch_instance_info() -> Tuple[bool, AwsInstanceInfo]: if not self._initialization_complete.wait(AWS_FETCH_METADATA_TIMEOUT):
try: raise AWSTimeoutError("Timed out while attempting to retrieve metadata from AWS")
response = requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "meta-data/instance-id",
timeout=AWS_TIMEOUT,
)
if not response:
return False, AwsInstanceInfo()
info = AwsInstanceInfo()
info.instance_id = response.text if response else False
info.region = AwsInstance._parse_region(
requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "meta-data/placement/availability-zone",
timeout=AWS_TIMEOUT,
).text
)
except (requests.RequestException, IOError) as e:
logger.debug("Failed init of AwsInstance while getting metadata: {}".format(e))
return False, AwsInstanceInfo()
try:
info.account_id = AwsInstance._extract_account_id(
requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "dynamic/instance-identity/document",
timeout=AWS_TIMEOUT,
).text
)
except (requests.RequestException, json.decoder.JSONDecodeError, IOError) as e:
logger.debug(
"Failed init of AwsInstance while getting dynamic instance data: {}".format(e)
)
return False, AwsInstanceInfo()
return True, info
@staticmethod
def _parse_region(region_url_response):
# For a list of regions, see:
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts
# .RegionsAndAvailabilityZones.html
# This regex will find any AWS region format string in the response.
re_phrase = r"((?:us|eu|ap|ca|cn|sa)-[a-z]*-[0-9])"
finding = re.findall(re_phrase, region_url_response, re.IGNORECASE)
if finding:
return finding[0]
else:
return None
@staticmethod
def _extract_account_id(instance_identity_document_response):
"""
Extracts the account id from the dynamic/instance-identity/document metadata path.
Based on https://forums.aws.amazon.com/message.jspa?messageID=409028 which has a few more
solutions,
in case Amazon break this mechanism.
:param instance_identity_document_response: json returned via the web page
../dynamic/instance-identity/document
:return: The account id
"""
return json.loads(instance_identity_document_response)[ACCOUNT_ID_KEY]
def get_account_id(self):
"""
:return: the AWS account ID which "owns" this instance.
See https://docs.aws.amazon.com/general/latest/gr/acct-identifiers.html
"""
return self.account_id

View File

@ -0,0 +1,86 @@
import json
import logging
import re
from typing import Optional, Tuple
import requests
AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254"
AWS_LATEST_METADATA_URI_PREFIX = f"http://{AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS}/latest/"
ACCOUNT_ID_KEY = "accountId"
logger = logging.getLogger(__name__)
AWS_TIMEOUT = 2
def fetch_aws_instance_metadata() -> Tuple[Optional[str], Optional[str], Optional[str]]:
instance_id = None
region = None
account_id = None
try:
instance_id = _fetch_aws_instance_id()
region = _fetch_aws_region()
account_id = _fetch_account_id()
except (
requests.RequestException,
IOError,
json.decoder.JSONDecodeError,
) as err:
logger.debug(f"Failed init of AWSInstance while getting metadata: {err}")
return (None, None, None)
return (instance_id, region, account_id)
def _fetch_aws_instance_id() -> Optional[str]:
url = AWS_LATEST_METADATA_URI_PREFIX + "meta-data/instance-id"
response = requests.get(
url,
timeout=AWS_TIMEOUT,
)
response.raise_for_status()
return response.text
def _fetch_aws_region() -> Optional[str]:
response = requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "meta-data/placement/availability-zone",
timeout=AWS_TIMEOUT,
)
response.raise_for_status()
return _parse_region(response.text)
def _parse_region(region_url_response: str) -> Optional[str]:
# For a list of regions, see:
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts
# .RegionsAndAvailabilityZones.html
# This regex will find any AWS region format string in the response.
re_phrase = r"((?:us|eu|ap|ca|cn|sa)-[a-z]*-[0-9])"
finding = re.findall(re_phrase, region_url_response, re.IGNORECASE)
if finding:
return finding[0]
else:
return None
def _fetch_account_id() -> str:
"""
Fetches and extracts the account id from the dynamic/instance-identity/document metadata path.
Based on https://forums.aws.amazon.com/message.jspa?messageID=409028 which has a few more
solutions, in case Amazon break this mechanism.
:param instance_identity_document_response: json returned via the web page
../dynamic/instance-identity/document
:return: The account id
"""
response = requests.get(
AWS_LATEST_METADATA_URI_PREFIX + "dynamic/instance-identity/document",
timeout=AWS_TIMEOUT,
)
response.raise_for_status()
return json.loads(response.text)[ACCOUNT_ID_KEY]

View File

@ -1,6 +1,6 @@
import logging import logging
from common.aws.aws_instance import AwsInstance from common.aws import AWSInstance
from infection_monkey.telemetry.aws_instance_telem import AWSInstanceTelemetry from infection_monkey.telemetry.aws_instance_telem import AWSInstanceTelemetry
from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter import ( from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter import (
LegacyTelemetryMessengerAdapter, LegacyTelemetryMessengerAdapter,
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
def _report_aws_environment(telemetry_messenger: LegacyTelemetryMessengerAdapter): def _report_aws_environment(telemetry_messenger: LegacyTelemetryMessengerAdapter):
logger.info("Collecting AWS info") logger.info("Collecting AWS info")
aws_instance = AwsInstance() aws_instance = AWSInstance()
if aws_instance.is_instance: if aws_instance.is_instance:
logger.info("Machine is an AWS instance") logger.info("Machine is an AWS instance")

View File

@ -1,12 +1,10 @@
import logging import logging
from functools import wraps from typing import Optional
from threading import Event
from typing import Callable, Optional
import boto3 import boto3
import botocore import botocore
from common.aws.aws_instance import AwsInstance from common.aws.aws_instance import AWSInstance
INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList" INSTANCE_INFORMATION_LIST_KEY = "InstanceInformationList"
INSTANCE_ID_KEY = "InstanceId" INSTANCE_ID_KEY = "InstanceId"
@ -29,54 +27,30 @@ def filter_instance_data_from_aws_response(response):
] ]
aws_instance: Optional[AwsInstance] = None aws_instance: Optional[AWSInstance] = None
AWS_INFO_FETCH_TIMEOUT = 10.0 # Seconds
init_done = Event()
def initialize(): def initialize():
global aws_instance global aws_instance
aws_instance = AwsInstance() aws_instance = AWSInstance()
init_done.set()
def wait_init_done(fnc: Callable):
@wraps(fnc)
def inner():
awaited = init_done.wait(AWS_INFO_FETCH_TIMEOUT)
if not awaited:
logger.error(
f"AWS service couldn't initialize in time! "
f"Current timeout is {AWS_INFO_FETCH_TIMEOUT}, "
f"but AWS info took longer to fetch from metadata server."
)
return
fnc()
return inner
@wait_init_done
def is_on_aws(): def is_on_aws():
return aws_instance.is_instance return aws_instance.is_instance
@wait_init_done
def get_region(): def get_region():
return aws_instance.region return aws_instance.region
@wait_init_done
def get_account_id(): def get_account_id():
return aws_instance.account_id return aws_instance.account_id
@wait_init_done
def get_client(client_type): def get_client(client_type):
return boto3.client(client_type, region_name=aws_instance.region) return boto3.client(client_type, region_name=aws_instance.region)
@wait_init_done
def get_instances(): def get_instances():
""" """
Get the information for all instances with the relevant roles. Get the information for all instances with the relevant roles.

View File

@ -0,0 +1,61 @@
import pytest
from common.aws import AWSInstance
INSTANCE_ID = "1234"
REGION = "USA"
ACCOUNT_ID = "4321"
@pytest.fixture
def patch_fetch_metadata(monkeypatch):
def inner(instance_id: str, region: str, account_id: str):
return_value = (instance_id, region, account_id)
monkeypatch.setattr(
"common.aws.aws_instance.fetch_aws_instance_metadata", lambda: return_value
)
return inner
@pytest.fixture(autouse=True)
def patch_fetch_metadata_default_values(patch_fetch_metadata):
patch_fetch_metadata(INSTANCE_ID, REGION, ACCOUNT_ID)
def test_is_instance__true():
aws_instance = AWSInstance()
assert aws_instance.is_instance is True
def test_is_instance__false_none(patch_fetch_metadata):
patch_fetch_metadata(None, "", "")
aws_instance = AWSInstance()
assert aws_instance.is_instance is False
def test_is_instance__false_empty_str(patch_fetch_metadata):
patch_fetch_metadata("", "", "")
aws_instance = AWSInstance()
assert aws_instance.is_instance is False
def test_instance_id():
aws_instance = AWSInstance()
assert aws_instance.instance_id == INSTANCE_ID
def test_region():
aws_instance = AWSInstance()
assert aws_instance.region == REGION
def test_account_id():
aws_instance = AWSInstance()
assert aws_instance.account_id == ACCOUNT_ID

View File

@ -4,7 +4,7 @@ import pytest
import requests import requests
import requests_mock import requests_mock
from common.aws.aws_instance import AWS_LATEST_METADATA_URI_PREFIX, AwsInstance from common.aws.aws_metadata import AWS_LATEST_METADATA_URI_PREFIX, fetch_aws_instance_metadata
INSTANCE_ID_RESPONSE = "i-1234567890abcdef0" INSTANCE_ID_RESPONSE = "i-1234567890abcdef0"
@ -38,7 +38,7 @@ EXPECTED_REGION = "us-west-2"
EXPECTED_ACCOUNT_ID = "123456789012" EXPECTED_ACCOUNT_ID = "123456789012"
def get_test_aws_instance( def patch_and_call_fetch_aws_instance_metadata(
text=MappingProxyType({"instance_id": None, "region": None, "account_id": None}), text=MappingProxyType({"instance_id": None, "region": None, "account_id": None}),
exception=MappingProxyType({"instance_id": None, "region": None, "account_id": None}), exception=MappingProxyType({"instance_id": None, "region": None, "account_id": None}),
): ):
@ -59,104 +59,85 @@ def get_test_aws_instance(
url, exc=exception["account_id"] url, exc=exception["account_id"]
) )
test_aws_instance_object = AwsInstance() return fetch_aws_instance_metadata()
return test_aws_instance_object
# all good data # all good data
@pytest.fixture @pytest.fixture
def good_data_mock_instance(): def good_metadata():
instance = get_test_aws_instance( return patch_and_call_fetch_aws_instance_metadata(
text={ text={
"instance_id": INSTANCE_ID_RESPONSE, "instance_id": INSTANCE_ID_RESPONSE,
"region": AVAILABILITY_ZONE_RESPONSE, "region": AVAILABILITY_ZONE_RESPONSE,
"account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE, "account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE,
} }
) )
yield instance
del instance
def test_is_instance_good_data(good_data_mock_instance): def test_instance_id_good_data(good_metadata):
assert good_data_mock_instance.is_instance assert good_metadata[0] == EXPECTED_INSTANCE_ID
def test_instance_id_good_data(good_data_mock_instance): def test_region_good_data(good_metadata):
assert good_data_mock_instance.instance_id == EXPECTED_INSTANCE_ID assert good_metadata[1] == EXPECTED_REGION
def test_region_good_data(good_data_mock_instance): def test_account_id_good_data(good_metadata):
assert good_data_mock_instance.region == EXPECTED_REGION assert good_metadata[2] == EXPECTED_ACCOUNT_ID
def test_account_id_good_data(good_data_mock_instance):
assert good_data_mock_instance.account_id == EXPECTED_ACCOUNT_ID
# 'region' bad data # 'region' bad data
@pytest.fixture @pytest.fixture
def bad_region_data_mock_instance(): def bad_region_metadata():
instance = get_test_aws_instance( return patch_and_call_fetch_aws_instance_metadata(
text={ text={
"instance_id": INSTANCE_ID_RESPONSE, "instance_id": INSTANCE_ID_RESPONSE,
"region": "in-a-different-world", "region": "in-a-different-world",
"account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE, "account_id": INSTANCE_IDENTITY_DOCUMENT_RESPONSE,
} }
) )
yield instance
del instance
def test_is_instance_bad_region_data(bad_region_data_mock_instance): def test_instance_id_bad_region_data(bad_region_metadata):
assert bad_region_data_mock_instance.is_instance assert bad_region_metadata[0] == EXPECTED_INSTANCE_ID
def test_instance_id_bad_region_data(bad_region_data_mock_instance): def test_region_bad_region_data(bad_region_metadata):
assert bad_region_data_mock_instance.instance_id == EXPECTED_INSTANCE_ID assert bad_region_metadata[1] is None
def test_region_bad_region_data(bad_region_data_mock_instance): def test_account_id_bad_region_data(bad_region_metadata):
assert bad_region_data_mock_instance.region is None assert bad_region_metadata[2] == EXPECTED_ACCOUNT_ID
def test_account_id_bad_region_data(bad_region_data_mock_instance):
assert bad_region_data_mock_instance.account_id == EXPECTED_ACCOUNT_ID
# 'account_id' bad data # 'account_id' bad data
@pytest.fixture @pytest.fixture
def bad_account_id_data_mock_instance(): def bad_account_id_metadata():
instance = get_test_aws_instance( return patch_and_call_fetch_aws_instance_metadata(
text={ text={
"instance_id": INSTANCE_ID_RESPONSE, "instance_id": INSTANCE_ID_RESPONSE,
"region": AVAILABILITY_ZONE_RESPONSE, "region": AVAILABILITY_ZONE_RESPONSE,
"account_id": "who-am-i", "account_id": "who-am-i",
} }
) )
yield instance
del instance
def test_is_instance_bad_account_id_data(bad_account_id_data_mock_instance): def test_instance_id_bad_account_id_data(bad_account_id_metadata):
assert not bad_account_id_data_mock_instance.is_instance assert bad_account_id_metadata[0] is None
def test_instance_id_bad_account_id_data(bad_account_id_data_mock_instance): def test_region_bad_account_id_data(bad_account_id_metadata):
assert bad_account_id_data_mock_instance.instance_id is None assert bad_account_id_metadata[1] is None
def test_region_bad_account_id_data(bad_account_id_data_mock_instance): def test_account_id_data_bad_account_id_data(bad_account_id_metadata):
assert bad_account_id_data_mock_instance.region is None assert bad_account_id_metadata[2] is None
def test_account_id_data_bad_account_id_data(bad_account_id_data_mock_instance):
assert bad_account_id_data_mock_instance.account_id is None
# 'region' bad requests # 'region' bad requests
@pytest.fixture @pytest.fixture
def bad_region_request_mock_instance(region_exception): def region_request_failure_metadata(region_exception):
instance = get_test_aws_instance( return patch_and_call_fetch_aws_instance_metadata(
text={ text={
"instance_id": INSTANCE_ID_RESPONSE, "instance_id": INSTANCE_ID_RESPONSE,
"region": None, "region": None,
@ -164,33 +145,26 @@ def bad_region_request_mock_instance(region_exception):
}, },
exception={"instance_id": None, "region": region_exception, "account_id": None}, exception={"instance_id": None, "region": region_exception, "account_id": None},
) )
yield instance
del instance
@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError])
def test_is_instance_bad_region_request(bad_region_request_mock_instance): def test_instance_id_bad_region_request(region_request_failure_metadata):
assert not bad_region_request_mock_instance.is_instance assert region_request_failure_metadata[0] is None
@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError])
def test_instance_id_bad_region_request(bad_region_request_mock_instance): def test_region_bad_region_request(region_request_failure_metadata):
assert bad_region_request_mock_instance.instance_id is None assert region_request_failure_metadata[1] is None
@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError]) @pytest.mark.parametrize("region_exception", [requests.RequestException, IOError])
def test_region_bad_region_request(bad_region_request_mock_instance): def test_account_id_bad_region_request(region_request_failure_metadata):
assert bad_region_request_mock_instance.region is None assert region_request_failure_metadata[2] is None
@pytest.mark.parametrize("region_exception", [requests.RequestException, IOError])
def test_account_id_bad_region_request(bad_region_request_mock_instance):
assert bad_region_request_mock_instance.account_id is None
# not found request # not found request
@pytest.fixture @pytest.fixture
def not_found_request_mock_instance(): def not_found_metadata():
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
# request made to get instance_id # request made to get instance_id
url = f"{AWS_LATEST_METADATA_URI_PREFIX}meta-data/instance-id" url = f"{AWS_LATEST_METADATA_URI_PREFIX}meta-data/instance-id"
@ -204,22 +178,16 @@ def not_found_request_mock_instance():
url = f"{AWS_LATEST_METADATA_URI_PREFIX}dynamic/instance-identity/document" url = f"{AWS_LATEST_METADATA_URI_PREFIX}dynamic/instance-identity/document"
m.get(url) m.get(url)
not_found_aws_instance_object = AwsInstance() return fetch_aws_instance_metadata()
yield not_found_aws_instance_object
del not_found_aws_instance_object
def test_is_instance_not_found_request(not_found_request_mock_instance): def test_instance_id_not_found_request(not_found_metadata):
assert not_found_request_mock_instance.is_instance is False assert not_found_metadata[0] is None
def test_instance_id_not_found_request(not_found_request_mock_instance): def test_region_not_found_request(not_found_metadata):
assert not_found_request_mock_instance.instance_id is None assert not_found_metadata[1] is None
def test_region_not_found_request(not_found_request_mock_instance): def test_account_id_not_found_request(not_found_metadata):
assert not_found_request_mock_instance.region is None assert not_found_metadata[2] is None
def test_account_id_not_found_request(not_found_request_mock_instance):
assert not_found_request_mock_instance.account_id is None