diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index bd9062eeb..622c17d7d 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -474,12 +474,10 @@ class InfectionMonkey(object): def run_ransomware(): telemetry_messenger = LegacyTelemetryMessengerAdapter() batching_telemetry_messenger = BatchingTelemetryMessenger(telemetry_messenger) - batching_telemetry_messenger.start() + try: RansomewarePayload( WormConfiguration.ransomware, batching_telemetry_messenger ).run_payload() except Exception as ex: LOG.error(f"An unexpected error occurred while running the ransomware payload: {ex}") - finally: - batching_telemetry_messenger.stop() diff --git a/monkey/infection_monkey/telemetry/messengers/batching_telemetry_messenger.py b/monkey/infection_monkey/telemetry/messengers/batching_telemetry_messenger.py index f5f21a760..9541d34d1 100644 --- a/monkey/infection_monkey/telemetry/messengers/batching_telemetry_messenger.py +++ b/monkey/infection_monkey/telemetry/messengers/batching_telemetry_messenger.py @@ -18,68 +18,77 @@ class BatchingTelemetryMessenger(ITelemetryMessenger): """ def __init__(self, telemetry_messenger: ITelemetryMessenger, period=DEFAULT_PERIOD): - self._telemetry_messenger = telemetry_messenger - self._period = period - - self._should_run_batch_thread = True self._queue: queue.Queue[ITelem] = queue.Queue() - # TODO: Create a "timer" or "countdown" class and inject an object instead of - # using time.time() - self._last_sent_time = time.time() - self._telemetry_batches: Dict[str, IBatchableTelem] = {} + self._thread = BatchingTelemetryMessenger._BatchingTelemetryMessengerThread( + self._queue, telemetry_messenger, period + ) + + self._thread.start() def __del__(self): - self.stop() - - def start(self): - self._should_run_batch_thread = True - self._manage_telemetry_batches_thread = threading.Thread( - target=self._manage_telemetry_batches - ) - self._manage_telemetry_batches_thread.start() - - def stop(self): - self._should_run_batch_thread = False - self._manage_telemetry_batches_thread.join() + self._thread.stop() def send_telemetry(self, telemetry: ITelem): self._queue.put(telemetry) - def _manage_telemetry_batches(self): - self._reset() + class _BatchingTelemetryMessengerThread: + def __init__(self, queue: queue.Queue, telemetry_messenger: ITelemetryMessenger, period): + self._queue: queue.Queue[ITelem] = queue + self._telemetry_messenger = telemetry_messenger + self._period = period - while self._should_run_batch_thread: - try: - telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD) + self._should_run_batch_thread = True + # TODO: Create a "timer" or "countdown" class and inject an object instead of + # using time.time() + self._last_sent_time = time.time() + self._telemetry_batches: Dict[str, IBatchableTelem] = {} - if isinstance(telemetry, IBatchableTelem): - self._add_telemetry_to_batch(telemetry) - else: - self._telemetry_messenger.send_telemetry(telemetry) - except queue.Empty: - pass + def start(self): + self._should_run_batch_thread = True + self._manage_telemetry_batches_thread = threading.Thread( + target=self._manage_telemetry_batches + ) + self._manage_telemetry_batches_thread.start() - if self._period_elapsed(): - self._send_telemetry_batches() - self._reset() + def stop(self): + self._should_run_batch_thread = False + self._manage_telemetry_batches_thread.join() - self._send_telemetry_batches() + def _manage_telemetry_batches(self): + self._reset() - def _reset(self): - self._last_sent_time = time.time() - self._telemetry_batches = {} + while self._should_run_batch_thread: + try: + telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD) - def _add_telemetry_to_batch(self, new_telemetry: IBatchableTelem): - telem_category = new_telemetry.telem_category + if isinstance(telemetry, IBatchableTelem): + self._add_telemetry_to_batch(telemetry) + else: + self._telemetry_messenger.send_telemetry(telemetry) + except queue.Empty: + pass - if telem_category in self._telemetry_batches: - self._telemetry_batches[telem_category].add_telemetry_to_batch(new_telemetry) - else: - self._telemetry_batches[telem_category] = new_telemetry + if self._period_elapsed(): + self._send_telemetry_batches() + self._reset() - def _period_elapsed(self): - return (time.time() - self._last_sent_time) > self._period + self._send_telemetry_batches() - def _send_telemetry_batches(self): - for batchable_telemetry in self._telemetry_batches.values(): - self._telemetry_messenger.send_telemetry(batchable_telemetry) + def _reset(self): + self._last_sent_time = time.time() + self._telemetry_batches = {} + + def _add_telemetry_to_batch(self, new_telemetry: IBatchableTelem): + telem_category = new_telemetry.telem_category + + if telem_category in self._telemetry_batches: + self._telemetry_batches[telem_category].add_telemetry_to_batch(new_telemetry) + else: + self._telemetry_batches[telem_category] = new_telemetry + + def _period_elapsed(self): + return (time.time() - self._last_sent_time) > self._period + + def _send_telemetry_batches(self): + for batchable_telemetry in self._telemetry_batches.values(): + self._telemetry_messenger.send_telemetry(batchable_telemetry) diff --git a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_batching_telemetry_messenger.py b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_batching_telemetry_messenger.py index 65eebc582..2de3e0ffe 100644 --- a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_batching_telemetry_messenger.py +++ b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_batching_telemetry_messenger.py @@ -56,11 +56,7 @@ class BatchableTelemStub(BatchableTelemMixin, BaseTelem, IBatchableTelem): @pytest.fixture def batching_telemetry_messenger(monkeypatch, telemetry_messenger_spy): patch_time(monkeypatch, 0) - btm = BatchingTelemetryMessenger(telemetry_messenger_spy, period=0.001) - btm.start() - yield btm - - btm.stop() + return BatchingTelemetryMessenger(telemetry_messenger_spy, period=0.001) def test_send_immediately(batching_telemetry_messenger, telemetry_messenger_spy):