forked from p15670423/monkey
Merge pull request #1856 from guardicore/1826-timer-todos
Resolve `Timer` TODOs
This commit is contained in:
commit
7b3b17251a
monkey/infection_monkey
|
@ -1,11 +1,11 @@
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem
|
from infection_monkey.telemetry.i_batchable_telem import IBatchableTelem
|
||||||
from infection_monkey.telemetry.i_telem import ITelem
|
from infection_monkey.telemetry.i_telem import ITelem
|
||||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||||
|
from infection_monkey.utils.timer import Timer
|
||||||
|
|
||||||
DEFAULT_PERIOD = 5
|
DEFAULT_PERIOD = 5
|
||||||
WAKES_PER_PERIOD = 4
|
WAKES_PER_PERIOD = 4
|
||||||
|
@ -40,8 +40,6 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
|
||||||
self._period = period
|
self._period = period
|
||||||
|
|
||||||
self._should_run_batch_thread = True
|
self._should_run_batch_thread = True
|
||||||
# TODO: Replace with infection_monkey.utils.timer.Timer
|
|
||||||
self._last_sent_time = time.time()
|
|
||||||
self._telemetry_batches: Dict[str, IBatchableTelem] = {}
|
self._telemetry_batches: Dict[str, IBatchableTelem] = {}
|
||||||
|
|
||||||
self._manage_telemetry_batches_thread = None
|
self._manage_telemetry_batches_thread = None
|
||||||
|
@ -59,21 +57,20 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
|
||||||
self._manage_telemetry_batches_thread = None
|
self._manage_telemetry_batches_thread = None
|
||||||
|
|
||||||
def _manage_telemetry_batches(self):
|
def _manage_telemetry_batches(self):
|
||||||
self._reset()
|
timer = Timer()
|
||||||
|
timer.set(self._period)
|
||||||
|
self._telemetry_batches = {}
|
||||||
|
|
||||||
while self._should_run_batch_thread:
|
while self._should_run_batch_thread:
|
||||||
self._process_next_telemetry()
|
self._process_next_telemetry()
|
||||||
|
|
||||||
if self._period_elapsed():
|
if timer.is_expired():
|
||||||
self._send_telemetry_batches()
|
self._send_telemetry_batches()
|
||||||
self._reset()
|
timer.reset()
|
||||||
|
self._telemetry_batches = {}
|
||||||
|
|
||||||
self._send_remaining_telemetry_batches()
|
self._send_remaining_telemetry_batches()
|
||||||
|
|
||||||
def _reset(self):
|
|
||||||
self._last_sent_time = time.time()
|
|
||||||
self._telemetry_batches = {}
|
|
||||||
|
|
||||||
def _process_next_telemetry(self):
|
def _process_next_telemetry(self):
|
||||||
try:
|
try:
|
||||||
telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD)
|
telemetry = self._queue.get(block=True, timeout=self._period / WAKES_PER_PERIOD)
|
||||||
|
@ -93,9 +90,6 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
|
||||||
else:
|
else:
|
||||||
self._telemetry_batches[telem_category] = new_telemetry
|
self._telemetry_batches[telem_category] = new_telemetry
|
||||||
|
|
||||||
def _period_elapsed(self):
|
|
||||||
return (time.time() - self._last_sent_time) > self._period
|
|
||||||
|
|
||||||
def _send_remaining_telemetry_batches(self):
|
def _send_remaining_telemetry_batches(self):
|
||||||
while not self._queue.empty():
|
while not self._queue.empty():
|
||||||
self._process_next_telemetry()
|
self._process_next_telemetry()
|
||||||
|
|
|
@ -8,6 +8,7 @@ from infection_monkey.network.firewall import app as firewall
|
||||||
from infection_monkey.network.info import get_free_tcp_port, local_ips
|
from infection_monkey.network.info import get_free_tcp_port, local_ips
|
||||||
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
|
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
|
||||||
from infection_monkey.transport.base import get_last_serve_time
|
from infection_monkey.transport.base import get_last_serve_time
|
||||||
|
from infection_monkey.utils.timer import Timer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -181,8 +182,9 @@ class MonkeyTunnel(Thread):
|
||||||
|
|
||||||
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
|
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
|
||||||
# QUIT_TIMEOUT seconds
|
# QUIT_TIMEOUT seconds
|
||||||
# TODO: Replace with infection_monkey.utils.timer.Timer
|
timer = Timer()
|
||||||
while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT):
|
timer.set(self._calculate_timeout())
|
||||||
|
while self._clients and not timer.is_expired():
|
||||||
try:
|
try:
|
||||||
search, address = self._broad_sock.recvfrom(BUFFER_READ)
|
search, address = self._broad_sock.recvfrom(BUFFER_READ)
|
||||||
if b"-" == search:
|
if b"-" == search:
|
||||||
|
@ -191,11 +193,19 @@ class MonkeyTunnel(Thread):
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
timer.set(self._calculate_timeout())
|
||||||
|
|
||||||
logger.info("Closing tunnel")
|
logger.info("Closing tunnel")
|
||||||
self._broad_sock.close()
|
self._broad_sock.close()
|
||||||
proxy.stop()
|
proxy.stop()
|
||||||
proxy.join()
|
proxy.join()
|
||||||
|
|
||||||
|
def _calculate_timeout(self) -> float:
|
||||||
|
try:
|
||||||
|
return QUIT_TIMEOUT - (time.time() - get_last_serve_time())
|
||||||
|
except TypeError: # get_last_serve_time() may return None
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def get_tunnel_for_ip(self, ip: str):
|
def get_tunnel_for_ip(self, ip: str):
|
||||||
|
|
||||||
if not self.local_port:
|
if not self.local_port:
|
||||||
|
|
|
@ -13,7 +13,8 @@ class Timer:
|
||||||
def set(self, timeout_sec: float):
|
def set(self, timeout_sec: float):
|
||||||
"""
|
"""
|
||||||
Set a timer
|
Set a timer
|
||||||
:param float timeout_sec: A fractional number of seconds to set the timeout for.
|
:param float timeout_sec: A nonnegative floating point number expressing the number of
|
||||||
|
seconds to set the timeout for.
|
||||||
"""
|
"""
|
||||||
self._timeout_sec = timeout_sec
|
self._timeout_sec = timeout_sec
|
||||||
self._start_time = time.time()
|
self._start_time = time.time()
|
||||||
|
|
Loading…
Reference in New Issue