Agent: Fix FingerprintData mypy issues

This commit is contained in:
Kekoa Kaaikala 2022-09-21 20:38:19 +00:00
parent e40d061091
commit 9d37a38994
15 changed files with 87 additions and 85 deletions

View File

@ -193,7 +193,7 @@ class SingleIpRange(NetworkRange):
:return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com) :return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com)
""" """
# The most common use case is to enter ip/range into "Scan IP/subnet list" # The most common use case is to enter ip/range into "Scan IP/subnet list"
domain_name = None domain_name = ""
if " " in string_: if " " in string_:
raise ValueError(f'"{string_}" is not a valid IP address or domain name.') raise ValueError(f'"{string_}" is not a valid IP address or domain name.')

View File

@ -2,7 +2,7 @@ import logging
from ipaddress import IPv4Interface from ipaddress import IPv4Interface
from queue import Queue from queue import Queue
from threading import Event from threading import Event
from typing import List, Sequence from typing import Dict, List, Sequence
from common.agent_configuration import ( from common.agent_configuration import (
ExploitationConfiguration, ExploitationConfiguration,
@ -154,11 +154,12 @@ class Propagator:
victim_host.services[psd.service] = {} victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)" victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None: victim_host.services[psd.service]["banner"] = psd.banner
victim_host.services[psd.service]["banner"] = psd.banner
@staticmethod @staticmethod
def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData): def _process_fingerprinter_results(
victim_host: VictimHost, fingerprint_data: Dict[str, FingerprintData]
):
for fd in fingerprint_data.values(): for fd in fingerprint_data.values():
# TODO: This logic preserves the existing behavior prior to introducing IMaster and # TODO: This logic preserves the existing behavior prior to introducing IMaster and
# IPuppet, but it is possibly flawed. Different fingerprinters may detect # IPuppet, but it is possibly flawed. Different fingerprinters may detect
@ -167,7 +168,7 @@ class Propagator:
if fd.os_type is not None: if fd.os_type is not None:
victim_host.os["type"] = fd.os_type victim_host.os["type"] = fd.os_type
if ("version" not in victim_host.os) and (fd.os_version is not None): if ("version" not in victim_host.os) and (fd.os_version):
victim_host.os["version"] = fd.os_version victim_host.os["version"] = fd.os_version
for service, details in fd.services.items(): for service, details in fd.services.items():

View File

@ -31,10 +31,10 @@ class ElasticSearchFingerprinter(IFingerprinter):
port_scan_data: Dict[int, PortScanData], port_scan_data: Dict[int, PortScanData],
_options: Dict, _options: Dict,
) -> FingerprintData: ) -> FingerprintData:
services = {} services: Dict[str, Any] = {}
if (ES_PORT not in port_scan_data) or (port_scan_data[ES_PORT].status != PortStatus.OPEN): if (ES_PORT not in port_scan_data) or (port_scan_data[ES_PORT].status != PortStatus.OPEN):
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
try: try:
elasticsearch_info = _query_elasticsearch(host) elasticsearch_info = _query_elasticsearch(host)
@ -42,7 +42,7 @@ class ElasticSearchFingerprinter(IFingerprinter):
except Exception as ex: except Exception as ex:
logger.debug(f"Did not detect an ElasticSearch cluster: {ex}") logger.debug(f"Did not detect an ElasticSearch cluster: {ex}")
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
def _query_elasticsearch(host: str) -> Dict[str, Any]: def _query_elasticsearch(host: str) -> Dict[str, Any]:

View File

