Merge pull request #1682 from guardicore/small-code-improvements

Small code improvements
This commit is contained in:
Mike Salvatore 2022-01-26 08:31:55 -05:00 committed by GitHub
commit f478444bb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 149 additions and 89 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
import threading import threading
import time import time
from typing import Any, Callable, Dict, List, Tuple from typing import Any, Callable, Dict, Iterable, List, Tuple
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
@ -11,10 +11,10 @@ from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.post_breach_telem import PostBreachTelem from infection_monkey.telemetry.post_breach_telem import PostBreachTelem
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
from infection_monkey.utils.timer import Timer from infection_monkey.utils.timer import Timer
from . import Exploiter, IPScanner, Propagator from . import Exploiter, IPScanner, Propagator
from .threading_utils import create_daemon_thread
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5 CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5 CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
@ -182,7 +182,7 @@ class AutomatedMaster(IMaster):
command, result = self._puppet.run_pba(name, options) command, result = self._puppet.run_pba(name, options)
self._telemetry_messenger.send_telemetry(PostBreachTelem(name, command, result)) self._telemetry_messenger.send_telemetry(PostBreachTelem(name, command, result))
def _can_propagate(self): def _can_propagate(self) -> bool:
return True return True
def _run_payload(self, payload: Tuple[str, Dict]): def _run_payload(self, payload: Tuple[str, Dict]):
@ -191,15 +191,14 @@ class AutomatedMaster(IMaster):
self._puppet.run_payload(name, options, self._stop) self._puppet.run_payload(name, options, self._stop)
def _run_plugins(self, plugin: List[Any], plugin_type: str, callback: Callable[[Any], None]): def _run_plugins(
self, plugins: Iterable[Any], plugin_type: str, callback: Callable[[Any], None]
):
logger.info(f"Running {plugin_type}s") logger.info(f"Running {plugin_type}s")
logger.debug(f"Found {len(plugin)} {plugin_type}(s) to run") logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run")
for p in plugin:
if self._stop.is_set():
logger.debug(f"Received a stop signal, skipping remaining {plugin_type}s")
return
interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s"
for p in interruptable_iter(plugins, self._stop, interrupted_message):
callback(p) callback(p)
logger.info(f"Finished running {plugin_type}s") logger.info(f"Finished running {plugin_type}s")

View File

@ -7,8 +7,7 @@ from typing import Callable, Dict, List
from infection_monkey.i_puppet import ExploiterResultData, IPuppet from infection_monkey.i_puppet import ExploiterResultData, IPuppet
from infection_monkey.model import VictimHost from infection_monkey.model import VictimHost
from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
from .threading_utils import run_worker_threads
QUEUE_TIMEOUT = 2 QUEUE_TIMEOUT = 2
@ -75,10 +74,7 @@ class Exploiter:
results_callback: Callback, results_callback: Callback,
stop: Event, stop: Event,
): ):
for exploiter in exploiters_to_run: for exploiter in interruptable_iter(exploiters_to_run, stop):
if stop.is_set():
break
exploiter_name = exploiter["name"] exploiter_name = exploiter["name"]
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop) exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
results_callback(exploiter_name, victim_host, exploiter_results) results_callback(exploiter_name, victim_host, exploiter_results)

View File

