Agent: Pass current depth to AutomatedMaster

This commit is contained in:
Mike Salvatore 2022-03-07 08:46:10 -05:00
parent 7cae4d6dec
commit 524b97078d
3 changed files with 11 additions and 14 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
import threading import threading
import time import time
from typing import Any, Callable, Dict, Iterable, List, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.i_master import IMaster from infection_monkey.i_master import IMaster
@ -30,12 +30,14 @@ logger = logging.getLogger()
class AutomatedMaster(IMaster): class AutomatedMaster(IMaster):
def __init__( def __init__(
self, self,
current_depth: Optional[int],
puppet: IPuppet, puppet: IPuppet,
telemetry_messenger: ITelemetryMessenger, telemetry_messenger: ITelemetryMessenger,
victim_host_factory: VictimHostFactory, victim_host_factory: VictimHostFactory,
control_channel: IControlChannel, control_channel: IControlChannel,
local_network_interfaces: List[NetworkInterface], local_network_interfaces: List[NetworkInterface],
): ):
self._current_depth = current_depth
self._puppet = puppet self._puppet = puppet
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
self._control_channel = control_channel self._control_channel = control_channel
@ -162,7 +164,9 @@ class AutomatedMaster(IMaster):
# still running. # still running.
credential_collector_thread.join() credential_collector_thread.join()
current_depth = config["depth"] current_depth = self._current_depth if self._current_depth is not None else config["depth"]
logger.info(f"Current depth is {current_depth}")
if self._can_propagate() and current_depth > 0: if self._can_propagate() and current_depth > 0:
self._propagator.propagate(config["propagation"], current_depth, self._stop) self._propagator.propagate(config["propagation"], current_depth, self._stop)

View File

@ -68,6 +68,7 @@ class InfectionMonkey:
# TODO used in propogation phase # TODO used in propogation phase
self._monkey_inbound_tunnel = None self._monkey_inbound_tunnel = None
self.telemetry_messenger = LegacyTelemetryMessengerAdapter() self.telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth
@staticmethod @staticmethod
def _get_arguments(args): def _get_arguments(args):
@ -93,7 +94,6 @@ class InfectionMonkey:
logger.info("Monkey is starting...") logger.info("Monkey is starting...")
self._set_propagation_depth(self._opts)
self._add_default_server_to_config(self._opts.server) self._add_default_server_to_config(self._opts.server)
self._connect_to_island() self._connect_to_island()
@ -111,14 +111,6 @@ class InfectionMonkey:
self._setup() self._setup()
self._master.start() self._master.start()
@staticmethod
def _set_propagation_depth(options):
if options.depth is not None:
WormConfiguration._depth_from_commandline = True
WormConfiguration.depth = options.depth
logger.debug("Setting propagation depth from command line")
logger.debug(f"Set propagation depth to {WormConfiguration.depth}")
@staticmethod @staticmethod
def _add_default_server_to_config(default_server: str): def _add_default_server_to_config(default_server: str):
if default_server: if default_server:
@ -179,6 +171,7 @@ class InfectionMonkey:
) )
self._master = AutomatedMaster( self._master = AutomatedMaster(
self._current_depth,
puppet, puppet,
telemetry_messenger, telemetry_messenger,
victim_host_factory, victim_host_factory,

View File

@ -14,7 +14,7 @@ INTERVAL = 0.001
def test_terminate_without_start(): def test_terminate_without_start():
m = AutomatedMaster(None, None, None, MagicMock(), []) m = AutomatedMaster(None, None, None, None, MagicMock(), [])
# Test that call to terminate does not raise exception # Test that call to terminate does not raise exception
m.terminate() m.terminate()
@ -34,7 +34,7 @@ def test_stop_if_cant_get_config_from_island(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL "infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL
) )
m = AutomatedMaster(None, None, None, cc, []) m = AutomatedMaster(None, None, None, None, cc, [])
m.start() m.start()
assert cc.get_config.call_count == CHECK_FOR_CONFIG_COUNT assert cc.get_config.call_count == CHECK_FOR_CONFIG_COUNT
@ -73,7 +73,7 @@ def test_stop_if_cant_get_stop_signal_from_island(monkeypatch, sleep_and_return_
"infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL "infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL
) )
m = AutomatedMaster(None, None, None, cc, []) m = AutomatedMaster(None, None, None, None, cc, [])
m.start() m.start()
assert cc.should_agent_stop.call_count == CHECK_FOR_STOP_AGENT_COUNT assert cc.should_agent_stop.call_count == CHECK_FOR_STOP_AGENT_COUNT