Island: Move data directory deletion prompt to data_dir.py

This commit is contained in:
Shreya Malviya 2021-10-20 20:42:49 +05:30
parent b4c48e7cfb
commit d048f4e4ce
3 changed files with 48 additions and 53 deletions

View File

@ -1,11 +1,10 @@
import atexit import atexit
import json import json
import logging import logging
import shutil
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from typing import Optional, Tuple from typing import Tuple
import gevent.hub import gevent.hub
from gevent.pywsgi import WSGIServer from gevent.pywsgi import WSGIServer
@ -33,7 +32,7 @@ from monkey_island.cc.services.initialize import initialize_services # noqa: E4
from monkey_island.cc.services.reporting.exporter_init import populate_exporter_list # noqa: E402 from monkey_island.cc.services.reporting.exporter_init import populate_exporter_list # noqa: E402
from monkey_island.cc.services.utils.network_utils import local_ip_addresses # noqa: E402 from monkey_island.cc.services.utils.network_utils import local_ip_addresses # noqa: E402
from monkey_island.cc.setup import island_config_options_validator # noqa: E402 from monkey_island.cc.setup import island_config_options_validator # noqa: E402
from monkey_island.cc.setup.data_dir import OldDataError # noqa: E402 from monkey_island.cc.setup.data_dir import IncompatibleDataDirectory # noqa: E402
from monkey_island.cc.setup.gevent_hub_error_handler import GeventHubErrorHandler # noqa: E402 from monkey_island.cc.setup.gevent_hub_error_handler import GeventHubErrorHandler # noqa: E402
from monkey_island.cc.setup.island_config_options import IslandConfigOptions # noqa: E402 from monkey_island.cc.setup.island_config_options import IslandConfigOptions # noqa: E402
from monkey_island.cc.setup.mongo import mongo_setup # noqa: E402 from monkey_island.cc.setup.mongo import mongo_setup # noqa: E402
@ -70,29 +69,7 @@ def _setup_data_dir(island_args: IslandCmdArgs) -> Tuple[IslandConfigOptions, st
except json.JSONDecodeError as ex: except json.JSONDecodeError as ex:
print(f"Error loading server config: {ex}") print(f"Error loading server config: {ex}")
exit(1) exit(1)
except OldDataError as ex: except IncompatibleDataDirectory:
return _handle_existing_data_directory(ex, island_args)
def _handle_existing_data_directory(
exception: Exception, island_args: IslandCmdArgs
) -> Optional[Tuple[IslandConfigOptions, str]]:
user_response = input(
f"\nExisting data directory ({exception.old_data_dir}) needs to be deleted."
" All data from previous runs will be lost. Proceed to delete? (y/n) "
)
if user_response == "y":
shutil.rmtree(exception.old_data_dir)
print("\nOld data directory was deleted. Trying to set up again...\n")
return _setup_data_dir(island_args)
elif user_response == "n":
print(
"\nExiting. Please backup and delete the existing data directory. Then, try again."
"\nTo learn how to restore and use a backup, please refer to the documentation.\n"
)
exit(1)
else:
print("\nExiting. Unrecognized response, please try again.\n")
exit(1) exit(1)

View File

@ -1,4 +1,5 @@
import logging import logging
import shutil
from pathlib import Path from pathlib import Path
from common.version import get_version from common.version import get_version
@ -7,28 +8,44 @@ from monkey_island.cc.setup.version_file_setup import get_version_from_dir, writ
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_data_dir_backup_suffix = ".old"
class IncompatibleDataDirectory(Exception):
class OldDataError(Exception): pass
def __init__(self, old_data_dir: Path) -> None:
self.old_data_dir = old_data_dir
def setup_data_dir(data_dir_path: Path) -> None: def setup_data_dir(data_dir_path: Path) -> None:
logger.info(f"Setting up data directory at {data_dir_path}.") logger.info(f"Setting up data directory at {data_dir_path}.")
if data_dir_path.exists(): if data_dir_path.exists() and _data_dir_version_mismatch_exists(data_dir_path):
logger.info(f"Data directory already exists at {data_dir_path}.") logger.info("Version in data directory does not match the Island's version.")
_check_current_data_dir(data_dir_path) _handle_old_data_directory(data_dir_path)
create_secure_directory(str(data_dir_path)) create_secure_directory(str(data_dir_path))
write_version(data_dir_path) write_version(data_dir_path)
logger.info("Data directory set up.") logger.info("Data directory set up.")
def _check_current_data_dir(data_dir_path: Path) -> None: def _handle_old_data_directory(data_dir_path: Path) -> None:
if _data_dir_version_mismatch_exists(data_dir_path): should_delete_data_directory = _prompt_user_to_delete_data_directory(data_dir_path)
logger.info("Version in data directory does not match the Island's version.") if should_delete_data_directory:
raise OldDataError(data_dir_path) shutil.rmtree(data_dir_path)
logger.info(f"{data_dir_path} was deleted.")
else:
logger.error(
"Unable to set up data directory. Please backup and delete the existing data directory"
f" ({data_dir_path}). Then, try again. To learn how to restore and use a backup, please"
" refer to the documentation."
)
raise IncompatibleDataDirectory()
def _prompt_user_to_delete_data_directory(data_dir_path: Path) -> bool:
user_response = input(
f"\nExisting data directory ({data_dir_path}) needs to be deleted."
" All data from previous runs will be lost. Proceed to delete? (y/n) "
)
print()
if user_response.lower() in {"y", "yes"}:
return True
return False
def _data_dir_version_mismatch_exists(data_dir_path: Path) -> bool: def _data_dir_version_mismatch_exists(data_dir_path: Path) -> bool:

View File

@ -2,8 +2,8 @@ from pathlib import Path
import pytest import pytest
from monkey_island.cc.setup.data_dir import _get_backup_path, setup_data_dir from monkey_island.cc.setup.data_dir import IncompatibleDataDirectory, setup_data_dir
from monkey_island.cc.setup.version_file_setup import _version_filename, get_version_from_dir from monkey_island.cc.setup.version_file_setup import _version_filename
current_version = "1.1.1" current_version = "1.1.1"
old_version = "1.1.0" old_version = "1.1.0"
@ -36,7 +36,9 @@ def test_setup_data_dir(temp_data_dir_path, temp_version_file_path):
assert version_file_path.read_text() == current_version assert version_file_path.read_text() == current_version
def test_old_version_present(temp_data_dir_path, temp_version_file_path): def test_old_version_removed(monkeypatch, temp_data_dir_path, temp_version_file_path):
monkeypatch.setattr("builtins.input", lambda _: "y")
temp_data_dir_path.mkdir() temp_data_dir_path.mkdir()
temp_version_file_path.write_text(old_version) temp_version_file_path.write_text(old_version)
bogus_file_path = temp_data_dir_path.joinpath("test.txt") bogus_file_path = temp_data_dir_path.joinpath("test.txt")
@ -46,25 +48,24 @@ def test_old_version_present(temp_data_dir_path, temp_version_file_path):
assert temp_version_file_path.read_text() == current_version assert temp_version_file_path.read_text() == current_version
assert not bogus_file_path.is_file() assert not bogus_file_path.is_file()
assert _get_backup_path(temp_data_dir_path).joinpath("test.txt").is_file()
def test_old_version_and_backup_present(temp_data_dir_path, temp_version_file_path): @pytest.mark.parametrize("input_value", ["n", "x"])
def test_old_version_not_removed(
monkeypatch, temp_data_dir_path, temp_version_file_path, input_value
):
monkeypatch.setattr("builtins.input", lambda _: input_value)
temp_data_dir_path.mkdir() temp_data_dir_path.mkdir()
temp_version_file_path.write_text(old_version) temp_version_file_path.write_text(old_version)
bogus_file_path = temp_data_dir_path.joinpath("test.txt")
old_backup_path = _get_backup_path(temp_data_dir_path)
old_backup_path.mkdir()
bogus_file_path = old_backup_path.joinpath("test.txt")
bogus_file_path.touch() bogus_file_path.touch()
with pytest.raises(IncompatibleDataDirectory):
setup_data_dir(temp_data_dir_path) setup_data_dir(temp_data_dir_path)
new_backup_path = old_backup_path
# Make sure old backup got deleted and new backup took it's place assert temp_version_file_path.read_text() == old_version
assert temp_version_file_path.read_text() == current_version assert bogus_file_path.is_file()
assert get_version_from_dir(new_backup_path) == old_version
assert not _get_backup_path(temp_data_dir_path).joinpath("test.txt").is_file()
def test_data_dir_setup_not_needed(temp_data_dir_path, temp_version_file_path): def test_data_dir_setup_not_needed(temp_data_dir_path, temp_version_file_path):