From cf2cdc4ab8b6a468886f56bd24ff205bd362a680 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 22 Jun 2021 13:16:53 -0400 Subject: [PATCH] agent: Add filter_files() function to dir_utils --- monkey/infection_monkey/utils/dir_utils.py | 6 ++++- .../infection_monkey/utils/test_dir_utils.py | 24 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/monkey/infection_monkey/utils/dir_utils.py b/monkey/infection_monkey/utils/dir_utils.py index 4350a279d..b64e4e13f 100644 --- a/monkey/infection_monkey/utils/dir_utils.py +++ b/monkey/infection_monkey/utils/dir_utils.py @@ -1,6 +1,10 @@ from pathlib import Path -from typing import List +from typing import Callable, List def get_all_files_in_directory(dir_path: Path) -> List[Path]: return [f for f in dir_path.iterdir() if f.is_file()] + + +def filter_files(files: List[Path], file_filter: Callable[[Path], bool]): + return [f for f in files if file_filter(f)] diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py index 7369c5f41..604efc9cb 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_dir_utils.py @@ -1,4 +1,7 @@ -from infection_monkey.utils.dir_utils import get_all_files_in_directory +from infection_monkey.utils.dir_utils import ( + filter_files, + get_all_files_in_directory, +) FILE_1 = "file.jpg.zip" FILE_2 = "file.xyz" @@ -51,3 +54,22 @@ def test_get_all_files_in_directory__subdir_has_files(tmp_path, monkeypatch): expected_return_value = sorted(files) assert sorted(get_all_files_in_directory(tmp_path)) == expected_return_value + + +def test_filter_files__no_results(tmp_path): + add_files_to_dir(tmp_path) + + files_in_dir = get_all_files_in_directory(tmp_path) + filtered_files = filter_files(files_in_dir, lambda _: False) + + assert len(filtered_files) == 0 + + +def test_filter_files__all_true(tmp_path): + files = add_files_to_dir(tmp_path) + expected_return_value = sorted(files) + + files_in_dir = get_all_files_in_directory(tmp_path) + filtered_files = filter_files(files_in_dir, lambda _: True) + + assert sorted(filtered_files) == expected_return_value