diff --git a/monkey/monkey_island/cc/server_utils/encryptor.py b/monkey/monkey_island/cc/server_utils/encryptor.py index 657a06e87..06a973526 100644 --- a/monkey/monkey_island/cc/server_utils/encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryptor.py @@ -6,7 +6,7 @@ import os from Crypto import Random # noqa: DUO133 # nosec: B413 from Crypto.Cipher import AES # noqa: DUO133 # nosec: B413 -from monkey_island.cc.server_utils.file_utils import get_file_descriptor_for_new_secure_file +from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file __author__ = "itay.mizeretz" @@ -27,8 +27,7 @@ class Encryptor: def _init_key(self, password_file_path: str): self._cipher_key = Random.new().read(self._BLOCK_SIZE) - get_file_descriptor_for_new_secure_file(path=password_file_path) - with open(password_file_path, "wb") as f: + with open_new_securely_permissioned_file(password_file_path, "wb") as f: f.write(self._cipher_key) def _load_existing_key(self, password_file): diff --git a/monkey/monkey_island/cc/server_utils/file_utils.py b/monkey/monkey_island/cc/server_utils/file_utils.py index 995501c09..e429eb464 100644 --- a/monkey/monkey_island/cc/server_utils/file_utils.py +++ b/monkey/monkey_island/cc/server_utils/file_utils.py @@ -2,6 +2,8 @@ import logging import os import platform import stat +from contextlib import contextmanager +from typing import Generator LOG = logging.getLogger(__name__) @@ -54,11 +56,15 @@ def _create_secure_directory_windows(path: str): raise ex -def get_file_descriptor_for_new_secure_file(path: str) -> int: +@contextmanager +def open_new_securely_permissioned_file(path: str, mode: str = "w") -> Generator: if is_windows_os(): - return _get_file_descriptor_for_new_secure_file_windows(path) + fd = _get_file_descriptor_for_new_secure_file_windows(path) else: - return _get_file_descriptor_for_new_secure_file_linux(path) + fd = _get_file_descriptor_for_new_secure_file_linux(path) + + with open(fd, mode) as f: + yield f def _get_file_descriptor_for_new_secure_file_linux(path: str) -> int: @@ -79,8 +85,12 @@ def _get_file_descriptor_for_new_secure_file_linux(path: str) -> int: def _get_file_descriptor_for_new_secure_file_windows(path: str) -> int: try: file_access = win32file.GENERIC_READ | win32file.GENERIC_WRITE - # subsequent open operations on the object will succeed only if read access is requested + + # Enables other processes to open this file with read-only access. + # Attempts by other processes to open the file for writing while this + # process still holds it open will fail. file_sharing = win32file.FILE_SHARE_READ + security_attributes = win32security.SECURITY_ATTRIBUTES() security_attributes.SECURITY_DESCRIPTOR = ( windows_permissions.get_security_descriptor_for_owner_only_perms() @@ -88,7 +98,7 @@ def _get_file_descriptor_for_new_secure_file_windows(path: str) -> int: file_creation = win32file.CREATE_NEW # fails if file exists file_attributes = win32file.FILE_FLAG_BACKUP_SEMANTICS - fd = win32file.CreateFile( + handle = win32file.CreateFile( path, file_access, file_sharing, @@ -98,13 +108,15 @@ def _get_file_descriptor_for_new_secure_file_windows(path: str) -> int: _get_null_value_for_win32(), ) - return fd + detached_handle = handle.Detach() + + return win32file._open_osfhandle(detached_handle, os.O_RDWR) except Exception as ex: LOG.error(f'Could not create a file at "{path}": {str(ex)}') raise ex -def _get_null_value_for_win32() -> None: +def _get_null_value_for_win32(): # https://stackoverflow.com/questions/46800142/in-python-with-pywin32-win32job-the-createjobobject-function-how-do-i-pass-nu # noqa: E501 return win32job.CreateJobObject(None, "") diff --git a/monkey/tests/unit_tests/monkey_island/cc/server_utils/test_file_utils.py b/monkey/tests/unit_tests/monkey_island/cc/server_utils/test_file_utils.py index 894f1e6b3..9a9ada29d 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/server_utils/test_file_utils.py +++ b/monkey/tests/unit_tests/monkey_island/cc/server_utils/test_file_utils.py @@ -6,13 +6,12 @@ import pytest from monkey_island.cc.server_utils.file_utils import ( create_secure_directory, expand_path, - get_file_descriptor_for_new_secure_file, is_windows_os, + open_new_securely_permissioned_file, ) if is_windows_os(): import win32api - import win32file import win32security FULL_CONTROL = 2032127 @@ -99,22 +98,26 @@ def test_create_secure_directory__perm_windows(test_path): assert ace_inheritance == ACE_INHERIT_OBJECT_AND_CONTAINER -def test_get_file_descriptor_for_new_secure_file__already_exists(test_path): +def test_open_new_securely_permissioned_file__already_exists(test_path): os.close(os.open(test_path, os.O_CREAT, stat.S_IRWXU)) assert os.path.isfile(test_path) with pytest.raises(Exception): - get_file_descriptor_for_new_secure_file(test_path) + with open_new_securely_permissioned_file(test_path): + pass -def test_get_file_descriptor_for_new_secure_file__no_parent_dir(test_path_nested): +def test_open_new_securely_permissioned_file__no_parent_dir(test_path_nested): with pytest.raises(Exception): - get_file_descriptor_for_new_secure_file(test_path_nested) + with open_new_securely_permissioned_file(test_path_nested): + pass @pytest.mark.skipif(is_windows_os(), reason="Tests Posix (not Windows) permissions.") -def test_get_file_descriptor_for_new_secure_file__perm_linux(test_path): - os.close(get_file_descriptor_for_new_secure_file(test_path)) +def test_open_new_securely_permissioned_file__perm_linux(test_path): + with open_new_securely_permissioned_file(test_path): + pass + st = os.stat(test_path) expected_mode = stat.S_IRUSR | stat.S_IWUSR @@ -124,8 +127,9 @@ def test_get_file_descriptor_for_new_secure_file__perm_linux(test_path): @pytest.mark.skipif(not is_windows_os(), reason="Tests Windows (not Posix) permissions.") -def test_get_file_descriptor_for_new_secure_file__perm_windows(test_path): - win32file.CloseHandle(get_file_descriptor_for_new_secure_file(test_path)) +def test_open_new_securely_permissioned_file__perm_windows(test_path): + with open_new_securely_permissioned_file(test_path): + pass acl, user_sid = _get_acl_and_sid_from_path(test_path) @@ -141,3 +145,12 @@ def test_get_file_descriptor_for_new_secure_file__perm_windows(test_path): assert ace_sid == user_sid assert ace_permissions == FULL_CONTROL and ace_access_mode == ACE_ACCESS_MODE_GRANT_ACCESS assert ace_inheritance == ACE_INHERIT_OBJECT_AND_CONTAINER + + +def test_open_new_securely_permissioned_file__write(test_path): + TEST_STR = b"Hello World" + with open_new_securely_permissioned_file(test_path, "wb") as f: + f.write(TEST_STR) + + with open(test_path, "rb") as f: + assert f.read() == TEST_STR