@ -1,6 +1,6 @@
import logging import logging
from contextlib import closing from contextlib import closing
from typing import Any, Dict, Iterable, Optional, Set, Tuple from typing import Any, Dict, Iterable, Mapping, Optional, Set, Tuple
from requests import head from requests import head
from requests.exceptions import ConnectionError, Timeout from requests.exceptions import ConnectionError, Timeout
@ -46,7 +46,7 @@ class HTTPFingerprinter(IFingerprinter):
"data": (server_header_contents, ssl), "data": (server_header_contents, ssl),
} }
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], Optional[bool]]: def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], Optional[bool]]:
@ -71,7 +71,7 @@ def _get_server_from_headers(url: str) -> Optional[str]:
return None return None
def _get_http_headers(url: str) -> Optional[Dict[str, Any]]: def _get_http_headers(url: str) -> Optional[Mapping[str, Any]]:
try: try:
logger.debug(f"Sending request for headers to {url}") logger.debug(f"Sending request for headers to {url}")
with closing(head(url, verify=False, timeout=1)) as response: # noqa: DUO123 with closing(head(url, verify=False, timeout=1)) as response: # noqa: DUO123

View File

@ -1,7 +1,7 @@
import errno import errno
import logging import logging
import socket import socket
from typing import Any, Dict, Optional from typing import Any, Dict
from infection_monkey.i_puppet import FingerprintData, IFingerprinter, PingScanData, PortScanData from infection_monkey.i_puppet import FingerprintData, IFingerprinter, PingScanData, PortScanData
@ -32,10 +32,10 @@ class MSSQLFingerprinter(IFingerprinter):
except Exception as ex: except Exception as ex:
logger.debug(f"Did not detect an MSSQL server: {ex}") logger.debug(f"Did not detect an MSSQL server: {ex}")
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
def _query_mssql_for_instance_data(host: str) -> Optional[bytes]: def _query_mssql_for_instance_data(host: str) -> bytes:
# Create a UDP socket and sets a timeout # Create a UDP socket and sets a timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(_MSSQL_SOCKET_TIMEOUT) sock.settimeout(_MSSQL_SOCKET_TIMEOUT)
@ -44,22 +44,20 @@ def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
# The message is a CLNT_UCAST_EX packet to get all instances # The message is a CLNT_UCAST_EX packet to get all instances
# https://msdn.microsoft.com/en-us/library/cc219745.aspx # https://msdn.microsoft.com/en-us/library/cc219745.aspx
message = "\x03" message_str = "\x03"
# Encode the message as a bytes array # Encode the message as a bytes array
message = message.encode() message = message_str.encode()
# send data and receive response # send data and receive response
try: try:
logger.info(f"Sending message to requested host: {host}, {message}") logger.info(f"Sending message to requested host: {host}, {message_str}")
sock.sendto(message, server_address) sock.sendto(message, server_address)
data, _ = sock.recvfrom(_BUFFER_SIZE) data, _ = sock.recvfrom(_BUFFER_SIZE)
return data return data
except socket.timeout as err: except socket.timeout as err:
logger.debug( logger.debug(f"Socket timeout reached, maybe browser service on host: {host} doesn't exist")
f"Socket timeout reached, maybe browser service on host: {host} doesnt " "exist"
)
raise err raise err
except socket.error as err: except socket.error as err:
if err.errno == errno.ECONNRESET: if err.errno == errno.ECONNRESET:
@ -78,7 +76,7 @@ def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
def _get_services_from_server_data(data: bytes) -> Dict[str, Any]: def _get_services_from_server_data(data: bytes) -> Dict[str, Any]:
services = {MSSQL_SERVICE: {}} services: Dict[str, Any] = {MSSQL_SERVICE: {}}
services[MSSQL_SERVICE]["display_name"] = DISPLAY_NAME services[MSSQL_SERVICE]["display_name"] = DISPLAY_NAME
services[MSSQL_SERVICE]["port"] = SQL_BROWSER_DEFAULT_PORT services[MSSQL_SERVICE]["port"] = SQL_BROWSER_DEFAULT_PORT

View File

