diff --git a/monkey/infection_monkey/master/exploiter.py b/monkey/infection_monkey/master/exploiter.py index 3f732ffa3..0dfc8869f 100644 --- a/monkey/infection_monkey/master/exploiter.py +++ b/monkey/infection_monkey/master/exploiter.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, List from infection_monkey.i_puppet import ExploiterResultData, IPuppet from infection_monkey.model import VictimHost -from .threading_utils import create_daemon_thread +from .threading_utils import run_worker_threads QUEUE_TIMEOUT = 2 @@ -40,16 +40,9 @@ class Exploiter: ) exploit_args = (exploiters_to_run, hosts_to_exploit, results_callback, scan_completed, stop) - - # TODO: This functionality is also used in IPScanner and can be generalized. Extract it. - exploiter_threads = [] - for i in range(0, self._num_workers): - t = create_daemon_thread(target=self._exploit_hosts_on_queue, args=exploit_args) - t.start() - exploiter_threads.append(t) - - for t in exploiter_threads: - t.join() + run_worker_threads( + target=self._exploit_hosts_on_queue, args=exploit_args, num_workers=self._num_workers + ) def _exploit_hosts_on_queue( self, diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py index 450ff3006..9e5851e7b 100644 --- a/monkey/infection_monkey/master/ip_scanner.py +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -14,7 +14,7 @@ from infection_monkey.i_puppet import ( ) from . import IPScanResults -from .threading_utils import create_daemon_thread +from .threading_utils import run_worker_threads logger = logging.getLogger() @@ -35,14 +35,7 @@ class IPScanner: ips.put(ip) scan_ips_args = (ips, options, results_callback, stop) - scan_threads = [] - for i in range(0, self._num_workers): - t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args) - t.start() - scan_threads.append(t) - - for t in scan_threads: - t.join() + run_worker_threads(target=self._scan_ips, args=scan_ips_args, num_workers=self._num_workers) def _scan_ips(self, ips: Queue, options: Dict, results_callback: Callback, stop: Event): logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") diff --git a/monkey/infection_monkey/master/threading_utils.py b/monkey/infection_monkey/master/threading_utils.py index 56cf4a459..dbcc67984 100644 --- a/monkey/infection_monkey/master/threading_utils.py +++ b/monkey/infection_monkey/master/threading_utils.py @@ -2,5 +2,16 @@ from threading import Thread from typing import Callable, Tuple +def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2): + worker_threads = [] + for i in range(0, num_workers): + t = create_daemon_thread(target=target, args=args) + t.start() + worker_threads.append(t) + + for t in worker_threads: + t.join() + + def create_daemon_thread(target: Callable[..., None], args: Tuple = ()): return Thread(target=target, args=args, daemon=True)