diff --git a/monkey/monkey_island/cc/server_utils/encryptor.py b/monkey/monkey_island/cc/server_utils/encryptor.py index abeb34dc3..4f376532c 100644 --- a/monkey/monkey_island/cc/server_utils/encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryptor.py @@ -6,7 +6,7 @@ import os from Crypto import Random # noqa: DUO133 # nosec: B413 from Crypto.Cipher import AES # noqa: DUO133 # nosec: B413 -from monkey_island.cc.server_utils.file_utils import create_secure_file +from monkey_island.cc.server_utils.file_utils import get_file_descriptor_for_new_secure_file __author__ = "itay.mizeretz" @@ -23,12 +23,11 @@ class Encryptor: if os.path.exists(password_file): self._load_existing_key(password_file) else: - create_secure_file(path=password_file) self._init_key(password_file) def _init_key(self, password_file): self._cipher_key = Random.new().read(self._BLOCK_SIZE) - with open(password_file, "wb") as f: + with open(get_file_descriptor_for_new_secure_file(path=password_file)) as f: f.write(self._cipher_key) def _load_existing_key(self, password_file): diff --git a/monkey/monkey_island/cc/server_utils/file_utils.py b/monkey/monkey_island/cc/server_utils/file_utils.py index cb94065e4..8681bcc54 100644 --- a/monkey/monkey_island/cc/server_utils/file_utils.py +++ b/monkey/monkey_island/cc/server_utils/file_utils.py @@ -2,7 +2,6 @@ import logging import os import platform import stat -from pathlib import Path LOG = logging.getLogger(__name__) @@ -55,25 +54,30 @@ def _create_secure_directory_windows(path: str): raise ex -def create_secure_file(path: str): +def get_file_descriptor_for_new_secure_file(path: str): if not os.path.isfile(path): if is_windows_os(): - _create_secure_file_windows(path) + return _get_file_descriptor_for_new_secure_file_windows(path) else: - _create_secure_file_linux(path) + return _get_file_descriptor_for_new_secure_file_linux(path) -def _create_secure_file_linux(path: str): +def _get_file_descriptor_for_new_secure_file_linux(path: str): try: mode = stat.S_IRUSR | stat.S_IWUSR - Path(path).touch(mode=mode, exist_ok=False) + flags = ( + os.O_RDWR | os.O_CREAT | os.O_EXCL + ) # read/write, create new, throw error if file exists + fd = os.open(path, flags, mode) + + return fd except Exception as ex: LOG.error(f'Could not create a file at "{path}": {str(ex)}') raise ex -def _create_secure_file_windows(path: str): +def _get_file_descriptor_for_new_secure_file_windows(path: str): try: file_access = win32file.GENERIC_READ | win32file.GENERIC_WRITE file_sharing = ( @@ -86,20 +90,20 @@ def _create_secure_file_windows(path: str): file_creation = win32file.CREATE_NEW # fails if file exists file_attributes = win32file.FILE_FLAG_BACKUP_SEMANTICS - win32file.CloseHandle( - win32file.CreateFile( - path, - file_access, - file_sharing, - security_attributes, - file_creation, - file_attributes, - win32job.CreateJobObject( - None, "" - ), # https://stackoverflow.com/questions/46800142/in-python-with-pywin32-win32job-the-createjobobject-function-how-do-i-pass-nu # noqa: E501 - ) + fd = win32file.CreateFile( + path, + file_access, + file_sharing, + security_attributes, + file_creation, + file_attributes, + win32job.CreateJobObject( + None, "" + ), # https://stackoverflow.com/questions/46800142/in-python-with-pywin32-win32job-the-createjobobject-function-how-do-i-pass-nu # noqa: E501 ) + return fd + except Exception as ex: LOG.error(f'Could not create a file at "{path}": {str(ex)}') raise ex