diff --git a/monkey/infection_monkey/network/network_scanner.py b/monkey/infection_monkey/network/network_scanner.py index 68abcb786..b751e99d7 100644 --- a/monkey/infection_monkey/network/network_scanner.py +++ b/monkey/infection_monkey/network/network_scanner.py @@ -21,19 +21,24 @@ SCAN_DELAY = 0 ITERATION_BLOCK_SIZE = 5 -def _grouper(iterables, chunk_size): +def generate_victims(net_ranges, chunk_size): """ - Goes over an iterable using chunks - :param iterables: a sequence of iterable objects - :param chunk_size: Chunk size, last chunk may be smaller + Generates VictimHosts in chunks from all the netranges + :param net_ranges: Iterable of network ranges + :param chunk_size: Maximum size of each chunk :return: """ - iterable = itertools.chain(*iterables) - while True: - group = tuple(itertools.islice(iterable, chunk_size)) - if not group: - break - yield group + chunk = [] + for net_range in net_ranges: + for address in net_range: + if hasattr(net_range, 'domain_name'): + victim = VictimHost(address, net_range.domain_name) + else: + victim = VictimHost(address) + chunk.append(victim) + if len(chunk) == chunk_size: + yield chunk + yield chunk class NetworkScanner(object): @@ -94,16 +99,9 @@ class NetworkScanner(object): """ pool = Pool() victims_count = 0 - for network_chunk in _grouper(self._ranges, ITERATION_BLOCK_SIZE): - LOG.debug("Scanning for potential victims in chunk %r", network_chunk) - victim_chunk = [] - for address in network_chunk: - # if hasattr(net_range, 'domain_name'): - # victim = VictimHost(address, net_range.domain_name) - # else: - victim = VictimHost(address) + for victim_chunk in generate_victims(self._ranges, ITERATION_BLOCK_SIZE): + LOG.debug("Scanning for potential victims in chunk %r", victim_chunk) - victim_chunk.append(victim) # skip self IP addresses victim_chunk = [x for x in victim_chunk if x.ip_addr not in self._ip_addresses] # skip IPs marked as blocked @@ -133,6 +131,7 @@ class NetworkScanner(object): # time.sleep uses seconds, while config is in milliseconds time.sleep(WormConfiguration.tcp_scan_interval / float(1000)) + @staticmethod def _is_any_ip_in_subnet(ip_addresses, subnet_str): for ip_address in ip_addresses: @@ -140,6 +139,7 @@ class NetworkScanner(object): return True return False + def scan_machine(self, victim): """ Scans specific machine using given scanner @@ -153,5 +153,6 @@ class NetworkScanner(object): else: return None + def on_island(self, server): return bool([x for x in self._ip_addresses if x in server])