Island: Refactor MongoCredentialsRepository

* Remove code duplication
* Init with PyMongo object
This commit is contained in:
Ilija Lazoroski 2022-07-08 19:20:25 +02:00
parent c808d50948
commit c48b38fb01
2 changed files with 78 additions and 70 deletions

View File

@ -1,7 +1,8 @@
from typing import Sequence from typing import Sequence
from flask_pymongo import PyMongo
from common.credentials import Credentials from common.credentials import Credentials
from monkey_island.cc.database import mongo
from monkey_island.cc.repository import RemovalError, RetrievalError, StorageError from monkey_island.cc.repository import RemovalError, RetrievalError, StorageError
from monkey_island.cc.repository.i_credentials_repository import ICredentialsRepository from monkey_island.cc.repository.i_credentials_repository import ICredentialsRepository
@ -11,27 +12,23 @@ class MongoCredentialsRepository(ICredentialsRepository):
Store credentials in a mongo database that can be used to propagate around the network. Store credentials in a mongo database that can be used to propagate around the network.
""" """
def __init__(self, mongo_db: PyMongo):
self._mongo = mongo_db
def get_configured_credentials(self) -> Sequence[Credentials]: def get_configured_credentials(self) -> Sequence[Credentials]:
try: try:
configured_credentials = []
list_configured_credentials = list(mongo.db.configured_credentials.find({}))
for c in list_configured_credentials:
del c["_id"]
configured_credentials.append(Credentials.from_mapping(c))
return configured_credentials return MongoCredentialsRepository._get_credentials_from_collection(
self._mongo.db.configured_credentials
)
except Exception as err: except Exception as err:
raise RetrievalError(err) raise RetrievalError(err)
def get_stolen_credentials(self) -> Sequence[Credentials]: def get_stolen_credentials(self) -> Sequence[Credentials]:
try: try:
stolen_credentials = [] return MongoCredentialsRepository._get_credentials_from_collection(
list_stolen_credentials = list(mongo.db.stolen_credentials.find({})) self._mongo.db.stolen_credentials
for c in list_stolen_credentials: )
del c["_id"]
stolen_credentials.append(Credentials.from_mapping(c))
return stolen_credentials
except Exception as err: except Exception as err:
raise RetrievalError(err) raise RetrievalError(err)
@ -47,28 +44,30 @@ class MongoCredentialsRepository(ICredentialsRepository):
def save_configured_credentials(self, credentials: Sequence[Credentials]): def save_configured_credentials(self, credentials: Sequence[Credentials]):
# TODO: Fix deduplication of Credentials in mongo # TODO: Fix deduplication of Credentials in mongo
try: try:
for c in credentials: MongoCredentialsRepository._save_credentials_to_collection(
mongo.db.configured_credentials.insert_one(Credentials.to_mapping(c)) credentials, self._mongo.db.configured_credentials
)
except Exception as err: except Exception as err:
raise StorageError(err) raise StorageError(err)
def save_stolen_credentials(self, credentials: Sequence[Credentials]): def save_stolen_credentials(self, credentials: Sequence[Credentials]):
# TODO: Fix deduplication of Credentials in mongo # TODO: Fix deduplication of Credentials in mongo
try: try:
for c in credentials: MongoCredentialsRepository._save_credentials_to_collection(
mongo.db.stolen_credentials.insert_one(Credentials.to_mapping(c)) credentials, self._mongo.db.stolen_credentials
)
except Exception as err: except Exception as err:
raise StorageError(err) raise StorageError(err)
def remove_configured_credentials(self): def remove_configured_credentials(self):
try: try:
mongo.db.configured_credentials.delete_many({}) MongoCredentialsRepository._delete_collection(self._mongo.db.configured_credentials)
except Exception as err: except Exception as err:
raise RemovalError(err) raise RemovalError(err)
def remove_stolen_credentials(self): def remove_stolen_credentials(self):
try: try:
mongo.db.stolen_credentials.delete_many({}) MongoCredentialsRepository._delete_collection(self._mongo.db.stolen_credentials)
except Exception as err: except Exception as err:
raise RemovalError(err) raise RemovalError(err)
@ -78,3 +77,22 @@ class MongoCredentialsRepository(ICredentialsRepository):
self.remove_stolen_credentials() self.remove_stolen_credentials()
except RemovalError as err: except RemovalError as err:
raise err raise err
@staticmethod
def _get_credentials_from_collection(collection) -> Sequence[Credentials]:
collection_result = []
list_collection_result = list(collection.find({}))
for c in list_collection_result:
del c["_id"]
collection_result.append(Credentials.from_mapping(c))
return collection_result
@staticmethod
def _save_credentials_to_collection(credentials: Sequence[Credentials], collection):
for c in credentials:
collection.insert_one(Credentials.to_mapping(c))
@staticmethod
def _delete_collection(collection):
collection.delete_many({})

View File

