Merge pull request #2324 from guardicore/2292-control-channel-client-api-client
2292 control channel client api client
This commit is contained in:
commit
f472963b78
|
@ -1,23 +1,11 @@
|
|||
import abc
|
||||
from typing import Optional, Sequence
|
||||
from uuid import UUID
|
||||
from typing import Sequence
|
||||
|
||||
from common.agent_configuration import AgentConfiguration
|
||||
from common.credentials import Credentials
|
||||
|
||||
|
||||
class IControlChannel(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def register_agent(self, parent_id: Optional[UUID] = None):
|
||||
"""
|
||||
Registers this agent with the Island when this agent starts
|
||||
|
||||
:param parent: The ID of the parent that spawned this agent, or None if this agent has no
|
||||
parent
|
||||
:raises IslandCommunicationError: If the agent cannot be successfully registered
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def should_agent_stop(self) -> bool:
|
||||
"""
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
import functools
|
||||
import json
|
||||
import logging
|
||||
from pprint import pformat
|
||||
from typing import List, Sequence
|
||||
|
||||
import requests
|
||||
|
||||
from common import OperatingSystem
|
||||
from common import AgentRegistrationData, OperatingSystem
|
||||
from common.agent_configuration import AgentConfiguration
|
||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||
from common.common_consts.timeouts import (
|
||||
LONG_REQUEST_TIMEOUT,
|
||||
MEDIUM_REQUEST_TIMEOUT,
|
||||
SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
from common.credentials import Credentials
|
||||
|
||||
from . import (
|
||||
AbstractIslandAPIClientFactory,
|
||||
|
@ -27,7 +35,9 @@ def handle_island_errors(fn):
|
|||
def decorated(*args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except requests.exceptions.ConnectionError as err:
|
||||
except IslandAPIError as err:
|
||||
raise err
|
||||
except (requests.exceptions.ConnectionError, requests.exceptions.TooManyRedirects) as err:
|
||||
raise IslandAPIConnectionError(err)
|
||||
except requests.exceptions.HTTPError as err:
|
||||
if 400 <= err.response.status_code < 500:
|
||||
|
@ -46,6 +56,17 @@ def handle_island_errors(fn):
|
|||
return decorated
|
||||
|
||||
|
||||
def convert_json_error_to_island_api_error(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
fn(*args, **kwargs)
|
||||
except json.JSONDecodeError as err:
|
||||
raise IslandAPIRequestFailedError(err)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class HTTPIslandAPIClient(IIslandAPIClient):
|
||||
"""
|
||||
A client for the Island's HTTP API
|
||||
|
@ -116,6 +137,58 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
|
||||
response.raise_for_status()
|
||||
|
||||
@handle_island_errors
|
||||
def register_agent(self, agent_registration_data: AgentRegistrationData):
|
||||
url = f"{self._api_url}/agents"
|
||||
response = requests.post( # noqa: DUO123
|
||||
url,
|
||||
json=agent_registration_data.dict(simplify=True),
|
||||
verify=False,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@handle_island_errors
|
||||
@convert_json_error_to_island_api_error
|
||||
def should_agent_stop(self, agent_id: str) -> bool:
|
||||
url = f"{self._api_url}/monkey-control/needs-to-stop/{agent_id}"
|
||||
response = requests.get( # noqa: DUO123
|
||||
url,
|
||||
verify=False,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()["stop_agent"]
|
||||
|
||||
@handle_island_errors
|
||||
@convert_json_error_to_island_api_error
|
||||
def get_config(self) -> AgentConfiguration:
|
||||
response = requests.get( # noqa: DUO123
|
||||
f"{self._api_url}/agent-configuration",
|
||||
verify=False,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
config_dict = response.json()
|
||||
|
||||
logger.debug(f"Received configuration:\n{pformat(config_dict)}")
|
||||
|
||||
return AgentConfiguration(**config_dict)
|
||||
|
||||
@handle_island_errors
|
||||
@convert_json_error_to_island_api_error
|
||||
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
|
||||
response = requests.get( # noqa: DUO123
|
||||
f"{self._api_url}/propagation-credentials",
|
||||
verify=False,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return [Credentials(**credentials) for credentials in response.json()]
|
||||
|
||||
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
|
||||
serialized_events: List[JSONSerializable] = []
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from common import OperatingSystem
|
||||
from common import AgentRegistrationData, OperatingSystem
|
||||
from common.agent_configuration import AgentConfiguration
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from common.credentials import Credentials
|
||||
|
||||
|
||||
class IIslandAPIClient(ABC):
|
||||
|
@ -13,7 +15,7 @@ class IIslandAPIClient(ABC):
|
|||
@abstractmethod
|
||||
def connect(self, island_server: str):
|
||||
"""
|
||||
Connectto the island's API
|
||||
Connect to the island's API
|
||||
|
||||
:param island_server: The socket address of the API
|
||||
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
|
||||
|
@ -74,7 +76,6 @@ class IIslandAPIClient(ABC):
|
|||
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
|
||||
:raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
|
||||
agent binary
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -92,3 +93,53 @@ class IIslandAPIClient(ABC):
|
|||
:raises IslandAPIError: If an unexpected error occurs while attempting to send events to
|
||||
the island
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def register_agent(self, agent_registration_data: AgentRegistrationData):
|
||||
"""
|
||||
Register an agent with the Island
|
||||
|
||||
:param agent_registration_data: Information about the agent to register
|
||||
with the island
|
||||
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||
:raises IslandAPITimeoutError: If the command timed out
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def should_agent_stop(self, agent_id: str) -> bool:
|
||||
"""
|
||||
Check with the island to see if the agent should stop
|
||||
|
||||
:param agent_id: The agent identifier for the agent to check
|
||||
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||
:raises IslandAPITimeoutError: If the command timed out
|
||||
:return: True if the agent should stop, otherwise False
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> AgentConfiguration:
|
||||
"""
|
||||
Get agent configuration from the island
|
||||
|
||||
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||
:raises IslandAPITimeoutError: If the command timed out
|
||||
:return: Agent configuration
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
|
||||
"""
|
||||
Get credentials from the island
|
||||
|
||||
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||
:raises IslandAPITimeoutError: If the command timed out
|
||||
:return: Credentials
|
||||
"""
|
||||
|
|
|
@ -1,127 +1,47 @@
|
|||
import json
|
||||
import logging
|
||||
from pprint import pformat
|
||||
from typing import Optional, Sequence
|
||||
from uuid import UUID
|
||||
from functools import wraps
|
||||
from typing import Sequence
|
||||
|
||||
import requests
|
||||
from urllib3 import disable_warnings
|
||||
|
||||
from common import AgentRegistrationData
|
||||
from common.agent_configuration import AgentConfiguration
|
||||
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
|
||||
from common.credentials import Credentials
|
||||
from common.network.network_utils import get_network_interfaces
|
||||
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||
from infection_monkey.utils import agent_process
|
||||
from infection_monkey.utils.ids import get_agent_id, get_machine_id
|
||||
from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIError
|
||||
|
||||
disable_warnings() # noqa: DUO131
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_island_api_errors(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except IslandAPIError as err:
|
||||
raise IslandCommunicationError(err)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ControlChannel(IControlChannel):
|
||||
def __init__(self, server: str, agent_id: str):
|
||||
def __init__(self, server: str, agent_id: str, api_client: IIslandAPIClient):
|
||||
self._agent_id = agent_id
|
||||
self._control_channel_server = server
|
||||
self._island_api_client = api_client
|
||||
|
||||
def register_agent(self, parent: Optional[UUID] = None):
|
||||
agent_registration_data = AgentRegistrationData(
|
||||
id=get_agent_id(),
|
||||
machine_hardware_id=get_machine_id(),
|
||||
start_time=agent_process.get_start_time(),
|
||||
# parent_id=parent,
|
||||
parent_id=None, # None for now, until we change GUID to UUID
|
||||
cc_server=self._control_channel_server,
|
||||
network_interfaces=get_network_interfaces(),
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"https://{self._control_channel_server}/api/agents"
|
||||
response = requests.post( # noqa: DUO123
|
||||
url,
|
||||
json=agent_registration_data.dict(simplify=True),
|
||||
verify=False,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except (
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.TooManyRedirects,
|
||||
requests.exceptions.HTTPError,
|
||||
) as e:
|
||||
raise IslandCommunicationError(e)
|
||||
|
||||
@handle_island_api_errors
|
||||
def should_agent_stop(self) -> bool:
|
||||
if not self._control_channel_server:
|
||||
logger.error("Agent should stop because it can't connect to the C&C server.")
|
||||
return True
|
||||
try:
|
||||
url = (
|
||||
f"https://{self._control_channel_server}/api/monkey-control"
|
||||
f"/needs-to-stop/{self._agent_id}"
|
||||
)
|
||||
response = requests.get( # noqa: DUO123
|
||||
url,
|
||||
verify=False,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = json.loads(response.content.decode())
|
||||
return json_response["stop_agent"]
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.TooManyRedirects,
|
||||
requests.exceptions.HTTPError,
|
||||
) as e:
|
||||
raise IslandCommunicationError(e)
|
||||
return self._island_api_client.should_agent_stop(self._agent_id)
|
||||
|
||||
@handle_island_api_errors
|
||||
def get_config(self) -> AgentConfiguration:
|
||||
try:
|
||||
response = requests.get( # noqa: DUO123
|
||||
f"https://{self._control_channel_server}/api/agent-configuration",
|
||||
verify=False,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
config_dict = json.loads(response.text)
|
||||
|
||||
logger.debug(f"Received configuration:\n{pformat(config_dict)}")
|
||||
|
||||
return AgentConfiguration(**config_dict)
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.TooManyRedirects,
|
||||
requests.exceptions.HTTPError,
|
||||
) as e:
|
||||
raise IslandCommunicationError(e)
|
||||
return self._island_api_client.get_config()
|
||||
|
||||
@handle_island_api_errors
|
||||
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
|
||||
propagation_credentials_url = (
|
||||
f"https://{self._control_channel_server}/api/propagation-credentials"
|
||||
)
|
||||
try:
|
||||
response = requests.get( # noqa: DUO123
|
||||
propagation_credentials_url,
|
||||
verify=False,
|
||||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return [Credentials(**credentials) for credentials in response.json()]
|
||||
except (
|
||||
requests.exceptions.JSONDecodeError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.TooManyRedirects,
|
||||
requests.exceptions.HTTPError,
|
||||
) as e:
|
||||
raise IslandCommunicationError(e)
|
||||
return self._island_api_client.get_credentials_for_propagation()
|
||||
|
|
|
@ -14,6 +14,7 @@ from common.agent_event_serializers import (
|
|||
register_common_agent_event_serializers,
|
||||
)
|
||||
from common.agent_events import CredentialsStolenEvent
|
||||
from common.agent_registration_data import AgentRegistrationData
|
||||
from common.event_queue import IAgentEventQueue, PyPubSubAgentEventQueue
|
||||
from common.network.network_utils import (
|
||||
address_to_ip_port,
|
||||
|
@ -88,9 +89,11 @@ from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter im
|
|||
LegacyTelemetryMessengerAdapter,
|
||||
)
|
||||
from infection_monkey.telemetry.state_telem import StateTelem
|
||||
from infection_monkey.utils import agent_process
|
||||
from infection_monkey.utils.aws_environment_check import run_aws_environment_check
|
||||
from infection_monkey.utils.environment import is_windows_os
|
||||
from infection_monkey.utils.file_utils import mark_file_for_deletion_on_windows
|
||||
from infection_monkey.utils.ids import get_agent_id, get_machine_id
|
||||
from infection_monkey.utils.monkey_dir import (
|
||||
create_monkey_dir,
|
||||
get_monkey_dir_path,
|
||||
|
@ -120,6 +123,8 @@ class InfectionMonkey:
|
|||
self._control_client = ControlClient(
|
||||
server_address=server, island_api_client=self._island_api_client
|
||||
)
|
||||
self._control_channel = ControlChannel(server, GUID, self._island_api_client)
|
||||
self._register_agent()
|
||||
|
||||
# TODO Refactor the telemetry messengers to accept control client
|
||||
# and remove control_client_object
|
||||
|
@ -165,6 +170,18 @@ class InfectionMonkey:
|
|||
|
||||
return server, island_api_client
|
||||
|
||||
def _register_agent(self, server: str):
|
||||
agent_registration_data = AgentRegistrationData(
|
||||
id=get_agent_id(),
|
||||
machine_hardware_id=get_machine_id(),
|
||||
start_time=agent_process.get_start_time(),
|
||||
# parent_id=parent,
|
||||
parent_id=None, # None for now, until we change GUID to UUID
|
||||
cc_server=server,
|
||||
network_interfaces=get_network_interfaces(),
|
||||
)
|
||||
self._island_api_client.register_agent(agent_registration_data)
|
||||
|
||||
def _select_server(
|
||||
self, server_clients: Mapping[str, IIslandAPIClient]
|
||||
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
||||
|
@ -195,7 +212,7 @@ class InfectionMonkey:
|
|||
|
||||
run_aws_environment_check(self._telemetry_messenger)
|
||||
|
||||
should_stop = ControlChannel(self._control_client.server_address, GUID).should_agent_stop()
|
||||
should_stop = self._control_channel.should_agent_stop()
|
||||
if should_stop:
|
||||
logger.info("The Monkey Island has instructed this agent to stop")
|
||||
return
|
||||
|
@ -211,9 +228,6 @@ class InfectionMonkey:
|
|||
if firewall.is_enabled():
|
||||
firewall.add_firewall_rule()
|
||||
|
||||
self._control_channel = ControlChannel(self._control_client.server_address, GUID)
|
||||
self._control_channel.register_agent(self._opts.parent)
|
||||
|
||||
config = self._control_channel.get_config()
|
||||
|
||||
relay_port = get_free_tcp_port()
|
||||
|
|
|
@ -10,6 +10,7 @@ from common.agent_event_serializers import (
|
|||
PydanticAgentEventSerializer,
|
||||
)
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from common.agent_registration_data import AgentRegistrationData
|
||||
from infection_monkey.island_api_client import (
|
||||
HTTPIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
|
@ -22,14 +23,25 @@ from infection_monkey.island_api_client import (
|
|||
SERVER = "1.1.1.1:9999"
|
||||
PBA_FILE = "dummy.pba"
|
||||
WINDOWS = "windows"
|
||||
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
||||
AGENT_REGISTRATION = AgentRegistrationData(
|
||||
id=AGENT_ID,
|
||||
machine_hardware_id=1,
|
||||
start_time=0,
|
||||
parent_id=None,
|
||||
cc_server=SERVER,
|
||||
network_interfaces=[],
|
||||
)
|
||||
|
||||
ISLAND_URI = f"https://{SERVER}/api?action=is-up"
|
||||
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
|
||||
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
|
||||
ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}"
|
||||
ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
|
||||
|
||||
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
||||
ISLAND_REGISTER_AGENT_URI = f"https://{SERVER}/api/agents"
|
||||
ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}"
|
||||
ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration"
|
||||
ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials"
|
||||
|
||||
|
||||
class Event1(AbstractAgentEvent):
|
||||
|
@ -275,3 +287,177 @@ def test_island_api_client_send_events__status_code(island_api_client, status_co
|
|||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
|
||||
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"actual_error, expected_error",
|
||||
[
|
||||
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||
(TimeoutError, IslandAPITimeoutError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__register_agent(island_api_client, actual_error, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_REGISTER_AGENT_URI, exc=actual_error)
|
||||
island_api_client.register_agent(AGENT_REGISTRATION)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, expected_error",
|
||||
[
|
||||
(401, IslandAPIRequestError),
|
||||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client_register_agent__status_code(
|
||||
island_api_client, status_code, expected_error
|
||||
):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.post(ISLAND_REGISTER_AGENT_URI, status_code=status_code)
|
||||
island_api_client.register_agent(AGENT_REGISTRATION)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"actual_error, expected_error",
|
||||
[
|
||||
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||
(TimeoutError, IslandAPITimeoutError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__should_agent_stop(island_api_client, actual_error, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.get(ISLAND_AGENT_STOP_URI, exc=actual_error)
|
||||
island_api_client.should_agent_stop(AGENT_ID)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, expected_error",
|
||||
[
|
||||
(401, IslandAPIRequestError),
|
||||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client_should_agent_stop__status_code(
|
||||
island_api_client, status_code, expected_error
|
||||
):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.get(ISLAND_AGENT_STOP_URI, status_code=status_code)
|
||||
island_api_client.should_agent_stop(AGENT_ID)
|
||||
|
||||
|
||||
def test_island_api_client_should_agent_stop__bad_json(island_api_client):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(IslandAPIRequestFailedError):
|
||||
m.get(ISLAND_AGENT_STOP_URI, content=b"bad")
|
||||
island_api_client.should_agent_stop(AGENT_ID)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"actual_error, expected_error",
|
||||
[
|
||||
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||
(TimeoutError, IslandAPITimeoutError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__get_config(island_api_client, actual_error, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.get(ISLAND_GET_CONFIG_URI, exc=actual_error)
|
||||
island_api_client.get_config()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, expected_error",
|
||||
[
|
||||
(401, IslandAPIRequestError),
|
||||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client_get_config__status_code(island_api_client, status_code, expected_error):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.get(ISLAND_GET_CONFIG_URI, status_code=status_code)
|
||||
island_api_client.get_config()
|
||||
|
||||
|
||||
def test_island_api_client_get_config__bad_json(island_api_client):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(IslandAPIRequestFailedError):
|
||||
m.get(ISLAND_GET_CONFIG_URI, content=b"bad")
|
||||
island_api_client.get_config()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"actual_error, expected_error",
|
||||
[
|
||||
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||
(TimeoutError, IslandAPITimeoutError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client__get_credentials_for_propagation(
|
||||
island_api_client, actual_error, expected_error
|
||||
):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, exc=actual_error)
|
||||
island_api_client.get_credentials_for_propagation()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, expected_error",
|
||||
[
|
||||
(401, IslandAPIRequestError),
|
||||
(501, IslandAPIRequestFailedError),
|
||||
],
|
||||
)
|
||||
def test_island_api_client_get_credentials_for_propagation__status_code(
|
||||
island_api_client, status_code, expected_error
|
||||
):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(expected_error):
|
||||
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, status_code=status_code)
|
||||
island_api_client.get_credentials_for_propagation()
|
||||
|
||||
|
||||
def test_island_api_client_get_credentials_for_propagation__bad_json(island_api_client):
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
with pytest.raises(IslandAPIRequestFailedError):
|
||||
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad")
|
||||
island_api_client.get_credentials_for_propagation()
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.i_control_channel import IslandCommunicationError
|
||||
from infection_monkey.island_api_client import (
|
||||
IIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
IslandAPIRequestError,
|
||||
IslandAPIRequestFailedError,
|
||||
IslandAPITimeoutError,
|
||||
)
|
||||
from infection_monkey.master.control_channel import ControlChannel
|
||||
|
||||
SERVER = "server"
|
||||
AGENT_ID = "agent"
|
||||
CONTROL_CHANNEL_API_ERRORS = [
|
||||
IslandAPIConnectionError,
|
||||
IslandAPIRequestError,
|
||||
IslandAPIRequestFailedError,
|
||||
IslandAPITimeoutError,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def island_api_client() -> IIslandAPIClient:
|
||||
client = MagicMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def control_channel(island_api_client) -> ControlChannel:
|
||||
return ControlChannel(SERVER, AGENT_ID, island_api_client)
|
||||
|
||||
|
||||
def test_control_channel__should_agent_stop(control_channel, island_api_client):
|
||||
control_channel.should_agent_stop()
|
||||
assert island_api_client.should_agent_stop.called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
||||
def test_control_channel__should_agent_stop_raises_error(
|
||||
control_channel, island_api_client, api_error
|
||||
):
|
||||
island_api_client.should_agent_stop.side_effect = api_error()
|
||||
|
||||
with pytest.raises(IslandCommunicationError):
|
||||
control_channel.should_agent_stop()
|
||||
|
||||
|
||||
def test_control_channel__get_config(control_channel, island_api_client):
|
||||
control_channel.get_config()
|
||||
assert island_api_client.get_config.called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
||||
def test_control_channel__get_config_raises_error(control_channel, island_api_client, api_error):
|
||||
island_api_client.get_config.side_effect = api_error()
|
||||
|
||||
with pytest.raises(IslandCommunicationError):
|
||||
control_channel.get_config()
|
||||
|
||||
|
||||
def test_control_channel__get_credentials_for_propagation(control_channel, island_api_client):
|
||||
control_channel.get_credentials_for_propagation()
|
||||
assert island_api_client.get_credentials_for_propagation.called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
||||
def test_control_channel__get_credentials_for_propagation_raises_error(
|
||||
control_channel, island_api_client, api_error
|
||||
):
|
||||
island_api_client.get_credentials_for_propagation.side_effect = api_error()
|
||||
|
||||
with pytest.raises(IslandCommunicationError):
|
||||
control_channel.get_credentials_for_propagation()
|
Loading…
Reference in New Issue