Merge pull request #1709 from guardicore/1601-fix-check-tcp-ports-bugs

Minor changes to TCP scanning
This commit is contained in:
Mike Salvatore 2022-02-10 12:23:29 -05:00 committed by GitHub
commit 5a8c072d6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 192 additions and 83 deletions

View File

@ -2,126 +2,145 @@ import logging
import select
import socket
import time
from itertools import zip_longest
from typing import Dict, List, Set
from typing import Iterable, Mapping, Tuple
from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
SLEEP_BETWEEN_POLL = 0.5
from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__)
def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]:
ports_scan = {}
open_ports, banners = _check_tcp_ports(host, ports, timeout)
open_ports = set(open_ports)
for port, banner in zip_longest(ports, banners, fillvalue=None):
ports_scan[port] = _build_port_scan_data(port, open_ports, banner)
return ports_scan
POLL_INTERVAL = 0.5
def _build_port_scan_data(port: int, open_ports: Set[int], banner: str) -> PortScanData:
if port in open_ports:
service = tcp_port_to_service(port)
return PortScanData(port, PortStatus.OPEN, banner, service)
else:
return _get_closed_port_data(port)
def scan_tcp_ports(
host: str, ports_to_scan: Iterable[int], timeout: float
) -> Mapping[int, PortScanData]:
open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
return _build_port_scan_data(ports_to_scan, open_ports)
def _build_port_scan_data(
ports_to_scan: Iterable[int], open_ports: Mapping[int, str]
) -> Mapping[int, PortScanData]:
port_scan_data = {}
for port in ports_to_scan:
if port in open_ports:
service = tcp_port_to_service(port)
banner = open_ports[port]
port_scan_data[port] = PortScanData(port, PortStatus.OPEN, banner, service)
else:
port_scan_data[port] = _get_closed_port_data(port)
return port_scan_data
def _get_closed_port_data(port: int) -> PortScanData:
return PortScanData(port, PortStatus.CLOSED, None, None)
def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT):
def _check_tcp_ports(
ip: str, ports_to_scan: Iterable[int], timeout: float = DEFAULT_TIMEOUT
) -> Mapping[int, str]:
"""
Checks whether any of the given ports are open on a target IP.
:param ip: IP of host to attack
:param ports: List of ports to attack. Must not be empty.
:param ports_to_scan: An iterable of ports to scan. Must not be empty.
:param timeout: Amount of time to wait for connection
:return: List of open ports.
:return: Mapping where the key is an open port and the value is the banner
:rtype: Mapping
"""
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))]
# CR: Don't use list comprehensions if you don't need a list
[s.setblocking(False) for s in sockets]
possible_ports = []
connected_ports_sockets = []
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports_to_scan))]
for s in sockets:
s.setblocking(False)
possible_ports = set()
connected_ports = set()
open_ports = {}
try:
logger.debug("Connecting to the following ports %s" % ",".join((str(x) for x in ports)))
for sock, port in zip(sockets, ports):
logger.debug(
"Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan))
)
for sock, port in zip(sockets, ports_to_scan):
err = sock.connect_ex((ip, port))
if err == 0: # immediate connect
connected_ports_sockets.append((port, sock))
possible_ports.append((port, sock))
continue
# BUG: I don't think a socket will ever connect successfully if this error is raised.
# From the documentation: "Resource temporarily unavailable... It is a nonfatal
# error, **and the operation should be retried later**." (emphasis mine). If the
# operation is not retried later, I don't see the point in appending this to
# possible_ports.
if err == 10035: # WSAEWOULDBLOCK is valid, see
# https://msdn.microsoft.com/en-us/library/windows/desktop/ms740668%28v=vs.85%29.aspx?f=255&MSPPError=-2147217396
possible_ports.append((port, sock))
continue
if err == 115: # EINPROGRESS 115 /* Operation now in progress */
possible_ports.append((port, sock))
continue
logger.warning("Failed to connect to port %s, error code is %d", port, err)
connected_ports.add((port, sock))
possible_ports.add((port, sock))
elif err == 10035: # WSAEWOULDBLOCK is valid.
# https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect
# says, "Use the select function to determine the completion of the connection
# request by checking to see if the socket is writable," which is being done below.
possible_ports.add((port, sock))
elif err == 115: # EINPROGRESS 115 /* Operation now in progress */
possible_ports.add((port, sock))
else:
logger.warning("Failed to connect to port %s, error code is %d", port, err)
if len(possible_ports) != 0:
timeout = int(round(timeout)) # clamp to integer, to avoid checking input
sockets_to_try = possible_ports[:]
# BUG: If any sockets were added to connected_ports_sockets on line 94, this would
# remove them.
connected_ports_sockets = []
while (timeout >= 0) and sockets_to_try:
sockets_to_try = possible_ports.copy()
timer = Timer()
timer.set(timeout)
while (not timer.is_expired()) and sockets_to_try:
# The call to select() may return sockets that are writeable but not actually
# connected. Adding this sleep prevents excessive looping.
time.sleep(min(POLL_INTERVAL, timer.time_remaining))
sock_objects = [s[1] for s in sockets_to_try]
# BUG: Since timeout is 0, this could block indefinitely
_, writeable_sockets, _ = select.select(sock_objects, sock_objects, sock_objects, 0)
_, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining)
for s in writeable_sockets:
try: # actual test
connected_ports_sockets.append((s.getpeername()[1], s))
connected_ports.add((s.getpeername()[1], s))
except socket.error: # bad socket, select didn't filter it properly
pass
sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets]
if sockets_to_try:
time.sleep(SLEEP_BETWEEN_POLL)
timeout -= SLEEP_BETWEEN_POLL
sockets_to_try = sockets_to_try - connected_ports
logger.debug(
"On host %s discovered the following ports %s"
% (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets]))
% (str(ip), ",".join([str(s[0]) for s in connected_ports]))
)
banners = []
if len(connected_ports_sockets) != 0:
open_ports = {port: "" for port, _ in connected_ports}
if len(connected_ports) != 0:
readable_sockets, _, _ = select.select(
[s[1] for s in connected_ports_sockets], [], [], 0
[s[1] for s in connected_ports], [], [], timer.time_remaining
)
# read first BANNER_READ bytes. We ignore errors because service might not send a
# decodable byte string.
# CR: Because of how black formats this, it is difficult to parse. Refactor to be
# easier to read.
# TODO: Rework the return of this function. Consider using dictionary
banners = [
sock.recv(BANNER_READ).decode(errors="ignore")
if sock in readable_sockets
else ""
for port, sock in connected_ports_sockets
]
pass
# try to cleanup
# CR: Evaluate whether or not we should call shutdown() before close() on each socket.
[s[1].close() for s in possible_ports]
return [port for port, sock in connected_ports_sockets], banners
else:
return [], []
for port, sock in connected_ports:
if sock in readable_sockets:
open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore")
else:
open_ports[port] = ""
except socket.error as exc:
logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc)
return [], []
_clean_up_sockets(possible_ports, connected_ports)
return open_ports
def _clean_up_sockets(
possible_ports: Iterable[Tuple[int, socket.socket]],
connected_ports_sockets: Iterable[Tuple[int, socket.socket]],
):
# Only call shutdown() on sockets we know to be connected
for port, s in connected_ports_sockets:
try:
s.shutdown(socket.SHUT_RDWR)
except socket.error as exc:
logger.warning(f"Error occurred while shutting down socket on port {port}: {exc}")
# Call close() for all sockets
for port, s in possible_ports:
try:
s.close()
except socket.error as exc:
logger.warning(f"Error occurred while closing socket on port {port}: {exc}")

