diff --git a/monkey/monkey_island/cc/services/__init__.py b/monkey/monkey_island/cc/services/__init__.py index c73aff356..8fdb5fc77 100644 --- a/monkey/monkey_island/cc/services/__init__.py +++ b/monkey/monkey_island/cc/services/__init__.py @@ -1,4 +1,4 @@ +from .agent_signals_service import AgentSignalsService from .authentication_service import AuthenticationService from .aws import AWSService -from .agent_signals_service import AgentSignalsService diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_agent_signals_service.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_agent_signals_service.py index 16aab943f..1beb53940 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/test_agent_signals_service.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_agent_signals_service.py @@ -4,7 +4,7 @@ from uuid import UUID import pytest from common.types import AgentID -from monkey_island.cc.models import Agent, Simulation +from monkey_island.cc.models import Agent, IslandMode, Simulation from monkey_island.cc.repository import IAgentRepository, ISimulationRepository, UnknownRecordError from monkey_island.cc.services import AgentSignalsService @@ -117,3 +117,28 @@ def test_progenitor_started_before_terminate( signals = agent_signals_service.get_signals(agent.id) assert signals.terminate.timestamp() == TERMINATE_TIMESTAMP + + +def test_on_terminate_agents_signal__stores_timestamp( + agent_signals_service: AgentSignalsService, mock_simulation_repository: ISimulationRepository +): + timestamp = 100 + mock_simulation_repository.get_simulation = MagicMock(return_value=Simulation()) + agent_signals_service.on_terminate_agents_signal(timestamp) + + expected_value = Simulation(terminate_signal_time=timestamp) + assert mock_simulation_repository.save_simulation.called_once_with(expected_value) + + +def test_on_terminate_agents_signal__updates_timestamp( + agent_signals_service: AgentSignalsService, mock_simulation_repository: ISimulationRepository +): + timestamp = 100 + mock_simulation_repository.get_simulation = MagicMock( + return_value=Simulation(mode=IslandMode.RANSOMWARE, terminate_signal_time=50) + ) + + agent_signals_service.on_terminate_agents_signal(timestamp) + + expected_value = Simulation(mode=IslandMode.RANSOMWARE, terminate_signal_time=timestamp) + assert mock_simulation_repository.save_simulation.called_once_with(expected_value)