From 20e3b20cb5b2d67e378738fd2aa10ae90bcb4be8 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Fri, 25 Mar 2022 10:50:10 -0400 Subject: [PATCH] Agent: Add interruptible_function decorator --- monkey/infection_monkey/utils/threading.py | 56 ++++++++++++++++++- .../infection_monkey/utils/test_threading.py | 54 ++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/monkey/infection_monkey/utils/threading.py b/monkey/infection_monkey/utils/threading.py index c7c1f7d58..0443978e6 100644 --- a/monkey/infection_monkey/utils/threading.py +++ b/monkey/infection_monkey/utils/threading.py @@ -1,7 +1,8 @@ import logging +from functools import wraps from itertools import count from threading import Event, Thread -from typing import Any, Callable, Iterable, Tuple +from typing import Any, Callable, Iterable, Optional, Tuple logger = logging.getLogger(__name__) @@ -53,3 +54,56 @@ def interruptible_iter( break yield i + + +def interruptible_function(*, msg: Optional[str] = None, default_return_value: Any = None): + """ + This decorator allows a function to be skipped if an interrupt (threading.Event) is set. This is + useful for interrupting running code without introducing duplicate `if` checks at the beginning + of each function. + + Note: It is required that the decorated function accept a keyword-only argument named + "interrupt". + + Example: + def run_algorithm(*inputs, interrupt: threading.Event): + return_value = do_action_1(inputs[1], interrupt=interrupt) + return_value = do_action_2(return_value + inputs[2], interrupt=interrupt) + return_value = do_action_3(return_value + inputs[3], interrupt=interrupt) + + return return_value + + @interruptible_function(msg="Interrupt detected, skipping action 1", default_return_value=0) + def do_action_1(input, *, interrupt: threading.Event): + # Process input + ... + + @interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0) + def do_action_2(input, *, interrupt: threading.Event): + # Process input + ... + + @interruptible_function(msg="Interrupt detected, skipping action 2", default_return_value=0) + def do_action_2(input, *, interrupt: threading.Event): + # Process input + ... + + :param str msg: A message to log at the debug level if an interrupt is detected. Defaults to + None. + :param Any default_return_value: A value to return if the wrapped function is not run. Defaults + to None. + """ + + def _decorator(fn): + @wraps(fn) + def _wrapper(*args, interrupt: Event, **kwargs): + if interrupt.is_set(): + if msg: + logger.debug(msg) + return default_return_value + + return fn(*args, interrupt=interrupt, **kwargs) + + return _wrapper + + return _decorator diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py index 7e04a1455..96a289096 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py @@ -1,8 +1,10 @@ import logging from threading import Event, current_thread +from typing import Any from infection_monkey.utils.threading import ( create_daemon_thread, + interruptible_function, interruptible_iter, run_worker_threads, ) @@ -73,3 +75,55 @@ def test_worker_thread_names(): assert "B-01" in thread_names assert "B-02" in thread_names assert len(thread_names) == 6 + + +class MockFunction: + def __init__(self): + self._call_count = 0 + + @property + def call_count(self): + return self._call_count + + @property + def return_value(self): + return 42 + + def __call__(self, *_, interrupt: Event) -> Any: + self._call_count += 1 + + return self.return_value + + +def test_interruptible_decorator_calls_decorated_function(): + fn = MockFunction() + int_fn = interruptible_function()(fn) + + return_value = int_fn(interrupt=Event()) + + assert return_value == fn.return_value + assert fn.call_count == 1 + + +def test_interruptible_decorator_skips_decorated_function(): + fn = MockFunction() + int_fn = interruptible_function()(fn) + interrupt = Event() + interrupt.set() + + return_value = int_fn(interrupt=interrupt) + + assert return_value is None + assert fn.call_count == 0 + + +def test_interruptible_decorator_returns_default_value_on_interrupt(): + fn = MockFunction() + int_fn = interruptible_function(default_return_value=777)(fn) + interrupt = Event() + interrupt.set() + + return_value = int_fn(interrupt=interrupt) + + assert return_value == 777 + assert fn.call_count == 0