diff --git a/monkey/infection_monkey/ransomware/file_selectors.py b/monkey/infection_monkey/ransomware/file_selectors.py new file mode 100644 index 000000000..f34bc9ca4 --- /dev/null +++ b/monkey/infection_monkey/ransomware/file_selectors.py @@ -0,0 +1,21 @@ +from pathlib import Path +from typing import List, Set + +from infection_monkey.utils.dir_utils import ( + file_extension_filter, + filter_files, + get_all_regular_files_in_directory, + is_not_shortcut_filter, + is_not_symlink_filter, +) + + +def select_production_safe_target_files(target_dir: Path, extensions: Set) -> List[Path]: + file_filters = [ + file_extension_filter(extensions), + is_not_shortcut_filter, + is_not_symlink_filter, + ] + + all_files = get_all_regular_files_in_directory(target_dir) + return filter_files(all_files, file_filters) diff --git a/monkey/infection_monkey/ransomware/ransomware_payload.py b/monkey/infection_monkey/ransomware/ransomware_payload.py index 5a4c6b412..edb2e76a4 100644 --- a/monkey/infection_monkey/ransomware/ransomware_payload.py +++ b/monkey/infection_monkey/ransomware/ransomware_payload.py @@ -2,15 +2,9 @@ import logging from pathlib import Path from typing import List, Optional, Tuple +from infection_monkey.ransomware.file_selectors import select_production_safe_target_files from infection_monkey.ransomware.ransomware_bitflip_encryptor import RansomwareBitflipEncryptor from infection_monkey.ransomware.valid_file_extensions import VALID_FILE_EXTENSIONS_FOR_ENCRYPTION -from infection_monkey.utils.dir_utils import ( - file_extension_filter, - filter_files, - get_all_regular_files_in_directory, - is_not_shortcut_filter, - is_not_symlink_filter, -) from infection_monkey.utils.environment import is_windows_os LOG = logging.getLogger(__name__) @@ -36,18 +30,13 @@ class RansomewarePayload: file_list = self._find_files() self._encrypt_files(file_list) - def _find_files(self): + def _find_files(self) -> List[Path]: if not self._target_dir: return [] - file_filters = [ - file_extension_filter(self._valid_file_extensions_for_encryption), - is_not_shortcut_filter, - is_not_symlink_filter, - ] - - all_files = get_all_regular_files_in_directory(Path(self._target_dir)) - return filter_files(all_files, file_filters) + return select_production_safe_target_files( + Path(self._target_dir), self._valid_file_extensions_for_encryption + ) def _encrypt_files(self, file_list: List[Path]) -> List[Tuple[Path, Optional[Exception]]]: results = []