From 67893b7825409783681e8dd2896bf849635ca537 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Mon, 12 Sep 2022 10:01:13 -0400 Subject: [PATCH] Agent: Find/check island servers concurrently --- .../infection_monkey/network/relay/utils.py | 69 ++++++++++++++----- .../network/relay/test_utils.py | 12 +++- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 4a5532c30..a19f316db 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -1,42 +1,75 @@ import logging import socket +from contextlib import suppress from ipaddress import IPv4Address -from typing import Iterable, Optional +from typing import Dict, Iterable, Iterator, MutableMapping, Optional import requests from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST -from infection_monkey.utils.threading import create_daemon_thread +from infection_monkey.utils.threading import ( + ThreadSafeIterator, + create_daemon_thread, + run_worker_threads, +) logger = logging.getLogger(__name__) +# The number of Island servers to test simultaneously. 32 threads seems large enough for all +# practical purposes. Revisit this if it's not. +NUM_FIND_SERVER_WORKERS = 32 + def find_server(servers: Iterable[str]) -> Optional[str]: - for server in servers: - logger.debug(f"Trying to connect to server: {server}") + server_list = list(servers) + server_iterator = ThreadSafeIterator(server_list.__iter__()) + server_results: Dict[str, bool] = {} - try: - requests.get( # noqa: DUO123 - f"https://{server}/api?action=is-up", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) + run_worker_threads( + _find_island_server, + "FindIslandServer", + args=(server_iterator, server_results), + num_workers=NUM_FIND_SERVER_WORKERS, + ) + for server in server_list: + if server_results[server]: return server - except requests.exceptions.ConnectionError as err: - logger.error(f"Unable to connect to server/relay {server}: {err}") - except TimeoutError as err: - logger.error(f"Timed out while connecting to server/relay {server}: {err}") - except Exception as err: - logger.error( - f"Exception encountered when trying to connect to server/relay {server}: {err}" - ) return None +def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]): + with suppress(StopIteration): + server = next(servers) + server_status[server] = _check_if_island_server(server) + + +def _check_if_island_server(server: str) -> bool: + logger.debug(f"Trying to connect to server: {server}") + + try: + requests.get( # noqa: DUO123 + f"https://{server}/api?action=is-up", + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + + return True + except requests.exceptions.ConnectionError as err: + logger.error(f"Unable to connect to server/relay {server}: {err}") + except TimeoutError as err: + logger.error(f"Timed out while connecting to server/relay {server}: {err}") + except Exception as err: + logger.error( + f"Exception encountered when trying to connect to server/relay {server}: {err}" + ) + + return False + + def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): for server in servers: t = create_daemon_thread( diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index 1acd08012..ac7eb1b16 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] ( SERVER_2, [(SERVER_1, {"exc": requests.exceptions.ConnectionError})] - + [(server, {"text": ""}) for server in servers[1:]], + + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item] ), ], ) @@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs): mock.get(f"https://{server}/api?action=is-up", **response) assert find_server(servers) is expected_server + + +def test_find_server__multiple_successes(): + with requests_mock.Mocker() as mock: + mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError) + mock.get(f"https://{SERVER_2}/api?action=is-up", text="") + mock.get(f"https://{SERVER_3}/api?action=is-up", text="") + mock.get(f"https://{SERVER_4}/api?action=is-up", text="") + + assert find_server(servers) == SERVER_2