diff --git a/monkey/common/event_serializers/__init__.py b/monkey/common/event_serializers/__init__.py index 2b60471b1..1adb5e3dd 100644 --- a/monkey/common/event_serializers/__init__.py +++ b/monkey/common/event_serializers/__init__.py @@ -1,2 +1,3 @@ -from .i_event_serialize import IEventSerializer +from .i_event_serialize import IEventSerializer, JSONSerializable from .event_serializer_registry import EventSerializerRegistry +from .pydantic_event_serializer import PydanticEventSerializer diff --git a/monkey/common/event_serializers/i_event_serialize.py b/monkey/common/event_serializers/i_event_serialize.py index 75d95be3c..a092946ab 100644 --- a/monkey/common/event_serializers/i_event_serialize.py +++ b/monkey/common/event_serializers/i_event_serialize.py @@ -3,8 +3,14 @@ from typing import Dict, List, Union from common.events import AbstractAgentEvent -JSONSerializable = Union[ - Dict[str, "JSONSerializable"], List["JSONSerializable"], int, str, float, bool, None +JSONSerializable = Union[ # type: ignore[misc] + Dict[str, "JSONSerializable"], # type: ignore[misc] + List["JSONSerializable"], # type: ignore[misc] + int, + str, + float, + bool, + None, ] diff --git a/monkey/common/event_serializers/pydantic_event_serializer.py b/monkey/common/event_serializers/pydantic_event_serializer.py new file mode 100644 index 000000000..263e3463b --- /dev/null +++ b/monkey/common/event_serializers/pydantic_event_serializer.py @@ -0,0 +1,24 @@ +import logging +from typing import Generic, Type, TypeVar + +from common.events import AbstractAgentEvent + +from . import IEventSerializer, JSONSerializable + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=AbstractAgentEvent) + + +class PydanticEventSerializer(IEventSerializer, Generic[T]): + def __init__(self, event_class: Type[T]): + self._event_class = event_class + + def serialize(self, event: T) -> JSONSerializable: + if not isinstance(event, self._event_class): + raise TypeError(f"Event object must be of type: {self._event_class.__name__}") + + return event.dict(simplify=True) + + def deserialize(self, serialized_event: JSONSerializable) -> T: + return self._event_class(**serialized_event) diff --git a/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py b/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py new file mode 100644 index 000000000..e3cb7a561 --- /dev/null +++ b/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py @@ -0,0 +1,49 @@ +from abc import ABC +from dataclasses import dataclass, field + +import pytest + +from common.base_models import InfectionMonkeyBaseModel +from common.event_serializers import IEventSerializer, PydanticEventSerializer +from common.events import AbstractAgentEvent + + +@dataclass(frozen=True) +class NotAgentEvent(ABC): + some_field: int + other_field: float + + +@dataclass(frozen=True) +class SomeAgentEvent(AbstractAgentEvent): + bogus: int = field(default_factory=int) + + +class PydanticEvent(InfectionMonkeyBaseModel): + some_field: str + + +@pytest.fixture +def pydantic_event_serializer() -> IEventSerializer: + return PydanticEventSerializer(PydanticEvent) + + +@pytest.mark.parametrize("event", [NotAgentEvent(1, 2.0), SomeAgentEvent(2)]) +def test_pydantic_event_serializer__serialize_wrong_type(pydantic_event_serializer, event): + with pytest.raises(TypeError): + pydantic_event_serializer.serialize(event) + + +def test_pydantic_event_serializer__deserialize_wrong_type(pydantic_event_serializer): + with pytest.raises(TypeError): + pydantic_event_serializer.deserialize("bla") + + +def test_pydanitc_event_serializer__de_serialize(pydantic_event_serializer): + pydantic_event = PydanticEvent(some_field="some_field") + + serialized_event = pydantic_event_serializer.serialize(pydantic_event) + deserialized_object = pydantic_event_serializer.deserialize(serialized_event) + + assert type(serialized_event) != type(deserialized_object) + assert deserialized_object == pydantic_event diff --git a/vulture_allowlist.py b/vulture_allowlist.py index a8369b7da..c77390784 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -298,6 +298,7 @@ serialize event deserialize serialized_event +PydanticEventSerializer # pydantic base models underscore_attrs_are_private