Merge pull request #1915 from guardicore/1904-service-resource-dependency-injection

1904 service resource dependency injection
This commit is contained in:
Mike Salvatore 2022-04-27 09:44:34 -04:00 committed by GitHub
commit 97300376ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1044 additions and 219 deletions

3
.gitignore vendored
View File

@ -101,3 +101,6 @@ venv/
# Hugo
.hugo_build.lock
# mypy
.mypy_cache

View File

@ -0,0 +1 @@
from .di_container import DIContainer

View File

@ -0,0 +1,133 @@
import inspect
from typing import Any, MutableMapping, Sequence, Type, TypeVar
T = TypeVar("T")
class UnregisteredTypeError(ValueError):
pass
class DIContainer:
"""
A dependency injection (DI) container that uses type annotations to resolve and inject
dependencies.
"""
def __init__(self):
self._type_registry = {}
self._instance_registry = {}
def register(self, interface: Type[T], concrete_type: Type[T]):
"""
Register a concrete type that satisfies a given interface.
:param interface: An interface or abstract base class that other classes depend upon
:param concrete_type: A type (class) that implements `interface`
"""
if not inspect.isclass(concrete_type):
raise TypeError(
"Expected a class, but received an instance of type "
f'"{concrete_type.__class__.__name__}"; Pass a class, not an instance, to '
"register(), or use register_instance() instead"
)
if not issubclass(concrete_type, interface):
raise TypeError(
f'Class "{concrete_type.__name__}" is not a subclass of {interface.__name__}'
)
self._type_registry[interface] = concrete_type
DIContainer._del_key(self._instance_registry, interface)
def register_instance(self, interface: Type[T], instance: T):
"""
Register a concrete instance that satisfies a given interface.
:param interface: An interface or abstract base class that other classes depend upon
:param instance: An instance (object) of a type that implements `interface`
"""
if not isinstance(instance, interface):
raise TypeError(
f'The provided instance of type "{instance.__class__.__name__}" '
f"is not an instance of {interface.__name__}"
)
self._instance_registry[interface] = instance
DIContainer._del_key(self._type_registry, interface)
def resolve(self, type_: Type[T]) -> T:
"""
Resolves all dependencies and returns a new instance of `type_` using constructor dependency
injection. Note that only positional arguments are resolved. Varargs, keyword-only args, and
default values are ignored.
:param type_: A type (class) to construct
:return: An instance of `type_`
"""
try:
return self._resolve_type(type_)
except UnregisteredTypeError:
pass
args = self.resolve_dependencies(type_)
return type_(*args)
def resolve_dependencies(self, type_: Type[T]) -> Sequence[Any]:
"""
Resolves all dependencies of type_ and returns a Sequence of objects that correspond type_'s
dependencies. Note that only positional arguments are resolved. Varargs, keyword-only args,
and default values are ignored.
:param type_: A type (class) to resolve dependencies for
:return: An Sequence of dependencies to be injected into type_'s constructor
"""
args = []
for arg_type in inspect.getfullargspec(type_).annotations.values():
instance = self._resolve_type(arg_type)
args.append(instance)
return tuple(args)
def _resolve_type(self, type_: Type[T]) -> T:
if type_ in self._type_registry:
return self._construct_new_instance(type_)
elif type_ in self._instance_registry:
return self._retrieve_registered_instance(type_)
raise UnregisteredTypeError(f'Failed to resolve unregistered type "{type_.__name__}"')
def _construct_new_instance(self, arg_type: Type[T]) -> T:
try:
return self._type_registry[arg_type]()
except TypeError:
# arg_type has dependencies that must be resolved. Recursively call resolve() to
# construct an instance of arg_type with all of the requesite dependencies injected.
return self.resolve(self._type_registry[arg_type])
def _retrieve_registered_instance(self, arg_type: Type[T]) -> T:
return self._instance_registry[arg_type]
def release(self, interface: Type[T]):
"""
Deregister's an interface
:param interface: The interface to release
"""
DIContainer._del_key(self._type_registry, interface)
DIContainer._del_key(self._instance_registry, interface)
@staticmethod
def _del_key(mapping: MutableMapping[T, Any], key: T):
"""
Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError
if the key does not exist.
:param MutableMapping: A mapping from which a key will be deleted
:param key: A key to delete from `mapping`
"""
try:
del mapping[key]
except KeyError:
pass

View File

@ -1,6 +1,7 @@
import hashlib
import os
from pathlib import Path
from typing import Iterable
class InvalidPath(Exception):
@ -21,3 +22,7 @@ def get_file_sha256_hash(filepath: Path):
sha256.update(block)
return sha256.hexdigest()
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
return filter(lambda f: f.is_file(), dir_path.iterdir())

View File

@ -2,10 +2,10 @@ import filecmp
from pathlib import Path
from typing import Iterable, Set
from common.utils.file_utils import get_all_regular_files_in_directory
from infection_monkey.utils.dir_utils import (
file_extension_filter,
filter_files,
get_all_regular_files_in_directory,
is_not_shortcut_filter,
is_not_symlink_filter,
)

View File

@ -2,10 +2,6 @@ from pathlib import Path
from typing import Callable, Iterable, Set
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
def filter_files(
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
) -> Iterable[Path]:

View File

@ -1,11 +1,13 @@
import os
import uuid
from datetime import timedelta
from typing import Type
import flask_restful
from flask import Flask, Response, send_from_directory
from werkzeug.exceptions import NotFound
from common import DIContainer
from monkey_island.cc.database import database, mongo
from monkey_island.cc.resources.agent_controls import StopAgentCheck, StopAllAgents
from monkey_island.cc.resources.attack.attack_report import AttackReport
@ -111,7 +113,17 @@ def init_app_url_rules(app):
app.add_url_rule("/<path:static_path>", "serve_static_file", serve_static_file)
def init_api_resources(api):
class FlaskDIWrapper:
def __init__(self, api: flask_restful.Api, container: DIContainer):
self._api = api
self._container = container
def add_resource(self, resource: Type[flask_restful.Resource], *urls: str):
dependencies = self._container.resolve_dependencies(resource)
self._api.add_resource(resource, *urls, resource_class_args=dependencies)
def init_api_resources(api: FlaskDIWrapper):
api.add_resource(Root, "/api")
api.add_resource(Registration, "/api/registration")
api.add_resource(Authenticate, "/api/auth")
@ -148,13 +160,18 @@ def init_api_resources(api):
api.add_resource(TelemetryFeed, "/api/telemetry-feed")
api.add_resource(Log, "/api/log")
api.add_resource(IslandLog, "/api/log/island/download")
api.add_resource(PBAFileDownload, "/api/pba/download/<string:filename>")
api.add_resource(
PBAFileDownload,
"/api/pba/download/<string:filename>",
)
api.add_resource(
FileUpload,
"/api/file-upload/<string:file_type>",
"/api/file-upload/<string:file_type>?load=<string:filename>",
"/api/file-upload/<string:file_type>?restore=<string:filename>",
"/api/file-upload/<string:target_os>",
"/api/file-upload/<string:target_os>?load=<string:filename>",
"/api/file-upload/<string:target_os>?restore=<string:filename>",
)
api.add_resource(PropagationCredentials, "/api/propagation-credentials/<string:guid>")
api.add_resource(RemoteRun, "/api/remote-monkey")
api.add_resource(VersionUpdate, "/api/version-update")
@ -168,7 +185,7 @@ def init_api_resources(api):
api.add_resource(TelemetryBlackboxEndpoint, "/api/test/telemetry")
def init_app(mongo_url):
def init_app(mongo_url: str, container: DIContainer):
app = Flask(__name__)
api = flask_restful.Api(app)
@ -177,6 +194,8 @@ def init_app(mongo_url):
init_app_config(app, mongo_url)
init_app_services(app)
init_app_url_rules(app)
init_api_resources(api)
flask_resource_manager = FlaskDIWrapper(api, container)
init_api_resources(flask_resource_manager)
return app

View File

