From bcc5265a998da21f09967b4736c4f7bdd63fdf13 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 21 Jun 2022 08:58:21 -0400 Subject: [PATCH] UT: Add test_file_download_endpoint_500() for PBAFileDownload --- .../cc/resources/test_pba_file_download.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py index caf86edab..067dcd102 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py @@ -5,7 +5,8 @@ import pytest from tests.common import StubDIContainer from tests.unit_tests.monkey_island.conftest import get_url_for_resource -from monkey_island.cc.repository import FileNotFoundError, IFileRepository +from monkey_island.cc import repository +from monkey_island.cc.repository import IFileRepository, RetrievalError from monkey_island.cc.resources.pba_file_download import PBAFileDownload FILE_NAME = "test_file" @@ -21,7 +22,7 @@ class MockFileRepository(IFileRepository): def open_file(self, unsafe_file_name: str) -> BinaryIO: if unsafe_file_name != FILE_NAME: - raise FileNotFoundError() + raise repository.FileNotFoundError() return self._file @@ -56,3 +57,25 @@ def test_file_download_endpoint_404(tmp_path, flask_client): resp = flask_client.get(download_url) assert resp.status_code == 404 + + +class OpenErrorFileRepository(MockFileRepository): + def open_file(self, unsafe_file_name: str) -> BinaryIO: + raise RetrievalError("Error retrieving file") + + +@pytest.fixture +def open_error_flask_client(build_flask_client): + container = StubDIContainer() + container.register(IFileRepository, OpenErrorFileRepository) + + with build_flask_client(container) as flask_client: + yield flask_client + + +def test_file_download_endpoint_500(tmp_path, open_error_flask_client): + download_url = get_url_for_resource(PBAFileDownload, filename="test") + + resp = open_error_flask_client.get(download_url) + + assert resp.status_code == 500