From 79eb7442ae60490e53755dedb7029c923ef844d9 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Wed, 12 May 2021 09:49:58 -0400 Subject: [PATCH] island: Move the specifics of saving pba files to pba service --- .../cc/resources/pba_file_upload.py | 19 +++++++++++-------- .../cc/services/post_breach_files.py | 6 ++++++ .../cc/services/test_post_breach_files.py | 16 ++++++++++++---- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/monkey/monkey_island/cc/resources/pba_file_upload.py b/monkey/monkey_island/cc/resources/pba_file_upload.py index 369bc3c42..39da8324f 100644 --- a/monkey/monkey_island/cc/resources/pba_file_upload.py +++ b/monkey/monkey_island/cc/resources/pba_file_upload.py @@ -1,9 +1,9 @@ import copy import logging -from pathlib import Path import flask_restful from flask import Response, request, send_from_directory +from werkzeug.datastructures import FileStorage from werkzeug.utils import secure_filename from common.config_value_paths import PBA_LINUX_FILENAME_PATH, PBA_WINDOWS_FILENAME_PATH @@ -45,27 +45,30 @@ class FileUpload(flask_restful.Resource): :param file_type: Type indicates which file was received, linux or windows :return: Returns flask response object with uploaded file's filename """ - filename = FileUpload.upload_pba_file(request, (file_type == LINUX_PBA_TYPE)) + filename = FileUpload.upload_pba_file( + request.files["filepond"], (file_type == LINUX_PBA_TYPE) + ) response = Response(response=filename, status=200, mimetype="text/plain") return response @staticmethod - def upload_pba_file(request_, is_linux=True): + def upload_pba_file(file_storage: FileStorage, is_linux=True): """ Uploads PBA file to island's file system :param request_: Request object containing PBA file :param is_linux: Boolean indicating if this file is for windows or for linux :return: filename string """ - filename = secure_filename(request_.files["filepond"].filename) - file_path = ( - Path(PostBreachFilesService.get_custom_pba_directory()).joinpath(filename).absolute() - ) - request_.files["filepond"].save(str(file_path)) + filename = secure_filename(file_storage.filename) + file_contents = file_storage.read() + + PostBreachFilesService.save_file(filename, file_contents) + ConfigService.set_config_value( (PBA_LINUX_FILENAME_PATH if is_linux else PBA_WINDOWS_FILENAME_PATH), filename ) + return filename @jwt_required diff --git a/monkey/monkey_island/cc/services/post_breach_files.py b/monkey/monkey_island/cc/services/post_breach_files.py index 06d2ffe48..94569db37 100644 --- a/monkey/monkey_island/cc/services/post_breach_files.py +++ b/monkey/monkey_island/cc/services/post_breach_files.py @@ -17,6 +17,12 @@ class PostBreachFilesService: cls.DATA_DIR = data_dir Path(cls.get_custom_pba_directory()).mkdir(mode=0o0700, parents=True, exist_ok=True) + @staticmethod + def save_file(filename: str, file_contents: bytes): + file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), filename) + with open(file_path, "wb") as f: + f.write(file_contents) + @staticmethod def remove_PBA_files(): for f in os.listdir(PostBreachFilesService.get_custom_pba_directory()): diff --git a/monkey/tests/monkey_island/cc/services/test_post_breach_files.py b/monkey/tests/monkey_island/cc/services/test_post_breach_files.py index cc1c80e1f..3c3fe82fe 100644 --- a/monkey/tests/monkey_island/cc/services/test_post_breach_files.py +++ b/monkey/tests/monkey_island/cc/services/test_post_breach_files.py @@ -15,10 +15,7 @@ def custom_pba_directory(tmpdir): def create_custom_pba_file(filename): - assert os.path.isdir(PostBreachFilesService.get_custom_pba_directory()) - - file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), filename) - open(file_path, "a").close() + PostBreachFilesService.save_file(filename, b"") def test_remove_pba_files(): @@ -59,3 +56,14 @@ def test_remove_nonexistant_file(monkeypatch): PostBreachFilesService.remove_file("/nonexistant/file") except Exception as ex: pytest.fail(f"Unxepected exception: {ex}") + + +def test_save_file(): + FILE_NAME = "test_file" + FILE_CONTENTS = b"hello" + PostBreachFilesService.save_file(FILE_NAME, FILE_CONTENTS) + + expected_file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), FILE_NAME) + + assert os.path.isfile(expected_file_path) + assert FILE_CONTENTS == open(expected_file_path, "rb").read()