forked from p15670423/monkey
Island: Implement AgentSignalsService.get_signals()
This commit is contained in:
parent
8e45a71a15
commit
a04a6a3cea
|
@ -1,13 +1,17 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from common.types import AgentID
|
from common.types import AgentID
|
||||||
from monkey_island.cc.repository import ISimulationRepository
|
|
||||||
from monkey_island.cc.models import AgentSignals
|
from monkey_island.cc.models import AgentSignals
|
||||||
|
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository
|
||||||
|
|
||||||
|
|
||||||
class AgentSignalsService:
|
class AgentSignalsService:
|
||||||
def __init__(self, simulation_repository: ISimulationRepository):
|
def __init__(
|
||||||
|
self, simulation_repository: ISimulationRepository, agent_repository: IAgentRepository
|
||||||
|
):
|
||||||
self._simulation_repository = simulation_repository
|
self._simulation_repository = simulation_repository
|
||||||
|
self._agent_repository = agent_repository
|
||||||
|
|
||||||
def get_signals(self, agent_id: AgentID) -> AgentSignals:
|
def get_signals(self, agent_id: AgentID) -> AgentSignals:
|
||||||
"""
|
"""
|
||||||
|
@ -16,7 +20,24 @@ class AgentSignalsService:
|
||||||
:param agent_id: The ID of the agent whose signals need to be retrieved
|
:param agent_id: The ID of the agent whose signals need to be retrieved
|
||||||
:return: Signals sent to the relevant agent
|
:return: Signals sent to the relevant agent
|
||||||
"""
|
"""
|
||||||
return AgentSignals(timestamp=datetime.now())
|
terminate_timestamp = self._get_terminate_signal_timestamp(agent_id)
|
||||||
|
return AgentSignals(terminate=terminate_timestamp)
|
||||||
|
|
||||||
|
def _get_terminate_signal_timestamp(self, agent_id: AgentID) -> Optional[datetime]:
|
||||||
|
simulation = self._simulation_repository.get_simulation()
|
||||||
|
terminate_all_signal_time = simulation.terminate_signal_time
|
||||||
|
if terminate_all_signal_time is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent = self._agent_repository.get_agent_by_id(agent_id)
|
||||||
|
if agent.start_time <= terminate_all_signal_time:
|
||||||
|
return terminate_all_signal_time
|
||||||
|
|
||||||
|
progenitor = self._agent_repository.get_progenitor(agent)
|
||||||
|
if progenitor.start_time <= terminate_all_signal_time:
|
||||||
|
return terminate_all_signal_time
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def on_terminate_agents_signal(self, timestamp: datetime):
|
def on_terminate_agents_signal(self, timestamp: datetime):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common.types import AgentID
|
||||||
|
from monkey_island.cc.models import Agent, Simulation
|
||||||
|
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository, UnknownRecordError
|
||||||
|
from monkey_island.cc.services import AgentSignalsService
|
||||||
|
|
||||||
|
AGENT_1 = Agent(
|
||||||
|
id=UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5"),
|
||||||
|
machine_id=1,
|
||||||
|
start_time=100,
|
||||||
|
parent_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_2 = Agent(
|
||||||
|
id=UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10"),
|
||||||
|
machine_id=2,
|
||||||
|
start_time=200,
|
||||||
|
parent_id=AGENT_1.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_3 = Agent(
|
||||||
|
id=UUID("0fc9afcb-1902-436b-bd5c-1ad194252484"),
|
||||||
|
machine_id=3,
|
||||||
|
start_time=300,
|
||||||
|
parent_id=AGENT_2.id,
|
||||||
|
)
|
||||||
|
AGENTS = [AGENT_1, AGENT_2, AGENT_3]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_simulation_repository() -> IAgentRepository:
|
||||||
|
return MagicMock(spec=ISimulationRepository)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mock_agent_repository() -> IAgentRepository:
|
||||||
|
def get_agent_by_id(agent_id: AgentID) -> Agent:
|
||||||
|
for agent in AGENTS:
|
||||||
|
if agent.id == agent_id:
|
||||||
|
return agent
|
||||||
|
|
||||||
|
raise UnknownRecordError(str(agent_id))
|
||||||
|
|
||||||
|
agent_repository = MagicMock(spec=IAgentRepository)
|
||||||
|
agent_repository.get_progenitor = MagicMock(return_value=AGENT_1)
|
||||||
|
agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id)
|
||||||
|
|
||||||
|
return agent_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_signals_service(mock_simulation_repository, mock_agent_repository) -> AgentSignalsService:
|
||||||
|
return AgentSignalsService(mock_simulation_repository, mock_agent_repository)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", AGENTS)
|
||||||
|
def test_terminate_is_none(
|
||||||
|
agent,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
mock_simulation_repository: ISimulationRepository,
|
||||||
|
):
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(terminate_signal_time=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals = agent_signals_service.get_signals(agent.id)
|
||||||
|
assert signals.terminate is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", AGENTS)
|
||||||
|
def test_agent_started_before_terminate(
|
||||||
|
agent,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
mock_simulation_repository: ISimulationRepository,
|
||||||
|
):
|
||||||
|
TERMINATE_TIMESTAMP = 400
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals = agent_signals_service.get_signals(agent.id)
|
||||||
|
|
||||||
|
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", AGENTS)
|
||||||
|
def test_agent_started_after_terminate(
|
||||||
|
agent,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
mock_simulation_repository: ISimulationRepository,
|
||||||
|
):
|
||||||
|
TERMINATE_TIMESTAMP = 50
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals = agent_signals_service.get_signals(agent.id)
|
||||||
|
|
||||||
|
assert signals.terminate is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent", AGENTS)
|
||||||
|
def test_progenitor_started_before_terminate(
|
||||||
|
agent,
|
||||||
|
agent_signals_service: AgentSignalsService,
|
||||||
|
mock_simulation_repository: ISimulationRepository,
|
||||||
|
):
|
||||||
|
TERMINATE_TIMESTAMP = 150
|
||||||
|
mock_simulation_repository.get_simulation = MagicMock(
|
||||||
|
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals = agent_signals_service.get_signals(agent.id)
|
||||||
|
|
||||||
|
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
|
Loading…
Reference in New Issue