From a04a6a3cea0c5f6b06f5626eae3fc4a33c74662f Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 22 Sep 2022 13:35:16 -0400 Subject: [PATCH] Island: Implement AgentSignalsService.get_signals() --- .../cc/services/agent_signals_service.py | 27 +++- .../cc/services/test_agent_signals_service.py | 119 ++++++++++++++++++ 2 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 monkey/tests/unit_tests/monkey_island/cc/services/test_agent_signals_service.py diff --git a/monkey/monkey_island/cc/services/agent_signals_service.py b/monkey/monkey_island/cc/services/agent_signals_service.py index 4a3c4ef74..23a15dba9 100644 --- a/monkey/monkey_island/cc/services/agent_signals_service.py +++ b/monkey/monkey_island/cc/services/agent_signals_service.py @@ -1,13 +1,17 @@ from datetime import datetime +from typing import Optional from common.types import AgentID -from monkey_island.cc.repository import ISimulationRepository from monkey_island.cc.models import AgentSignals +from monkey_island.cc.repository import IAgentRepository, ISimulationRepository class AgentSignalsService: - def __init__(self, simulation_repository: ISimulationRepository): + def __init__( + self, simulation_repository: ISimulationRepository, agent_repository: IAgentRepository + ): self._simulation_repository = simulation_repository + self._agent_repository = agent_repository def get_signals(self, agent_id: AgentID) -> AgentSignals: """ @@ -16,7 +20,24 @@ class AgentSignalsService: :param agent_id: The ID of the agent whose signals need to be retrieved :return: Signals sent to the relevant agent """ - return AgentSignals(timestamp=datetime.now()) + terminate_timestamp = self._get_terminate_signal_timestamp(agent_id) + return AgentSignals(terminate=terminate_timestamp) + + def _get_terminate_signal_timestamp(self, agent_id: AgentID) -> Optional[datetime]: + simulation = self._simulation_repository.get_simulation() + terminate_all_signal_time = simulation.terminate_signal_time + if terminate_all_signal_time is None: + return None + + agent = self._agent_repository.get_agent_by_id(agent_id) + if agent.start_time <= terminate_all_signal_time: + return terminate_all_signal_time + + progenitor = self._agent_repository.get_progenitor(agent) + if progenitor.start_time <= terminate_all_signal_time: + return terminate_all_signal_time + + return None def on_terminate_agents_signal(self, timestamp: datetime): """ diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_agent_signals_service.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_agent_signals_service.py new file mode 100644 index 000000000..16aab943f --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_agent_signals_service.py @@ -0,0 +1,119 @@ +from unittest.mock import MagicMock +from uuid import UUID + +import pytest + +from common.types import AgentID +from monkey_island.cc.models import Agent, Simulation +from monkey_island.cc.repository import IAgentRepository, ISimulationRepository, UnknownRecordError +from monkey_island.cc.services import AgentSignalsService + +AGENT_1 = Agent( + id=UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5"), + machine_id=1, + start_time=100, + parent_id=None, +) + +AGENT_2 = Agent( + id=UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10"), + machine_id=2, + start_time=200, + parent_id=AGENT_1.id, +) + +AGENT_3 = Agent( + id=UUID("0fc9afcb-1902-436b-bd5c-1ad194252484"), + machine_id=3, + start_time=300, + parent_id=AGENT_2.id, +) +AGENTS = [AGENT_1, AGENT_2, AGENT_3] + + +@pytest.fixture +def mock_simulation_repository() -> IAgentRepository: + return MagicMock(spec=ISimulationRepository) + + +@pytest.fixture(scope="session") +def mock_agent_repository() -> IAgentRepository: + def get_agent_by_id(agent_id: AgentID) -> Agent: + for agent in AGENTS: + if agent.id == agent_id: + return agent + + raise UnknownRecordError(str(agent_id)) + + agent_repository = MagicMock(spec=IAgentRepository) + agent_repository.get_progenitor = MagicMock(return_value=AGENT_1) + agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id) + + return agent_repository + + +@pytest.fixture +def agent_signals_service(mock_simulation_repository, mock_agent_repository) -> AgentSignalsService: + return AgentSignalsService(mock_simulation_repository, mock_agent_repository) + + +@pytest.mark.parametrize("agent", AGENTS) +def test_terminate_is_none( + agent, + agent_signals_service: AgentSignalsService, + mock_simulation_repository: ISimulationRepository, +): + mock_simulation_repository.get_simulation = MagicMock( + return_value=Simulation(terminate_signal_time=None) + ) + + signals = agent_signals_service.get_signals(agent.id) + assert signals.terminate is None + + +@pytest.mark.parametrize("agent", AGENTS) +def test_agent_started_before_terminate( + agent, + agent_signals_service: AgentSignalsService, + mock_simulation_repository: ISimulationRepository, +): + TERMINATE_TIMESTAMP = 400 + mock_simulation_repository.get_simulation = MagicMock( + return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP) + ) + + signals = agent_signals_service.get_signals(agent.id) + + assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP + + +@pytest.mark.parametrize("agent", AGENTS) +def test_agent_started_after_terminate( + agent, + agent_signals_service: AgentSignalsService, + mock_simulation_repository: ISimulationRepository, +): + TERMINATE_TIMESTAMP = 50 + mock_simulation_repository.get_simulation = MagicMock( + return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP) + ) + + signals = agent_signals_service.get_signals(agent.id) + + assert signals.terminate is None + + +@pytest.mark.parametrize("agent", AGENTS) +def test_progenitor_started_before_terminate( + agent, + agent_signals_service: AgentSignalsService, + mock_simulation_repository: ISimulationRepository, +): + TERMINATE_TIMESTAMP = 150 + mock_simulation_repository.get_simulation = MagicMock( + return_value=Simulation(terminate_signal_time=TERMINATE_TIMESTAMP) + ) + + signals = agent_signals_service.get_signals(agent.id) + + assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP