diff --git a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py index 454196ee8..5dd43a32f 100644 --- a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py @@ -1,10 +1,11 @@ -from typing import Sequence +from typing import Any, Dict, Mapping, Sequence from pymongo import MongoClient from common.credentials import Credentials from monkey_island.cc.repository import RemovalError, RetrievalError, StorageError from monkey_island.cc.repository.i_credentials_repository import ICredentialsRepository +from monkey_island.cc.server_utils.encryption import ILockableEncryptor class MongoCredentialsRepository(ICredentialsRepository): @@ -12,18 +13,15 @@ class MongoCredentialsRepository(ICredentialsRepository): Store credentials in a mongo database that can be used to propagate around the network. """ - def __init__(self, mongo: MongoClient): + def __init__(self, mongo: MongoClient, repository_encryptor: ILockableEncryptor): self._mongo = mongo + self._repository_encryptor = repository_encryptor def get_configured_credentials(self) -> Sequence[Credentials]: - return MongoCredentialsRepository._get_credentials_from_collection( - self._mongo.db.configured_credentials - ) + return self._get_credentials_from_collection(self._mongo.db.configured_credentials) def get_stolen_credentials(self) -> Sequence[Credentials]: - return MongoCredentialsRepository._get_credentials_from_collection( - self._mongo.db.stolen_credentials - ) + return self._get_credentials_from_collection(self._mongo.db.stolen_credentials) def get_all_credentials(self) -> Sequence[Credentials]: configured_credentials = self.get_configured_credentials() @@ -33,14 +31,10 @@ class MongoCredentialsRepository(ICredentialsRepository): def save_configured_credentials(self, credentials: Sequence[Credentials]): # TODO: Fix deduplication of Credentials in mongo - MongoCredentialsRepository._save_credentials_to_collection( - credentials, self._mongo.db.configured_credentials - ) + self._save_credentials_to_collection(credentials, self._mongo.db.configured_credentials) def save_stolen_credentials(self, credentials: Sequence[Credentials]): - MongoCredentialsRepository._save_credentials_to_collection( - credentials, self._mongo.db.stolen_credentials - ) + self._save_credentials_to_collection(credentials, self._mongo.db.stolen_credentials) def remove_configured_credentials(self): MongoCredentialsRepository._remove_credentials_fom_collection( @@ -56,27 +50,58 @@ class MongoCredentialsRepository(ICredentialsRepository): self.remove_configured_credentials() self.remove_stolen_credentials() - @staticmethod - def _get_credentials_from_collection(collection) -> Sequence[Credentials]: + def _get_credentials_from_collection(self, collection) -> Sequence[Credentials]: try: collection_result = [] list_collection_result = list(collection.find({})) - for c in list_collection_result: - del c["_id"] - collection_result.append(Credentials.from_mapping(c)) + for encrypted_credentials in list_collection_result: + del encrypted_credentials["_id"] + plaintext_credentials = self._decrypt_credentials_mapping(encrypted_credentials) + collection_result.append(Credentials.from_mapping(plaintext_credentials)) return collection_result except Exception as err: raise RetrievalError(err) - @staticmethod - def _save_credentials_to_collection(credentials: Sequence[Credentials], collection): + def _save_credentials_to_collection(self, credentials: Sequence[Credentials], collection): try: for c in credentials: - collection.insert_one(Credentials.to_mapping(c)) + encrypted_credentials = self._encrypt_credentials_mapping(Credentials.to_mapping(c)) + collection.insert_one(encrypted_credentials) except Exception as err: raise StorageError(err) + # NOTE: The encryption/decryption is complicated and also full of mostly duplicated code. Rather + # than spend the effort to improve them now, we can revisit them when we resolve #2072. + # Resolving #2072 will make it easier to simplify these methods and remove duplication. + def _encrypt_credentials_mapping(self, mapping: Mapping[str, Any]) -> Mapping[str, Any]: + encrypted_mapping: Dict[str, Any] = {} + + for secret_or_identity, credentials_components in mapping.items(): + encrypted_mapping[secret_or_identity] = [] + for component in credentials_components: + encrypted_component = {} + for key, value in component.items(): + encrypted_component[key] = self._repository_encryptor.encrypt(value.encode()) + + encrypted_mapping[secret_or_identity].append(encrypted_component) + + return encrypted_mapping + + def _decrypt_credentials_mapping(self, mapping: Mapping[str, Any]) -> Mapping[str, Any]: + encrypted_mapping: Dict[str, Any] = {} + + for secret_or_identity, credentials_components in mapping.items(): + encrypted_mapping[secret_or_identity] = [] + for component in credentials_components: + encrypted_component = {} + for key, value in component.items(): + encrypted_component[key] = self._repository_encryptor.decrypt(value).decode() + + encrypted_mapping[secret_or_identity].append(encrypted_component) + + return encrypted_mapping + @staticmethod def _remove_credentials_fom_collection(collection): try: diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py index e8776b440..c3d55940d 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py @@ -1,8 +1,11 @@ +from unittest.mock import MagicMock + import mongomock import pytest from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username from monkey_island.cc.repository import MongoCredentialsRepository +from monkey_island.cc.server_utils.encryption import ILockableEncryptor USER1 = "test_user_1" USER2 = "test_user_2" @@ -36,11 +39,27 @@ STOLEN_CREDENTIALS = [CREDENTIALS_OBJECT_2] CREDENTIALS_LIST = [CREDENTIALS_OBJECT_1, CREDENTIALS_OBJECT_2] -@pytest.fixture -def mongo_repository(): - mongo = mongomock.MongoClient() +def reverse(data: bytes) -> bytes: + return bytes(reversed(data)) - return MongoCredentialsRepository(mongo) + +@pytest.fixture +def repository_encryptor(): + repository_encryptor = MagicMock(spec=ILockableEncryptor) + repository_encryptor.encrypt = MagicMock(side_effect=reverse) + repository_encryptor.decrypt = MagicMock(side_effect=reverse) + + return repository_encryptor + + +@pytest.fixture +def mongo_client(): + return mongomock.MongoClient() + + +@pytest.fixture +def mongo_repository(mongo_client, repository_encryptor): + return MongoCredentialsRepository(mongo_client, repository_encryptor) def test_mongo_repository_get_configured(mongo_repository): @@ -74,7 +93,7 @@ def test_mongo_repository_configured(mongo_repository): def test_mongo_repository_stolen(mongo_repository): mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS) actual_stolen_credentials = mongo_repository.get_stolen_credentials() - assert actual_stolen_credentials == STOLEN_CREDENTIALS + assert sorted(actual_stolen_credentials) == sorted(STOLEN_CREDENTIALS) mongo_repository.remove_stolen_credentials() actual_stolen_credentials = mongo_repository.get_stolen_credentials() @@ -92,3 +111,43 @@ def test_mongo_repository_all(mongo_repository): assert mongo_repository.get_all_credentials() == [] assert mongo_repository.get_stolen_credentials() == [] assert mongo_repository.get_configured_credentials() == [] + + +# NOTE: The following tests are complicated, but they work. Rather than spend the effort to improve +# them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to +# simplify these tests. +def test_configured_secrets_encrypted(mongo_repository, mongo_client): + mongo_repository.save_configured_credentials([CREDENTIALS_OBJECT_2]) + check_if_stored_credentials_encrypted(mongo_client, CREDENTIALS_OBJECT_2) + + +def test_stolen_secrets_encrypted(mongo_repository, mongo_client): + mongo_repository.save_stolen_credentials([CREDENTIALS_OBJECT_2]) + check_if_stored_credentials_encrypted(mongo_client, CREDENTIALS_OBJECT_2) + + +def check_if_stored_credentials_encrypted(mongo_client, original_credentials): + raw_credentials = get_all_credentials_in_mongo(mongo_client) + original_credentials_mapping = Credentials.to_mapping(original_credentials) + for rc in raw_credentials: + for identity_or_secret, credentials_components in rc.items(): + for component in credentials_components: + for key, value in component.items(): + assert ( + original_credentials_mapping[identity_or_secret][0].get(key, None) != value + ) + + +def get_all_credentials_in_mongo(mongo_client): + encrypted_credentials = [] + + # Loop through all databases and collections and search for credentials. We don't want the tests + # to assume anything about the internal workings of the repository. + for db in mongo_client.list_database_names(): + for collection in mongo_client[db].list_collection_names(): + mongo_credentials = mongo_client[db][collection].find({}) + for mc in mongo_credentials: + del mc["_id"] + encrypted_credentials.append(mc) + + return encrypted_credentials