forked from p15670423/monkey
Island: Implement AgentSignalsService.get_signals()
This commit is contained in:
parent
8e45a71a15
commit
a04a6a3cea
|
@ -1,13 +1,17 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from common.types import AgentID
|
||||
from monkey_island.cc.repository import ISimulationRepository
|
||||
from monkey_island.cc.models import AgentSignals
|
||||
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository
|
||||
|
||||
|
||||
class AgentSignalsService:
|
||||
def __init__(self, simulation_repository: ISimulationRepository):
|
||||
def __init__(
|
||||
self, simulation_repository: ISimulationRepository, agent_repository: IAgentRepository
|
||||
):
|
||||
self._simulation_repository = simulation_repository
|
||||
self._agent_repository = agent_repository
|
||||
|
||||
def get_signals(self, agent_id: AgentID) -> AgentSignals:
|
||||
"""
|
||||
|
@ -16,7 +20,24 @@ class AgentSignalsService:
|
|||
:param agent_id: The ID of the agent whose signals need to be retrieved
|
||||
:return: Signals sent to the relevant agent
|
||||
"""
|
||||
return AgentSignals(timestamp=datetime.now())
|
||||
terminate_timestamp = self._get_terminate_signal_timestamp(agent_id)
|
||||
return AgentSignals(terminate=terminate_timestamp)
|
||||
|
||||
def _get_terminate_signal_timestamp(self, agent_id: AgentID) -> Optional[datetime]:
|
||||
simulation = self._simulation_repository.get_simulation()
|
||||
terminate_all_signal_time = simulation.terminate_signal_time
|
||||
if terminate_all_signal_time is None:
|
||||
return None
|
||||
|
||||
agent = self._agent_repository.get_agent_by_id(agent_id)
|
||||
if agent.start_time <= terminate_all_signal_time:
|
||||
return terminate_all_signal_time
|
||||
|
||||
progenitor = self._agent_repository.get_progenitor(agent)
|
||||
if progenitor.start_time <= terminate_all_signal_time:
|
||||
return terminate_all_signal_time
|
||||
|
||||
return None
|
||||
|
||||
def on_terminate_agents_signal(self, timestamp: datetime):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from common.types import AgentID
|
||||
from monkey_island.cc.models import Agent, Simulation
|
||||
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository, UnknownRecordError
|
||||
from monkey_island.cc.services import AgentSignalsService
|
||||
|
||||
AGENT_1 = Agent(
|
||||
id=UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5"),
|
||||
machine_id=1,
|
||||
start_time=100,
|
||||
parent_id=None,
|
||||
)
|
||||
|
||||
AGENT_2 = Agent(
|
||||
id=UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10"),
|
||||
machine_id=2,
|
||||
start_time=200,
|
||||
parent_id=AGENT_1.id,
|
||||
)
|
||||
|
||||
AGENT_3 = Agent(
|
||||
id=UUID("0fc9afcb-1902-436b-bd5c-1ad194252484"),
|
||||
machine_id=3,
|
||||
start_time=300,
|
||||
parent_id=AGENT_2.id,
|
||||
)
|
||||
AGENTS = [AGENT_1, AGENT_2, AGENT_3]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_simulation_repository() -> IAgentRepository:
|
||||
return MagicMock(spec=ISimulationRepository)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_agent_repository() -> IAgentRepository:
|
||||
def get_agent_by_id(agent_id: AgentID) -> Agent:
|
||||
for agent in AGENTS:
|
||||
if agent.id == agent_id:
|
||||
return agent
|
||||
|
||||
raise UnknownRecordError(str(agent_id))
|
||||
|
||||
agent_repository = MagicMock(spec=IAgentRepository)
|
||||
agent_repository.get_progenitor = MagicMock(return_value=AGENT_1)
|
||||
agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id)
|
||||
|
||||
return agent_repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_signals_service(mock_simulation_repository, mock_agent_repository) -> AgentSignalsService:
|
||||
return AgentSignalsService(mock_simulation_repository, mock_agent_repository)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent", AGENTS)
|
||||
def test_terminate_is_none(
|
||||
agent,
|
||||
agent_signals_service: AgentSignalsService,
|
||||
mock_simulation_repository: ISimulationRepository,
|
||||
):
|
||||
mock_simulation_repository.get_simulation = MagicMock(
|
||||
return_value=Simulation(terminate_signal_time=None)
|
||||
)
|
||||
|
||||
signals = agent_signals_service.get_signals(agent.id)
|
||||
assert signals.terminate is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent", AGENTS)
|
||||
def test_agent_started_before_terminate(
|
||||
agent,
|
||||
agent_signals_service: AgentSignalsService,
|
||||
mock_simulation_repository: ISimulationRepository,
|
||||
):
|
||||
TERMINATE_TIMESTAMP = 400
|
||||
mock_simulation_repository.get_simulation = MagicMock(
|
||||
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||
)
|
||||
|
||||
signals = agent_signals_service.get_signals(agent.id)
|
||||
|
||||
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent", AGENTS)
|
||||
def test_agent_started_after_terminate(
|
||||
agent,
|
||||
agent_signals_service: AgentSignalsService,
|
||||
mock_simulation_repository: ISimulationRepository,
|
||||
):
|
||||
TERMINATE_TIMESTAMP = 50
|
||||
mock_simulation_repository.get_simulation = MagicMock(
|
||||
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||
)
|
||||
|
||||
signals = agent_signals_service.get_signals(agent.id)
|
||||
|
||||
assert signals.terminate is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent", AGENTS)
|
||||
def test_progenitor_started_before_terminate(
|
||||
agent,
|
||||
agent_signals_service: AgentSignalsService,
|
||||
mock_simulation_repository: ISimulationRepository,
|
||||
):
|
||||
TERMINATE_TIMESTAMP = 150
|
||||
mock_simulation_repository.get_simulation = MagicMock(
|
||||
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
|
||||
)
|
||||
|
||||
signals = agent_signals_service.get_signals(agent.id)
|
||||
|
||||
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
|
Loading…
Reference in New Issue