Agent: Refactor MSSQL fingerprinter

* Refactor code to conform to the IFingerprinter interface
* Non-structured server response will return empty Fingerprint data
* Rename mssql_fingerprint to mssql_fingerprinter
* Unit tests
This commit is contained in:
Ilija Lazoroski 2022-02-09 12:27:00 +01:00
parent f0602edffb
commit e6f5b6113f
4 changed files with 204 additions and 92 deletions

View File

@ -185,7 +185,6 @@ class InfectionMonkey:
@staticmethod
def _build_puppet() -> IPuppet:
puppet = Puppet()
puppet.load_plugin("elastic", ElasticSearchFingerprinter(), PluginType.FINGERPRINTER)
puppet.load_plugin("http", HTTPFingerprinter(), PluginType.FINGERPRINTER)

View File

@ -1,91 +0,0 @@
import errno
import logging
import socket
import infection_monkey.config
from infection_monkey.network.HostFinger import HostFinger
logger = logging.getLogger(__name__)
class MSSQLFinger(HostFinger):
# Class related consts
SQL_BROWSER_DEFAULT_PORT = 1434
BUFFER_SIZE = 4096
TIMEOUT = 5
_SCANNED_SERVICE = "MSSQL"
def __init__(self):
self._config = infection_monkey.config.WormConfiguration
def get_host_fingerprint(self, host):
"""Gets Microsoft SQL Server instance information by querying the SQL Browser service.
:arg:
host (VictimHost): The MS-SSQL Server to query for information.
:returns:
Discovered server information written to the Host info struct.
True if success, False otherwise.
"""
# Create a UDP socket and sets a timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(self.TIMEOUT)
server_address = (str(host.ip_addr), self.SQL_BROWSER_DEFAULT_PORT)
# The message is a CLNT_UCAST_EX packet to get all instances
# https://msdn.microsoft.com/en-us/library/cc219745.aspx
message = "\x03"
# Encode the message as a bytesarray
message = message.encode()
# send data and receive response
try:
logger.info("Sending message to requested host: {0}, {1}".format(host, message))
sock.sendto(message, server_address)
data, server = sock.recvfrom(self.BUFFER_SIZE)
except socket.timeout:
logger.info(
"Socket timeout reached, maybe browser service on host: {0} doesnt "
"exist".format(host)
)
sock.close()
return False
except socket.error as e:
if e.errno == errno.ECONNRESET:
logger.info(
"Connection was forcibly closed by the remote host. The host: {0} is "
"rejecting the packet.".format(host)
)
else:
logger.error(
"An unknown socket error occurred while trying the mssql fingerprint, "
"closing socket.",
exc_info=True,
)
sock.close()
return False
self.init_service(
host.services, self._SCANNED_SERVICE, MSSQLFinger.SQL_BROWSER_DEFAULT_PORT
)
# Loop through the server data
instances_list = data[3:].decode().split(";;")
logger.info("{0} MSSQL instances found".format(len(instances_list)))
for instance in instances_list:
instance_info = instance.split(";")
if len(instance_info) > 1:
host.services[self._SCANNED_SERVICE][instance_info[1]] = {}
for i in range(1, len(instance_info), 2):
# Each instance's info is nested under its own name, if there are multiple
# instances
# each will appear under its own name
host.services[self._SCANNED_SERVICE][instance_info[1]][
instance_info[i - 1]
] = instance_info[i]
# Close the socket
sock.close()
return True

View File

