UT: Move mongo functions into a module

This commit is contained in:
Kekoa Kaaikala 2022-09-20 13:47:12 +00:00
parent 2d03e497e9
commit 644f3628a5
2 changed files with 26 additions and 24 deletions

View File

@ -0,0 +1,24 @@
from typing import Iterable
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
collections = [
collection
for db in get_all_databases_in_mongo(mongo_client)
for collection in get_all_collections_in_database(db)
]
assert len(collections) > 0
return collections
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
return (db[collection_name] for collection_name in db.list_collection_names())

View File

@ -4,9 +4,8 @@ from unittest.mock import MagicMock
import mongomock
import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from tests.data_for_tests.propagation_credentials import CREDENTIALS
from tests.unit_tests.monkey_island.cc.repository.mongo import get_all_collections_in_mongo
from common.credentials import Credentials
from monkey_island.cc.repository import (
@ -166,9 +165,7 @@ def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_cr
assert "***" not in value.decode()
def get_all_credentials_in_mongo(
mongo_client: MongoClient,
) -> Iterable[Mapping[str, Mapping[str, Any]]]:
def get_all_credentials_in_mongo(mongo_client: MongoClient) -> Iterable[Mapping[str, Any]]:
encrypted_credentials = []
# Loop through all databases and collections and search for credentials. We don't want the tests
@ -180,22 +177,3 @@ def get_all_credentials_in_mongo(
encrypted_credentials.append(mc)
return encrypted_credentials
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
collections = [
collection
for db in get_all_databases_in_mongo(mongo_client)
for collection in get_all_collections_in_database(db)
]
assert len(collections) > 0
return collections
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
return (db[collection_name] for collection_name in db.list_collection_names())