Agent: Move AbstractAgentEvent serialization to HTTPIslandAPIClient

This commit is contained in:
Mike Salvatore 2022-09-20 09:45:00 -04:00
parent 34a4d81336
commit 17cb77cfdd
4 changed files with 111 additions and 22 deletions

View File

@ -3,7 +3,7 @@ import queue
import threading import threading
from time import sleep from time import sleep
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_event_serializers import AgentEventSerializerRegistry
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.island_api_client import IIslandAPIClient
from infection_monkey.utils.threading import create_daemon_thread from infection_monkey.utils.threading import create_daemon_thread
@ -33,14 +33,9 @@ class AgentEventForwarder:
self._batching_agent_event_forwarder.stop() self._batching_agent_event_forwarder.stop()
def send_event(self, event: AbstractAgentEvent): def send_event(self, event: AbstractAgentEvent):
serialized_event = self._serialize_event(event) self._batching_agent_event_forwarder.add_event_to_queue(event)
self._batching_agent_event_forwarder.add_event_to_queue(serialized_event)
logger.debug(f"Sending event of type {type(event).__name__} to the Island") logger.debug(f"Sending event of type {type(event).__name__} to the Island")
def _serialize_event(self, event: AbstractAgentEvent) -> JSONSerializable:
serializer = self._agent_event_serializer_registry[event.__class__]
return serializer.serialize(event)
class BatchingAgentEventForwarder: class BatchingAgentEventForwarder:
""" """
@ -62,7 +57,7 @@ class BatchingAgentEventForwarder:
) )
self._batch_and_send_thread.start() self._batch_and_send_thread.start()
def add_event_to_queue(self, serialized_event: JSONSerializable): def add_event_to_queue(self, serialized_event: AbstractAgentEvent):
self._queue.put(serialized_event) self._queue.put(serialized_event)
def _manage_event_batches(self): def _manage_event_batches(self):

View File

