diff --git a/monkey/infection_monkey/network/scan_target_generator.py b/monkey/infection_monkey/network/scan_target_generator.py index cdcfbdb31..862e37aef 100644 --- a/monkey/infection_monkey/network/scan_target_generator.py +++ b/monkey/infection_monkey/network/scan_target_generator.py @@ -12,6 +12,7 @@ def compile_scan_target_list( ) -> List[str]: scan_targets = _get_ips_from_ranges_to_scan(ranges_to_scan) + _remove_local_ips(scan_targets, local_ips) _remove_blocklisted_ips(scan_targets, blocklisted_ips) scan_target_list = list(scan_targets) @@ -30,10 +31,18 @@ def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> Set[str]: return scan_targets +def _remove_local_ips(scan_targets: Set[str], local_ips: List[str]): + _remove_ips_from_scan_targets(scan_targets, local_ips) + + def _remove_blocklisted_ips(scan_targets: Set[str], blocked_ips: List[str]): - for blocked_ip in blocked_ips: + _remove_ips_from_scan_targets(scan_targets, blocked_ips) + + +def _remove_ips_from_scan_targets(scan_targets: Set[str], ips_to_remove: List[str]): + for ip in ips_to_remove: try: - scan_targets.remove(blocked_ip) + scan_targets.remove(ip) except KeyError: - # We don't need to remove the blocked ip if it's already missing from the scan_targets + # We don't need to remove the ip if it's already missing from the scan_targets pass diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py b/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py index 9e1b5fc0b..089644187 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py +++ b/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py @@ -103,3 +103,55 @@ def test_only_ip_blocklisted(ranges_to_scan): ) assert len(scan_targets) == 0 + + +def test_local_ips_removed_from_targets(): + local_ips = ["10.0.0.5", "10.0.0.32", "10.0.0.119", "192.168.1.33"] + + scan_targets = compile_scan_target_list( + local_ips=local_ips, + ranges_to_scan=["10.0.0.0/24"], + inaccessible_subnets=[], + blocklisted_ips=[], + enable_local_network_scan=False, + ) + + assert len(scan_targets) == 252 + for ip in local_ips: + assert ip not in scan_targets + + +@pytest.mark.parametrize("ranges_to_scan", [["10.0.0.5"], []]) +def test_only_scan_ip_is_local(ranges_to_scan): + local_ips = ["10.0.0.5", "10.0.0.32", "10.0.0.119", "192.168.1.33"] + + scan_targets = compile_scan_target_list( + local_ips=local_ips, + ranges_to_scan=ranges_to_scan, + inaccessible_subnets=[], + blocklisted_ips=[], + enable_local_network_scan=False, + ) + + assert len(scan_targets) == 0 + + +def test_local_ips_and_blocked_ips_removed_from_targets(): + local_ips = ["10.0.0.5", "10.0.0.32", "10.0.0.119", "192.168.1.33"] + blocked_ips = ["10.0.0.63", "192.168.1.77", "0.0.0.0"] + + scan_targets = compile_scan_target_list( + local_ips=local_ips, + ranges_to_scan=["10.0.0.0/24", "192.168.1.0/24"], + inaccessible_subnets=[], + blocklisted_ips=blocked_ips, + enable_local_network_scan=False, + ) + + assert len(scan_targets) == (2 * (256 - 1)) - len(local_ips) - (len(blocked_ips) - 1) + + for ip in local_ips: + assert ip not in scan_targets + + for ip in blocked_ips: + assert ip not in scan_targets