diff --git a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py index 9b82de1d6..6b956529b 100644 --- a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py +++ b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py @@ -14,15 +14,25 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler): java_class: bytes class_downloaded: threading.Event + download_lock: threading.Lock + + @classmethod + def initialize(cls, java_class: bytes, class_downloaded: threading.Event): + cls.java_class = java_class + cls.class_downloaded = class_downloaded + cls.download_lock = threading.Lock() def do_GET(self): - if HTTPHandler.class_downloaded.is_set(): - self.send_error( - HTTP_TOO_MANY_REQUESTS_ERROR_CODE, "Java exploit class has already been downloaded" - ) - return + with HTTPHandler.download_lock: + if HTTPHandler.class_downloaded.is_set(): + self.send_error( + HTTP_TOO_MANY_REQUESTS_ERROR_CODE, + "Java exploit class has already been downloaded", + ) + return + + HTTPHandler.class_downloaded.set() - HTTPHandler.class_downloaded.set() logger.info("Java class servergot a GET request!") self.send_response(200) self.send_header("Content-type", "application/octet-stream") @@ -52,7 +62,7 @@ class ExploitClassHTTPServer: self._class_downloaded = threading.Event() self._poll_interval = poll_interval - self._initialize_http_handler(java_class) + HTTPHandler.initialize(java_class, self._class_downloaded) self._server = http.server.HTTPServer((ip, port), HTTPHandler) # Setting `daemon=True` to save ourselves some trouble when this is merged to the @@ -63,10 +73,6 @@ class ExploitClassHTTPServer: target=self._server.serve_forever, args=(self._poll_interval,), daemon=True ) - def _initialize_http_handler(self, java_class: bytes): - HTTPHandler.java_class = java_class - HTTPHandler.class_downloaded = self._class_downloaded - def run(self): """ Runs the HTTP server in the background and blocks until the server has started.