From f55a3e483b8f761d00218ad62f8f2a02d96c4380 Mon Sep 17 00:00:00 2001
From: Daniel Goldberg <danielg@guardicore.com>
Date: Fri, 27 Sep 2019 18:10:59 +0300
Subject: [PATCH] Changed VictimHostGenerator to accept the local addresses
 rather than generating them itself. Changed UTs to be independent.

---
 .../model/victim_host_generator.py            |  7 ++---
 .../model/victim_host_generator_test.py       | 28 ++++++++++---------
 .../network/network_scanner.py                |  2 +-
 3 files changed, 19 insertions(+), 18 deletions(-)

diff --git a/monkey/infection_monkey/model/victim_host_generator.py b/monkey/infection_monkey/model/victim_host_generator.py
index 1309278c8..1e9eba9c2 100644
--- a/monkey/infection_monkey/model/victim_host_generator.py
+++ b/monkey/infection_monkey/model/victim_host_generator.py
@@ -1,12 +1,11 @@
 from infection_monkey.model.host import VictimHost
-from infection_monkey.network.info import local_ips
 
 
 class VictimHostGenerator(object):
-    def __init__(self, network_ranges, blocked_ips):
-        self._ip_addresses = local_ips()
+    def __init__(self, network_ranges, blocked_ips, same_machine_ips):
         self.blocked_ips = blocked_ips
         self.ranges = network_ranges
+        self.local_addresses = same_machine_ips
 
     def generate_victims(self, chunk_size):
         """
@@ -39,7 +38,7 @@ class VictimHostGenerator(object):
             yield victim
 
     def is_ip_scannable(self, ip_address):
-        if ip_address in self._ip_addresses:
+        if ip_address in self.local_addresses:
             return False
         if ip_address in self.blocked_ips:
             return False
diff --git a/monkey/infection_monkey/model/victim_host_generator_test.py b/monkey/infection_monkey/model/victim_host_generator_test.py
index 2414bd794..102014d45 100644
--- a/monkey/infection_monkey/model/victim_host_generator_test.py
+++ b/monkey/infection_monkey/model/victim_host_generator_test.py
@@ -6,39 +6,41 @@ from common.network.network_range import CidrRange, SingleIpRange
 class VictimHostGeneratorTester(TestCase):
 
     def setUp(self):
-        self.test_ranges = [CidrRange("10.0.0.0/28", False),  # this gives us 15 hosts
-                            SingleIpRange('41.50.13.37'),
-                            SingleIpRange('localhost')
-                            ]
-        self.generator = VictimHostGenerator(self.test_ranges, '10.0.0.1')
-        self.generator._ip_addresses = []  # test later on
+        self.cidr_range = CidrRange("10.0.0.0/28", False)  # this gives us 15 hosts
+        self.local_host_range = SingleIpRange('localhost')
+        self.random_single_ip_range = SingleIpRange('41.50.13.37')
 
     def test_chunking(self):
         chunk_size = 3
         # current test setup is 15+1+1-1 hosts
-        victims = self.generator.generate_victims(chunk_size)
+        test_ranges = [self.cidr_range, self.local_host_range, self.random_single_ip_range]
+        generator = VictimHostGenerator(test_ranges, '10.0.0.1', [])
+        victims = generator.generate_victims(chunk_size)
         for i in range(5):  # quickly check the equally sided chunks
             self.assertEqual(len(victims.next()), chunk_size)
         victim_chunk_last = victims.next()
         self.assertEqual(len(victim_chunk_last), 1)
 
     def test_remove_blocked_ip(self):
-        victims = list(self.generator.generate_victims_from_range(self.test_ranges[0]))
+        generator = VictimHostGenerator(self.cidr_range, ['10.0.0.1'], [])
+
+        victims = list(generator.generate_victims_from_range(self.cidr_range))
         self.assertEqual(len(victims), 14)  # 15 minus the 1 we blocked
 
     def test_remove_local_ips(self):
-        self.generator._ip_addresses = ['127.0.0.1']
-        victims = list(self.generator.generate_victims_from_range(self.test_ranges[-1]))
+        generator = VictimHostGenerator([], [], [])
+        generator.local_addresses = ['127.0.0.1']
+        victims = list(generator.generate_victims_from_range(self.local_host_range))
         self.assertEqual(len(victims), 0)  # block the local IP
 
     def test_generate_domain_victim(self):
         # domain name victim
-        self.generator._ip_addresses = []
-        victims = list(self.generator.generate_victims_from_range(self.test_ranges[-1]))
+        generator = VictimHostGenerator([], [], [])  # dummy object
+        victims = list(generator.generate_victims_from_range(self.local_host_range))
         self.assertEqual(len(victims), 1)
         self.assertEqual(victims[0].domain_name, 'localhost')
 
         # don't generate for other victims
-        victims = list(self.generator.generate_victims_from_range(self.test_ranges[1]))
+        victims = list(generator.generate_victims_from_range(self.random_single_ip_range))
         self.assertEqual(len(victims), 1)
         self.assertEqual(victims[0].domain_name, '')
diff --git a/monkey/infection_monkey/network/network_scanner.py b/monkey/infection_monkey/network/network_scanner.py
index 3e447b85e..9452a3fb8 100644
--- a/monkey/infection_monkey/network/network_scanner.py
+++ b/monkey/infection_monkey/network/network_scanner.py
@@ -78,7 +78,7 @@ class NetworkScanner(object):
         # Because we are using this to spread out IO heavy tasks, we can probably go a lot higher than CPU core size
         # But again, balance
         pool = Pool(ITERATION_BLOCK_SIZE)
-        victim_generator = VictimHostGenerator(self._ranges, WormConfiguration.blocked_ips)
+        victim_generator = VictimHostGenerator(self._ranges, WormConfiguration.blocked_ips, local_ips())
 
         victims_count = 0
         for victim_chunk in victim_generator.generate_victims(ITERATION_BLOCK_SIZE):