diff --git a/monkey/infection_monkey/telemetry/messengers/batching_telemetry_messenger.py b/monkey/infection_monkey/telemetry/messengers/batching_telemetry_messenger.py index 4d34012d8..9dc051666 100644 --- a/monkey/infection_monkey/telemetry/messengers/batching_telemetry_messenger.py +++ b/monkey/infection_monkey/telemetry/messengers/batching_telemetry_messenger.py @@ -1,11 +1,11 @@ import queue import threading -import time from typing import Dict from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem from infection_monkey.telemetry.i_telem import ITelem from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger +from infection_monkey.utils.timer import Timer DEFAULT_PERIOD = 5 WAKES_PER_PERIOD = 4 @@ -40,8 +40,6 @@ class BatchingTelemetryMessenger(ITelemetryMessenger): self._period = period self._should_run_batch_thread = True - # TODO: Replace with infection_monkey.utils.timer.Timer - self._last_sent_time = time.time() self._telemetry_batches: Dict[str, IBatchableTelem] = {} self._manage_telemetry_batches_thread = None @@ -59,21 +57,20 @@ class BatchingTelemetryMessenger(ITelemetryMessenger): self._manage_telemetry_batches_thread = None def _manage_telemetry_batches(self): - self._reset() + timer = Timer() + timer.set(self._period) + self._telemetry_batches = {} while self._should_run_batch_thread: self._process_next_telemetry() - if self._period_elapsed(): + if timer.is_expired(): self._send_telemetry_batches() - self._reset() + timer.reset() + self._telemetry_batches = {} self._send_remaining_telemetry_batches() - def _reset(self): - self._last_sent_time = time.time() - self._telemetry_batches = {} - def _process_next_telemetry(self): try: telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD) @@ -93,9 +90,6 @@ class BatchingTelemetryMessenger(ITelemetryMessenger): else: self._telemetry_batches[telem_category] = new_telemetry - def _period_elapsed(self): - return (time.time() - self._last_sent_time) > self._period - def _send_remaining_telemetry_batches(self): while not self._queue.empty(): self._process_next_telemetry() diff --git a/monkey/infection_monkey/tunnel.py b/monkey/infection_monkey/tunnel.py index b0f778534..26368bff6 100644 --- a/monkey/infection_monkey/tunnel.py +++ b/monkey/infection_monkey/tunnel.py @@ -8,6 +8,7 @@ from infection_monkey.network.firewall import app as firewall from infection_monkey.network.info import get_free_tcp_port, local_ips from infection_monkey.network.tools import check_tcp_port, get_interface_to_target from infection_monkey.transport.base import get_last_serve_time +from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) @@ -181,8 +182,9 @@ class MonkeyTunnel(Thread): # wait till all of the tunnel clients has been disconnected, or no one used the tunnel in # QUIT_TIMEOUT seconds - # TODO: Replace with infection_monkey.utils.timer.Timer - while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT): + timer = Timer() + timer.set(self._calculate_timeout()) + while self._clients and not timer.is_expired(): try: search, address = self._broad_sock.recvfrom(BUFFER_READ) if b"-" == search: @@ -191,11 +193,19 @@ class MonkeyTunnel(Thread): except socket.timeout: continue + timer.set(self._calculate_timeout()) + logger.info("Closing tunnel") self._broad_sock.close() proxy.stop() proxy.join() + def _calculate_timeout(self) -> float: + try: + return QUIT_TIMEOUT - (time.time() - get_last_serve_time()) + except TypeError: # get_last_serve_time() may return None + return 0.0 + def get_tunnel_for_ip(self, ip: str): if not self.local_port: diff --git a/monkey/infection_monkey/utils/timer.py b/monkey/infection_monkey/utils/timer.py index 2ed17d551..6095d466e 100644 --- a/monkey/infection_monkey/utils/timer.py +++ b/monkey/infection_monkey/utils/timer.py @@ -13,7 +13,8 @@ class Timer: def set(self, timeout_sec: float): """ Set a timer - :param float timeout_sec: A fractional number of seconds to set the timeout for. + :param float timeout_sec: A nonnegative floating point number expressing the number of + seconds to set the timeout for. """ self._timeout_sec = timeout_sec self._start_time = time.time()