Agent: Fix mypy issues in scan_target_generator.py

This commit is contained in:
Kekoa Kaaikala 2022-09-26 20:12:57 +00:00
parent 1bf610a4a8
commit 311c294033
3 changed files with 28 additions and 24 deletions

View File

@ -4,7 +4,7 @@ import random
import socket import socket
import struct import struct
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import List, Tuple from typing import Iterable, List, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,7 +58,7 @@ class NetworkRange(object, metaclass=ABCMeta):
return SingleIpRange(ip_address=address_str) return SingleIpRange(ip_address=address_str)
@staticmethod @staticmethod
def filter_invalid_ranges(ranges: List[str], error_msg: str) -> List[str]: def filter_invalid_ranges(ranges: Iterable[str], error_msg: str) -> List[str]:
valid_ranges = [] valid_ranges = []
for target_range in ranges: for target_range in ranges:
try: try:

View File

@ -4,7 +4,7 @@ import struct
from dataclasses import dataclass from dataclasses import dataclass
from random import shuffle # noqa: DUO102 from random import shuffle # noqa: DUO102
from threading import Lock from threading import Lock
from typing import Dict, Set from typing import Dict, Optional, Set
import netifaces import netifaces
import psutil import psutil
@ -25,7 +25,7 @@ RTF_REJECT = 0x0200
@dataclass @dataclass
class NetworkAddress: class NetworkAddress:
ip: str ip: str
domain: str domain: Optional[str]
def get_host_subnets(): def get_host_subnets():

View File

@ -2,15 +2,19 @@ import itertools
import logging import logging
import socket import socket
from ipaddress import IPv4Interface from ipaddress import IPv4Interface
from typing import Dict, List, Sequence from typing import Dict, Iterable, List, Optional, Sequence
from typing_extensions import Protocol, runtime_checkable
from common.network.network_range import InvalidNetworkRangeError, NetworkRange from common.network.network_range import InvalidNetworkRangeError, NetworkRange
from infection_monkey.network import NetworkAddress from infection_monkey.network import NetworkAddress
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: We can probably reduce code and save ourselves some trouble if we use IPv4Address and
# IPv4Network. See https://docs.python.org/3/library/ipaddress.html @runtime_checkable
class HasDomain(Protocol):
domain_name: str
def compile_scan_target_list( def compile_scan_target_list(
@ -26,10 +30,10 @@ def compile_scan_target_list(
scan_targets.extend(_get_ips_to_scan_from_interface(local_network_interfaces)) scan_targets.extend(_get_ips_to_scan_from_interface(local_network_interfaces))
if inaccessible_subnets: if inaccessible_subnets:
inaccessible_subnets = _get_segmentation_check_targets( other_targets = _get_segmentation_check_targets(
inaccessible_subnets, local_network_interfaces inaccessible_subnets, local_network_interfaces
) )
scan_targets.extend(inaccessible_subnets) scan_targets.extend(other_targets)
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces) scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips) scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
@ -39,8 +43,8 @@ def compile_scan_target_list(
return scan_targets return scan_targets
def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddress]: def _remove_redundant_targets(targets: Sequence[NetworkAddress]) -> List[NetworkAddress]:
reverse_dns: Dict[str, str] = {} reverse_dns: Dict[str, Optional[str]] = {}
for target in targets: for target in targets:
domain_name = target.domain domain_name = target.domain
ip = target.ip ip = target.ip
@ -52,14 +56,14 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr
def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]: def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
addresses = [] addresses = []
for address in range_obj: for address in range_obj:
if hasattr(range_obj, "domain_name"): if isinstance(range_obj, HasDomain):
addresses.append(NetworkAddress(address, range_obj.domain_name)) addresses.append(NetworkAddress(address, range_obj.domain_name))
else: else:
addresses.append(NetworkAddress(address, None)) addresses.append(NetworkAddress(address, None))
return addresses return addresses
def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAddress]: def _get_ips_from_subnets_to_scan(subnets_to_scan: Iterable[str]) -> List[NetworkAddress]:
ranges_to_scan = NetworkRange.filter_invalid_ranges( ranges_to_scan = NetworkRange.filter_invalid_ranges(
subnets_to_scan, "Bad network range input for targets to scan:" subnets_to_scan, "Bad network range input for targets to scan:"
) )
@ -68,7 +72,7 @@ def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAdd
return _get_ips_from_ranges_to_scan(network_ranges) return _get_ips_from_ranges_to_scan(network_ranges)
def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[NetworkAddress]: def _get_ips_from_ranges_to_scan(network_ranges: Iterable[NetworkRange]) -> List[NetworkAddress]:
scan_targets = [] scan_targets = []
for _range in network_ranges: for _range in network_ranges:
@ -77,8 +81,8 @@ def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[Net
def _get_ips_to_scan_from_interface( def _get_ips_to_scan_from_interface(
interfaces: List[IPv4Interface], interfaces: Sequence[IPv4Interface],
) -> List[NetworkAddress]: ) -> Sequence[NetworkAddress]:
ranges = [str(interface) for interface in interfaces] ranges = [str(interface) for interface in interfaces]
ranges = NetworkRange.filter_invalid_ranges( ranges = NetworkRange.filter_invalid_ranges(
@ -88,14 +92,14 @@ def _get_ips_to_scan_from_interface(
def _remove_interface_ips( def _remove_interface_ips(
scan_targets: List[NetworkAddress], interfaces: List[IPv4Interface] scan_targets: Sequence[NetworkAddress], interfaces: Iterable[IPv4Interface]
) -> List[NetworkAddress]: ) -> List[NetworkAddress]:
interface_ips = [str(interface.ip) for interface in interfaces] interface_ips = [str(interface.ip) for interface in interfaces]
return _remove_ips_from_scan_targets(scan_targets, interface_ips) return _remove_ips_from_scan_targets(scan_targets, interface_ips)
def _remove_blocklisted_ips( def _remove_blocklisted_ips(
scan_targets: List[NetworkAddress], blocked_ips: List[str] scan_targets: Sequence[NetworkAddress], blocked_ips: Sequence[str]
) -> List[NetworkAddress]: ) -> List[NetworkAddress]:
filtered_blocked_ips = NetworkRange.filter_invalid_ranges( filtered_blocked_ips = NetworkRange.filter_invalid_ranges(
blocked_ips, "Invalid blocked IP provided:" blocked_ips, "Invalid blocked IP provided:"
@ -106,15 +110,15 @@ def _remove_blocklisted_ips(
def _remove_ips_from_scan_targets( def _remove_ips_from_scan_targets(
scan_targets: List[NetworkAddress], ips_to_remove: List[str] scan_targets: Sequence[NetworkAddress], ips_to_remove: Iterable[str]
) -> List[NetworkAddress]: ) -> List[NetworkAddress]:
ips_to_remove_set = set(ips_to_remove) ips_to_remove_set = set(ips_to_remove)
return [address for address in scan_targets if address.ip not in ips_to_remove_set] return [address for address in scan_targets if address.ip not in ips_to_remove_set]
def _get_segmentation_check_targets( def _get_segmentation_check_targets(
inaccessible_subnets: List[str], local_interfaces: List[IPv4Interface] inaccessible_subnets: Iterable[str], local_interfaces: Iterable[IPv4Interface]
) -> List[NetworkAddress]: ) -> Sequence[NetworkAddress]:
ips_to_scan = [] ips_to_scan = []
local_ips = [str(interface.ip) for interface in local_interfaces] local_ips = [str(interface.ip) for interface in local_interfaces]
@ -134,17 +138,17 @@ def _get_segmentation_check_targets(
return ips_to_scan return ips_to_scan
def _convert_to_range_object(subnets: List[str]) -> List[NetworkRange]: def _convert_to_range_object(subnets: Iterable[str]) -> List[NetworkRange]:
return [NetworkRange.get_range_obj(subnet) for subnet in subnets] return [NetworkRange.get_range_obj(subnet) for subnet in subnets]
def _is_segmentation_check_required( def _is_segmentation_check_required(
local_ips: List[str], subnet1: NetworkRange, subnet2: NetworkRange local_ips: Sequence[str], subnet1: NetworkRange, subnet2: NetworkRange
): ):
return _is_any_ip_in_subnet(local_ips, subnet1) and not _is_any_ip_in_subnet(local_ips, subnet2) return _is_any_ip_in_subnet(local_ips, subnet1) and not _is_any_ip_in_subnet(local_ips, subnet2)
def _is_any_ip_in_subnet(ip_addresses: List[str], subnet: NetworkRange): def _is_any_ip_in_subnet(ip_addresses: Iterable[str], subnet: NetworkRange):
for ip_address in ip_addresses: for ip_address in ip_addresses:
if subnet.is_in_range(ip_address): if subnet.is_in_range(ip_address):
return True return True