diff --git a/monkey/monkey_island/cc/setup/island_config_options.py b/monkey/monkey_island/cc/setup/island_config_options.py index 71e9ba27f..384ed3079 100644 --- a/monkey/monkey_island/cc/setup/island_config_options.py +++ b/monkey/monkey_island/cc/setup/island_config_options.py @@ -1,5 +1,7 @@ from __future__ import annotations +from dpath import util + from common.utils.file_utils import expand_path from monkey_island.cc.server_utils.consts import ( DEFAULT_CERTIFICATE_PATHS, @@ -10,29 +12,51 @@ from monkey_island.cc.server_utils.consts import ( DEFAULT_START_MONGO_DB, ) +_DATA_DIR = "data_dir" +_SSL_CERT = "ssl_certificate" +_SSL_CERT_FILE = "ssl_certificate_file" +_SSL_CERT_KEY = "ssl_certificate_key_file" +_MONGODB = "mongodb" +_START_MONGODB = "start_mongodb" +_LOG_LEVEL = "log_level" + class IslandConfigOptions: def __init__(self, config_contents: dict = None): if not config_contents: config_contents = {} - self.data_dir = expand_path(config_contents.get("data_dir", DEFAULT_DATA_DIR)) + self.data_dir = expand_path(config_contents.get(_DATA_DIR, DEFAULT_DATA_DIR)) - self.log_level = config_contents.get("log_level", DEFAULT_LOG_LEVEL) + self.log_level = config_contents.get(_LOG_LEVEL, DEFAULT_LOG_LEVEL) self.start_mongodb = config_contents.get( - "mongodb", {"start_mongodb": DEFAULT_START_MONGO_DB} - ).get("start_mongodb", DEFAULT_START_MONGO_DB) + _MONGODB, {_START_MONGODB: DEFAULT_START_MONGO_DB} + ).get(_START_MONGODB, DEFAULT_START_MONGO_DB) self.crt_path = expand_path( - config_contents.get("ssl_certificate", DEFAULT_CERTIFICATE_PATHS).get( - "ssl_certificate_file", DEFAULT_CRT_PATH + config_contents.get(_SSL_CERT, DEFAULT_CERTIFICATE_PATHS).get( + _SSL_CERT_FILE, DEFAULT_CRT_PATH ) ) self.key_path = expand_path( - config_contents.get("ssl_certificate", DEFAULT_CERTIFICATE_PATHS).get( - "ssl_certificate_key_file", DEFAULT_KEY_PATH + config_contents.get(_SSL_CERT, DEFAULT_CERTIFICATE_PATHS).get( + _SSL_CERT_KEY, DEFAULT_KEY_PATH ) ) def update(self, target: dict): + target = self._expand_config_paths(target) self.__dict__.update(target) + + @staticmethod + def _expand_config_paths(config: dict) -> dict: + config_paths = [_DATA_DIR, f"{_SSL_CERT}.{_SSL_CERT_FILE}", f"{_SSL_CERT}.{_SSL_CERT_KEY}"] + + for config_path in config_paths: + try: + expanded_val = expand_path(util.get(config, config_path, ".")) + util.set(config, config_path, expanded_val, ".") + except KeyError: + pass + + return config diff --git a/monkey/tests/unit_tests/monkey_island/cc/setup/test_server_setup.py b/monkey/tests/unit_tests/monkey_island/cc/setup/test_server_setup.py index a6f4c59ca..226842607 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/setup/test_server_setup.py +++ b/monkey/tests/unit_tests/monkey_island/cc/setup/test_server_setup.py @@ -6,6 +6,7 @@ import pytest import monkey_island.cc.setup.config_setup # noqa: F401 from monkey_island.cc.arg_parser import IslandCmdArgs from monkey_island.cc.server_setup import _extract_config +from monkey_island.cc.server_utils.file_utils import is_windows_os from monkey_island.cc.setup.island_config_options import IslandConfigOptions BAD_JSON = '{"data_dir": "C:\\test\\test"' @@ -53,32 +54,36 @@ def test_deployment_config_overrides_defaults(deployment_server_config_path): ) -def test_cmd_config_overrides_everything( - deployment_server_config_path, cmd_server_config_path, user_default_server_config_path -): - expected = IslandConfigOptions({"log_level": "/log_level_4"}) +def test_cmd_config_overrides_everything(deployment_server_config_path, cmd_server_config_path): + expected = IslandConfigOptions({"log_level": "/log_level_3"}) create_server_config(dumps({"log_level": "/log_level_2"}), deployment_server_config_path) - create_server_config(dumps({"log_level": "/log_level_3"}), user_default_server_config_path) - create_server_config(dumps({"log_level": "/log_level_4"}), cmd_server_config_path) + create_server_config(dumps({"log_level": "/log_level_3"}), cmd_server_config_path) extracted_config = _extract_config( IslandCmdArgs(setup_only=False, server_config_path=cmd_server_config_path) ) assert expected.__dict__ == extracted_config.__dict__ -def test_not_overriding_unspecified_values( - deployment_server_config_path, cmd_server_config_path, user_default_server_config_path -): - expected = IslandConfigOptions({"log_level": "/log_level_4", "data_dir": "/data_dir1"}) +def test_not_overriding_unspecified_values(deployment_server_config_path, cmd_server_config_path): + expected = IslandConfigOptions({"log_level": "/log_level_2", "data_dir": "/data_dir1"}) create_server_config(dumps({"data_dir": "/data_dir1"}), deployment_server_config_path) - create_server_config(dumps({"log_level": "/log_level_3"}), user_default_server_config_path) - create_server_config(dumps({"log_level": "/log_level_4"}), cmd_server_config_path) + create_server_config(dumps({"log_level": "/log_level_2"}), cmd_server_config_path) extracted_config = _extract_config( IslandCmdArgs(setup_only=False, server_config_path=cmd_server_config_path) ) assert expected.__dict__ == extracted_config.__dict__ +def test_paths_get_expanded(deployment_server_config_path): + if is_windows_os(): + path = "%temp%/path" + else: + path = "$HOME/path" + create_server_config(dumps({"data_dir": path}), deployment_server_config_path) + extracted_config = _extract_config(IslandCmdArgs(setup_only=False, server_config_path=None)) + assert not extracted_config.data_dir == path + + def test_malformed_json(cmd_server_config_path): create_server_config(BAD_JSON, cmd_server_config_path) with pytest.raises(SystemExit):