diff --git a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py index 4ef4acda2..b0acf405e 100644 --- a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py @@ -1,7 +1,8 @@ from typing import Sequence +from flask_pymongo import PyMongo + from common.credentials import Credentials -from monkey_island.cc.database import mongo from monkey_island.cc.repository import RemovalError, RetrievalError, StorageError from monkey_island.cc.repository.i_credentials_repository import ICredentialsRepository @@ -11,27 +12,23 @@ class MongoCredentialsRepository(ICredentialsRepository): Store credentials in a mongo database that can be used to propagate around the network. """ + def __init__(self, mongo_db: PyMongo): + self._mongo = mongo_db + def get_configured_credentials(self) -> Sequence[Credentials]: try: - configured_credentials = [] - list_configured_credentials = list(mongo.db.configured_credentials.find({})) - for c in list_configured_credentials: - del c["_id"] - configured_credentials.append(Credentials.from_mapping(c)) - return configured_credentials + return MongoCredentialsRepository._get_credentials_from_collection( + self._mongo.db.configured_credentials + ) except Exception as err: raise RetrievalError(err) def get_stolen_credentials(self) -> Sequence[Credentials]: try: - stolen_credentials = [] - list_stolen_credentials = list(mongo.db.stolen_credentials.find({})) - for c in list_stolen_credentials: - del c["_id"] - stolen_credentials.append(Credentials.from_mapping(c)) - - return stolen_credentials + return MongoCredentialsRepository._get_credentials_from_collection( + self._mongo.db.stolen_credentials + ) except Exception as err: raise RetrievalError(err) @@ -47,28 +44,30 @@ class MongoCredentialsRepository(ICredentialsRepository): def save_configured_credentials(self, credentials: Sequence[Credentials]): # TODO: Fix deduplication of Credentials in mongo try: - for c in credentials: - mongo.db.configured_credentials.insert_one(Credentials.to_mapping(c)) + MongoCredentialsRepository._save_credentials_to_collection( + credentials, self._mongo.db.configured_credentials + ) except Exception as err: raise StorageError(err) def save_stolen_credentials(self, credentials: Sequence[Credentials]): # TODO: Fix deduplication of Credentials in mongo try: - for c in credentials: - mongo.db.stolen_credentials.insert_one(Credentials.to_mapping(c)) + MongoCredentialsRepository._save_credentials_to_collection( + credentials, self._mongo.db.stolen_credentials + ) except Exception as err: raise StorageError(err) def remove_configured_credentials(self): try: - mongo.db.configured_credentials.delete_many({}) + MongoCredentialsRepository._delete_collection(self._mongo.db.configured_credentials) except Exception as err: raise RemovalError(err) def remove_stolen_credentials(self): try: - mongo.db.stolen_credentials.delete_many({}) + MongoCredentialsRepository._delete_collection(self._mongo.db.stolen_credentials) except Exception as err: raise RemovalError(err) @@ -78,3 +77,22 @@ class MongoCredentialsRepository(ICredentialsRepository): self.remove_stolen_credentials() except RemovalError as err: raise err + + @staticmethod + def _get_credentials_from_collection(collection) -> Sequence[Credentials]: + collection_result = [] + list_collection_result = list(collection.find({})) + for c in list_collection_result: + del c["_id"] + collection_result.append(Credentials.from_mapping(c)) + + return collection_result + + @staticmethod + def _save_credentials_to_collection(credentials: Sequence[Credentials], collection): + for c in credentials: + collection.insert_one(Credentials.to_mapping(c)) + + @staticmethod + def _delete_collection(collection): + collection.delete_many({}) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py index 0a96bc9c6..8148780d3 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py @@ -42,101 +42,91 @@ CREDENTIALS_DICT_2 = { ], } +CONFIGURED_CREDENTIALS = [Credentials.from_mapping(CREDENTIALS_DICT_1)] + +STOLEN_CREDENTIALS = [Credentials.from_mapping(CREDENTIALS_DICT_2)] + +CREDENTIALS_LIST = [ + Credentials.from_mapping(CREDENTIALS_DICT_1), + Credentials.from_mapping(CREDENTIALS_DICT_2), +] + @pytest.fixture -def fake_mongo(monkeypatch): +def fake_mongo_repository(monkeypatch): mongo = mongoengine.connection.get_connection() - monkeypatch.setattr("monkey_island.cc.repository.mongo_credentials_repository.mongo", mongo) + return MongoCredentialsRepository(mongo) -def test_mongo_repository_get_configured(fake_mongo): +def test_mongo_repository_get_configured(fake_mongo_repository): - actual_configured_credentials = MongoCredentialsRepository().get_configured_credentials() + actual_configured_credentials = fake_mongo_repository.get_configured_credentials() assert actual_configured_credentials == [] -def test_mongo_repository_get_stolen(fake_mongo): +def test_mongo_repository_get_stolen(fake_mongo_repository): - actual_stolen_credentials = MongoCredentialsRepository().get_stolen_credentials() + actual_stolen_credentials = fake_mongo_repository.get_stolen_credentials() assert actual_stolen_credentials == [] -def test_mongo_repository_get_all(fake_mongo): +def test_mongo_repository_get_all(fake_mongo_repository): - actual_credentials = MongoCredentialsRepository().get_all_credentials() + actual_credentials = fake_mongo_repository.get_all_credentials() assert actual_credentials == [] -def test_mongo_repository_configured(fake_mongo): +def test_mongo_repository_configured(fake_mongo_repository): - credentials = [ - Credentials.from_mapping(CREDENTIALS_DICT_1), - Credentials.from_mapping(CREDENTIALS_DICT_2), - ] + fake_mongo_repository.save_configured_credentials(CREDENTIALS_LIST) - mongo_repository = MongoCredentialsRepository() - mongo_repository.save_configured_credentials(credentials) + actual_configured_credentials = fake_mongo_repository.get_configured_credentials() - actual_configured_credentials = mongo_repository.get_configured_credentials() + assert actual_configured_credentials == CREDENTIALS_LIST - assert actual_configured_credentials == credentials + fake_mongo_repository.remove_configured_credentials() - mongo_repository.remove_configured_credentials() - - actual_configured_credentials = mongo_repository.get_configured_credentials() + actual_configured_credentials = fake_mongo_repository.get_configured_credentials() assert actual_configured_credentials == [] -def test_mongo_repository_stolen(fake_mongo): +def test_mongo_repository_stolen(fake_mongo_repository): - stolen_credentials = [Credentials.from_mapping(CREDENTIALS_DICT_1)] + fake_mongo_repository.save_configured_credentials(CONFIGURED_CREDENTIALS) + fake_mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS) - configured_credentials = [Credentials.from_mapping(CREDENTIALS_DICT_2)] + actual_stolen_credentials = fake_mongo_repository.get_stolen_credentials() - mongo_repository = MongoCredentialsRepository() - mongo_repository.save_configured_credentials(configured_credentials) - mongo_repository.save_stolen_credentials(stolen_credentials) + assert actual_stolen_credentials == STOLEN_CREDENTIALS - actual_stolen_credentials = mongo_repository.get_stolen_credentials() + fake_mongo_repository.remove_stolen_credentials() - assert actual_stolen_credentials == stolen_credentials - - mongo_repository.remove_stolen_credentials() - - actual_stolen_credentials = mongo_repository.get_stolen_credentials() + actual_stolen_credentials = fake_mongo_repository.get_stolen_credentials() assert actual_stolen_credentials == [] # Must remove configured also for the next tests - mongo_repository.remove_configured_credentials() + fake_mongo_repository.remove_configured_credentials() -def test_mongo_repository_all(fake_mongo): +def test_mongo_repository_all(fake_mongo_repository): - configured_credentials = [Credentials.from_mapping(CREDENTIALS_DICT_1)] - stolen_credentials = [Credentials.from_mapping(CREDENTIALS_DICT_2)] - all_credentials = [ - Credentials.from_mapping(CREDENTIALS_DICT_1), - Credentials.from_mapping(CREDENTIALS_DICT_2), - ] + fake_mongo_repository.save_configured_credentials(CONFIGURED_CREDENTIALS) + fake_mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS) - mongo_repository = MongoCredentialsRepository() - mongo_repository.save_configured_credentials(configured_credentials) - mongo_repository.save_stolen_credentials(stolen_credentials) + actual_credentials = fake_mongo_repository.get_all_credentials() - actual_credentials = mongo_repository.get_all_credentials() + assert actual_credentials == CREDENTIALS_LIST - assert actual_credentials == all_credentials + fake_mongo_repository.remove_all_credentials() - mongo_repository.remove_all_credentials() - - actual_credentials = mongo_repository.get_all_credentials() - actual_stolen_credentials = mongo_repository.get_stolen_credentials() - actual_configured_credentials = mongo_repository.get_configured_credentials() + actual_credentials = fake_mongo_repository.get_all_credentials() + actual_stolen_credentials = fake_mongo_repository.get_stolen_credentials() + actual_configured_credentials = fake_mongo_repository.get_configured_credentials() assert actual_credentials == [] assert actual_stolen_credentials == []