forked from p15670423/monkey
Merge pull request #1709 from guardicore/1601-fix-check-tcp-ports-bugs
Minor changes to TCP scanning
This commit is contained in:
commit
5a8c072d6a
|
@ -2,126 +2,145 @@ import logging
|
||||||
import select
|
import select
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from itertools import zip_longest
|
from typing import Iterable, Mapping, Tuple
|
||||||
from typing import Dict, List, Set
|
|
||||||
|
|
||||||
from infection_monkey.i_puppet import PortScanData, PortStatus
|
from infection_monkey.i_puppet import PortScanData, PortStatus
|
||||||
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
|
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
|
||||||
|
from infection_monkey.utils.timer import Timer
|
||||||
SLEEP_BETWEEN_POLL = 0.5
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
POLL_INTERVAL = 0.5
|
||||||
def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]:
|
|
||||||
ports_scan = {}
|
|
||||||
|
|
||||||
open_ports, banners = _check_tcp_ports(host, ports, timeout)
|
|
||||||
open_ports = set(open_ports)
|
|
||||||
|
|
||||||
for port, banner in zip_longest(ports, banners, fillvalue=None):
|
|
||||||
ports_scan[port] = _build_port_scan_data(port, open_ports, banner)
|
|
||||||
|
|
||||||
return ports_scan
|
|
||||||
|
|
||||||
|
|
||||||
def _build_port_scan_data(port: int, open_ports: Set[int], banner: str) -> PortScanData:
|
def scan_tcp_ports(
|
||||||
if port in open_ports:
|
host: str, ports_to_scan: Iterable[int], timeout: float
|
||||||
service = tcp_port_to_service(port)
|
) -> Mapping[int, PortScanData]:
|
||||||
return PortScanData(port, PortStatus.OPEN, banner, service)
|
open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
|
||||||
else:
|
|
||||||
return _get_closed_port_data(port)
|
return _build_port_scan_data(ports_to_scan, open_ports)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_port_scan_data(
|
||||||
|
ports_to_scan: Iterable[int], open_ports: Mapping[int, str]
|
||||||
|
) -> Mapping[int, PortScanData]:
|
||||||
|
port_scan_data = {}
|
||||||
|
for port in ports_to_scan:
|
||||||
|
if port in open_ports:
|
||||||
|
service = tcp_port_to_service(port)
|
||||||
|
banner = open_ports[port]
|
||||||
|
|
||||||
|
port_scan_data[port] = PortScanData(port, PortStatus.OPEN, banner, service)
|
||||||
|
else:
|
||||||
|
port_scan_data[port] = _get_closed_port_data(port)
|
||||||
|
|
||||||
|
return port_scan_data
|
||||||
|
|
||||||
|
|
||||||
def _get_closed_port_data(port: int) -> PortScanData:
|
def _get_closed_port_data(port: int) -> PortScanData:
|
||||||
return PortScanData(port, PortStatus.CLOSED, None, None)
|
return PortScanData(port, PortStatus.CLOSED, None, None)
|
||||||
|
|
||||||
|
|
||||||
def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT):
|
def _check_tcp_ports(
|
||||||
|
ip: str, ports_to_scan: Iterable[int], timeout: float = DEFAULT_TIMEOUT
|
||||||
|
) -> Mapping[int, str]:
|
||||||
"""
|
"""
|
||||||
Checks whether any of the given ports are open on a target IP.
|
Checks whether any of the given ports are open on a target IP.
|
||||||
:param ip: IP of host to attack
|
:param ip: IP of host to attack
|
||||||
:param ports: List of ports to attack. Must not be empty.
|
:param ports_to_scan: An iterable of ports to scan. Must not be empty.
|
||||||
:param timeout: Amount of time to wait for connection
|
:param timeout: Amount of time to wait for connection
|
||||||
:return: List of open ports.
|
:return: Mapping where the key is an open port and the value is the banner
|
||||||
|
:rtype: Mapping
|
||||||
"""
|
"""
|
||||||
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))]
|
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports_to_scan))]
|
||||||
# CR: Don't use list comprehensions if you don't need a list
|
for s in sockets:
|
||||||
[s.setblocking(False) for s in sockets]
|
s.setblocking(False)
|
||||||
possible_ports = []
|
|
||||||
connected_ports_sockets = []
|
possible_ports = set()
|
||||||
|
connected_ports = set()
|
||||||
|
open_ports = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug("Connecting to the following ports %s" % ",".join((str(x) for x in ports)))
|
logger.debug(
|
||||||
for sock, port in zip(sockets, ports):
|
"Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan))
|
||||||
|
)
|
||||||
|
for sock, port in zip(sockets, ports_to_scan):
|
||||||
err = sock.connect_ex((ip, port))
|
err = sock.connect_ex((ip, port))
|
||||||
if err == 0: # immediate connect
|
if err == 0: # immediate connect
|
||||||
connected_ports_sockets.append((port, sock))
|
connected_ports.add((port, sock))
|
||||||
possible_ports.append((port, sock))
|
possible_ports.add((port, sock))
|
||||||
continue
|
elif err == 10035: # WSAEWOULDBLOCK is valid.
|
||||||
# BUG: I don't think a socket will ever connect successfully if this error is raised.
|
# https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect
|
||||||
# From the documentation: "Resource temporarily unavailable... It is a nonfatal
|
# says, "Use the select function to determine the completion of the connection
|
||||||
# error, **and the operation should be retried later**." (emphasis mine). If the
|
# request by checking to see if the socket is writable," which is being done below.
|
||||||
# operation is not retried later, I don't see the point in appending this to
|
possible_ports.add((port, sock))
|
||||||
# possible_ports.
|
elif err == 115: # EINPROGRESS 115 /* Operation now in progress */
|
||||||
if err == 10035: # WSAEWOULDBLOCK is valid, see
|
possible_ports.add((port, sock))
|
||||||
# https://msdn.microsoft.com/en-us/library/windows/desktop/ms740668%28v=vs.85%29.aspx?f=255&MSPPError=-2147217396
|
else:
|
||||||
possible_ports.append((port, sock))
|
logger.warning("Failed to connect to port %s, error code is %d", port, err)
|
||||||
continue
|
|
||||||
if err == 115: # EINPROGRESS 115 /* Operation now in progress */
|
|
||||||
possible_ports.append((port, sock))
|
|
||||||
continue
|
|
||||||
logger.warning("Failed to connect to port %s, error code is %d", port, err)
|
|
||||||
|
|
||||||
if len(possible_ports) != 0:
|
if len(possible_ports) != 0:
|
||||||
timeout = int(round(timeout)) # clamp to integer, to avoid checking input
|
sockets_to_try = possible_ports.copy()
|
||||||
sockets_to_try = possible_ports[:]
|
|
||||||
# BUG: If any sockets were added to connected_ports_sockets on line 94, this would
|
timer = Timer()
|
||||||
# remove them.
|
timer.set(timeout)
|
||||||
connected_ports_sockets = []
|
|
||||||
while (timeout >= 0) and sockets_to_try:
|
while (not timer.is_expired()) and sockets_to_try:
|
||||||
|
# The call to select() may return sockets that are writeable but not actually
|
||||||
|
# connected. Adding this sleep prevents excessive looping.
|
||||||
|
time.sleep(min(POLL_INTERVAL, timer.time_remaining))
|
||||||
|
|
||||||
sock_objects = [s[1] for s in sockets_to_try]
|
sock_objects = [s[1] for s in sockets_to_try]
|
||||||
|
|
||||||
# BUG: Since timeout is 0, this could block indefinitely
|
_, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining)
|
||||||
_, writeable_sockets, _ = select.select(sock_objects, sock_objects, sock_objects, 0)
|
|
||||||
for s in writeable_sockets:
|
for s in writeable_sockets:
|
||||||
try: # actual test
|
try: # actual test
|
||||||
connected_ports_sockets.append((s.getpeername()[1], s))
|
connected_ports.add((s.getpeername()[1], s))
|
||||||
except socket.error: # bad socket, select didn't filter it properly
|
except socket.error: # bad socket, select didn't filter it properly
|
||||||
pass
|
pass
|
||||||
sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets]
|
|
||||||
if sockets_to_try:
|
sockets_to_try = sockets_to_try - connected_ports
|
||||||
time.sleep(SLEEP_BETWEEN_POLL)
|
|
||||||
timeout -= SLEEP_BETWEEN_POLL
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"On host %s discovered the following ports %s"
|
"On host %s discovered the following ports %s"
|
||||||
% (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets]))
|
% (str(ip), ",".join([str(s[0]) for s in connected_ports]))
|
||||||
)
|
)
|
||||||
banners = []
|
|
||||||
if len(connected_ports_sockets) != 0:
|
open_ports = {port: "" for port, _ in connected_ports}
|
||||||
|
if len(connected_ports) != 0:
|
||||||
readable_sockets, _, _ = select.select(
|
readable_sockets, _, _ = select.select(
|
||||||
[s[1] for s in connected_ports_sockets], [], [], 0
|
[s[1] for s in connected_ports], [], [], timer.time_remaining
|
||||||
)
|
)
|
||||||
# read first BANNER_READ bytes. We ignore errors because service might not send a
|
# read first BANNER_READ bytes. We ignore errors because service might not send a
|
||||||
# decodable byte string.
|
# decodable byte string.
|
||||||
# CR: Because of how black formats this, it is difficult to parse. Refactor to be
|
for port, sock in connected_ports:
|
||||||
# easier to read.
|
if sock in readable_sockets:
|
||||||
|
open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore")
|
||||||
# TODO: Rework the return of this function. Consider using dictionary
|
else:
|
||||||
banners = [
|
open_ports[port] = ""
|
||||||
sock.recv(BANNER_READ).decode(errors="ignore")
|
|
||||||
if sock in readable_sockets
|
|
||||||
else ""
|
|
||||||
for port, sock in connected_ports_sockets
|
|
||||||
]
|
|
||||||
pass
|
|
||||||
# try to cleanup
|
|
||||||
# CR: Evaluate whether or not we should call shutdown() before close() on each socket.
|
|
||||||
[s[1].close() for s in possible_ports]
|
|
||||||
return [port for port, sock in connected_ports_sockets], banners
|
|
||||||
else:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
except socket.error as exc:
|
except socket.error as exc:
|
||||||
logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc)
|
logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc)
|
||||||
return [], []
|
|
||||||
|
_clean_up_sockets(possible_ports, connected_ports)
|
||||||
|
|
||||||
|
return open_ports
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_up_sockets(
|
||||||
|
possible_ports: Iterable[Tuple[int, socket.socket]],
|
||||||
|
connected_ports_sockets: Iterable[Tuple[int, socket.socket]],
|
||||||
|
):
|
||||||
|
# Only call shutdown() on sockets we know to be connected
|
||||||
|
for port, s in connected_ports_sockets:
|
||||||
|
try:
|
||||||
|
s.shutdown(socket.SHUT_RDWR)
|
||||||
|
except socket.error as exc:
|
||||||
|
logger.warning(f"Error occurred while shutting down socket on port {port}: {exc}")
|
||||||
|
|
||||||
|
# Call close() for all sockets
|
||||||
|
for port, s in possible_ports:
|
||||||
|
try:
|
||||||
|
s.close()
|
||||||
|
except socket.error as exc:
|
||||||
|
logger.warning(f"Error occurred while closing socket on port {port}: {exc}")
|
||||||
|
|
|
@ -25,7 +25,18 @@ class Timer:
|
||||||
TIMEOUT_SEC, False otherwise
|
TIMEOUT_SEC, False otherwise
|
||||||
:rtype: bool
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
return (time.time() - self._start_time) >= self._timeout_sec
|
return self.time_remaining == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def time_remaining(self) -> float:
|
||||||
|
"""
|
||||||
|
Return the amount of time remaining until the timer expires.
|
||||||
|
:return: The number of seconds until the timer expires. If the timer is expired, this
|
||||||
|
function returns 0 (it will never return a negative number).
|
||||||
|
:rtype: float
|
||||||
|
"""
|
||||||
|
time_remaining = self._timeout_sec - (time.time() - self._start_time)
|
||||||
|
return max(time_remaining, 0)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from infection_monkey.i_puppet import PortStatus
|
||||||
|
from infection_monkey.network import scan_tcp_ports
|
||||||
|
|
||||||
|
PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222]
|
||||||
|
|
||||||
|
OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_check_tcp_ports(monkeypatch, open_ports_data):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"infection_monkey.network.tcp_scanner._check_tcp_ports",
|
||||||
|
lambda *_: open_ports_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
|
||||||
|
def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data):
|
||||||
|
closed_ports = [8080, 143, 445]
|
||||||
|
|
||||||
|
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0)
|
||||||
|
|
||||||
|
assert len(port_scan_data) == 6
|
||||||
|
for port in open_ports_data.keys():
|
||||||
|
assert port_scan_data[port].port == port
|
||||||
|
assert port_scan_data[port].status == PortStatus.OPEN
|
||||||
|
assert port_scan_data[port].banner == open_ports_data.get(port)
|
||||||
|
|
||||||
|
for port in closed_ports:
|
||||||
|
assert port_scan_data[port].port == port
|
||||||
|
assert port_scan_data[port].status == PortStatus.CLOSED
|
||||||
|
assert port_scan_data[port].banner is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("open_ports_data", [{}])
|
||||||
|
def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data):
|
||||||
|
|
||||||
|
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0)
|
||||||
|
|
||||||
|
assert len(port_scan_data) == 6
|
||||||
|
for port in open_ports_data:
|
||||||
|
assert port_scan_data[port].port == port
|
||||||
|
assert port_scan_data[port].status == PortStatus.CLOSED
|
||||||
|
assert port_scan_data[port].banner is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
|
||||||
|
def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data):
|
||||||
|
|
||||||
|
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0)
|
||||||
|
|
||||||
|
assert len(port_scan_data) == 0
|
|
@ -67,3 +67,28 @@ def test_timer_reset(start_time, set_current_time, timeout):
|
||||||
|
|
||||||
set_current_time(start_time + (2 * timeout))
|
set_current_time(start_time + (2 * timeout))
|
||||||
assert t.is_expired()
|
assert t.is_expired()
|
||||||
|
|
||||||
|
|
||||||
|
def test_time_remaining(start_time, set_current_time):
|
||||||
|
timeout = 5
|
||||||
|
|
||||||
|
t = Timer()
|
||||||
|
t.set(timeout)
|
||||||
|
|
||||||
|
assert t.time_remaining == timeout
|
||||||
|
|
||||||
|
set_current_time(start_time + 2)
|
||||||
|
assert t.time_remaining == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_time_remaining_is_zero(start_time, set_current_time):
|
||||||
|
timeout = 5
|
||||||
|
|
||||||
|
t = Timer()
|
||||||
|
t.set(timeout)
|
||||||
|
|
||||||
|
set_current_time(start_time + timeout)
|
||||||
|
assert t.time_remaining == 0
|
||||||
|
|
||||||
|
set_current_time(start_time + (2 * timeout))
|
||||||
|
assert t.time_remaining == 0
|
||||||
|
|
Loading…
Reference in New Issue