Agent: Add ThreadSafeIterator
This commit is contained in:
parent
191ee1a5f9
commit
a9edbb2874
|
@ -1,8 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from threading import Event, Thread
|
from threading import Event, Lock, Thread
|
||||||
from typing import Any, Callable, Iterable, Optional, Tuple
|
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -116,3 +116,19 @@ class InterruptableThreadMixin:
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop a running thread."""
|
"""Stop a running thread."""
|
||||||
self._interrupted.set()
|
self._interrupted.set()
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSafeIterator(Iterator[T]):
|
||||||
|
"""Provides a thread-safe iterator that wraps another iterator"""
|
||||||
|
|
||||||
|
def __init__(self, iterator: Iterator[T]):
|
||||||
|
self._lock = Lock()
|
||||||
|
self._iterator = iterator
|
||||||
|
|
||||||
|
def __next__(self) -> T:
|
||||||
|
while True:
|
||||||
|
with self._lock:
|
||||||
|
return next(self._iterator)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
|
from itertools import zip_longest
|
||||||
from threading import Event, current_thread
|
from threading import Event, current_thread
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from infection_monkey.utils.threading import (
|
from infection_monkey.utils.threading import (
|
||||||
|
ThreadSafeIterator,
|
||||||
create_daemon_thread,
|
create_daemon_thread,
|
||||||
interruptible_function,
|
interruptible_function,
|
||||||
interruptible_iter,
|
interruptible_iter,
|
||||||
|
@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt():
|
||||||
|
|
||||||
assert return_value == 777
|
assert return_value == 777
|
||||||
assert fn.call_count == 0
|
assert fn.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_safe_iterator():
|
||||||
|
test_list = [1, 2, 3, 4, 5]
|
||||||
|
tsi = ThreadSafeIterator(test_list.__iter__())
|
||||||
|
|
||||||
|
for actual, expected in zip_longest(tsi, test_list):
|
||||||
|
assert actual == expected
|
||||||
|
|
Loading…
Reference in New Issue