Merge pull request #1819 from guardicore/1612-interruptible-ransomware
1612 interruptible ransomware
This commit is contained in:
commit
1ec5be908d
|
@ -1,6 +1,6 @@
|
|||
import filecmp
|
||||
from pathlib import Path
|
||||
from typing import List, Set
|
||||
from typing import Iterable, Set
|
||||
|
||||
from infection_monkey.utils.dir_utils import (
|
||||
file_extension_filter,
|
||||
|
@ -17,7 +17,7 @@ class ProductionSafeTargetFileSelector:
|
|||
def __init__(self, targeted_file_extensions: Set[str]):
|
||||
self._targeted_file_extensions = targeted_file_extensions
|
||||
|
||||
def __call__(self, target_dir: Path) -> List[Path]:
|
||||
def __call__(self, target_dir: Path) -> Iterable[Path]:
|
||||
file_filters = [
|
||||
file_extension_filter(self._targeted_file_extensions),
|
||||
is_not_shortcut_filter,
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Callable, List
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
|
||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||
from infection_monkey.utils.threading import interruptible_iter
|
||||
from infection_monkey.utils.threading import interruptible_function, interruptible_iter
|
||||
|
||||
from .consts import README_FILE_NAME, README_SRC
|
||||
from .ransomware_options import RansomwareOptions
|
||||
|
@ -18,7 +18,7 @@ class Ransomware:
|
|||
self,
|
||||
config: RansomwareOptions,
|
||||
encrypt_file: Callable[[Path], None],
|
||||
select_files: Callable[[Path], List[Path]],
|
||||
select_files: Callable[[Path], Iterable[Path]],
|
||||
leave_readme: Callable[[Path, Path], None],
|
||||
telemetry_messenger: ITelemetryMessenger,
|
||||
):
|
||||
|
@ -31,7 +31,9 @@ class Ransomware:
|
|||
|
||||
self._target_directory = self._config.target_directory
|
||||
self._readme_file_path = (
|
||||
self._target_directory / README_FILE_NAME if self._target_directory else None
|
||||
self._target_directory / README_FILE_NAME # type: ignore
|
||||
if self._target_directory
|
||||
else None
|
||||
)
|
||||
|
||||
def run(self, interrupt: threading.Event):
|
||||
|
@ -41,25 +43,26 @@ class Ransomware:
|
|||
logger.info("Running ransomware payload")
|
||||
|
||||
if self._config.encryption_enabled:
|
||||
file_list = self._find_files()
|
||||
self._encrypt_files(file_list, interrupt)
|
||||
files_to_encrypt = self._find_files()
|
||||
self._encrypt_files(files_to_encrypt, interrupt)
|
||||
|
||||
if self._config.readme_enabled:
|
||||
self._leave_readme_in_target_directory(interrupt)
|
||||
self._leave_readme_in_target_directory(interrupt=interrupt)
|
||||
|
||||
def _find_files(self) -> List[Path]:
|
||||
def _find_files(self) -> Iterable[Path]:
|
||||
logger.info(f"Collecting files in {self._target_directory}")
|
||||
return sorted(self._select_files(self._target_directory))
|
||||
return self._select_files(self._target_directory) # type: ignore
|
||||
|
||||
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
|
||||
def _encrypt_files(self, files_to_encrypt: Iterable[Path], interrupt: threading.Event):
|
||||
logger.info(f"Encrypting files in {self._target_directory}")
|
||||
|
||||
interrupted_message = (
|
||||
"Received a stop signal, skipping remaining files for encryption of ransomware payload"
|
||||
)
|
||||
for filepath in interruptible_iter(file_list, interrupt, interrupted_message):
|
||||
interrupted_message = "Received a stop signal, skipping encryption of remaining files"
|
||||
for filepath in interruptible_iter(files_to_encrypt, interrupt, interrupted_message):
|
||||
try:
|
||||
logger.debug(f"Encrypting {filepath}")
|
||||
|
||||
# Note that encrypting a single file is not interruptible. This is so that we avoid
|
||||
# leaving half-encrypted files on the user's system.
|
||||
self._encrypt_file(filepath)
|
||||
self._send_telemetry(filepath, True, "")
|
||||
except Exception as ex:
|
||||
|
@ -70,12 +73,9 @@ class Ransomware:
|
|||
encryption_attempt = FileEncryptionTelem(str(filepath), success, error)
|
||||
self._telemetry_messenger.send_telemetry(encryption_attempt)
|
||||
|
||||
def _leave_readme_in_target_directory(self, interrupt: threading.Event):
|
||||
if interrupt.is_set():
|
||||
logger.debug("Received a stop signal, skipping leave readme")
|
||||
return
|
||||
|
||||
@interruptible_function(msg="Received a stop signal, skipping leave readme")
|
||||
def _leave_readme_in_target_directory(self, *, interrupt: threading.Event):
|
||||
try:
|
||||
self._leave_readme(README_SRC, self._readme_file_path)
|
||||
self._leave_readme(README_SRC, self._readme_file_path) # type: ignore
|
||||
except Exception as ex:
|
||||
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
from pathlib import Path
|
||||
from typing import Callable, Iterable, List, Set
|
||||
from typing import Callable, Iterable, Set
|
||||
|
||||
|
||||
def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]:
|
||||
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
|
||||
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
|
||||
|
||||
|
||||
def filter_files(
|
||||
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
|
||||
) -> List[Path]:
|
||||
) -> Iterable[Path]:
|
||||
filtered_files = files
|
||||
for file_filter in file_filters:
|
||||
filtered_files = [f for f in filtered_files if file_filter(f)]
|
||||
filtered_files = filter(file_filter, filtered_files)
|
||||
|
||||
return filtered_files
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
from functools import wraps
|
||||
from itertools import count
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Callable, Iterable, Tuple
|
||||
from typing import Any, Callable, Iterable, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,3 +54,56 @@ def interruptible_iter(
|
|||
break
|
||||
|
||||
yield i
|
||||
|
||||
|
||||
def interruptible_function(*, msg: Optional[str] = None, default_return_value: Any = None):
|
||||
"""
|
||||
This decorator allows a function to be skipped if an interrupt (threading.Event) is set. This is
|
||||
useful for interrupting running code without introducing duplicate `if` checks at the beginning
|
||||
of each function.
|
||||
|
||||
Note: It is required that the decorated function accept a keyword-only argument named
|
||||
"interrupt".
|
||||
|
||||
Example:
|
||||
def run_algorithm(*inputs, interrupt: threading.Event):
|
||||
return_value = do_action_1(inputs[1], interrupt=interrupt)
|
||||
return_value = do_action_2(return_value + inputs[2], interrupt=interrupt)
|
||||
return_value = do_action_3(return_value + inputs[3], interrupt=interrupt)
|
||||
|
||||
return return_value
|
||||
|
||||
@interruptible_function(msg="Interrupt detected, skipping action 1", default_return_value=0)
|
||||
def do_action_1(input, *, interrupt: threading.Event):
|
||||
# Process input
|
||||
...
|
||||
|
||||
@interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0)
|
||||
def do_action_2(input, *, interrupt: threading.Event):
|
||||
# Process input
|
||||
...
|
||||
|
||||
@interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0)
|
||||
def do_action_2(input, *, interrupt: threading.Event):
|
||||
# Process input
|
||||
...
|
||||
|
||||
:param str msg: A message to log at the debug level if an interrupt is detected. Defaults to
|
||||
None.
|
||||
:param Any default_return_value: A value to return if the wrapped function is not run. Defaults
|
||||
to None.
|
||||
"""
|
||||
|
||||
def _decorator(fn):
|
||||
@wraps(fn)
|
||||
def _wrapper(*args, interrupt: Event, **kwargs):
|
||||
if interrupt.is_set():
|
||||
if msg:
|
||||
logger.debug(msg)
|
||||
return default_return_value
|
||||
|
||||
return fn(*args, interrupt=interrupt, **kwargs)
|
||||
|
||||
return _wrapper
|
||||
|
||||
return _decorator
|
||||
|
|
|
@ -24,7 +24,7 @@ def file_selector():
|
|||
|
||||
|
||||
def test_select_targeted_files_only(ransomware_test_data, file_selector):
|
||||
selected_files = file_selector(ransomware_test_data)
|
||||
selected_files = list(file_selector(ransomware_test_data))
|
||||
|
||||
assert len(selected_files) == 2
|
||||
assert (ransomware_test_data / ALL_ZEROS_PDF) in selected_files
|
||||
|
|
|
@ -58,10 +58,12 @@ def mock_file_encryptor():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_file_selector(ransomware_test_data):
|
||||
selected_files = [
|
||||
ransomware_test_data / ALL_ZEROS_PDF,
|
||||
ransomware_test_data / TEST_KEYBOARD_TXT,
|
||||
]
|
||||
selected_files = iter(
|
||||
[
|
||||
ransomware_test_data / ALL_ZEROS_PDF,
|
||||
ransomware_test_data / TEST_KEYBOARD_TXT,
|
||||
]
|
||||
)
|
||||
return MagicMock(return_value=selected_files)
|
||||
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
|
|||
add_subdirs_to_dir(tmp_path)
|
||||
|
||||
expected_return_value = []
|
||||
assert get_all_regular_files_in_directory(tmp_path) == expected_return_value
|
||||
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||
|
||||
|
||||
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
|
||||
|
@ -63,7 +63,7 @@ def test_filter_files__no_results(tmp_path):
|
|||
add_files_to_dir(tmp_path)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = filter_files(files_in_dir, [lambda _: False])
|
||||
filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
|
||||
|
||||
assert len(filtered_files) == 0
|
||||
|
||||
|
@ -109,8 +109,8 @@ def test_is_not_symlink_filter(tmp_path):
|
|||
link_path = tmp_path / "symlink.test"
|
||||
link_path.symlink_to(files[0], target_is_directory=False)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = filter_files(files_in_dir, [is_not_symlink_filter])
|
||||
files_in_dir = list(get_all_regular_files_in_directory(tmp_path))
|
||||
filtered_files = list(filter_files(files_in_dir, [is_not_symlink_filter]))
|
||||
|
||||
assert link_path in files_in_dir
|
||||
assert len(filtered_files) == len(FILES)
|
||||
|
@ -121,7 +121,7 @@ def test_is_not_shortcut_filter(tmp_path):
|
|||
add_files_to_dir(tmp_path)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = filter_files(files_in_dir, [is_not_shortcut_filter])
|
||||
filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))
|
||||
|
||||
assert len(filtered_files) == len(FILES) - 1
|
||||
assert SHORTCUT not in [f.name for f in filtered_files]
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from threading import Event, current_thread
|
||||
from typing import Any
|
||||
|
||||
from infection_monkey.utils.threading import (
|
||||
create_daemon_thread,
|
||||
interruptible_function,
|
||||
interruptible_iter,
|
||||
run_worker_threads,
|
||||
)
|
||||
|
@ -73,3 +75,55 @@ def test_worker_thread_names():
|
|||
assert "B-01" in thread_names
|
||||
assert "B-02" in thread_names
|
||||
assert len(thread_names) == 6
|
||||
|
||||
|
||||
class MockFunction:
|
||||
def __init__(self):
|
||||
self._call_count = 0
|
||||
|
||||
@property
|
||||
def call_count(self):
|
||||
return self._call_count
|
||||
|
||||
@property
|
||||
def return_value(self):
|
||||
return 42
|
||||
|
||||
def __call__(self, *_, interrupt: Event) -> Any:
|
||||
self._call_count += 1
|
||||
|
||||
return self.return_value
|
||||
|
||||
|
||||
def test_interruptible_decorator_calls_decorated_function():
|
||||
fn = MockFunction()
|
||||
int_fn = interruptible_function()(fn)
|
||||
|
||||
return_value = int_fn(interrupt=Event())
|
||||
|
||||
assert return_value == fn.return_value
|
||||
assert fn.call_count == 1
|
||||
|
||||
|
||||
def test_interruptible_decorator_skips_decorated_function():
|
||||
fn = MockFunction()
|
||||
int_fn = interruptible_function()(fn)
|
||||
interrupt = Event()
|
||||
interrupt.set()
|
||||
|
||||
return_value = int_fn(interrupt=interrupt)
|
||||
|
||||
assert return_value is None
|
||||
assert fn.call_count == 0
|
||||
|
||||
|
||||
def test_interruptible_decorator_returns_default_value_on_interrupt():
|
||||
fn = MockFunction()
|
||||
int_fn = interruptible_function(default_return_value=777)(fn)
|
||||
interrupt = Event()
|
||||
interrupt.set()
|
||||
|
||||
return_value = int_fn(interrupt=interrupt)
|
||||
|
||||
assert return_value == 777
|
||||
assert fn.call_count == 0
|
||||
|
|
Loading…
Reference in New Issue