From 1c6ca24a4755be5a241c3a3544965e83050c25b2 Mon Sep 17 00:00:00 2001
From: Kekoa Kaaikala <kekoa.kaaikala@gmail.com>
Date: Tue, 20 Sep 2022 18:21:55 +0000
Subject: [PATCH] Agent: Move register_agent out of ControlChannel

---
 monkey/infection_monkey/i_control_channel.py  | 14 +------------
 .../master/control_channel.py                 | 21 +------------------
 monkey/infection_monkey/monkey.py             | 18 ++++++++++++++--
 .../master/test_control_channel.py            | 15 -------------
 4 files changed, 18 insertions(+), 50 deletions(-)

diff --git a/monkey/infection_monkey/i_control_channel.py b/monkey/infection_monkey/i_control_channel.py
index 39075750a..25135231f 100644
--- a/monkey/infection_monkey/i_control_channel.py
+++ b/monkey/infection_monkey/i_control_channel.py
@@ -1,23 +1,11 @@
 import abc
-from typing import Optional, Sequence
-from uuid import UUID
+from typing import Sequence
 
 from common.agent_configuration import AgentConfiguration
 from common.credentials import Credentials
 
 
 class IControlChannel(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
-    def register_agent(self, parent_id: Optional[UUID] = None):
-        """
-        Registers this agent with the Island when this agent starts
-
-        :param parent: The ID of the parent that spawned this agent, or None if this agent has no
-                       parent
-        :raises IslandCommunicationError: If the agent cannot be successfully registered
-        """
-        pass
-
     @abc.abstractmethod
     def should_agent_stop(self) -> bool:
         """
diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py
index 7a2dd547a..eddeb8090 100644
--- a/monkey/infection_monkey/master/control_channel.py
+++ b/monkey/infection_monkey/master/control_channel.py
@@ -1,14 +1,11 @@
 import logging
 from functools import wraps
-from typing import Optional, Sequence
-from uuid import UUID
+from typing import Sequence
 
 from urllib3 import disable_warnings
 
-from common import AgentRegistrationData
 from common.agent_configuration import AgentConfiguration
 from common.credentials import Credentials
-from common.network.network_utils import get_network_interfaces
 from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
 from infection_monkey.island_api_client import (
     IIslandAPIClient,
@@ -17,8 +14,6 @@ from infection_monkey.island_api_client import (
     IslandAPIRequestFailedError,
     IslandAPITimeoutError,
 )
-from infection_monkey.utils import agent_process
-from infection_monkey.utils.ids import get_agent_id, get_machine_id
 
 disable_warnings()  # noqa: DUO131
 
@@ -48,20 +43,6 @@ class ControlChannel(IControlChannel):
         self._control_channel_server = server
         self._island_api_client = api_client
 
-    @handle_island_api_errors
-    def register_agent(self, parent: Optional[UUID] = None):
-        agent_registration_data = AgentRegistrationData(
-            id=get_agent_id(),
-            machine_hardware_id=get_machine_id(),
-            start_time=agent_process.get_start_time(),
-            # parent_id=parent,
-            parent_id=None,  # None for now, until we change GUID to UUID
-            cc_server=self._control_channel_server,
-            network_interfaces=get_network_interfaces(),
-        )
-
-        self._island_api_client.register_agent(agent_registration_data)
-
     @handle_island_api_errors
     def should_agent_stop(self) -> bool:
         if not self._control_channel_server:
diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py
index 403617721..460c07592 100644
--- a/monkey/infection_monkey/monkey.py
+++ b/monkey/infection_monkey/monkey.py
@@ -14,6 +14,7 @@ from common.agent_event_serializers import (
     register_common_agent_event_serializers,
 )
 from common.agent_events import CredentialsStolenEvent
+from common.agent_registration_data import AgentRegistrationData
 from common.event_queue import IAgentEventQueue, PyPubSubAgentEventQueue
 from common.network.network_utils import (
     address_to_ip_port,
@@ -88,9 +89,11 @@ from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter im
     LegacyTelemetryMessengerAdapter,
 )
 from infection_monkey.telemetry.state_telem import StateTelem
+from infection_monkey.utils import agent_process
 from infection_monkey.utils.aws_environment_check import run_aws_environment_check
 from infection_monkey.utils.environment import is_windows_os
 from infection_monkey.utils.file_utils import mark_file_for_deletion_on_windows
+from infection_monkey.utils.ids import get_agent_id, get_machine_id
 from infection_monkey.utils.monkey_dir import (
     create_monkey_dir,
     get_monkey_dir_path,
@@ -121,6 +124,7 @@ class InfectionMonkey:
             server_address=server, island_api_client=self._island_api_client
         )
         self._control_channel = ControlChannel(server, GUID, self._island_api_client)
+        self._register_agent()
 
         # TODO Refactor the telemetry messengers to accept control client
         # and remove control_client_object
@@ -166,6 +170,18 @@ class InfectionMonkey:
 
         return server, island_api_client
 
+    def _register_agent(self, server: str):
+        agent_registration_data = AgentRegistrationData(
+            id=get_agent_id(),
+            machine_hardware_id=get_machine_id(),
+            start_time=agent_process.get_start_time(),
+            # parent_id=parent,
+            parent_id=None,  # None for now, until we change GUID to UUID
+            cc_server=server,
+            network_interfaces=get_network_interfaces(),
+        )
+        self._island_api_client.register_agent(agent_registration_data)
+
     def _select_server(
         self, server_clients: Mapping[str, IIslandAPIClient]
     ) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
@@ -212,8 +228,6 @@ class InfectionMonkey:
         if firewall.is_enabled():
             firewall.add_firewall_rule()
 
-        self._control_channel.register_agent(self._opts.parent)
-
         config = self._control_channel.get_config()
 
         relay_port = get_free_tcp_port()
diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py
index 2521442c6..658635615 100644
--- a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py
+++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py
@@ -33,21 +33,6 @@ def control_channel(island_api_client) -> ControlChannel:
     return ControlChannel(SERVER, AGENT_ID, island_api_client)
 
 
-def test_control_channel__register_agent(control_channel, island_api_client):
-    control_channel.register_agent()
-    assert island_api_client.register_agent.called_once()
-
-
-@pytest.mark.parametrize("api_error", [IslandAPIConnectionError, IslandAPITimeoutError])
-def test_control_channel__register_agent_raises_error(
-    control_channel, island_api_client, api_error
-):
-    island_api_client.register_agent.side_effect = api_error()
-
-    with pytest.raises(IslandCommunicationError):
-        control_channel.register_agent()
-
-
 def test_control_channel__should_agent_stop(control_channel, island_api_client):
     control_channel.should_agent_stop()
     assert island_api_client.should_agent_stop.called_once()