forked from p34709852/monkey
Agent: Add interruptable_iter() generator
This commit is contained in:
parent
f8ea2e06ac
commit
df42d0752a
|
@ -1,5 +1,8 @@
|
||||||
from threading import Thread
|
import logging
|
||||||
from typing import Callable, Tuple
|
from threading import Event, Thread
|
||||||
|
from typing import Any, Callable, Iterable, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
|
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
|
||||||
|
@ -15,3 +18,27 @@ def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_worker
|
||||||
|
|
||||||
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
|
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
|
||||||
return Thread(target=target, args=args, daemon=True)
|
return Thread(target=target, args=args, daemon=True)
|
||||||
|
|
||||||
|
|
||||||
|
def interruptable_iter(
|
||||||
|
iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This
|
||||||
|
is a convinient way to make loops interruptable and avoids the need to add an `if` to each and
|
||||||
|
every loop.
|
||||||
|
:param Iterable iterator: An iterator that will be made interruptable.
|
||||||
|
:param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the
|
||||||
|
iterator's items from being processed.
|
||||||
|
:param str log_message: A message to be logged if the iterator is interrupted. If `log_message`
|
||||||
|
is `None` (default), then no message is logged.
|
||||||
|
:param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`.
|
||||||
|
"""
|
||||||
|
for i in iterator:
|
||||||
|
if interrupt.is_set():
|
||||||
|
if log_message:
|
||||||
|
logger.log(log_level, log_message)
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
yield i
|
||||||
|
|
|
@ -1,6 +1,47 @@
|
||||||
from infection_monkey.master.threading_utils import create_daemon_thread
|
import logging
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
|
from infection_monkey.master.threading_utils import create_daemon_thread, interruptable_iter
|
||||||
|
|
||||||
|
|
||||||
def test_create_daemon_thread():
|
def test_create_daemon_thread():
|
||||||
thread = create_daemon_thread(lambda: None)
|
thread = create_daemon_thread(lambda: None)
|
||||||
assert thread.daemon
|
assert thread.daemon
|
||||||
|
|
||||||
|
|
||||||
|
def test_interruptable_iter():
|
||||||
|
interrupt = Event()
|
||||||
|
items_from_iterator = []
|
||||||
|
test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted")
|
||||||
|
|
||||||
|
for i in test_iterator:
|
||||||
|
items_from_iterator.append(i)
|
||||||
|
if i == 3:
|
||||||
|
interrupt.set()
|
||||||
|
|
||||||
|
assert items_from_iterator == [0, 1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_interruptable_iter_not_interrupted():
|
||||||
|
interrupt = Event()
|
||||||
|
items_from_iterator = []
|
||||||
|
test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted")
|
||||||
|
|
||||||
|
for i in test_iterator:
|
||||||
|
items_from_iterator.append(i)
|
||||||
|
|
||||||
|
assert items_from_iterator == [0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_interruptable_iter_interrupted_before_used():
|
||||||
|
interrupt = Event()
|
||||||
|
items_from_iterator = []
|
||||||
|
test_iterator = interruptable_iter(
|
||||||
|
range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO
|
||||||
|
)
|
||||||
|
|
||||||
|
interrupt.set()
|
||||||
|
for i in test_iterator:
|
||||||
|
items_from_iterator.append(i)
|
||||||
|
|
||||||
|
assert not items_from_iterator
|
||||||
|
|
Loading…
Reference in New Issue