forked from p15670423/monkey
Agent: Fix mypy issues in scan_target_generator.py
This commit is contained in:
parent
1bf610a4a8
commit
311c294033
|
@ -4,7 +4,7 @@ import random
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import List, Tuple
|
from typing import Iterable, List, Tuple
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class NetworkRange(object, metaclass=ABCMeta):
|
||||||
return SingleIpRange(ip_address=address_str)
|
return SingleIpRange(ip_address=address_str)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def filter_invalid_ranges(ranges: List[str], error_msg: str) -> List[str]:
|
def filter_invalid_ranges(ranges: Iterable[str], error_msg: str) -> List[str]:
|
||||||
valid_ranges = []
|
valid_ranges = []
|
||||||
for target_range in ranges:
|
for target_range in ranges:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -4,7 +4,7 @@ import struct
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import shuffle # noqa: DUO102
|
from random import shuffle # noqa: DUO102
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Dict, Set
|
from typing import Dict, Optional, Set
|
||||||
|
|
||||||
import netifaces
|
import netifaces
|
||||||
import psutil
|
import psutil
|
||||||
|
@ -25,7 +25,7 @@ RTF_REJECT = 0x0200
|
||||||
@dataclass
|
@dataclass
|
||||||
class NetworkAddress:
|
class NetworkAddress:
|
||||||
ip: str
|
ip: str
|
||||||
domain: str
|
domain: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
def get_host_subnets():
|
def get_host_subnets():
|
||||||
|
|
|
@ -2,15 +2,19 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from ipaddress import IPv4Interface
|
from ipaddress import IPv4Interface
|
||||||
from typing import Dict, List, Sequence
|
from typing import Dict, Iterable, List, Optional, Sequence
|
||||||
|
|
||||||
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from common.network.network_range import InvalidNetworkRangeError, NetworkRange
|
from common.network.network_range import InvalidNetworkRangeError, NetworkRange
|
||||||
from infection_monkey.network import NetworkAddress
|
from infection_monkey.network import NetworkAddress
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO: We can probably reduce code and save ourselves some trouble if we use IPv4Address and
|
|
||||||
# IPv4Network. See https://docs.python.org/3/library/ipaddress.html
|
@runtime_checkable
|
||||||
|
class HasDomain(Protocol):
|
||||||
|
domain_name: str
|
||||||
|
|
||||||
|
|
||||||
def compile_scan_target_list(
|
def compile_scan_target_list(
|
||||||
|
@ -26,10 +30,10 @@ def compile_scan_target_list(
|
||||||
scan_targets.extend(_get_ips_to_scan_from_interface(local_network_interfaces))
|
scan_targets.extend(_get_ips_to_scan_from_interface(local_network_interfaces))
|
||||||
|
|
||||||
if inaccessible_subnets:
|
if inaccessible_subnets:
|
||||||
inaccessible_subnets = _get_segmentation_check_targets(
|
other_targets = _get_segmentation_check_targets(
|
||||||
inaccessible_subnets, local_network_interfaces
|
inaccessible_subnets, local_network_interfaces
|
||||||
)
|
)
|
||||||
scan_targets.extend(inaccessible_subnets)
|
scan_targets.extend(other_targets)
|
||||||
|
|
||||||
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
|
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
|
||||||
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
|
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
|
||||||
|
@ -39,8 +43,8 @@ def compile_scan_target_list(
|
||||||
return scan_targets
|
return scan_targets
|
||||||
|
|
||||||
|
|
||||||
def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddress]:
|
def _remove_redundant_targets(targets: Sequence[NetworkAddress]) -> List[NetworkAddress]:
|
||||||
reverse_dns: Dict[str, str] = {}
|
reverse_dns: Dict[str, Optional[str]] = {}
|
||||||
for target in targets:
|
for target in targets:
|
||||||
domain_name = target.domain
|
domain_name = target.domain
|
||||||
ip = target.ip
|
ip = target.ip
|
||||||
|
@ -52,14 +56,14 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr
|
||||||
def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
|
def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
|
||||||
addresses = []
|
addresses = []
|
||||||
for address in range_obj:
|
for address in range_obj:
|
||||||
if hasattr(range_obj, "domain_name"):
|
if isinstance(range_obj, HasDomain):
|
||||||
addresses.append(NetworkAddress(address, range_obj.domain_name))
|
addresses.append(NetworkAddress(address, range_obj.domain_name))
|
||||||
else:
|
else:
|
||||||
addresses.append(NetworkAddress(address, None))
|
addresses.append(NetworkAddress(address, None))
|
||||||
return addresses
|
return addresses
|
||||||
|
|
||||||
|
|
||||||
def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAddress]:
|
def _get_ips_from_subnets_to_scan(subnets_to_scan: Iterable[str]) -> List[NetworkAddress]:
|
||||||
ranges_to_scan = NetworkRange.filter_invalid_ranges(
|
ranges_to_scan = NetworkRange.filter_invalid_ranges(
|
||||||
subnets_to_scan, "Bad network range input for targets to scan:"
|
subnets_to_scan, "Bad network range input for targets to scan:"
|
||||||
)
|
)
|
||||||
|
@ -68,7 +72,7 @@ def _get_ips_from_subnets_to_scan(subnets_to_scan: List[str]) -> List[NetworkAdd
|
||||||
return _get_ips_from_ranges_to_scan(network_ranges)
|
return _get_ips_from_ranges_to_scan(network_ranges)
|
||||||
|
|
||||||
|
|
||||||
def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[NetworkAddress]:
|
def _get_ips_from_ranges_to_scan(network_ranges: Iterable[NetworkRange]) -> List[NetworkAddress]:
|
||||||
scan_targets = []
|
scan_targets = []
|
||||||
|
|
||||||
for _range in network_ranges:
|
for _range in network_ranges:
|
||||||
|
@ -77,8 +81,8 @@ def _get_ips_from_ranges_to_scan(network_ranges: List[NetworkRange]) -> List[Net
|
||||||
|
|
||||||
|
|
||||||
def _get_ips_to_scan_from_interface(
|
def _get_ips_to_scan_from_interface(
|
||||||
interfaces: List[IPv4Interface],
|
interfaces: Sequence[IPv4Interface],
|
||||||
) -> List[NetworkAddress]:
|
) -> Sequence[NetworkAddress]:
|
||||||
ranges = [str(interface) for interface in interfaces]
|
ranges = [str(interface) for interface in interfaces]
|
||||||
|
|
||||||
ranges = NetworkRange.filter_invalid_ranges(
|
ranges = NetworkRange.filter_invalid_ranges(
|
||||||
|
@ -88,14 +92,14 @@ def _get_ips_to_scan_from_interface(
|
||||||
|
|
||||||
|
|
||||||
def _remove_interface_ips(
|
def _remove_interface_ips(
|
||||||
scan_targets: List[NetworkAddress], interfaces: List[IPv4Interface]
|
scan_targets: Sequence[NetworkAddress], interfaces: Iterable[IPv4Interface]
|
||||||
) -> List[NetworkAddress]:
|
) -> List[NetworkAddress]:
|
||||||
interface_ips = [str(interface.ip) for interface in interfaces]
|
interface_ips = [str(interface.ip) for interface in interfaces]
|
||||||
return _remove_ips_from_scan_targets(scan_targets, interface_ips)
|
return _remove_ips_from_scan_targets(scan_targets, interface_ips)
|
||||||
|
|
||||||
|
|
||||||
def _remove_blocklisted_ips(
|
def _remove_blocklisted_ips(
|
||||||
scan_targets: List[NetworkAddress], blocked_ips: List[str]
|
scan_targets: Sequence[NetworkAddress], blocked_ips: Sequence[str]
|
||||||
) -> List[NetworkAddress]:
|
) -> List[NetworkAddress]:
|
||||||
filtered_blocked_ips = NetworkRange.filter_invalid_ranges(
|
filtered_blocked_ips = NetworkRange.filter_invalid_ranges(
|
||||||
blocked_ips, "Invalid blocked IP provided:"
|
blocked_ips, "Invalid blocked IP provided:"
|
||||||
|
@ -106,15 +110,15 @@ def _remove_blocklisted_ips(
|
||||||
|
|
||||||
|
|
||||||
def _remove_ips_from_scan_targets(
|
def _remove_ips_from_scan_targets(
|
||||||
scan_targets: List[NetworkAddress], ips_to_remove: List[str]
|
scan_targets: Sequence[NetworkAddress], ips_to_remove: Iterable[str]
|
||||||
) -> List[NetworkAddress]:
|
) -> List[NetworkAddress]:
|
||||||
ips_to_remove_set = set(ips_to_remove)
|
ips_to_remove_set = set(ips_to_remove)
|
||||||
return [address for address in scan_targets if address.ip not in ips_to_remove_set]
|
return [address for address in scan_targets if address.ip not in ips_to_remove_set]
|
||||||
|
|
||||||
|
|
||||||
def _get_segmentation_check_targets(
|
def _get_segmentation_check_targets(
|
||||||
inaccessible_subnets: List[str], local_interfaces: List[IPv4Interface]
|
inaccessible_subnets: Iterable[str], local_interfaces: Iterable[IPv4Interface]
|
||||||
) -> List[NetworkAddress]:
|
) -> Sequence[NetworkAddress]:
|
||||||
ips_to_scan = []
|
ips_to_scan = []
|
||||||
local_ips = [str(interface.ip) for interface in local_interfaces]
|
local_ips = [str(interface.ip) for interface in local_interfaces]
|
||||||
|
|
||||||
|
@ -134,17 +138,17 @@ def _get_segmentation_check_targets(
|
||||||
return ips_to_scan
|
return ips_to_scan
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_range_object(subnets: List[str]) -> List[NetworkRange]:
|
def _convert_to_range_object(subnets: Iterable[str]) -> List[NetworkRange]:
|
||||||
return [NetworkRange.get_range_obj(subnet) for subnet in subnets]
|
return [NetworkRange.get_range_obj(subnet) for subnet in subnets]
|
||||||
|
|
||||||
|
|
||||||
def _is_segmentation_check_required(
|
def _is_segmentation_check_required(
|
||||||
local_ips: List[str], subnet1: NetworkRange, subnet2: NetworkRange
|
local_ips: Sequence[str], subnet1: NetworkRange, subnet2: NetworkRange
|
||||||
):
|
):
|
||||||
return _is_any_ip_in_subnet(local_ips, subnet1) and not _is_any_ip_in_subnet(local_ips, subnet2)
|
return _is_any_ip_in_subnet(local_ips, subnet1) and not _is_any_ip_in_subnet(local_ips, subnet2)
|
||||||
|
|
||||||
|
|
||||||
def _is_any_ip_in_subnet(ip_addresses: List[str], subnet: NetworkRange):
|
def _is_any_ip_in_subnet(ip_addresses: Iterable[str], subnet: NetworkRange):
|
||||||
for ip_address in ip_addresses:
|
for ip_address in ip_addresses:
|
||||||
if subnet.is_in_range(ip_address):
|
if subnet.is_in_range(ip_address):
|
||||||
return True
|
return True
|
||||||
|
|
Loading…
Reference in New Issue