diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 407b6562e..97c5404d0 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,4 +1,5 @@ import functools +import json import logging from typing import List, Sequence @@ -139,6 +140,33 @@ class HTTPIslandAPIClient(IIslandAPIClient): except requests.exceptions.Timeout as e: raise IslandAPITimeoutError(e) + def should_agent_stop(self, island_server: str, agent_id: str) -> bool: + try: + url = f"https://{island_server}/api/monkey-control" f"/needs-to-stop/{agent_id}" + response = requests.get( # noqa: DUO123 + url, + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + json_response = json.loads(response.content.decode()) + return json_response["stop_agent"] + except ( + requests.exceptions.ConnectionError, + requests.exceptions.TooManyRedirects, + ) as e: + raise IslandAPIConnectionError(e) + except requests.exceptions.Timeout as e: + raise IslandAPITimeoutError(e) + except requests.exceptions.HTTPError as e: + if e.errno >= 500: + raise IslandAPIRequestFailedError(e) + else: + raise IslandAPIRequestError(e) + except json.JSONDecodeError as e: + raise IslandAPIRequestFailedError(e) + def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: serialized_events: List[JSONSerializable] = [] diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index cc32555dd..cbfbbf306 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -104,3 +104,17 @@ class IIslandAPIClient(ABC): :raises IslandAPIConnectionError: If the client could not connect to the island :raises IslandAPITimeoutError: If the command timed out """ + + @abstractmethod + def should_agent_stop(self, island_server: str, agent_id: str) -> bool: + """ + Check with the island to see if the agent should stop + + :param island_server: The server to query + :param agent_id: The agent identifier for the agent to check + :raises IslandAPIConnectionError: If the client could not connect to the island + :raises IslandAPIRequestError: If there was a problem with the client request + :raises IslandAPIRequestFailedError: If the server experienced an error + :raises IslandAPITimeoutError: If the command timed out + :return: True if the agent should stop, otherwise False + """ diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index cd4496a9d..f61091e8f 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -16,6 +16,8 @@ from infection_monkey.i_control_channel import IControlChannel, IslandCommunicat from infection_monkey.island_api_client import ( IIslandAPIClient, IslandAPIConnectionError, + IslandAPIRequestError, + IslandAPIRequestFailedError, IslandAPITimeoutError, ) from infection_monkey.utils import agent_process @@ -53,25 +55,14 @@ class ControlChannel(IControlChannel): logger.error("Agent should stop because it can't connect to the C&C server.") return True try: - url = ( - f"https://{self._control_channel_server}/api/monkey-control" - f"/needs-to-stop/{self._agent_id}" + return self._island_api_client.should_agent_stop( + self._control_channel_server, self._agent_id ) - response = requests.get( # noqa: DUO123 - url, - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() - - json_response = json.loads(response.content.decode()) - return json_response["stop_agent"] except ( - json.JSONDecodeError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.TooManyRedirects, - requests.exceptions.HTTPError, + IslandAPIConnectionError, + IslandAPIRequestError, + IslandAPIRequestFailedError, + IslandAPITimeoutError, ) as e: raise IslandCommunicationError(e) diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py index 75a3eb149..534ef157b 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py @@ -6,10 +6,15 @@ from infection_monkey.i_control_channel import IslandCommunicationError from infection_monkey.island_api_client import ( IIslandAPIClient, IslandAPIConnectionError, + IslandAPIRequestError, + IslandAPIRequestFailedError, IslandAPITimeoutError, ) from infection_monkey.master.control_channel import ControlChannel +SERVER = "server" +AGENT_ID = "agent" + @pytest.fixture def island_api_client() -> IIslandAPIClient: @@ -19,7 +24,7 @@ def island_api_client() -> IIslandAPIClient: @pytest.fixture def control_channel(island_api_client) -> ControlChannel: - return ControlChannel("server", "agent-id", island_api_client) + return ControlChannel(SERVER, AGENT_ID, island_api_client) def test_control_channel__register_agent(control_channel, island_api_client): @@ -43,3 +48,44 @@ def test_control_channel__register_agent_raises_on_timeout_error( with pytest.raises(IslandCommunicationError): control_channel.register_agent() + + +def test_control_channel__should_agent_stop(control_channel, island_api_client): + control_channel.should_agent_stop() + assert island_api_client.should_agent_stop.called_once() + + +def test_control_channel__should_agent_stop_raises_on_connection_error( + control_channel, island_api_client +): + island_api_client.should_agent_stop.side_effect = IslandAPIConnectionError() + + with pytest.raises(IslandCommunicationError): + control_channel.should_agent_stop() + + +def test_control_channel__should_agent_stop_raises_on_timeout_error( + control_channel, island_api_client +): + island_api_client.should_agent_stop.side_effect = IslandAPITimeoutError() + + with pytest.raises(IslandCommunicationError): + control_channel.should_agent_stop() + + +def test_control_channel__should_agent_stop_raises_on_request_error( + control_channel, island_api_client +): + island_api_client.should_agent_stop.side_effect = IslandAPIRequestError() + + with pytest.raises(IslandCommunicationError): + control_channel.should_agent_stop() + + +def test_control_channel__should_agent_stop_raises_on_request_failed_error( + control_channel, island_api_client +): + island_api_client.should_agent_stop.side_effect = IslandAPIRequestFailedError() + + with pytest.raises(IslandCommunicationError): + control_channel.should_agent_stop()