forked from p15670423/monkey
Island: Use new IslandModeService in IslandMode resource
This commit is contained in:
parent
50a982672e
commit
8f7e7f98e7
|
@ -5,8 +5,7 @@ from flask import make_response, request
|
|||
|
||||
from monkey_island.cc.resources.AbstractResource import AbstractResource
|
||||
from monkey_island.cc.resources.request_authentication import jwt_required
|
||||
from monkey_island.cc.services.config_manipulator import update_config_on_mode_set
|
||||
from monkey_island.cc.services.mode.island_mode_service import get_mode, set_mode
|
||||
from monkey_island.cc.services import IslandModeService
|
||||
from monkey_island.cc.services.mode.mode_enum import IslandModeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -16,20 +15,16 @@ class IslandMode(AbstractResource):
|
|||
# API Spec: Instead of POST, this could just be PATCH
|
||||
urls = ["/api/island-mode"]
|
||||
|
||||
def __init__(self, island_mode_service: IslandModeService):
|
||||
self._island_mode_service = island_mode_service
|
||||
|
||||
@jwt_required
|
||||
def post(self):
|
||||
try:
|
||||
body = json.loads(request.data)
|
||||
mode_str = body.get("mode")
|
||||
mode = IslandModeEnum(body.get("mode"))
|
||||
|
||||
mode = IslandModeEnum(mode_str)
|
||||
set_mode(mode)
|
||||
|
||||
if not update_config_on_mode_set(mode):
|
||||
logger.error(
|
||||
"Could not apply configuration changes per mode. "
|
||||
"Using default advanced configuration."
|
||||
)
|
||||
self._island_mode_service.set_mode(mode)
|
||||
|
||||
return make_response({}, 200)
|
||||
except (AttributeError, json.decoder.JSONDecodeError):
|
||||
|
@ -39,5 +34,5 @@ class IslandMode(AbstractResource):
|
|||
|
||||
@jwt_required
|
||||
def get(self):
|
||||
island_mode = get_mode()
|
||||
return make_response({"mode": island_mode}, 200)
|
||||
island_mode = self._island_mode_service.get_mode()
|
||||
return make_response({"mode": island_mode.value}, 200)
|
||||
|
|
|
@ -2,3 +2,4 @@ from .single_file_repository import SingleFileRepository
|
|||
from .mock_file_repository import MockFileRepository, FILE_CONTENTS, FILE_NAME
|
||||
from .open_error_file_repository import OpenErrorFileRepository
|
||||
from .in_memory_agent_configuration_repository import InMemoryAgentConfigurationRepository
|
||||
from .in_memory_simulation_configuration import InMemorySimulationRepository
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
import dataclasses
|
||||
|
||||
from monkey_island.cc.models import Simulation
|
||||
from monkey_island.cc.repository import ISimulationRepository
|
||||
from monkey_island.cc.services.mode.mode_enum import IslandModeEnum
|
||||
|
||||
|
||||
class InMemorySimulationRepository(ISimulationRepository):
|
||||
def __init__(self):
|
||||
self._simulation = Simulation()
|
||||
|
||||
def get_simulation(self) -> Simulation:
|
||||
return self._simulation
|
||||
|
||||
def save_simulation(self, simulation: Simulation):
|
||||
self._simulation = simulation
|
||||
|
||||
def get_mode(self) -> IslandModeEnum:
|
||||
return self._simulation.mode
|
||||
|
||||
def set_mode(self, mode: IslandModeEnum):
|
||||
self._simulation = dataclasses.replace(self._simulation, mode=mode)
|
|
@ -1,28 +1,32 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from tests.utils import raise_
|
||||
from tests.common import StubDIContainer
|
||||
from tests.monkey_island import InMemorySimulationRepository
|
||||
|
||||
from monkey_island.cc.models.island_mode_model import IslandMode
|
||||
from monkey_island.cc.resources import island_mode as island_mode_resource
|
||||
from monkey_island.cc.repository import ISimulationRepository, RetrievalError
|
||||
from monkey_island.cc.resources.island_mode import IslandMode as IslandModeResource
|
||||
from monkey_island.cc.services import IslandModeService
|
||||
from monkey_island.cc.services.mode.mode_enum import IslandModeEnum
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def uses_database():
|
||||
IslandMode.objects().delete()
|
||||
@pytest.fixture
|
||||
def flask_client(build_flask_client):
|
||||
container = StubDIContainer()
|
||||
|
||||
container.register(ISimulationRepository, InMemorySimulationRepository)
|
||||
container.register_instance(IslandModeService, container.resolve(IslandModeService))
|
||||
|
||||
with build_flask_client(container) as flask_client:
|
||||
yield flask_client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
[IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value, IslandModeEnum.UNSET.value],
|
||||
)
|
||||
def test_island_mode_post(flask_client, mode, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"monkey_island.cc.resources.island_mode.update_config_on_mode_set",
|
||||
lambda _: None,
|
||||
)
|
||||
def test_island_mode_post(flask_client, mode):
|
||||
resp = flask_client.post(
|
||||
IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True
|
||||
)
|
||||
|
@ -42,19 +46,25 @@ def test_island_mode_post__invalid_json(flask_client, invalid_json):
|
|||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_island_mode_post__internal_server_error(monkeypatch, flask_client):
|
||||
monkeypatch.setattr(island_mode_resource, "set_mode", lambda x: raise_(Exception()))
|
||||
def test_island_mode_post__internal_server_error(build_flask_client):
|
||||
mock_island_mode_service = MagicMock(spec=IslandModeService)
|
||||
mock_island_mode_service.set_mode = MagicMock(side_effect=RetrievalError)
|
||||
|
||||
container = StubDIContainer()
|
||||
container.register_instance(IslandModeService, mock_island_mode_service)
|
||||
|
||||
with build_flask_client(container) as flask_client:
|
||||
resp = flask_client.post(
|
||||
IslandModeResource.urls[0],
|
||||
data=json.dumps({"mode": IslandModeEnum.RANSOMWARE.value}),
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
resp = flask_client.post(
|
||||
IslandModeResource.urls[0],
|
||||
data=json.dumps({"mode": IslandModeEnum.RANSOMWARE.value}),
|
||||
follow_redirects=True,
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", [IslandModeEnum.RANSOMWARE.value, IslandModeEnum.ADVANCED.value])
|
||||
def test_island_mode_endpoint(flask_client, uses_database, mode):
|
||||
def test_island_mode_endpoint(flask_client, mode):
|
||||
flask_client.post(
|
||||
IslandModeResource.urls[0], data=json.dumps({"mode": mode}), follow_redirects=True
|
||||
)
|
||||
|
@ -63,7 +73,7 @@ def test_island_mode_endpoint(flask_client, uses_database, mode):
|
|||
assert json.loads(resp.data)["mode"] == mode
|
||||
|
||||
|
||||
def test_island_mode_endpoint__invalid_mode(flask_client, uses_database):
|
||||
def test_island_mode_endpoint__invalid_mode(flask_client):
|
||||
resp_post = flask_client.post(
|
||||
IslandModeResource.urls[0], data=json.dumps({"mode": "bogus_mode"}), follow_redirects=True
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue