diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 034f47295..33042d2cf 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -5,7 +5,7 @@ import subprocess import sys from ipaddress import IPv4Address, IPv4Interface from pathlib import Path, WindowsPath -from typing import List, Tuple +from typing import List, Mapping, Optional, Tuple from pubsub.core import Publisher @@ -137,9 +137,13 @@ class InfectionMonkey: return opts + # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server` def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]: logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") - server, island_api_client = find_server(self._opts.servers) + server_clients = find_server(self._opts.servers) + + server, island_api_client = self._select_server(server_clients) + if server: logger.info(f"Successfully connected to the island via {server}") else: @@ -150,11 +154,20 @@ class InfectionMonkey: # NOTE: Since we pass the address for each of our interfaces to the exploited # machines, is it possible for a machine to unintentionally unregister itself from the # relay if it is able to connect to the relay over multiple interfaces? - servers_to_close = (s for s in self._opts.servers if s != server) + servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s]) send_remove_from_waitlist_control_message_to_relays(servers_to_close) return server, island_api_client + def _select_server( + self, server_clients: Mapping[str, IIslandAPIClient] + ) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: + for server in self._opts.servers: + if server_clients[server]: + return server, server_clients[server] + + return None, None + @staticmethod def _log_arguments(args): arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()]) diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 4a704bf08..e31c9ac97 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -2,7 +2,7 @@ import logging import socket from contextlib import suppress from ipaddress import IPv4Address -from typing import Dict, Iterable, Iterator, MutableMapping, Optional, Tuple +from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Tuple from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) NUM_FIND_SERVER_WORKERS = 32 -def find_server(servers: Iterable[str]) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: +def find_server(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]: server_list = list(servers) server_iterator = ThreadSafeIterator(server_list.__iter__()) server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {} @@ -39,12 +39,7 @@ def find_server(servers: Iterable[str]) -> Tuple[Optional[str], Optional[IIsland num_workers=NUM_FIND_SERVER_WORKERS, ) - for server in server_list: - if server_results[server]: - island_api_client = server_results[server] - return server, island_api_client - - return (None, None) + return server_results def _find_island_server(