Agent: Raise custom control client exception

Move stop agent timeout to a constant, make custom control
client exception and raise it, reset failed stop after successfull
connection.
This commit is contained in:
Ilija Lazoroski 2021-12-15 13:36:34 +01:00
parent f299e61b20
commit 72a5e94111
3 changed files with 16 additions and 5 deletions

View File

@ -56,3 +56,7 @@ class DomainControllerNameFetchError(FailedExploitationError):
class InvalidConfigurationError(Exception): class InvalidConfigurationError(Exception):
""" Raise when configuration is invalid """ """ Raise when configuration is invalid """
class ControlClientConnectionError(Exception):
"""Raise when unable to connect to control client"""

View File

@ -3,6 +3,7 @@ import threading
import time import time
from typing import Any, Callable, Dict, List, Tuple from typing import Any, Callable, Dict, List, Tuple
from common.utils.exceptions import ControlClientConnectionError
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
from infection_monkey.i_puppet import IPuppet from infection_monkey.i_puppet import IPuppet
@ -20,6 +21,7 @@ CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC /
SHUTDOWN_TIMEOUT = 5 SHUTDOWN_TIMEOUT = 5
NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads
NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads
STOP_AGENT_TIMEOUT = 5
logger = logging.getLogger() logger = logging.getLogger()
@ -45,6 +47,7 @@ class AutomatedMaster(IMaster):
self._stop = threading.Event() self._stop = threading.Event()
self._master_thread = create_daemon_thread(target=self._run_master_thread) self._master_thread = create_daemon_thread(target=self._run_master_thread)
self._simulation_thread = create_daemon_thread(target=self._run_simulation) self._simulation_thread = create_daemon_thread(target=self._run_simulation)
self._failed_stop = 0
def start(self): def start(self):
logger.info("Starting automated breach and attack simulation") logger.info("Starting automated breach and attack simulation")
@ -96,10 +99,11 @@ class AutomatedMaster(IMaster):
try: try:
if self._control_channel.should_agent_stop(): if self._control_channel.should_agent_stop():
logger.debug('Received the "stop" signal from the Island') logger.debug('Received the "stop" signal from the Island')
self._failed_stop = 0
self._stop.set() self._stop.set()
except Exception as e: except ControlClientConnectionError as e:
self._failed_stop += 1 self._failed_stop += 1
if self._failed_stop > 5: if self._failed_stop > STOP_AGENT_TIMEOUT:
logger.error(f"An error occurred while trying to check for agent stop: {e}") logger.error(f"An error occurred while trying to check for agent stop: {e}")
self._stop.set() self._stop.set()
@ -110,7 +114,7 @@ class AutomatedMaster(IMaster):
try: try:
config = self._control_channel.get_config()["config"] config = self._control_channel.get_config()["config"]
except Exception as e: except ControlClientConnectionError as e:
logger.error(f"An error occurred while fetching configuration: {e}") logger.error(f"An error occurred while fetching configuration: {e}")
return return

View File

@ -4,6 +4,7 @@ import logging
import requests import requests
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from common.utils.exceptions import ControlClientConnectionError
from infection_monkey.config import WormConfiguration from infection_monkey.config import WormConfiguration
from infection_monkey.control import ControlClient from infection_monkey.control import ControlClient
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
@ -37,7 +38,9 @@ class ControlChannel(IControlChannel):
response = json.loads(response.content.decode()) response = json.loads(response.content.decode())
return response["stop_agent"] return response["stop_agent"]
except Exception as e: except Exception as e:
raise Exception(f"An error occurred while trying to connect to server. {e}") raise ControlClientConnectionError(
f"An error occurred while trying to connect to server: {e}"
)
def get_config(self) -> dict: def get_config(self) -> dict:
try: try:
@ -50,7 +53,7 @@ class ControlChannel(IControlChannel):
return json.loads(response.content.decode()) return json.loads(response.content.decode())
except Exception as exc: except Exception as exc:
raise Exception( raise ControlClientConnectionError(
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc "Error connecting to control server %s: %s", WormConfiguration.current_server, exc
) )