forked from p34709852/monkey
Island: Use IFileStorageService in PBAFileDownload resource
This commit is contained in:
parent
c03a5aac4b
commit
d1e18e9dbd
|
@ -47,6 +47,7 @@ from monkey_island.cc.resources.zero_trust.finding_event import ZeroTrustFinding
|
||||||
from monkey_island.cc.resources.zero_trust.zero_trust_report import ZeroTrustReport
|
from monkey_island.cc.resources.zero_trust.zero_trust_report import ZeroTrustReport
|
||||||
from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH
|
from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH
|
||||||
from monkey_island.cc.server_utils.custom_json_encoder import CustomJSONEncoder
|
from monkey_island.cc.server_utils.custom_json_encoder import CustomJSONEncoder
|
||||||
|
from monkey_island.cc.services import DirectoryFileStorageService
|
||||||
from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService
|
from monkey_island.cc.services.remote_run_aws import RemoteRunAwsService
|
||||||
from monkey_island.cc.services.representations import output_json
|
from monkey_island.cc.services.representations import output_json
|
||||||
|
|
||||||
|
@ -149,7 +150,11 @@ def init_api_resources(api, data_dir: Path):
|
||||||
api.add_resource(TelemetryFeed, "/api/telemetry-feed")
|
api.add_resource(TelemetryFeed, "/api/telemetry-feed")
|
||||||
api.add_resource(Log, "/api/log")
|
api.add_resource(Log, "/api/log")
|
||||||
api.add_resource(IslandLog, "/api/log/island/download")
|
api.add_resource(IslandLog, "/api/log/island/download")
|
||||||
api.add_resource(PBAFileDownload, "/api/pba/download/<string:filename>")
|
api.add_resource(
|
||||||
|
PBAFileDownload,
|
||||||
|
"/api/pba/download/<string:filename>",
|
||||||
|
resource_class_kwargs={"file_storage_service": DirectoryFileStorageService(data_dir)},
|
||||||
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
FileUpload,
|
FileUpload,
|
||||||
"/api/file-upload/<string:file_type>",
|
"/api/file-upload/<string:file_type>",
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import flask_restful
|
import logging
|
||||||
from flask import send_from_directory
|
|
||||||
|
|
||||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
import flask_restful
|
||||||
|
from flask import make_response, send_file
|
||||||
|
|
||||||
|
from monkey_island.cc.services import IFileStorageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
class PBAFileDownload(flask_restful.Resource):
|
class PBAFileDownload(flask_restful.Resource):
|
||||||
|
@ -9,7 +13,17 @@ class PBAFileDownload(flask_restful.Resource):
|
||||||
File download endpoint used by monkey to download user's PBA file
|
File download endpoint used by monkey to download user's PBA file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file_storage_service: IFileStorageService):
|
||||||
|
self._file_storage_service = file_storage_service
|
||||||
|
|
||||||
# Used by monkey. can't secure.
|
# Used by monkey. can't secure.
|
||||||
def get(self, filename):
|
def get(self, filename: str):
|
||||||
custom_pba_dir = PostBreachFilesService.get_custom_pba_directory()
|
try:
|
||||||
return send_from_directory(custom_pba_dir, filename)
|
file = self._file_storage_service.open_file(filename)
|
||||||
|
|
||||||
|
# `send_file()` handles the closing of the open file.
|
||||||
|
return send_file(file, mimetype="application/octet-stream")
|
||||||
|
except OSError as ex:
|
||||||
|
error_msg = f"Failed to open file {filename}: {ex}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return make_response({"error": error_msg}, 404)
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
def test_file_download_endpoint(tmp_path, flask_client):
|
||||||
|
file_contents = "HelloWorld!"
|
||||||
|
file_name = "test_file"
|
||||||
|
(tmp_path / file_name).write_text(file_contents)
|
||||||
|
|
||||||
|
resp = flask_client.get(f"/api/pba/download/{file_name}")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert next(resp.response).decode() == file_contents
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_download_endpoint_404(tmp_path, flask_client):
|
||||||
|
file_name = "nonexistant_file"
|
||||||
|
|
||||||
|
resp = flask_client.get(f"/api/pba/download/{file_name}")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
Loading…
Reference in New Issue