diff --git a/monkey/infection_monkey/master/threading_utils.py b/monkey/infection_monkey/master/threading_utils.py index dbcc67984..9ceec895f 100644 --- a/monkey/infection_monkey/master/threading_utils.py +++ b/monkey/infection_monkey/master/threading_utils.py @@ -1,5 +1,8 @@ -from threading import Thread -from typing import Callable, Tuple +import logging +from threading import Event, Thread +from typing import Any, Callable, Iterable, Tuple + +logger = logging.getLogger(__name__) def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2): @@ -15,3 +18,27 @@ def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_worker def create_daemon_thread(target: Callable[..., None], args: Tuple = ()): return Thread(target=target, args=args, daemon=True) + + +def interruptable_iter( + iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG +) -> Any: + """ + Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This + is a convinient way to make loops interruptable and avoids the need to add an `if` to each and + every loop. + :param Iterable iterator: An iterator that will be made interruptable. + :param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the + iterator's items from being processed. + :param str log_message: A message to be logged if the iterator is interrupted. If `log_message` + is `None` (default), then no message is logged. + :param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`. + """ + for i in iterator: + if interrupt.is_set(): + if log_message: + logger.log(log_level, log_message) + + break + + yield i diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_threading_utils.py b/monkey/tests/unit_tests/infection_monkey/master/test_threading_utils.py index 73fd7bad9..11d4fdf61 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_threading_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_threading_utils.py @@ -1,6 +1,47 @@ -from infection_monkey.master.threading_utils import create_daemon_thread +import logging +from threading import Event + +from infection_monkey.master.threading_utils import create_daemon_thread, interruptable_iter def test_create_daemon_thread(): thread = create_daemon_thread(lambda: None) assert thread.daemon + + +def test_interruptable_iter(): + interrupt = Event() + items_from_iterator = [] + test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted") + + for i in test_iterator: + items_from_iterator.append(i) + if i == 3: + interrupt.set() + + assert items_from_iterator == [0, 1, 2, 3] + + +def test_interruptable_iter_not_interrupted(): + interrupt = Event() + items_from_iterator = [] + test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted") + + for i in test_iterator: + items_from_iterator.append(i) + + assert items_from_iterator == [0, 1, 2, 3, 4] + + +def test_interruptable_iter_interrupted_before_used(): + interrupt = Event() + items_from_iterator = [] + test_iterator = interruptable_iter( + range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO + ) + + interrupt.set() + for i in test_iterator: + items_from_iterator.append(i) + + assert not items_from_iterator