UT: Add unit tests for ExploitClassHTTPServer

This commit is contained in:
Mike Salvatore 2022-01-12 19:11:25 -05:00
parent c2f3042442
commit a3cc641101
2 changed files with 66 additions and 4 deletions

View File

@ -3,6 +3,8 @@ import logging
logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS_ERROR_CODE = 429
class HTTPHandler(http.server.BaseHTTPRequestHandler):
@ -19,7 +21,9 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
if HTTPHandler.class_downloaded:
self.send_error(429, "Java exploit class has already been downloaded")
self.send_error(
HTTP_TOO_MANY_REQUESTS_ERROR_CODE, "Java exploit class has already been downloaded"
)
return
HTTPHandler._set_class_downloaded()
@ -32,20 +36,20 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler):
class ExploitClassHTTPServer:
def __init__(self, ip: str, port: int, java_class: bytes):
def __init__(self, ip: str, port: int, java_class: bytes, poll_interval: float = 0.5):
logger.debug(f"The Java Exploit class will be served at {ip}:{port}")
HTTPHandler.java_class = java_class
HTTPHandler.reset()
self._server = http.server.HTTPServer((ip, port), HTTPHandler)
self._server.socket.settimeout(0.5)
self._poll_interval = poll_interval
def run(self):
logger.debug("Starting ExploitClassHTTPServer")
HTTPHandler.reset()
self._server.serve_forever()
self._server.serve_forever(self._poll_interval)
logger.debug("The Java Exploit class HTTP server has stopped")
def stop(self):

View File

@ -0,0 +1,58 @@
import threading
import pytest
import requests
from infection_monkey.exploit.log4shell_utils import ExploitClassHTTPServer
from infection_monkey.network.info import get_free_tcp_port
@pytest.fixture
def ip():
return "127.0.0.1"
@pytest.fixture
def port():
return get_free_tcp_port()
@pytest.fixture
def java_class():
return b"\xde\xad\xbe\xef"
@pytest.fixture
def server(ip, port, java_class):
server = ExploitClassHTTPServer(ip, port, java_class, 0.01)
server_thread = threading.Thread(target=server.run)
server_thread.start()
yield server
server.stop()
server_thread.join()
@pytest.fixture
def exploit_url(ip, port):
return f"http://{ip}:{port}/Exploit"
@pytest.mark.usefixtures("server")
def test_only_single_download_allowed(exploit_url, java_class):
response_1 = requests.get(exploit_url)
assert response_1.status_code == 200
assert response_1.content == java_class
response_2 = requests.get(exploit_url)
assert response_2.status_code == 429
assert response_2.content != java_class
def test_exploit_class_downloded(server, exploit_url):
assert not server.exploit_class_downloaded()
requests.get(exploit_url)
assert server.exploit_class_downloaded()