Agent, Island, Common: Refactor AwsService from class to package

This also changes AwsInstance from singleton and instead the aws_service package is used as one
This commit is contained in:
vakarisz 2022-05-02 14:51:06 +03:00
parent 7b2ff1e159
commit f3a5a7090b
10 changed files with 79 additions and 107 deletions

View File

@ -6,8 +6,6 @@ from typing import Optional, Tuple
import requests import requests
from common.utils.code_utils import Singleton
AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254" AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254"
AWS_LATEST_METADATA_URI_PREFIX = "http://{0}/latest/".format(AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS) AWS_LATEST_METADATA_URI_PREFIX = "http://{0}/latest/".format(AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS)
ACCOUNT_ID_KEY = "accountId" ACCOUNT_ID_KEY = "accountId"
@ -29,8 +27,6 @@ class AwsInstance:
Class which gives useful information about the current instance you're on. Class which gives useful information about the current instance you're on.
""" """
__metaclass__ = Singleton
def __init__(self): def __init__(self):
self._is_instance, self._instance_info = AwsInstance._fetch_instance_info() self._is_instance, self._instance_info = AwsInstance._fetch_instance_info()

View File

@ -1,4 +1,7 @@
import logging import logging
from functools import wraps
from threading import Event
from typing import Callable, Optional
import boto3 import boto3
import botocore import botocore
@ -26,32 +29,50 @@ def filter_instance_data_from_aws_response(response):
] ]
class AwsService(object): aws_instance: Optional[AwsInstance] = None
""" AWS_INFO_FETCH_TIMEOUT = 10.0 # Seconds
A wrapper class around the boto3 client and session modules, which supplies various AWS init_done = Event()
services.
This class will assume:
1. That it's running on an EC2 instance
2. That the instance is associated with the correct IAM role. See
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#iam-role
for details.
"""
region = None def initialize():
global aws_instance
aws_instance = AwsInstance()
init_done.set()
@staticmethod
def set_region(region):
AwsService.region = region
@staticmethod def wait_init_done(fnc: Callable):
def get_client(client_type, region=None): @wraps(fnc)
return boto3.client( def inner():
client_type, region_name=region if region is not None else AwsService.region awaited = init_done.wait(AWS_INFO_FETCH_TIMEOUT)
if not awaited:
logger.error(
f"AWS service couldn't initialize in time! "
f"Current timeout is {AWS_INFO_FETCH_TIMEOUT}, "
f"but AWS info took longer to fetch from metadata server."
) )
return
fnc()
@staticmethod return inner
def get_instances():
@wait_init_done
def is_on_aws():
return aws_instance.is_instance
@wait_init_done
def get_region():
return aws_instance.region
@wait_init_done
def get_client(client_type):
return boto3.client(client_type, region_name=aws_instance.region)
@wait_init_done
def get_instances():
""" """
Get the information for all instances with the relevant roles. Get the information for all instances with the relevant roles.
@ -62,8 +83,7 @@ class AwsService(object):
:raises: botocore.exceptions.ClientError if can't describe local instance information. :raises: botocore.exceptions.ClientError if can't describe local instance information.
:return: All visible instances from this instance :return: All visible instances from this instance
""" """
current_instance = AwsInstance() local_ssm_client = boto3.client("ssm", aws_instance.region)
local_ssm_client = boto3.client("ssm", current_instance.region)
try: try:
response = local_ssm_client.describe_instance_information() response = local_ssm_client.describe_instance_information()

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
from common.aws.aws_service import AwsService from common.aws import aws_service
from common.cmd.aws.aws_cmd_result import AwsCmdResult from common.cmd.aws.aws_cmd_result import AwsCmdResult
from common.cmd.cmd_runner import CmdRunner from common.cmd.cmd_runner import CmdRunner
from common.cmd.cmd_status import CmdStatus from common.cmd.cmd_status import CmdStatus
@ -18,7 +18,7 @@ class AwsCmdRunner(CmdRunner):
super(AwsCmdRunner, self).__init__(is_linux) super(AwsCmdRunner, self).__init__(is_linux)
self.instance_id = instance_id self.instance_id = instance_id
self.region = region self.region = region
self.ssm = AwsService.get_client("ssm", region) self.ssm = aws_service.get_client("ssm", region)
def query_command(self, command_id): def query_command(self, command_id):
time.sleep(2) time.sleep(2)

View File

