forked from p15670423/monkey
Island: Refactor simulation to use pydantic
This commit is contained in:
parent
6fe501195b
commit
f0b50b254e
|
@ -7,7 +7,7 @@ from .monkey import Monkey
|
|||
from .monkey_ttl import MonkeyTtl
|
||||
from .pba_results import PbaResults
|
||||
from monkey_island.cc.models.report.report import Report
|
||||
from .simulation import Simulation, SimulationSchema, IslandMode
|
||||
from .simulation import Simulation, IslandMode
|
||||
from .user_credentials import UserCredentials
|
||||
from .machine import Machine, MachineID
|
||||
from .communication_type import CommunicationType
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from marshmallow import Schema, post_load
|
||||
from marshmallow_enum import EnumField
|
||||
from common.base_models import InfectionMonkeyBaseModel
|
||||
|
||||
|
||||
class IslandMode(Enum):
|
||||
|
@ -13,14 +11,5 @@ class IslandMode(Enum):
|
|||
ADVANCED = "advanced"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Simulation:
|
||||
class Simulation(InfectionMonkeyBaseModel):
|
||||
mode: IslandMode = IslandMode.UNSET
|
||||
|
||||
|
||||
class SimulationSchema(Schema):
|
||||
mode = EnumField(IslandMode)
|
||||
|
||||
@post_load
|
||||
def _make_simulation(self, data, **kwargs):
|
||||
return Simulation(**data)
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import dataclasses
|
||||
import io
|
||||
|
||||
from monkey_island.cc import repository
|
||||
from monkey_island.cc.models import IslandMode, Simulation, SimulationSchema
|
||||
from monkey_island.cc.models import IslandMode, Simulation
|
||||
from monkey_island.cc.repository import IFileRepository, ISimulationRepository, RetrievalError
|
||||
|
||||
SIMULATION_STATE_FILE_NAME = "simulation_state.json"
|
||||
|
@ -11,21 +10,20 @@ SIMULATION_STATE_FILE_NAME = "simulation_state.json"
|
|||
class FileSimulationRepository(ISimulationRepository):
|
||||
def __init__(self, file_repository: IFileRepository):
|
||||
self._file_repository = file_repository
|
||||
self._simulation_schema = SimulationSchema()
|
||||
|
||||
def get_simulation(self) -> Simulation:
|
||||
try:
|
||||
with self._file_repository.open_file(SIMULATION_STATE_FILE_NAME) as f:
|
||||
simulation_json = f.read().decode()
|
||||
|
||||
return self._simulation_schema.loads(simulation_json)
|
||||
return Simulation.parse_raw(simulation_json)
|
||||
except repository.FileNotFoundError:
|
||||
return Simulation()
|
||||
except Exception as err:
|
||||
raise RetrievalError(f"Error retrieving the simulation state: {err}")
|
||||
|
||||
def save_simulation(self, simulation: Simulation):
|
||||
simulation_json = self._simulation_schema.dumps(simulation)
|
||||
simulation_json = simulation.json()
|
||||
|
||||
self._file_repository.save_file(
|
||||
SIMULATION_STATE_FILE_NAME, io.BytesIO(simulation_json.encode())
|
||||
|
@ -35,6 +33,4 @@ class FileSimulationRepository(ISimulationRepository):
|
|||
return self.get_simulation().mode
|
||||
|
||||
def set_mode(self, mode: IslandMode):
|
||||
old_simulation = self.get_simulation()
|
||||
new_simulation = dataclasses.replace(old_simulation, mode=mode)
|
||||
self.save_simulation(new_simulation)
|
||||
self.save_simulation(Simulation(mode=mode))
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import dataclasses
|
||||
|
||||
from monkey_island.cc.models import IslandMode, Simulation
|
||||
from monkey_island.cc.repository import ISimulationRepository
|
||||
|
||||
|
@ -18,4 +16,4 @@ class InMemorySimulationRepository(ISimulationRepository):
|
|||
return self._simulation.mode
|
||||
|
||||
def set_mode(self, mode: IslandMode):
|
||||
self._simulation = dataclasses.replace(self._simulation, mode=mode)
|
||||
self._simulation = Simulation(mode=mode)
|
||||
|
|
|
@ -12,7 +12,7 @@ def simulation_repository():
|
|||
|
||||
@pytest.mark.parametrize("mode", list(IslandMode))
|
||||
def test_save_simulation(simulation_repository, mode):
|
||||
simulation = Simulation(mode)
|
||||
simulation = Simulation(mode=mode)
|
||||
simulation_repository.save_simulation(simulation)
|
||||
|
||||
assert simulation_repository.get_simulation() == simulation
|
||||
|
|
Loading…
Reference in New Issue