diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index 3d3f66a14..6fdded293 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -2,126 +2,145 @@ import logging import select import socket import time -from itertools import zip_longest -from typing import Dict, List, Set +from typing import Iterable, Mapping, Tuple from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service - -SLEEP_BETWEEN_POLL = 0.5 +from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) - -def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]: - ports_scan = {} - - open_ports, banners = _check_tcp_ports(host, ports, timeout) - open_ports = set(open_ports) - - for port, banner in zip_longest(ports, banners, fillvalue=None): - ports_scan[port] = _build_port_scan_data(port, open_ports, banner) - - return ports_scan +POLL_INTERVAL = 0.5 -def _build_port_scan_data(port: int, open_ports: Set[int], banner: str) -> PortScanData: - if port in open_ports: - service = tcp_port_to_service(port) - return PortScanData(port, PortStatus.OPEN, banner, service) - else: - return _get_closed_port_data(port) +def scan_tcp_ports( + host: str, ports_to_scan: Iterable[int], timeout: float +) -> Mapping[int, PortScanData]: + open_ports = _check_tcp_ports(host, ports_to_scan, timeout) + + return _build_port_scan_data(ports_to_scan, open_ports) + + +def _build_port_scan_data( + ports_to_scan: Iterable[int], open_ports: Mapping[int, str] +) -> Mapping[int, PortScanData]: + port_scan_data = {} + for port in ports_to_scan: + if port in open_ports: + service = tcp_port_to_service(port) + banner = open_ports[port] + + port_scan_data[port] = PortScanData(port, PortStatus.OPEN, banner, service) + else: + port_scan_data[port] = _get_closed_port_data(port) + + return port_scan_data def _get_closed_port_data(port: int) -> PortScanData: return PortScanData(port, PortStatus.CLOSED, None, None) -def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT): +def _check_tcp_ports( + ip: str, ports_to_scan: Iterable[int], timeout: float = DEFAULT_TIMEOUT +) -> Mapping[int, str]: """ Checks whether any of the given ports are open on a target IP. :param ip: IP of host to attack - :param ports: List of ports to attack. Must not be empty. + :param ports_to_scan: An iterable of ports to scan. Must not be empty. :param timeout: Amount of time to wait for connection - :return: List of open ports. + :return: Mapping where the key is an open port and the value is the banner + :rtype: Mapping """ - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))] - # CR: Don't use list comprehensions if you don't need a list - [s.setblocking(False) for s in sockets] - possible_ports = [] - connected_ports_sockets = [] + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports_to_scan))] + for s in sockets: + s.setblocking(False) + + possible_ports = set() + connected_ports = set() + open_ports = {} + try: - logger.debug("Connecting to the following ports %s" % ",".join((str(x) for x in ports))) - for sock, port in zip(sockets, ports): + logger.debug( + "Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan)) + ) + for sock, port in zip(sockets, ports_to_scan): err = sock.connect_ex((ip, port)) if err == 0: # immediate connect - connected_ports_sockets.append((port, sock)) - possible_ports.append((port, sock)) - continue - # BUG: I don't think a socket will ever connect successfully if this error is raised. - # From the documentation: "Resource temporarily unavailable... It is a nonfatal - # error, **and the operation should be retried later**." (emphasis mine). If the - # operation is not retried later, I don't see the point in appending this to - # possible_ports. - if err == 10035: # WSAEWOULDBLOCK is valid, see - # https://msdn.microsoft.com/en-us/library/windows/desktop/ms740668%28v=vs.85%29.aspx?f=255&MSPPError=-2147217396 - possible_ports.append((port, sock)) - continue - if err == 115: # EINPROGRESS 115 /* Operation now in progress */ - possible_ports.append((port, sock)) - continue - logger.warning("Failed to connect to port %s, error code is %d", port, err) + connected_ports.add((port, sock)) + possible_ports.add((port, sock)) + elif err == 10035: # WSAEWOULDBLOCK is valid. + # https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect + # says, "Use the select function to determine the completion of the connection + # request by checking to see if the socket is writable," which is being done below. + possible_ports.add((port, sock)) + elif err == 115: # EINPROGRESS 115 /* Operation now in progress */ + possible_ports.add((port, sock)) + else: + logger.warning("Failed to connect to port %s, error code is %d", port, err) if len(possible_ports) != 0: - timeout = int(round(timeout)) # clamp to integer, to avoid checking input - sockets_to_try = possible_ports[:] - # BUG: If any sockets were added to connected_ports_sockets on line 94, this would - # remove them. - connected_ports_sockets = [] - while (timeout >= 0) and sockets_to_try: + sockets_to_try = possible_ports.copy() + + timer = Timer() + timer.set(timeout) + + while (not timer.is_expired()) and sockets_to_try: + # The call to select() may return sockets that are writeable but not actually + # connected. Adding this sleep prevents excessive looping. + time.sleep(min(POLL_INTERVAL, timer.time_remaining)) + sock_objects = [s[1] for s in sockets_to_try] - # BUG: Since timeout is 0, this could block indefinitely - _, writeable_sockets, _ = select.select(sock_objects, sock_objects, sock_objects, 0) + _, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining) for s in writeable_sockets: try: # actual test - connected_ports_sockets.append((s.getpeername()[1], s)) + connected_ports.add((s.getpeername()[1], s)) except socket.error: # bad socket, select didn't filter it properly pass - sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets] - if sockets_to_try: - time.sleep(SLEEP_BETWEEN_POLL) - timeout -= SLEEP_BETWEEN_POLL + + sockets_to_try = sockets_to_try - connected_ports logger.debug( "On host %s discovered the following ports %s" - % (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets])) + % (str(ip), ",".join([str(s[0]) for s in connected_ports])) ) - banners = [] - if len(connected_ports_sockets) != 0: + + open_ports = {port: "" for port, _ in connected_ports} + if len(connected_ports) != 0: readable_sockets, _, _ = select.select( - [s[1] for s in connected_ports_sockets], [], [], 0 + [s[1] for s in connected_ports], [], [], timer.time_remaining ) # read first BANNER_READ bytes. We ignore errors because service might not send a # decodable byte string. - # CR: Because of how black formats this, it is difficult to parse. Refactor to be - # easier to read. - - # TODO: Rework the return of this function. Consider using dictionary - banners = [ - sock.recv(BANNER_READ).decode(errors="ignore") - if sock in readable_sockets - else "" - for port, sock in connected_ports_sockets - ] - pass - # try to cleanup - # CR: Evaluate whether or not we should call shutdown() before close() on each socket. - [s[1].close() for s in possible_ports] - return [port for port, sock in connected_ports_sockets], banners - else: - return [], [] + for port, sock in connected_ports: + if sock in readable_sockets: + open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore") + else: + open_ports[port] = "" except socket.error as exc: logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc) - return [], [] + + _clean_up_sockets(possible_ports, connected_ports) + + return open_ports + + +def _clean_up_sockets( + possible_ports: Iterable[Tuple[int, socket.socket]], + connected_ports_sockets: Iterable[Tuple[int, socket.socket]], +): + # Only call shutdown() on sockets we know to be connected + for port, s in connected_ports_sockets: + try: + s.shutdown(socket.SHUT_RDWR) + except socket.error as exc: + logger.warning(f"Error occurred while shutting down socket on port {port}: {exc}") + + # Call close() for all sockets + for port, s in possible_ports: + try: + s.close() + except socket.error as exc: + logger.warning(f"Error occurred while closing socket on port {port}: {exc}") diff --git a/monkey/infection_monkey/utils/timer.py b/monkey/infection_monkey/utils/timer.py index 366a10b20..2ed17d551 100644 --- a/monkey/infection_monkey/utils/timer.py +++ b/monkey/infection_monkey/utils/timer.py @@ -25,7 +25,18 @@ class Timer: TIMEOUT_SEC, False otherwise :rtype: bool """ - return (time.time() - self._start_time) >= self._timeout_sec + return self.time_remaining == 0 + + @property + def time_remaining(self) -> float: + """ + Return the amount of time remaining until the timer expires. + :return: The number of seconds until the timer expires. If the timer is expired, this + function returns 0 (it will never return a negative number). + :rtype: float + """ + time_remaining = self._timeout_sec - (time.time() - self._start_time) + return max(time_remaining, 0) def reset(self): """ diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_tcp_scanning.py b/monkey/tests/unit_tests/infection_monkey/network/test_tcp_scanning.py new file mode 100644 index 000000000..e383e1004 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network/test_tcp_scanning.py @@ -0,0 +1,54 @@ +import pytest + +from infection_monkey.i_puppet import PortStatus +from infection_monkey.network import scan_tcp_ports + +PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] + +OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"} + + +@pytest.fixture +def patch_check_tcp_ports(monkeypatch, open_ports_data): + monkeypatch.setattr( + "infection_monkey.network.tcp_scanner._check_tcp_ports", + lambda *_: open_ports_data, + ) + + +@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) +def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data): + closed_ports = [8080, 143, 445] + + port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) + + assert len(port_scan_data) == 6 + for port in open_ports_data.keys(): + assert port_scan_data[port].port == port + assert port_scan_data[port].status == PortStatus.OPEN + assert port_scan_data[port].banner == open_ports_data.get(port) + + for port in closed_ports: + assert port_scan_data[port].port == port + assert port_scan_data[port].status == PortStatus.CLOSED + assert port_scan_data[port].banner is None + + +@pytest.mark.parametrize("open_ports_data", [{}]) +def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data): + + port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) + + assert len(port_scan_data) == 6 + for port in open_ports_data: + assert port_scan_data[port].port == port + assert port_scan_data[port].status == PortStatus.CLOSED + assert port_scan_data[port].banner is None + + +@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) +def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data): + + port_scan_data = scan_tcp_ports("127.0.0.1", [], 0) + + assert len(port_scan_data) == 0 diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_timer.py b/monkey/tests/unit_tests/infection_monkey/utils/test_timer.py index 5359b8c79..b5291cc0e 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_timer.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_timer.py @@ -67,3 +67,28 @@ def test_timer_reset(start_time, set_current_time, timeout): set_current_time(start_time + (2 * timeout)) assert t.is_expired() + + +def test_time_remaining(start_time, set_current_time): + timeout = 5 + + t = Timer() + t.set(timeout) + + assert t.time_remaining == timeout + + set_current_time(start_time + 2) + assert t.time_remaining == 3 + + +def test_time_remaining_is_zero(start_time, set_current_time): + timeout = 5 + + t = Timer() + t.set(timeout) + + set_current_time(start_time + timeout) + assert t.time_remaining == 0 + + set_current_time(start_time + (2 * timeout)) + assert t.time_remaining == 0