UT: Add test_file_download_endpoint_500() for PBAFileDownload

This commit is contained in:
Mike Salvatore 2022-06-21 08:58:21 -04:00
parent 44795531b8
commit bcc5265a99
1 changed files with 25 additions and 2 deletions

View File

@ -5,7 +5,8 @@ import pytest
from tests.common import StubDIContainer from tests.common import StubDIContainer
from tests.unit_tests.monkey_island.conftest import get_url_for_resource from tests.unit_tests.monkey_island.conftest import get_url_for_resource
from monkey_island.cc.repository import FileNotFoundError, IFileRepository from monkey_island.cc import repository
from monkey_island.cc.repository import IFileRepository, RetrievalError
from monkey_island.cc.resources.pba_file_download import PBAFileDownload from monkey_island.cc.resources.pba_file_download import PBAFileDownload
FILE_NAME = "test_file" FILE_NAME = "test_file"
@ -21,7 +22,7 @@ class MockFileRepository(IFileRepository):
def open_file(self, unsafe_file_name: str) -> BinaryIO: def open_file(self, unsafe_file_name: str) -> BinaryIO:
if unsafe_file_name != FILE_NAME: if unsafe_file_name != FILE_NAME:
raise FileNotFoundError() raise repository.FileNotFoundError()
return self._file return self._file
@ -56,3 +57,25 @@ def test_file_download_endpoint_404(tmp_path, flask_client):
resp = flask_client.get(download_url) resp = flask_client.get(download_url)
assert resp.status_code == 404 assert resp.status_code == 404
class OpenErrorFileRepository(MockFileRepository):
def open_file(self, unsafe_file_name: str) -> BinaryIO:
raise RetrievalError("Error retrieving file")
@pytest.fixture
def open_error_flask_client(build_flask_client):
container = StubDIContainer()
container.register(IFileRepository, OpenErrorFileRepository)
with build_flask_client(container) as flask_client:
yield flask_client
def test_file_download_endpoint_500(tmp_path, open_error_flask_client):
download_url = get_url_for_resource(PBAFileDownload, filename="test")
resp = open_error_flask_client.get(download_url)
assert resp.status_code == 500