Island: Encrypt credentials in MongoCredentialsRepository

This commit is contained in:
Mike Salvatore 2022-07-12 13:22:22 -04:00
parent bdee3b9d8a
commit cee52ab12c
2 changed files with 111 additions and 27 deletions

View File

@ -1,10 +1,11 @@
from typing import Sequence from typing import Any, Dict, Mapping, Sequence
from pymongo import MongoClient from pymongo import MongoClient
from common.credentials import Credentials from common.credentials import Credentials
from monkey_island.cc.repository import RemovalError, RetrievalError, StorageError from monkey_island.cc.repository import RemovalError, RetrievalError, StorageError
from monkey_island.cc.repository.i_credentials_repository import ICredentialsRepository from monkey_island.cc.repository.i_credentials_repository import ICredentialsRepository
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
class MongoCredentialsRepository(ICredentialsRepository): class MongoCredentialsRepository(ICredentialsRepository):
@ -12,18 +13,15 @@ class MongoCredentialsRepository(ICredentialsRepository):
Store credentials in a mongo database that can be used to propagate around the network. Store credentials in a mongo database that can be used to propagate around the network.
""" """
def __init__(self, mongo: MongoClient): def __init__(self, mongo: MongoClient, repository_encryptor: ILockableEncryptor):
self._mongo = mongo self._mongo = mongo
self._repository_encryptor = repository_encryptor
def get_configured_credentials(self) -> Sequence[Credentials]: def get_configured_credentials(self) -> Sequence[Credentials]:
return MongoCredentialsRepository._get_credentials_from_collection( return self._get_credentials_from_collection(self._mongo.db.configured_credentials)
self._mongo.db.configured_credentials
)
def get_stolen_credentials(self) -> Sequence[Credentials]: def get_stolen_credentials(self) -> Sequence[Credentials]:
return MongoCredentialsRepository._get_credentials_from_collection( return self._get_credentials_from_collection(self._mongo.db.stolen_credentials)
self._mongo.db.stolen_credentials
)
def get_all_credentials(self) -> Sequence[Credentials]: def get_all_credentials(self) -> Sequence[Credentials]:
configured_credentials = self.get_configured_credentials() configured_credentials = self.get_configured_credentials()
@ -33,14 +31,10 @@ class MongoCredentialsRepository(ICredentialsRepository):
def save_configured_credentials(self, credentials: Sequence[Credentials]): def save_configured_credentials(self, credentials: Sequence[Credentials]):
# TODO: Fix deduplication of Credentials in mongo # TODO: Fix deduplication of Credentials in mongo
MongoCredentialsRepository._save_credentials_to_collection( self._save_credentials_to_collection(credentials, self._mongo.db.configured_credentials)
credentials, self._mongo.db.configured_credentials
)
def save_stolen_credentials(self, credentials: Sequence[Credentials]): def save_stolen_credentials(self, credentials: Sequence[Credentials]):
MongoCredentialsRepository._save_credentials_to_collection( self._save_credentials_to_collection(credentials, self._mongo.db.stolen_credentials)
credentials, self._mongo.db.stolen_credentials
)
def remove_configured_credentials(self): def remove_configured_credentials(self):
MongoCredentialsRepository._remove_credentials_fom_collection( MongoCredentialsRepository._remove_credentials_fom_collection(
@ -56,27 +50,58 @@ class MongoCredentialsRepository(ICredentialsRepository):
self.remove_configured_credentials() self.remove_configured_credentials()
self.remove_stolen_credentials() self.remove_stolen_credentials()
@staticmethod def _get_credentials_from_collection(self, collection) -> Sequence[Credentials]:
def _get_credentials_from_collection(collection) -> Sequence[Credentials]:
try: try:
collection_result = [] collection_result = []
list_collection_result = list(collection.find({})) list_collection_result = list(collection.find({}))
for c in list_collection_result: for encrypted_credentials in list_collection_result:
del c["_id"] del encrypted_credentials["_id"]
collection_result.append(Credentials.from_mapping(c)) plaintext_credentials = self._decrypt_credentials_mapping(encrypted_credentials)
collection_result.append(Credentials.from_mapping(plaintext_credentials))
return collection_result return collection_result
except Exception as err: except Exception as err:
raise RetrievalError(err) raise RetrievalError(err)
@staticmethod def _save_credentials_to_collection(self, credentials: Sequence[Credentials], collection):
def _save_credentials_to_collection(credentials: Sequence[Credentials], collection):
try: try:
for c in credentials: for c in credentials:
collection.insert_one(Credentials.to_mapping(c)) encrypted_credentials = self._encrypt_credentials_mapping(Credentials.to_mapping(c))
collection.insert_one(encrypted_credentials)
except Exception as err: except Exception as err:
raise StorageError(err) raise StorageError(err)
# NOTE: The encryption/decryption is complicated and also full of mostly duplicated code. Rather
# than spend the effort to improve them now, we can revisit them when we resolve #2072.
# Resolving #2072 will make it easier to simplify these methods and remove duplication.
def _encrypt_credentials_mapping(self, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
encrypted_mapping: Dict[str, Any] = {}
for secret_or_identity, credentials_components in mapping.items():
encrypted_mapping[secret_or_identity] = []
for component in credentials_components:
encrypted_component = {}
for key, value in component.items():
encrypted_component[key] = self._repository_encryptor.encrypt(value.encode())
encrypted_mapping[secret_or_identity].append(encrypted_component)
return encrypted_mapping
def _decrypt_credentials_mapping(self, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
encrypted_mapping: Dict[str, Any] = {}
for secret_or_identity, credentials_components in mapping.items():
encrypted_mapping[secret_or_identity] = []
for component in credentials_components:
encrypted_component = {}
for key, value in component.items():
encrypted_component[key] = self._repository_encryptor.decrypt(value).decode()
encrypted_mapping[secret_or_identity].append(encrypted_component)
return encrypted_mapping
@staticmethod @staticmethod
def _remove_credentials_fom_collection(collection): def _remove_credentials_fom_collection(collection):
try: try:

View File

@ -1,8 +1,11 @@
from unittest.mock import MagicMock
import mongomock import mongomock
import pytest import pytest
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
from monkey_island.cc.repository import MongoCredentialsRepository from monkey_island.cc.repository import MongoCredentialsRepository
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
USER1 = "test_user_1" USER1 = "test_user_1"
USER2 = "test_user_2" USER2 = "test_user_2"
@ -36,11 +39,27 @@ STOLEN_CREDENTIALS = [CREDENTIALS_OBJECT_2]
CREDENTIALS_LIST = [CREDENTIALS_OBJECT_1, CREDENTIALS_OBJECT_2] CREDENTIALS_LIST = [CREDENTIALS_OBJECT_1, CREDENTIALS_OBJECT_2]
@pytest.fixture def reverse(data: bytes) -> bytes:
def mongo_repository(): return bytes(reversed(data))
mongo = mongomock.MongoClient()
return MongoCredentialsRepository(mongo)
@pytest.fixture
def repository_encryptor():
repository_encryptor = MagicMock(spec=ILockableEncryptor)
repository_encryptor.encrypt = MagicMock(side_effect=reverse)
repository_encryptor.decrypt = MagicMock(side_effect=reverse)
return repository_encryptor
@pytest.fixture
def mongo_client():
return mongomock.MongoClient()
@pytest.fixture
def mongo_repository(mongo_client, repository_encryptor):
return MongoCredentialsRepository(mongo_client, repository_encryptor)
def test_mongo_repository_get_configured(mongo_repository): def test_mongo_repository_get_configured(mongo_repository):
@ -74,7 +93,7 @@ def test_mongo_repository_configured(mongo_repository):
def test_mongo_repository_stolen(mongo_repository): def test_mongo_repository_stolen(mongo_repository):
mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS) mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS)
actual_stolen_credentials = mongo_repository.get_stolen_credentials() actual_stolen_credentials = mongo_repository.get_stolen_credentials()
assert actual_stolen_credentials == STOLEN_CREDENTIALS assert sorted(actual_stolen_credentials) == sorted(STOLEN_CREDENTIALS)
mongo_repository.remove_stolen_credentials() mongo_repository.remove_stolen_credentials()
actual_stolen_credentials = mongo_repository.get_stolen_credentials() actual_stolen_credentials = mongo_repository.get_stolen_credentials()
@ -92,3 +111,43 @@ def test_mongo_repository_all(mongo_repository):
assert mongo_repository.get_all_credentials() == [] assert mongo_repository.get_all_credentials() == []
assert mongo_repository.get_stolen_credentials() == [] assert mongo_repository.get_stolen_credentials() == []
assert mongo_repository.get_configured_credentials() == [] assert mongo_repository.get_configured_credentials() == []
# NOTE: The following tests are complicated, but they work. Rather than spend the effort to improve
# them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to
# simplify these tests.
def test_configured_secrets_encrypted(mongo_repository, mongo_client):
mongo_repository.save_configured_credentials([CREDENTIALS_OBJECT_2])
check_if_stored_credentials_encrypted(mongo_client, CREDENTIALS_OBJECT_2)
def test_stolen_secrets_encrypted(mongo_repository, mongo_client):
mongo_repository.save_stolen_credentials([CREDENTIALS_OBJECT_2])
check_if_stored_credentials_encrypted(mongo_client, CREDENTIALS_OBJECT_2)
def check_if_stored_credentials_encrypted(mongo_client, original_credentials):
raw_credentials = get_all_credentials_in_mongo(mongo_client)
original_credentials_mapping = Credentials.to_mapping(original_credentials)
for rc in raw_credentials:
for identity_or_secret, credentials_components in rc.items():
for component in credentials_components:
for key, value in component.items():
assert (
original_credentials_mapping[identity_or_secret][0].get(key, None) != value
)
def get_all_credentials_in_mongo(mongo_client):
encrypted_credentials = []
# Loop through all databases and collections and search for credentials. We don't want the tests
# to assume anything about the internal workings of the repository.
for db in mongo_client.list_database_names():
for collection in mongo_client[db].list_collection_names():
mongo_credentials = mongo_client[db][collection].find({})
for mc in mongo_credentials:
del mc["_id"]
encrypted_credentials.append(mc)
return encrypted_credentials