@ -26,10 +26,10 @@ def compile_scan_target_list(
scan_targets.extend(_get_ips_to_scan_from_local_interface(local_network_interfaces)) scan_targets.extend(_get_ips_to_scan_from_local_interface(local_network_interfaces))
if inaccessible_subnets: if inaccessible_subnets:
inaccessible_subnets = _get_segmentation_check_targets( inaccessible_subnet_addresses = _get_segmentation_check_targets(
inaccessible_subnets, local_network_interfaces inaccessible_subnets, local_network_interfaces
) )
scan_targets.extend(inaccessible_subnets) scan_targets.extend(inaccessible_subnet_addresses)
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces) scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips) scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
@ -44,7 +44,7 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr
for target in targets: for target in targets:
domain_name = target.domain domain_name = target.domain
ip = target.ip ip = target.ip
if ip not in reverse_dns or (reverse_dns[ip] is None and domain_name is not None): if ip not in reverse_dns or (not reverse_dns[ip] and domain_name):
reverse_dns[ip] = domain_name reverse_dns[ip] = domain_name
return [NetworkAddress(key, value) for (key, value) in reverse_dns.items()] return [NetworkAddress(key, value) for (key, value) in reverse_dns.items()]
@ -55,7 +55,7 @@ def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
if hasattr(range_obj, "domain_name"): if hasattr(range_obj, "domain_name"):
addresses.append(NetworkAddress(address, range_obj.domain_name)) addresses.append(NetworkAddress(address, range_obj.domain_name))
else: else:
addresses.append(NetworkAddress(address, None)) addresses.append(NetworkAddress(address, ""))
return addresses return addresses

View File

@ -144,16 +144,16 @@ class SMBFingerprinter(IFingerprinter):
port_scan_data: Dict[int, PortScanData], port_scan_data: Dict[int, PortScanData],
_options: Dict, _options: Dict,
) -> FingerprintData: ) -> FingerprintData:
services = {} services: Dict = {}
smb_service = { smb_service = {
"display_name": DISPLAY_NAME, "display_name": DISPLAY_NAME,
"port": SMB_PORT, "port": SMB_PORT,
} }
os_type = None os_type = None
os_version = None os_version = ""
if (SMB_PORT not in port_scan_data) or (port_scan_data[SMB_PORT].status != PortStatus.OPEN): if (SMB_PORT not in port_scan_data) or (port_scan_data[SMB_PORT].status != PortStatus.OPEN):
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
logger.debug(f"Fingerprinting potential SMB port {SMB_PORT} on {host}") logger.debug(f"Fingerprinting potential SMB port {SMB_PORT} on {host}")

View File

@ -21,7 +21,7 @@ class SSHFingerprinter(IFingerprinter):
_options: Dict, _options: Dict,
) -> FingerprintData: ) -> FingerprintData:
os_type = None os_type = None
os_version = None os_version = ""
services = {} services = {}
for ps_data in port_scan_data.values(): for ps_data in port_scan_data.values():
@ -35,9 +35,9 @@ class SSHFingerprinter(IFingerprinter):
return FingerprintData(os_type, os_version, services) return FingerprintData(os_type, os_version, services)
@staticmethod @staticmethod
def _get_host_os(banner) -> Tuple[Optional[str], Optional[str]]: def _get_host_os(banner) -> Tuple[Optional[OperatingSystem], str]:
os = None os = None
os_version = None os_version = ""
for dist in LINUX_DIST_SSH: for dist in LINUX_DIST_SSH:
if banner.lower().find(dist) != -1: if banner.lower().find(dist) != -1:
os_version = banner.split(" ").pop().strip() os_version = banner.split(" ").pop().strip()

View File

