Agent: Fix server selection logic

This commit is contained in:
Mike Salvatore 2022-09-19 14:05:34 -04:00
parent 9ea291a7fa
commit 2ebb7621e3
2 changed files with 19 additions and 11 deletions

View File

@ -5,7 +5,7 @@ import subprocess
import sys
from ipaddress import IPv4Address, IPv4Interface
from pathlib import Path, WindowsPath
from typing import List, Tuple
from typing import List, Mapping, Optional, Tuple
from pubsub.core import Publisher
@ -137,9 +137,13 @@ class InfectionMonkey:
return opts
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
server, island_api_client = find_server(self._opts.servers)
server_clients = find_server(self._opts.servers)
server, island_api_client = self._select_server(server_clients)
if server:
logger.info(f"Successfully connected to the island via {server}")
else:
@ -150,11 +154,20 @@ class InfectionMonkey:
# NOTE: Since we pass the address for each of our interfaces to the exploited
# machines, is it possible for a machine to unintentionally unregister itself from the
# relay if it is able to connect to the relay over multiple interfaces?
servers_to_close = (s for s in self._opts.servers if s != server)
servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s])
send_remove_from_waitlist_control_message_to_relays(servers_to_close)
return server, island_api_client
def _select_server(
self, server_clients: Mapping[str, IIslandAPIClient]
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
for server in self._opts.servers:
if server_clients[server]:
return server, server_clients[server]
return None, None
@staticmethod
def _log_arguments(args):
arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()])

View File

@ -2,7 +2,7 @@ import logging
import socket
from contextlib import suppress
from ipaddress import IPv4Address
from typing import Dict, Iterable, Iterator, MutableMapping, Optional, Tuple
from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Tuple
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
def find_server(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]:
server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
@ -39,12 +39,7 @@ def find_server(servers: Iterable[str]) -> Tuple[Optional[str], Optional[IIsland
num_workers=NUM_FIND_SERVER_WORKERS,
)
for server in server_list:
if server_results[server]:
island_api_client = server_results[server]
return server, island_api_client
return (None, None)
return server_results
def _find_island_server(