diff --git a/monkey/infection_monkey/i_control_channel.py b/monkey/infection_monkey/i_control_channel.py index eb1a4d5b2..33539417c 100644 --- a/monkey/infection_monkey/i_control_channel.py +++ b/monkey/infection_monkey/i_control_channel.py @@ -25,3 +25,7 @@ class IControlChannel(metaclass=abc.ABCMeta): :rtype: dict """ pass + + +class IslandCommunicationError(Exception): + """Raise when unable to connect to control client""" diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py index 75a5a5dbf..8c95d529b 100644 --- a/monkey/infection_monkey/master/automated_master.py +++ b/monkey/infection_monkey/master/automated_master.py @@ -3,7 +3,7 @@ import threading import time from typing import Any, Callable, Dict, List, Tuple -from infection_monkey.i_control_channel import IControlChannel +from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_master import IMaster from infection_monkey.i_puppet import IPuppet from infection_monkey.model import VictimHostFactory @@ -20,6 +20,8 @@ CHECK_FOR_TERMINATE_INTERVAL_SEC = CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC / SHUTDOWN_TIMEOUT = 5 NUM_SCAN_THREADS = 16 # TODO: Adjust this to the optimal number of scan threads NUM_EXPLOIT_THREADS = 4 # TODO: Adjust this to the optimal number of exploit threads +CHECK_FOR_STOP_AGENT_COUNT = 5 +CHECK_FOR_CONFIG_COUNT = 3 logger = logging.getLogger() @@ -85,23 +87,46 @@ class AutomatedMaster(IMaster): while self._master_thread_should_run(): if timer.is_expired(): - # TODO: Handle exceptions in _check_for_stop() once - # ControlChannel.should_agent_stop() is refactored. self._check_for_stop() timer.reset() time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC) + @staticmethod + def _try_communicate_with_island(fn: Callable[[], Any], max_tries: int): + tries = 0 + while tries < max_tries: + try: + return fn() + except IslandCommunicationError as e: + tries += 1 + logger.debug(f"{e}. Retries left: {max_tries-tries}") + if tries >= max_tries: + raise e + def _check_for_stop(self): - if self._control_channel.should_agent_stop(): - logger.debug('Received the "stop" signal from the Island') + try: + stop = AutomatedMaster._try_communicate_with_island( + self._control_channel.should_agent_stop, CHECK_FOR_STOP_AGENT_COUNT + ) + if stop: + logger.info('Received the "stop" signal from the Island') + self._stop.set() + except IslandCommunicationError as e: + logger.error(f"An error occurred while trying to check for agent stop: {e}") self._stop.set() def _master_thread_should_run(self): return (not self._stop.is_set()) and self._simulation_thread.is_alive() def _run_simulation(self): - config = self._control_channel.get_config()["config"] + try: + config = AutomatedMaster._try_communicate_with_island( + self._control_channel.get_config, CHECK_FOR_CONFIG_COUNT + )["config"] + except IslandCommunicationError as e: + logger.error(f"An error occurred while fetching configuration: {e}") + return system_info_collector_thread = create_daemon_thread( target=self._run_plugins, @@ -137,14 +162,6 @@ class AutomatedMaster(IMaster): pba_thread.join() - # TODO: This code is just for testing in development. Remove when - # implementation of AutomatedMaster is finished. - while True: - time.sleep(2) - logger.debug("Simulation thread is finished sleeping") - if self._stop.is_set(): - break - def _collect_system_info(self, collector: str): system_info_telemetry = {} system_info_telemetry[collector] = self._puppet.run_sys_info_collector(collector) diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index 5fdd03942..52b565d55 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -6,7 +6,7 @@ import requests from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from infection_monkey.config import WormConfiguration from infection_monkey.control import ControlClient -from infection_monkey.i_control_channel import IControlChannel +from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError requests.packages.urllib3.disable_warnings() @@ -33,14 +33,18 @@ class ControlChannel(IControlChannel): proxies=ControlClient.proxies, timeout=SHORT_REQUEST_TIMEOUT, ) + response.raise_for_status() response = json.loads(response.content.decode()) return response["stop_agent"] - except Exception as e: - # TODO: Evaluate how this exception is handled; don't just log and ignore it. - logger.error(f"An error occurred while trying to connect to server. {e}") - - return True + except ( + json.JSONDecodeError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.TooManyRedirects, + requests.exceptions.HTTPError, + ) as e: + raise IslandCommunicationError(e) def get_config(self) -> dict: try: @@ -50,15 +54,17 @@ class ControlChannel(IControlChannel): proxies=ControlClient.proxies, timeout=SHORT_REQUEST_TIMEOUT, ) + response.raise_for_status() return json.loads(response.content.decode()) - except Exception as exc: - # TODO: Evaluate how this exception is handled; don't just log and ignore it. - logger.warning( - "Error connecting to control server %s: %s", WormConfiguration.current_server, exc - ) - - return {} + except ( + json.JSONDecodeError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.TooManyRedirects, + requests.exceptions.HTTPError, + ) as e: + raise IslandCommunicationError(e) def get_credentials_for_propagation(self) -> dict: try: @@ -68,11 +74,15 @@ class ControlChannel(IControlChannel): proxies=ControlClient.proxies, timeout=SHORT_REQUEST_TIMEOUT, ) + response.raise_for_status() response = json.loads(response.content.decode())["propagation_credentials"] return response - except Exception as e: - # TODO: Evaluate how this exception is handled; don't just log and ignore it. - logger.error(f"An error occurred while trying to connect to server. {e}") - - return {} + except ( + json.JSONDecodeError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.TooManyRedirects, + requests.exceptions.HTTPError, + ) as e: + raise IslandCommunicationError(e) diff --git a/monkey/tests/data_for_tests/monkey_configs/automated_master_config.json b/monkey/tests/data_for_tests/monkey_configs/automated_master_config.json new file mode 100644 index 000000000..4a7816301 --- /dev/null +++ b/monkey/tests/data_for_tests/monkey_configs/automated_master_config.json @@ -0,0 +1,112 @@ +{ + "config": { + "propagation": { + "network_scan": { + "tcp": { + "timeout_ms": 3000, + "ports": [ + 22, + 2222, + 445, + 135, + 3389, + 80, + 8080, + 443, + 8008, + 3306, + 7001, + 8088, + 9200 + ] + }, + "icmp": { + "timeout_ms": 1000 + }, + "fingerprinters": [ + "SMBFinger", + "SSHFinger", + "HTTPFinger", + "MySQLFinger", + "MSSQLFinger", + "ElasticFinger" + ] + }, + "targets": { + "blocked_ips": ["192.168.1.1", "192.168.1.100"], + "inaccessible_subnets": ["10.0.0.0/24", "10.0.10.0/24"], + "local_network_scan": true, + "subnet_scan_list": [ + "192.168.1.50", + "192.168.56.0/24", + "10.0.33.0/30", + "10.0.0.1", + "10.0.0.2" + ] + }, + "exploiters": { + "brute_force": [ + {"name": "MSSQLExploiter", "propagator": true}, + {"name": "PowerShellExploiter", "propagator": true}, + {"name": "SmbExploiter", "propagator": true}, + {"name": "SSHExploiter", "propagator": true}, + {"name": "WmiExploiter", "propagator": true} + ], + "vulnerability": [ + {"name": "DrupalExploiter", "propagator": true}, + {"name": "ElasticGroovyExploiter", "propagator": true}, + {"name": "HadoopExploiter", "propagator": true}, + {"name": "ShellShockExploiter", "propagator": true}, + {"name": "Struts2Exploiter", "propagator": true}, + {"name": "WebLogicExploiter", "propagator": true}, + {"name": "ZerologonExploiter", "propagator": false} + ] + } + }, + "PBA_linux_filename": "", + "PBA_windows_filename": "", + "command_servers": ["10.197.94.72:5000"], + "current_server": "localhost:5000", + "custom_PBA_linux_cmd": "", + "custom_PBA_windows_cmd": "", + "depth": 2, + "dropper_set_date": true, + "exploit_lm_hash_list": ["DEADBEEF", "FACADE"], + "exploit_ntlm_hash_list": ["BEADED", "ACCEDE", "DECADE"], + "exploit_password_list": ["p1", "p2", "p3"], + "exploit_ssh_keys": "hidden", + "exploit_user_list": ["u1", "u2", "u3"], + "exploiter_classes": [], + "max_depth": 2, + "post_breach_actions": { + "CommunicateAsBackdoorUser": {}, + "ModifyShellStartupFiles": {}, + "HiddenFiles": {}, + "TrapCommand": {}, + "ChangeSetuidSetgid": {}, + "ScheduleJobs": {}, + "Timestomping": {}, + "AccountDiscovery": {}, + "Custom": { + "linux_command": "chmod u+x my_exec && ./my_exec", + "windows_cmd": "powershell test_driver.ps1", + "linux_filename": "my_exec", + "windows_filename": "test_driver.ps1" + } + }, + "payloads": { + "ransomware": { + "encryption": { + "directories": {"linux_target_dir": "", "windows_target_dir": ""}, + "enabled": true + }, + "other_behaviors": {"readme": true} + } + }, + "system_info_collector_classes": [ + "AwsCollector", + "ProcessListCollector", + "MimikatzCollector" + ] + } +} diff --git a/monkey/tests/unit_tests/conftest.py b/monkey/tests/unit_tests/conftest.py index 3099263b0..5b83a7e69 100644 --- a/monkey/tests/unit_tests/conftest.py +++ b/monkey/tests/unit_tests/conftest.py @@ -1,6 +1,7 @@ -import os +import json import sys from pathlib import Path +from typing import Callable, Dict import pytest from _pytest.monkeypatch import MonkeyPatch @@ -11,7 +12,7 @@ sys.path.insert(0, MONKEY_BASE_PATH) @pytest.fixture(scope="session") def data_for_tests_dir(pytestconfig): - return Path(os.path.join(pytestconfig.rootdir, "monkey", "tests", "data_for_tests")) + return Path(pytestconfig.rootdir) / "monkey" / "tests" / "data_for_tests" @pytest.fixture(scope="session") @@ -39,3 +40,17 @@ def monkeypatch_session(): monkeypatch_ = MonkeyPatch() yield monkeypatch_ monkeypatch_.undo() + + +@pytest.fixture +def monkey_configs_dir(data_for_tests_dir) -> Path: + return data_for_tests_dir / "monkey_configs" + + +@pytest.fixture +def load_monkey_config(data_for_tests_dir) -> Callable[[str], Dict]: + def inner(filename: str) -> Dict: + config_path = data_for_tests_dir / "monkey_configs" / filename + return json.loads(open(config_path, "r").read()) + + return inner diff --git a/monkey/tests/unit_tests/infection_monkey/conftest.py b/monkey/tests/unit_tests/infection_monkey/conftest.py index 533572f98..14a193112 100644 --- a/monkey/tests/unit_tests/infection_monkey/conftest.py +++ b/monkey/tests/unit_tests/infection_monkey/conftest.py @@ -15,3 +15,8 @@ class TelemetryMessengerSpy(ITelemetryMessenger): @pytest.fixture def telemetry_messenger_spy(): return TelemetryMessengerSpy() + + +@pytest.fixture +def automated_master_config(load_monkey_config): + return load_monkey_config("automated_master_config.json") diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_automated_master.py b/monkey/tests/unit_tests/infection_monkey/master/test_automated_master.py index 0584ca1cd..378eab883 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_automated_master.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_automated_master.py @@ -1,4 +1,14 @@ +import time +from unittest.mock import MagicMock + from infection_monkey.master import AutomatedMaster +from infection_monkey.master.automated_master import ( + CHECK_FOR_CONFIG_COUNT, + CHECK_FOR_STOP_AGENT_COUNT, +) +from infection_monkey.master.control_channel import IslandCommunicationError + +INTERVAL = 0.001 def test_terminate_without_start(): @@ -6,3 +16,52 @@ def test_terminate_without_start(): # Test that call to terminate does not raise exception m.terminate() + + +def test_stop_if_cant_get_config_from_island(monkeypatch): + cc = MagicMock() + cc.should_agent_stop = MagicMock(return_value=False) + cc.get_config = MagicMock( + side_effect=IslandCommunicationError("Failed to communicate with island") + ) + + monkeypatch.setattr( + "infection_monkey.master.automated_master.CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC", + INTERVAL, + ) + monkeypatch.setattr( + "infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL + ) + m = AutomatedMaster(None, None, None, cc) + m.start() + + assert cc.get_config.call_count == CHECK_FOR_CONFIG_COUNT + + +# NOTE: This test is a little bit brittle, and probably needs too much knowlegde of the internals +# of AutomatedMaster. For now, it works and it runs quickly. In the future, if we find that +# this test isn't valuable or it starts causing issues, we can just remove it. +def test_stop_if_cant_get_stop_signal_from_island(monkeypatch, automated_master_config): + cc = MagicMock() + cc.should_agent_stop = MagicMock( + side_effect=IslandCommunicationError("Failed to communicate with island") + ) + # Ensure that should_agent_stop times out before get_config() returns to prevent the + # Propagator's sub-threads from hanging + cc.get_config = MagicMock( + return_value=automated_master_config, + side_effect=lambda: time.sleep(INTERVAL * (CHECK_FOR_STOP_AGENT_COUNT + 1)), + ) + + monkeypatch.setattr( + "infection_monkey.master.automated_master.CHECK_ISLAND_FOR_STOP_COMMAND_INTERVAL_SEC", + INTERVAL, + ) + monkeypatch.setattr( + "infection_monkey.master.automated_master.CHECK_FOR_TERMINATE_INTERVAL_SEC", INTERVAL + ) + + m = AutomatedMaster(None, None, None, cc) + m.start() + + assert cc.should_agent_stop.call_count == CHECK_FOR_STOP_AGENT_COUNT diff --git a/monkey/tests/unit_tests/monkey_island/cc/conftest.py b/monkey/tests/unit_tests/monkey_island/cc/conftest.py index 5777b3492..ba5a2c66e 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/conftest.py +++ b/monkey/tests/unit_tests/monkey_island/cc/conftest.py @@ -1,38 +1,21 @@ # Without these imports pytests can't use fixtures, # because they are not found import json -from typing import Dict import pytest from tests.unit_tests.monkey_island.cc.mongomock_fixtures import * # noqa: F401,F403,E402 -from tests.unit_tests.monkey_island.cc.server_utils.encryption.test_password_based_encryption import ( # noqa: E501 - FLAT_PLAINTEXT_MONKEY_CONFIG_FILENAME, - MONKEY_CONFIGS_DIR_PATH, - STANDARD_PLAINTEXT_MONKEY_CONFIG_FILENAME, -) from monkey_island.cc.server_utils.encryption import unlock_datastore_encryptor -@pytest.fixture -def load_monkey_config(data_for_tests_dir) -> Dict: - def inner(filename: str) -> Dict: - config_path = ( - data_for_tests_dir / MONKEY_CONFIGS_DIR_PATH / FLAT_PLAINTEXT_MONKEY_CONFIG_FILENAME - ) - return json.loads(open(config_path, "r").read()) - - return inner - - @pytest.fixture def monkey_config(load_monkey_config): - return load_monkey_config(STANDARD_PLAINTEXT_MONKEY_CONFIG_FILENAME) + return load_monkey_config("monkey_config_standard.json") @pytest.fixture def flat_monkey_config(load_monkey_config): - return load_monkey_config(FLAT_PLAINTEXT_MONKEY_CONFIG_FILENAME) + return load_monkey_config("flat_config.json") @pytest.fixture diff --git a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_password_based_encryption.py b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_password_based_encryption.py index ce0b46705..038b17ec1 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_password_based_encryption.py +++ b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_password_based_encryption.py @@ -13,9 +13,6 @@ from monkey_island.cc.server_utils.encryption import ( # Mark all tests in this module as slow pytestmark = pytest.mark.slow -MONKEY_CONFIGS_DIR_PATH = "monkey_configs" -STANDARD_PLAINTEXT_MONKEY_CONFIG_FILENAME = "monkey_config_standard.json" -FLAT_PLAINTEXT_MONKEY_CONFIG_FILENAME = "flat_config.json" PASSWORD = "hello123" INCORRECT_PASSWORD = "goodbye321"