diff --git a/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py b/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py index 424a511f7..6a44b5cff 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import Union from Crypto import Random # noqa: DUO133 # nosec: B413 @@ -9,49 +10,62 @@ from .i_encryptor import IEncryptor from .key_based_encryptor import KeyBasedEncryptor from .password_based_bytes_encryptor import PasswordBasedBytesEncryptor -_KEY_FILENAME = "mongo_key.bin" -_BLOCK_SIZE = 32 - _encryptor: Union[None, IEncryptor] = None -def _load_existing_key(key_file_path: str, secret: str) -> KeyBasedEncryptor: - with open(key_file_path, "rb") as f: - encrypted_key = f.read() - cipher_key = PasswordBasedBytesEncryptor(secret).decrypt(encrypted_key) - return KeyBasedEncryptor(cipher_key) +class DataStoreEncryptor(IEncryptor): + _KEY_LENGTH_BYTES = 32 + + def __init__(self, secret: str, key_file_path: Path): + self._key_file_path = key_file_path + self._password_based_encryptor = PasswordBasedBytesEncryptor(secret) + self._key_based_encryptor = self._initialize_key_based_encryptor() + + def _initialize_key_based_encryptor(self): + if os.path.exists(self._key_file_path): + return self._load_existing_key() + + return self._create_new_key() + + def _load_existing_key(self) -> KeyBasedEncryptor: + with open(self._key_file_path, "rb") as f: + encrypted_key = f.read() + + plaintext_key = self._password_based_encryptor.decrypt(encrypted_key) + return KeyBasedEncryptor(plaintext_key) + + def _create_new_key(self) -> KeyBasedEncryptor: + plaintext_key = Random.new().read(DataStoreEncryptor._KEY_LENGTH_BYTES) + + encrypted_key = self._password_based_encryptor.encrypt(plaintext_key) + with open_new_securely_permissioned_file(self._key_file_path, "wb") as f: + f.write(encrypted_key) + + return KeyBasedEncryptor(plaintext_key) + + def encrypt(self, plaintext: str) -> str: + return self._key_based_encryptor.encrypt(plaintext) + + def decrypt(self, ciphertext: str): + return self._key_based_encryptor.decrypt(ciphertext) + + def erase_key(self): + if self._key_file_path.is_file(): + self._key_file_path.unlink() -def _create_new_key(key_file_path: str, secret: str) -> KeyBasedEncryptor: - cipher_key = _get_random_bytes() - encrypted_key = PasswordBasedBytesEncryptor(secret).encrypt(cipher_key) - with open_new_securely_permissioned_file(key_file_path, "wb") as f: - f.write(encrypted_key) - return KeyBasedEncryptor(cipher_key) +def remove_old_datastore_key(): + if _encryptor: + _encryptor.erase_key() -def _get_random_bytes() -> bytes: - return Random.new().read(_BLOCK_SIZE) - - -def remove_old_datastore_key(key_file_dir: str): - key_file_path = _get_key_file_path(key_file_dir) - if os.path.isfile(key_file_path): - os.remove(key_file_path) - - -def initialize_datastore_encryptor(key_file_dir: str, secret: str): +def initialize_datastore_encryptor( + key_file_dir: str, secret: str, key_file_name: str = "mongo_key.bin" +): global _encryptor - key_file_path = _get_key_file_path(key_file_dir) - if os.path.exists(key_file_path): - _encryptor = _load_existing_key(key_file_path, secret) - else: - _encryptor = _create_new_key(key_file_path, secret) - - -def _get_key_file_path(key_file_dir: str) -> str: - return os.path.join(key_file_dir, _KEY_FILENAME) + key_file_path = Path(key_file_dir) / key_file_name + _encryptor = DataStoreEncryptor(secret, key_file_path) def get_datastore_encryptor() -> IEncryptor: diff --git a/monkey/monkey_island/cc/services/authentication.py b/monkey/monkey_island/cc/services/authentication.py index 9d3d3baa7..2d7940055 100644 --- a/monkey/monkey_island/cc/services/authentication.py +++ b/monkey/monkey_island/cc/services/authentication.py @@ -22,7 +22,7 @@ class AuthenticationService: @staticmethod def reset_datastore_encryptor(username: str, password: str): - remove_old_datastore_key(AuthenticationService.KEY_FILE_DIRECTORY) + remove_old_datastore_key() AuthenticationService._init_encryptor_from_credentials(username, password) @staticmethod diff --git a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py index f24040c22..7d5616185 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py +++ b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py @@ -10,10 +10,24 @@ from monkey_island.cc.server_utils.encryption import ( PLAINTEXT = "Hello, Monkey!" MOCK_SECRET = "53CR31" +KEY_FILENAME = "test_key.bin" + + +@pytest.fixture(autouse=True) +def cleanup_encryptor(): + yield + data_store_encryptor._encryptor = None + + +@pytest.fixture +def key_file(tmp_path): + return tmp_path / KEY_FILENAME + @pytest.mark.slow -@pytest.mark.usefixtures("uses_encryptor") -def test_encryption(data_for_tests_dir): +def test_encryption(tmp_path): + initialize_datastore_encryptor(tmp_path, MOCK_SECRET, KEY_FILENAME) + encrypted_data = get_datastore_encryptor().encrypt(PLAINTEXT) assert encrypted_data != PLAINTEXT @@ -21,45 +35,35 @@ def test_encryption(data_for_tests_dir): assert decrypted_data == PLAINTEXT -@pytest.fixture -def cleanup_encryptor(): - yield - data_store_encryptor._encryptor = None - - -@pytest.mark.usefixtures("cleanup_encryptor") -@pytest.fixture -def initialized_encryptor_dir(tmpdir): - initialize_datastore_encryptor(tmpdir, MOCK_SECRET) - return tmpdir +@pytest.mark.slow +def test_key_creation(key_file, tmp_path): + assert not key_file.is_file() + initialize_datastore_encryptor(tmp_path, MOCK_SECRET, KEY_FILENAME) + assert key_file.is_file() @pytest.mark.slow -def test_key_creation(initialized_encryptor_dir): - assert (initialized_encryptor_dir / data_store_encryptor._KEY_FILENAME).isfile() +def test_key_removal(key_file, tmp_path): + initialize_datastore_encryptor(tmp_path, MOCK_SECRET, KEY_FILENAME) + assert key_file.is_file() + + remove_old_datastore_key() + assert not key_file.is_file() -@pytest.mark.slow -def test_key_removal(initialized_encryptor_dir): - remove_old_datastore_key(initialized_encryptor_dir) - assert not (initialized_encryptor_dir / data_store_encryptor._KEY_FILENAME).isfile() - - -def test_key_removal__no_key(tmpdir): - assert not (tmpdir / data_store_encryptor._KEY_FILENAME).isfile() +def test_key_removal__no_key(key_file): + assert not key_file.is_file() # Make sure no error thrown when we try to remove an non-existing key - remove_old_datastore_key(tmpdir) - data_store_encryptor._factory = None + remove_old_datastore_key() -@pytest.mark.slow -@pytest.mark.usefixtures("cleanup_encryptor") -def test_key_file_encryption(tmpdir, monkeypatch): - monkeypatch.setattr(data_store_encryptor, "_get_random_bytes", lambda: PLAINTEXT.encode()) - initialize_datastore_encryptor(tmpdir, MOCK_SECRET) - key_file_path = data_store_encryptor._get_key_file_path(tmpdir) - key_file_contents = open(key_file_path, "rb").read() - assert not key_file_contents == PLAINTEXT.encode() +def test_key_removal__no_key_2(key_file, tmp_path): + assert not key_file.is_file() + initialize_datastore_encryptor(tmp_path, MOCK_SECRET, KEY_FILENAME) + assert key_file.is_file() - key_based_encryptor = data_store_encryptor._load_existing_key(key_file_path, MOCK_SECRET) - assert key_based_encryptor._key == PLAINTEXT.encode() + key_file.unlink() + assert not key_file.is_file() + + # Make sure no error thrown when we try to remove an non-existing key + get_datastore_encryptor().erase_key()