@ -1,11 +1,12 @@
import functools import functools
import logging import logging
from typing import Sequence from typing import List, Sequence
import requests import requests
from common.agent_event_serializers import JSONSerializable from common.agent_events import AbstractAgentEvent
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
from . import ( from . import (
IIslandAPIClient, IIslandAPIClient,
@ -49,7 +50,11 @@ class HTTPIslandAPIClient(IIslandAPIClient):
""" """
@handle_island_errors @handle_island_errors
def __init__(self, island_server: str): def __init__(
self,
island_server: str,
agent_event_serializer_registry: AgentEventSerializerRegistry = None,
):
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
f"https://{island_server}/api?action=is-up", f"https://{island_server}/api?action=is-up",
verify=False, verify=False,
@ -60,6 +65,8 @@ class HTTPIslandAPIClient(IIslandAPIClient):
self._island_server = island_server self._island_server = island_server
self._api_url = f"https://{self._island_server}/api" self._api_url = f"https://{self._island_server}/api"
self._agent_event_serializer_registry = agent_event_serializer_registry
@handle_island_errors @handle_island_errors
def send_log(self, log_contents: str): def send_log(self, log_contents: str):
response = requests.post( # noqa: DUO123 response = requests.post( # noqa: DUO123
@ -85,9 +92,21 @@ class HTTPIslandAPIClient(IIslandAPIClient):
def send_events(self, events: Sequence[JSONSerializable]): def send_events(self, events: Sequence[JSONSerializable]):
response = requests.post( # noqa: DUO123 response = requests.post( # noqa: DUO123
f"{self._api_url}/agent-events", f"{self._api_url}/agent-events",
json=events, json=self._serialize_events(events),
verify=False, verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT, timeout=MEDIUM_REQUEST_TIMEOUT,
) )
response.raise_for_status() response.raise_for_status()
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
serialized_events: List[JSONSerializable] = []
try:
for e in events:
serializer = self._agent_event_serializer_registry[e.__class__]
serialized_events.append(serializer.serialize(e))
except Exception as err:
raise IslandAPIRequestError(err)
return serialized_events

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Sequence from typing import Sequence
from common.agent_event_serializers import JSONSerializable from common.agent_events import AbstractAgentEvent
class IIslandAPIClient(ABC): class IIslandAPIClient(ABC):
@ -59,11 +59,11 @@ class IIslandAPIClient(ABC):
""" """
@abstractmethod @abstractmethod
def send_events(self, events: Sequence[JSONSerializable]): def send_events(self, events: Sequence[AbstractAgentEvent]):
""" """
Send a sequence of Agent events to the Island Send a sequence of agent events to the Island
:param events: A sequence of Agent events :param events: A sequence of agent events
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island :raises IslandAPIConnectionError: If the client cannot successfully connect to the island
:raises IslandAPIRequestError: If an error occurs while attempting to connect to the :raises IslandAPIRequestError: If an error occurs while attempting to connect to the
island due to an issue in the request sent from the client island due to an issue in the request sent from the client

View File

@ -1,6 +1,12 @@
import pytest import pytest
from uuid import UUID
import requests import requests
import requests_mock import requests_mock
from common.agent_events import AbstractAgentEvent
from common.agent_event_serializers import (
AgentEventSerializerRegistry,
PydanticAgentEventSerializer,
)
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
HTTPIslandAPIClient, HTTPIslandAPIClient,
@ -19,6 +25,8 @@ ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"actual_error, expected_error", "actual_error, expected_error",
@ -121,6 +129,69 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error
island_api_client.get_pba_file(filename=PBA_FILE) island_api_client.get_pba_file(filename=PBA_FILE)
class Event1(AbstractAgentEvent):
a: int
class Event2(AbstractAgentEvent):
b: str
class Event3(AbstractAgentEvent):
c: int
@pytest.fixture
def agent_event_serializer_registry():
agent_event_serializer_registry = AgentEventSerializerRegistry()
agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1)
agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2)
return agent_event_serializer_registry
def test_island_api_client_send_events__serialization(agent_event_serializer_registry):
events_to_send = [
Event1(source=AGENT_ID, timestamp=0, a=1),
Event2(source=AGENT_ID, timestamp=0, b="hello"),
]
expected_json = [
{
"source": "80988359-a1cd-42a2-9b47-5b94b37cd673",
"target": None,
"timestamp": 0.0,
"tags": [],
"a": 1,
"type": "Event1",
},
{
"source": "80988359-a1cd-42a2-9b47-5b94b37cd673",
"target": None,
"timestamp": 0.0,
"tags": [],
"b": "hello",
"type": "Event2",
},
]
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
m.post(ISLAND_SEND_EVENTS_URI)
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
island_api_client.send_events(events=events_to_send)
assert m.last_request.json() == expected_json
def test_island_api_client_send_events__serialization_failed(agent_event_serializer_registry):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
with pytest.raises(IslandAPIRequestError):
m.post(ISLAND_SEND_EVENTS_URI)
island_api_client.send_events(events=[Event3(source=AGENT_ID, c=1)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"actual_error, expected_error", "actual_error, expected_error",
[ [
@ -129,14 +200,16 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error
(Exception, IslandAPIError), (Exception, IslandAPIError),
], ],
) )
def test_island_api_client__send_events(actual_error, expected_error): def test_island_api_client__send_events(
actual_error, expected_error, agent_event_serializer_registry
):
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get(ISLAND_URI) m.get(ISLAND_URI)
island_api_client = HTTPIslandAPIClient(SERVER) island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
with pytest.raises(expected_error): with pytest.raises(expected_error):
m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error) m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error)
island_api_client.send_events(events="some_data") island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -146,11 +219,13 @@ def test_island_api_client__send_events(actual_error, expected_error):
(501, IslandAPIRequestFailedError), (501, IslandAPIRequestFailedError),
], ],
) )
def test_island_api_client_send_events__status_code(status_code, expected_error): def test_island_api_client_send_events__status_code(
status_code, expected_error, agent_event_serializer_registry
):
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get(ISLAND_URI) m.get(ISLAND_URI)
island_api_client = HTTPIslandAPIClient(SERVER) island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
with pytest.raises(expected_error): with pytest.raises(expected_error):
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
island_api_client.send_events(events="some_data") island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])