From 54ef77698c391cb21423386d05fddb56671cf863 Mon Sep 17 00:00:00 2001 From: Kekoa Kaaikala Date: Mon, 19 Sep 2022 16:08:42 +0000 Subject: [PATCH] Agent: Add register_agent to IslandAPIClient --- .../http_island_api_client.py | 27 ++++++++++- .../island_api_client/i_island_api_client.py | 14 +++++- .../master/control_channel.py | 24 ++++------ .../master/test_control_channel.py | 45 +++++++++++++++++++ 4 files changed, 92 insertions(+), 18 deletions(-) create mode 100644 monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index a0bebc5a0..407b6562e 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -4,10 +4,14 @@ from typing import List, Sequence import requests -from common import OperatingSystem +from common import AgentRegistrationData, OperatingSystem from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_events import AbstractAgentEvent -from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT +from common.common_consts.timeouts import ( + LONG_REQUEST_TIMEOUT, + MEDIUM_REQUEST_TIMEOUT, + SHORT_REQUEST_TIMEOUT, +) from . import ( AbstractIslandAPIClientFactory, @@ -116,6 +120,25 @@ class HTTPIslandAPIClient(IIslandAPIClient): response.raise_for_status() + def register_agent(self, agent_registration_data: AgentRegistrationData): + try: + url = f"https://{agent_registration_data.cc_server}/api/agents" + response = requests.post( # noqa: DUO123 + url, + json=agent_registration_data.dict(simplify=True), + verify=False, + timeout=SHORT_REQUEST_TIMEOUT, + ) + response.raise_for_status() + except ( + requests.exceptions.ConnectionError, + requests.exceptions.TooManyRedirects, + requests.exceptions.HTTPError, + ) as e: + raise IslandAPIConnectionError(e) + except requests.exceptions.Timeout as e: + raise IslandAPITimeoutError(e) + def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: serialized_events: List[JSONSerializable] = [] diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index 5bebc79c1..cc32555dd 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -4,6 +4,8 @@ from typing import Optional, Sequence from common import OperatingSystem from common.agent_events import AbstractAgentEvent +from common import AgentRegistrationData + class IIslandAPIClient(ABC): """ @@ -74,7 +76,6 @@ class IIslandAPIClient(ABC): :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the agent binary - """ @abstractmethod @@ -92,3 +93,14 @@ class IIslandAPIClient(ABC): :raises IslandAPIError: If an unexpected error occurs while attempting to send events to the island """ + + @abstractmethod + def register_agent(self, agent_registration_data: AgentRegistrationData): + """ + Register an agent with the Island + + :param agent_registration_data: Information about the agent to register + with the island + :raises IslandAPIConnectionError: If the client could not connect to the island + :raises IslandAPITimeoutError: If the command timed out + """ diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index 76be63b5d..cd4496a9d 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -13,6 +13,11 @@ from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.credentials import Credentials from common.network.network_utils import get_network_interfaces from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError +from infection_monkey.island_api_client import ( + IIslandAPIClient, + IslandAPIConnectionError, + IslandAPITimeoutError, +) from infection_monkey.utils import agent_process from infection_monkey.utils.ids import get_agent_id, get_machine_id @@ -22,9 +27,10 @@ logger = logging.getLogger(__name__) class ControlChannel(IControlChannel): - def __init__(self, server: str, agent_id: str): + def __init__(self, server: str, agent_id: str, api_client: IIslandAPIClient): self._agent_id = agent_id self._control_channel_server = server + self._island_api_client = api_client def register_agent(self, parent: Optional[UUID] = None): agent_registration_data = AgentRegistrationData( @@ -38,20 +44,8 @@ class ControlChannel(IControlChannel): ) try: - url = f"https://{self._control_channel_server}/api/agents" - response = requests.post( # noqa: DUO123 - url, - json=agent_registration_data.dict(simplify=True), - verify=False, - timeout=SHORT_REQUEST_TIMEOUT, - ) - response.raise_for_status() - except ( - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.TooManyRedirects, - requests.exceptions.HTTPError, - ) as e: + self._island_api_client.register_agent(agent_registration_data) + except (IslandAPIConnectionError, IslandAPITimeoutError) as e: raise IslandCommunicationError(e) def should_agent_stop(self) -> bool: diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py new file mode 100644 index 000000000..75a3eb149 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py @@ -0,0 +1,45 @@ +from unittest.mock import MagicMock + +import pytest + +from infection_monkey.i_control_channel import IslandCommunicationError +from infection_monkey.island_api_client import ( + IIslandAPIClient, + IslandAPIConnectionError, + IslandAPITimeoutError, +) +from infection_monkey.master.control_channel import ControlChannel + + +@pytest.fixture +def island_api_client() -> IIslandAPIClient: + client = MagicMock() + return client + + +@pytest.fixture +def control_channel(island_api_client) -> ControlChannel: + return ControlChannel("server", "agent-id", island_api_client) + + +def test_control_channel__register_agent(control_channel, island_api_client): + control_channel.register_agent() + assert island_api_client.register_agent.called_once() + + +def test_control_channel__register_agent_raises_on_connection_error( + control_channel, island_api_client +): + island_api_client.register_agent.side_effect = IslandAPIConnectionError() + + with pytest.raises(IslandCommunicationError): + control_channel.register_agent() + + +def test_control_channel__register_agent_raises_on_timeout_error( + control_channel, island_api_client +): + island_api_client.register_agent.side_effect = IslandAPITimeoutError() + + with pytest.raises(IslandCommunicationError): + control_channel.register_agent()