diff --git a/monkey/monkey_island/cc/server_setup.py b/monkey/monkey_island/cc/server_setup.py index 05a73280f..8b8e433ea 100644 --- a/monkey/monkey_island/cc/server_setup.py +++ b/monkey/monkey_island/cc/server_setup.py @@ -1,11 +1,10 @@ import atexit import json import logging -import shutil import sys from pathlib import Path from threading import Thread -from typing import Optional, Tuple +from typing import Tuple import gevent.hub from gevent.pywsgi import WSGIServer @@ -33,7 +32,7 @@ from monkey_island.cc.services.initialize import initialize_services # noqa: E4 from monkey_island.cc.services.reporting.exporter_init import populate_exporter_list # noqa: E402 from monkey_island.cc.services.utils.network_utils import local_ip_addresses # noqa: E402 from monkey_island.cc.setup import island_config_options_validator # noqa: E402 -from monkey_island.cc.setup.data_dir import OldDataError # noqa: E402 +from monkey_island.cc.setup.data_dir import IncompatibleDataDirectory # noqa: E402 from monkey_island.cc.setup.gevent_hub_error_handler import GeventHubErrorHandler # noqa: E402 from monkey_island.cc.setup.island_config_options import IslandConfigOptions # noqa: E402 from monkey_island.cc.setup.mongo import mongo_setup # noqa: E402 @@ -70,29 +69,7 @@ def _setup_data_dir(island_args: IslandCmdArgs) -> Tuple[IslandConfigOptions, st except json.JSONDecodeError as ex: print(f"Error loading server config: {ex}") exit(1) - except OldDataError as ex: - return _handle_existing_data_directory(ex, island_args) - - -def _handle_existing_data_directory( - exception: Exception, island_args: IslandCmdArgs -) -> Optional[Tuple[IslandConfigOptions, str]]: - user_response = input( - f"\nExisting data directory ({exception.old_data_dir}) needs to be deleted." - " All data from previous runs will be lost. Proceed to delete? (y/n) " - ) - if user_response == "y": - shutil.rmtree(exception.old_data_dir) - print("\nOld data directory was deleted. Trying to set up again...\n") - return _setup_data_dir(island_args) - elif user_response == "n": - print( - "\nExiting. Please backup and delete the existing data directory. Then, try again." - "\nTo learn how to restore and use a backup, please refer to the documentation.\n" - ) - exit(1) - else: - print("\nExiting. Unrecognized response, please try again.\n") + except IncompatibleDataDirectory: exit(1) diff --git a/monkey/monkey_island/cc/setup/data_dir.py b/monkey/monkey_island/cc/setup/data_dir.py index 245a4f14c..c728dca04 100644 --- a/monkey/monkey_island/cc/setup/data_dir.py +++ b/monkey/monkey_island/cc/setup/data_dir.py @@ -1,4 +1,5 @@ import logging +import shutil from pathlib import Path from common.version import get_version @@ -7,28 +8,44 @@ from monkey_island.cc.setup.version_file_setup import get_version_from_dir, writ logger = logging.getLogger(__name__) -_data_dir_backup_suffix = ".old" - -class OldDataError(Exception): - def __init__(self, old_data_dir: Path) -> None: - self.old_data_dir = old_data_dir +class IncompatibleDataDirectory(Exception): + pass def setup_data_dir(data_dir_path: Path) -> None: logger.info(f"Setting up data directory at {data_dir_path}.") - if data_dir_path.exists(): - logger.info(f"Data directory already exists at {data_dir_path}.") - _check_current_data_dir(data_dir_path) + if data_dir_path.exists() and _data_dir_version_mismatch_exists(data_dir_path): + logger.info("Version in data directory does not match the Island's version.") + _handle_old_data_directory(data_dir_path) create_secure_directory(str(data_dir_path)) write_version(data_dir_path) logger.info("Data directory set up.") -def _check_current_data_dir(data_dir_path: Path) -> None: - if _data_dir_version_mismatch_exists(data_dir_path): - logger.info("Version in data directory does not match the Island's version.") - raise OldDataError(data_dir_path) +def _handle_old_data_directory(data_dir_path: Path) -> None: + should_delete_data_directory = _prompt_user_to_delete_data_directory(data_dir_path) + if should_delete_data_directory: + shutil.rmtree(data_dir_path) + logger.info(f"{data_dir_path} was deleted.") + else: + logger.error( + "Unable to set up data directory. Please backup and delete the existing data directory" + f" ({data_dir_path}). Then, try again. To learn how to restore and use a backup, please" + " refer to the documentation." + ) + raise IncompatibleDataDirectory() + + +def _prompt_user_to_delete_data_directory(data_dir_path: Path) -> bool: + user_response = input( + f"\nExisting data directory ({data_dir_path}) needs to be deleted." + " All data from previous runs will be lost. Proceed to delete? (y/n) " + ) + print() + if user_response.lower() in {"y", "yes"}: + return True + return False def _data_dir_version_mismatch_exists(data_dir_path: Path) -> bool: diff --git a/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py b/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py index 907e62741..696b604dc 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py +++ b/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py @@ -2,8 +2,8 @@ from pathlib import Path import pytest -from monkey_island.cc.setup.data_dir import _get_backup_path, setup_data_dir -from monkey_island.cc.setup.version_file_setup import _version_filename, get_version_from_dir +from monkey_island.cc.setup.data_dir import IncompatibleDataDirectory, setup_data_dir +from monkey_island.cc.setup.version_file_setup import _version_filename current_version = "1.1.1" old_version = "1.1.0" @@ -36,7 +36,9 @@ def test_setup_data_dir(temp_data_dir_path, temp_version_file_path): assert version_file_path.read_text() == current_version -def test_old_version_present(temp_data_dir_path, temp_version_file_path): +def test_old_version_removed(monkeypatch, temp_data_dir_path, temp_version_file_path): + monkeypatch.setattr("builtins.input", lambda _: "y") + temp_data_dir_path.mkdir() temp_version_file_path.write_text(old_version) bogus_file_path = temp_data_dir_path.joinpath("test.txt") @@ -46,25 +48,24 @@ def test_old_version_present(temp_data_dir_path, temp_version_file_path): assert temp_version_file_path.read_text() == current_version assert not bogus_file_path.is_file() - assert _get_backup_path(temp_data_dir_path).joinpath("test.txt").is_file() -def test_old_version_and_backup_present(temp_data_dir_path, temp_version_file_path): +@pytest.mark.parametrize("input_value", ["n", "x"]) +def test_old_version_not_removed( + monkeypatch, temp_data_dir_path, temp_version_file_path, input_value +): + monkeypatch.setattr("builtins.input", lambda _: input_value) + temp_data_dir_path.mkdir() temp_version_file_path.write_text(old_version) - - old_backup_path = _get_backup_path(temp_data_dir_path) - old_backup_path.mkdir() - bogus_file_path = old_backup_path.joinpath("test.txt") + bogus_file_path = temp_data_dir_path.joinpath("test.txt") bogus_file_path.touch() - setup_data_dir(temp_data_dir_path) - new_backup_path = old_backup_path + with pytest.raises(IncompatibleDataDirectory): + setup_data_dir(temp_data_dir_path) - # Make sure old backup got deleted and new backup took it's place - assert temp_version_file_path.read_text() == current_version - assert get_version_from_dir(new_backup_path) == old_version - assert not _get_backup_path(temp_data_dir_path).joinpath("test.txt").is_file() + assert temp_version_file_path.read_text() == old_version + assert bogus_file_path.is_file() def test_data_dir_setup_not_needed(temp_data_dir_path, temp_version_file_path):