From e539495545ff554c7787ec027c8a91b97f7b0f94 Mon Sep 17 00:00:00 2001
From: Ilija Lazoroski <ilija.la@live.com>
Date: Wed, 7 Sep 2022 16:47:47 +0200
Subject: [PATCH] Agent: Find server and send control relay message to all
 other servers

---
 monkey/infection_monkey/monkey.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py
index 596cd0d7f..539d5c61d 100644
--- a/monkey/infection_monkey/monkey.py
+++ b/monkey/infection_monkey/monkey.py
@@ -42,7 +42,7 @@ from infection_monkey.master.control_channel import ControlChannel
 from infection_monkey.model import VictimHostFactory
 from infection_monkey.network.firewall import app as firewall
 from infection_monkey.network.info import get_network_interfaces
-from infection_monkey.network.relay.utils import find_server
+from infection_monkey.network.relay.utils import find_server, send_relay_control_message
 from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
 from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
 from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
@@ -98,9 +98,9 @@ class InfectionMonkey:
         self._opts = self._get_arguments(args)
 
         # TODO: Revisit variable names
-        server = find_server(self._opts.servers)
-        self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
-        self._control_client = ControlClient(server_address=server)
+        self._get_server()
+        self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self.server)
+        self._control_client = ControlClient(server_address=self.server)
 
         # TODO Refactor the telemetry messengers to accept control client
         # and remove control_client_object
@@ -122,6 +122,11 @@ class InfectionMonkey:
 
         return opts
 
+    def _get_server(self):
+        servers_iterator = (s for s in self._opts.servers)
+        self.server = find_server(servers_iterator)
+        send_relay_control_message(servers_iterator)
+
     @staticmethod
     def _log_arguments(args):
         arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()])
@@ -159,14 +164,13 @@ class InfectionMonkey:
             logger.debug(f"Default server set to: {self._control_client.server_address}")
         else:
             raise Exception(
-                f"Failed to connect to the island via "
-                f"any known server address: {self._opts.servers}"
+                f"Failed to connect to the island via " f"any known server address: {self.server}"
             )
 
         self._control_client.wakeup(parent=self._opts.parent)
 
     def _current_server_is_set(self) -> bool:
-        if find_server(servers=self._opts.servers):
+        if self.server:
             return True
 
         return False