Agent: Move AbstractAgentEvent serialization to HTTPIslandAPIClient

This commit is contained in:
Mike Salvatore 2022-09-20 09:45:00 -04:00
parent 34a4d81336
commit 17cb77cfdd
4 changed files with 111 additions and 22 deletions

View File

@ -3,7 +3,7 @@ import queue
import threading
from time import sleep
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
from common.agent_event_serializers import AgentEventSerializerRegistry
from common.agent_events import AbstractAgentEvent
from infection_monkey.island_api_client import IIslandAPIClient
from infection_monkey.utils.threading import create_daemon_thread
@ -33,14 +33,9 @@ class AgentEventForwarder:
self._batching_agent_event_forwarder.stop()
def send_event(self, event: AbstractAgentEvent):
serialized_event = self._serialize_event(event)
self._batching_agent_event_forwarder.add_event_to_queue(serialized_event)
self._batching_agent_event_forwarder.add_event_to_queue(event)
logger.debug(f"Sending event of type {type(event).__name__} to the Island")
def _serialize_event(self, event: AbstractAgentEvent) -> JSONSerializable:
serializer = self._agent_event_serializer_registry[event.__class__]
return serializer.serialize(event)
class BatchingAgentEventForwarder:
"""
@ -62,7 +57,7 @@ class BatchingAgentEventForwarder:
)
self._batch_and_send_thread.start()
def add_event_to_queue(self, serialized_event: JSONSerializable):
def add_event_to_queue(self, serialized_event: AbstractAgentEvent):
self._queue.put(serialized_event)
def _manage_event_batches(self):

View File

@ -1,11 +1,12 @@
import functools
import logging
from typing import Sequence
from typing import List, Sequence
import requests
from common.agent_event_serializers import JSONSerializable
from common.agent_events import AbstractAgentEvent
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
from . import (
IIslandAPIClient,
@ -49,7 +50,11 @@ class HTTPIslandAPIClient(IIslandAPIClient):
"""
@handle_island_errors
def __init__(self, island_server: str):
def __init__(
self,
island_server: str,
agent_event_serializer_registry: AgentEventSerializerRegistry = None,
):
response = requests.get( # noqa: DUO123
f"https://{island_server}/api?action=is-up",
verify=False,
@ -60,6 +65,8 @@ class HTTPIslandAPIClient(IIslandAPIClient):
self._island_server = island_server
self._api_url = f"https://{self._island_server}/api"
self._agent_event_serializer_registry = agent_event_serializer_registry
@handle_island_errors
def send_log(self, log_contents: str):
response = requests.post( # noqa: DUO123
@ -85,9 +92,21 @@ class HTTPIslandAPIClient(IIslandAPIClient):
def send_events(self, events: Sequence[JSONSerializable]):
response = requests.post( # noqa: DUO123
f"{self._api_url}/agent-events",
json=events,
json=self._serialize_events(events),
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
response.raise_for_status()
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
serialized_events: List[JSONSerializable] = []
try:
for e in events:
serializer = self._agent_event_serializer_registry[e.__class__]
serialized_events.append(serializer.serialize(e))
except Exception as err:
raise IslandAPIRequestError(err)
return serialized_events

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Sequence
from common.agent_event_serializers import JSONSerializable
from common.agent_events import AbstractAgentEvent
class IIslandAPIClient(ABC):
@ -59,11 +59,11 @@ class IIslandAPIClient(ABC):
"""
@abstractmethod
def send_events(self, events: Sequence[JSONSerializable]):
def send_events(self, events: Sequence[AbstractAgentEvent]):
"""
Send a sequence of Agent events to the Island
Send a sequence of agent events to the Island
:param events: A sequence of Agent events
:param events: A sequence of agent events
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
:raises IslandAPIRequestError: If an error occurs while attempting to connect to the
island due to an issue in the request sent from the client

View File

@ -1,6 +1,12 @@
import pytest
from uuid import UUID
import requests
import requests_mock
from common.agent_events import AbstractAgentEvent
from common.agent_event_serializers import (
AgentEventSerializerRegistry,
PydanticAgentEventSerializer,
)
from infection_monkey.island_api_client import (
HTTPIslandAPIClient,
@ -19,6 +25,8 @@ ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
@pytest.mark.parametrize(
"actual_error, expected_error",
@ -121,6 +129,69 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error
island_api_client.get_pba_file(filename=PBA_FILE)
class Event1(AbstractAgentEvent):
a: int
class Event2(AbstractAgentEvent):
b: str
class Event3(AbstractAgentEvent):
c: int
@pytest.fixture
def agent_event_serializer_registry():
agent_event_serializer_registry = AgentEventSerializerRegistry()
agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1)
agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2)
return agent_event_serializer_registry
def test_island_api_client_send_events__serialization(agent_event_serializer_registry):
events_to_send = [
Event1(source=AGENT_ID, timestamp=0, a=1),
Event2(source=AGENT_ID, timestamp=0, b="hello"),
]
expected_json = [
{
"source": "80988359-a1cd-42a2-9b47-5b94b37cd673",
"target": None,
"timestamp": 0.0,
"tags": [],
"a": 1,
"type": "Event1",
},
{
"source": "80988359-a1cd-42a2-9b47-5b94b37cd673",
"target": None,
"timestamp": 0.0,
"tags": [],
"b": "hello",
"type": "Event2",
},
]
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
m.post(ISLAND_SEND_EVENTS_URI)
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
island_api_client.send_events(events=events_to_send)
assert m.last_request.json() == expected_json
def test_island_api_client_send_events__serialization_failed(agent_event_serializer_registry):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
with pytest.raises(IslandAPIRequestError):
m.post(ISLAND_SEND_EVENTS_URI)
island_api_client.send_events(events=[Event3(source=AGENT_ID, c=1)])
@pytest.mark.parametrize(
"actual_error, expected_error",
[
@ -129,14 +200,16 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error
(Exception, IslandAPIError),
],
)
def test_island_api_client__send_events(actual_error, expected_error):
def test_island_api_client__send_events(
actual_error, expected_error, agent_event_serializer_registry
):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client = HTTPIslandAPIClient(SERVER)
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
with pytest.raises(expected_error):
m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error)
island_api_client.send_events(events="some_data")
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
@pytest.mark.parametrize(
@ -146,11 +219,13 @@ def test_island_api_client__send_events(actual_error, expected_error):
(501, IslandAPIRequestFailedError),
],
)
def test_island_api_client_send_events__status_code(status_code, expected_error):
def test_island_api_client_send_events__status_code(
status_code, expected_error, agent_event_serializer_registry
):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client = HTTPIslandAPIClient(SERVER)
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
with pytest.raises(expected_error):
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
island_api_client.send_events(events="some_data")
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])