UT: Use nested comprehension in get_all_collections_in_mongo()

This commit is contained in:
Mike Salvatore 2022-07-15 10:05:17 -04:00
parent 3f20b71d25
commit fb11c29208
1 changed files with 8 additions and 6 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Iterable, List, Mapping, Sequence from typing import Any, Iterable, Mapping, Sequence
from unittest.mock import MagicMock from unittest.mock import MagicMock
import mongomock import mongomock
@ -22,6 +22,7 @@ def reverse(data: bytes) -> bytes:
@pytest.fixture @pytest.fixture
def repository_encryptor(): def repository_encryptor():
# NOTE: Tests will fail if any inputs to this mock encryptor are palindromes.
repository_encryptor = MagicMock(spec=ILockableEncryptor) repository_encryptor = MagicMock(spec=ILockableEncryptor)
repository_encryptor.encrypt = MagicMock(side_effect=reverse) repository_encryptor.encrypt = MagicMock(side_effect=reverse)
repository_encryptor.decrypt = MagicMock(side_effect=reverse) repository_encryptor.decrypt = MagicMock(side_effect=reverse)
@ -136,12 +137,13 @@ def get_all_credentials_in_mongo(
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]: def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
collections: List[Collection] = [] collections = [
collection
databases = get_all_databases_in_mongo(mongo_client) for db in get_all_databases_in_mongo(mongo_client)
for db in databases: for collection in get_all_collections_in_database(db)
collections.extend(get_all_collections_in_database(db)) ]
assert len(collections) > 0
return collections return collections