Merge pull request #1682 from guardicore/small-code-improvements

Small code improvements
This commit is contained in:
Mike Salvatore 2022-01-26 08:31:55 -05:00 committed by GitHub
commit f478444bb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 149 additions and 89 deletions

View File

@ -1,7 +1,7 @@
import logging
import threading
import time
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, Iterable, List, Tuple
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.i_master import IMaster
@ -11,10 +11,10 @@ from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
from infection_monkey.utils.timer import Timer
from . import Exploiter, IPScanner, Propagator
from .threading_utils import create_daemon_thread
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
@ -182,7 +182,7 @@ class AutomatedMaster(IMaster):
command, result = self._puppet.run_pba(name, options)
self._telemetry_messenger.send_telemetry(PostBreachTelem(name, command, result))
def _can_propagate(self):
def _can_propagate(self) -> bool:
return True
def _run_payload(self, payload: Tuple[str, Dict]):
@ -191,15 +191,14 @@ class AutomatedMaster(IMaster):
self._puppet.run_payload(name, options, self._stop)
def _run_plugins(self, plugin: List[Any], plugin_type: str, callback: Callable[[Any], None]):
def _run_plugins(
self, plugins: Iterable[Any], plugin_type: str, callback: Callable[[Any], None]
):
logger.info(f"Running {plugin_type}s")
logger.debug(f"Found {len(plugin)} {plugin_type}(s) to run")
for p in plugin:
if self._stop.is_set():
logger.debug(f"Received a stop signal, skipping remaining {plugin_type}s")
return
logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run")
interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s"
for p in interruptable_iter(plugins, self._stop, interrupted_message):
callback(p)
logger.info(f"Finished running {plugin_type}s")

View File

@ -7,8 +7,7 @@ from typing import Callable, Dict, List
from infection_monkey.i_puppet import ExploiterResultData, IPuppet
from infection_monkey.model import VictimHost
from .threading_utils import run_worker_threads
from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
QUEUE_TIMEOUT = 2
@ -75,10 +74,7 @@ class Exploiter:
results_callback: Callback,
stop: Event,
):
for exploiter in exploiters_to_run:
if stop.is_set():
break
for exploiter in interruptable_iter(exploiters_to_run, stop):
exploiter_name = exploiter["name"]
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
results_callback(exploiter_name, victim_host, exploiter_results)

View File

@ -13,9 +13,9 @@ from infection_monkey.i_puppet import (
PortStatus,
)
from infection_monkey.network import NetworkAddress
from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
from . import IPScanResults
from .threading_utils import run_worker_threads
logger = logging.getLogger()
@ -49,25 +49,23 @@ class IPScanner:
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event
):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
tcp_ports = options["tcp"]["ports"]
try:
while not stop.is_set():
address = addresses.get_nowait()
ip = address.ip
logger.info(f"Scanning {ip}")
logger.info(f"Scanning {address.ip}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
ping_scan_data = self._puppet.ping(ip, icmp_timeout)
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
tcp_ports = options["tcp"]["ports"]
port_scan_data = self._scan_tcp_ports(ip, tcp_ports, tcp_timeout, stop)
ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop)
fingerprint_data = {}
if IPScanner.port_scan_found_open_port(port_scan_data):
fingerprinters = options["fingerprinters"]
fingerprint_data = self._run_fingerprinters(
ip, fingerprinters, ping_scan_data, port_scan_data, stop
address.ip, fingerprinters, ping_scan_data, port_scan_data, stop
)
scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data)
@ -87,10 +85,7 @@ class IPScanner:
) -> Dict[int, PortScanData]:
port_scan_data = {}
for p in ports:
if stop.is_set():
break
for p in interruptable_iter(ports, stop):
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
return port_scan_data
@ -109,10 +104,7 @@ class IPScanner:
) -> Dict[str, FingerprintData]:
fingerprint_data = {}
for f in fingerprinters:
if stop.is_set():
break
for f in interruptable_iter(fingerprinters, stop):
fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data)
return fingerprint_data

View File

@ -16,9 +16,9 @@ from infection_monkey.network.scan_target_generator import compile_scan_target_l
from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.scan_telem import ScanTelem
from infection_monkey.utils.threading import create_daemon_thread
from . import Exploiter, IPScanner, IPScanResults
from .threading_utils import create_daemon_thread
logger = logging.getLogger()
@ -107,14 +107,13 @@ class Propagator:
victim_host.os["type"] = ping_scan_data.os
@staticmethod
def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData) -> bool:
for psd in port_scan_data.values():
if psd.status == PortStatus.OPEN:
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData):
for psd in filter(lambda psd: psd.status == PortStatus.OPEN, port_scan_data.values()):
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
@staticmethod
def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData):

View File

@ -1,17 +0,0 @@
from threading import Thread
from typing import Callable, Tuple
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
worker_threads = []
for i in range(0, num_workers):
t = create_daemon_thread(target=target, args=args)
t.start()
worker_threads.append(t)
for t in worker_threads:
t.join()
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
return Thread(target=target, args=args, daemon=True)

View File

@ -28,17 +28,12 @@ class InPlaceFileEncryptor:
def _encrypt_file(self, filepath: Path):
with open(filepath, "rb+") as f:
data = f.read(self._chunk_size)
while data:
num_bytes_read = len(data)
for data in iter(lambda: f.read(self._chunk_size), b""):
encrypted_data = self._encrypt_bytes(data)
f.seek(-num_bytes_read, 1)
f.seek(-len(encrypted_data), 1)
f.write(encrypted_data)
data = f.read(self._chunk_size)
def _add_extension(self, filepath: Path):
new_filepath = filepath.with_suffix(f"{filepath.suffix}{self._new_file_extension}")
filepath.rename(new_filepath)

