Island: Refactor simulation to use pydantic

This commit is contained in:
vakaris_zilius 2022-09-07 11:44:47 +00:00
parent 6fe501195b
commit f0b50b254e
5 changed files with 9 additions and 26 deletions

View File

@ -7,7 +7,7 @@ from .monkey import Monkey
from .monkey_ttl import MonkeyTtl from .monkey_ttl import MonkeyTtl
from .pba_results import PbaResults from .pba_results import PbaResults
from monkey_island.cc.models.report.report import Report from monkey_island.cc.models.report.report import Report
from .simulation import Simulation, SimulationSchema, IslandMode from .simulation import Simulation, IslandMode
from .user_credentials import UserCredentials from .user_credentials import UserCredentials
from .machine import Machine, MachineID from .machine import Machine, MachineID
from .communication_type import CommunicationType from .communication_type import CommunicationType

View File

@ -1,10 +1,8 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from enum import Enum from enum import Enum
from marshmallow import Schema, post_load from common.base_models import InfectionMonkeyBaseModel
from marshmallow_enum import EnumField
class IslandMode(Enum): class IslandMode(Enum):
@ -13,14 +11,5 @@ class IslandMode(Enum):
ADVANCED = "advanced" ADVANCED = "advanced"
@dataclass(frozen=True) class Simulation(InfectionMonkeyBaseModel):
class Simulation:
mode: IslandMode = IslandMode.UNSET mode: IslandMode = IslandMode.UNSET
class SimulationSchema(Schema):
mode = EnumField(IslandMode)
@post_load
def _make_simulation(self, data, **kwargs):
return Simulation(**data)

View File

@ -1,8 +1,7 @@
import dataclasses
import io import io
from monkey_island.cc import repository from monkey_island.cc import repository
from monkey_island.cc.models import IslandMode, Simulation, SimulationSchema from monkey_island.cc.models import IslandMode, Simulation
from monkey_island.cc.repository import IFileRepository, ISimulationRepository, RetrievalError from monkey_island.cc.repository import IFileRepository, ISimulationRepository, RetrievalError
SIMULATION_STATE_FILE_NAME = "simulation_state.json" SIMULATION_STATE_FILE_NAME = "simulation_state.json"
@ -11,21 +10,20 @@ SIMULATION_STATE_FILE_NAME = "simulation_state.json"
class FileSimulationRepository(ISimulationRepository): class FileSimulationRepository(ISimulationRepository):
def __init__(self, file_repository: IFileRepository): def __init__(self, file_repository: IFileRepository):
self._file_repository = file_repository self._file_repository = file_repository
self._simulation_schema = SimulationSchema()
def get_simulation(self) -> Simulation: def get_simulation(self) -> Simulation:
try: try:
with self._file_repository.open_file(SIMULATION_STATE_FILE_NAME) as f: with self._file_repository.open_file(SIMULATION_STATE_FILE_NAME) as f:
simulation_json = f.read().decode() simulation_json = f.read().decode()
return self._simulation_schema.loads(simulation_json) return Simulation.parse_raw(simulation_json)
except repository.FileNotFoundError: except repository.FileNotFoundError:
return Simulation() return Simulation()
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving the simulation state: {err}") raise RetrievalError(f"Error retrieving the simulation state: {err}")
def save_simulation(self, simulation: Simulation): def save_simulation(self, simulation: Simulation):
simulation_json = self._simulation_schema.dumps(simulation) simulation_json = simulation.json()
self._file_repository.save_file( self._file_repository.save_file(
SIMULATION_STATE_FILE_NAME, io.BytesIO(simulation_json.encode()) SIMULATION_STATE_FILE_NAME, io.BytesIO(simulation_json.encode())
@ -35,6 +33,4 @@ class FileSimulationRepository(ISimulationRepository):
return self.get_simulation().mode return self.get_simulation().mode
def set_mode(self, mode: IslandMode): def set_mode(self, mode: IslandMode):
old_simulation = self.get_simulation() self.save_simulation(Simulation(mode=mode))
new_simulation = dataclasses.replace(old_simulation, mode=mode)
self.save_simulation(new_simulation)

View File

@ -1,5 +1,3 @@
import dataclasses
from monkey_island.cc.models import IslandMode, Simulation from monkey_island.cc.models import IslandMode, Simulation
from monkey_island.cc.repository import ISimulationRepository from monkey_island.cc.repository import ISimulationRepository
@ -18,4 +16,4 @@ class InMemorySimulationRepository(ISimulationRepository):
return self._simulation.mode return self._simulation.mode
def set_mode(self, mode: IslandMode): def set_mode(self, mode: IslandMode):
self._simulation = dataclasses.replace(self._simulation, mode=mode) self._simulation = Simulation(mode=mode)

View File

@ -12,7 +12,7 @@ def simulation_repository():
@pytest.mark.parametrize("mode", list(IslandMode)) @pytest.mark.parametrize("mode", list(IslandMode))
def test_save_simulation(simulation_repository, mode): def test_save_simulation(simulation_repository, mode):
simulation = Simulation(mode) simulation = Simulation(mode=mode)
simulation_repository.save_simulation(simulation) simulation_repository.save_simulation(simulation)
assert simulation_repository.get_simulation() == simulation assert simulation_repository.get_simulation() == simulation