From f5ef660bd26e271a490e0270f0c09a52c51a88f6 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Mon, 7 Feb 2022 10:26:20 -0500 Subject: [PATCH] Agent: Refactor HTTPFinger to conform to IFingerprinter interface * Remove dependency on Plugin, HostFinger, and WormConfiguration * Improve readability * Reduce unnecessary HTTP requests by using the PortScanData to only query ports we know are open. --- monkey/infection_monkey/network/httpfinger.py | 95 +++++++++++++------ 1 file changed, 64 insertions(+), 31 deletions(-) diff --git a/monkey/infection_monkey/network/httpfinger.py b/monkey/infection_monkey/network/httpfinger.py index 99e9deaab..5c242a204 100644 --- a/monkey/infection_monkey/network/httpfinger.py +++ b/monkey/infection_monkey/network/httpfinger.py @@ -1,47 +1,80 @@ import logging +from contextlib import closing +from typing import Dict, Iterable, Optional, Set, Tuple -import infection_monkey.config -from infection_monkey.network.HostFinger import HostFinger +from requests import head +from requests.exceptions import ConnectionError, Timeout + +from infection_monkey.i_puppet import ( + FingerprintData, + IFingerprinter, + PingScanData, + PortScanData, + PortStatus, +) logger = logging.getLogger(__name__) -class HTTPFinger(HostFinger): +class HTTPFinger(IFingerprinter): """ Goal is to recognise HTTP servers, where what we currently care about is apache. """ - _SCANNED_SERVICE = "HTTP" + def get_host_fingerprint( + self, + host: str, + ping_scan_data: PingScanData, + port_scan_data: Dict[int, PortScanData], + options: Dict, + ): + services = {} + http_ports = set(options.get("http_ports", [])) + ports_to_fingerprint = _get_open_http_ports(http_ports, port_scan_data) - def __init__(self): - self._config = infection_monkey.config.WormConfiguration - self.HTTP = [(port, str(port)) for port in self._config.HTTP_PORTS] + for port in ports_to_fingerprint: + server_header_contents, ssl = _query_potential_http_server(host, port) - def get_host_fingerprint(self, host): - from contextlib import closing + if server_header_contents is not None: + services[f"tcp-{port}"] = { + "display_name": "HTTP", + "port": port, + "name": "http", + "data": (server_header_contents, ssl), + } - from requests import head - from requests.exceptions import ConnectionError, Timeout + return FingerprintData(None, None, services) - for port in self.HTTP: - # check both http and https - http = "http://" + host.ip_addr + ":" + port[1] - https = "https://" + host.ip_addr + ":" + port[1] - # try http, we don't optimise for 443 - for url in (https, http): # start with https and downgrade - try: - with closing(head(url, verify=False, timeout=1)) as req: # noqa: DUO123 - server = req.headers.get("Server") - ssl = True if "https://" in url else False - self.init_service(host.services, ("tcp-" + port[1]), port[0]) - host.services["tcp-" + port[1]]["name"] = "http" - host.services["tcp-" + port[1]]["data"] = (server, ssl) - logger.info("Port %d is open on host %s " % (port[0], host)) - break # https will be the same on the same port - except Timeout: - logger.debug(f"Timeout while requesting headers from {url}") - except ConnectionError: # Someone doesn't like us - logger.debug(f"Connection error while requesting headers from {url}") +def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], Optional[bool]]: + # check both http and https + http = f"http://{host}:{port}" + https = f"https://{host}:{port}" - return True + # try http, we don't optimise for 443 + for url, ssl in ((https, True), (http, False)): # start with https and downgrade + server_header_contents = _get_server_from_headers(url) + + if server_header_contents is not None: + return (server_header_contents, ssl) + + return (None, None) + + +def _get_server_from_headers(url: str) -> Optional[str]: + try: + with closing(head(url, verify=False, timeout=1)) as req: # noqa: DUO123 + return req.headers.get("Server") + except Timeout: + logger.debug(f"Timeout while requesting headers from {url}") + except ConnectionError: # Someone doesn't like us + logger.debug(f"Connection error while requesting headers from {url}") + + return None + + +def _get_open_http_ports( + allowed_http_ports: Set, port_scan_data: Dict[int, PortScanData] +) -> Iterable[int]: + open_ports = (psd.port for psd in port_scan_data.values() if psd.status == PortStatus.Open) + return (port for port in open_ports if port in allowed_http_ports)