From 9000a01d1dfe15a264048ff4cdbb70c173e66862 Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Tue, 20 Sep 2022 12:28:50 +0200 Subject: [PATCH 01/10] Agent: Add send_events to IIslandAPIClient --- .../island_api_client/i_island_api_client.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index fefba9973..cf9670e9f 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -1,4 +1,7 @@ from abc import ABC, abstractmethod +from typing import Sequence + +from common.agent_event_serializers import JSONSerializable class IIslandAPIClient(ABC): @@ -54,3 +57,19 @@ class IIslandAPIClient(ABC): :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the custom PBA file """ + + @abstractmethod + def send_events(self, events: Sequence[JSONSerializable]): + """ + Send the events to the Island + + :param events: Events that are going to be send to the Island + :raises IslandAPIConnectionError: If the client cannot successfully connect to the island + :raises IslandAPIRequestError: If an error occurs while attempting to connect to the + island due to an issue in the request sent from the client + :raises IslandAPIRequestFailedError: If an error occurs while attempting to connect to the + island due to an error on the server + :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island + :raises IslandAPIError: If an unexpected error occurs while attempting to send events to + the island + """ From f39007b0ce7312ac8017bd40464ad80e6cb7f6ca Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Tue, 20 Sep 2022 12:29:29 +0200 Subject: [PATCH 02/10] Agent: Implement send_events in HTTPIslandAPIClient --- .../http_island_api_client.py | 13 +++++++ .../test_http_island_api_client.py | 36 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 37feb8942..ab444c3e8 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,8 +1,10 @@ import functools import logging +from typing import Sequence import requests +from common.agent_event_serializers import JSONSerializable from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from . import ( @@ -76,3 +78,14 @@ class HTTPIslandAPIClient(IIslandAPIClient): response.raise_for_status() return response.content + + @handle_island_errors + def send_events(self, events: Sequence[JSONSerializable]): + response = requests.post( # noqa: DUO123 + f"{self._api_url}/agent-events", + json=events, + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + + response.raise_for_status() diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index bd6bfcb41..6562bf153 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -17,6 +17,7 @@ PBA_FILE = "dummy.pba" ISLAND_URI = f"https://{SERVER}/api?action=is-up" ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" +ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" @pytest.mark.parametrize( @@ -118,3 +119,38 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error with pytest.raises(expected_error): m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code) island_api_client.get_pba_file(filename=PBA_FILE) + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + (Exception, IslandAPIError), + ], +) +def test_island_api_client__send_events(actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error) + island_api_client.send_events(events="some_data") + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_send_events__status_code(status_code, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER) + + with pytest.raises(expected_error): + m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) + island_api_client.send_events(events="some_data") From b320fba2c886f70a667065f93e55bf57896a22ad Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Tue, 20 Sep 2022 12:32:00 +0200 Subject: [PATCH 03/10] Agent: Modify AgentEventForwarder to use IIslandAPIClient --- .../infection_monkey/agent_event_forwarder.py | 40 +++++-------- .../test_agent_event_forwarder.py | 58 +++++++++---------- 2 files changed, 44 insertions(+), 54 deletions(-) diff --git a/monkey/infection_monkey/agent_event_forwarder.py b/monkey/infection_monkey/agent_event_forwarder.py index f280ad4f3..3bfdf21e9 100644 --- a/monkey/infection_monkey/agent_event_forwarder.py +++ b/monkey/infection_monkey/agent_event_forwarder.py @@ -3,18 +3,15 @@ import queue import threading from time import sleep -import requests - from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_events import AbstractAgentEvent -from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT +from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.utils.threading import create_daemon_thread logger = logging.getLogger(__name__) DEFAULT_TIME_PERIOD_SECONDS = 5 -AGENT_EVENTS_API_URL = "https://%s/api/agent-events" class AgentEventForwarder: @@ -23,12 +20,13 @@ class AgentEventForwarder: """ def __init__( - self, server_address: str, agent_event_serializer_registry: AgentEventSerializerRegistry + self, + island_api_client: IIslandAPIClient, + agent_event_serializer_registry: AgentEventSerializerRegistry, ): - self._server_address = server_address self._agent_event_serializer_registry = agent_event_serializer_registry - self._batching_agent_event_forwarder = BatchingAgentEventForwarder(self._server_address) + self._batching_agent_event_forwarder = BatchingAgentEventForwarder(island_api_client) self._batching_agent_event_forwarder.start() def __del__(self): @@ -37,11 +35,9 @@ class AgentEventForwarder: def send_event(self, event: AbstractAgentEvent): serialized_event = self._serialize_event(event) self._batching_agent_event_forwarder.add_event_to_queue(serialized_event) - logger.debug( - f"Sending event of type {type(event).__name__} to the Island at {self._server_address}" - ) + logger.debug(f"Sending event of type {type(event).__name__} to the Island") - def _serialize_event(self, event: AbstractAgentEvent): + def _serialize_event(self, event: AbstractAgentEvent) -> JSONSerializable: serializer = self._agent_event_serializer_registry[event.__class__] return serializer.serialize(event) @@ -51,8 +47,10 @@ class BatchingAgentEventForwarder: Handles the batching and sending of the Agent's events to the Island """ - def __init__(self, server_address: str, time_period: int = DEFAULT_TIME_PERIOD_SECONDS): - self._server_address = server_address + def __init__( + self, island_api_client: IIslandAPIClient, time_period: int = DEFAULT_TIME_PERIOD_SECONDS + ): + self._island_api_client = island_api_client self._time_period = time_period self._queue: queue.Queue[AbstractAgentEvent] = queue.Queue() @@ -84,18 +82,10 @@ class BatchingAgentEventForwarder: events.append(self._queue.get(block=False)) try: - logger.debug(f"Sending Agent events to Island at {self._server_address}: {events}") - requests.post( # noqa: DUO123 - AGENT_EVENTS_API_URL % (self._server_address,), - json=events, - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - except Exception as exc: - logger.warning( - f"Exception caught when connecting to the Island at {self._server_address}" - f": {exc}" - ) + logger.debug(f"Sending Agent events to Island: {events}") + self._island_api_client.send_events(events) + except Exception as err: + logger.warning(f"Exception caught when connecting to the Island: {err}") def _send_remaining_events(self): self._send_events_to_island() diff --git a/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py b/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py index 718f11de9..aa2d1381a 100644 --- a/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py +++ b/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py @@ -1,16 +1,22 @@ import time +from unittest.mock import MagicMock import pytest -import requests_mock -from infection_monkey.agent_event_forwarder import AGENT_EVENTS_API_URL, BatchingAgentEventForwarder +from infection_monkey.agent_event_forwarder import BatchingAgentEventForwarder +from infection_monkey.island_api_client import IIslandAPIClient SERVER = "1.1.1.1:9999" @pytest.fixture -def event_sender(): - return BatchingAgentEventForwarder(SERVER, time_period=0.001) +def mock_island_api_client(): + return MagicMock(spec=IIslandAPIClient) + + +@pytest.fixture +def event_sender(mock_island_api_client): + return BatchingAgentEventForwarder(mock_island_api_client, time_period=0.001) # NOTE: If these tests are too slow or end up being racey, we can redesign AgentEventForwarder to @@ -18,35 +24,29 @@ def event_sender(): # BatchingAgentEventForwarder would have unit tests, but AgentEventForwarder would not. -def test_send_events(event_sender): - with requests_mock.Mocker() as mock: - mock.post(AGENT_EVENTS_API_URL % SERVER) - - event_sender.start() - - for _ in range(5): - event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 1 +def test_send_events(event_sender, mock_island_api_client): + event_sender.start() + for _ in range(5): event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 2 + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 1 - event_sender.stop() + event_sender.add_event_to_queue({}) + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 2 + + event_sender.stop() -def test_send_remaining_events(event_sender): - with requests_mock.Mocker() as mock: - mock.post(AGENT_EVENTS_API_URL % SERVER) - - event_sender.start() - - for _ in range(5): - event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 1 +def test_send_remaining_events(event_sender, mock_island_api_client): + event_sender.start() + for _ in range(5): event_sender.add_event_to_queue({}) - event_sender.stop() - assert mock.call_count == 2 + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 1 + + event_sender.add_event_to_queue({}) + event_sender.stop() + assert mock_island_api_client.send_events.call_count == 2 From 14592d964e8efc132064ee932a36566ba13fed11 Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Tue, 20 Sep 2022 12:52:36 +0200 Subject: [PATCH 04/10] Agent: Pass island_api_client when constructing AgentEventForwarder --- monkey/infection_monkey/monkey.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 96da82225..ba63d87cb 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -248,7 +248,7 @@ class InfectionMonkey: ) event_queue = PyPubSubAgentEventQueue(Publisher()) - InfectionMonkey._subscribe_events( + self._subscribe_events( event_queue, propagation_credentials_repository, self._control_client.server_address, @@ -274,8 +274,8 @@ class InfectionMonkey: propagation_credentials_repository, ) - @staticmethod def _subscribe_events( + self, event_queue: IAgentEventQueue, propagation_credentials_repository: IPropagationCredentialsRepository, server_address: str, @@ -288,7 +288,7 @@ class InfectionMonkey: ), ) event_queue.subscribe_all_events( - AgentEventForwarder(server_address, agent_event_serializer_registry).send_event + AgentEventForwarder(self.island_api_client, agent_event_serializer_registry).send_event ) def _build_puppet( From 8b52ba0686849d3030d58e7aab29f6bd398a0616 Mon Sep 17 00:00:00 2001 From: Shreya Malviya Date: Tue, 20 Sep 2022 18:01:17 +0530 Subject: [PATCH 05/10] Agent: Modify docstring in IIslandAPIClient.send_events() --- .../infection_monkey/island_api_client/i_island_api_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index cf9670e9f..84c3052ca 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -61,9 +61,9 @@ class IIslandAPIClient(ABC): @abstractmethod def send_events(self, events: Sequence[JSONSerializable]): """ - Send the events to the Island + Send a sequence of Agent events to the Island - :param events: Events that are going to be send to the Island + :param events: A sequence of Agent events :raises IslandAPIConnectionError: If the client cannot successfully connect to the island :raises IslandAPIRequestError: If an error occurs while attempting to connect to the island due to an issue in the request sent from the client From 34a4d813361dc10fb26e0953145b35913a439194 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 20 Sep 2022 09:39:20 -0400 Subject: [PATCH 06/10] Agent: Reraise IslandAPIError in handle_island_errors() --- .../island_api_client/http_island_api_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index ab444c3e8..1d816f6b9 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -35,6 +35,8 @@ def handle_island_errors(fn): raise IslandAPIError(err) except TimeoutError as err: raise IslandAPITimeoutError(err) + except IslandAPIError as err: + raise err except Exception as err: raise IslandAPIError(err) From 17cb77cfdd8e2d0f2bf24833f522d639dea7b152 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 20 Sep 2022 09:45:00 -0400 Subject: [PATCH 07/10] Agent: Move AbstractAgentEvent serialization to HTTPIslandAPIClient --- .../infection_monkey/agent_event_forwarder.py | 11 +-- .../http_island_api_client.py | 27 +++++- .../island_api_client/i_island_api_client.py | 8 +- .../test_http_island_api_client.py | 87 +++++++++++++++++-- 4 files changed, 111 insertions(+), 22 deletions(-) diff --git a/monkey/infection_monkey/agent_event_forwarder.py b/monkey/infection_monkey/agent_event_forwarder.py index 3bfdf21e9..f253255d3 100644 --- a/monkey/infection_monkey/agent_event_forwarder.py +++ b/monkey/infection_monkey/agent_event_forwarder.py @@ -3,7 +3,7 @@ import queue import threading from time import sleep -from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable +from common.agent_event_serializers import AgentEventSerializerRegistry from common.agent_events import AbstractAgentEvent from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.utils.threading import create_daemon_thread @@ -33,14 +33,9 @@ class AgentEventForwarder: self._batching_agent_event_forwarder.stop() def send_event(self, event: AbstractAgentEvent): - serialized_event = self._serialize_event(event) - self._batching_agent_event_forwarder.add_event_to_queue(serialized_event) + self._batching_agent_event_forwarder.add_event_to_queue(event) logger.debug(f"Sending event of type {type(event).__name__} to the Island") - def _serialize_event(self, event: AbstractAgentEvent) -> JSONSerializable: - serializer = self._agent_event_serializer_registry[event.__class__] - return serializer.serialize(event) - class BatchingAgentEventForwarder: """ @@ -62,7 +57,7 @@ class BatchingAgentEventForwarder: ) self._batch_and_send_thread.start() - def add_event_to_queue(self, serialized_event: JSONSerializable): + def add_event_to_queue(self, serialized_event: AbstractAgentEvent): self._queue.put(serialized_event) def _manage_event_batches(self): diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 1d816f6b9..7118e3e37 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,11 +1,12 @@ import functools import logging -from typing import Sequence +from typing import List, Sequence import requests -from common.agent_event_serializers import JSONSerializable +from common.agent_events import AbstractAgentEvent from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT +from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from . import ( IIslandAPIClient, @@ -49,7 +50,11 @@ class HTTPIslandAPIClient(IIslandAPIClient): """ @handle_island_errors - def __init__(self, island_server: str): + def __init__( + self, + island_server: str, + agent_event_serializer_registry: AgentEventSerializerRegistry = None, + ): response = requests.get( # noqa: DUO123 f"https://{island_server}/api?action=is-up", verify=False, @@ -60,6 +65,8 @@ class HTTPIslandAPIClient(IIslandAPIClient): self._island_server = island_server self._api_url = f"https://{self._island_server}/api" + self._agent_event_serializer_registry = agent_event_serializer_registry + @handle_island_errors def send_log(self, log_contents: str): response = requests.post( # noqa: DUO123 @@ -85,9 +92,21 @@ class HTTPIslandAPIClient(IIslandAPIClient): def send_events(self, events: Sequence[JSONSerializable]): response = requests.post( # noqa: DUO123 f"{self._api_url}/agent-events", - json=events, + json=self._serialize_events(events), verify=False, timeout=MEDIUM_REQUEST_TIMEOUT, ) response.raise_for_status() + + def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: + serialized_events: List[JSONSerializable] = [] + + try: + for e in events: + serializer = self._agent_event_serializer_registry[e.__class__] + serialized_events.append(serializer.serialize(e)) + except Exception as err: + raise IslandAPIRequestError(err) + + return serialized_events diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index 84c3052ca..d75f2dea5 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Sequence -from common.agent_event_serializers import JSONSerializable +from common.agent_events import AbstractAgentEvent class IIslandAPIClient(ABC): @@ -59,11 +59,11 @@ class IIslandAPIClient(ABC): """ @abstractmethod - def send_events(self, events: Sequence[JSONSerializable]): + def send_events(self, events: Sequence[AbstractAgentEvent]): """ - Send a sequence of Agent events to the Island + Send a sequence of agent events to the Island - :param events: A sequence of Agent events + :param events: A sequence of agent events :raises IslandAPIConnectionError: If the client cannot successfully connect to the island :raises IslandAPIRequestError: If an error occurs while attempting to connect to the island due to an issue in the request sent from the client diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 6562bf153..607b18e7b 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -1,6 +1,12 @@ import pytest +from uuid import UUID import requests import requests_mock +from common.agent_events import AbstractAgentEvent +from common.agent_event_serializers import ( + AgentEventSerializerRegistry, + PydanticAgentEventSerializer, +) from infection_monkey.island_api_client import ( HTTPIslandAPIClient, @@ -19,6 +25,8 @@ ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" +AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") + @pytest.mark.parametrize( "actual_error, expected_error", @@ -121,6 +129,69 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error island_api_client.get_pba_file(filename=PBA_FILE) +class Event1(AbstractAgentEvent): + a: int + + +class Event2(AbstractAgentEvent): + b: str + + +class Event3(AbstractAgentEvent): + c: int + +@pytest.fixture +def agent_event_serializer_registry(): + agent_event_serializer_registry = AgentEventSerializerRegistry() + agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1) + agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2) + + return agent_event_serializer_registry + + +def test_island_api_client_send_events__serialization(agent_event_serializer_registry): + events_to_send = [ + Event1(source=AGENT_ID, timestamp=0, a=1), + Event2(source=AGENT_ID, timestamp=0, b="hello"), + ] + expected_json = [ + { + "source": "80988359-a1cd-42a2-9b47-5b94b37cd673", + "target": None, + "timestamp": 0.0, + "tags": [], + "a": 1, + "type": "Event1", + }, + { + "source": "80988359-a1cd-42a2-9b47-5b94b37cd673", + "target": None, + "timestamp": 0.0, + "tags": [], + "b": "hello", + "type": "Event2", + }, + ] + + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + m.post(ISLAND_SEND_EVENTS_URI) + island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + + island_api_client.send_events(events=events_to_send) + + assert m.last_request.json() == expected_json + +def test_island_api_client_send_events__serialization_failed(agent_event_serializer_registry): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + + with pytest.raises(IslandAPIRequestError): + m.post(ISLAND_SEND_EVENTS_URI) + island_api_client.send_events(events=[Event3(source=AGENT_ID, c=1)]) + + @pytest.mark.parametrize( "actual_error, expected_error", [ @@ -129,14 +200,16 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error (Exception, IslandAPIError), ], ) -def test_island_api_client__send_events(actual_error, expected_error): +def test_island_api_client__send_events( + actual_error, expected_error, agent_event_serializer_registry +): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error) - island_api_client.send_events(events="some_data") + island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)]) @pytest.mark.parametrize( @@ -146,11 +219,13 @@ def test_island_api_client__send_events(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_send_events__status_code(status_code, expected_error): +def test_island_api_client_send_events__status_code( + status_code, expected_error, agent_event_serializer_registry +): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) - island_api_client.send_events(events="some_data") + island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)]) From eea7fc1ee2bf19d29e8bbd12769ecf55005a33d6 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 20 Sep 2022 09:58:15 -0400 Subject: [PATCH 08/10] Agent: Add AbstractIslandAPIClientFactory --- .../infection_monkey/island_api_client/__init__.py | 1 + .../abstract_island_api_client_factory.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) create mode 100644 monkey/infection_monkey/island_api_client/abstract_island_api_client_factory.py diff --git a/monkey/infection_monkey/island_api_client/__init__.py b/monkey/infection_monkey/island_api_client/__init__.py index ec513e774..0dd8a7865 100644 --- a/monkey/infection_monkey/island_api_client/__init__.py +++ b/monkey/infection_monkey/island_api_client/__init__.py @@ -6,4 +6,5 @@ from .island_api_client_errors import ( IslandAPITimeoutError, ) from .i_island_api_client import IIslandAPIClient +from .abstract_island_api_client_factory import AbstractIslandAPIClientFactory from .http_island_api_client import HTTPIslandAPIClient diff --git a/monkey/infection_monkey/island_api_client/abstract_island_api_client_factory.py b/monkey/infection_monkey/island_api_client/abstract_island_api_client_factory.py new file mode 100644 index 000000000..2a74dcd96 --- /dev/null +++ b/monkey/infection_monkey/island_api_client/abstract_island_api_client_factory.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +from . import IIslandAPIClient + + +class AbstractIslandAPIClientFactory(ABC): + @abstractmethod + def create_island_api_client(self) -> IIslandAPIClient: + """ + Create an IIslandAPIClient + + :return: A concrete instance of an IIslandAPIClient + """ From e9433ad23ba2b5f13b7ce95b2335444cc47e6d41 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 20 Sep 2022 10:05:47 -0400 Subject: [PATCH 09/10] Agent: Initialize _agent_event_serializer_registry in __init__() --- monkey/infection_monkey/monkey.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index ba63d87cb..ab3461c73 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -107,9 +107,12 @@ logging.getLogger("urllib3").setLevel(logging.INFO) class InfectionMonkey: def __init__(self, args): logger.info("Monkey is initializing...") + self._singleton = SystemSingleton() self._opts = self._get_arguments(args) + self._agent_event_serializer_registry = self._setup_agent_event_serializers() + # TODO: Revisit variable names server, island_api_client = self._connect_to_island_api() # TODO: `address_to_port()` should return the port as an integer. @@ -207,8 +210,6 @@ class InfectionMonkey: if firewall.is_enabled(): firewall.add_firewall_rule() - self._agent_event_serializer_registry = self._setup_agent_event_serializers() - self._control_channel = ControlChannel(self._control_client.server_address, GUID) self._control_channel.register_agent(self._opts.parent) From 9807c23571ccd1e323bd764890b2f99aaae63334 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 20 Sep 2022 10:23:21 -0400 Subject: [PATCH 10/10] Agent: Add IIslandAPIClient.connect() Different clients may have different dependencies in their constructors. Use connect() instead of __init__() to connect to the Island. Add an AbstractIslandAPIClientFactory and HTTPIslandAPIClientFactory to facilitate this. --- .../island_api_client/__init__.py | 2 +- .../http_island_api_client.py | 25 ++++- .../island_api_client/i_island_api_client.py | 4 +- monkey/infection_monkey/monkey.py | 12 ++- .../infection_monkey/network/relay/utils.py | 23 ++-- .../test_http_island_api_client.py | 100 ++++++++++-------- .../network/relay/test_utils.py | 22 +++- 7 files changed, 116 insertions(+), 72 deletions(-) diff --git a/monkey/infection_monkey/island_api_client/__init__.py b/monkey/infection_monkey/island_api_client/__init__.py index 0dd8a7865..0eb69612e 100644 --- a/monkey/infection_monkey/island_api_client/__init__.py +++ b/monkey/infection_monkey/island_api_client/__init__.py @@ -7,4 +7,4 @@ from .island_api_client_errors import ( ) from .i_island_api_client import IIslandAPIClient from .abstract_island_api_client_factory import AbstractIslandAPIClientFactory -from .http_island_api_client import HTTPIslandAPIClient +from .http_island_api_client import HTTPIslandAPIClient, HTTPIslandAPIClientFactory diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 7118e3e37..e3f23ff81 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -4,11 +4,12 @@ from typing import List, Sequence import requests +from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_events import AbstractAgentEvent from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT -from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from . import ( + AbstractIslandAPIClientFactory, IIslandAPIClient, IslandAPIConnectionError, IslandAPIError, @@ -49,11 +50,16 @@ class HTTPIslandAPIClient(IIslandAPIClient): A client for the Island's HTTP API """ - @handle_island_errors def __init__( + self, + agent_event_serializer_registry: AgentEventSerializerRegistry, + ): + self._agent_event_serializer_registry = agent_event_serializer_registry + + @handle_island_errors + def connect( self, island_server: str, - agent_event_serializer_registry: AgentEventSerializerRegistry = None, ): response = requests.get( # noqa: DUO123 f"https://{island_server}/api?action=is-up", @@ -65,8 +71,6 @@ class HTTPIslandAPIClient(IIslandAPIClient): self._island_server = island_server self._api_url = f"https://{self._island_server}/api" - self._agent_event_serializer_registry = agent_event_serializer_registry - @handle_island_errors def send_log(self, log_contents: str): response = requests.post( # noqa: DUO123 @@ -110,3 +114,14 @@ class HTTPIslandAPIClient(IIslandAPIClient): raise IslandAPIRequestError(err) return serialized_events + + +class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): + def __init__( + self, + agent_event_serializer_registry: AgentEventSerializerRegistry = None, + ): + self._agent_event_serializer_registry = agent_event_serializer_registry + + def create_island_api_client(self): + return HTTPIslandAPIClient(self._agent_event_serializer_registry) diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index d75f2dea5..61fc98c58 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -10,9 +10,9 @@ class IIslandAPIClient(ABC): """ @abstractmethod - def __init__(self, island_server: str): + def connect(self, island_server: str): """ - Construct an island API client and connect it to the island + Connectto the island's API :param island_server: The socket address of the API :raises IslandAPIConnectionError: If the client cannot successfully connect to the island diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index ab3461c73..49952ceca 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -45,7 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.exploit.wmiexec import WmiExploiter from infection_monkey.exploit.zerologon import ZerologonExploiter from infection_monkey.i_puppet import IPuppet, PluginType -from infection_monkey.island_api_client import IIslandAPIClient +from infection_monkey.island_api_client import HTTPIslandAPIClientFactory, IIslandAPIClient from infection_monkey.master import AutomatedMaster from infection_monkey.master.control_channel import ControlChannel from infection_monkey.model import VictimHostFactory @@ -114,12 +114,12 @@ class InfectionMonkey: self._agent_event_serializer_registry = self._setup_agent_event_serializers() # TODO: Revisit variable names - server, island_api_client = self._connect_to_island_api() + server, self._island_api_client = self._connect_to_island_api() # TODO: `address_to_port()` should return the port as an integer. self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_port = int(self._cmd_island_port) self._control_client = ControlClient( - server_address=server, island_api_client=island_api_client + server_address=server, island_api_client=self._island_api_client ) # TODO Refactor the telemetry messengers to accept control client @@ -145,7 +145,9 @@ class InfectionMonkey: # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server` def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]: logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") - server_clients = find_available_island_apis(self._opts.servers) + server_clients = find_available_island_apis( + self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) + ) server, island_api_client = self._select_server(server_clients) @@ -289,7 +291,7 @@ class InfectionMonkey: ), ) event_queue.subscribe_all_events( - AgentEventForwarder(self.island_api_client, agent_event_serializer_registry).send_event + AgentEventForwarder(self._island_api_client, agent_event_serializer_registry).send_event ) def _build_puppet( diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 451fd65e7..92c2fb767 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -7,7 +7,7 @@ from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port from infection_monkey.island_api_client import ( - HTTPIslandAPIClient, + AbstractIslandAPIClientFactory, IIslandAPIClient, IslandAPIConnectionError, IslandAPIError, @@ -27,7 +27,9 @@ logger = logging.getLogger(__name__) NUM_FIND_SERVER_WORKERS = 32 -def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]: +def find_available_island_apis( + servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory +) -> Mapping[str, Optional[IIslandAPIClient]]: server_list = list(servers) server_iterator = ThreadSafeIterator(server_list.__iter__()) server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {} @@ -35,7 +37,7 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[ run_worker_threads( _find_island_server, "FindIslandServer", - args=(server_iterator, server_results), + args=(server_iterator, server_results, island_api_client_factory), num_workers=NUM_FIND_SERVER_WORKERS, ) @@ -43,18 +45,25 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[ def _find_island_server( - servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]] + servers: Iterator[str], + server_status: MutableMapping[str, Optional[IIslandAPIClient]], + island_api_client_factory: AbstractIslandAPIClientFactory, ): with suppress(StopIteration): server = next(servers) - server_status[server] = _check_if_island_server(server) + server_status[server] = _check_if_island_server(server, island_api_client_factory) -def _check_if_island_server(server: str) -> IIslandAPIClient: +def _check_if_island_server( + server: str, island_api_client_factory: AbstractIslandAPIClientFactory +) -> IIslandAPIClient: logger.debug(f"Trying to connect to server: {server}") try: - return HTTPIslandAPIClient(server) + client = island_api_client_factory.create_island_api_client() + client.connect(server) + + return client except IslandAPIConnectionError as err: logger.error(f"Unable to connect to server/relay {server}: {err}") except IslandAPITimeoutError as err: diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 607b18e7b..530713aac 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -1,13 +1,14 @@ -import pytest from uuid import UUID + +import pytest import requests import requests_mock -from common.agent_events import AbstractAgentEvent + from common.agent_event_serializers import ( AgentEventSerializerRegistry, PydanticAgentEventSerializer, ) - +from common.agent_events import AbstractAgentEvent from infection_monkey.island_api_client import ( HTTPIslandAPIClient, IslandAPIConnectionError, @@ -28,6 +29,32 @@ ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") +class Event1(AbstractAgentEvent): + a: int + + +class Event2(AbstractAgentEvent): + b: str + + +class Event3(AbstractAgentEvent): + c: int + + +@pytest.fixture +def agent_event_serializer_registry(): + agent_event_serializer_registry = AgentEventSerializerRegistry() + agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1) + agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2) + + return agent_event_serializer_registry + + +@pytest.fixture +def island_api_client(agent_event_serializer_registry): + return HTTPIslandAPIClient(agent_event_serializer_registry) + + @pytest.mark.parametrize( "actual_error, expected_error", [ @@ -36,12 +63,12 @@ AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") (Exception, IslandAPIError), ], ) -def test_island_api_client(actual_error, expected_error): +def test_island_api_client(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI, exc=actual_error) with pytest.raises(expected_error): - HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) @pytest.mark.parametrize( @@ -51,12 +78,12 @@ def test_island_api_client(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client__status_code(status_code, expected_error): +def test_island_api_client__status_code(island_api_client, status_code, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI, status_code=status_code) with pytest.raises(expected_error): - HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) @pytest.mark.parametrize( @@ -67,10 +94,10 @@ def test_island_api_client__status_code(status_code, expected_error): (Exception, IslandAPIError), ], ) -def test_island_api_client__send_log(actual_error, expected_error): +def test_island_api_client__send_log(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_LOG_URI, exc=actual_error) @@ -84,10 +111,10 @@ def test_island_api_client__send_log(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_send_log__status_code(status_code, expected_error): +def test_island_api_client_send_log__status_code(island_api_client, status_code, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_LOG_URI, status_code=status_code) @@ -102,10 +129,10 @@ def test_island_api_client_send_log__status_code(status_code, expected_error): (Exception, IslandAPIError), ], ) -def test_island_api_client__get_pba_file(actual_error, expected_error): +def test_island_api_client__get_pba_file(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.get(ISLAND_GET_PBA_FILE_URI, exc=actual_error) @@ -119,37 +146,19 @@ def test_island_api_client__get_pba_file(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_get_pba_file__status_code(status_code, expected_error): +def test_island_api_client_get_pba_file__status_code( + island_api_client, status_code, expected_error +): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code) island_api_client.get_pba_file(filename=PBA_FILE) -class Event1(AbstractAgentEvent): - a: int - - -class Event2(AbstractAgentEvent): - b: str - - -class Event3(AbstractAgentEvent): - c: int - -@pytest.fixture -def agent_event_serializer_registry(): - agent_event_serializer_registry = AgentEventSerializerRegistry() - agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1) - agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2) - - return agent_event_serializer_registry - - -def test_island_api_client_send_events__serialization(agent_event_serializer_registry): +def test_island_api_client_send_events__serialization(island_api_client): events_to_send = [ Event1(source=AGENT_ID, timestamp=0, a=1), Event2(source=AGENT_ID, timestamp=0, b="hello"), @@ -176,16 +185,17 @@ def test_island_api_client_send_events__serialization(agent_event_serializer_reg with requests_mock.Mocker() as m: m.get(ISLAND_URI) m.post(ISLAND_SEND_EVENTS_URI) - island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + island_api_client.connect(SERVER) island_api_client.send_events(events=events_to_send) assert m.last_request.json() == expected_json -def test_island_api_client_send_events__serialization_failed(agent_event_serializer_registry): + +def test_island_api_client_send_events__serialization_failed(island_api_client): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + island_api_client.connect(SERVER) with pytest.raises(IslandAPIRequestError): m.post(ISLAND_SEND_EVENTS_URI) @@ -200,12 +210,10 @@ def test_island_api_client_send_events__serialization_failed(agent_event_seriali (Exception, IslandAPIError), ], ) -def test_island_api_client__send_events( - actual_error, expected_error, agent_event_serializer_registry -): +def test_island_api_client__send_events(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error) @@ -219,12 +227,10 @@ def test_island_api_client__send_events( (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_send_events__status_code( - status_code, expected_error, agent_event_serializer_registry -): +def test_island_api_client_send_events__status_code(island_api_client, status_code, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER, agent_event_serializer_registry) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index dd00ebc94..affb61aff 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -1,7 +1,12 @@ import pytest import requests_mock -from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError +from common.agent_event_serializers import AgentEventSerializerRegistry +from infection_monkey.island_api_client import ( + HTTPIslandAPIClientFactory, + IIslandAPIClient, + IslandAPIConnectionError, +) from infection_monkey.network.relay.utils import find_available_island_apis SERVER_1 = "1.1.1.1:12312" @@ -13,6 +18,11 @@ SERVER_4 = "4.4.4.4:5000" servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] +@pytest.fixture +def island_api_client_factory(): + return HTTPIslandAPIClientFactory(AgentEventSerializerRegistry()) + + @pytest.mark.parametrize( "expected_available_servers, server_response_pairs", [ @@ -24,12 +34,14 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] ), ], ) -def test_find_available_island_apis(expected_available_servers, server_response_pairs): +def test_find_available_island_apis( + expected_available_servers, server_response_pairs, island_api_client_factory +): with requests_mock.Mocker() as mock: for server, response in server_response_pairs: mock.get(f"https://{server}/api?action=is-up", **response) - available_apis = find_available_island_apis(servers) + available_apis = find_available_island_apis(servers, island_api_client_factory) assert len(available_apis) == len(server_response_pairs) @@ -40,14 +52,14 @@ def test_find_available_island_apis(expected_available_servers, server_response_ assert island_api_client is None -def test_find_available_island_apis__multiple_successes(): +def test_find_available_island_apis__multiple_successes(island_api_client_factory): available_servers = [SERVER_2, SERVER_3] with requests_mock.Mocker() as mock: mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError) for server in available_servers: mock.get(f"https://{server}/api?action=is-up", text="") - available_apis = find_available_island_apis(servers) + available_apis = find_available_island_apis(servers, island_api_client_factory) assert available_apis[SERVER_1] is None assert available_apis[SERVER_4] is None