diff --git a/monkey/infection_monkey/network/scan_target_generator.py b/monkey/infection_monkey/network/scan_target_generator.py index 862e37aef..1f4e44e86 100644 --- a/monkey/infection_monkey/network/scan_target_generator.py +++ b/monkey/infection_monkey/network/scan_target_generator.py @@ -1,10 +1,17 @@ +from collections import namedtuple from typing import List, Set from common.network.network_range import NetworkRange +# TODO: Convert to class and validate the format of the address and netmask +# Example: address="192.168.1.1", netmask="/24" +NetworkInterface = namedtuple("NetworkInterface", ("address", "netmask")) + +# TODO: Validate all parameters +# TODO: Implement inaccessible_subnets def compile_scan_target_list( - local_ips: List[str], + local_network_interfaces: List[NetworkInterface], ranges_to_scan: List[str], inaccessible_subnets: List[str], blocklisted_ips: List[str], @@ -12,7 +19,10 @@ def compile_scan_target_list( ) -> List[str]: scan_targets = _get_ips_from_ranges_to_scan(ranges_to_scan) - _remove_local_ips(scan_targets, local_ips) + if enable_local_network_scan: + scan_targets.update(_get_ips_to_scan_from_local_interface(local_network_interfaces)) + + _remove_interface_ips(scan_targets, local_network_interfaces) _remove_blocklisted_ips(scan_targets, blocklisted_ips) scan_target_list = list(scan_targets) @@ -31,8 +41,14 @@ def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> Set[str]: return scan_targets -def _remove_local_ips(scan_targets: Set[str], local_ips: List[str]): - _remove_ips_from_scan_targets(scan_targets, local_ips) +def _get_ips_to_scan_from_local_interface(interfaces: List[NetworkInterface]) -> Set[str]: + ranges = [f"{interface.address}{interface.netmask}" for interface in interfaces] + return _get_ips_from_ranges_to_scan(ranges) + + +def _remove_interface_ips(scan_targets: Set[str], interfaces: List[NetworkInterface]): + interface_ips = [interface.address for interface in interfaces] + _remove_ips_from_scan_targets(scan_targets, interface_ips) def _remove_blocklisted_ips(scan_targets: Set[str], blocked_ips: List[str]): diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py b/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py index 089644187..cf69b7a30 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py +++ b/monkey/tests/unit_tests/infection_monkey/network/test_scan_target_generator.py @@ -1,11 +1,14 @@ import pytest -from infection_monkey.network.scan_target_generator import compile_scan_target_list +from infection_monkey.network.scan_target_generator import ( + NetworkInterface, + compile_scan_target_list, +) def compile_ranges_only(ranges): return compile_scan_target_list( - local_ips=[], + local_network_interfaces=[], ranges_to_scan=ranges, inaccessible_subnets=[], blocklisted_ips=[], @@ -78,7 +81,7 @@ def test_blocklisted_ips(): blocklisted_ips = ["10.0.0.5", "10.0.0.32", "10.0.0.119", "192.168.1.33"] scan_targets = compile_scan_target_list( - local_ips=[], + local_network_interfaces=[], ranges_to_scan=["10.0.0.0/24"], inaccessible_subnets=[], blocklisted_ips=blocklisted_ips, @@ -95,7 +98,7 @@ def test_only_ip_blocklisted(ranges_to_scan): blocklisted_ips = ["10.0.0.5"] scan_targets = compile_scan_target_list( - local_ips=[], + local_network_interfaces=[], ranges_to_scan=ranges_to_scan, inaccessible_subnets=[], blocklisted_ips=blocklisted_ips, @@ -105,11 +108,16 @@ def test_only_ip_blocklisted(ranges_to_scan): assert len(scan_targets) == 0 -def test_local_ips_removed_from_targets(): - local_ips = ["10.0.0.5", "10.0.0.32", "10.0.0.119", "192.168.1.33"] +def test_local_network_interface_ips_removed_from_targets(): + local_network_interfaces = [ + NetworkInterface("10.0.0.5", "/24"), + NetworkInterface("10.0.0.32", "/24"), + NetworkInterface("10.0.0.119", "/24"), + NetworkInterface("192.168.1.33", "/24"), + ] scan_targets = compile_scan_target_list( - local_ips=local_ips, + local_network_interfaces=local_network_interfaces, ranges_to_scan=["10.0.0.0/24"], inaccessible_subnets=[], blocklisted_ips=[], @@ -117,16 +125,21 @@ def test_local_ips_removed_from_targets(): ) assert len(scan_targets) == 252 - for ip in local_ips: - assert ip not in scan_targets + for interface in local_network_interfaces: + assert interface.address not in scan_targets @pytest.mark.parametrize("ranges_to_scan", [["10.0.0.5"], []]) def test_only_scan_ip_is_local(ranges_to_scan): - local_ips = ["10.0.0.5", "10.0.0.32", "10.0.0.119", "192.168.1.33"] + local_network_interfaces = [ + NetworkInterface("10.0.0.5", "/24"), + NetworkInterface("10.0.0.32", "/24"), + NetworkInterface("10.0.0.119", "/24"), + NetworkInterface("192.168.1.33", "/24"), + ] scan_targets = compile_scan_target_list( - local_ips=local_ips, + local_network_interfaces=local_network_interfaces, ranges_to_scan=ranges_to_scan, inaccessible_subnets=[], blocklisted_ips=[], @@ -136,22 +149,155 @@ def test_only_scan_ip_is_local(ranges_to_scan): assert len(scan_targets) == 0 -def test_local_ips_and_blocked_ips_removed_from_targets(): - local_ips = ["10.0.0.5", "10.0.0.32", "10.0.0.119", "192.168.1.33"] +def test_local_network_interface_ips_and_blocked_ips_removed_from_targets(): + local_network_interfaces = [ + NetworkInterface("10.0.0.5", "/24"), + NetworkInterface("10.0.0.32", "/24"), + NetworkInterface("10.0.0.119", "/24"), + NetworkInterface("192.168.1.33", "/24"), + ] blocked_ips = ["10.0.0.63", "192.168.1.77", "0.0.0.0"] scan_targets = compile_scan_target_list( - local_ips=local_ips, + local_network_interfaces=local_network_interfaces, ranges_to_scan=["10.0.0.0/24", "192.168.1.0/24"], inaccessible_subnets=[], blocklisted_ips=blocked_ips, enable_local_network_scan=False, ) - assert len(scan_targets) == (2 * (256 - 1)) - len(local_ips) - (len(blocked_ips) - 1) + assert len(scan_targets) == (2 * (256 - 1)) - len(local_network_interfaces) - ( + len(blocked_ips) - 1 + ) - for ip in local_ips: - assert ip not in scan_targets + for interface in local_network_interfaces: + assert interface.address not in scan_targets for ip in blocked_ips: assert ip not in scan_targets + + +def test_local_subnet_added(): + local_network_interfaces = [NetworkInterface("10.0.0.5", "/24")] + + scan_targets = compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=[], + inaccessible_subnets=[], + blocklisted_ips=[], + enable_local_network_scan=True, + ) + + assert len(scan_targets) == 254 + + for ip in range(0, 5): + assert f"10.0.0.{ip} in scan_targets" + for ip in range(6, 255): + assert f"10.0.0.{ip} in scan_targets" + + +def test_multiple_local_subnets_added(): + local_network_interfaces = [ + NetworkInterface("10.0.0.5", "/24"), + NetworkInterface("172.33.66.99", "/24"), + ] + + scan_targets = compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=[], + inaccessible_subnets=[], + blocklisted_ips=[], + enable_local_network_scan=True, + ) + + assert len(scan_targets) == 2 * (255 - 1) + + for ip in range(0, 5): + assert f"10.0.0.{ip} in scan_targets" + for ip in range(6, 255): + assert f"10.0.0.{ip} in scan_targets" + + for ip in range(0, 99): + assert f"172.33.66.{ip} in scan_targets" + for ip in range(100, 255): + assert f"172.33.66.{ip} in scan_targets" + + +def test_blocklisted_ips_missing_from_local_subnets(): + local_network_interfaces = [ + NetworkInterface("10.0.0.5", "/24"), + NetworkInterface("172.33.66.99", "/24"), + ] + blocklisted_ips = ["10.0.0.12", "10.0.0.13", "172.33.66.25"] + + scan_targets = compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=[], + inaccessible_subnets=[], + blocklisted_ips=blocklisted_ips, + enable_local_network_scan=True, + ) + + assert len(scan_targets) == 2 * (255 - 1) - len(blocklisted_ips) + + for ip in blocklisted_ips: + assert ip not in scan_targets + + +def test_local_subnets_and_ranges_added(): + local_network_interfaces = [NetworkInterface("10.0.0.5", "/24")] + + scan_targets = compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=["172.33.66.40/30"], + inaccessible_subnets=[], + blocklisted_ips=[], + enable_local_network_scan=True, + ) + + assert len(scan_targets) == 254 + 3 + + for ip in range(0, 5): + assert f"10.0.0.{ip} in scan_targets" + for ip in range(6, 255): + assert f"10.0.0.{ip} in scan_targets" + + for ip in range(40, 43): + assert f"172.33.66.{ip} in scan_targets" + + +def test_local_network_interfaces_specified_but_disabled(): + local_network_interfaces = [NetworkInterface("10.0.0.5", "/24")] + + scan_targets = compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=["172.33.66.40/30"], + inaccessible_subnets=[], + blocklisted_ips=[], + enable_local_network_scan=False, + ) + + assert len(scan_targets) == 3 + + for ip in range(40, 43): + assert f"172.33.66.{ip} in scan_targets" + + +def test_local_network_interfaces_subnet_masks(): + local_network_interfaces = [ + NetworkInterface("172.60.145.109", "/30"), + NetworkInterface("172.60.145.144", "/30"), + ] + + scan_targets = compile_scan_target_list( + local_network_interfaces=local_network_interfaces, + ranges_to_scan=[], + inaccessible_subnets=[], + blocklisted_ips=[], + enable_local_network_scan=True, + ) + + assert len(scan_targets) == 4 + + for ip in [108, 110, 145, 146]: + assert f"172.60.145.{ip}" in scan_targets