@ -1,7 +1,11 @@
import flask_restful
from flask import send_from_directory
import logging
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
import flask_restful
from flask import make_response, send_file
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
logger = logging.getLogger(__file__)
class PBAFileDownload(flask_restful.Resource):
@ -9,7 +13,17 @@ class PBAFileDownload(flask_restful.Resource):
File download endpoint used by monkey to download user's PBA file
"""
def __init__(self, file_storage_service: IFileStorageService):
self._file_storage_service = file_storage_service
# Used by monkey. can't secure.
def get(self, filename):
custom_pba_dir = PostBreachFilesService.get_custom_pba_directory()
return send_from_directory(custom_pba_dir, filename)
def get(self, filename: str):
try:
file = self._file_storage_service.open_file(filename)
# `send_file()` handles the closing of the open file.
return send_file(file, mimetype="application/octet-stream")
except FileRetrievalError as err:
error_msg = f"Failed to open file {filename}: {err}"
logger.error(error_msg)
return make_response({"error": error_msg}, 404)

View File

@ -1,15 +1,19 @@
import copy
import logging
from http import HTTPStatus
import flask_restful
from flask import Response, request, send_from_directory
from flask import Response, make_response, request, send_file
from werkzeug.datastructures import FileStorage
from werkzeug.utils import secure_filename
from common.config_value_paths import PBA_LINUX_FILENAME_PATH, PBA_WINDOWS_FILENAME_PATH
from monkey_island.cc.resources.auth.auth import jwt_required
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
from monkey_island.cc.services.config import ConfigService
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
logger = logging.getLogger(__file__)
# Front end uses these strings to identify which files to work with (linux or windows)
LINUX_PBA_TYPE = "PBAlinux"
@ -21,42 +25,60 @@ class FileUpload(flask_restful.Resource):
File upload endpoint used to exchange files with filepond component on the front-end
"""
def __init__(self, file_storage_service: IFileStorageService):
self._file_storage_service = file_storage_service
# TODO: Fix references/coupling to filepond
# TODO: Add comment explaining why this is basically a duplicate of the endpoint in the
# PBAFileDownload resource.
@jwt_required
def get(self, file_type):
def get(self, target_os):
"""
Sends file to filepond
:param file_type: Type indicates which file to send, linux or windows
:param target_os: Indicates which file to send, linux or windows
:return: Returns file contents
"""
if self.is_pba_file_type_supported(file_type):
if self._is_target_os_supported(target_os):
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
# Verify that file_name is indeed a file from config
if file_type == LINUX_PBA_TYPE:
if target_os == LINUX_PBA_TYPE:
# TODO: Make these paths Tuples so we don't need to copy them
filename = ConfigService.get_config_value(copy.deepcopy(PBA_LINUX_FILENAME_PATH))
else:
filename = ConfigService.get_config_value(copy.deepcopy(PBA_WINDOWS_FILENAME_PATH))
return send_from_directory(PostBreachFilesService.get_custom_pba_directory(), filename)
try:
file = self._file_storage_service.open_file(filename)
# `send_file()` handles the closing of the open file.
return send_file(file, mimetype="application/octet-stream")
except FileRetrievalError as err:
error_msg = f"Failed to open file {filename}: {err}"
logger.error(error_msg)
return make_response({"error": error_msg}, 404)
@jwt_required
def post(self, file_type):
def post(self, target_os):
"""
Receives user's uploaded file from filepond
:param file_type: Type indicates which file was received, linux or windows
:param target_os: Type indicates which file was received, linux or windows
:return: Returns flask response object with uploaded file's filename
"""
if self.is_pba_file_type_supported(file_type):
if self._is_target_os_supported(target_os):
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
filename = FileUpload.upload_pba_file(
request.files["filepond"], (file_type == LINUX_PBA_TYPE)
filename = self._upload_pba_file(
# TODO: This "filepond" string can be changed to be more generic in the `react-filepond`
# component.
request.files["filepond"],
(target_os == LINUX_PBA_TYPE),
)
response = Response(response=filename, status=200, mimetype="text/plain")
return response
@staticmethod
def upload_pba_file(file_storage: FileStorage, is_linux=True):
def _upload_pba_file(self, file_storage: FileStorage, is_linux=True):
"""
Uploads PBA file to island's file system
:param request_: Request object containing PBA file
@ -64,9 +86,7 @@ class FileUpload(flask_restful.Resource):
:return: filename string
"""
filename = secure_filename(file_storage.filename)
file_contents = file_storage.read()
PostBreachFilesService.save_file(filename, file_contents)
self._file_storage_service.save_file(filename, file_storage.stream)
ConfigService.set_config_value(
(PBA_LINUX_FILENAME_PATH if is_linux else PBA_WINDOWS_FILENAME_PATH), filename
@ -75,25 +95,25 @@ class FileUpload(flask_restful.Resource):
return filename
@jwt_required
def delete(self, file_type):
def delete(self, target_os):
"""
Deletes file that has been deleted on the front end
:param file_type: Type indicates which file was deleted, linux of windows
:param target_os: Type indicates which file was deleted, linux of windows
:return: Empty response
"""
if self.is_pba_file_type_supported(file_type):
if self._is_target_os_supported(target_os):
return Response(status=HTTPStatus.UNPROCESSABLE_ENTITY, mimetype="text/plain")
filename_path = (
PBA_LINUX_FILENAME_PATH if file_type == "PBAlinux" else PBA_WINDOWS_FILENAME_PATH
PBA_LINUX_FILENAME_PATH if target_os == "PBAlinux" else PBA_WINDOWS_FILENAME_PATH
)
filename = ConfigService.get_config_value(filename_path)
if filename:
PostBreachFilesService.remove_file(filename)
self._file_storage_service.delete_file(filename)
ConfigService.set_config_value(filename_path, "")
return {}
return make_response({}, 200)
@staticmethod
def is_pba_file_type_supported(file_type: str) -> bool:
return file_type not in {LINUX_PBA_TYPE, WINDOWS_PBA_TYPE}
def _is_target_os_supported(target_os: str) -> bool:
return target_os not in {LINUX_PBA_TYPE, WINDOWS_PBA_TYPE}

View File

@ -16,6 +16,7 @@ MONKEY_ISLAND_DIR_BASE_PATH = str(Path(__file__).parent.parent)
if str(MONKEY_ISLAND_DIR_BASE_PATH) not in sys.path:
sys.path.insert(0, MONKEY_ISLAND_DIR_BASE_PATH)
from common import DIContainer # noqa: E402
from common.version import get_version # noqa: E402
from monkey_island.cc.app import init_app # noqa: E402
from monkey_island.cc.arg_parser import IslandCmdArgs # noqa: E402
@ -47,7 +48,7 @@ def run_monkey_island():
_exit_on_invalid_config_options(config_options)
_configure_logging(config_options)
_initialize_globals(config_options.data_dir)
container = _initialize_di_container(config_options.data_dir)
mongo_db_process = None
if config_options.start_mongodb:
@ -56,7 +57,7 @@ def run_monkey_island():
_connect_to_mongodb(mongo_db_process)
_configure_gevent_exception_handling(config_options.data_dir)
_start_island_server(island_args.setup_only, config_options)
_start_island_server(island_args.setup_only, config_options, container)
def _extract_config(island_args: IslandCmdArgs) -> IslandConfigOptions:
@ -88,8 +89,8 @@ def _configure_logging(config_options):
setup_logging(config_options.data_dir, config_options.log_level)
def _initialize_globals(data_dir: Path):
initialize_services(data_dir)
def _initialize_di_container(data_dir: Path) -> DIContainer:
return initialize_services(data_dir)
def _start_mongodb(data_dir: Path) -> MongoDbProcess:
@ -127,9 +128,11 @@ def _configure_gevent_exception_handling(data_dir):
hub.handle_error = GeventHubErrorHandler(hub, logger)
def _start_island_server(should_setup_only, config_options: IslandConfigOptions):
def _start_island_server(
should_setup_only: bool, config_options: IslandConfigOptions, container: DIContainer
):
populate_exporter_list()
app = init_app(mongo_setup.MONGO_URL)
app = init_app(mongo_setup.MONGO_URL, container)
if should_setup_only:
logger.warning("Setup only flag passed. Exiting.")

View File

@ -1,2 +1,5 @@
from .i_file_storage_service import IFileStorageService, FileRetrievalError
from .directory_file_storage_service import DirectoryFileStorageService
from .authentication.authentication_service import AuthenticationService
from .authentication.json_file_user_datastore import JsonFileUserDatastore

View File

@ -1,3 +1,5 @@
from pathlib import Path
import bcrypt
from common.utils.exceptions import (
@ -23,7 +25,7 @@ class AuthenticationService:
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
# not a priority.
@classmethod
def initialize(cls, data_dir: str, user_datastore: IUserDatastore):
def initialize(cls, data_dir: Path, user_datastore: IUserDatastore):
cls.DATA_DIR = data_dir
cls.user_datastore = user_datastore

View File

@ -0,0 +1,75 @@
import logging
import shutil
from pathlib import Path
from typing import BinaryIO
from common.utils.file_utils import get_all_regular_files_in_directory
from monkey_island.cc.server_utils.file_utils import create_secure_directory
from . import FileRetrievalError, IFileStorageService
logger = logging.getLogger(__name__)
class DirectoryFileStorageService(IFileStorageService):
"""
A implementation of IFileStorageService that reads and writes files from/to the local
filesystem.
"""
def __init__(self, storage_directory: Path):
"""
:param storage_directory: A Path object representing the directory where files will be
stored. If the directory does not exist, it will be created.
"""
if storage_directory.exists() and not storage_directory.is_dir():
raise ValueError(f"The provided path must point to a directory: {storage_directory}")
if not storage_directory.exists():
create_secure_directory(storage_directory)
self._storage_directory = storage_directory
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
safe_file_path = self._get_safe_file_path(unsafe_file_name)
logger.debug(f"Saving file contents to {safe_file_path}")
with open(safe_file_path, "wb") as dest:
shutil.copyfileobj(file_contents, dest)
def open_file(self, unsafe_file_name: str) -> BinaryIO:
safe_file_path = self._get_safe_file_path(unsafe_file_name)
try:
logger.debug(f"Opening {safe_file_path}")
return open(safe_file_path, "rb")
except OSError as err:
logger.error(err)
raise FileRetrievalError(f"Failed to retrieve file {safe_file_path}: {err}") from err
def delete_file(self, unsafe_file_name: str):
safe_file_path = self._get_safe_file_path(unsafe_file_name)
try:
logger.debug(f"Deleting {safe_file_path}")
safe_file_path.unlink()
except FileNotFoundError:
# This method is idempotent.
pass
def _get_safe_file_path(self, unsafe_file_name: str):
# Remove any path information from the file name.
safe_file_name = Path(unsafe_file_name).resolve().name
safe_file_path = (self._storage_directory / safe_file_name).resolve()
# This is a paranoid check to avoid directory traversal attacks.
if self._storage_directory.resolve() not in safe_file_path.parents:
raise ValueError(f"The file named {unsafe_file_name} can not be safely retrieved")
logger.debug(f"Unsafe file name {unsafe_file_name} sanitized: {safe_file_path}")
return safe_file_path
def delete_all_files(self):
for file in get_all_regular_files_in_directory(self._storage_directory):
logger.debug(f"Deleting {file}")
file.unlink()

View File

@ -0,0 +1,53 @@
import abc
from typing import BinaryIO
class FileRetrievalError(ValueError):
pass
class IFileStorageService(metaclass=abc.ABCMeta):
"""
A service that allows the storage and retrieval of individual files.
"""
@abc.abstractmethod
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
"""
Save a file, identified by a name
:param unsafe_file_name: An unsanitized file name that will identify the file
:param file_contents: The data to be stored in the file
"""
pass
@abc.abstractmethod
def open_file(self, unsafe_file_name: str) -> BinaryIO:
"""
Open a file and return a file-like object
:param unsafe_file_name: An unsanitized file name that identifies the file to be opened
:return: A file-like object providing access to the file's contents
:rtype: io.BinaryIO
:raises FileRetrievalError: if the file cannot be opened
"""
pass
@abc.abstractmethod
def delete_file(self, unsafe_file_name: str):
"""
Delete a file
This method will delete the file specified by `unsafe_file_name`. This operation is
idempotent and will succeed if the file to be deleted does not exist.
:param unsafe_file_name: An unsanitized file name that identifies the file to be deleted
"""
pass
@abc.abstractmethod
def delete_all_files(self):
"""
Delete all files that have been stored using this service.
"""
pass

View File

@ -1,10 +1,22 @@
from pathlib import Path
from common import DIContainer
from monkey_island.cc.services import DirectoryFileStorageService, IFileStorageService
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
from . import AuthenticationService, JsonFileUserDatastore
def initialize_services(data_dir):
PostBreachFilesService.initialize(data_dir)
def initialize_services(data_dir: Path) -> DIContainer:
container = DIContainer()
container.register_instance(
IFileStorageService, DirectoryFileStorageService(data_dir / "custom_pbas")
)
# This is temporary until we get DI all worked out.
PostBreachFilesService.initialize(container.resolve(IFileStorageService))
LocalMonkeyRunService.initialize(data_dir)
AuthenticationService.initialize(data_dir, JsonFileUserDatastore(data_dir))
return container

View File

@ -1,46 +1,24 @@
import logging
import os
from monkey_island.cc.server_utils.file_utils import create_secure_directory
from monkey_island.cc.services import IFileStorageService
logger = logging.getLogger(__name__)
# TODO: This service wraps an IFileStorageService for the sole purpose of making the
# `remove_PBA_files()` method available to the ConfigService. This whole service can be
# removed once ConfigService is refactored to be stateful (it already is but everything is
# still statically/globally scoped) and use dependency injection.
class PostBreachFilesService:
DATA_DIR = None
CUSTOM_PBA_DIRNAME = "custom_pbas"
_file_storage_service = None
# TODO: A number of these services should be instance objects instead of
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
# not a priority.
@classmethod
def initialize(cls, data_dir):
cls.DATA_DIR = data_dir
custom_pba_dir = cls.get_custom_pba_directory()
create_secure_directory(custom_pba_dir)
def initialize(cls, file_storage_service: IFileStorageService):
cls._file_storage_service = file_storage_service
@staticmethod
def save_file(filename: str, file_contents: bytes):
file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), filename)
with open(file_path, "wb") as f:
f.write(file_contents)
@staticmethod
def remove_PBA_files():
for f in os.listdir(PostBreachFilesService.get_custom_pba_directory()):
PostBreachFilesService.remove_file(f)
@staticmethod
def remove_file(file_name):
file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), file_name)
try:
if os.path.exists(file_path):
os.remove(file_path)
except OSError as e:
logger.error("Can't remove previously uploaded post breach files: %s" % e)
@staticmethod
def get_custom_pba_directory():
return os.path.join(
PostBreachFilesService.DATA_DIR, PostBreachFilesService.CUSTOM_PBA_DIRNAME
)
@classmethod
def remove_PBA_files(cls):
cls._file_storage_service.delete_all_files()

View File

@ -3,6 +3,7 @@ import os
import platform
import stat
import subprocess
from pathlib import Path
from shutil import copyfile
from monkey_island.cc.resources.monkey_download import get_agent_executable_path
@ -19,7 +20,7 @@ class LocalMonkeyRunService:
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
# not a priority.
@classmethod
def initialize(cls, data_dir):
def initialize(cls, data_dir: Path):
cls.DATA_DIR = data_dir
@staticmethod

View File

@ -7,6 +7,9 @@ if is_windows_os():
FULL_CONTROL = 2032127
ACE_ACCESS_MODE_GRANT_ACCESS = win32security.GRANT_ACCESS
ACE_INHERIT_OBJECT_AND_CONTAINER = 3
else:
import os
import stat
def _get_acl_and_sid_from_path(path: str):
@ -33,3 +36,12 @@ def assert_windows_permissions(path: str):
assert ace_sid == user_sid
assert ace_permissions == FULL_CONTROL and ace_access_mode == ACE_ACCESS_MODE_GRANT_ACCESS
assert ace_inheritance == ACE_INHERIT_OBJECT_AND_CONTAINER
def assert_linux_permissions(path: str):
st = os.stat(path)
expected_mode = stat.S_IRWXU
actual_mode = st.st_mode & (stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
assert expected_mode == actual_mode

View File

@ -0,0 +1,283 @@
import abc
import pytest
from common import DIContainer
class IServiceA(metaclass=abc.ABCMeta):
@abc.abstractmethod
def do_something(self):
pass
class IServiceB(metaclass=abc.ABCMeta):
pass
class ServiceA(IServiceA):
def do_something(self):
pass
class ServiceB(IServiceB):
pass
class TestClass1:
__test__ = False
def __init__(self, service_a: IServiceA):
self.service_a = service_a
class TestClass2:
__test__ = False
def __init__(self, service_b: IServiceB):
self.service_b = service_b
class TestClass3:
__test__ = False
def __init__(self, service_a: IServiceA, service_b: IServiceB):
self.service_a = service_a
self.service_b = service_b
@pytest.fixture
def container():
return DIContainer()
def test_register_resolve(container):
container.register(IServiceA, ServiceA)
test_1 = container.resolve(TestClass1)
assert isinstance(test_1.service_a, ServiceA)
def test_correct_instance_type_injected(container):
container.register(IServiceA, ServiceA)
container.register(IServiceB, ServiceB)
test_1 = container.resolve(TestClass1)
test_2 = container.resolve(TestClass2)
assert isinstance(test_1.service_a, ServiceA)
assert isinstance(test_2.service_b, ServiceB)
def test_multiple_correct_instance_types_injected(container):
container.register(IServiceA, ServiceA)
container.register(IServiceB, ServiceB)
test_3 = container.resolve(TestClass3)
assert isinstance(test_3.service_a, ServiceA)
assert isinstance(test_3.service_b, ServiceB)
def test_register_instance(container):
service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance)
test_1 = container.resolve(TestClass1)
assert id(service_a_instance) == id(test_1.service_a)
def test_register_multiple_instances(container):
service_a_instance = ServiceA()
service_b_instance = ServiceB()
container.register_instance(IServiceA, service_a_instance)
container.register_instance(IServiceB, service_b_instance)
test_3 = container.resolve(TestClass3)
assert id(service_a_instance) == id(test_3.service_a)
assert id(service_b_instance) == id(test_3.service_b)
def test_register_mixed_instance_and_type(container):
service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance)
container.register(IServiceB, ServiceB)
test_2 = container.resolve(TestClass2)
test_3 = container.resolve(TestClass3)
assert id(service_a_instance) == id(test_3.service_a)
assert isinstance(test_2.service_b, ServiceB)
assert isinstance(test_3.service_b, ServiceB)
assert id(test_2.service_b) != id(test_3.service_b)
def test_unregistered_type():
container = DIContainer()
with pytest.raises(ValueError):
container.resolve(TestClass1)
def test_type_registration_overwritten(container):
class ServiceA2(IServiceA):
def do_something(self):
pass
container.register(IServiceA, ServiceA)
container.register(IServiceA, ServiceA2)
test_1 = container.resolve(TestClass1)
assert isinstance(test_1.service_a, ServiceA2)
def test_instance_registration_overwritten(container):
service_a_instance_1 = ServiceA()
service_a_instance_2 = ServiceA()
container.register_instance(IServiceA, service_a_instance_1)
container.register_instance(IServiceA, service_a_instance_2)
test_1 = container.resolve(TestClass1)
assert id(test_1.service_a) != id(service_a_instance_1)
assert id(test_1.service_a) == id(service_a_instance_2)
def test_type_overrides_instance(container):
service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance)
container.register(IServiceA, ServiceA)
test_1 = container.resolve(TestClass1)
assert id(test_1.service_a) != id(service_a_instance)
assert isinstance(test_1.service_a, ServiceA)
def test_instance_overrides_type(container):
service_a_instance = ServiceA()
container.register(IServiceA, ServiceA)
container.register_instance(IServiceA, service_a_instance)
test_1 = container.resolve(TestClass1)
assert id(test_1.service_a) == id(service_a_instance)
def test_release_type(container):
container.register(IServiceA, ServiceA)
container.release(IServiceA)
with pytest.raises(ValueError):
container.resolve(TestClass1)
def test_release_instance(container):
service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance)
container.release(IServiceA)
with pytest.raises(ValueError):
container.resolve(TestClass1)
class IServiceC(metaclass=abc.ABCMeta):
pass
class ServiceC(IServiceC):
def __init__(self, service_a: IServiceA):
self.service_a = service_a
class TestClass4:
__test__ = False
def __init__(self, service_c: IServiceC):
self.service_c = service_c
def test_recursive_resolution__depth_2(container):
service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance)
container.register(IServiceC, ServiceC)
test4 = container.resolve(TestClass4)
assert isinstance(test4.service_c, ServiceC)
assert id(test4.service_c.service_a) == id(service_a_instance)
class IServiceD(metaclass=abc.ABCMeta):
pass
class ServiceD(IServiceD):
def __init__(self, service_c: IServiceC, service_b: IServiceB):
self.service_b = service_b
self.service_c = service_c
class TestClass5:
__test__ = False
def __init__(self, service_d: IServiceD):
self.service_d = service_d
def test_recursive_resolution__depth_3(container):
container.register(IServiceA, ServiceA)
container.register(IServiceB, ServiceB)
container.register(IServiceC, ServiceC)
container.register(IServiceD, ServiceD)
test5 = container.resolve(TestClass5)
assert isinstance(test5.service_d, ServiceD)
assert isinstance(test5.service_d.service_b, ServiceB)
assert isinstance(test5.service_d.service_c, ServiceC)
assert isinstance(test5.service_d.service_c.service_a, ServiceA)
def test_resolve_registered_interface(container):
container.register(IServiceA, ServiceA)
resolved_instance = container.resolve(IServiceA)
assert isinstance(resolved_instance, ServiceA)
def test_resolve_registered_instance(container):
service_a_instance = ServiceA()
container.register_instance(IServiceA, service_a_instance)
service_a_actual_instance = container.resolve(IServiceA)
assert id(service_a_actual_instance) == id(service_a_instance)
def test_resolve_dependencies(container):
container.register(IServiceA, ServiceA)
container.register(IServiceB, ServiceB)
dependencies = container.resolve_dependencies(TestClass3)
assert isinstance(dependencies[0], ServiceA)
assert isinstance(dependencies[1], ServiceB)
def test_register_instance_as_type(container):
service_a_instance = ServiceA()
with pytest.raises(TypeError):
container.register(IServiceA, service_a_instance)
def test_register_conflicting_type(container):
with pytest.raises(TypeError):
container.register(IServiceA, ServiceB)
def test_register_instance_with_conflicting_type(container):
service_b_instance = ServiceB()
with pytest.raises(TypeError):
container.register_instance(IServiceA, service_b_instance)