@ -1,6 +1,6 @@
import logging import logging
from common.aws.aws_instance import AwsInstance from common.aws import aws_service
from infection_monkey.telemetry.aws_instance_telem import AWSInstanceTelemetry from infection_monkey.telemetry.aws_instance_telem import AWSInstanceTelemetry
from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter import ( from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter import (
LegacyTelemetryMessengerAdapter, LegacyTelemetryMessengerAdapter,
@ -10,16 +10,12 @@ from infection_monkey.utils.threading import create_daemon_thread
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _running_on_aws(aws_instance: AwsInstance) -> bool:
return aws_instance.is_instance
def _report_aws_environment(telemetry_messenger: LegacyTelemetryMessengerAdapter): def _report_aws_environment(telemetry_messenger: LegacyTelemetryMessengerAdapter):
logger.info("Collecting AWS info") logger.info("Collecting AWS info")
aws_instance = AwsInstance() aws_instance = aws_service.initialize()
if _running_on_aws(aws_instance): if aws_service.is_on_aws():
logger.info("Machine is an AWS instance") logger.info("Machine is an AWS instance")
telemetry_messenger.send_telemetry(AWSInstanceTelemetry(aws_instance.instance_id)) telemetry_messenger.send_telemetry(AWSInstanceTelemetry(aws_instance.instance_id))
else: else:

View File

@ -1,7 +1,6 @@
import os import os
import uuid import uuid
from datetime import timedelta from datetime import timedelta
from threading import Thread
from typing import Type from typing import Type
import flask_restful import flask_restful
@ -49,7 +48,6 @@ from monkey_island.cc.resources.zero_trust.finding_event import ZeroTrustFinding
from monkey_island.cc.resources.zero_trust.zero_trust_report import ZeroTrustReport from monkey_island.cc.resources.zero_trust.zero_trust_report import ZeroTrustReport
from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH
from monkey_island.cc.server_utils.custom_json_encoder import CustomJSONEncoder from monkey_island.cc.server_utils.custom_json_encoder import CustomJSONEncoder
from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService
from monkey_island.cc.services.representations import output_json from monkey_island.cc.services.representations import output_json
HOME_FILE = "index.html" HOME_FILE = "index.html"
@ -104,10 +102,6 @@ def init_app_services(app):
with app.app_context(): with app.app_context():
database.init() database.init()
# If running on AWS, this will initialize the instance data, which is used "later" in the
# execution of the island. Run on a daemon thread since it's slow.
Thread(target=RemoteRunAwsService.init, name="AWS check thread", daemon=True).start()
def init_app_url_rules(app): def init_app_url_rules(app):
app.add_url_rule("/", "serve_home", serve_home) app.add_url_rule("/", "serve_home", serve_home)

View File

@ -4,7 +4,7 @@ import flask_restful
from botocore.exceptions import ClientError, NoCredentialsError from botocore.exceptions import ClientError, NoCredentialsError
from flask import jsonify, make_response, request from flask import jsonify, make_response, request
from common.aws.aws_service import AwsService from common.aws import aws_service
from monkey_island.cc.resources.auth.auth import jwt_required from monkey_island.cc.resources.auth.auth import jwt_required
from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService
@ -19,10 +19,6 @@ NO_CREDS_ERROR_FORMAT = (
class RemoteRun(flask_restful.Resource): class RemoteRun(flask_restful.Resource):
def __init__(self):
super(RemoteRun, self).__init__()
RemoteRunAwsService.init()
def run_aws_monkeys(self, request_body): def run_aws_monkeys(self, request_body):
instances = request_body.get("instances") instances = request_body.get("instances")
island_ip = request_body.get("island_ip") island_ip = request_body.get("island_ip")
@ -32,11 +28,11 @@ class RemoteRun(flask_restful.Resource):
def get(self): def get(self):
action = request.args.get("action") action = request.args.get("action")
if action == "list_aws": if action == "list_aws":
is_aws = RemoteRunAwsService.is_running_on_aws() is_aws = aws_service.is_on_aws()
resp = {"is_aws": is_aws} resp = {"is_aws": is_aws}
if is_aws: if is_aws:
try: try:
resp["instances"] = AwsService.get_instances() resp["instances"] = aws_service.get_instances()
except NoCredentialsError as e: except NoCredentialsError as e:
resp["error"] = NO_CREDS_ERROR_FORMAT.format(e) resp["error"] = NO_CREDS_ERROR_FORMAT.format(e)
return jsonify(resp) return jsonify(resp)
@ -52,7 +48,6 @@ class RemoteRun(flask_restful.Resource):
body = json.loads(request.data) body = json.loads(request.data)
resp = {} resp = {}
if body.get("type") == "aws": if body.get("type") == "aws":
RemoteRunAwsService.update_aws_region_authless()
result = self.run_aws_monkeys(body) result = self.run_aws_monkeys(body)
resp["result"] = result resp["result"] = result
return jsonify(resp) return jsonify(resp)

View File

@ -1,6 +1,8 @@
from pathlib import Path from pathlib import Path
from threading import Thread
from common import DIContainer from common import DIContainer
from common.aws import aws_service
from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService
from monkey_island.cc.services.post_breach_files import PostBreachFilesService from monkey_island.cc.services.post_breach_files import PostBreachFilesService
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
@ -14,6 +16,9 @@ def initialize_services(data_dir: Path) -> DIContainer:
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas") IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
) )
# Takes a while so it's best to start it in the background
Thread(target=aws_service.initialize, name="AwsService initialization", daemon=True).start()
# This is temporary until we get DI all worked out. # This is temporary until we get DI all worked out.
PostBreachFilesService.initialize(container.resolve(IFileStorageService)) PostBreachFilesService.initialize(container.resolve(IFileStorageService))
LocalMonkeyRunService.initialize(data_dir) LocalMonkeyRunService.initialize(data_dir)

View File

@ -1,34 +1,13 @@
import logging import logging
from threading import Event
from common.aws.aws_instance import AwsInstance
from common.aws.aws_service import AwsService
from common.cmd.aws.aws_cmd_runner import AwsCmdRunner from common.cmd.aws.aws_cmd_runner import AwsCmdRunner
from common.cmd.cmd import Cmd from common.cmd.cmd import Cmd
from common.cmd.cmd_runner import CmdRunner from common.cmd.cmd_runner import CmdRunner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AWS_INFO_FETCH_TIMEOUT = 10 # Seconds
aws_info_fetch_done = Event()
class RemoteRunAwsService: class RemoteRunAwsService:
aws_instance = None
def __init__(self):
pass
@staticmethod
def init():
"""
Initializes service. Subsequent calls to this function have no effect.
Must be called at least once (in entire monkey lifetime) before usage of functions
:return: None
"""
if RemoteRunAwsService.aws_instance is None:
RemoteRunAwsService.aws_instance = AwsInstance()
aws_info_fetch_done.set()
@staticmethod @staticmethod
def run_aws_monkeys(instances, island_ip): def run_aws_monkeys(instances, island_ip):
""" """
@ -47,18 +26,6 @@ class RemoteRunAwsService:
lambda _, result: result.is_success, lambda _, result: result.is_success,
) )
@staticmethod
def is_running_on_aws():
aws_info_fetch_done.wait(AWS_INFO_FETCH_TIMEOUT)
return RemoteRunAwsService.aws_instance.is_instance
@staticmethod
def update_aws_region_authless():
"""
Updates the AWS region without auth params (via IAM role)
"""
AwsService.set_region(RemoteRunAwsService.aws_instance.region)
@staticmethod @staticmethod
def _run_aws_monkey_cmd_async(instance_id, is_linux, island_ip): def _run_aws_monkey_cmd_async(instance_id, is_linux, island_ip):
""" """

