Common: Move get_all_regular_files_in_directory() to utils.file_utils

This commit is contained in:
Mike Salvatore 2022-04-25 12:32:59 -04:00
parent 2f4ffad3f6
commit cd8fa699b0
7 changed files with 72 additions and 59 deletions

View File

@ -1,6 +1,7 @@
import hashlib import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Iterable
class InvalidPath(Exception): class InvalidPath(Exception):
@ -21,3 +22,7 @@ def get_file_sha256_hash(filepath: Path):
sha256.update(block) sha256.update(block)
return sha256.hexdigest() return sha256.hexdigest()
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
return filter(lambda f: f.is_file(), dir_path.iterdir())

View File

@ -2,10 +2,10 @@ import filecmp
from pathlib import Path from pathlib import Path
from typing import Iterable, Set from typing import Iterable, Set
from common.utils.file_utils import get_all_regular_files_in_directory
from infection_monkey.utils.dir_utils import ( from infection_monkey.utils.dir_utils import (
file_extension_filter, file_extension_filter,
filter_files, filter_files,
get_all_regular_files_in_directory,
is_not_shortcut_filter, is_not_shortcut_filter,
is_not_symlink_filter, is_not_symlink_filter,
) )

View File

@ -2,10 +2,6 @@ from pathlib import Path
from typing import Callable, Iterable, Set from typing import Callable, Iterable, Set
def get_all_regular_files_in_directory(dir_path: Path) -> Iterable[Path]:
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
def filter_files( def filter_files(
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]] files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
) -> Iterable[Path]: ) -> Iterable[Path]:

View File

@ -2,6 +2,7 @@ import shutil
from pathlib import Path from pathlib import Path
from typing import BinaryIO from typing import BinaryIO
from common.utils.file_utils import get_all_regular_files_in_directory
from monkey_island.cc.server_utils.file_utils import create_secure_directory from monkey_island.cc.server_utils.file_utils import create_secure_directory
from . import IFileStorageService from . import IFileStorageService
@ -50,5 +51,5 @@ class DirectoryFileStorageService(IFileStorageService):
return self._storage_directory / safe_file_name return self._storage_directory / safe_file_name
def delete_all_files(self): def delete_all_files(self):
for file in filter(lambda f: f.is_file(), self._storage_directory.iterdir()): for file in get_all_regular_files_in_directory(self._storage_directory):
file.unlink() file.unlink()

View File

@ -1,8 +1,14 @@
import os import os
import pytest import pytest
from tests.utils import add_files_to_dir, add_subdirs_to_dir
from common.utils.file_utils import InvalidPath, expand_path, get_file_sha256_hash from common.utils.file_utils import (
InvalidPath,
expand_path,
get_all_regular_files_in_directory,
get_file_sha256_hash,
)
def test_expand_user(patched_home_env): def test_expand_user(patched_home_env):
@ -26,3 +32,32 @@ def test_expand_path__empty_path_provided():
def test_get_file_sha256_hash(stable_file, stable_file_sha256_hash): def test_get_file_sha256_hash(stable_file, stable_file_sha256_hash):
assert get_file_sha256_hash(stable_file) == stable_file_sha256_hash assert get_file_sha256_hash(stable_file) == stable_file_sha256_hash
SUBDIRS = ["subdir1", "subdir2"]
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg"]
def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path, SUBDIRS)
expected_return_value = []
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path, SUBDIRS)
files = add_files_to_dir(tmp_path, FILES)
expected_return_value = sorted(files)
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch):
subdirs = add_subdirs_to_dir(tmp_path, SUBDIRS)
add_files_to_dir(subdirs[0], FILES)
files = add_files_to_dir(tmp_path, FILES)
expected_return_value = sorted(files)
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value

View File

