Agent: Add optional name to create_daemon_thread and run_worker_threads

This commit is contained in:
Mike Salvatore 2022-03-09 08:52:10 -05:00 committed by vakarisz
parent b34c287238
commit f9a7672767
2 changed files with 45 additions and 7 deletions

View File

@ -1,14 +1,19 @@
import logging import logging
from itertools import count
from threading import Event, Thread from threading import Event, Thread
from typing import Any, Callable, Iterable, Tuple from typing import Any, Callable, Iterable, Optional, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2): def run_worker_threads(
target: Callable[..., None], name_prefix: str = None, args: Tuple = (), num_workers: int = 2
):
worker_threads = [] worker_threads = []
counter = run_worker_threads.counters.setdefault(name_prefix, count(start=1))
for i in range(0, num_workers): for i in range(0, num_workers):
t = create_daemon_thread(target=target, args=args) name = None if name_prefix is None else f"{name_prefix}-{next(counter)}"
t = create_daemon_thread(target=target, name=name, args=args)
t.start() t.start()
worker_threads.append(t) worker_threads.append(t)
@ -16,8 +21,13 @@ def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_worker
t.join() t.join()
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()) -> Thread: run_worker_threads.counters = {}
return Thread(target=target, args=args, daemon=True)
def create_daemon_thread(
target: Callable[..., None], name: Optional[str] = None, args: Tuple = ()
) -> Thread:
return Thread(target=target, name=name, args=args, daemon=True)
def interruptable_iter( def interruptable_iter(

View File

@ -1,7 +1,11 @@
import logging import logging
from threading import Event from threading import Event, current_thread
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter from infection_monkey.utils.threading import (
create_daemon_thread,
interruptable_iter,
run_worker_threads,
)
def test_create_daemon_thread(): def test_create_daemon_thread():
@ -9,6 +13,11 @@ def test_create_daemon_thread():
assert thread.daemon assert thread.daemon
def test_create_daemon_thread_naming():
thread = create_daemon_thread(lambda: None, name="test")
assert thread.name == "test"
def test_interruptable_iter(): def test_interruptable_iter():
interrupt = Event() interrupt = Event()
items_from_iterator = [] items_from_iterator = []
@ -45,3 +54,22 @@ def test_interruptable_iter_interrupted_before_used():
items_from_iterator.append(i) items_from_iterator.append(i)
assert not items_from_iterator assert not items_from_iterator
def test_worker_thread_names():
thread_names = set()
def add_thread_name_to_list():
thread_names.add(current_thread().name)
run_worker_threads(target=add_thread_name_to_list, name_prefix="A", num_workers=2)
run_worker_threads(target=add_thread_name_to_list, name_prefix="B", num_workers=2)
run_worker_threads(target=add_thread_name_to_list, name_prefix="A", num_workers=2)
assert "A-1" in thread_names
assert "A-2" in thread_names
assert "A-3" in thread_names
assert "A-4" in thread_names
assert "B-1" in thread_names
assert "B-2" in thread_names
assert len(thread_names) == 6