Agent: Implement should retry task in automated master
Add handling of known requests exceptions in ControlClient.
This commit is contained in:
parent
72a5e94111
commit
b53fae038d
|
@ -58,5 +58,5 @@ class InvalidConfigurationError(Exception):
|
|||
""" Raise when configuration is invalid """
|
||||
|
||||
|
||||
class ControlClientConnectionError(Exception):
|
||||
class IslandCommunicationError(Exception):
|
||||
"""Raise when unable to connect to control client"""
|
||||
|
|
|
@ -3,7 +3,7 @@ import threading
|
|||
import time
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
from common.utils.exceptions import ControlClientConnectionError
|
||||
from common.utils.exceptions import IslandCommunicationError
|
||||
from infection_monkey.i_control_channel import IControlChannel
|
||||
from infection_monkey.i_master import IMaster
|
||||
from infection_monkey.i_puppet import IPuppet
|
||||
|
@ -21,7 +21,8 @@ CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC /
|
|||
SHUTDOWN_TIMEOUT = 5
|
||||
NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads
|
||||
NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads
|
||||
STOP_AGENT_TIMEOUT = 5
|
||||
CHECK_FOR_STOP_AGENT_COUNT = 5
|
||||
CHECK_FOR_CONFIG_COUNT = 1
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
@ -48,6 +49,7 @@ class AutomatedMaster(IMaster):
|
|||
self._master_thread = create_daemon_thread(target=self._run_master_thread)
|
||||
self._simulation_thread = create_daemon_thread(target=self._run_simulation)
|
||||
self._failed_stop = 0
|
||||
self._failed_config = 0
|
||||
|
||||
def start(self):
|
||||
logger.info("Starting automated breach and attack simulation")
|
||||
|
@ -95,26 +97,38 @@ class AutomatedMaster(IMaster):
|
|||
|
||||
time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC)
|
||||
|
||||
def _should_retry_task(self, fn: Callable[[], Any], max_tries: int):
|
||||
tries = 0
|
||||
while tries < max_tries:
|
||||
try:
|
||||
return fn()
|
||||
except IslandCommunicationError as e:
|
||||
tries += 1
|
||||
logger.debug(f"{e}. Retries left: {max_tries-tries}")
|
||||
if tries >= max_tries:
|
||||
raise e
|
||||
|
||||
def _check_for_stop(self):
|
||||
try:
|
||||
if self._control_channel.should_agent_stop():
|
||||
stop = self._should_retry_task(
|
||||
self._control_channel.should_agent_stop, CHECK_FOR_STOP_AGENT_COUNT
|
||||
)
|
||||
if stop:
|
||||
logger.debug('Received the "stop" signal from the Island')
|
||||
self._failed_stop = 0
|
||||
self._stop.set()
|
||||
except ControlClientConnectionError as e:
|
||||
self._failed_stop += 1
|
||||
if self._failed_stop > STOP_AGENT_TIMEOUT:
|
||||
logger.error(f"An error occurred while trying to check for agent stop: {e}")
|
||||
self._stop.set()
|
||||
except IslandCommunicationError as e:
|
||||
logger.error(f"An error occurred while trying to check for agent stop: {e}")
|
||||
self._stop.set()
|
||||
|
||||
def _master_thread_should_run(self):
|
||||
return (not self._stop.is_set()) and self._simulation_thread.is_alive()
|
||||
|
||||
def _run_simulation(self):
|
||||
|
||||
try:
|
||||
config = self._control_channel.get_config()["config"]
|
||||
except ControlClientConnectionError as e:
|
||||
config = self._should_retry_task(
|
||||
self._control_channel.get_config, CHECK_FOR_CONFIG_COUNT
|
||||
)["config"]
|
||||
except IslandCommunicationError as e:
|
||||
logger.error(f"An error occurred while fetching configuration: {e}")
|
||||
return
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import requests
|
||||
|
||||
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
|
||||
from common.utils.exceptions import ControlClientConnectionError
|
||||
from common.utils.exceptions import IslandCommunicationError
|
||||
from infection_monkey.config import WormConfiguration
|
||||
from infection_monkey.control import ControlClient
|
||||
from infection_monkey.i_control_channel import IControlChannel
|
||||
|
@ -37,10 +37,14 @@ class ControlChannel(IControlChannel):
|
|||
|
||||
response = json.loads(response.content.decode())
|
||||
return response["stop_agent"]
|
||||
except Exception as e:
|
||||
raise ControlClientConnectionError(
|
||||
f"An error occurred while trying to connect to server: {e}"
|
||||
)
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.TooManyRedirects,
|
||||
requests.exceptions.HTTPError,
|
||||
) as e:
|
||||
raise IslandCommunicationError(e)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
try:
|
||||
|
@ -50,12 +54,17 @@ class ControlChannel(IControlChannel):
|
|||
proxies=ControlClient.proxies,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return json.loads(response.content.decode())
|
||||
except Exception as exc:
|
||||
raise ControlClientConnectionError(
|
||||
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
|
||||
)
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.TooManyRedirects,
|
||||
requests.exceptions.HTTPError,
|
||||
) as e:
|
||||
raise IslandCommunicationError(e)
|
||||
|
||||
def get_credentials_for_propagation(self) -> dict:
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue