Agent: Fix FingerprintData mypy issues

This commit is contained in:
Kekoa Kaaikala 2022-09-21 20:38:19 +00:00
parent e40d061091
commit 9d37a38994
15 changed files with 87 additions and 85 deletions

View File

@ -193,7 +193,7 @@ class SingleIpRange(NetworkRange):
:return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com)
"""
# The most common use case is to enter ip/range into "Scan IP/subnet list"
domain_name = None
domain_name = ""
if " " in string_:
raise ValueError(f'"{string_}" is not a valid IP address or domain name.')

View File

@ -2,7 +2,7 @@ import logging
from ipaddress import IPv4Interface
from queue import Queue
from threading import Event
from typing import List, Sequence
from typing import Dict, List, Sequence
from common.agent_configuration import (
ExploitationConfiguration,
@ -154,11 +154,12 @@ class Propagator:
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
victim_host.services[psd.service]["banner"] = psd.banner
@staticmethod
def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData):
def _process_fingerprinter_results(
victim_host: VictimHost, fingerprint_data: Dict[str, FingerprintData]
):
for fd in fingerprint_data.values():
# TODO: This logic preserves the existing behavior prior to introducing IMaster and
# IPuppet, but it is possibly flawed. Different fingerprinters may detect
@ -167,7 +168,7 @@ class Propagator:
if fd.os_type is not None:
victim_host.os["type"] = fd.os_type
if ("version" not in victim_host.os) and (fd.os_version is not None):
if ("version" not in victim_host.os) and (fd.os_version):
victim_host.os["version"] = fd.os_version
for service, details in fd.services.items():

View File

@ -31,10 +31,10 @@ class ElasticSearchFingerprinter(IFingerprinter):
port_scan_data: Dict[int, PortScanData],
_options: Dict,
) -> FingerprintData:
services = {}
services: Dict[str, Any] = {}
if (ES_PORT not in port_scan_data) or (port_scan_data[ES_PORT].status != PortStatus.OPEN):
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
try:
elasticsearch_info = _query_elasticsearch(host)
@ -42,7 +42,7 @@ class ElasticSearchFingerprinter(IFingerprinter):
except Exception as ex:
logger.debug(f"Did not detect an ElasticSearch cluster: {ex}")
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
def _query_elasticsearch(host: str) -> Dict[str, Any]:

View File

@ -1,6 +1,6 @@
import logging
from contextlib import closing
from typing import Any, Dict, Iterable, Optional, Set, Tuple
from typing import Any, Dict, Iterable, Mapping, Optional, Set, Tuple
from requests import head
from requests.exceptions import ConnectionError, Timeout
@ -46,7 +46,7 @@ class HTTPFingerprinter(IFingerprinter):
"data": (server_header_contents, ssl),
}
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], Optional[bool]]:
@ -71,7 +71,7 @@ def _get_server_from_headers(url: str) -> Optional[str]:
return None
def _get_http_headers(url: str) -> Optional[Dict[str, Any]]:
def _get_http_headers(url: str) -> Optional[Mapping[str, Any]]:
try:
logger.debug(f"Sending request for headers to {url}")
with closing(head(url, verify=False, timeout=1)) as response: # noqa: DUO123

View File

@ -1,7 +1,7 @@
import errno
import logging
import socket
from typing import Any, Dict, Optional
from typing import Any, Dict
from infection_monkey.i_puppet import FingerprintData, IFingerprinter, PingScanData, PortScanData
@ -32,10 +32,10 @@ class MSSQLFingerprinter(IFingerprinter):
except Exception as ex:
logger.debug(f"Did not detect an MSSQL server: {ex}")
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
def _query_mssql_for_instance_data(host: str) -> bytes:
# Create a UDP socket and sets a timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(_MSSQL_SOCKET_TIMEOUT)
@ -44,22 +44,20 @@ def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
# The message is a CLNT_UCAST_EX packet to get all instances
# https://msdn.microsoft.com/en-us/library/cc219745.aspx
message = "\x03"
message_str = "\x03"
# Encode the message as a bytes array
message = message.encode()
message = message_str.encode()
# send data and receive response
try:
logger.info(f"Sending message to requested host: {host}, {message}")
logger.info(f"Sending message to requested host: {host}, {message_str}")
sock.sendto(message, server_address)
data, _ = sock.recvfrom(_BUFFER_SIZE)
return data
except socket.timeout as err:
logger.debug(
f"Socket timeout reached, maybe browser service on host: {host} doesnt " "exist"
)
logger.debug(f"Socket timeout reached, maybe browser service on host: {host} doesn't exist")
raise err
except socket.error as err:
if err.errno == errno.ECONNRESET:
@ -78,7 +76,7 @@ def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
def _get_services_from_server_data(data: bytes) -> Dict[str, Any]:
services = {MSSQL_SERVICE: {}}
services: Dict[str, Any] = {MSSQL_SERVICE: {}}
services[MSSQL_SERVICE]["display_name"] = DISPLAY_NAME
services[MSSQL_SERVICE]["port"] = SQL_BROWSER_DEFAULT_PORT

View File

@ -26,10 +26,10 @@ def compile_scan_target_list(
scan_targets.extend(_get_ips_to_scan_from_local_interface(local_network_interfaces))
if inaccessible_subnets:
inaccessible_subnets = _get_segmentation_check_targets(
inaccessible_subnet_addresses = _get_segmentation_check_targets(
inaccessible_subnets, local_network_interfaces
)
scan_targets.extend(inaccessible_subnets)
scan_targets.extend(inaccessible_subnet_addresses)
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
@ -44,7 +44,7 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr
for target in targets:
domain_name = target.domain
ip = target.ip
if ip not in reverse_dns or (reverse_dns[ip] is None and domain_name is not None):
if ip not in reverse_dns or (not reverse_dns[ip] and domain_name):
reverse_dns[ip] = domain_name
return [NetworkAddress(key, value) for (key, value) in reverse_dns.items()]
@ -55,7 +55,7 @@ def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
if hasattr(range_obj, "domain_name"):
addresses.append(NetworkAddress(address, range_obj.domain_name))
else:
addresses.append(NetworkAddress(address, None))
addresses.append(NetworkAddress(address, ""))
return addresses

View File

@ -144,16 +144,16 @@ class SMBFingerprinter(IFingerprinter):
port_scan_data: Dict[int, PortScanData],
_options: Dict,
) -> FingerprintData:
services = {}
services: Dict = {}
smb_service = {
"display_name": DISPLAY_NAME,
"port": SMB_PORT,
}
os_type = None
os_version = None
os_version = ""
if (SMB_PORT not in port_scan_data) or (port_scan_data[SMB_PORT].status != PortStatus.OPEN):
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
logger.debug(f"Fingerprinting potential SMB port {SMB_PORT} on {host}")

View File

@ -21,7 +21,7 @@ class SSHFingerprinter(IFingerprinter):
_options: Dict,
) -> FingerprintData:
os_type = None
os_version = None
os_version = ""
services = {}
for ps_data in port_scan_data.values():
@ -35,9 +35,9 @@ class SSHFingerprinter(IFingerprinter):
return FingerprintData(os_type, os_version, services)
@staticmethod
def _get_host_os(banner) -> Tuple[Optional[str], Optional[str]]:
def _get_host_os(banner) -> Tuple[Optional[OperatingSystem], str]:
os = None
os_version = None
os_version = ""
for dist in LINUX_DIST_SSH:
if banner.lower().find(dist) != -1:
os_version = banner.split(" ").pop().strip()

View File

@ -1,6 +1,6 @@
import logging
import threading
from typing import Dict, Iterable, List, Sequence
from typing import Dict, Iterable, List, Mapping, Sequence
from common.common_consts.timeouts import CONNECTION_TIMEOUT
from common.credentials import Credentials
@ -8,6 +8,7 @@ from infection_monkey import network_scanning
from infection_monkey.i_puppet import (
ExploiterResultData,
FingerprintData,
IFingerprinter,
IPuppet,
PingScanData,
PluginType,
@ -18,7 +19,7 @@ from infection_monkey.model import VictimHost
from .plugin_registry import PluginRegistry
EMPTY_FINGERPRINT = PingScanData(False, None)
EMPTY_FINGERPRINT = FingerprintData(None, "", {})
logger = logging.getLogger()
@ -45,7 +46,7 @@ class Puppet(IPuppet):
def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = CONNECTION_TIMEOUT
) -> Dict[int, PortScanData]:
) -> Mapping[int, PortScanData]:
return network_scanning.scan_tcp_ports(host, ports, timeout)
def fingerprint(
@ -57,7 +58,9 @@ class Puppet(IPuppet):
options: Dict,
) -> FingerprintData:
try:
fingerprinter = self._plugin_registry.get_plugin(name, PluginType.FINGERPRINTER)
fingerprinter: IFingerprinter = self._plugin_registry.get_plugin(
name, PluginType.FINGERPRINTER
)
return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options)
except Exception:
logger.exception(

View File

@ -154,7 +154,7 @@ def assert_fingerprint_results_no_3(fingerprint_data):
def assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data):
assert address.ip not in {"10.0.0.1", "10.0.0.3"}
assert address.domain is None
assert not address.domain
assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6
@ -178,9 +178,9 @@ def test_scan_single_ip(callback, scan_config, stop):
def test_scan_multiple_ips(callback, scan_config, stop):
addresses = [
NetworkAddress("10.0.0.1", "d1"),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", "d3"),
NetworkAddress("10.0.0.4", None),
NetworkAddress("10.0.0.4", ""),
]
ns = IPScanner(MockPuppet(), num_workers=4)
@ -203,7 +203,7 @@ def test_scan_multiple_ips(callback, scan_config, stop):
@pytest.mark.slow
def test_scan_lots_of_ips(callback, scan_config, stop):
addresses = [NetworkAddress(f"10.0.0.{i}", None) for i in range(0, 255)]
addresses = [NetworkAddress(f"10.0.0.{i}", "") for i in range(0, 255)]
ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(addresses, scan_config, callback, stop)
@ -223,10 +223,10 @@ def test_stop_after_callback(scan_config, stop):
stoppable_callback = MagicMock(side_effect=_callback)
addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", ""),
]
ns = IPScanner(MockPuppet(), num_workers=2)
@ -251,10 +251,10 @@ def test_interrupt_before_fingerprinting(callback, scan_config, stop):
puppet.fingerprint = MagicMock()
addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", ""),
]
ns = IPScanner(puppet, num_workers=2)
@ -270,7 +270,7 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
stoppable_fingerprint.barrier.wait()
stop.set()
return FingerprintData(None, None, {})
return FingerprintData(None, "", {})
stoppable_fingerprint.barrier = Barrier(2)
@ -278,10 +278,10 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint)
addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", ""),
]
ns = IPScanner(puppet, num_workers=2)

View File

@ -38,7 +38,7 @@ def test_successful(monkeypatch, fingerprinter):
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1
es_service = fingerprint_data.services[ES_SERVICE]
@ -60,7 +60,7 @@ def test_fingerprinting_skipped_if_port_closed(monkeypatch, fingerprinter, port_
assert not mock_query_elasticsearch.called
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0
@ -82,5 +82,5 @@ def test_no_response_from_server(monkeypatch, fingerprinter, mock_query_function
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0

View File

@ -63,7 +63,7 @@ def test_fingerprint_only_port_443(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_called_with("https://127.0.0.1:443")
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
@ -86,7 +86,7 @@ def test_open_port_no_http_server(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_any_call("http://127.0.0.1:9200")
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0
@ -106,7 +106,7 @@ def test_multiple_open_ports(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_any_call("http://127.0.0.1:8080")
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 2
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
@ -126,7 +126,7 @@ def test_server_missing_from_http_headers(mock_get_http_headers, http_fingerprin
assert mock_get_http_headers.call_count == 2
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-1080"]["data"][0] == ""

View File

@ -45,7 +45,7 @@ def test_mssql_fingerprint_successful(monkeypatch, fingerprinter):
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1
# Each mssql instance is under his name
@ -78,7 +78,7 @@ def test_mssql_no_response_from_server(monkeypatch, fingerprinter, mock_query_fu
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0
@ -98,5 +98,5 @@ def test_mssql_wrong_response_from_server(monkeypatch, fingerprinter):
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0

View File

@ -24,7 +24,7 @@ def test_single_subnet():
assert len(scan_targets) == 255
for i in range(0, 255):
assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{i}", "") in scan_targets
@pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"])
@ -33,8 +33,8 @@ def test_single_ip(single_ip):
scan_targets = compile_ranges_only([single_ip])
assert len(scan_targets) == 1
assert NetworkAddress("10.0.0.2", None) in scan_targets
assert NetworkAddress("10.0.0.2", None) == scan_targets[0]
assert NetworkAddress("10.0.0.2", "") in scan_targets
assert NetworkAddress("10.0.0.2", "") == scan_targets[0]
def test_multiple_subnet():
@ -43,10 +43,10 @@ def test_multiple_subnet():
assert len(scan_targets) == 262
for i in range(0, 255):
assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{i}", "") in scan_targets
for i in range(8, 15):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_middle_of_range_subnet():
@ -55,7 +55,7 @@ def test_middle_of_range_subnet():
assert len(scan_targets) == 7
for i in range(0, 7):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
@pytest.mark.parametrize(
@ -68,7 +68,7 @@ def test_ip_range(ip_range):
assert len(scan_targets) == 9
for i in range(25, 34):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_no_duplicates():
@ -77,7 +77,7 @@ def test_no_duplicates():
assert len(scan_targets) == 7
for i in range(0, 7):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_blocklisted_ips():
@ -146,7 +146,7 @@ def test_no_redundant_targets():
)
assert len(scan_targets) == 2
assert NetworkAddress(ip="127.0.0.0", domain=None) in scan_targets
assert NetworkAddress(ip="127.0.0.0", domain="") in scan_targets
assert NetworkAddress(ip="127.0.0.1", domain="localhost") in scan_targets
@ -212,7 +212,7 @@ def test_local_subnet_added():
assert len(scan_targets) == 254
for ip in chain(range(0, 5), range(6, 255)):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
def test_multiple_local_subnets_added():
@ -232,10 +232,10 @@ def test_multiple_local_subnets_added():
assert len(scan_targets) == 2 * (255 - 1)
for ip in chain(range(0, 5), range(6, 255)):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in chain(range(0, 99), range(100, 255)):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_blocklisted_ips_missing_from_local_subnets():
@ -273,12 +273,12 @@ def test_local_subnets_and_ranges_added():
assert len(scan_targets) == 254 + 3
for ip in range(0, 5):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in range(6, 255):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in range(40, 43):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_local_network_interfaces_specified_but_disabled():
@ -295,7 +295,7 @@ def test_local_network_interfaces_specified_but_disabled():
assert len(scan_targets) == 3
for ip in range(40, 43):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_local_network_interfaces_subnet_masks():
@ -315,7 +315,7 @@ def test_local_network_interfaces_subnet_masks():
assert len(scan_targets) == 4
for ip in [108, 110, 145, 146]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_targets():
@ -334,7 +334,7 @@ def test_segmentation_targets():
assert len(scan_targets) == 3
for ip in [144, 145, 146]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_clash_with_blocked():
@ -377,7 +377,7 @@ def test_segmentation_clash_with_targets():
assert len(scan_targets) == 3
for ip in [148, 149, 150]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_one_network():
@ -443,7 +443,7 @@ def test_invalid_inputs():
assert len(scan_targets) == 3
for ip in [148, 149, 150]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_invalid_blocklisted_ip():

View File

@ -18,7 +18,7 @@ def test_no_ssh_ports_open(ssh_fingerprinter):
}
results = ssh_fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, None)
assert results == FingerprintData(None, None, {})
assert results == FingerprintData(None, "", {})
def test_no_os(ssh_fingerprinter):
@ -32,7 +32,7 @@ def test_no_os(ssh_fingerprinter):
assert results == FingerprintData(
None,
None,
"",
{
"tcp-22": {
"display_name": "SSH",