diff --git a/monkey/infection_monkey/ransomware/utils.py b/monkey/infection_monkey/ransomware/utils.py index 8ad1279c7..dde0ff973 100644 --- a/monkey/infection_monkey/ransomware/utils.py +++ b/monkey/infection_monkey/ransomware/utils.py @@ -1,16 +1,14 @@ -import os +from pathlib import Path from typing import List from infection_monkey.ransomware.valid_file_extensions import VALID_FILE_EXTENSIONS_FOR_ENCRYPTION -from infection_monkey.utils.dir_utils import get_all_files_in_directory +from infection_monkey.utils.dir_utils import ( + file_extension_filter, + filter_files, + get_all_files_in_directory, +) -def get_files_to_encrypt(dir_path: str) -> List[str]: - all_files = get_all_files_in_directory(dir_path) - - files_to_encrypt = [] - for file in all_files: - if os.path.splitext(file)[1] in VALID_FILE_EXTENSIONS_FOR_ENCRYPTION: - files_to_encrypt.append(file) - - return files_to_encrypt +def get_files_to_encrypt(dir_path: str) -> List[Path]: + all_files = get_all_files_in_directory(Path(dir_path)) + return filter_files(all_files, file_extension_filter(VALID_FILE_EXTENSIONS_FOR_ENCRYPTION)) diff --git a/monkey/infection_monkey/utils/dir_utils.py b/monkey/infection_monkey/utils/dir_utils.py index b64e4e13f..23b4408a9 100644 --- a/monkey/infection_monkey/utils/dir_utils.py +++ b/monkey/infection_monkey/utils/dir_utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Callable, List +from typing import Callable, List, Set def get_all_files_in_directory(dir_path: Path) -> List[Path]: @@ -8,3 +8,10 @@ def get_all_files_in_directory(dir_path: Path) -> List[Path]: def filter_files(files: List[Path], file_filter: Callable[[Path], bool]): return [f for f in files if file_filter(f)] + + +def file_extension_filter(file_extensions: Set): + def inner_filter(f: Path): + return f.suffix in file_extensions + + return inner_filter diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py index 604efc9cb..ac8312bd8 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py @@ -1,4 +1,5 @@ from infection_monkey.utils.dir_utils import ( + file_extension_filter, filter_files, get_all_files_in_directory, ) @@ -73,3 +74,14 @@ def test_filter_files__all_true(tmp_path): filtered_files = filter_files(files_in_dir, lambda _: True) assert sorted(filtered_files) == expected_return_value + + +def test_file_extension_filter(tmp_path): + valid_extensions = {".zip", ".tar"} + + files = add_files_to_dir(tmp_path) + + files_in_dir = get_all_files_in_directory(tmp_path) + filtered_files = filter_files(files_in_dir, file_extension_filter(valid_extensions)) + + assert files[0:1] == filtered_files