diff --git a/monkey/monkey_island/cc/services/__init__.py b/monkey/monkey_island/cc/services/__init__.py index 15d0b801d..2d13fae99 100644 --- a/monkey/monkey_island/cc/services/__init__.py +++ b/monkey/monkey_island/cc/services/__init__.py @@ -1,3 +1,4 @@ from .authentication.authentication_service import AuthenticationService from .authentication.json_file_user_datastore import JsonFileUserDatastore from .i_file_storage_service import IFileStorageService +from .directory_file_storage_service import DirectoryFileStorageService diff --git a/monkey/monkey_island/cc/services/directory_file_storage_service.py b/monkey/monkey_island/cc/services/directory_file_storage_service.py new file mode 100644 index 000000000..30aa2a4d0 --- /dev/null +++ b/monkey/monkey_island/cc/services/directory_file_storage_service.py @@ -0,0 +1,54 @@ +import shutil +from pathlib import Path +from typing import BinaryIO + +from monkey_island.cc.server_utils.file_utils import create_secure_directory + +from . import IFileStorageService + + +class DirectoryFileStorageService(IFileStorageService): + """ + A implementation of IFileStorageService that reads and writes files from/to the local + filesystem. + """ + + def __init__(self, storage_directory: Path): + """ + :param storage_directory: A Path object representing the directory where files will be + stored. If the directory does not exist, it will be created. + """ + if storage_directory.exists() and not storage_directory.is_dir(): + raise ValueError(f"The provided path must point to a directory: {storage_directory}") + + if not storage_directory.exists(): + create_secure_directory(storage_directory) + + self._storage_directory = storage_directory + + def save_file(self, unsafe_file_name: str, file_contents: BinaryIO): + safe_file_path = self._get_safe_file_path(unsafe_file_name) + + with open(safe_file_path, "wb") as dest: + shutil.copyfileobj(file_contents, dest) + + def open_file(self, unsafe_file_name: str) -> BinaryIO: + safe_file_path = self._get_safe_file_path(unsafe_file_name) + return open(safe_file_path, "rb") + + def delete_file(self, unsafe_file_name: str): + safe_file_path = self._get_safe_file_path(unsafe_file_name) + + safe_file_path.unlink() + + def _get_safe_file_path(self, unsafe_file_name: str): + # Remove any path information from the file name. + safe_file_name = Path(unsafe_file_name).resolve().name + + # TODO: Add super paranoid check + + return self._storage_directory / safe_file_name + + def delete_all_files(self): + for file in self._storage_directory.iterdir(): + file.unlink() diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_directory_file_storage_service.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_directory_file_storage_service.py new file mode 100644 index 000000000..7d65489d1 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_directory_file_storage_service.py @@ -0,0 +1,108 @@ +import io +from pathlib import Path + +import pytest +from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions + +from monkey_island.cc.server_utils.file_utils import is_windows_os +from monkey_island.cc.services import DirectoryFileStorageService + + +def test_error_if_file(tmp_path): + new_file = tmp_path / "new_file.txt" + new_file.write_text("HelloWorld!") + + with pytest.raises(ValueError): + DirectoryFileStorageService(new_file) + + +def test_directory_created(tmp_path): + new_dir = tmp_path / "new_dir" + + DirectoryFileStorageService(new_dir) + + assert new_dir.exists() and new_dir.is_dir() + + +@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.") +def test_directory_permissions__linux(tmp_path): + new_dir = tmp_path / "new_dir" + + DirectoryFileStorageService(new_dir) + + assert_linux_permissions(new_dir) + + +@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.") +def test_directory_permissions__windows(tmp_path): + new_dir = tmp_path / "new_dir" + + DirectoryFileStorageService(new_dir) + + assert_windows_permissions(new_dir) + + +def save_file(tmp_path, file_path_prefix=""): + file_name = "test.txt" + file_contents = "Hello World!" + expected_file_path = tmp_path / file_name + + fss = DirectoryFileStorageService(tmp_path) + fss.save_file(Path(file_path_prefix) / file_name, io.BytesIO(file_contents.encode())) + + assert expected_file_path.is_file() + assert expected_file_path.read_text() == file_contents + + +def delete_file(tmp_path, file_path_prefix=""): + file_name = "file.txt" + file = tmp_path / file_name + file.touch() + assert file.is_file() + + fss = DirectoryFileStorageService(tmp_path) + fss.delete_file(Path(file_path_prefix) / file_name) + + assert not file.exists() + + +def open_file(tmp_path, file_path_prefix=""): + file_name = "test.txt" + expected_file_contents = "Hello World!" + expected_file_path = tmp_path / file_name + expected_file_path.write_text(expected_file_contents) + + fss = DirectoryFileStorageService(tmp_path) + with fss.open_file(Path(file_path_prefix) / file_name) as f: + actual_file_contents = f.read() + + assert actual_file_contents == expected_file_contents.encode() + + +@pytest.mark.parametrize("fn", [save_file, open_file, delete_file]) +def test_fn(tmp_path, fn): + fn(tmp_path) + + +@pytest.mark.parametrize("fn", [save_file, open_file, delete_file]) +def test_fn__ignore_relative_path(tmp_path, fn): + fn(tmp_path, "../../") + + +@pytest.mark.parametrize("fn", [save_file, open_file, delete_file]) +def test_fn__ignore_absolute_path(tmp_path, fn): + if is_windows_os(): + fn(tmp_path, "C:\\Windows") + else: + fn(tmp_path, "/home/") + + +def test_remove_all_files(tmp_path): + for filename in ["1.txt", "2.txt", "3.txt"]: + (tmp_path / filename).touch() + + fss = DirectoryFileStorageService(tmp_path) + fss.delete_all_files() + + for file in tmp_path.iterdir(): + assert False, f"{tmp_path} was expected to be empty, but contained files"