From 325c4368deb2e5340f7dddc8f63661d0d54045eb Mon Sep 17 00:00:00 2001
From: vakaris_zilius <vakarisz@yahoo.com>
Date: Mon, 21 Mar 2022 16:09:43 +0000
Subject: [PATCH] Agent: Remove unnecessary interrupts from log4shell

---
 monkey/infection_monkey/exploit/log4shell.py | 20 +++++---------------
 1 file changed, 5 insertions(+), 15 deletions(-)

diff --git a/monkey/infection_monkey/exploit/log4shell.py b/monkey/infection_monkey/exploit/log4shell.py
index 95dc773f4..d2e3ef2a5 100644
--- a/monkey/infection_monkey/exploit/log4shell.py
+++ b/monkey/infection_monkey/exploit/log4shell.py
@@ -18,6 +18,7 @@ from infection_monkey.network.info import get_free_tcp_port
 from infection_monkey.network.tools import get_interface_to_target
 from infection_monkey.utils.commands import build_monkey_commandline
 from infection_monkey.utils.monkey_dir import get_monkey_dir_path
+from infection_monkey.utils.threading import interruptable_iter
 from infection_monkey.utils.timer import Timer
 
 logger = logging.getLogger(__name__)
@@ -42,6 +43,8 @@ class Log4ShellExploiter(WebRCE):
         self._start_servers()
         try:
             self.exploit(None, None)
+            if self._is_interrupted():
+                self._set_interrupted()
             return self.exploit_result
         finally:
             self._stop_servers()
@@ -133,11 +136,8 @@ class Log4ShellExploiter(WebRCE):
         # Try to exploit all services,
         # because we don't know which services are running and on which ports
         for exploit in get_log4shell_service_exploiters():
-            for port in self._open_ports:
-
-                if self._is_interrupted():
-                    self._set_interrupted()
-                    return self.exploit_result
+            intr_ports = interruptable_iter(self._open_ports, self.interrupt)
+            for port in intr_ports:
 
                 logger.debug(
                     f'Attempting Log4Shell exploit on for service "{exploit.service_name}"'
@@ -151,10 +151,6 @@ class Log4ShellExploiter(WebRCE):
                         f"potential {exploit.service_name} service: {ex}"
                     )
 
-                if self._is_interrupted():
-                    self._set_interrupted()
-                    return self.exploit_result
-
                 if self._wait_for_victim():
                     self.exploit_info["vulnerable_service"] = {
                         "service_name": exploit.service_name,
@@ -168,9 +164,6 @@ class Log4ShellExploiter(WebRCE):
         if victim_called_back:
             self._wait_for_victim_to_download_agent()
 
-        if self._is_interrupted():
-            return False
-
         return victim_called_back
 
     def _wait_for_victim_to_download_java_bytecode(self) -> bool:
@@ -196,8 +189,5 @@ class Log4ShellExploiter(WebRCE):
             if self._agent_http_server_thread.downloads > 0:
                 break
 
-            if self._is_interrupted():
-                return
-
             # TODO: if the http server got an error we're waiting for nothing here
             time.sleep(1)