Island, UT: fix island config option extraction to also expand paths and add a UT for that

This commit is contained in:
VakarisZ 2021-11-30 12:12:29 +02:00
parent 03566d2966
commit 06f31791fc
2 changed files with 49 additions and 20 deletions

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dpath import util
from common.utils.file_utils import expand_path from common.utils.file_utils import expand_path
from monkey_island.cc.server_utils.consts import ( from monkey_island.cc.server_utils.consts import (
DEFAULT_CERTIFICATE_PATHS, DEFAULT_CERTIFICATE_PATHS,
@ -10,29 +12,51 @@ from monkey_island.cc.server_utils.consts import (
DEFAULT_START_MONGO_DB, DEFAULT_START_MONGO_DB,
) )
_DATA_DIR = "data_dir"
_SSL_CERT = "ssl_certificate"
_SSL_CERT_FILE = "ssl_certificate_file"
_SSL_CERT_KEY = "ssl_certificate_key_file"
_MONGODB = "mongodb"
_START_MONGODB = "start_mongodb"
_LOG_LEVEL = "log_level"
class IslandConfigOptions: class IslandConfigOptions:
def __init__(self, config_contents: dict = None): def __init__(self, config_contents: dict = None):
if not config_contents: if not config_contents:
config_contents = {} config_contents = {}
self.data_dir = expand_path(config_contents.get("data_dir", DEFAULT_DATA_DIR)) self.data_dir = expand_path(config_contents.get(_DATA_DIR, DEFAULT_DATA_DIR))
self.log_level = config_contents.get("log_level", DEFAULT_LOG_LEVEL) self.log_level = config_contents.get(_LOG_LEVEL, DEFAULT_LOG_LEVEL)
self.start_mongodb = config_contents.get( self.start_mongodb = config_contents.get(
"mongodb", {"start_mongodb": DEFAULT_START_MONGO_DB} _MONGODB, {_START_MONGODB: DEFAULT_START_MONGO_DB}
).get("start_mongodb", DEFAULT_START_MONGO_DB) ).get(_START_MONGODB, DEFAULT_START_MONGO_DB)
self.crt_path = expand_path( self.crt_path = expand_path(
config_contents.get("ssl_certificate", DEFAULT_CERTIFICATE_PATHS).get( config_contents.get(_SSL_CERT, DEFAULT_CERTIFICATE_PATHS).get(
"ssl_certificate_file", DEFAULT_CRT_PATH _SSL_CERT_FILE, DEFAULT_CRT_PATH
) )
) )
self.key_path = expand_path( self.key_path = expand_path(
config_contents.get("ssl_certificate", DEFAULT_CERTIFICATE_PATHS).get( config_contents.get(_SSL_CERT, DEFAULT_CERTIFICATE_PATHS).get(
"ssl_certificate_key_file", DEFAULT_KEY_PATH _SSL_CERT_KEY, DEFAULT_KEY_PATH
) )
) )
def update(self, target: dict): def update(self, target: dict):
target = self._expand_config_paths(target)
self.__dict__.update(target) self.__dict__.update(target)
@staticmethod
def _expand_config_paths(config: dict) -> dict:
config_paths = [_DATA_DIR, f"{_SSL_CERT}.{_SSL_CERT_FILE}", f"{_SSL_CERT}.{_SSL_CERT_KEY}"]
for config_path in config_paths:
try:
expanded_val = expand_path(util.get(config, config_path, "."))
util.set(config, config_path, expanded_val, ".")
except KeyError:
pass
return config

View File

@ -6,6 +6,7 @@ import pytest
import monkey_island.cc.setup.config_setup # noqa: F401 import monkey_island.cc.setup.config_setup # noqa: F401
from monkey_island.cc.arg_parser import IslandCmdArgs from monkey_island.cc.arg_parser import IslandCmdArgs
from monkey_island.cc.server_setup import _extract_config from monkey_island.cc.server_setup import _extract_config
from monkey_island.cc.server_utils.file_utils import is_windows_os
from monkey_island.cc.setup.island_config_options import IslandConfigOptions from monkey_island.cc.setup.island_config_options import IslandConfigOptions
BAD_JSON = '{"data_dir": "C:\\test\\test"' BAD_JSON = '{"data_dir": "C:\\test\\test"'
@ -53,32 +54,36 @@ def test_deployment_config_overrides_defaults(deployment_server_config_path):
) )
def test_cmd_config_overrides_everything( def test_cmd_config_overrides_everything(deployment_server_config_path, cmd_server_config_path):
deployment_server_config_path, cmd_server_config_path, user_default_server_config_path expected = IslandConfigOptions({"log_level": "/log_level_3"})
):
expected = IslandConfigOptions({"log_level": "/log_level_4"})
create_server_config(dumps({"log_level": "/log_level_2"}), deployment_server_config_path) create_server_config(dumps({"log_level": "/log_level_2"}), deployment_server_config_path)
create_server_config(dumps({"log_level": "/log_level_3"}), user_default_server_config_path) create_server_config(dumps({"log_level": "/log_level_3"}), cmd_server_config_path)
create_server_config(dumps({"log_level": "/log_level_4"}), cmd_server_config_path)
extracted_config = _extract_config( extracted_config = _extract_config(
IslandCmdArgs(setup_only=False, server_config_path=cmd_server_config_path) IslandCmdArgs(setup_only=False, server_config_path=cmd_server_config_path)
) )
assert expected.__dict__ == extracted_config.__dict__ assert expected.__dict__ == extracted_config.__dict__
def test_not_overriding_unspecified_values( def test_not_overriding_unspecified_values(deployment_server_config_path, cmd_server_config_path):
deployment_server_config_path, cmd_server_config_path, user_default_server_config_path expected = IslandConfigOptions({"log_level": "/log_level_2", "data_dir": "/data_dir1"})
):
expected = IslandConfigOptions({"log_level": "/log_level_4", "data_dir": "/data_dir1"})
create_server_config(dumps({"data_dir": "/data_dir1"}), deployment_server_config_path) create_server_config(dumps({"data_dir": "/data_dir1"}), deployment_server_config_path)
create_server_config(dumps({"log_level": "/log_level_3"}), user_default_server_config_path) create_server_config(dumps({"log_level": "/log_level_2"}), cmd_server_config_path)
create_server_config(dumps({"log_level": "/log_level_4"}), cmd_server_config_path)
extracted_config = _extract_config( extracted_config = _extract_config(
IslandCmdArgs(setup_only=False, server_config_path=cmd_server_config_path) IslandCmdArgs(setup_only=False, server_config_path=cmd_server_config_path)
) )
assert expected.__dict__ == extracted_config.__dict__ assert expected.__dict__ == extracted_config.__dict__
def test_paths_get_expanded(deployment_server_config_path):
if is_windows_os():
path = "%temp%/path"
else:
path = "$HOME/path"
create_server_config(dumps({"data_dir": path}), deployment_server_config_path)
extracted_config = _extract_config(IslandCmdArgs(setup_only=False, server_config_path=None))
assert not extracted_config.data_dir == path
def test_malformed_json(cmd_server_config_path): def test_malformed_json(cmd_server_config_path):
create_server_config(BAD_JSON, cmd_server_config_path) create_server_config(BAD_JSON, cmd_server_config_path)
with pytest.raises(SystemExit): with pytest.raises(SystemExit):