Agent: Add interruptible_function decorator

This commit is contained in:
Mike Salvatore 2022-03-25 10:50:10 -04:00
parent 7c6ba2e276
commit 20e3b20cb5
2 changed files with 109 additions and 1 deletions

View File

@ -1,7 +1,8 @@
import logging
from functools import wraps
from itertools import count
from threading import Event, Thread
from typing import Any, Callable, Iterable, Tuple
from typing import Any, Callable, Iterable, Optional, Tuple
logger = logging.getLogger(__name__)
@ -53,3 +54,56 @@ def interruptible_iter(
break
yield i
def interruptible_function(*, msg: Optional[str] = None, default_return_value: Any = None):
"""
This decorator allows a function to be skipped if an interrupt (threading.Event) is set. This is
useful for interrupting running code without introducing duplicate `if` checks at the beginning
of each function.
Note: It is required that the decorated function accept a keyword-only argument named
"interrupt".
Example:
def run_algorithm(*inputs, interrupt: threading.Event):
return_value = do_action_1(inputs[1], interrupt=interrupt)
return_value = do_action_2(return_value + inputs[2], interrupt=interrupt)
return_value = do_action_3(return_value + inputs[3], interrupt=interrupt)
return return_value
@interruptible_function(msg="Interrupt detected, skipping action 1", default_return_value=0)
def do_action_1(input, *, interrupt: threading.Event):
# Process input
...
@interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0)
def do_action_2(input, *, interrupt: threading.Event):
# Process input
...
@interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0)
def do_action_2(input, *, interrupt: threading.Event):
# Process input
...
:param str msg: A message to log at the debug level if an interrupt is detected. Defaults to
None.
:param Any default_return_value: A value to return if the wrapped function is not run. Defaults
to None.
"""
def _decorator(fn):
@wraps(fn)
def _wrapper(*args, interrupt: Event, **kwargs):
if interrupt.is_set():
if msg:
logger.debug(msg)
return default_return_value
return fn(*args, interrupt=interrupt, **kwargs)
return _wrapper
return _decorator

View File

@ -1,8 +1,10 @@
import logging
from threading import Event, current_thread
from typing import Any
from infection_monkey.utils.threading import (
create_daemon_thread,
interruptible_function,
interruptible_iter,
run_worker_threads,
)
@ -73,3 +75,55 @@ def test_worker_thread_names():
assert "B-01" in thread_names
assert "B-02" in thread_names
assert len(thread_names) == 6
class MockFunction:
def __init__(self):
self._call_count = 0
@property
def call_count(self):
return self._call_count
@property
def return_value(self):
return 42
def __call__(self, *_, interrupt: Event) -> Any:
self._call_count += 1
return self.return_value
def test_interruptible_decorator_calls_decorated_function():
fn = MockFunction()
int_fn = interruptible_function()(fn)
return_value = int_fn(interrupt=Event())
assert return_value == fn.return_value
assert fn.call_count == 1
def test_interruptible_decorator_skips_decorated_function():
fn = MockFunction()
int_fn = interruptible_function()(fn)
interrupt = Event()
interrupt.set()
return_value = int_fn(interrupt=interrupt)
assert return_value is None
assert fn.call_count == 0
def test_interruptible_decorator_returns_default_value_on_interrupt():
fn = MockFunction()
int_fn = interruptible_function(default_return_value=777)(fn)
interrupt = Event()
interrupt.set()
return_value = int_fn(interrupt=interrupt)
assert return_value == 777
assert fn.call_count == 0