diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py b/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py index a40766d5e..532282e8b 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py @@ -1,12 +1,17 @@ +from typing import BinaryIO from unittest.mock import MagicMock import flask_jwt_extended import pytest +from tests.common import StubDIContainer from tests.unit_tests.monkey_island.conftest import init_mock_app import monkey_island.cc.app import monkey_island.cc.resources.auth.auth import monkey_island.cc.resources.island_mode +from monkey_island.cc.repository import IFileRepository, RetrievalError + +from .mock_file_repository import MockFileRepository @pytest.fixture @@ -38,3 +43,17 @@ def get_mock_app(container): flask_jwt_extended.JWTManager(app) return app + + +class OpenErrorFileRepository(MockFileRepository): + def open_file(self, unsafe_file_name: str) -> BinaryIO: + raise RetrievalError("Error retrieving file") + + +@pytest.fixture +def open_error_flask_client(build_flask_client): + container = StubDIContainer() + container.register(IFileRepository, OpenErrorFileRepository) + + with build_flask_client(container) as flask_client: + yield flask_client diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py index 885f69609..eb189b4e7 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py @@ -1,10 +1,8 @@ -from typing import BinaryIO - import pytest from tests.common import StubDIContainer from tests.unit_tests.monkey_island.conftest import get_url_for_resource -from monkey_island.cc.repository import IFileRepository, RetrievalError +from monkey_island.cc.repository import IFileRepository from monkey_island.cc.resources.pba_file_download import PBAFileDownload from .mock_file_repository import FILE_CONTENTS, FILE_NAME, MockFileRepository @@ -36,20 +34,6 @@ def test_file_download_endpoint_404(tmp_path, flask_client): assert resp.status_code == 404 -class OpenErrorFileRepository(MockFileRepository): - def open_file(self, unsafe_file_name: str) -> BinaryIO: - raise RetrievalError("Error retrieving file") - - -@pytest.fixture -def open_error_flask_client(build_flask_client): - container = StubDIContainer() - container.register(IFileRepository, OpenErrorFileRepository) - - with build_flask_client(container) as flask_client: - yield flask_client - - def test_file_download_endpoint_500(tmp_path, open_error_flask_client): download_url = get_url_for_resource(PBAFileDownload, filename="test")