forked from p34709852/monkey
Agent: Use a Timer in _check_tcp_ports() to simplify logic
This commit is contained in:
parent
e981ead150
commit
0e7f171c4a
|
@ -1,14 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
import select
|
import select
|
||||||
import socket
|
import socket
|
||||||
import time
|
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Dict, List, Set
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
from infection_monkey.i_puppet import PortScanData, PortStatus
|
from infection_monkey.i_puppet import PortScanData, PortStatus
|
||||||
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
|
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
|
||||||
|
from infection_monkey.utils.timer import Timer
|
||||||
SLEEP_BETWEEN_POLL = 0.5
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -71,21 +69,21 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT
|
||||||
logger.warning("Failed to connect to port %s, error code is %d", port, err)
|
logger.warning("Failed to connect to port %s, error code is %d", port, err)
|
||||||
|
|
||||||
if len(possible_ports) != 0:
|
if len(possible_ports) != 0:
|
||||||
timeout = int(round(timeout)) # clamp to integer, to avoid checking input
|
|
||||||
sockets_to_try = possible_ports.copy()
|
sockets_to_try = possible_ports.copy()
|
||||||
while (timeout >= 0) and sockets_to_try:
|
|
||||||
|
timer = Timer()
|
||||||
|
timer.set(timeout)
|
||||||
|
|
||||||
|
while (not timer.is_expired()) and sockets_to_try:
|
||||||
sock_objects = [s[1] for s in sockets_to_try]
|
sock_objects = [s[1] for s in sockets_to_try]
|
||||||
|
|
||||||
_, writeable_sockets, _ = select.select([], sock_objects, [], 0)
|
_, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining)
|
||||||
for s in writeable_sockets:
|
for s in writeable_sockets:
|
||||||
try: # actual test
|
try: # actual test
|
||||||
connected_ports_sockets.append((s.getpeername()[1], s))
|
connected_ports_sockets.append((s.getpeername()[1], s))
|
||||||
except socket.error: # bad socket, select didn't filter it properly
|
except socket.error: # bad socket, select didn't filter it properly
|
||||||
pass
|
pass
|
||||||
sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets]
|
sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets]
|
||||||
if sockets_to_try:
|
|
||||||
time.sleep(SLEEP_BETWEEN_POLL)
|
|
||||||
timeout -= SLEEP_BETWEEN_POLL
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"On host %s discovered the following ports %s"
|
"On host %s discovered the following ports %s"
|
||||||
|
@ -95,7 +93,7 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT
|
||||||
banners = []
|
banners = []
|
||||||
if len(connected_ports_sockets) != 0:
|
if len(connected_ports_sockets) != 0:
|
||||||
readable_sockets, _, _ = select.select(
|
readable_sockets, _, _ = select.select(
|
||||||
[s[1] for s in connected_ports_sockets], [], [], 0
|
[s[1] for s in connected_ports_sockets], [], [], timer.time_remaining
|
||||||
)
|
)
|
||||||
# read first BANNER_READ bytes. We ignore errors because service might not send a
|
# read first BANNER_READ bytes. We ignore errors because service might not send a
|
||||||
# decodable byte string.
|
# decodable byte string.
|
||||||
|
|
Loading…
Reference in New Issue