From b320fba2c886f70a667065f93e55bf57896a22ad Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Tue, 20 Sep 2022 12:32:00 +0200 Subject: [PATCH] Agent: Modify AgentEventForwarder to use IIslandAPIClient --- .../infection_monkey/agent_event_forwarder.py | 40 +++++-------- .../test_agent_event_forwarder.py | 58 +++++++++---------- 2 files changed, 44 insertions(+), 54 deletions(-) diff --git a/monkey/infection_monkey/agent_event_forwarder.py b/monkey/infection_monkey/agent_event_forwarder.py index f280ad4f3..3bfdf21e9 100644 --- a/monkey/infection_monkey/agent_event_forwarder.py +++ b/monkey/infection_monkey/agent_event_forwarder.py @@ -3,18 +3,15 @@ import queue import threading from time import sleep -import requests - from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_events import AbstractAgentEvent -from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT +from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.utils.threading import create_daemon_thread logger = logging.getLogger(__name__) DEFAULT_TIME_PERIOD_SECONDS = 5 -AGENT_EVENTS_API_URL = "https://%s/api/agent-events" class AgentEventForwarder: @@ -23,12 +20,13 @@ class AgentEventForwarder: """ def __init__( - self, server_address: str, agent_event_serializer_registry: AgentEventSerializerRegistry + self, + island_api_client: IIslandAPIClient, + agent_event_serializer_registry: AgentEventSerializerRegistry, ): - self._server_address = server_address self._agent_event_serializer_registry = agent_event_serializer_registry - self._batching_agent_event_forwarder = BatchingAgentEventForwarder(self._server_address) + self._batching_agent_event_forwarder = BatchingAgentEventForwarder(island_api_client) self._batching_agent_event_forwarder.start() def __del__(self): @@ -37,11 +35,9 @@ class AgentEventForwarder: def send_event(self, event: AbstractAgentEvent): serialized_event = self._serialize_event(event) self._batching_agent_event_forwarder.add_event_to_queue(serialized_event) - logger.debug( - f"Sending event of type {type(event).__name__} to the Island at {self._server_address}" - ) + logger.debug(f"Sending event of type {type(event).__name__} to the Island") - def _serialize_event(self, event: AbstractAgentEvent): + def _serialize_event(self, event: AbstractAgentEvent) -> JSONSerializable: serializer = self._agent_event_serializer_registry[event.__class__] return serializer.serialize(event) @@ -51,8 +47,10 @@ class BatchingAgentEventForwarder: Handles the batching and sending of the Agent's events to the Island """ - def __init__(self, server_address: str, time_period: int = DEFAULT_TIME_PERIOD_SECONDS): - self._server_address = server_address + def __init__( + self, island_api_client: IIslandAPIClient, time_period: int = DEFAULT_TIME_PERIOD_SECONDS + ): + self._island_api_client = island_api_client self._time_period = time_period self._queue: queue.Queue[AbstractAgentEvent] = queue.Queue() @@ -84,18 +82,10 @@ class BatchingAgentEventForwarder: events.append(self._queue.get(block=False)) try: - logger.debug(f"Sending Agent events to Island at {self._server_address}: {events}") - requests.post( # noqa: DUO123 - AGENT_EVENTS_API_URL % (self._server_address,), - json=events, - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - except Exception as exc: - logger.warning( - f"Exception caught when connecting to the Island at {self._server_address}" - f": {exc}" - ) + logger.debug(f"Sending Agent events to Island: {events}") + self._island_api_client.send_events(events) + except Exception as err: + logger.warning(f"Exception caught when connecting to the Island: {err}") def _send_remaining_events(self): self._send_events_to_island() diff --git a/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py b/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py index 718f11de9..aa2d1381a 100644 --- a/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py +++ b/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py @@ -1,16 +1,22 @@ import time +from unittest.mock import MagicMock import pytest -import requests_mock -from infection_monkey.agent_event_forwarder import AGENT_EVENTS_API_URL, BatchingAgentEventForwarder +from infection_monkey.agent_event_forwarder import BatchingAgentEventForwarder +from infection_monkey.island_api_client import IIslandAPIClient SERVER = "1.1.1.1:9999" @pytest.fixture -def event_sender(): - return BatchingAgentEventForwarder(SERVER, time_period=0.001) +def mock_island_api_client(): + return MagicMock(spec=IIslandAPIClient) + + +@pytest.fixture +def event_sender(mock_island_api_client): + return BatchingAgentEventForwarder(mock_island_api_client, time_period=0.001) # NOTE: If these tests are too slow or end up being racey, we can redesign AgentEventForwarder to @@ -18,35 +24,29 @@ def event_sender(): # BatchingAgentEventForwarder would have unit tests, but AgentEventForwarder would not. -def test_send_events(event_sender): - with requests_mock.Mocker() as mock: - mock.post(AGENT_EVENTS_API_URL % SERVER) - - event_sender.start() - - for _ in range(5): - event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 1 +def test_send_events(event_sender, mock_island_api_client): + event_sender.start() + for _ in range(5): event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 2 + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 1 - event_sender.stop() + event_sender.add_event_to_queue({}) + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 2 + + event_sender.stop() -def test_send_remaining_events(event_sender): - with requests_mock.Mocker() as mock: - mock.post(AGENT_EVENTS_API_URL % SERVER) - - event_sender.start() - - for _ in range(5): - event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 1 +def test_send_remaining_events(event_sender, mock_island_api_client): + event_sender.start() + for _ in range(5): event_sender.add_event_to_queue({}) - event_sender.stop() - assert mock.call_count == 2 + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 1 + + event_sender.add_event_to_queue({}) + event_sender.stop() + assert mock_island_api_client.send_events.call_count == 2