diff --git a/infection_monkey/network/tools.py b/infection_monkey/network/tools.py index 303b0dd8f..43dd7286c 100644 --- a/infection_monkey/network/tools.py +++ b/infection_monkey/network/tools.py @@ -8,6 +8,7 @@ DEFAULT_TIMEOUT = 10 BANNER_READ = 1024 LOG = logging.getLogger(__name__) +SLEEP_BETWEEN_POLL = 0.5 def struct_unpack_tracker(data, index, fmt): @@ -126,15 +127,24 @@ def check_tcp_ports(ip, ports, timeout=DEFAULT_TIMEOUT, get_banner=False): LOG.warning("Failed to connect to port %s, error code is %d", port, err) if len(possible_ports) != 0: - time.sleep(timeout) - sock_objects = [s[1] for s in possible_ports] - # first filter - _, writeable_sockets, _ = select.select(sock_objects, sock_objects, sock_objects, 0) - for s in writeable_sockets: - try: # actual test - connected_ports_sockets.append((s.getpeername()[1], s)) - except socket.error: # bad socket, select didn't filter it properly - pass + timeout = int(round(timeout)) # clamp to integer, to avoid checking input + time_left = timeout + sockets_to_try = possible_ports[:] + connected_ports_sockets = [] + while (time_left >= 0) and len(sockets_to_try): + sock_objects = [s[1] for s in sockets_to_try] + + _, writeable_sockets, _ = select.select(sock_objects, sock_objects, sock_objects, 0) + for s in writeable_sockets: + try: # actual test + connected_ports_sockets.append((s.getpeername()[1], s)) + except socket.error: # bad socket, select didn't filter it properly + pass + sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets] + if sockets_to_try: + time.sleep(SLEEP_BETWEEN_POLL) + timeout -= SLEEP_BETWEEN_POLL + LOG.debug( "On host %s discovered the following ports %s" % (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets])))