diff --git a/monkey/infection_monkey/control.py b/monkey/infection_monkey/control.py index 90a7b6078..d11de18fc 100644 --- a/monkey/infection_monkey/control.py +++ b/monkey/infection_monkey/control.py @@ -292,7 +292,12 @@ class ControlClient(object): proxy_class = HTTPConnectProxy target_addr, target_port = None, None - return tunnel.MonkeyTunnel(proxy_class, target_addr=target_addr, target_port=target_port) + return tunnel.MonkeyTunnel( + proxy_class, + keep_tunnel_open_time=WormConfiguration.keep_tunnel_open_time, + target_addr=target_addr, + target_port=target_port, + ) @staticmethod def get_pba_file(filename): diff --git a/monkey/infection_monkey/tunnel.py b/monkey/infection_monkey/tunnel.py index 4aa90e80f..a78d2bf45 100644 --- a/monkey/infection_monkey/tunnel.py +++ b/monkey/infection_monkey/tunnel.py @@ -2,7 +2,7 @@ import logging import socket import struct import time -from threading import Thread +from threading import Event, Thread from infection_monkey.network.firewall import app as firewall from infection_monkey.network.info import get_free_tcp_port, local_ips @@ -109,10 +109,18 @@ def quit_tunnel(address, timeout=DEFAULT_TIMEOUT): class MonkeyTunnel(Thread): - def __init__(self, proxy_class, target_addr=None, target_port=None, timeout=DEFAULT_TIMEOUT): + def __init__( + self, + proxy_class, + keep_tunnel_open_time, + target_addr=None, + target_port=None, + timeout=DEFAULT_TIMEOUT, + ): self._target_addr = target_addr self._target_port = target_port self._proxy_class = proxy_class + self._keep_tunnel_open_time = keep_tunnel_open_time self._broad_sock = None self._timeout = timeout self._stopped = False @@ -121,6 +129,7 @@ class MonkeyTunnel(Thread): super(MonkeyTunnel, self).__init__() self.daemon = True self.l_ips = None + self._wait_for_exploited_machines = Event() def run(self): self._broad_sock = _set_multicast_socket(self._timeout) @@ -195,5 +204,17 @@ class MonkeyTunnel(Thread): ip_match = get_interface_to_target(ip) return "%s:%d" % (ip_match, self.local_port) + def set_wait_for_exploited_machines(self): + self._wait_for_exploited_machines.set() + def stop(self): + self._wait_for_exploited_machine_connection() self._stopped = True + + def _wait_for_exploited_machine_connection(self): + if self._wait_for_exploited_machines.is_set(): + logger.info( + f"Waiting {self._keep_tunnel_open_time} seconds for exploited machines to connect " + "to the tunnel." + ) + time.sleep(self._keep_tunnel_open_time)