diff --git a/monkey/infection_monkey/network_scanning/http_fingerprinter.py b/monkey/infection_monkey/network_scanning/http_fingerprinter.py index 8ececc72a..f5dcfac64 100644 --- a/monkey/infection_monkey/network_scanning/http_fingerprinter.py +++ b/monkey/infection_monkey/network_scanning/http_fingerprinter.py @@ -1,8 +1,8 @@ import logging from contextlib import closing -from typing import Dict, Iterable, Optional, Set, Tuple +from typing import Dict, Iterable, Optional, Set, Tuple, Any -from requests import head +from requests import head, Response from requests.exceptions import ConnectionError, Timeout from infection_monkey.i_puppet import ( @@ -25,11 +25,11 @@ class HTTPFingerprinter(IFingerprinter): """ def get_host_fingerprint( - self, - host: str, - _: PingScanData, - port_scan_data: Dict[int, PortScanData], - options: Dict, + self, + host: str, + _: PingScanData, + port_scan_data: Dict[int, PortScanData], + options: Dict, ) -> FingerprintData: services = {} http_ports = set(options.get("http_ports", [])) @@ -55,22 +55,27 @@ def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], O https = f"https://{host}:{port}" for url, ssl in ((https, True), (http, False)): # start with https and downgrade - server_header_contents = _get_server_from_headers(url) + server_header = _get_server_from_headers(url) - if server_header_contents is not None: - return (server_header_contents, ssl) + if server_header is not None: + return server_header, ssl - return (None, None) + return None, None def _get_server_from_headers(url: str) -> Optional[str]: + headers = _get_http_headers(url) + if headers: + return headers.get("Server", "") + + return None + + +def _get_http_headers(url: str) -> Optional[Dict[str, Any]]: try: logger.debug(f"Sending request for headers to {url}") - with closing(head(url, verify=False, timeout=1)) as req: # noqa: DUO123 - server = req.headers.get("Server") - - logger.debug(f'Got server string "{server}" from {url}') - return server + with closing(head(url, verify=False, timeout=1)) as response: # noqa: DUO123 + return response.headers except Timeout: logger.debug(f"Timeout while requesting headers from {url}") except ConnectionError: # Someone doesn't like us @@ -80,7 +85,7 @@ def _get_server_from_headers(url: str) -> Optional[str]: def _get_open_http_ports( - allowed_http_ports: Set, port_scan_data: Dict[int, PortScanData] + allowed_http_ports: Set, port_scan_data: Dict[int, PortScanData] ) -> Iterable[int]: open_ports = (psd.port for psd in port_scan_data.values() if psd.status == PortStatus.OPEN) return (port for port in open_ports if port in allowed_http_ports) diff --git a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_http_fingerprinter.py b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_http_fingerprinter.py index 8baa97782..20a320048 100644 --- a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_http_fingerprinter.py +++ b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_http_fingerprinter.py @@ -7,25 +7,27 @@ from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprint OPTIONS = {"http_ports": [80, 443, 8080, 9200]} -PYTHON_SERVER_HEADER = "SimpleHTTP/0.6 Python/3.6.9" -APACHE_SERVER_HEADER = "Apache/Server/Header" +PYTHON_SERVER_HEADER = {"Server": "SimpleHTTP/0.6 Python/3.6.9"} +APACHE_SERVER_HEADER = {"Server": "Apache/Server/Header"} +NO_SERVER_HEADER = {"Not_Server": "No Header for you"} SERVER_HEADERS = { "https://127.0.0.1:443": PYTHON_SERVER_HEADER, "http://127.0.0.1:8080": APACHE_SERVER_HEADER, + "http://127.0.0.1:1080": NO_SERVER_HEADER, } @pytest.fixture -def mock_get_server_from_headers(): - return MagicMock(side_effect=lambda port: SERVER_HEADERS.get(port, None)) +def mock_get_http_headers(): + return MagicMock(side_effect=lambda url: SERVER_HEADERS.get(url, None)) @pytest.fixture(autouse=True) -def patch_get_server_from_headers(monkeypatch, mock_get_server_from_headers): +def patch_get_http_headers(monkeypatch, mock_get_http_headers): monkeypatch.setattr( "infection_monkey.network_scanning.http_fingerprinter._get_server_from_headers", - mock_get_server_from_headers, + mock_get_http_headers, )