diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index a8ec9751b..62af733e7 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -1,8 +1,7 @@ import logging import select import socket -from itertools import zip_longest -from typing import Dict, List, Set +from typing import Iterable, Mapping from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service @@ -11,26 +10,28 @@ from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) -def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]: - ports_scan = {} +def scan_tcp_ports( + host: str, ports_to_scan: Iterable[int], timeout: float +) -> Mapping[int, PortScanData]: + open_ports = _check_tcp_ports(host, ports_to_scan, timeout) - open_ports_data = _check_tcp_ports(host, ports, timeout) - - open_ports = set(open_ports_data["open_ports"]) - banners = open_ports_data["banners"] - - for port, banner in zip_longest(ports, banners, fillvalue=None): - ports_scan[port] = _build_port_scan_data(port, open_ports, banner) - - return ports_scan + return _build_port_scan_data(ports_to_scan, open_ports) -def _build_port_scan_data(port: int, open_ports: Set[int], banner: str) -> PortScanData: - if port in open_ports: - service = tcp_port_to_service(port) - return PortScanData(port, PortStatus.OPEN, banner, service) - else: - return _get_closed_port_data(port) +def _build_port_scan_data( + ports_to_scan: Iterable[int], open_ports: Mapping[int, str] +) -> Mapping[int, PortScanData]: + port_scan_data = {} + for port in ports_to_scan: + if port in open_ports: + service = tcp_port_to_service(port) + banner = open_ports[port] + + port_scan_data[port] = PortScanData(port, PortStatus.OPEN, banner, service) + else: + port_scan_data[port] = _get_closed_port_data(port) + + return port_scan_data def _get_closed_port_data(port: int) -> PortScanData: @@ -38,26 +39,29 @@ def _get_closed_port_data(port: int) -> PortScanData: def _check_tcp_ports( - ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT -) -> Dict[str, List]: + ip: str, ports_to_scan: Iterable[int], timeout: float = DEFAULT_TIMEOUT +) -> Mapping[int, str]: """ Checks whether any of the given ports are open on a target IP. :param ip: IP of host to attack - :param ports: List of ports to attack. Must not be empty. + :param ports_to_scan: An iterable of ports to scan. Must not be empty. :param timeout: Amount of time to wait for connection - :return: Dict with list of open ports and list of banners. + :return: Mapping where the key is an open port and the value is the banner + :rtype: Mapping """ - open_ports_data = {"open_ports": [], "banners": []} - - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))] + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports_to_scan))] for s in sockets: s.setblocking(False) possible_ports = [] connected_ports_sockets = [] + open_ports = {} + try: - logger.debug("Connecting to the following ports %s" % ",".join((str(x) for x in ports))) - for sock, port in zip(sockets, ports): + logger.debug( + "Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan)) + ) + for sock, port in zip(sockets, ports_to_scan): err = sock.connect_ex((ip, port)) if err == 0: # immediate connect connected_ports_sockets.append((port, sock)) @@ -96,7 +100,7 @@ def _check_tcp_ports( % (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets])) ) - banners = [] + open_ports = {port: "" for port, _ in connected_ports_sockets} if len(connected_ports_sockets) != 0: readable_sockets, _, _ = select.select( [s[1] for s in connected_ports_sockets], [], [], timer.time_remaining @@ -105,20 +109,17 @@ def _check_tcp_ports( # decodable byte string. for port, sock in connected_ports_sockets: if sock in readable_sockets: - banners.append(sock.recv(BANNER_READ).decode(errors="ignore")) + open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore") else: - banners.append("") + open_ports[port] = "" # try to cleanup for s in possible_ports: s[1].shutdown(socket.SHUT_RDWR) s[1].close() - open_ports_data["open_ports"] = [port for port, _ in connected_ports_sockets] - open_ports_data["banners"] = banners - except socket.error as exc: logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc) finally: - return open_ports_data + return open_ports