diff --git a/monkey/infection_monkey/i_control_channel.py b/monkey/infection_monkey/i_control_channel.py index 39075750a..25135231f 100644 --- a/monkey/infection_monkey/i_control_channel.py +++ b/monkey/infection_monkey/i_control_channel.py @@ -1,23 +1,11 @@ import abc -from typing import Optional, Sequence -from uuid import UUID +from typing import Sequence from common.agent_configuration import AgentConfiguration from common.credentials import Credentials class IControlChannel(metaclass=abc.ABCMeta): - @abc.abstractmethod - def register_agent(self, parent_id: Optional[UUID] = None): - """ - Registers this agent with the Island when this agent starts - - :param parent: The ID of the parent that spawned this agent, or None if this agent has no - parent - :raises IslandCommunicationError: If the agent cannot be successfully registered - """ - pass - @abc.abstractmethod def should_agent_stop(self) -> bool: """ diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index a0bebc5a0..467f97252 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -1,13 +1,21 @@ import functools +import json import logging +from pprint import pformat from typing import List, Sequence import requests -from common import OperatingSystem +from common import AgentRegistrationData, OperatingSystem +from common.agent_configuration import AgentConfiguration from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_events import AbstractAgentEvent -from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT +from common.common_consts.timeouts import ( + LONG_REQUEST_TIMEOUT, + MEDIUM_REQUEST_TIMEOUT, + SHORT_REQUEST_TIMEOUT, +) +from common.credentials import Credentials from . import ( AbstractIslandAPIClientFactory, @@ -27,7 +35,9 @@ def handle_island_errors(fn): def decorated(*args, **kwargs): try: return fn(*args, **kwargs) - except requests.exceptions.ConnectionError as err: + except IslandAPIError as err: + raise err + except (requests.exceptions.ConnectionError, requests.exceptions.TooManyRedirects) as err: raise IslandAPIConnectionError(err) except requests.exceptions.HTTPError as err: if 400 <= err.response.status_code < 500: @@ -46,6 +56,17 @@ def handle_island_errors(fn): return decorated +def convert_json_error_to_island_api_error(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + fn(*args, **kwargs) + except json.JSONDecodeError as err: + raise IslandAPIRequestFailedError(err) + + return wrapper + + class HTTPIslandAPIClient(IIslandAPIClient): """ A client for the Island's HTTP API @@ -116,6 +137,58 @@ class HTTPIslandAPIClient(IIslandAPIClient): response.raise_for_status() + @handle_island_errors + def register_agent(self, agent_registration_data: AgentRegistrationData): + url = f"{self._api_url}/agents" + response = requests.post( # noqa: DUO123 + url, + json=agent_registration_data.dict(simplify=True), + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + @handle_island_errors + @convert_json_error_to_island_api_error + def should_agent_stop(self, agent_id: str) -> bool: + url = f"{self._api_url}/monkey-control/needs-to-stop/{agent_id}" + response = requests.get( # noqa: DUO123 + url, + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + return response.json()["stop_agent"] + + @handle_island_errors + @convert_json_error_to_island_api_error + def get_config(self) -> AgentConfiguration: + response = requests.get( # noqa: DUO123 + f"{self._api_url}/agent-configuration", + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + config_dict = response.json() + + logger.debug(f"Received configuration:\n{pformat(config_dict)}") + + return AgentConfiguration(**config_dict) + + @handle_island_errors + @convert_json_error_to_island_api_error + def get_credentials_for_propagation(self) -> Sequence[Credentials]: + response = requests.get( # noqa: DUO123 + f"{self._api_url}/propagation-credentials", + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + + return [Credentials(**credentials) for credentials in response.json()] + def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: serialized_events: List[JSONSerializable] = [] diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index 5bebc79c1..3ec15237d 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod from typing import Optional, Sequence -from common import OperatingSystem +from common import AgentRegistrationData, OperatingSystem +from common.agent_configuration import AgentConfiguration from common.agent_events import AbstractAgentEvent +from common.credentials import Credentials class IIslandAPIClient(ABC): @@ -13,7 +15,7 @@ class IIslandAPIClient(ABC): @abstractmethod def connect(self, island_server: str): """ - Connectto the island's API + Connect to the island's API :param island_server: The socket address of the API :raises IslandAPIConnectionError: If the client cannot successfully connect to the island @@ -74,7 +76,6 @@ class IIslandAPIClient(ABC): :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the agent binary - """ @abstractmethod @@ -92,3 +93,53 @@ class IIslandAPIClient(ABC): :raises IslandAPIError: If an unexpected error occurs while attempting to send events to the island """ + + @abstractmethod + def register_agent(self, agent_registration_data: AgentRegistrationData): + """ + Register an agent with the Island + + :param agent_registration_data: Information about the agent to register + with the island + :raises IslandAPIConnectionError: If the client could not connect to the island + :raises IslandAPIRequestError: If there was a problem with the client request + :raises IslandAPIRequestFailedError: If the server experienced an error + :raises IslandAPITimeoutError: If the command timed out + """ + + @abstractmethod + def should_agent_stop(self, agent_id: str) -> bool: + """ + Check with the island to see if the agent should stop + + :param agent_id: The agent identifier for the agent to check + :raises IslandAPIConnectionError: If the client could not connect to the island + :raises IslandAPIRequestError: If there was a problem with the client request + :raises IslandAPIRequestFailedError: If the server experienced an error + :raises IslandAPITimeoutError: If the command timed out + :return: True if the agent should stop, otherwise False + """ + + @abstractmethod + def get_config(self) -> AgentConfiguration: + """ + Get agent configuration from the island + + :raises IslandAPIConnectionError: If the client could not connect to the island + :raises IslandAPIRequestError: If there was a problem with the client request + :raises IslandAPIRequestFailedError: If the server experienced an error + :raises IslandAPITimeoutError: If the command timed out + :return: Agent configuration + """ + + @abstractmethod + def get_credentials_for_propagation(self) -> Sequence[Credentials]: + """ + Get credentials from the island + + :raises IslandAPIConnectionError: If the client could not connect to the island + :raises IslandAPIRequestError: If there was a problem with the client request + :raises IslandAPIRequestFailedError: If the server experienced an error + :raises IslandAPITimeoutError: If the command timed out + :return: Credentials + """ diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index 76be63b5d..39fe94d1b 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -1,127 +1,47 @@ -import json import logging -from pprint import pformat -from typing import Optional, Sequence -from uuid import UUID +from functools import wraps +from typing import Sequence -import requests from urllib3 import disable_warnings -from common import AgentRegistrationData from common.agent_configuration import AgentConfiguration -from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.credentials import Credentials -from common.network.network_utils import get_network_interfaces from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError -from infection_monkey.utils import agent_process -from infection_monkey.utils.ids import get_agent_id, get_machine_id +from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIError disable_warnings() # noqa: DUO131 logger = logging.getLogger(__name__) +def handle_island_api_errors(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + func(*args, **kwargs) + except IslandAPIError as err: + raise IslandCommunicationError(err) + + return wrapper + + class ControlChannel(IControlChannel): - def __init__(self, server: str, agent_id: str): + def __init__(self, server: str, agent_id: str, api_client: IIslandAPIClient): self._agent_id = agent_id self._control_channel_server = server + self._island_api_client = api_client - def register_agent(self, parent: Optional[UUID] = None): - agent_registration_data = AgentRegistrationData( - id=get_agent_id(), - machine_hardware_id=get_machine_id(), - start_time=agent_process.get_start_time(), - # parent_id=parent, - parent_id=None, # None for now, until we change GUID to UUID - cc_server=self._control_channel_server, - network_interfaces=get_network_interfaces(), - ) - - try: - url = f"https://{self._control_channel_server}/api/agents" - response = requests.post( # noqa: DUO123 - url, - json=agent_registration_data.dict(simplify=True), - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() - except ( - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.TooManyRedirects, - requests.exceptions.HTTPError, - ) as e: - raise IslandCommunicationError(e) - + @handle_island_api_errors def should_agent_stop(self) -> bool: if not self._control_channel_server: logger.error("Agent should stop because it can't connect to the C&C server.") return True - try: - url = ( - f"https://{self._control_channel_server}/api/monkey-control" - f"/needs-to-stop/{self._agent_id}" - ) - response = requests.get( # noqa: DUO123 - url, - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() - - json_response = json.loads(response.content.decode()) - return json_response["stop_agent"] - except ( - json.JSONDecodeError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.TooManyRedirects, - requests.exceptions.HTTPError, - ) as e: - raise IslandCommunicationError(e) + return self._island_api_client.should_agent_stop(self._agent_id) + @handle_island_api_errors def get_config(self) -> AgentConfiguration: - try: - response = requests.get( # noqa: DUO123 - f"https://{self._control_channel_server}/api/agent-configuration", - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() - - config_dict = json.loads(response.text) - - logger.debug(f"Received configuration:\n{pformat(config_dict)}") - - return AgentConfiguration(**config_dict) - except ( - json.JSONDecodeError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.TooManyRedirects, - requests.exceptions.HTTPError, - ) as e: - raise IslandCommunicationError(e) + return self._island_api_client.get_config() + @handle_island_api_errors def get_credentials_for_propagation(self) -> Sequence[Credentials]: - propagation_credentials_url = ( - f"https://{self._control_channel_server}/api/propagation-credentials" - ) - try: - response = requests.get( # noqa: DUO123 - propagation_credentials_url, - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() - - return [Credentials(**credentials) for credentials in response.json()] - except ( - requests.exceptions.JSONDecodeError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.TooManyRedirects, - requests.exceptions.HTTPError, - ) as e: - raise IslandCommunicationError(e) + return self._island_api_client.get_credentials_for_propagation() diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 50ed3a3cb..460c07592 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -14,6 +14,7 @@ from common.agent_event_serializers import ( register_common_agent_event_serializers, ) from common.agent_events import CredentialsStolenEvent +from common.agent_registration_data import AgentRegistrationData from common.event_queue import IAgentEventQueue, PyPubSubAgentEventQueue from common.network.network_utils import ( address_to_ip_port, @@ -88,9 +89,11 @@ from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter im LegacyTelemetryMessengerAdapter, ) from infection_monkey.telemetry.state_telem import StateTelem +from infection_monkey.utils import agent_process from infection_monkey.utils.aws_environment_check import run_aws_environment_check from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.file_utils import mark_file_for_deletion_on_windows +from infection_monkey.utils.ids import get_agent_id, get_machine_id from infection_monkey.utils.monkey_dir import ( create_monkey_dir, get_monkey_dir_path, @@ -120,6 +123,8 @@ class InfectionMonkey: self._control_client = ControlClient( server_address=server, island_api_client=self._island_api_client ) + self._control_channel = ControlChannel(server, GUID, self._island_api_client) + self._register_agent() # TODO Refactor the telemetry messengers to accept control client # and remove control_client_object @@ -165,6 +170,18 @@ class InfectionMonkey: return server, island_api_client + def _register_agent(self, server: str): + agent_registration_data = AgentRegistrationData( + id=get_agent_id(), + machine_hardware_id=get_machine_id(), + start_time=agent_process.get_start_time(), + # parent_id=parent, + parent_id=None, # None for now, until we change GUID to UUID + cc_server=server, + network_interfaces=get_network_interfaces(), + ) + self._island_api_client.register_agent(agent_registration_data) + def _select_server( self, server_clients: Mapping[str, IIslandAPIClient] ) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: @@ -195,7 +212,7 @@ class InfectionMonkey: run_aws_environment_check(self._telemetry_messenger) - should_stop = ControlChannel(self._control_client.server_address, GUID).should_agent_stop() + should_stop = self._control_channel.should_agent_stop() if should_stop: logger.info("The Monkey Island has instructed this agent to stop") return @@ -211,9 +228,6 @@ class InfectionMonkey: if firewall.is_enabled(): firewall.add_firewall_rule() - self._control_channel = ControlChannel(self._control_client.server_address, GUID) - self._control_channel.register_agent(self._opts.parent) - config = self._control_channel.get_config() relay_port = get_free_tcp_port() diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 1d0ac17fc..03117b006 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -10,6 +10,7 @@ from common.agent_event_serializers import ( PydanticAgentEventSerializer, ) from common.agent_events import AbstractAgentEvent +from common.agent_registration_data import AgentRegistrationData from infection_monkey.island_api_client import ( HTTPIslandAPIClient, IslandAPIConnectionError, @@ -22,14 +23,25 @@ from infection_monkey.island_api_client import ( SERVER = "1.1.1.1:9999" PBA_FILE = "dummy.pba" WINDOWS = "windows" +AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") +AGENT_REGISTRATION = AgentRegistrationData( + id=AGENT_ID, + machine_hardware_id=1, + start_time=0, + parent_id=None, + cc_server=SERVER, + network_interfaces=[], +) ISLAND_URI = f"https://{SERVER}/api?action=is-up" ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}" ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" - -AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") +ISLAND_REGISTER_AGENT_URI = f"https://{SERVER}/api/agents" +ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}" +ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration" +ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials" class Event1(AbstractAgentEvent): @@ -275,3 +287,177 @@ def test_island_api_client_send_events__status_code(island_api_client, status_co with pytest.raises(expected_error): m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)]) + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + ], +) +def test_island_api_client__register_agent(island_api_client, actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.post(ISLAND_REGISTER_AGENT_URI, exc=actual_error) + island_api_client.register_agent(AGENT_REGISTRATION) + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_register_agent__status_code( + island_api_client, status_code, expected_error +): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.post(ISLAND_REGISTER_AGENT_URI, status_code=status_code) + island_api_client.register_agent(AGENT_REGISTRATION) + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + ], +) +def test_island_api_client__should_agent_stop(island_api_client, actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_AGENT_STOP_URI, exc=actual_error) + island_api_client.should_agent_stop(AGENT_ID) + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_should_agent_stop__status_code( + island_api_client, status_code, expected_error +): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_AGENT_STOP_URI, status_code=status_code) + island_api_client.should_agent_stop(AGENT_ID) + + +def test_island_api_client_should_agent_stop__bad_json(island_api_client): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(IslandAPIRequestFailedError): + m.get(ISLAND_AGENT_STOP_URI, content=b"bad") + island_api_client.should_agent_stop(AGENT_ID) + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + ], +) +def test_island_api_client__get_config(island_api_client, actual_error, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_CONFIG_URI, exc=actual_error) + island_api_client.get_config() + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_get_config__status_code(island_api_client, status_code, expected_error): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_CONFIG_URI, status_code=status_code) + island_api_client.get_config() + + +def test_island_api_client_get_config__bad_json(island_api_client): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(IslandAPIRequestFailedError): + m.get(ISLAND_GET_CONFIG_URI, content=b"bad") + island_api_client.get_config() + + +@pytest.mark.parametrize( + "actual_error, expected_error", + [ + (requests.exceptions.ConnectionError, IslandAPIConnectionError), + (TimeoutError, IslandAPITimeoutError), + ], +) +def test_island_api_client__get_credentials_for_propagation( + island_api_client, actual_error, expected_error +): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, exc=actual_error) + island_api_client.get_credentials_for_propagation() + + +@pytest.mark.parametrize( + "status_code, expected_error", + [ + (401, IslandAPIRequestError), + (501, IslandAPIRequestFailedError), + ], +) +def test_island_api_client_get_credentials_for_propagation__status_code( + island_api_client, status_code, expected_error +): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(expected_error): + m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, status_code=status_code) + island_api_client.get_credentials_for_propagation() + + +def test_island_api_client_get_credentials_for_propagation__bad_json(island_api_client): + with requests_mock.Mocker() as m: + m.get(ISLAND_URI) + island_api_client.connect(SERVER) + + with pytest.raises(IslandAPIRequestFailedError): + m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad") + island_api_client.get_credentials_for_propagation() diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py new file mode 100644 index 000000000..658635615 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py @@ -0,0 +1,76 @@ +from unittest.mock import MagicMock + +import pytest + +from infection_monkey.i_control_channel import IslandCommunicationError +from infection_monkey.island_api_client import ( + IIslandAPIClient, + IslandAPIConnectionError, + IslandAPIRequestError, + IslandAPIRequestFailedError, + IslandAPITimeoutError, +) +from infection_monkey.master.control_channel import ControlChannel + +SERVER = "server" +AGENT_ID = "agent" +CONTROL_CHANNEL_API_ERRORS = [ + IslandAPIConnectionError, + IslandAPIRequestError, + IslandAPIRequestFailedError, + IslandAPITimeoutError, +] + + +@pytest.fixture +def island_api_client() -> IIslandAPIClient: + client = MagicMock() + return client + + +@pytest.fixture +def control_channel(island_api_client) -> ControlChannel: + return ControlChannel(SERVER, AGENT_ID, island_api_client) + + +def test_control_channel__should_agent_stop(control_channel, island_api_client): + control_channel.should_agent_stop() + assert island_api_client.should_agent_stop.called_once() + + +@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) +def test_control_channel__should_agent_stop_raises_error( + control_channel, island_api_client, api_error +): + island_api_client.should_agent_stop.side_effect = api_error() + + with pytest.raises(IslandCommunicationError): + control_channel.should_agent_stop() + + +def test_control_channel__get_config(control_channel, island_api_client): + control_channel.get_config() + assert island_api_client.get_config.called_once() + + +@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) +def test_control_channel__get_config_raises_error(control_channel, island_api_client, api_error): + island_api_client.get_config.side_effect = api_error() + + with pytest.raises(IslandCommunicationError): + control_channel.get_config() + + +def test_control_channel__get_credentials_for_propagation(control_channel, island_api_client): + control_channel.get_credentials_for_propagation() + assert island_api_client.get_credentials_for_propagation.called_once() + + +@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) +def test_control_channel__get_credentials_for_propagation_raises_error( + control_channel, island_api_client, api_error +): + island_api_client.get_credentials_for_propagation.side_effect = api_error() + + with pytest.raises(IslandCommunicationError): + control_channel.get_credentials_for_propagation()