diff --git a/monkey/infection_monkey/utils/threading.py b/monkey/infection_monkey/utils/threading.py index be28aa0b1..6d8b28253 100644 --- a/monkey/infection_monkey/utils/threading.py +++ b/monkey/infection_monkey/utils/threading.py @@ -1,8 +1,8 @@ import logging from functools import wraps from itertools import count -from threading import Event, Thread -from typing import Any, Callable, Iterable, Optional, Tuple +from threading import Event, Lock, Thread +from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar logger = logging.getLogger(__name__) @@ -116,3 +116,19 @@ class InterruptableThreadMixin: def stop(self): """Stop a running thread.""" self._interrupted.set() + + +T = TypeVar("T") + + +class ThreadSafeIterator(Iterator[T]): + """Provides a thread-safe iterator that wraps another iterator""" + + def __init__(self, iterator: Iterator[T]): + self._lock = Lock() + self._iterator = iterator + + def __next__(self) -> T: + while True: + with self._lock: + return next(self._iterator) diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py index 96a289096..05b813b66 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py @@ -1,8 +1,10 @@ import logging +from itertools import zip_longest from threading import Event, current_thread from typing import Any from infection_monkey.utils.threading import ( + ThreadSafeIterator, create_daemon_thread, interruptible_function, interruptible_iter, @@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt(): assert return_value == 777 assert fn.call_count == 0 + + +def test_thread_safe_iterator(): + test_list = [1, 2, 3, 4, 5] + tsi = ThreadSafeIterator(test_list.__iter__()) + + for actual, expected in zip_longest(tsi, test_list): + assert actual == expected