forked from p15670423/monkey
UT: Improve code quality of credentials encryption/decryption tests
This commit is contained in:
parent
0687b010ff
commit
e9dc8d88e7
|
@ -1,7 +1,11 @@
|
|||
from typing import Any, Iterable, List, Mapping, Sequence
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import mongomock
|
||||
import pytest
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS
|
||||
|
||||
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
|
||||
|
@ -118,36 +122,60 @@ def test_mongo_repository_all(mongo_repository):
|
|||
# them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to
|
||||
# simplify these tests.
|
||||
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
|
||||
def test_configured_secrets_encrypted(mongo_repository, mongo_client, credentials):
|
||||
def test_configured_secrets_encrypted(
|
||||
mongo_repository: MongoCredentialsRepository,
|
||||
mongo_client: MongoClient,
|
||||
credentials: Sequence[Credentials],
|
||||
):
|
||||
mongo_repository.save_configured_credentials([credentials])
|
||||
check_if_stored_credentials_encrypted(mongo_client, credentials)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
|
||||
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials):
|
||||
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials: Credentials):
|
||||
mongo_repository.save_stolen_credentials([credentials])
|
||||
check_if_stored_credentials_encrypted(mongo_client, credentials)
|
||||
|
||||
|
||||
def check_if_stored_credentials_encrypted(mongo_client, original_credentials):
|
||||
raw_credentials = get_all_credentials_in_mongo(mongo_client)
|
||||
def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_credentials):
|
||||
original_credentials_mapping = Credentials.to_mapping(original_credentials)
|
||||
raw_credentials = get_all_credentials_in_mongo(mongo_client)
|
||||
|
||||
for rc in raw_credentials:
|
||||
for identity_or_secret, credentials_component in rc.items():
|
||||
for key, value in credentials_component.items():
|
||||
assert original_credentials_mapping[identity_or_secret][key] != value
|
||||
assert original_credentials_mapping[identity_or_secret][key] != value.decode()
|
||||
|
||||
|
||||
def get_all_credentials_in_mongo(mongo_client):
|
||||
def get_all_credentials_in_mongo(
|
||||
mongo_client: MongoClient,
|
||||
) -> Iterable[Mapping[str, Mapping[str, Any]]]:
|
||||
encrypted_credentials = []
|
||||
|
||||
# Loop through all databases and collections and search for credentials. We don't want the tests
|
||||
# to assume anything about the internal workings of the repository.
|
||||
for db in mongo_client.list_database_names():
|
||||
for collection in mongo_client[db].list_collection_names():
|
||||
mongo_credentials = mongo_client[db][collection].find({})
|
||||
for mc in mongo_credentials:
|
||||
del mc["_id"]
|
||||
encrypted_credentials.append(mc)
|
||||
for collection in get_all_collections_in_mongo(mongo_client):
|
||||
mongo_credentials = collection.find({})
|
||||
for mc in mongo_credentials:
|
||||
del mc["_id"]
|
||||
encrypted_credentials.append(mc)
|
||||
|
||||
return encrypted_credentials
|
||||
|
||||
|
||||
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
|
||||
collections: List[Collection] = []
|
||||
|
||||
databases = get_all_databases_in_mongo(mongo_client)
|
||||
for db in databases:
|
||||
collections.extend(get_all_collections_in_database(db))
|
||||
|
||||
return collections
|
||||
|
||||
|
||||
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
|
||||
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
|
||||
|
||||
|
||||
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
|
||||
return (db[collection_name] for collection_name in db.list_collection_names())
|
||||
|
|
Loading…
Reference in New Issue