From 0e7f171c4a1e1843883c8fcf60d7d590030baec3 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Wed, 9 Feb 2022 19:21:45 -0500 Subject: [PATCH] Agent: Use a Timer in _check_tcp_ports() to simplify logic --- monkey/infection_monkey/network/tcp_scanner.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index 84453cfa7..9e69d9351 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -1,14 +1,12 @@ import logging import select import socket -import time from itertools import zip_longest from typing import Dict, List, Set from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service - -SLEEP_BETWEEN_POLL = 0.5 +from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) @@ -71,21 +69,21 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT logger.warning("Failed to connect to port %s, error code is %d", port, err) if len(possible_ports) != 0: - timeout = int(round(timeout)) # clamp to integer, to avoid checking input sockets_to_try = possible_ports.copy() - while (timeout >= 0) and sockets_to_try: + + timer = Timer() + timer.set(timeout) + + while (not timer.is_expired()) and sockets_to_try: sock_objects = [s[1] for s in sockets_to_try] - _, writeable_sockets, _ = select.select([], sock_objects, [], 0) + _, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining) for s in writeable_sockets: try: # actual test connected_ports_sockets.append((s.getpeername()[1], s)) except socket.error: # bad socket, select didn't filter it properly pass sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets] - if sockets_to_try: - time.sleep(SLEEP_BETWEEN_POLL) - timeout -= SLEEP_BETWEEN_POLL logger.debug( "On host %s discovered the following ports %s" @@ -95,7 +93,7 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT banners = [] if len(connected_ports_sockets) != 0: readable_sockets, _, _ = select.select( - [s[1] for s in connected_ports_sockets], [], [], 0 + [s[1] for s in connected_ports_sockets], [], [], timer.time_remaining ) # read first BANNER_READ bytes. We ignore errors because service might not send a # decodable byte string.