forked from p15670423/monkey
Agent: Raise custom control client exception
Move stop agent timeout to a constant, make custom control client exception and raise it, reset failed stop after successfull connection.
This commit is contained in:
parent
f299e61b20
commit
72a5e94111
|
@ -56,3 +56,7 @@ class DomainControllerNameFetchError(FailedExploitationError):
|
||||||
|
|
||||||
class InvalidConfigurationError(Exception):
|
class InvalidConfigurationError(Exception):
|
||||||
""" Raise when configuration is invalid """
|
""" Raise when configuration is invalid """
|
||||||
|
|
||||||
|
|
||||||
|
class ControlClientConnectionError(Exception):
|
||||||
|
"""Raise when unable to connect to control client"""
|
||||||
|
|
|
@ -3,6 +3,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, List, Tuple
|
from typing import Any, Callable, Dict, List, Tuple
|
||||||
|
|
||||||
|
from common.utils.exceptions import ControlClientConnectionError
|
||||||
from infection_monkey.i_control_channel import IControlChannel
|
from infection_monkey.i_control_channel import IControlChannel
|
||||||
from infection_monkey.i_master import IMaster
|
from infection_monkey.i_master import IMaster
|
||||||
from infection_monkey.i_puppet import IPuppet
|
from infection_monkey.i_puppet import IPuppet
|
||||||
|
@ -20,6 +21,7 @@ CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC /
|
||||||
SHUTDOWN_TIMEOUT = 5
|
SHUTDOWN_TIMEOUT = 5
|
||||||
NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads
|
NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads
|
||||||
NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads
|
NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads
|
||||||
|
STOP_AGENT_TIMEOUT = 5
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
@ -45,6 +47,7 @@ class AutomatedMaster(IMaster):
|
||||||
self._stop = threading.Event()
|
self._stop = threading.Event()
|
||||||
self._master_thread = create_daemon_thread(target=self._run_master_thread)
|
self._master_thread = create_daemon_thread(target=self._run_master_thread)
|
||||||
self._simulation_thread = create_daemon_thread(target=self._run_simulation)
|
self._simulation_thread = create_daemon_thread(target=self._run_simulation)
|
||||||
|
self._failed_stop = 0
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
logger.info("Starting automated breach and attack simulation")
|
logger.info("Starting automated breach and attack simulation")
|
||||||
|
@ -96,10 +99,11 @@ class AutomatedMaster(IMaster):
|
||||||
try:
|
try:
|
||||||
if self._control_channel.should_agent_stop():
|
if self._control_channel.should_agent_stop():
|
||||||
logger.debug('Received the "stop" signal from the Island')
|
logger.debug('Received the "stop" signal from the Island')
|
||||||
|
self._failed_stop = 0
|
||||||
self._stop.set()
|
self._stop.set()
|
||||||
except Exception as e:
|
except ControlClientConnectionError as e:
|
||||||
self._failed_stop += 1
|
self._failed_stop += 1
|
||||||
if self._failed_stop > 5:
|
if self._failed_stop > STOP_AGENT_TIMEOUT:
|
||||||
logger.error(f"An error occurred while trying to check for agent stop: {e}")
|
logger.error(f"An error occurred while trying to check for agent stop: {e}")
|
||||||
self._stop.set()
|
self._stop.set()
|
||||||
|
|
||||||
|
@ -110,7 +114,7 @@ class AutomatedMaster(IMaster):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = self._control_channel.get_config()["config"]
|
config = self._control_channel.get_config()["config"]
|
||||||
except Exception as e:
|
except ControlClientConnectionError as e:
|
||||||
logger.error(f"An error occurred while fetching configuration: {e}")
|
logger.error(f"An error occurred while fetching configuration: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import logging
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
|
||||||
|
from common.utils.exceptions import ControlClientConnectionError
|
||||||
from infection_monkey.config import WormConfiguration
|
from infection_monkey.config import WormConfiguration
|
||||||
from infection_monkey.control import ControlClient
|
from infection_monkey.control import ControlClient
|
||||||
from infection_monkey.i_control_channel import IControlChannel
|
from infection_monkey.i_control_channel import IControlChannel
|
||||||
|
@ -37,7 +38,9 @@ class ControlChannel(IControlChannel):
|
||||||
response = json.loads(response.content.decode())
|
response = json.loads(response.content.decode())
|
||||||
return response["stop_agent"]
|
return response["stop_agent"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred while trying to connect to server. {e}")
|
raise ControlClientConnectionError(
|
||||||
|
f"An error occurred while trying to connect to server: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_config(self) -> dict:
|
def get_config(self) -> dict:
|
||||||
try:
|
try:
|
||||||
|
@ -50,7 +53,7 @@ class ControlChannel(IControlChannel):
|
||||||
|
|
||||||
return json.loads(response.content.decode())
|
return json.loads(response.content.decode())
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise Exception(
|
raise ControlClientConnectionError(
|
||||||
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
|
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue