Merge pull request #1637 from guardicore/1597-implement-automated-master

1597 implement automated master - Part 1
This commit is contained in:
Mike Salvatore 2021-12-03 08:55:41 -05:00 committed by GitHub
commit a1601f120f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 274 additions and 2 deletions

View File

@ -32,6 +32,7 @@ class SSHExploiter(HostExploiter):
self.skip_exist = self._config.skip_exploit_if_file_exist
def log_transfer(self, transferred, total):
# TODO: Replace with infection_monkey.utils.timer.Timer
if time.time() - self._update_timestamp > TRANSFER_UPDATE_RATE:
logger.debug("SFTP transferred: %d bytes, total: %d bytes", transferred, total)
self._update_timestamp = time.time()

View File

@ -0,0 +1,161 @@
import logging
import threading
import time
from typing import Dict, List
from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.telemetry.system_info_telem import SystemInfoTelem
from infection_monkey.utils.timer import Timer
CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC = 5
CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / 5
SHUTDOWN_TIMEOUT = 5
logger = logging.getLogger()
class AutomatedMaster(IMaster):
def __init__(
self,
puppet: IPuppet,
telemetry_messenger: ITelemetryMessenger,
control_channel: IControlChannel,
):
self._puppet = puppet
self._telemetry_messenger = telemetry_messenger
self._control_channel = control_channel
self._stop = threading.Event()
self._master_thread = threading.Thread(target=self._run_master_thread, daemon=True)
self._simulation_thread = threading.Thread(target=self._run_simulation, daemon=True)
def start(self):
logger.info("Starting automated breach and attack simulation")
self._master_thread.start()
self._master_thread.join()
logger.info("The simulation has been shutdown.")
def terminate(self):
logger.info("Stopping automated breach and attack simulation")
self._stop.set()
if self._master_thread.is_alive():
self._master_thread.join()
def _run_master_thread(self):
self._simulation_thread.start()
self._wait_for_master_stop_condition()
logger.debug("Waiting for the simulation thread to stop")
self._simulation_thread.join(SHUTDOWN_TIMEOUT)
if self._simulation_thread.is_alive():
logger.warning("Timed out waiting for the simulation to stop")
# Since the master thread and all child threads are daemon threads, they will be
# forcefully killed when the program exits.
# TODO: Daemon threads to not die when the parent THREAD does, but when the parent
# PROCESS does. This could lead to conflicts between threads that refuse to die
# and the cleanup() function. Come up with a solution.
logger.warning("Forcefully killing the simulation")
def _wait_for_master_stop_condition(self):
timer = Timer()
timer.set(CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC)
while self._master_thread_should_run():
if timer.is_expired():
# TODO: Handle exceptions in _check_for_stop() once
# ControlChannel.should_agent_stop() is refactored.
self._check_for_stop()
timer.reset()
time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC)
def _check_for_stop(self):
if self._control_channel.should_agent_stop():
logger.debug('Received the "stop" signal from the Island')
self._stop.set()
def _master_thread_should_run(self):
return (not self._stop.is_set()) and self._simulation_thread.is_alive()
def _run_simulation(self):
config = self._control_channel.get_config()
system_info_collector_thread = threading.Thread(
target=self._collect_system_info,
args=(config["system_info_collector_classes"],),
daemon=True,
)
pba_thread = threading.Thread(
target=self._run_pbas, args=(config["post_breach_actions"],), daemon=True
)
system_info_collector_thread.start()
pba_thread.start()
# Future stages of the simulation require the output of the system info collectors. Nothing
# requires the output of PBAs, so we don't need to join on that thread here. We will join on
# the PBA thread later in this function to prevent the simulation from ending while PBAs are
# still running.
system_info_collector_thread.join()
if self._can_propagate():
propagation_thread = threading.Thread(
target=self._propagate, args=(config,), daemon=True
)
propagation_thread.start()
propagation_thread.join()
payload_thread = threading.Thread(
target=self._run_payloads, args=(config["payloads"],), daemon=True
)
payload_thread.start()
payload_thread.join()
pba_thread.join()
# TODO: This code is just for testing in development. Remove when
# implementation of AutomatedMaster is finished.
while True:
time.sleep(2)
logger.debug("Simulation thread is finished sleeping")
if self._stop.is_set():
break
def _collect_system_info(self, enabled_collectors: List[str]):
logger.info("Running system info collectors")
for collector in enabled_collectors:
if self._stop.is_set():
logger.debug("Received a stop signal, skipping remaining system info collectors")
break
logger.info(f"Running system info collector: {collector}")
system_info_telemetry = {}
system_info_telemetry[collector] = self._puppet.run_sys_info_collector(collector)
self._telemetry_messenger.send_telemetry(
SystemInfoTelem({"collectors": system_info_telemetry})
)
logger.info("Finished running system info collectors")
def _run_pbas(self, enabled_pbas: List[str]):
pass
def _can_propagate(self):
return True
def _propagate(self, config: Dict):
pass
def _run_payloads(self, enabled_payloads: Dict[str, Dict]):
pass
def cleanup(self):
pass

