Agent: Extract run_worker_threads() from IPScanner and Exploiter

This commit is contained in:
Mike Salvatore 2021-12-14 15:18:50 -05:00
parent 3394629cb2
commit bda192eba9
3 changed files with 17 additions and 20 deletions

View File

@ -8,7 +8,7 @@ from typing import Callable, Dict, List
from infection_monkey.i_puppet import ExploiterResultData, IPuppet from infection_monkey.i_puppet import ExploiterResultData, IPuppet
from infection_monkey.model import VictimHost from infection_monkey.model import VictimHost
from .threading_utils import create_daemon_thread from .threading_utils import run_worker_threads
QUEUE_TIMEOUT = 2 QUEUE_TIMEOUT = 2
@ -40,16 +40,9 @@ class Exploiter:
) )
exploit_args = (exploiters_to_run, hosts_to_exploit, results_callback, scan_completed, stop) exploit_args = (exploiters_to_run, hosts_to_exploit, results_callback, scan_completed, stop)
run_worker_threads(
# TODO: This functionality is also used in IPScanner and can be generalized. Extract it. target=self._exploit_hosts_on_queue, args=exploit_args, num_workers=self._num_workers
exploiter_threads = [] )
for i in range(0, self._num_workers):
t = create_daemon_thread(target=self._exploit_hosts_on_queue, args=exploit_args)
t.start()
exploiter_threads.append(t)
for t in exploiter_threads:
t.join()
def _exploit_hosts_on_queue( def _exploit_hosts_on_queue(
self, self,

View File

@ -14,7 +14,7 @@ from infection_monkey.i_puppet import (
) )
from . import IPScanResults from . import IPScanResults
from .threading_utils import create_daemon_thread from .threading_utils import run_worker_threads
logger = logging.getLogger() logger = logging.getLogger()
@ -35,14 +35,7 @@ class IPScanner:
ips.put(ip) ips.put(ip)
scan_ips_args = (ips, options, results_callback, stop) scan_ips_args = (ips, options, results_callback, stop)
scan_threads = [] run_worker_threads(target=self._scan_ips, args=scan_ips_args, num_workers=self._num_workers)
for i in range(0, self._num_workers):
t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args)
t.start()
scan_threads.append(t)
for t in scan_threads:
t.join()
def _scan_ips(self, ips: Queue, options: Dict, results_callback: Callback, stop: Event): def _scan_ips(self, ips: Queue, options: Dict, results_callback: Callback, stop: Event):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")

View File

@ -2,5 +2,16 @@ from threading import Thread
from typing import Callable, Tuple from typing import Callable, Tuple
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
worker_threads = []
for i in range(0, num_workers):
t = create_daemon_thread(target=target, args=args)
t.start()
worker_threads.append(t)
for t in worker_threads:
t.join()
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()): def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
return Thread(target=target, args=args, daemon=True) return Thread(target=target, args=args, daemon=True)