diff --git a/monkey/monkey_island/cc/repository/i_credentials_repository.py b/monkey/monkey_island/cc/repository/i_credentials_repository.py index 381782533..9affc52ea 100644 --- a/monkey/monkey_island/cc/repository/i_credentials_repository.py +++ b/monkey/monkey_island/cc/repository/i_credentials_repository.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod from typing import Sequence from common.credentials import Credentials @@ -13,6 +13,7 @@ class ICredentialsRepository(ABC): a simulation. """ + @abstractmethod def get_configured_credentials(self) -> Sequence[Credentials]: """ Retrieve credentials that were configured. @@ -21,8 +22,8 @@ class ICredentialsRepository(ABC): credentials :return: Sequence of configured credentials """ - pass + @abstractmethod def get_stolen_credentials(self) -> Sequence[Credentials]: """ Retrieve credentials that were stolen during a simulation. @@ -31,8 +32,8 @@ class ICredentialsRepository(ABC): credentials :return: Sequence of stolen credentials """ - pass + @abstractmethod def get_all_credentials(self) -> Sequence[Credentials]: """ Retrieve all credentials in the repository. @@ -41,8 +42,8 @@ class ICredentialsRepository(ABC): credentials :return: Sequence of stolen and configured credentials """ - pass + @abstractmethod def save_configured_credentials(self, credentials: Sequence[Credentials]): """ Save credentials that were configured. @@ -50,8 +51,8 @@ class ICredentialsRepository(ABC): :param credentials: Configured Credentials to store in the repository :raises StorageError: If an error is encountered while attempting to store the credentials """ - pass + @abstractmethod def save_stolen_credentials(self, credentials: Sequence[Credentials]): """ Save credentials that were stolen during a simulation. @@ -59,28 +60,35 @@ class ICredentialsRepository(ABC): :param credentials: Stolen Credentials to store in the repository :raises StorageError: If an error is encountered while attempting to store the credentials """ - pass + @abstractmethod def remove_configured_credentials(self): """ Remove credentials that were configured from the repository. :raises RemovalError: If an error is encountered while attempting to remove the credentials """ - pass + @abstractmethod def remove_stolen_credentials(self): """ Remove stolen credentials from the repository. :raises RemovalError: If an error is encountered while attempting to remove the credentials """ - pass + @abstractmethod def remove_all_credentials(self): """ Remove all credentials in the repository. :raises RemovalError: If an error is encountered while attempting to remove the credentials """ - pass + + @abstractmethod + def reset(self): + """ + An alias for remove_all_credentials() + + :raises RemovalError: If an error is encountered while attempting to remove the credentials + """ diff --git a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py index 3fdc306a8..438ce48b6 100644 --- a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py @@ -20,7 +20,7 @@ class MongoCredentialsRepository(ICredentialsRepository): """ def __init__(self, mongo: MongoClient, repository_encryptor: ILockableEncryptor): - self._database = mongo.monkeyisland + self._database = mongo.monkey_island self._repository_encryptor = repository_encryptor def get_configured_credentials(self) -> Sequence[Credentials]: @@ -52,6 +52,9 @@ class MongoCredentialsRepository(ICredentialsRepository): self.remove_configured_credentials() self.remove_stolen_credentials() + def reset(self): + self.remove_all_credentials() + def _get_credentials_from_collection(self, collection) -> Sequence[Credentials]: try: collection_result = [] @@ -109,6 +112,6 @@ class MongoCredentialsRepository(ICredentialsRepository): @staticmethod def _remove_credentials_fom_collection(collection): try: - collection.delete_many({}) - except RemovalError as err: - raise err + collection.drop() + except Exception as err: + raise RemovalError(f"Error removing credentials: {err}") diff --git a/monkey/tests/monkey_island/in_memory_credentials_repository.py b/monkey/tests/monkey_island/in_memory_credentials_repository.py index 6eb8155e8..389a31e08 100644 --- a/monkey/tests/monkey_island/in_memory_credentials_repository.py +++ b/monkey/tests/monkey_island/in_memory_credentials_repository.py @@ -33,3 +33,6 @@ class InMemoryCredentialsRepository(ICredentialsRepository): def remove_all_credentials(self): self.remove_configured_credentials() self.remove_stolen_credentials() + + def reset(self): + self.remove_all_credentials() diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py index ffd3c3bfa..f921c9df5 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py @@ -9,7 +9,13 @@ from pymongo.database import Database from tests.data_for_tests.propagation_credentials import CREDENTIALS from common.credentials import Credentials -from monkey_island.cc.repository import MongoCredentialsRepository +from monkey_island.cc.repository import ( + ICredentialsRepository, + MongoCredentialsRepository, + RemovalError, + RetrievalError, + StorageError, +) from monkey_island.cc.server_utils.encryption import ILockableEncryptor CONFIGURED_CREDENTIALS = CREDENTIALS[0:3] @@ -40,6 +46,36 @@ def mongo_repository(mongo_client, repository_encryptor): return MongoCredentialsRepository(mongo_client, repository_encryptor) +@pytest.fixture +def error_raising_mock_mongo_client() -> mongomock.MongoClient: + mongo_client = MagicMock(spec=mongomock.MongoClient) + mongo_client.monkey_island = MagicMock(spec=mongomock.Database) + mongo_client.monkey_island.stolen_credentials = MagicMock(spec=mongomock.Collection) + mongo_client.monkey_island.configured_credentials = MagicMock(spec=mongomock.Collection) + + mongo_client.monkey_island.configured_credentials.find = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.stolen_credentials.find = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.stolen_credentials.insert_one = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.stolen_credentials.drop = MagicMock( + side_effect=Exception("some exception") + ) + + return mongo_client + + +@pytest.fixture +def error_raising_credentials_repository( + error_raising_mock_mongo_client: mongomock.MongoClient, repository_encryptor: ILockableEncryptor +) -> ICredentialsRepository: + return MongoCredentialsRepository(error_raising_mock_mongo_client, repository_encryptor) + + def test_mongo_repository_get_configured(mongo_repository): actual_configured_credentials = mongo_repository.get_configured_credentials() @@ -91,6 +127,26 @@ def test_mongo_repository_all(mongo_repository): assert mongo_repository.get_configured_credentials() == [] +def test_mongo_repository_get__retrieval_error(error_raising_credentials_repository): + with pytest.raises(RetrievalError): + error_raising_credentials_repository.get_all_credentials() + + +def test_mongo_repository_save__storage_error(error_raising_credentials_repository): + with pytest.raises(StorageError): + error_raising_credentials_repository.save_stolen_credentials(STOLEN_CREDENTIALS) + + +def test_mongo_repository_remove_credentials__removal_error(error_raising_credentials_repository): + with pytest.raises(RemovalError): + error_raising_credentials_repository.remove_stolen_credentials() + + +def test_mongo_repository_reset__removal_error(error_raising_credentials_repository): + with pytest.raises(RemovalError): + error_raising_credentials_repository.reset() + + @pytest.mark.parametrize("credentials", CREDENTIALS) def test_configured_secrets_encrypted( mongo_repository: MongoCredentialsRepository,