@ -1,6 +1,6 @@
import logging import logging
import threading import threading
from typing import Dict, Iterable, List, Sequence from typing import Dict, Iterable, List, Mapping, Sequence
from common.common_consts.timeouts import CONNECTION_TIMEOUT from common.common_consts.timeouts import CONNECTION_TIMEOUT
from common.credentials import Credentials from common.credentials import Credentials
@ -8,6 +8,7 @@ from infection_monkey import network_scanning
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
FingerprintData, FingerprintData,
IFingerprinter,
IPuppet, IPuppet,
PingScanData, PingScanData,
PluginType, PluginType,
@ -18,7 +19,7 @@ from infection_monkey.model import VictimHost
from .plugin_registry import PluginRegistry from .plugin_registry import PluginRegistry
EMPTY_FINGERPRINT = PingScanData(False, None) EMPTY_FINGERPRINT = FingerprintData(None, "", {})
logger = logging.getLogger() logger = logging.getLogger()
@ -45,7 +46,7 @@ class Puppet(IPuppet):
def scan_tcp_ports( def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = CONNECTION_TIMEOUT self, host: str, ports: List[int], timeout: float = CONNECTION_TIMEOUT
) -> Dict[int, PortScanData]: ) -> Mapping[int, PortScanData]:
return network_scanning.scan_tcp_ports(host, ports, timeout) return network_scanning.scan_tcp_ports(host, ports, timeout)
def fingerprint( def fingerprint(
@ -57,7 +58,9 @@ class Puppet(IPuppet):
options: Dict, options: Dict,
) -> FingerprintData: ) -> FingerprintData:
try: try:
fingerprinter = self._plugin_registry.get_plugin(name, PluginType.FINGERPRINTER) fingerprinter: IFingerprinter = self._plugin_registry.get_plugin(
name, PluginType.FINGERPRINTER
)
return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options) return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options)
except Exception: except Exception:
logger.exception( logger.exception(

View File

@ -154,7 +154,7 @@ def assert_fingerprint_results_no_3(fingerprint_data):
def assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data):
assert address.ip not in {"10.0.0.1", "10.0.0.3"} assert address.ip not in {"10.0.0.1", "10.0.0.3"}
assert address.domain is None assert not address.domain
assert ping_scan_data.response_received is False assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6 assert len(port_scan_data.keys()) == 6
@ -178,9 +178,9 @@ def test_scan_single_ip(callback, scan_config, stop):
def test_scan_multiple_ips(callback, scan_config, stop): def test_scan_multiple_ips(callback, scan_config, stop):
addresses = [ addresses = [
NetworkAddress("10.0.0.1", "d1"), NetworkAddress("10.0.0.1", "d1"),
NetworkAddress("10.0.0.2", None), NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", "d3"), NetworkAddress("10.0.0.3", "d3"),
NetworkAddress("10.0.0.4", None), NetworkAddress("10.0.0.4", ""),
] ]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
@ -203,7 +203,7 @@ def test_scan_multiple_ips(callback, scan_config, stop):
@pytest.mark.slow @pytest.mark.slow
def test_scan_lots_of_ips(callback, scan_config, stop): def test_scan_lots_of_ips(callback, scan_config, stop):
addresses = [NetworkAddress(f"10.0.0.{i}", None) for i in range(0, 255)] addresses = [NetworkAddress(f"10.0.0.{i}", "") for i in range(0, 255)]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(addresses, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
@ -223,10 +223,10 @@ def test_stop_after_callback(scan_config, stop):
stoppable_callback = MagicMock(side_effect=_callback) stoppable_callback = MagicMock(side_effect=_callback)
addresses = [ addresses = [
NetworkAddress("10.0.0.1", None), NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", None), NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", None), NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", None), NetworkAddress("10.0.0.4", ""),
] ]
ns = IPScanner(MockPuppet(), num_workers=2) ns = IPScanner(MockPuppet(), num_workers=2)
@ -251,10 +251,10 @@ def test_interrupt_before_fingerprinting(callback, scan_config, stop):
puppet.fingerprint = MagicMock() puppet.fingerprint = MagicMock()
addresses = [ addresses = [
NetworkAddress("10.0.0.1", None), NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", None), NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", None), NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", None), NetworkAddress("10.0.0.4", ""),
] ]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)
@ -270,7 +270,7 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
stoppable_fingerprint.barrier.wait() stoppable_fingerprint.barrier.wait()
stop.set() stop.set()
return FingerprintData(None, None, {}) return FingerprintData(None, "", {})
stoppable_fingerprint.barrier = Barrier(2) stoppable_fingerprint.barrier = Barrier(2)
@ -278,10 +278,10 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint) puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint)
addresses = [ addresses = [
NetworkAddress("10.0.0.1", None), NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", None), NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", None), NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", None), NetworkAddress("10.0.0.4", ""),
] ]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)

