forked from p15670423/monkey
Merge pull request #1915 from guardicore/1904-service-resource-dependency-injection
1904 service resource dependency injection
This commit is contained in:
commit
97300376ef
|
@ -101,3 +101,6 @@ venv/
|
|||
|
||||
# Hugo
|
||||
.hugo_build.lock
|
||||
|
||||
# mypy
|
||||
.mypy_cache
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .di_container import DIContainer
|
|
@ -0,0 +1,133 @@
|
|||
import inspect
|
||||
from typing import Any, MutableMapping, Sequence, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class UnregisteredTypeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DIContainer:
|
||||
"""
|
||||
A dependency injection (DI) container that uses type annotations to resolve and inject
|
||||
dependencies.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._type_registry = {}
|
||||
self._instance_registry = {}
|
||||
|
||||
def register(self, interface: Type[T], concrete_type: Type[T]):
|
||||
"""
|
||||
Register a concrete type that satisfies a given interface.
|
||||
|
||||
:param interface: An interface or abstract base class that other classes depend upon
|
||||
:param concrete_type: A type (class) that implements `interface`
|
||||
"""
|
||||
if not inspect.isclass(concrete_type):
|
||||
raise TypeError(
|
||||
"Expected a class, but received an instance of type "
|
||||
f'"{concrete_type.__class__.__name__}"; Pass a class, not an instance, to '
|
||||
"register(), or use register_instance() instead"
|
||||
)
|
||||
|
||||
if not issubclass(concrete_type, interface):
|
||||
raise TypeError(
|
||||
f'Class "{concrete_type.__name__}" is not a subclass of {interface.__name__}'
|
||||
)
|
||||
|
||||
self._type_registry[interface] = concrete_type
|
||||
DIContainer._del_key(self._instance_registry, interface)
|
||||
|
||||
def register_instance(self, interface: Type[T], instance: T):
|
||||
"""
|
||||
Register a concrete instance that satisfies a given interface.
|
||||
|
||||
:param interface: An interface or abstract base class that other classes depend upon
|
||||
:param instance: An instance (object) of a type that implements `interface`
|
||||
"""
|
||||
if not isinstance(instance, interface):
|
||||
raise TypeError(
|
||||
f'The provided instance of type "{instance.__class__.__name__}" '
|
||||
f"is not an instance of {interface.__name__}"
|
||||
)
|
||||
|
||||
self._instance_registry[interface] = instance
|
||||
DIContainer._del_key(self._type_registry, interface)
|
||||
|
||||
def resolve(self, type_: Type[T]) -> T:
|
||||
"""
|
||||
Resolves all dependencies and returns a new instance of `type_` using constructor dependency
|
||||
injection. Note that only positional arguments are resolved. Varargs, keyword-only args, and
|
||||
default values are ignored.
|
||||
|
||||
:param type_: A type (class) to construct
|
||||
:return: An instance of `type_`
|
||||
"""
|
||||
try:
|
||||
return self._resolve_type(type_)
|
||||
except UnregisteredTypeError:
|
||||
pass
|
||||
|
||||
args = self.resolve_dependencies(type_)
|
||||
return type_(*args)
|
||||
|
||||
def resolve_dependencies(self, type_: Type[T]) -> Sequence[Any]:
|
||||
"""
|
||||
Resolves all dependencies of type_ and returns a Sequence of objects that correspond type_'s
|
||||
dependencies. Note that only positional arguments are resolved. Varargs, keyword-only args,
|
||||
and default values are ignored.
|
||||
|
||||
:param type_: A type (class) to resolve dependencies for
|
||||
:return: An Sequence of dependencies to be injected into type_'s constructor
|
||||
"""
|
||||
args = []
|
||||
|
||||
for arg_type in inspect.getfullargspec(type_).annotations.values():
|
||||
instance = self._resolve_type(arg_type)
|
||||
args.append(instance)
|
||||
|
||||
return tuple(args)
|
||||
|
||||
def _resolve_type(self, type_: Type[T]) -> T:
|
||||
if type_ in self._type_registry:
|
||||
return self._construct_new_instance(type_)
|
||||
elif type_ in self._instance_registry:
|
||||
return self._retrieve_registered_instance(type_)
|
||||
|
||||
raise UnregisteredTypeError(f'Failed to resolve unregistered type "{type_.__name__}"')
|
||||
|
||||
def _construct_new_instance(self, arg_type: Type[T]) -> T:
|
||||
try:
|
||||
return self._type_registry[arg_type]()
|
||||
except TypeError:
|
||||
# arg_type has dependencies that must be resolved. Recursively call resolve() to
|
||||
# construct an instance of arg_type with all of the requesite dependencies injected.
|
||||
return self.resolve(self._type_registry[arg_type])
|
||||
|
||||
def _retrieve_registered_instance(self, arg_type: Type[T]) -> T:
|
||||
return self._instance_registry[arg_type]
|
||||
|
||||
def release(self, interface: Type[T]):
|
||||
"""
|
||||
Deregister's an interface
|
||||
|
||||
:param interface: The interface to release
|
||||
"""
|
||||
DIContainer._del_key(self._type_registry, interface)
|
||||
DIContainer._del_key(self._instance_registry, interface)
|
||||
|
||||
@staticmethod
|
||||
def _del_key(mapping: MutableMapping[T, Any], key: T):
|
||||
"""
|
||||
Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError
|
||||
if the key does not exist.
|
||||
|
||||
:param MutableMapping: A mapping from which a key will be deleted
|
||||
:param key: A key to delete from `mapping`
|
||||
"""
|
||||
try:
|
||||
del mapping[key]
|
||||
except KeyError:
|
||||
pass
|
|
@ -1,6 +1,7 @@
|
|||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class InvalidPath(Exception):
|
||||
|
@ -21,3 +22,7 @@ def get_file_sha256_hash(filepath: Path):
|
|||
sha256.update(block)
|
||||
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
|
||||
return filter(lambda f: f.is_file(), dir_path.iterdir())
|
||||
|
|
|
@ -2,10 +2,10 @@ import filecmp
|
|||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
from common.utils.file_utils import get_all_regular_files_in_directory
|
||||
from infection_monkey.utils.dir_utils import (
|
||||
file_extension_filter,
|
||||
filter_files,
|
||||
get_all_regular_files_in_directory,
|
||||
is_not_shortcut_filter,
|
||||
is_not_symlink_filter,
|
||||
)
|
||||
|
|
|
@ -2,10 +2,6 @@ from pathlib import Path
|
|||
from typing import Callable, Iterable, Set
|
||||
|
||||
|
||||
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
|
||||
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
|
||||
|
||||
|
||||
def filter_files(
|
||||
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
|
||||
) -> Iterable[Path]:
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import os
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Type
|
||||
|
||||
import flask_restful
|
||||
from flask import Flask, Response, send_from_directory
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from common import DIContainer
|
||||
from monkey_island.cc.database import database, mongo
|
||||
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
|
||||
from monkey_island.cc.resources.attack.attack_report import AttackReport
|
||||
|
@ -111,7 +113,17 @@ def init_app_url_rules(app):
|
|||
app.add_url_rule("/<path:static_path>", "serve_static_file", serve_static_file)
|
||||
|
||||
|
||||
def init_api_resources(api):
|
||||
class FlaskDIWrapper:
|
||||
def __init__(self, api: flask_restful.Api, container: DIContainer):
|
||||
self._api = api
|
||||
self._container = container
|
||||
|
||||
def add_resource(self, resource: Type[flask_restful.Resource], *urls: str):
|
||||
dependencies = self._container.resolve_dependencies(resource)
|
||||
self._api.add_resource(resource, *urls, resource_class_args=dependencies)
|
||||
|
||||
|
||||
def init_api_resources(api: FlaskDIWrapper):
|
||||
api.add_resource(Root, "/api")
|
||||
api.add_resource(Registration, "/api/registration")
|
||||
api.add_resource(Authenticate, "/api/auth")
|
||||
|
@ -148,13 +160,18 @@ def init_api_resources(api):
|
|||
api.add_resource(TelemetryFeed, "/api/telemetry-feed")
|
||||
api.add_resource(Log, "/api/log")
|
||||
api.add_resource(IslandLog, "/api/log/island/download")
|
||||
api.add_resource(PBAFileDownload, "/api/pba/download/<string:filename>")
|
||||
|
||||
api.add_resource(
|
||||
PBAFileDownload,
|
||||
"/api/pba/download/<string:filename>",
|
||||
)
|
||||
api.add_resource(
|
||||
FileUpload,
|
||||
"/api/file-upload/<string:file_type>",
|
||||
"/api/file-upload/<string:file_type>?load=<string:filename>",
|
||||
"/api/file-upload/<string:file_type>?restore=<string:filename>",
|
||||
"/api/file-upload/<string:target_os>",
|
||||
"/api/file-upload/<string:target_os>?load=<string:filename>",
|
||||
"/api/file-upload/<string:target_os>?restore=<string:filename>",
|
||||
)
|
||||
|
||||
api.add_resource(PropagationCredentials, "/api/propagation-credentials/<string:guid>")
|
||||
api.add_resource(RemoteRun, "/api/remote-monkey")
|
||||
api.add_resource(VersionUpdate, "/api/version-update")
|
||||
|
@ -168,7 +185,7 @@ def init_api_resources(api):
|
|||
api.add_resource(TelemetryBlackboxEndpoint, "/api/test/telemetry")
|
||||
|
||||
|
||||
def init_app(mongo_url):
|
||||
def init_app(mongo_url: str, container: DIContainer):
|
||||
app = Flask(__name__)
|
||||
|
||||
api = flask_restful.Api(app)
|
||||
|
@ -177,6 +194,8 @@ def init_app(mongo_url):
|
|||
init_app_config(app, mongo_url)
|
||||
init_app_services(app)
|
||||
init_app_url_rules(app)
|
||||
init_api_resources(api)
|
||||
|
||||
flask_resource_manager = FlaskDIWrapper(api, container)
|
||||
init_api_resources(flask_resource_manager)
|
||||
|
||||
return app
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
import flask_restful
|
||||
from flask import send_from_directory
|
||||
import logging
|
||||
|
||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||
import flask_restful
|
||||
from flask import make_response, send_file
|
||||
|
||||
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class PBAFileDownload(flask_restful.Resource):
|
||||
|
@ -9,7 +13,17 @@ class PBAFileDownload(flask_restful.Resource):
|
|||
File download endpoint used by monkey to download user's PBA file
|
||||
"""
|
||||
|
||||
def __init__(self, file_storage_service: IFileStorageService):
|
||||
self._file_storage_service = file_storage_service
|
||||
|
||||
# Used by monkey. can't secure.
|
||||
def get(self, filename):
|
||||
custom_pba_dir = PostBreachFilesService.get_custom_pba_directory()
|
||||
return send_from_directory(custom_pba_dir, filename)
|
||||
def get(self, filename: str):
|
||||
try:
|
||||
file = self._file_storage_service.open_file(filename)
|
||||
|
||||
# `send_file()` handles the closing of the open file.
|
||||
return send_file(file, mimetype="application/octet-stream")
|
||||
except FileRetrievalError as err:
|
||||
error_msg = f"Failed to open file {filename}: {err}"
|
||||
logger.error(error_msg)
|
||||
return make_response({"error": error_msg}, 404)
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
import copy
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
|
||||
import flask_restful
|
||||
from flask import Response, request, send_from_directory
|
||||
from flask import Response, make_response, request, send_file
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from common.config_value_paths import PBA_LINUX_FILENAME_PATH, PBA_WINDOWS_FILENAME_PATH
|
||||
from monkey_island.cc.resources.auth.auth import jwt_required
|
||||
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||
from monkey_island.cc.services.config import ConfigService
|
||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
# Front end uses these strings to identify which files to work with (linux or windows)
|
||||
LINUX_PBA_TYPE = "PBAlinux"
|
||||
|
@ -21,42 +25,60 @@ class FileUpload(flask_restful.Resource):
|
|||
File upload endpoint used to exchange files with filepond component on the front-end
|
||||
"""
|
||||
|
||||
def __init__(self, file_storage_service: IFileStorageService):
|
||||
self._file_storage_service = file_storage_service
|
||||
|
||||
# TODO: Fix references/coupling to filepond
|
||||
# TODO: Add comment explaining why this is basically a duplicate of the endpoint in the
|
||||
# PBAFileDownload resource.
|
||||
@jwt_required
|
||||
def get(self, file_type):
|
||||
def get(self, target_os):
|
||||
"""
|
||||
Sends file to filepond
|
||||
:param file_type: Type indicates which file to send, linux or windows
|
||||
:param target_os: Indicates which file to send, linux or windows
|
||||
:return: Returns file contents
|
||||
"""
|
||||
if self.is_pba_file_type_supported(file_type):
|
||||
if self._is_target_os_supported(target_os):
|
||||
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
||||
|
||||
# Verify that file_name is indeed a file from config
|
||||
if file_type == LINUX_PBA_TYPE:
|
||||
if target_os == LINUX_PBA_TYPE:
|
||||
# TODO: Make these paths Tuples so we don't need to copy them
|
||||
filename = ConfigService.get_config_value(copy.deepcopy(PBA_LINUX_FILENAME_PATH))
|
||||
else:
|
||||
filename = ConfigService.get_config_value(copy.deepcopy(PBA_WINDOWS_FILENAME_PATH))
|
||||
return send_from_directory(PostBreachFilesService.get_custom_pba_directory(), filename)
|
||||
|
||||
try:
|
||||
file = self._file_storage_service.open_file(filename)
|
||||
|
||||
# `send_file()` handles the closing of the open file.
|
||||
return send_file(file, mimetype="application/octet-stream")
|
||||
except FileRetrievalError as err:
|
||||
error_msg = f"Failed to open file {filename}: {err}"
|
||||
logger.error(error_msg)
|
||||
return make_response({"error": error_msg}, 404)
|
||||
|
||||
@jwt_required
|
||||
def post(self, file_type):
|
||||
def post(self, target_os):
|
||||
"""
|
||||
Receives user's uploaded file from filepond
|
||||
:param file_type: Type indicates which file was received, linux or windows
|
||||
:param target_os: Type indicates which file was received, linux or windows
|
||||
:return: Returns flask response object with uploaded file's filename
|
||||
"""
|
||||
if self.is_pba_file_type_supported(file_type):
|
||||
if self._is_target_os_supported(target_os):
|
||||
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
||||
|
||||
filename = FileUpload.upload_pba_file(
|
||||
request.files["filepond"], (file_type == LINUX_PBA_TYPE)
|
||||
filename = self._upload_pba_file(
|
||||
# TODO: This "filepond" string can be changed to be more generic in the `react-filepond`
|
||||
# component.
|
||||
request.files["filepond"],
|
||||
(target_os == LINUX_PBA_TYPE),
|
||||
)
|
||||
|
||||
response = Response(response=filename, status=200, mimetype="text/plain")
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def upload_pba_file(file_storage: FileStorage, is_linux=True):
|
||||
def _upload_pba_file(self, file_storage: FileStorage, is_linux=True):
|
||||
"""
|
||||
Uploads PBA file to island's file system
|
||||
:param request_: Request object containing PBA file
|
||||
|
@ -64,9 +86,7 @@ class FileUpload(flask_restful.Resource):
|
|||
:return: filename string
|
||||
"""
|
||||
filename = secure_filename(file_storage.filename)
|
||||
file_contents = file_storage.read()
|
||||
|
||||
PostBreachFilesService.save_file(filename, file_contents)
|
||||
self._file_storage_service.save_file(filename, file_storage.stream)
|
||||
|
||||
ConfigService.set_config_value(
|
||||
(PBA_LINUX_FILENAME_PATH if is_linux else PBA_WINDOWS_FILENAME_PATH), filename
|
||||
|
@ -75,25 +95,25 @@ class FileUpload(flask_restful.Resource):
|
|||
return filename
|
||||
|
||||
@jwt_required
|
||||
def delete(self, file_type):
|
||||
def delete(self, target_os):
|
||||
"""
|
||||
Deletes file that has been deleted on the front end
|
||||
:param file_type: Type indicates which file was deleted, linux of windows
|
||||
:param target_os: Type indicates which file was deleted, linux of windows
|
||||
:return: Empty response
|
||||
"""
|
||||
if self.is_pba_file_type_supported(file_type):
|
||||
if self._is_target_os_supported(target_os):
|
||||
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
||||
|
||||
filename_path = (
|
||||
PBA_LINUX_FILENAME_PATH if file_type == "PBAlinux" else PBA_WINDOWS_FILENAME_PATH
|
||||
PBA_LINUX_FILENAME_PATH if target_os == "PBAlinux" else PBA_WINDOWS_FILENAME_PATH
|
||||
)
|
||||
filename = ConfigService.get_config_value(filename_path)
|
||||
if filename:
|
||||
PostBreachFilesService.remove_file(filename)
|
||||
self._file_storage_service.delete_file(filename)
|
||||
ConfigService.set_config_value(filename_path, "")
|
||||
|
||||
return {}
|
||||
return make_response({}, 200)
|
||||
|
||||
@staticmethod
|
||||
def is_pba_file_type_supported(file_type: str) -> bool:
|
||||
return file_type not in {LINUX_PBA_TYPE, WINDOWS_PBA_TYPE}
|
||||
def _is_target_os_supported(target_os: str) -> bool:
|
||||
return target_os not in {LINUX_PBA_TYPE, WINDOWS_PBA_TYPE}
|
||||
|
|
|
@ -16,6 +16,7 @@ MONKEY_ISLAND_DIR_BASE_PATH = str(Path(__file__).parent.parent)
|
|||
if str(MONKEY_ISLAND_DIR_BASE_PATH) not in sys.path:
|
||||
sys.path.insert(0, MONKEY_ISLAND_DIR_BASE_PATH)
|
||||
|
||||
from common import DIContainer # noqa: E402
|
||||
from common.version import get_version # noqa: E402
|
||||
from monkey_island.cc.app import init_app # noqa: E402
|
||||
from monkey_island.cc.arg_parser import IslandCmdArgs # noqa: E402
|
||||
|
@ -47,7 +48,7 @@ def run_monkey_island():
|
|||
_exit_on_invalid_config_options(config_options)
|
||||
|
||||
_configure_logging(config_options)
|
||||
_initialize_globals(config_options.data_dir)
|
||||
container = _initialize_di_container(config_options.data_dir)
|
||||
|
||||
mongo_db_process = None
|
||||
if config_options.start_mongodb:
|
||||
|
@ -56,7 +57,7 @@ def run_monkey_island():
|
|||
_connect_to_mongodb(mongo_db_process)
|
||||
|
||||
_configure_gevent_exception_handling(config_options.data_dir)
|
||||
_start_island_server(island_args.setup_only, config_options)
|
||||
_start_island_server(island_args.setup_only, config_options, container)
|
||||
|
||||
|
||||
def _extract_config(island_args: IslandCmdArgs) -> IslandConfigOptions:
|
||||
|
@ -88,8 +89,8 @@ def _configure_logging(config_options):
|
|||
setup_logging(config_options.data_dir, config_options.log_level)
|
||||
|
||||
|
||||
def _initialize_globals(data_dir: Path):
|
||||
initialize_services(data_dir)
|
||||
def _initialize_di_container(data_dir: Path) -> DIContainer:
|
||||
return initialize_services(data_dir)
|
||||
|
||||
|
||||
def _start_mongodb(data_dir: Path) -> MongoDbProcess:
|
||||
|
@ -127,9 +128,11 @@ def _configure_gevent_exception_handling(data_dir):
|
|||
hub.handle_error = GeventHubErrorHandler(hub, logger)
|
||||
|
||||
|
||||
def _start_island_server(should_setup_only, config_options: IslandConfigOptions):
|
||||
def _start_island_server(
|
||||
should_setup_only: bool, config_options: IslandConfigOptions, container: DIContainer
|
||||
):
|
||||
populate_exporter_list()
|
||||
app = init_app(mongo_setup.MONGO_URL)
|
||||
app = init_app(mongo_setup.MONGO_URL, container)
|
||||
|
||||
if should_setup_only:
|
||||
logger.warning("Setup only flag passed. Exiting.")
|
||||
|
|
|
@ -1,2 +1,5 @@
|
|||
from .i_file_storage_service import IFileStorageService, FileRetrievalError
|
||||
from .directory_file_storage_service import DirectoryFileStorageService
|
||||
|
||||
from .authentication.authentication_service import AuthenticationService
|
||||
from .authentication.json_file_user_datastore import JsonFileUserDatastore
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from pathlib import Path
|
||||
|
||||
import bcrypt
|
||||
|
||||
from common.utils.exceptions import (
|
||||
|
@ -23,7 +25,7 @@ class AuthenticationService:
|
|||
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
||||
# not a priority.
|
||||
@classmethod
|
||||
def initialize(cls, data_dir: str, user_datastore: IUserDatastore):
|
||||
def initialize(cls, data_dir: Path, user_datastore: IUserDatastore):
|
||||
cls.DATA_DIR = data_dir
|
||||
cls.user_datastore = user_datastore
|
||||
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO
|
||||
|
||||
from common.utils.file_utils import get_all_regular_files_in_directory
|
||||
from monkey_island.cc.server_utils.file_utils import create_secure_directory
|
||||
|
||||
from . import FileRetrievalError, IFileStorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DirectoryFileStorageService(IFileStorageService):
|
||||
"""
|
||||
A implementation of IFileStorageService that reads and writes files from/to the local
|
||||
filesystem.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_directory: Path):
|
||||
"""
|
||||
:param storage_directory: A Path object representing the directory where files will be
|
||||
stored. If the directory does not exist, it will be created.
|
||||
"""
|
||||
if storage_directory.exists() and not storage_directory.is_dir():
|
||||
raise ValueError(f"The provided path must point to a directory: {storage_directory}")
|
||||
|
||||
if not storage_directory.exists():
|
||||
create_secure_directory(storage_directory)
|
||||
|
||||
self._storage_directory = storage_directory
|
||||
|
||||
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
|
||||
safe_file_path = self._get_safe_file_path(unsafe_file_name)
|
||||
|
||||
logger.debug(f"Saving file contents to {safe_file_path}")
|
||||
with open(safe_file_path, "wb") as dest:
|
||||
shutil.copyfileobj(file_contents, dest)
|
||||
|
||||
def open_file(self, unsafe_file_name: str) -> BinaryIO:
|
||||
safe_file_path = self._get_safe_file_path(unsafe_file_name)
|
||||
|
||||
try:
|
||||
logger.debug(f"Opening {safe_file_path}")
|
||||
return open(safe_file_path, "rb")
|
||||
except OSError as err:
|
||||
logger.error(err)
|
||||
raise FileRetrievalError(f"Failed to retrieve file {safe_file_path}: {err}") from err
|
||||
|
||||
def delete_file(self, unsafe_file_name: str):
|
||||
safe_file_path = self._get_safe_file_path(unsafe_file_name)
|
||||
|
||||
try:
|
||||
logger.debug(f"Deleting {safe_file_path}")
|
||||
safe_file_path.unlink()
|
||||
except FileNotFoundError:
|
||||
# This method is idempotent.
|
||||
pass
|
||||
|
||||
def _get_safe_file_path(self, unsafe_file_name: str):
|
||||
# Remove any path information from the file name.
|
||||
safe_file_name = Path(unsafe_file_name).resolve().name
|
||||
safe_file_path = (self._storage_directory / safe_file_name).resolve()
|
||||
|
||||
# This is a paranoid check to avoid directory traversal attacks.
|
||||
if self._storage_directory.resolve() not in safe_file_path.parents:
|
||||
raise ValueError(f"The file named {unsafe_file_name} can not be safely retrieved")
|
||||
|
||||
logger.debug(f"Unsafe file name {unsafe_file_name} sanitized: {safe_file_path}")
|
||||
return safe_file_path
|
||||
|
||||
def delete_all_files(self):
|
||||
for file in get_all_regular_files_in_directory(self._storage_directory):
|
||||
logger.debug(f"Deleting {file}")
|
||||
file.unlink()
|
|
@ -0,0 +1,53 @@
|
|||
import abc
|
||||
from typing import BinaryIO
|
||||
|
||||
|
||||
class FileRetrievalError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class IFileStorageService(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
A service that allows the storage and retrieval of individual files.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
|
||||
"""
|
||||
Save a file, identified by a name
|
||||
|
||||
:param unsafe_file_name: An unsanitized file name that will identify the file
|
||||
:param file_contents: The data to be stored in the file
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def open_file(self, unsafe_file_name: str) -> BinaryIO:
|
||||
"""
|
||||
Open a file and return a file-like object
|
||||
|
||||
:param unsafe_file_name: An unsanitized file name that identifies the file to be opened
|
||||
:return: A file-like object providing access to the file's contents
|
||||
:rtype: io.BinaryIO
|
||||
:raises FileRetrievalError: if the file cannot be opened
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_file(self, unsafe_file_name: str):
|
||||
"""
|
||||
Delete a file
|
||||
|
||||
This method will delete the file specified by `unsafe_file_name`. This operation is
|
||||
idempotent and will succeed if the file to be deleted does not exist.
|
||||
|
||||
:param unsafe_file_name: An unsanitized file name that identifies the file to be deleted
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_all_files(self):
|
||||
"""
|
||||
Delete all files that have been stored using this service.
|
||||
"""
|
||||
pass
|
|
@ -1,10 +1,22 @@
|
|||
from pathlib import Path
|
||||
|
||||
from common import DIContainer
|
||||
from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService
|
||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
||||
|
||||
from . import AuthenticationService, JsonFileUserDatastore
|
||||
|
||||
|
||||
def initialize_services(data_dir):
|
||||
PostBreachFilesService.initialize(data_dir)
|
||||
def initialize_services(data_dir: Path) -> DIContainer:
|
||||
container = DIContainer()
|
||||
container.register_instance(
|
||||
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
|
||||
)
|
||||
|
||||
# This is temporary until we get DI all worked out.
|
||||
PostBreachFilesService.initialize(container.resolve(IFileStorageService))
|
||||
LocalMonkeyRunService.initialize(data_dir)
|
||||
AuthenticationService.initialize(data_dir, JsonFileUserDatastore(data_dir))
|
||||
|
||||
return container
|
||||
|
|
|
@ -1,46 +1,24 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from monkey_island.cc.server_utils.file_utils import create_secure_directory
|
||||
from monkey_island.cc.services import IFileStorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: This service wraps an IFileStorageService for the sole purpose of making the
|
||||
# `remove_PBA_files()` method available to the ConfigService. This whole service can be
|
||||
# removed once ConfigService is refactored to be stateful (it already is but everything is
|
||||
# still statically/globally scoped) and use dependency injection.
|
||||
class PostBreachFilesService:
|
||||
DATA_DIR = None
|
||||
CUSTOM_PBA_DIRNAME = "custom_pbas"
|
||||
_file_storage_service = None
|
||||
|
||||
# TODO: A number of these services should be instance objects instead of
|
||||
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
||||
# not a priority.
|
||||
@classmethod
|
||||
def initialize(cls, data_dir):
|
||||
cls.DATA_DIR = data_dir
|
||||
custom_pba_dir = cls.get_custom_pba_directory()
|
||||
create_secure_directory(custom_pba_dir)
|
||||
def initialize(cls, file_storage_service: IFileStorageService):
|
||||
cls._file_storage_service = file_storage_service
|
||||
|
||||
@staticmethod
|
||||
def save_file(filename: str, file_contents: bytes):
|
||||
file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
@staticmethod
|
||||
def remove_PBA_files():
|
||||
for f in os.listdir(PostBreachFilesService.get_custom_pba_directory()):
|
||||
PostBreachFilesService.remove_file(f)
|
||||
|
||||
@staticmethod
|
||||
def remove_file(file_name):
|
||||
file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), file_name)
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
except OSError as e:
|
||||
logger.error("Can't remove previously uploaded post breach files: %s" % e)
|
||||
|
||||
@staticmethod
|
||||
def get_custom_pba_directory():
|
||||
return os.path.join(
|
||||
PostBreachFilesService.DATA_DIR, PostBreachFilesService.CUSTOM_PBA_DIRNAME
|
||||
)
|
||||
@classmethod
|
||||
def remove_PBA_files(cls):
|
||||
cls._file_storage_service.delete_all_files()
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import platform
|
||||
import stat
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
from monkey_island.cc.resources.monkey_download import get_agent_executable_path
|
||||
|
@ -19,7 +20,7 @@ class LocalMonkeyRunService:
|
|||
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
||||
# not a priority.
|
||||
@classmethod
|
||||
def initialize(cls, data_dir):
|
||||
def initialize(cls, data_dir: Path):
|
||||
cls.DATA_DIR = data_dir
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -7,6 +7,9 @@ if is_windows_os():
|
|||
FULL_CONTROL = 2032127
|
||||
ACE_ACCESS_MODE_GRANT_ACCESS = win32security.GRANT_ACCESS
|
||||
ACE_INHERIT_OBJECT_AND_CONTAINER = 3
|
||||
else:
|
||||
import os
|
||||
import stat
|
||||
|
||||
|
||||
def _get_acl_and_sid_from_path(path: str):
|
||||
|
@ -33,3 +36,12 @@ def assert_windows_permissions(path: str):
|
|||
assert ace_sid == user_sid
|
||||
assert ace_permissions == FULL_CONTROL and ace_access_mode == ACE_ACCESS_MODE_GRANT_ACCESS
|
||||
assert ace_inheritance == ACE_INHERIT_OBJECT_AND_CONTAINER
|
||||
|
||||
|
||||
def assert_linux_permissions(path: str):
|
||||
st = os.stat(path)
|
||||
|
||||
expected_mode = stat.S_IRWXU
|
||||
actual_mode = st.st_mode & (stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||
|
||||
assert expected_mode == actual_mode
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
import abc
|
||||
|
||||
import pytest
|
||||
|
||||
from common import DIContainer
|
||||
|
||||
|
||||
class IServiceA(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def do_something(self):
|
||||
pass
|
||||
|
||||
|
||||
class IServiceB(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceA(IServiceA):
|
||||
def do_something(self):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceB(IServiceB):
|
||||
pass
|
||||
|
||||
|
||||
class TestClass1:
|
||||
__test__ = False
|
||||
|
||||
def __init__(self, service_a: IServiceA):
|
||||
self.service_a = service_a
|
||||
|
||||
|
||||
class TestClass2:
|
||||
__test__ = False
|
||||
|
||||
def __init__(self, service_b: IServiceB):
|
||||
self.service_b = service_b
|
||||
|
||||
|
||||
class TestClass3:
|
||||
__test__ = False
|
||||
|
||||
def __init__(self, service_a: IServiceA, service_b: IServiceB):
|
||||
self.service_a = service_a
|
||||
self.service_b = service_b
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def container():
|
||||
return DIContainer()
|
||||
|
||||
|
||||
def test_register_resolve(container):
|
||||
container.register(IServiceA, ServiceA)
|
||||
test_1 = container.resolve(TestClass1)
|
||||
|
||||
assert isinstance(test_1.service_a, ServiceA)
|
||||
|
||||
|
||||
def test_correct_instance_type_injected(container):
|
||||
container.register(IServiceA, ServiceA)
|
||||
container.register(IServiceB, ServiceB)
|
||||
test_1 = container.resolve(TestClass1)
|
||||
test_2 = container.resolve(TestClass2)
|
||||
|
||||
assert isinstance(test_1.service_a, ServiceA)
|
||||
assert isinstance(test_2.service_b, ServiceB)
|
||||
|
||||
|
||||
def test_multiple_correct_instance_types_injected(container):
|
||||
container.register(IServiceA, ServiceA)
|
||||
container.register(IServiceB, ServiceB)
|
||||
test_3 = container.resolve(TestClass3)
|
||||
|
||||
assert isinstance(test_3.service_a, ServiceA)
|
||||
assert isinstance(test_3.service_b, ServiceB)
|
||||
|
||||
|
||||
def test_register_instance(container):
|
||||
service_a_instance = ServiceA()
|
||||
|
||||
container.register_instance(IServiceA, service_a_instance)
|
||||
test_1 = container.resolve(TestClass1)
|
||||
|
||||
assert id(service_a_instance) == id(test_1.service_a)
|
||||
|
||||
|
||||
def test_register_multiple_instances(container):
|
||||
service_a_instance = ServiceA()
|
||||
service_b_instance = ServiceB()
|
||||
|
||||
container.register_instance(IServiceA, service_a_instance)
|
||||
container.register_instance(IServiceB, service_b_instance)
|
||||
test_3 = container.resolve(TestClass3)
|
||||
|
||||
assert id(service_a_instance) == id(test_3.service_a)
|
||||
assert id(service_b_instance) == id(test_3.service_b)
|
||||
|
||||
|
||||
def test_register_mixed_instance_and_type(container):
|
||||
service_a_instance = ServiceA()
|
||||
|
||||
container.register_instance(IServiceA, service_a_instance)
|
||||
container.register(IServiceB, ServiceB)
|
||||
test_2 = container.resolve(TestClass2)
|
||||
test_3 = container.resolve(TestClass3)
|
||||
|
||||
assert id(service_a_instance) == id(test_3.service_a)
|
||||
assert isinstance(test_2.service_b, ServiceB)
|
||||
assert isinstance(test_3.service_b, ServiceB)
|
||||
assert id(test_2.service_b) != id(test_3.service_b)
|
||||
|
||||
|
||||
def test_unregistered_type():
|
||||
container = DIContainer()
|
||||
with pytest.raises(ValueError):
|
||||
container.resolve(TestClass1)
|
||||
|
||||
|
||||
def test_type_registration_overwritten(container):
|
||||
class ServiceA2(IServiceA):
|
||||
def do_something(self):
|
||||
pass
|
||||
|
||||
container.register(IServiceA, ServiceA)
|
||||
container.register(IServiceA, ServiceA2)
|
||||
test_1 = container.resolve(TestClass1)
|
||||
|
||||
assert isinstance(test_1.service_a, ServiceA2)
|
||||
|
||||
|
||||
def test_instance_registration_overwritten(container):
|
||||
service_a_instance_1 = ServiceA()
|
||||
service_a_instance_2 = ServiceA()
|
||||
|
||||
container.register_instance(IServiceA, service_a_instance_1)
|
||||
container.register_instance(IServiceA, service_a_instance_2)
|
||||
test_1 = container.resolve(TestClass1)
|
||||
|
||||
assert id(test_1.service_a) != id(service_a_instance_1)
|
||||
assert id(test_1.service_a) == id(service_a_instance_2)
|
||||
|
||||
|
||||
def test_type_overrides_instance(container):
|
||||
service_a_instance = ServiceA()
|
||||
|
||||
container.register_instance(IServiceA, service_a_instance)
|
||||
container.register(IServiceA, ServiceA)
|
||||
test_1 = container.resolve(TestClass1)
|
||||
|
||||
assert id(test_1.service_a) != id(service_a_instance)
|
||||
assert isinstance(test_1.service_a, ServiceA)
|
||||
|
||||
|
||||
def test_instance_overrides_type(container):
|
||||
service_a_instance = ServiceA()
|
||||
|
||||
container.register(IServiceA, ServiceA)
|
||||
container.register_instance(IServiceA, service_a_instance)
|
||||
test_1 = container.resolve(TestClass1)
|
||||
|
||||
assert id(test_1.service_a) == id(service_a_instance)
|
||||
|
||||
|
||||
def test_release_type(container):
|
||||
container.register(IServiceA, ServiceA)
|
||||
container.release(IServiceA)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
container.resolve(TestClass1)
|
||||
|
||||
|
||||
def test_release_instance(container):
|
||||
service_a_instance = ServiceA()
|
||||
container.register_instance(IServiceA, service_a_instance)
|
||||
|
||||
container.release(IServiceA)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
container.resolve(TestClass1)
|
||||
|
||||
|
||||
class IServiceC(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceC(IServiceC):
|
||||
def __init__(self, service_a: IServiceA):
|
||||
self.service_a = service_a
|
||||
|
||||
|
||||
class TestClass4:
|
||||
__test__ = False
|
||||
|
||||
def __init__(self, service_c: IServiceC):
|
||||
self.service_c = service_c
|
||||
|
||||
|
||||
def test_recursive_resolution__depth_2(container):
|
||||
service_a_instance = ServiceA()
|
||||
container.register_instance(IServiceA, service_a_instance)
|
||||
container.register(IServiceC, ServiceC)
|
||||
|
||||
test4 = container.resolve(TestClass4)
|
||||
|
||||
assert isinstance(test4.service_c, ServiceC)
|
||||
assert id(test4.service_c.service_a) == id(service_a_instance)
|
||||
|
||||
|
||||
class IServiceD(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceD(IServiceD):
|
||||
def __init__(self, service_c: IServiceC, service_b: IServiceB):
|
||||
self.service_b = service_b
|
||||
self.service_c = service_c
|
||||
|
||||
|
||||
class TestClass5:
|
||||
__test__ = False
|
||||
|
||||
def __init__(self, service_d: IServiceD):
|
||||
self.service_d = service_d
|
||||
|
||||
|
||||
def test_recursive_resolution__depth_3(container):
|
||||
container.register(IServiceA, ServiceA)
|
||||
container.register(IServiceB, ServiceB)
|
||||
container.register(IServiceC, ServiceC)
|
||||
container.register(IServiceD, ServiceD)
|
||||
|
||||
test5 = container.resolve(TestClass5)
|
||||
|
||||
assert isinstance(test5.service_d, ServiceD)
|
||||
assert isinstance(test5.service_d.service_b, ServiceB)
|
||||
assert isinstance(test5.service_d.service_c, ServiceC)
|
||||
assert isinstance(test5.service_d.service_c.service_a, ServiceA)
|
||||
|
||||
|
||||
def test_resolve_registered_interface(container):
|
||||
container.register(IServiceA, ServiceA)
|
||||
|
||||
resolved_instance = container.resolve(IServiceA)
|
||||
|
||||
assert isinstance(resolved_instance, ServiceA)
|
||||
|
||||
|
||||
def test_resolve_registered_instance(container):
|
||||
service_a_instance = ServiceA()
|
||||
container.register_instance(IServiceA, service_a_instance)
|
||||
|
||||
service_a_actual_instance = container.resolve(IServiceA)
|
||||
|
||||
assert id(service_a_actual_instance) == id(service_a_instance)
|
||||
|
||||
|
||||
def test_resolve_dependencies(container):
|
||||
container.register(IServiceA, ServiceA)
|
||||
container.register(IServiceB, ServiceB)
|
||||
|
||||
dependencies = container.resolve_dependencies(TestClass3)
|
||||
|
||||
assert isinstance(dependencies[0], ServiceA)
|
||||
assert isinstance(dependencies[1], ServiceB)
|
||||
|
||||
|
||||
def test_register_instance_as_type(container):
|
||||
service_a_instance = ServiceA()
|
||||
with pytest.raises(TypeError):
|
||||
container.register(IServiceA, service_a_instance)
|
||||
|
||||
|
||||
def test_register_conflicting_type(container):
|
||||
with pytest.raises(TypeError):
|
||||
container.register(IServiceA, ServiceB)
|
||||
|
||||
|
||||
def test_register_instance_with_conflicting_type(container):
|
||||
service_b_instance = ServiceB()
|
||||
with pytest.raises(TypeError):
|
||||
container.register_instance(IServiceA, service_b_instance)
|
|
@ -1,8 +1,14 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from tests.utils import add_files_to_dir, add_subdirs_to_dir
|
||||
|
||||
from common.utils.file_utils import InvalidPath, expand_path, get_file_sha256_hash
|
||||
from common.utils.file_utils import (
|
||||
InvalidPath,
|
||||
expand_path,
|
||||
get_all_regular_files_in_directory,
|
||||
get_file_sha256_hash,
|
||||
)
|
||||
|
||||
|
||||
def test_expand_user(patched_home_env):
|
||||
|
@ -26,3 +32,32 @@ def test_expand_path__empty_path_provided():
|
|||
|
||||
def test_get_file_sha256_hash(stable_file, stable_file_sha256_hash):
|
||||
assert get_file_sha256_hash(stable_file) == stable_file_sha256_hash
|
||||
|
||||
|
||||
SUBDIRS = ["subdir1", "subdir2"]
|
||||
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg"]
|
||||
|
||||
|
||||
def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
|
||||
add_subdirs_to_dir(tmp_path, SUBDIRS)
|
||||
|
||||
expected_return_value = []
|
||||
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||
|
||||
|
||||
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
|
||||
add_subdirs_to_dir(tmp_path, SUBDIRS)
|
||||
files = add_files_to_dir(tmp_path, FILES)
|
||||
|
||||
expected_return_value = sorted(files)
|
||||
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||
|
||||
|
||||
def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch):
|
||||
subdirs = add_subdirs_to_dir(tmp_path, SUBDIRS)
|
||||
add_files_to_dir(subdirs[0], FILES)
|
||||
|
||||
files = add_files_to_dir(tmp_path, FILES)
|
||||
|
||||
expected_return_value = sorted(files)
|
||||
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||
|
|
|
@ -1,66 +1,22 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from tests.utils import is_user_admin
|
||||
from tests.utils import add_files_to_dir, is_user_admin
|
||||
|
||||
from common.utils.file_utils import get_all_regular_files_in_directory
|
||||
from infection_monkey.utils.dir_utils import (
|
||||
file_extension_filter,
|
||||
filter_files,
|
||||
get_all_regular_files_in_directory,
|
||||
is_not_shortcut_filter,
|
||||
is_not_symlink_filter,
|
||||
)
|
||||
|
||||
SHORTCUT = "shortcut.lnk"
|
||||
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg", SHORTCUT]
|
||||
SUBDIRS = ["subdir1", "subdir2"]
|
||||
|
||||
|
||||
def add_subdirs_to_dir(parent_dir):
|
||||
subdirs = [parent_dir / s for s in SUBDIRS]
|
||||
|
||||
for subdir in subdirs:
|
||||
subdir.mkdir()
|
||||
|
||||
return subdirs
|
||||
|
||||
|
||||
def add_files_to_dir(parent_dir):
|
||||
files = [parent_dir / f for f in FILES]
|
||||
|
||||
for f in files:
|
||||
f.touch()
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
|
||||
add_subdirs_to_dir(tmp_path)
|
||||
|
||||
expected_return_value = []
|
||||
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||
|
||||
|
||||
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
|
||||
add_subdirs_to_dir(tmp_path)
|
||||
files = add_files_to_dir(tmp_path)
|
||||
|
||||
expected_return_value = sorted(files)
|
||||
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||
|
||||
|
||||
def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch):
|
||||
subdirs = add_subdirs_to_dir(tmp_path)
|
||||
add_files_to_dir(subdirs[0])
|
||||
|
||||
files = add_files_to_dir(tmp_path)
|
||||
|
||||
expected_return_value = sorted(files)
|
||||
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||
|
||||
|
||||
def test_filter_files__no_results(tmp_path):
|
||||
add_files_to_dir(tmp_path)
|
||||
add_files_to_dir(tmp_path, FILES)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
|
||||
|
@ -69,7 +25,7 @@ def test_filter_files__no_results(tmp_path):
|
|||
|
||||
|
||||
def test_filter_files__all_true(tmp_path):
|
||||
files = add_files_to_dir(tmp_path)
|
||||
files = add_files_to_dir(tmp_path, FILES)
|
||||
expected_return_value = sorted(files)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
|
@ -79,7 +35,7 @@ def test_filter_files__all_true(tmp_path):
|
|||
|
||||
|
||||
def test_filter_files__multiple_filters(tmp_path):
|
||||
files = add_files_to_dir(tmp_path)
|
||||
files = add_files_to_dir(tmp_path, FILES)
|
||||
expected_return_value = sorted(files[4:6])
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
|
@ -93,7 +49,7 @@ def test_filter_files__multiple_filters(tmp_path):
|
|||
def test_file_extension_filter(tmp_path):
|
||||
valid_extensions = {".zip", ".xyz"}
|
||||
|
||||
files = add_files_to_dir(tmp_path)
|
||||
files = add_files_to_dir(tmp_path, FILES)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = filter_files(files_in_dir, [file_extension_filter(valid_extensions)])
|
||||
|
@ -105,7 +61,7 @@ def test_file_extension_filter(tmp_path):
|
|||
os.name == "nt" and not is_user_admin(), reason="Test requires admin rights on Windows"
|
||||
)
|
||||
def test_is_not_symlink_filter(tmp_path):
|
||||
files = add_files_to_dir(tmp_path)
|
||||
files = add_files_to_dir(tmp_path, FILES)
|
||||
link_path = tmp_path / "symlink.test"
|
||||
link_path.symlink_to(files[0], target_is_directory=False)
|
||||
|
||||
|
@ -118,7 +74,7 @@ def test_is_not_symlink_filter(tmp_path):
|
|||
|
||||
|
||||
def test_is_not_shortcut_filter(tmp_path):
|
||||
add_files_to_dir(tmp_path)
|
||||
add_files_to_dir(tmp_path, FILES)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import flask_jwt_extended
|
||||
import flask_restful
|
||||
import pytest
|
||||
|
@ -9,15 +11,28 @@ import monkey_island.cc.resources.island_mode
|
|||
from monkey_island.cc.services.representations import output_json
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture
|
||||
def flask_client(monkeypatch_session):
|
||||
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
|
||||
|
||||
with mock_init_app().test_client() as client:
|
||||
container = MagicMock()
|
||||
container.resolve_dependencies.return_value = []
|
||||
|
||||
with mock_init_app(container).test_client() as client:
|
||||
yield client
|
||||
|
||||
|
||||
def mock_init_app():
|
||||
@pytest.fixture
|
||||
def build_flask_client(monkeypatch_session):
|
||||
def inner(container):
|
||||
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
|
||||
|
||||
return mock_init_app(container).test_client()
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def mock_init_app(container):
|
||||
app = Flask(__name__)
|
||||
app.config["SECRET_KEY"] = "test_key"
|
||||
|
||||
|
@ -25,7 +40,8 @@ def mock_init_app():
|
|||
api.representations = {"application/json": output_json}
|
||||
|
||||
monkey_island.cc.app.init_app_url_rules(app)
|
||||
monkey_island.cc.app.init_api_resources(api)
|
||||
flask_resource_manager = monkey_island.cc.app.FlaskDIWrapper(api, container)
|
||||
monkey_island.cc.app.init_api_resources(flask_resource_manager)
|
||||
|
||||
flask_jwt_extended.JWTManager(app)
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import io
|
||||
from typing import BinaryIO
|
||||
|
||||
import pytest
|
||||
|
||||
from common import DIContainer
|
||||
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||
|
||||
FILE_NAME = "test_file"
|
||||
FILE_CONTENTS = b"HelloWorld!"
|
||||
|
||||
|
||||
class MockFileStorageService(IFileStorageService):
|
||||
def __init__(self):
|
||||
self._file = io.BytesIO(FILE_CONTENTS)
|
||||
|
||||
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
|
||||
pass
|
||||
|
||||
def open_file(self, unsafe_file_name: str) -> BinaryIO:
|
||||
if unsafe_file_name != FILE_NAME:
|
||||
raise FileRetrievalError()
|
||||
|
||||
return self._file
|
||||
|
||||
def delete_file(self, unsafe_file_name: str):
|
||||
pass
|
||||
|
||||
def delete_all_files(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_client(build_flask_client, tmp_path):
|
||||
container = DIContainer()
|
||||
container.register(IFileStorageService, MockFileStorageService)
|
||||
|
||||
with build_flask_client(container) as flask_client:
|
||||
yield flask_client
|
||||
|
||||
|
||||
def test_file_download_endpoint(tmp_path, flask_client):
|
||||
resp = flask_client.get(f"/api/pba/download/{FILE_NAME}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert next(resp.response) == FILE_CONTENTS
|
||||
|
||||
|
||||
def test_file_download_endpoint_404(tmp_path, flask_client):
|
||||
nonexistant_file_name = "nonexistant_file"
|
||||
|
||||
resp = flask_client.get(f"/api/pba/download/{nonexistant_file_name}")
|
||||
|
||||
assert resp.status_code == 404
|
|
@ -1,10 +1,16 @@
|
|||
import io
|
||||
from typing import BinaryIO
|
||||
|
||||
import pytest
|
||||
from tests.utils import raise_
|
||||
|
||||
from common import DIContainer
|
||||
from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
|
||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||
|
||||
TEST_FILE = b"""-----------------------------1
|
||||
TEST_FILE_CONTENTS = b"m0nk3y"
|
||||
TEST_FILE = (
|
||||
b"""-----------------------------1
|
||||
Content-Disposition: form-data; name="filepond"
|
||||
|
||||
{}
|
||||
|
@ -12,13 +18,11 @@ Content-Disposition: form-data; name="filepond"
|
|||
Content-Disposition: form-data; name="filepond"; filename="test.py"
|
||||
Content-Type: text/x-python
|
||||
|
||||
m0nk3y
|
||||
"""
|
||||
+ TEST_FILE_CONTENTS
|
||||
+ b"""
|
||||
-----------------------------1--"""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def custom_pba_directory(tmpdir):
|
||||
PostBreachFilesService.initialize(tmpdir)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -35,8 +39,41 @@ def mock_get_config_value(monkeypatch):
|
|||
)
|
||||
|
||||
|
||||
class MockFileStorageService(IFileStorageService):
|
||||
def __init__(self):
|
||||
self._file = None
|
||||
|
||||
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
|
||||
self._file = io.BytesIO(file_contents.read())
|
||||
|
||||
def open_file(self, unsafe_file_name: str) -> BinaryIO:
|
||||
if self._file is None:
|
||||
raise FileRetrievalError()
|
||||
return self._file
|
||||
|
||||
def delete_file(self, unsafe_file_name: str):
|
||||
self._file = None
|
||||
|
||||
def delete_all_files(self):
|
||||
self.delete_file("")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_storage_service():
|
||||
return MockFileStorageService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_client(build_flask_client, file_storage_service):
|
||||
container = DIContainer()
|
||||
container.register_instance(IFileStorageService, file_storage_service)
|
||||
|
||||
with build_flask_client(container) as flask_client:
|
||||
yield flask_client
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
||||
def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config_value):
|
||||
def test_pba_file_upload_post(flask_client, pba_os, mock_set_config_value):
|
||||
resp = flask_client.post(
|
||||
f"/api/file-upload/{pba_os}",
|
||||
data=TEST_FILE,
|
||||
|
@ -46,7 +83,7 @@ def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config
|
|||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_pba_file_upload_post__invalid(flask_client, monkeypatch, mock_set_config_value):
|
||||
def test_pba_file_upload_post__invalid(flask_client, mock_set_config_value):
|
||||
resp = flask_client.post(
|
||||
"/api/file-upload/bogus",
|
||||
data=TEST_FILE,
|
||||
|
@ -58,12 +95,9 @@ def test_pba_file_upload_post__invalid(flask_client, monkeypatch, mock_set_confi
|
|||
|
||||
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
||||
def test_pba_file_upload_post__internal_server_error(
|
||||
flask_client, pba_os, monkeypatch, mock_set_config_value
|
||||
flask_client, pba_os, mock_set_config_value, file_storage_service
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
"monkey_island.cc.resources.pba_file_upload.FileUpload.upload_pba_file",
|
||||
lambda x, y: raise_(Exception()),
|
||||
)
|
||||
file_storage_service.save_file = lambda x, y: raise_(Exception())
|
||||
|
||||
resp = flask_client.post(
|
||||
f"/api/file-upload/{pba_os}",
|
||||
|
@ -75,16 +109,14 @@ def test_pba_file_upload_post__internal_server_error(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
||||
def test_pba_file_upload_get__file_not_found(
|
||||
flask_client, pba_os, monkeypatch, mock_get_config_value
|
||||
):
|
||||
def test_pba_file_upload_get__file_not_found(flask_client, pba_os, mock_get_config_value):
|
||||
resp = flask_client.get(f"/api/file-upload/{pba_os}?load=bogus_mogus.py")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
||||
def test_pba_file_upload_endpoint(
|
||||
flask_client, pba_os, monkeypatch, mock_get_config_value, mock_set_config_value
|
||||
flask_client, pba_os, mock_get_config_value, mock_set_config_value
|
||||
):
|
||||
resp_post = flask_client.post(
|
||||
f"/api/file-upload/{pba_os}",
|
||||
|
@ -95,7 +127,7 @@ def test_pba_file_upload_endpoint(
|
|||
|
||||
resp_get = flask_client.get(f"/api/file-upload/{pba_os}?load=test.py")
|
||||
assert resp_get.status_code == 200
|
||||
assert resp_get.data.decode() == "m0nk3y"
|
||||
assert resp_get.data == TEST_FILE_CONTENTS
|
||||
# Closing the response closes the file handle, else it can't be deleted
|
||||
resp_get.close()
|
||||
|
||||
|
@ -111,7 +143,7 @@ def test_pba_file_upload_endpoint(
|
|||
|
||||
|
||||
def test_pba_file_upload_endpoint__invalid(
|
||||
flask_client, monkeypatch, mock_set_config_value, mock_get_config_value
|
||||
flask_client, mock_set_config_value, mock_get_config_value
|
||||
):
|
||||
resp_post = flask_client.post(
|
||||
"/api/file-upload/bogus",
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
import stat
|
||||
|
||||
import pytest
|
||||
from tests.monkey_island.utils import assert_windows_permissions
|
||||
from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions
|
||||
|
||||
from monkey_island.cc.server_utils.file_utils import (
|
||||
create_secure_directory,
|
||||
|
@ -39,12 +39,8 @@ def test_create_secure_directory__no_parent_dir(test_path_nested):
|
|||
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
|
||||
def test_create_secure_directory__perm_linux(test_path):
|
||||
create_secure_directory(test_path)
|
||||
st = os.stat(test_path)
|
||||
|
||||
expected_mode = stat.S_IRWXU
|
||||
actual_mode = st.st_mode & (stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||
|
||||
assert expected_mode == actual_mode
|
||||
assert_linux_permissions(test_path)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
import io
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions
|
||||
|
||||
from monkey_island.cc.server_utils.file_utils import is_windows_os
|
||||
from monkey_island.cc.services import DirectoryFileStorageService, FileRetrievalError
|
||||
|
||||
|
||||
def test_error_if_storage_directory_is_file(tmp_path):
|
||||
new_file = tmp_path / "new_file.txt"
|
||||
new_file.write_text("HelloWorld!")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DirectoryFileStorageService(new_file)
|
||||
|
||||
|
||||
def test_directory_created(tmp_path):
|
||||
new_dir = tmp_path / "new_dir"
|
||||
|
||||
DirectoryFileStorageService(new_dir)
|
||||
|
||||
assert new_dir.exists() and new_dir.is_dir()
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
|
||||
def test_directory_permissions__linux(tmp_path):
|
||||
new_dir = tmp_path / "new_dir"
|
||||
|
||||
DirectoryFileStorageService(new_dir)
|
||||
|
||||
assert_linux_permissions(new_dir)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
|
||||
def test_directory_permissions__windows(tmp_path):
|
||||
new_dir = tmp_path / "new_dir"
|
||||
|
||||
DirectoryFileStorageService(new_dir)
|
||||
|
||||
assert_windows_permissions(new_dir)
|
||||
|
||||
|
||||
def save_file(tmp_path, file_path_prefix=""):
|
||||
file_name = "test.txt"
|
||||
file_contents = "Hello World!"
|
||||
expected_file_path = tmp_path / file_name
|
||||
|
||||
fss = DirectoryFileStorageService(tmp_path)
|
||||
fss.save_file(Path(file_path_prefix) / file_name, io.BytesIO(file_contents.encode()))
|
||||
|
||||
assert expected_file_path.is_file()
|
||||
assert expected_file_path.read_text() == file_contents
|
||||
|
||||
|
||||
def delete_file(tmp_path, file_path_prefix=""):
|
||||
file_name = "file.txt"
|
||||
file = tmp_path / file_name
|
||||
file.touch()
|
||||
assert file.is_file()
|
||||
|
||||
fss = DirectoryFileStorageService(tmp_path)
|
||||
fss.delete_file(Path(file_path_prefix) / file_name)
|
||||
|
||||
assert not file.exists()
|
||||
|
||||
|
||||
def open_file(tmp_path, file_path_prefix=""):
|
||||
file_name = "test.txt"
|
||||
expected_file_contents = "Hello World!"
|
||||
expected_file_path = tmp_path / file_name
|
||||
expected_file_path.write_text(expected_file_contents)
|
||||
|
||||
fss = DirectoryFileStorageService(tmp_path)
|
||||
with fss.open_file(Path(file_path_prefix) / file_name) as f:
|
||||
actual_file_contents = f.read()
|
||||
|
||||
assert actual_file_contents == expected_file_contents.encode()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
|
||||
def test_fn(tmp_path, fn):
|
||||
fn(tmp_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
|
||||
def test_fn__ignore_relative_path(tmp_path, fn):
|
||||
fn(tmp_path, "../../")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
|
||||
def test_fn__ignore_absolute_path(tmp_path, fn):
|
||||
if is_windows_os():
|
||||
fn(tmp_path, "C:\\Windows")
|
||||
else:
|
||||
fn(tmp_path, "/home/")
|
||||
|
||||
|
||||
def test_remove_all_files(tmp_path):
|
||||
for filename in ["1.txt", "2.txt", "3.txt"]:
|
||||
(tmp_path / filename).touch()
|
||||
|
||||
fss = DirectoryFileStorageService(tmp_path)
|
||||
fss.delete_all_files()
|
||||
|
||||
for file in tmp_path.iterdir():
|
||||
assert False, f"{tmp_path} was expected to be empty, but contained files"
|
||||
|
||||
|
||||
def test_remove_all_files__skip_directories(tmp_path):
|
||||
test_dir = tmp_path / "test_dir"
|
||||
test_dir.mkdir()
|
||||
for filename in ["1.txt", "2.txt", "3.txt"]:
|
||||
(tmp_path / filename).touch()
|
||||
|
||||
fss = DirectoryFileStorageService(tmp_path)
|
||||
fss.delete_all_files()
|
||||
|
||||
for file in tmp_path.iterdir():
|
||||
assert file.name == test_dir.name
|
||||
|
||||
|
||||
def test_remove_nonexistant_file(tmp_path):
|
||||
fss = DirectoryFileStorageService(tmp_path)
|
||||
|
||||
# This test will fail if this call raises an exception.
|
||||
fss.delete_file("nonexistant_file.txt")
|
||||
|
||||
|
||||
def test_open_nonexistant_file(tmp_path):
|
||||
fss = DirectoryFileStorageService(tmp_path)
|
||||
|
||||
with pytest.raises(FileRetrievalError):
|
||||
fss.open_file("nonexistant_file.txt")
|
|
@ -1,29 +1,31 @@
|
|||
import io
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from tests.monkey_island.utils import assert_windows_permissions
|
||||
from tests.utils import raise_
|
||||
|
||||
from monkey_island.cc.server_utils.file_utils import is_windows_os
|
||||
from monkey_island.cc.services import DirectoryFileStorageService
|
||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_storage_service(tmp_path):
|
||||
return DirectoryFileStorageService(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def custom_pba_directory(tmpdir):
|
||||
PostBreachFilesService.initialize(tmpdir)
|
||||
def post_breach_files_service(file_storage_service):
|
||||
PostBreachFilesService.initialize(file_storage_service)
|
||||
|
||||
|
||||
def create_custom_pba_file(filename):
|
||||
PostBreachFilesService.save_file(filename, b"")
|
||||
def test_remove_pba_files(file_storage_service, tmp_path):
|
||||
file_storage_service.save_file("linux_file", io.BytesIO(b""))
|
||||
file_storage_service.save_file("windows_file", io.BytesIO(b""))
|
||||
assert not dir_is_empty(tmp_path)
|
||||
|
||||
|
||||
def test_remove_pba_files():
|
||||
create_custom_pba_file("linux_file")
|
||||
create_custom_pba_file("windows_file")
|
||||
|
||||
assert not dir_is_empty(PostBreachFilesService.get_custom_pba_directory())
|
||||
PostBreachFilesService.remove_PBA_files()
|
||||
assert dir_is_empty(PostBreachFilesService.get_custom_pba_directory())
|
||||
|
||||
assert dir_is_empty(tmp_path)
|
||||
|
||||
|
||||
def dir_is_empty(dir_path):
|
||||
|
@ -31,45 +33,11 @@ def dir_is_empty(dir_path):
|
|||
return len(dir_contents) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
|
||||
def test_custom_pba_dir_permissions_linux():
|
||||
st = os.stat(PostBreachFilesService.get_custom_pba_directory())
|
||||
|
||||
assert st.st_mode == 0o40700
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
|
||||
def test_custom_pba_dir_permissions_windows():
|
||||
pba_dir = PostBreachFilesService.get_custom_pba_directory()
|
||||
|
||||
assert_windows_permissions(pba_dir)
|
||||
|
||||
|
||||
def test_remove_failure(monkeypatch):
|
||||
def test_remove_failure(file_storage_service, monkeypatch):
|
||||
monkeypatch.setattr(os, "remove", lambda x: raise_(OSError("Permission denied")))
|
||||
|
||||
try:
|
||||
create_custom_pba_file("windows_file")
|
||||
file_storage_service.save_file("windows_file", io.BytesIO(b""))
|
||||
PostBreachFilesService.remove_PBA_files()
|
||||
except Exception as ex:
|
||||
pytest.fail(f"Unxepected exception: {ex}")
|
||||
|
||||
|
||||
def test_remove_nonexistant_file(monkeypatch):
|
||||
monkeypatch.setattr(os, "remove", lambda x: raise_(FileNotFoundError("FileNotFound")))
|
||||
|
||||
try:
|
||||
PostBreachFilesService.remove_file("/nonexistant/file")
|
||||
except Exception as ex:
|
||||
pytest.fail(f"Unxepected exception: {ex}")
|
||||
|
||||
|
||||
def test_save_file():
|
||||
FILE_NAME = "test_file"
|
||||
FILE_CONTENTS = b"hello"
|
||||
PostBreachFilesService.save_file(FILE_NAME, FILE_CONTENTS)
|
||||
|
||||
expected_file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), FILE_NAME)
|
||||
|
||||
assert os.path.isfile(expected_file_path)
|
||||
assert FILE_CONTENTS == open(expected_file_path, "rb").read()
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import ctypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def is_user_admin():
|
||||
|
@ -11,3 +13,21 @@ def is_user_admin():
|
|||
|
||||
def raise_(ex):
|
||||
raise ex
|
||||
|
||||
|
||||
def add_subdirs_to_dir(parent_dir: Path, subdirs: Iterable[str]) -> Iterable[Path]:
|
||||
subdir_paths = [parent_dir / s for s in subdirs]
|
||||
|
||||
for subdir in subdir_paths:
|
||||
subdir.mkdir()
|
||||
|
||||
return subdir_paths
|
||||
|
||||
|
||||
def add_files_to_dir(parent_dir: Path, file_names: Iterable[str]) -> Iterable[Path]:
|
||||
files = [parent_dir / f for f in file_names]
|
||||
|
||||
for f in files:
|
||||
f.touch()
|
||||
|
||||
return files
|
||||
|
|
Loading…
Reference in New Issue