agent: Simplify get_all_files_in_directory() with list comprehension

This commit is contained in:
Mike Salvatore 2021-06-22 12:15:03 -04:00
parent efef40edf9
commit e2dfd6a5e3
2 changed files with 34 additions and 24 deletions

View File

@ -1,8 +1,8 @@
import os from pathlib import Path
from typing import List from typing import List
def get_all_files_in_directory(dir_path: str) -> List: def get_all_files_in_directory(dir_path: str) -> List:
return list( path = Path(dir_path)
filter(os.path.isfile, [os.path.join(dir_path, item) for item in os.listdir(dir_path)])
) return [str(f) for f in path.iterdir() if f.is_file()]

View File

@ -1,4 +1,5 @@
import os import os
from pathlib import Path
from infection_monkey.utils.dir_utils import get_all_files_in_directory from infection_monkey.utils.dir_utils import get_all_files_in_directory
@ -8,39 +9,48 @@ SUBDIR_1 = "subdir1"
SUBDIR_2 = "subdir2" SUBDIR_2 = "subdir2"
def test_get_all_files_in_directory__no_files(tmpdir, monkeypatch): def add_subdirs_to_dir(parent_dir):
subdir1 = os.path.join(tmpdir, SUBDIR_1) subdir1 = os.path.join(parent_dir, SUBDIR_1)
subdir2 = os.path.join(tmpdir, SUBDIR_2) subdir2 = os.path.join(parent_dir, SUBDIR_2)
subdirs = [subdir1, subdir2] subdirs = [subdir1, subdir2]
for subdir in subdirs: for subdir in subdirs:
os.mkdir(subdir) os.mkdir(subdir)
all_items_in_dir = subdirs return subdirs
monkeypatch.setattr("os.listdir", lambda _: all_items_in_dir)
def add_files_to_dir(parent_dir):
file1 = os.path.join(parent_dir, FILE_1)
file2 = os.path.join(parent_dir, FILE_2)
files = [file1, file2]
for f in files:
Path(f).touch()
return files
def test_get_all_files_in_directory__no_files(tmpdir, monkeypatch):
add_subdirs_to_dir(tmpdir)
expected_return_value = [] expected_return_value = []
assert get_all_files_in_directory(tmpdir) == expected_return_value assert get_all_files_in_directory(tmpdir) == expected_return_value
def test_get_all_files_in_directory__has_files(tmpdir, monkeypatch): def test_get_all_files_in_directory__has_files(tmpdir, monkeypatch):
subdir1 = os.path.join(tmpdir, SUBDIR_1) add_subdirs_to_dir(tmpdir)
subdir2 = os.path.join(tmpdir, SUBDIR_2) files = add_files_to_dir(tmpdir)
subdirs = [subdir1, subdir2]
file1 = os.path.join(tmpdir, FILE_1) expected_return_value = sorted(files)
file2 = os.path.join(tmpdir, FILE_2) assert sorted(get_all_files_in_directory(tmpdir)) == expected_return_value
files = [file1, file2]
for subdir in subdirs:
os.mkdir(subdir)
for file in files: def test_get_all_files_in_directory__subdir_has_files(tmpdir, monkeypatch):
with open(file, "w") as _: subdirs = add_subdirs_to_dir(tmpdir)
pass add_files_to_dir(subdirs[0])
all_items_in_dir = subdirs + files files = add_files_to_dir(tmpdir)
monkeypatch.setattr("os.listdir", lambda _: all_items_in_dir)
expected_return_value = files expected_return_value = sorted(files)
assert get_all_files_in_directory(tmpdir) == expected_return_value assert sorted(get_all_files_in_directory(tmpdir)) == expected_return_value