Agent: Use a Timer in _check_tcp_ports() to simplify logic

This commit is contained in:
Mike Salvatore 2022-02-09 19:21:45 -05:00 committed by Ilija Lazoroski
parent e981ead150
commit 0e7f171c4a
1 changed files with 8 additions and 10 deletions

View File

@ -1,14 +1,12 @@
import logging import logging
import select import select
import socket import socket
import time
from itertools import zip_longest from itertools import zip_longest
from typing import Dict, List, Set from typing import Dict, List, Set
from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
from infection_monkey.utils.timer import Timer
SLEEP_BETWEEN_POLL = 0.5
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,21 +69,21 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT
logger.warning("Failed to connect to port %s, error code is %d", port, err) logger.warning("Failed to connect to port %s, error code is %d", port, err)
if len(possible_ports) != 0: if len(possible_ports) != 0:
timeout = int(round(timeout)) # clamp to integer, to avoid checking input
sockets_to_try = possible_ports.copy() sockets_to_try = possible_ports.copy()
while (timeout >= 0) and sockets_to_try:
timer = Timer()
timer.set(timeout)
while (not timer.is_expired()) and sockets_to_try:
sock_objects = [s[1] for s in sockets_to_try] sock_objects = [s[1] for s in sockets_to_try]
_, writeable_sockets, _ = select.select([], sock_objects, [], 0) _, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining)
for s in writeable_sockets: for s in writeable_sockets:
try: # actual test try: # actual test
connected_ports_sockets.append((s.getpeername()[1], s)) connected_ports_sockets.append((s.getpeername()[1], s))
except socket.error: # bad socket, select didn't filter it properly except socket.error: # bad socket, select didn't filter it properly
pass pass
sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets] sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets]
if sockets_to_try:
time.sleep(SLEEP_BETWEEN_POLL)
timeout -= SLEEP_BETWEEN_POLL
logger.debug( logger.debug(
"On host %s discovered the following ports %s" "On host %s discovered the following ports %s"
@ -95,7 +93,7 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT
banners = [] banners = []
if len(connected_ports_sockets) != 0: if len(connected_ports_sockets) != 0:
readable_sockets, _, _ = select.select( readable_sockets, _, _ = select.select(
[s[1] for s in connected_ports_sockets], [], [], 0 [s[1] for s in connected_ports_sockets], [], [], timer.time_remaining
) )
# read first BANNER_READ bytes. We ignore errors because service might not send a # read first BANNER_READ bytes. We ignore errors because service might not send a
# decodable byte string. # decodable byte string.