diff --git a/monkey/monkey_island/cc/models/__init__.py b/monkey/monkey_island/cc/models/__init__.py index 521e7f720..29c38674a 100644 --- a/monkey/monkey_island/cc/models/__init__.py +++ b/monkey/monkey_island/cc/models/__init__.py @@ -7,7 +7,7 @@ from .monkey import Monkey from .monkey_ttl import MonkeyTtl from .pba_results import PbaResults from monkey_island.cc.models.report.report import Report -from .simulation import Simulation, SimulationSchema, IslandMode +from .simulation import Simulation, IslandMode from .user_credentials import UserCredentials from .machine import Machine, MachineID from .communication_type import CommunicationType diff --git a/monkey/monkey_island/cc/models/simulation.py b/monkey/monkey_island/cc/models/simulation.py index 88b3ca25d..d04bee76b 100644 --- a/monkey/monkey_island/cc/models/simulation.py +++ b/monkey/monkey_island/cc/models/simulation.py @@ -1,10 +1,8 @@ from __future__ import annotations -from dataclasses import dataclass from enum import Enum -from marshmallow import Schema, post_load -from marshmallow_enum import EnumField +from common.base_models import InfectionMonkeyBaseModel class IslandMode(Enum): @@ -13,14 +11,5 @@ class IslandMode(Enum): ADVANCED = "advanced" -@dataclass(frozen=True) -class Simulation: +class Simulation(InfectionMonkeyBaseModel): mode: IslandMode = IslandMode.UNSET - - -class SimulationSchema(Schema): - mode = EnumField(IslandMode) - - @post_load - def _make_simulation(self, data, **kwargs): - return Simulation(**data) diff --git a/monkey/monkey_island/cc/repository/file_simulation_repository.py b/monkey/monkey_island/cc/repository/file_simulation_repository.py index 0b4843fb6..5f82095c9 100644 --- a/monkey/monkey_island/cc/repository/file_simulation_repository.py +++ b/monkey/monkey_island/cc/repository/file_simulation_repository.py @@ -1,8 +1,7 @@ -import dataclasses import io from monkey_island.cc import repository -from monkey_island.cc.models import IslandMode, Simulation, SimulationSchema +from monkey_island.cc.models import IslandMode, Simulation from monkey_island.cc.repository import IFileRepository, ISimulationRepository, RetrievalError SIMULATION_STATE_FILE_NAME = "simulation_state.json" @@ -11,21 +10,20 @@ SIMULATION_STATE_FILE_NAME = "simulation_state.json" class FileSimulationRepository(ISimulationRepository): def __init__(self, file_repository: IFileRepository): self._file_repository = file_repository - self._simulation_schema = SimulationSchema() def get_simulation(self) -> Simulation: try: with self._file_repository.open_file(SIMULATION_STATE_FILE_NAME) as f: simulation_json = f.read().decode() - return self._simulation_schema.loads(simulation_json) + return Simulation.parse_raw(simulation_json) except repository.FileNotFoundError: return Simulation() except Exception as err: raise RetrievalError(f"Error retrieving the simulation state: {err}") def save_simulation(self, simulation: Simulation): - simulation_json = self._simulation_schema.dumps(simulation) + simulation_json = simulation.json() self._file_repository.save_file( SIMULATION_STATE_FILE_NAME, io.BytesIO(simulation_json.encode()) @@ -35,6 +33,4 @@ class FileSimulationRepository(ISimulationRepository): return self.get_simulation().mode def set_mode(self, mode: IslandMode): - old_simulation = self.get_simulation() - new_simulation = dataclasses.replace(old_simulation, mode=mode) - self.save_simulation(new_simulation) + self.save_simulation(Simulation(mode=mode)) diff --git a/monkey/tests/monkey_island/in_memory_simulation_configuration.py b/monkey/tests/monkey_island/in_memory_simulation_configuration.py index b2b90fb40..61b1064e2 100644 --- a/monkey/tests/monkey_island/in_memory_simulation_configuration.py +++ b/monkey/tests/monkey_island/in_memory_simulation_configuration.py @@ -1,5 +1,3 @@ -import dataclasses - from monkey_island.cc.models import IslandMode, Simulation from monkey_island.cc.repository import ISimulationRepository @@ -18,4 +16,4 @@ class InMemorySimulationRepository(ISimulationRepository): return self._simulation.mode def set_mode(self, mode: IslandMode): - self._simulation = dataclasses.replace(self._simulation, mode=mode) + self._simulation = Simulation(mode=mode) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py index d01775b98..6136aa5ac 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_simulation_repository.py @@ -12,7 +12,7 @@ def simulation_repository(): @pytest.mark.parametrize("mode", list(IslandMode)) def test_save_simulation(simulation_repository, mode): - simulation = Simulation(mode) + simulation = Simulation(mode=mode) simulation_repository.save_simulation(simulation) assert simulation_repository.get_simulation() == simulation