agent: Remove start/stop from BatchingTelemetryMessenger

My original plan was to start a thread in __init__() and stop the thread
when __del__() was called. Since the running thread (object) contains a
reference to the BatchingTelemetryMessenger object that launched it, the
destructor will not be called until the thread is stopped. This
resulted in adding a stop() method (fadd978) followed by adding a
start() method (1d066c8e).

By using an inner class to run the thread, we enable the class to be
used as originally intended, reducing the burden on the user of this
class. The thread is now started on construction and stopped on
destruction. The user can remain blissfully unaware that anything
resembling threading is going in, and can use the
BatchingTelemetryMessenger just like any other ITelemetryMessenger.
This commit is contained in:
Mike Salvatore 2021-06-28 12:21:57 -04:00
parent 1d066c8e6d
commit 2f62a14fbf
3 changed files with 60 additions and 57 deletions

View File

@ -474,12 +474,10 @@ class InfectionMonkey(object):
def run_ransomware(): def run_ransomware():
telemetry_messenger = LegacyTelemetryMessengerAdapter() telemetry_messenger = LegacyTelemetryMessengerAdapter()
batching_telemetry_messenger = BatchingTelemetryMessenger(telemetry_messenger) batching_telemetry_messenger = BatchingTelemetryMessenger(telemetry_messenger)
batching_telemetry_messenger.start()
try: try:
RansomewarePayload( RansomewarePayload(
WormConfiguration.ransomware, batching_telemetry_messenger WormConfiguration.ransomware, batching_telemetry_messenger
).run_payload() ).run_payload()
except Exception as ex: except Exception as ex:
LOG.error(f"An unexpected error occurred while running the ransomware payload: {ex}") LOG.error(f"An unexpected error occurred while running the ransomware payload: {ex}")
finally:
batching_telemetry_messenger.stop()

View File

@ -18,68 +18,77 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
""" """
def __init__(self, telemetry_messenger: ITelemetryMessenger, period=DEFAULT_PERIOD): def __init__(self, telemetry_messenger: ITelemetryMessenger, period=DEFAULT_PERIOD):
self._telemetry_messenger = telemetry_messenger
self._period = period
self._should_run_batch_thread = True
self._queue: queue.Queue[ITelem] = queue.Queue() self._queue: queue.Queue[ITelem] = queue.Queue()
# TODO: Create a "timer" or "countdown" class and inject an object instead of self._thread = BatchingTelemetryMessenger._BatchingTelemetryMessengerThread(
# using time.time() self._queue, telemetry_messenger, period
self._last_sent_time = time.time() )
self._telemetry_batches: Dict[str, IBatchableTelem] = {}
self._thread.start()
def __del__(self): def __del__(self):
self.stop() self._thread.stop()
def start(self):
self._should_run_batch_thread = True
self._manage_telemetry_batches_thread = threading.Thread(
target=self._manage_telemetry_batches
)
self._manage_telemetry_batches_thread.start()
def stop(self):
self._should_run_batch_thread = False
self._manage_telemetry_batches_thread.join()
def send_telemetry(self, telemetry: ITelem): def send_telemetry(self, telemetry: ITelem):
self._queue.put(telemetry) self._queue.put(telemetry)
def _manage_telemetry_batches(self): class _BatchingTelemetryMessengerThread:
self._reset() def __init__(self, queue: queue.Queue, telemetry_messenger: ITelemetryMessenger, period):
self._queue: queue.Queue[ITelem] = queue
self._telemetry_messenger = telemetry_messenger
self._period = period
while self._should_run_batch_thread: self._should_run_batch_thread = True
try: # TODO: Create a "timer" or "countdown" class and inject an object instead of
telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD) # using time.time()
self._last_sent_time = time.time()
self._telemetry_batches: Dict[str, IBatchableTelem] = {}
if isinstance(telemetry, IBatchableTelem): def start(self):
self._add_telemetry_to_batch(telemetry) self._should_run_batch_thread = True
else: self._manage_telemetry_batches_thread = threading.Thread(
self._telemetry_messenger.send_telemetry(telemetry) target=self._manage_telemetry_batches
except queue.Empty: )
pass self._manage_telemetry_batches_thread.start()
if self._period_elapsed(): def stop(self):
self._send_telemetry_batches() self._should_run_batch_thread = False
self._reset() self._manage_telemetry_batches_thread.join()
self._send_telemetry_batches() def _manage_telemetry_batches(self):
self._reset()
def _reset(self): while self._should_run_batch_thread:
self._last_sent_time = time.time() try:
self._telemetry_batches = {} telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD)
def _add_telemetry_to_batch(self, new_telemetry: IBatchableTelem): if isinstance(telemetry, IBatchableTelem):
telem_category = new_telemetry.telem_category self._add_telemetry_to_batch(telemetry)
else:
self._telemetry_messenger.send_telemetry(telemetry)
except queue.Empty:
pass
if telem_category in self._telemetry_batches: if self._period_elapsed():
self._telemetry_batches[telem_category].add_telemetry_to_batch(new_telemetry) self._send_telemetry_batches()
else: self._reset()
self._telemetry_batches[telem_category] = new_telemetry
def _period_elapsed(self): self._send_telemetry_batches()
return (time.time() - self._last_sent_time) > self._period
def _send_telemetry_batches(self): def _reset(self):
for batchable_telemetry in self._telemetry_batches.values(): self._last_sent_time = time.time()
self._telemetry_messenger.send_telemetry(batchable_telemetry) self._telemetry_batches = {}
def _add_telemetry_to_batch(self, new_telemetry: IBatchableTelem):
telem_category = new_telemetry.telem_category
if telem_category in self._telemetry_batches:
self._telemetry_batches[telem_category].add_telemetry_to_batch(new_telemetry)
else:
self._telemetry_batches[telem_category] = new_telemetry
def _period_elapsed(self):
return (time.time() - self._last_sent_time) > self._period
def _send_telemetry_batches(self):
for batchable_telemetry in self._telemetry_batches.values():
self._telemetry_messenger.send_telemetry(batchable_telemetry)

View File

@ -56,11 +56,7 @@ class BatchableTelemStub(BatchableTelemMixin, BaseTelem, IBatchableTelem):
@pytest.fixture @pytest.fixture
def batching_telemetry_messenger(monkeypatch, telemetry_messenger_spy): def batching_telemetry_messenger(monkeypatch, telemetry_messenger_spy):
patch_time(monkeypatch, 0) patch_time(monkeypatch, 0)
btm = BatchingTelemetryMessenger(telemetry_messenger_spy, period=0.001) return BatchingTelemetryMessenger(telemetry_messenger_spy, period=0.001)
btm.start()
yield btm
btm.stop()
def test_send_immediately(batching_telemetry_messenger, telemetry_messenger_spy): def test_send_immediately(batching_telemetry_messenger, telemetry_messenger_spy):