Monkey island: change the methods in data_dir.py and version_file_setup.py to handle Path rather than str.

This commit is contained in:
VakarisZ 2021-10-15 16:08:20 +03:00
parent e77ed9769b
commit 15949a9ed5
4 changed files with 13 additions and 12 deletions

View File

@ -22,7 +22,7 @@ def _setup_config_by_cmd_arg(server_config_path) -> Tuple[IslandConfigOptions, s
# TODO refactor like in https://github.com/guardicore/monkey/pull/1528 because # TODO refactor like in https://github.com/guardicore/monkey/pull/1528 because
# there's absolutely no reason to be exposed to IslandConfigOptions extraction logic # there's absolutely no reason to be exposed to IslandConfigOptions extraction logic
# if you want to modify data directory related code. # if you want to modify data directory related code.
setup_data_dir(str(config.data_dir)) setup_data_dir(config.data_dir)
return config, server_config_path return config, server_config_path
@ -34,7 +34,7 @@ def _setup_default_config() -> Tuple[IslandConfigOptions, str]:
# TODO refactor like in https://github.com/guardicore/monkey/pull/1528 because # TODO refactor like in https://github.com/guardicore/monkey/pull/1528 because
# there's absolutely no reason to be exposed to IslandConfigOptions extraction logic # there's absolutely no reason to be exposed to IslandConfigOptions extraction logic
# if you want to modify data directory related code. # if you want to modify data directory related code.
setup_data_dir(str(default_data_dir)) setup_data_dir(default_data_dir)
server_config_path = server_config_handler.create_default_server_config_file(default_data_dir) server_config_path = server_config_handler.create_default_server_config_file(default_data_dir)
config = server_config_handler.load_server_config_from_file(server_config_path) config = server_config_handler.load_server_config_from_file(server_config_path)

View File

@ -1,5 +1,6 @@
import logging import logging
import shutil import shutil
from pathlib import Path
from common.version import get_version from common.version import get_version
from monkey_island.cc.server_utils.file_utils import create_secure_directory from monkey_island.cc.server_utils.file_utils import create_secure_directory
@ -12,15 +13,15 @@ from monkey_island.cc.setup.version_file_setup import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def setup_data_dir(data_dir_path: str): def setup_data_dir(data_dir_path: Path):
logger.info("Setting up data directory.") logger.info("Setting up data directory.")
_reset_data_dir(data_dir_path) _reset_data_dir(data_dir_path)
create_secure_directory(data_dir_path) create_secure_directory(str(data_dir_path))
write_version(data_dir_path) write_version(data_dir_path)
logger.info("Data directory set up.") logger.info("Data directory set up.")
def _reset_data_dir(data_dir_path: str): def _reset_data_dir(data_dir_path: Path):
try: try:
data_dir_version = get_version_from_dir(data_dir_path) data_dir_version = get_version_from_dir(data_dir_path)
except FileNotFoundError: except FileNotFoundError:

View File

@ -7,13 +7,13 @@ from common.version import get_version
_version_filename = "VERSION" _version_filename = "VERSION"
def get_version_from_dir(dir_path: str) -> str: def get_version_from_dir(dir_path: Path) -> str:
version_file_path = Path(dir_path, _version_filename) version_file_path = dir_path.joinpath(_version_filename)
return version_file_path.read_text() return version_file_path.read_text()
def write_version(dir_path: str): def write_version(dir_path: Path):
version_file_path = Path(dir_path, _version_filename) version_file_path = dir_path.joinpath(_version_filename)
version_file_path.write_text(get_version()) version_file_path.write_text(get_version())

View File

@ -29,7 +29,7 @@ def mocked_version_file_path(mocked_data_dir_path: Path) -> Path:
def test_setup_data_dir(mocked_data_dir_path, mocked_version_file_path): def test_setup_data_dir(mocked_data_dir_path, mocked_version_file_path):
data_dir_path = mocked_data_dir_path data_dir_path = mocked_data_dir_path
setup_data_dir(str(data_dir_path)) setup_data_dir(data_dir_path)
assert data_dir_path.is_dir() assert data_dir_path.is_dir()
version_file_path = mocked_version_file_path version_file_path = mocked_version_file_path
@ -43,7 +43,7 @@ def test_old_version_present(mocked_data_dir_path, mocked_version_file_path):
bogus_file_path.touch() bogus_file_path.touch()
# mock version # mock version
setup_data_dir(str(mocked_data_dir_path)) setup_data_dir(mocked_data_dir_path)
assert mocked_version_file_path.read_text() == current_version assert mocked_version_file_path.read_text() == current_version
assert not bogus_file_path.is_file() assert not bogus_file_path.is_file()
@ -55,6 +55,6 @@ def test_data_dir_setup_not_needed(mocked_data_dir_path, mocked_version_file_pat
bogus_file_path = mocked_data_dir_path.joinpath("test.txt") bogus_file_path = mocked_data_dir_path.joinpath("test.txt")
bogus_file_path.touch() bogus_file_path.touch()
setup_data_dir(str(mocked_data_dir_path)) setup_data_dir(mocked_data_dir_path)
assert mocked_version_file_path.read_text() == current_version assert mocked_version_file_path.read_text() == current_version
assert bogus_file_path.is_file() assert bogus_file_path.is_file()