Agent: Change _check_tcp_ports() to return Mapping[int, str]

This commit is contained in:
Mike Salvatore 2022-02-10 09:32:14 -05:00 committed by Ilija Lazoroski
parent d3dd6ffeb0
commit a53b611759
1 changed files with 36 additions and 35 deletions

View File

@ -1,8 +1,7 @@
import logging
import select
import socket
from itertools import zip_longest
from typing import Dict, List, Set
from typing import Iterable, Mapping
from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
@ -11,26 +10,28 @@ from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__)
def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]:
ports_scan = {}
def scan_tcp_ports(
host: str, ports_to_scan: Iterable[int], timeout: float
) -> Mapping[int, PortScanData]:
open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
open_ports_data = _check_tcp_ports(host, ports, timeout)
open_ports = set(open_ports_data["open_ports"])
banners = open_ports_data["banners"]
for port, banner in zip_longest(ports, banners, fillvalue=None):
ports_scan[port] = _build_port_scan_data(port, open_ports, banner)
return ports_scan
return _build_port_scan_data(ports_to_scan, open_ports)
def _build_port_scan_data(port: int, open_ports: Set[int], banner: str) -> PortScanData:
if port in open_ports:
service = tcp_port_to_service(port)
return PortScanData(port, PortStatus.OPEN, banner, service)
else:
return _get_closed_port_data(port)
def _build_port_scan_data(
ports_to_scan: Iterable[int], open_ports: Mapping[int, str]
) -> Mapping[int, PortScanData]:
port_scan_data = {}
for port in ports_to_scan:
if port in open_ports:
service = tcp_port_to_service(port)
banner = open_ports[port]
port_scan_data[port] = PortScanData(port, PortStatus.OPEN, banner, service)
else:
port_scan_data[port] = _get_closed_port_data(port)
return port_scan_data
def _get_closed_port_data(port: int) -> PortScanData:
@ -38,26 +39,29 @@ def _get_closed_port_data(port: int) -> PortScanData:
def _check_tcp_ports(
ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT
) -> Dict[str, List]:
ip: str, ports_to_scan: Iterable[int], timeout: float = DEFAULT_TIMEOUT
) -> Mapping[int, str]:
"""
Checks whether any of the given ports are open on a target IP.
:param ip: IP of host to attack
:param ports: List of ports to attack. Must not be empty.
:param ports_to_scan: An iterable of ports to scan. Must not be empty.
:param timeout: Amount of time to wait for connection
:return: Dict with list of open ports and list of banners.
:return: Mapping where the key is an open port and the value is the banner
:rtype: Mapping
"""
open_ports_data = {"open_ports": [], "banners": []}
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))]
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports_to_scan))]
for s in sockets:
s.setblocking(False)
possible_ports = []
connected_ports_sockets = []
open_ports = {}
try:
logger.debug("Connecting to the following ports %s" % ",".join((str(x) for x in ports)))
for sock, port in zip(sockets, ports):
logger.debug(
"Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan))
)
for sock, port in zip(sockets, ports_to_scan):
err = sock.connect_ex((ip, port))
if err == 0: # immediate connect
connected_ports_sockets.append((port, sock))
@ -96,7 +100,7 @@ def _check_tcp_ports(
% (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets]))
)
banners = []
open_ports = {port: "" for port, _ in connected_ports_sockets}
if len(connected_ports_sockets) != 0:
readable_sockets, _, _ = select.select(
[s[1] for s in connected_ports_sockets], [], [], timer.time_remaining
@ -105,20 +109,17 @@ def _check_tcp_ports(
# decodable byte string.
for port, sock in connected_ports_sockets:
if sock in readable_sockets:
banners.append(sock.recv(BANNER_READ).decode(errors="ignore"))
open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore")
else:
banners.append("")
open_ports[port] = ""
# try to cleanup
for s in possible_ports:
s[1].shutdown(socket.SHUT_RDWR)
s[1].close()
open_ports_data["open_ports"] = [port for port, _ in connected_ports_sockets]
open_ports_data["banners"] = banners
except socket.error as exc:
logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc)
finally:
return open_ports_data
return open_ports