View File

@ -5,6 +5,7 @@ from typing import Callable, List
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.utils.threading import interruptable_iter
from .consts import README_FILE_NAME, README_SRC
from .ransomware_options import RansomwareOptions
@ -53,13 +54,10 @@ class Ransomware:
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
logger.info(f"Encrypting files in {self._target_directory}")
for filepath in file_list:
if interrupt.is_set():
logger.debug(
"Received a stop signal, skipping remaining files for encryption of "
"ransomware payload"
)
return
interrupted_message = (
"Received a stop signal, skipping remaining files for encryption of ransomware payload"
)
for filepath in interruptable_iter(file_list, interrupt, interrupted_message):
try:
logger.debug(f"Encrypting {filepath}")
self._encrypt_file(filepath)
@ -73,11 +71,11 @@ class Ransomware:
self._telemetry_messenger.send_telemetry(encryption_attempt)
def _leave_readme_in_target_directory(self, interrupt: threading.Event):
try:
if interrupt.is_set():
logger.debug("Received a stop signal, skipping leave readme")
return
if interrupt.is_set():
logger.debug("Received a stop signal, skipping leave readme")
return
try:
self._leave_readme(README_SRC, self._readme_file_path)
except Exception as ex:
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")

View File

@ -6,7 +6,9 @@ def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]:
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool]]):
def filter_files(
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
) -> List[Path]:
filtered_files = files
for file_filter in file_filters:
filtered_files = [f for f in filtered_files if file_filter(f)]
@ -14,16 +16,16 @@ def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool
return filtered_files
def file_extension_filter(file_extensions: Set):
def inner_filter(f: Path):
def file_extension_filter(file_extensions: Set) -> Callable[[Path], bool]:
def inner_filter(f: Path) -> bool:
return f.suffix in file_extensions
return inner_filter
def is_not_symlink_filter(f: Path):
def is_not_symlink_filter(f: Path) -> bool:
return not f.is_symlink()
def is_not_shortcut_filter(f: Path):
def is_not_shortcut_filter(f: Path) -> bool:
return f.suffix != ".lnk"

View File

@ -0,0 +1,44 @@
import logging
from threading import Event, Thread
from typing import Any, Callable, Iterable, Tuple
logger = logging.getLogger(__name__)
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
worker_threads = []
for i in range(0, num_workers):
t = create_daemon_thread(target=target, args=args)
t.start()
worker_threads.append(t)
for t in worker_threads:
t.join()
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
return Thread(target=target, args=args, daemon=True)
def interruptable_iter(
iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG
) -> Any:
"""
Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This
is a convinient way to make loops interruptable and avoids the need to add an `if` to each and
every loop.
:param Iterable iterator: An iterator that will be made interruptable.
:param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the
iterator's items from being processed.
:param str log_message: A message to be logged if the iterator is interrupted. If `log_message`
is `None` (default), then no message is logged.
:param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`.
"""
for i in iterator:
if interrupt.is_set():
if log_message:
logger.log(log_level, log_message)
break
yield i

View File

@ -1,5 +1,8 @@
from __future__ import annotations
from types import MappingProxyType as ImmutableMapping
from typing import Mapping
from common.utils.file_utils import expand_path
from monkey_island.cc.server_utils.consts import (
DEFAULT_CRT_PATH,
@ -19,10 +22,7 @@ _LOG_LEVEL = "log_level"
class IslandConfigOptions:
def __init__(self, config_contents: dict = None):
if config_contents is None:
config_contents = {}
def __init__(self, config_contents: Mapping[str, Mapping] = ImmutableMapping({})):
self.data_dir = DEFAULT_DATA_DIR
self.log_level = DEFAULT_LOG_LEVEL
self.start_mongodb = DEFAULT_START_MONGO_DB
@ -33,7 +33,7 @@ class IslandConfigOptions:
self.update(config_contents)
def update(self, config_contents: dict):
def update(self, config_contents: Mapping[str, Mapping]):
self.data_dir = config_contents.get(_DATA_DIR, self.data_dir)
self.log_level = config_contents.get(_LOG_LEVEL, self.log_level)

View File

@ -117,7 +117,12 @@ def test_interrupt_while_encrypting(
mfe.assert_any_call(ransomware_test_data / HELLO_TXT)
def test_no_readme_after_interrupt(ransomware, interrupt, mock_leave_readme):
def test_no_readme_after_interrupt(
ransomware_options, build_ransomware, interrupt, mock_leave_readme
):
ransomware_options.readme_enabled = True
ransomware = build_ransomware(ransomware_options)
interrupt.set()
ransomware.run(interrupt)

View File

@ -0,0 +1,47 @@
import logging
from threading import Event
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
def test_create_daemon_thread():
thread = create_daemon_thread(lambda: None)
assert thread.daemon
def test_interruptable_iter():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted")
for i in test_iterator:
items_from_iterator.append(i)
if i == 3:
interrupt.set()
assert items_from_iterator == [0, 1, 2, 3]
def test_interruptable_iter_not_interrupted():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted")
for i in test_iterator:
items_from_iterator.append(i)
assert items_from_iterator == [0, 1, 2, 3, 4]
def test_interruptable_iter_interrupted_before_used():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(
range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO
)
interrupt.set()
for i in test_iterator:
items_from_iterator.append(i)
assert not items_from_iterator