forked from p15670423/monkey
Island: Refactor simulation to use pydantic
This commit is contained in:
parent
6fe501195b
commit
f0b50b254e
|
@ -7,7 +7,7 @@ from .monkey import Monkey
|
||||||
from .monkey_ttl import MonkeyTtl
|
from .monkey_ttl import MonkeyTtl
|
||||||
from .pba_results import PbaResults
|
from .pba_results import PbaResults
|
||||||
from monkey_island.cc.models.report.report import Report
|
from monkey_island.cc.models.report.report import Report
|
||||||
from .simulation import Simulation, SimulationSchema, IslandMode
|
from .simulation import Simulation, IslandMode
|
||||||
from .user_credentials import UserCredentials
|
from .user_credentials import UserCredentials
|
||||||
from .machine import Machine, MachineID
|
from .machine import Machine, MachineID
|
||||||
from .communication_type import CommunicationType
|
from .communication_type import CommunicationType
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from marshmallow import Schema, post_load
|
from common.base_models import InfectionMonkeyBaseModel
|
||||||
from marshmallow_enum import EnumField
|
|
||||||
|
|
||||||
|
|
||||||
class IslandMode(Enum):
|
class IslandMode(Enum):
|
||||||
|
@ -13,14 +11,5 @@ class IslandMode(Enum):
|
||||||
ADVANCED = "advanced"
|
ADVANCED = "advanced"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
class Simulation(InfectionMonkeyBaseModel):
|
||||||
class Simulation:
|
|
||||||
mode: IslandMode = IslandMode.UNSET
|
mode: IslandMode = IslandMode.UNSET
|
||||||
|
|
||||||
|
|
||||||
class SimulationSchema(Schema):
|
|
||||||
mode = EnumField(IslandMode)
|
|
||||||
|
|
||||||
@post_load
|
|
||||||
def _make_simulation(self, data, **kwargs):
|
|
||||||
return Simulation(**data)
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import dataclasses
|
|
||||||
import io
|
import io
|
||||||
|
|
||||||
from monkey_island.cc import repository
|
from monkey_island.cc import repository
|
||||||
from monkey_island.cc.models import IslandMode, Simulation, SimulationSchema
|
from monkey_island.cc.models import IslandMode, Simulation
|
||||||
from monkey_island.cc.repository import IFileRepository, ISimulationRepository, RetrievalError
|
from monkey_island.cc.repository import IFileRepository, ISimulationRepository, RetrievalError
|
||||||
|
|
||||||
SIMULATION_STATE_FILE_NAME = "simulation_state.json"
|
SIMULATION_STATE_FILE_NAME = "simulation_state.json"
|
||||||
|
@ -11,21 +10,20 @@ SIMULATION_STATE_FILE_NAME = "simulation_state.json"
|
||||||
class FileSimulationRepository(ISimulationRepository):
|
class FileSimulationRepository(ISimulationRepository):
|
||||||
def __init__(self, file_repository: IFileRepository):
|
def __init__(self, file_repository: IFileRepository):
|
||||||
self._file_repository = file_repository
|
self._file_repository = file_repository
|
||||||
self._simulation_schema = SimulationSchema()
|
|
||||||
|
|
||||||
def get_simulation(self) -> Simulation:
|
def get_simulation(self) -> Simulation:
|
||||||
try:
|
try:
|
||||||
with self._file_repository.open_file(SIMULATION_STATE_FILE_NAME) as f:
|
with self._file_repository.open_file(SIMULATION_STATE_FILE_NAME) as f:
|
||||||
simulation_json = f.read().decode()
|
simulation_json = f.read().decode()
|
||||||
|
|
||||||
return self._simulation_schema.loads(simulation_json)
|
return Simulation.parse_raw(simulation_json)
|
||||||
except repository.FileNotFoundError:
|
except repository.FileNotFoundError:
|
||||||
return Simulation()
|
return Simulation()
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise RetrievalError(f"Error retrieving the simulation state: {err}")
|
raise RetrievalError(f"Error retrieving the simulation state: {err}")
|
||||||
|
|
||||||
def save_simulation(self, simulation: Simulation):
|
def save_simulation(self, simulation: Simulation):
|
||||||
simulation_json = self._simulation_schema.dumps(simulation)
|
simulation_json = simulation.json()
|
||||||
|
|
||||||
self._file_repository.save_file(
|
self._file_repository.save_file(
|
||||||
SIMULATION_STATE_FILE_NAME, io.BytesIO(simulation_json.encode())
|
SIMULATION_STATE_FILE_NAME, io.BytesIO(simulation_json.encode())
|
||||||
|
@ -35,6 +33,4 @@ class FileSimulationRepository(ISimulationRepository):
|
||||||
return self.get_simulation().mode
|
return self.get_simulation().mode
|
||||||
|
|
||||||
def set_mode(self, mode: IslandMode):
|
def set_mode(self, mode: IslandMode):
|
||||||
old_simulation = self.get_simulation()
|
self.save_simulation(Simulation(mode=mode))
|
||||||
new_simulation = dataclasses.replace(old_simulation, mode=mode)
|
|
||||||
self.save_simulation(new_simulation)
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
from monkey_island.cc.models import IslandMode, Simulation
|
from monkey_island.cc.models import IslandMode, Simulation
|
||||||
from monkey_island.cc.repository import ISimulationRepository
|
from monkey_island.cc.repository import ISimulationRepository
|
||||||
|
|
||||||
|
@ -18,4 +16,4 @@ class InMemorySimulationRepository(ISimulationRepository):
|
||||||
return self._simulation.mode
|
return self._simulation.mode
|
||||||
|
|
||||||
def set_mode(self, mode: IslandMode):
|
def set_mode(self, mode: IslandMode):
|
||||||
self._simulation = dataclasses.replace(self._simulation, mode=mode)
|
self._simulation = Simulation(mode=mode)
|
||||||
|
|
|
@ -12,7 +12,7 @@ def simulation_repository():
|
||||||
|
|
||||||
@pytest.mark.parametrize("mode", list(IslandMode))
|
@pytest.mark.parametrize("mode", list(IslandMode))
|
||||||
def test_save_simulation(simulation_repository, mode):
|
def test_save_simulation(simulation_repository, mode):
|
||||||
simulation = Simulation(mode)
|
simulation = Simulation(mode=mode)
|
||||||
simulation_repository.save_simulation(simulation)
|
simulation_repository.save_simulation(simulation)
|
||||||
|
|
||||||
assert simulation_repository.get_simulation() == simulation
|
assert simulation_repository.get_simulation() == simulation
|
||||||
|
|
Loading…
Reference in New Issue