diff --git a/monkey/tests/unit_tests/common/event_queue/test_pypubsub_event_queue.py b/monkey/tests/unit_tests/common/event_queue/test_pypubsub_event_queue.py index 50c9635d9..a06eb3f52 100644 --- a/monkey/tests/unit_tests/common/event_queue/test_pypubsub_event_queue.py +++ b/monkey/tests/unit_tests/common/event_queue/test_pypubsub_event_queue.py @@ -1,10 +1,12 @@ from dataclasses import dataclass -from unittest.mock import MagicMock import pytest from pubsub import pub -from common.event_queue.pypubsub_event_queue import PyPubSubEventQueue +from common.event_queue.pypubsub_event_queue import ( + INTERNAL_ALL_EVENT_TYPES_TOPIC, + PyPubSubEventQueue, +) from common.events import AbstractEvent EVENT_TAG_1 = "event tag 1" @@ -19,41 +21,57 @@ class EventType(AbstractEvent): tags = [EVENT_TAG_1, EVENT_TAG_2] -@pytest.fixture(autouse=True) -def wrap_pypubsub_functions(): - # This is done so that we can use `.call_count` in the tests. - pub.sendMessage = MagicMock(side_effect=pub.sendMessage) - - pypubsub_event_queue = PyPubSubEventQueue(pub) +subscriber_1_calls = subscriber_2_calls = subscriber_1 = subscriber_2 = None + + +@pytest.fixture(autouse=True, scope="function") +def reset_subscribers(): + global subscriber_1, subscriber_2, subscriber_1_calls, subscriber_2_calls + subscriber_1_calls = [] + subscriber_2_calls = [] + subscriber_1 = lambda event, topic=pub.AUTO_TOPIC: subscriber_1_calls.append(topic.getName()) + subscriber_2 = lambda event, topic=pub.AUTO_TOPIC: subscriber_2_calls.append(topic.getName()) + + +def test_topic_subscription(): + pypubsub_event_queue.subscribe_type(EventType, subscriber_1) + pypubsub_event_queue.subscribe_tag(EVENT_TAG_2, subscriber_1) + pypubsub_event_queue.subscribe_tag(EVENT_TAG_1, subscriber_2) + pypubsub_event_queue.publish(EventType) + + assert subscriber_1_calls == [EventType.__name__, EVENT_TAG_2] + assert subscriber_2_calls == [EVENT_TAG_1] + def test_subscribe_all(): - subscriber = MagicMock() + subscriber_calls = [] + subscriber = lambda topic=pub.AUTO_TOPIC: subscriber_calls.append(topic.getName()) pypubsub_event_queue.subscribe_all(subscriber) pypubsub_event_queue.publish(EventType) - assert pub.sendMessage.call_count == 3 - assert subscriber.call_count == 3 + assert subscriber_calls == [ + EventType.__name__, + INTERNAL_ALL_EVENT_TYPES_TOPIC, + EVENT_TAG_1, + EVENT_TAG_2, + ] def test_subscribe_types(): - subscriber = MagicMock() - - pypubsub_event_queue.subscribe_type(EventType, subscriber) + pypubsub_event_queue.subscribe_type(EventType, subscriber_1) pypubsub_event_queue.publish(EventType) - assert pub.sendMessage.call_count == 3 - assert subscriber.call_count == 1 + assert subscriber_1_calls == [EventType.__name__] + assert subscriber_2_calls == [] def test_subscribe_tags(): - subscriber = MagicMock() - - pypubsub_event_queue.subscribe_tag(EVENT_TAG_1, subscriber) - pypubsub_event_queue.subscribe_tag(EVENT_TAG_2, subscriber) + pypubsub_event_queue.subscribe_tag(EVENT_TAG_1, subscriber_1) + pypubsub_event_queue.subscribe_tag(EVENT_TAG_2, subscriber_2) pypubsub_event_queue.publish(EventType) - assert pub.sendMessage.call_count == 3 - assert subscriber.call_count == 2 + assert subscriber_1_calls == [EVENT_TAG_1] + assert subscriber_2_calls == [EVENT_TAG_2]