View File

@ -5,6 +5,7 @@ from datetime import datetime
import boto3 import boto3
from botocore.exceptions import UnknownServiceError from botocore.exceptions import UnknownServiceError
from common.aws import aws_service
from common.aws.aws_instance import AwsInstance from common.aws.aws_instance import AwsInstance
from monkey_island.cc.services.reporting.exporter import Exporter from monkey_island.cc.services.reporting.exporter import Exporter
@ -35,8 +36,7 @@ class AWSExporter(Exporter):
logger.info("No issues were found by the monkey, no need to send anything") logger.info("No issues were found by the monkey, no need to send anything")
return True return True
# Not suppressing error here on purpose. current_aws_region = aws_service.get_region()
current_aws_region = AwsInstance().region
for machine in issues_list: for machine in issues_list:
for issue in issues_list[machine]: for issue in issues_list[machine]:

View File

@ -1,6 +1,6 @@
import logging import logging
from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService from common.aws import aws_service
from monkey_island.cc.services.reporting.aws_exporter import AWSExporter from monkey_island.cc.services.reporting.aws_exporter import AWSExporter
from monkey_island.cc.services.reporting.report_exporter_manager import ReportExporterManager from monkey_island.cc.services.reporting.report_exporter_manager import ReportExporterManager
@ -22,8 +22,7 @@ def populate_exporter_list():
def try_add_aws_exporter_to_manager(manager): def try_add_aws_exporter_to_manager(manager):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
RemoteRunAwsService.init() if aws_service.is_on_aws():
if RemoteRunAwsService.is_running_on_aws():
manager.add_exporter_to_list(AWSExporter) manager.add_exporter_to_list(AWSExporter)
except Exception: except Exception:
logger.error("Failed adding aws exporter to manager. Exception info:", exc_info=True) logger.error("Failed adding aws exporter to manager. Exception info:", exc_info=True)