diff --git a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py index 22d9e4961..4c0eca054 100644 --- a/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py +++ b/monkey/infection_monkey/exploit/log4shell_utils/exploit_class_http_server.py @@ -3,6 +3,8 @@ import logging logger = logging.getLogger(__name__) +HTTP_TOO_MANY_REQUESTS_ERROR_CODE = 429 + class HTTPHandler(http.server.BaseHTTPRequestHandler): @@ -19,7 +21,9 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler): def do_GET(self): if HTTPHandler.class_downloaded: - self.send_error(429, "Java exploit class has already been downloaded") + self.send_error( + HTTP_TOO_MANY_REQUESTS_ERROR_CODE, "Java exploit class has already been downloaded" + ) return HTTPHandler._set_class_downloaded() @@ -32,20 +36,20 @@ class HTTPHandler(http.server.BaseHTTPRequestHandler): class ExploitClassHTTPServer: - def __init__(self, ip: str, port: int, java_class: bytes): + def __init__(self, ip: str, port: int, java_class: bytes, poll_interval: float = 0.5): logger.debug(f"The Java Exploit class will be served at {ip}:{port}") HTTPHandler.java_class = java_class HTTPHandler.reset() self._server = http.server.HTTPServer((ip, port), HTTPHandler) - self._server.socket.settimeout(0.5) + self._poll_interval = poll_interval def run(self): logger.debug("Starting ExploitClassHTTPServer") HTTPHandler.reset() - self._server.serve_forever() + self._server.serve_forever(self._poll_interval) logger.debug("The Java Exploit class HTTP server has stopped") def stop(self): diff --git a/monkey/tests/unit_tests/infection_monkey/exploit/log4shell_utils/test_exploit_class_http_server.py b/monkey/tests/unit_tests/infection_monkey/exploit/log4shell_utils/test_exploit_class_http_server.py new file mode 100644 index 000000000..d111ef825 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/exploit/log4shell_utils/test_exploit_class_http_server.py @@ -0,0 +1,58 @@ +import threading + +import pytest +import requests + +from infection_monkey.exploit.log4shell_utils import ExploitClassHTTPServer +from infection_monkey.network.info import get_free_tcp_port + + +@pytest.fixture +def ip(): + return "127.0.0.1" + + +@pytest.fixture +def port(): + return get_free_tcp_port() + + +@pytest.fixture +def java_class(): + return b"\xde\xad\xbe\xef" + + +@pytest.fixture +def server(ip, port, java_class): + server = ExploitClassHTTPServer(ip, port, java_class, 0.01) + server_thread = threading.Thread(target=server.run) + server_thread.start() + + yield server + + server.stop() + server_thread.join() + + +@pytest.fixture +def exploit_url(ip, port): + return f"http://{ip}:{port}/Exploit" + + +@pytest.mark.usefixtures("server") +def test_only_single_download_allowed(exploit_url, java_class): + response_1 = requests.get(exploit_url) + assert response_1.status_code == 200 + assert response_1.content == java_class + + response_2 = requests.get(exploit_url) + assert response_2.status_code == 429 + assert response_2.content != java_class + + +def test_exploit_class_downloded(server, exploit_url): + assert not server.exploit_class_downloaded() + + requests.get(exploit_url) + + assert server.exploit_class_downloaded()