diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 6cd1d86a1..65e34df3c 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,8 +1,9 @@ import functools import json import logging +from datetime import datetime from pprint import pformat -from typing import List, Sequence +from typing import List, Optional, Sequence import requests @@ -146,19 +147,6 @@ class HTTPIslandAPIClient(IIslandAPIClient): ) response.raise_for_status() - @handle_island_errors - @convert_json_error_to_island_api_error - def should_agent_stop(self, agent_id: str) -> bool: - url = f"{self._api_url}/monkey-control/needs-to-stop/{agent_id}" - response = requests.get( # noqa: DUO123 - url, - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() - - return response.json()["stop_agent"] - @handle_island_errors @convert_json_error_to_island_api_error def get_config(self) -> AgentConfiguration: @@ -199,6 +187,18 @@ class HTTPIslandAPIClient(IIslandAPIClient): return serialized_events + @handle_island_errors + @convert_json_error_to_island_api_error + def get_agent_signals(self, agent_id: str) -> Optional[datetime]: + url = f"{self._api_url}/agent-signals/{agent_id}" + response = requests.get( # noqa: DUO123 + url, + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + return response.json()["terminate"] + class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): def __init__( diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index 8ecd98b49..c2a4dc899 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Sequence +from datetime import datetime +from typing import Optional, Sequence from common import AgentRegistrationData, OperatingSystem from common.agent_configuration import AgentConfiguration @@ -107,19 +108,6 @@ class IIslandAPIClient(ABC): :raises IslandAPITimeoutError: If the command timed out """ - @abstractmethod - def should_agent_stop(self, agent_id: str) -> bool: - """ - Check with the island to see if the agent should stop - - :param agent_id: The agent identifier for the agent to check - :raises IslandAPIConnectionError: If the client could not connect to the island - :raises IslandAPIRequestError: If there was a problem with the client request - :raises IslandAPIRequestFailedError: If the server experienced an error - :raises IslandAPITimeoutError: If the command timed out - :return: True if the agent should stop, otherwise False - """ - @abstractmethod def get_config(self) -> AgentConfiguration: """ @@ -143,3 +131,16 @@ class IIslandAPIClient(ABC): :raises IslandAPITimeoutError: If the command timed out :return: Credentials """ + + @abstractmethod + def get_agent_signals(self, agent_id: str) -> Optional[datetime]: + """ + Gets an agent's signals from the island + + :param agent_id: ID of the agent whose signals should be retrieved + :raises IslandAPIConnectionError: If the client could not connect to the island + :raises IslandAPIRequestError: If there was a problem with the client request + :raises IslandAPIRequestFailedError: If the server experienced an error + :raises IslandAPITimeoutError: If the command timed out + :return: The relevant agent's terminate signal's timestamp + """ diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index 48b827f63..947d6c0da 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -36,7 +36,7 @@ class ControlChannel(IControlChannel): if not self._control_channel_server: logger.error("Agent should stop because it can't connect to the C&C server.") return True - return self._island_api_client.should_agent_stop(self._agent_id) + return self._island_api_client.get_agent_signals(self._agent_id) is not None @handle_island_api_errors def get_config(self) -> AgentConfiguration: diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 03117b006..9505e6649 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -33,6 +33,8 @@ AGENT_REGISTRATION = AgentRegistrationData( network_interfaces=[], ) +TIMESTAMP = 123456789 + ISLAND_URI = f"https://{SERVER}/api?action=is-up" ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" @@ -42,6 +44,7 @@ ISLAND_REGISTER_AGENT_URI = f"https://{SERVER}/api/agents" ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}" ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration" ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials" +ISLAND_GET_AGENT_SIGNALS = f"https://{SERVER}/api/agent-signals/{AGENT_ID}" class Event1(AbstractAgentEvent): @@ -325,52 +328,6 @@ def test_island_api_client_register_agent__status_code( island_api_client.register_agent(AGENT_REGISTRATION) -@pytest.mark.parametrize( - "actual_error, expected_error", - [ - (requests.exceptions.ConnectionError, IslandAPIConnectionError), - (TimeoutError, IslandAPITimeoutError), - ], -) -def test_island_api_client__should_agent_stop(island_api_client, actual_error, expected_error): - with requests_mock.Mocker() as m: - m.get(ISLAND_URI) - island_api_client.connect(SERVER) - - with pytest.raises(expected_error): - m.get(ISLAND_AGENT_STOP_URI, exc=actual_error) - island_api_client.should_agent_stop(AGENT_ID) - - -@pytest.mark.parametrize( - "status_code, expected_error", - [ - (401, IslandAPIRequestError), - (501, IslandAPIRequestFailedError), - ], -) -def test_island_api_client_should_agent_stop__status_code( - island_api_client, status_code, expected_error -): - with requests_mock.Mocker() as m: - m.get(ISLAND_URI) - island_api_client.connect(SERVER) - - with pytest.raises(expected_error): - m.get(ISLAND_AGENT_STOP_URI, status_code=status_code) - island_api_client.should_agent_stop(AGENT_ID) - - -def test_island_api_client_should_agent_stop__bad_json(island_api_client): - with requests_mock.Mocker() as m: - m.get(ISLAND_URI) - island_api_client.connect(SERVER) - - with pytest.raises(IslandAPIRequestFailedError): - m.get(ISLAND_AGENT_STOP_URI, content=b"bad") - island_api_client.should_agent_stop(AGENT_ID) - - @pytest.mark.parametrize( "actual_error, expected_error", [ @@ -461,3 +418,61 @@ def test_island_api_client_get_credentials_for_propagation__bad_json(island_api_ with pytest.raises(IslandAPIRequestFailedError): m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad") island_api_client.get_credentials_for_propagation() + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + ], +) +def test_island_api_client__get_agent_signals(island_api_client, actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_AGENT_SIGNALS, exc=actual_error) + island_api_client.get_agent_signals(agent_id=AGENT_ID) + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_get_agent_signals__status_code( + island_api_client, status_code, expected_error +): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_AGENT_SIGNALS, status_code=status_code) + island_api_client.get_agent_signals(agent_id=AGENT_ID) + + +@pytest.mark.parametrize("expected_timestamp", [TIMESTAMP, None]) +def test_island_api_client_get_agent_signals(island_api_client, expected_timestamp): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": expected_timestamp}) + actual_terminate_timestamp = island_api_client.get_agent_signals(agent_id=AGENT_ID) + + assert actual_terminate_timestamp == expected_timestamp + + +def test_island_api_client_get_agent_signals__bad_json(island_api_client): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(IslandAPIError): + m.get(ISLAND_GET_AGENT_SIGNALS, json={"bogus": "vogus"}) + island_api_client.get_agent_signals(agent_id=AGENT_ID) diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py index 658635615..1da0d0713 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py @@ -35,14 +35,14 @@ def control_channel(island_api_client) -> ControlChannel: def test_control_channel__should_agent_stop(control_channel, island_api_client): control_channel.should_agent_stop() - assert island_api_client.should_agent_stop.called_once() + assert island_api_client.get_agent_signals.called_once() @pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) def test_control_channel__should_agent_stop_raises_error( control_channel, island_api_client, api_error ): - island_api_client.should_agent_stop.side_effect = api_error() + island_api_client.get_agent_signals.side_effect = api_error() with pytest.raises(IslandCommunicationError): control_channel.should_agent_stop()