View File

@ -1,8 +1,14 @@
import os
import pytest
from tests.utils import add_files_to_dir, add_subdirs_to_dir
from common.utils.file_utils import InvalidPath, expand_path, get_file_sha256_hash
from common.utils.file_utils import (
InvalidPath,
expand_path,
get_all_regular_files_in_directory,
get_file_sha256_hash,
)
def test_expand_user(patched_home_env):
@ -26,3 +32,32 @@ def test_expand_path__empty_path_provided():
def test_get_file_sha256_hash(stable_file, stable_file_sha256_hash):
assert get_file_sha256_hash(stable_file) == stable_file_sha256_hash
SUBDIRS = ["subdir1", "subdir2"]
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg"]
def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path, SUBDIRS)
expected_return_value = []
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path, SUBDIRS)
files = add_files_to_dir(tmp_path, FILES)
expected_return_value = sorted(files)
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch):
subdirs = add_subdirs_to_dir(tmp_path, SUBDIRS)
add_files_to_dir(subdirs[0], FILES)
files = add_files_to_dir(tmp_path, FILES)
expected_return_value = sorted(files)
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value

View File

@ -1,66 +1,22 @@
import os
import pytest
from tests.utils import is_user_admin
from tests.utils import add_files_to_dir, is_user_admin
from common.utils.file_utils import get_all_regular_files_in_directory
from infection_monkey.utils.dir_utils import (
file_extension_filter,
filter_files,
get_all_regular_files_in_directory,
is_not_shortcut_filter,
is_not_symlink_filter,
)
SHORTCUT = "shortcut.lnk"
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg", SHORTCUT]
SUBDIRS = ["subdir1", "subdir2"]
def add_subdirs_to_dir(parent_dir):
subdirs = [parent_dir / s for s in SUBDIRS]
for subdir in subdirs:
subdir.mkdir()
return subdirs
def add_files_to_dir(parent_dir):
files = [parent_dir / f for f in FILES]
for f in files:
f.touch()
return files
def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path)
expected_return_value = []
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path)
files = add_files_to_dir(tmp_path)
expected_return_value = sorted(files)
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch):
subdirs = add_subdirs_to_dir(tmp_path)
add_files_to_dir(subdirs[0])
files = add_files_to_dir(tmp_path)
expected_return_value = sorted(files)
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_filter_files__no_results(tmp_path):
add_files_to_dir(tmp_path)
add_files_to_dir(tmp_path, FILES)
files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
@ -69,7 +25,7 @@ def test_filter_files__no_results(tmp_path):
def test_filter_files__all_true(tmp_path):
files = add_files_to_dir(tmp_path)
files = add_files_to_dir(tmp_path, FILES)
expected_return_value = sorted(files)
files_in_dir = get_all_regular_files_in_directory(tmp_path)
@ -79,7 +35,7 @@ def test_filter_files__all_true(tmp_path):
def test_filter_files__multiple_filters(tmp_path):
files = add_files_to_dir(tmp_path)
files = add_files_to_dir(tmp_path, FILES)
expected_return_value = sorted(files[4:6])
files_in_dir = get_all_regular_files_in_directory(tmp_path)
@ -93,7 +49,7 @@ def test_filter_files__multiple_filters(tmp_path):
def test_file_extension_filter(tmp_path):
valid_extensions = {".zip", ".xyz"}
files = add_files_to_dir(tmp_path)
files = add_files_to_dir(tmp_path, FILES)
files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = filter_files(files_in_dir, [file_extension_filter(valid_extensions)])
@ -105,7 +61,7 @@ def test_file_extension_filter(tmp_path):
os.name == "nt" and not is_user_admin(), reason="Test requires admin rights on Windows"
)
def test_is_not_symlink_filter(tmp_path):
files = add_files_to_dir(tmp_path)
files = add_files_to_dir(tmp_path, FILES)
link_path = tmp_path / "symlink.test"
link_path.symlink_to(files[0], target_is_directory=False)
@ -118,7 +74,7 @@ def test_is_not_symlink_filter(tmp_path):
def test_is_not_shortcut_filter(tmp_path):
add_files_to_dir(tmp_path)
add_files_to_dir(tmp_path, FILES)
files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))

