Agent: Move AbstractAgentEvent serialization to HTTPIslandAPIClient
This commit is contained in:
parent
34a4d81336
commit
17cb77cfdd
|
@ -3,7 +3,7 @@ import queue
|
|||
import threading
|
||||
from time import sleep
|
||||
|
||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||
from common.agent_event_serializers import AgentEventSerializerRegistry
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from infection_monkey.island_api_client import IIslandAPIClient
|
||||
from infection_monkey.utils.threading import create_daemon_thread
|
||||
|
@ -33,14 +33,9 @@ class AgentEventForwarder:
|
|||
self._batching_agent_event_forwarder.stop()
|
||||
|
||||
def send_event(self, event: AbstractAgentEvent):
|
||||
serialized_event = self._serialize_event(event)
|
||||
self._batching_agent_event_forwarder.add_event_to_queue(serialized_event)
|
||||
self._batching_agent_event_forwarder.add_event_to_queue(event)
|
||||
logger.debug(f"Sending event of type {type(event).__name__} to the Island")
|
||||
|
||||
def _serialize_event(self, event: AbstractAgentEvent) -> JSONSerializable:
|
||||
serializer = self._agent_event_serializer_registry[event.__class__]
|
||||
return serializer.serialize(event)
|
||||
|
||||
|
||||
class BatchingAgentEventForwarder:
|
||||
"""
|
||||
|
@ -62,7 +57,7 @@ class BatchingAgentEventForwarder:
|
|||
)
|
||||
self._batch_and_send_thread.start()
|
||||
|
||||
def add_event_to_queue(self, serialized_event: JSONSerializable):
|
||||
def add_event_to_queue(self, serialized_event: AbstractAgentEvent):
|
||||
self._queue.put(serialized_event)
|
||||
|
||||
def _manage_event_batches(self):
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import functools
|
||||
import logging
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
import requests
|
||||
|
||||
from common.agent_event_serializers import JSONSerializable
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||
|
||||
from . import (
|
||||
IIslandAPIClient,
|
||||
|
@ -49,7 +50,11 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
"""
|
||||
|
||||
@handle_island_errors
|
||||
def __init__(self, island_server: str):
|
||||
def __init__(
|
||||
self,
|
||||
island_server: str,
|
||||
agent_event_serializer_registry: AgentEventSerializerRegistry = None,
|
||||
):
|
||||
response = requests.get( # noqa: DUO123
|
||||
f"https://{island_server}/api?action=is-up",
|
||||
verify=False,
|
||||
|
@ -60,6 +65,8 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
self._island_server = island_server
|
||||
self._api_url = f"https://{self._island_server}/api"
|
||||
|
||||
self._agent_event_serializer_registry = agent_event_serializer_registry
|
||||
|
||||
@handle_island_errors
|
||||
def send_log(self, log_contents: str):
|
||||
response = requests.post( # noqa: DUO123
|
||||
|
@ -85,9 +92,21 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
def send_events(self, events: Sequence[JSONSerializable]):
|
||||
response = requests.post( # noqa: DUO123
|
||||
f"{self._api_url}/agent-events",
|
||||
json=events,
|
||||
json=self._serialize_events(events),
|
||||
verify=False,
|
||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
|
||||
serialized_events: List[JSONSerializable] = []
|
||||
|
||||
try:
|
||||
for e in events:
|
||||
serializer = self._agent_event_serializer_registry[e.__class__]
|
||||
serialized_events.append(serializer.serialize(e))
|
||||
except Exception as err:
|
||||
raise IslandAPIRequestError(err)
|
||||
|
||||
return serialized_events
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence
|
||||
|
||||
from common.agent_event_serializers import JSONSerializable
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
|
||||
|
||||
class IIslandAPIClient(ABC):
|
||||
|
@ -59,11 +59,11 @@ class IIslandAPIClient(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def send_events(self, events: Sequence[JSONSerializable]):
|
||||
def send_events(self, events: Sequence[AbstractAgentEvent]):
|
||||
"""
|
||||
Send a sequence of Agent events to the Island
|
||||
Send a sequence of agent events to the Island
|
||||
|
||||
:param events: A sequence of Agent events
|
||||
:param events: A sequence of agent events
|
||||
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
|
||||
:raises IslandAPIRequestError: If an error occurs while attempting to connect to the
|
||||
island due to an issue in the request sent from the client
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
import pytest
|
||||
from uuid import UUID
|
||||
import requests
|
||||
import requests_mock
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from common.agent_event_serializers import (
|
||||
AgentEventSerializerRegistry,
|
||||
PydanticAgentEventSerializer,
|
||||
)
|
||||
|
||||
from infection_monkey.island_api_client import (
|
||||
HTTPIslandAPIClient,
|
||||
|
@ -19,6 +25,8 @@ ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
|
|||
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
|
||||
ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
|
||||
|
||||
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"actual_error, expected_error",
|
||||
|
@ -121,6 +129,69 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error
|
|||
island_api_client.get_pba_file(filename=PBA_FILE)
|
||||
|
||||
|
||||
class Event1(AbstractAgentEvent):
|
||||
a: int
|
||||
|
||||
|
||||
class Event2(AbstractAgentEvent):
|
||||
b: str
|
||||
|
||||
|
||||
class Event3(AbstractAgentEvent):
|
||||
c: int
|
||||
|
||||
@pytest.fixture
|
||||
def agent_event_serializer_registry():
|
||||
agent_event_serializer_registry = AgentEventSerializerRegistry()
|
||||
agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1)
|
||||
agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2)
|
||||
|
||||
return agent_event_serializer_registry
|
||||
|
||||
|
||||
def test_island_api_client_send_events__serialization(agent_event_serializer_registry):
|
||||
events_to_send = [
|
||||
Event1(source=AGENT_ID, timestamp=0, a=1),
|
||||
Event2(source=AGENT_ID, timestamp=0, b="hello"),
|
||||
]
|
||||
expected_json = [
|
||||
{
|
||||
"source": "80988359-a1cd-42a2-9b47-5b94b37cd673",
|
||||
"target": None,
|
||||
"timestamp": 0.0,
|
||||
"tags": [],
|
||||
"a": 1,
|
||||
"type": "Event1",
|
||||
},
|
||||
{
|
||||
"source": "80988359-a1cd-42a2-9b47-5b94b37cd673",
|
||||
"target": None,
|
||||
"timestamp": 0.0,
|
||||
"tags": [],
|
||||
"b": "hello",
|
||||
"type": "Event2",
|
||||
},
|
||||
]
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
m.post(ISLAND_SEND_EVENTS_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
|
||||
|
||||
island_api_client.send_events(events=events_to_send)
|
||||
|
||||
assert m.last_request.json() == expected_json
|
||||
|
||||
def test_island_api_client_send_events__serialization_failed(agent_event_serializer_registry):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
|
||||
|
||||
with pytest.raises(IslandAPIRequestError):
|
||||
m.post(ISLAND_SEND_EVENTS_URI)
|
||||
island_api_client.send_events(events=[Event3(source=AGENT_ID, c=1)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"actual_error, expected_error",
|
||||
[
|
||||
|
@ -129,14 +200,16 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error
|
|||
(Exception, IslandAPIError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__send_events(actual_error, expected_error):
|
||||
def test_island_api_client__send_events(
|
||||
actual_error, expected_error, agent_event_serializer_registry
|
||||
):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error)
|
||||
island_api_client.send_events(events="some_data")
|
||||
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -146,11 +219,13 @@ def test_island_api_client__send_events(actual_error, expected_error):
|
|||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client_send_events__status_code(status_code, expected_error):
|
||||
def test_island_api_client_send_events__status_code(
|
||||
status_code, expected_error, agent_event_serializer_registry
|
||||
):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER)
|
||||
island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
|
||||
island_api_client.send_events(events="some_data")
|
||||
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
|
||||
|
|
Loading…
Reference in New Issue