island: Return file descriptor when creating secure file

This commit is contained in:
Shreya 2021-06-15 18:55:15 +05:30
parent 6b4a0906c0
commit 14371f3fba
2 changed files with 25 additions and 22 deletions

View File

@ -6,7 +6,7 @@ import os
from Crypto import Random # noqa: DUO133 # nosec: B413 from Crypto import Random # noqa: DUO133 # nosec: B413
from Crypto.Cipher import AES # noqa: DUO133 # nosec: B413 from Crypto.Cipher import AES # noqa: DUO133 # nosec: B413
from monkey_island.cc.server_utils.file_utils import create_secure_file from monkey_island.cc.server_utils.file_utils import get_file_descriptor_for_new_secure_file
__author__ = "itay.mizeretz" __author__ = "itay.mizeretz"
@ -23,12 +23,11 @@ class Encryptor:
if os.path.exists(password_file): if os.path.exists(password_file):
self._load_existing_key(password_file) self._load_existing_key(password_file)
else: else:
create_secure_file(path=password_file)
self._init_key(password_file) self._init_key(password_file)
def _init_key(self, password_file): def _init_key(self, password_file):
self._cipher_key = Random.new().read(self._BLOCK_SIZE) self._cipher_key = Random.new().read(self._BLOCK_SIZE)
with open(password_file, "wb") as f: with open(get_file_descriptor_for_new_secure_file(path=password_file)) as f:
f.write(self._cipher_key) f.write(self._cipher_key)
def _load_existing_key(self, password_file): def _load_existing_key(self, password_file):

View File

@ -2,7 +2,6 @@ import logging
import os import os
import platform import platform
import stat import stat
from pathlib import Path
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -55,25 +54,30 @@ def _create_secure_directory_windows(path: str):
raise ex raise ex
def create_secure_file(path: str): def get_file_descriptor_for_new_secure_file(path: str):
if not os.path.isfile(path): if not os.path.isfile(path):
if is_windows_os(): if is_windows_os():
_create_secure_file_windows(path) return _get_file_descriptor_for_new_secure_file_windows(path)
else: else:
_create_secure_file_linux(path) return _get_file_descriptor_for_new_secure_file_linux(path)
def _create_secure_file_linux(path: str): def _get_file_descriptor_for_new_secure_file_linux(path: str):
try: try:
mode = stat.S_IRUSR | stat.S_IWUSR mode = stat.S_IRUSR | stat.S_IWUSR
Path(path).touch(mode=mode, exist_ok=False) flags = (
os.O_RDWR | os.O_CREAT | os.O_EXCL
) # read/write, create new, throw error if file exists
fd = os.open(path, flags, mode)
return fd
except Exception as ex: except Exception as ex:
LOG.error(f'Could not create a file at "{path}": {str(ex)}') LOG.error(f'Could not create a file at "{path}": {str(ex)}')
raise ex raise ex
def _create_secure_file_windows(path: str): def _get_file_descriptor_for_new_secure_file_windows(path: str):
try: try:
file_access = win32file.GENERIC_READ | win32file.GENERIC_WRITE file_access = win32file.GENERIC_READ | win32file.GENERIC_WRITE
file_sharing = ( file_sharing = (
@ -86,8 +90,7 @@ def _create_secure_file_windows(path: str):
file_creation = win32file.CREATE_NEW # fails if file exists file_creation = win32file.CREATE_NEW # fails if file exists
file_attributes = win32file.FILE_FLAG_BACKUP_SEMANTICS file_attributes = win32file.FILE_FLAG_BACKUP_SEMANTICS
win32file.CloseHandle( fd = win32file.CreateFile(
win32file.CreateFile(
path, path,
file_access, file_access,
file_sharing, file_sharing,
@ -98,7 +101,8 @@ def _create_secure_file_windows(path: str):
None, "" None, ""
), # https://stackoverflow.com/questions/46800142/in-python-with-pywin32-win32job-the-createjobobject-function-how-do-i-pass-nu # noqa: E501 ), # https://stackoverflow.com/questions/46800142/in-python-with-pywin32-win32job-the-createjobobject-function-how-do-i-pass-nu # noqa: E501
) )
)
return fd
except Exception as ex: except Exception as ex:
LOG.error(f'Could not create a file at "{path}": {str(ex)}') LOG.error(f'Could not create a file at "{path}": {str(ex)}')