diff --git a/monkey/infection_monkey/agent_event_forwarder.py b/monkey/infection_monkey/agent_event_forwarder.py index 3bfdf21e9..f253255d3 100644 --- a/monkey/infection_monkey/agent_event_forwarder.py +++ b/monkey/infection_monkey/agent_event_forwarder.py @@ -3,7 +3,7 @@ import queue import threading from time import sleep -from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable +from common.agent_event_serializers import AgentEventSerializerRegistry from common.agent_events import AbstractAgentEvent from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.utils.threading import create_daemon_thread @@ -33,14 +33,9 @@ class AgentEventForwarder: self._batching_agent_event_forwarder.stop() def send_event(self, event: AbstractAgentEvent): - serialized_event = self._serialize_event(event) - self._batching_agent_event_forwarder.add_event_to_queue(serialized_event) + self._batching_agent_event_forwarder.add_event_to_queue(event) logger.debug(f"Sending event of type {type(event).__name__} to the Island") - def _serialize_event(self, event: AbstractAgentEvent) -> JSONSerializable: - serializer = self._agent_event_serializer_registry[event.__class__] - return serializer.serialize(event) - class BatchingAgentEventForwarder: """ @@ -62,7 +57,7 @@ class BatchingAgentEventForwarder: ) self._batch_and_send_thread.start() - def add_event_to_queue(self, serialized_event: JSONSerializable): + def add_event_to_queue(self, serialized_event: AbstractAgentEvent): self._queue.put(serialized_event) def _manage_event_batches(self): diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 1d816f6b9..7118e3e37 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,11 +1,12 @@ import functools import logging -from typing import Sequence +from typing import List, Sequence import requests -from common.agent_event_serializers import JSONSerializable +from common.agent_events import AbstractAgentEvent from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT +from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from . import ( IIslandAPIClient, @@ -49,7 +50,11 @@ class HTTPIslandAPIClient(IIslandAPIClient): """ @handle_island_errors - def __init__(self, island_server: str): + def __init__( + self, + island_server: str, + agent_event_serializer_registry: AgentEventSerializerRegistry = None, + ): response = requests.get( # noqa: DUO123 f"https://{island_server}/api?action=is-up", verify=False, @@ -60,6 +65,8 @@ class HTTPIslandAPIClient(IIslandAPIClient): self._island_server = island_server self._api_url = f"https://{self._island_server}/api" + self._agent_event_serializer_registry = agent_event_serializer_registry + @handle_island_errors def send_log(self, log_contents: str): response = requests.post( # noqa: DUO123 @@ -85,9 +92,21 @@ class HTTPIslandAPIClient(IIslandAPIClient): def send_events(self, events: Sequence[JSONSerializable]): response = requests.post( # noqa: DUO123 f"{self._api_url}/agent-events", - json=events, + json=self._serialize_events(events), verify=False, timeout=MEDIUM_REQUEST_TIMEOUT, ) response.raise_for_status() + + def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: + serialized_events: List[JSONSerializable] = [] + + try: + for e in events: + serializer = self._agent_event_serializer_registry[e.__class__] + serialized_events.append(serializer.serialize(e)) + except Exception as err: + raise IslandAPIRequestError(err) + + return serialized_events diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index 84c3052ca..d75f2dea5 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Sequence -from common.agent_event_serializers import JSONSerializable +from common.agent_events import AbstractAgentEvent class IIslandAPIClient(ABC): @@ -59,11 +59,11 @@ class IIslandAPIClient(ABC): """ @abstractmethod - def send_events(self, events: Sequence[JSONSerializable]): + def send_events(self, events: Sequence[AbstractAgentEvent]): """ - Send a sequence of Agent events to the Island + Send a sequence of agent events to the Island - :param events: A sequence of Agent events + :param events: A sequence of agent events :raises IslandAPIConnectionError: If the client cannot successfully connect to the island :raises IslandAPIRequestError: If an error occurs while attempting to connect to the island due to an issue in the request sent from the client diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 6562bf153..607b18e7b 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -1,6 +1,12 @@ import pytest +from uuid import UUID import requests import requests_mock +from common.agent_events import AbstractAgentEvent +from common.agent_event_serializers import ( + AgentEventSerializerRegistry, + PydanticAgentEventSerializer, +) from infection_monkey.island_api_client import ( HTTPIslandAPIClient, @@ -19,6 +25,8 @@ ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" +AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") + @pytest.mark.parametrize( "actual_error, expected_error", @@ -121,6 +129,69 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error island_api_client.get_pba_file(filename=PBA_FILE) +class Event1(AbstractAgentEvent): + a: int + + +class Event2(AbstractAgentEvent): + b: str + + +class Event3(AbstractAgentEvent): + c: int + +@pytest.fixture +def agent_event_serializer_registry(): + agent_event_serializer_registry = AgentEventSerializerRegistry() + agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1) + agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2) + + return agent_event_serializer_registry + + +def test_island_api_client_send_events__serialization(agent_event_serializer_registry): + events_to_send = [ + Event1(source=AGENT_ID, timestamp=0, a=1), + Event2(source=AGENT_ID, timestamp=0, b="hello"), + ] + expected_json = [ + { + "source": "80988359-a1cd-42a2-9b47-5b94b37cd673", + "target": None, + "timestamp": 0.0, + "tags": [], + "a": 1, + "type": "Event1", + }, + { + "source": "80988359-a1cd-42a2-9b47-5b94b37cd673", + "target": None, + "timestamp": 0.0, + "tags": [], + "b": "hello", + "type": "Event2", + }, + ] + + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + m.post(ISLAND_SEND_EVENTS_URI) + island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + + island_api_client.send_events(events=events_to_send) + + assert m.last_request.json() == expected_json + +def test_island_api_client_send_events__serialization_failed(agent_event_serializer_registry): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + + with pytest.raises(IslandAPIRequestError): + m.post(ISLAND_SEND_EVENTS_URI) + island_api_client.send_events(events=[Event3(source=AGENT_ID, c=1)]) + + @pytest.mark.parametrize( "actual_error, expected_error", [ @@ -129,14 +200,16 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error (Exception, IslandAPIError), ], ) -def test_island_api_client__send_events(actual_error, expected_error): +def test_island_api_client__send_events( + actual_error, expected_error, agent_event_serializer_registry +): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error) - island_api_client.send_events(events="some_data") + island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)]) @pytest.mark.parametrize( @@ -146,11 +219,13 @@ def test_island_api_client__send_events(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_send_events__status_code(status_code, expected_error): +def test_island_api_client_send_events__status_code( + status_code, expected_error, agent_event_serializer_registry +): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) - island_api_client.send_events(events="some_data") + island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])