Merge branch 'improve-mongo-credentials-repo' into develop

This commit is contained in:
Mike Salvatore 2022-09-15 12:05:04 -04:00
commit cbbd10dd71
4 changed files with 84 additions and 14 deletions

View File

@ -1,4 +1,4 @@
from abc import ABC from abc import ABC, abstractmethod
from typing import Sequence from typing import Sequence
from common.credentials import Credentials from common.credentials import Credentials
@ -13,6 +13,7 @@ class ICredentialsRepository(ABC):
a simulation. a simulation.
""" """
@abstractmethod
def get_configured_credentials(self) -> Sequence[Credentials]: def get_configured_credentials(self) -> Sequence[Credentials]:
""" """
Retrieve credentials that were configured. Retrieve credentials that were configured.
@ -21,8 +22,8 @@ class ICredentialsRepository(ABC):
credentials credentials
:return: Sequence of configured credentials :return: Sequence of configured credentials
""" """
pass
@abstractmethod
def get_stolen_credentials(self) -> Sequence[Credentials]: def get_stolen_credentials(self) -> Sequence[Credentials]:
""" """
Retrieve credentials that were stolen during a simulation. Retrieve credentials that were stolen during a simulation.
@ -31,8 +32,8 @@ class ICredentialsRepository(ABC):
credentials credentials
:return: Sequence of stolen credentials :return: Sequence of stolen credentials
""" """
pass
@abstractmethod
def get_all_credentials(self) -> Sequence[Credentials]: def get_all_credentials(self) -> Sequence[Credentials]:
""" """
Retrieve all credentials in the repository. Retrieve all credentials in the repository.
@ -41,8 +42,8 @@ class ICredentialsRepository(ABC):
credentials credentials
:return: Sequence of stolen and configured credentials :return: Sequence of stolen and configured credentials
""" """
pass
@abstractmethod
def save_configured_credentials(self, credentials: Sequence[Credentials]): def save_configured_credentials(self, credentials: Sequence[Credentials]):
""" """
Save credentials that were configured. Save credentials that were configured.
@ -50,8 +51,8 @@ class ICredentialsRepository(ABC):
:param credentials: Configured Credentials to store in the repository :param credentials: Configured Credentials to store in the repository
:raises StorageError: If an error is encountered while attempting to store the credentials :raises StorageError: If an error is encountered while attempting to store the credentials
""" """
pass
@abstractmethod
def save_stolen_credentials(self, credentials: Sequence[Credentials]): def save_stolen_credentials(self, credentials: Sequence[Credentials]):
""" """
Save credentials that were stolen during a simulation. Save credentials that were stolen during a simulation.
@ -59,28 +60,35 @@ class ICredentialsRepository(ABC):
:param credentials: Stolen Credentials to store in the repository :param credentials: Stolen Credentials to store in the repository
:raises StorageError: If an error is encountered while attempting to store the credentials :raises StorageError: If an error is encountered while attempting to store the credentials
""" """
pass
@abstractmethod
def remove_configured_credentials(self): def remove_configured_credentials(self):
""" """
Remove credentials that were configured from the repository. Remove credentials that were configured from the repository.
:raises RemovalError: If an error is encountered while attempting to remove the credentials :raises RemovalError: If an error is encountered while attempting to remove the credentials
""" """
pass
@abstractmethod
def remove_stolen_credentials(self): def remove_stolen_credentials(self):
""" """
Remove stolen credentials from the repository. Remove stolen credentials from the repository.
:raises RemovalError: If an error is encountered while attempting to remove the credentials :raises RemovalError: If an error is encountered while attempting to remove the credentials
""" """
pass
@abstractmethod
def remove_all_credentials(self): def remove_all_credentials(self):
""" """
Remove all credentials in the repository. Remove all credentials in the repository.
:raises RemovalError: If an error is encountered while attempting to remove the credentials :raises RemovalError: If an error is encountered while attempting to remove the credentials
""" """
pass
@abstractmethod
def reset(self):
"""
An alias for remove_all_credentials()
:raises RemovalError: If an error is encountered while attempting to remove the credentials
"""

View File

