Agent: Add get_credentials_for_propagation to IslandAPIClient

This commit is contained in:
Kekoa Kaaikala 2022-09-19 20:06:43 +00:00
parent d6795492a4
commit b260dcc5cb
4 changed files with 68 additions and 26 deletions

View File

@ -15,6 +15,7 @@ from common.common_consts.timeouts import (
MEDIUM_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT,
SHORT_REQUEST_TIMEOUT, SHORT_REQUEST_TIMEOUT,
) )
from common.credentials import Credentials
from . import ( from . import (
AbstractIslandAPIClientFactory, AbstractIslandAPIClientFactory,
@ -198,6 +199,31 @@ class HTTPIslandAPIClient(IIslandAPIClient):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise IslandAPIRequestFailedError(e) raise IslandAPIRequestFailedError(e)
def get_credentials_for_propagation(self, island_server: str) -> Sequence[Credentials]:
try:
response = requests.get( # noqa: DUO123
f"https://{island_server}/api/propagation-credentials",
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
return [Credentials(**credentials) for credentials in response.json()]
except (
requests.exceptions.ConnectionError,
requests.exceptions.TooManyRedirects,
) as e:
raise IslandAPIConnectionError(e)
except requests.exceptions.Timeout as e:
raise IslandAPITimeoutError(e)
except requests.exceptions.HTTPError as e:
if e.errno >= 500:
raise IslandAPIRequestFailedError(e)
else:
raise IslandAPIRequestError(e)
except json.JSONDecodeError as e:
raise IslandAPIRequestFailedError(e)
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
serialized_events: List[JSONSerializable] = [] serialized_events: List[JSONSerializable] = []

View File

@ -6,6 +6,7 @@ from common.agent_events import AbstractAgentEvent
from common import AgentRegistrationData from common import AgentRegistrationData
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.credentials import Credentials
class IIslandAPIClient(ABC): class IIslandAPIClient(ABC):
@ -132,3 +133,16 @@ class IIslandAPIClient(ABC):
:raises IslandAPITimeoutError: If the command timed out :raises IslandAPITimeoutError: If the command timed out
:return: Agent configuration :return: Agent configuration
""" """
@abstractmethod
def get_credentials_for_propagation(self, island_server: str) -> Sequence[Credentials]:
"""
Get credentials from the island
:param island_server: The server to query
:raises IslandAPIConnectionError: If the client could not connect to the island
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
:return: Credentials
"""

View File

@ -2,12 +2,10 @@ import logging
from typing import Optional, Sequence from typing import Optional, Sequence
from uuid import UUID from uuid import UUID
import requests
from urllib3 import disable_warnings from urllib3 import disable_warnings
from common import AgentRegistrationData from common import AgentRegistrationData
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from common.credentials import Credentials from common.credentials import Credentials
from common.network.network_utils import get_network_interfaces from common.network.network_utils import get_network_interfaces
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
@ -76,23 +74,14 @@ class ControlChannel(IControlChannel):
raise IslandCommunicationError(e) raise IslandCommunicationError(e)
def get_credentials_for_propagation(self) -> Sequence[Credentials]: def get_credentials_for_propagation(self) -> Sequence[Credentials]:
propagation_credentials_url = (
f"https://{self._control_channel_server}/api/propagation-credentials"
)
try: try:
response = requests.get( # noqa: DUO123 return self._island_api_client.get_credentials_for_propagation(
propagation_credentials_url, self._control_channel_server
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
) )
response.raise_for_status()
return [Credentials(**credentials) for credentials in response.json()]
except ( except (
requests.exceptions.JSONDecodeError, IslandAPIConnectionError,
requests.exceptions.ConnectionError, IslandAPIRequestError,
requests.exceptions.Timeout, IslandAPIRequestFailedError,
requests.exceptions.TooManyRedirects, IslandAPITimeoutError,
requests.exceptions.HTTPError,
) as e: ) as e:
raise IslandCommunicationError(e) raise IslandCommunicationError(e)

View File

@ -14,6 +14,12 @@ from infection_monkey.master.control_channel import ControlChannel
SERVER = "server" SERVER = "server"
AGENT_ID = "agent" AGENT_ID = "agent"
CONTROL_CHANNEL_API_ERRORS = [
IslandAPIConnectionError,
IslandAPIRequestError,
IslandAPIRequestFailedError,
IslandAPITimeoutError,
]
@pytest.fixture @pytest.fixture
@ -96,17 +102,24 @@ def test_control_channel__get_config(control_channel, island_api_client):
assert island_api_client.get_config.called_once() assert island_api_client.get_config.called_once()
@pytest.mark.parametrize( @pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
"api_error",
[
IslandAPIConnectionError,
IslandAPIRequestError,
IslandAPIRequestFailedError,
IslandAPITimeoutError,
],
)
def test_control_channel__get_config_raises_error(control_channel, island_api_client, api_error): def test_control_channel__get_config_raises_error(control_channel, island_api_client, api_error):
island_api_client.get_config.side_effect = api_error() island_api_client.get_config.side_effect = api_error()
with pytest.raises(IslandCommunicationError): with pytest.raises(IslandCommunicationError):
control_channel.get_config() control_channel.get_config()
def test_control_channel__get_credentials_for_propagation(control_channel, island_api_client):
control_channel.get_credentials_for_propagation()
assert island_api_client.get_credentials_for_propagation.called_once()
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
def test_control_channel__get_credentials_for_propagation_raises_error(
control_channel, island_api_client, api_error
):
island_api_client.get_credentials_for_propagation.side_effect = api_error()
with pytest.raises(IslandCommunicationError):
control_channel.get_credentials_for_propagation()