diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py index bc304d2d8..b31c21550 100644 --- a/monkey/infection_monkey/master/automated_master.py +++ b/monkey/infection_monkey/master/automated_master.py @@ -121,7 +121,9 @@ class AutomatedMaster(IMaster): # system_info_collector_thread.join() if self._can_propagate(): - propagation_thread = create_daemon_thread(target=self._propagate, args=(config,)) + propagation_thread = create_daemon_thread( + target=self._propagate, args=(config["propagation"],) + ) propagation_thread.start() propagation_thread.join() @@ -160,14 +162,12 @@ class AutomatedMaster(IMaster): return True # TODO: Refactor propagation into its own class - def _propagate(self, config: Dict): + def _propagate(self, propagation_config: Dict): logger.info("Attempting to propagate") self._hosts_to_exploit = Queue() - scan_thread = create_daemon_thread( - target=self._scan_network, args=(config["network_scan"],) - ) + scan_thread = create_daemon_thread(target=self._scan_network, args=(propagation_config,)) exploit_thread = create_daemon_thread(target=self._exploit_targets, args=(scan_thread,)) scan_thread.start() @@ -178,19 +178,14 @@ class AutomatedMaster(IMaster): logger.info("Finished attempting to propagate") - def _scan_network(self, scan_config: Dict): + def _scan_network(self, propagation_config: Dict): logger.info("Starting network scan") # TODO: Generate list of IPs to scan ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)] - self._ip_scanner.scan( - ips_to_scan, - scan_config["icmp"], - scan_config["tcp"], - self._handle_scanned_host, - self._stop, - ) + scan_config = propagation_config["network_scan"] + self._ip_scanner.scan(ips_to_scan, scan_config, self._handle_scanned_host, self._stop) logger.info("Finished network scan") diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py index 61329ef5d..4f438ccf3 100644 --- a/monkey/infection_monkey/master/ip_scanner.py +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -20,26 +20,14 @@ class IPScanner: self._puppet = puppet self._num_workers = num_workers - def scan( - self, - ips: List[str], - icmp_config: Dict, - tcp_config: Dict, - report_results_callback: Callback, - stop: Event, - ): - # Pre-fill a Queue with all IPs so that threads can safely exit when the queue is empty. - ips_to_scan = Queue() - for ip in ips: - ips_to_scan.put(ip) + def scan(self, ips_to_scan: List[str], options: Dict, results_callback: Callback, stop: Event): + # Pre-fill a Queue with all IPs to scan so that threads know they can safely exit when the + # queue is empty. + ips = Queue() + for ip in ips_to_scan: + ips.put(ip) - scan_ips_args = ( - ips_to_scan, - icmp_config, - tcp_config, - report_results_callback, - stop, - ) + scan_ips_args = (ips, options, results_callback, stop) scan_threads = [] for i in range(0, self._num_workers): t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args) @@ -49,27 +37,20 @@ class IPScanner: for t in scan_threads: t.join() - def _scan_ips( - self, - ips_to_scan: Queue, - icmp_config: Dict, - tcp_config: Dict, - report_results_callback: Callback, - stop: Event, - ): + def _scan_ips(self, ips: Queue, options: Dict, results_callback: Callback, stop: Event): logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") try: while not stop.is_set(): - ip = ips_to_scan.get_nowait() + ip = ips.get_nowait() logger.info(f"Scanning {ip}") victim_host = VictimHost(ip) - self._ping_ip(ip, victim_host, icmp_config) - self._scan_tcp_ports(ip, victim_host, tcp_config, stop) + self._ping_ip(ip, victim_host, options["icmp"]) + self._scan_tcp_ports(ip, victim_host, options["tcp"], stop) - report_results_callback(victim_host) + results_callback(victim_host) except queue.Empty: logger.debug( diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py b/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py index f73b5f39a..186d85be1 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py @@ -18,24 +18,22 @@ class MockPuppet(MockPuppet): @pytest.fixture -def tcp_scan_config(): +def scan_config(): return { - "timeout_ms": 3000, - "ports": [ - 22, - 445, - 3389, - 443, - 8008, - 3306, - ], - } - - -@pytest.fixture -def icmp_scan_config(): - return { - "timeout_ms": 1000, + "tcp": { + "timeout_ms": 3000, + "ports": [ + 22, + 445, + 3389, + 443, + 8008, + 3306, + ], + }, + "icmp": { + "timeout_ms": 1000, + }, } @@ -80,22 +78,22 @@ def assert_host_down(victim_host): assert len(victim_host.services.keys()) == 0 -def test_scan_single_ip(callback, icmp_scan_config, tcp_scan_config, stop): +def test_scan_single_ip(callback, scan_config, stop): ips = ["10.0.0.1"] ns = IPScanner(MockPuppet(), num_workers=1) - ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) + ns.scan(ips, scan_config, callback, stop) callback.assert_called_once() assert_dot_1(callback.call_args_list[0][0][0]) -def test_scan_multiple_ips(callback, icmp_scan_config, tcp_scan_config, stop): +def test_scan_multiple_ips(callback, scan_config, stop): ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] ns = IPScanner(MockPuppet(), num_workers=4) - ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) + ns.scan(ips, scan_config, callback, stop) assert callback.call_count == 4 @@ -105,7 +103,16 @@ def test_scan_multiple_ips(callback, icmp_scan_config, tcp_scan_config, stop): assert_host_down(callback.call_args_list[3][0][0]) -def test_stop_after_callback(icmp_scan_config, tcp_scan_config, stop): +def test_scan_lots_of_ips(callback, scan_config, stop): + ips = [f"10.0.0.{i}" for i in range(0, 255)] + + ns = IPScanner(MockPuppet(), num_workers=4) + ns.scan(ips, scan_config, callback, stop) + + assert callback.call_count == 255 + + +def test_stop_after_callback(scan_config, stop): def _callback(_): # Block all threads here until 2 threads reach this barrier, then set stop # and test that niether thread continues to scan. @@ -119,12 +126,12 @@ def test_stop_after_callback(icmp_scan_config, tcp_scan_config, stop): ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] ns = IPScanner(MockPuppet(), num_workers=2) - ns.scan(ips, icmp_scan_config, tcp_scan_config, stopable_callback, stop) + ns.scan(ips, scan_config, stopable_callback, stop) assert stopable_callback.call_count == 2 -def test_interrupt_port_scanning(callback, icmp_scan_config, tcp_scan_config, stop): +def test_interrupt_port_scanning(callback, scan_config, stop): def stopable_scan_tcp_port(port, _, __): # Block all threads here until 2 threads reach this barrier, then set stop # and test that niether thread scans any more ports @@ -141,6 +148,6 @@ def test_interrupt_port_scanning(callback, icmp_scan_config, tcp_scan_config, st ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] ns = IPScanner(puppet, num_workers=2) - ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) + ns.scan(ips, scan_config, callback, stop) assert puppet.scan_tcp_port.call_count == 2