diff --git a/monkey/monkey_island/cc/setup/data_dir.py b/monkey/monkey_island/cc/setup/data_dir.py index c95bc897d..d4d176ba7 100644 --- a/monkey/monkey_island/cc/setup/data_dir.py +++ b/monkey/monkey_island/cc/setup/data_dir.py @@ -1,5 +1,44 @@ +import logging +import shutil + +from common.version import get_version from monkey_island.cc.server_utils.file_utils import create_secure_directory +from monkey_island.cc.setup.version_file_setup import ( + get_version_from_dir, + is_version_greater, + write_version, +) + +logger = logging.getLogger(__name__) -def setup_data_dir(path: str): - create_secure_directory(path) +def setup_data_dir(data_dir_path: str): + logger.info("Setting up data directory.") + _reset_data_dir(data_dir_path) + create_secure_directory(data_dir_path) + write_version(data_dir_path) + logger.info("Data directory set up.") + + +def _reset_data_dir(data_dir_path: str): + try: + data_dir_version = get_version_from_dir(data_dir_path) + except FileNotFoundError: + logger.info("Version file not found on the data directory.") + _remove_data_dir(data_dir_path) + return + + island_version = get_version() + logger.info(f"Version found in the data directory: {data_dir_version}") + logger.info(f"Currently running version: {island_version}") + if is_version_greater(island_version, data_dir_version): + _remove_data_dir(data_dir_path) + + +def _remove_data_dir(data_dir_path: str): + logger.info("Attempting to remove data directory.") + try: + shutil.rmtree(data_dir_path) + logger.info("Data directory removed.") + except FileNotFoundError: + logger.info("Data directory not found, nothing to remove.") diff --git a/monkey/monkey_island/cc/setup/version_file_setup.py b/monkey/monkey_island/cc/setup/version_file_setup.py new file mode 100644 index 000000000..ac45c201b --- /dev/null +++ b/monkey/monkey_island/cc/setup/version_file_setup.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from packaging import version + +from common.version import get_version + +_version_filename = "VERSION" + + +def get_version_from_dir(dir_path: str) -> str: + version_file_path = Path(dir_path, _version_filename) + return version_file_path.read_text() + + +def write_version(dir_path: str): + version_file_path = Path(dir_path, _version_filename) + version_file_path.write_text(get_version()) + + +def is_version_greater(version1: str, version2: str) -> bool: + return version.parse(version1) > version.parse(version2) diff --git a/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py b/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py new file mode 100644 index 000000000..81ed1f51d --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/setup/test_data_dir.py @@ -0,0 +1,60 @@ +from pathlib import Path + +import pytest + +from monkey_island.cc.setup.data_dir import setup_data_dir +from monkey_island.cc.setup.version_file_setup import _version_filename + +current_version = "1.1.1" +old_version = "1.1.0" + + +@pytest.fixture(autouse=True) +def mock_version(monkeypatch): + monkeypatch.setattr("monkey_island.cc.setup.data_dir.get_version", lambda: current_version) + monkeypatch.setattr( + "monkey_island.cc.setup.version_file_setup.get_version", lambda: current_version + ) + + +@pytest.fixture +def mocked_data_dir_path(tmpdir) -> Path: + return Path(tmpdir, "data_dir") + + +@pytest.fixture +def mocked_version_file_path(mocked_data_dir_path: Path) -> Path: + return mocked_data_dir_path.joinpath(_version_filename) + + +def test_setup_data_dir(mocked_data_dir_path, mocked_version_file_path): + data_dir_path = mocked_data_dir_path + setup_data_dir(str(data_dir_path)) + assert data_dir_path.is_dir() + + version_file_path = mocked_version_file_path + assert version_file_path.read_text() == current_version + + +def test_old_version_present(mocked_data_dir_path, mocked_version_file_path): + mocked_data_dir_path.mkdir() + mocked_version_file_path.write_text(old_version) + bogus_file_path = mocked_data_dir_path.joinpath("test.txt") + bogus_file_path.touch() + + # mock version + setup_data_dir(str(mocked_data_dir_path)) + + assert mocked_version_file_path.read_text() == current_version + assert not bogus_file_path.is_file() + + +def test_data_dir_setup_not_needed(mocked_data_dir_path, mocked_version_file_path): + mocked_data_dir_path.mkdir() + mocked_version_file_path.write_text(current_version) + bogus_file_path = mocked_data_dir_path.joinpath("test.txt") + bogus_file_path.touch() + + setup_data_dir(str(mocked_data_dir_path)) + assert mocked_version_file_path.read_text() == current_version + assert bogus_file_path.is_file()