diff --git a/monkey/monkey_island/cc/server_utils/encryption/__init__.py b/monkey/monkey_island/cc/server_utils/encryption/__init__.py index fa9692ca3..1c3e422db 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/__init__.py +++ b/monkey/monkey_island/cc/server_utils/encryption/__init__.py @@ -8,7 +8,7 @@ from .encryptor_factory import ( FactoryNotInitializedError, remove_old_datastore_key, get_encryptor_factory, - get_secret_from_credentials, + _get_secret_from_credentials, initialize_encryptor_factory, ) from .data_store_encryptor import initialize_datastore_encryptor, get_datastore_encryptor diff --git a/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py b/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py index f58401783..949102c84 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py @@ -1,49 +1,17 @@ from __future__ import annotations -import os - # PyCrypto is deprecated, but we use pycryptodome, which uses the exact same imports but # is maintained. from typing import Union -from Crypto import Random # noqa: DUO133 # nosec: B413 - -from monkey_island.cc.server_utils.encryption import FactoryNotInitializedError, KeyBasedEncryptor -from monkey_island.cc.server_utils.encryption.encryptor_factory import ( - get_encryptor_factory, - get_secret_from_credentials, -) -from monkey_island.cc.server_utils.encryption.password_based_bytes_encryption import ( - PasswordBasedBytesEncryptor, -) -from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file +from monkey_island.cc.server_utils.encryption import KeyBasedEncryptor _encryptor: Union[None, DataStoreEncryptor] = None class DataStoreEncryptor: - _BLOCK_SIZE = 32 - - def __init__(self, key_file_path: str, secret: str): - if os.path.exists(key_file_path): - self._key_based_encryptor = DataStoreEncryptor._load_existing_key(key_file_path, secret) - else: - self._key_based_encryptor = DataStoreEncryptor._create_new_key(key_file_path, secret) - - @staticmethod - def _load_existing_key(key_file_path: str, secret: str): - with open(key_file_path, "rb") as f: - encrypted_key = f.read() - cipher_key = PasswordBasedBytesEncryptor(secret).decrypt(encrypted_key) - return KeyBasedEncryptor(cipher_key) - - @staticmethod - def _create_new_key(key_file_path: str, secret: str): - cipher_key = Random.new().read(DataStoreEncryptor._BLOCK_SIZE) - encrypted_key = PasswordBasedBytesEncryptor(secret).encrypt(cipher_key) - with open_new_securely_permissioned_file(key_file_path, "wb") as f: - f.write(encrypted_key) - return KeyBasedEncryptor(cipher_key) + def __init__(self, key_based_encryptor: KeyBasedEncryptor): + self._key_based_encryptor = key_based_encryptor def enc(self, message: str): return self._key_based_encryptor.encrypt(message) @@ -52,14 +20,10 @@ class DataStoreEncryptor: return self._key_based_encryptor.decrypt(enc_message) -def initialize_datastore_encryptor(username: str, password: str): +def initialize_datastore_encryptor(key_based_encryptor: KeyBasedEncryptor): global _encryptor - factory = get_encryptor_factory() - if not factory: - raise FactoryNotInitializedError - secret = get_secret_from_credentials(username, password) - _encryptor = DataStoreEncryptor(factory.key_file_path, secret) + _encryptor = DataStoreEncryptor(key_based_encryptor) def get_datastore_encryptor(): diff --git a/monkey/monkey_island/cc/server_utils/encryption/encryptor_factory.py b/monkey/monkey_island/cc/server_utils/encryption/encryptor_factory.py index 67d3e6ee4..0ae3e70a6 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/encryptor_factory.py +++ b/monkey/monkey_island/cc/server_utils/encryption/encryptor_factory.py @@ -1,38 +1,75 @@ -from __future__ import annotations - import os -from ctypes import Union -_factory: Union[None, EncryptorFactory] = None +from Crypto import Random + +from monkey_island.cc.server_utils.encryption import ( + KeyBasedEncryptor, + initialize_datastore_encryptor, +) +from monkey_island.cc.server_utils.encryption.password_based_bytes_encryption import ( + PasswordBasedBytesEncryptor, +) +from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file + +_KEY_FILENAME = "mongo_key.bin" +_BLOCK_SIZE = 32 class EncryptorFactory: + def __init__(self): + self.key_file_path = None + self.secret = None - _KEY_FILENAME = "mongo_key.bin" + def set_key_file_path(self, key_file_path: str): + self.key_file_path = key_file_path - def __init__(self, key_file_dir: str): - self.key_file_path = os.path.join(key_file_dir, self._KEY_FILENAME) + def set_secret(self, username: str, password: str): + self.secret = _get_secret_from_credentials(username, password) + + def initialize_encryptor(self): + if os.path.exists(self.key_file_path): + key_based_encryptor = _load_existing_key(self.key_file_path, self.secret) + else: + key_based_encryptor = _create_new_key(self.key_file_path, self.secret) + initialize_datastore_encryptor(key_based_encryptor) -class FactoryNotInitializedError(Exception): +class KeyPathNotSpecifiedError(Exception): pass -def get_secret_from_credentials(username: str, password: str) -> str: +def _load_existing_key(key_file_path: str, secret: str): + with open(key_file_path, "rb") as f: + encrypted_key = f.read() + cipher_key = PasswordBasedBytesEncryptor(secret).decrypt(encrypted_key) + return KeyBasedEncryptor(cipher_key) + + +def _create_new_key(key_file_path: str, secret: str): + cipher_key = _get_random_bytes() + encrypted_key = PasswordBasedBytesEncryptor(secret).encrypt(cipher_key) + with open_new_securely_permissioned_file(key_file_path, "wb") as f: + f.write(encrypted_key) + return KeyBasedEncryptor(cipher_key) + + +def _get_random_bytes() -> bytes: + return Random.new().read(_BLOCK_SIZE) + + +def _get_secret_from_credentials(username: str, password: str) -> str: return f"{username}:{password}" def remove_old_datastore_key(): - if _factory is None: - raise FactoryNotInitializedError + if not _factory.key_file_path: + raise KeyPathNotSpecifiedError if os.path.isfile(_factory.key_file_path): os.remove(_factory.key_file_path) -def initialize_encryptor_factory(key_file_dir: str): - global _factory - _factory = EncryptorFactory(key_file_dir) - - def get_encryptor_factory(): return _factory + + +_factory = EncryptorFactory() diff --git a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py index e8901862b..1d4d23273 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py +++ b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py @@ -10,6 +10,7 @@ from monkey_island.cc.server_utils.encryption import ( initialize_encryptor_factory, remove_old_datastore_key, ) +from monkey_island.cc.server_utils.encryption.data_store_encryptor import DataStoreEncryptor from monkey_island.cc.server_utils.encryption.encryptor_factory import EncryptorFactory PLAINTEXT = "Hello, Monkey!" @@ -37,7 +38,7 @@ def test_key_creation(initialized_key_dir): assert (initialized_key_dir / EncryptorFactory._KEY_FILENAME).isfile() -def test_key_removal(initialized_key_dir, monkeypatch): +def test_key_removal(initialized_key_dir): remove_old_datastore_key() assert not (initialized_key_dir / EncryptorFactory._KEY_FILENAME).isfile() @@ -61,3 +62,7 @@ def test_initialize_encryptor(tmpdir): assert not (tmpdir / EncryptorFactory._KEY_FILENAME).isfile() initialize_datastore_encryptor(MOCK_USERNAME, MOCK_PASSWORD) assert (tmpdir / EncryptorFactory._KEY_FILENAME).isfile() + + +def test_key_file_encryption(tmpdir, monkeypatch): + monkeypatch(DataStoreEncryptor._)