forked from p15670423/monkey
UT: Add unit tests for ExploitClassHTTPServer
This commit is contained in:
parent
c2f3042442
commit
a3cc641101
monkey
infection_monkey/exploit/log4shell_utils
tests/unit_tests/infection_monkey/exploit/log4shell_utils
|
@ -3,6 +3,8 @@ import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HTTP_TOO_MANY_REQUESTS_ERROR_CODE = 429
|
||||||
|
|
||||||
|
|
||||||
class HTTPHandler(http.server.BaseHTTPRequestHandler):
|
class HTTPHandler(http.server.BaseHTTPRequestHandler):
|
||||||
|
|
||||||
|
@ -19,7 +21,9 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler):
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if HTTPHandler.class_downloaded:
|
if HTTPHandler.class_downloaded:
|
||||||
self.send_error(429, "Java exploit class has already been downloaded")
|
self.send_error(
|
||||||
|
HTTP_TOO_MANY_REQUESTS_ERROR_CODE, "Java exploit class has already been downloaded"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
HTTPHandler._set_class_downloaded()
|
HTTPHandler._set_class_downloaded()
|
||||||
|
@ -32,20 +36,20 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler):
|
||||||
|
|
||||||
|
|
||||||
class ExploitClassHTTPServer:
|
class ExploitClassHTTPServer:
|
||||||
def __init__(self, ip: str, port: int, java_class: bytes):
|
def __init__(self, ip: str, port: int, java_class: bytes, poll_interval: float = 0.5):
|
||||||
logger.debug(f"The Java Exploit class will be served at {ip}:{port}")
|
logger.debug(f"The Java Exploit class will be served at {ip}:{port}")
|
||||||
|
|
||||||
HTTPHandler.java_class = java_class
|
HTTPHandler.java_class = java_class
|
||||||
HTTPHandler.reset()
|
HTTPHandler.reset()
|
||||||
|
|
||||||
self._server = http.server.HTTPServer((ip, port), HTTPHandler)
|
self._server = http.server.HTTPServer((ip, port), HTTPHandler)
|
||||||
self._server.socket.settimeout(0.5)
|
self._poll_interval = poll_interval
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
logger.debug("Starting ExploitClassHTTPServer")
|
logger.debug("Starting ExploitClassHTTPServer")
|
||||||
HTTPHandler.reset()
|
HTTPHandler.reset()
|
||||||
|
|
||||||
self._server.serve_forever()
|
self._server.serve_forever(self._poll_interval)
|
||||||
logger.debug("The Java Exploit class HTTP server has stopped")
|
logger.debug("The Java Exploit class HTTP server has stopped")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from infection_monkey.exploit.log4shell_utils import ExploitClassHTTPServer
|
||||||
|
from infection_monkey.network.info import get_free_tcp_port
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ip():
|
||||||
|
return "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def port():
|
||||||
|
return get_free_tcp_port()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def java_class():
|
||||||
|
return b"\xde\xad\xbe\xef"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def server(ip, port, java_class):
|
||||||
|
server = ExploitClassHTTPServer(ip, port, java_class, 0.01)
|
||||||
|
server_thread = threading.Thread(target=server.run)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
yield server
|
||||||
|
|
||||||
|
server.stop()
|
||||||
|
server_thread.join()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def exploit_url(ip, port):
|
||||||
|
return f"http://{ip}:{port}/Exploit"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("server")
|
||||||
|
def test_only_single_download_allowed(exploit_url, java_class):
|
||||||
|
response_1 = requests.get(exploit_url)
|
||||||
|
assert response_1.status_code == 200
|
||||||
|
assert response_1.content == java_class
|
||||||
|
|
||||||
|
response_2 = requests.get(exploit_url)
|
||||||
|
assert response_2.status_code == 429
|
||||||
|
assert response_2.content != java_class
|
||||||
|
|
||||||
|
|
||||||
|
def test_exploit_class_downloded(server, exploit_url):
|
||||||
|
assert not server.exploit_class_downloaded()
|
||||||
|
|
||||||
|
requests.get(exploit_url)
|
||||||
|
|
||||||
|
assert server.exploit_class_downloaded()
|
Loading…
Reference in New Issue