diff --git a/infection_monkey/network/mssql_fingerprint.py b/infection_monkey/network/mssql_fingerprint.py index bbe762e27..51f5e25b0 100644 --- a/infection_monkey/network/mssql_fingerprint.py +++ b/infection_monkey/network/mssql_fingerprint.py @@ -3,7 +3,6 @@ import socket from model.host import VictimHost from network import HostFinger -from .tools import struct_unpack_tracker, struct_unpack_tracker_string LOG = logging.getLogger(__name__) @@ -11,6 +10,12 @@ LOG = logging.getLogger(__name__) class MSSQLFingerprint(HostFinger): + # Class related consts + SQL_BROWSER_DEFAULT_PORT = 1434 + BUFFER_SIZE = 4096 + TIMEOUT = 5 + SERVICE_NAME = 'mssql' + def __init__(self): self._config = __import__('config').WormConfiguration @@ -23,19 +28,16 @@ class MSSQLFingerprint(HostFinger): Discovered server information written to the Host info struct. """ + assert isinstance(host, VictimHost) + # Create a UDP socket and sets a timeout sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.settimeout(timeout) - server_address = (str(host), browser_port) + sock.settimeout(self.TIMEOUT) + server_address = (str(host), self.SQL_BROWSER_DEFAULT_PORT) - if instance_name: - # The message is a CLNT_UCAST_INST packet to get a single instance - # https://msdn.microsoft.com/en-us/library/cc219746.aspx - message = '\x04{0}\x00'.format(instance_name) - else: - # The message is a CLNT_UCAST_EX packet to get all instances - # https://msdn.microsoft.com/en-us/library/cc219745.aspx - message = '\x03' + # The message is a CLNT_UCAST_EX packet to get all instances + # https://msdn.microsoft.com/en-us/library/cc219745.aspx + message = '\x03' # Encode the message as a bytesarray message = message.encode() @@ -43,23 +45,22 @@ class MSSQLFingerprint(HostFinger): # send data and receive response results = [] try: - logging.info('Sending message to requested host: {0}, {1}'.format(host, message)) + LOG.info('Sending message to requested host: {0}, {1}'.format(host, message)) sock.sendto(message, server_address) - data, server = sock.recvfrom(buffer_size) + data, server = sock.recvfrom(self.BUFFER_SIZE) except socket.timeout: - logging.error('Socket timeout reached, maybe browser service on host: {0} doesnt exist'.format(host)) + LOG.error('Socket timeout reached, maybe browser service on host: {0} doesnt exist'.format(host)) + sock.close() return results + host.services[self.SERVICE_NAME] = {} + # Loop through the server data for server in data[3:].decode().split(';;'): - server_info = OrderedDict() instance_info = server.split(';') - if len(instance_info) > 1: for i in range(1, len(instance_info), 2): - server_info[instance_info[i - 1]] = instance_info[i] - - results.append(server_info) + host.services[self.SERVICE_NAME][instance_info[i - 1]] = instance_info[i] # Close the socket sock.close()