diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index a7576f258..20ed730a8 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -32,7 +32,7 @@ from infection_monkey.telemetry.tunnel_telem import TunnelTelem from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.monkey_dir import get_monkey_dir_path, remove_monkey_dir from infection_monkey.utils.monkey_log_path import get_monkey_log_path -from infection_monkey.utils.signal_handler import register_signal_handlers +from infection_monkey.utils.signal_handler import register_signal_handlers, reset_signal_handlers from infection_monkey.windows_upgrader import WindowsUpgrader logger = logging.getLogger(__name__) @@ -215,6 +215,8 @@ class InfectionMonkey: if self._master: self._master.cleanup() + reset_signal_handlers() + if self._monkey_inbound_tunnel: self._monkey_inbound_tunnel.stop() self._monkey_inbound_tunnel.join() diff --git a/monkey/infection_monkey/utils/signal_handler.py b/monkey/infection_monkey/utils/signal_handler.py index 831b31441..202c27489 100644 --- a/monkey/infection_monkey/utils/signal_handler.py +++ b/monkey/infection_monkey/utils/signal_handler.py @@ -1,5 +1,6 @@ import logging import signal +from typing import Optional from infection_monkey.i_master import IMaster from infection_monkey.utils.environment import is_windows_os @@ -7,18 +8,25 @@ from infection_monkey.utils.environment import is_windows_os logger = logging.getLogger(__name__) +_signal_handler = None + + class StopSignalHandler: def __init__(self, master: IMaster): self._master = master - def handle_posix_signals(self, signum: int, _): - self._handle_signal(signum, False) + # Windows won't let us correctly deregister a method, but Callables and closures work. + def __call__(self, signum: int, *args) -> Optional[bool]: + if is_windows_os(): + return self._handle_windows_signals(signum) + else: + self._handle_posix_signals(signum, args) - def handle_windows_signals(self, signum: int): + def _handle_windows_signals(self, signum: int) -> bool: import win32con if signum in {win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT}: - self._handle_signal(signum, False) + self._terminate_master(signum, False) return True if signum == win32con.CTRL_CLOSE_EVENT: @@ -26,25 +34,44 @@ class StopSignalHandler: # Calling self._handle_signal() with block=True to give the master a chance to # gracefully shut down. Note that the OS has a timeout that will forcefully kill the # process if this handler hasn't returned in time. - self._handle_signal(signum, True) + self._terminate_master(signum, True) return True return False - def _handle_signal(self, signum: int, block: bool): + def _handle_posix_signals(self, signum: int, *_): + self._terminate_master(signum, False) + + def _terminate_master(self, signum: int, block: bool): logger.info(f"The Monkey Agent received signal {signum}") self._master.terminate(block) def register_signal_handlers(master: IMaster): - stop_signal_handler = StopSignalHandler(master) + global _signal_handler + _signal_handler = StopSignalHandler(master) if is_windows_os(): import win32api # CTRL_CLOSE_EVENT signal has a timeout of 5000ms, # after that OS will forcefully kill the process - win32api.SetConsoleCtrlHandler(stop_signal_handler.handle_windows_signals, True) + win32api.SetConsoleCtrlHandler(_signal_handler, True) else: - signal.signal(signal.SIGINT, stop_signal_handler.handle_posix_signals) - signal.signal(signal.SIGTERM, stop_signal_handler.handle_posix_signals) + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + + +def reset_signal_handlers(): + """ + Resets the signal handlers back to the default handlers provided by Python + """ + global _signal_handler + + if is_windows_os(): + import win32api + + win32api.SetConsoleCtrlHandler(_signal_handler, False) + else: + signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGTERM, signal.SIG_DFL)