forked from p15670423/monkey
Agent, UT: Implement scan target validation
This changes validate scan target inputs and skip invalid ones. If an invalid blocked IP is specified, then an unhandled exception is raised.
This commit is contained in:
parent
59ff3d39ce
commit
cabadeb7d1
|
@ -9,6 +9,10 @@ from typing import Tuple
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidNetworkRangeError(Exception):
|
||||
"""Raise when invalid network range is provided"""
|
||||
|
||||
|
||||
class NetworkRange(object, metaclass=ABCMeta):
|
||||
def __init__(self, shuffle=True):
|
||||
self._shuffle = shuffle
|
||||
|
@ -53,6 +57,13 @@ class NetworkRange(object, metaclass=ABCMeta):
|
|||
return CidrRange(cidr_range=address_str)
|
||||
return SingleIpRange(ip_address=address_str)
|
||||
|
||||
@staticmethod
|
||||
def validate_range(address_str: str):
|
||||
try:
|
||||
NetworkRange.get_range_obj(address_str)
|
||||
except (ValueError, OSError) as e:
|
||||
raise InvalidNetworkRangeError(e)
|
||||
|
||||
@staticmethod
|
||||
def check_if_range(address_str):
|
||||
if -1 != address_str.find("-"):
|
||||
|
@ -178,10 +189,9 @@ class SingleIpRange(NetworkRange):
|
|||
ip = socket.gethostbyname(string_)
|
||||
domain_name = string_
|
||||
except socket.error:
|
||||
logger.error(
|
||||
raise ValueError(
|
||||
"Your specified host: {} is not found as a domain name and"
|
||||
" it's not an IP address".format(string_)
|
||||
)
|
||||
return None, string_
|
||||
# If a string_ was entered instead of IP we presume that it was domain name and translate it
|
||||
return ip, domain_name
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import List, Set
|
||||
|
||||
from common.network.network_range import NetworkRange
|
||||
from common.network.network_range import InvalidNetworkRangeError, NetworkRange
|
||||
|
||||
# TODO: Convert to class and validate the format of the address and netmask
|
||||
# Example: address="192.168.1.1", netmask="/24"
|
||||
NetworkInterface = namedtuple("NetworkInterface", ("address", "netmask"))
|
||||
|
||||
|
||||
# TODO: Validate all parameters
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compile_scan_target_list(
|
||||
local_network_interfaces: List[NetworkInterface],
|
||||
ranges_to_scan: List[str],
|
||||
|
@ -17,6 +18,7 @@ def compile_scan_target_list(
|
|||
blocklisted_ips: List[str],
|
||||
enable_local_network_scan: bool,
|
||||
) -> List[str]:
|
||||
|
||||
scan_targets = _get_ips_from_ranges_to_scan(ranges_to_scan)
|
||||
|
||||
if enable_local_network_scan:
|
||||
|
@ -40,15 +42,20 @@ def compile_scan_target_list(
|
|||
def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> Set[str]:
|
||||
scan_targets = set()
|
||||
|
||||
ranges_to_scan = _filter_invalid_ranges(
|
||||
ranges_to_scan, "Bad network range input for targets to scan:"
|
||||
)
|
||||
|
||||
network_ranges = [NetworkRange.get_range_obj(_range) for _range in ranges_to_scan]
|
||||
for _range in network_ranges:
|
||||
scan_targets.update(set(_range))
|
||||
|
||||
return scan_targets
|
||||
|
||||
|
||||
def _get_ips_to_scan_from_local_interface(interfaces: List[NetworkInterface]) -> Set[str]:
|
||||
ranges = [f"{interface.address}{interface.netmask}" for interface in interfaces]
|
||||
|
||||
ranges = _filter_invalid_ranges(ranges, "Local network interface returns an invalid IP:")
|
||||
return _get_ips_from_ranges_to_scan(ranges)
|
||||
|
||||
|
||||
|
@ -58,6 +65,9 @@ def _remove_interface_ips(scan_targets: Set[str], interfaces: List[NetworkInterf
|
|||
|
||||
|
||||
def _remove_blocklisted_ips(scan_targets: Set[str], blocked_ips: List[str]):
|
||||
filtered_blocked_ips = _filter_invalid_ranges(blocked_ips, "Invalid blocked IP provided:")
|
||||
if not len(filtered_blocked_ips) == len(blocked_ips):
|
||||
raise InvalidNetworkRangeError("Received an invalid blocked IP. Aborting just in case.")
|
||||
_remove_ips_from_scan_targets(scan_targets, blocked_ips)
|
||||
|
||||
|
||||
|
@ -76,6 +86,11 @@ def _get_segmentation_check_targets(
|
|||
subnets_to_scan = set()
|
||||
local_ips = [interface.address for interface in local_interfaces]
|
||||
|
||||
local_ips = _filter_invalid_ranges(local_ips, "Invalid local IP found: ")
|
||||
inaccessible_subnets = _filter_invalid_ranges(
|
||||
inaccessible_subnets, "Invalid segmentation scan target: "
|
||||
)
|
||||
|
||||
inaccessible_subnets = _convert_to_range_object(inaccessible_subnets)
|
||||
subnet_pairs = itertools.product(inaccessible_subnets, inaccessible_subnets)
|
||||
|
||||
|
@ -87,6 +102,18 @@ def _get_segmentation_check_targets(
|
|||
return subnets_to_scan
|
||||
|
||||
|
||||
def _filter_invalid_ranges(ranges: List[str], error_msg: str) -> List[str]:
|
||||
filtered = []
|
||||
for target_range in ranges:
|
||||
try:
|
||||
NetworkRange.validate_range(target_range)
|
||||
except InvalidNetworkRangeError as e:
|
||||
logger.error(f"{error_msg} {e}")
|
||||
continue
|
||||
filtered.append(target_range)
|
||||
return filtered
|
||||
|
||||
|
||||
def _convert_to_range_object(subnets: List[str]) -> List[NetworkRange]:
|
||||
return [NetworkRange.get_range_obj(subnet) for subnet in subnets]
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from itertools import chain
|
||||
|
||||
import pytest
|
||||
from network.scan_target_generator import _filter_invalid_ranges
|
||||
|
||||
from common.network.network_range import InvalidNetworkRangeError
|
||||
from infection_monkey.network.scan_target_generator import (
|
||||
NetworkInterface,
|
||||
compile_scan_target_list,
|
||||
|
@ -399,3 +401,84 @@ def test_segmentation_inaccessible_networks():
|
|||
)
|
||||
|
||||
assert len(scan_targets) == 0
|
||||
|
||||
|
||||
def test_invalid_inputs():
|
||||
local_network_interfaces = [
|
||||
NetworkInterface("172.60.999.109", "/30"),
|
||||
NetworkInterface("172.60.145.109", "/30"),
|
||||
]
|
||||
|
||||
inaccessible_subnets = [
|
||||
"172.60.145.1 - 172.60.145.1111",
|
||||
"172.60.147.888/30" "172.60.147.8/30",
|
||||
"172.60.147.148/30",
|
||||
]
|
||||
|
||||
targets = ["172.60.145.149/33", "1.-1.1.1", "1.a.2.2", "172.60.145.151/30"]
|
||||
|
||||
scan_targets = compile_scan_target_list(
|
||||
local_network_interfaces=local_network_interfaces,
|
||||
ranges_to_scan=targets,
|
||||
inaccessible_subnets=inaccessible_subnets,
|
||||
blocklisted_ips=[],
|
||||
enable_local_network_scan=False,
|
||||
)
|
||||
|
||||
assert len(scan_targets) == 3
|
||||
|
||||
for ip in [148, 149, 150]:
|
||||
assert f"172.60.145.{ip}" in scan_targets
|
||||
|
||||
|
||||
def test_range_filtering():
|
||||
invalid_ranges = [
|
||||
# Invalid IP segment
|
||||
"172.60.999.109",
|
||||
"172.60.-1.109",
|
||||
"172.60.999.109 - 172.60.1.109",
|
||||
"172.60.999.109/32",
|
||||
"172.60.999.109/24",
|
||||
# Invalid CIDR
|
||||
"172.60.1.109/33",
|
||||
"172.60.1.109/-1",
|
||||
# Typos
|
||||
"172.60.9.109 -t 172.60.1.109",
|
||||
"172.60..9.109",
|
||||
"172.60,9.109",
|
||||
" 172.60 .9.109 ",
|
||||
]
|
||||
|
||||
valid_ranges = [
|
||||
" 172.60.9.109 ",
|
||||
"172.60.9.109 - 172.60.1.109",
|
||||
"172.60.9.109- 172.60.1.109",
|
||||
"0.0.0.0",
|
||||
"localhost"
|
||||
]
|
||||
|
||||
invalid_ranges.extend(valid_ranges)
|
||||
|
||||
remaining = _filter_invalid_ranges(invalid_ranges, "Test error:")
|
||||
for _range in remaining:
|
||||
assert _range in valid_ranges
|
||||
assert len(remaining) == len(valid_ranges)
|
||||
|
||||
|
||||
def test_invalid_blocklisted_ip():
|
||||
local_network_interfaces = [NetworkInterface("172.60.145.109", "/30")]
|
||||
|
||||
inaccessible_subnets = ["172.60.147.8/30", "172.60.147.148/30"]
|
||||
|
||||
targets = ["172.60.145.151/30"]
|
||||
|
||||
blocklisted = ["172.60.145.153", "172.60.145.753"]
|
||||
|
||||
with pytest.raises(InvalidNetworkRangeError):
|
||||
compile_scan_target_list(
|
||||
local_network_interfaces=local_network_interfaces,
|
||||
ranges_to_scan=targets,
|
||||
inaccessible_subnets=inaccessible_subnets,
|
||||
blocklisted_ips=blocklisted,
|
||||
enable_local_network_scan=False,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue