UT: Improve code quality of credentials encryption/decryption tests
This commit is contained in:
parent
0687b010ff
commit
e9dc8d88e7
|
@ -1,7 +1,11 @@
|
||||||
|
from typing import Any, Iterable, List, Mapping, Sequence
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import mongomock
|
import mongomock
|
||||||
import pytest
|
import pytest
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.database import Database
|
||||||
from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS
|
from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS
|
||||||
|
|
||||||
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
|
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
|
||||||
|
@ -118,36 +122,60 @@ def test_mongo_repository_all(mongo_repository):
|
||||||
# them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to
|
# them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to
|
||||||
# simplify these tests.
|
# simplify these tests.
|
||||||
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
|
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
|
||||||
def test_configured_secrets_encrypted(mongo_repository, mongo_client, credentials):
|
def test_configured_secrets_encrypted(
|
||||||
|
mongo_repository: MongoCredentialsRepository,
|
||||||
|
mongo_client: MongoClient,
|
||||||
|
credentials: Sequence[Credentials],
|
||||||
|
):
|
||||||
mongo_repository.save_configured_credentials([credentials])
|
mongo_repository.save_configured_credentials([credentials])
|
||||||
check_if_stored_credentials_encrypted(mongo_client, credentials)
|
check_if_stored_credentials_encrypted(mongo_client, credentials)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
|
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
|
||||||
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials):
|
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials: Credentials):
|
||||||
mongo_repository.save_stolen_credentials([credentials])
|
mongo_repository.save_stolen_credentials([credentials])
|
||||||
check_if_stored_credentials_encrypted(mongo_client, credentials)
|
check_if_stored_credentials_encrypted(mongo_client, credentials)
|
||||||
|
|
||||||
|
|
||||||
def check_if_stored_credentials_encrypted(mongo_client, original_credentials):
|
def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_credentials):
|
||||||
raw_credentials = get_all_credentials_in_mongo(mongo_client)
|
|
||||||
original_credentials_mapping = Credentials.to_mapping(original_credentials)
|
original_credentials_mapping = Credentials.to_mapping(original_credentials)
|
||||||
|
raw_credentials = get_all_credentials_in_mongo(mongo_client)
|
||||||
|
|
||||||
for rc in raw_credentials:
|
for rc in raw_credentials:
|
||||||
for identity_or_secret, credentials_component in rc.items():
|
for identity_or_secret, credentials_component in rc.items():
|
||||||
for key, value in credentials_component.items():
|
for key, value in credentials_component.items():
|
||||||
assert original_credentials_mapping[identity_or_secret][key] != value
|
assert original_credentials_mapping[identity_or_secret][key] != value.decode()
|
||||||
|
|
||||||
|
|
||||||
def get_all_credentials_in_mongo(mongo_client):
|
def get_all_credentials_in_mongo(
|
||||||
|
mongo_client: MongoClient,
|
||||||
|
) -> Iterable[Mapping[str, Mapping[str, Any]]]:
|
||||||
encrypted_credentials = []
|
encrypted_credentials = []
|
||||||
|
|
||||||
# Loop through all databases and collections and search for credentials. We don't want the tests
|
# Loop through all databases and collections and search for credentials. We don't want the tests
|
||||||
# to assume anything about the internal workings of the repository.
|
# to assume anything about the internal workings of the repository.
|
||||||
for db in mongo_client.list_database_names():
|
for collection in get_all_collections_in_mongo(mongo_client):
|
||||||
for collection in mongo_client[db].list_collection_names():
|
mongo_credentials = collection.find({})
|
||||||
mongo_credentials = mongo_client[db][collection].find({})
|
for mc in mongo_credentials:
|
||||||
for mc in mongo_credentials:
|
del mc["_id"]
|
||||||
del mc["_id"]
|
encrypted_credentials.append(mc)
|
||||||
encrypted_credentials.append(mc)
|
|
||||||
|
|
||||||
return encrypted_credentials
|
return encrypted_credentials
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
|
||||||
|
collections: List[Collection] = []
|
||||||
|
|
||||||
|
databases = get_all_databases_in_mongo(mongo_client)
|
||||||
|
for db in databases:
|
||||||
|
collections.extend(get_all_collections_in_database(db))
|
||||||
|
|
||||||
|
return collections
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
|
||||||
|
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
|
||||||
|
return (db[collection_name] for collection_name in db.list_collection_names())
|
||||||
|
|
Loading…
Reference in New Issue