diff --git a/monkey/common/__init__.py b/monkey/common/__init__.py index 4898b2df8..69190bce5 100644 --- a/monkey/common/__init__.py +++ b/monkey/common/__init__.py @@ -7,3 +7,4 @@ from .operating_system import OperatingSystem from . import types from . import base_models from .agent_registration_data import AgentRegistrationData +from .agent_signals import AgentSignals diff --git a/monkey/monkey_island/cc/models/agent_signals.py b/monkey/common/agent_signals.py similarity index 71% rename from monkey/monkey_island/cc/models/agent_signals.py rename to monkey/common/agent_signals.py index 37af7b4c1..c351e8823 100644 --- a/monkey/monkey_island/cc/models/agent_signals.py +++ b/monkey/common/agent_signals.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Optional -from common.base_models import InfectionMonkeyBaseModel +from .base_models import InfectionMonkeyBaseModel class AgentSignals(InfectionMonkeyBaseModel): diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 65e34df3c..09fcbf762 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,13 +1,12 @@ import functools import json import logging -from datetime import datetime from pprint import pformat -from typing import List, Optional, Sequence +from typing import List, Sequence import requests -from common import AgentRegistrationData, OperatingSystem +from common import AgentRegistrationData, AgentSignals, OperatingSystem from common.agent_configuration import AgentConfiguration from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_events import AbstractAgentEvent @@ -189,7 +188,7 @@ class HTTPIslandAPIClient(IIslandAPIClient): @handle_island_errors @convert_json_error_to_island_api_error - def get_agent_signals(self, agent_id: str) -> Optional[datetime]: + def get_agent_signals(self, agent_id: str) -> AgentSignals: url = f"{self._api_url}/agent-signals/{agent_id}" response = requests.get( # noqa: DUO123 url, @@ -197,7 +196,7 @@ class HTTPIslandAPIClient(IIslandAPIClient): timeout=SHORT_REQUEST_TIMEOUT, ) response.raise_for_status() - return response.json()["terminate"] + return AgentSignals(**response.json()) class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index c2a4dc899..2cf1e72ab 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod -from datetime import datetime from typing import Optional, Sequence -from common import AgentRegistrationData, OperatingSystem +from common import AgentRegistrationData, AgentSignals, OperatingSystem from common.agent_configuration import AgentConfiguration from common.agent_events import AbstractAgentEvent from common.credentials import Credentials @@ -133,7 +132,7 @@ class IIslandAPIClient(ABC): """ @abstractmethod - def get_agent_signals(self, agent_id: str) -> Optional[datetime]: + def get_agent_signals(self, agent_id: str) -> AgentSignals: """ Gets an agent's signals from the island @@ -142,5 +141,5 @@ class IIslandAPIClient(ABC): :raises IslandAPIRequestError: If there was a problem with the client request :raises IslandAPIRequestFailedError: If the server experienced an error :raises IslandAPITimeoutError: If the command timed out - :return: The relevant agent's terminate signal's timestamp + :return: The relevant agent's signals """ diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index 947d6c0da..ad2081065 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -36,7 +36,8 @@ class ControlChannel(IControlChannel): if not self._control_channel_server: logger.error("Agent should stop because it can't connect to the C&C server.") return True - return self._island_api_client.get_agent_signals(self._agent_id) is not None + agent_signals = self._island_api_client.get_agent_signals(self._agent_id) + return agent_signals.terminate is not None @handle_island_api_errors def get_config(self) -> AgentConfiguration: diff --git a/monkey/monkey_island/cc/models/__init__.py b/monkey/monkey_island/cc/models/__init__.py index ca4078faa..65e63fe14 100644 --- a/monkey/monkey_island/cc/models/__init__.py +++ b/monkey/monkey_island/cc/models/__init__.py @@ -15,4 +15,3 @@ from .communication_type import CommunicationType from .node import Node from common.types import AgentID from .agent import Agent -from .agent_signals import AgentSignals diff --git a/monkey/monkey_island/cc/services/agent_signals_service.py b/monkey/monkey_island/cc/services/agent_signals_service.py index 473a10066..bf0b3e61f 100644 --- a/monkey/monkey_island/cc/services/agent_signals_service.py +++ b/monkey/monkey_island/cc/services/agent_signals_service.py @@ -2,8 +2,9 @@ import logging from datetime import datetime from typing import Optional +from common.agent_signals import AgentSignals from common.types import AgentID -from monkey_island.cc.models import AgentSignals, Simulation +from monkey_island.cc.models import Simulation from monkey_island.cc.repository import IAgentRepository, ISimulationRepository logger = logging.getLogger(__name__) diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 9505e6649..a861a4849 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -4,7 +4,7 @@ import pytest import requests import requests_mock -from common import OperatingSystem +from common import AgentSignals, OperatingSystem from common.agent_event_serializers import ( AgentEventSerializerRegistry, PydanticAgentEventSerializer, @@ -456,16 +456,17 @@ def test_island_api_client_get_agent_signals__status_code( island_api_client.get_agent_signals(agent_id=AGENT_ID) -@pytest.mark.parametrize("expected_timestamp", [TIMESTAMP, None]) -def test_island_api_client_get_agent_signals(island_api_client, expected_timestamp): +@pytest.mark.parametrize("timestamp", [TIMESTAMP, None]) +def test_island_api_client_get_agent_signals(island_api_client, timestamp): + expected_agent_signals = AgentSignals(terminate=timestamp) with requests_mock.Mocker() as m: m.get(ISLAND_URI) island_api_client.connect(SERVER) - m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": expected_timestamp}) - actual_terminate_timestamp = island_api_client.get_agent_signals(agent_id=AGENT_ID) + m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": timestamp}) + actual_agent_signals = island_api_client.get_agent_signals(agent_id=AGENT_ID) - assert actual_terminate_timestamp == expected_timestamp + assert actual_agent_signals == expected_agent_signals def test_island_api_client_get_agent_signals__bad_json(island_api_client): diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py index 1da0d0713..efc52f79f 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py @@ -1,8 +1,10 @@ +from typing import Optional from unittest.mock import MagicMock import pytest -from infection_monkey.i_control_channel import IslandCommunicationError +from common import AgentSignals +from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.island_api_client import ( IIslandAPIClient, IslandAPIConnectionError, @@ -33,9 +35,17 @@ def control_channel(island_api_client) -> ControlChannel: return ControlChannel(SERVER, AGENT_ID, island_api_client) -def test_control_channel__should_agent_stop(control_channel, island_api_client): - control_channel.should_agent_stop() - assert island_api_client.get_agent_signals.called_once() +@pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)]) +def test_control_channel__should_agent_stop( + control_channel: IControlChannel, + island_api_client: IIslandAPIClient, + signal_time: Optional[int], + expected_should_stop: bool, +): + island_api_client.get_agent_signals = MagicMock( + return_value=AgentSignals(terminate=signal_time) + ) + assert control_channel.should_agent_stop() is expected_should_stop @pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/agent_signals/test_agent_signals.py b/monkey/tests/unit_tests/monkey_island/cc/resources/agent_signals/test_agent_signals.py index e7a20a9c9..723ac1afe 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/agent_signals/test_agent_signals.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/agent_signals/test_agent_signals.py @@ -5,7 +5,7 @@ from uuid import UUID import pytest from tests.common import StubDIContainer -from monkey_island.cc.models import AgentSignals as Signals +from common.agent_signals import AgentSignals as Signals from monkey_island.cc.repository import RetrievalError, StorageError from monkey_island.cc.services import AgentSignalsService