Agent: Encapsulate parallelism in ExploitClassHTTPServer

This commit is contained in:
Mike Salvatore 2022-01-18 08:54:02 -05:00
parent 212fb3a653
commit 63085273a9
4 changed files with 26 additions and 22 deletions

View File

@ -1,5 +1,4 @@
import logging import logging
from threading import Thread
from time import sleep from time import sleep
from common.utils.exploit_enum import ExploitType from common.utils.exploit_enum import ExploitType
@ -45,7 +44,6 @@ class Log4ShellExploiter(WebRCE):
self._ldap_server = None self._ldap_server = None
self._exploit_class_http_server = None self._exploit_class_http_server = None
self._exploit_class_http_server_thread = None
self._agent_http_server_thread = None self._agent_http_server_thread = None
def _exploit_host(self): def _exploit_host(self):
@ -83,14 +81,7 @@ class Log4ShellExploiter(WebRCE):
self._exploit_class_http_server = ExploitClassHTTPServer( self._exploit_class_http_server = ExploitClassHTTPServer(
self._class_http_server_ip, self._class_http_server_port, java_class self._class_http_server_ip, self._class_http_server_port, java_class
) )
# Setting `daemon=True` to save ourselves some trouble when this is merged to the self._exploit_class_http_server.run()
# agent-refactor branch.
# TODO: Make a call to `create_daemon_thread()` instead of calling the `Thread()`
# constructor directly after merging to the agent-refactor branch.
self._exploit_class_http_server_thread = Thread(
target=self._exploit_class_http_server.run, daemon=True
)
self._exploit_class_http_server_thread.start()
def _start_ldap_server(self): def _start_ldap_server(self):
self._ldap_server = LDAPExploitServer( self._ldap_server = LDAPExploitServer(
@ -99,7 +90,6 @@ class Log4ShellExploiter(WebRCE):
http_server_port=self._class_http_server_port, http_server_port=self._class_http_server_port,
storage_dir=get_monkey_dir_path(), storage_dir=get_monkey_dir_path(),
) )
self._ldap_server.run() self._ldap_server.run()
def _stop_servers(self): def _stop_servers(self):

View File

@ -38,6 +38,13 @@ class ExploitClassHTTPServer:
self._initialize_http_handler(java_class) self._initialize_http_handler(java_class)
self._server = http.server.HTTPServer((ip, port), HTTPHandler) self._server = http.server.HTTPServer((ip, port), HTTPHandler)
# Setting `daemon=True` to save ourselves some trouble when this is merged to the
# agent-refactor branch.
# TODO: Make a call to `create_daemon_thread()` instead of calling the `Thread()`
# constructor directly after merging to the agent-refactor branch.
self._server_thread = threading.Thread(
target=self._server.serve_forever, args=(self._poll_interval,), daemon=True
)
def _initialize_http_handler(self, java_class: bytes): def _initialize_http_handler(self, java_class: bytes):
HTTPHandler.java_class = java_class HTTPHandler.java_class = java_class
@ -47,12 +54,23 @@ class ExploitClassHTTPServer:
logger.info("Starting ExploitClassHTTPServer") logger.info("Starting ExploitClassHTTPServer")
self._class_downloaded.clear() self._class_downloaded.clear()
self._server.serve_forever(self._poll_interval) # NOTE: Unlike in LDAPExploitServer, we theoretically don't need to worry about a race
logger.debug("The Java Exploit class HTTP server has stopped") # between when `serve_forever()` is ready to handle requests and when the victim machine
# sends its requests. See
# https://stackoverflow.com/questions/22606480/how-can-i-test-if-python-http-server-httpserver-is-serving-forever
# for more information.
self._server_thread.start()
def stop(self): def stop(self, timeout: float = None):
if self._server_thread.is_alive():
logger.debug("Stopping the Java Exploit class HTTP server") logger.debug("Stopping the Java Exploit class HTTP server")
self._server.shutdown() self._server.shutdown()
self._server_thread.join(timeout)
if self._server_thread.is_alive():
logger.warning("Timed out while waiting for The HTTP exploit server to stop")
else:
logger.debug("The Java Exploit class HTTP server has stopped")
def exploit_class_downloaded(self) -> bool: def exploit_class_downloaded(self) -> bool:
return self._class_downloaded.is_set() return self._class_downloaded.is_set()

View File

@ -179,6 +179,6 @@ class LDAPExploitServer:
self._server_process.join(timeout) self._server_process.join(timeout)
if self._server_process.is_alive(): if self._server_process.is_alive():
logger.warning("Timed out while waiting for the LDAP exploit server to stop.") logger.warning("Timed out while waiting for the LDAP exploit server to stop")
else: else:
logger.debug("Successfully stopped the LDAP exploit server") logger.debug("Successfully stopped the LDAP exploit server")

View File

@ -1,5 +1,3 @@
import threading
import pytest import pytest
import requests import requests
@ -25,13 +23,11 @@ def java_class():
@pytest.fixture @pytest.fixture
def server(ip, port, java_class): def server(ip, port, java_class):
server = ExploitClassHTTPServer(ip, port, java_class, 0.01) server = ExploitClassHTTPServer(ip, port, java_class, 0.01)
server_thread = threading.Thread(target=server.run) server.run()
server_thread.start()
yield server yield server
server.stop() server.stop()
server_thread.join()
@pytest.fixture @pytest.fixture