diff --git a/monkey/common/network/network_range.py b/monkey/common/network/network_range.py index 479f5f0d7..b7c8f14a4 100644 --- a/monkey/common/network/network_range.py +++ b/monkey/common/network/network_range.py @@ -9,6 +9,10 @@ from typing import Tuple logger = logging.getLogger(__name__) +class InvalidNetworkRangeError(Exception): + """Raise when invalid network range is provided""" + + class NetworkRange(object, metaclass=ABCMeta): def __init__(self, shuffle=True): self._shuffle = shuffle @@ -53,6 +57,13 @@ class NetworkRange(object, metaclass=ABCMeta): return CidrRange(cidr_range=address_str) return SingleIpRange(ip_address=address_str) + @staticmethod + def validate_range(address_str: str): + try: + NetworkRange.get_range_obj(address_str) + except (ValueError, OSError) as e: + raise InvalidNetworkRangeError(e) + @staticmethod def check_if_range(address_str): if -1 != address_str.find("-"): @@ -178,10 +189,9 @@ class SingleIpRange(NetworkRange): ip = socket.gethostbyname(string_) domain_name = string_ except socket.error: - logger.error( + raise ValueError( "Your specified host: {} is not found as a domain name and" " it's not an IP address".format(string_) ) - return None, string_ # If a string_ was entered instead of IP we presume that it was domain name and translate it return ip, domain_name diff --git a/monkey/infection_monkey/network/scan_target_generator.py b/monkey/infection_monkey/network/scan_target_generator.py index 1e0b4055e..79768c067 100644 --- a/monkey/infection_monkey/network/scan_target_generator.py +++ b/monkey/infection_monkey/network/scan_target_generator.py @@ -1,15 +1,16 @@ import itertools +import logging from collections import namedtuple from typing import List, Set -from common.network.network_range import NetworkRange +from common.network.network_range import InvalidNetworkRangeError, NetworkRange -# TODO: Convert to class and validate the format of the address and netmask -# Example: address="192.168.1.1", netmask="/24" NetworkInterface = namedtuple("NetworkInterface", ("address", "netmask")) -# TODO: Validate all parameters +logger = logging.getLogger(__name__) + + def compile_scan_target_list( local_network_interfaces: List[NetworkInterface], ranges_to_scan: List[str], @@ -17,6 +18,7 @@ def compile_scan_target_list( blocklisted_ips: List[str], enable_local_network_scan: bool, ) -> List[str]: + scan_targets = _get_ips_from_ranges_to_scan(ranges_to_scan) if enable_local_network_scan: @@ -40,15 +42,20 @@ def compile_scan_target_list( def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> Set[str]: scan_targets = set() + ranges_to_scan = _filter_invalid_ranges( + ranges_to_scan, "Bad network range input for targets to scan:" + ) + network_ranges = [NetworkRange.get_range_obj(_range) for _range in ranges_to_scan] for _range in network_ranges: scan_targets.update(set(_range)) - return scan_targets def _get_ips_to_scan_from_local_interface(interfaces: List[NetworkInterface]) -> Set[str]: ranges = [f"{interface.address}{interface.netmask}" for interface in interfaces] + + ranges = _filter_invalid_ranges(ranges, "Local network interface returns an invalid IP:") return _get_ips_from_ranges_to_scan(ranges) @@ -58,6 +65,9 @@ def _remove_interface_ips(scan_targets: Set[str], interfaces: List[NetworkInterf def _remove_blocklisted_ips(scan_targets: Set[str], blocked_ips: List[str]): + filtered_blocked_ips = _filter_invalid_ranges(blocked_ips, "Invalid blocked IP provided:") + if not len(filtered_blocked_ips) == len(blocked_ips): + raise InvalidNetworkRangeError("Received an invalid blocked IP. Aborting just in case.") _remove_ips_from_scan_targets(scan_targets, blocked_ips) @@ -76,6 +86,11 @@ def _get_segmentation_check_targets( subnets_to_scan = set() local_ips = [interface.address for interface in local_interfaces] + local_ips = _filter_invalid_ranges(local_ips, "Invalid local IP found: ") + inaccessible_subnets = _filter_invalid_ranges( + inaccessible_subnets, "Invalid segmentation scan target: " + ) + inaccessible_subnets = _convert_to_range_object(inaccessible_subnets) subnet_pairs = itertools.product(inaccessible_subnets, inaccessible_subnets) @@ -87,6 +102,18 @@ def _get_segmentation_check_targets( return subnets_to_scan +def _filter_invalid_ranges(ranges: List[str], error_msg: str) -> List[str]: + filtered = [] + for target_range in ranges: + try: + NetworkRange.validate_range(target_range) + except InvalidNetworkRangeError as e: + logger.error(f"{error_msg} {e}") + continue + filtered.append(target_range) + return filtered + + def _convert_to_range_object(subnets: List[str]) -> List[NetworkRange]: return [NetworkRange.get_range_obj(subnet) for subnet in subnets] diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py b/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py index a28ae8275..af194b300 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py +++ b/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py @@ -1,7 +1,9 @@ from itertools import chain import pytest +from network.scan_target_generator import _filter_invalid_ranges +from common.network.network_range import InvalidNetworkRangeError from infection_monkey.network.scan_target_generator import ( NetworkInterface, compile_scan_target_list, @@ -399,3 +401,84 @@ def test_segmentation_inaccessible_networks(): ) assert len(scan_targets) == 0 + + +def test_invalid_inputs(): + local_network_interfaces = [ + NetworkInterface("172.60.999.109", "/30"), + NetworkInterface("172.60.145.109", "/30"), + ] + + inaccessible_subnets = [ + "172.60.145.1 - 172.60.145.1111", + "172.60.147.888/30" "172.60.147.8/30", + "172.60.147.148/30", + ] + + targets = ["172.60.145.149/33", "1.-1.1.1", "1.a.2.2", "172.60.145.151/30"] + + scan_targets = compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=targets, + inaccessible_subnets=inaccessible_subnets, + blocklisted_ips=[], + enable_local_network_scan=False, + ) + + assert len(scan_targets) == 3 + + for ip in [148, 149, 150]: + assert f"172.60.145.{ip}" in scan_targets + + +def test_range_filtering(): + invalid_ranges = [ + # Invalid IP segment + "172.60.999.109", + "172.60.-1.109", + "172.60.999.109 - 172.60.1.109", + "172.60.999.109/32", + "172.60.999.109/24", + # Invalid CIDR + "172.60.1.109/33", + "172.60.1.109/-1", + # Typos + "172.60.9.109 -t 172.60.1.109", + "172.60..9.109", + "172.60,9.109", + " 172.60 .9.109 ", + ] + + valid_ranges = [ + " 172.60.9.109 ", + "172.60.9.109 - 172.60.1.109", + "172.60.9.109- 172.60.1.109", + "0.0.0.0", + "localhost" + ] + + invalid_ranges.extend(valid_ranges) + + remaining = _filter_invalid_ranges(invalid_ranges, "Test error:") + for _range in remaining: + assert _range in valid_ranges + assert len(remaining) == len(valid_ranges) + + +def test_invalid_blocklisted_ip(): + local_network_interfaces = [NetworkInterface("172.60.145.109", "/30")] + + inaccessible_subnets = ["172.60.147.8/30", "172.60.147.148/30"] + + targets = ["172.60.145.151/30"] + + blocklisted = ["172.60.145.153", "172.60.145.753"] + + with pytest.raises(InvalidNetworkRangeError): + compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=targets, + inaccessible_subnets=inaccessible_subnets, + blocklisted_ips=blocklisted, + enable_local_network_scan=False, + )