forked from p15670423/monkey
Agent: Add interruptible_function decorator
This commit is contained in:
parent
7c6ba2e276
commit
20e3b20cb5
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
from functools import wraps
|
||||
from itertools import count
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Callable, Iterable, Tuple
|
||||
from typing import Any, Callable, Iterable, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,3 +54,56 @@ def interruptible_iter(
|
|||
break
|
||||
|
||||
yield i
|
||||
|
||||
|
||||
def interruptible_function(*, msg: Optional[str] = None, default_return_value: Any = None):
|
||||
"""
|
||||
This decorator allows a function to be skipped if an interrupt (threading.Event) is set. This is
|
||||
useful for interrupting running code without introducing duplicate `if` checks at the beginning
|
||||
of each function.
|
||||
|
||||
Note: It is required that the decorated function accept a keyword-only argument named
|
||||
"interrupt".
|
||||
|
||||
Example:
|
||||
def run_algorithm(*inputs, interrupt: threading.Event):
|
||||
return_value = do_action_1(inputs[1], interrupt=interrupt)
|
||||
return_value = do_action_2(return_value + inputs[2], interrupt=interrupt)
|
||||
return_value = do_action_3(return_value + inputs[3], interrupt=interrupt)
|
||||
|
||||
return return_value
|
||||
|
||||
@interruptible_function(msg="Interrupt detected, skipping action 1", default_return_value=0)
|
||||
def do_action_1(input, *, interrupt: threading.Event):
|
||||
# Process input
|
||||
...
|
||||
|
||||
@interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0)
|
||||
def do_action_2(input, *, interrupt: threading.Event):
|
||||
# Process input
|
||||
...
|
||||
|
||||
@interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0)
|
||||
def do_action_2(input, *, interrupt: threading.Event):
|
||||
# Process input
|
||||
...
|
||||
|
||||
:param str msg: A message to log at the debug level if an interrupt is detected. Defaults to
|
||||
None.
|
||||
:param Any default_return_value: A value to return if the wrapped function is not run. Defaults
|
||||
to None.
|
||||
"""
|
||||
|
||||
def _decorator(fn):
|
||||
@wraps(fn)
|
||||
def _wrapper(*args, interrupt: Event, **kwargs):
|
||||
if interrupt.is_set():
|
||||
if msg:
|
||||
logger.debug(msg)
|
||||
return default_return_value
|
||||
|
||||
return fn(*args, interrupt=interrupt, **kwargs)
|
||||
|
||||
return _wrapper
|
||||
|
||||
return _decorator
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from threading import Event, current_thread
|
||||
from typing import Any
|
||||
|
||||
from infection_monkey.utils.threading import (
|
||||
create_daemon_thread,
|
||||
interruptible_function,
|
||||
interruptible_iter,
|
||||
run_worker_threads,
|
||||
)
|
||||
|
@ -73,3 +75,55 @@ def test_worker_thread_names():
|
|||
assert "B-01" in thread_names
|
||||
assert "B-02" in thread_names
|
||||
assert len(thread_names) == 6
|
||||
|
||||
|
||||
class MockFunction:
|
||||
def __init__(self):
|
||||
self._call_count = 0
|
||||
|
||||
@property
|
||||
def call_count(self):
|
||||
return self._call_count
|
||||
|
||||
@property
|
||||
def return_value(self):
|
||||
return 42
|
||||
|
||||
def __call__(self, *_, interrupt: Event) -> Any:
|
||||
self._call_count += 1
|
||||
|
||||
return self.return_value
|
||||
|
||||
|
||||
def test_interruptible_decorator_calls_decorated_function():
|
||||
fn = MockFunction()
|
||||
int_fn = interruptible_function()(fn)
|
||||
|
||||
return_value = int_fn(interrupt=Event())
|
||||
|
||||
assert return_value == fn.return_value
|
||||
assert fn.call_count == 1
|
||||
|
||||
|
||||
def test_interruptible_decorator_skips_decorated_function():
|
||||
fn = MockFunction()
|
||||
int_fn = interruptible_function()(fn)
|
||||
interrupt = Event()
|
||||
interrupt.set()
|
||||
|
||||
return_value = int_fn(interrupt=interrupt)
|
||||
|
||||
assert return_value is None
|
||||
assert fn.call_count == 0
|
||||
|
||||
|
||||
def test_interruptible_decorator_returns_default_value_on_interrupt():
|
||||
fn = MockFunction()
|
||||
int_fn = interruptible_function(default_return_value=777)(fn)
|
||||
interrupt = Event()
|
||||
interrupt.set()
|
||||
|
||||
return_value = int_fn(interrupt=interrupt)
|
||||
|
||||
assert return_value == 777
|
||||
assert fn.call_count == 0
|
||||
|
|
Loading…
Reference in New Issue