Island: Use EventSerializerRegistry

This commit is contained in:
Kekoa Kaaikala 2022-09-14 17:20:38 +00:00
parent f5e398c175
commit ad5bba9e2f
2 changed files with 58 additions and 25 deletions

View File

@ -1,27 +1,33 @@
from typing import Sequence, Type from typing import Any, MutableMapping, Sequence, Type
from pymongo import MongoClient from pymongo import MongoClient
from common.event_serializers import EVENT_TYPE_FIELD, EventSerializerRegistry
from common.events import AbstractAgentEvent from common.events import AbstractAgentEvent
from common.types import AgentID from common.types import AgentID
from monkey_island.cc.repository import IEventRepository from monkey_island.cc.repository import IEventRepository
from . import RemovalError, RetrievalError, StorageError from . import RemovalError, RetrievalError, StorageError
from .consts import MONGO_OBJECT_ID_KEY
class MongoEventRepository(IEventRepository): class MongoEventRepository(IEventRepository):
def __init__(self, mongo_client: MongoClient): def __init__(self, mongo_client: MongoClient, serializer_registry: EventSerializerRegistry):
self._events_collection = mongo_client.monkey_island.events self._events_collection = mongo_client.monkey_island.events
self._serializers = serializer_registry
def save_event(self, event: AbstractAgentEvent): def save_event(self, event: AbstractAgentEvent):
try: try:
self._events_collection.insert_one(event.dict(simplify=True)) serializer = self._serializers[type(event)]
serialized_event = serializer.serialize(event)
self._events_collection.insert_one(serialized_event)
except Exception as err: except Exception as err:
raise StorageError(err) raise StorageError(err)
def get_events(self) -> Sequence[AbstractAgentEvent]: def get_events(self) -> Sequence[AbstractAgentEvent]:
try: try:
return list(self._events_collection.find()) serialized_events = list(self._events_collection.find())
return list(map(self._deserialize, serialized_events))
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving events: {err}") raise RetrievalError(f"Error retrieving events: {err}")
@ -29,19 +35,22 @@ class MongoEventRepository(IEventRepository):
self, event_type: Type[AbstractAgentEvent] self, event_type: Type[AbstractAgentEvent]
) -> Sequence[AbstractAgentEvent]: ) -> Sequence[AbstractAgentEvent]:
try: try:
return [] serialized_events = list(self._events_collection.find({EVENT_TYPE_FIELD: event_type}))
return list(map(self._deserialize, serialized_events))
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving events for type {event_type}: {err}") raise RetrievalError(f"Error retrieving events for type {event_type}: {err}")
def get_events_by_tag(self, tag: str) -> Sequence[AbstractAgentEvent]: def get_events_by_tag(self, tag: str) -> Sequence[AbstractAgentEvent]:
try: try:
return list(self._events_collection.find({"tags": {"$in": [tag]}})) serialized_events = list(self._events_collection.find({"tags": {"$in": [tag]}}))
return list(map(self._deserialize, serialized_events))
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving events for tag {tag}: {err}") raise RetrievalError(f"Error retrieving events for tag {tag}: {err}")
def get_events_by_source(self, source: AgentID) -> Sequence[AbstractAgentEvent]: def get_events_by_source(self, source: AgentID) -> Sequence[AbstractAgentEvent]:
try: try:
return list(self._events_collection.find({"source": source})) serialized_events = list(self._events_collection.find({"source": source}))
return list(map(self._deserialize, serialized_events))
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving events for source {source}: {err}") raise RetrievalError(f"Error retrieving events for source {source}: {err}")
@ -50,3 +59,9 @@ class MongoEventRepository(IEventRepository):
self._events_collection.drop() self._events_collection.drop()
except Exception as err: except Exception as err:
raise RemovalError(f"Error resetting the repository: {err}") raise RemovalError(f"Error resetting the repository: {err}")
def _deserialize(self, mongo_record: MutableMapping[str, Any]) -> AbstractAgentEvent:
del mongo_record[MONGO_OBJECT_ID_KEY]
event_type = mongo_record[EVENT_TYPE_FIELD]
serializer = self._serializers[event_type]
return serializer.deserialize(mongo_record)

View File

