From 7c6ba2e276e1141c794789afbf94894f077ecefc Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Fri, 25 Mar 2022 08:27:45 -0400 Subject: [PATCH] Agent: Use iterators instead of lists for ransomware file filtering --- .../payload/ransomware/file_selectors.py | 4 ++-- .../payload/ransomware/ransomware.py | 22 ++++++++++--------- monkey/infection_monkey/utils/dir_utils.py | 8 +++---- .../payload/ransomware/test_file_selectors.py | 2 +- .../payload/ransomware/test_ransomware.py | 10 +++++---- .../infection_monkey/utils/test_dir_utils.py | 10 ++++----- 6 files changed, 30 insertions(+), 26 deletions(-) diff --git a/monkey/infection_monkey/payload/ransomware/file_selectors.py b/monkey/infection_monkey/payload/ransomware/file_selectors.py index bcdd87b46..1303970e7 100644 --- a/monkey/infection_monkey/payload/ransomware/file_selectors.py +++ b/monkey/infection_monkey/payload/ransomware/file_selectors.py @@ -1,6 +1,6 @@ import filecmp from pathlib import Path -from typing import List, Set +from typing import Iterable, Set from infection_monkey.utils.dir_utils import ( file_extension_filter, @@ -17,7 +17,7 @@ class ProductionSafeTargetFileSelector: def __init__(self, targeted_file_extensions: Set[str]): self._targeted_file_extensions = targeted_file_extensions - def __call__(self, target_dir: Path) -> List[Path]: + def __call__(self, target_dir: Path) -> Iterable[Path]: file_filters = [ file_extension_filter(self._targeted_file_extensions), is_not_shortcut_filter, diff --git a/monkey/infection_monkey/payload/ransomware/ransomware.py b/monkey/infection_monkey/payload/ransomware/ransomware.py index 9cf488c32..966476be2 100644 --- a/monkey/infection_monkey/payload/ransomware/ransomware.py +++ b/monkey/infection_monkey/payload/ransomware/ransomware.py @@ -1,7 +1,7 @@ import logging import threading from pathlib import Path -from typing import Callable, List +from typing import Callable, Iterable from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger @@ -18,7 +18,7 @@ class Ransomware: self, config: RansomwareOptions, encrypt_file: Callable[[Path], None], - select_files: Callable[[Path], List[Path]], + select_files: Callable[[Path], Iterable[Path]], leave_readme: Callable[[Path, Path], None], telemetry_messenger: ITelemetryMessenger, ): @@ -31,7 +31,9 @@ class Ransomware: self._target_directory = self._config.target_directory self._readme_file_path = ( - self._target_directory / README_FILE_NAME if self._target_directory else None + self._target_directory / README_FILE_NAME # type: ignore + if self._target_directory + else None ) def run(self, interrupt: threading.Event): @@ -41,23 +43,23 @@ class Ransomware: logger.info("Running ransomware payload") if self._config.encryption_enabled: - file_list = self._find_files() - self._encrypt_files(file_list, interrupt) + files_to_encrypt = self._find_files() + self._encrypt_files(files_to_encrypt, interrupt) if self._config.readme_enabled: self._leave_readme_in_target_directory(interrupt) - def _find_files(self) -> List[Path]: + def _find_files(self) -> Iterable[Path]: logger.info(f"Collecting files in {self._target_directory}") - return sorted(self._select_files(self._target_directory)) + return self._select_files(self._target_directory) # type: ignore - def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event): + def _encrypt_files(self, files_to_encrypt: Iterable[Path], interrupt: threading.Event): logger.info(f"Encrypting files in {self._target_directory}") interrupted_message = ( "Received a stop signal, skipping remaining files for encryption of ransomware payload" ) - for filepath in interruptible_iter(file_list, interrupt, interrupted_message): + for filepath in interruptible_iter(files_to_encrypt, interrupt, interrupted_message): try: logger.debug(f"Encrypting {filepath}") self._encrypt_file(filepath) @@ -76,6 +78,6 @@ class Ransomware: return try: - self._leave_readme(README_SRC, self._readme_file_path) + self._leave_readme(README_SRC, self._readme_file_path) # type: ignore except Exception as ex: logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}") diff --git a/monkey/infection_monkey/utils/dir_utils.py b/monkey/infection_monkey/utils/dir_utils.py index 2fd29af9e..da0a5e2e4 100644 --- a/monkey/infection_monkey/utils/dir_utils.py +++ b/monkey/infection_monkey/utils/dir_utils.py @@ -1,17 +1,17 @@ from pathlib import Path -from typing import Callable, Iterable, List, Set +from typing import Callable, Iterable, Set -def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]: +def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]: return filter_files(dir_path.iterdir(), [lambda f: f.is_file()]) def filter_files( files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]] -) -> List[Path]: +) -> Iterable[Path]: filtered_files = files for file_filter in file_filters: - filtered_files = [f for f in filtered_files if file_filter(f)] + filtered_files = filter(file_filter, filtered_files) return filtered_files diff --git a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_file_selectors.py b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_file_selectors.py index f779b733e..8b1309c07 100644 --- a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_file_selectors.py +++ b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_file_selectors.py @@ -24,7 +24,7 @@ def file_selector(): def test_select_targeted_files_only(ransomware_test_data, file_selector): - selected_files = file_selector(ransomware_test_data) + selected_files = list(file_selector(ransomware_test_data)) assert len(selected_files) == 2 assert (ransomware_test_data / ALL_ZEROS_PDF) in selected_files diff --git a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py index 365f9fecd..88f37037c 100644 --- a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py +++ b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py @@ -58,10 +58,12 @@ def mock_file_encryptor(): @pytest.fixture def mock_file_selector(ransomware_test_data): - selected_files = [ - ransomware_test_data / ALL_ZEROS_PDF, - ransomware_test_data / TEST_KEYBOARD_TXT, - ] + selected_files = iter( + [ + ransomware_test_data / ALL_ZEROS_PDF, + ransomware_test_data / TEST_KEYBOARD_TXT, + ] + ) return MagicMock(return_value=selected_files) diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py index 8ebddf280..adf18bf5a 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py @@ -38,7 +38,7 @@ def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch): add_subdirs_to_dir(tmp_path) expected_return_value = [] - assert get_all_regular_files_in_directory(tmp_path) == expected_return_value + assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch): @@ -63,7 +63,7 @@ def test_filter_files__no_results(tmp_path): add_files_to_dir(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path) - filtered_files = filter_files(files_in_dir, [lambda _: False]) + filtered_files = list(filter_files(files_in_dir, [lambda _: False])) assert len(filtered_files) == 0 @@ -109,8 +109,8 @@ def test_is_not_symlink_filter(tmp_path): link_path = tmp_path / "symlink.test" link_path.symlink_to(files[0], target_is_directory=False) - files_in_dir = get_all_regular_files_in_directory(tmp_path) - filtered_files = filter_files(files_in_dir, [is_not_symlink_filter]) + files_in_dir = list(get_all_regular_files_in_directory(tmp_path)) + filtered_files = list(filter_files(files_in_dir, [is_not_symlink_filter])) assert link_path in files_in_dir assert len(filtered_files) == len(FILES) @@ -121,7 +121,7 @@ def test_is_not_shortcut_filter(tmp_path): add_files_to_dir(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path) - filtered_files = filter_files(files_in_dir, [is_not_shortcut_filter]) + filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter])) assert len(filtered_files) == len(FILES) - 1 assert SHORTCUT not in [f.name for f in filtered_files]