Island, UT: fix island config option extraction to also expand paths and add a UT for that

This commit is contained in:
VakarisZ 2021-11-30 12:12:29 +02:00
parent 03566d2966
commit 06f31791fc
2 changed files with 49 additions and 20 deletions

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from dpath import util
from common.utils.file_utils import expand_path
from monkey_island.cc.server_utils.consts import (
DEFAULT_CERTIFICATE_PATHS,
@ -10,29 +12,51 @@ from monkey_island.cc.server_utils.consts import (
DEFAULT_START_MONGO_DB,
)
_DATA_DIR = "data_dir"
_SSL_CERT = "ssl_certificate"
_SSL_CERT_FILE = "ssl_certificate_file"
_SSL_CERT_KEY = "ssl_certificate_key_file"
_MONGODB = "mongodb"
_START_MONGODB = "start_mongodb"
_LOG_LEVEL = "log_level"
class IslandConfigOptions:
def __init__(self, config_contents: dict = None):
if not config_contents:
config_contents = {}
self.data_dir = expand_path(config_contents.get("data_dir", DEFAULT_DATA_DIR))
self.data_dir = expand_path(config_contents.get(_DATA_DIR, DEFAULT_DATA_DIR))
self.log_level = config_contents.get("log_level", DEFAULT_LOG_LEVEL)
self.log_level = config_contents.get(_LOG_LEVEL, DEFAULT_LOG_LEVEL)
self.start_mongodb = config_contents.get(
"mongodb", {"start_mongodb": DEFAULT_START_MONGO_DB}
).get("start_mongodb", DEFAULT_START_MONGO_DB)
_MONGODB, {_START_MONGODB: DEFAULT_START_MONGO_DB}
).get(_START_MONGODB, DEFAULT_START_MONGO_DB)
self.crt_path = expand_path(
config_contents.get("ssl_certificate", DEFAULT_CERTIFICATE_PATHS).get(
"ssl_certificate_file", DEFAULT_CRT_PATH
config_contents.get(_SSL_CERT, DEFAULT_CERTIFICATE_PATHS).get(
_SSL_CERT_FILE, DEFAULT_CRT_PATH
)
)
self.key_path = expand_path(
config_contents.get("ssl_certificate", DEFAULT_CERTIFICATE_PATHS).get(
"ssl_certificate_key_file", DEFAULT_KEY_PATH
config_contents.get(_SSL_CERT, DEFAULT_CERTIFICATE_PATHS).get(
_SSL_CERT_KEY, DEFAULT_KEY_PATH
)
)
def update(self, target: dict):
target = self._expand_config_paths(target)
self.__dict__.update(target)
@staticmethod
def _expand_config_paths(config: dict) -> dict:
config_paths = [_DATA_DIR, f"{_SSL_CERT}.{_SSL_CERT_FILE}", f"{_SSL_CERT}.{_SSL_CERT_KEY}"]
for config_path in config_paths:
try:
expanded_val = expand_path(util.get(config, config_path, "."))
util.set(config, config_path, expanded_val, ".")
except KeyError:
pass
return config

View File

@ -6,6 +6,7 @@ import pytest
import monkey_island.cc.setup.config_setup # noqa: F401
from monkey_island.cc.arg_parser import IslandCmdArgs
from monkey_island.cc.server_setup import _extract_config
from monkey_island.cc.server_utils.file_utils import is_windows_os
from monkey_island.cc.setup.island_config_options import IslandConfigOptions
BAD_JSON = '{"data_dir": "C:\\test\\test"'
@ -53,32 +54,36 @@ def test_deployment_config_overrides_defaults(deployment_server_config_path):
)
def test_cmd_config_overrides_everything(
deployment_server_config_path, cmd_server_config_path, user_default_server_config_path
):
expected = IslandConfigOptions({"log_level": "/log_level_4"})
def test_cmd_config_overrides_everything(deployment_server_config_path, cmd_server_config_path):
expected = IslandConfigOptions({"log_level": "/log_level_3"})
create_server_config(dumps({"log_level": "/log_level_2"}), deployment_server_config_path)
create_server_config(dumps({"log_level": "/log_level_3"}), user_default_server_config_path)
create_server_config(dumps({"log_level": "/log_level_4"}), cmd_server_config_path)
create_server_config(dumps({"log_level": "/log_level_3"}), cmd_server_config_path)
extracted_config = _extract_config(
IslandCmdArgs(setup_only=False, server_config_path=cmd_server_config_path)
)
assert expected.__dict__ == extracted_config.__dict__
def test_not_overriding_unspecified_values(
deployment_server_config_path, cmd_server_config_path, user_default_server_config_path
):
expected = IslandConfigOptions({"log_level": "/log_level_4", "data_dir": "/data_dir1"})
def test_not_overriding_unspecified_values(deployment_server_config_path, cmd_server_config_path):
expected = IslandConfigOptions({"log_level": "/log_level_2", "data_dir": "/data_dir1"})
create_server_config(dumps({"data_dir": "/data_dir1"}), deployment_server_config_path)
create_server_config(dumps({"log_level": "/log_level_3"}), user_default_server_config_path)
create_server_config(dumps({"log_level": "/log_level_4"}), cmd_server_config_path)
create_server_config(dumps({"log_level": "/log_level_2"}), cmd_server_config_path)
extracted_config = _extract_config(
IslandCmdArgs(setup_only=False, server_config_path=cmd_server_config_path)
)
assert expected.__dict__ == extracted_config.__dict__
def test_paths_get_expanded(deployment_server_config_path):
if is_windows_os():
path = "%temp%/path"
else:
path = "$HOME/path"
create_server_config(dumps({"data_dir": path}), deployment_server_config_path)
extracted_config = _extract_config(IslandCmdArgs(setup_only=False, server_config_path=None))
assert not extracted_config.data_dir == path
def test_malformed_json(cmd_server_config_path):
create_server_config(BAD_JSON, cmd_server_config_path)
with pytest.raises(SystemExit):