diff --git a/monkey/infection_monkey/exploit/log4shell.py b/monkey/infection_monkey/exploit/log4shell.py index 2eada8a0d..48d331aaa 100644 --- a/monkey/infection_monkey/exploit/log4shell.py +++ b/monkey/infection_monkey/exploit/log4shell.py @@ -1,5 +1,4 @@ import logging -from threading import Thread from time import sleep from common.utils.exploit_enum import ExploitType @@ -45,7 +44,6 @@ class Log4ShellExploiter(WebRCE): self._ldap_server = None self._exploit_class_http_server = None - self._exploit_class_http_server_thread = None self._agent_http_server_thread = None def _exploit_host(self): @@ -83,14 +81,7 @@ class Log4ShellExploiter(WebRCE): self._exploit_class_http_server = ExploitClassHTTPServer( self._class_http_server_ip, self._class_http_server_port, java_class ) - # Setting `daemon=True` to save ourselves some trouble when this is merged to the - # agent-refactor branch. - # TODO: Make a call to `create_daemon_thread()` instead of calling the `Thread()` - # constructor directly after merging to the agent-refactor branch. - self._exploit_class_http_server_thread = Thread( - target=self._exploit_class_http_server.run, daemon=True - ) - self._exploit_class_http_server_thread.start() + self._exploit_class_http_server.run() def _start_ldap_server(self): self._ldap_server = LDAPExploitServer( @@ -99,7 +90,6 @@ class Log4ShellExploiter(WebRCE): http_server_port=self._class_http_server_port, storage_dir=get_monkey_dir_path(), ) - self._ldap_server.run() def _stop_servers(self): diff --git a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py index c10a70d2e..c93f910e6 100644 --- a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py +++ b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py @@ -38,6 +38,13 @@ class ExploitClassHTTPServer: self._initialize_http_handler(java_class) self._server = http.server.HTTPServer((ip, port), HTTPHandler) + # Setting `daemon=True` to save ourselves some trouble when this is merged to the + # agent-refactor branch. + # TODO: Make a call to `create_daemon_thread()` instead of calling the `Thread()` + # constructor directly after merging to the agent-refactor branch. + self._server_thread = threading.Thread( + target=self._server.serve_forever, args=(self._poll_interval,), daemon=True + ) def _initialize_http_handler(self, java_class: bytes): HTTPHandler.java_class = java_class @@ -47,12 +54,23 @@ class ExploitClassHTTPServer: logger.info("Starting ExploitClassHTTPServer") self._class_downloaded.clear() - self._server.serve_forever(self._poll_interval) - logger.debug("The Java Exploit class HTTP server has stopped") + # NOTE: Unlike in LDAPExploitServer, we theoretically don't need to worry about a race + # between when `serve_forever()` is ready to handle requests and when the victim machine + # sends its requests. See + # https://stackoverflow.com/questions/22606480/how-can-i-test-if-python-http-server-httpserver-is-serving-forever + # for more information. + self._server_thread.start() - def stop(self): - logger.debug("Stopping the Java Exploit class HTTP server") - self._server.shutdown() + def stop(self, timeout: float = None): + if self._server_thread.is_alive(): + logger.debug("Stopping the Java Exploit class HTTP server") + self._server.shutdown() + self._server_thread.join(timeout) + + if self._server_thread.is_alive(): + logger.warning("Timed out while waiting for The HTTP exploit server to stop") + else: + logger.debug("The Java Exploit class HTTP server has stopped") def exploit_class_downloaded(self) -> bool: return self._class_downloaded.is_set() diff --git a/monkey/infection_monkey/exploit/log4shell_utils/ldap_server.py b/monkey/infection_monkey/exploit/log4shell_utils/ldap_server.py index f0a4f3e18..0b29fd4cf 100644 --- a/monkey/infection_monkey/exploit/log4shell_utils/ldap_server.py +++ b/monkey/infection_monkey/exploit/log4shell_utils/ldap_server.py @@ -179,6 +179,6 @@ class LDAPExploitServer: self._server_process.join(timeout) if self._server_process.is_alive(): - logger.warning("Timed out while waiting for the LDAP exploit server to stop.") + logger.warning("Timed out while waiting for the LDAP exploit server to stop") else: logger.debug("Successfully stopped the LDAP exploit server") diff --git a/monkey/tests/unit_tests/infection_monkey/exploit/log4shell_utils/test_exploit_class_http_server.py b/monkey/tests/unit_tests/infection_monkey/exploit/log4shell_utils/test_exploit_class_http_server.py index d111ef825..b22ef41da 100644 --- a/monkey/tests/unit_tests/infection_monkey/exploit/log4shell_utils/test_exploit_class_http_server.py +++ b/monkey/tests/unit_tests/infection_monkey/exploit/log4shell_utils/test_exploit_class_http_server.py @@ -1,5 +1,3 @@ -import threading - import pytest import requests @@ -25,13 +23,11 @@ def java_class(): @pytest.fixture def server(ip, port, java_class): server = ExploitClassHTTPServer(ip, port, java_class, 0.01) - server_thread = threading.Thread(target=server.run) - server_thread.start() + server.run() yield server server.stop() - server_thread.join() @pytest.fixture