UT: Refactor PBA resource upload/download tests to use DI

This commit is contained in:
Mike Salvatore 2022-04-26 12:58:22 -04:00
parent d084b00367
commit e296cd5225
3 changed files with 107 additions and 15 deletions

View File

@ -1,3 +1,5 @@
from unittest.mock import MagicMock
import flask_jwt_extended import flask_jwt_extended
import flask_restful import flask_restful
import pytest import pytest
@ -10,14 +12,27 @@ from monkey_island.cc.services.representations import output_json
@pytest.fixture @pytest.fixture
def flask_client(monkeypatch_session, tmp_path): def flask_client(monkeypatch_session):
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None) monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
with mock_init_app(tmp_path).test_client() as client: container = MagicMock()
container.resolve_dependencies.return_value = []
with mock_init_app(container).test_client() as client:
yield client yield client
def mock_init_app(data_dir): @pytest.fixture
def build_flask_client(monkeypatch_session):
def inner(container):
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
return mock_init_app(container).test_client()
return inner
def mock_init_app(container):
app = Flask(__name__) app = Flask(__name__)
app.config["SECRET_KEY"] = "test_key" app.config["SECRET_KEY"] = "test_key"
@ -25,7 +40,8 @@ def mock_init_app(data_dir):
api.representations = {"application/json": output_json} api.representations = {"application/json": output_json}
monkey_island.cc.app.init_app_url_rules(app) monkey_island.cc.app.init_app_url_rules(app)
monkey_island.cc.app.init_api_resources(api, data_dir) flask_resource_manager = monkey_island.cc.app.FlaskResourceManager(api, container)
monkey_island.cc.app.init_api_resources(flask_resource_manager)
flask_jwt_extended.JWTManager(app) flask_jwt_extended.JWTManager(app)

View File

@ -1,17 +1,54 @@
def test_file_download_endpoint(tmp_path, flask_client): import io
file_contents = "HelloWorld!" from typing import BinaryIO
file_name = "test_file"
(tmp_path / "custom_pbas" / file_name).write_text(file_contents)
resp = flask_client.get(f"/api/pba/download/{file_name}") import pytest
from common import DIContainer
from monkey_island.cc.services import IFileStorageService
FILE_NAME = "test_file"
FILE_CONTENTS = b"HelloWorld!"
class MockFileStorageService(IFileStorageService):
def __init__(self):
self._file = io.BytesIO(FILE_CONTENTS)
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
pass
def open_file(self, unsafe_file_name: str) -> BinaryIO:
if unsafe_file_name != FILE_NAME:
raise OSError()
return self._file
def delete_file(self, unsafe_file_name: str):
pass
def delete_all_files(self):
pass
@pytest.fixture
def flask_client(build_flask_client, tmp_path):
container = DIContainer()
container.register(IFileStorageService, MockFileStorageService)
with build_flask_client(container) as flask_client:
yield flask_client
def test_file_download_endpoint(tmp_path, flask_client):
resp = flask_client.get(f"/api/pba/download/{FILE_NAME}")
assert resp.status_code == 200 assert resp.status_code == 200
assert next(resp.response).decode() == file_contents assert next(resp.response) == FILE_CONTENTS
def test_file_download_endpoint_404(tmp_path, flask_client): def test_file_download_endpoint_404(tmp_path, flask_client):
file_name = "nonexistant_file" nonexistant_file_name = "nonexistant_file"
resp = flask_client.get(f"/api/pba/download/{file_name}") resp = flask_client.get(f"/api/pba/download/{nonexistant_file_name}")
assert resp.status_code == 404 assert resp.status_code == 404

View File

@ -1,9 +1,16 @@
import io
from typing import BinaryIO
import pytest import pytest
from tests.utils import raise_ from tests.utils import raise_
from common import DIContainer
from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
from monkey_island.cc.services import IFileStorageService
TEST_FILE = b"""-----------------------------1 TEST_FILE_CONTENTS = b"m0nk3y"
TEST_FILE = (
b"""-----------------------------1
Content-Disposition: form-data; name="filepond" Content-Disposition: form-data; name="filepond"
{} {}
@ -11,8 +18,11 @@ Content-Disposition: form-data; name="filepond"
Content-Disposition: form-data; name="filepond"; filename="test.py" Content-Disposition: form-data; name="filepond"; filename="test.py"
Content-Type: text/x-python Content-Type: text/x-python
m0nk3y """
+ TEST_FILE_CONTENTS
+ b"""
-----------------------------1--""" -----------------------------1--"""
)
@pytest.fixture @pytest.fixture
@ -29,6 +39,35 @@ def mock_get_config_value(monkeypatch):
) )
class MockFileStorageService(IFileStorageService):
def __init__(self):
self._file = None
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
self._file = io.BytesIO(file_contents.read())
def open_file(self, unsafe_file_name: str) -> BinaryIO:
if self._file is None:
# TODO: Add FileRetrievalError
raise OSError()
return self._file
def delete_file(self, unsafe_file_name: str):
self._file = None
def delete_all_files(self):
self.delete_file("")
@pytest.fixture
def flask_client(build_flask_client, tmp_path):
container = DIContainer()
container.register(IFileStorageService, MockFileStorageService)
with build_flask_client(container) as flask_client:
yield flask_client
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE]) @pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config_value): def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config_value):
resp = flask_client.post( resp = flask_client.post(
@ -89,7 +128,7 @@ def test_pba_file_upload_endpoint(
resp_get = flask_client.get(f"/api/file-upload/{pba_os}?load=test.py") resp_get = flask_client.get(f"/api/file-upload/{pba_os}?load=test.py")
assert resp_get.status_code == 200 assert resp_get.status_code == 200
assert resp_get.data.decode() == "m0nk3y" assert resp_get.data == TEST_FILE_CONTENTS
# Closing the response closes the file handle, else it can't be deleted # Closing the response closes the file handle, else it can't be deleted
resp_get.close() resp_get.close()