Island: Use new IslandModeService in IslandMode resource

This commit is contained in:
Mike Salvatore 2022-07-01 11:31:04 -04:00
parent 50a982672e
commit 8f7e7f98e7
4 changed files with 61 additions and 33 deletions

View File

@ -5,8 +5,7 @@ from flask import make_response, request
from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required from monkey_island.cc.resources.request_authentication import jwt_required
from monkey_island.cc.services.config_manipulator import update_config_on_mode_set from monkey_island.cc.services import IslandModeService
from monkey_island.cc.services.mode.island_mode_service import get_mode, set_mode
from monkey_island.cc.services.mode.mode_enum import IslandModeEnum from monkey_island.cc.services.mode.mode_enum import IslandModeEnum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,20 +15,16 @@ class IslandMode(AbstractResource):
# API Spec: Instead of POST, this could just be PATCH # API Spec: Instead of POST, this could just be PATCH
urls = ["/api/island-mode"] urls = ["/api/island-mode"]
def __init__(self, island_mode_service: IslandModeService):
self._island_mode_service = island_mode_service
@jwt_required @jwt_required
def post(self): def post(self):
try: try:
body = json.loads(request.data) body = json.loads(request.data)
mode_str = body.get("mode") mode = IslandModeEnum(body.get("mode"))
mode = IslandModeEnum(mode_str) self._island_mode_service.set_mode(mode)
set_mode(mode)
if not update_config_on_mode_set(mode):
logger.error(
"Could not apply configuration changes per mode. "
"Using default advanced configuration."
)
return make_response({}, 200) return make_response({}, 200)
except (AttributeError, json.decoder.JSONDecodeError): except (AttributeError, json.decoder.JSONDecodeError):
@ -39,5 +34,5 @@ class IslandMode(AbstractResource):
@jwt_required @jwt_required
def get(self): def get(self):
island_mode = get_mode() island_mode = self._island_mode_service.get_mode()
return make_response({"mode": island_mode}, 200) return make_response({"mode": island_mode.value}, 200)

View File

@ -2,3 +2,4 @@ from .single_file_repository import SingleFileRepository
from .mock_file_repository import MockFileRepository, FILE_CONTENTS, FILE_NAME from .mock_file_repository import MockFileRepository, FILE_CONTENTS, FILE_NAME
from .open_error_file_repository import OpenErrorFileRepository from .open_error_file_repository import OpenErrorFileRepository
from .in_memory_agent_configuration_repository import InMemoryAgentConfigurationRepository from .in_memory_agent_configuration_repository import InMemoryAgentConfigurationRepository
from .in_memory_simulation_configuration import InMemorySimulationRepository

View File

@ -0,0 +1,22 @@
import dataclasses
from monkey_island.cc.models import Simulation
from monkey_island.cc.repository import ISimulationRepository
from monkey_island.cc.services.mode.mode_enum import IslandModeEnum
class InMemorySimulationRepository(ISimulationRepository):
def __init__(self):
self._simulation = Simulation()
def get_simulation(self) -> Simulation:
return self._simulation
def save_simulation(self, simulation: Simulation):
self._simulation = simulation
def get_mode(self) -> IslandModeEnum:
return self._simulation.mode
def set_mode(self, mode: IslandModeEnum):
self._simulation = dataclasses.replace(self._simulation, mode=mode)

View File

@ -1,28 +1,32 @@
import json import json
from unittest.mock import MagicMock
import pytest import pytest
from tests.utils import raise_ from tests.common import StubDIContainer
from tests.monkey_island import InMemorySimulationRepository
from monkey_island.cc.models.island_mode_model import IslandMode from monkey_island.cc.repository import ISimulationRepository, RetrievalError
from monkey_island.cc.resources import island_mode as island_mode_resource
from monkey_island.cc.resources.island_mode import IslandMode as IslandModeResource from monkey_island.cc.resources.island_mode import IslandMode as IslandModeResource
from monkey_island.cc.services import IslandModeService
from monkey_island.cc.services.mode.mode_enum import IslandModeEnum from monkey_island.cc.services.mode.mode_enum import IslandModeEnum
@pytest.fixture(scope="function") @pytest.fixture
def uses_database(): def flask_client(build_flask_client):
IslandMode.objects().delete() container = StubDIContainer()
container.register(ISimulationRepository, InMemorySimulationRepository)
container.register_instance(IslandModeService, container.resolve(IslandModeService))
with build_flask_client(container) as flask_client:
yield flask_client
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mode", "mode",
[IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value, IslandModeEnum.UNSET.value], [IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value, IslandModeEnum.UNSET.value],
) )
def test_island_mode_post(flask_client, mode, monkeypatch): def test_island_mode_post(flask_client, mode):
monkeypatch.setattr(
"monkey_island.cc.resources.island_mode.update_config_on_mode_set",
lambda _: None,
)
resp = flask_client.post( resp = flask_client.post(
IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True
) )
@ -42,19 +46,25 @@ def test_island_mode_post__invalid_json(flask_client, invalid_json):
assert resp.status_code == 400 assert resp.status_code == 400
def test_island_mode_post__internal_server_error(monkeypatch, flask_client): def test_island_mode_post__internal_server_error(build_flask_client):
monkeypatch.setattr(island_mode_resource, "set_mode", lambda x: raise_(Exception())) mock_island_mode_service = MagicMock(spec=IslandModeService)
mock_island_mode_service.set_mode = MagicMock(side_effect=RetrievalError)
container = StubDIContainer()
container.register_instance(IslandModeService, mock_island_mode_service)
with build_flask_client(container) as flask_client:
resp = flask_client.post(
IslandModeResource.urls[0],
data=json.dumps({"mode": IslandModeEnum.RANSOMWARE.value}),
follow_redirects=True,
)
resp = flask_client.post(
IslandModeResource.urls[0],
data=json.dumps({"mode": IslandModeEnum.RANSOMWARE.value}),
follow_redirects=True,
)
assert resp.status_code == 500 assert resp.status_code == 500
@pytest.mark.parametrize("mode", [IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value]) @pytest.mark.parametrize("mode", [IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value])
def test_island_mode_endpoint(flask_client, uses_database, mode): def test_island_mode_endpoint(flask_client, mode):
flask_client.post( flask_client.post(
IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True
) )
@ -63,7 +73,7 @@ def test_island_mode_endpoint(flask_client, uses_database, mode):
assert json.loads(resp.data)["mode"] == mode assert json.loads(resp.data)["mode"] == mode
def test_island_mode_endpoint__invalid_mode(flask_client, uses_database): def test_island_mode_endpoint__invalid_mode(flask_client):
resp_post = flask_client.post( resp_post = flask_client.post(
IslandModeResource.urls[0], data=json.dumps({"mode": "bogus_mode"}), follow_redirects=True IslandModeResource.urls[0], data=json.dumps({"mode": "bogus_mode"}), follow_redirects=True
) )