Agent: Use iterators instead of lists for ransomware file filtering

This commit is contained in:
Mike Salvatore 2022-03-25 08:27:45 -04:00
parent 4316329384
commit 7c6ba2e276
6 changed files with 30 additions and 26 deletions

View File

@ -1,6 +1,6 @@
import filecmp import filecmp
from pathlib import Path from pathlib import Path
from typing import List, Set from typing import Iterable, Set
from infection_monkey.utils.dir_utils import ( from infection_monkey.utils.dir_utils import (
file_extension_filter, file_extension_filter,
@ -17,7 +17,7 @@ class ProductionSafeTargetFileSelector:
def __init__(self, targeted_file_extensions: Set[str]): def __init__(self, targeted_file_extensions: Set[str]):
self._targeted_file_extensions = targeted_file_extensions self._targeted_file_extensions = targeted_file_extensions
def __call__(self, target_dir: Path) -> List[Path]: def __call__(self, target_dir: Path) -> Iterable[Path]:
file_filters = [ file_filters = [
file_extension_filter(self._targeted_file_extensions), file_extension_filter(self._targeted_file_extensions),
is_not_shortcut_filter, is_not_shortcut_filter,

View File

@ -1,7 +1,7 @@
import logging import logging
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Callable, List from typing import Callable, Iterable
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
@ -18,7 +18,7 @@ class Ransomware:
self, self,
config: RansomwareOptions, config: RansomwareOptions,
encrypt_file: Callable[[Path], None], encrypt_file: Callable[[Path], None],
select_files: Callable[[Path], List[Path]], select_files: Callable[[Path], Iterable[Path]],
leave_readme: Callable[[Path, Path], None], leave_readme: Callable[[Path, Path], None],
telemetry_messenger: ITelemetryMessenger, telemetry_messenger: ITelemetryMessenger,
): ):
@ -31,7 +31,9 @@ class Ransomware:
self._target_directory = self._config.target_directory self._target_directory = self._config.target_directory
self._readme_file_path = ( self._readme_file_path = (
self._target_directory / README_FILE_NAME if self._target_directory else None self._target_directory / README_FILE_NAME # type: ignore
if self._target_directory
else None
) )
def run(self, interrupt: threading.Event): def run(self, interrupt: threading.Event):
@ -41,23 +43,23 @@ class Ransomware:
logger.info("Running ransomware payload") logger.info("Running ransomware payload")
if self._config.encryption_enabled: if self._config.encryption_enabled:
file_list = self._find_files() files_to_encrypt = self._find_files()
self._encrypt_files(file_list, interrupt) self._encrypt_files(files_to_encrypt, interrupt)
if self._config.readme_enabled: if self._config.readme_enabled:
self._leave_readme_in_target_directory(interrupt) self._leave_readme_in_target_directory(interrupt)
def _find_files(self) -> List[Path]: def _find_files(self) -> Iterable[Path]:
logger.info(f"Collecting files in {self._target_directory}") logger.info(f"Collecting files in {self._target_directory}")
return sorted(self._select_files(self._target_directory)) return self._select_files(self._target_directory) # type: ignore
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event): def _encrypt_files(self, files_to_encrypt: Iterable[Path], interrupt: threading.Event):
logger.info(f"Encrypting files in {self._target_directory}") logger.info(f"Encrypting files in {self._target_directory}")
interrupted_message = ( interrupted_message = (
"Received a stop signal, skipping remaining files for encryption of ransomware payload" "Received a stop signal, skipping remaining files for encryption of ransomware payload"
) )
for filepath in interruptible_iter(file_list, interrupt, interrupted_message): for filepath in interruptible_iter(files_to_encrypt, interrupt, interrupted_message):
try: try:
logger.debug(f"Encrypting {filepath}") logger.debug(f"Encrypting {filepath}")
self._encrypt_file(filepath) self._encrypt_file(filepath)
@ -76,6 +78,6 @@ class Ransomware:
return return
try: try:
self._leave_readme(README_SRC, self._readme_file_path) self._leave_readme(README_SRC, self._readme_file_path) # type: ignore
except Exception as ex: except Exception as ex:
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}") logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")

View File

@ -1,17 +1,17 @@
from pathlib import Path from pathlib import Path
from typing import Callable, Iterable, List, Set from typing import Callable, Iterable, Set
def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]: def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()]) return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
def filter_files( def filter_files(
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]] files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
) -> List[Path]: ) -> Iterable[Path]:
filtered_files = files filtered_files = files
for file_filter in file_filters: for file_filter in file_filters:
filtered_files = [f for f in filtered_files if file_filter(f)] filtered_files = filter(file_filter, filtered_files)
return filtered_files return filtered_files

View File

@ -24,7 +24,7 @@ def file_selector():
def test_select_targeted_files_only(ransomware_test_data, file_selector): def test_select_targeted_files_only(ransomware_test_data, file_selector):
selected_files = file_selector(ransomware_test_data) selected_files = list(file_selector(ransomware_test_data))
assert len(selected_files) == 2 assert len(selected_files) == 2
assert (ransomware_test_data / ALL_ZEROS_PDF) in selected_files assert (ransomware_test_data / ALL_ZEROS_PDF) in selected_files

View File

@ -58,10 +58,12 @@ def mock_file_encryptor():
@pytest.fixture @pytest.fixture
def mock_file_selector(ransomware_test_data): def mock_file_selector(ransomware_test_data):
selected_files = [ selected_files = iter(
[
ransomware_test_data / ALL_ZEROS_PDF, ransomware_test_data / ALL_ZEROS_PDF,
ransomware_test_data / TEST_KEYBOARD_TXT, ransomware_test_data / TEST_KEYBOARD_TXT,
] ]
)
return MagicMock(return_value=selected_files) return MagicMock(return_value=selected_files)

View File

@ -38,7 +38,7 @@ def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path) add_subdirs_to_dir(tmp_path)
expected_return_value = [] expected_return_value = []
assert get_all_regular_files_in_directory(tmp_path) == expected_return_value assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch): def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
@ -63,7 +63,7 @@ def test_filter_files__no_results(tmp_path):
add_files_to_dir(tmp_path) add_files_to_dir(tmp_path)
files_in_dir = get_all_regular_files_in_directory(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = filter_files(files_in_dir, [lambda _: False]) filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
assert len(filtered_files) == 0 assert len(filtered_files) == 0
@ -109,8 +109,8 @@ def test_is_not_symlink_filter(tmp_path):
link_path = tmp_path / "symlink.test" link_path = tmp_path / "symlink.test"
link_path.symlink_to(files[0], target_is_directory=False) link_path.symlink_to(files[0], target_is_directory=False)
files_in_dir = get_all_regular_files_in_directory(tmp_path) files_in_dir = list(get_all_regular_files_in_directory(tmp_path))
filtered_files = filter_files(files_in_dir, [is_not_symlink_filter]) filtered_files = list(filter_files(files_in_dir, [is_not_symlink_filter]))
assert link_path in files_in_dir assert link_path in files_in_dir
assert len(filtered_files) == len(FILES) assert len(filtered_files) == len(FILES)
@ -121,7 +121,7 @@ def test_is_not_shortcut_filter(tmp_path):
add_files_to_dir(tmp_path) add_files_to_dir(tmp_path)
files_in_dir = get_all_regular_files_in_directory(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = filter_files(files_in_dir, [is_not_shortcut_filter]) filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))
assert len(filtered_files) == len(FILES) - 1 assert len(filtered_files) == len(FILES) - 1
assert SHORTCUT not in [f.name for f in filtered_files] assert SHORTCUT not in [f.name for f in filtered_files]