forked from p15670423/monkey
Merge pull request #1682 from guardicore/small-code-improvements
Small code improvements
This commit is contained in:
commit
f478444bb7
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||
from infection_monkey.i_master import IMaster
|
||||
|
@ -11,10 +11,10 @@ from infection_monkey.network import NetworkInterface
|
|||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
|
||||
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
|
||||
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
|
||||
from infection_monkey.utils.timer import Timer
|
||||
|
||||
from . import Exploiter, IPScanner, Propagator
|
||||
from .threading_utils import create_daemon_thread
|
||||
|
||||
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
|
||||
CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
|
||||
|
@ -182,7 +182,7 @@ class AutomatedMaster(IMaster):
|
|||
command, result = self._puppet.run_pba(name, options)
|
||||
self._telemetry_messenger.send_telemetry(PostBreachTelem(name, command, result))
|
||||
|
||||
def _can_propagate(self):
|
||||
def _can_propagate(self) -> bool:
|
||||
return True
|
||||
|
||||
def _run_payload(self, payload: Tuple[str, Dict]):
|
||||
|
@ -191,15 +191,14 @@ class AutomatedMaster(IMaster):
|
|||
|
||||
self._puppet.run_payload(name, options, self._stop)
|
||||
|
||||
def _run_plugins(self, plugin: List[Any], plugin_type: str, callback: Callable[[Any], None]):
|
||||
def _run_plugins(
|
||||
self, plugins: Iterable[Any], plugin_type: str, callback: Callable[[Any], None]
|
||||
):
|
||||
logger.info(f"Running {plugin_type}s")
|
||||
logger.debug(f"Found {len(plugin)} {plugin_type}(s) to run")
|
||||
|
||||
for p in plugin:
|
||||
if self._stop.is_set():
|
||||
logger.debug(f"Received a stop signal, skipping remaining {plugin_type}s")
|
||||
return
|
||||
logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run")
|
||||
|
||||
interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s"
|
||||
for p in interruptable_iter(plugins, self._stop, interrupted_message):
|
||||
callback(p)
|
||||
|
||||
logger.info(f"Finished running {plugin_type}s")
|
||||
|
|
|
@ -7,8 +7,7 @@ from typing import Callable, Dict, List
|
|||
|
||||
from infection_monkey.i_puppet import ExploiterResultData, IPuppet
|
||||
from infection_monkey.model import VictimHost
|
||||
|
||||
from .threading_utils import run_worker_threads
|
||||
from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
|
||||
|
||||
QUEUE_TIMEOUT = 2
|
||||
|
||||
|
@ -75,10 +74,7 @@ class Exploiter:
|
|||
results_callback: Callback,
|
||||
stop: Event,
|
||||
):
|
||||
for exploiter in exploiters_to_run:
|
||||
if stop.is_set():
|
||||
break
|
||||
|
||||
for exploiter in interruptable_iter(exploiters_to_run, stop):
|
||||
exploiter_name = exploiter["name"]
|
||||
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
|
||||
results_callback(exploiter_name, victim_host, exploiter_results)
|
||||
|
|
|
@ -13,9 +13,9 @@ from infection_monkey.i_puppet import (
|
|||
PortStatus,
|
||||
)
|
||||
from infection_monkey.network import NetworkAddress
|
||||
from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
|
||||
|
||||
from . import IPScanResults
|
||||
from .threading_utils import run_worker_threads
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
@ -49,25 +49,23 @@ class IPScanner:
|
|||
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event
|
||||
):
|
||||
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
|
||||
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
|
||||
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
|
||||
tcp_ports = options["tcp"]["ports"]
|
||||
|
||||
try:
|
||||
while not stop.is_set():
|
||||
address = addresses.get_nowait()
|
||||
ip = address.ip
|
||||
logger.info(f"Scanning {ip}")
|
||||
logger.info(f"Scanning {address.ip}")
|
||||
|
||||
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
|
||||
ping_scan_data = self._puppet.ping(ip, icmp_timeout)
|
||||
|
||||
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
|
||||
tcp_ports = options["tcp"]["ports"]
|
||||
port_scan_data = self._scan_tcp_ports(ip, tcp_ports, tcp_timeout, stop)
|
||||
ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
|
||||
port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop)
|
||||
|
||||
fingerprint_data = {}
|
||||
if IPScanner.port_scan_found_open_port(port_scan_data):
|
||||
fingerprinters = options["fingerprinters"]
|
||||
fingerprint_data = self._run_fingerprinters(
|
||||
ip, fingerprinters, ping_scan_data, port_scan_data, stop
|
||||
address.ip, fingerprinters, ping_scan_data, port_scan_data, stop
|
||||
)
|
||||
|
||||
scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data)
|
||||
|
@ -87,10 +85,7 @@ class IPScanner:
|
|||
) -> Dict[int, PortScanData]:
|
||||
port_scan_data = {}
|
||||
|
||||
for p in ports:
|
||||
if stop.is_set():
|
||||
break
|
||||
|
||||
for p in interruptable_iter(ports, stop):
|
||||
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
|
||||
|
||||
return port_scan_data
|
||||
|
@ -109,10 +104,7 @@ class IPScanner:
|
|||
) -> Dict[str, FingerprintData]:
|
||||
fingerprint_data = {}
|
||||
|
||||
for f in fingerprinters:
|
||||
if stop.is_set():
|
||||
break
|
||||
|
||||
for f in interruptable_iter(fingerprinters, stop):
|
||||
fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data)
|
||||
|
||||
return fingerprint_data
|
||||
|
|
|
@ -16,9 +16,9 @@ from infection_monkey.network.scan_target_generator import compile_scan_target_l
|
|||
from infection_monkey.telemetry.exploit_telem import ExploitTelem
|
||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||
from infection_monkey.telemetry.scan_telem import ScanTelem
|
||||
from infection_monkey.utils.threading import create_daemon_thread
|
||||
|
||||
from . import Exploiter, IPScanner, IPScanResults
|
||||
from .threading_utils import create_daemon_thread
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
@ -107,14 +107,13 @@ class Propagator:
|
|||
victim_host.os["type"] = ping_scan_data.os
|
||||
|
||||
@staticmethod
|
||||
def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData) -> bool:
|
||||
for psd in port_scan_data.values():
|
||||
if psd.status == PortStatus.OPEN:
|
||||
victim_host.services[psd.service] = {}
|
||||
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
|
||||
victim_host.services[psd.service]["port"] = psd.port
|
||||
if psd.banner is not None:
|
||||
victim_host.services[psd.service]["banner"] = psd.banner
|
||||
def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData):
|
||||
for psd in filter(lambda psd: psd.status == PortStatus.OPEN, port_scan_data.values()):
|
||||
victim_host.services[psd.service] = {}
|
||||
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
|
||||
victim_host.services[psd.service]["port"] = psd.port
|
||||
if psd.banner is not None:
|
||||
victim_host.services[psd.service]["banner"] = psd.banner
|
||||
|
||||
@staticmethod
|
||||
def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData):
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
from threading import Thread
|
||||
from typing import Callable, Tuple
|
||||
|
||||
|
||||
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
|
||||
worker_threads = []
|
||||
for i in range(0, num_workers):
|
||||
t = create_daemon_thread(target=target, args=args)
|
||||
t.start()
|
||||
worker_threads.append(t)
|
||||
|
||||
for t in worker_threads:
|
||||
t.join()
|
||||
|
||||
|
||||
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
|
||||
return Thread(target=target, args=args, daemon=True)
|
|
@ -28,17 +28,12 @@ class InPlaceFileEncryptor:
|
|||
|
||||
def _encrypt_file(self, filepath: Path):
|
||||
with open(filepath, "rb+") as f:
|
||||
data = f.read(self._chunk_size)
|
||||
while data:
|
||||
num_bytes_read = len(data)
|
||||
|
||||
for data in iter(lambda: f.read(self._chunk_size), b""):
|
||||
encrypted_data = self._encrypt_bytes(data)
|
||||
|
||||
f.seek(-num_bytes_read, 1)
|
||||
f.seek(-len(encrypted_data), 1)
|
||||
f.write(encrypted_data)
|
||||
|
||||
data = f.read(self._chunk_size)
|
||||
|
||||
def _add_extension(self, filepath: Path):
|
||||
new_filepath = filepath.with_suffix(f"{filepath.suffix}{self._new_file_extension}")
|
||||
filepath.rename(new_filepath)
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Callable, List
|
|||
|
||||
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
|
||||
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||
from infection_monkey.utils.threading import interruptable_iter
|
||||
|
||||
from .consts import README_FILE_NAME, README_SRC
|
||||
from .ransomware_options import RansomwareOptions
|
||||
|
@ -53,13 +54,10 @@ class Ransomware:
|
|||
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
|
||||
logger.info(f"Encrypting files in {self._target_directory}")
|
||||
|
||||
for filepath in file_list:
|
||||
if interrupt.is_set():
|
||||
logger.debug(
|
||||
"Received a stop signal, skipping remaining files for encryption of "
|
||||
"ransomware payload"
|
||||
)
|
||||
return
|
||||
interrupted_message = (
|
||||
"Received a stop signal, skipping remaining files for encryption of ransomware payload"
|
||||
)
|
||||
for filepath in interruptable_iter(file_list, interrupt, interrupted_message):
|
||||
try:
|
||||
logger.debug(f"Encrypting {filepath}")
|
||||
self._encrypt_file(filepath)
|
||||
|
@ -73,11 +71,11 @@ class Ransomware:
|
|||
self._telemetry_messenger.send_telemetry(encryption_attempt)
|
||||
|
||||
def _leave_readme_in_target_directory(self, interrupt: threading.Event):
|
||||
try:
|
||||
if interrupt.is_set():
|
||||
logger.debug("Received a stop signal, skipping leave readme")
|
||||
return
|
||||
if interrupt.is_set():
|
||||
logger.debug("Received a stop signal, skipping leave readme")
|
||||
return
|
||||
|
||||
try:
|
||||
self._leave_readme(README_SRC, self._readme_file_path)
|
||||
except Exception as ex:
|
||||
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")
|
||||
|
|
|
@ -6,7 +6,9 @@ def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]:
|
|||
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
|
||||
|
||||
|
||||
def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool]]):
|
||||
def filter_files(
|
||||
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
|
||||
) -> List[Path]:
|
||||
filtered_files = files
|
||||
for file_filter in file_filters:
|
||||
filtered_files = [f for f in filtered_files if file_filter(f)]
|
||||
|
@ -14,16 +16,16 @@ def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool
|
|||
return filtered_files
|
||||
|
||||
|
||||
def file_extension_filter(file_extensions: Set):
|
||||
def inner_filter(f: Path):
|
||||
def file_extension_filter(file_extensions: Set) -> Callable[[Path], bool]:
|
||||
def inner_filter(f: Path) -> bool:
|
||||
return f.suffix in file_extensions
|
||||
|
||||
return inner_filter
|
||||
|
||||
|
||||
def is_not_symlink_filter(f: Path):
|
||||
def is_not_symlink_filter(f: Path) -> bool:
|
||||
return not f.is_symlink()
|
||||
|
||||
|
||||
def is_not_shortcut_filter(f: Path):
|
||||
def is_not_shortcut_filter(f: Path) -> bool:
|
||||
return f.suffix != ".lnk"
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import logging
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Callable, Iterable, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
|
||||
worker_threads = []
|
||||
for i in range(0, num_workers):
|
||||
t = create_daemon_thread(target=target, args=args)
|
||||
t.start()
|
||||
worker_threads.append(t)
|
||||
|
||||
for t in worker_threads:
|
||||
t.join()
|
||||
|
||||
|
||||
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
|
||||
return Thread(target=target, args=args, daemon=True)
|
||||
|
||||
|
||||
def interruptable_iter(
|
||||
iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG
|
||||
) -> Any:
|
||||
"""
|
||||
Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This
|
||||
is a convinient way to make loops interruptable and avoids the need to add an `if` to each and
|
||||
every loop.
|
||||
:param Iterable iterator: An iterator that will be made interruptable.
|
||||
:param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the
|
||||
iterator's items from being processed.
|
||||
:param str log_message: A message to be logged if the iterator is interrupted. If `log_message`
|
||||
is `None` (default), then no message is logged.
|
||||
:param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`.
|
||||
"""
|
||||
for i in iterator:
|
||||
if interrupt.is_set():
|
||||
if log_message:
|
||||
logger.log(log_level, log_message)
|
||||
|
||||
break
|
||||
|
||||
yield i
|
|
@ -1,5 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import MappingProxyType as ImmutableMapping
|
||||
from typing import Mapping
|
||||
|
||||
from common.utils.file_utils import expand_path
|
||||
from monkey_island.cc.server_utils.consts import (
|
||||
DEFAULT_CRT_PATH,
|
||||
|
@ -19,10 +22,7 @@ _LOG_LEVEL = "log_level"
|
|||
|
||||
|
||||
class IslandConfigOptions:
|
||||
def __init__(self, config_contents: dict = None):
|
||||
if config_contents is None:
|
||||
config_contents = {}
|
||||
|
||||
def __init__(self, config_contents: Mapping[str, Mapping] = ImmutableMapping({})):
|
||||
self.data_dir = DEFAULT_DATA_DIR
|
||||
self.log_level = DEFAULT_LOG_LEVEL
|
||||
self.start_mongodb = DEFAULT_START_MONGO_DB
|
||||
|
@ -33,7 +33,7 @@ class IslandConfigOptions:
|
|||
|
||||
self.update(config_contents)
|
||||
|
||||
def update(self, config_contents: dict):
|
||||
def update(self, config_contents: Mapping[str, Mapping]):
|
||||
self.data_dir = config_contents.get(_DATA_DIR, self.data_dir)
|
||||
|
||||
self.log_level = config_contents.get(_LOG_LEVEL, self.log_level)
|
||||
|
|
|
@ -117,7 +117,12 @@ def test_interrupt_while_encrypting(
|
|||
mfe.assert_any_call(ransomware_test_data / HELLO_TXT)
|
||||
|
||||
|
||||
def test_no_readme_after_interrupt(ransomware, interrupt, mock_leave_readme):
|
||||
def test_no_readme_after_interrupt(
|
||||
ransomware_options, build_ransomware, interrupt, mock_leave_readme
|
||||
):
|
||||
ransomware_options.readme_enabled = True
|
||||
ransomware = build_ransomware(ransomware_options)
|
||||
|
||||
interrupt.set()
|
||||
ransomware.run(interrupt)
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
import logging
|
||||
from threading import Event
|
||||
|
||||
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
|
||||
|
||||
|
||||
def test_create_daemon_thread():
|
||||
thread = create_daemon_thread(lambda: None)
|
||||
assert thread.daemon
|
||||
|
||||
|
||||
def test_interruptable_iter():
|
||||
interrupt = Event()
|
||||
items_from_iterator = []
|
||||
test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted")
|
||||
|
||||
for i in test_iterator:
|
||||
items_from_iterator.append(i)
|
||||
if i == 3:
|
||||
interrupt.set()
|
||||
|
||||
assert items_from_iterator == [0, 1, 2, 3]
|
||||
|
||||
|
||||
def test_interruptable_iter_not_interrupted():
|
||||
interrupt = Event()
|
||||
items_from_iterator = []
|
||||
test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted")
|
||||
|
||||
for i in test_iterator:
|
||||
items_from_iterator.append(i)
|
||||
|
||||
assert items_from_iterator == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_interruptable_iter_interrupted_before_used():
|
||||
interrupt = Event()
|
||||
items_from_iterator = []
|
||||
test_iterator = interruptable_iter(
|
||||
range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO
|
||||
)
|
||||
|
||||
interrupt.set()
|
||||
for i in test_iterator:
|
||||
items_from_iterator.append(i)
|
||||
|
||||
assert not items_from_iterator
|
Loading…
Reference in New Issue