From e9dc8d88e739520ff390da0123005dedd6eb09b4 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 14 Jul 2022 14:29:24 -0400 Subject: [PATCH] UT: Improve code quality of credentials encryption/decryption tests --- .../test_mongo_credentials_repository.py | 52 ++++++++++++++----- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py index 0d1b90801..56453e4e2 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_credentials_repository.py @@ -1,7 +1,11 @@ +from typing import Any, Iterable, List, Mapping, Sequence from unittest.mock import MagicMock import mongomock import pytest +from pymongo import MongoClient +from pymongo.collection import Collection +from pymongo.database import Database from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username @@ -118,36 +122,60 @@ def test_mongo_repository_all(mongo_repository): # them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to # simplify these tests. @pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS) -def test_configured_secrets_encrypted(mongo_repository, mongo_client, credentials): +def test_configured_secrets_encrypted( + mongo_repository: MongoCredentialsRepository, + mongo_client: MongoClient, + credentials: Sequence[Credentials], +): mongo_repository.save_configured_credentials([credentials]) check_if_stored_credentials_encrypted(mongo_client, credentials) @pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS) -def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials): +def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials: Credentials): mongo_repository.save_stolen_credentials([credentials]) check_if_stored_credentials_encrypted(mongo_client, credentials) -def check_if_stored_credentials_encrypted(mongo_client, original_credentials): - raw_credentials = get_all_credentials_in_mongo(mongo_client) +def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_credentials): original_credentials_mapping = Credentials.to_mapping(original_credentials) + raw_credentials = get_all_credentials_in_mongo(mongo_client) + for rc in raw_credentials: for identity_or_secret, credentials_component in rc.items(): for key, value in credentials_component.items(): - assert original_credentials_mapping[identity_or_secret][key] != value + assert original_credentials_mapping[identity_or_secret][key] != value.decode() -def get_all_credentials_in_mongo(mongo_client): +def get_all_credentials_in_mongo( + mongo_client: MongoClient, +) -> Iterable[Mapping[str, Mapping[str, Any]]]: encrypted_credentials = [] # Loop through all databases and collections and search for credentials. We don't want the tests # to assume anything about the internal workings of the repository. - for db in mongo_client.list_database_names(): - for collection in mongo_client[db].list_collection_names(): - mongo_credentials = mongo_client[db][collection].find({}) - for mc in mongo_credentials: - del mc["_id"] - encrypted_credentials.append(mc) + for collection in get_all_collections_in_mongo(mongo_client): + mongo_credentials = collection.find({}) + for mc in mongo_credentials: + del mc["_id"] + encrypted_credentials.append(mc) return encrypted_credentials + + +def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]: + collections: List[Collection] = [] + + databases = get_all_databases_in_mongo(mongo_client) + for db in databases: + collections.extend(get_all_collections_in_database(db)) + + return collections + + +def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]: + return (mongo_client[db_name] for db_name in mongo_client.list_database_names()) + + +def get_all_collections_in_database(db: Database) -> Iterable[Collection]: + return (db[collection_name] for collection_name in db.list_collection_names())