forked from p15670423/monkey
Merge pull request #1709 from guardicore/1601-fix-check-tcp-ports-bugs
Minor changes to TCP scanning
This commit is contained in:
commit
5a8c072d6a
|
@ -2,126 +2,145 @@ import logging
|
|||
import select
|
||||
import socket
|
||||
import time
|
||||
from itertools import zip_longest
|
||||
from typing import Dict, List, Set
|
||||
from typing import Iterable, Mapping, Tuple
|
||||
|
||||
from infection_monkey.i_puppet import PortScanData, PortStatus
|
||||
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
|
||||
|
||||
SLEEP_BETWEEN_POLL = 0.5
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]:
|
||||
ports_scan = {}
|
||||
|
||||
open_ports, banners = _check_tcp_ports(host, ports, timeout)
|
||||
open_ports = set(open_ports)
|
||||
|
||||
for port, banner in zip_longest(ports, banners, fillvalue=None):
|
||||
ports_scan[port] = _build_port_scan_data(port, open_ports, banner)
|
||||
|
||||
return ports_scan
|
||||
POLL_INTERVAL = 0.5
|
||||
|
||||
|
||||
def _build_port_scan_data(port: int, open_ports: Set[int], banner: str) -> PortScanData:
|
||||
if port in open_ports:
|
||||
service = tcp_port_to_service(port)
|
||||
return PortScanData(port, PortStatus.OPEN, banner, service)
|
||||
else:
|
||||
return _get_closed_port_data(port)
|
||||
def scan_tcp_ports(
|
||||
host: str, ports_to_scan: Iterable[int], timeout: float
|
||||
) -> Mapping[int, PortScanData]:
|
||||
open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
|
||||
|
||||
return _build_port_scan_data(ports_to_scan, open_ports)
|
||||
|
||||
|
||||
def _build_port_scan_data(
|
||||
ports_to_scan: Iterable[int], open_ports: Mapping[int, str]
|
||||
) -> Mapping[int, PortScanData]:
|
||||
port_scan_data = {}
|
||||
for port in ports_to_scan:
|
||||
if port in open_ports:
|
||||
service = tcp_port_to_service(port)
|
||||
banner = open_ports[port]
|
||||
|
||||
port_scan_data[port] = PortScanData(port, PortStatus.OPEN, banner, service)
|
||||
else:
|
||||
port_scan_data[port] = _get_closed_port_data(port)
|
||||
|
||||
return port_scan_data
|
||||
|
||||
|
||||
def _get_closed_port_data(port: int) -> PortScanData:
|
||||
return PortScanData(port, PortStatus.CLOSED, None, None)
|
||||
|
||||
|
||||
def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT):
|
||||
def _check_tcp_ports(
|
||||
ip: str, ports_to_scan: Iterable[int], timeout: float = DEFAULT_TIMEOUT
|
||||
) -> Mapping[int, str]:
|
||||
"""
|
||||
Checks whether any of the given ports are open on a target IP.
|
||||
:param ip: IP of host to attack
|
||||
:param ports: List of ports to attack. Must not be empty.
|
||||
:param ports_to_scan: An iterable of ports to scan. Must not be empty.
|
||||
:param timeout: Amount of time to wait for connection
|
||||
:return: List of open ports.
|
||||
:return: Mapping where the key is an open port and the value is the banner
|
||||
:rtype: Mapping
|
||||
"""
|
||||
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))]
|
||||
# CR: Don't use list comprehensions if you don't need a list
|
||||
[s.setblocking(False) for s in sockets]
|
||||
possible_ports = []
|
||||
connected_ports_sockets = []
|
||||
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports_to_scan))]
|
||||
for s in sockets:
|
||||
s.setblocking(False)
|
||||
|
||||
possible_ports = set()
|
||||
connected_ports = set()
|
||||
open_ports = {}
|
||||
|
||||
try:
|
||||
logger.debug("Connecting to the following ports %s" % ",".join((str(x) for x in ports)))
|
||||
for sock, port in zip(sockets, ports):
|
||||
logger.debug(
|
||||
"Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan))
|
||||
)
|
||||
for sock, port in zip(sockets, ports_to_scan):
|
||||
err = sock.connect_ex((ip, port))
|
||||
if err == 0: # immediate connect
|
||||
connected_ports_sockets.append((port, sock))
|
||||
possible_ports.append((port, sock))
|
||||
continue
|
||||
# BUG: I don't think a socket will ever connect successfully if this error is raised.
|
||||
# From the documentation: "Resource temporarily unavailable... It is a nonfatal
|
||||
# error, **and the operation should be retried later**." (emphasis mine). If the
|
||||
# operation is not retried later, I don't see the point in appending this to
|
||||
# possible_ports.
|
||||
if err == 10035: # WSAEWOULDBLOCK is valid, see
|
||||
# https://msdn.microsoft.com/en-us/library/windows/desktop/ms740668%28v=vs.85%29.aspx?f=255&MSPPError=-2147217396
|
||||
possible_ports.append((port, sock))
|
||||
continue
|
||||
if err == 115: # EINPROGRESS 115 /* Operation now in progress */
|
||||
possible_ports.append((port, sock))
|
||||
continue
|
||||
logger.warning("Failed to connect to port %s, error code is %d", port, err)
|
||||
connected_ports.add((port, sock))
|
||||
possible_ports.add((port, sock))
|
||||
elif err == 10035: # WSAEWOULDBLOCK is valid.
|
||||
# https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect
|
||||
# says, "Use the select function to determine the completion of the connection
|
||||
# request by checking to see if the socket is writable," which is being done below.
|
||||
possible_ports.add((port, sock))
|
||||
elif err == 115: # EINPROGRESS 115 /* Operation now in progress */
|
||||
possible_ports.add((port, sock))
|
||||
else:
|
||||
logger.warning("Failed to connect to port %s, error code is %d", port, err)
|
||||
|
||||
if len(possible_ports) != 0:
|
||||
timeout = int(round(timeout)) # clamp to integer, to avoid checking input
|
||||
sockets_to_try = possible_ports[:]
|
||||
# BUG: If any sockets were added to connected_ports_sockets on line 94, this would
|
||||
# remove them.
|
||||
connected_ports_sockets = []
|
||||
while (timeout >= 0) and sockets_to_try:
|
||||
sockets_to_try = possible_ports.copy()
|
||||
|
||||
timer = Timer()
|
||||
timer.set(timeout)
|
||||
|
||||
while (not timer.is_expired()) and sockets_to_try:
|
||||
# The call to select() may return sockets that are writeable but not actually
|
||||
# connected. Adding this sleep prevents excessive looping.
|
||||
time.sleep(min(POLL_INTERVAL, timer.time_remaining))
|
||||
|
||||
sock_objects = [s[1] for s in sockets_to_try]
|
||||
|
||||
# BUG: Since timeout is 0, this could block indefinitely
|
||||
_, writeable_sockets, _ = select.select(sock_objects, sock_objects, sock_objects, 0)
|
||||
_, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining)
|
||||
for s in writeable_sockets:
|
||||
try: # actual test
|
||||
connected_ports_sockets.append((s.getpeername()[1], s))
|
||||
connected_ports.add((s.getpeername()[1], s))
|
||||
except socket.error: # bad socket, select didn't filter it properly
|
||||
pass
|
||||
sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets]
|
||||
if sockets_to_try:
|
||||
time.sleep(SLEEP_BETWEEN_POLL)
|
||||
timeout -= SLEEP_BETWEEN_POLL
|
||||
|
||||
sockets_to_try = sockets_to_try - connected_ports
|
||||
|
||||
logger.debug(
|
||||
"On host %s discovered the following ports %s"
|
||||
% (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets]))
|
||||
% (str(ip), ",".join([str(s[0]) for s in connected_ports]))
|
||||
)
|
||||
banners = []
|
||||
if len(connected_ports_sockets) != 0:
|
||||
|
||||
open_ports = {port: "" for port, _ in connected_ports}
|
||||
if len(connected_ports) != 0:
|
||||
readable_sockets, _, _ = select.select(
|
||||
[s[1] for s in connected_ports_sockets], [], [], 0
|
||||
[s[1] for s in connected_ports], [], [], timer.time_remaining
|
||||
)
|
||||
# read first BANNER_READ bytes. We ignore errors because service might not send a
|
||||
# decodable byte string.
|
||||
# CR: Because of how black formats this, it is difficult to parse. Refactor to be
|
||||
# easier to read.
|
||||
|
||||
# TODO: Rework the return of this function. Consider using dictionary
|
||||
banners = [
|
||||
sock.recv(BANNER_READ).decode(errors="ignore")
|
||||
if sock in readable_sockets
|
||||
else ""
|
||||
for port, sock in connected_ports_sockets
|
||||
]
|
||||
pass
|
||||
# try to cleanup
|
||||
# CR: Evaluate whether or not we should call shutdown() before close() on each socket.
|
||||
[s[1].close() for s in possible_ports]
|
||||
return [port for port, sock in connected_ports_sockets], banners
|
||||
else:
|
||||
return [], []
|
||||
for port, sock in connected_ports:
|
||||
if sock in readable_sockets:
|
||||
open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore")
|
||||
else:
|
||||
open_ports[port] = ""
|
||||
|
||||
except socket.error as exc:
|
||||
logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc)
|
||||
return [], []
|
||||
|
||||
_clean_up_sockets(possible_ports, connected_ports)
|
||||
|
||||
return open_ports
|
||||
|
||||
|
||||
def _clean_up_sockets(
|
||||
possible_ports: Iterable[Tuple[int, socket.socket]],
|
||||
connected_ports_sockets: Iterable[Tuple[int, socket.socket]],
|
||||
):
|
||||
# Only call shutdown() on sockets we know to be connected
|
||||
for port, s in connected_ports_sockets:
|
||||
try:
|
||||
s.shutdown(socket.SHUT_RDWR)
|
||||
except socket.error as exc:
|
||||
logger.warning(f"Error occurred while shutting down socket on port {port}: {exc}")
|
||||
|
||||
# Call close() for all sockets
|
||||
for port, s in possible_ports:
|
||||
try:
|
||||
s.close()
|
||||
except socket.error as exc:
|
||||
logger.warning(f"Error occurred while closing socket on port {port}: {exc}")
|
||||
|
|
|
@ -25,7 +25,18 @@ class Timer:
|
|||
TIMEOUT_SEC, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
return (time.time() - self._start_time) >= self._timeout_sec
|
||||
return self.time_remaining == 0
|
||||
|
||||
@property
|
||||
def time_remaining(self) -> float:
|
||||
"""
|
||||
Return the amount of time remaining until the timer expires.
|
||||
:return: The number of seconds until the timer expires. If the timer is expired, this
|
||||
function returns 0 (it will never return a negative number).
|
||||
:rtype: float
|
||||
"""
|
||||
time_remaining = self._timeout_sec - (time.time() - self._start_time)
|
||||
return max(time_remaining, 0)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import pytest
|
||||
|
||||
from infection_monkey.i_puppet import PortStatus
|
||||
from infection_monkey.network import scan_tcp_ports
|
||||
|
||||
PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222]
|
||||
|
||||
OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_check_tcp_ports(monkeypatch, open_ports_data):
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.tcp_scanner._check_tcp_ports",
|
||||
lambda *_: open_ports_data,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
|
||||
def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data):
|
||||
closed_ports = [8080, 143, 445]
|
||||
|
||||
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0)
|
||||
|
||||
assert len(port_scan_data) == 6
|
||||
for port in open_ports_data.keys():
|
||||
assert port_scan_data[port].port == port
|
||||
assert port_scan_data[port].status == PortStatus.OPEN
|
||||
assert port_scan_data[port].banner == open_ports_data.get(port)
|
||||
|
||||
for port in closed_ports:
|
||||
assert port_scan_data[port].port == port
|
||||
assert port_scan_data[port].status == PortStatus.CLOSED
|
||||
assert port_scan_data[port].banner is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("open_ports_data", [{}])
|
||||
def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data):
|
||||
|
||||
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0)
|
||||
|
||||
assert len(port_scan_data) == 6
|
||||
for port in open_ports_data:
|
||||
assert port_scan_data[port].port == port
|
||||
assert port_scan_data[port].status == PortStatus.CLOSED
|
||||
assert port_scan_data[port].banner is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
|
||||
def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data):
|
||||
|
||||
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0)
|
||||
|
||||
assert len(port_scan_data) == 0
|
|
@ -67,3 +67,28 @@ def test_timer_reset(start_time, set_current_time, timeout):
|
|||
|
||||
set_current_time(start_time + (2 * timeout))
|
||||
assert t.is_expired()
|
||||
|
||||
|
||||
def test_time_remaining(start_time, set_current_time):
|
||||
timeout = 5
|
||||
|
||||
t = Timer()
|
||||
t.set(timeout)
|
||||
|
||||
assert t.time_remaining == timeout
|
||||
|
||||
set_current_time(start_time + 2)
|
||||
assert t.time_remaining == 3
|
||||
|
||||
|
||||
def test_time_remaining_is_zero(start_time, set_current_time):
|
||||
timeout = 5
|
||||
|
||||
t = Timer()
|
||||
t.set(timeout)
|
||||
|
||||
set_current_time(start_time + timeout)
|
||||
assert t.time_remaining == 0
|
||||
|
||||
set_current_time(start_time + (2 * timeout))
|
||||
assert t.time_remaining == 0
|
||||
|
|
Loading…
Reference in New Issue