diff --git a/monkey/infection_monkey/agent_event_forwarder.py b/monkey/infection_monkey/agent_event_forwarder.py new file mode 100644 index 000000000..f8e4dfef7 --- /dev/null +++ b/monkey/infection_monkey/agent_event_forwarder.py @@ -0,0 +1,106 @@ +import logging +import queue +import threading +from time import sleep + +import requests + +from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT +from common.event_serializers import EventSerializerRegistry +from common.event_serializers.i_event_serializer import JSONSerializable +from common.events import AbstractAgentEvent +from infection_monkey.utils.threading import create_daemon_thread + +logger = logging.getLogger(__name__) + + +DEFAULT_TIME_PERIOD_SECONDS = 5 +EVENTS_API_URL = "https://%s/api/events" + + +class AgentEventForwarder: + """ + Sends information about the events carried out by the Agent to the Island in batches + """ + + def __init__( + self, server_address: str, agent_event_serializer_registry: EventSerializerRegistry + ): + self._server_address = server_address + self._agent_event_serializer_registry = agent_event_serializer_registry + + self._batching_agent_event_forwarder = BatchingAgentEventForwarder(self._server_address) + self._batching_agent_event_forwarder.start() + + def __del__(self): + self._batching_agent_event_forwarder.stop() + + def send_event(self, event: AbstractAgentEvent): + serialized_event = self._serialize_event(event) + self._batching_agent_event_forwarder.add_event_to_queue(serialized_event) + logger.debug( + f"Sending event of type {type(event).__name__} to the Island at {self._server_address}" + ) + + def _serialize_event(self, event: AbstractAgentEvent): + serializer = self._agent_event_serializer_registry[event.__class__] + return serializer.serialize(event) + + +class BatchingAgentEventForwarder: + """ + Handles the batching and sending of the Agent's events to the Island + """ + + def __init__(self, server_address: str, time_period: int = DEFAULT_TIME_PERIOD_SECONDS): + self._server_address = server_address + self._time_period = time_period + + self._queue: queue.Queue[AbstractAgentEvent] = queue.Queue() + self._stop_batch_and_send_thread = threading.Event() + + def start(self): + self._batch_and_send_thread = create_daemon_thread( + target=self._manage_event_batches, name="SendEventsToIslandInBatchesThread" + ) + self._batch_and_send_thread.start() + + def add_event_to_queue(self, serialized_event: JSONSerializable): + self._queue.put(serialized_event) + + def _manage_event_batches(self): + while not self._stop_batch_and_send_thread.is_set(): + self._send_events_to_island() + sleep(self._time_period) + + self._send_remaining_events() + + def _send_events_to_island(self): + if self._queue.empty(): + return + + events = [] + + while not self._queue.empty(): + events.append(self._queue.get(block=False)) + + try: + logger.debug(f"Sending events to Island at {self._server_address}: {events}") + requests.post( # noqa: DUO123 + EVENTS_API_URL % (self._server_address,), + json=events, + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + except Exception as exc: + logger.warning( + f"Exception caught when connecting to the Island at {self._server_address}" + f": {exc}" + ) + + def _send_remaining_events(self): + self._send_events_to_island() + + def stop(self): + self._stop_batch_and_send_thread.set() + self._batch_and_send_thread.join() diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index d3b4df8db..dfc34bec2 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -19,6 +19,7 @@ from common.network.network_utils import address_to_ip_port from common.utils.argparse_types import positive_int from common.utils.attack_utils import ScanStatus, UsageEnum from common.version import get_version +from infection_monkey.agent_event_forwarder import AgentEventForwarder from infection_monkey.config import GUID from infection_monkey.control import ControlClient from infection_monkey.credential_collectors import ( @@ -186,7 +187,7 @@ class InfectionMonkey: if firewall.is_enabled(): firewall.add_firewall_rule() - _ = self._setup_agent_event_serializers() + self._agent_event_serializer_registry = self._setup_agent_event_serializers() self._control_channel = ControlChannel(self._control_client.server_address, GUID) self._control_channel.register_agent(self._opts.parent) @@ -227,7 +228,12 @@ class InfectionMonkey: ) event_queue = PyPubSubAgentEventQueue(Publisher()) - InfectionMonkey._subscribe_events(event_queue, propagation_credentials_repository) + InfectionMonkey._subscribe_events( + event_queue, + propagation_credentials_repository, + self._control_client.server_address, + self._agent_event_serializer_registry, + ) puppet = self._build_puppet(propagation_credentials_repository, event_queue) @@ -252,6 +258,8 @@ class InfectionMonkey: def _subscribe_events( event_queue: IAgentEventQueue, propagation_credentials_repository: IPropagationCredentialsRepository, + server_address: str, + agent_event_serializer_registry: EventSerializerRegistry, ): event_queue.subscribe_type( CredentialsStolenEvent, @@ -259,6 +267,9 @@ class InfectionMonkey: propagation_credentials_repository ), ) + event_queue.subscribe_all_events( + AgentEventForwarder(server_address, agent_event_serializer_registry).send_event + ) @staticmethod def _get_local_network_interfaces() -> List[IPv4Interface]: diff --git a/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py b/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py new file mode 100644 index 000000000..e0bbcb8c0 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/test_agent_event_forwarder.py @@ -0,0 +1,52 @@ +import time + +import pytest +import requests_mock + +from infection_monkey.agent_event_forwarder import EVENTS_API_URL, BatchingAgentEventForwarder + +SERVER = "1.1.1.1:9999" + + +@pytest.fixture +def event_sender(): + return BatchingAgentEventForwarder(SERVER, time_period=0.001) + + +# NOTE: If these tests are too slow or end up being racey, we can redesign AgentEventForwarder to +# handle threading and simply command BatchingAgentEventForwarder when to send events. +# BatchingAgentEventForwarder would have unit tests, but AgentEventForwarder would not. + + +def test_send_events(event_sender): + with requests_mock.Mocker() as mock: + mock.post(EVENTS_API_URL % SERVER) + + event_sender.start() + + for _ in range(5): + event_sender.add_event_to_queue({}) + time.sleep(0.01) + assert mock.call_count == 1 + + event_sender.add_event_to_queue({}) + time.sleep(0.01) + assert mock.call_count == 2 + + event_sender.stop() + + +def test_send_remaining_events(event_sender): + with requests_mock.Mocker() as mock: + mock.post(EVENTS_API_URL % SERVER) + + event_sender.start() + + for _ in range(5): + event_sender.add_event_to_queue({}) + time.sleep(0.01) + assert mock.call_count == 1 + + event_sender.add_event_to_queue({}) + event_sender.stop() + assert mock.call_count == 2