From 54ef77698c391cb21423386d05fddb56671cf863 Mon Sep 17 00:00:00 2001
From: Kekoa Kaaikala <kekoa.kaaikala@gmail.com>
Date: Mon, 19 Sep 2022 16:08:42 +0000
Subject: [PATCH] Agent: Add register_agent to IslandAPIClient

---
 .../http_island_api_client.py                 | 27 ++++++++++-
 .../island_api_client/i_island_api_client.py  | 14 +++++-
 .../master/control_channel.py                 | 24 ++++------
 .../master/test_control_channel.py            | 45 +++++++++++++++++++
 4 files changed, 92 insertions(+), 18 deletions(-)
 create mode 100644 monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py

diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py
index a0bebc5a0..407b6562e 100644
--- a/monkey/infection_monkey/island_api_client/http_island_api_client.py
+++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py
@@ -4,10 +4,14 @@ from typing import List, Sequence
 
 import requests
 
-from common import OperatingSystem
+from common import AgentRegistrationData, OperatingSystem
 from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
 from common.agent_events import AbstractAgentEvent
-from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
+from common.common_consts.timeouts import (
+    LONG_REQUEST_TIMEOUT,
+    MEDIUM_REQUEST_TIMEOUT,
+    SHORT_REQUEST_TIMEOUT,
+)
 
 from . import (
     AbstractIslandAPIClientFactory,
@@ -116,6 +120,25 @@ class HTTPIslandAPIClient(IIslandAPIClient):
 
         response.raise_for_status()
 
+    def register_agent(self, agent_registration_data: AgentRegistrationData):
+        try:
+            url = f"https://{agent_registration_data.cc_server}/api/agents"
+            response = requests.post(  # noqa: DUO123
+                url,
+                json=agent_registration_data.dict(simplify=True),
+                verify=False,
+                timeout=SHORT_REQUEST_TIMEOUT,
+            )
+            response.raise_for_status()
+        except (
+            requests.exceptions.ConnectionError,
+            requests.exceptions.TooManyRedirects,
+            requests.exceptions.HTTPError,
+        ) as e:
+            raise IslandAPIConnectionError(e)
+        except requests.exceptions.Timeout as e:
+            raise IslandAPITimeoutError(e)
+
     def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
         serialized_events: List[JSONSerializable] = []
 
diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py
index 5bebc79c1..cc32555dd 100644
--- a/monkey/infection_monkey/island_api_client/i_island_api_client.py
+++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py
@@ -4,6 +4,8 @@ from typing import Optional, Sequence
 from common import OperatingSystem
 from common.agent_events import AbstractAgentEvent
 
+from common import AgentRegistrationData
+
 
 class IIslandAPIClient(ABC):
     """
@@ -74,7 +76,6 @@ class IIslandAPIClient(ABC):
         :raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
         :raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
                                 agent binary
-
         """
 
     @abstractmethod
@@ -92,3 +93,14 @@ class IIslandAPIClient(ABC):
         :raises IslandAPIError: If an unexpected error occurs while attempting to send events to
                                 the island
         """
+
+    @abstractmethod
+    def register_agent(self, agent_registration_data: AgentRegistrationData):
+        """
+        Register an agent with the Island
+
+        :param agent_registration_data: Information about the agent to register
+            with the island
+        :raises IslandAPIConnectionError: If the client could not connect to the island
+        :raises IslandAPITimeoutError: If the command timed out
+        """
diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py
index 76be63b5d..cd4496a9d 100644
--- a/monkey/infection_monkey/master/control_channel.py
+++ b/monkey/infection_monkey/master/control_channel.py
@@ -13,6 +13,11 @@ from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
 from common.credentials import Credentials
 from common.network.network_utils import get_network_interfaces
 from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
+from infection_monkey.island_api_client import (
+    IIslandAPIClient,
+    IslandAPIConnectionError,
+    IslandAPITimeoutError,
+)
 from infection_monkey.utils import agent_process
 from infection_monkey.utils.ids import get_agent_id, get_machine_id
 
@@ -22,9 +27,10 @@ logger = logging.getLogger(__name__)
 
 
 class ControlChannel(IControlChannel):
-    def __init__(self, server: str, agent_id: str):
+    def __init__(self, server: str, agent_id: str, api_client: IIslandAPIClient):
         self._agent_id = agent_id
         self._control_channel_server = server
+        self._island_api_client = api_client
 
     def register_agent(self, parent: Optional[UUID] = None):
         agent_registration_data = AgentRegistrationData(
@@ -38,20 +44,8 @@ class ControlChannel(IControlChannel):
         )
 
         try:
-            url = f"https://{self._control_channel_server}/api/agents"
-            response = requests.post(  # noqa: DUO123
-                url,
-                json=agent_registration_data.dict(simplify=True),
-                verify=False,
-                timeout=SHORT_REQUEST_TIMEOUT,
-            )
-            response.raise_for_status()
-        except (
-            requests.exceptions.ConnectionError,
-            requests.exceptions.Timeout,
-            requests.exceptions.TooManyRedirects,
-            requests.exceptions.HTTPError,
-        ) as e:
+            self._island_api_client.register_agent(agent_registration_data)
+        except (IslandAPIConnectionError, IslandAPITimeoutError) as e:
             raise IslandCommunicationError(e)
 
     def should_agent_stop(self) -> bool:
diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py
new file mode 100644
index 000000000..75a3eb149
--- /dev/null
+++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py
@@ -0,0 +1,45 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from infection_monkey.i_control_channel import IslandCommunicationError
+from infection_monkey.island_api_client import (
+    IIslandAPIClient,
+    IslandAPIConnectionError,
+    IslandAPITimeoutError,
+)
+from infection_monkey.master.control_channel import ControlChannel
+
+
+@pytest.fixture
+def island_api_client() -> IIslandAPIClient:
+    client = MagicMock()
+    return client
+
+
+@pytest.fixture
+def control_channel(island_api_client) -> ControlChannel:
+    return ControlChannel("server", "agent-id", island_api_client)
+
+
+def test_control_channel__register_agent(control_channel, island_api_client):
+    control_channel.register_agent()
+    assert island_api_client.register_agent.called_once()
+
+
+def test_control_channel__register_agent_raises_on_connection_error(
+    control_channel, island_api_client
+):
+    island_api_client.register_agent.side_effect = IslandAPIConnectionError()
+
+    with pytest.raises(IslandCommunicationError):
+        control_channel.register_agent()
+
+
+def test_control_channel__register_agent_raises_on_timeout_error(
+    control_channel, island_api_client
+):
+    island_api_client.register_agent.side_effect = IslandAPITimeoutError()
+
+    with pytest.raises(IslandCommunicationError):
+        control_channel.register_agent()