@ -1,66 +1,22 @@
import os import os
import pytest import pytest
from tests.utils import is_user_admin from tests.utils import add_files_to_dir, is_user_admin
from common.utils.file_utils import get_all_regular_files_in_directory
from infection_monkey.utils.dir_utils import ( from infection_monkey.utils.dir_utils import (
file_extension_filter, file_extension_filter,
filter_files, filter_files,
get_all_regular_files_in_directory,
is_not_shortcut_filter, is_not_shortcut_filter,
is_not_symlink_filter, is_not_symlink_filter,
) )
SHORTCUT = "shortcut.lnk" SHORTCUT = "shortcut.lnk"
FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg", SHORTCUT] FILES = ["file.jpg.zip", "file.xyz", "1.tar", "2.tgz", "2.png", "2.mpg", SHORTCUT]
SUBDIRS = ["subdir1", "subdir2"]
def add_subdirs_to_dir(parent_dir):
subdirs = [parent_dir / s for s in SUBDIRS]
for subdir in subdirs:
subdir.mkdir()
return subdirs
def add_files_to_dir(parent_dir):
files = [parent_dir / f for f in FILES]
for f in files:
f.touch()
return files
def test_get_all_regular_files_in_directory__no_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path)
expected_return_value = []
assert list(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__has_files(tmp_path, monkeypatch):
add_subdirs_to_dir(tmp_path)
files = add_files_to_dir(tmp_path)
expected_return_value = sorted(files)
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_get_all_regular_files_in_directory__subdir_has_files(tmp_path, monkeypatch):
subdirs = add_subdirs_to_dir(tmp_path)
add_files_to_dir(subdirs[0])
files = add_files_to_dir(tmp_path)
expected_return_value = sorted(files)
assert sorted(get_all_regular_files_in_directory(tmp_path)) == expected_return_value
def test_filter_files__no_results(tmp_path): def test_filter_files__no_results(tmp_path):
add_files_to_dir(tmp_path) add_files_to_dir(tmp_path, FILES)
files_in_dir = get_all_regular_files_in_directory(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = list(filter_files(files_in_dir, [lambda _: False])) filtered_files = list(filter_files(files_in_dir, [lambda _: False]))
@ -69,7 +25,7 @@ def test_filter_files__no_results(tmp_path):
def test_filter_files__all_true(tmp_path): def test_filter_files__all_true(tmp_path):
files = add_files_to_dir(tmp_path) files = add_files_to_dir(tmp_path, FILES)
expected_return_value = sorted(files) expected_return_value = sorted(files)
files_in_dir = get_all_regular_files_in_directory(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path)
@ -79,7 +35,7 @@ def test_filter_files__all_true(tmp_path):
def test_filter_files__multiple_filters(tmp_path): def test_filter_files__multiple_filters(tmp_path):
files = add_files_to_dir(tmp_path) files = add_files_to_dir(tmp_path, FILES)
expected_return_value = sorted(files[4:6]) expected_return_value = sorted(files[4:6])
files_in_dir = get_all_regular_files_in_directory(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path)
@ -93,7 +49,7 @@ def test_filter_files__multiple_filters(tmp_path):
def test_file_extension_filter(tmp_path): def test_file_extension_filter(tmp_path):
valid_extensions = {".zip", ".xyz"} valid_extensions = {".zip", ".xyz"}
files = add_files_to_dir(tmp_path) files = add_files_to_dir(tmp_path, FILES)
files_in_dir = get_all_regular_files_in_directory(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = filter_files(files_in_dir, [file_extension_filter(valid_extensions)]) filtered_files = filter_files(files_in_dir, [file_extension_filter(valid_extensions)])
@ -105,7 +61,7 @@ def test_file_extension_filter(tmp_path):
os.name == "nt" and not is_user_admin(), reason="Test requires admin rights on Windows" os.name == "nt" and not is_user_admin(), reason="Test requires admin rights on Windows"
) )
def test_is_not_symlink_filter(tmp_path): def test_is_not_symlink_filter(tmp_path):
files = add_files_to_dir(tmp_path) files = add_files_to_dir(tmp_path, FILES)
link_path = tmp_path / "symlink.test" link_path = tmp_path / "symlink.test"
link_path.symlink_to(files[0], target_is_directory=False) link_path.symlink_to(files[0], target_is_directory=False)
@ -118,7 +74,7 @@ def test_is_not_symlink_filter(tmp_path):
def test_is_not_shortcut_filter(tmp_path): def test_is_not_shortcut_filter(tmp_path):
add_files_to_dir(tmp_path) add_files_to_dir(tmp_path, FILES)
files_in_dir = get_all_regular_files_in_directory(tmp_path) files_in_dir = get_all_regular_files_in_directory(tmp_path)
filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter])) filtered_files = list(filter_files(files_in_dir, [is_not_shortcut_filter]))

View File

@ -1,5 +1,7 @@
import ctypes import ctypes
import os import os
from pathlib import Path
from typing import Iterable
def is_user_admin(): def is_user_admin():
@ -11,3 +13,21 @@ def is_user_admin():
def raise_(ex): def raise_(ex):
raise ex raise ex
def add_subdirs_to_dir(parent_dir: Path, subdirs: Iterable[str]) -> Iterable[Path]:
subdir_paths = [parent_dir / s for s in subdirs]
for subdir in subdir_paths:
subdir.mkdir()
return subdir_paths
def add_files_to_dir(parent_dir: Path, file_names: Iterable[str]) -> Iterable[Path]:
files = [parent_dir / f for f in file_names]
for f in files:
f.touch()
return files