diff --git a/monkey/common/utils/file_utils.py b/monkey/common/utils/file_utils.py index fd2c85ec1..c9cb78139 100644 --- a/monkey/common/utils/file_utils.py +++ b/monkey/common/utils/file_utils.py @@ -1,6 +1,7 @@ import hashlib import os from pathlib import Path +from typing import Iterable class InvalidPath(Exception): @@ -21,3 +22,7 @@ def get_file_sha256_hash(filepath: Path): sha256.update(block) return sha256.hexdigest() + + +def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]: + return filter(lambda f: f.is_file(), dir_path.iterdir()) diff --git a/monkey/infection_monkey/payload/ransomware/file_selectors.py b/monkey/infection_monkey/payload/ransomware/file_selectors.py index 1303970e7..1857d7d63 100644 --- a/monkey/infection_monkey/payload/ransomware/file_selectors.py +++ b/monkey/infection_monkey/payload/ransomware/file_selectors.py @@ -2,10 +2,10 @@ import filecmp from pathlib import Path from typing import Iterable, Set +from common.utils.file_utils import get_all_regular_files_in_directory from infection_monkey.utils.dir_utils import ( file_extension_filter, filter_files, - get_all_regular_files_in_directory, is_not_shortcut_filter, is_not_symlink_filter, ) diff --git a/monkey/infection_monkey/utils/dir_utils.py b/monkey/infection_monkey/utils/dir_utils.py index da0a5e2e4..9cd20a316 100644 --- a/monkey/infection_monkey/utils/dir_utils.py +++ b/monkey/infection_monkey/utils/dir_utils.py @@ -2,10 +2,6 @@ from pathlib import Path from typing import Callable, Iterable, Set -def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]: - return filter_files(dir_path.iterdir(), [lambda f: f.is_file()]) - - def filter_files( files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]] ) -> Iterable[Path]: diff --git a/monkey/monkey_island/cc/services/directory_file_storage_service.py b/monkey/monkey_island/cc/services/directory_file_storage_service.py index 48fb95bfc..60c3c1370 100644 --- a/monkey/monkey_island/cc/services/directory_file_storage_service.py +++ b/monkey/monkey_island/cc/services/directory_file_storage_service.py @@ -2,6 +2,7 @@ import shutil from pathlib import Path from typing import BinaryIO +from common.utils.file_utils import get_all_regular_files_in_directory from monkey_island.cc.server_utils.file_utils import create_secure_directory from . import IFileStorageService @@ -50,5 +51,5 @@ class DirectoryFileStorageService(IFileStorageService): return self._storage_directory / safe_file_name def delete_all_files(self): - for file in filter(lambda f: f.is_file(), self._storage_directory.iterdir()): + for file in get_all_regular_files_in_directory(self._storage_directory): file.unlink() diff --git a/monkey/tests/unit_tests/common/utils/test_common_file_utils.py b/monkey/tests/unit_tests/common/utils/test_common_file_utils.py index 79d00d027..aac13839e 100644 --- a/monkey/tests/unit_tests/common/utils/test_common_file_utils.py +++ b/monkey/tests/unit_tests/common/utils/test_common_file_utils.py @@ -1,8 +1,14 @@ import os import pytest +from tests.utils import add_files_to_dir, add_subdirs_to_dir -from common.utils.file_utils import InvalidPath, expand_path, get_file_sha256_hash +from common.utils.file_utils import ( + InvalidPath, + expand_path, + get_all_regular_files_in_directory, + get_file_sha256_hash, +) def test_expand_user(patched_home_env): @@ -26,3 +32,32 @@ def test_expand_path__empty_path_provided(): def test_get_file_sha256_hash(stable_file, stable_file_sha256_hash): assert get_file_sha256_hash(stable_file) == stable_file_sha256_hash + + +SUBDIRS = ["subdir1", "subdir2"] +FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg"] + + +def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch): + add_subdirs_to_dir(tmp_path, SUBDIRS) + + expected_return_value = [] + assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value + + +def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch): + add_subdirs_to_dir(tmp_path, SUBDIRS) + files = add_files_to_dir(tmp_path, FILES) + + expected_return_value = sorted(files) + assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value + + +def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch): + subdirs = add_subdirs_to_dir(tmp_path, SUBDIRS) + add_files_to_dir(subdirs[0], FILES) + + files = add_files_to_dir(tmp_path, FILES) + + expected_return_value = sorted(files) + assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py index adf18bf5a..fe7b499bf 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py @@ -1,66 +1,22 @@ import os import pytest -from tests.utils import is_user_admin +from tests.utils import add_files_to_dir, is_user_admin +from common.utils.file_utils import get_all_regular_files_in_directory from infection_monkey.utils.dir_utils import ( file_extension_filter, filter_files, - get_all_regular_files_in_directory, is_not_shortcut_filter, is_not_symlink_filter, ) SHORTCUT = "shortcut.lnk" FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg", SHORTCUT] -SUBDIRS = ["subdir1", "subdir2"] - - -def add_subdirs_to_dir(parent_dir): - subdirs = [parent_dir / s for s in SUBDIRS] - - for subdir in subdirs: - subdir.mkdir() - - return subdirs - - -def add_files_to_dir(parent_dir): - files = [parent_dir / f for f in FILES] - - for f in files: - f.touch() - - return files - - -def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch): - add_subdirs_to_dir(tmp_path) - - expected_return_value = [] - assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value - - -def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch): - add_subdirs_to_dir(tmp_path) - files = add_files_to_dir(tmp_path) - - expected_return_value = sorted(files) - assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value - - -def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch): - subdirs = add_subdirs_to_dir(tmp_path) - add_files_to_dir(subdirs[0]) - - files = add_files_to_dir(tmp_path) - - expected_return_value = sorted(files) - assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value def test_filter_files__no_results(tmp_path): - add_files_to_dir(tmp_path) + add_files_to_dir(tmp_path, FILES) files_in_dir = get_all_regular_files_in_directory(tmp_path) filtered_files = list(filter_files(files_in_dir, [lambda _: False])) @@ -69,7 +25,7 @@ def test_filter_files__no_results(tmp_path): def test_filter_files__all_true(tmp_path): - files = add_files_to_dir(tmp_path) + files = add_files_to_dir(tmp_path, FILES) expected_return_value = sorted(files) files_in_dir = get_all_regular_files_in_directory(tmp_path) @@ -79,7 +35,7 @@ def test_filter_files__all_true(tmp_path): def test_filter_files__multiple_filters(tmp_path): - files = add_files_to_dir(tmp_path) + files = add_files_to_dir(tmp_path, FILES) expected_return_value = sorted(files[4:6]) files_in_dir = get_all_regular_files_in_directory(tmp_path) @@ -93,7 +49,7 @@ def test_filter_files__multiple_filters(tmp_path): def test_file_extension_filter(tmp_path): valid_extensions = {".zip", ".xyz"} - files = add_files_to_dir(tmp_path) + files = add_files_to_dir(tmp_path, FILES) files_in_dir = get_all_regular_files_in_directory(tmp_path) filtered_files = filter_files(files_in_dir, [file_extension_filter(valid_extensions)]) @@ -105,7 +61,7 @@ def test_file_extension_filter(tmp_path): os.name == "nt" and not is_user_admin(), reason="Test requires admin rights on Windows" ) def test_is_not_symlink_filter(tmp_path): - files = add_files_to_dir(tmp_path) + files = add_files_to_dir(tmp_path, FILES) link_path = tmp_path / "symlink.test" link_path.symlink_to(files[0], target_is_directory=False) @@ -118,7 +74,7 @@ def test_is_not_symlink_filter(tmp_path): def test_is_not_shortcut_filter(tmp_path): - add_files_to_dir(tmp_path) + add_files_to_dir(tmp_path, FILES) files_in_dir = get_all_regular_files_in_directory(tmp_path) filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter])) diff --git a/monkey/tests/utils.py b/monkey/tests/utils.py index 9b57a9cc7..b2268e572 100644 --- a/monkey/tests/utils.py +++ b/monkey/tests/utils.py @@ -1,5 +1,7 @@ import ctypes import os +from pathlib import Path +from typing import Iterable def is_user_admin(): @@ -11,3 +13,21 @@ def is_user_admin(): def raise_(ex): raise ex + + +def add_subdirs_to_dir(parent_dir: Path, subdirs: Iterable[str]) -> Iterable[Path]: + subdir_paths = [parent_dir / s for s in subdirs] + + for subdir in subdir_paths: + subdir.mkdir() + + return subdir_paths + + +def add_files_to_dir(parent_dir: Path, file_names: Iterable[str]) -> Iterable[Path]: + files = [parent_dir / f for f in file_names] + + for f in files: + f.touch() + + return files