Agent: Add ThreadSafeIterator

This commit is contained in:
Mike Salvatore 2022-09-12 09:55:04 -04:00
parent 191ee1a5f9
commit a9edbb2874
2 changed files with 28 additions and 2 deletions

View File

@ -1,8 +1,8 @@
import logging import logging
from functools import wraps from functools import wraps
from itertools import count from itertools import count
from threading import Event, Thread from threading import Event, Lock, Thread
from typing import Any, Callable, Iterable, Optional, Tuple from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -116,3 +116,19 @@ class InterruptableThreadMixin:
def stop(self): def stop(self):
"""Stop a running thread.""" """Stop a running thread."""
self._interrupted.set() self._interrupted.set()
T = TypeVar("T")
class ThreadSafeIterator(Iterator[T]):
"""Provides a thread-safe iterator that wraps another iterator"""
def __init__(self, iterator: Iterator[T]):
self._lock = Lock()
self._iterator = iterator
def __next__(self) -> T:
while True:
with self._lock:
return next(self._iterator)

View File

@ -1,8 +1,10 @@
import logging import logging
from itertools import zip_longest
from threading import Event, current_thread from threading import Event, current_thread
from typing import Any from typing import Any
from infection_monkey.utils.threading import ( from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread, create_daemon_thread,
interruptible_function, interruptible_function,
interruptible_iter, interruptible_iter,
@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt():
assert return_value == 777 assert return_value == 777
assert fn.call_count == 0 assert fn.call_count == 0
def test_thread_safe_iterator():
test_list = [1, 2, 3, 4, 5]
tsi = ThreadSafeIterator(test_list.__iter__())
for actual, expected in zip_longest(tsi, test_list):
assert actual == expected