@ -4,9 +4,11 @@ from unittest.mock import MagicMock
import mongomock import mongomock
import pytest import pytest
from pydantic import Field
from monkey.common.events.abstract_agent_event import AbstractAgentEvent from common.event_serializers import EventSerializerRegistry, PydanticEventSerializer
from monkey.monkey_island.cc.repository import ( from common.events import AbstractAgentEvent
from monkey_island.cc.repository import (
IEventRepository, IEventRepository,
MongoEventRepository, MongoEventRepository,
RemovalError, RemovalError,
@ -16,27 +18,41 @@ from monkey.monkey_island.cc.repository import (
class FakeAgentEvent(AbstractAgentEvent): class FakeAgentEvent(AbstractAgentEvent):
data = Field(default=435)
class FakeAgentItemEvent(AbstractAgentEvent):
item: str item: str
EVENTS: List[AbstractAgentEvent] = [ EVENTS: List[AbstractAgentEvent] = [
AbstractAgentEvent(source=uuid.uuid4(), tags={"foo"}), FakeAgentEvent(source=uuid.uuid4(), tags={"foo"}),
AbstractAgentEvent(source=uuid.uuid4(), tags={"foo", "bar"}), FakeAgentEvent(source=uuid.uuid4(), tags={"foo", "bar"}),
AbstractAgentEvent(source=uuid.uuid4(), tags={"bar", "baz"}), FakeAgentEvent(source=uuid.uuid4(), tags={"bar", "baz"}),
FakeAgentEvent(source=uuid.uuid4(), tags={"baz"}, item="blah"), FakeAgentItemEvent(source=uuid.uuid4(), tags={"baz"}, item="blah"),
] ]
@pytest.fixture @pytest.fixture
def mongo_client(): def event_serializer_registry() -> EventSerializerRegistry:
registry = EventSerializerRegistry()
registry[FakeAgentEvent] = PydanticEventSerializer(FakeAgentEvent)
registry[FakeAgentItemEvent] = PydanticEventSerializer(FakeAgentItemEvent)
return registry
@pytest.fixture
def mongo_client(event_serializer_registry):
client = mongomock.MongoClient() client = mongomock.MongoClient()
client.monkey_island.events.insert_many((e.dict(simplify=True) for e in EVENTS)) client.monkey_island.events.insert_many(
(event_serializer_registry[type(e)].serialize(e) for e in EVENTS)
)
return client return client
@pytest.fixture @pytest.fixture
def mongo_repository(mongo_client) -> IEventRepository: def mongo_repository(mongo_client, event_serializer_registry) -> IEventRepository:
return MongoEventRepository(mongo_client) return MongoEventRepository(mongo_client, event_serializer_registry)
@pytest.fixture @pytest.fixture
@ -63,18 +79,20 @@ def error_raising_mongo_client(mongo_client) -> mongomock.MongoClient:
@pytest.fixture @pytest.fixture
def error_raising_mongo_repository(error_raising_mongo_client) -> IEventRepository: def error_raising_mongo_repository(
return MongoEventRepository(error_raising_mongo_client) error_raising_mongo_client, event_serializer_registry
) -> IEventRepository:
return MongoEventRepository(error_raising_mongo_client, event_serializer_registry)
def assert_same_contents(a, b): def assert_same_contents(a, b):
assert len(a) == len(b) assert len(a) == len(b)
difference = set(a) ^ set(b) for item in a:
assert not difference assert item in b
def test_mongo_event_repository__save_event(mongo_repository: IEventRepository): def test_mongo_event_repository__save_event(mongo_repository: IEventRepository):
event = AbstractAgentEvent(source=uuid.uuid4()) event = FakeAgentEvent(source=uuid.uuid4())
mongo_repository.save_event(event) mongo_repository.save_event(event)
events = mongo_repository.get_events() events = mongo_repository.get_events()
@ -84,7 +102,7 @@ def test_mongo_event_repository__save_event(mongo_repository: IEventRepository):
def test_mongo_event_repository__save_event_raises( def test_mongo_event_repository__save_event_raises(
error_raising_mongo_repository: IEventRepository, error_raising_mongo_repository: IEventRepository,
): ):
event = AbstractAgentEvent(source=uuid.uuid4()) event = FakeAgentEvent(source=uuid.uuid4())
with pytest.raises(StorageError): with pytest.raises(StorageError):
error_raising_mongo_repository.save_event(event) error_raising_mongo_repository.save_event(event)
@ -104,7 +122,7 @@ def test_mongo_event_repository__get_events_raises(
def test_mongo_event_repository__get_events_by_type(mongo_repository: IEventRepository): def test_mongo_event_repository__get_events_by_type(mongo_repository: IEventRepository):
events = mongo_repository.get_events_by_type(FakeAgentEvent) events = mongo_repository.get_events_by_type(FakeAgentItemEvent)
expected_events = [EVENTS[3]] expected_events = [EVENTS[3]]
assert_same_contents(events, expected_events) assert_same_contents(events, expected_events)
@ -114,7 +132,7 @@ def test_mongo_event_repository__get_events_by_type_raises(
error_raising_mongo_repository: IEventRepository, error_raising_mongo_repository: IEventRepository,
): ):
with pytest.raises(RetrievalError): with pytest.raises(RetrievalError):
error_raising_mongo_repository.get_events_by_type(FakeAgentEvent) error_raising_mongo_repository.get_events_by_type(FakeAgentItemEvent)
def test_mongo_event_repository__get_events_by_tag(mongo_repository: IEventRepository): def test_mongo_event_repository__get_events_by_tag(mongo_repository: IEventRepository):