diff --git a/monkey/infection_monkey/utils/threading.py b/monkey/infection_monkey/utils/threading.py index 54bc469be..80b688759 100644 --- a/monkey/infection_monkey/utils/threading.py +++ b/monkey/infection_monkey/utils/threading.py @@ -1,14 +1,19 @@ import logging +from itertools import count from threading import Event, Thread -from typing import Any, Callable, Iterable, Tuple +from typing import Any, Callable, Iterable, Optional, Tuple logger = logging.getLogger(__name__) -def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2): +def run_worker_threads( + target: Callable[..., None], name_prefix: str = None, args: Tuple = (), num_workers: int = 2 +): worker_threads = [] + counter = run_worker_threads.counters.setdefault(name_prefix, count(start=1)) for i in range(0, num_workers): - t = create_daemon_thread(target=target, args=args) + name = None if name_prefix is None else f"{name_prefix}-{next(counter)}" + t = create_daemon_thread(target=target, name=name, args=args) t.start() worker_threads.append(t) @@ -16,8 +21,13 @@ def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_worker t.join() -def create_daemon_thread(target: Callable[..., None], args: Tuple = ()) -> Thread: - return Thread(target=target, args=args, daemon=True) +run_worker_threads.counters = {} + + +def create_daemon_thread( + target: Callable[..., None], name: Optional[str] = None, args: Tuple = () +) -> Thread: + return Thread(target=target, name=name, args=args, daemon=True) def interruptable_iter( diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py index 659fc7205..8b55cc9b5 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py @@ -1,7 +1,11 @@ import logging -from threading import Event +from threading import Event, current_thread -from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter +from infection_monkey.utils.threading import ( + create_daemon_thread, + interruptable_iter, + run_worker_threads, +) def test_create_daemon_thread(): @@ -9,6 +13,11 @@ def test_create_daemon_thread(): assert thread.daemon +def test_create_daemon_thread_naming(): + thread = create_daemon_thread(lambda: None, name="test") + assert thread.name == "test" + + def test_interruptable_iter(): interrupt = Event() items_from_iterator = [] @@ -45,3 +54,22 @@ def test_interruptable_iter_interrupted_before_used(): items_from_iterator.append(i) assert not items_from_iterator + + +def test_worker_thread_names(): + thread_names = set() + + def add_thread_name_to_list(): + thread_names.add(current_thread().name) + + run_worker_threads(target=add_thread_name_to_list, name_prefix="A", num_workers=2) + run_worker_threads(target=add_thread_name_to_list, name_prefix="B", num_workers=2) + run_worker_threads(target=add_thread_name_to_list, name_prefix="A", num_workers=2) + + assert "A-1" in thread_names + assert "A-2" in thread_names + assert "A-3" in thread_names + assert "A-4" in thread_names + assert "B-1" in thread_names + assert "B-2" in thread_names + assert len(thread_names) == 6