Agent: Fix mypy issues in scan_target_generator.py

This commit is contained in:
Kekoa Kaaikala 2022-09-26 20:12:57 +00:00
parent 1bf610a4a8
commit 311c294033
3 changed files with 28 additions and 24 deletions

View File

@ -4,7 +4,7 @@ import random
import socket
import struct
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
from typing import Iterable, List, Tuple
logger = logging.getLogger(__name__)
@ -58,7 +58,7 @@ class NetworkRange(object, metaclass=ABCMeta):
return SingleIpRange(ip_address=address_str)
@staticmethod
def filter_invalid_ranges(ranges: List[str], error_msg: str) -> List[str]:
def filter_invalid_ranges(ranges: Iterable[str], error_msg: str) -> List[str]:
valid_ranges = []
for target_range in ranges:
try:

View File

@ -4,7 +4,7 @@ import struct
from dataclasses import dataclass
from random import shuffle # noqa: DUO102
from threading import Lock
from typing import Dict, Set
from typing import Dict, Optional, Set
import netifaces
import psutil
@ -25,7 +25,7 @@ RTF_REJECT = 0x0200
@dataclass
class NetworkAddress:
ip: str
domain: str
domain: Optional[str]
def get_host_subnets():

View File

@ -2,15 +2,19 @@ import itertools
import logging
import socket
from ipaddress import IPv4Interface
from typing import Dict, List, Sequence
from typing import Dict, Iterable, List, Optional, Sequence
from typing_extensions import Protocol, runtime_checkable
from common.network.network_range import InvalidNetworkRangeError, NetworkRange
from infection_monkey.network import NetworkAddress
logger = logging.getLogger(__name__)
# TODO: We can probably reduce code and save ourselves some trouble if we use IPv4Address and
# IPv4Network. See https://docs.python.org/3/library/ipaddress.html
@runtime_checkable
class HasDomain(Protocol):
domain_name: str
def compile_scan_target_list(
@ -26,10 +30,10 @@ def compile_scan_target_list(
scan_targets.extend(_get_ips_to_scan_from_interface(local_network_interfaces))
if inaccessible_subnets:
inaccessible_subnets = _get_segmentation_check_targets(
other_targets = _get_segmentation_check_targets(
inaccessible_subnets, local_network_interfaces
)
scan_targets.extend(inaccessible_subnets)
scan_targets.extend(other_targets)
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
@ -39,8 +43,8 @@ def compile_scan_target_list(
return scan_targets
def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddress]:
reverse_dns: Dict[str, str] = {}
def _remove_redundant_targets(targets: Sequence[NetworkAddress]) -> List[NetworkAddress]:
reverse_dns: Dict[str, Optional[str]] = {}
for target in targets:
domain_name = target.domain
ip = target.ip
@ -52,14 +56,14 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr
def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
addresses = []
for address in range_obj:
if hasattr(range_obj, "domain_name"):
if isinstance(range_obj, HasDomain):
addresses.append(NetworkAddress(address, range_obj.domain_name))
else:
addresses.append(NetworkAddress(address, None))
return addresses
def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAddress]:
def _get_ips_from_subnets_to_scan(subnets_to_scan: Iterable[str]) -> List[NetworkAddress]:
ranges_to_scan = NetworkRange.filter_invalid_ranges(
subnets_to_scan, "Bad network range input for targets to scan:"
)
@ -68,7 +72,7 @@ def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAdd
return _get_ips_from_ranges_to_scan(network_ranges)
def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[NetworkAddress]:
def _get_ips_from_ranges_to_scan(network_ranges: Iterable[NetworkRange]) -> List[NetworkAddress]:
scan_targets = []
for _range in network_ranges:
@ -77,8 +81,8 @@ def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[Net
def _get_ips_to_scan_from_interface(
interfaces: List[IPv4Interface],
) -> List[NetworkAddress]:
interfaces: Sequence[IPv4Interface],
) -> Sequence[NetworkAddress]:
ranges = [str(interface) for interface in interfaces]
ranges = NetworkRange.filter_invalid_ranges(
@ -88,14 +92,14 @@ def _get_ips_to_scan_from_interface(
def _remove_interface_ips(
scan_targets: List[NetworkAddress], interfaces: List[IPv4Interface]
scan_targets: Sequence[NetworkAddress], interfaces: Iterable[IPv4Interface]
) -> List[NetworkAddress]:
interface_ips = [str(interface.ip) for interface in interfaces]
return _remove_ips_from_scan_targets(scan_targets, interface_ips)
def _remove_blocklisted_ips(
scan_targets: List[NetworkAddress], blocked_ips: List[str]
scan_targets: Sequence[NetworkAddress], blocked_ips: Sequence[str]
) -> List[NetworkAddress]:
filtered_blocked_ips = NetworkRange.filter_invalid_ranges(
blocked_ips, "Invalid blocked IP provided:"
@ -106,15 +110,15 @@ def _remove_blocklisted_ips(
def _remove_ips_from_scan_targets(
scan_targets: List[NetworkAddress], ips_to_remove: List[str]
scan_targets: Sequence[NetworkAddress], ips_to_remove: Iterable[str]
) -> List[NetworkAddress]:
ips_to_remove_set = set(ips_to_remove)
return [address for address in scan_targets if address.ip not in ips_to_remove_set]
def _get_segmentation_check_targets(
inaccessible_subnets: List[str], local_interfaces: List[IPv4Interface]
) -> List[NetworkAddress]:
inaccessible_subnets: Iterable[str], local_interfaces: Iterable[IPv4Interface]
) -> Sequence[NetworkAddress]:
ips_to_scan = []
local_ips = [str(interface.ip) for interface in local_interfaces]
@ -134,17 +138,17 @@ def _get_segmentation_check_targets(
return ips_to_scan
def _convert_to_range_object(subnets: List[str]) -> List[NetworkRange]:
def _convert_to_range_object(subnets: Iterable[str]) -> List[NetworkRange]:
return [NetworkRange.get_range_obj(subnet) for subnet in subnets]
def _is_segmentation_check_required(
local_ips: List[str], subnet1: NetworkRange, subnet2: NetworkRange
local_ips: Sequence[str], subnet1: NetworkRange, subnet2: NetworkRange
):
return _is_any_ip_in_subnet(local_ips, subnet1) and not _is_any_ip_in_subnet(local_ips, subnet2)
def _is_any_ip_in_subnet(ip_addresses: List[str], subnet: NetworkRange):
def _is_any_ip_in_subnet(ip_addresses: Iterable[str], subnet: NetworkRange):
for ip_address in ip_addresses:
if subnet.is_in_range(ip_address):
return True