diff --git a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py index 722aa3463..3da05611f 100644 --- a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py @@ -80,30 +80,27 @@ class MongoCredentialsRepository(ICredentialsRepository): def _encrypt_credentials_mapping(self, mapping: Mapping[str, Any]) -> Mapping[str, Any]: encrypted_mapping: Dict[str, Any] = {} - for secret_or_identity, credentials_components in mapping.items(): - encrypted_mapping[secret_or_identity] = [] - for component in credentials_components: - encrypted_component = {} - for key, value in component.items(): - encrypted_component[key] = self._repository_encryptor.encrypt(value.encode()) + for secret_or_identity, credentials_component in mapping.items(): + encrypted_component = {} + for key, value in credentials_component.items(): + encrypted_component[key] = self._repository_encryptor.encrypt(value.encode()) - encrypted_mapping[secret_or_identity].append(encrypted_component) + encrypted_mapping[secret_or_identity] = encrypted_component return encrypted_mapping def _decrypt_credentials_mapping(self, mapping: Mapping[str, Any]) -> Mapping[str, Any]: - encrypted_mapping: Dict[str, Any] = {} + decrypted_mapping: Dict[str, Any] = {} - for secret_or_identity, credentials_components in mapping.items(): - encrypted_mapping[secret_or_identity] = [] - for component in credentials_components: - encrypted_component = {} - for key, value in component.items(): - encrypted_component[key] = self._repository_encryptor.decrypt(value).decode() + for secret_or_identity, credentials_component in mapping.items(): + decrypted_mapping[secret_or_identity] = [] + decrypted_component = {} + for key, value in credentials_component.items(): + decrypted_component[key] = self._repository_encryptor.decrypt(value).decode() - encrypted_mapping[secret_or_identity].append(encrypted_component) + decrypted_mapping[secret_or_identity] = decrypted_component - return encrypted_mapping + return decrypted_mapping @staticmethod def _remove_credentials_fom_collection(collection): diff --git a/monkey/tests/data_for_tests/propagation_credentials.py b/monkey/tests/data_for_tests/propagation_credentials.py index 47862b2bc..cea619ca7 100644 --- a/monkey/tests/data_for_tests/propagation_credentials.py +++ b/monkey/tests/data_for_tests/propagation_credentials.py @@ -1,4 +1,4 @@ -from common.credentials import Credentials, LMHash, NTHash, Password, Username +from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username username = "m0nk3y_user" special_username = "m0nk3y.user" @@ -6,8 +6,21 @@ nt_hash = "C1C58F96CDF212B50837BC11A00BE47C" lm_hash = "299BD128C1101FD6299BD128C1101FD6" password_1 = "trytostealthis" password_2 = "password" +PUBLIC_KEY = "MY_PUBLIC_KEY" +PRIVATE_KEY = "MY_PRIVATE_KEY" PROPAGATION_CREDENTIALS_1 = Credentials(identity=Username(username), secret=Password(password_1)) PROPAGATION_CREDENTIALS_2 = Credentials(identity=Username(special_username), secret=LMHash(lm_hash)) PROPAGATION_CREDENTIALS_3 = Credentials(identity=Username(username), secret=NTHash(nt_hash)) PROPAGATION_CREDENTIALS_4 = Credentials(identity=Username(username), secret=Password(password_2)) +PROPAGATION_CREDENTIALS_5 = Credentials( + identity=Username(username), secret=SSHKeypair(PRIVATE_KEY, PUBLIC_KEY) +) + +PROPAGATION_CREDENTIALS = [ + PROPAGATION_CREDENTIALS_1, + PROPAGATION_CREDENTIALS_2, + PROPAGATION_CREDENTIALS_3, + PROPAGATION_CREDENTIALS_4, + PROPAGATION_CREDENTIALS_5, +] diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py index c3d55940d..0d1b90801 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock import mongomock import pytest +from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username from monkey_island.cc.repository import MongoCredentialsRepository @@ -32,9 +33,9 @@ SECRETS_2 = (Password(PASSWORD2), Password(PASSWORD3)) CREDENTIALS_OBJECT_2 = Credentials(IDENTITIES_2, SECRETS_2) -CONFIGURED_CREDENTIALS = [CREDENTIALS_OBJECT_1] +CONFIGURED_CREDENTIALS = PROPAGATION_CREDENTIALS[0:3] -STOLEN_CREDENTIALS = [CREDENTIALS_OBJECT_2] +STOLEN_CREDENTIALS = PROPAGATION_CREDENTIALS[3:6] CREDENTIALS_LIST = [CREDENTIALS_OBJECT_1, CREDENTIALS_OBJECT_2] @@ -81,9 +82,9 @@ def test_mongo_repository_get_all(mongo_repository): def test_mongo_repository_configured(mongo_repository): - mongo_repository.save_configured_credentials(CREDENTIALS_LIST) + mongo_repository.save_configured_credentials(PROPAGATION_CREDENTIALS) actual_configured_credentials = mongo_repository.get_configured_credentials() - assert actual_configured_credentials == CREDENTIALS_LIST + assert actual_configured_credentials == PROPAGATION_CREDENTIALS mongo_repository.remove_configured_credentials() actual_configured_credentials = mongo_repository.get_configured_credentials() @@ -93,7 +94,7 @@ def test_mongo_repository_configured(mongo_repository): def test_mongo_repository_stolen(mongo_repository): mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS) actual_stolen_credentials = mongo_repository.get_stolen_credentials() - assert sorted(actual_stolen_credentials) == sorted(STOLEN_CREDENTIALS) + assert actual_stolen_credentials == STOLEN_CREDENTIALS mongo_repository.remove_stolen_credentials() actual_stolen_credentials = mongo_repository.get_stolen_credentials() @@ -104,7 +105,7 @@ def test_mongo_repository_all(mongo_repository): mongo_repository.save_configured_credentials(CONFIGURED_CREDENTIALS) mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS) actual_credentials = mongo_repository.get_all_credentials() - assert actual_credentials == CREDENTIALS_LIST + assert actual_credentials == PROPAGATION_CREDENTIALS mongo_repository.remove_all_credentials() @@ -116,26 +117,25 @@ def test_mongo_repository_all(mongo_repository): # NOTE: The following tests are complicated, but they work. Rather than spend the effort to improve # them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to # simplify these tests. -def test_configured_secrets_encrypted(mongo_repository, mongo_client): - mongo_repository.save_configured_credentials([CREDENTIALS_OBJECT_2]) - check_if_stored_credentials_encrypted(mongo_client, CREDENTIALS_OBJECT_2) +@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS) +def test_configured_secrets_encrypted(mongo_repository, mongo_client, credentials): + mongo_repository.save_configured_credentials([credentials]) + check_if_stored_credentials_encrypted(mongo_client, credentials) -def test_stolen_secrets_encrypted(mongo_repository, mongo_client): - mongo_repository.save_stolen_credentials([CREDENTIALS_OBJECT_2]) - check_if_stored_credentials_encrypted(mongo_client, CREDENTIALS_OBJECT_2) +@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS) +def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials): + mongo_repository.save_stolen_credentials([credentials]) + check_if_stored_credentials_encrypted(mongo_client, credentials) def check_if_stored_credentials_encrypted(mongo_client, original_credentials): raw_credentials = get_all_credentials_in_mongo(mongo_client) original_credentials_mapping = Credentials.to_mapping(original_credentials) for rc in raw_credentials: - for identity_or_secret, credentials_components in rc.items(): - for component in credentials_components: - for key, value in component.items(): - assert ( - original_credentials_mapping[identity_or_secret][0].get(key, None) != value - ) + for identity_or_secret, credentials_component in rc.items(): + for key, value in credentials_component.items(): + assert original_credentials_mapping[identity_or_secret][key] != value def get_all_credentials_in_mongo(mongo_client):