Merge pull request #1637 from guardicore/1597-implement-automated-master
1597 implement automated master - Part 1
This commit is contained in:
commit
a1601f120f
|
@ -32,6 +32,7 @@ class SSHExploiter(HostExploiter):
|
||||||
self.skip_exist = self._config.skip_exploit_if_file_exist
|
self.skip_exist = self._config.skip_exploit_if_file_exist
|
||||||
|
|
||||||
def log_transfer(self, transferred, total):
|
def log_transfer(self, transferred, total):
|
||||||
|
# TODO: Replace with infection_monkey.utils.timer.Timer
|
||||||
if time.time() - self._update_timestamp > TRANSFER_UPDATE_RATE:
|
if time.time() - self._update_timestamp > TRANSFER_UPDATE_RATE:
|
||||||
logger.debug("SFTP transferred: %d bytes, total: %d bytes", transferred, total)
|
logger.debug("SFTP transferred: %d bytes, total: %d bytes", transferred, total)
|
||||||
self._update_timestamp = time.time()
|
self._update_timestamp = time.time()
|
||||||
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from infection_monkey.i_control_channel import IControlChannel
|
||||||
|
from infection_monkey.i_master import IMaster
|
||||||
|
from infection_monkey.i_puppet import IPuppet
|
||||||
|
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
|
||||||
|
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
|
||||||
|
from infection_monkey.utils.timer import Timer
|
||||||
|
|
||||||
|
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
|
||||||
|
CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
|
||||||
|
SHUTDOWN_TIMEOUT = 5
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class AutomatedMaster(IMaster):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
puppet: IPuppet,
|
||||||
|
telemetry_messenger: ITelemetryMessenger,
|
||||||
|
control_channel: IControlChannel,
|
||||||
|
):
|
||||||
|
self._puppet = puppet
|
||||||
|
self._telemetry_messenger = telemetry_messenger
|
||||||
|
self._control_channel = control_channel
|
||||||
|
|
||||||
|
self._stop = threading.Event()
|
||||||
|
self._master_thread = threading.Thread(target=self._run_master_thread, daemon=True)
|
||||||
|
self._simulation_thread = threading.Thread(target=self._run_simulation, daemon=True)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
logger.info("Starting automated breach and attack simulation")
|
||||||
|
self._master_thread.start()
|
||||||
|
self._master_thread.join()
|
||||||
|
logger.info("The simulation has been shutdown.")
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
logger.info("Stopping automated breach and attack simulation")
|
||||||
|
self._stop.set()
|
||||||
|
|
||||||
|
if self._master_thread.is_alive():
|
||||||
|
self._master_thread.join()
|
||||||
|
|
||||||
|
def _run_master_thread(self):
|
||||||
|
self._simulation_thread.start()
|
||||||
|
|
||||||
|
self._wait_for_master_stop_condition()
|
||||||
|
|
||||||
|
logger.debug("Waiting for the simulation thread to stop")
|
||||||
|
self._simulation_thread.join(SHUTDOWN_TIMEOUT)
|
||||||
|
|
||||||
|
if self._simulation_thread.is_alive():
|
||||||
|
logger.warning("Timed out waiting for the simulation to stop")
|
||||||
|
# Since the master thread and all child threads are daemon threads, they will be
|
||||||
|
# forcefully killed when the program exits.
|
||||||
|
# TODO: Daemon threads to not die when the parent THREAD does, but when the parent
|
||||||
|
# PROCESS does. This could lead to conflicts between threads that refuse to die
|
||||||
|
# and the cleanup() function. Come up with a solution.
|
||||||
|
logger.warning("Forcefully killing the simulation")
|
||||||
|
|
||||||
|
def _wait_for_master_stop_condition(self):
|
||||||
|
timer = Timer()
|
||||||
|
timer.set(CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC)
|
||||||
|
|
||||||
|
while self._master_thread_should_run():
|
||||||
|
if timer.is_expired():
|
||||||
|
# TODO: Handle exceptions in _check_for_stop() once
|
||||||
|
# ControlChannel.should_agent_stop() is refactored.
|
||||||
|
self._check_for_stop()
|
||||||
|
timer.reset()
|
||||||
|
|
||||||
|
time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC)
|
||||||
|
|
||||||
|
def _check_for_stop(self):
|
||||||
|
if self._control_channel.should_agent_stop():
|
||||||
|
logger.debug('Received the "stop" signal from the Island')
|
||||||
|
self._stop.set()
|
||||||
|
|
||||||
|
def _master_thread_should_run(self):
|
||||||
|
return (not self._stop.is_set()) and self._simulation_thread.is_alive()
|
||||||
|
|
||||||
|
def _run_simulation(self):
|
||||||
|
config = self._control_channel.get_config()
|
||||||
|
|
||||||
|
system_info_collector_thread = threading.Thread(
|
||||||
|
target=self._collect_system_info,
|
||||||
|
args=(config["system_info_collector_classes"],),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
pba_thread = threading.Thread(
|
||||||
|
target=self._run_pbas, args=(config["post_breach_actions"],), daemon=True
|
||||||
|
)
|
||||||
|
|
||||||
|
system_info_collector_thread.start()
|
||||||
|
pba_thread.start()
|
||||||
|
|
||||||
|
# Future stages of the simulation require the output of the system info collectors. Nothing
|
||||||
|
# requires the output of PBAs, so we don't need to join on that thread here. We will join on
|
||||||
|
# the PBA thread later in this function to prevent the simulation from ending while PBAs are
|
||||||
|
# still running.
|
||||||
|
system_info_collector_thread.join()
|
||||||
|
|
||||||
|
if self._can_propagate():
|
||||||
|
propagation_thread = threading.Thread(
|
||||||
|
target=self._propagate, args=(config,), daemon=True
|
||||||
|
)
|
||||||
|
propagation_thread.start()
|
||||||
|
propagation_thread.join()
|
||||||
|
|
||||||
|
payload_thread = threading.Thread(
|
||||||
|
target=self._run_payloads, args=(config["payloads"],), daemon=True
|
||||||
|
)
|
||||||
|
payload_thread.start()
|
||||||
|
payload_thread.join()
|
||||||
|
|
||||||
|
pba_thread.join()
|
||||||
|
|
||||||
|
# TODO: This code is just for testing in development. Remove when
|
||||||
|
# implementation of AutomatedMaster is finished.
|
||||||
|
while True:
|
||||||
|
time.sleep(2)
|
||||||
|
logger.debug("Simulation thread is finished sleeping")
|
||||||
|
if self._stop.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
def _collect_system_info(self, enabled_collectors: List[str]):
|
||||||
|
logger.info("Running system info collectors")
|
||||||
|
|
||||||
|
for collector in enabled_collectors:
|
||||||
|
if self._stop.is_set():
|
||||||
|
logger.debug("Received a stop signal, skipping remaining system info collectors")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"Running system info collector: {collector}")
|
||||||
|
|
||||||
|
system_info_telemetry = {}
|
||||||
|
system_info_telemetry[collector] = self._puppet.run_sys_info_collector(collector)
|
||||||
|
self._telemetry_messenger.send_telemetry(
|
||||||
|
SystemInfoTelem({"collectors": system_info_telemetry})
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Finished running system info collectors")
|
||||||
|
|
||||||
|
def _run_pbas(self, enabled_pbas: List[str]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _can_propagate(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _propagate(self, config: Dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run_payloads(self, enabled_payloads: Dict[str, Dict]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
pass
|
|
@ -40,8 +40,7 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
|
||||||
self._period = period
|
self._period = period
|
||||||
|
|
||||||
self._should_run_batch_thread = True
|
self._should_run_batch_thread = True
|
||||||
# TODO: Create a "timer" or "countdown" class and inject an object instead of
|
# TODO: Replace with infection_monkey.utils.timer.Timer
|
||||||
# using time.time()
|
|
||||||
self._last_sent_time = time.time()
|
self._last_sent_time = time.time()
|
||||||
self._telemetry_batches: Dict[str, IBatchableTelem] = {}
|
self._telemetry_batches: Dict[str, IBatchableTelem] = {}
|
||||||
|
|
||||||
|
|
|
@ -173,6 +173,7 @@ class MonkeyTunnel(Thread):
|
||||||
|
|
||||||
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
|
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
|
||||||
# QUIT_TIMEOUT seconds
|
# QUIT_TIMEOUT seconds
|
||||||
|
# TODO: Replace with infection_monkey.utils.timer.Timer
|
||||||
while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT):
|
while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT):
|
||||||
try:
|
try:
|
||||||
search, address = self._broad_sock.recvfrom(BUFFER_READ)
|
search, address = self._broad_sock.recvfrom(BUFFER_READ)
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
"""
|
||||||
|
A class for checking whether or not a certain amount of time has elapsed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._timeout_sec = 0
|
||||||
|
self._start_time = 0
|
||||||
|
|
||||||
|
def set(self, timeout_sec: float):
|
||||||
|
"""
|
||||||
|
Set a timer
|
||||||
|
:param float timeout_sec: A fractional number of seconds to set the timeout for.
|
||||||
|
"""
|
||||||
|
self._timeout_sec = timeout_sec
|
||||||
|
self._start_time = time.time()
|
||||||
|
|
||||||
|
def is_expired(self):
|
||||||
|
"""
|
||||||
|
Check whether or not the timer has expired
|
||||||
|
:return: True if the elapsed time since set(TIMEOUT_SEC) was called is greater than
|
||||||
|
TIMEOUT_SEC, False otherwise
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return (time.time() - self._start_time) >= self._timeout_sec
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the timer without changing the timeout
|
||||||
|
"""
|
||||||
|
self._start_time = time.time()
|
|
@ -0,0 +1,7 @@
|
||||||
|
from infection_monkey.master.automated_master import AutomatedMaster
|
||||||
|
|
||||||
|
def test_terminate_without_start():
|
||||||
|
m = AutomatedMaster(None, None, None)
|
||||||
|
|
||||||
|
# Test that call to terminate does not raise exception
|
||||||
|
m.terminate()
|
|
@ -0,0 +1,69 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from infection_monkey.utils.timer import Timer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def start_time(set_current_time):
|
||||||
|
start_time = 100
|
||||||
|
set_current_time(start_time)
|
||||||
|
|
||||||
|
return start_time
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def set_current_time(monkeypatch):
|
||||||
|
def inner(current_time):
|
||||||
|
monkeypatch.setattr(time, "time", lambda: current_time)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("timeout"), [5, 1.25])
|
||||||
|
def test_timer_not_expired(start_time, set_current_time, timeout):
|
||||||
|
t = Timer()
|
||||||
|
t.set(timeout)
|
||||||
|
|
||||||
|
assert not t.is_expired()
|
||||||
|
|
||||||
|
set_current_time(start_time + (timeout - 0.001))
|
||||||
|
assert not t.is_expired()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("timeout"), [5, 1.25])
|
||||||
|
def test_timer_expired(start_time, set_current_time, timeout):
|
||||||
|
t = Timer()
|
||||||
|
t.set(timeout)
|
||||||
|
|
||||||
|
assert not t.is_expired()
|
||||||
|
|
||||||
|
set_current_time(start_time + timeout)
|
||||||
|
assert t.is_expired()
|
||||||
|
|
||||||
|
set_current_time(start_time + timeout + 0.001)
|
||||||
|
assert t.is_expired()
|
||||||
|
|
||||||
|
|
||||||
|
def test_unset_timer_expired():
|
||||||
|
t = Timer()
|
||||||
|
|
||||||
|
assert t.is_expired()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("timeout"), [5, 1.25])
|
||||||
|
def test_timer_reset(start_time, set_current_time, timeout):
|
||||||
|
t = Timer()
|
||||||
|
t.set(timeout)
|
||||||
|
|
||||||
|
assert not t.is_expired()
|
||||||
|
|
||||||
|
set_current_time(start_time + timeout)
|
||||||
|
assert t.is_expired()
|
||||||
|
|
||||||
|
t.reset()
|
||||||
|
assert not t.is_expired()
|
||||||
|
|
||||||
|
set_current_time(start_time + (2 * timeout))
|
||||||
|
assert t.is_expired()
|
Loading…
Reference in New Issue