forked from p15670423/monkey
Agent: Add interruptable_iter() generator
This commit is contained in:
parent
f8ea2e06ac
commit
df42d0752a
|
@ -1,5 +1,8 @@
|
|||
from threading import Thread
|
||||
from typing import Callable, Tuple
|
||||
import logging
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Callable, Iterable, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
|
||||
|
@ -15,3 +18,27 @@ def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_worker
|
|||
|
||||
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
|
||||
return Thread(target=target, args=args, daemon=True)
|
||||
|
||||
|
||||
def interruptable_iter(
|
||||
iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG
|
||||
) -> Any:
|
||||
"""
|
||||
Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This
|
||||
is a convinient way to make loops interruptable and avoids the need to add an `if` to each and
|
||||
every loop.
|
||||
:param Iterable iterator: An iterator that will be made interruptable.
|
||||
:param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the
|
||||
iterator's items from being processed.
|
||||
:param str log_message: A message to be logged if the iterator is interrupted. If `log_message`
|
||||
is `None` (default), then no message is logged.
|
||||
:param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`.
|
||||
"""
|
||||
for i in iterator:
|
||||
if interrupt.is_set():
|
||||
if log_message:
|
||||
logger.log(log_level, log_message)
|
||||
|
||||
break
|
||||
|
||||
yield i
|
||||
|
|
|
@ -1,6 +1,47 @@
|
|||
from infection_monkey.master.threading_utils import create_daemon_thread
|
||||
import logging
|
||||
from threading import Event
|
||||
|
||||
from infection_monkey.master.threading_utils import create_daemon_thread, interruptable_iter
|
||||
|
||||
|
||||
def test_create_daemon_thread():
|
||||
thread = create_daemon_thread(lambda: None)
|
||||
assert thread.daemon
|
||||
|
||||
|
||||
def test_interruptable_iter():
|
||||
interrupt = Event()
|
||||
items_from_iterator = []
|
||||
test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted")
|
||||
|
||||
for i in test_iterator:
|
||||
items_from_iterator.append(i)
|
||||
if i == 3:
|
||||
interrupt.set()
|
||||
|
||||
assert items_from_iterator == [0, 1, 2, 3]
|
||||
|
||||
|
||||
def test_interruptable_iter_not_interrupted():
|
||||
interrupt = Event()
|
||||
items_from_iterator = []
|
||||
test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted")
|
||||
|
||||
for i in test_iterator:
|
||||
items_from_iterator.append(i)
|
||||
|
||||
assert items_from_iterator == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_interruptable_iter_interrupted_before_used():
|
||||
interrupt = Event()
|
||||
items_from_iterator = []
|
||||
test_iterator = interruptable_iter(
|
||||
range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO
|
||||
)
|
||||
|
||||
interrupt.set()
|
||||
for i in test_iterator:
|
||||
items_from_iterator.append(i)
|
||||
|
||||
assert not items_from_iterator
|
||||
|
|
Loading…
Reference in New Issue