Merge pull request #2324 from guardicore/2292-control-channel-client-api-client

2292 control channel client api client
This commit is contained in:
Mike Salvatore 2022-09-20 14:47:51 -04:00 committed by GitHub
commit f472963b78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 435 additions and 127 deletions

View File

@ -1,23 +1,11 @@
import abc import abc
from typing import Optional, Sequence from typing import Sequence
from uuid import UUID
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.credentials import Credentials from common.credentials import Credentials
class IControlChannel(metaclass=abc.ABCMeta): class IControlChannel(metaclass=abc.ABCMeta):
@abc.abstractmethod
def register_agent(self, parent_id: Optional[UUID] = None):
"""
Registers this agent with the Island when this agent starts
:param parent: The ID of the parent that spawned this agent, or None if this agent has no
parent
:raises IslandCommunicationError: If the agent cannot be successfully registered
"""
pass
@abc.abstractmethod @abc.abstractmethod
def should_agent_stop(self) -> bool: def should_agent_stop(self) -> bool:
""" """

View File

@ -1,13 +1,21 @@
import functools import functools
import json
import logging import logging
from pprint import pformat
from typing import List, Sequence from typing import List, Sequence
import requests import requests
from common import OperatingSystem from common import AgentRegistrationData, OperatingSystem
from common.agent_configuration import AgentConfiguration
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import (
LONG_REQUEST_TIMEOUT,
MEDIUM_REQUEST_TIMEOUT,
SHORT_REQUEST_TIMEOUT,
)
from common.credentials import Credentials
from . import ( from . import (
AbstractIslandAPIClientFactory, AbstractIslandAPIClientFactory,
@ -27,7 +35,9 @@ def handle_island_errors(fn):
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
try: try:
return fn(*args, **kwargs) return fn(*args, **kwargs)
except requests.exceptions.ConnectionError as err: except IslandAPIError as err:
raise err
except (requests.exceptions.ConnectionError, requests.exceptions.TooManyRedirects) as err:
raise IslandAPIConnectionError(err) raise IslandAPIConnectionError(err)
except requests.exceptions.HTTPError as err: except requests.exceptions.HTTPError as err:
if 400 <= err.response.status_code < 500: if 400 <= err.response.status_code < 500:
@ -46,6 +56,17 @@ def handle_island_errors(fn):
return decorated return decorated
def convert_json_error_to_island_api_error(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
fn(*args, **kwargs)
except json.JSONDecodeError as err:
raise IslandAPIRequestFailedError(err)
return wrapper
class HTTPIslandAPIClient(IIslandAPIClient): class HTTPIslandAPIClient(IIslandAPIClient):
""" """
A client for the Island's HTTP API A client for the Island's HTTP API
@ -116,6 +137,58 @@ class HTTPIslandAPIClient(IIslandAPIClient):
response.raise_for_status() response.raise_for_status()
@handle_island_errors
def register_agent(self, agent_registration_data: AgentRegistrationData):
url = f"{self._api_url}/agents"
response = requests.post( # noqa: DUO123
url,
json=agent_registration_data.dict(simplify=True),
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
@handle_island_errors
@convert_json_error_to_island_api_error
def should_agent_stop(self, agent_id: str) -> bool:
url = f"{self._api_url}/monkey-control/needs-to-stop/{agent_id}"
response = requests.get( # noqa: DUO123
url,
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
return response.json()["stop_agent"]
@handle_island_errors
@convert_json_error_to_island_api_error
def get_config(self) -> AgentConfiguration:
response = requests.get( # noqa: DUO123
f"{self._api_url}/agent-configuration",
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
config_dict = response.json()
logger.debug(f"Received configuration:\n{pformat(config_dict)}")
return AgentConfiguration(**config_dict)
@handle_island_errors
@convert_json_error_to_island_api_error
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
response = requests.get( # noqa: DUO123
f"{self._api_url}/propagation-credentials",
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
return [Credentials(**credentials) for credentials in response.json()]
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
serialized_events: List[JSONSerializable] = [] serialized_events: List[JSONSerializable] = []

View File

@ -1,8 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Sequence from typing import Optional, Sequence
from common import OperatingSystem from common import AgentRegistrationData, OperatingSystem
from common.agent_configuration import AgentConfiguration
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from common.credentials import Credentials
class IIslandAPIClient(ABC): class IIslandAPIClient(ABC):
@ -13,7 +15,7 @@ class IIslandAPIClient(ABC):
@abstractmethod @abstractmethod
def connect(self, island_server: str): def connect(self, island_server: str):
""" """
Connectto the island's API Connect to the island's API
:param island_server: The socket address of the API :param island_server: The socket address of the API
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island :raises IslandAPIConnectionError: If the client cannot successfully connect to the island
@ -74,7 +76,6 @@ class IIslandAPIClient(ABC):
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
:raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
agent binary agent binary
""" """
@abstractmethod @abstractmethod
@ -92,3 +93,53 @@ class IIslandAPIClient(ABC):
:raises IslandAPIError: If an unexpected error occurs while attempting to send events to :raises IslandAPIError: If an unexpected error occurs while attempting to send events to
the island the island
""" """
@abstractmethod
def register_agent(self, agent_registration_data: AgentRegistrationData):
"""
Register an agent with the Island
:param agent_registration_data: Information about the agent to register
with the island
:raises IslandAPIConnectionError: If the client could not connect to the island
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
"""
@abstractmethod
def should_agent_stop(self, agent_id: str) -> bool:
"""
Check with the island to see if the agent should stop
:param agent_id: The agent identifier for the agent to check
:raises IslandAPIConnectionError: If the client could not connect to the island
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
:return: True if the agent should stop, otherwise False
"""
@abstractmethod
def get_config(self) -> AgentConfiguration:
"""
Get agent configuration from the island
:raises IslandAPIConnectionError: If the client could not connect to the island
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
:return: Agent configuration
"""
@abstractmethod
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
"""
Get credentials from the island
:raises IslandAPIConnectionError: If the client could not connect to the island
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
:return: Credentials
"""

View File

@ -1,127 +1,47 @@
import json
import logging import logging
from pprint import pformat from functools import wraps
from typing import Optional, Sequence from typing import Sequence
from uuid import UUID
import requests
from urllib3 import disable_warnings from urllib3 import disable_warnings
from common import AgentRegistrationData
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from common.credentials import Credentials from common.credentials import Credentials
from common.network.network_utils import get_network_interfaces
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.utils import agent_process from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIError
from infection_monkey.utils.ids import get_agent_id, get_machine_id
disable_warnings() # noqa: DUO131 disable_warnings() # noqa: DUO131
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def handle_island_api_errors(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except IslandAPIError as err:
raise IslandCommunicationError(err)
return wrapper
class ControlChannel(IControlChannel): class ControlChannel(IControlChannel):
def __init__(self, server: str, agent_id: str): def __init__(self, server: str, agent_id: str, api_client: IIslandAPIClient):
self._agent_id = agent_id self._agent_id = agent_id
self._control_channel_server = server self._control_channel_server = server
self._island_api_client = api_client
def register_agent(self, parent: Optional[UUID] = None): @handle_island_api_errors
agent_registration_data = AgentRegistrationData(
id=get_agent_id(),
machine_hardware_id=get_machine_id(),
start_time=agent_process.get_start_time(),
# parent_id=parent,
parent_id=None, # None for now, until we change GUID to UUID
cc_server=self._control_channel_server,
network_interfaces=get_network_interfaces(),
)
try:
url = f"https://{self._control_channel_server}/api/agents"
response = requests.post( # noqa: DUO123
url,
json=agent_registration_data.dict(simplify=True),
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.TooManyRedirects,
requests.exceptions.HTTPError,
) as e:
raise IslandCommunicationError(e)
def should_agent_stop(self) -> bool: def should_agent_stop(self) -> bool:
if not self._control_channel_server: if not self._control_channel_server:
logger.error("Agent should stop because it can't connect to the C&C server.") logger.error("Agent should stop because it can't connect to the C&C server.")
return True return True
try: return self._island_api_client.should_agent_stop(self._agent_id)
url = (
f"https://{self._control_channel_server}/api/monkey-control"
f"/needs-to-stop/{self._agent_id}"
)
response = requests.get( # noqa: DUO123
url,
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
json_response = json.loads(response.content.decode())
return json_response["stop_agent"]
except (
json.JSONDecodeError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.TooManyRedirects,
requests.exceptions.HTTPError,
) as e:
raise IslandCommunicationError(e)
@handle_island_api_errors
def get_config(self) -> AgentConfiguration: def get_config(self) -> AgentConfiguration:
try: return self._island_api_client.get_config()
response = requests.get( # noqa: DUO123
f"https://{self._control_channel_server}/api/agent-configuration",
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
config_dict = json.loads(response.text)
logger.debug(f"Received configuration:\n{pformat(config_dict)}")
return AgentConfiguration(**config_dict)
except (
json.JSONDecodeError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.TooManyRedirects,
requests.exceptions.HTTPError,
) as e:
raise IslandCommunicationError(e)
@handle_island_api_errors
def get_credentials_for_propagation(self) -> Sequence[Credentials]: def get_credentials_for_propagation(self) -> Sequence[Credentials]:
propagation_credentials_url = ( return self._island_api_client.get_credentials_for_propagation()
f"https://{self._control_channel_server}/api/propagation-credentials"
)
try:
response = requests.get( # noqa: DUO123
propagation_credentials_url,
verify=False,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
return [Credentials(**credentials) for credentials in response.json()]
except (
requests.exceptions.JSONDecodeError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.TooManyRedirects,
requests.exceptions.HTTPError,
) as e:
raise IslandCommunicationError(e)

View File

@ -14,6 +14,7 @@ from common.agent_event_serializers import (
register_common_agent_event_serializers, register_common_agent_event_serializers,
) )
from common.agent_events import CredentialsStolenEvent from common.agent_events import CredentialsStolenEvent
from common.agent_registration_data import AgentRegistrationData
from common.event_queue import IAgentEventQueue, PyPubSubAgentEventQueue from common.event_queue import IAgentEventQueue, PyPubSubAgentEventQueue
from common.network.network_utils import ( from common.network.network_utils import (
address_to_ip_port, address_to_ip_port,
@ -88,9 +89,11 @@ from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter im
LegacyTelemetryMessengerAdapter, LegacyTelemetryMessengerAdapter,
) )
from infection_monkey.telemetry.state_telem import StateTelem from infection_monkey.telemetry.state_telem import StateTelem
from infection_monkey.utils import agent_process
from infection_monkey.utils.aws_environment_check import run_aws_environment_check from infection_monkey.utils.aws_environment_check import run_aws_environment_check
from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.environment import is_windows_os
from infection_monkey.utils.file_utils import mark_file_for_deletion_on_windows from infection_monkey.utils.file_utils import mark_file_for_deletion_on_windows
from infection_monkey.utils.ids import get_agent_id, get_machine_id
from infection_monkey.utils.monkey_dir import ( from infection_monkey.utils.monkey_dir import (
create_monkey_dir, create_monkey_dir,
get_monkey_dir_path, get_monkey_dir_path,
@ -120,6 +123,8 @@ class InfectionMonkey:
self._control_client = ControlClient( self._control_client = ControlClient(
server_address=server, island_api_client=self._island_api_client server_address=server, island_api_client=self._island_api_client
) )
self._control_channel = ControlChannel(server, GUID, self._island_api_client)
self._register_agent()
# TODO Refactor the telemetry messengers to accept control client # TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object # and remove control_client_object
@ -165,6 +170,18 @@ class InfectionMonkey:
return server, island_api_client return server, island_api_client
def _register_agent(self, server: str):
agent_registration_data = AgentRegistrationData(
id=get_agent_id(),
machine_hardware_id=get_machine_id(),
start_time=agent_process.get_start_time(),
# parent_id=parent,
parent_id=None, # None for now, until we change GUID to UUID
cc_server=server,
network_interfaces=get_network_interfaces(),
)
self._island_api_client.register_agent(agent_registration_data)
def _select_server( def _select_server(
self, server_clients: Mapping[str, IIslandAPIClient] self, server_clients: Mapping[str, IIslandAPIClient]
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: ) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
@ -195,7 +212,7 @@ class InfectionMonkey:
run_aws_environment_check(self._telemetry_messenger) run_aws_environment_check(self._telemetry_messenger)
should_stop = ControlChannel(self._control_client.server_address, GUID).should_agent_stop() should_stop = self._control_channel.should_agent_stop()
if should_stop: if should_stop:
logger.info("The Monkey Island has instructed this agent to stop") logger.info("The Monkey Island has instructed this agent to stop")
return return
@ -211,9 +228,6 @@ class InfectionMonkey:
if firewall.is_enabled(): if firewall.is_enabled():
firewall.add_firewall_rule() firewall.add_firewall_rule()
self._control_channel = ControlChannel(self._control_client.server_address, GUID)
self._control_channel.register_agent(self._opts.parent)
config = self._control_channel.get_config() config = self._control_channel.get_config()
relay_port = get_free_tcp_port() relay_port = get_free_tcp_port()

View File

@ -10,6 +10,7 @@ from common.agent_event_serializers import (
PydanticAgentEventSerializer, PydanticAgentEventSerializer,
) )
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from common.agent_registration_data import AgentRegistrationData
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
HTTPIslandAPIClient, HTTPIslandAPIClient,
IslandAPIConnectionError, IslandAPIConnectionError,
@ -22,14 +23,25 @@ from infection_monkey.island_api_client import (
SERVER = "1.1.1.1:9999" SERVER = "1.1.1.1:9999"
PBA_FILE = "dummy.pba" PBA_FILE = "dummy.pba"
WINDOWS = "windows" WINDOWS = "windows"
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
AGENT_REGISTRATION = AgentRegistrationData(
id=AGENT_ID,
machine_hardware_id=1,
start_time=0,
parent_id=None,
cc_server=SERVER,
network_interfaces=[],
)
ISLAND_URI = f"https://{SERVER}/api?action=is-up" ISLAND_URI = f"https://{SERVER}/api?action=is-up"
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log" ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}" ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}" ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}"
ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events" ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
ISLAND_REGISTER_AGENT_URI = f"https://{SERVER}/api/agents"
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}"
ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration"
ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials"
class Event1(AbstractAgentEvent): class Event1(AbstractAgentEvent):
@ -275,3 +287,177 @@ def test_island_api_client_send_events__status_code(island_api_client, status_co
with pytest.raises(expected_error): with pytest.raises(expected_error):
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code) m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)]) island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
@pytest.mark.parametrize(
"actual_error, expected_error",
[
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
(TimeoutError, IslandAPITimeoutError),
],
)
def test_island_api_client__register_agent(island_api_client, actual_error, expected_error):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.post(ISLAND_REGISTER_AGENT_URI, exc=actual_error)
island_api_client.register_agent(AGENT_REGISTRATION)
@pytest.mark.parametrize(
"status_code, expected_error",
[
(401, IslandAPIRequestError),
(501, IslandAPIRequestFailedError),
],
)
def test_island_api_client_register_agent__status_code(
island_api_client, status_code, expected_error
):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.post(ISLAND_REGISTER_AGENT_URI, status_code=status_code)
island_api_client.register_agent(AGENT_REGISTRATION)
@pytest.mark.parametrize(
"actual_error, expected_error",
[
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
(TimeoutError, IslandAPITimeoutError),
],
)
def test_island_api_client__should_agent_stop(island_api_client, actual_error, expected_error):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_AGENT_STOP_URI, exc=actual_error)
island_api_client.should_agent_stop(AGENT_ID)
@pytest.mark.parametrize(
"status_code, expected_error",
[
(401, IslandAPIRequestError),
(501, IslandAPIRequestFailedError),
],
)
def test_island_api_client_should_agent_stop__status_code(
island_api_client, status_code, expected_error
):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_AGENT_STOP_URI, status_code=status_code)
island_api_client.should_agent_stop(AGENT_ID)
def test_island_api_client_should_agent_stop__bad_json(island_api_client):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(IslandAPIRequestFailedError):
m.get(ISLAND_AGENT_STOP_URI, content=b"bad")
island_api_client.should_agent_stop(AGENT_ID)
@pytest.mark.parametrize(
"actual_error, expected_error",
[
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
(TimeoutError, IslandAPITimeoutError),
],
)
def test_island_api_client__get_config(island_api_client, actual_error, expected_error):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_GET_CONFIG_URI, exc=actual_error)
island_api_client.get_config()
@pytest.mark.parametrize(
"status_code, expected_error",
[
(401, IslandAPIRequestError),
(501, IslandAPIRequestFailedError),
],
)
def test_island_api_client_get_config__status_code(island_api_client, status_code, expected_error):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_GET_CONFIG_URI, status_code=status_code)
island_api_client.get_config()
def test_island_api_client_get_config__bad_json(island_api_client):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(IslandAPIRequestFailedError):
m.get(ISLAND_GET_CONFIG_URI, content=b"bad")
island_api_client.get_config()
@pytest.mark.parametrize(
"actual_error, expected_error",
[
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
(TimeoutError, IslandAPITimeoutError),
],
)
def test_island_api_client__get_credentials_for_propagation(
island_api_client, actual_error, expected_error
):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, exc=actual_error)
island_api_client.get_credentials_for_propagation()
@pytest.mark.parametrize(
"status_code, expected_error",
[
(401, IslandAPIRequestError),
(501, IslandAPIRequestFailedError),
],
)
def test_island_api_client_get_credentials_for_propagation__status_code(
island_api_client, status_code, expected_error
):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(expected_error):
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, status_code=status_code)
island_api_client.get_credentials_for_propagation()
def test_island_api_client_get_credentials_for_propagation__bad_json(island_api_client):
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
with pytest.raises(IslandAPIRequestFailedError):
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad")
island_api_client.get_credentials_for_propagation()

View File

@ -0,0 +1,76 @@
from unittest.mock import MagicMock
import pytest
from infection_monkey.i_control_channel import IslandCommunicationError
from infection_monkey.island_api_client import (
IIslandAPIClient,
IslandAPIConnectionError,
IslandAPIRequestError,
IslandAPIRequestFailedError,
IslandAPITimeoutError,
)
from infection_monkey.master.control_channel import ControlChannel
SERVER = "server"
AGENT_ID = "agent"
CONTROL_CHANNEL_API_ERRORS = [
IslandAPIConnectionError,
IslandAPIRequestError,
IslandAPIRequestFailedError,
IslandAPITimeoutError,
]
@pytest.fixture
def island_api_client() -> IIslandAPIClient:
client = MagicMock()
return client
@pytest.fixture
def control_channel(island_api_client) -> ControlChannel:
return ControlChannel(SERVER, AGENT_ID, island_api_client)
def test_control_channel__should_agent_stop(control_channel, island_api_client):
control_channel.should_agent_stop()
assert island_api_client.should_agent_stop.called_once()
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
def test_control_channel__should_agent_stop_raises_error(
control_channel, island_api_client, api_error
):
island_api_client.should_agent_stop.side_effect = api_error()
with pytest.raises(IslandCommunicationError):
control_channel.should_agent_stop()
def test_control_channel__get_config(control_channel, island_api_client):
control_channel.get_config()
assert island_api_client.get_config.called_once()
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
def test_control_channel__get_config_raises_error(control_channel, island_api_client, api_error):
island_api_client.get_config.side_effect = api_error()
with pytest.raises(IslandCommunicationError):
control_channel.get_config()
def test_control_channel__get_credentials_for_propagation(control_channel, island_api_client):
control_channel.get_credentials_for_propagation()
assert island_api_client.get_credentials_for_propagation.called_once()
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
def test_control_channel__get_credentials_for_propagation_raises_error(
control_channel, island_api_client, api_error
):
island_api_client.get_credentials_for_propagation.side_effect = api_error()
with pytest.raises(IslandCommunicationError):
control_channel.get_credentials_for_propagation()