View File

@ -40,8 +40,7 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
self._period = period
self._should_run_batch_thread = True
# TODO: Create a "timer" or "countdown" class and inject an object instead of
# using time.time()
# TODO: Replace with infection_monkey.utils.timer.Timer
self._last_sent_time = time.time()
self._telemetry_batches: Dict[str, IBatchableTelem] = {}

View File

@ -173,6 +173,7 @@ class MonkeyTunnel(Thread):
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
# QUIT_TIMEOUT seconds
# TODO: Replace with infection_monkey.utils.timer.Timer
while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT):
try:
search, address = self._broad_sock.recvfrom(BUFFER_READ)

View File

@ -0,0 +1,34 @@
import time
class Timer:
"""
A class for checking whether or not a certain amount of time has elapsed.
"""
def __init__(self):
self._timeout_sec = 0
self._start_time = 0
def set(self, timeout_sec: float):
"""
Set a timer
:param float timeout_sec: A fractional number of seconds to set the timeout for.
"""
self._timeout_sec = timeout_sec
self._start_time = time.time()
def is_expired(self):
"""
Check whether or not the timer has expired
:return: True if the elapsed time since set(TIMEOUT_SEC) was called is greater than
TIMEOUT_SEC, False otherwise
:rtype: bool
"""
return (time.time() - self._start_time) >= self._timeout_sec
def reset(self):
"""
Reset the timer without changing the timeout
"""
self._start_time = time.time()

View File

@ -0,0 +1,7 @@
from infection_monkey.master.automated_master import AutomatedMaster
def test_terminate_without_start():
m = AutomatedMaster(None, None, None)
# Test that call to terminate does not raise exception
m.terminate()

View File

@ -0,0 +1,69 @@
import time
import pytest
from infection_monkey.utils.timer import Timer
@pytest.fixture
def start_time(set_current_time):
start_time = 100
set_current_time(start_time)
return start_time
@pytest.fixture
def set_current_time(monkeypatch):
def inner(current_time):
monkeypatch.setattr(time, "time", lambda: current_time)
return inner
@pytest.mark.parametrize(("timeout"), [5, 1.25])
def test_timer_not_expired(start_time, set_current_time, timeout):
t = Timer()
t.set(timeout)
assert not t.is_expired()
set_current_time(start_time + (timeout - 0.001))
assert not t.is_expired()
@pytest.mark.parametrize(("timeout"), [5, 1.25])
def test_timer_expired(start_time, set_current_time, timeout):
t = Timer()
t.set(timeout)
assert not t.is_expired()
set_current_time(start_time + timeout)
assert t.is_expired()
set_current_time(start_time + timeout + 0.001)
assert t.is_expired()
def test_unset_timer_expired():
t = Timer()
assert t.is_expired()
@pytest.mark.parametrize(("timeout"), [5, 1.25])
def test_timer_reset(start_time, set_current_time, timeout):
t = Timer()
t.set(timeout)
assert not t.is_expired()
set_current_time(start_time + timeout)
assert t.is_expired()
t.reset()
assert not t.is_expired()
set_current_time(start_time + (2 * timeout))
assert t.is_expired()