@ -0,0 +1,102 @@
import errno
import logging
import socket
from typing import Any, Dict, Optional
from infection_monkey.i_puppet import FingerprintData, IFingerprinter, PingScanData, PortScanData
MSSQL_SERVICE = "MSSQL"
SQL_BROWSER_DEFAULT_PORT = 1434
_BUFFER_SIZE = 4096
_MSSQL_SOCKET_TIMEOUT = 5
logger = logging.getLogger(__name__)
class MSSQLFingerprinter(IFingerprinter):
def get_host_fingerprint(
self,
host: str,
_: PingScanData,
port_scan_data: Dict[int, PortScanData],
options: Dict,
):
"""Gets Microsoft SQL Server instance information by querying the SQL Browser service."""
services = {}
try:
data = _query_mssql_for_instance_data(host)
services = _get_services_from_server_data(data)
except Exception as ex:
logger.debug(f"Did not detect an MSSQL server: {ex}")
return FingerprintData(None, None, services)
def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
# Create a UDP socket and sets a timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(_MSSQL_SOCKET_TIMEOUT)
server_address = (host, SQL_BROWSER_DEFAULT_PORT)
# The message is a CLNT_UCAST_EX packet to get all instances
# https://msdn.microsoft.com/en-us/library/cc219745.aspx
message = "\x03"
# Encode the message as a bytes array
message = message.encode()
# send data and receive response
try:
logger.info(f"Sending message to requested host: {host}, {message}")
sock.sendto(message, server_address)
data, _ = sock.recvfrom(_BUFFER_SIZE)
return data
except socket.timeout as err:
logger.debug(
f"Socket timeout reached, maybe browser service on host: {host} doesnt " "exist"
)
raise err
except socket.error as err:
if err.errno == errno.ECONNRESET:
error_message = (
f"Connection was forcibly closed by the remote host. The host: {host} is "
"rejecting the packet."
)
else:
error_message = (
"An unknown socket error occurred while trying the mssql fingerprint, "
"closing socket."
)
raise Exception(error_message) from err
finally:
sock.close()
def _get_services_from_server_data(data: bytes) -> Dict[str, Any]:
services = {MSSQL_SERVICE: {}}
services[MSSQL_SERVICE]["display_name"] = MSSQL_SERVICE
services[MSSQL_SERVICE]["port"] = SQL_BROWSER_DEFAULT_PORT
# Loop through the server data
mssql_instances = filter(lambda i: i != "", data[3:].decode().split(";;"))
for instance in mssql_instances:
instance_info = instance.split(";")
if len(instance_info) > 1:
services[MSSQL_SERVICE][instance_info[1]] = {}
for i in range(1, len(instance_info), 2):
# Each instance's info is nested under its own name, if there are multiple
# instances
# each will appear under its own name
services[MSSQL_SERVICE][instance_info[1]][instance_info[i - 1]] = instance_info[i]
logger.debug(f"Found MSSQL instance: {instance}")
if len(services[MSSQL_SERVICE].keys()) == 2:
services = {}
return services

View File

@ -0,0 +1,102 @@
import socket
from unittest.mock import MagicMock
import pytest
from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.network.mssql_fingerprinter import (
MSSQL_SERVICE,
SQL_BROWSER_DEFAULT_PORT,
MSSQLFingerprinter,
)
PORT_SCAN_DATA_BOGUS = {
80: PortScanData(80, PortStatus.OPEN, "", "tcp-80"),
8080: PortScanData(8080, PortStatus.OPEN, "", "tcp-8080"),
}
@pytest.fixture
def fingerprinter():
return MSSQLFingerprinter()
def test_mssql_fingerprint_successful(monkeypatch, fingerprinter):
successful_service_response = {
"ServerName": "BogusVogus",
"InstanceName": "GhostServer",
"IsClustered": "No",
"Version": "11.1.1111.111",
"tcp": "1433",
"np": "blah_blah",
}
successful_server_response = (
b"\x05y\x00ServerName;BogusVogus;InstanceName;GhostServer;"
b"IsClustered;No;Version;11.1.1111.111;tcp;1433;np;blah_blah;;"
)
monkeypatch.setattr(
"infection_monkey.network.mssql_fingerprinter._query_mssql_for_instance_data",
lambda _: successful_server_response,
)
fingerprint_data = fingerprinter.get_host_fingerprint(
"127.0.0.1", None, PORT_SCAN_DATA_BOGUS, {}
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert len(fingerprint_data.services.keys()) == 1
# Each mssql instance is under his name
assert len(fingerprint_data.services["MSSQL"].keys()) == 3
assert fingerprint_data.services["MSSQL"]["display_name"] == MSSQL_SERVICE
assert fingerprint_data.services["MSSQL"]["port"] == SQL_BROWSER_DEFAULT_PORT
mssql_service = fingerprint_data.services["MSSQL"]["BogusVogus"]
assert len(mssql_service.keys()) == len(successful_service_response.keys())
for key, value in successful_service_response.items():
assert mssql_service[key] == value
@pytest.mark.parametrize(
"mock_query_function",
[
MagicMock(side_effect=socket.timeout),
MagicMock(side_effect=socket.error),
MagicMock(side_effect=Exception),
],
)
def test_mssql_no_response_from_server(monkeypatch, fingerprinter, mock_query_function):
monkeypatch.setattr(
"infection_monkey.network.mssql_fingerprinter._query_mssql_for_instance_data",
mock_query_function,
)
fingerprint_data = fingerprinter.get_host_fingerprint(
"127.0.0.1", None, PORT_SCAN_DATA_BOGUS, {}
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert len(fingerprint_data.services.keys()) == 0
def test_mssql_wrong_response_from_server(monkeypatch, fingerprinter):
mangled_server_response = (
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
b"Pellentesque ultrices ornare libero, ;;"
)
monkeypatch.setattr(
"infection_monkey.network.mssql_fingerprinter._query_mssql_for_instance_data",
lambda _: mangled_server_response,
)
fingerprint_data = fingerprinter.get_host_fingerprint(
"127.0.0.1", None, PORT_SCAN_DATA_BOGUS, {}
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert len(fingerprint_data.services.keys()) == 0