forked from p15670423/monkey
Merge pull request #1915 from guardicore/1904-service-resource-dependency-injection
1904 service resource dependency injection
This commit is contained in:
commit
97300376ef
|
@ -101,3 +101,6 @@ venv/
|
||||||
|
|
||||||
# Hugo
|
# Hugo
|
||||||
.hugo_build.lock
|
.hugo_build.lock
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from .di_container import DIContainer
|
|
@ -0,0 +1,133 @@
|
||||||
|
import inspect
|
||||||
|
from typing import Any, MutableMapping, Sequence, Type, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class UnregisteredTypeError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DIContainer:
|
||||||
|
"""
|
||||||
|
A dependency injection (DI) container that uses type annotations to resolve and inject
|
||||||
|
dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._type_registry = {}
|
||||||
|
self._instance_registry = {}
|
||||||
|
|
||||||
|
def register(self, interface: Type[T], concrete_type: Type[T]):
|
||||||
|
"""
|
||||||
|
Register a concrete type that satisfies a given interface.
|
||||||
|
|
||||||
|
:param interface: An interface or abstract base class that other classes depend upon
|
||||||
|
:param concrete_type: A type (class) that implements `interface`
|
||||||
|
"""
|
||||||
|
if not inspect.isclass(concrete_type):
|
||||||
|
raise TypeError(
|
||||||
|
"Expected a class, but received an instance of type "
|
||||||
|
f'"{concrete_type.__class__.__name__}"; Pass a class, not an instance, to '
|
||||||
|
"register(), or use register_instance() instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not issubclass(concrete_type, interface):
|
||||||
|
raise TypeError(
|
||||||
|
f'Class "{concrete_type.__name__}" is not a subclass of {interface.__name__}'
|
||||||
|
)
|
||||||
|
|
||||||
|
self._type_registry[interface] = concrete_type
|
||||||
|
DIContainer._del_key(self._instance_registry, interface)
|
||||||
|
|
||||||
|
def register_instance(self, interface: Type[T], instance: T):
|
||||||
|
"""
|
||||||
|
Register a concrete instance that satisfies a given interface.
|
||||||
|
|
||||||
|
:param interface: An interface or abstract base class that other classes depend upon
|
||||||
|
:param instance: An instance (object) of a type that implements `interface`
|
||||||
|
"""
|
||||||
|
if not isinstance(instance, interface):
|
||||||
|
raise TypeError(
|
||||||
|
f'The provided instance of type "{instance.__class__.__name__}" '
|
||||||
|
f"is not an instance of {interface.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._instance_registry[interface] = instance
|
||||||
|
DIContainer._del_key(self._type_registry, interface)
|
||||||
|
|
||||||
|
def resolve(self, type_: Type[T]) -> T:
|
||||||
|
"""
|
||||||
|
Resolves all dependencies and returns a new instance of `type_` using constructor dependency
|
||||||
|
injection. Note that only positional arguments are resolved. Varargs, keyword-only args, and
|
||||||
|
default values are ignored.
|
||||||
|
|
||||||
|
:param type_: A type (class) to construct
|
||||||
|
:return: An instance of `type_`
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self._resolve_type(type_)
|
||||||
|
except UnregisteredTypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args = self.resolve_dependencies(type_)
|
||||||
|
return type_(*args)
|
||||||
|
|
||||||
|
def resolve_dependencies(self, type_: Type[T]) -> Sequence[Any]:
|
||||||
|
"""
|
||||||
|
Resolves all dependencies of type_ and returns a Sequence of objects that correspond type_'s
|
||||||
|
dependencies. Note that only positional arguments are resolved. Varargs, keyword-only args,
|
||||||
|
and default values are ignored.
|
||||||
|
|
||||||
|
:param type_: A type (class) to resolve dependencies for
|
||||||
|
:return: An Sequence of dependencies to be injected into type_'s constructor
|
||||||
|
"""
|
||||||
|
args = []
|
||||||
|
|
||||||
|
for arg_type in inspect.getfullargspec(type_).annotations.values():
|
||||||
|
instance = self._resolve_type(arg_type)
|
||||||
|
args.append(instance)
|
||||||
|
|
||||||
|
return tuple(args)
|
||||||
|
|
||||||
|
def _resolve_type(self, type_: Type[T]) -> T:
|
||||||
|
if type_ in self._type_registry:
|
||||||
|
return self._construct_new_instance(type_)
|
||||||
|
elif type_ in self._instance_registry:
|
||||||
|
return self._retrieve_registered_instance(type_)
|
||||||
|
|
||||||
|
raise UnregisteredTypeError(f'Failed to resolve unregistered type "{type_.__name__}"')
|
||||||
|
|
||||||
|
def _construct_new_instance(self, arg_type: Type[T]) -> T:
|
||||||
|
try:
|
||||||
|
return self._type_registry[arg_type]()
|
||||||
|
except TypeError:
|
||||||
|
# arg_type has dependencies that must be resolved. Recursively call resolve() to
|
||||||
|
# construct an instance of arg_type with all of the requesite dependencies injected.
|
||||||
|
return self.resolve(self._type_registry[arg_type])
|
||||||
|
|
||||||
|
def _retrieve_registered_instance(self, arg_type: Type[T]) -> T:
|
||||||
|
return self._instance_registry[arg_type]
|
||||||
|
|
||||||
|
def release(self, interface: Type[T]):
|
||||||
|
"""
|
||||||
|
Deregister's an interface
|
||||||
|
|
||||||
|
:param interface: The interface to release
|
||||||
|
"""
|
||||||
|
DIContainer._del_key(self._type_registry, interface)
|
||||||
|
DIContainer._del_key(self._instance_registry, interface)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _del_key(mapping: MutableMapping[T, Any], key: T):
|
||||||
|
"""
|
||||||
|
Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError
|
||||||
|
if the key does not exist.
|
||||||
|
|
||||||
|
:param MutableMapping: A mapping from which a key will be deleted
|
||||||
|
:param key: A key to delete from `mapping`
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
del mapping[key]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
|
@ -1,6 +1,7 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
|
||||||
class InvalidPath(Exception):
|
class InvalidPath(Exception):
|
||||||
|
@ -21,3 +22,7 @@ def get_file_sha256_hash(filepath: Path):
|
||||||
sha256.update(block)
|
sha256.update(block)
|
||||||
|
|
||||||
return sha256.hexdigest()
|
return sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
|
||||||
|
return filter(lambda f: f.is_file(), dir_path.iterdir())
|
||||||
|
|
|
@ -2,10 +2,10 @@ import filecmp
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Set
|
from typing import Iterable, Set
|
||||||
|
|
||||||
|
from common.utils.file_utils import get_all_regular_files_in_directory
|
||||||
from infection_monkey.utils.dir_utils import (
|
from infection_monkey.utils.dir_utils import (
|
||||||
file_extension_filter,
|
file_extension_filter,
|
||||||
filter_files,
|
filter_files,
|
||||||
get_all_regular_files_in_directory,
|
|
||||||
is_not_shortcut_filter,
|
is_not_shortcut_filter,
|
||||||
is_not_symlink_filter,
|
is_not_symlink_filter,
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,10 +2,6 @@ from pathlib import Path
|
||||||
from typing import Callable, Iterable, Set
|
from typing import Callable, Iterable, Set
|
||||||
|
|
||||||
|
|
||||||
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
|
|
||||||
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
|
|
||||||
|
|
||||||
|
|
||||||
def filter_files(
|
def filter_files(
|
||||||
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
|
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
|
||||||
) -> Iterable[Path]:
|
) -> Iterable[Path]:
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
import flask_restful
|
import flask_restful
|
||||||
from flask import Flask, Response, send_from_directory
|
from flask import Flask, Response, send_from_directory
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from common import DIContainer
|
||||||
from monkey_island.cc.database import database, mongo
|
from monkey_island.cc.database import database, mongo
|
||||||
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
|
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
|
||||||
from monkey_island.cc.resources.attack.attack_report import AttackReport
|
from monkey_island.cc.resources.attack.attack_report import AttackReport
|
||||||
|
@ -111,7 +113,17 @@ def init_app_url_rules(app):
|
||||||
app.add_url_rule("/<path:static_path>", "serve_static_file", serve_static_file)
|
app.add_url_rule("/<path:static_path>", "serve_static_file", serve_static_file)
|
||||||
|
|
||||||
|
|
||||||
def init_api_resources(api):
|
class FlaskDIWrapper:
|
||||||
|
def __init__(self, api: flask_restful.Api, container: DIContainer):
|
||||||
|
self._api = api
|
||||||
|
self._container = container
|
||||||
|
|
||||||
|
def add_resource(self, resource: Type[flask_restful.Resource], *urls: str):
|
||||||
|
dependencies = self._container.resolve_dependencies(resource)
|
||||||
|
self._api.add_resource(resource, *urls, resource_class_args=dependencies)
|
||||||
|
|
||||||
|
|
||||||
|
def init_api_resources(api: FlaskDIWrapper):
|
||||||
api.add_resource(Root, "/api")
|
api.add_resource(Root, "/api")
|
||||||
api.add_resource(Registration, "/api/registration")
|
api.add_resource(Registration, "/api/registration")
|
||||||
api.add_resource(Authenticate, "/api/auth")
|
api.add_resource(Authenticate, "/api/auth")
|
||||||
|
@ -148,13 +160,18 @@ def init_api_resources(api):
|
||||||
api.add_resource(TelemetryFeed, "/api/telemetry-feed")
|
api.add_resource(TelemetryFeed, "/api/telemetry-feed")
|
||||||
api.add_resource(Log, "/api/log")
|
api.add_resource(Log, "/api/log")
|
||||||
api.add_resource(IslandLog, "/api/log/island/download")
|
api.add_resource(IslandLog, "/api/log/island/download")
|
||||||
api.add_resource(PBAFileDownload, "/api/pba/download/<string:filename>")
|
|
||||||
|
api.add_resource(
|
||||||
|
PBAFileDownload,
|
||||||
|
"/api/pba/download/<string:filename>",
|
||||||
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
FileUpload,
|
FileUpload,
|
||||||
"/api/file-upload/<string:file_type>",
|
"/api/file-upload/<string:target_os>",
|
||||||
"/api/file-upload/<string:file_type>?load=<string:filename>",
|
"/api/file-upload/<string:target_os>?load=<string:filename>",
|
||||||
"/api/file-upload/<string:file_type>?restore=<string:filename>",
|
"/api/file-upload/<string:target_os>?restore=<string:filename>",
|
||||||
)
|
)
|
||||||
|
|
||||||
api.add_resource(PropagationCredentials, "/api/propagation-credentials/<string:guid>")
|
api.add_resource(PropagationCredentials, "/api/propagation-credentials/<string:guid>")
|
||||||
api.add_resource(RemoteRun, "/api/remote-monkey")
|
api.add_resource(RemoteRun, "/api/remote-monkey")
|
||||||
api.add_resource(VersionUpdate, "/api/version-update")
|
api.add_resource(VersionUpdate, "/api/version-update")
|
||||||
|
@ -168,7 +185,7 @@ def init_api_resources(api):
|
||||||
api.add_resource(TelemetryBlackboxEndpoint, "/api/test/telemetry")
|
api.add_resource(TelemetryBlackboxEndpoint, "/api/test/telemetry")
|
||||||
|
|
||||||
|
|
||||||
def init_app(mongo_url):
|
def init_app(mongo_url: str, container: DIContainer):
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
api = flask_restful.Api(app)
|
api = flask_restful.Api(app)
|
||||||
|
@ -177,6 +194,8 @@ def init_app(mongo_url):
|
||||||
init_app_config(app, mongo_url)
|
init_app_config(app, mongo_url)
|
||||||
init_app_services(app)
|
init_app_services(app)
|
||||||
init_app_url_rules(app)
|
init_app_url_rules(app)
|
||||||
init_api_resources(api)
|
|
||||||
|
flask_resource_manager = FlaskDIWrapper(api, container)
|
||||||
|
init_api_resources(flask_resource_manager)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import flask_restful
|
import logging
|
||||||
from flask import send_from_directory
|
|
||||||
|
|
||||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
import flask_restful
|
||||||
|
from flask import make_response, send_file
|
||||||
|
|
||||||
|
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
class PBAFileDownload(flask_restful.Resource):
|
class PBAFileDownload(flask_restful.Resource):
|
||||||
|
@ -9,7 +13,17 @@ class PBAFileDownload(flask_restful.Resource):
|
||||||
File download endpoint used by monkey to download user's PBA file
|
File download endpoint used by monkey to download user's PBA file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file_storage_service: IFileStorageService):
|
||||||
|
self._file_storage_service = file_storage_service
|
||||||
|
|
||||||
# Used by monkey. can't secure.
|
# Used by monkey. can't secure.
|
||||||
def get(self, filename):
|
def get(self, filename: str):
|
||||||
custom_pba_dir = PostBreachFilesService.get_custom_pba_directory()
|
try:
|
||||||
return send_from_directory(custom_pba_dir, filename)
|
file = self._file_storage_service.open_file(filename)
|
||||||
|
|
||||||
|
# `send_file()` handles the closing of the open file.
|
||||||
|
return send_file(file, mimetype="application/octet-stream")
|
||||||
|
except FileRetrievalError as err:
|
||||||
|
error_msg = f"Failed to open file {filename}: {err}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return make_response({"error": error_msg}, 404)
|
||||||
|
|
|
@ -1,15 +1,19 @@
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
import flask_restful
|
import flask_restful
|
||||||
from flask import Response, request, send_from_directory
|
from flask import Response, make_response, request, send_file
|
||||||
from werkzeug.datastructures import FileStorage
|
from werkzeug.datastructures import FileStorage
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
from common.config_value_paths import PBA_LINUX_FILENAME_PATH, PBA_WINDOWS_FILENAME_PATH
|
from common.config_value_paths import PBA_LINUX_FILENAME_PATH, PBA_WINDOWS_FILENAME_PATH
|
||||||
from monkey_island.cc.resources.auth.auth import jwt_required
|
from monkey_island.cc.resources.auth.auth import jwt_required
|
||||||
|
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||||
from monkey_island.cc.services.config import ConfigService
|
from monkey_island.cc.services.config import ConfigService
|
||||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
# Front end uses these strings to identify which files to work with (linux or windows)
|
# Front end uses these strings to identify which files to work with (linux or windows)
|
||||||
LINUX_PBA_TYPE = "PBAlinux"
|
LINUX_PBA_TYPE = "PBAlinux"
|
||||||
|
@ -21,42 +25,60 @@ class FileUpload(flask_restful.Resource):
|
||||||
File upload endpoint used to exchange files with filepond component on the front-end
|
File upload endpoint used to exchange files with filepond component on the front-end
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file_storage_service: IFileStorageService):
|
||||||
|
self._file_storage_service = file_storage_service
|
||||||
|
|
||||||
|
# TODO: Fix references/coupling to filepond
|
||||||
|
# TODO: Add comment explaining why this is basically a duplicate of the endpoint in the
|
||||||
|
# PBAFileDownload resource.
|
||||||
@jwt_required
|
@jwt_required
|
||||||
def get(self, file_type):
|
def get(self, target_os):
|
||||||
"""
|
"""
|
||||||
Sends file to filepond
|
Sends file to filepond
|
||||||
:param file_type: Type indicates which file to send, linux or windows
|
:param target_os: Indicates which file to send, linux or windows
|
||||||
:return: Returns file contents
|
:return: Returns file contents
|
||||||
"""
|
"""
|
||||||
if self.is_pba_file_type_supported(file_type):
|
if self._is_target_os_supported(target_os):
|
||||||
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
||||||
|
|
||||||
# Verify that file_name is indeed a file from config
|
# Verify that file_name is indeed a file from config
|
||||||
if file_type == LINUX_PBA_TYPE:
|
if target_os == LINUX_PBA_TYPE:
|
||||||
|
# TODO: Make these paths Tuples so we don't need to copy them
|
||||||
filename = ConfigService.get_config_value(copy.deepcopy(PBA_LINUX_FILENAME_PATH))
|
filename = ConfigService.get_config_value(copy.deepcopy(PBA_LINUX_FILENAME_PATH))
|
||||||
else:
|
else:
|
||||||
filename = ConfigService.get_config_value(copy.deepcopy(PBA_WINDOWS_FILENAME_PATH))
|
filename = ConfigService.get_config_value(copy.deepcopy(PBA_WINDOWS_FILENAME_PATH))
|
||||||
return send_from_directory(PostBreachFilesService.get_custom_pba_directory(), filename)
|
|
||||||
|
try:
|
||||||
|
file = self._file_storage_service.open_file(filename)
|
||||||
|
|
||||||
|
# `send_file()` handles the closing of the open file.
|
||||||
|
return send_file(file, mimetype="application/octet-stream")
|
||||||
|
except FileRetrievalError as err:
|
||||||
|
error_msg = f"Failed to open file {filename}: {err}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return make_response({"error": error_msg}, 404)
|
||||||
|
|
||||||
@jwt_required
|
@jwt_required
|
||||||
def post(self, file_type):
|
def post(self, target_os):
|
||||||
"""
|
"""
|
||||||
Receives user's uploaded file from filepond
|
Receives user's uploaded file from filepond
|
||||||
:param file_type: Type indicates which file was received, linux or windows
|
:param target_os: Type indicates which file was received, linux or windows
|
||||||
:return: Returns flask response object with uploaded file's filename
|
:return: Returns flask response object with uploaded file's filename
|
||||||
"""
|
"""
|
||||||
if self.is_pba_file_type_supported(file_type):
|
if self._is_target_os_supported(target_os):
|
||||||
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
||||||
|
|
||||||
filename = FileUpload.upload_pba_file(
|
filename = self._upload_pba_file(
|
||||||
request.files["filepond"], (file_type == LINUX_PBA_TYPE)
|
# TODO: This "filepond" string can be changed to be more generic in the `react-filepond`
|
||||||
|
# component.
|
||||||
|
request.files["filepond"],
|
||||||
|
(target_os == LINUX_PBA_TYPE),
|
||||||
)
|
)
|
||||||
|
|
||||||
response = Response(response=filename, status=200, mimetype="text/plain")
|
response = Response(response=filename, status=200, mimetype="text/plain")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@staticmethod
|
def _upload_pba_file(self, file_storage: FileStorage, is_linux=True):
|
||||||
def upload_pba_file(file_storage: FileStorage, is_linux=True):
|
|
||||||
"""
|
"""
|
||||||
Uploads PBA file to island's file system
|
Uploads PBA file to island's file system
|
||||||
:param request_: Request object containing PBA file
|
:param request_: Request object containing PBA file
|
||||||
|
@ -64,9 +86,7 @@ class FileUpload(flask_restful.Resource):
|
||||||
:return: filename string
|
:return: filename string
|
||||||
"""
|
"""
|
||||||
filename = secure_filename(file_storage.filename)
|
filename = secure_filename(file_storage.filename)
|
||||||
file_contents = file_storage.read()
|
self._file_storage_service.save_file(filename, file_storage.stream)
|
||||||
|
|
||||||
PostBreachFilesService.save_file(filename, file_contents)
|
|
||||||
|
|
||||||
ConfigService.set_config_value(
|
ConfigService.set_config_value(
|
||||||
(PBA_LINUX_FILENAME_PATH if is_linux else PBA_WINDOWS_FILENAME_PATH), filename
|
(PBA_LINUX_FILENAME_PATH if is_linux else PBA_WINDOWS_FILENAME_PATH), filename
|
||||||
|
@ -75,25 +95,25 @@ class FileUpload(flask_restful.Resource):
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
@jwt_required
|
@jwt_required
|
||||||
def delete(self, file_type):
|
def delete(self, target_os):
|
||||||
"""
|
"""
|
||||||
Deletes file that has been deleted on the front end
|
Deletes file that has been deleted on the front end
|
||||||
:param file_type: Type indicates which file was deleted, linux of windows
|
:param target_os: Type indicates which file was deleted, linux of windows
|
||||||
:return: Empty response
|
:return: Empty response
|
||||||
"""
|
"""
|
||||||
if self.is_pba_file_type_supported(file_type):
|
if self._is_target_os_supported(target_os):
|
||||||
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
|
||||||
|
|
||||||
filename_path = (
|
filename_path = (
|
||||||
PBA_LINUX_FILENAME_PATH if file_type == "PBAlinux" else PBA_WINDOWS_FILENAME_PATH
|
PBA_LINUX_FILENAME_PATH if target_os == "PBAlinux" else PBA_WINDOWS_FILENAME_PATH
|
||||||
)
|
)
|
||||||
filename = ConfigService.get_config_value(filename_path)
|
filename = ConfigService.get_config_value(filename_path)
|
||||||
if filename:
|
if filename:
|
||||||
PostBreachFilesService.remove_file(filename)
|
self._file_storage_service.delete_file(filename)
|
||||||
ConfigService.set_config_value(filename_path, "")
|
ConfigService.set_config_value(filename_path, "")
|
||||||
|
|
||||||
return {}
|
return make_response({}, 200)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_pba_file_type_supported(file_type: str) -> bool:
|
def _is_target_os_supported(target_os: str) -> bool:
|
||||||
return file_type not in {LINUX_PBA_TYPE, WINDOWS_PBA_TYPE}
|
return target_os not in {LINUX_PBA_TYPE, WINDOWS_PBA_TYPE}
|
||||||
|
|
|
@ -16,6 +16,7 @@ MONKEY_ISLAND_DIR_BASE_PATH = str(Path(__file__).parent.parent)
|
||||||
if str(MONKEY_ISLAND_DIR_BASE_PATH) not in sys.path:
|
if str(MONKEY_ISLAND_DIR_BASE_PATH) not in sys.path:
|
||||||
sys.path.insert(0, MONKEY_ISLAND_DIR_BASE_PATH)
|
sys.path.insert(0, MONKEY_ISLAND_DIR_BASE_PATH)
|
||||||
|
|
||||||
|
from common import DIContainer # noqa: E402
|
||||||
from common.version import get_version # noqa: E402
|
from common.version import get_version # noqa: E402
|
||||||
from monkey_island.cc.app import init_app # noqa: E402
|
from monkey_island.cc.app import init_app # noqa: E402
|
||||||
from monkey_island.cc.arg_parser import IslandCmdArgs # noqa: E402
|
from monkey_island.cc.arg_parser import IslandCmdArgs # noqa: E402
|
||||||
|
@ -47,7 +48,7 @@ def run_monkey_island():
|
||||||
_exit_on_invalid_config_options(config_options)
|
_exit_on_invalid_config_options(config_options)
|
||||||
|
|
||||||
_configure_logging(config_options)
|
_configure_logging(config_options)
|
||||||
_initialize_globals(config_options.data_dir)
|
container = _initialize_di_container(config_options.data_dir)
|
||||||
|
|
||||||
mongo_db_process = None
|
mongo_db_process = None
|
||||||
if config_options.start_mongodb:
|
if config_options.start_mongodb:
|
||||||
|
@ -56,7 +57,7 @@ def run_monkey_island():
|
||||||
_connect_to_mongodb(mongo_db_process)
|
_connect_to_mongodb(mongo_db_process)
|
||||||
|
|
||||||
_configure_gevent_exception_handling(config_options.data_dir)
|
_configure_gevent_exception_handling(config_options.data_dir)
|
||||||
_start_island_server(island_args.setup_only, config_options)
|
_start_island_server(island_args.setup_only, config_options, container)
|
||||||
|
|
||||||
|
|
||||||
def _extract_config(island_args: IslandCmdArgs) -> IslandConfigOptions:
|
def _extract_config(island_args: IslandCmdArgs) -> IslandConfigOptions:
|
||||||
|
@ -88,8 +89,8 @@ def _configure_logging(config_options):
|
||||||
setup_logging(config_options.data_dir, config_options.log_level)
|
setup_logging(config_options.data_dir, config_options.log_level)
|
||||||
|
|
||||||
|
|
||||||
def _initialize_globals(data_dir: Path):
|
def _initialize_di_container(data_dir: Path) -> DIContainer:
|
||||||
initialize_services(data_dir)
|
return initialize_services(data_dir)
|
||||||
|
|
||||||
|
|
||||||
def _start_mongodb(data_dir: Path) -> MongoDbProcess:
|
def _start_mongodb(data_dir: Path) -> MongoDbProcess:
|
||||||
|
@ -127,9 +128,11 @@ def _configure_gevent_exception_handling(data_dir):
|
||||||
hub.handle_error = GeventHubErrorHandler(hub, logger)
|
hub.handle_error = GeventHubErrorHandler(hub, logger)
|
||||||
|
|
||||||
|
|
||||||
def _start_island_server(should_setup_only, config_options: IslandConfigOptions):
|
def _start_island_server(
|
||||||
|
should_setup_only: bool, config_options: IslandConfigOptions, container: DIContainer
|
||||||
|
):
|
||||||
populate_exporter_list()
|
populate_exporter_list()
|
||||||
app = init_app(mongo_setup.MONGO_URL)
|
app = init_app(mongo_setup.MONGO_URL, container)
|
||||||
|
|
||||||
if should_setup_only:
|
if should_setup_only:
|
||||||
logger.warning("Setup only flag passed. Exiting.")
|
logger.warning("Setup only flag passed. Exiting.")
|
||||||
|
|
|
@ -1,2 +1,5 @@
|
||||||
|
from .i_file_storage_service import IFileStorageService, FileRetrievalError
|
||||||
|
from .directory_file_storage_service import DirectoryFileStorageService
|
||||||
|
|
||||||
from .authentication.authentication_service import AuthenticationService
|
from .authentication.authentication_service import AuthenticationService
|
||||||
from .authentication.json_file_user_datastore import JsonFileUserDatastore
|
from .authentication.json_file_user_datastore import JsonFileUserDatastore
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
from common.utils.exceptions import (
|
from common.utils.exceptions import (
|
||||||
|
@ -23,7 +25,7 @@ class AuthenticationService:
|
||||||
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
||||||
# not a priority.
|
# not a priority.
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, data_dir: str, user_datastore: IUserDatastore):
|
def initialize(cls, data_dir: Path, user_datastore: IUserDatastore):
|
||||||
cls.DATA_DIR = data_dir
|
cls.DATA_DIR = data_dir
|
||||||
cls.user_datastore = user_datastore
|
cls.user_datastore = user_datastore
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import BinaryIO
|
||||||
|
|
||||||
|
from common.utils.file_utils import get_all_regular_files_in_directory
|
||||||
|
from monkey_island.cc.server_utils.file_utils import create_secure_directory
|
||||||
|
|
||||||
|
from . import FileRetrievalError, IFileStorageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryFileStorageService(IFileStorageService):
|
||||||
|
"""
|
||||||
|
A implementation of IFileStorageService that reads and writes files from/to the local
|
||||||
|
filesystem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, storage_directory: Path):
|
||||||
|
"""
|
||||||
|
:param storage_directory: A Path object representing the directory where files will be
|
||||||
|
stored. If the directory does not exist, it will be created.
|
||||||
|
"""
|
||||||
|
if storage_directory.exists() and not storage_directory.is_dir():
|
||||||
|
raise ValueError(f"The provided path must point to a directory: {storage_directory}")
|
||||||
|
|
||||||
|
if not storage_directory.exists():
|
||||||
|
create_secure_directory(storage_directory)
|
||||||
|
|
||||||
|
self._storage_directory = storage_directory
|
||||||
|
|
||||||
|
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
|
||||||
|
safe_file_path = self._get_safe_file_path(unsafe_file_name)
|
||||||
|
|
||||||
|
logger.debug(f"Saving file contents to {safe_file_path}")
|
||||||
|
with open(safe_file_path, "wb") as dest:
|
||||||
|
shutil.copyfileobj(file_contents, dest)
|
||||||
|
|
||||||
|
def open_file(self, unsafe_file_name: str) -> BinaryIO:
|
||||||
|
safe_file_path = self._get_safe_file_path(unsafe_file_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"Opening {safe_file_path}")
|
||||||
|
return open(safe_file_path, "rb")
|
||||||
|
except OSError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise FileRetrievalError(f"Failed to retrieve file {safe_file_path}: {err}") from err
|
||||||
|
|
||||||
|
def delete_file(self, unsafe_file_name: str):
|
||||||
|
safe_file_path = self._get_safe_file_path(unsafe_file_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"Deleting {safe_file_path}")
|
||||||
|
safe_file_path.unlink()
|
||||||
|
except FileNotFoundError:
|
||||||
|
# This method is idempotent.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_safe_file_path(self, unsafe_file_name: str):
|
||||||
|
# Remove any path information from the file name.
|
||||||
|
safe_file_name = Path(unsafe_file_name).resolve().name
|
||||||
|
safe_file_path = (self._storage_directory / safe_file_name).resolve()
|
||||||
|
|
||||||
|
# This is a paranoid check to avoid directory traversal attacks.
|
||||||
|
if self._storage_directory.resolve() not in safe_file_path.parents:
|
||||||
|
raise ValueError(f"The file named {unsafe_file_name} can not be safely retrieved")
|
||||||
|
|
||||||
|
logger.debug(f"Unsafe file name {unsafe_file_name} sanitized: {safe_file_path}")
|
||||||
|
return safe_file_path
|
||||||
|
|
||||||
|
def delete_all_files(self):
|
||||||
|
for file in get_all_regular_files_in_directory(self._storage_directory):
|
||||||
|
logger.debug(f"Deleting {file}")
|
||||||
|
file.unlink()
|
|
@ -0,0 +1,53 @@
|
||||||
|
import abc
|
||||||
|
from typing import BinaryIO
|
||||||
|
|
||||||
|
|
||||||
|
class FileRetrievalError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IFileStorageService(metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
A service that allows the storage and retrieval of individual files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
|
||||||
|
"""
|
||||||
|
Save a file, identified by a name
|
||||||
|
|
||||||
|
:param unsafe_file_name: An unsanitized file name that will identify the file
|
||||||
|
:param file_contents: The data to be stored in the file
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def open_file(self, unsafe_file_name: str) -> BinaryIO:
|
||||||
|
"""
|
||||||
|
Open a file and return a file-like object
|
||||||
|
|
||||||
|
:param unsafe_file_name: An unsanitized file name that identifies the file to be opened
|
||||||
|
:return: A file-like object providing access to the file's contents
|
||||||
|
:rtype: io.BinaryIO
|
||||||
|
:raises FileRetrievalError: if the file cannot be opened
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def delete_file(self, unsafe_file_name: str):
|
||||||
|
"""
|
||||||
|
Delete a file
|
||||||
|
|
||||||
|
This method will delete the file specified by `unsafe_file_name`. This operation is
|
||||||
|
idempotent and will succeed if the file to be deleted does not exist.
|
||||||
|
|
||||||
|
:param unsafe_file_name: An unsanitized file name that identifies the file to be deleted
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def delete_all_files(self):
|
||||||
|
"""
|
||||||
|
Delete all files that have been stored using this service.
|
||||||
|
"""
|
||||||
|
pass
|
|
@ -1,10 +1,22 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from common import DIContainer
|
||||||
|
from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService
|
||||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||||
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
||||||
|
|
||||||
from . import AuthenticationService, JsonFileUserDatastore
|
from . import AuthenticationService, JsonFileUserDatastore
|
||||||
|
|
||||||
|
|
||||||
def initialize_services(data_dir):
|
def initialize_services(data_dir: Path) -> DIContainer:
|
||||||
PostBreachFilesService.initialize(data_dir)
|
container = DIContainer()
|
||||||
|
container.register_instance(
|
||||||
|
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is temporary until we get DI all worked out.
|
||||||
|
PostBreachFilesService.initialize(container.resolve(IFileStorageService))
|
||||||
LocalMonkeyRunService.initialize(data_dir)
|
LocalMonkeyRunService.initialize(data_dir)
|
||||||
AuthenticationService.initialize(data_dir, JsonFileUserDatastore(data_dir))
|
AuthenticationService.initialize(data_dir, JsonFileUserDatastore(data_dir))
|
||||||
|
|
||||||
|
return container
|
||||||
|
|
|
@ -1,46 +1,24 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
from monkey_island.cc.server_utils.file_utils import create_secure_directory
|
from monkey_island.cc.services import IFileStorageService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This service wraps an IFileStorageService for the sole purpose of making the
|
||||||
|
# `remove_PBA_files()` method available to the ConfigService. This whole service can be
|
||||||
|
# removed once ConfigService is refactored to be stateful (it already is but everything is
|
||||||
|
# still statically/globally scoped) and use dependency injection.
|
||||||
class PostBreachFilesService:
|
class PostBreachFilesService:
|
||||||
DATA_DIR = None
|
_file_storage_service = None
|
||||||
CUSTOM_PBA_DIRNAME = "custom_pbas"
|
|
||||||
|
|
||||||
# TODO: A number of these services should be instance objects instead of
|
# TODO: A number of these services should be instance objects instead of
|
||||||
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
||||||
# not a priority.
|
# not a priority.
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, data_dir):
|
def initialize(cls, file_storage_service: IFileStorageService):
|
||||||
cls.DATA_DIR = data_dir
|
cls._file_storage_service = file_storage_service
|
||||||
custom_pba_dir = cls.get_custom_pba_directory()
|
|
||||||
create_secure_directory(custom_pba_dir)
|
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def save_file(filename: str, file_contents: bytes):
|
def remove_PBA_files(cls):
|
||||||
file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), filename)
|
cls._file_storage_service.delete_all_files()
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(file_contents)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def remove_PBA_files():
|
|
||||||
for f in os.listdir(PostBreachFilesService.get_custom_pba_directory()):
|
|
||||||
PostBreachFilesService.remove_file(f)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def remove_file(file_name):
|
|
||||||
file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), file_name)
|
|
||||||
try:
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
os.remove(file_path)
|
|
||||||
except OSError as e:
|
|
||||||
logger.error("Can't remove previously uploaded post breach files: %s" % e)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_custom_pba_directory():
|
|
||||||
return os.path.join(
|
|
||||||
PostBreachFilesService.DATA_DIR, PostBreachFilesService.CUSTOM_PBA_DIRNAME
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import platform
|
import platform
|
||||||
import stat
|
import stat
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
|
|
||||||
from monkey_island.cc.resources.monkey_download import get_agent_executable_path
|
from monkey_island.cc.resources.monkey_download import get_agent_executable_path
|
||||||
|
@ -19,7 +20,7 @@ class LocalMonkeyRunService:
|
||||||
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
||||||
# not a priority.
|
# not a priority.
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, data_dir):
|
def initialize(cls, data_dir: Path):
|
||||||
cls.DATA_DIR = data_dir
|
cls.DATA_DIR = data_dir
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -7,6 +7,9 @@ if is_windows_os():
|
||||||
FULL_CONTROL = 2032127
|
FULL_CONTROL = 2032127
|
||||||
ACE_ACCESS_MODE_GRANT_ACCESS = win32security.GRANT_ACCESS
|
ACE_ACCESS_MODE_GRANT_ACCESS = win32security.GRANT_ACCESS
|
||||||
ACE_INHERIT_OBJECT_AND_CONTAINER = 3
|
ACE_INHERIT_OBJECT_AND_CONTAINER = 3
|
||||||
|
else:
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
|
||||||
|
|
||||||
def _get_acl_and_sid_from_path(path: str):
|
def _get_acl_and_sid_from_path(path: str):
|
||||||
|
@ -33,3 +36,12 @@ def assert_windows_permissions(path: str):
|
||||||
assert ace_sid == user_sid
|
assert ace_sid == user_sid
|
||||||
assert ace_permissions == FULL_CONTROL and ace_access_mode == ACE_ACCESS_MODE_GRANT_ACCESS
|
assert ace_permissions == FULL_CONTROL and ace_access_mode == ACE_ACCESS_MODE_GRANT_ACCESS
|
||||||
assert ace_inheritance == ACE_INHERIT_OBJECT_AND_CONTAINER
|
assert ace_inheritance == ACE_INHERIT_OBJECT_AND_CONTAINER
|
||||||
|
|
||||||
|
|
||||||
|
def assert_linux_permissions(path: str):
|
||||||
|
st = os.stat(path)
|
||||||
|
|
||||||
|
expected_mode = stat.S_IRWXU
|
||||||
|
actual_mode = st.st_mode & (stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||||
|
|
||||||
|
assert expected_mode == actual_mode
|
||||||
|
|
|
@ -0,0 +1,283 @@
|
||||||
|
import abc
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common import DIContainer
|
||||||
|
|
||||||
|
|
||||||
|
class IServiceA(metaclass=abc.ABCMeta):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def do_something(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IServiceB(metaclass=abc.ABCMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceA(IServiceA):
|
||||||
|
def do_something(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceB(IServiceB):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass1:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, service_a: IServiceA):
|
||||||
|
self.service_a = service_a
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass2:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, service_b: IServiceB):
|
||||||
|
self.service_b = service_b
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass3:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, service_a: IServiceA, service_b: IServiceB):
|
||||||
|
self.service_a = service_a
|
||||||
|
self.service_b = service_b
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def container():
|
||||||
|
return DIContainer()
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_resolve(container):
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
test_1 = container.resolve(TestClass1)
|
||||||
|
|
||||||
|
assert isinstance(test_1.service_a, ServiceA)
|
||||||
|
|
||||||
|
|
||||||
|
def test_correct_instance_type_injected(container):
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.register(IServiceB, ServiceB)
|
||||||
|
test_1 = container.resolve(TestClass1)
|
||||||
|
test_2 = container.resolve(TestClass2)
|
||||||
|
|
||||||
|
assert isinstance(test_1.service_a, ServiceA)
|
||||||
|
assert isinstance(test_2.service_b, ServiceB)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_correct_instance_types_injected(container):
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.register(IServiceB, ServiceB)
|
||||||
|
test_3 = container.resolve(TestClass3)
|
||||||
|
|
||||||
|
assert isinstance(test_3.service_a, ServiceA)
|
||||||
|
assert isinstance(test_3.service_b, ServiceB)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_instance(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
test_1 = container.resolve(TestClass1)
|
||||||
|
|
||||||
|
assert id(service_a_instance) == id(test_1.service_a)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_multiple_instances(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
service_b_instance = ServiceB()
|
||||||
|
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
container.register_instance(IServiceB, service_b_instance)
|
||||||
|
test_3 = container.resolve(TestClass3)
|
||||||
|
|
||||||
|
assert id(service_a_instance) == id(test_3.service_a)
|
||||||
|
assert id(service_b_instance) == id(test_3.service_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_mixed_instance_and_type(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
container.register(IServiceB, ServiceB)
|
||||||
|
test_2 = container.resolve(TestClass2)
|
||||||
|
test_3 = container.resolve(TestClass3)
|
||||||
|
|
||||||
|
assert id(service_a_instance) == id(test_3.service_a)
|
||||||
|
assert isinstance(test_2.service_b, ServiceB)
|
||||||
|
assert isinstance(test_3.service_b, ServiceB)
|
||||||
|
assert id(test_2.service_b) != id(test_3.service_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unregistered_type():
|
||||||
|
container = DIContainer()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
container.resolve(TestClass1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_type_registration_overwritten(container):
|
||||||
|
class ServiceA2(IServiceA):
|
||||||
|
def do_something(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.register(IServiceA, ServiceA2)
|
||||||
|
test_1 = container.resolve(TestClass1)
|
||||||
|
|
||||||
|
assert isinstance(test_1.service_a, ServiceA2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_instance_registration_overwritten(container):
|
||||||
|
service_a_instance_1 = ServiceA()
|
||||||
|
service_a_instance_2 = ServiceA()
|
||||||
|
|
||||||
|
container.register_instance(IServiceA, service_a_instance_1)
|
||||||
|
container.register_instance(IServiceA, service_a_instance_2)
|
||||||
|
test_1 = container.resolve(TestClass1)
|
||||||
|
|
||||||
|
assert id(test_1.service_a) != id(service_a_instance_1)
|
||||||
|
assert id(test_1.service_a) == id(service_a_instance_2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_type_overrides_instance(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
test_1 = container.resolve(TestClass1)
|
||||||
|
|
||||||
|
assert id(test_1.service_a) != id(service_a_instance)
|
||||||
|
assert isinstance(test_1.service_a, ServiceA)
|
||||||
|
|
||||||
|
|
||||||
|
def test_instance_overrides_type(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
test_1 = container.resolve(TestClass1)
|
||||||
|
|
||||||
|
assert id(test_1.service_a) == id(service_a_instance)
|
||||||
|
|
||||||
|
|
||||||
|
def test_release_type(container):
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.release(IServiceA)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
container.resolve(TestClass1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_release_instance(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
|
||||||
|
container.release(IServiceA)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
container.resolve(TestClass1)
|
||||||
|
|
||||||
|
|
||||||
|
class IServiceC(metaclass=abc.ABCMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceC(IServiceC):
|
||||||
|
def __init__(self, service_a: IServiceA):
|
||||||
|
self.service_a = service_a
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass4:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, service_c: IServiceC):
|
||||||
|
self.service_c = service_c
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_resolution__depth_2(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
container.register(IServiceC, ServiceC)
|
||||||
|
|
||||||
|
test4 = container.resolve(TestClass4)
|
||||||
|
|
||||||
|
assert isinstance(test4.service_c, ServiceC)
|
||||||
|
assert id(test4.service_c.service_a) == id(service_a_instance)
|
||||||
|
|
||||||
|
|
||||||
|
class IServiceD(metaclass=abc.ABCMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceD(IServiceD):
|
||||||
|
def __init__(self, service_c: IServiceC, service_b: IServiceB):
|
||||||
|
self.service_b = service_b
|
||||||
|
self.service_c = service_c
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass5:
|
||||||
|
__test__ = False
|
||||||
|
|
||||||
|
def __init__(self, service_d: IServiceD):
|
||||||
|
self.service_d = service_d
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_resolution__depth_3(container):
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.register(IServiceB, ServiceB)
|
||||||
|
container.register(IServiceC, ServiceC)
|
||||||
|
container.register(IServiceD, ServiceD)
|
||||||
|
|
||||||
|
test5 = container.resolve(TestClass5)
|
||||||
|
|
||||||
|
assert isinstance(test5.service_d, ServiceD)
|
||||||
|
assert isinstance(test5.service_d.service_b, ServiceB)
|
||||||
|
assert isinstance(test5.service_d.service_c, ServiceC)
|
||||||
|
assert isinstance(test5.service_d.service_c.service_a, ServiceA)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_registered_interface(container):
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
|
||||||
|
resolved_instance = container.resolve(IServiceA)
|
||||||
|
|
||||||
|
assert isinstance(resolved_instance, ServiceA)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_registered_instance(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
container.register_instance(IServiceA, service_a_instance)
|
||||||
|
|
||||||
|
service_a_actual_instance = container.resolve(IServiceA)
|
||||||
|
|
||||||
|
assert id(service_a_actual_instance) == id(service_a_instance)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_dependencies(container):
|
||||||
|
container.register(IServiceA, ServiceA)
|
||||||
|
container.register(IServiceB, ServiceB)
|
||||||
|
|
||||||
|
dependencies = container.resolve_dependencies(TestClass3)
|
||||||
|
|
||||||
|
assert isinstance(dependencies[0], ServiceA)
|
||||||
|
assert isinstance(dependencies[1], ServiceB)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_instance_as_type(container):
|
||||||
|
service_a_instance = ServiceA()
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
container.register(IServiceA, service_a_instance)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_conflicting_type(container):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
container.register(IServiceA, ServiceB)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_instance_with_conflicting_type(container):
|
||||||
|
service_b_instance = ServiceB()
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
container.register_instance(IServiceA, service_b_instance)
|
|
@ -1,8 +1,14 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from tests.utils import add_files_to_dir, add_subdirs_to_dir
|
||||||
|
|
||||||
from common.utils.file_utils import InvalidPath, expand_path, get_file_sha256_hash
|
from common.utils.file_utils import (
|
||||||
|
InvalidPath,
|
||||||
|
expand_path,
|
||||||
|
get_all_regular_files_in_directory,
|
||||||
|
get_file_sha256_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_expand_user(patched_home_env):
|
def test_expand_user(patched_home_env):
|
||||||
|
@ -26,3 +32,32 @@ def test_expand_path__empty_path_provided():
|
||||||
|
|
||||||
def test_get_file_sha256_hash(stable_file, stable_file_sha256_hash):
|
def test_get_file_sha256_hash(stable_file, stable_file_sha256_hash):
|
||||||
assert get_file_sha256_hash(stable_file) == stable_file_sha256_hash
|
assert get_file_sha256_hash(stable_file) == stable_file_sha256_hash
|
||||||
|
|
||||||
|
|
||||||
|
SUBDIRS = ["subdir1", "subdir2"]
|
||||||
|
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
|
||||||
|
add_subdirs_to_dir(tmp_path, SUBDIRS)
|
||||||
|
|
||||||
|
expected_return_value = []
|
||||||
|
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
|
||||||
|
add_subdirs_to_dir(tmp_path, SUBDIRS)
|
||||||
|
files = add_files_to_dir(tmp_path, FILES)
|
||||||
|
|
||||||
|
expected_return_value = sorted(files)
|
||||||
|
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch):
|
||||||
|
subdirs = add_subdirs_to_dir(tmp_path, SUBDIRS)
|
||||||
|
add_files_to_dir(subdirs[0], FILES)
|
||||||
|
|
||||||
|
files = add_files_to_dir(tmp_path, FILES)
|
||||||
|
|
||||||
|
expected_return_value = sorted(files)
|
||||||
|
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||||
|
|
|
@ -1,66 +1,22 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tests.utils import is_user_admin
|
from tests.utils import add_files_to_dir, is_user_admin
|
||||||
|
|
||||||
|
from common.utils.file_utils import get_all_regular_files_in_directory
|
||||||
from infection_monkey.utils.dir_utils import (
|
from infection_monkey.utils.dir_utils import (
|
||||||
file_extension_filter,
|
file_extension_filter,
|
||||||
filter_files,
|
filter_files,
|
||||||
get_all_regular_files_in_directory,
|
|
||||||
is_not_shortcut_filter,
|
is_not_shortcut_filter,
|
||||||
is_not_symlink_filter,
|
is_not_symlink_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
SHORTCUT = "shortcut.lnk"
|
SHORTCUT = "shortcut.lnk"
|
||||||
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg", SHORTCUT]
|
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg", SHORTCUT]
|
||||||
SUBDIRS = ["subdir1", "subdir2"]
|
|
||||||
|
|
||||||
|
|
||||||
def add_subdirs_to_dir(parent_dir):
|
|
||||||
subdirs = [parent_dir / s for s in SUBDIRS]
|
|
||||||
|
|
||||||
for subdir in subdirs:
|
|
||||||
subdir.mkdir()
|
|
||||||
|
|
||||||
return subdirs
|
|
||||||
|
|
||||||
|
|
||||||
def add_files_to_dir(parent_dir):
|
|
||||||
files = [parent_dir / f for f in FILES]
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
f.touch()
|
|
||||||
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
|
|
||||||
add_subdirs_to_dir(tmp_path)
|
|
||||||
|
|
||||||
expected_return_value = []
|
|
||||||
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
|
|
||||||
add_subdirs_to_dir(tmp_path)
|
|
||||||
files = add_files_to_dir(tmp_path)
|
|
||||||
|
|
||||||
expected_return_value = sorted(files)
|
|
||||||
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch):
|
|
||||||
subdirs = add_subdirs_to_dir(tmp_path)
|
|
||||||
add_files_to_dir(subdirs[0])
|
|
||||||
|
|
||||||
files = add_files_to_dir(tmp_path)
|
|
||||||
|
|
||||||
expected_return_value = sorted(files)
|
|
||||||
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_filter_files__no_results(tmp_path):
|
def test_filter_files__no_results(tmp_path):
|
||||||
add_files_to_dir(tmp_path)
|
add_files_to_dir(tmp_path, FILES)
|
||||||
|
|
||||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||||
filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
|
filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
|
||||||
|
@ -69,7 +25,7 @@ def test_filter_files__no_results(tmp_path):
|
||||||
|
|
||||||
|
|
||||||
def test_filter_files__all_true(tmp_path):
|
def test_filter_files__all_true(tmp_path):
|
||||||
files = add_files_to_dir(tmp_path)
|
files = add_files_to_dir(tmp_path, FILES)
|
||||||
expected_return_value = sorted(files)
|
expected_return_value = sorted(files)
|
||||||
|
|
||||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||||
|
@ -79,7 +35,7 @@ def test_filter_files__all_true(tmp_path):
|
||||||
|
|
||||||
|
|
||||||
def test_filter_files__multiple_filters(tmp_path):
|
def test_filter_files__multiple_filters(tmp_path):
|
||||||
files = add_files_to_dir(tmp_path)
|
files = add_files_to_dir(tmp_path, FILES)
|
||||||
expected_return_value = sorted(files[4:6])
|
expected_return_value = sorted(files[4:6])
|
||||||
|
|
||||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||||
|
@ -93,7 +49,7 @@ def test_filter_files__multiple_filters(tmp_path):
|
||||||
def test_file_extension_filter(tmp_path):
|
def test_file_extension_filter(tmp_path):
|
||||||
valid_extensions = {".zip", ".xyz"}
|
valid_extensions = {".zip", ".xyz"}
|
||||||
|
|
||||||
files = add_files_to_dir(tmp_path)
|
files = add_files_to_dir(tmp_path, FILES)
|
||||||
|
|
||||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||||
filtered_files = filter_files(files_in_dir, [file_extension_filter(valid_extensions)])
|
filtered_files = filter_files(files_in_dir, [file_extension_filter(valid_extensions)])
|
||||||
|
@ -105,7 +61,7 @@ def test_file_extension_filter(tmp_path):
|
||||||
os.name == "nt" and not is_user_admin(), reason="Test requires admin rights on Windows"
|
os.name == "nt" and not is_user_admin(), reason="Test requires admin rights on Windows"
|
||||||
)
|
)
|
||||||
def test_is_not_symlink_filter(tmp_path):
|
def test_is_not_symlink_filter(tmp_path):
|
||||||
files = add_files_to_dir(tmp_path)
|
files = add_files_to_dir(tmp_path, FILES)
|
||||||
link_path = tmp_path / "symlink.test"
|
link_path = tmp_path / "symlink.test"
|
||||||
link_path.symlink_to(files[0], target_is_directory=False)
|
link_path.symlink_to(files[0], target_is_directory=False)
|
||||||
|
|
||||||
|
@ -118,7 +74,7 @@ def test_is_not_symlink_filter(tmp_path):
|
||||||
|
|
||||||
|
|
||||||
def test_is_not_shortcut_filter(tmp_path):
|
def test_is_not_shortcut_filter(tmp_path):
|
||||||
add_files_to_dir(tmp_path)
|
add_files_to_dir(tmp_path, FILES)
|
||||||
|
|
||||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||||
filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))
|
filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import flask_jwt_extended
|
import flask_jwt_extended
|
||||||
import flask_restful
|
import flask_restful
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -9,15 +11,28 @@ import monkey_island.cc.resources.island_mode
|
||||||
from monkey_island.cc.services.representations import output_json
|
from monkey_island.cc.services.representations import output_json
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture
|
||||||
def flask_client(monkeypatch_session):
|
def flask_client(monkeypatch_session):
|
||||||
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
|
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
|
||||||
|
|
||||||
with mock_init_app().test_client() as client:
|
container = MagicMock()
|
||||||
|
container.resolve_dependencies.return_value = []
|
||||||
|
|
||||||
|
with mock_init_app(container).test_client() as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
def mock_init_app():
|
@pytest.fixture
|
||||||
|
def build_flask_client(monkeypatch_session):
|
||||||
|
def inner(container):
|
||||||
|
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
|
||||||
|
|
||||||
|
return mock_init_app(container).test_client()
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def mock_init_app(container):
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config["SECRET_KEY"] = "test_key"
|
app.config["SECRET_KEY"] = "test_key"
|
||||||
|
|
||||||
|
@ -25,7 +40,8 @@ def mock_init_app():
|
||||||
api.representations = {"application/json": output_json}
|
api.representations = {"application/json": output_json}
|
||||||
|
|
||||||
monkey_island.cc.app.init_app_url_rules(app)
|
monkey_island.cc.app.init_app_url_rules(app)
|
||||||
monkey_island.cc.app.init_api_resources(api)
|
flask_resource_manager = monkey_island.cc.app.FlaskDIWrapper(api, container)
|
||||||
|
monkey_island.cc.app.init_api_resources(flask_resource_manager)
|
||||||
|
|
||||||
flask_jwt_extended.JWTManager(app)
|
flask_jwt_extended.JWTManager(app)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import io
|
||||||
|
from typing import BinaryIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common import DIContainer
|
||||||
|
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||||
|
|
||||||
|
FILE_NAME = "test_file"
|
||||||
|
FILE_CONTENTS = b"HelloWorld!"
|
||||||
|
|
||||||
|
|
||||||
|
class MockFileStorageService(IFileStorageService):
|
||||||
|
def __init__(self):
|
||||||
|
self._file = io.BytesIO(FILE_CONTENTS)
|
||||||
|
|
||||||
|
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def open_file(self, unsafe_file_name: str) -> BinaryIO:
|
||||||
|
if unsafe_file_name != FILE_NAME:
|
||||||
|
raise FileRetrievalError()
|
||||||
|
|
||||||
|
return self._file
|
||||||
|
|
||||||
|
def delete_file(self, unsafe_file_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete_all_files(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flask_client(build_flask_client, tmp_path):
|
||||||
|
container = DIContainer()
|
||||||
|
container.register(IFileStorageService, MockFileStorageService)
|
||||||
|
|
||||||
|
with build_flask_client(container) as flask_client:
|
||||||
|
yield flask_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_download_endpoint(tmp_path, flask_client):
|
||||||
|
resp = flask_client.get(f"/api/pba/download/{FILE_NAME}")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert next(resp.response) == FILE_CONTENTS
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_download_endpoint_404(tmp_path, flask_client):
|
||||||
|
nonexistant_file_name = "nonexistant_file"
|
||||||
|
|
||||||
|
resp = flask_client.get(f"/api/pba/download/{nonexistant_file_name}")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
|
@ -1,10 +1,16 @@
|
||||||
|
import io
|
||||||
|
from typing import BinaryIO
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tests.utils import raise_
|
from tests.utils import raise_
|
||||||
|
|
||||||
|
from common import DIContainer
|
||||||
from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
|
from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
|
||||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
|
||||||
|
|
||||||
TEST_FILE = b"""-----------------------------1
|
TEST_FILE_CONTENTS = b"m0nk3y"
|
||||||
|
TEST_FILE = (
|
||||||
|
b"""-----------------------------1
|
||||||
Content-Disposition: form-data; name="filepond"
|
Content-Disposition: form-data; name="filepond"
|
||||||
|
|
||||||
{}
|
{}
|
||||||
|
@ -12,13 +18,11 @@ Content-Disposition: form-data; name="filepond"
|
||||||
Content-Disposition: form-data; name="filepond"; filename="test.py"
|
Content-Disposition: form-data; name="filepond"; filename="test.py"
|
||||||
Content-Type: text/x-python
|
Content-Type: text/x-python
|
||||||
|
|
||||||
m0nk3y
|
"""
|
||||||
|
+ TEST_FILE_CONTENTS
|
||||||
|
+ b"""
|
||||||
-----------------------------1--"""
|
-----------------------------1--"""
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def custom_pba_directory(tmpdir):
|
|
||||||
PostBreachFilesService.initialize(tmpdir)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -35,8 +39,41 @@ def mock_get_config_value(monkeypatch):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockFileStorageService(IFileStorageService):
|
||||||
|
def __init__(self):
|
||||||
|
self._file = None
|
||||||
|
|
||||||
|
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
|
||||||
|
self._file = io.BytesIO(file_contents.read())
|
||||||
|
|
||||||
|
def open_file(self, unsafe_file_name: str) -> BinaryIO:
|
||||||
|
if self._file is None:
|
||||||
|
raise FileRetrievalError()
|
||||||
|
return self._file
|
||||||
|
|
||||||
|
def delete_file(self, unsafe_file_name: str):
|
||||||
|
self._file = None
|
||||||
|
|
||||||
|
def delete_all_files(self):
|
||||||
|
self.delete_file("")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def file_storage_service():
|
||||||
|
return MockFileStorageService()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flask_client(build_flask_client, file_storage_service):
|
||||||
|
container = DIContainer()
|
||||||
|
container.register_instance(IFileStorageService, file_storage_service)
|
||||||
|
|
||||||
|
with build_flask_client(container) as flask_client:
|
||||||
|
yield flask_client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
||||||
def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config_value):
|
def test_pba_file_upload_post(flask_client, pba_os, mock_set_config_value):
|
||||||
resp = flask_client.post(
|
resp = flask_client.post(
|
||||||
f"/api/file-upload/{pba_os}",
|
f"/api/file-upload/{pba_os}",
|
||||||
data=TEST_FILE,
|
data=TEST_FILE,
|
||||||
|
@ -46,7 +83,7 @@ def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
def test_pba_file_upload_post__invalid(flask_client, monkeypatch, mock_set_config_value):
|
def test_pba_file_upload_post__invalid(flask_client, mock_set_config_value):
|
||||||
resp = flask_client.post(
|
resp = flask_client.post(
|
||||||
"/api/file-upload/bogus",
|
"/api/file-upload/bogus",
|
||||||
data=TEST_FILE,
|
data=TEST_FILE,
|
||||||
|
@ -58,12 +95,9 @@ def test_pba_file_upload_post__invalid(flask_client, monkeypatch, mock_set_confi
|
||||||
|
|
||||||
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
||||||
def test_pba_file_upload_post__internal_server_error(
|
def test_pba_file_upload_post__internal_server_error(
|
||||||
flask_client, pba_os, monkeypatch, mock_set_config_value
|
flask_client, pba_os, mock_set_config_value, file_storage_service
|
||||||
):
|
):
|
||||||
monkeypatch.setattr(
|
file_storage_service.save_file = lambda x, y: raise_(Exception())
|
||||||
"monkey_island.cc.resources.pba_file_upload.FileUpload.upload_pba_file",
|
|
||||||
lambda x, y: raise_(Exception()),
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = flask_client.post(
|
resp = flask_client.post(
|
||||||
f"/api/file-upload/{pba_os}",
|
f"/api/file-upload/{pba_os}",
|
||||||
|
@ -75,16 +109,14 @@ def test_pba_file_upload_post__internal_server_error(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
||||||
def test_pba_file_upload_get__file_not_found(
|
def test_pba_file_upload_get__file_not_found(flask_client, pba_os, mock_get_config_value):
|
||||||
flask_client, pba_os, monkeypatch, mock_get_config_value
|
|
||||||
):
|
|
||||||
resp = flask_client.get(f"/api/file-upload/{pba_os}?load=bogus_mogus.py")
|
resp = flask_client.get(f"/api/file-upload/{pba_os}?load=bogus_mogus.py")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
|
||||||
def test_pba_file_upload_endpoint(
|
def test_pba_file_upload_endpoint(
|
||||||
flask_client, pba_os, monkeypatch, mock_get_config_value, mock_set_config_value
|
flask_client, pba_os, mock_get_config_value, mock_set_config_value
|
||||||
):
|
):
|
||||||
resp_post = flask_client.post(
|
resp_post = flask_client.post(
|
||||||
f"/api/file-upload/{pba_os}",
|
f"/api/file-upload/{pba_os}",
|
||||||
|
@ -95,7 +127,7 @@ def test_pba_file_upload_endpoint(
|
||||||
|
|
||||||
resp_get = flask_client.get(f"/api/file-upload/{pba_os}?load=test.py")
|
resp_get = flask_client.get(f"/api/file-upload/{pba_os}?load=test.py")
|
||||||
assert resp_get.status_code == 200
|
assert resp_get.status_code == 200
|
||||||
assert resp_get.data.decode() == "m0nk3y"
|
assert resp_get.data == TEST_FILE_CONTENTS
|
||||||
# Closing the response closes the file handle, else it can't be deleted
|
# Closing the response closes the file handle, else it can't be deleted
|
||||||
resp_get.close()
|
resp_get.close()
|
||||||
|
|
||||||
|
@ -111,7 +143,7 @@ def test_pba_file_upload_endpoint(
|
||||||
|
|
||||||
|
|
||||||
def test_pba_file_upload_endpoint__invalid(
|
def test_pba_file_upload_endpoint__invalid(
|
||||||
flask_client, monkeypatch, mock_set_config_value, mock_get_config_value
|
flask_client, mock_set_config_value, mock_get_config_value
|
||||||
):
|
):
|
||||||
resp_post = flask_client.post(
|
resp_post = flask_client.post(
|
||||||
"/api/file-upload/bogus",
|
"/api/file-upload/bogus",
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import stat
|
import stat
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tests.monkey_island.utils import assert_windows_permissions
|
from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions
|
||||||
|
|
||||||
from monkey_island.cc.server_utils.file_utils import (
|
from monkey_island.cc.server_utils.file_utils import (
|
||||||
create_secure_directory,
|
create_secure_directory,
|
||||||
|
@ -39,12 +39,8 @@ def test_create_secure_directory__no_parent_dir(test_path_nested):
|
||||||
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
|
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
|
||||||
def test_create_secure_directory__perm_linux(test_path):
|
def test_create_secure_directory__perm_linux(test_path):
|
||||||
create_secure_directory(test_path)
|
create_secure_directory(test_path)
|
||||||
st = os.stat(test_path)
|
|
||||||
|
|
||||||
expected_mode = stat.S_IRWXU
|
assert_linux_permissions(test_path)
|
||||||
actual_mode = st.st_mode & (stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
|
||||||
|
|
||||||
assert expected_mode == actual_mode
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
|
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
|
||||||
|
|
|
@ -0,0 +1,135 @@
|
||||||
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions
|
||||||
|
|
||||||
|
from monkey_island.cc.server_utils.file_utils import is_windows_os
|
||||||
|
from monkey_island.cc.services import DirectoryFileStorageService, FileRetrievalError
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_if_storage_directory_is_file(tmp_path):
|
||||||
|
new_file = tmp_path / "new_file.txt"
|
||||||
|
new_file.write_text("HelloWorld!")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DirectoryFileStorageService(new_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_directory_created(tmp_path):
|
||||||
|
new_dir = tmp_path / "new_dir"
|
||||||
|
|
||||||
|
DirectoryFileStorageService(new_dir)
|
||||||
|
|
||||||
|
assert new_dir.exists() and new_dir.is_dir()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
|
||||||
|
def test_directory_permissions__linux(tmp_path):
|
||||||
|
new_dir = tmp_path / "new_dir"
|
||||||
|
|
||||||
|
DirectoryFileStorageService(new_dir)
|
||||||
|
|
||||||
|
assert_linux_permissions(new_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
|
||||||
|
def test_directory_permissions__windows(tmp_path):
|
||||||
|
new_dir = tmp_path / "new_dir"
|
||||||
|
|
||||||
|
DirectoryFileStorageService(new_dir)
|
||||||
|
|
||||||
|
assert_windows_permissions(new_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def save_file(tmp_path, file_path_prefix=""):
|
||||||
|
file_name = "test.txt"
|
||||||
|
file_contents = "Hello World!"
|
||||||
|
expected_file_path = tmp_path / file_name
|
||||||
|
|
||||||
|
fss = DirectoryFileStorageService(tmp_path)
|
||||||
|
fss.save_file(Path(file_path_prefix) / file_name, io.BytesIO(file_contents.encode()))
|
||||||
|
|
||||||
|
assert expected_file_path.is_file()
|
||||||
|
assert expected_file_path.read_text() == file_contents
|
||||||
|
|
||||||
|
|
||||||
|
def delete_file(tmp_path, file_path_prefix=""):
|
||||||
|
file_name = "file.txt"
|
||||||
|
file = tmp_path / file_name
|
||||||
|
file.touch()
|
||||||
|
assert file.is_file()
|
||||||
|
|
||||||
|
fss = DirectoryFileStorageService(tmp_path)
|
||||||
|
fss.delete_file(Path(file_path_prefix) / file_name)
|
||||||
|
|
||||||
|
assert not file.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def open_file(tmp_path, file_path_prefix=""):
|
||||||
|
file_name = "test.txt"
|
||||||
|
expected_file_contents = "Hello World!"
|
||||||
|
expected_file_path = tmp_path / file_name
|
||||||
|
expected_file_path.write_text(expected_file_contents)
|
||||||
|
|
||||||
|
fss = DirectoryFileStorageService(tmp_path)
|
||||||
|
with fss.open_file(Path(file_path_prefix) / file_name) as f:
|
||||||
|
actual_file_contents = f.read()
|
||||||
|
|
||||||
|
assert actual_file_contents == expected_file_contents.encode()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
|
||||||
|
def test_fn(tmp_path, fn):
|
||||||
|
fn(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
|
||||||
|
def test_fn__ignore_relative_path(tmp_path, fn):
|
||||||
|
fn(tmp_path, "../../")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
|
||||||
|
def test_fn__ignore_absolute_path(tmp_path, fn):
|
||||||
|
if is_windows_os():
|
||||||
|
fn(tmp_path, "C:\\Windows")
|
||||||
|
else:
|
||||||
|
fn(tmp_path, "/home/")
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_all_files(tmp_path):
|
||||||
|
for filename in ["1.txt", "2.txt", "3.txt"]:
|
||||||
|
(tmp_path / filename).touch()
|
||||||
|
|
||||||
|
fss = DirectoryFileStorageService(tmp_path)
|
||||||
|
fss.delete_all_files()
|
||||||
|
|
||||||
|
for file in tmp_path.iterdir():
|
||||||
|
assert False, f"{tmp_path} was expected to be empty, but contained files"
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_all_files__skip_directories(tmp_path):
|
||||||
|
test_dir = tmp_path / "test_dir"
|
||||||
|
test_dir.mkdir()
|
||||||
|
for filename in ["1.txt", "2.txt", "3.txt"]:
|
||||||
|
(tmp_path / filename).touch()
|
||||||
|
|
||||||
|
fss = DirectoryFileStorageService(tmp_path)
|
||||||
|
fss.delete_all_files()
|
||||||
|
|
||||||
|
for file in tmp_path.iterdir():
|
||||||
|
assert file.name == test_dir.name
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_nonexistant_file(tmp_path):
|
||||||
|
fss = DirectoryFileStorageService(tmp_path)
|
||||||
|
|
||||||
|
# This test will fail if this call raises an exception.
|
||||||
|
fss.delete_file("nonexistant_file.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_open_nonexistant_file(tmp_path):
|
||||||
|
fss = DirectoryFileStorageService(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(FileRetrievalError):
|
||||||
|
fss.open_file("nonexistant_file.txt")
|
|
@ -1,29 +1,31 @@
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tests.monkey_island.utils import assert_windows_permissions
|
|
||||||
from tests.utils import raise_
|
from tests.utils import raise_
|
||||||
|
|
||||||
from monkey_island.cc.server_utils.file_utils import is_windows_os
|
from monkey_island.cc.services import DirectoryFileStorageService
|
||||||
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def file_storage_service(tmp_path):
|
||||||
|
return DirectoryFileStorageService(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def custom_pba_directory(tmpdir):
|
def post_breach_files_service(file_storage_service):
|
||||||
PostBreachFilesService.initialize(tmpdir)
|
PostBreachFilesService.initialize(file_storage_service)
|
||||||
|
|
||||||
|
|
||||||
def create_custom_pba_file(filename):
|
def test_remove_pba_files(file_storage_service, tmp_path):
|
||||||
PostBreachFilesService.save_file(filename, b"")
|
file_storage_service.save_file("linux_file", io.BytesIO(b""))
|
||||||
|
file_storage_service.save_file("windows_file", io.BytesIO(b""))
|
||||||
|
assert not dir_is_empty(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
def test_remove_pba_files():
|
|
||||||
create_custom_pba_file("linux_file")
|
|
||||||
create_custom_pba_file("windows_file")
|
|
||||||
|
|
||||||
assert not dir_is_empty(PostBreachFilesService.get_custom_pba_directory())
|
|
||||||
PostBreachFilesService.remove_PBA_files()
|
PostBreachFilesService.remove_PBA_files()
|
||||||
assert dir_is_empty(PostBreachFilesService.get_custom_pba_directory())
|
|
||||||
|
assert dir_is_empty(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
def dir_is_empty(dir_path):
|
def dir_is_empty(dir_path):
|
||||||
|
@ -31,45 +33,11 @@ def dir_is_empty(dir_path):
|
||||||
return len(dir_contents) == 0
|
return len(dir_contents) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
|
def test_remove_failure(file_storage_service, monkeypatch):
|
||||||
def test_custom_pba_dir_permissions_linux():
|
|
||||||
st = os.stat(PostBreachFilesService.get_custom_pba_directory())
|
|
||||||
|
|
||||||
assert st.st_mode == 0o40700
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
|
|
||||||
def test_custom_pba_dir_permissions_windows():
|
|
||||||
pba_dir = PostBreachFilesService.get_custom_pba_directory()
|
|
||||||
|
|
||||||
assert_windows_permissions(pba_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def test_remove_failure(monkeypatch):
|
|
||||||
monkeypatch.setattr(os, "remove", lambda x: raise_(OSError("Permission denied")))
|
monkeypatch.setattr(os, "remove", lambda x: raise_(OSError("Permission denied")))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
create_custom_pba_file("windows_file")
|
file_storage_service.save_file("windows_file", io.BytesIO(b""))
|
||||||
PostBreachFilesService.remove_PBA_files()
|
PostBreachFilesService.remove_PBA_files()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
pytest.fail(f"Unxepected exception: {ex}")
|
pytest.fail(f"Unxepected exception: {ex}")
|
||||||
|
|
||||||
|
|
||||||
def test_remove_nonexistant_file(monkeypatch):
|
|
||||||
monkeypatch.setattr(os, "remove", lambda x: raise_(FileNotFoundError("FileNotFound")))
|
|
||||||
|
|
||||||
try:
|
|
||||||
PostBreachFilesService.remove_file("/nonexistant/file")
|
|
||||||
except Exception as ex:
|
|
||||||
pytest.fail(f"Unxepected exception: {ex}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_file():
|
|
||||||
FILE_NAME = "test_file"
|
|
||||||
FILE_CONTENTS = b"hello"
|
|
||||||
PostBreachFilesService.save_file(FILE_NAME, FILE_CONTENTS)
|
|
||||||
|
|
||||||
expected_file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), FILE_NAME)
|
|
||||||
|
|
||||||
assert os.path.isfile(expected_file_path)
|
|
||||||
assert FILE_CONTENTS == open(expected_file_path, "rb").read()
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
|
||||||
def is_user_admin():
|
def is_user_admin():
|
||||||
|
@ -11,3 +13,21 @@ def is_user_admin():
|
||||||
|
|
||||||
def raise_(ex):
|
def raise_(ex):
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
|
||||||
|
def add_subdirs_to_dir(parent_dir: Path, subdirs: Iterable[str]) -> Iterable[Path]:
|
||||||
|
subdir_paths = [parent_dir / s for s in subdirs]
|
||||||
|
|
||||||
|
for subdir in subdir_paths:
|
||||||
|
subdir.mkdir()
|
||||||
|
|
||||||
|
return subdir_paths
|
||||||
|
|
||||||
|
|
||||||
|
def add_files_to_dir(parent_dir: Path, file_names: Iterable[str]) -> Iterable[Path]:
|
||||||
|
files = [parent_dir / f for f in file_names]
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
f.touch()
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
Loading…
Reference in New Issue