forked from p34709852/monkey
UT: Move mongo functions into a module
This commit is contained in:
parent
2d03e497e9
commit
644f3628a5
|
@ -0,0 +1,24 @@
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.database import Database
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
|
||||||
|
collections = [
|
||||||
|
collection
|
||||||
|
for db in get_all_databases_in_mongo(mongo_client)
|
||||||
|
for collection in get_all_collections_in_database(db)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(collections) > 0
|
||||||
|
return collections
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
|
||||||
|
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
|
||||||
|
return (db[collection_name] for collection_name in db.list_collection_names())
|
|
@ -4,9 +4,8 @@ from unittest.mock import MagicMock
|
||||||
import mongomock
|
import mongomock
|
||||||
import pytest
|
import pytest
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from pymongo.collection import Collection
|
|
||||||
from pymongo.database import Database
|
|
||||||
from tests.data_for_tests.propagation_credentials import CREDENTIALS
|
from tests.data_for_tests.propagation_credentials import CREDENTIALS
|
||||||
|
from tests.unit_tests.monkey_island.cc.repository.mongo import get_all_collections_in_mongo
|
||||||
|
|
||||||
from common.credentials import Credentials
|
from common.credentials import Credentials
|
||||||
from monkey_island.cc.repository import (
|
from monkey_island.cc.repository import (
|
||||||
|
@ -166,9 +165,7 @@ def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_cr
|
||||||
assert "***" not in value.decode()
|
assert "***" not in value.decode()
|
||||||
|
|
||||||
|
|
||||||
def get_all_credentials_in_mongo(
|
def get_all_credentials_in_mongo(mongo_client: MongoClient) -> Iterable[Mapping[str, Any]]:
|
||||||
mongo_client: MongoClient,
|
|
||||||
) -> Iterable[Mapping[str, Mapping[str, Any]]]:
|
|
||||||
encrypted_credentials = []
|
encrypted_credentials = []
|
||||||
|
|
||||||
# Loop through all databases and collections and search for credentials. We don't want the tests
|
# Loop through all databases and collections and search for credentials. We don't want the tests
|
||||||
|
@ -180,22 +177,3 @@ def get_all_credentials_in_mongo(
|
||||||
encrypted_credentials.append(mc)
|
encrypted_credentials.append(mc)
|
||||||
|
|
||||||
return encrypted_credentials
|
return encrypted_credentials
|
||||||
|
|
||||||
|
|
||||||
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
|
|
||||||
collections = [
|
|
||||||
collection
|
|
||||||
for db in get_all_databases_in_mongo(mongo_client)
|
|
||||||
for collection in get_all_collections_in_database(db)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(collections) > 0
|
|
||||||
return collections
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
|
|
||||||
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
|
|
||||||
return (db[collection_name] for collection_name in db.list_collection_names())
|
|
||||||
|
|
Loading…
Reference in New Issue