UT: Add unit tests for ExploitClassHTTPServer
This commit is contained in:
parent
c2f3042442
commit
a3cc641101
|
@ -3,6 +3,8 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HTTP_TOO_MANY_REQUESTS_ERROR_CODE = 429
|
||||
|
||||
|
||||
class HTTPHandler(http.server.BaseHTTPRequestHandler):
|
||||
|
||||
|
@ -19,7 +21,9 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler):
|
|||
|
||||
def do_GET(self):
|
||||
if HTTPHandler.class_downloaded:
|
||||
self.send_error(429, "Java exploit class has already been downloaded")
|
||||
self.send_error(
|
||||
HTTP_TOO_MANY_REQUESTS_ERROR_CODE, "Java exploit class has already been downloaded"
|
||||
)
|
||||
return
|
||||
|
||||
HTTPHandler._set_class_downloaded()
|
||||
|
@ -32,20 +36,20 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler):
|
|||
|
||||
|
||||
class ExploitClassHTTPServer:
|
||||
def __init__(self, ip: str, port: int, java_class: bytes):
|
||||
def __init__(self, ip: str, port: int, java_class: bytes, poll_interval: float = 0.5):
|
||||
logger.debug(f"The Java Exploit class will be served at {ip}:{port}")
|
||||
|
||||
HTTPHandler.java_class = java_class
|
||||
HTTPHandler.reset()
|
||||
|
||||
self._server = http.server.HTTPServer((ip, port), HTTPHandler)
|
||||
self._server.socket.settimeout(0.5)
|
||||
self._poll_interval = poll_interval
|
||||
|
||||
def run(self):
|
||||
logger.debug("Starting ExploitClassHTTPServer")
|
||||
HTTPHandler.reset()
|
||||
|
||||
self._server.serve_forever()
|
||||
self._server.serve_forever(self._poll_interval)
|
||||
logger.debug("The Java Exploit class HTTP server has stopped")
|
||||
|
||||
def stop(self):
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
import threading
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from infection_monkey.exploit.log4shell_utils import ExploitClassHTTPServer
|
||||
from infection_monkey.network.info import get_free_tcp_port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ip():
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def port():
|
||||
return get_free_tcp_port()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def java_class():
|
||||
return b"\xde\xad\xbe\xef"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server(ip, port, java_class):
|
||||
server = ExploitClassHTTPServer(ip, port, java_class, 0.01)
|
||||
server_thread = threading.Thread(target=server.run)
|
||||
server_thread.start()
|
||||
|
||||
yield server
|
||||
|
||||
server.stop()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def exploit_url(ip, port):
|
||||
return f"http://{ip}:{port}/Exploit"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("server")
|
||||
def test_only_single_download_allowed(exploit_url, java_class):
|
||||
response_1 = requests.get(exploit_url)
|
||||
assert response_1.status_code == 200
|
||||
assert response_1.content == java_class
|
||||
|
||||
response_2 = requests.get(exploit_url)
|
||||
assert response_2.status_code == 429
|
||||
assert response_2.content != java_class
|
||||
|
||||
|
||||
def test_exploit_class_downloded(server, exploit_url):
|
||||
assert not server.exploit_class_downloaded()
|
||||
|
||||
requests.get(exploit_url)
|
||||
|
||||
assert server.exploit_class_downloaded()
|
Loading…
Reference in New Issue