View File

@ -1,3 +1,5 @@
from unittest.mock import MagicMock
import flask_jwt_extended
import flask_restful
import pytest
@ -9,15 +11,28 @@ import monkey_island.cc.resources.island_mode
from monkey_island.cc.services.representations import output_json
@pytest.fixture(scope="session")
@pytest.fixture
def flask_client(monkeypatch_session):
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
with mock_init_app().test_client() as client:
container = MagicMock()
container.resolve_dependencies.return_value = []
with mock_init_app(container).test_client() as client:
yield client
def mock_init_app():
@pytest.fixture
def build_flask_client(monkeypatch_session):
def inner(container):
monkeypatch_session.setattr(flask_jwt_extended, "verify_jwt_in_request", lambda: None)
return mock_init_app(container).test_client()
return inner
def mock_init_app(container):
app = Flask(__name__)
app.config["SECRET_KEY"] = "test_key"
@ -25,7 +40,8 @@ def mock_init_app():
api.representations = {"application/json": output_json}
monkey_island.cc.app.init_app_url_rules(app)
monkey_island.cc.app.init_api_resources(api)
flask_resource_manager = monkey_island.cc.app.FlaskDIWrapper(api, container)
monkey_island.cc.app.init_api_resources(flask_resource_manager)
flask_jwt_extended.JWTManager(app)

View File

@ -0,0 +1,54 @@
import io
from typing import BinaryIO
import pytest
from common import DIContainer
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
FILE_NAME = "test_file"
FILE_CONTENTS = b"HelloWorld!"
class MockFileStorageService(IFileStorageService):
def __init__(self):
self._file = io.BytesIO(FILE_CONTENTS)
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
pass
def open_file(self, unsafe_file_name: str) -> BinaryIO:
if unsafe_file_name != FILE_NAME:
raise FileRetrievalError()
return self._file
def delete_file(self, unsafe_file_name: str):
pass
def delete_all_files(self):
pass
@pytest.fixture
def flask_client(build_flask_client, tmp_path):
container = DIContainer()
container.register(IFileStorageService, MockFileStorageService)
with build_flask_client(container) as flask_client:
yield flask_client
def test_file_download_endpoint(tmp_path, flask_client):
resp = flask_client.get(f"/api/pba/download/{FILE_NAME}")
assert resp.status_code == 200
assert next(resp.response) == FILE_CONTENTS
def test_file_download_endpoint_404(tmp_path, flask_client):
nonexistant_file_name = "nonexistant_file"
resp = flask_client.get(f"/api/pba/download/{nonexistant_file_name}")
assert resp.status_code == 404

