diff --git a/monkey/common/network/network_range.py b/monkey/common/network/network_range.py index b7c8f14a4..5b1342370 100644 --- a/monkey/common/network/network_range.py +++ b/monkey/common/network/network_range.py @@ -178,7 +178,7 @@ class SingleIpRange(NetworkRange): :return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com) """ # The most common use case is to enter ip/range into "Scan IP/subnet list" - domain_name = "" + domain_name = None # Try casting user's input as IP try: diff --git a/monkey/infection_monkey/network/scan_target_generator.py b/monkey/infection_monkey/network/scan_target_generator.py index 79768c067..927123d48 100644 --- a/monkey/infection_monkey/network/scan_target_generator.py +++ b/monkey/infection_monkey/network/scan_target_generator.py @@ -1,12 +1,12 @@ import itertools import logging from collections import namedtuple -from typing import List, Set +from typing import List from common.network.network_range import InvalidNetworkRangeError, NetworkRange NetworkInterface = namedtuple("NetworkInterface", ("address", "netmask")) - +NetworkAddress = namedtuple("NetworkAddress", ("ip", "domain")) logger = logging.getLogger(__name__) @@ -17,73 +17,101 @@ def compile_scan_target_list( inaccessible_subnets: List[str], blocklisted_ips: List[str], enable_local_network_scan: bool, -) -> List[str]: - +) -> List[NetworkAddress]: scan_targets = _get_ips_from_ranges_to_scan(ranges_to_scan) if enable_local_network_scan: - scan_targets.update(_get_ips_to_scan_from_local_interface(local_network_interfaces)) + scan_targets.extend(_get_ips_to_scan_from_local_interface(local_network_interfaces)) if inaccessible_subnets: inaccessible_subnets = _get_segmentation_check_targets( inaccessible_subnets, local_network_interfaces ) - scan_targets.update(inaccessible_subnets) + scan_targets.extend(inaccessible_subnets) - _remove_interface_ips(scan_targets, local_network_interfaces) - _remove_blocklisted_ips(scan_targets, blocklisted_ips) + scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces) + scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips) + scan_targets = _remove_redundant_targets(scan_targets) + scan_targets.sort() - scan_target_list = list(scan_targets) - scan_target_list.sort() - - return scan_target_list + return scan_targets -def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> Set[str]: - scan_targets = set() +def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddress]: + target_dict = {} + for target in targets: + domain_name = target.domain + ip = target.ip + if ip not in target_dict or (target_dict[ip] is None and domain_name is not None): + target_dict[ip] = domain_name + return [NetworkAddress(key, value) for (key, value) in target_dict.items()] + + +def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]: + addresses = [] + for address in range_obj: + if hasattr(range_obj, "domain_name"): + addresses.append(NetworkAddress(address, range_obj.domain_name)) + else: + addresses.append(NetworkAddress(address, None)) + return addresses + + +def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> List[NetworkAddress]: + scan_targets = [] ranges_to_scan = _filter_invalid_ranges( ranges_to_scan, "Bad network range input for targets to scan:" ) network_ranges = [NetworkRange.get_range_obj(_range) for _range in ranges_to_scan] + for _range in network_ranges: - scan_targets.update(set(_range)) + scan_targets.extend(_range_to_addresses(_range)) return scan_targets -def _get_ips_to_scan_from_local_interface(interfaces: List[NetworkInterface]) -> Set[str]: +def _get_ips_to_scan_from_local_interface( + interfaces: List[NetworkInterface], +) -> List[NetworkAddress]: ranges = [f"{interface.address}{interface.netmask}" for interface in interfaces] ranges = _filter_invalid_ranges(ranges, "Local network interface returns an invalid IP:") return _get_ips_from_ranges_to_scan(ranges) -def _remove_interface_ips(scan_targets: Set[str], interfaces: List[NetworkInterface]): +def _remove_interface_ips( + scan_targets: List[NetworkAddress], interfaces: List[NetworkInterface] +) -> List[NetworkAddress]: interface_ips = [interface.address for interface in interfaces] - _remove_ips_from_scan_targets(scan_targets, interface_ips) + return _remove_ips_from_scan_targets(scan_targets, interface_ips) -def _remove_blocklisted_ips(scan_targets: Set[str], blocked_ips: List[str]): +def _remove_blocklisted_ips( + scan_targets: List[NetworkAddress], blocked_ips: List[str] +) -> List[NetworkAddress]: filtered_blocked_ips = _filter_invalid_ranges(blocked_ips, "Invalid blocked IP provided:") if not len(filtered_blocked_ips) == len(blocked_ips): raise InvalidNetworkRangeError("Received an invalid blocked IP. Aborting just in case.") - _remove_ips_from_scan_targets(scan_targets, blocked_ips) + return _remove_ips_from_scan_targets(scan_targets, filtered_blocked_ips) -def _remove_ips_from_scan_targets(scan_targets: Set[str], ips_to_remove: List[str]): +def _remove_ips_from_scan_targets( + scan_targets: List[NetworkAddress], ips_to_remove: List[str] +) -> List[NetworkAddress]: for ip in ips_to_remove: try: - scan_targets.remove(ip) + scan_targets = [address for address in scan_targets if address.ip != ip] except KeyError: # We don't need to remove the ip if it's already missing from the scan_targets pass + return scan_targets def _get_segmentation_check_targets( inaccessible_subnets: List[str], local_interfaces: List[NetworkInterface] -): - subnets_to_scan = set() +) -> List[NetworkAddress]: + subnets_to_scan = [] local_ips = [interface.address for interface in local_interfaces] local_ips = _filter_invalid_ranges(local_ips, "Invalid local IP found: ") @@ -97,7 +125,7 @@ def _get_segmentation_check_targets( for (subnet1, subnet2) in subnet_pairs: if _is_segmentation_check_required(local_ips, subnet1, subnet2): ips = _get_ips_from_ranges_to_scan(subnet2) - subnets_to_scan.update(ips) + subnets_to_scan.extend(ips) return subnets_to_scan diff --git a/monkey/tests/unit_tests/infection_monkey/model/test_victim_host_generator.py b/monkey/tests/unit_tests/infection_monkey/model/test_victim_host_generator.py index c60992fee..0133102eb 100644 --- a/monkey/tests/unit_tests/infection_monkey/model/test_victim_host_generator.py +++ b/monkey/tests/unit_tests/infection_monkey/model/test_victim_host_generator.py @@ -39,8 +39,3 @@ class TestVictimHostGenerator(TestCase): victims = list(generator.generate_victims_from_range(self.local_host_range)) self.assertEqual(len(victims), 1) self.assertEqual(victims[0].domain_name, "localhost") - - # don't generate for other victims - victims = list(generator.generate_victims_from_range(self.random_single_ip_range)) - self.assertEqual(len(victims), 1) - self.assertEqual(victims[0].domain_name, "") diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py b/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py index af194b300..41600897d 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py +++ b/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py @@ -1,7 +1,7 @@ from itertools import chain import pytest -from network.scan_target_generator import _filter_invalid_ranges +from network.scan_target_generator import NetworkAddress, _filter_invalid_ranges from common.network.network_range import InvalidNetworkRangeError from infection_monkey.network.scan_target_generator import ( @@ -26,7 +26,7 @@ def test_single_subnet(): assert len(scan_targets) == 255 for i in range(0, 255): - assert f"10.0.0.{i}" in scan_targets + assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets @pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"]) @@ -35,8 +35,8 @@ def test_single_ip(single_ip): scan_targets = compile_ranges_only([single_ip]) assert len(scan_targets) == 1 - assert "10.0.0.2" in scan_targets - assert "10.0.0.2" == scan_targets[0] + assert NetworkAddress("10.0.0.2", None) in scan_targets + assert NetworkAddress("10.0.0.2", None) == scan_targets[0] def test_multiple_subnet(): @@ -45,10 +45,10 @@ def test_multiple_subnet(): assert len(scan_targets) == 262 for i in range(0, 255): - assert f"10.0.0.{i}" in scan_targets + assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets for i in range(8, 15): - assert f"192.168.56.{i}" in scan_targets + assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets def test_middle_of_range_subnet(): @@ -57,7 +57,7 @@ def test_middle_of_range_subnet(): assert len(scan_targets) == 7 for i in range(0, 7): - assert f"192.168.56.{i}" in scan_targets + assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets @pytest.mark.parametrize( @@ -70,7 +70,7 @@ def test_ip_range(ip_range): assert len(scan_targets) == 9 for i in range(25, 34): - assert f"192.168.56.{i}" in scan_targets + assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets def test_no_duplicates(): @@ -79,7 +79,7 @@ def test_no_duplicates(): assert len(scan_targets) == 7 for i in range(0, 7): - assert f"192.168.56.{i}" in scan_targets + assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets def test_blocklisted_ips(): @@ -134,6 +134,24 @@ def test_local_network_interface_ips_removed_from_targets(): assert interface.address not in scan_targets +def test_no_redundant_targets(): + local_network_interfaces = [ + NetworkInterface("10.0.0.5", "/24"), + ] + + scan_targets = compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=["127.0.0.0", "127.0.0.1", "localhost"], + inaccessible_subnets=[], + blocklisted_ips=[], + enable_local_network_scan=False, + ) + + assert len(scan_targets) == 2 + assert NetworkAddress(ip="127.0.0.0", domain=None) in scan_targets + assert NetworkAddress(ip="127.0.0.1", domain="localhost") in scan_targets + + @pytest.mark.parametrize("ranges_to_scan", [["10.0.0.5"], []]) def test_only_scan_ip_is_local(ranges_to_scan): local_network_interfaces = [ @@ -196,7 +214,7 @@ def test_local_subnet_added(): assert len(scan_targets) == 254 for ip in chain(range(0, 5), range(6, 255)): - assert f"10.0.0.{ip} in scan_targets" + assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets def test_multiple_local_subnets_added(): @@ -216,10 +234,10 @@ def test_multiple_local_subnets_added(): assert len(scan_targets) == 2 * (255 - 1) for ip in chain(range(0, 5), range(6, 255)): - assert f"10.0.0.{ip} in scan_targets" + assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets for ip in chain(range(0, 99), range(100, 255)): - assert f"172.33.66.{ip} in scan_targets" + assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets def test_blocklisted_ips_missing_from_local_subnets(): @@ -257,12 +275,12 @@ def test_local_subnets_and_ranges_added(): assert len(scan_targets) == 254 + 3 for ip in range(0, 5): - assert f"10.0.0.{ip} in scan_targets" + assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets for ip in range(6, 255): - assert f"10.0.0.{ip} in scan_targets" + assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets for ip in range(40, 43): - assert f"172.33.66.{ip} in scan_targets" + assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets def test_local_network_interfaces_specified_but_disabled(): @@ -279,7 +297,7 @@ def test_local_network_interfaces_specified_but_disabled(): assert len(scan_targets) == 3 for ip in range(40, 43): - assert f"172.33.66.{ip} in scan_targets" + assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets def test_local_network_interfaces_subnet_masks(): @@ -299,7 +317,7 @@ def test_local_network_interfaces_subnet_masks(): assert len(scan_targets) == 4 for ip in [108, 110, 145, 146]: - assert f"172.60.145.{ip}" in scan_targets + assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets def test_segmentation_targets(): @@ -318,7 +336,7 @@ def test_segmentation_targets(): assert len(scan_targets) == 3 for ip in [144, 145, 146]: - assert f"172.60.145.{ip}" in scan_targets + assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets def test_segmentation_clash_with_blocked(): @@ -361,7 +379,7 @@ def test_segmentation_clash_with_targets(): assert len(scan_targets) == 3 for ip in [148, 149, 150]: - assert f"172.60.145.{ip}" in scan_targets + assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets def test_segmentation_one_network(): @@ -428,7 +446,7 @@ def test_invalid_inputs(): assert len(scan_targets) == 3 for ip in [148, 149, 150]: - assert f"172.60.145.{ip}" in scan_targets + assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets def test_range_filtering(): @@ -454,7 +472,7 @@ def test_range_filtering(): "172.60.9.109 - 172.60.1.109", "172.60.9.109- 172.60.1.109", "0.0.0.0", - "localhost" + "localhost", ] invalid_ranges.extend(valid_ranges)