From fef6350871bd46cbed39d4a074c29975c2c74e03 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 30 Sep 2021 13:13:26 -0400 Subject: [PATCH] Tests: Reduced code duplication in database initializer tests --- .../setup/mongo/test_database_initializer.py | 67 +++++++++++-------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/monkey/tests/unit_tests/monkey_island/cc/setup/mongo/test_database_initializer.py b/monkey/tests/unit_tests/monkey_island/cc/setup/mongo/test_database_initializer.py index ed20c5ea0..d3ca3fbcc 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/setup/mongo/test_database_initializer.py +++ b/monkey/tests/unit_tests/monkey_island/cc/setup/mongo/test_database_initializer.py @@ -1,4 +1,3 @@ -from pathlib import Path from unittest.mock import MagicMock import mongomock @@ -8,49 +7,59 @@ from monkey_island.cc.setup.mongo.database_initializer import reset_database @pytest.fixture -def fake_mongo(monkeypatch): - mongo = mongomock.MongoClient() - monkeypatch.setattr("monkey_island.cc.setup.mongo.database_initializer.mongo", mongo) - monkeypatch.setattr("monkey_island.cc.services.database.mongo", mongo) - return mongo +def patch_attack_mitigations_path(monkeypatch, data_for_tests_dir): + def inner(file_name): + path = data_for_tests_dir / "mongo_mitigations" / file_name + monkeypatch.setattr( + "monkey_island.cc.setup.mongo.database_initializer.ATTACK_MITIGATION_PATH", path + ) + + return inner + + +@pytest.fixture(scope="module", autouse=True) +def patch_dependencies(monkeypatch_session): + monkeypatch_session.setattr( + "monkey_island.cc.services.config.ConfigService.init_config", lambda: None + ) + monkeypatch_session.setattr( + "monkey_island.cc.services.attack.attack_config.AttackConfig.reset_config", lambda: None + ) + monkeypatch_session.setattr( + "monkey_island.cc.services.database.jsonify", MagicMock(return_value=True) + ) @pytest.fixture -def fake_config(monkeypatch): - monkeypatch.setattr("monkey_island.cc.services.config.ConfigService.init_config", lambda: None) - monkeypatch.setattr("monkey_island.cc.services.attack.attack_config.AttackConfig.reset_config", lambda: None) - monkeypatch.setattr("monkey_island.cc.services.database.jsonify", MagicMock(return_value=True)) +def mock_mongo_client(monkeypatch): + mongo = mongomock.MongoClient() + mongo.db.validate_collection = MagicMock(return_value=True) + + monkeypatch.setattr("monkey_island.cc.setup.mongo.database_initializer.mongo", mongo) + monkeypatch.setattr("monkey_island.cc.services.database.mongo", mongo) + + return mongo -def test_store_mitigations_on_mongo(monkeypatch, data_for_tests_dir, fake_mongo, fake_config): - monkeypatch.setattr( - "monkey_island.cc.setup.mongo.database_initializer.ATTACK_MITIGATION_PATH", - Path(data_for_tests_dir) / "mongo_mitigations" / "attack_mitigations.json", - ) - fake_mongo.db.validate_collection = MagicMock(return_value=True) +def test_store_mitigations_on_mongo(patch_attack_mitigations_path, mock_mongo_client): + patch_attack_mitigations_path("attack_mitigations.json") + reset_database() - assert len(list(fake_mongo.db.attack_mitigations.find({}))) == 3 + assert len(list(mock_mongo_client.db.attack_mitigations.find({}))) == 3 -def test_store_mitigations_on_mongo__invalid_mitigation( - monkeypatch, data_for_tests_dir, fake_mongo, fake_config -): - monkeypatch.setattr( - "monkey_island.cc.setup.mongo.database_initializer.ATTACK_MITIGATION_PATH", - Path(data_for_tests_dir) / "mongo_mitigations" / "invalid_mitigation", - ) - fake_mongo.db.validate_collection = MagicMock(return_value=True) +def test_store_mitigations_on_mongo__invalid_mitigation(patch_attack_mitigations_path): + patch_attack_mitigations_path("invalid_mitigation") + with pytest.raises(Exception): reset_database() -def test_get_all_mitigations(monkeypatch, fake_mongo, fake_config): - fake_mongo.db.validate_collection = MagicMock(return_value=True) - +def test_get_all_mitigations(mock_mongo_client): reset_database() - mitigations = list(fake_mongo.db.attack_mitigations.find({})) + mitigations = list(mock_mongo_client.db.attack_mitigations.find({})) assert len(mitigations) >= 266