diff --git a/monkey/infection_monkey/tunnel.py b/monkey/infection_monkey/tunnel.py index b0f778534..55349dfd6 100644 --- a/monkey/infection_monkey/tunnel.py +++ b/monkey/infection_monkey/tunnel.py @@ -8,6 +8,7 @@ from infection_monkey.network.firewall import app as firewall from infection_monkey.network.info import get_free_tcp_port, local_ips from infection_monkey.network.tools import check_tcp_port, get_interface_to_target from infection_monkey.transport.base import get_last_serve_time +from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) @@ -181,8 +182,9 @@ class MonkeyTunnel(Thread): # wait till all of the tunnel clients has been disconnected, or no one used the tunnel in # QUIT_TIMEOUT seconds - # TODO: Replace with infection_monkey.utils.timer.Timer - while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT): + timer = Timer() + timer.set(self._calculate_timeout()) + while self._clients and not timer.is_expired(): try: search, address = self._broad_sock.recvfrom(BUFFER_READ) if b"-" == search: @@ -191,11 +193,16 @@ class MonkeyTunnel(Thread): except socket.timeout: continue + timer.set(self._calculate_timeout()) + logger.info("Closing tunnel") self._broad_sock.close() proxy.stop() proxy.join() + def _calculate_timeout(self): + return QUIT_TIMEOUT - (time.time() - get_last_serve_time()) + def get_tunnel_for_ip(self, ip: str): if not self.local_port: