diff --git a/monkey/common/network/network_range.py b/monkey/common/network/network_range.py index 1dfd46aa8..df32801e4 100644 --- a/monkey/common/network/network_range.py +++ b/monkey/common/network/network_range.py @@ -4,7 +4,7 @@ import random import socket import struct from abc import ABCMeta, abstractmethod -from typing import List, Tuple +from typing import Iterable, List, Tuple logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ class NetworkRange(object, metaclass=ABCMeta): return SingleIpRange(ip_address=address_str) @staticmethod - def filter_invalid_ranges(ranges: List[str], error_msg: str) -> List[str]: + def filter_invalid_ranges(ranges: Iterable[str], error_msg: str) -> List[str]: valid_ranges = [] for target_range in ranges: try: diff --git a/monkey/infection_monkey/network/info.py b/monkey/infection_monkey/network/info.py index 8cc12038d..916c642c4 100644 --- a/monkey/infection_monkey/network/info.py +++ b/monkey/infection_monkey/network/info.py @@ -4,7 +4,7 @@ import struct from dataclasses import dataclass from random import shuffle # noqa: DUO102 from threading import Lock -from typing import Dict, Set +from typing import Dict, Optional, Set import netifaces import psutil @@ -25,7 +25,7 @@ RTF_REJECT = 0x0200 @dataclass class NetworkAddress: ip: str - domain: str + domain: Optional[str] def get_host_subnets(): diff --git a/monkey/infection_monkey/network_scanning/scan_target_generator.py b/monkey/infection_monkey/network_scanning/scan_target_generator.py index d561dfd69..fa4034792 100644 --- a/monkey/infection_monkey/network_scanning/scan_target_generator.py +++ b/monkey/infection_monkey/network_scanning/scan_target_generator.py @@ -2,15 +2,19 @@ import itertools import logging import socket from ipaddress import IPv4Interface -from typing import Dict, List, Sequence +from typing import Dict, Iterable, List, Optional, Sequence + +from typing_extensions import Protocol, runtime_checkable from common.network.network_range import InvalidNetworkRangeError, NetworkRange from infection_monkey.network import NetworkAddress logger = logging.getLogger(__name__) -# TODO: We can probably reduce code and save ourselves some trouble if we use IPv4Address and -# IPv4Network. See https://docs.python.org/3/library/ipaddress.html + +@runtime_checkable +class HasDomain(Protocol): + domain_name: str def compile_scan_target_list( @@ -26,10 +30,10 @@ def compile_scan_target_list( scan_targets.extend(_get_ips_to_scan_from_interface(local_network_interfaces)) if inaccessible_subnets: - inaccessible_subnets = _get_segmentation_check_targets( + other_targets = _get_segmentation_check_targets( inaccessible_subnets, local_network_interfaces ) - scan_targets.extend(inaccessible_subnets) + scan_targets.extend(other_targets) scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces) scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips) @@ -39,8 +43,8 @@ def compile_scan_target_list( return scan_targets -def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddress]: - reverse_dns: Dict[str, str] = {} +def _remove_redundant_targets(targets: Sequence[NetworkAddress]) -> List[NetworkAddress]: + reverse_dns: Dict[str, Optional[str]] = {} for target in targets: domain_name = target.domain ip = target.ip @@ -52,14 +56,14 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]: addresses = [] for address in range_obj: - if hasattr(range_obj, "domain_name"): + if isinstance(range_obj, HasDomain): addresses.append(NetworkAddress(address, range_obj.domain_name)) else: addresses.append(NetworkAddress(address, None)) return addresses -def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAddress]: +def _get_ips_from_subnets_to_scan(subnets_to_scan: Iterable[str]) -> List[NetworkAddress]: ranges_to_scan = NetworkRange.filter_invalid_ranges( subnets_to_scan, "Bad network range input for targets to scan:" ) @@ -68,7 +72,7 @@ def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAdd return _get_ips_from_ranges_to_scan(network_ranges) -def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[NetworkAddress]: +def _get_ips_from_ranges_to_scan(network_ranges: Iterable[NetworkRange]) -> List[NetworkAddress]: scan_targets = [] for _range in network_ranges: @@ -77,8 +81,8 @@ def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[Net def _get_ips_to_scan_from_interface( - interfaces: List[IPv4Interface], -) -> List[NetworkAddress]: + interfaces: Sequence[IPv4Interface], +) -> Sequence[NetworkAddress]: ranges = [str(interface) for interface in interfaces] ranges = NetworkRange.filter_invalid_ranges( @@ -88,14 +92,14 @@ def _get_ips_to_scan_from_interface( def _remove_interface_ips( - scan_targets: List[NetworkAddress], interfaces: List[IPv4Interface] + scan_targets: Sequence[NetworkAddress], interfaces: Iterable[IPv4Interface] ) -> List[NetworkAddress]: interface_ips = [str(interface.ip) for interface in interfaces] return _remove_ips_from_scan_targets(scan_targets, interface_ips) def _remove_blocklisted_ips( - scan_targets: List[NetworkAddress], blocked_ips: List[str] + scan_targets: Sequence[NetworkAddress], blocked_ips: Sequence[str] ) -> List[NetworkAddress]: filtered_blocked_ips = NetworkRange.filter_invalid_ranges( blocked_ips, "Invalid blocked IP provided:" @@ -106,15 +110,15 @@ def _remove_blocklisted_ips( def _remove_ips_from_scan_targets( - scan_targets: List[NetworkAddress], ips_to_remove: List[str] + scan_targets: Sequence[NetworkAddress], ips_to_remove: Iterable[str] ) -> List[NetworkAddress]: ips_to_remove_set = set(ips_to_remove) return [address for address in scan_targets if address.ip not in ips_to_remove_set] def _get_segmentation_check_targets( - inaccessible_subnets: List[str], local_interfaces: List[IPv4Interface] -) -> List[NetworkAddress]: + inaccessible_subnets: Iterable[str], local_interfaces: Iterable[IPv4Interface] +) -> Sequence[NetworkAddress]: ips_to_scan = [] local_ips = [str(interface.ip) for interface in local_interfaces] @@ -134,17 +138,17 @@ def _get_segmentation_check_targets( return ips_to_scan -def _convert_to_range_object(subnets: List[str]) -> List[NetworkRange]: +def _convert_to_range_object(subnets: Iterable[str]) -> List[NetworkRange]: return [NetworkRange.get_range_obj(subnet) for subnet in subnets] def _is_segmentation_check_required( - local_ips: List[str], subnet1: NetworkRange, subnet2: NetworkRange + local_ips: Sequence[str], subnet1: NetworkRange, subnet2: NetworkRange ): return _is_any_ip_in_subnet(local_ips, subnet1) and not _is_any_ip_in_subnet(local_ips, subnet2) -def _is_any_ip_in_subnet(ip_addresses: List[str], subnet: NetworkRange): +def _is_any_ip_in_subnet(ip_addresses: Iterable[str], subnet: NetworkRange): for ip_address in ip_addresses: if subnet.is_in_range(ip_address): return True