diff --git a/monkey/monkey_island/cc/resources/island_mode.py b/monkey/monkey_island/cc/resources/island_mode.py index 5d7343459..1a9483eec 100644 --- a/monkey/monkey_island/cc/resources/island_mode.py +++ b/monkey/monkey_island/cc/resources/island_mode.py @@ -5,8 +5,7 @@ from flask import make_response, request from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.request_authentication import jwt_required -from monkey_island.cc.services.config_manipulator import update_config_on_mode_set -from monkey_island.cc.services.mode.island_mode_service import get_mode, set_mode +from monkey_island.cc.services import IslandModeService from monkey_island.cc.services.mode.mode_enum import IslandModeEnum logger = logging.getLogger(__name__) @@ -16,20 +15,16 @@ class IslandMode(AbstractResource): # API Spec: Instead of POST, this could just be PATCH urls = ["/api/island-mode"] + def __init__(self, island_mode_service: IslandModeService): + self._island_mode_service = island_mode_service + @jwt_required def post(self): try: body = json.loads(request.data) - mode_str = body.get("mode") + mode = IslandModeEnum(body.get("mode")) - mode = IslandModeEnum(mode_str) - set_mode(mode) - - if not update_config_on_mode_set(mode): - logger.error( - "Could not apply configuration changes per mode. " - "Using default advanced configuration." - ) + self._island_mode_service.set_mode(mode) return make_response({}, 200) except (AttributeError, json.decoder.JSONDecodeError): @@ -39,5 +34,5 @@ class IslandMode(AbstractResource): @jwt_required def get(self): - island_mode = get_mode() - return make_response({"mode": island_mode}, 200) + island_mode = self._island_mode_service.get_mode() + return make_response({"mode": island_mode.value}, 200) diff --git a/monkey/tests/monkey_island/__init__.py b/monkey/tests/monkey_island/__init__.py index 8f3654a6e..935752c37 100644 --- a/monkey/tests/monkey_island/__init__.py +++ b/monkey/tests/monkey_island/__init__.py @@ -2,3 +2,4 @@ from .single_file_repository import SingleFileRepository from .mock_file_repository import MockFileRepository, FILE_CONTENTS, FILE_NAME from .open_error_file_repository import OpenErrorFileRepository from .in_memory_agent_configuration_repository import InMemoryAgentConfigurationRepository +from .in_memory_simulation_configuration import InMemorySimulationRepository diff --git a/monkey/tests/monkey_island/in_memory_simulation_configuration.py b/monkey/tests/monkey_island/in_memory_simulation_configuration.py new file mode 100644 index 000000000..e6e037fcd --- /dev/null +++ b/monkey/tests/monkey_island/in_memory_simulation_configuration.py @@ -0,0 +1,22 @@ +import dataclasses + +from monkey_island.cc.models import Simulation +from monkey_island.cc.repository import ISimulationRepository +from monkey_island.cc.services.mode.mode_enum import IslandModeEnum + + +class InMemorySimulationRepository(ISimulationRepository): + def __init__(self): + self._simulation = Simulation() + + def get_simulation(self) -> Simulation: + return self._simulation + + def save_simulation(self, simulation: Simulation): + self._simulation = simulation + + def get_mode(self) -> IslandModeEnum: + return self._simulation.mode + + def set_mode(self, mode: IslandModeEnum): + self._simulation = dataclasses.replace(self._simulation, mode=mode) diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py index 2e603864d..fec2b6d37 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_island_mode.py @@ -1,28 +1,32 @@ import json +from unittest.mock import MagicMock import pytest -from tests.utils import raise_ +from tests.common import StubDIContainer +from tests.monkey_island import InMemorySimulationRepository -from monkey_island.cc.models.island_mode_model import IslandMode -from monkey_island.cc.resources import island_mode as island_mode_resource +from monkey_island.cc.repository import ISimulationRepository, RetrievalError from monkey_island.cc.resources.island_mode import IslandMode as IslandModeResource +from monkey_island.cc.services import IslandModeService from monkey_island.cc.services.mode.mode_enum import IslandModeEnum -@pytest.fixture(scope="function") -def uses_database(): - IslandMode.objects().delete() +@pytest.fixture +def flask_client(build_flask_client): + container = StubDIContainer() + + container.register(ISimulationRepository, InMemorySimulationRepository) + container.register_instance(IslandModeService, container.resolve(IslandModeService)) + + with build_flask_client(container) as flask_client: + yield flask_client @pytest.mark.parametrize( "mode", [IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value, IslandModeEnum.UNSET.value], ) -def test_island_mode_post(flask_client, mode, monkeypatch): - monkeypatch.setattr( - "monkey_island.cc.resources.island_mode.update_config_on_mode_set", - lambda _: None, - ) +def test_island_mode_post(flask_client, mode): resp = flask_client.post( IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True ) @@ -42,19 +46,25 @@ def test_island_mode_post__invalid_json(flask_client, invalid_json): assert resp.status_code == 400 -def test_island_mode_post__internal_server_error(monkeypatch, flask_client): - monkeypatch.setattr(island_mode_resource, "set_mode", lambda x: raise_(Exception())) +def test_island_mode_post__internal_server_error(build_flask_client): + mock_island_mode_service = MagicMock(spec=IslandModeService) + mock_island_mode_service.set_mode = MagicMock(side_effect=RetrievalError) + + container = StubDIContainer() + container.register_instance(IslandModeService, mock_island_mode_service) + + with build_flask_client(container) as flask_client: + resp = flask_client.post( + IslandModeResource.urls[0], + data=json.dumps({"mode": IslandModeEnum.RANSOMWARE.value}), + follow_redirects=True, + ) - resp = flask_client.post( - IslandModeResource.urls[0], - data=json.dumps({"mode": IslandModeEnum.RANSOMWARE.value}), - follow_redirects=True, - ) assert resp.status_code == 500 @pytest.mark.parametrize("mode", [IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value]) -def test_island_mode_endpoint(flask_client, uses_database, mode): +def test_island_mode_endpoint(flask_client, mode): flask_client.post( IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True ) @@ -63,7 +73,7 @@ def test_island_mode_endpoint(flask_client, uses_database, mode): assert json.loads(resp.data)["mode"] == mode -def test_island_mode_endpoint__invalid_mode(flask_client, uses_database): +def test_island_mode_endpoint__invalid_mode(flask_client): resp_post = flask_client.post( IslandModeResource.urls[0], data=json.dumps({"mode": "bogus_mode"}), follow_redirects=True )