From 644f3628a5061d8809ca9bf3a9df08c1202adf03 Mon Sep 17 00:00:00 2001 From: Kekoa Kaaikala Date: Tue, 20 Sep 2022 13:47:12 +0000 Subject: [PATCH] UT: Move mongo functions into a module --- .../monkey_island/cc/repository/mongo.py | 24 +++++++++++++++++ .../test_mongo_credentials_repository.py | 26 ++----------------- 2 files changed, 26 insertions(+), 24 deletions(-) create mode 100644 monkey/tests/unit_tests/monkey_island/cc/repository/mongo.py diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/mongo.py b/monkey/tests/unit_tests/monkey_island/cc/repository/mongo.py new file mode 100644 index 000000000..26fa6340b --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/mongo.py @@ -0,0 +1,24 @@ +from typing import Iterable + +from pymongo import MongoClient +from pymongo.collection import Collection +from pymongo.database import Database + + +def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]: + collections = [ + collection + for db in get_all_databases_in_mongo(mongo_client) + for collection in get_all_collections_in_database(db) + ] + + assert len(collections) > 0 + return collections + + +def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]: + return (mongo_client[db_name] for db_name in mongo_client.list_database_names()) + + +def get_all_collections_in_database(db: Database) -> Iterable[Collection]: + return (db[collection_name] for collection_name in db.list_collection_names()) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py index a88cba05c..bf745bead 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py @@ -4,9 +4,8 @@ from unittest.mock import MagicMock import mongomock import pytest from pymongo import MongoClient -from pymongo.collection import Collection -from pymongo.database import Database from tests.data_for_tests.propagation_credentials import CREDENTIALS +from tests.unit_tests.monkey_island.cc.repository.mongo import get_all_collections_in_mongo from common.credentials import Credentials from monkey_island.cc.repository import ( @@ -166,9 +165,7 @@ def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_cr assert "***" not in value.decode() -def get_all_credentials_in_mongo( - mongo_client: MongoClient, -) -> Iterable[Mapping[str, Mapping[str, Any]]]: +def get_all_credentials_in_mongo(mongo_client: MongoClient) -> Iterable[Mapping[str, Any]]: encrypted_credentials = [] # Loop through all databases and collections and search for credentials. We don't want the tests @@ -180,22 +177,3 @@ def get_all_credentials_in_mongo( encrypted_credentials.append(mc) return encrypted_credentials - - -def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]: - collections = [ - collection - for db in get_all_databases_in_mongo(mongo_client) - for collection in get_all_collections_in_database(db) - ] - - assert len(collections) > 0 - return collections - - -def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]: - return (mongo_client[db_name] for db_name in mongo_client.list_database_names()) - - -def get_all_collections_in_database(db: Database) -> Iterable[Collection]: - return (db[collection_name] for collection_name in db.list_collection_names())