View File

@ -38,7 +38,7 @@ def test_successful(monkeypatch, fingerprinter):
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
es_service = fingerprint_data.services[ES_SERVICE] es_service = fingerprint_data.services[ES_SERVICE]
@ -60,7 +60,7 @@ def test_fingerprinting_skipped_if_port_closed(monkeypatch, fingerprinter, port_
assert not mock_query_elasticsearch.called assert not mock_query_elasticsearch.called
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0
@ -82,5 +82,5 @@ def test_no_response_from_server(monkeypatch, fingerprinter, mock_query_function
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0

View File

@ -63,7 +63,7 @@ def test_fingerprint_only_port_443(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_called_with("https://127.0.0.1:443") mock_get_http_headers.assert_called_with("https://127.0.0.1:443")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"] assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
@ -86,7 +86,7 @@ def test_open_port_no_http_server(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_any_call("http://127.0.0.1:9200") mock_get_http_headers.assert_any_call("http://127.0.0.1:9200")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0
@ -106,7 +106,7 @@ def test_multiple_open_ports(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_any_call("http://127.0.0.1:8080") mock_get_http_headers.assert_any_call("http://127.0.0.1:8080")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 2 assert len(fingerprint_data.services.keys()) == 2
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"] assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
@ -126,7 +126,7 @@ def test_server_missing_from_http_headers(mock_get_http_headers, http_fingerprin
assert mock_get_http_headers.call_count == 2 assert mock_get_http_headers.call_count == 2
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-1080"]["data"][0] == "" assert fingerprint_data.services["tcp-1080"]["data"][0] == ""

View File

@ -45,7 +45,7 @@ def test_mssql_fingerprint_successful(monkeypatch, fingerprinter):
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
# Each mssql instance is under his name # Each mssql instance is under his name
@ -78,7 +78,7 @@ def test_mssql_no_response_from_server(monkeypatch, fingerprinter, mock_query_fu
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0
@ -98,5 +98,5 @@ def test_mssql_wrong_response_from_server(monkeypatch, fingerprinter):
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0

View File

@ -24,7 +24,7 @@ def test_single_subnet():
assert len(scan_targets) == 255 assert len(scan_targets) == 255
for i in range(0, 255): for i in range(0, 255):
assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets assert NetworkAddress(f"10.0.0.{i}", "") in scan_targets
@pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"]) @pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"])
@ -33,8 +33,8 @@ def test_single_ip(single_ip):
scan_targets = compile_ranges_only([single_ip]) scan_targets = compile_ranges_only([single_ip])
assert len(scan_targets) == 1 assert len(scan_targets) == 1
assert NetworkAddress("10.0.0.2", None) in scan_targets assert NetworkAddress("10.0.0.2", "") in scan_targets
assert NetworkAddress("10.0.0.2", None) == scan_targets[0] assert NetworkAddress("10.0.0.2", "") == scan_targets[0]
def test_multiple_subnet(): def test_multiple_subnet():
@ -43,10 +43,10 @@ def test_multiple_subnet():
assert len(scan_targets) == 262 assert len(scan_targets) == 262
for i in range(0, 255): for i in range(0, 255):
assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets assert NetworkAddress(f"10.0.0.{i}", "") in scan_targets
for i in range(8, 15): for i in range(8, 15):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_middle_of_range_subnet(): def test_middle_of_range_subnet():
@ -55,7 +55,7 @@ def test_middle_of_range_subnet():
assert len(scan_targets) == 7 assert len(scan_targets) == 7
for i in range(0, 7): for i in range(0, 7):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -68,7 +68,7 @@ def test_ip_range(ip_range):
assert len(scan_targets) == 9 assert len(scan_targets) == 9
for i in range(25, 34): for i in range(25, 34):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_no_duplicates(): def test_no_duplicates():
@ -77,7 +77,7 @@ def test_no_duplicates():
assert len(scan_targets) == 7 assert len(scan_targets) == 7
for i in range(0, 7): for i in range(0, 7):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_blocklisted_ips(): def test_blocklisted_ips():
@ -146,7 +146,7 @@ def test_no_redundant_targets():
) )
assert len(scan_targets) == 2 assert len(scan_targets) == 2
assert NetworkAddress(ip="127.0.0.0", domain=None) in scan_targets assert NetworkAddress(ip="127.0.0.0", domain="") in scan_targets
assert NetworkAddress(ip="127.0.0.1", domain="localhost") in scan_targets assert NetworkAddress(ip="127.0.0.1", domain="localhost") in scan_targets
@ -212,7 +212,7 @@ def test_local_subnet_added():
assert len(scan_targets) == 254 assert len(scan_targets) == 254
for ip in chain(range(0, 5), range(6, 255)): for ip in chain(range(0, 5), range(6, 255)):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
def test_multiple_local_subnets_added(): def test_multiple_local_subnets_added():
@ -232,10 +232,10 @@ def test_multiple_local_subnets_added():
assert len(scan_targets) == 2 * (255 - 1) assert len(scan_targets) == 2 * (255 - 1)
for ip in chain(range(0, 5), range(6, 255)): for ip in chain(range(0, 5), range(6, 255)):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in chain(range(0, 99), range(100, 255)): for ip in chain(range(0, 99), range(100, 255)):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_blocklisted_ips_missing_from_local_subnets(): def test_blocklisted_ips_missing_from_local_subnets():
@ -273,12 +273,12 @@ def test_local_subnets_and_ranges_added():
assert len(scan_targets) == 254 + 3 assert len(scan_targets) == 254 + 3
for ip in range(0, 5): for ip in range(0, 5):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in range(6, 255): for ip in range(6, 255):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in range(40, 43): for ip in range(40, 43):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_local_network_interfaces_specified_but_disabled(): def test_local_network_interfaces_specified_but_disabled():
@ -295,7 +295,7 @@ def test_local_network_interfaces_specified_but_disabled():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in range(40, 43): for ip in range(40, 43):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_local_network_interfaces_subnet_masks(): def test_local_network_interfaces_subnet_masks():
@ -315,7 +315,7 @@ def test_local_network_interfaces_subnet_masks():
assert len(scan_targets) == 4 assert len(scan_targets) == 4
for ip in [108, 110, 145, 146]: for ip in [108, 110, 145, 146]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_targets(): def test_segmentation_targets():
@ -334,7 +334,7 @@ def test_segmentation_targets():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [144, 145, 146]: for ip in [144, 145, 146]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_clash_with_blocked(): def test_segmentation_clash_with_blocked():
@ -377,7 +377,7 @@ def test_segmentation_clash_with_targets():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [148, 149, 150]: for ip in [148, 149, 150]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_one_network(): def test_segmentation_one_network():
@ -443,7 +443,7 @@ def test_invalid_inputs():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [148, 149, 150]: for ip in [148, 149, 150]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_invalid_blocklisted_ip(): def test_invalid_blocklisted_ip():

View File

@ -18,7 +18,7 @@ def test_no_ssh_ports_open(ssh_fingerprinter):
} }
results = ssh_fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, None) results = ssh_fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, None)
assert results == FingerprintData(None, None, {}) assert results == FingerprintData(None, "", {})
def test_no_os(ssh_fingerprinter): def test_no_os(ssh_fingerprinter):
@ -32,7 +32,7 @@ def test_no_os(ssh_fingerprinter):
assert results == FingerprintData( assert results == FingerprintData(
None, None,
None, "",
{ {
"tcp-22": { "tcp-22": {
"display_name": "SSH", "display_name": "SSH",