Merge pull request #2195 from guardicore/refactor-island-boot

Refactor island boot
This commit is contained in:
Mike Salvatore 2022-08-15 08:35:00 -04:00 committed by GitHub
commit 4e9aa62c61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 95 deletions

View File

@ -3,6 +3,8 @@ import json
import logging import logging
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
from typing import Sequence, Tuple
import gevent.hub import gevent.hub
import requests import requests
@ -11,6 +13,7 @@ from gevent.pywsgi import WSGIServer
from monkey_island.cc import Version from monkey_island.cc import Version
from monkey_island.cc.deployment import Deployment from monkey_island.cc.deployment import Deployment
from monkey_island.cc.server_utils.consts import ISLAND_PORT from monkey_island.cc.server_utils.consts import ISLAND_PORT
from monkey_island.cc.server_utils.island_logger import get_log_file_path
from monkey_island.cc.setup.config_setup import get_server_config from monkey_island.cc.setup.config_setup import get_server_config
# Add the monkey_island directory to the path, to make sure imports that don't start with # Add the monkey_island directory to the path, to make sure imports that don't start with
@ -27,6 +30,7 @@ from monkey_island.cc.arg_parser import parse_cli_args # noqa: E402
from monkey_island.cc.server_utils.consts import ( # noqa: E402 from monkey_island.cc.server_utils.consts import ( # noqa: E402
GEVENT_EXCEPTION_LOG, GEVENT_EXCEPTION_LOG,
MONGO_CONNECTION_TIMEOUT, MONGO_CONNECTION_TIMEOUT,
MONKEY_ISLAND_ABS_PATH,
) )
from monkey_island.cc.server_utils.island_logger import reset_logger, setup_logging # noqa: E402 from monkey_island.cc.server_utils.island_logger import reset_logger, setup_logging # noqa: E402
from monkey_island.cc.services.initialize import initialize_services # noqa: E402 from monkey_island.cc.services.initialize import initialize_services # noqa: E402
@ -45,44 +49,41 @@ logger = logging.getLogger(__name__)
def run_monkey_island(): def run_monkey_island():
island_args = parse_cli_args() island_args = parse_cli_args()
config_options = _extract_config(island_args) config_options = _extract_config(island_args)
_setup_data_dir(config_options.data_dir)
_exit_on_invalid_config_options(config_options) _exit_on_invalid_config_options(config_options)
_setup_data_dir(config_options.data_dir)
_configure_logging(config_options) _configure_logging(config_options)
container = _initialize_di_container(config_options.data_dir)
mongo_db_process = None ip_addresses, deployment, version = _collect_system_info()
if config_options.start_mongodb:
mongo_db_process = _start_mongodb(config_options.data_dir)
_connect_to_mongodb(mongo_db_process) _send_analytics(deployment, version)
container = _initialize_di_container(ip_addresses, version, config_options.data_dir)
_configure_gevent_exception_handling(config_options.data_dir) _initialize_mongodb_connection(config_options.start_mongodb, config_options.data_dir)
_start_island_server(island_args.setup_only, config_options, container) _start_island_server(ip_addresses, island_args.setup_only, config_options, container)
def _extract_config(island_args: IslandCmdArgs) -> IslandConfigOptions: def _extract_config(island_args: IslandCmdArgs) -> IslandConfigOptions:
try: try:
return get_server_config(island_args) return get_server_config(island_args)
except json.JSONDecodeError as ex: except json.JSONDecodeError as err:
print(f"Error loading server config: {ex}") print(f"Error loading server config: {err}")
sys.exit(1) sys.exit(1)
def _setup_data_dir(data_dir_path: Path): def _setup_data_dir(data_dir_path: Path):
try: try:
setup_data_dir(data_dir_path) setup_data_dir(data_dir_path)
except IncompatibleDataDirectory as ex: except IncompatibleDataDirectory as err:
print(f"Incompatible data directory: {ex}") print(f"Incompatible data directory: {err}")
sys.exit(1) sys.exit(1)
def _exit_on_invalid_config_options(config_options: IslandConfigOptions): def _exit_on_invalid_config_options(config_options: IslandConfigOptions):
try: try:
island_config_options_validator.raise_on_invalid_options(config_options) island_config_options_validator.raise_on_invalid_options(config_options)
except Exception as ex: except Exception as err:
print(f"Configuration error: {ex}") print(f"Configuration error: {err}")
sys.exit(1) sys.exit(1)
@ -91,8 +92,48 @@ def _configure_logging(config_options):
setup_logging(config_options.data_dir, config_options.log_level) setup_logging(config_options.data_dir, config_options.log_level)
def _initialize_di_container(data_dir: Path) -> DIContainer: def _collect_system_info() -> Tuple[Sequence[str], Deployment, Version]:
return initialize_services(data_dir) deployment = _get_deployment()
version = Version(get_version(), deployment)
return (get_ip_addresses(), deployment, version)
def _get_deployment() -> Deployment:
deployment_file_path = Path(MONKEY_ISLAND_ABS_PATH) / "cc" / "deployment.json"
try:
with open(deployment_file_path, "r") as deployment_info_file:
deployment_info = json.load(deployment_info_file)
return Deployment[deployment_info["deployment"].upper()]
except KeyError as err:
raise Exception(
f"The deployment file ({deployment_file_path}) did not contain the expected data: "
f"missing key {err}"
)
except Exception as err:
raise Exception(f"Failed to fetch the deployment from {deployment_file_path}: {err}")
def _initialize_di_container(
ip_addresses: Sequence[str], version: Version, data_dir: Path
) -> DIContainer:
container = DIContainer()
container.register_convention(Sequence[str], "ip_addresses", ip_addresses)
container.register_instance(Version, version)
container.register_convention(Path, "data_dir", data_dir)
container.register_convention(Path, "island_log_file_path", get_log_file_path(data_dir))
initialize_services(container, data_dir)
return container
def _initialize_mongodb_connection(start_mongodb: bool, data_dir: Path):
mongo_db_process = None
if start_mongodb:
mongo_db_process = _start_mongodb(data_dir)
_connect_to_mongodb(mongo_db_process)
def _start_mongodb(data_dir: Path) -> MongoDbProcess: def _start_mongodb(data_dir: Path) -> MongoDbProcess:
@ -105,34 +146,27 @@ def _start_mongodb(data_dir: Path) -> MongoDbProcess:
def _connect_to_mongodb(mongo_db_process: MongoDbProcess): def _connect_to_mongodb(mongo_db_process: MongoDbProcess):
try: try:
mongo_setup.connect_to_mongodb(MONGO_CONNECTION_TIMEOUT) mongo_setup.connect_to_mongodb(MONGO_CONNECTION_TIMEOUT)
except mongo_setup.MongoDBTimeOutError as ex: except mongo_setup.MongoDBTimeOutError as err:
if mongo_db_process and not mongo_db_process.is_running(): if mongo_db_process and not mongo_db_process.is_running():
logger.error( logger.error(
f"Failed to start MongoDB process. Check log at {mongo_db_process.log_file}." f"Failed to start MongoDB process. Check log at {mongo_db_process.log_file}."
) )
else: else:
logger.error(ex) logger.error(err)
sys.exit(1) sys.exit(1)
except mongo_setup.MongoDBVersionError as ex: except mongo_setup.MongoDBVersionError as err:
logger.error(ex) logger.error(err)
sys.exit(1) sys.exit(1)
def _configure_gevent_exception_handling(data_dir):
hub = gevent.hub.get_hub()
gevent_exception_log = open(data_dir / GEVENT_EXCEPTION_LOG, "w+", buffering=1)
atexit.register(gevent_exception_log.close)
# Send gevent's exception output to a log file.
# https://www.gevent.org/api/gevent.hub.html#gevent.hub.Hub.exception_stream
hub.exception_stream = gevent_exception_log
hub.handle_error = GeventHubErrorHandler(hub, logger)
def _start_island_server( def _start_island_server(
should_setup_only: bool, config_options: IslandConfigOptions, container: DIContainer ip_addresses: Sequence[str],
should_setup_only: bool,
config_options: IslandConfigOptions,
container: DIContainer,
): ):
_configure_gevent_exception_handling(config_options.data_dir)
app = init_app(mongo_setup.MONGO_URL, container) app = init_app(mongo_setup.MONGO_URL, container)
if should_setup_only: if should_setup_only:
@ -152,11 +186,22 @@ def _start_island_server(
log=_get_wsgi_server_logger(), log=_get_wsgi_server_logger(),
error_log=logger, error_log=logger,
) )
_log_init_info() _log_init_info(ip_addresses)
_send_analytics(container)
http_server.serve_forever() http_server.serve_forever()
def _configure_gevent_exception_handling(data_dir: Path):
hub = gevent.hub.get_hub()
gevent_exception_log = open(data_dir / GEVENT_EXCEPTION_LOG, "w+", buffering=1)
atexit.register(gevent_exception_log.close)
# Send gevent's exception output to a log file.
# https://www.gevent.org/api/gevent.hub.html#gevent.hub.Hub.exception_stream
hub.exception_stream = gevent_exception_log
hub.handle_error = GeventHubErrorHandler(hub, logger)
def _get_wsgi_server_logger() -> logging.Logger: def _get_wsgi_server_logger() -> logging.Logger:
wsgi_server_logger = logger.getChild("wsgi") wsgi_server_logger = logger.getChild("wsgi")
wsgi_server_logger.addFilter(PyWSGILoggingFilter()) wsgi_server_logger.addFilter(PyWSGILoggingFilter())
@ -164,31 +209,28 @@ def _get_wsgi_server_logger() -> logging.Logger:
return wsgi_server_logger return wsgi_server_logger
def _log_init_info(): def _log_init_info(ip_addresses: Sequence[str]):
logger.info("Monkey Island Server is running!") logger.info("Monkey Island Server is running!")
logger.info(f"version: {get_version()}") logger.info(f"version: {get_version()}")
_log_web_interface_access_urls() _log_web_interface_access_urls(ip_addresses)
def _log_web_interface_access_urls(): def _log_web_interface_access_urls(ip_addresses: Sequence[str]):
web_interface_urls = ", ".join([f"https://{ip}:{ISLAND_PORT}" for ip in get_ip_addresses()]) web_interface_urls = ", ".join([f"https://{ip}:{ISLAND_PORT}" for ip in ip_addresses])
logger.info( logger.info(
"To access the web interface, navigate to one of the the following URLs using your " "To access the web interface, navigate to one of the the following URLs using your "
f"browser: {web_interface_urls}" f"browser: {web_interface_urls}"
) )
ANALYTICS_URL = ( def _send_analytics(deployment: Deployment, version: Version):
"https://m15mjynko3.execute-api.us-east-1.amazonaws.com/default?version={" def _inner(deployment: Deployment, version: Version):
"version}&deployment={deployment}" url = (
"https://m15mjynko3.execute-api.us-east-1.amazonaws.com/default"
f"?version={version.version_number}&deployment={deployment.value}"
) )
def _send_analytics(di_container):
version = di_container.resolve(Version)
deployment = di_container.resolve(Deployment)
url = ANALYTICS_URL.format(deployment=deployment.value, version=version.version_number)
try: try:
response = requests.get(url).json() response = requests.get(url).json()
logger.info( logger.info(
@ -199,3 +241,5 @@ def _send_analytics(di_container):
logger.info( logger.info(
f"Failed to send deployment type and version number to the analytics server: {err}" f"Failed to send deployment type and version number to the analytics server: {err}"
) )
Thread(target=_inner, args=(deployment, version), daemon=True).start()

View File

@ -1,7 +1,5 @@
import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Sequence
from pubsub.core import Publisher from pubsub.core import Publisher
from pymongo import MongoClient from pymongo import MongoClient
@ -16,9 +14,6 @@ from common.aws import AWSInstance
from common.common_consts.telem_categories import TelemCategoryEnum from common.common_consts.telem_categories import TelemCategoryEnum
from common.event_queue import IEventQueue, PyPubSubEventQueue from common.event_queue import IEventQueue, PyPubSubEventQueue
from common.utils.file_utils import get_binary_io_sha256_hash from common.utils.file_utils import get_binary_io_sha256_hash
from common.version import get_version
from monkey_island.cc import Version
from monkey_island.cc.deployment import Deployment
from monkey_island.cc.repository import ( from monkey_island.cc.repository import (
AgentBinaryRepository, AgentBinaryRepository,
FileAgentConfigurationRepository, FileAgentConfigurationRepository,
@ -39,7 +34,6 @@ from monkey_island.cc.repository import (
) )
from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH
from monkey_island.cc.server_utils.encryption import ILockableEncryptor, RepositoryEncryptor from monkey_island.cc.server_utils.encryption import ILockableEncryptor, RepositoryEncryptor
from monkey_island.cc.server_utils.island_logger import get_log_file_path
from monkey_island.cc.services import AWSService, IslandModeService, RepositoryService from monkey_island.cc.services import AWSService, IslandModeService, RepositoryService
from monkey_island.cc.services.attack.technique_reports.T1003 import T1003, T1003GetReportData from monkey_island.cc.services.attack.technique_reports.T1003 import T1003, T1003GetReportData
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
@ -49,7 +43,6 @@ from monkey_island.cc.services.telemetry.processing.credentials.credentials_pars
from monkey_island.cc.services.telemetry.processing.processing import ( from monkey_island.cc.services.telemetry.processing.processing import (
TELEMETRY_CATEGORY_TO_PROCESSING_FUNC, TELEMETRY_CATEGORY_TO_PROCESSING_FUNC,
) )
from monkey_island.cc.services.utils.network_utils import get_ip_addresses
from monkey_island.cc.setup.mongo.mongo_setup import MONGO_URL from monkey_island.cc.setup.mongo.mongo_setup import MONGO_URL
from . import AuthenticationService from . import AuthenticationService
@ -58,21 +51,17 @@ from .reporting.report import ReportService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AGENT_BINARIES_PATH = Path(MONKEY_ISLAND_ABS_PATH) / "cc" / "binaries" AGENT_BINARIES_PATH = Path(MONKEY_ISLAND_ABS_PATH) / "cc" / "binaries"
DEPLOYMENT_FILE_PATH = Path(MONKEY_ISLAND_ABS_PATH) / "cc" / "deployment.json"
REPOSITORY_KEY_FILE_NAME = "repository_key.bin" REPOSITORY_KEY_FILE_NAME = "repository_key.bin"
def initialize_services(data_dir: Path) -> DIContainer: def initialize_services(container: DIContainer, data_dir: Path):
container = DIContainer() _register_conventions(container)
_register_conventions(container, data_dir)
container.register_instance(Deployment, _get_depyloyment_from_file(DEPLOYMENT_FILE_PATH))
container.register_instance(AWSInstance, AWSInstance()) container.register_instance(AWSInstance, AWSInstance())
container.register_instance(MongoClient, MongoClient(MONGO_URL, serverSelectionTimeoutMS=100)) container.register_instance(MongoClient, MongoClient(MONGO_URL, serverSelectionTimeoutMS=100))
container.register_instance( container.register_instance(
ILockableEncryptor, RepositoryEncryptor(data_dir / REPOSITORY_KEY_FILE_NAME) ILockableEncryptor, RepositoryEncryptor(data_dir / REPOSITORY_KEY_FILE_NAME)
) )
container.register_instance(Version, container.resolve(Version))
container.register(Publisher, Publisher) container.register(Publisher, Publisher)
container.register_instance(IEventQueue, container.resolve(PyPubSubEventQueue)) container.register_instance(IEventQueue, container.resolve(PyPubSubEventQueue))
@ -88,11 +77,8 @@ def initialize_services(data_dir: Path) -> DIContainer:
container.resolve(ICredentialsRepository), container.resolve(ICredentialsRepository),
) )
return container
def _register_conventions(container: DIContainer):
def _register_conventions(container: DIContainer, data_dir: Path):
container.register_convention(Path, "data_dir", data_dir)
container.register_convention( container.register_convention(
AgentConfiguration, "default_agent_configuration", DEFAULT_AGENT_CONFIGURATION AgentConfiguration, "default_agent_configuration", DEFAULT_AGENT_CONFIGURATION
) )
@ -101,9 +87,6 @@ def _register_conventions(container: DIContainer, data_dir: Path):
"default_ransomware_agent_configuration", "default_ransomware_agent_configuration",
DEFAULT_RANSOMWARE_AGENT_CONFIGURATION, DEFAULT_RANSOMWARE_AGENT_CONFIGURATION,
) )
container.register_convention(Path, "island_log_file_path", get_log_file_path(data_dir))
container.register_convention(str, "version_number", get_version())
container.register_convention(Sequence[str], "ip_addresses", get_ip_addresses())
def _register_repositories(container: DIContainer, data_dir: Path): def _register_repositories(container: DIContainer, data_dir: Path):
@ -161,22 +144,6 @@ def _log_agent_binary_hashes(agent_binary_repository: IAgentBinaryRepository):
logger.info(f"{os} agent: SHA-256 hash: {binary_sha256_hash}") logger.info(f"{os} agent: SHA-256 hash: {binary_sha256_hash}")
# TODO: The deployment should probably be passed into initialize_services(), but we can rework that
# when we refactor this file.
def _get_depyloyment_from_file(file_path: Path) -> Deployment:
try:
with open(file_path, "r") as deployment_info_file:
deployment_info = json.load(deployment_info_file)
return Deployment[deployment_info["deployment"].upper()]
except KeyError as err:
raise Exception(
f"The deployment file ({file_path}) did not contain the expected data: "
f"missing key {err}"
)
except Exception as err:
raise Exception(f"Failed to fetch the deployment from {file_path}: {err}")
def _register_services(container: DIContainer): def _register_services(container: DIContainer):
container.register_instance(AWSService, container.resolve(AWSService)) container.register_instance(AWSService, container.resolve(AWSService))
container.register_instance(LocalMonkeyRunService, container.resolve(LocalMonkeyRunService)) container.register_instance(LocalMonkeyRunService, container.resolve(LocalMonkeyRunService))