forked from p15670423/monkey
agent: Add file_extension_filter to dir_utils
This commit is contained in:
parent
cf2cdc4ab8
commit
5c1902ca73
|
@ -1,16 +1,14 @@
|
||||||
import os
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from infection_monkey.ransomware.valid_file_extensions import VALID_FILE_EXTENSIONS_FOR_ENCRYPTION
|
from infection_monkey.ransomware.valid_file_extensions import VALID_FILE_EXTENSIONS_FOR_ENCRYPTION
|
||||||
from infection_monkey.utils.dir_utils import get_all_files_in_directory
|
from infection_monkey.utils.dir_utils import (
|
||||||
|
file_extension_filter,
|
||||||
|
filter_files,
|
||||||
|
get_all_files_in_directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_files_to_encrypt(dir_path: str) -> List[str]:
|
def get_files_to_encrypt(dir_path: str) -> List[Path]:
|
||||||
all_files = get_all_files_in_directory(dir_path)
|
all_files = get_all_files_in_directory(Path(dir_path))
|
||||||
|
return filter_files(all_files, file_extension_filter(VALID_FILE_EXTENSIONS_FOR_ENCRYPTION))
|
||||||
files_to_encrypt = []
|
|
||||||
for file in all_files:
|
|
||||||
if os.path.splitext(file)[1] in VALID_FILE_EXTENSIONS_FOR_ENCRYPTION:
|
|
||||||
files_to_encrypt.append(file)
|
|
||||||
|
|
||||||
return files_to_encrypt
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Set
|
||||||
|
|
||||||
|
|
||||||
def get_all_files_in_directory(dir_path: Path) -> List[Path]:
|
def get_all_files_in_directory(dir_path: Path) -> List[Path]:
|
||||||
|
@ -8,3 +8,10 @@ def get_all_files_in_directory(dir_path: Path) -> List[Path]:
|
||||||
|
|
||||||
def filter_files(files: List[Path], file_filter: Callable[[Path], bool]):
|
def filter_files(files: List[Path], file_filter: Callable[[Path], bool]):
|
||||||
return [f for f in files if file_filter(f)]
|
return [f for f in files if file_filter(f)]
|
||||||
|
|
||||||
|
|
||||||
|
def file_extension_filter(file_extensions: Set):
|
||||||
|
def inner_filter(f: Path):
|
||||||
|
return f.suffix in file_extensions
|
||||||
|
|
||||||
|
return inner_filter
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from infection_monkey.utils.dir_utils import (
|
from infection_monkey.utils.dir_utils import (
|
||||||
|
file_extension_filter,
|
||||||
filter_files,
|
filter_files,
|
||||||
get_all_files_in_directory,
|
get_all_files_in_directory,
|
||||||
)
|
)
|
||||||
|
@ -73,3 +74,14 @@ def test_filter_files__all_true(tmp_path):
|
||||||
filtered_files = filter_files(files_in_dir, lambda _: True)
|
filtered_files = filter_files(files_in_dir, lambda _: True)
|
||||||
|
|
||||||
assert sorted(filtered_files) == expected_return_value
|
assert sorted(filtered_files) == expected_return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_extension_filter(tmp_path):
|
||||||
|
valid_extensions = {".zip", ".tar"}
|
||||||
|
|
||||||
|
files = add_files_to_dir(tmp_path)
|
||||||
|
|
||||||
|
files_in_dir = get_all_files_in_directory(tmp_path)
|
||||||
|
filtered_files = filter_files(files_in_dir, file_extension_filter(valid_extensions))
|
||||||
|
|
||||||
|
assert files[0:1] == filtered_files
|
||||||
|
|
Loading…
Reference in New Issue