diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py b/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py index 903528be2..db58ed173 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import flask_jwt_extended import flask_restful import pytest @@ -10,14 +12,27 @@ from monkey_island.cc.services.representations import output_json @pytest.fixture -def flask_client(monkeypatch_session, tmp_path): +def flask_client(monkeypatch_session): monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None) - with mock_init_app(tmp_path).test_client() as client: + container = MagicMock() + container.resolve_dependencies.return_value = [] + + with mock_init_app(container).test_client() as client: yield client -def mock_init_app(data_dir): +@pytest.fixture +def build_flask_client(monkeypatch_session): + def inner(container): + monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None) + + return mock_init_app(container).test_client() + + return inner + + +def mock_init_app(container): app = Flask(__name__) app.config["SECRET_KEY"] = "test_key" @@ -25,7 +40,8 @@ def mock_init_app(data_dir): api.representations = {"application/json": output_json} monkey_island.cc.app.init_app_url_rules(app) - monkey_island.cc.app.init_api_resources(api, data_dir) + flask_resource_manager = monkey_island.cc.app.FlaskResourceManager(api, container) + monkey_island.cc.app.init_api_resources(flask_resource_manager) flask_jwt_extended.JWTManager(app) diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py index 5b6dcad3f..6b76115cb 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py @@ -1,17 +1,54 @@ -def test_file_download_endpoint(tmp_path, flask_client): - file_contents = "HelloWorld!" - file_name = "test_file" - (tmp_path / "custom_pbas" / file_name).write_text(file_contents) +import io +from typing import BinaryIO - resp = flask_client.get(f"/api/pba/download/{file_name}") +import pytest + +from common import DIContainer +from monkey_island.cc.services import IFileStorageService + +FILE_NAME = "test_file" +FILE_CONTENTS = b"HelloWorld!" + + +class MockFileStorageService(IFileStorageService): + def __init__(self): + self._file = io.BytesIO(FILE_CONTENTS) + + def save_file(self, unsafe_file_name: str, file_contents: BinaryIO): + pass + + def open_file(self, unsafe_file_name: str) -> BinaryIO: + if unsafe_file_name != FILE_NAME: + raise OSError() + + return self._file + + def delete_file(self, unsafe_file_name: str): + pass + + def delete_all_files(self): + pass + + +@pytest.fixture +def flask_client(build_flask_client, tmp_path): + container = DIContainer() + container.register(IFileStorageService, MockFileStorageService) + + with build_flask_client(container) as flask_client: + yield flask_client + + +def test_file_download_endpoint(tmp_path, flask_client): + resp = flask_client.get(f"/api/pba/download/{FILE_NAME}") assert resp.status_code == 200 - assert next(resp.response).decode() == file_contents + assert next(resp.response) == FILE_CONTENTS def test_file_download_endpoint_404(tmp_path, flask_client): - file_name = "nonexistant_file" + nonexistant_file_name = "nonexistant_file" - resp = flask_client.get(f"/api/pba/download/{file_name}") + resp = flask_client.get(f"/api/pba/download/{nonexistant_file_name}") assert resp.status_code == 404 diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_upload.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_upload.py index ca5f0f175..28b908c86 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_upload.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_upload.py @@ -1,9 +1,16 @@ +import io +from typing import BinaryIO + import pytest from tests.utils import raise_ +from common import DIContainer from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE +from monkey_island.cc.services import IFileStorageService -TEST_FILE = b"""-----------------------------1 +TEST_FILE_CONTENTS = b"m0nk3y" +TEST_FILE = ( + b"""-----------------------------1 Content-Disposition: form-data; name="filepond" {} @@ -11,8 +18,11 @@ Content-Disposition: form-data; name="filepond" Content-Disposition: form-data; name="filepond"; filename="test.py" Content-Type: text/x-python -m0nk3y +""" + + TEST_FILE_CONTENTS + + b""" -----------------------------1--""" +) @pytest.fixture @@ -29,6 +39,35 @@ def mock_get_config_value(monkeypatch): ) +class MockFileStorageService(IFileStorageService): + def __init__(self): + self._file = None + + def save_file(self, unsafe_file_name: str, file_contents: BinaryIO): + self._file = io.BytesIO(file_contents.read()) + + def open_file(self, unsafe_file_name: str) -> BinaryIO: + if self._file is None: + # TODO: Add FileRetrievalError + raise OSError() + return self._file + + def delete_file(self, unsafe_file_name: str): + self._file = None + + def delete_all_files(self): + self.delete_file("") + + +@pytest.fixture +def flask_client(build_flask_client, tmp_path): + container = DIContainer() + container.register(IFileStorageService, MockFileStorageService) + + with build_flask_client(container) as flask_client: + yield flask_client + + @pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE]) def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config_value): resp = flask_client.post( @@ -89,7 +128,7 @@ def test_pba_file_upload_endpoint( resp_get = flask_client.get(f"/api/file-upload/{pba_os}?load=test.py") assert resp_get.status_code == 200 - assert resp_get.data.decode() == "m0nk3y" + assert resp_get.data == TEST_FILE_CONTENTS # Closing the response closes the file handle, else it can't be deleted resp_get.close()