From db8dfd9f179cd0fcdd7d4aedfab698e79a53b5eb Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 22 Jun 2021 13:43:27 -0400 Subject: [PATCH] agent: Refactor filter_files to accept a list of filters --- .../ransomware/ransomware_payload.py | 4 +++- monkey/infection_monkey/utils/dir_utils.py | 8 ++++++-- .../infection_monkey/utils/test_dir_utils.py | 17 +++++++---------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/monkey/infection_monkey/ransomware/ransomware_payload.py b/monkey/infection_monkey/ransomware/ransomware_payload.py index c09c35ab9..e89dc2625 100644 --- a/monkey/infection_monkey/ransomware/ransomware_payload.py +++ b/monkey/infection_monkey/ransomware/ransomware_payload.py @@ -24,8 +24,10 @@ class RansomewarePayload: self._encrypt_files(file_list) def _find_files(self): + file_filters = [file_extension_filter(VALID_FILE_EXTENSIONS_FOR_ENCRYPTION)] + all_files = get_all_files_in_directory(self.target_dir) - return filter_files(all_files, file_extension_filter(VALID_FILE_EXTENSIONS_FOR_ENCRYPTION)) + return filter_files(all_files, file_filters) def _encrypt_files(self, file_list): for file in file_list: diff --git a/monkey/infection_monkey/utils/dir_utils.py b/monkey/infection_monkey/utils/dir_utils.py index 23b4408a9..2a7797a7b 100644 --- a/monkey/infection_monkey/utils/dir_utils.py +++ b/monkey/infection_monkey/utils/dir_utils.py @@ -6,8 +6,12 @@ def get_all_files_in_directory(dir_path: Path) -> List[Path]: return [f for f in dir_path.iterdir() if f.is_file()] -def filter_files(files: List[Path], file_filter: Callable[[Path], bool]): - return [f for f in files if file_filter(f)] +def filter_files(files: List[Path], file_filters: List[Callable[[Path], bool]]): + filtered_files = files + for file_filter in file_filters: + filtered_files = [f for f in filtered_files if file_filter(f)] + + return filtered_files def file_extension_filter(file_extensions: Set): diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py index ac8312bd8..4657155c0 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py @@ -4,8 +4,7 @@ from infection_monkey.utils.dir_utils import ( get_all_files_in_directory, ) -FILE_1 = "file.jpg.zip" -FILE_2 = "file.xyz" +FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz"] SUBDIR_1 = "subdir1" SUBDIR_2 = "subdir2" @@ -22,9 +21,7 @@ def add_subdirs_to_dir(parent_dir): def add_files_to_dir(parent_dir): - file1 = parent_dir / FILE_1 - file2 = parent_dir / FILE_2 - files = [file1, file2] + files = [parent_dir / f for f in FILES] for f in files: f.touch() @@ -61,7 +58,7 @@ def test_filter_files__no_results(tmp_path): add_files_to_dir(tmp_path) files_in_dir = get_all_files_in_directory(tmp_path) - filtered_files = filter_files(files_in_dir, lambda _: False) + filtered_files = filter_files(files_in_dir, [lambda _: False]) assert len(filtered_files) == 0 @@ -71,17 +68,17 @@ def test_filter_files__all_true(tmp_path): expected_return_value = sorted(files) files_in_dir = get_all_files_in_directory(tmp_path) - filtered_files = filter_files(files_in_dir, lambda _: True) + filtered_files = filter_files(files_in_dir, [lambda _: True]) assert sorted(filtered_files) == expected_return_value def test_file_extension_filter(tmp_path): - valid_extensions = {".zip", ".tar"} + valid_extensions = {".zip", ".xyz"} files = add_files_to_dir(tmp_path) files_in_dir = get_all_files_in_directory(tmp_path) - filtered_files = filter_files(files_in_dir, file_extension_filter(valid_extensions)) + filtered_files = filter_files(files_in_dir, [file_extension_filter(valid_extensions)]) - assert files[0:1] == filtered_files + assert sorted(files[0:2]) == sorted(filtered_files)