diff --git a/monkey/common/event_queue/__init__.py b/monkey/common/event_queue/__init__.py index 1c2e287d7..63a105a06 100644 --- a/monkey/common/event_queue/__init__.py +++ b/monkey/common/event_queue/__init__.py @@ -2,3 +2,4 @@ from .types import AgentEventSubscriber from .pypubsub_publisher_wrapper import PyPubSubPublisherWrapper from .i_agent_event_queue import IAgentEventQueue from .pypubsub_agent_event_queue import PyPubSubAgentEventQueue +from .locking_agent_event_queue_decorator import LockingAgentEventQueueDecorator diff --git a/monkey/common/event_queue/locking_agent_event_queue_decorator.py b/monkey/common/event_queue/locking_agent_event_queue_decorator.py new file mode 100644 index 000000000..c3aa8097a --- /dev/null +++ b/monkey/common/event_queue/locking_agent_event_queue_decorator.py @@ -0,0 +1,31 @@ +from threading import Lock +from typing import Type + +from common.agent_events import AbstractAgentEvent + +from . import AgentEventSubscriber, IAgentEventQueue + + +class LockingAgentEventQueueDecorator(IAgentEventQueue): + """ + Makes an IAgentEventQueue thread-safe by locking publish() + """ + + def __init__(self, agent_event_queue: IAgentEventQueue, lock: Lock): + self._lock = lock + self._agent_event_queue = agent_event_queue + + def subscribe_all_events(self, subscriber: AgentEventSubscriber): + self._agent_event_queue.subscribe_all_events(subscriber) + + def subscribe_type( + self, event_type: Type[AbstractAgentEvent], subscriber: AgentEventSubscriber + ): + self._agent_event_queue.subscribe_type(event_type, subscriber) + + def subscribe_tag(self, tag: str, subscriber: AgentEventSubscriber): + self._agent_event_queue.subscribe_tag(tag, subscriber) + + def publish(self, event: AbstractAgentEvent): + with self._lock: + self._agent_event_queue.publish(event) diff --git a/monkey/monkey_island/cc/event_queue/__init__.py b/monkey/monkey_island/cc/event_queue/__init__.py index 6eab2b6ac..1a9ec662b 100644 --- a/monkey/monkey_island/cc/event_queue/__init__.py +++ b/monkey/monkey_island/cc/event_queue/__init__.py @@ -1,3 +1,4 @@ from .types import IslandEventSubscriber from .i_island_event_queue import IIslandEventQueue, IslandEventTopic from .pypubsub_island_event_queue import PyPubSubIslandEventQueue +from .locking_island_event_queue_decorator import LockingIslandEventQueueDecorator diff --git a/monkey/monkey_island/cc/event_queue/locking_island_event_queue_decorator.py b/monkey/monkey_island/cc/event_queue/locking_island_event_queue_decorator.py new file mode 100644 index 000000000..200f2f1d7 --- /dev/null +++ b/monkey/monkey_island/cc/event_queue/locking_island_event_queue_decorator.py @@ -0,0 +1,20 @@ +from threading import Lock + +from . import IIslandEventQueue, IslandEventSubscriber, IslandEventTopic + + +class LockingIslandEventQueueDecorator(IIslandEventQueue): + """ + Makes an IIslandEventQueue thread-safe by locking publish() + """ + + def __init__(self, island_event_queue: IIslandEventQueue, lock: Lock): + self._lock = lock + self._island_event_queue = island_event_queue + + def subscribe(self, topic: IslandEventTopic, subscriber: IslandEventSubscriber): + self._island_event_queue.subscribe(topic, subscriber) + + def publish(self, topic: IslandEventTopic, **kwargs): + with self._lock: + self._island_event_queue.publish(topic, **kwargs) diff --git a/monkey/monkey_island/cc/event_queue/pypubsub_island_event_queue.py b/monkey/monkey_island/cc/event_queue/pypubsub_island_event_queue.py index 8f7759072..b2259ce3e 100644 --- a/monkey/monkey_island/cc/event_queue/pypubsub_island_event_queue.py +++ b/monkey/monkey_island/cc/event_queue/pypubsub_island_event_queue.py @@ -10,6 +10,10 @@ logger = logging.getLogger(__name__) class PyPubSubIslandEventQueue(IIslandEventQueue): + """ + Implements IIslandEventQueue using pypubsub + """ + def __init__(self, pypubsub_publisher: Publisher): self._pypubsub_publisher_wrapper = PyPubSubPublisherWrapper(pypubsub_publisher) diff --git a/monkey/monkey_island/cc/services/initialize.py b/monkey/monkey_island/cc/services/initialize.py index 1cd020fd8..40839c388 100644 --- a/monkey/monkey_island/cc/services/initialize.py +++ b/monkey/monkey_island/cc/services/initialize.py @@ -1,4 +1,5 @@ import logging +import threading from pathlib import Path from pubsub.core import Publisher @@ -15,9 +16,17 @@ from common.agent_event_serializers import ( register_common_agent_event_serializers, ) from common.aws import AWSInstance -from common.event_queue import IAgentEventQueue, PyPubSubAgentEventQueue +from common.event_queue import ( + IAgentEventQueue, + LockingAgentEventQueueDecorator, + PyPubSubAgentEventQueue, +) from common.utils.file_utils import get_binary_io_sha256_hash -from monkey_island.cc.event_queue import IIslandEventQueue, PyPubSubIslandEventQueue +from monkey_island.cc.event_queue import ( + IIslandEventQueue, + LockingIslandEventQueueDecorator, + PyPubSubIslandEventQueue, +) from monkey_island.cc.repository import ( AgentBinaryRepository, FileAgentConfigurationRepository, @@ -72,8 +81,7 @@ def initialize_services(container: DIContainer, data_dir: Path): ILockableEncryptor, RepositoryEncryptor(data_dir / REPOSITORY_KEY_FILE_NAME) ) container.register(Publisher, Publisher) - container.register_instance(IAgentEventQueue, container.resolve(PyPubSubAgentEventQueue)) - container.register_instance(IIslandEventQueue, container.resolve(PyPubSubIslandEventQueue)) + _register_event_queues(container) _setup_agent_event_serializers(container) _register_repositories(container, data_dir) @@ -100,6 +108,32 @@ def _register_conventions(container: DIContainer): ) +def _register_event_queues(container: DIContainer): + event_queue_lock = threading.Lock() + + agent_event_queue = container.resolve(PyPubSubAgentEventQueue) + decorated_agent_event_queue = _decorate_agent_event_queue(agent_event_queue, event_queue_lock) + container.register_instance(IAgentEventQueue, decorated_agent_event_queue) + + island_event_queue = container.resolve(PyPubSubIslandEventQueue) + decorated_island_event_queue = _decorate_island_event_queue( + island_event_queue, event_queue_lock + ) + container.register_instance(IIslandEventQueue, decorated_island_event_queue) + + +def _decorate_agent_event_queue( + agent_event_queue: IAgentEventQueue, lock: threading.Lock +) -> IAgentEventQueue: + return LockingAgentEventQueueDecorator(agent_event_queue, lock) + + +def _decorate_island_event_queue( + island_event_queue: IIslandEventQueue, lock: threading.Lock +) -> IIslandEventQueue: + return LockingIslandEventQueueDecorator(island_event_queue, lock) + + def _register_repositories(container: DIContainer, data_dir: Path): container.register_instance( IFileRepository,