Merge pull request #1819 from guardicore/1612-interruptible-ransomware

1612 interruptible ransomware
This commit is contained in:
Mike Salvatore 2022-03-28 09:15:30 -04:00 committed by GitHub
commit 1ec5be908d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 147 additions and 37 deletions

View File

@ -1,6 +1,6 @@
import filecmp
from pathlib import Path
from typing import List, Set
from typing import Iterable, Set
from infection_monkey.utils.dir_utils import (
file_extension_filter,
@ -17,7 +17,7 @@ class ProductionSafeTargetFileSelector:
def __init__(self, targeted_file_extensions: Set[str]):
self._targeted_file_extensions = targeted_file_extensions
def __call__(self, target_dir: Path) -> List[Path]:
def __call__(self, target_dir: Path) -> Iterable[Path]:
file_filters = [
file_extension_filter(self._targeted_file_extensions),
is_not_shortcut_filter,

View File

@ -1,11 +1,11 @@
import logging
import threading
from pathlib import Path
from typing import Callable, List
from typing import Callable, Iterable
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.utils.threading import interruptible_iter
from infection_monkey.utils.threading import interruptible_function, interruptible_iter
from .consts import README_FILE_NAME, README_SRC
from .ransomware_options import RansomwareOptions
@ -18,7 +18,7 @@ class Ransomware:
self,
config: RansomwareOptions,
encrypt_file: Callable[[Path], None],
select_files: Callable[[Path], List[Path]],
select_files: Callable[[Path], Iterable[Path]],
leave_readme: Callable[[Path, Path], None],
telemetry_messenger: ITelemetryMessenger,
):
@ -31,7 +31,9 @@ class Ransomware:
self._target_directory = self._config.target_directory
self._readme_file_path = (
self._target_directory / README_FILE_NAME if self._target_directory else None
self._target_directory / README_FILE_NAME # type: ignore
if self._target_directory
else None
)
def run(self, interrupt: threading.Event):
@ -41,25 +43,26 @@ class Ransomware:
logger.info("Running ransomware payload")
if self._config.encryption_enabled:
file_list = self._find_files()
self._encrypt_files(file_list, interrupt)
files_to_encrypt = self._find_files()
self._encrypt_files(files_to_encrypt, interrupt)
if self._config.readme_enabled:
self._leave_readme_in_target_directory(interrupt)
self._leave_readme_in_target_directory(interrupt=interrupt)
def _find_files(self) -> List[Path]:
def _find_files(self) -> Iterable[Path]:
logger.info(f"Collecting files in {self._target_directory}")
return sorted(self._select_files(self._target_directory))
return self._select_files(self._target_directory) # type: ignore
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
def _encrypt_files(self, files_to_encrypt: Iterable[Path], interrupt: threading.Event):
logger.info(f"Encrypting files in {self._target_directory}")
interrupted_message = (
"Received a stop signal, skipping remaining files for encryption of ransomware payload"
)
for filepath in interruptible_iter(file_list, interrupt, interrupted_message):
interrupted_message = "Received a stop signal, skipping encryption of remaining files"
for filepath in interruptible_iter(files_to_encrypt, interrupt, interrupted_message):
try:
logger.debug(f"Encrypting {filepath}")
# Note that encrypting a single file is not interruptible. This is so that we avoid
# leaving half-encrypted files on the user's system.
self._encrypt_file(filepath)
self._send_telemetry(filepath, True, "")
except Exception as ex:
@ -70,12 +73,9 @@ class Ransomware:
encryption_attempt = FileEncryptionTelem(str(filepath), success, error)
self._telemetry_messenger.send_telemetry(encryption_attempt)
def _leave_readme_in_target_directory(self, interrupt: threading.Event):
if interrupt.is_set():
logger.debug("Received a stop signal, skipping leave readme")
return
@interruptible_function(msg="Received a stop signal, skipping leave readme")
def _leave_readme_in_target_directory(self, *, interrupt: threading.Event):
try:
self._leave_readme(README_SRC, self._readme_file_path)
self._leave_readme(README_SRC, self._readme_file_path) # type: ignore
except Exception as ex:
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")

View File

@ -1,17 +1,17 @@
from pathlib import Path
from typing import Callable, Iterable, List, Set
from typing import Callable, Iterable, Set
def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]:
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
def filter_files(
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
) -> List[Path]:
) -> Iterable[Path]:
filtered_files = files
for file_filter in file_filters:
filtered_files = [f for f in filtered_files if file_filter(f)]
filtered_files = filter(file_filter, filtered_files)
return filtered_files

View File

