diff --git a/monkey/common/event_serializers/pydantic_event_serializer.py b/monkey/common/event_serializers/pydantic_event_serializer.py index 263e3463b..2291f3de9 100644 --- a/monkey/common/event_serializers/pydantic_event_serializer.py +++ b/monkey/common/event_serializers/pydantic_event_serializer.py @@ -2,8 +2,9 @@ import logging from typing import Generic, Type, TypeVar from common.events import AbstractAgentEvent +from common.utils.code_utils import del_key -from . import IEventSerializer, JSONSerializable +from . import EVENT_TYPE_FIELD, IEventSerializer, JSONSerializable logger = logging.getLogger(__name__) @@ -18,7 +19,19 @@ class PydanticEventSerializer(IEventSerializer, Generic[T]): if not isinstance(event, self._event_class): raise TypeError(f"Event object must be of type: {self._event_class.__name__}") - return event.dict(simplify=True) + event_dict = event.dict(simplify=True) + event_dict[EVENT_TYPE_FIELD] = type(event).__name__ + + return event_dict def deserialize(self, serialized_event: JSONSerializable) -> T: - return self._event_class(**serialized_event) + if not isinstance(serialized_event, dict): + raise TypeError( + "Serialized pydantic events must be a dictionary, but got {type(serialized_event)}" + ) + + # pydantic serialized events will always be dicts with a copy() method + event_dict = serialized_event.copy() # type: ignore[union-attr] + del_key(event_dict, EVENT_TYPE_FIELD) + + return self._event_class(**event_dict) diff --git a/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py b/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py index 94bf3fc14..4958f0303 100644 --- a/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py +++ b/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py @@ -5,7 +5,7 @@ from uuid import UUID import pytest from pydantic import Field -from common.event_serializers import IEventSerializer, PydanticEventSerializer +from common.event_serializers import EVENT_TYPE_FIELD, IEventSerializer, PydanticEventSerializer from common.events import AbstractAgentEvent AGENT_ID = UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5") @@ -52,3 +52,10 @@ def test_pydanitc_event_serializer__de_serialize(pydantic_event_serializer): assert type(serialized_event) != type(deserialized_object) assert deserialized_object == pydantic_event + + +def test_pydanitc_event_serializer__serialize_inclued_type(pydantic_event_serializer): + pydantic_event = PydanticEvent(source=AGENT_ID, some_field="some_field") + + serialized_event = pydantic_event_serializer.serialize(pydantic_event) + assert serialized_event[EVENT_TYPE_FIELD] == PydanticEvent.__name__