diff --git a/monkey/monkey_island/cc/setup/data_dir.py b/monkey/monkey_island/cc/setup/data_dir.py index bee9d0a05..c47aa1680 100644 --- a/monkey/monkey_island/cc/setup/data_dir.py +++ b/monkey/monkey_island/cc/setup/data_dir.py @@ -11,6 +11,7 @@ from monkey_island.cc.setup.version_file_setup import ( ) logger = logging.getLogger(__name__) +_data_dir_backup_suffix = ".old" def setup_data_dir(data_dir_path: Path): @@ -26,20 +27,28 @@ def _reset_data_dir(data_dir_path: Path): data_dir_version = get_version_from_dir(data_dir_path) except FileNotFoundError: logger.info("Version file not found on the data directory.") - _remove_data_dir(data_dir_path) + _backup_old_data_dir(data_dir_path) return island_version = get_version() logger.info(f"Version found in the data directory: {data_dir_version}") logger.info(f"Currently running version: {island_version}") if is_version_greater(island_version, data_dir_version): - _remove_data_dir(data_dir_path) + _backup_old_data_dir(data_dir_path) -def _remove_data_dir(data_dir_path: str): - logger.info("Attempting to remove data directory.") +def _backup_old_data_dir(data_dir_path: Path): + logger.info("Attempting to backup old data directory.") try: - shutil.rmtree(data_dir_path) - logger.info("Data directory removed.") + backup_path = _get_backup_path(data_dir_path) + if backup_path.is_dir(): + shutil.rmtree(backup_path) + Path(data_dir_path).replace(backup_path) + logger.info(f"Old data directory moved to {backup_path}.") except FileNotFoundError: - logger.info("Data directory not found, nothing to remove.") + logger.info("Old data directory not found, nothing to backup.") + + +def _get_backup_path(data_dir_path: Path) -> Path: + backup_dir_name = data_dir_path.name + _data_dir_backup_suffix + return Path(data_dir_path.parent, backup_dir_name) diff --git a/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py b/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py index 096c5b0f3..dfabfa54e 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py +++ b/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py @@ -2,8 +2,8 @@ from pathlib import Path import pytest -from monkey_island.cc.setup.data_dir import setup_data_dir -from monkey_island.cc.setup.version_file_setup import _version_filename +from monkey_island.cc.setup.data_dir import _get_backup_path, setup_data_dir +from monkey_island.cc.setup.version_file_setup import _version_filename, get_version_from_dir current_version = "1.1.1" old_version = "1.1.0" @@ -42,11 +42,29 @@ def test_old_version_present(mocked_data_dir_path, mocked_version_file_path): bogus_file_path = mocked_data_dir_path.joinpath("test.txt") bogus_file_path.touch() - # mock version setup_data_dir(mocked_data_dir_path) assert mocked_version_file_path.read_text() == current_version assert not bogus_file_path.is_file() + assert _get_backup_path(mocked_data_dir_path).joinpath("test.txt").is_file() + + +def test_old_version_and_backup_present(mocked_data_dir_path, mocked_version_file_path): + mocked_data_dir_path.mkdir() + mocked_version_file_path.write_text(old_version) + + old_backup_path = _get_backup_path(mocked_data_dir_path) + old_backup_path.mkdir() + bogus_file_path = old_backup_path.joinpath("test.txt") + bogus_file_path.touch() + + setup_data_dir(mocked_data_dir_path) + new_backup_path = old_backup_path + + # Make sure old backup got deleted and new backup took it's place + assert mocked_version_file_path.read_text() == current_version + assert get_version_from_dir(new_backup_path) == old_version + assert not _get_backup_path(mocked_data_dir_path).joinpath("test.txt").is_file() def test_data_dir_setup_not_needed(mocked_data_dir_path, mocked_version_file_path):