Agent, UT: Implement scan target validation

This changes validate scan target inputs and skip invalid ones. If an invalid blocked IP is specified, then an unhandled exception is raised.
This commit is contained in:
vakarisz 2021-12-15 13:06:07 +02:00
parent 59ff3d39ce
commit cabadeb7d1
3 changed files with 127 additions and 7 deletions

View File

@ -9,6 +9,10 @@ from typing import Tuple
logger = logging.getLogger(__name__)
class InvalidNetworkRangeError(Exception):
"""Raise when invalid network range is provided"""
class NetworkRange(object, metaclass=ABCMeta):
def __init__(self, shuffle=True):
self._shuffle = shuffle
@ -53,6 +57,13 @@ class NetworkRange(object, metaclass=ABCMeta):
return CidrRange(cidr_range=address_str)
return SingleIpRange(ip_address=address_str)
@staticmethod
def validate_range(address_str: str):
try:
NetworkRange.get_range_obj(address_str)
except (ValueError, OSError) as e:
raise InvalidNetworkRangeError(e)
@staticmethod
def check_if_range(address_str):
if -1 != address_str.find("-"):
@ -178,10 +189,9 @@ class SingleIpRange(NetworkRange):
ip = socket.gethostbyname(string_)
domain_name = string_
except socket.error:
logger.error(
raise ValueError(
"Your specified host: {} is not found as a domain name and"
" it's not an IP address".format(string_)
)
return None, string_
# If a string_ was entered instead of IP we presume that it was domain name and translate it
return ip, domain_name

View File

@ -1,15 +1,16 @@
import itertools
import logging
from collections import namedtuple
from typing import List, Set
from common.network.network_range import NetworkRange
from common.network.network_range import InvalidNetworkRangeError, NetworkRange
# TODO: Convert to class and validate the format of the address and netmask
# Example: address="192.168.1.1", netmask="/24"
NetworkInterface = namedtuple("NetworkInterface", ("address", "netmask"))
# TODO: Validate all parameters
logger = logging.getLogger(__name__)
def compile_scan_target_list(
local_network_interfaces: List[NetworkInterface],
ranges_to_scan: List[str],
@ -17,6 +18,7 @@ def compile_scan_target_list(
blocklisted_ips: List[str],
enable_local_network_scan: bool,
) -> List[str]:
scan_targets = _get_ips_from_ranges_to_scan(ranges_to_scan)
if enable_local_network_scan:
@ -40,15 +42,20 @@ def compile_scan_target_list(
def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> Set[str]:
scan_targets = set()
ranges_to_scan = _filter_invalid_ranges(
ranges_to_scan, "Bad network range input for targets to scan:"
)
network_ranges = [NetworkRange.get_range_obj(_range) for _range in ranges_to_scan]
for _range in network_ranges:
scan_targets.update(set(_range))
return scan_targets
def _get_ips_to_scan_from_local_interface(interfaces: List[NetworkInterface]) -> Set[str]:
ranges = [f"{interface.address}{interface.netmask}" for interface in interfaces]
ranges = _filter_invalid_ranges(ranges, "Local network interface returns an invalid IP:")
return _get_ips_from_ranges_to_scan(ranges)
@ -58,6 +65,9 @@ def _remove_interface_ips(scan_targets: Set[str], interfaces: List[NetworkInterf
def _remove_blocklisted_ips(scan_targets: Set[str], blocked_ips: List[str]):
filtered_blocked_ips = _filter_invalid_ranges(blocked_ips, "Invalid blocked IP provided:")
if not len(filtered_blocked_ips) == len(blocked_ips):
raise InvalidNetworkRangeError("Received an invalid blocked IP. Aborting just in case.")
_remove_ips_from_scan_targets(scan_targets, blocked_ips)
@ -76,6 +86,11 @@ def _get_segmentation_check_targets(
subnets_to_scan = set()
local_ips = [interface.address for interface in local_interfaces]
local_ips = _filter_invalid_ranges(local_ips, "Invalid local IP found: ")
inaccessible_subnets = _filter_invalid_ranges(
inaccessible_subnets, "Invalid segmentation scan target: "
)
inaccessible_subnets = _convert_to_range_object(inaccessible_subnets)
subnet_pairs = itertools.product(inaccessible_subnets, inaccessible_subnets)
@ -87,6 +102,18 @@ def _get_segmentation_check_targets(
return subnets_to_scan
def _filter_invalid_ranges(ranges: List[str], error_msg: str) -> List[str]:
filtered = []
for target_range in ranges:
try:
NetworkRange.validate_range(target_range)
except InvalidNetworkRangeError as e:
logger.error(f"{error_msg} {e}")
continue
filtered.append(target_range)
return filtered
def _convert_to_range_object(subnets: List[str]) -> List[NetworkRange]:
return [NetworkRange.get_range_obj(subnet) for subnet in subnets]

View File

@ -1,7 +1,9 @@
from itertools import chain
import pytest
from network.scan_target_generator import _filter_invalid_ranges
from common.network.network_range import InvalidNetworkRangeError
from infection_monkey.network.scan_target_generator import (
NetworkInterface,
compile_scan_target_list,
@ -399,3 +401,84 @@ def test_segmentation_inaccessible_networks():
)
assert len(scan_targets) == 0
def test_invalid_inputs():
local_network_interfaces = [
NetworkInterface("172.60.999.109", "/30"),
NetworkInterface("172.60.145.109", "/30"),
]
inaccessible_subnets = [
"172.60.145.1 - 172.60.145.1111",
"172.60.147.888/30" "172.60.147.8/30",
"172.60.147.148/30",
]
targets = ["172.60.145.149/33", "1.-1.1.1", "1.a.2.2", "172.60.145.151/30"]
scan_targets = compile_scan_target_list(
local_network_interfaces=local_network_interfaces,
ranges_to_scan=targets,
inaccessible_subnets=inaccessible_subnets,
blocklisted_ips=[],
enable_local_network_scan=False,
)
assert len(scan_targets) == 3
for ip in [148, 149, 150]:
assert f"172.60.145.{ip}" in scan_targets
def test_range_filtering():
invalid_ranges = [
# Invalid IP segment
"172.60.999.109",
"172.60.-1.109",
"172.60.999.109 - 172.60.1.109",
"172.60.999.109/32",
"172.60.999.109/24",
# Invalid CIDR
"172.60.1.109/33",
"172.60.1.109/-1",
# Typos
"172.60.9.109 -t 172.60.1.109",
"172.60..9.109",
"172.60,9.109",
" 172.60 .9.109 ",
]
valid_ranges = [
" 172.60.9.109 ",
"172.60.9.109 - 172.60.1.109",
"172.60.9.109- 172.60.1.109",
"0.0.0.0",
"localhost"
]
invalid_ranges.extend(valid_ranges)
remaining = _filter_invalid_ranges(invalid_ranges, "Test error:")
for _range in remaining:
assert _range in valid_ranges
assert len(remaining) == len(valid_ranges)
def test_invalid_blocklisted_ip():
local_network_interfaces = [NetworkInterface("172.60.145.109", "/30")]
inaccessible_subnets = ["172.60.147.8/30", "172.60.147.148/30"]
targets = ["172.60.145.151/30"]
blocklisted = ["172.60.145.153", "172.60.145.753"]
with pytest.raises(InvalidNetworkRangeError):
compile_scan_target_list(
local_network_interfaces=local_network_interfaces,
ranges_to_scan=targets,
inaccessible_subnets=inaccessible_subnets,
blocklisted_ips=blocklisted,
enable_local_network_scan=False,
)