diff --git a/monkey/monkey_island/cc/repository/mongo_event_repository.py b/monkey/monkey_island/cc/repository/mongo_event_repository.py index 6708e3341..772c0a3fb 100644 --- a/monkey/monkey_island/cc/repository/mongo_event_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_event_repository.py @@ -1,27 +1,33 @@ -from typing import Sequence, Type +from typing import Any, MutableMapping, Sequence, Type from pymongo import MongoClient +from common.event_serializers import EVENT_TYPE_FIELD, EventSerializerRegistry from common.events import AbstractAgentEvent from common.types import AgentID from monkey_island.cc.repository import IEventRepository from . import RemovalError, RetrievalError, StorageError +from .consts import MONGO_OBJECT_ID_KEY class MongoEventRepository(IEventRepository): - def __init__(self, mongo_client: MongoClient): + def __init__(self, mongo_client: MongoClient, serializer_registry: EventSerializerRegistry): self._events_collection = mongo_client.monkey_island.events + self._serializers = serializer_registry def save_event(self, event: AbstractAgentEvent): try: - self._events_collection.insert_one(event.dict(simplify=True)) + serializer = self._serializers[type(event)] + serialized_event = serializer.serialize(event) + self._events_collection.insert_one(serialized_event) except Exception as err: raise StorageError(err) def get_events(self) -> Sequence[AbstractAgentEvent]: try: - return list(self._events_collection.find()) + serialized_events = list(self._events_collection.find()) + return list(map(self._deserialize, serialized_events)) except Exception as err: raise RetrievalError(f"Error retrieving events: {err}") @@ -29,19 +35,22 @@ class MongoEventRepository(IEventRepository): self, event_type: Type[AbstractAgentEvent] ) -> Sequence[AbstractAgentEvent]: try: - return [] + serialized_events = list(self._events_collection.find({EVENT_TYPE_FIELD: event_type})) + return list(map(self._deserialize, serialized_events)) except Exception as err: raise RetrievalError(f"Error retrieving events for type {event_type}: {err}") def get_events_by_tag(self, tag: str) -> Sequence[AbstractAgentEvent]: try: - return list(self._events_collection.find({"tags": {"$in": [tag]}})) + serialized_events = list(self._events_collection.find({"tags": {"$in": [tag]}})) + return list(map(self._deserialize, serialized_events)) except Exception as err: raise RetrievalError(f"Error retrieving events for tag {tag}: {err}") def get_events_by_source(self, source: AgentID) -> Sequence[AbstractAgentEvent]: try: - return list(self._events_collection.find({"source": source})) + serialized_events = list(self._events_collection.find({"source": source})) + return list(map(self._deserialize, serialized_events)) except Exception as err: raise RetrievalError(f"Error retrieving events for source {source}: {err}") @@ -50,3 +59,9 @@ class MongoEventRepository(IEventRepository): self._events_collection.drop() except Exception as err: raise RemovalError(f"Error resetting the repository: {err}") + + def _deserialize(self, mongo_record: MutableMapping[str, Any]) -> AbstractAgentEvent: + del mongo_record[MONGO_OBJECT_ID_KEY] + event_type = mongo_record[EVENT_TYPE_FIELD] + serializer = self._serializers[event_type] + return serializer.deserialize(mongo_record) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_event_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_event_repository.py index cedd43005..ac5a02f34 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_event_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_event_repository.py @@ -4,9 +4,11 @@ from unittest.mock import MagicMock import mongomock import pytest +from pydantic import Field -from monkey.common.events.abstract_agent_event import AbstractAgentEvent -from monkey.monkey_island.cc.repository import ( +from common.event_serializers import EventSerializerRegistry, PydanticEventSerializer +from common.events import AbstractAgentEvent +from monkey_island.cc.repository import ( IEventRepository, MongoEventRepository, RemovalError, @@ -16,27 +18,41 @@ from monkey.monkey_island.cc.repository import ( class FakeAgentEvent(AbstractAgentEvent): + data = Field(default=435) + + +class FakeAgentItemEvent(AbstractAgentEvent): item: str EVENTS: List[AbstractAgentEvent] = [ - AbstractAgentEvent(source=uuid.uuid4(), tags={"foo"}), - AbstractAgentEvent(source=uuid.uuid4(), tags={"foo", "bar"}), - AbstractAgentEvent(source=uuid.uuid4(), tags={"bar", "baz"}), - FakeAgentEvent(source=uuid.uuid4(), tags={"baz"}, item="blah"), + FakeAgentEvent(source=uuid.uuid4(), tags={"foo"}), + FakeAgentEvent(source=uuid.uuid4(), tags={"foo", "bar"}), + FakeAgentEvent(source=uuid.uuid4(), tags={"bar", "baz"}), + FakeAgentItemEvent(source=uuid.uuid4(), tags={"baz"}, item="blah"), ] @pytest.fixture -def mongo_client(): +def event_serializer_registry() -> EventSerializerRegistry: + registry = EventSerializerRegistry() + registry[FakeAgentEvent] = PydanticEventSerializer(FakeAgentEvent) + registry[FakeAgentItemEvent] = PydanticEventSerializer(FakeAgentItemEvent) + return registry + + +@pytest.fixture +def mongo_client(event_serializer_registry): client = mongomock.MongoClient() - client.monkey_island.events.insert_many((e.dict(simplify=True) for e in EVENTS)) + client.monkey_island.events.insert_many( + (event_serializer_registry[type(e)].serialize(e) for e in EVENTS) + ) return client @pytest.fixture -def mongo_repository(mongo_client) -> IEventRepository: - return MongoEventRepository(mongo_client) +def mongo_repository(mongo_client, event_serializer_registry) -> IEventRepository: + return MongoEventRepository(mongo_client, event_serializer_registry) @pytest.fixture @@ -63,18 +79,20 @@ def error_raising_mongo_client(mongo_client) -> mongomock.MongoClient: @pytest.fixture -def error_raising_mongo_repository(error_raising_mongo_client) -> IEventRepository: - return MongoEventRepository(error_raising_mongo_client) +def error_raising_mongo_repository( + error_raising_mongo_client, event_serializer_registry +) -> IEventRepository: + return MongoEventRepository(error_raising_mongo_client, event_serializer_registry) def assert_same_contents(a, b): assert len(a) == len(b) - difference = set(a) ^ set(b) - assert not difference + for item in a: + assert item in b def test_mongo_event_repository__save_event(mongo_repository: IEventRepository): - event = AbstractAgentEvent(source=uuid.uuid4()) + event = FakeAgentEvent(source=uuid.uuid4()) mongo_repository.save_event(event) events = mongo_repository.get_events() @@ -84,7 +102,7 @@ def test_mongo_event_repository__save_event(mongo_repository: IEventRepository): def test_mongo_event_repository__save_event_raises( error_raising_mongo_repository: IEventRepository, ): - event = AbstractAgentEvent(source=uuid.uuid4()) + event = FakeAgentEvent(source=uuid.uuid4()) with pytest.raises(StorageError): error_raising_mongo_repository.save_event(event) @@ -104,7 +122,7 @@ def test_mongo_event_repository__get_events_raises( def test_mongo_event_repository__get_events_by_type(mongo_repository: IEventRepository): - events = mongo_repository.get_events_by_type(FakeAgentEvent) + events = mongo_repository.get_events_by_type(FakeAgentItemEvent) expected_events = [EVENTS[3]] assert_same_contents(events, expected_events) @@ -114,7 +132,7 @@ def test_mongo_event_repository__get_events_by_type_raises( error_raising_mongo_repository: IEventRepository, ): with pytest.raises(RetrievalError): - error_raising_mongo_repository.get_events_by_type(FakeAgentEvent) + error_raising_mongo_repository.get_events_by_type(FakeAgentItemEvent) def test_mongo_event_repository__get_events_by_tag(mongo_repository: IEventRepository):