diff --git a/monkey/monkey_island/cc/resources/pba_file_download.py b/monkey/monkey_island/cc/resources/pba_file_download.py index ba5857c53..a11e964b1 100644 --- a/monkey/monkey_island/cc/resources/pba_file_download.py +++ b/monkey/monkey_island/cc/resources/pba_file_download.py @@ -3,7 +3,7 @@ import logging import flask_restful from flask import make_response, send_file -from monkey_island.cc.services import IFileStorageService +from monkey_island.cc.services import FileRetrievalError, IFileStorageService logger = logging.getLogger(__file__) @@ -23,7 +23,7 @@ class PBAFileDownload(flask_restful.Resource): # `send_file()` handles the closing of the open file. return send_file(file, mimetype="application/octet-stream") - except OSError as ex: - error_msg = f"Failed to open file {filename}: {ex}" + except FileRetrievalError as err: + error_msg = f"Failed to open file {filename}: {err}" logger.error(error_msg) return make_response({"error": error_msg}, 404) diff --git a/monkey/monkey_island/cc/resources/pba_file_upload.py b/monkey/monkey_island/cc/resources/pba_file_upload.py index d7b170a9f..2a27b324a 100644 --- a/monkey/monkey_island/cc/resources/pba_file_upload.py +++ b/monkey/monkey_island/cc/resources/pba_file_upload.py @@ -9,7 +9,7 @@ from werkzeug.utils import secure_filename from common.config_value_paths import PBA_LINUX_FILENAME_PATH, PBA_WINDOWS_FILENAME_PATH from monkey_island.cc.resources.auth.auth import jwt_required -from monkey_island.cc.services import IFileStorageService +from monkey_island.cc.services import FileRetrievalError, IFileStorageService from monkey_island.cc.services.config import ConfigService logger = logging.getLogger(__file__) @@ -53,8 +53,8 @@ class FileUpload(flask_restful.Resource): # `send_file()` handles the closing of the open file. return send_file(file, mimetype="application/octet-stream") - except OSError as ex: - error_msg = f"Failed to open file {filename}: {ex}" + except FileRetrievalError as err: + error_msg = f"Failed to open file {filename}: {err}" logger.error(error_msg) return make_response({"error": error_msg}, 404) diff --git a/monkey/monkey_island/cc/services/__init__.py b/monkey/monkey_island/cc/services/__init__.py index 3f57b4fc0..43aa39382 100644 --- a/monkey/monkey_island/cc/services/__init__.py +++ b/monkey/monkey_island/cc/services/__init__.py @@ -1,4 +1,4 @@ -from .i_file_storage_service import IFileStorageService +from .i_file_storage_service import IFileStorageService, FileRetrievalError from .directory_file_storage_service import DirectoryFileStorageService from .authentication.authentication_service import AuthenticationService diff --git a/monkey/monkey_island/cc/services/directory_file_storage_service.py b/monkey/monkey_island/cc/services/directory_file_storage_service.py index d6fd267fe..d2d0c811d 100644 --- a/monkey/monkey_island/cc/services/directory_file_storage_service.py +++ b/monkey/monkey_island/cc/services/directory_file_storage_service.py @@ -5,7 +5,7 @@ from typing import BinaryIO from common.utils.file_utils import get_all_regular_files_in_directory from monkey_island.cc.server_utils.file_utils import create_secure_directory -from . import IFileStorageService +from . import FileRetrievalError, IFileStorageService class DirectoryFileStorageService(IFileStorageService): @@ -35,7 +35,11 @@ class DirectoryFileStorageService(IFileStorageService): def open_file(self, unsafe_file_name: str) -> BinaryIO: safe_file_path = self._get_safe_file_path(unsafe_file_name) - return open(safe_file_path, "rb") + + try: + return open(safe_file_path, "rb") + except OSError as err: + raise FileRetrievalError(f"Failed to retrieve file {safe_file_path}: {err}") from err def delete_file(self, unsafe_file_name: str): safe_file_path = self._get_safe_file_path(unsafe_file_name) diff --git a/monkey/monkey_island/cc/services/i_file_storage_service.py b/monkey/monkey_island/cc/services/i_file_storage_service.py index a27cbf6ad..5903c07d1 100644 --- a/monkey/monkey_island/cc/services/i_file_storage_service.py +++ b/monkey/monkey_island/cc/services/i_file_storage_service.py @@ -2,6 +2,10 @@ import abc from typing import BinaryIO +class FileRetrievalError(ValueError): + pass + + class IFileStorageService(metaclass=abc.ABCMeta): """ A service that allows the storage and retrieval of individual files. @@ -25,6 +29,7 @@ class IFileStorageService(metaclass=abc.ABCMeta): :param unsafe_file_name: An unsanitized file name that identifies the file to be opened :return: A file-like object providing access to the file's contents :rtype: io.BinaryIO + :raises FileRetrievalError: if the file cannot be opened """ pass diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py index 6b76115cb..d0a8ec48f 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py @@ -4,7 +4,7 @@ from typing import BinaryIO import pytest from common import DIContainer -from monkey_island.cc.services import IFileStorageService +from monkey_island.cc.services import FileRetrievalError, IFileStorageService FILE_NAME = "test_file" FILE_CONTENTS = b"HelloWorld!" @@ -19,7 +19,7 @@ class MockFileStorageService(IFileStorageService): def open_file(self, unsafe_file_name: str) -> BinaryIO: if unsafe_file_name != FILE_NAME: - raise OSError() + raise FileRetrievalError() return self._file diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_upload.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_upload.py index 6f615e19f..da1eb2c02 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_upload.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_upload.py @@ -6,7 +6,7 @@ from tests.utils import raise_ from common import DIContainer from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE -from monkey_island.cc.services import IFileStorageService +from monkey_island.cc.services import FileRetrievalError, IFileStorageService TEST_FILE_CONTENTS = b"m0nk3y" TEST_FILE = ( @@ -48,8 +48,7 @@ class MockFileStorageService(IFileStorageService): def open_file(self, unsafe_file_name: str) -> BinaryIO: if self._file is None: - # TODO: Add FileRetrievalError - raise OSError() + raise FileRetrievalError() return self._file def delete_file(self, unsafe_file_name: str): diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_directory_file_storage_service.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_directory_file_storage_service.py index 07f65151d..ea40aa49d 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/test_directory_file_storage_service.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_directory_file_storage_service.py @@ -5,10 +5,10 @@ import pytest from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions from monkey_island.cc.server_utils.file_utils import is_windows_os -from monkey_island.cc.services import DirectoryFileStorageService +from monkey_island.cc.services import DirectoryFileStorageService, FileRetrievalError -def test_error_if_file(tmp_path): +def test_error_if_storage_directory_is_file(tmp_path): new_file = tmp_path / "new_file.txt" new_file.write_text("HelloWorld!") @@ -126,3 +126,10 @@ def test_remove_nonexistant_file(tmp_path): # This test will fail if this call raises an exception. fss.delete_file("nonexistant_file.txt") + + +def test_open_nonexistant_file(tmp_path): + fss = DirectoryFileStorageService(tmp_path) + + with pytest.raises(FileRetrievalError): + fss.open_file("nonexistant_file.txt")