From 3aff182d479f3e21b46b5a562d79ae872df4d422 Mon Sep 17 00:00:00 2001
From: Shreya Malviya <shreya.malviya@gmail.com>
Date: Wed, 10 Aug 2022 14:53:46 +0530
Subject: [PATCH] UT: Simplify PyPubSubEventQueue's tests

---
 .../event_queue/test_pypubsub_event_queue.py  | 62 ++++++++++++-------
 1 file changed, 40 insertions(+), 22 deletions(-)

diff --git a/monkey/tests/unit_tests/common/event_queue/test_pypubsub_event_queue.py b/monkey/tests/unit_tests/common/event_queue/test_pypubsub_event_queue.py
index 50c9635d9..a06eb3f52 100644
--- a/monkey/tests/unit_tests/common/event_queue/test_pypubsub_event_queue.py
+++ b/monkey/tests/unit_tests/common/event_queue/test_pypubsub_event_queue.py
@@ -1,10 +1,12 @@
 from dataclasses import dataclass
-from unittest.mock import MagicMock
 
 import pytest
 from pubsub import pub
 
-from common.event_queue.pypubsub_event_queue import PyPubSubEventQueue
+from common.event_queue.pypubsub_event_queue import (
+    INTERNAL_ALL_EVENT_TYPES_TOPIC,
+    PyPubSubEventQueue,
+)
 from common.events import AbstractEvent
 
 EVENT_TAG_1 = "event tag 1"
@@ -19,41 +21,57 @@ class EventType(AbstractEvent):
     tags = [EVENT_TAG_1, EVENT_TAG_2]
 
 
-@pytest.fixture(autouse=True)
-def wrap_pypubsub_functions():
-    # This is done so that we can use `.call_count` in the tests.
-    pub.sendMessage = MagicMock(side_effect=pub.sendMessage)
-
-
 pypubsub_event_queue = PyPubSubEventQueue(pub)
 
+subscriber_1_calls = subscriber_2_calls = subscriber_1 = subscriber_2 = None
+
+
+@pytest.fixture(autouse=True, scope="function")
+def reset_subscribers():
+    global subscriber_1, subscriber_2, subscriber_1_calls, subscriber_2_calls
+    subscriber_1_calls = []
+    subscriber_2_calls = []
+    subscriber_1 = lambda event, topic=pub.AUTO_TOPIC: subscriber_1_calls.append(topic.getName())
+    subscriber_2 = lambda event, topic=pub.AUTO_TOPIC: subscriber_2_calls.append(topic.getName())
+
+
+def test_topic_subscription():
+    pypubsub_event_queue.subscribe_type(EventType, subscriber_1)
+    pypubsub_event_queue.subscribe_tag(EVENT_TAG_2, subscriber_1)
+    pypubsub_event_queue.subscribe_tag(EVENT_TAG_1, subscriber_2)
+    pypubsub_event_queue.publish(EventType)
+
+    assert subscriber_1_calls == [EventType.__name__, EVENT_TAG_2]
+    assert subscriber_2_calls == [EVENT_TAG_1]
+
 
 def test_subscribe_all():
-    subscriber = MagicMock()
+    subscriber_calls = []
+    subscriber = lambda topic=pub.AUTO_TOPIC: subscriber_calls.append(topic.getName())
 
     pypubsub_event_queue.subscribe_all(subscriber)
     pypubsub_event_queue.publish(EventType)
 
-    assert pub.sendMessage.call_count == 3
-    assert subscriber.call_count == 3
+    assert subscriber_calls == [
+        EventType.__name__,
+        INTERNAL_ALL_EVENT_TYPES_TOPIC,
+        EVENT_TAG_1,
+        EVENT_TAG_2,
+    ]
 
 
 def test_subscribe_types():
-    subscriber = MagicMock()
-
-    pypubsub_event_queue.subscribe_type(EventType, subscriber)
+    pypubsub_event_queue.subscribe_type(EventType, subscriber_1)
     pypubsub_event_queue.publish(EventType)
 
-    assert pub.sendMessage.call_count == 3
-    assert subscriber.call_count == 1
+    assert subscriber_1_calls == [EventType.__name__]
+    assert subscriber_2_calls == []
 
 
 def test_subscribe_tags():
-    subscriber = MagicMock()
-
-    pypubsub_event_queue.subscribe_tag(EVENT_TAG_1, subscriber)
-    pypubsub_event_queue.subscribe_tag(EVENT_TAG_2, subscriber)
+    pypubsub_event_queue.subscribe_tag(EVENT_TAG_1, subscriber_1)
+    pypubsub_event_queue.subscribe_tag(EVENT_TAG_2, subscriber_2)
     pypubsub_event_queue.publish(EventType)
 
-    assert pub.sendMessage.call_count == 3
-    assert subscriber.call_count == 2
+    assert subscriber_1_calls == [EVENT_TAG_1]
+    assert subscriber_2_calls == [EVENT_TAG_2]