Merge pull request #1856 from guardicore/1826-timer-todos

Resolve `Timer` TODOs
This commit is contained in:
VakarisZ 2022-04-08 09:32:20 +03:00 committed by GitHub
commit 7b3b17251a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 16 deletions

View File

@ -1,11 +1,11 @@
import queue import queue
import threading import threading
import time
from typing import Dict from typing import Dict
from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem
from infection_monkey.telemetry.i_telem import ITelem from infection_monkey.telemetry.i_telem import ITelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.utils.timer import Timer
DEFAULT_PERIOD = 5 DEFAULT_PERIOD = 5
WAKES_PER_PERIOD = 4 WAKES_PER_PERIOD = 4
@ -40,8 +40,6 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
self._period = period self._period = period
self._should_run_batch_thread = True self._should_run_batch_thread = True
# TODO: Replace with infection_monkey.utils.timer.Timer
self._last_sent_time = time.time()
self._telemetry_batches: Dict[str, IBatchableTelem] = {} self._telemetry_batches: Dict[str, IBatchableTelem] = {}
self._manage_telemetry_batches_thread = None self._manage_telemetry_batches_thread = None
@ -59,21 +57,20 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
self._manage_telemetry_batches_thread = None self._manage_telemetry_batches_thread = None
def _manage_telemetry_batches(self): def _manage_telemetry_batches(self):
self._reset() timer = Timer()
timer.set(self._period)
self._telemetry_batches = {}
while self._should_run_batch_thread: while self._should_run_batch_thread:
self._process_next_telemetry() self._process_next_telemetry()
if self._period_elapsed(): if timer.is_expired():
self._send_telemetry_batches() self._send_telemetry_batches()
self._reset() timer.reset()
self._telemetry_batches = {}
self._send_remaining_telemetry_batches() self._send_remaining_telemetry_batches()
def _reset(self):
self._last_sent_time = time.time()
self._telemetry_batches = {}
def _process_next_telemetry(self): def _process_next_telemetry(self):
try: try:
telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD) telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD)
@ -93,9 +90,6 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
else: else:
self._telemetry_batches[telem_category] = new_telemetry self._telemetry_batches[telem_category] = new_telemetry
def _period_elapsed(self):
return (time.time() - self._last_sent_time) > self._period
def _send_remaining_telemetry_batches(self): def _send_remaining_telemetry_batches(self):
while not self._queue.empty(): while not self._queue.empty():
self._process_next_telemetry() self._process_next_telemetry()

View File

@ -8,6 +8,7 @@ from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port, local_ips from infection_monkey.network.info import get_free_tcp_port, local_ips
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
from infection_monkey.transport.base import get_last_serve_time from infection_monkey.transport.base import get_last_serve_time
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -181,8 +182,9 @@ class MonkeyTunnel(Thread):
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in # wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
# QUIT_TIMEOUT seconds # QUIT_TIMEOUT seconds
# TODO: Replace with infection_monkey.utils.timer.Timer timer = Timer()
while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT): timer.set(self._calculate_timeout())
while self._clients and not timer.is_expired():
try: try:
search, address = self._broad_sock.recvfrom(BUFFER_READ) search, address = self._broad_sock.recvfrom(BUFFER_READ)
if b"-" == search: if b"-" == search:
@ -191,11 +193,19 @@ class MonkeyTunnel(Thread):
except socket.timeout: except socket.timeout:
continue continue
timer.set(self._calculate_timeout())
logger.info("Closing tunnel") logger.info("Closing tunnel")
self._broad_sock.close() self._broad_sock.close()
proxy.stop() proxy.stop()
proxy.join() proxy.join()
def _calculate_timeout(self) -> float:
try:
return QUIT_TIMEOUT - (time.time() - get_last_serve_time())
except TypeError: # get_last_serve_time() may return None
return 0.0
def get_tunnel_for_ip(self, ip: str): def get_tunnel_for_ip(self, ip: str):
if not self.local_port: if not self.local_port:

View File

@ -13,7 +13,8 @@ class Timer:
def set(self, timeout_sec: float): def set(self, timeout_sec: float):
""" """
Set a timer Set a timer
:param float timeout_sec: A fractional number of seconds to set the timeout for. :param float timeout_sec: A nonnegative floating point number expressing the number of
seconds to set the timeout for.
""" """
self._timeout_sec = timeout_sec self._timeout_sec = timeout_sec
self._start_time = time.time() self._start_time = time.time()