Agent: Fix mypy issues in scan_target_generator.py
This commit is contained in:
parent
1bf610a4a8
commit
311c294033
|
@ -4,7 +4,7 @@ import random
|
|||
import socket
|
||||
import struct
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Tuple
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ class NetworkRange(object, metaclass=ABCMeta):
|
|||
return SingleIpRange(ip_address=address_str)
|
||||
|
||||
@staticmethod
|
||||
def filter_invalid_ranges(ranges: List[str], error_msg: str) -> List[str]:
|
||||
def filter_invalid_ranges(ranges: Iterable[str], error_msg: str) -> List[str]:
|
||||
valid_ranges = []
|
||||
for target_range in ranges:
|
||||
try:
|
||||
|
|
|
@ -4,7 +4,7 @@ import struct
|
|||
from dataclasses import dataclass
|
||||
from random import shuffle # noqa: DUO102
|
||||
from threading import Lock
|
||||
from typing import Dict, Set
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import netifaces
|
||||
import psutil
|
||||
|
@ -25,7 +25,7 @@ RTF_REJECT = 0x0200
|
|||
@dataclass
|
||||
class NetworkAddress:
|
||||
ip: str
|
||||
domain: str
|
||||
domain: Optional[str]
|
||||
|
||||
|
||||
def get_host_subnets():
|
||||
|
|
|
@ -2,15 +2,19 @@ import itertools
|
|||
import logging
|
||||
import socket
|
||||
from ipaddress import IPv4Interface
|
||||
from typing import Dict, List, Sequence
|
||||
from typing import Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from common.network.network_range import InvalidNetworkRangeError, NetworkRange
|
||||
from infection_monkey.network import NetworkAddress
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: We can probably reduce code and save ourselves some trouble if we use IPv4Address and
|
||||
# IPv4Network. See https://docs.python.org/3/library/ipaddress.html
|
||||
|
||||
@runtime_checkable
|
||||
class HasDomain(Protocol):
|
||||
domain_name: str
|
||||
|
||||
|
||||
def compile_scan_target_list(
|
||||
|
@ -26,10 +30,10 @@ def compile_scan_target_list(
|
|||
scan_targets.extend(_get_ips_to_scan_from_interface(local_network_interfaces))
|
||||
|
||||
if inaccessible_subnets:
|
||||
inaccessible_subnets = _get_segmentation_check_targets(
|
||||
other_targets = _get_segmentation_check_targets(
|
||||
inaccessible_subnets, local_network_interfaces
|
||||
)
|
||||
scan_targets.extend(inaccessible_subnets)
|
||||
scan_targets.extend(other_targets)
|
||||
|
||||
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
|
||||
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
|
||||
|
@ -39,8 +43,8 @@ def compile_scan_target_list(
|
|||
return scan_targets
|
||||
|
||||
|
||||
def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddress]:
|
||||
reverse_dns: Dict[str, str] = {}
|
||||
def _remove_redundant_targets(targets: Sequence[NetworkAddress]) -> List[NetworkAddress]:
|
||||
reverse_dns: Dict[str, Optional[str]] = {}
|
||||
for target in targets:
|
||||
domain_name = target.domain
|
||||
ip = target.ip
|
||||
|
@ -52,14 +56,14 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr
|
|||
def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
|
||||
addresses = []
|
||||
for address in range_obj:
|
||||
if hasattr(range_obj, "domain_name"):
|
||||
if isinstance(range_obj, HasDomain):
|
||||
addresses.append(NetworkAddress(address, range_obj.domain_name))
|
||||
else:
|
||||
addresses.append(NetworkAddress(address, None))
|
||||
return addresses
|
||||
|
||||
|
||||
def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAddress]:
|
||||
def _get_ips_from_subnets_to_scan(subnets_to_scan: Iterable[str]) -> List[NetworkAddress]:
|
||||
ranges_to_scan = NetworkRange.filter_invalid_ranges(
|
||||
subnets_to_scan, "Bad network range input for targets to scan:"
|
||||
)
|
||||
|
@ -68,7 +72,7 @@ def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAdd
|
|||
return _get_ips_from_ranges_to_scan(network_ranges)
|
||||
|
||||
|
||||
def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[NetworkAddress]:
|
||||
def _get_ips_from_ranges_to_scan(network_ranges: Iterable[NetworkRange]) -> List[NetworkAddress]:
|
||||
scan_targets = []
|
||||
|
||||
for _range in network_ranges:
|
||||
|
@ -77,8 +81,8 @@ def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[Net
|
|||
|
||||
|
||||
def _get_ips_to_scan_from_interface(
|
||||
interfaces: List[IPv4Interface],
|
||||
) -> List[NetworkAddress]:
|
||||
interfaces: Sequence[IPv4Interface],
|
||||
) -> Sequence[NetworkAddress]:
|
||||
ranges = [str(interface) for interface in interfaces]
|
||||
|
||||
ranges = NetworkRange.filter_invalid_ranges(
|
||||
|
@ -88,14 +92,14 @@ def _get_ips_to_scan_from_interface(
|
|||
|
||||
|
||||
def _remove_interface_ips(
|
||||
scan_targets: List[NetworkAddress], interfaces: List[IPv4Interface]
|
||||
scan_targets: Sequence[NetworkAddress], interfaces: Iterable[IPv4Interface]
|
||||
) -> List[NetworkAddress]:
|
||||
interface_ips = [str(interface.ip) for interface in interfaces]
|
||||
return _remove_ips_from_scan_targets(scan_targets, interface_ips)
|
||||
|
||||
|
||||
def _remove_blocklisted_ips(
|
||||
scan_targets: List[NetworkAddress], blocked_ips: List[str]
|
||||
scan_targets: Sequence[NetworkAddress], blocked_ips: Sequence[str]
|
||||
) -> List[NetworkAddress]:
|
||||
filtered_blocked_ips = NetworkRange.filter_invalid_ranges(
|
||||
blocked_ips, "Invalid blocked IP provided:"
|
||||
|
@ -106,15 +110,15 @@ def _remove_blocklisted_ips(
|
|||
|
||||
|
||||
def _remove_ips_from_scan_targets(
|
||||
scan_targets: List[NetworkAddress], ips_to_remove: List[str]
|
||||
scan_targets: Sequence[NetworkAddress], ips_to_remove: Iterable[str]
|
||||
) -> List[NetworkAddress]:
|
||||
ips_to_remove_set = set(ips_to_remove)
|
||||
return [address for address in scan_targets if address.ip not in ips_to_remove_set]
|
||||
|
||||
|
||||
def _get_segmentation_check_targets(
|
||||
inaccessible_subnets: List[str], local_interfaces: List[IPv4Interface]
|
||||
) -> List[NetworkAddress]:
|
||||
inaccessible_subnets: Iterable[str], local_interfaces: Iterable[IPv4Interface]
|
||||
) -> Sequence[NetworkAddress]:
|
||||
ips_to_scan = []
|
||||
local_ips = [str(interface.ip) for interface in local_interfaces]
|
||||
|
||||
|
@ -134,17 +138,17 @@ def _get_segmentation_check_targets(
|
|||
return ips_to_scan
|
||||
|
||||
|
||||
def _convert_to_range_object(subnets: List[str]) -> List[NetworkRange]:
|
||||
def _convert_to_range_object(subnets: Iterable[str]) -> List[NetworkRange]:
|
||||
return [NetworkRange.get_range_obj(subnet) for subnet in subnets]
|
||||
|
||||
|
||||
def _is_segmentation_check_required(
|
||||
local_ips: List[str], subnet1: NetworkRange, subnet2: NetworkRange
|
||||
local_ips: Sequence[str], subnet1: NetworkRange, subnet2: NetworkRange
|
||||
):
|
||||
return _is_any_ip_in_subnet(local_ips, subnet1) and not _is_any_ip_in_subnet(local_ips, subnet2)
|
||||
|
||||
|
||||
def _is_any_ip_in_subnet(ip_addresses: List[str], subnet: NetworkRange):
|
||||
def _is_any_ip_in_subnet(ip_addresses: Iterable[str], subnet: NetworkRange):
|
||||
for ip_address in ip_addresses:
|
||||
if subnet.is_in_range(ip_address):
|
||||
return True
|
||||
|
|
Loading…
Reference in New Issue