@ -20,7 +20,7 @@ class MongoCredentialsRepository(ICredentialsRepository):
""" """
def __init__(self, mongo: MongoClient, repository_encryptor: ILockableEncryptor): def __init__(self, mongo: MongoClient, repository_encryptor: ILockableEncryptor):
self._database = mongo.monkeyisland self._database = mongo.monkey_island
self._repository_encryptor = repository_encryptor self._repository_encryptor = repository_encryptor
def get_configured_credentials(self) -> Sequence[Credentials]: def get_configured_credentials(self) -> Sequence[Credentials]:
@ -52,6 +52,9 @@ class MongoCredentialsRepository(ICredentialsRepository):
self.remove_configured_credentials() self.remove_configured_credentials()
self.remove_stolen_credentials() self.remove_stolen_credentials()
def reset(self):
self.remove_all_credentials()
def _get_credentials_from_collection(self, collection) -> Sequence[Credentials]: def _get_credentials_from_collection(self, collection) -> Sequence[Credentials]:
try: try:
collection_result = [] collection_result = []
@ -109,6 +112,6 @@ class MongoCredentialsRepository(ICredentialsRepository):
@staticmethod @staticmethod
def _remove_credentials_fom_collection(collection): def _remove_credentials_fom_collection(collection):
try: try:
collection.delete_many({}) collection.drop()
except RemovalError as err: except Exception as err:
raise err raise RemovalError(f"Error removing credentials: {err}")

View File

@ -33,3 +33,6 @@ class InMemoryCredentialsRepository(ICredentialsRepository):
def remove_all_credentials(self): def remove_all_credentials(self):
self.remove_configured_credentials() self.remove_configured_credentials()
self.remove_stolen_credentials() self.remove_stolen_credentials()
def reset(self):
self.remove_all_credentials()

View File

@ -9,7 +9,13 @@ from pymongo.database import Database
from tests.data_for_tests.propagation_credentials import CREDENTIALS from tests.data_for_tests.propagation_credentials import CREDENTIALS
from common.credentials import Credentials from common.credentials import Credentials
from monkey_island.cc.repository import MongoCredentialsRepository from monkey_island.cc.repository import (
ICredentialsRepository,
MongoCredentialsRepository,
RemovalError,
RetrievalError,
StorageError,
)
from monkey_island.cc.server_utils.encryption import ILockableEncryptor from monkey_island.cc.server_utils.encryption import ILockableEncryptor
CONFIGURED_CREDENTIALS = CREDENTIALS[0:3] CONFIGURED_CREDENTIALS = CREDENTIALS[0:3]
@ -40,6 +46,36 @@ def mongo_repository(mongo_client, repository_encryptor):
return MongoCredentialsRepository(mongo_client, repository_encryptor) return MongoCredentialsRepository(mongo_client, repository_encryptor)
@pytest.fixture
def error_raising_mock_mongo_client() -> mongomock.MongoClient:
mongo_client = MagicMock(spec=mongomock.MongoClient)
mongo_client.monkey_island = MagicMock(spec=mongomock.Database)
mongo_client.monkey_island.stolen_credentials = MagicMock(spec=mongomock.Collection)
mongo_client.monkey_island.configured_credentials = MagicMock(spec=mongomock.Collection)
mongo_client.monkey_island.configured_credentials.find = MagicMock(
side_effect=Exception("some exception")
)
mongo_client.monkey_island.stolen_credentials.find = MagicMock(
side_effect=Exception("some exception")
)
mongo_client.monkey_island.stolen_credentials.insert_one = MagicMock(
side_effect=Exception("some exception")
)
mongo_client.monkey_island.stolen_credentials.drop = MagicMock(
side_effect=Exception("some exception")
)
return mongo_client
@pytest.fixture
def error_raising_credentials_repository(
error_raising_mock_mongo_client: mongomock.MongoClient, repository_encryptor: ILockableEncryptor
) -> ICredentialsRepository:
return MongoCredentialsRepository(error_raising_mock_mongo_client, repository_encryptor)
def test_mongo_repository_get_configured(mongo_repository): def test_mongo_repository_get_configured(mongo_repository):
actual_configured_credentials = mongo_repository.get_configured_credentials() actual_configured_credentials = mongo_repository.get_configured_credentials()
@ -91,6 +127,26 @@ def test_mongo_repository_all(mongo_repository):
assert mongo_repository.get_configured_credentials() == [] assert mongo_repository.get_configured_credentials() == []
def test_mongo_repository_get__retrieval_error(error_raising_credentials_repository):
with pytest.raises(RetrievalError):
error_raising_credentials_repository.get_all_credentials()
def test_mongo_repository_save__storage_error(error_raising_credentials_repository):
with pytest.raises(StorageError):
error_raising_credentials_repository.save_stolen_credentials(STOLEN_CREDENTIALS)
def test_mongo_repository_remove_credentials__removal_error(error_raising_credentials_repository):
with pytest.raises(RemovalError):
error_raising_credentials_repository.remove_stolen_credentials()
def test_mongo_repository_reset__removal_error(error_raising_credentials_repository):
with pytest.raises(RemovalError):
error_raising_credentials_repository.reset()
@pytest.mark.parametrize("credentials", CREDENTIALS) @pytest.mark.parametrize("credentials", CREDENTIALS)
def test_configured_secrets_encrypted( def test_configured_secrets_encrypted(
mongo_repository: MongoCredentialsRepository, mongo_repository: MongoCredentialsRepository,