@ -42,101 +42,91 @@ CREDENTIALS_DICT_2 = {
], ],
} }
CONFIGURED_CREDENTIALS = [Credentials.from_mapping(CREDENTIALS_DICT_1)]
STOLEN_CREDENTIALS = [Credentials.from_mapping(CREDENTIALS_DICT_2)]
CREDENTIALS_LIST = [
Credentials.from_mapping(CREDENTIALS_DICT_1),
Credentials.from_mapping(CREDENTIALS_DICT_2),
]
@pytest.fixture @pytest.fixture
def fake_mongo(monkeypatch): def fake_mongo_repository(monkeypatch):
mongo = mongoengine.connection.get_connection() mongo = mongoengine.connection.get_connection()
monkeypatch.setattr("monkey_island.cc.repository.mongo_credentials_repository.mongo", mongo) return MongoCredentialsRepository(mongo)
def test_mongo_repository_get_configured(fake_mongo): def test_mongo_repository_get_configured(fake_mongo_repository):
actual_configured_credentials = MongoCredentialsRepository().get_configured_credentials() actual_configured_credentials = fake_mongo_repository.get_configured_credentials()
assert actual_configured_credentials == [] assert actual_configured_credentials == []
def test_mongo_repository_get_stolen(fake_mongo): def test_mongo_repository_get_stolen(fake_mongo_repository):
actual_stolen_credentials = MongoCredentialsRepository().get_stolen_credentials() actual_stolen_credentials = fake_mongo_repository.get_stolen_credentials()
assert actual_stolen_credentials == [] assert actual_stolen_credentials == []
def test_mongo_repository_get_all(fake_mongo): def test_mongo_repository_get_all(fake_mongo_repository):
actual_credentials = MongoCredentialsRepository().get_all_credentials() actual_credentials = fake_mongo_repository.get_all_credentials()
assert actual_credentials == [] assert actual_credentials == []
def test_mongo_repository_configured(fake_mongo): def test_mongo_repository_configured(fake_mongo_repository):
credentials = [ fake_mongo_repository.save_configured_credentials(CREDENTIALS_LIST)
Credentials.from_mapping(CREDENTIALS_DICT_1),
Credentials.from_mapping(CREDENTIALS_DICT_2),
]
mongo_repository = MongoCredentialsRepository() actual_configured_credentials = fake_mongo_repository.get_configured_credentials()
mongo_repository.save_configured_credentials(credentials)
actual_configured_credentials = mongo_repository.get_configured_credentials() assert actual_configured_credentials == CREDENTIALS_LIST
assert actual_configured_credentials == credentials fake_mongo_repository.remove_configured_credentials()
mongo_repository.remove_configured_credentials() actual_configured_credentials = fake_mongo_repository.get_configured_credentials()
actual_configured_credentials = mongo_repository.get_configured_credentials()
assert actual_configured_credentials == [] assert actual_configured_credentials == []
def test_mongo_repository_stolen(fake_mongo): def test_mongo_repository_stolen(fake_mongo_repository):
stolen_credentials = [Credentials.from_mapping(CREDENTIALS_DICT_1)] fake_mongo_repository.save_configured_credentials(CONFIGURED_CREDENTIALS)
fake_mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS)
configured_credentials = [Credentials.from_mapping(CREDENTIALS_DICT_2)] actual_stolen_credentials = fake_mongo_repository.get_stolen_credentials()
mongo_repository = MongoCredentialsRepository() assert actual_stolen_credentials == STOLEN_CREDENTIALS
mongo_repository.save_configured_credentials(configured_credentials)
mongo_repository.save_stolen_credentials(stolen_credentials)
actual_stolen_credentials = mongo_repository.get_stolen_credentials() fake_mongo_repository.remove_stolen_credentials()
assert actual_stolen_credentials == stolen_credentials actual_stolen_credentials = fake_mongo_repository.get_stolen_credentials()
mongo_repository.remove_stolen_credentials()
actual_stolen_credentials = mongo_repository.get_stolen_credentials()
assert actual_stolen_credentials == [] assert actual_stolen_credentials == []
# Must remove configured also for the next tests # Must remove configured also for the next tests
mongo_repository.remove_configured_credentials() fake_mongo_repository.remove_configured_credentials()
def test_mongo_repository_all(fake_mongo): def test_mongo_repository_all(fake_mongo_repository):
configured_credentials = [Credentials.from_mapping(CREDENTIALS_DICT_1)] fake_mongo_repository.save_configured_credentials(CONFIGURED_CREDENTIALS)
stolen_credentials = [Credentials.from_mapping(CREDENTIALS_DICT_2)] fake_mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS)
all_credentials = [
Credentials.from_mapping(CREDENTIALS_DICT_1),
Credentials.from_mapping(CREDENTIALS_DICT_2),
]
mongo_repository = MongoCredentialsRepository() actual_credentials = fake_mongo_repository.get_all_credentials()
mongo_repository.save_configured_credentials(configured_credentials)
mongo_repository.save_stolen_credentials(stolen_credentials)
actual_credentials = mongo_repository.get_all_credentials() assert actual_credentials == CREDENTIALS_LIST
assert actual_credentials == all_credentials fake_mongo_repository.remove_all_credentials()
mongo_repository.remove_all_credentials() actual_credentials = fake_mongo_repository.get_all_credentials()
actual_stolen_credentials = fake_mongo_repository.get_stolen_credentials()
actual_credentials = mongo_repository.get_all_credentials() actual_configured_credentials = fake_mongo_repository.get_configured_credentials()
actual_stolen_credentials = mongo_repository.get_stolen_credentials()
actual_configured_credentials = mongo_repository.get_configured_credentials()
assert actual_credentials == [] assert actual_credentials == []
assert actual_stolen_credentials == [] assert actual_stolen_credentials == []