diff --git a/monkey/monkey_island/cc/repository/mongo_event_repository.py b/monkey/monkey_island/cc/repository/mongo_event_repository.py index dbf080a84..fe85b3216 100644 --- a/monkey/monkey_island/cc/repository/mongo_event_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_event_repository.py @@ -1,4 +1,4 @@ -from typing import Any, MutableMapping, Sequence, Type +from typing import Any, Dict, MutableMapping, Sequence, Type from pymongo import MongoClient @@ -28,8 +28,7 @@ class MongoEventRepository(IEventRepository): def get_events(self) -> Sequence[AbstractAgentEvent]: try: - serialized_events = list(self._events_collection.find()) - return list(map(self._deserialize, serialized_events)) + return self._query_events({}) except Exception as err: raise RetrievalError(f"Error retrieving events: {err}") @@ -37,24 +36,19 @@ class MongoEventRepository(IEventRepository): self, event_type: Type[AbstractAgentEvent] ) -> Sequence[AbstractAgentEvent]: try: - serialized_events = list( - self._events_collection.find({EVENT_TYPE_FIELD: event_type.__name__}) - ) - return list(map(self._deserialize, serialized_events)) + return self._query_events({EVENT_TYPE_FIELD: event_type.__name__}) except Exception as err: raise RetrievalError(f"Error retrieving events for type {event_type}: {err}") def get_events_by_tag(self, tag: str) -> Sequence[AbstractAgentEvent]: try: - serialized_events = list(self._events_collection.find({"tags": {"$in": [tag]}})) - return list(map(self._deserialize, serialized_events)) + return self._query_events({"tags": {"$in": [tag]}}) except Exception as err: raise RetrievalError(f"Error retrieving events for tag {tag}: {err}") def get_events_by_source(self, source: AgentID) -> Sequence[AbstractAgentEvent]: try: - serialized_events = list(self._events_collection.find({"source": str(source)})) - return list(map(self._deserialize, serialized_events)) + return self._query_events({"source": str(source)}) except Exception as err: raise RetrievalError(f"Error retrieving events for source {source}: {err}") @@ -65,7 +59,10 @@ class MongoEventRepository(IEventRepository): raise RemovalError(f"Error resetting the repository: {err}") def _deserialize(self, mongo_record: MutableMapping[str, Any]) -> AbstractAgentEvent: - del mongo_record[MONGO_OBJECT_ID_KEY] event_type = mongo_record[EVENT_TYPE_FIELD] serializer = self._serializers[event_type] return serializer.deserialize(mongo_record) + + def _query_events(self, query: Dict[Any, Any]) -> Sequence[AbstractAgentEvent]: + serialized_events = list(self._events_collection.find(query, {MONGO_OBJECT_ID_KEY: False})) + return list(map(self._deserialize, serialized_events))