Agent, UT: Implement domain names in scan_target_generator.py

Change the ip strings to NetworkAddress named tuple that has ip and domain name. This tuple better describes the target and is necessary because VictimHost uses domain names
This commit is contained in:
vakarisz 2021-12-16 12:03:40 +02:00
parent cabadeb7d1
commit 549eb5d389
4 changed files with 93 additions and 52 deletions

View File

@ -178,7 +178,7 @@ class SingleIpRange(NetworkRange):
:return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com) :return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com)
""" """
# The most common use case is to enter ip/range into "Scan IP/subnet list" # The most common use case is to enter ip/range into "Scan IP/subnet list"
domain_name = "" domain_name = None
# Try casting user's input as IP # Try casting user's input as IP
try: try:

View File

@ -1,12 +1,12 @@
import itertools import itertools
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import List, Set from typing import List
from common.network.network_range import InvalidNetworkRangeError, NetworkRange from common.network.network_range import InvalidNetworkRangeError, NetworkRange
NetworkInterface = namedtuple("NetworkInterface", ("address", "netmask")) NetworkInterface = namedtuple("NetworkInterface", ("address", "netmask"))
NetworkAddress = namedtuple("NetworkAddress", ("ip", "domain"))
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,73 +17,101 @@ def compile_scan_target_list(
inaccessible_subnets: List[str], inaccessible_subnets: List[str],
blocklisted_ips: List[str], blocklisted_ips: List[str],
enable_local_network_scan: bool, enable_local_network_scan: bool,
) -> List[str]: ) -> List[NetworkAddress]:
scan_targets = _get_ips_from_ranges_to_scan(ranges_to_scan) scan_targets = _get_ips_from_ranges_to_scan(ranges_to_scan)
if enable_local_network_scan: if enable_local_network_scan:
scan_targets.update(_get_ips_to_scan_from_local_interface(local_network_interfaces)) scan_targets.extend(_get_ips_to_scan_from_local_interface(local_network_interfaces))
if inaccessible_subnets: if inaccessible_subnets:
inaccessible_subnets = _get_segmentation_check_targets( inaccessible_subnets = _get_segmentation_check_targets(
inaccessible_subnets, local_network_interfaces inaccessible_subnets, local_network_interfaces
) )
scan_targets.update(inaccessible_subnets) scan_targets.extend(inaccessible_subnets)
_remove_interface_ips(scan_targets, local_network_interfaces) scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
_remove_blocklisted_ips(scan_targets, blocklisted_ips) scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
scan_targets = _remove_redundant_targets(scan_targets)
scan_targets.sort()
scan_target_list = list(scan_targets) return scan_targets
scan_target_list.sort()
return scan_target_list
def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> Set[str]: def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddress]:
scan_targets = set() target_dict = {}
for target in targets:
domain_name = target.domain
ip = target.ip
if ip not in target_dict or (target_dict[ip] is None and domain_name is not None):
target_dict[ip] = domain_name
return [NetworkAddress(key, value) for (key, value) in target_dict.items()]
def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
addresses = []
for address in range_obj:
if hasattr(range_obj, "domain_name"):
addresses.append(NetworkAddress(address, range_obj.domain_name))
else:
addresses.append(NetworkAddress(address, None))
return addresses
def _get_ips_from_ranges_to_scan(ranges_to_scan: List[str]) -> List[NetworkAddress]:
scan_targets = []
ranges_to_scan = _filter_invalid_ranges( ranges_to_scan = _filter_invalid_ranges(
ranges_to_scan, "Bad network range input for targets to scan:" ranges_to_scan, "Bad network range input for targets to scan:"
) )
network_ranges = [NetworkRange.get_range_obj(_range) for _range in ranges_to_scan] network_ranges = [NetworkRange.get_range_obj(_range) for _range in ranges_to_scan]
for _range in network_ranges: for _range in network_ranges:
scan_targets.update(set(_range)) scan_targets.extend(_range_to_addresses(_range))
return scan_targets return scan_targets
def _get_ips_to_scan_from_local_interface(interfaces: List[NetworkInterface]) -> Set[str]: def _get_ips_to_scan_from_local_interface(
interfaces: List[NetworkInterface],
) -> List[NetworkAddress]:
ranges = [f"{interface.address}{interface.netmask}" for interface in interfaces] ranges = [f"{interface.address}{interface.netmask}" for interface in interfaces]
ranges = _filter_invalid_ranges(ranges, "Local network interface returns an invalid IP:") ranges = _filter_invalid_ranges(ranges, "Local network interface returns an invalid IP:")
return _get_ips_from_ranges_to_scan(ranges) return _get_ips_from_ranges_to_scan(ranges)
def _remove_interface_ips(scan_targets: Set[str], interfaces: List[NetworkInterface]): def _remove_interface_ips(
scan_targets: List[NetworkAddress], interfaces: List[NetworkInterface]
) -> List[NetworkAddress]:
interface_ips = [interface.address for interface in interfaces] interface_ips = [interface.address for interface in interfaces]
_remove_ips_from_scan_targets(scan_targets, interface_ips) return _remove_ips_from_scan_targets(scan_targets, interface_ips)
def _remove_blocklisted_ips(scan_targets: Set[str], blocked_ips: List[str]): def _remove_blocklisted_ips(
scan_targets: List[NetworkAddress], blocked_ips: List[str]
) -> List[NetworkAddress]:
filtered_blocked_ips = _filter_invalid_ranges(blocked_ips, "Invalid blocked IP provided:") filtered_blocked_ips = _filter_invalid_ranges(blocked_ips, "Invalid blocked IP provided:")
if not len(filtered_blocked_ips) == len(blocked_ips): if not len(filtered_blocked_ips) == len(blocked_ips):
raise InvalidNetworkRangeError("Received an invalid blocked IP. Aborting just in case.") raise InvalidNetworkRangeError("Received an invalid blocked IP. Aborting just in case.")
_remove_ips_from_scan_targets(scan_targets, blocked_ips) return _remove_ips_from_scan_targets(scan_targets, filtered_blocked_ips)
def _remove_ips_from_scan_targets(scan_targets: Set[str], ips_to_remove: List[str]): def _remove_ips_from_scan_targets(
scan_targets: List[NetworkAddress], ips_to_remove: List[str]
) -> List[NetworkAddress]:
for ip in ips_to_remove: for ip in ips_to_remove:
try: try:
scan_targets.remove(ip) scan_targets = [address for address in scan_targets if address.ip != ip]
except KeyError: except KeyError:
# We don't need to remove the ip if it's already missing from the scan_targets # We don't need to remove the ip if it's already missing from the scan_targets
pass pass
return scan_targets
def _get_segmentation_check_targets( def _get_segmentation_check_targets(
inaccessible_subnets: List[str], local_interfaces: List[NetworkInterface] inaccessible_subnets: List[str], local_interfaces: List[NetworkInterface]
): ) -> List[NetworkAddress]:
subnets_to_scan = set() subnets_to_scan = []
local_ips = [interface.address for interface in local_interfaces] local_ips = [interface.address for interface in local_interfaces]
local_ips = _filter_invalid_ranges(local_ips, "Invalid local IP found: ") local_ips = _filter_invalid_ranges(local_ips, "Invalid local IP found: ")
@ -97,7 +125,7 @@ def _get_segmentation_check_targets(
for (subnet1, subnet2) in subnet_pairs: for (subnet1, subnet2) in subnet_pairs:
if _is_segmentation_check_required(local_ips, subnet1, subnet2): if _is_segmentation_check_required(local_ips, subnet1, subnet2):
ips = _get_ips_from_ranges_to_scan(subnet2) ips = _get_ips_from_ranges_to_scan(subnet2)
subnets_to_scan.update(ips) subnets_to_scan.extend(ips)
return subnets_to_scan return subnets_to_scan

View File

@ -39,8 +39,3 @@ class TestVictimHostGenerator(TestCase):
victims = list(generator.generate_victims_from_range(self.local_host_range)) victims = list(generator.generate_victims_from_range(self.local_host_range))
self.assertEqual(len(victims), 1) self.assertEqual(len(victims), 1)
self.assertEqual(victims[0].domain_name, "localhost") self.assertEqual(victims[0].domain_name, "localhost")
# don't generate for other victims
victims = list(generator.generate_victims_from_range(self.random_single_ip_range))
self.assertEqual(len(victims), 1)
self.assertEqual(victims[0].domain_name, "")

View File

@ -1,7 +1,7 @@
from itertools import chain from itertools import chain
import pytest import pytest
from network.scan_target_generator import _filter_invalid_ranges from network.scan_target_generator import NetworkAddress, _filter_invalid_ranges
from common.network.network_range import InvalidNetworkRangeError from common.network.network_range import InvalidNetworkRangeError
from infection_monkey.network.scan_target_generator import ( from infection_monkey.network.scan_target_generator import (
@ -26,7 +26,7 @@ def test_single_subnet():
assert len(scan_targets) == 255 assert len(scan_targets) == 255
for i in range(0, 255): for i in range(0, 255):
assert f"10.0.0.{i}" in scan_targets assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets
@pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"]) @pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"])
@ -35,8 +35,8 @@ def test_single_ip(single_ip):
scan_targets = compile_ranges_only([single_ip]) scan_targets = compile_ranges_only([single_ip])
assert len(scan_targets) == 1 assert len(scan_targets) == 1
assert "10.0.0.2" in scan_targets assert NetworkAddress("10.0.0.2", None) in scan_targets
assert "10.0.0.2" == scan_targets[0] assert NetworkAddress("10.0.0.2", None) == scan_targets[0]
def test_multiple_subnet(): def test_multiple_subnet():
@ -45,10 +45,10 @@ def test_multiple_subnet():
assert len(scan_targets) == 262 assert len(scan_targets) == 262
for i in range(0, 255): for i in range(0, 255):
assert f"10.0.0.{i}" in scan_targets assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets
for i in range(8, 15): for i in range(8, 15):
assert f"192.168.56.{i}" in scan_targets assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
def test_middle_of_range_subnet(): def test_middle_of_range_subnet():
@ -57,7 +57,7 @@ def test_middle_of_range_subnet():
assert len(scan_targets) == 7 assert len(scan_targets) == 7
for i in range(0, 7): for i in range(0, 7):
assert f"192.168.56.{i}" in scan_targets assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -70,7 +70,7 @@ def test_ip_range(ip_range):
assert len(scan_targets) == 9 assert len(scan_targets) == 9
for i in range(25, 34): for i in range(25, 34):
assert f"192.168.56.{i}" in scan_targets assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
def test_no_duplicates(): def test_no_duplicates():
@ -79,7 +79,7 @@ def test_no_duplicates():
assert len(scan_targets) == 7 assert len(scan_targets) == 7
for i in range(0, 7): for i in range(0, 7):
assert f"192.168.56.{i}" in scan_targets assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
def test_blocklisted_ips(): def test_blocklisted_ips():
@ -134,6 +134,24 @@ def test_local_network_interface_ips_removed_from_targets():
assert interface.address not in scan_targets assert interface.address not in scan_targets
def test_no_redundant_targets():
local_network_interfaces = [
NetworkInterface("10.0.0.5", "/24"),
]
scan_targets = compile_scan_target_list(
local_network_interfaces=local_network_interfaces,
ranges_to_scan=["127.0.0.0", "127.0.0.1", "localhost"],
inaccessible_subnets=[],
blocklisted_ips=[],
enable_local_network_scan=False,
)
assert len(scan_targets) == 2
assert NetworkAddress(ip="127.0.0.0", domain=None) in scan_targets
assert NetworkAddress(ip="127.0.0.1", domain="localhost") in scan_targets
@pytest.mark.parametrize("ranges_to_scan", [["10.0.0.5"], []]) @pytest.mark.parametrize("ranges_to_scan", [["10.0.0.5"], []])
def test_only_scan_ip_is_local(ranges_to_scan): def test_only_scan_ip_is_local(ranges_to_scan):
local_network_interfaces = [ local_network_interfaces = [
@ -196,7 +214,7 @@ def test_local_subnet_added():
assert len(scan_targets) == 254 assert len(scan_targets) == 254
for ip in chain(range(0, 5), range(6, 255)): for ip in chain(range(0, 5), range(6, 255)):
assert f"10.0.0.{ip} in scan_targets" assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
def test_multiple_local_subnets_added(): def test_multiple_local_subnets_added():
@ -216,10 +234,10 @@ def test_multiple_local_subnets_added():
assert len(scan_targets) == 2 * (255 - 1) assert len(scan_targets) == 2 * (255 - 1)
for ip in chain(range(0, 5), range(6, 255)): for ip in chain(range(0, 5), range(6, 255)):
assert f"10.0.0.{ip} in scan_targets" assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
for ip in chain(range(0, 99), range(100, 255)): for ip in chain(range(0, 99), range(100, 255)):
assert f"172.33.66.{ip} in scan_targets" assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
def test_blocklisted_ips_missing_from_local_subnets(): def test_blocklisted_ips_missing_from_local_subnets():
@ -257,12 +275,12 @@ def test_local_subnets_and_ranges_added():
assert len(scan_targets) == 254 + 3 assert len(scan_targets) == 254 + 3
for ip in range(0, 5): for ip in range(0, 5):
assert f"10.0.0.{ip} in scan_targets" assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
for ip in range(6, 255): for ip in range(6, 255):
assert f"10.0.0.{ip} in scan_targets" assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
for ip in range(40, 43): for ip in range(40, 43):
assert f"172.33.66.{ip} in scan_targets" assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
def test_local_network_interfaces_specified_but_disabled(): def test_local_network_interfaces_specified_but_disabled():
@ -279,7 +297,7 @@ def test_local_network_interfaces_specified_but_disabled():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in range(40, 43): for ip in range(40, 43):
assert f"172.33.66.{ip} in scan_targets" assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
def test_local_network_interfaces_subnet_masks(): def test_local_network_interfaces_subnet_masks():
@ -299,7 +317,7 @@ def test_local_network_interfaces_subnet_masks():
assert len(scan_targets) == 4 assert len(scan_targets) == 4
for ip in [108, 110, 145, 146]: for ip in [108, 110, 145, 146]:
assert f"172.60.145.{ip}" in scan_targets assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
def test_segmentation_targets(): def test_segmentation_targets():
@ -318,7 +336,7 @@ def test_segmentation_targets():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [144, 145, 146]: for ip in [144, 145, 146]:
assert f"172.60.145.{ip}" in scan_targets assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
def test_segmentation_clash_with_blocked(): def test_segmentation_clash_with_blocked():
@ -361,7 +379,7 @@ def test_segmentation_clash_with_targets():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [148, 149, 150]: for ip in [148, 149, 150]:
assert f"172.60.145.{ip}" in scan_targets assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
def test_segmentation_one_network(): def test_segmentation_one_network():
@ -428,7 +446,7 @@ def test_invalid_inputs():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [148, 149, 150]: for ip in [148, 149, 150]:
assert f"172.60.145.{ip}" in scan_targets assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
def test_range_filtering(): def test_range_filtering():
@ -454,7 +472,7 @@ def test_range_filtering():
"172.60.9.109 - 172.60.1.109", "172.60.9.109 - 172.60.1.109",
"172.60.9.109- 172.60.1.109", "172.60.9.109- 172.60.1.109",
"0.0.0.0", "0.0.0.0",
"localhost" "localhost",
] ]
invalid_ranges.extend(valid_ranges) invalid_ranges.extend(valid_ranges)