diff --git a/monkey/common/event_serializers/pydantic_event_serializer.py b/monkey/common/event_serializers/pydantic_event_serializer.py index 63332f00c..fab4f8c06 100644 --- a/monkey/common/event_serializers/pydantic_event_serializer.py +++ b/monkey/common/event_serializers/pydantic_event_serializer.py @@ -1,21 +1,22 @@ import logging -from typing import Type +from typing import Type, TypeVar -from common.base_models import InfectionMonkeyBaseModel from common.events import AbstractAgentEvent from . import IEventSerializer, JSONSerializable logger = logging.getLogger(__name__) +T = TypeVar("T", bound=AbstractAgentEvent) + class PydanticEventSerializer(IEventSerializer): - def __init__(self, event_class: Type[AbstractAgentEvent]): + def __init__(self, event_class: Type[T]): self._event_class = event_class - def serialize(self, event: AbstractAgentEvent) -> JSONSerializable: + def serialize(self, event: T) -> JSONSerializable: if not issubclass(event.__class__, self._event_class): - raise TypeError(f"Event object must be of type: {InfectionMonkeyBaseModel.__name__}") + raise TypeError(f"Event object must be of type: {self._event_class.__name__}") try: return event.dict() @@ -24,5 +25,5 @@ class PydanticEventSerializer(IEventSerializer): return None - def deserialize(self, serialized_event: JSONSerializable) -> AbstractAgentEvent: - return self._event_class.parse_obj(serialized_event) + def deserialize(self, serialized_event: JSONSerializable) -> T: + return self._event_class(**serialized_event) diff --git a/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py b/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py index ce38f2680..e3cb7a561 100644 --- a/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py +++ b/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py @@ -2,10 +2,9 @@ from abc import ABC from dataclasses import dataclass, field import pytest -from pydantic import ValidationError from common.base_models import InfectionMonkeyBaseModel -from common.event_serializers import PydanticEventSerializer +from common.event_serializers import IEventSerializer, PydanticEventSerializer from common.events import AbstractAgentEvent @@ -25,7 +24,7 @@ class PydanticEvent(InfectionMonkeyBaseModel): @pytest.fixture -def pydantic_event_serializer(): +def pydantic_event_serializer() -> IEventSerializer: return PydanticEventSerializer(PydanticEvent) @@ -36,7 +35,7 @@ def test_pydantic_event_serializer__serialize_wrong_type(pydantic_event_serializ def test_pydantic_event_serializer__deserialize_wrong_type(pydantic_event_serializer): - with pytest.raises(ValidationError): + with pytest.raises(TypeError): pydantic_event_serializer.deserialize("bla")