diff --git a/monkey/monkey_island/cc/repository/__init__.py b/monkey/monkey_island/cc/repository/__init__.py index a414b5eb0..63e9ac214 100644 --- a/monkey/monkey_island/cc/repository/__init__.py +++ b/monkey/monkey_island/cc/repository/__init__.py @@ -27,3 +27,4 @@ from .mongo_machine_repository import MongoMachineRepository from .mongo_agent_repository import MongoAgentRepository from .mongo_node_repository import MongoNodeRepository from .stubbed_event_repository import StubbedEventRepository +from .mongo_event_repository import MongoEventRepository diff --git a/monkey/monkey_island/cc/repository/mongo_event_repository.py b/monkey/monkey_island/cc/repository/mongo_event_repository.py new file mode 100644 index 000000000..6614e165b --- /dev/null +++ b/monkey/monkey_island/cc/repository/mongo_event_repository.py @@ -0,0 +1,68 @@ +from typing import Any, Dict, MutableMapping, Sequence, Type + +from pymongo import MongoClient + +from common.event_serializers import EVENT_TYPE_FIELD, EventSerializerRegistry +from common.events import AbstractAgentEvent +from common.types import AgentID +from monkey_island.cc.repository import IEventRepository + +from . import RemovalError, RetrievalError, StorageError +from .consts import MONGO_OBJECT_ID_KEY + + +class MongoEventRepository(IEventRepository): + """A repository for storing and retrieving events in MongoDB""" + + def __init__(self, mongo_client: MongoClient, serializer_registry: EventSerializerRegistry): + self._events_collection = mongo_client.monkey_island.events + self._serializers = serializer_registry + + def save_event(self, event: AbstractAgentEvent): + try: + serializer = self._serializers[type(event)] + serialized_event = serializer.serialize(event) + self._events_collection.insert_one(serialized_event) + except Exception as err: + raise StorageError(f"Error saving event: {err}") + + def get_events(self) -> Sequence[AbstractAgentEvent]: + try: + return self._query_events({}) + except Exception as err: + raise RetrievalError(f"Error retrieving events: {err}") + + def get_events_by_type( + self, event_type: Type[AbstractAgentEvent] + ) -> Sequence[AbstractAgentEvent]: + try: + return self._query_events({EVENT_TYPE_FIELD: event_type.__name__}) + except Exception as err: + raise RetrievalError(f"Error retrieving events for type {event_type}: {err}") + + def get_events_by_tag(self, tag: str) -> Sequence[AbstractAgentEvent]: + try: + return self._query_events({"tags": {"$in": [tag]}}) + except Exception as err: + raise RetrievalError(f"Error retrieving events for tag {tag}: {err}") + + def get_events_by_source(self, source: AgentID) -> Sequence[AbstractAgentEvent]: + try: + return self._query_events({"source": str(source)}) + except Exception as err: + raise RetrievalError(f"Error retrieving events for source {source}: {err}") + + def reset(self): + try: + self._events_collection.drop() + except Exception as err: + raise RemovalError(f"Error resetting the repository: {err}") + + def _deserialize(self, mongo_record: MutableMapping[str, Any]) -> AbstractAgentEvent: + event_type = mongo_record[EVENT_TYPE_FIELD] + serializer = self._serializers[event_type] + return serializer.deserialize(mongo_record) + + def _query_events(self, query: Dict[Any, Any]) -> Sequence[AbstractAgentEvent]: + serialized_events = self._events_collection.find(query, {MONGO_OBJECT_ID_KEY: False}) + return list(map(self._deserialize, serialized_events)) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_event_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_event_repository.py new file mode 100644 index 000000000..64de7edc8 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_event_repository.py @@ -0,0 +1,172 @@ +import uuid +from typing import List +from unittest.mock import MagicMock + +import mongomock +import pytest +from pydantic import Field + +from common.event_serializers import EventSerializerRegistry, PydanticEventSerializer +from common.events import AbstractAgentEvent +from monkey_island.cc.repository import ( + IEventRepository, + MongoEventRepository, + RemovalError, + RetrievalError, + StorageError, +) + + +class FakeAgentEvent(AbstractAgentEvent): + data = Field(default=435) + + +class FakeAgentItemEvent(AbstractAgentEvent): + item: str + + +EVENTS: List[AbstractAgentEvent] = [ + FakeAgentEvent(source=uuid.uuid4(), tags={"foo"}), + FakeAgentEvent(source=uuid.uuid4(), tags={"foo", "bar"}), + FakeAgentEvent(source=uuid.uuid4(), tags={"bar", "baz"}), + FakeAgentItemEvent(source=uuid.uuid4(), tags={"baz"}, item="blah"), +] + + +@pytest.fixture +def event_serializer_registry() -> EventSerializerRegistry: + registry = EventSerializerRegistry() + registry[FakeAgentEvent] = PydanticEventSerializer(FakeAgentEvent) + registry[FakeAgentItemEvent] = PydanticEventSerializer(FakeAgentItemEvent) + return registry + + +@pytest.fixture +def mongo_client(event_serializer_registry): + client = mongomock.MongoClient() + client.monkey_island.events.insert_many( + (event_serializer_registry[type(e)].serialize(e) for e in EVENTS) + ) + return client + + +@pytest.fixture +def mongo_repository(mongo_client, event_serializer_registry) -> IEventRepository: + return MongoEventRepository(mongo_client, event_serializer_registry) + + +@pytest.fixture +def error_raising_mongo_client(mongo_client) -> mongomock.MongoClient: + mongo_client = MagicMock(spec=mongomock.MongoClient) + mongo_client.monkey_island = MagicMock(spec=mongomock.Database) + mongo_client.monkey_island.events = MagicMock(spec=mongomock.Collection) + + mongo_client.monkey_island.events.find = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.events.insert_one = MagicMock( + side_effect=Exception("some exception") + ) + mongo_client.monkey_island.events.drop = MagicMock(side_effect=Exception("some exception")) + + return mongo_client + + +@pytest.fixture +def error_raising_mongo_repository( + error_raising_mongo_client, event_serializer_registry +) -> IEventRepository: + return MongoEventRepository(error_raising_mongo_client, event_serializer_registry) + + +def assert_same_contents(a, b): + assert len(a) == len(b) + for item in a: + assert item in b + + +def test_mongo_event_repository__save_event(mongo_repository: IEventRepository): + event = FakeAgentEvent(source=uuid.uuid4()) + mongo_repository.save_event(event) + events = mongo_repository.get_events() + + assert event in events + + +def test_mongo_event_repository__save_event_raises( + error_raising_mongo_repository: IEventRepository, +): + event = FakeAgentEvent(source=uuid.uuid4()) + + with pytest.raises(StorageError): + error_raising_mongo_repository.save_event(event) + + +def test_mongo_event_repository__get_events(mongo_repository: IEventRepository): + events = mongo_repository.get_events() + + assert_same_contents(events, EVENTS) + + +def test_mongo_event_repository__get_events_raises( + error_raising_mongo_repository: IEventRepository, +): + with pytest.raises(RetrievalError): + error_raising_mongo_repository.get_events() + + +def test_mongo_event_repository__get_events_by_type(mongo_repository: IEventRepository): + events = mongo_repository.get_events_by_type(FakeAgentItemEvent) + + expected_events = [EVENTS[3]] + assert_same_contents(events, expected_events) + + +def test_mongo_event_repository__get_events_by_type_raises( + error_raising_mongo_repository: IEventRepository, +): + with pytest.raises(RetrievalError): + error_raising_mongo_repository.get_events_by_type(FakeAgentItemEvent) + + +def test_mongo_event_repository__get_events_by_tag(mongo_repository: IEventRepository): + events = mongo_repository.get_events_by_tag("bar") + + expected_events = [EVENTS[1], EVENTS[2]] + assert_same_contents(events, expected_events) + + +def test_mongo_event_repository__get_events_by_tag_raises( + error_raising_mongo_repository: IEventRepository, +): + with pytest.raises(RetrievalError): + error_raising_mongo_repository.get_events_by_tag("bar") + + +def test_mongo_event_repository__get_events_by_source(mongo_repository: IEventRepository): + source_event = EVENTS[2] + events = mongo_repository.get_events_by_source(source_event.source) + + expected_events = [source_event] + assert_same_contents(events, expected_events) + + +def test_mongo_event_repository__get_events_by_source_raises( + error_raising_mongo_repository: IEventRepository, +): + with pytest.raises(RetrievalError): + source_event = EVENTS[2] + error_raising_mongo_repository.get_events_by_source(source_event.source) + + +def test_mongo_event_repository__reset(mongo_repository: IEventRepository): + initial_events = mongo_repository.get_events() + assert initial_events + + mongo_repository.reset() + events = mongo_repository.get_events() + + assert not events + + +def test_mongo_event_repository__reset_raises(error_raising_mongo_repository: IEventRepository): + with pytest.raises(RemovalError): + error_raising_mongo_repository.reset() diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 6aa14abad..23ec0f252 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -312,6 +312,7 @@ IEventRepository.save_event IEventRepository.get_events_by_type IEventRepository.get_events_by_tag IEventRepository.get_events_by_source +MongoEventRepository # pydantic base models