diff --git a/monkey/monkey_island/cc/resources/auth/registration.py b/monkey/monkey_island/cc/resources/auth/registration.py index 0877ee4a3..e6743302f 100644 --- a/monkey/monkey_island/cc/resources/auth/registration.py +++ b/monkey/monkey_island/cc/resources/auth/registration.py @@ -9,7 +9,7 @@ from monkey_island.cc.resources.auth.credential_utils import ( get_secret_from_request, get_user_credentials_from_request, ) -from monkey_island.cc.server_utils.encryption.data_store_encryptor import setup_datastore_key +from monkey_island.cc.server_utils.encryption import remove_old_datastore_key, setup_datastore_key from monkey_island.cc.setup.mongo.database_initializer import reset_database logger = logging.getLogger(__name__) @@ -21,11 +21,11 @@ class Registration(flask_restful.Resource): return {"needs_registration": is_registration_needed} def post(self): - # TODO delete the old key here, before creating new one credentials = get_user_credentials_from_request(request) try: env_singleton.env.try_add_user(credentials) + remove_old_datastore_key() setup_datastore_key(get_secret_from_request(request)) reset_database() return make_response({"error": ""}, 200) diff --git a/monkey/monkey_island/cc/server_utils/encryption/__init__.py b/monkey/monkey_island/cc/server_utils/encryption/__init__.py index 84e6e6252..531659a9e 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/__init__.py +++ b/monkey/monkey_island/cc/server_utils/encryption/__init__.py @@ -9,6 +9,9 @@ from monkey_island.cc.server_utils.encryption.data_store_encryptor import ( DataStoreEncryptor, get_datastore_encryptor, initialize_datastore_encryptor, + remove_old_datastore_key, + setup_datastore_key, + EncryptorNotInitializedError, ) from .dict_encryption.dict_encryptor import ( SensitiveField, diff --git a/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py b/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py index e5a054080..49d39b505 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py @@ -69,12 +69,28 @@ class EncryptorNotInitializedError(Exception): pass +def encryptor_initialized_key_not_set(f): + def inner_function(*args, **kwargs): + if _encryptor is None: + raise EncryptorNotInitializedError + else: + if not _encryptor.is_key_setup(): + return f(*args, **kwargs) + else: + pass + + return inner_function + + +@encryptor_initialized_key_not_set +def remove_old_datastore_key(): + if os.path.isfile(_encryptor.key_file_path): + os.remove(_encryptor.key_file_path) + + +@encryptor_initialized_key_not_set def setup_datastore_key(secret: str): - if _encryptor is None: - raise EncryptorNotInitializedError - else: - if not _encryptor.is_key_setup(): - _encryptor.init_key(secret) + _encryptor.init_key(secret) def get_datastore_encryptor(): diff --git a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py index 22f2e9676..746054841 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py +++ b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py @@ -5,10 +5,13 @@ from tests.unit_tests.monkey_island.cc.conftest import ENCRYPTOR_SECRET from monkey_island.cc.server_utils.encryption import ( DataStoreEncryptor, + EncryptorNotInitializedError, + data_store_encryptor, get_datastore_encryptor, initialize_datastore_encryptor, + remove_old_datastore_key, + setup_datastore_key, ) -from monkey_island.cc.server_utils.encryption.data_store_encryptor import setup_datastore_key PLAINTEXT = "Hello, Monkey!" @@ -22,7 +25,47 @@ def test_encryption(data_for_tests_dir): assert decrypted_data == PLAINTEXT -def test_create_new_password_file(tmpdir): +@pytest.fixture +def initialized_key_dir(tmpdir): initialize_datastore_encryptor(tmpdir) setup_datastore_key(ENCRYPTOR_SECRET) + yield tmpdir + data_store_encryptor._encryptor = None + + +def test_key_creation(initialized_key_dir): + assert os.path.isfile(os.path.join(initialized_key_dir, DataStoreEncryptor._KEY_FILENAME)) + + +def test_key_removal_fails_if_key_initialized(initialized_key_dir): + remove_old_datastore_key() + assert os.path.isfile(os.path.join(initialized_key_dir, DataStoreEncryptor._KEY_FILENAME)) + + +def test_key_removal(initialized_key_dir, monkeypatch): + monkeypatch.setattr(DataStoreEncryptor, "is_key_setup", lambda _: False) + remove_old_datastore_key() + assert not os.path.isfile(os.path.join(initialized_key_dir, DataStoreEncryptor._KEY_FILENAME)) + + +def test_key_removal__no_key(tmpdir): + initialize_datastore_encryptor(tmpdir) + assert not os.path.isfile(os.path.join(tmpdir, DataStoreEncryptor._KEY_FILENAME)) + # Make sure no error thrown when we try to remove an non-existing key + remove_old_datastore_key() + + data_store_encryptor._encryptor = None + + +def test_encryptor_not_initialized(): + with pytest.raises(EncryptorNotInitializedError): + remove_old_datastore_key() + setup_datastore_key() + + +def test_setup_datastore_key(tmpdir): + initialize_datastore_encryptor(tmpdir) + assert not os.path.isfile(os.path.join(tmpdir, DataStoreEncryptor._KEY_FILENAME)) + setup_datastore_key(ENCRYPTOR_SECRET) assert os.path.isfile(os.path.join(tmpdir, DataStoreEncryptor._KEY_FILENAME)) + assert get_datastore_encryptor().is_key_setup()