Agent: Change _check_tcp_ports() to return Mapping[int, str]
This commit is contained in:
parent
d3dd6ffeb0
commit
a53b611759
|
@ -1,8 +1,7 @@
|
|||
import logging
|
||||
import select
|
||||
import socket
|
||||
from itertools import zip_longest
|
||||
from typing import Dict, List, Set
|
||||
from typing import Iterable, Mapping
|
||||
|
||||
from infection_monkey.i_puppet import PortScanData, PortStatus
|
||||
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
|
||||
|
@ -11,26 +10,28 @@ from infection_monkey.utils.timer import Timer
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]:
|
||||
ports_scan = {}
|
||||
def scan_tcp_ports(
|
||||
host: str, ports_to_scan: Iterable[int], timeout: float
|
||||
) -> Mapping[int, PortScanData]:
|
||||
open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
|
||||
|
||||
open_ports_data = _check_tcp_ports(host, ports, timeout)
|
||||
|
||||
open_ports = set(open_ports_data["open_ports"])
|
||||
banners = open_ports_data["banners"]
|
||||
|
||||
for port, banner in zip_longest(ports, banners, fillvalue=None):
|
||||
ports_scan[port] = _build_port_scan_data(port, open_ports, banner)
|
||||
|
||||
return ports_scan
|
||||
return _build_port_scan_data(ports_to_scan, open_ports)
|
||||
|
||||
|
||||
def _build_port_scan_data(port: int, open_ports: Set[int], banner: str) -> PortScanData:
|
||||
if port in open_ports:
|
||||
service = tcp_port_to_service(port)
|
||||
return PortScanData(port, PortStatus.OPEN, banner, service)
|
||||
else:
|
||||
return _get_closed_port_data(port)
|
||||
def _build_port_scan_data(
|
||||
ports_to_scan: Iterable[int], open_ports: Mapping[int, str]
|
||||
) -> Mapping[int, PortScanData]:
|
||||
port_scan_data = {}
|
||||
for port in ports_to_scan:
|
||||
if port in open_ports:
|
||||
service = tcp_port_to_service(port)
|
||||
banner = open_ports[port]
|
||||
|
||||
port_scan_data[port] = PortScanData(port, PortStatus.OPEN, banner, service)
|
||||
else:
|
||||
port_scan_data[port] = _get_closed_port_data(port)
|
||||
|
||||
return port_scan_data
|
||||
|
||||
|
||||
def _get_closed_port_data(port: int) -> PortScanData:
|
||||
|
@ -38,26 +39,29 @@ def _get_closed_port_data(port: int) -> PortScanData:
|
|||
|
||||
|
||||
def _check_tcp_ports(
|
||||
ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT
|
||||
) -> Dict[str, List]:
|
||||
ip: str, ports_to_scan: Iterable[int], timeout: float = DEFAULT_TIMEOUT
|
||||
) -> Mapping[int, str]:
|
||||
"""
|
||||
Checks whether any of the given ports are open on a target IP.
|
||||
:param ip: IP of host to attack
|
||||
:param ports: List of ports to attack. Must not be empty.
|
||||
:param ports_to_scan: An iterable of ports to scan. Must not be empty.
|
||||
:param timeout: Amount of time to wait for connection
|
||||
:return: Dict with list of open ports and list of banners.
|
||||
:return: Mapping where the key is an open port and the value is the banner
|
||||
:rtype: Mapping
|
||||
"""
|
||||
open_ports_data = {"open_ports": [], "banners": []}
|
||||
|
||||
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))]
|
||||
sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports_to_scan))]
|
||||
for s in sockets:
|
||||
s.setblocking(False)
|
||||
|
||||
possible_ports = []
|
||||
connected_ports_sockets = []
|
||||
open_ports = {}
|
||||
|
||||
try:
|
||||
logger.debug("Connecting to the following ports %s" % ",".join((str(x) for x in ports)))
|
||||
for sock, port in zip(sockets, ports):
|
||||
logger.debug(
|
||||
"Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan))
|
||||
)
|
||||
for sock, port in zip(sockets, ports_to_scan):
|
||||
err = sock.connect_ex((ip, port))
|
||||
if err == 0: # immediate connect
|
||||
connected_ports_sockets.append((port, sock))
|
||||
|
@ -96,7 +100,7 @@ def _check_tcp_ports(
|
|||
% (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets]))
|
||||
)
|
||||
|
||||
banners = []
|
||||
open_ports = {port: "" for port, _ in connected_ports_sockets}
|
||||
if len(connected_ports_sockets) != 0:
|
||||
readable_sockets, _, _ = select.select(
|
||||
[s[1] for s in connected_ports_sockets], [], [], timer.time_remaining
|
||||
|
@ -105,20 +109,17 @@ def _check_tcp_ports(
|
|||
# decodable byte string.
|
||||
for port, sock in connected_ports_sockets:
|
||||
if sock in readable_sockets:
|
||||
banners.append(sock.recv(BANNER_READ).decode(errors="ignore"))
|
||||
open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore")
|
||||
else:
|
||||
banners.append("")
|
||||
open_ports[port] = ""
|
||||
|
||||
# try to cleanup
|
||||
for s in possible_ports:
|
||||
s[1].shutdown(socket.SHUT_RDWR)
|
||||
s[1].close()
|
||||
|
||||
open_ports_data["open_ports"] = [port for port, _ in connected_ports_sockets]
|
||||
open_ports_data["banners"] = banners
|
||||
|
||||
except socket.error as exc:
|
||||
logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc)
|
||||
|
||||
finally:
|
||||
return open_ports_data
|
||||
return open_ports
|
||||
|
|
Loading…
Reference in New Issue