diff --git a/monkey/infection_monkey/agent_event_forwarder.py b/monkey/infection_monkey/agent_event_forwarder.py index f280ad4f3..f253255d3 100644 --- a/monkey/infection_monkey/agent_event_forwarder.py +++ b/monkey/infection_monkey/agent_event_forwarder.py @@ -3,18 +3,15 @@ import queue import threading from time import sleep -import requests - -from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable +from common.agent_event_serializers import AgentEventSerializerRegistry from common.agent_events import AbstractAgentEvent -from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT +from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.utils.threading import create_daemon_thread logger = logging.getLogger(__name__) DEFAULT_TIME_PERIOD_SECONDS = 5 -AGENT_EVENTS_API_URL = "https://%s/api/agent-events" class AgentEventForwarder: @@ -23,27 +20,21 @@ class AgentEventForwarder: """ def __init__( - self, server_address: str, agent_event_serializer_registry: AgentEventSerializerRegistry + self, + island_api_client: IIslandAPIClient, + agent_event_serializer_registry: AgentEventSerializerRegistry, ): - self._server_address = server_address self._agent_event_serializer_registry = agent_event_serializer_registry - self._batching_agent_event_forwarder = BatchingAgentEventForwarder(self._server_address) + self._batching_agent_event_forwarder = BatchingAgentEventForwarder(island_api_client) self._batching_agent_event_forwarder.start() def __del__(self): self._batching_agent_event_forwarder.stop() def send_event(self, event: AbstractAgentEvent): - serialized_event = self._serialize_event(event) - self._batching_agent_event_forwarder.add_event_to_queue(serialized_event) - logger.debug( - f"Sending event of type {type(event).__name__} to the Island at {self._server_address}" - ) - - def _serialize_event(self, event: AbstractAgentEvent): - serializer = self._agent_event_serializer_registry[event.__class__] - return serializer.serialize(event) + self._batching_agent_event_forwarder.add_event_to_queue(event) + logger.debug(f"Sending event of type {type(event).__name__} to the Island") class BatchingAgentEventForwarder: @@ -51,8 +42,10 @@ class BatchingAgentEventForwarder: Handles the batching and sending of the Agent's events to the Island """ - def __init__(self, server_address: str, time_period: int = DEFAULT_TIME_PERIOD_SECONDS): - self._server_address = server_address + def __init__( + self, island_api_client: IIslandAPIClient, time_period: int = DEFAULT_TIME_PERIOD_SECONDS + ): + self._island_api_client = island_api_client self._time_period = time_period self._queue: queue.Queue[AbstractAgentEvent] = queue.Queue() @@ -64,7 +57,7 @@ class BatchingAgentEventForwarder: ) self._batch_and_send_thread.start() - def add_event_to_queue(self, serialized_event: JSONSerializable): + def add_event_to_queue(self, serialized_event: AbstractAgentEvent): self._queue.put(serialized_event) def _manage_event_batches(self): @@ -84,18 +77,10 @@ class BatchingAgentEventForwarder: events.append(self._queue.get(block=False)) try: - logger.debug(f"Sending Agent events to Island at {self._server_address}: {events}") - requests.post( # noqa: DUO123 - AGENT_EVENTS_API_URL % (self._server_address,), - json=events, - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - except Exception as exc: - logger.warning( - f"Exception caught when connecting to the Island at {self._server_address}" - f": {exc}" - ) + logger.debug(f"Sending Agent events to Island: {events}") + self._island_api_client.send_events(events) + except Exception as err: + logger.warning(f"Exception caught when connecting to the Island: {err}") def _send_remaining_events(self): self._send_events_to_island() diff --git a/monkey/infection_monkey/island_api_client/__init__.py b/monkey/infection_monkey/island_api_client/__init__.py index ec513e774..0eb69612e 100644 --- a/monkey/infection_monkey/island_api_client/__init__.py +++ b/monkey/infection_monkey/island_api_client/__init__.py @@ -6,4 +6,5 @@ from .island_api_client_errors import ( IslandAPITimeoutError, ) from .i_island_api_client import IIslandAPIClient -from .http_island_api_client import HTTPIslandAPIClient +from .abstract_island_api_client_factory import AbstractIslandAPIClientFactory +from .http_island_api_client import HTTPIslandAPIClient, HTTPIslandAPIClientFactory diff --git a/monkey/infection_monkey/island_api_client/abstract_island_api_client_factory.py b/monkey/infection_monkey/island_api_client/abstract_island_api_client_factory.py new file mode 100644 index 000000000..2a74dcd96 --- /dev/null +++ b/monkey/infection_monkey/island_api_client/abstract_island_api_client_factory.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +from . import IIslandAPIClient + + +class AbstractIslandAPIClientFactory(ABC): + @abstractmethod + def create_island_api_client(self) -> IIslandAPIClient: + """ + Create an IIslandAPIClient + + :return: A concrete instance of an IIslandAPIClient + """ diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 3b2adfe61..a0bebc5a0 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,12 +1,16 @@ import functools import logging +from typing import List, Sequence import requests from common import OperatingSystem +from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable +from common.agent_events import AbstractAgentEvent from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from . import ( + AbstractIslandAPIClientFactory, IIslandAPIClient, IslandAPIConnectionError, IslandAPIError, @@ -34,6 +38,8 @@ def handle_island_errors(fn): raise IslandAPIError(err) except TimeoutError as err: raise IslandAPITimeoutError(err) + except IslandAPIError as err: + raise err except Exception as err: raise IslandAPIError(err) @@ -45,8 +51,17 @@ class HTTPIslandAPIClient(IIslandAPIClient): A client for the Island's HTTP API """ + def __init__( + self, + agent_event_serializer_registry: AgentEventSerializerRegistry, + ): + self._agent_event_serializer_registry = agent_event_serializer_registry + @handle_island_errors - def __init__(self, island_server: str): + def connect( + self, + island_server: str, + ): response = requests.get( # noqa: DUO123 f"https://{island_server}/api?action=is-up", verify=False, @@ -89,3 +104,37 @@ class HTTPIslandAPIClient(IIslandAPIClient): response.raise_for_status() return response.content + + @handle_island_errors + def send_events(self, events: Sequence[JSONSerializable]): + response = requests.post( # noqa: DUO123 + f"{self._api_url}/agent-events", + json=self._serialize_events(events), + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + + response.raise_for_status() + + def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: + serialized_events: List[JSONSerializable] = [] + + try: + for e in events: + serializer = self._agent_event_serializer_registry[e.__class__] + serialized_events.append(serializer.serialize(e)) + except Exception as err: + raise IslandAPIRequestError(err) + + return serialized_events + + +class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): + def __init__( + self, + agent_event_serializer_registry: AgentEventSerializerRegistry = None, + ): + self._agent_event_serializer_registry = agent_event_serializer_registry + + def create_island_api_client(self): + return HTTPIslandAPIClient(self._agent_event_serializer_registry) diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index fea93d1dc..a15bfb671 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Sequence + +from common.agent_events import AbstractAgentEvent class IIslandAPIClient(ABC): @@ -8,9 +10,9 @@ class IIslandAPIClient(ABC): """ @abstractmethod - def __init__(self, island_server: str): + def connect(self, island_server: str): """ - Construct an island API client and connect it to the island + Connectto the island's API :param island_server: The socket address of the API :raises IslandAPIConnectionError: If the client cannot successfully connect to the island @@ -71,4 +73,21 @@ class IIslandAPIClient(ABC): :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the agent binary + + """ + + @abstractmethod + def send_events(self, events: Sequence[AbstractAgentEvent]): + """ + Send a sequence of agent events to the Island + + :param events: A sequence of agent events + :raises IslandAPIConnectionError: If the client cannot successfully connect to the island + :raises IslandAPIRequestError: If an error occurs while attempting to connect to the + island due to an issue in the request sent from the client + :raises IslandAPIRequestFailedError: If an error occurs while attempting to connect to the + island due to an error on the server + :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island + :raises IslandAPIError: If an unexpected error occurs while attempting to send events to + the island """ diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 21b115735..50ed3a3cb 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -45,7 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.exploit.wmiexec import WmiExploiter from infection_monkey.exploit.zerologon import ZerologonExploiter from infection_monkey.i_puppet import IPuppet, PluginType -from infection_monkey.island_api_client import IIslandAPIClient +from infection_monkey.island_api_client import HTTPIslandAPIClientFactory, IIslandAPIClient from infection_monkey.master import AutomatedMaster from infection_monkey.master.control_channel import ControlChannel from infection_monkey.model import VictimHostFactory @@ -107,9 +107,12 @@ logging.getLogger("urllib3").setLevel(logging.INFO) class InfectionMonkey: def __init__(self, args): logger.info("Monkey is initializing...") + self._singleton = SystemSingleton() self._opts = self._get_arguments(args) + self._agent_event_serializer_registry = self._setup_agent_event_serializers() + server, self._island_api_client = self._connect_to_island_api() # TODO: `address_to_port()` should return the port as an integer. self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) @@ -141,7 +144,9 @@ class InfectionMonkey: # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server` def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]: logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") - server_clients = find_available_island_apis(self._opts.servers) + server_clients = find_available_island_apis( + self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) + ) server, island_api_client = self._select_server(server_clients) @@ -206,8 +211,6 @@ class InfectionMonkey: if firewall.is_enabled(): firewall.add_firewall_rule() - self._agent_event_serializer_registry = self._setup_agent_event_serializers() - self._control_channel = ControlChannel(self._control_client.server_address, GUID) self._control_channel.register_agent(self._opts.parent) @@ -247,7 +250,7 @@ class InfectionMonkey: ) event_queue = PyPubSubAgentEventQueue(Publisher()) - InfectionMonkey._subscribe_events( + self._subscribe_events( event_queue, propagation_credentials_repository, self._control_client.server_address, @@ -273,8 +276,8 @@ class InfectionMonkey: propagation_credentials_repository, ) - @staticmethod def _subscribe_events( + self, event_queue: IAgentEventQueue, propagation_credentials_repository: IPropagationCredentialsRepository, server_address: str, @@ -287,7 +290,7 @@ class InfectionMonkey: ), ) event_queue.subscribe_all_events( - AgentEventForwarder(server_address, agent_event_serializer_registry).send_event + AgentEventForwarder(self._island_api_client, agent_event_serializer_registry).send_event ) def _build_puppet( diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 451fd65e7..92c2fb767 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -7,7 +7,7 @@ from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port from infection_monkey.island_api_client import ( - HTTPIslandAPIClient, + AbstractIslandAPIClientFactory, IIslandAPIClient, IslandAPIConnectionError, IslandAPIError, @@ -27,7 +27,9 @@ logger = logging.getLogger(__name__) NUM_FIND_SERVER_WORKERS = 32 -def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]: +def find_available_island_apis( + servers: Iterable[str], island_api_client_factory: AbstractIslandAPIClientFactory +) -> Mapping[str, Optional[IIslandAPIClient]]: server_list = list(servers) server_iterator = ThreadSafeIterator(server_list.__iter__()) server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {} @@ -35,7 +37,7 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[ run_worker_threads( _find_island_server, "FindIslandServer", - args=(server_iterator, server_results), + args=(server_iterator, server_results, island_api_client_factory), num_workers=NUM_FIND_SERVER_WORKERS, ) @@ -43,18 +45,25 @@ def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[ def _find_island_server( - servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]] + servers: Iterator[str], + server_status: MutableMapping[str, Optional[IIslandAPIClient]], + island_api_client_factory: AbstractIslandAPIClientFactory, ): with suppress(StopIteration): server = next(servers) - server_status[server] = _check_if_island_server(server) + server_status[server] = _check_if_island_server(server, island_api_client_factory) -def _check_if_island_server(server: str) -> IIslandAPIClient: +def _check_if_island_server( + server: str, island_api_client_factory: AbstractIslandAPIClientFactory +) -> IIslandAPIClient: logger.debug(f"Trying to connect to server: {server}") try: - return HTTPIslandAPIClient(server) + client = island_api_client_factory.create_island_api_client() + client.connect(server) + + return client except IslandAPIConnectionError as err: logger.error(f"Unable to connect to server/relay {server}: {err}") except IslandAPITimeoutError as err: diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 22ad87161..1d0ac17fc 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -1,8 +1,15 @@ +from uuid import UUID + import pytest import requests import requests_mock from common import OperatingSystem +from common.agent_event_serializers import ( + AgentEventSerializerRegistry, + PydanticAgentEventSerializer, +) +from common.agent_events import AbstractAgentEvent from infection_monkey.island_api_client import ( HTTPIslandAPIClient, IslandAPIConnectionError, @@ -20,6 +27,35 @@ ISLAND_URI = f"https://{SERVER}/api?action=is-up" ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}" +ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" + +AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") + + +class Event1(AbstractAgentEvent): + a: int + + +class Event2(AbstractAgentEvent): + b: str + + +class Event3(AbstractAgentEvent): + c: int + + +@pytest.fixture +def agent_event_serializer_registry(): + agent_event_serializer_registry = AgentEventSerializerRegistry() + agent_event_serializer_registry[Event1] = PydanticAgentEventSerializer(Event1) + agent_event_serializer_registry[Event2] = PydanticAgentEventSerializer(Event2) + + return agent_event_serializer_registry + + +@pytest.fixture +def island_api_client(agent_event_serializer_registry): + return HTTPIslandAPIClient(agent_event_serializer_registry) @pytest.mark.parametrize( @@ -30,12 +66,12 @@ ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}" (Exception, IslandAPIError), ], ) -def test_island_api_client(actual_error, expected_error): +def test_island_api_client(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI, exc=actual_error) with pytest.raises(expected_error): - HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) @pytest.mark.parametrize( @@ -45,12 +81,12 @@ def test_island_api_client(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client__status_code(status_code, expected_error): +def test_island_api_client__status_code(island_api_client, status_code, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI, status_code=status_code) with pytest.raises(expected_error): - HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) @pytest.mark.parametrize( @@ -61,10 +97,10 @@ def test_island_api_client__status_code(status_code, expected_error): (Exception, IslandAPIError), ], ) -def test_island_api_client__send_log(actual_error, expected_error): +def test_island_api_client__send_log(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_LOG_URI, exc=actual_error) @@ -78,10 +114,10 @@ def test_island_api_client__send_log(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_send_log__status_code(status_code, expected_error): +def test_island_api_client_send_log__status_code(island_api_client, status_code, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.post(ISLAND_SEND_LOG_URI, status_code=status_code) @@ -96,10 +132,10 @@ def test_island_api_client_send_log__status_code(status_code, expected_error): (Exception, IslandAPIError), ], ) -def test_island_api_client__get_pba_file(actual_error, expected_error): +def test_island_api_client__get_pba_file(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.get(ISLAND_GET_PBA_FILE_URI, exc=actual_error) @@ -113,10 +149,12 @@ def test_island_api_client__get_pba_file(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client_get_pba_file__status_code(status_code, expected_error): +def test_island_api_client_get_pba_file__status_code( + island_api_client, status_code, expected_error +): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.get(ISLAND_GET_PBA_FILE_URI, status_code=status_code) @@ -131,10 +169,10 @@ def test_island_api_client_get_pba_file__status_code(status_code, expected_error (Exception, IslandAPIError), ], ) -def test_island_api_client__get_agent_binary(actual_error, expected_error): +def test_island_api_client__get_agent_binary(island_api_client, actual_error, expected_error): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.get(ISLAND_GET_AGENT_BINARY_URI, exc=actual_error) @@ -148,11 +186,92 @@ def test_island_api_client__get_agent_binary(actual_error, expected_error): (501, IslandAPIRequestFailedError), ], ) -def test_island_api_client__get_agent_binary_status_code(status_code, expected_error): +def test_island_api_client__get_agent_binary_status_code( + island_api_client, status_code, expected_error +): with requests_mock.Mocker() as m: m.get(ISLAND_URI) - island_api_client = HTTPIslandAPIClient(SERVER) + island_api_client.connect(SERVER) with pytest.raises(expected_error): m.get(ISLAND_GET_AGENT_BINARY_URI, status_code=status_code) island_api_client.get_agent_binary(operating_system=OperatingSystem.WINDOWS) + + +def test_island_api_client_send_events__serialization(island_api_client): + events_to_send = [ + Event1(source=AGENT_ID, timestamp=0, a=1), + Event2(source=AGENT_ID, timestamp=0, b="hello"), + ] + expected_json = [ + { + "source": "80988359-a1cd-42a2-9b47-5b94b37cd673", + "target": None, + "timestamp": 0.0, + "tags": [], + "a": 1, + "type": "Event1", + }, + { + "source": "80988359-a1cd-42a2-9b47-5b94b37cd673", + "target": None, + "timestamp": 0.0, + "tags": [], + "b": "hello", + "type": "Event2", + }, + ] + + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + m.post(ISLAND_SEND_EVENTS_URI) + island_api_client.connect(SERVER) + + island_api_client.send_events(events=events_to_send) + + assert m.last_request.json() == expected_json + + +def test_island_api_client_send_events__serialization_failed(island_api_client): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(IslandAPIRequestError): + m.post(ISLAND_SEND_EVENTS_URI) + island_api_client.send_events(events=[Event3(source=AGENT_ID, c=1)]) + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + (Exception, IslandAPIError), + ], +) +def test_island_api_client__send_events(island_api_client, actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.post(ISLAND_SEND_EVENTS_URI, exc=actual_error) + island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)]) + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_send_events__status_code(island_api_client, status_code, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) + island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)]) diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index dd00ebc94..affb61aff 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -1,7 +1,12 @@ import pytest import requests_mock -from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError +from common.agent_event_serializers import AgentEventSerializerRegistry +from infection_monkey.island_api_client import ( + HTTPIslandAPIClientFactory, + IIslandAPIClient, + IslandAPIConnectionError, +) from infection_monkey.network.relay.utils import find_available_island_apis SERVER_1 = "1.1.1.1:12312" @@ -13,6 +18,11 @@ SERVER_4 = "4.4.4.4:5000" servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] +@pytest.fixture +def island_api_client_factory(): + return HTTPIslandAPIClientFactory(AgentEventSerializerRegistry()) + + @pytest.mark.parametrize( "expected_available_servers, server_response_pairs", [ @@ -24,12 +34,14 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] ), ], ) -def test_find_available_island_apis(expected_available_servers, server_response_pairs): +def test_find_available_island_apis( + expected_available_servers, server_response_pairs, island_api_client_factory +): with requests_mock.Mocker() as mock: for server, response in server_response_pairs: mock.get(f"https://{server}/api?action=is-up", **response) - available_apis = find_available_island_apis(servers) + available_apis = find_available_island_apis(servers, island_api_client_factory) assert len(available_apis) == len(server_response_pairs) @@ -40,14 +52,14 @@ def test_find_available_island_apis(expected_available_servers, server_response_ assert island_api_client is None -def test_find_available_island_apis__multiple_successes(): +def test_find_available_island_apis__multiple_successes(island_api_client_factory): available_servers = [SERVER_2, SERVER_3] with requests_mock.Mocker() as mock: mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError) for server in available_servers: mock.get(f"https://{server}/api?action=is-up", text="") - available_apis = find_available_island_apis(servers) + available_apis = find_available_island_apis(servers, island_api_client_factory) assert available_apis[SERVER_1] is None assert available_apis[SERVER_4] is None diff --git a/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py b/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py index 718f11de9..aa2d1381a 100644 --- a/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py +++ b/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py @@ -1,16 +1,22 @@ import time +from unittest.mock import MagicMock import pytest -import requests_mock -from infection_monkey.agent_event_forwarder import AGENT_EVENTS_API_URL, BatchingAgentEventForwarder +from infection_monkey.agent_event_forwarder import BatchingAgentEventForwarder +from infection_monkey.island_api_client import IIslandAPIClient SERVER = "1.1.1.1:9999" @pytest.fixture -def event_sender(): - return BatchingAgentEventForwarder(SERVER, time_period=0.001) +def mock_island_api_client(): + return MagicMock(spec=IIslandAPIClient) + + +@pytest.fixture +def event_sender(mock_island_api_client): + return BatchingAgentEventForwarder(mock_island_api_client, time_period=0.001) # NOTE: If these tests are too slow or end up being racey, we can redesign AgentEventForwarder to @@ -18,35 +24,29 @@ def event_sender(): # BatchingAgentEventForwarder would have unit tests, but AgentEventForwarder would not. -def test_send_events(event_sender): - with requests_mock.Mocker() as mock: - mock.post(AGENT_EVENTS_API_URL % SERVER) - - event_sender.start() - - for _ in range(5): - event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 1 +def test_send_events(event_sender, mock_island_api_client): + event_sender.start() + for _ in range(5): event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 2 + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 1 - event_sender.stop() + event_sender.add_event_to_queue({}) + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 2 + + event_sender.stop() -def test_send_remaining_events(event_sender): - with requests_mock.Mocker() as mock: - mock.post(AGENT_EVENTS_API_URL % SERVER) - - event_sender.start() - - for _ in range(5): - event_sender.add_event_to_queue({}) - time.sleep(0.02) - assert mock.call_count == 1 +def test_send_remaining_events(event_sender, mock_island_api_client): + event_sender.start() + for _ in range(5): event_sender.add_event_to_queue({}) - event_sender.stop() - assert mock.call_count == 2 + time.sleep(0.05) + assert mock_island_api_client.send_events.call_count == 1 + + event_sender.add_event_to_queue({}) + event_sender.stop() + assert mock_island_api_client.send_events.call_count == 2