diff --git a/monkey/infection_monkey/network/elasticsearch_fingerprinter.py b/monkey/infection_monkey/network/elasticsearch_fingerprinter.py index 8ec2e3890..6670c3621 100644 --- a/monkey/infection_monkey/network/elasticsearch_fingerprinter.py +++ b/monkey/infection_monkey/network/elasticsearch_fingerprinter.py @@ -1,48 +1,64 @@ -import json import logging from contextlib import closing +from typing import Any, Dict import requests -from requests.exceptions import ConnectionError, Timeout -import infection_monkey.config from common.common_consts.network_consts import ES_SERVICE -from infection_monkey.network.HostFinger import HostFinger +from infection_monkey.i_puppet import ( + FingerprintData, + IFingerprinter, + PingScanData, + PortScanData, + PortStatus, +) ES_PORT = 9200 ES_HTTP_TIMEOUT = 5 logger = logging.getLogger(__name__) -class ElasticFinger(HostFinger): +class ElasticSearchFingerprinter(IFingerprinter): """ Fingerprints elastic search clusters, only on port 9200 """ - _SCANNED_SERVICE = "Elastic search" + def get_host_fingerprint( + self, + host: str, + _ping_scan_data: PingScanData, + port_scan_data: Dict[int, PortScanData], + _options: Dict, + ) -> FingerprintData: + services = {} - def __init__(self): - self._config = infection_monkey.config.WormConfiguration + if (ES_PORT not in port_scan_data) or (port_scan_data[ES_PORT].status != PortStatus.OPEN): + return FingerprintData(None, None, services) - def get_host_fingerprint(self, host): - """ - Returns elasticsearch metadata - :param host: - :return: Success/failure, data is saved in the host struct - """ try: - url = "http://%s:%s/" % (host.ip_addr, ES_PORT) - with closing(requests.get(url, timeout=ES_HTTP_TIMEOUT)) as req: - data = json.loads(req.text) - self.init_service(host.services, ES_SERVICE, ES_PORT) - host.services[ES_SERVICE]["cluster_name"] = data["cluster_name"] - host.services[ES_SERVICE]["name"] = data["name"] - host.services[ES_SERVICE]["version"] = data["version"]["number"] - return True - except Timeout: - logger.debug("Got timeout while trying to read header information") - except ConnectionError: # Someone doesn't like us - logger.debug("Unknown connection error") - except KeyError: - logger.debug("Failed parsing the ElasticSearch JSOn response") - return False + elasticsearch_info = _query_elasticsearch(host) + services[ES_SERVICE] = _get_service_from_query_info(elasticsearch_info) + except Exception as ex: + logger.debug(f"Did not detect an ElasticSearch cluster: {ex}") + + return FingerprintData(None, None, services) + + +def _query_elasticsearch(host: str) -> Dict[str, Any]: + url = "http://%s:%s/" % (host, ES_PORT) + logger.debug(f"Sending request to {url}") + with closing(requests.get(url, timeout=ES_HTTP_TIMEOUT)) as response: + return response.json() + + +def _get_service_from_query_info(elasticsearch_info: Dict[str, Any]) -> Dict[str, Any]: + try: + return { + "display_name": "ElasticSearch", + "port": ES_PORT, + "cluster_name": elasticsearch_info["cluster_name"], + "name": elasticsearch_info["name"], + "version": elasticsearch_info["version"]["number"], + } + except KeyError as ke: + raise Exception(f"Unable to find the key {ke} in the server's response") from ke diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_elasticsearch_fingerprinter.py b/monkey/tests/unit_tests/infection_monkey/network/test_elasticsearch_fingerprinter.py new file mode 100644 index 000000000..f15afa60e --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network/test_elasticsearch_fingerprinter.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock + +import pytest + +from common.common_consts.network_consts import ES_SERVICE +from infection_monkey.i_puppet import PortScanData, PortStatus +from infection_monkey.network.elasticsearch_fingerprinter import ES_PORT, ElasticSearchFingerprinter + +PORT_SCAN_DATA_OPEN = {ES_PORT: PortScanData(ES_PORT, PortStatus.OPEN, "", f"tcp-{ES_PORT}")} +PORT_SCAN_DATA_CLOSED = {ES_PORT: PortScanData(ES_PORT, PortStatus.CLOSED, "", f"tcp-{ES_PORT}")} +PORT_SCAN_DATA_MISSING = { + 80: PortScanData(80, PortStatus.OPEN, "", "tcp-80"), + 8080: PortScanData(8080, PortStatus.OPEN, "", "tcp-8080"), +} + + +@pytest.fixture +def fingerprinter(): + return ElasticSearchFingerprinter() + + +def test_successful(monkeypatch, fingerprinter): + successful_server_response = { + "cluster_name": "test cluster", + "name": "test name", + "version": {"number": "1.0.0"}, + } + monkeypatch.setattr( + "infection_monkey.network.elasticsearch_fingerprinter._query_elasticsearch", + lambda _: successful_server_response, + ) + + fingerprint_data = fingerprinter.get_host_fingerprint( + "127.0.0.1", None, PORT_SCAN_DATA_OPEN, {} + ) + + assert fingerprint_data.os_type is None + assert fingerprint_data.os_version is None + assert len(fingerprint_data.services.keys()) == 1 + + es_service = fingerprint_data.services[ES_SERVICE] + + assert es_service["cluster_name"] == successful_server_response["cluster_name"] + assert es_service["version"] == successful_server_response["version"]["number"] + assert es_service["name"] == successful_server_response["name"] + + +@pytest.mark.parametrize("port_scan_data", [PORT_SCAN_DATA_CLOSED, PORT_SCAN_DATA_MISSING]) +def test_fingerprinting_skipped_if_port_closed(monkeypatch, fingerprinter, port_scan_data): + mock_query_elasticsearch = MagicMock() + monkeypatch.setattr( + "infection_monkey.network.elasticsearch_fingerprinter._query_elasticsearch", + mock_query_elasticsearch, + ) + + fingerprint_data = fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, {}) + + assert not mock_query_elasticsearch.called + assert fingerprint_data.os_type is None + assert fingerprint_data.os_version is None + assert len(fingerprint_data.services.keys()) == 0 + + +@pytest.mark.parametrize( + "mock_query_function", + [ + MagicMock(side_effect=Exception("test exception")), + MagicMock(return_value={"unexpected_key": "unexpected_value"}), + ], +) +def test_no_response_from_server(monkeypatch, fingerprinter, mock_query_function): + monkeypatch.setattr( + "infection_monkey.network.elasticsearch_fingerprinter._query_elasticsearch", + mock_query_function, + ) + + fingerprint_data = fingerprinter.get_host_fingerprint( + "127.0.0.1", None, PORT_SCAN_DATA_OPEN, {} + ) + + assert fingerprint_data.os_type is None + assert fingerprint_data.os_version is None + assert len(fingerprint_data.services.keys()) == 0