UT: Improve code quality of credentials encryption/decryption tests

This commit is contained in:
Mike Salvatore 2022-07-14 14:29:24 -04:00
parent 0687b010ff
commit e9dc8d88e7
1 changed files with 40 additions and 12 deletions

View File

@ -1,7 +1,11 @@
from typing import Any, Iterable, List, Mapping, Sequence
from unittest.mock import MagicMock from unittest.mock import MagicMock
import mongomock import mongomock
import pytest import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
@ -118,36 +122,60 @@ def test_mongo_repository_all(mongo_repository):
# them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to # them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to
# simplify these tests. # simplify these tests.
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS) @pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
def test_configured_secrets_encrypted(mongo_repository, mongo_client, credentials): def test_configured_secrets_encrypted(
mongo_repository: MongoCredentialsRepository,
mongo_client: MongoClient,
credentials: Sequence[Credentials],
):
mongo_repository.save_configured_credentials([credentials]) mongo_repository.save_configured_credentials([credentials])
check_if_stored_credentials_encrypted(mongo_client, credentials) check_if_stored_credentials_encrypted(mongo_client, credentials)
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS) @pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials): def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials: Credentials):
mongo_repository.save_stolen_credentials([credentials]) mongo_repository.save_stolen_credentials([credentials])
check_if_stored_credentials_encrypted(mongo_client, credentials) check_if_stored_credentials_encrypted(mongo_client, credentials)
def check_if_stored_credentials_encrypted(mongo_client, original_credentials): def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_credentials):
raw_credentials = get_all_credentials_in_mongo(mongo_client)
original_credentials_mapping = Credentials.to_mapping(original_credentials) original_credentials_mapping = Credentials.to_mapping(original_credentials)
raw_credentials = get_all_credentials_in_mongo(mongo_client)
for rc in raw_credentials: for rc in raw_credentials:
for identity_or_secret, credentials_component in rc.items(): for identity_or_secret, credentials_component in rc.items():
for key, value in credentials_component.items(): for key, value in credentials_component.items():
assert original_credentials_mapping[identity_or_secret][key] != value assert original_credentials_mapping[identity_or_secret][key] != value.decode()
def get_all_credentials_in_mongo(mongo_client): def get_all_credentials_in_mongo(
mongo_client: MongoClient,
) -> Iterable[Mapping[str, Mapping[str, Any]]]:
encrypted_credentials = [] encrypted_credentials = []
# Loop through all databases and collections and search for credentials. We don't want the tests # Loop through all databases and collections and search for credentials. We don't want the tests
# to assume anything about the internal workings of the repository. # to assume anything about the internal workings of the repository.
for db in mongo_client.list_database_names(): for collection in get_all_collections_in_mongo(mongo_client):
for collection in mongo_client[db].list_collection_names(): mongo_credentials = collection.find({})
mongo_credentials = mongo_client[db][collection].find({})
for mc in mongo_credentials: for mc in mongo_credentials:
del mc["_id"] del mc["_id"]
encrypted_credentials.append(mc) encrypted_credentials.append(mc)
return encrypted_credentials return encrypted_credentials
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
collections: List[Collection] = []
databases = get_all_databases_in_mongo(mongo_client)
for db in databases:
collections.extend(get_all_collections_in_database(db))
return collections
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
return (db[collection_name] for collection_name in db.list_collection_names())