diff --git a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py index 4c0eca054..ba72412b4 100644 --- a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py +++ b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py @@ -1,5 +1,6 @@ import http.server import logging +import threading logger = logging.getLogger(__name__) @@ -9,24 +10,16 @@ HTTP_TOO_MANY_REQUESTS_ERROR_CODE = 429 class HTTPHandler(http.server.BaseHTTPRequestHandler): java_class: bytes - class_downloaded = False - - @classmethod - def reset(cls): - cls.class_downloaded = False - - @classmethod - def _set_class_downloaded(cls): - cls.class_downloaded = True + class_downloaded: threading.Event def do_GET(self): - if HTTPHandler.class_downloaded: + if HTTPHandler.class_downloaded.is_set(): self.send_error( HTTP_TOO_MANY_REQUESTS_ERROR_CODE, "Java exploit class has already been downloaded" ) return - HTTPHandler._set_class_downloaded() + HTTPHandler.class_downloaded.set() logger.info("Java class servergot a GET request!") self.send_response(200) self.send_header("Content-type", "application/octet-stream") @@ -38,16 +31,17 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler): class ExploitClassHTTPServer: def __init__(self, ip: str, port: int, java_class: bytes, poll_interval: float = 0.5): logger.debug(f"The Java Exploit class will be served at {ip}:{port}") + self._class_downloaded = threading.Event() HTTPHandler.java_class = java_class - HTTPHandler.reset() + HTTPHandler.class_downloaded = self._class_downloaded self._server = http.server.HTTPServer((ip, port), HTTPHandler) self._poll_interval = poll_interval def run(self): logger.debug("Starting ExploitClassHTTPServer") - HTTPHandler.reset() + self._class_downloaded.clear() self._server.serve_forever(self._poll_interval) logger.debug("The Java Exploit class HTTP server has stopped") @@ -57,4 +51,4 @@ class ExploitClassHTTPServer: self._server.shutdown() def exploit_class_downloaded(self) -> bool: - return HTTPHandler.class_downloaded + return self._class_downloaded.is_set()