diff --git a/monkey/common/configuration/__init__.py b/monkey/common/configuration/__init__.py index fc1f3c84d..fb52b54e2 100644 --- a/monkey/common/configuration/__init__.py +++ b/monkey/common/configuration/__init__.py @@ -12,4 +12,5 @@ from .agent_sub_configurations import ( ) from .default_agent_configuration import ( DEFAULT_AGENT_CONFIGURATION, + DEFAULT_RANSOMWARE_AGENT_CONFIGURATION, ) diff --git a/monkey/common/configuration/default_agent_configuration.py b/monkey/common/configuration/default_agent_configuration.py index 251676017..9045429ca 100644 --- a/monkey/common/configuration/default_agent_configuration.py +++ b/monkey/common/configuration/default_agent_configuration.py @@ -1,3 +1,5 @@ +import dataclasses + from . import AgentConfiguration from .agent_sub_configurations import ( CustomPBAConfiguration, @@ -112,3 +114,7 @@ DEFAULT_AGENT_CONFIGURATION = AgentConfiguration( payloads=PAYLOAD_CONFIGURATION, propagation=PROPAGATION_CONFIGURATION, ) + +DEFAULT_RANSOMWARE_AGENT_CONFIGURATION = dataclasses.replace( + DEFAULT_AGENT_CONFIGURATION, post_breach_actions=[] +) diff --git a/monkey/monkey_island/cc/models/__init__.py b/monkey/monkey_island/cc/models/__init__.py index c293ae2e7..bf2addf63 100644 --- a/monkey/monkey_island/cc/models/__init__.py +++ b/monkey/monkey_island/cc/models/__init__.py @@ -8,3 +8,4 @@ from .monkey_ttl import MonkeyTtl from .pba_results import PbaResults from monkey_island.cc.models.report.report import Report from .stolen_credentials import StolenCredentials +from .simulation import Simulation, SimulationSchema, IslandMode diff --git a/monkey/monkey_island/cc/models/island_mode_model.py b/monkey/monkey_island/cc/models/island_mode_model.py deleted file mode 100644 index 8d6aab74a..000000000 --- a/monkey/monkey_island/cc/models/island_mode_model.py +++ /dev/null @@ -1,7 +0,0 @@ -from mongoengine import Document, StringField - - -class IslandMode(Document): - COLLECTION_NAME = "island_mode" - - mode = StringField() diff --git a/monkey/monkey_island/cc/models/simulation.py b/monkey/monkey_island/cc/models/simulation.py new file mode 100644 index 000000000..88b3ca25d --- /dev/null +++ b/monkey/monkey_island/cc/models/simulation.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +from marshmallow import Schema, post_load +from marshmallow_enum import EnumField + + +class IslandMode(Enum): + UNSET = "unset" + RANSOMWARE = "ransomware" + ADVANCED = "advanced" + + +@dataclass(frozen=True) +class Simulation: + mode: IslandMode = IslandMode.UNSET + + +class SimulationSchema(Schema): + mode = EnumField(IslandMode) + + @post_load + def _make_simulation(self, data, **kwargs): + return Simulation(**data) diff --git a/monkey/monkey_island/cc/repository/__init__.py b/monkey/monkey_island/cc/repository/__init__.py index 52e6f6873..925a60dd5 100644 --- a/monkey/monkey_island/cc/repository/__init__.py +++ b/monkey/monkey_island/cc/repository/__init__.py @@ -4,3 +4,5 @@ from .i_agent_binary_repository import IAgentBinaryRepository from .agent_binary_repository import AgentBinaryRepository from .i_agent_configuration_repository import IAgentConfigurationRepository from .file_agent_configuration_repository import FileAgentConfigurationRepository +from .i_simulation_repository import ISimulationRepository +from .file_simulation_repository import FileSimulationRepository diff --git a/monkey/monkey_island/cc/repository/file_simulation_repository.py b/monkey/monkey_island/cc/repository/file_simulation_repository.py new file mode 100644 index 000000000..0b4843fb6 --- /dev/null +++ b/monkey/monkey_island/cc/repository/file_simulation_repository.py @@ -0,0 +1,40 @@ +import dataclasses +import io + +from monkey_island.cc import repository +from monkey_island.cc.models import IslandMode, Simulation, SimulationSchema +from monkey_island.cc.repository import IFileRepository, ISimulationRepository, RetrievalError + +SIMULATION_STATE_FILE_NAME = "simulation_state.json" + + +class FileSimulationRepository(ISimulationRepository): + def __init__(self, file_repository: IFileRepository): + self._file_repository = file_repository + self._simulation_schema = SimulationSchema() + + def get_simulation(self) -> Simulation: + try: + with self._file_repository.open_file(SIMULATION_STATE_FILE_NAME) as f: + simulation_json = f.read().decode() + + return self._simulation_schema.loads(simulation_json) + except repository.FileNotFoundError: + return Simulation() + except Exception as err: + raise RetrievalError(f"Error retrieving the simulation state: {err}") + + def save_simulation(self, simulation: Simulation): + simulation_json = self._simulation_schema.dumps(simulation) + + self._file_repository.save_file( + SIMULATION_STATE_FILE_NAME, io.BytesIO(simulation_json.encode()) + ) + + def get_mode(self) -> IslandMode: + return self.get_simulation().mode + + def set_mode(self, mode: IslandMode): + old_simulation = self.get_simulation() + new_simulation = dataclasses.replace(old_simulation, mode=mode) + self.save_simulation(new_simulation) diff --git a/monkey/monkey_island/cc/repository/i_simulation_repository.py b/monkey/monkey_island/cc/repository/i_simulation_repository.py index 3a3854fb0..94bef50c4 100644 --- a/monkey/monkey_island/cc/repository/i_simulation_repository.py +++ b/monkey/monkey_island/cc/repository/i_simulation_repository.py @@ -1,11 +1,48 @@ -from abc import ABC +from abc import ABC, abstractmethod + +from monkey_island.cc.models import IslandMode, Simulation class ISimulationRepository(ABC): - # TODO define simulation object. It should contain metadata about simulation, - # like start, end times, mode and last forced stop of all monkeys - def save_simulation(self, simulation: Simulation): # noqa: F821 + @abstractmethod + def get_simulation(self) -> Simulation: + """ + Get the simulation state + + :raises RetrievalError: If the simulation state could not be retrieved + """ + pass - def get_simulation(self): + @abstractmethod + def save_simulation(self, simulation: Simulation): + """ + Save the simulation state + + :param simulation: The simulation state + :raises StorageError: If the simulation states could not be saved + """ + + pass + + @abstractmethod + def get_mode(self) -> IslandMode: + """ + Get's the island's current mode + + :return The island's current mode + :raises RetrievalError: If the mode could not be retrieved + """ + + pass + + @abstractmethod + def set_mode(self, mode: IslandMode): + """ + Set the island's mode + + :param mode: The island's new mode + :raises StorageError: If the mode could not be saved + """ + pass diff --git a/monkey/monkey_island/cc/resources/island_mode.py b/monkey/monkey_island/cc/resources/island_mode.py index 5d7343459..3273f150e 100644 --- a/monkey/monkey_island/cc/resources/island_mode.py +++ b/monkey/monkey_island/cc/resources/island_mode.py @@ -3,11 +3,10 @@ import logging from flask import make_response, request +from monkey_island.cc.models import IslandMode as IslandModeEnum from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.request_authentication import jwt_required -from monkey_island.cc.services.config_manipulator import update_config_on_mode_set -from monkey_island.cc.services.mode.island_mode_service import get_mode, set_mode -from monkey_island.cc.services.mode.mode_enum import IslandModeEnum +from monkey_island.cc.services import IslandModeService logger = logging.getLogger(__name__) @@ -16,20 +15,16 @@ class IslandMode(AbstractResource): # API Spec: Instead of POST, this could just be PATCH urls = ["/api/island-mode"] + def __init__(self, island_mode_service: IslandModeService): + self._island_mode_service = island_mode_service + @jwt_required def post(self): try: body = json.loads(request.data) - mode_str = body.get("mode") + mode = IslandModeEnum(body.get("mode")) - mode = IslandModeEnum(mode_str) - set_mode(mode) - - if not update_config_on_mode_set(mode): - logger.error( - "Could not apply configuration changes per mode. " - "Using default advanced configuration." - ) + self._island_mode_service.set_mode(mode) return make_response({}, 200) except (AttributeError, json.decoder.JSONDecodeError): @@ -39,5 +34,5 @@ class IslandMode(AbstractResource): @jwt_required def get(self): - island_mode = get_mode() - return make_response({"mode": island_mode}, 200) + island_mode = self._island_mode_service.get_mode() + return make_response({"mode": island_mode.value}, 200) diff --git a/monkey/monkey_island/cc/services/__init__.py b/monkey/monkey_island/cc/services/__init__.py index 66bb1ea65..bfecc7588 100644 --- a/monkey/monkey_island/cc/services/__init__.py +++ b/monkey/monkey_island/cc/services/__init__.py @@ -2,3 +2,4 @@ from .authentication.authentication_service import AuthenticationService from .authentication.json_file_user_datastore import JsonFileUserDatastore from .aws import AWSService +from .island_mode_service import IslandModeService diff --git a/monkey/monkey_island/cc/services/config.py b/monkey/monkey_island/cc/services/config.py index 3cb5ce09f..99f5a589c 100644 --- a/monkey/monkey_island/cc/services/config.py +++ b/monkey/monkey_island/cc/services/config.py @@ -23,10 +23,7 @@ from monkey_island.cc.server_utils.encryption import ( encrypt_dict, get_datastore_encryptor, ) -from monkey_island.cc.services.config_manipulator import update_config_per_mode from monkey_island.cc.services.config_schema.config_schema import SCHEMA -from monkey_island.cc.services.mode.island_mode_service import get_mode -from monkey_island.cc.services.mode.mode_enum import IslandModeEnum from monkey_island.cc.services.post_breach_files import PostBreachFilesService logger = logging.getLogger(__name__) @@ -250,14 +247,6 @@ class ConfigService: @staticmethod def reset_config(): PostBreachFilesService.remove_PBA_files() - config = ConfigService.get_default_config(True) - - mode = get_mode() - if mode == IslandModeEnum.UNSET.value: - ConfigService.update_config(config, should_encrypt=False) - else: - update_config_per_mode(mode, config, should_encrypt=False) - logger.info("Monkey config reset was called") @staticmethod diff --git a/monkey/monkey_island/cc/services/config_manipulator.py b/monkey/monkey_island/cc/services/config_manipulator.py deleted file mode 100644 index 41f555be4..000000000 --- a/monkey/monkey_island/cc/services/config_manipulator.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Dict - -import dpath.util - -import monkey_island.cc.services.config as config_service -from monkey_island.cc.services.config_manipulators import MANIPULATOR_PER_MODE -from monkey_island.cc.services.mode.mode_enum import IslandModeEnum - - -def update_config_on_mode_set(mode: IslandModeEnum) -> bool: - config = config_service.ConfigService.get_config() - return update_config_per_mode(mode.value, config, True) - - -def update_config_per_mode(mode: str, config: Dict, should_encrypt: bool) -> bool: - config = _set_default_config_values_per_mode(mode, config) - return config_service.ConfigService.update_config( - config_json=config, should_encrypt=should_encrypt - ) - - -def _set_default_config_values_per_mode(mode: str, config: Dict) -> Dict: - config_manipulator = MANIPULATOR_PER_MODE[mode] - config = _apply_config_manipulator(config, config_manipulator) - return config - - -def _apply_config_manipulator(config: Dict, config_manipulator: Dict): - for path, value in config_manipulator.items(): - dpath.util.set(config, path, value, ".") - return config diff --git a/monkey/monkey_island/cc/services/config_manipulators.py b/monkey/monkey_island/cc/services/config_manipulators.py deleted file mode 100644 index 947a32971..000000000 --- a/monkey/monkey_island/cc/services/config_manipulators.py +++ /dev/null @@ -1,7 +0,0 @@ -from monkey_island.cc.services.mode.mode_enum import IslandModeEnum - -MANIPULATOR_PER_MODE = { - IslandModeEnum.ADVANCED.value: {}, - IslandModeEnum.RANSOMWARE.value: {"monkey.post_breach.post_breach_actions": []}, - IslandModeEnum.UNSET.value: {}, -} diff --git a/monkey/monkey_island/cc/services/database.py b/monkey/monkey_island/cc/services/database.py index 46b5e0ffd..3dd512026 100644 --- a/monkey/monkey_island/cc/services/database.py +++ b/monkey/monkey_island/cc/services/database.py @@ -6,7 +6,6 @@ from monkey_island.cc.database import mongo from monkey_island.cc.models import Config from monkey_island.cc.models.agent_controls import AgentControls from monkey_island.cc.models.attack.attack_mitigations import AttackMitigations -from monkey_island.cc.models.island_mode_model import IslandMode from monkey_island.cc.services.config import ConfigService logger = logging.getLogger(__name__) @@ -33,7 +32,7 @@ class Database(object): @staticmethod def _should_drop(collection: str, drop_config: bool) -> bool: if not drop_config: - if collection == IslandMode.COLLECTION_NAME or collection == Config.COLLECTION_NAME: + if collection == Config.COLLECTION_NAME: return False return ( not collection.startswith("system.") diff --git a/monkey/monkey_island/cc/services/initialize.py b/monkey/monkey_island/cc/services/initialize.py index 922c3654b..28a88558c 100644 --- a/monkey/monkey_island/cc/services/initialize.py +++ b/monkey/monkey_island/cc/services/initialize.py @@ -3,19 +3,25 @@ from pathlib import Path from common import DIContainer from common.aws import AWSInstance -from common.configuration import DEFAULT_AGENT_CONFIGURATION, AgentConfiguration +from common.configuration import ( + DEFAULT_AGENT_CONFIGURATION, + DEFAULT_RANSOMWARE_AGENT_CONFIGURATION, + AgentConfiguration, +) from common.utils.file_utils import get_binary_io_sha256_hash from monkey_island.cc.repository import ( AgentBinaryRepository, FileAgentConfigurationRepository, + FileSimulationRepository, IAgentBinaryRepository, IAgentConfigurationRepository, IFileRepository, + ISimulationRepository, LocalStorageFileRepository, RetrievalError, ) from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH -from monkey_island.cc.services import AWSService +from monkey_island.cc.services import AWSService, IslandModeService from monkey_island.cc.services.post_breach_files import PostBreachFilesService from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService @@ -29,22 +35,10 @@ AGENT_BINARIES_PATH = Path(MONKEY_ISLAND_ABS_PATH) / "cc" / "binaries" def initialize_services(data_dir: Path) -> DIContainer: container = DIContainer() - - container.register_convention(Path, "data_dir", data_dir) - container.register_convention( - AgentConfiguration, "default_agent_configuration", DEFAULT_AGENT_CONFIGURATION - ) + _register_conventions(container, data_dir) container.register_instance(AWSInstance, AWSInstance()) - - container.register_instance( - IFileRepository, LocalStorageFileRepository(data_dir / "runtime_data") - ) - container.register_instance(AWSService, container.resolve(AWSService)) - container.register_instance(IAgentBinaryRepository, _build_agent_binary_repository()) - container.register_instance(LocalMonkeyRunService, container.resolve(LocalMonkeyRunService)) - container.register_instance( - IAgentConfigurationRepository, container.resolve(FileAgentConfigurationRepository) - ) + _register_repositories(container, data_dir) + _register_services(container) # This is temporary until we get DI all worked out. PostBreachFilesService.initialize(container.resolve(IFileRepository)) @@ -54,6 +48,29 @@ def initialize_services(data_dir: Path) -> DIContainer: return container +def _register_conventions(container: DIContainer, data_dir: Path): + container.register_convention(Path, "data_dir", data_dir) + container.register_convention( + AgentConfiguration, "default_agent_configuration", DEFAULT_AGENT_CONFIGURATION + ) + container.register_convention( + AgentConfiguration, + "default_ransomware_agent_configuration", + DEFAULT_RANSOMWARE_AGENT_CONFIGURATION, + ) + + +def _register_repositories(container: DIContainer, data_dir: Path): + container.register_instance( + IFileRepository, LocalStorageFileRepository(data_dir / "runtime_data") + ) + container.register_instance(IAgentBinaryRepository, _build_agent_binary_repository()) + container.register_instance( + IAgentConfigurationRepository, container.resolve(FileAgentConfigurationRepository) + ) + container.register_instance(ISimulationRepository, container.resolve(FileSimulationRepository)) + + def _build_agent_binary_repository(): file_repository = LocalStorageFileRepository(AGENT_BINARIES_PATH) agent_binary_repository = AgentBinaryRepository(file_repository) @@ -85,3 +102,9 @@ def _log_agent_binary_hashes(agent_binary_repository: IAgentBinaryRepository): for os, binary_sha256_hash in agent_hashes.items(): logger.info(f"{os} agent: SHA-256 hash: {binary_sha256_hash}") + + +def _register_services(container: DIContainer): + container.register_instance(AWSService, container.resolve(AWSService)) + container.register_instance(LocalMonkeyRunService, container.resolve(LocalMonkeyRunService)) + container.register_instance(IslandModeService, container.resolve(IslandModeService)) diff --git a/monkey/monkey_island/cc/services/island_mode_service.py b/monkey/monkey_island/cc/services/island_mode_service.py new file mode 100644 index 000000000..2d65d0067 --- /dev/null +++ b/monkey/monkey_island/cc/services/island_mode_service.py @@ -0,0 +1,46 @@ +from common.configuration import AgentConfiguration +from monkey_island.cc.models import IslandMode +from monkey_island.cc.repository import IAgentConfigurationRepository, ISimulationRepository + + +class IslandModeService: + def __init__( + self, + _agent_configuration_repository: IAgentConfigurationRepository, + simulation_repository: ISimulationRepository, + default_agent_configuration: AgentConfiguration, + default_ransomware_agent_configuration: AgentConfiguration, + ): + self._agent_configuration_repository = _agent_configuration_repository + self._simulation_repository = simulation_repository + self._default_agent_configuration = default_agent_configuration + self._default_ransomware_agent_configuration = default_ransomware_agent_configuration + + def get_mode(self): + """ + Get's the island's current mode + + :return The island's current mode + :raises RetrievalError: If the mode could not be retrieved + """ + return self._simulation_repository.get_mode() + + def set_mode(self, mode: IslandMode): + """ + Set the island's mode + + :param mode: The island's new mode + :raises StorageError: If the mode could not be saved + """ + self._simulation_repository.set_mode(mode) + self._set_configuration(mode) + + def _set_configuration(self, mode: IslandMode): + if mode == IslandMode.RANSOMWARE: + self._agent_configuration_repository.store_configuration( + self._default_ransomware_agent_configuration + ) + else: + self._agent_configuration_repository.store_configuration( + self._default_agent_configuration + ) diff --git a/monkey/monkey_island/cc/services/mode/__init__.py b/monkey/monkey_island/cc/services/mode/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/monkey/monkey_island/cc/services/mode/island_mode_service.py b/monkey/monkey_island/cc/services/mode/island_mode_service.py deleted file mode 100644 index 32a3d943d..000000000 --- a/monkey/monkey_island/cc/services/mode/island_mode_service.py +++ /dev/null @@ -1,17 +0,0 @@ -from monkey_island.cc.models.island_mode_model import IslandMode -from monkey_island.cc.services.mode.mode_enum import IslandModeEnum - - -def set_mode(mode: IslandModeEnum): - IslandMode.drop_collection() - island_mode_model = IslandMode() - island_mode_model.mode = mode.value - island_mode_model.save() - - -def get_mode() -> str: - if IslandMode.objects: - mode = IslandMode.objects[0].mode - return mode - else: - return IslandModeEnum.UNSET.value diff --git a/monkey/monkey_island/cc/services/mode/mode_enum.py b/monkey/monkey_island/cc/services/mode/mode_enum.py deleted file mode 100644 index d6258bbff..000000000 --- a/monkey/monkey_island/cc/services/mode/mode_enum.py +++ /dev/null @@ -1,7 +0,0 @@ -from enum import Enum - - -class IslandModeEnum(Enum): - UNSET = "unset" - RANSOMWARE = "ransomware" - ADVANCED = "advanced" diff --git a/monkey/tests/monkey_island/__init__.py b/monkey/tests/monkey_island/__init__.py index 8f3654a6e..935752c37 100644 --- a/monkey/tests/monkey_island/__init__.py +++ b/monkey/tests/monkey_island/__init__.py @@ -2,3 +2,4 @@ from .single_file_repository import SingleFileRepository from .mock_file_repository import MockFileRepository, FILE_CONTENTS, FILE_NAME from .open_error_file_repository import OpenErrorFileRepository from .in_memory_agent_configuration_repository import InMemoryAgentConfigurationRepository +from .in_memory_simulation_configuration import InMemorySimulationRepository diff --git a/monkey/tests/monkey_island/in_memory_simulation_configuration.py b/monkey/tests/monkey_island/in_memory_simulation_configuration.py new file mode 100644 index 000000000..b2b90fb40 --- /dev/null +++ b/monkey/tests/monkey_island/in_memory_simulation_configuration.py @@ -0,0 +1,21 @@ +import dataclasses + +from monkey_island.cc.models import IslandMode, Simulation +from monkey_island.cc.repository import ISimulationRepository + + +class InMemorySimulationRepository(ISimulationRepository): + def __init__(self): + self._simulation = Simulation() + + def get_simulation(self) -> Simulation: + return self._simulation + + def save_simulation(self, simulation: Simulation): + self._simulation = simulation + + def get_mode(self) -> IslandMode: + return self._simulation.mode + + def set_mode(self, mode: IslandMode): + self._simulation = dataclasses.replace(self._simulation, mode=mode) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py new file mode 100644 index 000000000..d01775b98 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py @@ -0,0 +1,48 @@ +import pytest +from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository + +from monkey_island.cc.models import IslandMode, Simulation +from monkey_island.cc.repository import FileSimulationRepository, RetrievalError + + +@pytest.fixture +def simulation_repository(): + return FileSimulationRepository(SingleFileRepository()) + + +@pytest.mark.parametrize("mode", list(IslandMode)) +def test_save_simulation(simulation_repository, mode): + simulation = Simulation(mode) + simulation_repository.save_simulation(simulation) + + assert simulation_repository.get_simulation() == simulation + + +def test_get_default_simulation(simulation_repository): + default_simulation = Simulation() + + assert simulation_repository.get_simulation() == default_simulation + + +def test_set_mode(simulation_repository): + simulation_repository.set_mode(IslandMode.ADVANCED) + + assert simulation_repository.get_mode() == IslandMode.ADVANCED + + +def test_get_mode_default(simulation_repository): + assert simulation_repository.get_mode() == IslandMode.UNSET + + +def test_get_simulation_retrieval_error(): + simulation_repository = FileSimulationRepository(OpenErrorFileRepository()) + + with pytest.raises(RetrievalError): + simulation_repository.get_simulation() + + +def test_get_mode_retrieval_error(): + simulation_repository = FileSimulationRepository(OpenErrorFileRepository()) + + with pytest.raises(RetrievalError): + simulation_repository.get_mode() diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py index 2e603864d..780553eaa 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py @@ -1,28 +1,41 @@ import json +from unittest.mock import MagicMock import pytest -from tests.utils import raise_ +from tests.common import StubDIContainer +from tests.monkey_island import InMemorySimulationRepository -from monkey_island.cc.models.island_mode_model import IslandMode -from monkey_island.cc.resources import island_mode as island_mode_resource +from monkey_island.cc.models import IslandMode +from monkey_island.cc.repository import RetrievalError from monkey_island.cc.resources.island_mode import IslandMode as IslandModeResource -from monkey_island.cc.services.mode.mode_enum import IslandModeEnum +from monkey_island.cc.services import IslandModeService -@pytest.fixture(scope="function") -def uses_database(): - IslandMode.objects().delete() +class MockIslandModeService(IslandModeService): + def __init__(self): + self._simulation_repository = InMemorySimulationRepository() + + def get_mode(self) -> IslandMode: + return self._simulation_repository.get_mode() + + def set_mode(self, mode: IslandMode): + self._simulation_repository.set_mode(mode) + + +@pytest.fixture +def flask_client(build_flask_client): + container = StubDIContainer() + container.register_instance(IslandModeService, MockIslandModeService()) + + with build_flask_client(container) as flask_client: + yield flask_client @pytest.mark.parametrize( "mode", - [IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value, IslandModeEnum.UNSET.value], + [IslandMode.RANSOMWARE.value, IslandMode.ADVANCED.value, IslandMode.UNSET.value], ) -def test_island_mode_post(flask_client, mode, monkeypatch): - monkeypatch.setattr( - "monkey_island.cc.resources.island_mode.update_config_on_mode_set", - lambda _: None, - ) +def test_island_mode_post(flask_client, mode): resp = flask_client.post( IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True ) @@ -42,19 +55,25 @@ def test_island_mode_post__invalid_json(flask_client, invalid_json): assert resp.status_code == 400 -def test_island_mode_post__internal_server_error(monkeypatch, flask_client): - monkeypatch.setattr(island_mode_resource, "set_mode", lambda x: raise_(Exception())) +def test_island_mode_post__internal_server_error(build_flask_client): + mock_island_mode_service = MagicMock(spec=IslandModeService) + mock_island_mode_service.set_mode = MagicMock(side_effect=RetrievalError) + + container = StubDIContainer() + container.register_instance(IslandModeService, mock_island_mode_service) + + with build_flask_client(container) as flask_client: + resp = flask_client.post( + IslandModeResource.urls[0], + data=json.dumps({"mode": IslandMode.RANSOMWARE.value}), + follow_redirects=True, + ) - resp = flask_client.post( - IslandModeResource.urls[0], - data=json.dumps({"mode": IslandModeEnum.RANSOMWARE.value}), - follow_redirects=True, - ) assert resp.status_code == 500 -@pytest.mark.parametrize("mode", [IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value]) -def test_island_mode_endpoint(flask_client, uses_database, mode): +@pytest.mark.parametrize("mode", [IslandMode.RANSOMWARE.value, IslandMode.ADVANCED.value]) +def test_island_mode_endpoint(flask_client, mode): flask_client.post( IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True ) @@ -63,10 +82,10 @@ def test_island_mode_endpoint(flask_client, uses_database, mode): assert json.loads(resp.data)["mode"] == mode -def test_island_mode_endpoint__invalid_mode(flask_client, uses_database): +def test_island_mode_endpoint__invalid_mode(flask_client): resp_post = flask_client.post( IslandModeResource.urls[0], data=json.dumps({"mode": "bogus_mode"}), follow_redirects=True ) resp_get = flask_client.get(IslandModeResource.urls[0], follow_redirects=True) assert resp_post.status_code == 422 - assert json.loads(resp_get.data)["mode"] == IslandModeEnum.UNSET.value + assert json.loads(resp_get.data)["mode"] == IslandMode.UNSET.value diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_config_manipulator.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_config_manipulator.py deleted file mode 100644 index 403b5aee3..000000000 --- a/monkey/tests/unit_tests/monkey_island/cc/services/test_config_manipulator.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest - -from monkey_island.cc.services.config_manipulator import update_config_on_mode_set -from monkey_island.cc.services.mode.mode_enum import IslandModeEnum - - -@pytest.mark.slow -@pytest.mark.usefixtures("uses_encryptor") -def test_update_config_on_mode_set_advanced(config, monkeypatch): - monkeypatch.setattr("monkey_island.cc.services.config.ConfigService.get_config", lambda: config) - monkeypatch.setattr( - "monkey_island.cc.services.config.ConfigService.update_config", - lambda config_json, should_encrypt: config_json, - ) - - mode = IslandModeEnum.ADVANCED - manipulated_config = update_config_on_mode_set(mode) - assert manipulated_config == config - - -@pytest.mark.slow -@pytest.mark.usefixtures("uses_encryptor") -def test_update_config_on_mode_set_ransomware(config, monkeypatch): - monkeypatch.setattr("monkey_island.cc.services.config.ConfigService.get_config", lambda: config) - monkeypatch.setattr( - "monkey_island.cc.services.config.ConfigService.update_config", - lambda config_json, should_encrypt: config_json, - ) - - mode = IslandModeEnum.RANSOMWARE - manipulated_config = update_config_on_mode_set(mode) - assert manipulated_config["monkey"]["post_breach"]["post_breach_actions"] == [] diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_island_mode_service.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_island_mode_service.py new file mode 100644 index 000000000..feb125b49 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_island_mode_service.py @@ -0,0 +1,42 @@ +import pytest +from tests.monkey_island import InMemoryAgentConfigurationRepository, InMemorySimulationRepository + +from common.configuration import DEFAULT_AGENT_CONFIGURATION, DEFAULT_RANSOMWARE_AGENT_CONFIGURATION +from monkey_island.cc.models import IslandMode +from monkey_island.cc.services import IslandModeService + + +@pytest.fixture +def agent_configuration_repository(): + return InMemoryAgentConfigurationRepository() + + +@pytest.fixture +def island_mode_service(agent_configuration_repository): + return IslandModeService( + agent_configuration_repository, + InMemorySimulationRepository(), + DEFAULT_AGENT_CONFIGURATION, + DEFAULT_RANSOMWARE_AGENT_CONFIGURATION, + ) + + +@pytest.mark.parametrize("mode", list(IslandMode)) +def test_set_mode(island_mode_service, mode): + island_mode_service.set_mode(mode) + assert island_mode_service.get_mode() == mode + + +@pytest.mark.parametrize( + "mode, expected_config", + [ + (IslandMode.UNSET, DEFAULT_AGENT_CONFIGURATION), + (IslandMode.ADVANCED, DEFAULT_AGENT_CONFIGURATION), + (IslandMode.RANSOMWARE, DEFAULT_RANSOMWARE_AGENT_CONFIGURATION), + ], +) +def test_set_mode_sets_config( + island_mode_service, agent_configuration_repository, mode, expected_config +): + island_mode_service.set_mode(mode) + assert agent_configuration_repository.get_configuration() == expected_config diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 126f6c596..22ddb97d5 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -195,7 +195,9 @@ _make_icmp_scan_configuration # unused method (monkey/common/configuration/agen _make_tcp_scan_configuration # unused method (monkey/common/configuration/agent_configuration.py:122) _make_network_scan_configuration # unused method (monkey/common/configuration/agent_configuration.py:110) _make_propagation_configuration # unused method (monkey/common/configuration/agent_configuration.py:167) -_make_agent_configuration # unused method (monkey/common/configuration/agent_configuration.py:192) + +# Models +_make_simulation # unused method (monkey/monkey_island/cc/models/simulation.py:19 # TODO DELETE AFTER RESOURCE REFACTORING