diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index f92fcc946..e5e338c4e 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -15,6 +15,7 @@ from common.common_consts.timeouts import ( MEDIUM_REQUEST_TIMEOUT, SHORT_REQUEST_TIMEOUT, ) +from common.credentials import Credentials from . import ( AbstractIslandAPIClientFactory, @@ -198,6 +199,31 @@ class HTTPIslandAPIClient(IIslandAPIClient): except json.JSONDecodeError as e: raise IslandAPIRequestFailedError(e) + def get_credentials_for_propagation(self, island_server: str) -> Sequence[Credentials]: + try: + response = requests.get( # noqa: DUO123 + f"https://{island_server}/api/propagation-credentials", + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + return [Credentials(**credentials) for credentials in response.json()] + except ( + requests.exceptions.ConnectionError, + requests.exceptions.TooManyRedirects, + ) as e: + raise IslandAPIConnectionError(e) + except requests.exceptions.Timeout as e: + raise IslandAPITimeoutError(e) + except requests.exceptions.HTTPError as e: + if e.errno >= 500: + raise IslandAPIRequestFailedError(e) + else: + raise IslandAPIRequestError(e) + except json.JSONDecodeError as e: + raise IslandAPIRequestFailedError(e) + def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: serialized_events: List[JSONSerializable] = [] diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index 064511a73..409a1b265 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -6,6 +6,7 @@ from common.agent_events import AbstractAgentEvent from common import AgentRegistrationData from common.agent_configuration import AgentConfiguration +from common.credentials import Credentials class IIslandAPIClient(ABC): @@ -132,3 +133,16 @@ class IIslandAPIClient(ABC): :raises IslandAPITimeoutError: If the command timed out :return: Agent configuration """ + + @abstractmethod + def get_credentials_for_propagation(self, island_server: str) -> Sequence[Credentials]: + """ + Get credentials from the island + + :param island_server: The server to query + :raises IslandAPIConnectionError: If the client could not connect to the island + :raises IslandAPIRequestError: If there was a problem with the client request + :raises IslandAPIRequestFailedError: If the server experienced an error + :raises IslandAPITimeoutError: If the command timed out + :return: Credentials + """ diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index d01a6a233..cce8791e1 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -2,12 +2,10 @@ import logging from typing import Optional, Sequence from uuid import UUID -import requests from urllib3 import disable_warnings from common import AgentRegistrationData from common.agent_configuration import AgentConfiguration -from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.credentials import Credentials from common.network.network_utils import get_network_interfaces from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError @@ -76,23 +74,14 @@ class ControlChannel(IControlChannel): raise IslandCommunicationError(e) def get_credentials_for_propagation(self) -> Sequence[Credentials]: - propagation_credentials_url = ( - f"https://{self._control_channel_server}/api/propagation-credentials" - ) try: - response = requests.get( # noqa: DUO123 - propagation_credentials_url, - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, + return self._island_api_client.get_credentials_for_propagation( + self._control_channel_server ) - response.raise_for_status() - - return [Credentials(**credentials) for credentials in response.json()] except ( - requests.exceptions.JSONDecodeError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.TooManyRedirects, - requests.exceptions.HTTPError, + IslandAPIConnectionError, + IslandAPIRequestError, + IslandAPIRequestFailedError, + IslandAPITimeoutError, ) as e: raise IslandCommunicationError(e) diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py index dbb5ffb75..ff5c276f3 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py @@ -14,6 +14,12 @@ from infection_monkey.master.control_channel import ControlChannel SERVER = "server" AGENT_ID = "agent" +CONTROL_CHANNEL_API_ERRORS = [ + IslandAPIConnectionError, + IslandAPIRequestError, + IslandAPIRequestFailedError, + IslandAPITimeoutError, +] @pytest.fixture @@ -96,17 +102,24 @@ def test_control_channel__get_config(control_channel, island_api_client): assert island_api_client.get_config.called_once() -@pytest.mark.parametrize( - "api_error", - [ - IslandAPIConnectionError, - IslandAPIRequestError, - IslandAPIRequestFailedError, - IslandAPITimeoutError, - ], -) +@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) def test_control_channel__get_config_raises_error(control_channel, island_api_client, api_error): island_api_client.get_config.side_effect = api_error() with pytest.raises(IslandCommunicationError): control_channel.get_config() + + +def test_control_channel__get_credentials_for_propagation(control_channel, island_api_client): + control_channel.get_credentials_for_propagation() + assert island_api_client.get_credentials_for_propagation.called_once() + + +@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) +def test_control_channel__get_credentials_for_propagation_raises_error( + control_channel, island_api_client, api_error +): + island_api_client.get_credentials_for_propagation.side_effect = api_error() + + with pytest.raises(IslandCommunicationError): + control_channel.get_credentials_for_propagation()