@ -13,9 +13,9 @@ from infection_monkey.i_puppet import (
PortStatus, PortStatus,
) )
from infection_monkey.network import NetworkAddress from infection_monkey.network import NetworkAddress
from infection_monkey.utils.threading import interruptable_iter, run_worker_threads
from . import IPScanResults from . import IPScanResults
from .threading_utils import run_worker_threads
logger = logging.getLogger() logger = logging.getLogger()
@ -49,25 +49,23 @@ class IPScanner:
self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event
): ):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
tcp_ports = options["tcp"]["ports"]
try: try:
while not stop.is_set(): while not stop.is_set():
address = addresses.get_nowait() address = addresses.get_nowait()
ip = address.ip logger.info(f"Scanning {address.ip}")
logger.info(f"Scanning {ip}")
icmp_timeout = options["icmp"]["timeout_ms"] / 1000 ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
ping_scan_data = self._puppet.ping(ip, icmp_timeout) port_scan_data = self._scan_tcp_ports(address.ip, tcp_ports, tcp_timeout, stop)
tcp_timeout = options["tcp"]["timeout_ms"] / 1000
tcp_ports = options["tcp"]["ports"]
port_scan_data = self._scan_tcp_ports(ip, tcp_ports, tcp_timeout, stop)
fingerprint_data = {} fingerprint_data = {}
if IPScanner.port_scan_found_open_port(port_scan_data): if IPScanner.port_scan_found_open_port(port_scan_data):
fingerprinters = options["fingerprinters"] fingerprinters = options["fingerprinters"]
fingerprint_data = self._run_fingerprinters( fingerprint_data = self._run_fingerprinters(
ip, fingerprinters, ping_scan_data, port_scan_data, stop address.ip, fingerprinters, ping_scan_data, port_scan_data, stop
) )
scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data) scan_results = IPScanResults(ping_scan_data, port_scan_data, fingerprint_data)
@ -87,10 +85,7 @@ class IPScanner:
) -> Dict[int, PortScanData]: ) -> Dict[int, PortScanData]:
port_scan_data = {} port_scan_data = {}
for p in ports: for p in interruptable_iter(ports, stop):
if stop.is_set():
break
port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout) port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, timeout)
return port_scan_data return port_scan_data
@ -109,10 +104,7 @@ class IPScanner:
) -> Dict[str, FingerprintData]: ) -> Dict[str, FingerprintData]:
fingerprint_data = {} fingerprint_data = {}
for f in fingerprinters: for f in interruptable_iter(fingerprinters, stop):
if stop.is_set():
break
fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data) fingerprint_data[f] = self._puppet.fingerprint(f, ip, ping_scan_data, port_scan_data)
return fingerprint_data return fingerprint_data

View File

@ -16,9 +16,9 @@ from infection_monkey.network.scan_target_generator import compile_scan_target_l
from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.scan_telem import ScanTelem from infection_monkey.telemetry.scan_telem import ScanTelem
from infection_monkey.utils.threading import create_daemon_thread
from . import Exploiter, IPScanner, IPScanResults from . import Exploiter, IPScanner, IPScanResults
from .threading_utils import create_daemon_thread
logger = logging.getLogger() logger = logging.getLogger()
@ -107,14 +107,13 @@ class Propagator:
victim_host.os["type"] = ping_scan_data.os victim_host.os["type"] = ping_scan_data.os
@staticmethod @staticmethod
def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData) -> bool: def _process_tcp_scan_results(victim_host: VictimHost, port_scan_data: PortScanData):
for psd in port_scan_data.values(): for psd in filter(lambda psd: psd.status == PortStatus.OPEN, port_scan_data.values()):
if psd.status == PortStatus.OPEN: victim_host.services[psd.service] = {}
victim_host.services[psd.service] = {} victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["display_name"] = "unknown(TCP)" victim_host.services[psd.service]["port"] = psd.port
victim_host.services[psd.service]["port"] = psd.port if psd.banner is not None:
if psd.banner is not None: victim_host.services[psd.service]["banner"] = psd.banner
victim_host.services[psd.service]["banner"] = psd.banner
@staticmethod @staticmethod
def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData): def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData):

View File

@ -1,17 +0,0 @@
from threading import Thread
from typing import Callable, Tuple
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
worker_threads = []
for i in range(0, num_workers):
t = create_daemon_thread(target=target, args=args)
t.start()
worker_threads.append(t)
for t in worker_threads:
t.join()
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
return Thread(target=target, args=args, daemon=True)

View File

@ -28,17 +28,12 @@ class InPlaceFileEncryptor:
def _encrypt_file(self, filepath: Path): def _encrypt_file(self, filepath: Path):
with open(filepath, "rb+") as f: with open(filepath, "rb+") as f:
data = f.read(self._chunk_size) for data in iter(lambda: f.read(self._chunk_size), b""):
while data:
num_bytes_read = len(data)
encrypted_data = self._encrypt_bytes(data) encrypted_data = self._encrypt_bytes(data)
f.seek(-num_bytes_read, 1) f.seek(-len(encrypted_data), 1)
f.write(encrypted_data) f.write(encrypted_data)
data = f.read(self._chunk_size)
def _add_extension(self, filepath: Path): def _add_extension(self, filepath: Path):
new_filepath = filepath.with_suffix(f"{filepath.suffix}{self._new_file_extension}") new_filepath = filepath.with_suffix(f"{filepath.suffix}{self._new_file_extension}")
filepath.rename(new_filepath) filepath.rename(new_filepath)

View File

@ -5,6 +5,7 @@ from typing import Callable, List
from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem from infection_monkey.telemetry.file_encryption_telem import FileEncryptionTelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.utils.threading import interruptable_iter
from .consts import README_FILE_NAME, README_SRC from .consts import README_FILE_NAME, README_SRC
from .ransomware_options import RansomwareOptions from .ransomware_options import RansomwareOptions
@ -53,13 +54,10 @@ class Ransomware:
def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event): def _encrypt_files(self, file_list: List[Path], interrupt: threading.Event):
logger.info(f"Encrypting files in {self._target_directory}") logger.info(f"Encrypting files in {self._target_directory}")
for filepath in file_list: interrupted_message = (
if interrupt.is_set(): "Received a stop signal, skipping remaining files for encryption of ransomware payload"
logger.debug( )
"Received a stop signal, skipping remaining files for encryption of " for filepath in interruptable_iter(file_list, interrupt, interrupted_message):
"ransomware payload"
)
return
try: try:
logger.debug(f"Encrypting {filepath}") logger.debug(f"Encrypting {filepath}")
self._encrypt_file(filepath) self._encrypt_file(filepath)
@ -73,11 +71,11 @@ class Ransomware:
self._telemetry_messenger.send_telemetry(encryption_attempt) self._telemetry_messenger.send_telemetry(encryption_attempt)
def _leave_readme_in_target_directory(self, interrupt: threading.Event): def _leave_readme_in_target_directory(self, interrupt: threading.Event):
try: if interrupt.is_set():
if interrupt.is_set(): logger.debug("Received a stop signal, skipping leave readme")
logger.debug("Received a stop signal, skipping leave readme") return
return
try:
self._leave_readme(README_SRC, self._readme_file_path) self._leave_readme(README_SRC, self._readme_file_path)
except Exception as ex: except Exception as ex:
logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}") logger.warning(f"An error occurred while attempting to leave a README.txt file: {ex}")

View File

@ -6,7 +6,9 @@ def get_all_regular_files_in_directory(dir_path: Path) -> List[Path]:
return filter_files(dir_path.iterdir(), [lambda f: f.is_file()]) return filter_files(dir_path.iterdir(), [lambda f: f.is_file()])
def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool]]): def filter_files(
files: Iterable[Path], file_filters: Iterable[Callable[[Path], bool]]
) -> List[Path]:
filtered_files = files filtered_files = files
for file_filter in file_filters: for file_filter in file_filters:
filtered_files = [f for f in filtered_files if file_filter(f)] filtered_files = [f for f in filtered_files if file_filter(f)]
@ -14,16 +16,16 @@ def filter_files(files: Iterable[Path], file_filters: List[Callable[[Path], bool
return filtered_files return filtered_files
def file_extension_filter(file_extensions: Set): def file_extension_filter(file_extensions: Set) -> Callable[[Path], bool]:
def inner_filter(f: Path): def inner_filter(f: Path) -> bool:
return f.suffix in file_extensions return f.suffix in file_extensions
return inner_filter return inner_filter
def is_not_symlink_filter(f: Path): def is_not_symlink_filter(f: Path) -> bool:
return not f.is_symlink() return not f.is_symlink()
def is_not_shortcut_filter(f: Path): def is_not_shortcut_filter(f: Path) -> bool:
return f.suffix != ".lnk" return f.suffix != ".lnk"

View File

@ -0,0 +1,44 @@
import logging
from threading import Event, Thread
from typing import Any, Callable, Iterable, Tuple
logger = logging.getLogger(__name__)
def run_worker_threads(target: Callable[..., None], args: Tuple = (), num_workers: int = 2):
worker_threads = []
for i in range(0, num_workers):
t = create_daemon_thread(target=target, args=args)
t.start()
worker_threads.append(t)
for t in worker_threads:
t.join()
def create_daemon_thread(target: Callable[..., None], args: Tuple = ()):
return Thread(target=target, args=args, daemon=True)
def interruptable_iter(
iterator: Iterable, interrupt: Event, log_message: str = None, log_level: int = logging.DEBUG
) -> Any:
"""
Wraps an iterator so that the iterator can be interrupted if the `interrupt` Event is set. This
is a convinient way to make loops interruptable and avoids the need to add an `if` to each and
every loop.
:param Iterable iterator: An iterator that will be made interruptable.
:param Event interrupt: A `threading.Event` that, if set, will prevent the remainder of the
iterator's items from being processed.
:param str log_message: A message to be logged if the iterator is interrupted. If `log_message`
is `None` (default), then no message is logged.
:param int log_level: The log level at which to log `log_message`, defaults to `logging.DEBUG`.
"""
for i in iterator:
if interrupt.is_set():
if log_message:
logger.log(log_level, log_message)
break
yield i

View File

@ -1,5 +1,8 @@
from __future__ import annotations from __future__ import annotations
from types import MappingProxyType as ImmutableMapping
from typing import Mapping
from common.utils.file_utils import expand_path from common.utils.file_utils import expand_path
from monkey_island.cc.server_utils.consts import ( from monkey_island.cc.server_utils.consts import (
DEFAULT_CRT_PATH, DEFAULT_CRT_PATH,
@ -19,10 +22,7 @@ _LOG_LEVEL = "log_level"
class IslandConfigOptions: class IslandConfigOptions:
def __init__(self, config_contents: dict = None): def __init__(self, config_contents: Mapping[str, Mapping] = ImmutableMapping({})):
if config_contents is None:
config_contents = {}
self.data_dir = DEFAULT_DATA_DIR self.data_dir = DEFAULT_DATA_DIR
self.log_level = DEFAULT_LOG_LEVEL self.log_level = DEFAULT_LOG_LEVEL
self.start_mongodb = DEFAULT_START_MONGO_DB self.start_mongodb = DEFAULT_START_MONGO_DB
@ -33,7 +33,7 @@ class IslandConfigOptions:
self.update(config_contents) self.update(config_contents)
def update(self, config_contents: dict): def update(self, config_contents: Mapping[str, Mapping]):
self.data_dir = config_contents.get(_DATA_DIR, self.data_dir) self.data_dir = config_contents.get(_DATA_DIR, self.data_dir)
self.log_level = config_contents.get(_LOG_LEVEL, self.log_level) self.log_level = config_contents.get(_LOG_LEVEL, self.log_level)

View File

@ -117,7 +117,12 @@ def test_interrupt_while_encrypting(
mfe.assert_any_call(ransomware_test_data / HELLO_TXT) mfe.assert_any_call(ransomware_test_data / HELLO_TXT)
def test_no_readme_after_interrupt(ransomware, interrupt, mock_leave_readme): def test_no_readme_after_interrupt(
ransomware_options, build_ransomware, interrupt, mock_leave_readme
):
ransomware_options.readme_enabled = True
ransomware = build_ransomware(ransomware_options)
interrupt.set() interrupt.set()
ransomware.run(interrupt) ransomware.run(interrupt)

View File

@ -0,0 +1,47 @@
import logging
from threading import Event
from infection_monkey.utils.threading import create_daemon_thread, interruptable_iter
def test_create_daemon_thread():
thread = create_daemon_thread(lambda: None)
assert thread.daemon
def test_interruptable_iter():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(range(0, 10), interrupt, "Test iterator was interrupted")
for i in test_iterator:
items_from_iterator.append(i)
if i == 3:
interrupt.set()
assert items_from_iterator == [0, 1, 2, 3]
def test_interruptable_iter_not_interrupted():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(range(0, 5), interrupt, "Test iterator was interrupted")
for i in test_iterator:
items_from_iterator.append(i)
assert items_from_iterator == [0, 1, 2, 3, 4]
def test_interruptable_iter_interrupted_before_used():
interrupt = Event()
items_from_iterator = []
test_iterator = interruptable_iter(
range(0, 5), interrupt, "Test iterator was interrupted", logging.INFO
)
interrupt.set()
for i in test_iterator:
items_from_iterator.append(i)
assert not items_from_iterator