UT: Improve code quality of credentials encryption/decryption tests

This commit is contained in:
Mike Salvatore 2022-07-14 14:29:24 -04:00
parent 0687b010ff
commit e9dc8d88e7
1 changed files with 40 additions and 12 deletions

View File

@ -1,7 +1,11 @@
from typing import Any, Iterable, List, Mapping, Sequence
from unittest.mock import MagicMock
import mongomock
import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
@ -118,36 +122,60 @@ def test_mongo_repository_all(mongo_repository):
# them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to
# simplify these tests.
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
def test_configured_secrets_encrypted(mongo_repository, mongo_client, credentials):
def test_configured_secrets_encrypted(
mongo_repository: MongoCredentialsRepository,
mongo_client: MongoClient,
credentials: Sequence[Credentials],
):
mongo_repository.save_configured_credentials([credentials])
check_if_stored_credentials_encrypted(mongo_client, credentials)
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials):
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials: Credentials):
mongo_repository.save_stolen_credentials([credentials])
check_if_stored_credentials_encrypted(mongo_client, credentials)
def check_if_stored_credentials_encrypted(mongo_client, original_credentials):
raw_credentials = get_all_credentials_in_mongo(mongo_client)
def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_credentials):
original_credentials_mapping = Credentials.to_mapping(original_credentials)
raw_credentials = get_all_credentials_in_mongo(mongo_client)
for rc in raw_credentials:
for identity_or_secret, credentials_component in rc.items():
for key, value in credentials_component.items():
assert original_credentials_mapping[identity_or_secret][key] != value
assert original_credentials_mapping[identity_or_secret][key] != value.decode()
def get_all_credentials_in_mongo(mongo_client):
def get_all_credentials_in_mongo(
mongo_client: MongoClient,
) -> Iterable[Mapping[str, Mapping[str, Any]]]:
encrypted_credentials = []
# Loop through all databases and collections and search for credentials. We don't want the tests
# to assume anything about the internal workings of the repository.
for db in mongo_client.list_database_names():
for collection in mongo_client[db].list_collection_names():
mongo_credentials = mongo_client[db][collection].find({})
for mc in mongo_credentials:
del mc["_id"]
encrypted_credentials.append(mc)
for collection in get_all_collections_in_mongo(mongo_client):
mongo_credentials = collection.find({})
for mc in mongo_credentials:
del mc["_id"]
encrypted_credentials.append(mc)
return encrypted_credentials
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
collections: List[Collection] = []
databases = get_all_databases_in_mongo(mongo_client)
for db in databases:
collections.extend(get_all_collections_in_database(db))
return collections
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
return (db[collection_name] for collection_name in db.list_collection_names())