Agent: Implement should retry task in automated master

Add handling of known requests exceptions in ControlClient.
This commit is contained in:
Ilija Lazoroski 2021-12-15 15:43:38 +01:00
parent 72a5e94111
commit b53fae038d
3 changed files with 45 additions and 22 deletions

View File

@ -58,5 +58,5 @@ class InvalidConfigurationError(Exception):
""" Raise when configuration is invalid """ """ Raise when configuration is invalid """
class ControlClientConnectionError(Exception): class IslandCommunicationError(Exception):
"""Raise when unable to connect to control client""" """Raise when unable to connect to control client"""

View File

@ -3,7 +3,7 @@ import threading
import time import time
from typing import Any, Callable, Dict, List, Tuple from typing import Any, Callable, Dict, List, Tuple
from common.utils.exceptions import ControlClientConnectionError from common.utils.exceptions import IslandCommunicationError
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet from infection_monkey.i_puppet import IPuppet
@ -21,7 +21,8 @@ CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC /
SHUTDOWN_TIMEOUT = 5 SHUTDOWN_TIMEOUT = 5
NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads
NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads
STOP_AGENT_TIMEOUT = 5 CHECK_FOR_STOP_AGENT_COUNT = 5
CHECK_FOR_CONFIG_COUNT = 1
logger = logging.getLogger() logger = logging.getLogger()
@ -48,6 +49,7 @@ class AutomatedMaster(IMaster):
self._master_thread = create_daemon_thread(target=self._run_master_thread) self._master_thread = create_daemon_thread(target=self._run_master_thread)
self._simulation_thread = create_daemon_thread(target=self._run_simulation) self._simulation_thread = create_daemon_thread(target=self._run_simulation)
self._failed_stop = 0 self._failed_stop = 0
self._failed_config = 0
def start(self): def start(self):
logger.info("Starting automated breach and attack simulation") logger.info("Starting automated breach and attack simulation")
@ -95,15 +97,26 @@ class AutomatedMaster(IMaster):
time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC) time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC)
def _should_retry_task(self, fn: Callable[[], Any], max_tries: int):
tries = 0
while tries < max_tries:
try:
return fn()
except IslandCommunicationError as e:
tries += 1
logger.debug(f"{e}. Retries left: {max_tries-tries}")
if tries >= max_tries:
raise e
def _check_for_stop(self): def _check_for_stop(self):
try: try:
if self._control_channel.should_agent_stop(): stop = self._should_retry_task(
self._control_channel.should_agent_stop, CHECK_FOR_STOP_AGENT_COUNT
)
if stop:
logger.debug('Received the "stop" signal from the Island') logger.debug('Received the "stop" signal from the Island')
self._failed_stop = 0
self._stop.set() self._stop.set()
except ControlClientConnectionError as e: except IslandCommunicationError as e:
self._failed_stop += 1
if self._failed_stop > STOP_AGENT_TIMEOUT:
logger.error(f"An error occurred while trying to check for agent stop: {e}") logger.error(f"An error occurred while trying to check for agent stop: {e}")
self._stop.set() self._stop.set()
@ -111,10 +124,11 @@ class AutomatedMaster(IMaster):
return (not self._stop.is_set()) and self._simulation_thread.is_alive() return (not self._stop.is_set()) and self._simulation_thread.is_alive()
def _run_simulation(self): def _run_simulation(self):
try: try:
config = self._control_channel.get_config()["config"] config = self._should_retry_task(
except ControlClientConnectionError as e: self._control_channel.get_config, CHECK_FOR_CONFIG_COUNT
)["config"]
except IslandCommunicationError as e:
logger.error(f"An error occurred while fetching configuration: {e}") logger.error(f"An error occurred while fetching configuration: {e}")
return return

View File

@ -4,7 +4,7 @@ import logging
import requests import requests
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from common.utils.exceptions import ControlClientConnectionError from common.utils.exceptions import IslandCommunicationError
from infection_monkey.config import WormConfiguration from infection_monkey.config import WormConfiguration
from infection_monkey.control import ControlClient from infection_monkey.control import ControlClient
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
@ -37,10 +37,14 @@ class ControlChannel(IControlChannel):
response = json.loads(response.content.decode()) response = json.loads(response.content.decode())
return response["stop_agent"] return response["stop_agent"]
except Exception as e: except (
raise ControlClientConnectionError( json.JSONDecodeError,
f"An error occurred while trying to connect to server: {e}" requests.exceptions.ConnectionError,
) requests.exceptions.Timeout,
requests.exceptions.TooManyRedirects,
requests.exceptions.HTTPError,
) as e:
raise IslandCommunicationError(e)
def get_config(self) -> dict: def get_config(self) -> dict:
try: try:
@ -50,12 +54,17 @@ class ControlChannel(IControlChannel):
proxies=ControlClient.proxies, proxies=ControlClient.proxies,
timeout=SHORT_REQUEST_TIMEOUT, timeout=SHORT_REQUEST_TIMEOUT,
) )
response.raise_for_status()
return json.loads(response.content.decode()) return json.loads(response.content.decode())
except Exception as exc: except (
raise ControlClientConnectionError( json.JSONDecodeError,
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc requests.exceptions.ConnectionError,
) requests.exceptions.Timeout,
requests.exceptions.TooManyRedirects,
requests.exceptions.HTTPError,
) as e:
raise IslandCommunicationError(e)
def get_credentials_for_propagation(self) -> dict: def get_credentials_for_propagation(self) -> dict:
try: try: