island: Move the specifics of saving pba files to pba service

This commit is contained in:
Mike Salvatore 2021-05-12 09:49:58 -04:00
parent 253588b3ac
commit 79eb7442ae
3 changed files with 29 additions and 12 deletions

View File

@ -1,9 +1,9 @@
import copy
import logging
from pathlib import Path
import flask_restful
from flask import Response, request, send_from_directory
from werkzeug.datastructures import FileStorage
from werkzeug.utils import secure_filename
from common.config_value_paths import PBA_LINUX_FILENAME_PATH, PBA_WINDOWS_FILENAME_PATH
@ -45,27 +45,30 @@ class FileUpload(flask_restful.Resource):
:param file_type: Type indicates which file was received, linux or windows
:return: Returns flask response object with uploaded file's filename
"""
filename = FileUpload.upload_pba_file(request, (file_type == LINUX_PBA_TYPE))
filename = FileUpload.upload_pba_file(
request.files["filepond"], (file_type == LINUX_PBA_TYPE)
)
response = Response(response=filename, status=200, mimetype="text/plain")
return response
@staticmethod
def upload_pba_file(request_, is_linux=True):
def upload_pba_file(file_storage: FileStorage, is_linux=True):
"""
Uploads PBA file to island's file system
:param request_: Request object containing PBA file
:param is_linux: Boolean indicating if this file is for windows or for linux
:return: filename string
"""
filename = secure_filename(request_.files["filepond"].filename)
file_path = (
Path(PostBreachFilesService.get_custom_pba_directory()).joinpath(filename).absolute()
)
request_.files["filepond"].save(str(file_path))
filename = secure_filename(file_storage.filename)
file_contents = file_storage.read()
PostBreachFilesService.save_file(filename, file_contents)
ConfigService.set_config_value(
(PBA_LINUX_FILENAME_PATH if is_linux else PBA_WINDOWS_FILENAME_PATH), filename
)
return filename
@jwt_required

View File

@ -17,6 +17,12 @@ class PostBreachFilesService:
cls.DATA_DIR = data_dir
Path(cls.get_custom_pba_directory()).mkdir(mode=0o0700, parents=True, exist_ok=True)
@staticmethod
def save_file(filename: str, file_contents: bytes):
file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), filename)
with open(file_path, "wb") as f:
f.write(file_contents)
@staticmethod
def remove_PBA_files():
for f in os.listdir(PostBreachFilesService.get_custom_pba_directory()):

View File

@ -15,10 +15,7 @@ def custom_pba_directory(tmpdir):
def create_custom_pba_file(filename):
assert os.path.isdir(PostBreachFilesService.get_custom_pba_directory())
file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), filename)
open(file_path, "a").close()
PostBreachFilesService.save_file(filename, b"")
def test_remove_pba_files():
@ -59,3 +56,14 @@ def test_remove_nonexistant_file(monkeypatch):
PostBreachFilesService.remove_file("/nonexistant/file")
except Exception as ex:
pytest.fail(f"Unxepected exception: {ex}")
def test_save_file():
FILE_NAME = "test_file"
FILE_CONTENTS = b"hello"
PostBreachFilesService.save_file(FILE_NAME, FILE_CONTENTS)
expected_file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), FILE_NAME)
assert os.path.isfile(expected_file_path)
assert FILE_CONTENTS == open(expected_file_path, "rb").read()