@ -1,7 +1,8 @@
import logging
from functools import wraps
from itertools import count
from threading import Event, Thread
from typing import Any, Callable, Iterable, Tuple
from typing import Any, Callable, Iterable, Optional, Tuple
logger = logging.getLogger(__name__)
@ -53,3 +54,56 @@ def interruptible_iter(
break
yield i
def interruptible_function(*, msg: Optional[str] = None, default_return_value: Any = None):
"""
This decorator allows a function to be skipped if an interrupt (threading.Event) is set. This is
useful for interrupting running code without introducing duplicate `if` checks at the beginning
of each function.
Note: It is required that the decorated function accept a keyword-only argument named
"interrupt".
Example:
def run_algorithm(*inputs, interrupt: threading.Event):
return_value = do_action_1(inputs[1], interrupt=interrupt)
return_value = do_action_2(return_value + inputs[2], interrupt=interrupt)
return_value = do_action_3(return_value + inputs[3], interrupt=interrupt)
return return_value
@interruptible_function(msg="Interrupt detected, skipping action 1", default_return_value=0)
def do_action_1(input, *, interrupt: threading.Event):
# Process input
...
@interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0)
def do_action_2(input, *, interrupt: threading.Event):
# Process input
...
@interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0)
def do_action_2(input, *, interrupt: threading.Event):
# Process input
...
:param str msg: A message to log at the debug level if an interrupt is detected. Defaults to
None.
:param Any default_return_value: A value to return if the wrapped function is not run. Defaults
to None.
"""
def _decorator(fn):
@wraps(fn)
def _wrapper(*args, interrupt: Event, **kwargs):
if interrupt.is_set():
if msg:
logger.debug(msg)
return default_return_value
return fn(*args, interrupt=interrupt, **kwargs)
return _wrapper
return _decorator

View File

@ -24,7 +24,7 @@ def file_selector():
def test_select_targeted_files_only(ransomware_test_data, file_selector):
selected_files = file_selector(ransomware_test_data)
selected_files = list(file_selector(ransomware_test_data))
assert len(selected_files) == 2
assert (ransomware_test_data / ALL_ZEROS_PDF) in selected_files

View File

@ -58,10 +58,12 @@ def mock_file_encryptor():
@pytest.fixture
def mock_file_selector(ransomware_test_data):
selected_files = [
ransomware_test_data / ALL_ZEROS_PDF,
ransomware_test_data / TEST_KEYBOARD_TXT,
]
selected_files = iter(
[
ransomware_test_data / ALL_ZEROS_PDF,
ransomware_test_data / TEST_KEYBOARD_TXT,
]
)
return MagicMock(return_value=selected_files)

View File

@ -38,7 +38,7 @@ def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path)
expected_return_value = []
assert get_all_regular_files_in_directory(tmp_path) == expected_return_value
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
@ -63,7 +63,7 @@ def test_filter_files__no_results(tmp_path):
add_files_to_dir(tmp_path)
files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = filter_files(files_in_dir, [lambda _: False])
filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
assert len(filtered_files) == 0
@ -109,8 +109,8 @@ def test_is_not_symlink_filter(tmp_path):
link_path = tmp_path / "symlink.test"
link_path.symlink_to(files[0], target_is_directory=False)
files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = filter_files(files_in_dir, [is_not_symlink_filter])
files_in_dir = list(get_all_regular_files_in_directory(tmp_path))
filtered_files = list(filter_files(files_in_dir, [is_not_symlink_filter]))
assert link_path in files_in_dir
assert len(filtered_files) == len(FILES)
@ -121,7 +121,7 @@ def test_is_not_shortcut_filter(tmp_path):
add_files_to_dir(tmp_path)
files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = filter_files(files_in_dir, [is_not_shortcut_filter])
filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))
assert len(filtered_files) == len(FILES) - 1
assert SHORTCUT not in [f.name for f in filtered_files]

View File

@ -1,8 +1,10 @@
import logging
from threading import Event, current_thread
from typing import Any
from infection_monkey.utils.threading import (
create_daemon_thread,
interruptible_function,
interruptible_iter,
run_worker_threads,
)
@ -73,3 +75,55 @@ def test_worker_thread_names():
assert "B-01" in thread_names
assert "B-02" in thread_names
assert len(thread_names) == 6
class MockFunction:
def __init__(self):
self._call_count = 0
@property
def call_count(self):
return self._call_count
@property
def return_value(self):
return 42
def __call__(self, *_, interrupt: Event) -> Any:
self._call_count += 1
return self.return_value
def test_interruptible_decorator_calls_decorated_function():
fn = MockFunction()
int_fn = interruptible_function()(fn)
return_value = int_fn(interrupt=Event())
assert return_value == fn.return_value
assert fn.call_count == 1
def test_interruptible_decorator_skips_decorated_function():
fn = MockFunction()
int_fn = interruptible_function()(fn)
interrupt = Event()
interrupt.set()
return_value = int_fn(interrupt=interrupt)
assert return_value is None
assert fn.call_count == 0
def test_interruptible_decorator_returns_default_value_on_interrupt():
fn = MockFunction()
int_fn = interruptible_function(default_return_value=777)(fn)
interrupt = Event()
interrupt.set()
return_value = int_fn(interrupt=interrupt)
assert return_value == 777
assert fn.call_count == 0