diff --git a/monkey/common/utils/file_utils.py b/monkey/common/utils/file_utils.py index 225fb8732..6110cd020 100644 --- a/monkey/common/utils/file_utils.py +++ b/monkey/common/utils/file_utils.py @@ -1,5 +1,11 @@ import os +class InvalidPath(Exception): + pass + + def expand_path(path: str) -> str: + if not path: + raise InvalidPath("Empty path provided") return os.path.expandvars(os.path.expanduser(path)) diff --git a/monkey/infection_monkey/ransomware/ransomware_payload.py b/monkey/infection_monkey/ransomware/ransomware_payload.py index a58b50545..f2e3eb476 100644 --- a/monkey/infection_monkey/ransomware/ransomware_payload.py +++ b/monkey/infection_monkey/ransomware/ransomware_payload.py @@ -4,7 +4,7 @@ from pathlib import Path from pprint import pformat from typing import List, Optional, Tuple -from common.utils.file_utils import expand_path +from common.utils.file_utils import InvalidPath, expand_path from infection_monkey.ransomware.bitflip_encryptor import BitflipEncryptor from infection_monkey.ransomware.file_selectors import select_production_safe_target_files from infection_monkey.ransomware.targeted_file_extensions import TARGETED_FILE_EXTENSIONS @@ -28,15 +28,7 @@ class RansomwarePayload: self._encryption_enabled = config["encryption"]["enabled"] self._readme_enabled = config["other_behaviors"]["readme"] - target_directories = config["encryption"]["directories"] - self._target_dir = Path( - expand_path( - target_directories["windows_target_dir"] - if is_windows_os() - else target_directories["linux_target_dir"] - ) - ) - + self._target_dir = RansomwarePayload.get_target_dir(config) self._new_file_extension = EXTENSION self._valid_file_extensions_for_encryption = TARGETED_FILE_EXTENSIONS.copy() self._valid_file_extensions_for_encryption.discard(self._new_file_extension) @@ -44,8 +36,22 @@ class RansomwarePayload: self._encryptor = BitflipEncryptor(chunk_size=CHUNK_SIZE) self._telemetry_messenger = telemetry_messenger + @staticmethod + def get_target_dir(config: dict): + target_directories = config["encryption"]["directories"] + if is_windows_os(): + target_dir_field = target_directories["windows_target_dir"] + else: + target_dir_field = target_directories["linux_target_dir"] + + try: + return Path(expand_path(target_dir_field)) + except InvalidPath as e: + LOG.debug(f"Target ransomware dir set to None: {e}") + return None + def run_payload(self): - if self._encryption_enabled: + if self._encryption_enabled and self._target_dir: LOG.info("Running ransomware payload") file_list = self._find_files() self._encrypt_files(file_list) diff --git a/monkey/tests/unit_tests/common/utils/test_common_file_utils.py b/monkey/tests/unit_tests/common/utils/test_common_file_utils.py index 226a403b8..b67341cfe 100644 --- a/monkey/tests/unit_tests/common/utils/test_common_file_utils.py +++ b/monkey/tests/unit_tests/common/utils/test_common_file_utils.py @@ -1,6 +1,8 @@ import os -from common.utils.file_utils import expand_path +import pytest + +from common.utils.file_utils import InvalidPath, expand_path def test_expand_user(patched_home_env): @@ -15,3 +17,8 @@ def test_expand_vars(patched_home_env): expected_path = os.path.join(patched_home_env, "test") assert expand_path(input_path) == expected_path + + +def test_expand_path__empty_path_provided(): + with pytest.raises(InvalidPath): + expand_path("")