Island: Implement AgentSignalsService.get_signals()

This commit is contained in:
Mike Salvatore 2022-09-22 13:35:16 -04:00
parent 8e45a71a15
commit a04a6a3cea
2 changed files with 143 additions and 3 deletions

View File

@ -1,13 +1,17 @@
from datetime import datetime
from typing import Optional
from common.types import AgentID
from monkey_island.cc.repository import ISimulationRepository
from monkey_island.cc.models import AgentSignals
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository
class AgentSignalsService:
def __init__(self, simulation_repository: ISimulationRepository):
def __init__(
self, simulation_repository: ISimulationRepository, agent_repository: IAgentRepository
):
self._simulation_repository = simulation_repository
self._agent_repository = agent_repository
def get_signals(self, agent_id: AgentID) -> AgentSignals:
"""
@ -16,7 +20,24 @@ class AgentSignalsService:
:param agent_id: The ID of the agent whose signals need to be retrieved
:return: Signals sent to the relevant agent
"""
return AgentSignals(timestamp=datetime.now())
terminate_timestamp = self._get_terminate_signal_timestamp(agent_id)
return AgentSignals(terminate=terminate_timestamp)
def _get_terminate_signal_timestamp(self, agent_id: AgentID) -> Optional[datetime]:
simulation = self._simulation_repository.get_simulation()
terminate_all_signal_time = simulation.terminate_signal_time
if terminate_all_signal_time is None:
return None
agent = self._agent_repository.get_agent_by_id(agent_id)
if agent.start_time <= terminate_all_signal_time:
return terminate_all_signal_time
progenitor = self._agent_repository.get_progenitor(agent)
if progenitor.start_time <= terminate_all_signal_time:
return terminate_all_signal_time
return None
def on_terminate_agents_signal(self, timestamp: datetime):
"""

View File

@ -0,0 +1,119 @@
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from common.types import AgentID
from monkey_island.cc.models import Agent, Simulation
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository, UnknownRecordError
from monkey_island.cc.services import AgentSignalsService
AGENT_1 = Agent(
id=UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5"),
machine_id=1,
start_time=100,
parent_id=None,
)
AGENT_2 = Agent(
id=UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10"),
machine_id=2,
start_time=200,
parent_id=AGENT_1.id,
)
AGENT_3 = Agent(
id=UUID("0fc9afcb-1902-436b-bd5c-1ad194252484"),
machine_id=3,
start_time=300,
parent_id=AGENT_2.id,
)
AGENTS = [AGENT_1, AGENT_2, AGENT_3]
@pytest.fixture
def mock_simulation_repository() -> IAgentRepository:
return MagicMock(spec=ISimulationRepository)
@pytest.fixture(scope="session")
def mock_agent_repository() -> IAgentRepository:
def get_agent_by_id(agent_id: AgentID) -> Agent:
for agent in AGENTS:
if agent.id == agent_id:
return agent
raise UnknownRecordError(str(agent_id))
agent_repository = MagicMock(spec=IAgentRepository)
agent_repository.get_progenitor = MagicMock(return_value=AGENT_1)
agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id)
return agent_repository
@pytest.fixture
def agent_signals_service(mock_simulation_repository, mock_agent_repository) -> AgentSignalsService:
return AgentSignalsService(mock_simulation_repository, mock_agent_repository)
@pytest.mark.parametrize("agent", AGENTS)
def test_terminate_is_none(
agent,
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=None)
)
signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate is None
@pytest.mark.parametrize("agent", AGENTS)
def test_agent_started_before_terminate(
agent,
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
TERMINATE_TIMESTAMP = 400
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
)
signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP
@pytest.mark.parametrize("agent", AGENTS)
def test_agent_started_after_terminate(
agent,
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
TERMINATE_TIMESTAMP = 50
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
)
signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate is None
@pytest.mark.parametrize("agent", AGENTS)
def test_progenitor_started_before_terminate(
agent,
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
TERMINATE_TIMESTAMP = 150
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP)
)
signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP