diff --git a/monkey/monkey_island/cc/environment/__init__.py b/monkey/monkey_island/cc/environment/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/monkey/monkey_island/cc/environment/server_config_handler.py b/monkey/monkey_island/cc/environment/server_config_handler.py deleted file mode 100644 index be70a8bec..000000000 --- a/monkey/monkey_island/cc/environment/server_config_handler.py +++ /dev/null @@ -1,11 +0,0 @@ -import json - -from monkey_island.cc.setup.island_config_options import IslandConfigOptions - - -def load_server_config_from_file(server_config_path) -> IslandConfigOptions: - with open(server_config_path, "r") as f: - config_content = f.read() - config = json.loads(config_content) - - return IslandConfigOptions(config) diff --git a/monkey/monkey_island/cc/server_setup.py b/monkey/monkey_island/cc/server_setup.py index 3de58ef8e..61d13a96a 100644 --- a/monkey/monkey_island/cc/server_setup.py +++ b/monkey/monkey_island/cc/server_setup.py @@ -10,7 +10,7 @@ import gevent.hub from gevent.pywsgi import WSGIServer from monkey_island.cc.server_utils.consts import ISLAND_PORT -from monkey_island.cc.setup.config_setup import extract_server_config +from monkey_island.cc.setup.config_setup import get_server_config # Add the monkey_island directory to the path, to make sure imports that don't start with # "monkey_island." work. @@ -64,10 +64,7 @@ def run_monkey_island(): def _extract_config(island_args: IslandCmdArgs) -> IslandConfigOptions: try: - return extract_server_config(island_args) - except OSError as ex: - print(f"Error opening server config file: {ex}") - exit(1) + return get_server_config(island_args) except json.JSONDecodeError as ex: print(f"Error loading server config: {ex}") exit(1) diff --git a/monkey/monkey_island/cc/setup/config_setup.py b/monkey/monkey_island/cc/setup/config_setup.py index 501df5b73..b997e35d4 100644 --- a/monkey/monkey_island/cc/setup/config_setup.py +++ b/monkey/monkey_island/cc/setup/config_setup.py @@ -1,14 +1,37 @@ +import json +from logging import getLogger +from pathlib import Path + from common.utils.file_utils import expand_path from monkey_island.cc.arg_parser import IslandCmdArgs -from monkey_island.cc.environment import server_config_handler from monkey_island.cc.server_utils.consts import DEFAULT_SERVER_CONFIG_PATH from monkey_island.cc.setup.island_config_options import IslandConfigOptions +logger = getLogger(__name__) + + +def get_server_config(island_args: IslandCmdArgs) -> IslandConfigOptions: + config = IslandConfigOptions({}) + + update_config_from_file(config, DEFAULT_SERVER_CONFIG_PATH) -def extract_server_config(island_args: IslandCmdArgs) -> IslandConfigOptions: if island_args.server_config_path: path_to_config = expand_path(island_args.server_config_path) - else: - path_to_config = DEFAULT_SERVER_CONFIG_PATH + update_config_from_file(config, path_to_config) - return server_config_handler.load_server_config_from_file(path_to_config) + return config + + +def update_config_from_file(config: IslandConfigOptions, config_path: Path): + try: + config_from_file = load_server_config_from_file(config_path) + config.update(config_from_file) + except OSError: + logger.info(f"Server config not found in path {config_path}") + + +def load_server_config_from_file(server_config_path) -> IslandConfigOptions: + with open(server_config_path, "r") as f: + config_content = f.read() + config = json.loads(config_content) + return IslandConfigOptions(config) diff --git a/monkey/monkey_island/cc/setup/island_config_options.py b/monkey/monkey_island/cc/setup/island_config_options.py index 66a49306a..12b141923 100644 --- a/monkey/monkey_island/cc/setup/island_config_options.py +++ b/monkey/monkey_island/cc/setup/island_config_options.py @@ -31,3 +31,6 @@ class IslandConfigOptions: "ssl_certificate_key_file", DEFAULT_KEY_PATH ) ) + + def update(self, target: IslandConfigOptions): + self.__dict__.update(target.__dict__) diff --git a/monkey/tests/unit_tests/monkey_island/cc/environment/__init__.py b/monkey/tests/unit_tests/monkey_island/cc/environment/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/monkey/tests/unit_tests/monkey_island/cc/setup/test_server_setup.py b/monkey/tests/unit_tests/monkey_island/cc/setup/test_server_setup.py new file mode 100644 index 000000000..cb3511b10 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/setup/test_server_setup.py @@ -0,0 +1,69 @@ +from json import dumps +from pathlib import Path + +import pytest + +import monkey_island.cc.setup.config_setup # noqa: F401 +from monkey_island.cc.arg_parser import IslandCmdArgs +from monkey_island.cc.server_setup import _extract_config +from monkey_island.cc.setup.island_config_options import IslandConfigOptions + +BAD_JSON = '{"data_dir": "C:\\test\\test"' + + +@pytest.fixture +def user_server_config_path(tmpdir) -> Path: + return tmpdir / "fake_server_config.json" + + +@pytest.fixture +def deployment_server_config_path(tmpdir) -> Path: + return tmpdir / "fake_server_config2.json" + + +def create_server_config(config_contents: str, server_config_path: Path): + with open(server_config_path, "w") as file: + file.write(config_contents) + + +@pytest.fixture(autouse=True) +def mock_deployment_config_path(monkeypatch, deployment_server_config_path): + monkeypatch.setattr( + "monkey_island.cc.setup.config_setup.DEFAULT_SERVER_CONFIG_PATH", + deployment_server_config_path, + ) + + +def test_extract_config_defaults(): + expected = IslandConfigOptions({}) + assert ( + expected.__dict__ + == _extract_config(IslandCmdArgs(setup_only=False, server_config_path=None)).__dict__ + ) + + +def test_package_config_overrides_defaults(deployment_server_config_path): + expected = IslandConfigOptions({"key_path": "/key_path_2"}) + create_server_config(dumps({"key_path": "/key_path_2"}), deployment_server_config_path) + assert ( + expected.__dict__ + == _extract_config(IslandCmdArgs(setup_only=False, server_config_path=None)).__dict__ + ) + + +def test_user_config_overrides_package_config( + deployment_server_config_path, user_server_config_path +): + expected = IslandConfigOptions({"key_path": "/key_path_3"}) + create_server_config(dumps({"key_path": "/key_path_2"}), deployment_server_config_path) + create_server_config(dumps({"key_path": "/key_path_3"}), user_server_config_path) + extracted_config = _extract_config( + IslandCmdArgs(setup_only=False, server_config_path=user_server_config_path) + ) + assert expected.__dict__ == extracted_config.__dict__ + + +def test_malformed_json(user_server_config_path): + create_server_config(BAD_JSON, user_server_config_path) + with pytest.raises(SystemExit): + _extract_config(IslandCmdArgs(setup_only=False, server_config_path=user_server_config_path))