View File

@ -25,7 +25,18 @@ class Timer:
TIMEOUT_SEC, False otherwise
:rtype: bool
"""
return (time.time() - self._start_time) >= self._timeout_sec
return self.time_remaining == 0
@property
def time_remaining(self) -> float:
"""
Return the amount of time remaining until the timer expires.
:return: The number of seconds until the timer expires. If the timer is expired, this
function returns 0 (it will never return a negative number).
:rtype: float
"""
time_remaining = self._timeout_sec - (time.time() - self._start_time)
return max(time_remaining, 0)
def reset(self):
"""

View File

@ -0,0 +1,54 @@
import pytest
from infection_monkey.i_puppet import PortStatus
from infection_monkey.network import scan_tcp_ports
PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222]
OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"}
@pytest.fixture
def patch_check_tcp_ports(monkeypatch, open_ports_data):
monkeypatch.setattr(
"infection_monkey.network.tcp_scanner._check_tcp_ports",
lambda *_: open_ports_data,
)
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data):
closed_ports = [8080, 143, 445]
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0)
assert len(port_scan_data) == 6
for port in open_ports_data.keys():
assert port_scan_data[port].port == port
assert port_scan_data[port].status == PortStatus.OPEN
assert port_scan_data[port].banner == open_ports_data.get(port)
for port in closed_ports:
assert port_scan_data[port].port == port
assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None
@pytest.mark.parametrize("open_ports_data", [{}])
def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data):
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0)
assert len(port_scan_data) == 6
for port in open_ports_data:
assert port_scan_data[port].port == port
assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data):
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0)
assert len(port_scan_data) == 0

View File

@ -67,3 +67,28 @@ def test_timer_reset(start_time, set_current_time, timeout):
set_current_time(start_time + (2 * timeout))
assert t.is_expired()
def test_time_remaining(start_time, set_current_time):
timeout = 5
t = Timer()
t.set(timeout)
assert t.time_remaining == timeout
set_current_time(start_time + 2)
assert t.time_remaining == 3
def test_time_remaining_is_zero(start_time, set_current_time):
timeout = 5
t = Timer()
t.set(timeout)
set_current_time(start_time + timeout)
assert t.time_remaining == 0
set_current_time(start_time + (2 * timeout))
assert t.time_remaining == 0