forked from p15670423/monkey
Merge pull request #2324 from guardicore/2292-control-channel-client-api-client
2292 control channel client api client
This commit is contained in:
commit
f472963b78
|
@ -1,23 +1,11 @@
|
||||||
import abc
|
import abc
|
||||||
from typing import Optional, Sequence
|
from typing import Sequence
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from common.agent_configuration import AgentConfiguration
|
from common.agent_configuration import AgentConfiguration
|
||||||
from common.credentials import Credentials
|
from common.credentials import Credentials
|
||||||
|
|
||||||
|
|
||||||
class IControlChannel(metaclass=abc.ABCMeta):
|
class IControlChannel(metaclass=abc.ABCMeta):
|
||||||
@abc.abstractmethod
|
|
||||||
def register_agent(self, parent_id: Optional[UUID] = None):
|
|
||||||
"""
|
|
||||||
Registers this agent with the Island when this agent starts
|
|
||||||
|
|
||||||
:param parent: The ID of the parent that spawned this agent, or None if this agent has no
|
|
||||||
parent
|
|
||||||
:raises IslandCommunicationError: If the agent cannot be successfully registered
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def should_agent_stop(self) -> bool:
|
def should_agent_stop(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,13 +1,21 @@
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from pprint import pformat
|
||||||
from typing import List, Sequence
|
from typing import List, Sequence
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import AgentRegistrationData, OperatingSystem
|
||||||
|
from common.agent_configuration import AgentConfiguration
|
||||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import (
|
||||||
|
LONG_REQUEST_TIMEOUT,
|
||||||
|
MEDIUM_REQUEST_TIMEOUT,
|
||||||
|
SHORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
from common.credentials import Credentials
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
AbstractIslandAPIClientFactory,
|
AbstractIslandAPIClientFactory,
|
||||||
|
@ -27,7 +35,9 @@ def handle_island_errors(fn):
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
except requests.exceptions.ConnectionError as err:
|
except IslandAPIError as err:
|
||||||
|
raise err
|
||||||
|
except (requests.exceptions.ConnectionError, requests.exceptions.TooManyRedirects) as err:
|
||||||
raise IslandAPIConnectionError(err)
|
raise IslandAPIConnectionError(err)
|
||||||
except requests.exceptions.HTTPError as err:
|
except requests.exceptions.HTTPError as err:
|
||||||
if 400 <= err.response.status_code < 500:
|
if 400 <= err.response.status_code < 500:
|
||||||
|
@ -46,6 +56,17 @@ def handle_island_errors(fn):
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def convert_json_error_to_island_api_error(fn):
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
fn(*args, **kwargs)
|
||||||
|
except json.JSONDecodeError as err:
|
||||||
|
raise IslandAPIRequestFailedError(err)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class HTTPIslandAPIClient(IIslandAPIClient):
|
class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
"""
|
"""
|
||||||
A client for the Island's HTTP API
|
A client for the Island's HTTP API
|
||||||
|
@ -116,6 +137,58 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@handle_island_errors
|
||||||
|
def register_agent(self, agent_registration_data: AgentRegistrationData):
|
||||||
|
url = f"{self._api_url}/agents"
|
||||||
|
response = requests.post( # noqa: DUO123
|
||||||
|
url,
|
||||||
|
json=agent_registration_data.dict(simplify=True),
|
||||||
|
verify=False,
|
||||||
|
timeout=SHORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
@handle_island_errors
|
||||||
|
@convert_json_error_to_island_api_error
|
||||||
|
def should_agent_stop(self, agent_id: str) -> bool:
|
||||||
|
url = f"{self._api_url}/monkey-control/needs-to-stop/{agent_id}"
|
||||||
|
response = requests.get( # noqa: DUO123
|
||||||
|
url,
|
||||||
|
verify=False,
|
||||||
|
timeout=SHORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return response.json()["stop_agent"]
|
||||||
|
|
||||||
|
@handle_island_errors
|
||||||
|
@convert_json_error_to_island_api_error
|
||||||
|
def get_config(self) -> AgentConfiguration:
|
||||||
|
response = requests.get( # noqa: DUO123
|
||||||
|
f"{self._api_url}/agent-configuration",
|
||||||
|
verify=False,
|
||||||
|
timeout=SHORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
config_dict = response.json()
|
||||||
|
|
||||||
|
logger.debug(f"Received configuration:\n{pformat(config_dict)}")
|
||||||
|
|
||||||
|
return AgentConfiguration(**config_dict)
|
||||||
|
|
||||||
|
@handle_island_errors
|
||||||
|
@convert_json_error_to_island_api_error
|
||||||
|
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
|
||||||
|
response = requests.get( # noqa: DUO123
|
||||||
|
f"{self._api_url}/propagation-credentials",
|
||||||
|
verify=False,
|
||||||
|
timeout=SHORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return [Credentials(**credentials) for credentials in response.json()]
|
||||||
|
|
||||||
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
|
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
|
||||||
serialized_events: List[JSONSerializable] = []
|
serialized_events: List[JSONSerializable] = []
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import AgentRegistrationData, OperatingSystem
|
||||||
|
from common.agent_configuration import AgentConfiguration
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
|
from common.credentials import Credentials
|
||||||
|
|
||||||
|
|
||||||
class IIslandAPIClient(ABC):
|
class IIslandAPIClient(ABC):
|
||||||
|
@ -13,7 +15,7 @@ class IIslandAPIClient(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def connect(self, island_server: str):
|
def connect(self, island_server: str):
|
||||||
"""
|
"""
|
||||||
Connectto the island's API
|
Connect to the island's API
|
||||||
|
|
||||||
:param island_server: The socket address of the API
|
:param island_server: The socket address of the API
|
||||||
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
|
:raises IslandAPIConnectionError: If the client cannot successfully connect to the island
|
||||||
|
@ -74,7 +76,6 @@ class IIslandAPIClient(ABC):
|
||||||
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
|
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
|
||||||
:raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
|
:raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
|
||||||
agent binary
|
agent binary
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -92,3 +93,53 @@ class IIslandAPIClient(ABC):
|
||||||
:raises IslandAPIError: If an unexpected error occurs while attempting to send events to
|
:raises IslandAPIError: If an unexpected error occurs while attempting to send events to
|
||||||
the island
|
the island
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def register_agent(self, agent_registration_data: AgentRegistrationData):
|
||||||
|
"""
|
||||||
|
Register an agent with the Island
|
||||||
|
|
||||||
|
:param agent_registration_data: Information about the agent to register
|
||||||
|
with the island
|
||||||
|
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||||
|
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||||
|
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||||
|
:raises IslandAPITimeoutError: If the command timed out
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_agent_stop(self, agent_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check with the island to see if the agent should stop
|
||||||
|
|
||||||
|
:param agent_id: The agent identifier for the agent to check
|
||||||
|
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||||
|
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||||
|
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||||
|
:raises IslandAPITimeoutError: If the command timed out
|
||||||
|
:return: True if the agent should stop, otherwise False
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_config(self) -> AgentConfiguration:
|
||||||
|
"""
|
||||||
|
Get agent configuration from the island
|
||||||
|
|
||||||
|
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||||
|
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||||
|
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||||
|
:raises IslandAPITimeoutError: If the command timed out
|
||||||
|
:return: Agent configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
|
||||||
|
"""
|
||||||
|
Get credentials from the island
|
||||||
|
|
||||||
|
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||||
|
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||||
|
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||||
|
:raises IslandAPITimeoutError: If the command timed out
|
||||||
|
:return: Credentials
|
||||||
|
"""
|
||||||
|
|
|
@ -1,127 +1,47 @@
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from pprint import pformat
|
from functools import wraps
|
||||||
from typing import Optional, Sequence
|
from typing import Sequence
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from urllib3 import disable_warnings
|
from urllib3 import disable_warnings
|
||||||
|
|
||||||
from common import AgentRegistrationData
|
|
||||||
from common.agent_configuration import AgentConfiguration
|
from common.agent_configuration import AgentConfiguration
|
||||||
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
|
|
||||||
from common.credentials import Credentials
|
from common.credentials import Credentials
|
||||||
from common.network.network_utils import get_network_interfaces
|
|
||||||
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||||
from infection_monkey.utils import agent_process
|
from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIError
|
||||||
from infection_monkey.utils.ids import get_agent_id, get_machine_id
|
|
||||||
|
|
||||||
disable_warnings() # noqa: DUO131
|
disable_warnings() # noqa: DUO131
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_island_api_errors(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
except IslandAPIError as err:
|
||||||
|
raise IslandCommunicationError(err)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class ControlChannel(IControlChannel):
|
class ControlChannel(IControlChannel):
|
||||||
def __init__(self, server: str, agent_id: str):
|
def __init__(self, server: str, agent_id: str, api_client: IIslandAPIClient):
|
||||||
self._agent_id = agent_id
|
self._agent_id = agent_id
|
||||||
self._control_channel_server = server
|
self._control_channel_server = server
|
||||||
|
self._island_api_client = api_client
|
||||||
|
|
||||||
def register_agent(self, parent: Optional[UUID] = None):
|
@handle_island_api_errors
|
||||||
agent_registration_data = AgentRegistrationData(
|
|
||||||
id=get_agent_id(),
|
|
||||||
machine_hardware_id=get_machine_id(),
|
|
||||||
start_time=agent_process.get_start_time(),
|
|
||||||
# parent_id=parent,
|
|
||||||
parent_id=None, # None for now, until we change GUID to UUID
|
|
||||||
cc_server=self._control_channel_server,
|
|
||||||
network_interfaces=get_network_interfaces(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
url = f"https://{self._control_channel_server}/api/agents"
|
|
||||||
response = requests.post( # noqa: DUO123
|
|
||||||
url,
|
|
||||||
json=agent_registration_data.dict(simplify=True),
|
|
||||||
verify=False,
|
|
||||||
timeout=SHORT_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
except (
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
requests.exceptions.Timeout,
|
|
||||||
requests.exceptions.TooManyRedirects,
|
|
||||||
requests.exceptions.HTTPError,
|
|
||||||
) as e:
|
|
||||||
raise IslandCommunicationError(e)
|
|
||||||
|
|
||||||
def should_agent_stop(self) -> bool:
|
def should_agent_stop(self) -> bool:
|
||||||
if not self._control_channel_server:
|
if not self._control_channel_server:
|
||||||
logger.error("Agent should stop because it can't connect to the C&C server.")
|
logger.error("Agent should stop because it can't connect to the C&C server.")
|
||||||
return True
|
return True
|
||||||
try:
|
return self._island_api_client.should_agent_stop(self._agent_id)
|
||||||
url = (
|
|
||||||
f"https://{self._control_channel_server}/api/monkey-control"
|
|
||||||
f"/needs-to-stop/{self._agent_id}"
|
|
||||||
)
|
|
||||||
response = requests.get( # noqa: DUO123
|
|
||||||
url,
|
|
||||||
verify=False,
|
|
||||||
timeout=SHORT_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
json_response = json.loads(response.content.decode())
|
|
||||||
return json_response["stop_agent"]
|
|
||||||
except (
|
|
||||||
json.JSONDecodeError,
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
requests.exceptions.Timeout,
|
|
||||||
requests.exceptions.TooManyRedirects,
|
|
||||||
requests.exceptions.HTTPError,
|
|
||||||
) as e:
|
|
||||||
raise IslandCommunicationError(e)
|
|
||||||
|
|
||||||
|
@handle_island_api_errors
|
||||||
def get_config(self) -> AgentConfiguration:
|
def get_config(self) -> AgentConfiguration:
|
||||||
try:
|
return self._island_api_client.get_config()
|
||||||
response = requests.get( # noqa: DUO123
|
|
||||||
f"https://{self._control_channel_server}/api/agent-configuration",
|
|
||||||
verify=False,
|
|
||||||
timeout=SHORT_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
config_dict = json.loads(response.text)
|
|
||||||
|
|
||||||
logger.debug(f"Received configuration:\n{pformat(config_dict)}")
|
|
||||||
|
|
||||||
return AgentConfiguration(**config_dict)
|
|
||||||
except (
|
|
||||||
json.JSONDecodeError,
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
requests.exceptions.Timeout,
|
|
||||||
requests.exceptions.TooManyRedirects,
|
|
||||||
requests.exceptions.HTTPError,
|
|
||||||
) as e:
|
|
||||||
raise IslandCommunicationError(e)
|
|
||||||
|
|
||||||
|
@handle_island_api_errors
|
||||||
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
|
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
|
||||||
propagation_credentials_url = (
|
return self._island_api_client.get_credentials_for_propagation()
|
||||||
f"https://{self._control_channel_server}/api/propagation-credentials"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = requests.get( # noqa: DUO123
|
|
||||||
propagation_credentials_url,
|
|
||||||
verify=False,
|
|
||||||
timeout=SHORT_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return [Credentials(**credentials) for credentials in response.json()]
|
|
||||||
except (
|
|
||||||
requests.exceptions.JSONDecodeError,
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
requests.exceptions.Timeout,
|
|
||||||
requests.exceptions.TooManyRedirects,
|
|
||||||
requests.exceptions.HTTPError,
|
|
||||||
) as e:
|
|
||||||
raise IslandCommunicationError(e)
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from common.agent_event_serializers import (
|
||||||
register_common_agent_event_serializers,
|
register_common_agent_event_serializers,
|
||||||
)
|
)
|
||||||
from common.agent_events import CredentialsStolenEvent
|
from common.agent_events import CredentialsStolenEvent
|
||||||
|
from common.agent_registration_data import AgentRegistrationData
|
||||||
from common.event_queue import IAgentEventQueue, PyPubSubAgentEventQueue
|
from common.event_queue import IAgentEventQueue, PyPubSubAgentEventQueue
|
||||||
from common.network.network_utils import (
|
from common.network.network_utils import (
|
||||||
address_to_ip_port,
|
address_to_ip_port,
|
||||||
|
@ -88,9 +89,11 @@ from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter im
|
||||||
LegacyTelemetryMessengerAdapter,
|
LegacyTelemetryMessengerAdapter,
|
||||||
)
|
)
|
||||||
from infection_monkey.telemetry.state_telem import StateTelem
|
from infection_monkey.telemetry.state_telem import StateTelem
|
||||||
|
from infection_monkey.utils import agent_process
|
||||||
from infection_monkey.utils.aws_environment_check import run_aws_environment_check
|
from infection_monkey.utils.aws_environment_check import run_aws_environment_check
|
||||||
from infection_monkey.utils.environment import is_windows_os
|
from infection_monkey.utils.environment import is_windows_os
|
||||||
from infection_monkey.utils.file_utils import mark_file_for_deletion_on_windows
|
from infection_monkey.utils.file_utils import mark_file_for_deletion_on_windows
|
||||||
|
from infection_monkey.utils.ids import get_agent_id, get_machine_id
|
||||||
from infection_monkey.utils.monkey_dir import (
|
from infection_monkey.utils.monkey_dir import (
|
||||||
create_monkey_dir,
|
create_monkey_dir,
|
||||||
get_monkey_dir_path,
|
get_monkey_dir_path,
|
||||||
|
@ -120,6 +123,8 @@ class InfectionMonkey:
|
||||||
self._control_client = ControlClient(
|
self._control_client = ControlClient(
|
||||||
server_address=server, island_api_client=self._island_api_client
|
server_address=server, island_api_client=self._island_api_client
|
||||||
)
|
)
|
||||||
|
self._control_channel = ControlChannel(server, GUID, self._island_api_client)
|
||||||
|
self._register_agent()
|
||||||
|
|
||||||
# TODO Refactor the telemetry messengers to accept control client
|
# TODO Refactor the telemetry messengers to accept control client
|
||||||
# and remove control_client_object
|
# and remove control_client_object
|
||||||
|
@ -165,6 +170,18 @@ class InfectionMonkey:
|
||||||
|
|
||||||
return server, island_api_client
|
return server, island_api_client
|
||||||
|
|
||||||
|
def _register_agent(self, server: str):
|
||||||
|
agent_registration_data = AgentRegistrationData(
|
||||||
|
id=get_agent_id(),
|
||||||
|
machine_hardware_id=get_machine_id(),
|
||||||
|
start_time=agent_process.get_start_time(),
|
||||||
|
# parent_id=parent,
|
||||||
|
parent_id=None, # None for now, until we change GUID to UUID
|
||||||
|
cc_server=server,
|
||||||
|
network_interfaces=get_network_interfaces(),
|
||||||
|
)
|
||||||
|
self._island_api_client.register_agent(agent_registration_data)
|
||||||
|
|
||||||
def _select_server(
|
def _select_server(
|
||||||
self, server_clients: Mapping[str, IIslandAPIClient]
|
self, server_clients: Mapping[str, IIslandAPIClient]
|
||||||
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
|
||||||
|
@ -195,7 +212,7 @@ class InfectionMonkey:
|
||||||
|
|
||||||
run_aws_environment_check(self._telemetry_messenger)
|
run_aws_environment_check(self._telemetry_messenger)
|
||||||
|
|
||||||
should_stop = ControlChannel(self._control_client.server_address, GUID).should_agent_stop()
|
should_stop = self._control_channel.should_agent_stop()
|
||||||
if should_stop:
|
if should_stop:
|
||||||
logger.info("The Monkey Island has instructed this agent to stop")
|
logger.info("The Monkey Island has instructed this agent to stop")
|
||||||
return
|
return
|
||||||
|
@ -211,9 +228,6 @@ class InfectionMonkey:
|
||||||
if firewall.is_enabled():
|
if firewall.is_enabled():
|
||||||
firewall.add_firewall_rule()
|
firewall.add_firewall_rule()
|
||||||
|
|
||||||
self._control_channel = ControlChannel(self._control_client.server_address, GUID)
|
|
||||||
self._control_channel.register_agent(self._opts.parent)
|
|
||||||
|
|
||||||
config = self._control_channel.get_config()
|
config = self._control_channel.get_config()
|
||||||
|
|
||||||
relay_port = get_free_tcp_port()
|
relay_port = get_free_tcp_port()
|
||||||
|
|
|
@ -10,6 +10,7 @@ from common.agent_event_serializers import (
|
||||||
PydanticAgentEventSerializer,
|
PydanticAgentEventSerializer,
|
||||||
)
|
)
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
|
from common.agent_registration_data import AgentRegistrationData
|
||||||
from infection_monkey.island_api_client import (
|
from infection_monkey.island_api_client import (
|
||||||
HTTPIslandAPIClient,
|
HTTPIslandAPIClient,
|
||||||
IslandAPIConnectionError,
|
IslandAPIConnectionError,
|
||||||
|
@ -22,14 +23,25 @@ from infection_monkey.island_api_client import (
|
||||||
SERVER = "1.1.1.1:9999"
|
SERVER = "1.1.1.1:9999"
|
||||||
PBA_FILE = "dummy.pba"
|
PBA_FILE = "dummy.pba"
|
||||||
WINDOWS = "windows"
|
WINDOWS = "windows"
|
||||||
|
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
||||||
|
AGENT_REGISTRATION = AgentRegistrationData(
|
||||||
|
id=AGENT_ID,
|
||||||
|
machine_hardware_id=1,
|
||||||
|
start_time=0,
|
||||||
|
parent_id=None,
|
||||||
|
cc_server=SERVER,
|
||||||
|
network_interfaces=[],
|
||||||
|
)
|
||||||
|
|
||||||
ISLAND_URI = f"https://{SERVER}/api?action=is-up"
|
ISLAND_URI = f"https://{SERVER}/api?action=is-up"
|
||||||
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
|
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/log"
|
||||||
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
|
ISLAND_GET_PBA_FILE_URI = f"https://{SERVER}/api/pba/download/{PBA_FILE}"
|
||||||
ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}"
|
ISLAND_GET_AGENT_BINARY_URI = f"https://{SERVER}/api/agent-binaries/{WINDOWS}"
|
||||||
ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
|
ISLAND_SEND_EVENTS_URI = f"https://{SERVER}/api/agent-events"
|
||||||
|
ISLAND_REGISTER_AGENT_URI = f"https://{SERVER}/api/agents"
|
||||||
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
ISLAND_AGENT_STOP_URI = f"https://{SERVER}/api/monkey-control/needs-to-stop/{AGENT_ID}"
|
||||||
|
ISLAND_GET_CONFIG_URI = f"https://{SERVER}/api/agent-configuration"
|
||||||
|
ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/propagation-credentials"
|
||||||
|
|
||||||
|
|
||||||
class Event1(AbstractAgentEvent):
|
class Event1(AbstractAgentEvent):
|
||||||
|
@ -275,3 +287,177 @@ def test_island_api_client_send_events__status_code(island_api_client, status_co
|
||||||
with pytest.raises(expected_error):
|
with pytest.raises(expected_error):
|
||||||
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
|
m.post(ISLAND_SEND_EVENTS_URI, status_code=status_code)
|
||||||
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
|
island_api_client.send_events(events=[Event1(source=AGENT_ID, a=1)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"actual_error, expected_error",
|
||||||
|
[
|
||||||
|
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||||
|
(TimeoutError, IslandAPITimeoutError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client__register_agent(island_api_client, actual_error, expected_error):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.post(ISLAND_REGISTER_AGENT_URI, exc=actual_error)
|
||||||
|
island_api_client.register_agent(AGENT_REGISTRATION)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code, expected_error",
|
||||||
|
[
|
||||||
|
(401, IslandAPIRequestError),
|
||||||
|
(501, IslandAPIRequestFailedError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client_register_agent__status_code(
|
||||||
|
island_api_client, status_code, expected_error
|
||||||
|
):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.post(ISLAND_REGISTER_AGENT_URI, status_code=status_code)
|
||||||
|
island_api_client.register_agent(AGENT_REGISTRATION)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"actual_error, expected_error",
|
||||||
|
[
|
||||||
|
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||||
|
(TimeoutError, IslandAPITimeoutError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client__should_agent_stop(island_api_client, actual_error, expected_error):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.get(ISLAND_AGENT_STOP_URI, exc=actual_error)
|
||||||
|
island_api_client.should_agent_stop(AGENT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code, expected_error",
|
||||||
|
[
|
||||||
|
(401, IslandAPIRequestError),
|
||||||
|
(501, IslandAPIRequestFailedError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client_should_agent_stop__status_code(
|
||||||
|
island_api_client, status_code, expected_error
|
||||||
|
):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.get(ISLAND_AGENT_STOP_URI, status_code=status_code)
|
||||||
|
island_api_client.should_agent_stop(AGENT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def test_island_api_client_should_agent_stop__bad_json(island_api_client):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(IslandAPIRequestFailedError):
|
||||||
|
m.get(ISLAND_AGENT_STOP_URI, content=b"bad")
|
||||||
|
island_api_client.should_agent_stop(AGENT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"actual_error, expected_error",
|
||||||
|
[
|
||||||
|
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||||
|
(TimeoutError, IslandAPITimeoutError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client__get_config(island_api_client, actual_error, expected_error):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.get(ISLAND_GET_CONFIG_URI, exc=actual_error)
|
||||||
|
island_api_client.get_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code, expected_error",
|
||||||
|
[
|
||||||
|
(401, IslandAPIRequestError),
|
||||||
|
(501, IslandAPIRequestFailedError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client_get_config__status_code(island_api_client, status_code, expected_error):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.get(ISLAND_GET_CONFIG_URI, status_code=status_code)
|
||||||
|
island_api_client.get_config()
|
||||||
|
|
||||||
|
|
||||||
|
def test_island_api_client_get_config__bad_json(island_api_client):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(IslandAPIRequestFailedError):
|
||||||
|
m.get(ISLAND_GET_CONFIG_URI, content=b"bad")
|
||||||
|
island_api_client.get_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"actual_error, expected_error",
|
||||||
|
[
|
||||||
|
(requests.exceptions.ConnectionError, IslandAPIConnectionError),
|
||||||
|
(TimeoutError, IslandAPITimeoutError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client__get_credentials_for_propagation(
|
||||||
|
island_api_client, actual_error, expected_error
|
||||||
|
):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, exc=actual_error)
|
||||||
|
island_api_client.get_credentials_for_propagation()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code, expected_error",
|
||||||
|
[
|
||||||
|
(401, IslandAPIRequestError),
|
||||||
|
(501, IslandAPIRequestFailedError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_island_api_client_get_credentials_for_propagation__status_code(
|
||||||
|
island_api_client, status_code, expected_error
|
||||||
|
):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(expected_error):
|
||||||
|
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, status_code=status_code)
|
||||||
|
island_api_client.get_credentials_for_propagation()
|
||||||
|
|
||||||
|
|
||||||
|
def test_island_api_client_get_credentials_for_propagation__bad_json(island_api_client):
|
||||||
|
with requests_mock.Mocker() as m:
|
||||||
|
m.get(ISLAND_URI)
|
||||||
|
island_api_client.connect(SERVER)
|
||||||
|
|
||||||
|
with pytest.raises(IslandAPIRequestFailedError):
|
||||||
|
m.get(ISLAND_GET_PROPAGATION_CREDENTIALS_URI, content=b"bad")
|
||||||
|
island_api_client.get_credentials_for_propagation()
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from infection_monkey.i_control_channel import IslandCommunicationError
|
||||||
|
from infection_monkey.island_api_client import (
|
||||||
|
IIslandAPIClient,
|
||||||
|
IslandAPIConnectionError,
|
||||||
|
IslandAPIRequestError,
|
||||||
|
IslandAPIRequestFailedError,
|
||||||
|
IslandAPITimeoutError,
|
||||||
|
)
|
||||||
|
from infection_monkey.master.control_channel import ControlChannel
|
||||||
|
|
||||||
|
SERVER = "server"
|
||||||
|
AGENT_ID = "agent"
|
||||||
|
CONTROL_CHANNEL_API_ERRORS = [
|
||||||
|
IslandAPIConnectionError,
|
||||||
|
IslandAPIRequestError,
|
||||||
|
IslandAPIRequestFailedError,
|
||||||
|
IslandAPITimeoutError,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def island_api_client() -> IIslandAPIClient:
|
||||||
|
client = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def control_channel(island_api_client) -> ControlChannel:
|
||||||
|
return ControlChannel(SERVER, AGENT_ID, island_api_client)
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_channel__should_agent_stop(control_channel, island_api_client):
|
||||||
|
control_channel.should_agent_stop()
|
||||||
|
assert island_api_client.should_agent_stop.called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
||||||
|
def test_control_channel__should_agent_stop_raises_error(
|
||||||
|
control_channel, island_api_client, api_error
|
||||||
|
):
|
||||||
|
island_api_client.should_agent_stop.side_effect = api_error()
|
||||||
|
|
||||||
|
with pytest.raises(IslandCommunicationError):
|
||||||
|
control_channel.should_agent_stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_channel__get_config(control_channel, island_api_client):
|
||||||
|
control_channel.get_config()
|
||||||
|
assert island_api_client.get_config.called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
||||||
|
def test_control_channel__get_config_raises_error(control_channel, island_api_client, api_error):
|
||||||
|
island_api_client.get_config.side_effect = api_error()
|
||||||
|
|
||||||
|
with pytest.raises(IslandCommunicationError):
|
||||||
|
control_channel.get_config()
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_channel__get_credentials_for_propagation(control_channel, island_api_client):
|
||||||
|
control_channel.get_credentials_for_propagation()
|
||||||
|
assert island_api_client.get_credentials_for_propagation.called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
||||||
|
def test_control_channel__get_credentials_for_propagation_raises_error(
|
||||||
|
control_channel, island_api_client, api_error
|
||||||
|
):
|
||||||
|
island_api_client.get_credentials_for_propagation.side_effect = api_error()
|
||||||
|
|
||||||
|
with pytest.raises(IslandCommunicationError):
|
||||||
|
control_channel.get_credentials_for_propagation()
|
Loading…
Reference in New Issue