From d31fd2c811e7213ec711f526e8937f142941277b Mon Sep 17 00:00:00 2001
From: Mike Salvatore <mike.s.salvatore@gmail.com>
Date: Wed, 24 Nov 2021 12:51:33 -0500
Subject: [PATCH] Agent: Improve Windows signal handler

---
 .../infection_monkey/utils/signal_handler.py  | 31 ++++++++++++++-----
 1 file changed, 24 insertions(+), 7 deletions(-)

diff --git a/monkey/infection_monkey/utils/signal_handler.py b/monkey/infection_monkey/utils/signal_handler.py
index d75b08f10..87d965398 100644
--- a/monkey/infection_monkey/utils/signal_handler.py
+++ b/monkey/infection_monkey/utils/signal_handler.py
@@ -12,22 +12,39 @@ class StopSignalHandler:
     def __init__(self, master: IMaster):
         self._master = master
 
-    def __call__(self, signum, _=None):
+    def handle_posix_signals(self, signum, _):
+        self._handle_signal(signum)
+        # Windows signal handlers must return boolean. Only raising this exception for POSIX
+        # signals.
+        raise PlannedShutdownException("Monkey Agent got an interrupt signal")
+
+    def handle_windows_signals(self, signum):
+        import win32con
+
+        # TODO: This signal handler gets called for a CTRL_CLOSE_EVENT, but the system immediately
+        #       kills the process after the handler returns. After the master is implemented and the
+        #       setup/teardown of the Agent is fully refactored, revisit this signal handler and
+        #       modify as necessary to more gracefully handle CTRL_CLOSE_EVENT signals.
+        if signum in {win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT, win32con.CTRL_CLOSE_EVENT}:
+            self._handle_signal(signum)
+            return True
+
+        return False
+
+    def _handle_signal(self, signum):
         logger.info(f"The Monkey Agent received signal {signum}")
         self._master.terminate()
-        raise PlannedShutdownException("Monkey Agent got an interrupt signal")
 
 
 def register_signal_handlers(master: IMaster):
     stop_signal_handler = StopSignalHandler(master)
-    signal.signal(signal.SIGINT, stop_signal_handler)
-    signal.signal(signal.SIGTERM, stop_signal_handler)
 
     if is_windows_os():
         import win32api
 
-        signal.signal(signal.SIGBREAK, stop_signal_handler)
-
         # CTRL_CLOSE_EVENT signal has a timeout of 5000ms,
         # after that OS will forcefully kill the process
-        win32api.SetConsoleCtrlHandler(stop_signal_handler, True)
+        win32api.SetConsoleCtrlHandler(stop_signal_handler.handle_windows_signals, True)
+    else:
+        signal.signal(signal.SIGINT, stop_signal_handler.handle_posix_signals)
+        signal.signal(signal.SIGTERM, stop_signal_handler.handle_posix_signals)