Agent: Add interruptable_iter() generator

This commit is contained in:
Mike Salvatore 2022-01-24 08:43:57 -05:00
parent f8ea2e06ac
commit df42d0752a
2 changed files with 71 additions and 3 deletions

View File

@ -1,5 +1,8 @@
from threading import Thread
from typing import Callable, Tuple
import logging
from threading import Event, Thread
from typing import Any, Callable, Iterable, Tuple
logger = logging.getLogger(__name__)
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
@ -15,3 +18,27 @@ def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_worker
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
return Thread(target=target, args=args, daemon=True)
def interruptable_iter(
iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG
) -> Any:
"""
Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This
is a convinient way to make loops interruptable and avoids the need to add an `if` to each and
every loop.
:param Iterable iterator: An iterator that will be made interruptable.
:param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the
iterator's items from being processed.
:param str log_message: A message to be logged if the iterator is interrupted. If `log_message`
is `None` (default), then no message is logged.
:param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`.
"""
for i in iterator:
if interrupt.is_set():
if log_message:
logger.log(log_level, log_message)
break
yield i

View File

@ -1,6 +1,47 @@
from infection_monkey.master.threading_utils import create_daemon_thread
import logging
from threading import Event
from infection_monkey.master.threading_utils import create_daemon_thread, interruptable_iter
def test_create_daemon_thread():
thread = create_daemon_thread(lambda: None)
assert thread.daemon
def test_interruptable_iter():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted")
for i in test_iterator:
items_from_iterator.append(i)
if i == 3:
interrupt.set()
assert items_from_iterator == [0, 1, 2, 3]
def test_interruptable_iter_not_interrupted():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted")
for i in test_iterator:
items_from_iterator.append(i)
assert items_from_iterator == [0, 1, 2, 3, 4]
def test_interruptable_iter_interrupted_before_used():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(
range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO
)
interrupt.set()
for i in test_iterator:
items_from_iterator.append(i)
assert not items_from_iterator