From 4226cb5b9ee4bb6a31f4c1bd71d626a499424a4e Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Mon, 11 Jul 2022 12:25:08 +0200 Subject: [PATCH] Island: Move error handling to private methods in MongoCredentialsRepository --- .../mongo_credentials_repository.py | 97 ++++++++----------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py index b0acf405e..4d9566420 100644 --- a/monkey/monkey_island/cc/repository/mongo_credentials_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_credentials_repository.py @@ -1,6 +1,6 @@ from typing import Sequence -from flask_pymongo import PyMongo +from pymongo import MongoClient from common.credentials import Credentials from monkey_island.cc.repository import RemovalError, RetrievalError, StorageError @@ -12,87 +12,70 @@ class MongoCredentialsRepository(ICredentialsRepository): Store credentials in a mongo database that can be used to propagate around the network. """ - def __init__(self, mongo_db: PyMongo): - self._mongo = mongo_db + def __init__(self, mongo: MongoClient): + self._mongo = mongo def get_configured_credentials(self) -> Sequence[Credentials]: - try: - - return MongoCredentialsRepository._get_credentials_from_collection( - self._mongo.db.configured_credentials - ) - except Exception as err: - raise RetrievalError(err) + return MongoCredentialsRepository._get_credentials_from_collection( + self._mongo.db.configured_credentials + ) def get_stolen_credentials(self) -> Sequence[Credentials]: - try: - return MongoCredentialsRepository._get_credentials_from_collection( - self._mongo.db.stolen_credentials - ) - except Exception as err: - raise RetrievalError(err) + return MongoCredentialsRepository._get_credentials_from_collection( + self._mongo.db.stolen_credentials + ) def get_all_credentials(self) -> Sequence[Credentials]: - try: - configured_credentials = self.get_configured_credentials() - stolen_credentials = self.get_stolen_credentials() + configured_credentials = self.get_configured_credentials() + stolen_credentials = self.get_stolen_credentials() - return [*configured_credentials, *stolen_credentials] - except RetrievalError as err: - raise err + return [*configured_credentials, *stolen_credentials] def save_configured_credentials(self, credentials: Sequence[Credentials]): # TODO: Fix deduplication of Credentials in mongo - try: - MongoCredentialsRepository._save_credentials_to_collection( - credentials, self._mongo.db.configured_credentials - ) - except Exception as err: - raise StorageError(err) + MongoCredentialsRepository._save_credentials_to_collection( + credentials, self._mongo.db.configured_credentials + ) def save_stolen_credentials(self, credentials: Sequence[Credentials]): - # TODO: Fix deduplication of Credentials in mongo - try: - MongoCredentialsRepository._save_credentials_to_collection( - credentials, self._mongo.db.stolen_credentials - ) - except Exception as err: - raise StorageError(err) + MongoCredentialsRepository._save_credentials_to_collection( + credentials, self._mongo.db.stolen_credentials + ) def remove_configured_credentials(self): - try: - MongoCredentialsRepository._delete_collection(self._mongo.db.configured_credentials) - except Exception as err: - raise RemovalError(err) + MongoCredentialsRepository._delete_collection(self._mongo.db.configured_credentials) def remove_stolen_credentials(self): - try: - MongoCredentialsRepository._delete_collection(self._mongo.db.stolen_credentials) - except Exception as err: - raise RemovalError(err) + MongoCredentialsRepository._delete_collection(self._mongo.db.stolen_credentials) def remove_all_credentials(self): - try: - self.remove_configured_credentials() - self.remove_stolen_credentials() - except RemovalError as err: - raise err + self.remove_configured_credentials() + self.remove_stolen_credentials() @staticmethod def _get_credentials_from_collection(collection) -> Sequence[Credentials]: - collection_result = [] - list_collection_result = list(collection.find({})) - for c in list_collection_result: - del c["_id"] - collection_result.append(Credentials.from_mapping(c)) + try: + collection_result = [] + list_collection_result = list(collection.find({})) + for c in list_collection_result: + del c["_id"] + collection_result.append(Credentials.from_mapping(c)) - return collection_result + return collection_result + except Exception as err: + raise RetrievalError(err) @staticmethod def _save_credentials_to_collection(credentials: Sequence[Credentials], collection): - for c in credentials: - collection.insert_one(Credentials.to_mapping(c)) + try: + for c in credentials: + collection.insert_one(Credentials.to_mapping(c)) + except Exception as err: + raise StorageError(err) @staticmethod def _delete_collection(collection): - collection.delete_many({}) + try: + collection.delete_many({}) + except RemovalError as err: + raise err