View File

@ -1,10 +1,16 @@
import io
from typing import BinaryIO
import pytest
from tests.utils import raise_
from common import DIContainer
from monkey_island.cc.resources.pba_file_upload import LINUX_PBA_TYPE, WINDOWS_PBA_TYPE
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
from monkey_island.cc.services import FileRetrievalError, IFileStorageService
TEST_FILE = b"""-----------------------------1
TEST_FILE_CONTENTS = b"m0nk3y"
TEST_FILE = (
b"""-----------------------------1
Content-Disposition: form-data; name="filepond"
{}
@ -12,13 +18,11 @@ Content-Disposition: form-data; name="filepond"
Content-Disposition: form-data; name="filepond"; filename="test.py"
Content-Type: text/x-python
m0nk3y
"""
+ TEST_FILE_CONTENTS
+ b"""
-----------------------------1--"""
@pytest.fixture(autouse=True)
def custom_pba_directory(tmpdir):
PostBreachFilesService.initialize(tmpdir)
)
@pytest.fixture
@ -35,8 +39,41 @@ def mock_get_config_value(monkeypatch):
)
class MockFileStorageService(IFileStorageService):
def __init__(self):
self._file = None
def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):
self._file = io.BytesIO(file_contents.read())
def open_file(self, unsafe_file_name: str) -> BinaryIO:
if self._file is None:
raise FileRetrievalError()
return self._file
def delete_file(self, unsafe_file_name: str):
self._file = None
def delete_all_files(self):
self.delete_file("")
@pytest.fixture
def file_storage_service():
return MockFileStorageService()
@pytest.fixture
def flask_client(build_flask_client, file_storage_service):
container = DIContainer()
container.register_instance(IFileStorageService, file_storage_service)
with build_flask_client(container) as flask_client:
yield flask_client
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config_value):
def test_pba_file_upload_post(flask_client, pba_os, mock_set_config_value):
resp = flask_client.post(
f"/api/file-upload/{pba_os}",
data=TEST_FILE,
@ -46,7 +83,7 @@ def test_pba_file_upload_post(flask_client, pba_os, monkeypatch, mock_set_config
assert resp.status_code == 200
def test_pba_file_upload_post__invalid(flask_client, monkeypatch, mock_set_config_value):
def test_pba_file_upload_post__invalid(flask_client, mock_set_config_value):
resp = flask_client.post(
"/api/file-upload/bogus",
data=TEST_FILE,
@ -58,12 +95,9 @@ def test_pba_file_upload_post__invalid(flask_client, monkeypatch, mock_set_confi
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
def test_pba_file_upload_post__internal_server_error(
flask_client, pba_os, monkeypatch, mock_set_config_value
flask_client, pba_os, mock_set_config_value, file_storage_service
):
monkeypatch.setattr(
"monkey_island.cc.resources.pba_file_upload.FileUpload.upload_pba_file",
lambda x, y: raise_(Exception()),
)
file_storage_service.save_file = lambda x, y: raise_(Exception())
resp = flask_client.post(
f"/api/file-upload/{pba_os}",
@ -75,16 +109,14 @@ def test_pba_file_upload_post__internal_server_error(
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
def test_pba_file_upload_get__file_not_found(
flask_client, pba_os, monkeypatch, mock_get_config_value
):
def test_pba_file_upload_get__file_not_found(flask_client, pba_os, mock_get_config_value):
resp = flask_client.get(f"/api/file-upload/{pba_os}?load=bogus_mogus.py")
assert resp.status_code == 404
@pytest.mark.parametrize("pba_os", [LINUX_PBA_TYPE, WINDOWS_PBA_TYPE])
def test_pba_file_upload_endpoint(
flask_client, pba_os, monkeypatch, mock_get_config_value, mock_set_config_value
flask_client, pba_os, mock_get_config_value, mock_set_config_value
):
resp_post = flask_client.post(
f"/api/file-upload/{pba_os}",
@ -95,7 +127,7 @@ def test_pba_file_upload_endpoint(
resp_get = flask_client.get(f"/api/file-upload/{pba_os}?load=test.py")
assert resp_get.status_code == 200
assert resp_get.data.decode() == "m0nk3y"
assert resp_get.data == TEST_FILE_CONTENTS
# Closing the response closes the file handle, else it can't be deleted
resp_get.close()
@ -111,7 +143,7 @@ def test_pba_file_upload_endpoint(
def test_pba_file_upload_endpoint__invalid(
flask_client, monkeypatch, mock_set_config_value, mock_get_config_value
flask_client, mock_set_config_value, mock_get_config_value
):
resp_post = flask_client.post(
"/api/file-upload/bogus",

View File

@ -2,7 +2,7 @@ import os
import stat
import pytest
from tests.monkey_island.utils import assert_windows_permissions
from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions
from monkey_island.cc.server_utils.file_utils import (
create_secure_directory,
@ -39,12 +39,8 @@ def test_create_secure_directory__no_parent_dir(test_path_nested):
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
def test_create_secure_directory__perm_linux(test_path):
create_secure_directory(test_path)
st = os.stat(test_path)
expected_mode = stat.S_IRWXU
actual_mode = st.st_mode & (stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
assert expected_mode == actual_mode
assert_linux_permissions(test_path)
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")

View File

@ -0,0 +1,135 @@
import io
from pathlib import Path
import pytest
from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions
from monkey_island.cc.server_utils.file_utils import is_windows_os
from monkey_island.cc.services import DirectoryFileStorageService, FileRetrievalError
def test_error_if_storage_directory_is_file(tmp_path):
new_file = tmp_path / "new_file.txt"
new_file.write_text("HelloWorld!")
with pytest.raises(ValueError):
DirectoryFileStorageService(new_file)
def test_directory_created(tmp_path):
new_dir = tmp_path / "new_dir"
DirectoryFileStorageService(new_dir)
assert new_dir.exists() and new_dir.is_dir()
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
def test_directory_permissions__linux(tmp_path):
new_dir = tmp_path / "new_dir"
DirectoryFileStorageService(new_dir)
assert_linux_permissions(new_dir)
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
def test_directory_permissions__windows(tmp_path):
new_dir = tmp_path / "new_dir"
DirectoryFileStorageService(new_dir)
assert_windows_permissions(new_dir)
def save_file(tmp_path, file_path_prefix=""):
file_name = "test.txt"
file_contents = "Hello World!"
expected_file_path = tmp_path / file_name
fss = DirectoryFileStorageService(tmp_path)
fss.save_file(Path(file_path_prefix) / file_name, io.BytesIO(file_contents.encode()))
assert expected_file_path.is_file()
assert expected_file_path.read_text() == file_contents
def delete_file(tmp_path, file_path_prefix=""):
file_name = "file.txt"
file = tmp_path / file_name
file.touch()
assert file.is_file()
fss = DirectoryFileStorageService(tmp_path)
fss.delete_file(Path(file_path_prefix) / file_name)
assert not file.exists()
def open_file(tmp_path, file_path_prefix=""):
file_name = "test.txt"
expected_file_contents = "Hello World!"
expected_file_path = tmp_path / file_name
expected_file_path.write_text(expected_file_contents)
fss = DirectoryFileStorageService(tmp_path)
with fss.open_file(Path(file_path_prefix) / file_name) as f:
actual_file_contents = f.read()
assert actual_file_contents == expected_file_contents.encode()
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
def test_fn(tmp_path, fn):
fn(tmp_path)
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
def test_fn__ignore_relative_path(tmp_path, fn):
fn(tmp_path, "../../")
@pytest.mark.parametrize("fn", [save_file, open_file, delete_file])
def test_fn__ignore_absolute_path(tmp_path, fn):
if is_windows_os():
fn(tmp_path, "C:\\Windows")
else:
fn(tmp_path, "/home/")
def test_remove_all_files(tmp_path):
for filename in ["1.txt", "2.txt", "3.txt"]:
(tmp_path / filename).touch()
fss = DirectoryFileStorageService(tmp_path)
fss.delete_all_files()
for file in tmp_path.iterdir():
assert False, f"{tmp_path} was expected to be empty, but contained files"
def test_remove_all_files__skip_directories(tmp_path):
test_dir = tmp_path / "test_dir"
test_dir.mkdir()
for filename in ["1.txt", "2.txt", "3.txt"]:
(tmp_path / filename).touch()
fss = DirectoryFileStorageService(tmp_path)
fss.delete_all_files()
for file in tmp_path.iterdir():
assert file.name == test_dir.name
def test_remove_nonexistant_file(tmp_path):
fss = DirectoryFileStorageService(tmp_path)
# This test will fail if this call raises an exception.
fss.delete_file("nonexistant_file.txt")
def test_open_nonexistant_file(tmp_path):
fss = DirectoryFileStorageService(tmp_path)
with pytest.raises(FileRetrievalError):
fss.open_file("nonexistant_file.txt")

View File

@ -1,29 +1,31 @@
import io
import os
import pytest
from tests.monkey_island.utils import assert_windows_permissions
from tests.utils import raise_
from monkey_island.cc.server_utils.file_utils import is_windows_os
from monkey_island.cc.services import DirectoryFileStorageService
from monkey_island.cc.services.post_breach_files import PostBreachFilesService
@pytest.fixture
def file_storage_service(tmp_path):
return DirectoryFileStorageService(tmp_path)
@pytest.fixture(autouse=True)
def custom_pba_directory(tmpdir):
PostBreachFilesService.initialize(tmpdir)
def post_breach_files_service(file_storage_service):
PostBreachFilesService.initialize(file_storage_service)
def create_custom_pba_file(filename):
PostBreachFilesService.save_file(filename, b"")
def test_remove_pba_files(file_storage_service, tmp_path):
file_storage_service.save_file("linux_file", io.BytesIO(b""))
file_storage_service.save_file("windows_file", io.BytesIO(b""))
assert not dir_is_empty(tmp_path)
def test_remove_pba_files():
create_custom_pba_file("linux_file")
create_custom_pba_file("windows_file")
assert not dir_is_empty(PostBreachFilesService.get_custom_pba_directory())
PostBreachFilesService.remove_PBA_files()
assert dir_is_empty(PostBreachFilesService.get_custom_pba_directory())
assert dir_is_empty(tmp_path)
def dir_is_empty(dir_path):
@ -31,45 +33,11 @@ def dir_is_empty(dir_path):
return len(dir_contents) == 0
@pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.")
def test_custom_pba_dir_permissions_linux():
st = os.stat(PostBreachFilesService.get_custom_pba_directory())
assert st.st_mode == 0o40700
@pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.")
def test_custom_pba_dir_permissions_windows():
pba_dir = PostBreachFilesService.get_custom_pba_directory()
assert_windows_permissions(pba_dir)
def test_remove_failure(monkeypatch):
def test_remove_failure(file_storage_service, monkeypatch):
monkeypatch.setattr(os, "remove", lambda x: raise_(OSError("Permission denied")))
try:
create_custom_pba_file("windows_file")
file_storage_service.save_file("windows_file", io.BytesIO(b""))
PostBreachFilesService.remove_PBA_files()
except Exception as ex:
pytest.fail(f"Unxepected exception: {ex}")
def test_remove_nonexistant_file(monkeypatch):
monkeypatch.setattr(os, "remove", lambda x: raise_(FileNotFoundError("FileNotFound")))
try:
PostBreachFilesService.remove_file("/nonexistant/file")
except Exception as ex:
pytest.fail(f"Unxepected exception: {ex}")
def test_save_file():
FILE_NAME = "test_file"
FILE_CONTENTS = b"hello"
PostBreachFilesService.save_file(FILE_NAME, FILE_CONTENTS)
expected_file_path = os.path.join(PostBreachFilesService.get_custom_pba_directory(), FILE_NAME)
assert os.path.isfile(expected_file_path)
assert FILE_CONTENTS == open(expected_file_path, "rb").read()

View File

@ -1,5 +1,7 @@
import ctypes
import os
from pathlib import Path
from typing import Iterable
def is_user_admin():
@ -11,3 +13,21 @@ def is_user_admin():
def raise_(ex):
raise ex
def add_subdirs_to_dir(parent_dir: Path, subdirs: Iterable[str]) -> Iterable[Path]:
subdir_paths = [parent_dir / s for s in subdirs]
for subdir in subdir_paths:
subdir.mkdir()
return subdir_paths
def add_files_to_dir(parent_dir: Path, file_names: Iterable[str]) -> Iterable[Path]:
files = [parent_dir / f for f in file_names]
for f in files:
f.touch()
return files