diff --git a/monkey/monkey_island/cc/app.py b/monkey/monkey_island/cc/app.py index 77d52ac8c..17e86fc23 100644 --- a/monkey/monkey_island/cc/app.py +++ b/monkey/monkey_island/cc/app.py @@ -47,6 +47,7 @@ from monkey_island.cc.resources.zero_trust.finding_event import ZeroTrustFinding from monkey_island.cc.resources.zero_trust.zero_trust_report import ZeroTrustReport from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH from monkey_island.cc.server_utils.custom_json_encoder import CustomJSONEncoder +from monkey_island.cc.services import DirectoryFileStorageService from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService from monkey_island.cc.services.representations import output_json @@ -149,7 +150,11 @@ def init_api_resources(api, data_dir: Path): api.add_resource(TelemetryFeed, "/api/telemetry-feed") api.add_resource(Log, "/api/log") api.add_resource(IslandLog, "/api/log/island/download") - api.add_resource(PBAFileDownload, "/api/pba/download/") + api.add_resource( + PBAFileDownload, + "/api/pba/download/", + resource_class_kwargs={"file_storage_service": DirectoryFileStorageService(data_dir)}, + ) api.add_resource( FileUpload, "/api/file-upload/", diff --git a/monkey/monkey_island/cc/resources/pba_file_download.py b/monkey/monkey_island/cc/resources/pba_file_download.py index df9766ed6..ba5857c53 100644 --- a/monkey/monkey_island/cc/resources/pba_file_download.py +++ b/monkey/monkey_island/cc/resources/pba_file_download.py @@ -1,7 +1,11 @@ -import flask_restful -from flask import send_from_directory +import logging -from monkey_island.cc.services.post_breach_files import PostBreachFilesService +import flask_restful +from flask import make_response, send_file + +from monkey_island.cc.services import IFileStorageService + +logger = logging.getLogger(__file__) class PBAFileDownload(flask_restful.Resource): @@ -9,7 +13,17 @@ class PBAFileDownload(flask_restful.Resource): File download endpoint used by monkey to download user's PBA file """ + def __init__(self, file_storage_service: IFileStorageService): + self._file_storage_service = file_storage_service + # Used by monkey. can't secure. - def get(self, filename): - custom_pba_dir = PostBreachFilesService.get_custom_pba_directory() - return send_from_directory(custom_pba_dir, filename) + def get(self, filename: str): + try: + file = self._file_storage_service.open_file(filename) + + # `send_file()` handles the closing of the open file. + return send_file(file, mimetype="application/octet-stream") + except OSError as ex: + error_msg = f"Failed to open file {filename}: {ex}" + logger.error(error_msg) + return make_response({"error": error_msg}, 404) diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py new file mode 100644 index 000000000..333ff4cf1 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_pba_file_download.py @@ -0,0 +1,17 @@ +def test_file_download_endpoint(tmp_path, flask_client): + file_contents = "HelloWorld!" + file_name = "test_file" + (tmp_path / file_name).write_text(file_contents) + + resp = flask_client.get(f"/api/pba/download/{file_name}") + + assert resp.status_code == 200 + assert next(resp.response).decode() == file_contents + + +def test_file_download_endpoint_404(tmp_path, flask_client): + file_name = "nonexistant_file" + + resp = flask_client.get(f"/api/pba/download/{file_name}") + + assert resp.status_code == 404