diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py index 1da0d0713..efc52f79f 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py @@ -1,8 +1,10 @@ +from typing import Optional from unittest.mock import MagicMock import pytest -from infection_monkey.i_control_channel import IslandCommunicationError +from common import AgentSignals +from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.island_api_client import ( IIslandAPIClient, IslandAPIConnectionError, @@ -33,9 +35,17 @@ def control_channel(island_api_client) -> ControlChannel: return ControlChannel(SERVER, AGENT_ID, island_api_client) -def test_control_channel__should_agent_stop(control_channel, island_api_client): - control_channel.should_agent_stop() - assert island_api_client.get_agent_signals.called_once() +@pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)]) +def test_control_channel__should_agent_stop( + control_channel: IControlChannel, + island_api_client: IIslandAPIClient, + signal_time: Optional[int], + expected_should_stop: bool, +): + island_api_client.get_agent_signals = MagicMock( + return_value=AgentSignals(terminate=signal_time) + ) + assert control_channel.should_agent_stop() is expected_should_stop @pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)