diff --git a/monkey/monkey_island/cc/repository/__init__.py b/monkey/monkey_island/cc/repository/__init__.py index 52e6f6873..925a60dd5 100644 --- a/monkey/monkey_island/cc/repository/__init__.py +++ b/monkey/monkey_island/cc/repository/__init__.py @@ -4,3 +4,5 @@ from .i_agent_binary_repository import IAgentBinaryRepository from .agent_binary_repository import AgentBinaryRepository from .i_agent_configuration_repository import IAgentConfigurationRepository from .file_agent_configuration_repository import FileAgentConfigurationRepository +from .i_simulation_repository import ISimulationRepository +from .file_simulation_repository import FileSimulationRepository diff --git a/monkey/monkey_island/cc/repository/file_simulation_repository.py b/monkey/monkey_island/cc/repository/file_simulation_repository.py new file mode 100644 index 000000000..fed7ce0eb --- /dev/null +++ b/monkey/monkey_island/cc/repository/file_simulation_repository.py @@ -0,0 +1,41 @@ +import dataclasses +import io + +from monkey_island.cc import repository +from monkey_island.cc.models import Simulation, SimulationSchema +from monkey_island.cc.repository import IFileRepository, ISimulationRepository, RetrievalError +from monkey_island.cc.services.mode.mode_enum import IslandModeEnum + +SIMULATION_STATE_FILE_NAME = "simulation_state.json" + + +class FileSimulationRepository(ISimulationRepository): + def __init__(self, file_repository: IFileRepository): + self._file_repository = file_repository + self._simulation_schema = SimulationSchema() + + def save_simulation(self, simulation: Simulation): + simulation_json = self._simulation_schema.dumps(simulation) + + self._file_repository.save_file( + SIMULATION_STATE_FILE_NAME, io.BytesIO(simulation_json.encode()) + ) + + def get_simulation(self) -> Simulation: + try: + with self._file_repository.open_file(SIMULATION_STATE_FILE_NAME) as f: + simulation_json = f.read().decode() + + return self._simulation_schema.loads(simulation_json) + except repository.FileNotFoundError: + return Simulation() + except Exception as err: + raise RetrievalError(f"Error retrieving the simulation state: {err}") + + def get_mode(self) -> IslandModeEnum: + return self.get_simulation().mode + + def set_mode(self, mode: IslandModeEnum): + old_simulation = self.get_simulation() + new_simulation = dataclasses.replace(old_simulation, mode=mode) + self.save_simulation(new_simulation) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py new file mode 100644 index 000000000..b5b0f63da --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py @@ -0,0 +1,52 @@ +import pytest +from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository + +from monkey_island.cc.models import Simulation +from monkey_island.cc.repository import FileSimulationRepository, RetrievalError +from monkey_island.cc.services.mode.mode_enum import IslandModeEnum + + +@pytest.fixture +def simulation_repository(): + return FileSimulationRepository(SingleFileRepository()) + + +def test_save_simulation(simulation_repository): + simulation = Simulation(IslandModeEnum.RANSOMWARE) + + old_simulation = simulation_repository.get_simulation() + simulation_repository.save_simulation(simulation) + new_simulation = simulation_repository.get_simulation() + + assert old_simulation != simulation + assert new_simulation == simulation + + +def test_get_default_simulation(simulation_repository): + default_simulation = Simulation() + + assert simulation_repository.get_simulation() == default_simulation + + +def test_set_mode(simulation_repository): + simulation_repository.set_mode(IslandModeEnum.ADVANCED) + + assert simulation_repository.get_mode() == IslandModeEnum.ADVANCED + + +def test_get_mode_default(simulation_repository): + assert simulation_repository.get_mode() == IslandModeEnum.UNSET + + +def test_get_simulation_retrieval_error(): + simulation_repository = FileSimulationRepository(OpenErrorFileRepository()) + + with pytest.raises(RetrievalError): + simulation_repository.get_simulation() + + +def test_get_mode_retrieval_error(): + simulation_repository = FileSimulationRepository(OpenErrorFileRepository()) + + with pytest.raises(RetrievalError): + simulation_repository.get_mode()