diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 4a5532c30..a19f316db 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -1,42 +1,75 @@ import logging import socket +from contextlib import suppress from ipaddress import IPv4Address -from typing import Iterable, Optional +from typing import Dict, Iterable, Iterator, MutableMapping, Optional import requests from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST -from infection_monkey.utils.threading import create_daemon_thread +from infection_monkey.utils.threading import ( + ThreadSafeIterator, + create_daemon_thread, + run_worker_threads, +) logger = logging.getLogger(__name__) +# The number of Island servers to test simultaneously. 32 threads seems large enough for all +# practical purposes. Revisit this if it's not. +NUM_FIND_SERVER_WORKERS = 32 + def find_server(servers: Iterable[str]) -> Optional[str]: - for server in servers: - logger.debug(f"Trying to connect to server: {server}") + server_list = list(servers) + server_iterator = ThreadSafeIterator(server_list.__iter__()) + server_results: Dict[str, bool] = {} - try: - requests.get( # noqa: DUO123 - f"https://{server}/api?action=is-up", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) + run_worker_threads( + _find_island_server, + "FindIslandServer", + args=(server_iterator, server_results), + num_workers=NUM_FIND_SERVER_WORKERS, + ) + for server in server_list: + if server_results[server]: return server - except requests.exceptions.ConnectionError as err: - logger.error(f"Unable to connect to server/relay {server}: {err}") - except TimeoutError as err: - logger.error(f"Timed out while connecting to server/relay {server}: {err}") - except Exception as err: - logger.error( - f"Exception encountered when trying to connect to server/relay {server}: {err}" - ) return None +def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]): + with suppress(StopIteration): + server = next(servers) + server_status[server] = _check_if_island_server(server) + + +def _check_if_island_server(server: str) -> bool: + logger.debug(f"Trying to connect to server: {server}") + + try: + requests.get( # noqa: DUO123 + f"https://{server}/api?action=is-up", + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + + return True + except requests.exceptions.ConnectionError as err: + logger.error(f"Unable to connect to server/relay {server}: {err}") + except TimeoutError as err: + logger.error(f"Timed out while connecting to server/relay {server}: {err}") + except Exception as err: + logger.error( + f"Exception encountered when trying to connect to server/relay {server}: {err}" + ) + + return False + + def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): for server in servers: t = create_daemon_thread( diff --git a/monkey/infection_monkey/utils/threading.py b/monkey/infection_monkey/utils/threading.py index be28aa0b1..6d8b28253 100644 --- a/monkey/infection_monkey/utils/threading.py +++ b/monkey/infection_monkey/utils/threading.py @@ -1,8 +1,8 @@ import logging from functools import wraps from itertools import count -from threading import Event, Thread -from typing import Any, Callable, Iterable, Optional, Tuple +from threading import Event, Lock, Thread +from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar logger = logging.getLogger(__name__) @@ -116,3 +116,19 @@ class InterruptableThreadMixin: def stop(self): """Stop a running thread.""" self._interrupted.set() + + +T = TypeVar("T") + + +class ThreadSafeIterator(Iterator[T]): + """Provides a thread-safe iterator that wraps another iterator""" + + def __init__(self, iterator: Iterator[T]): + self._lock = Lock() + self._iterator = iterator + + def __next__(self) -> T: + while True: + with self._lock: + return next(self._iterator) diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index 1acd08012..ac7eb1b16 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] ( SERVER_2, [(SERVER_1, {"exc": requests.exceptions.ConnectionError})] - + [(server, {"text": ""}) for server in servers[1:]], + + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item] ), ], ) @@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs): mock.get(f"https://{server}/api?action=is-up", **response) assert find_server(servers) is expected_server + + +def test_find_server__multiple_successes(): + with requests_mock.Mocker() as mock: + mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError) + mock.get(f"https://{SERVER_2}/api?action=is-up", text="") + mock.get(f"https://{SERVER_3}/api?action=is-up", text="") + mock.get(f"https://{SERVER_4}/api?action=is-up", text="") + + assert find_server(servers) == SERVER_2 diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py index 96a289096..05b813b66 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py @@ -1,8 +1,10 @@ import logging +from itertools import zip_longest from threading import Event, current_thread from typing import Any from infection_monkey.utils.threading import ( + ThreadSafeIterator, create_daemon_thread, interruptible_function, interruptible_iter, @@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt(): assert return_value == 777 assert fn.call_count == 0 + + +def test_thread_safe_iterator(): + test_list = [1, 2, 3, 4, 5] + tsi = ThreadSafeIterator(test_list.__iter__()) + + for actual, expected in zip_longest(tsi, test_list): + assert actual == expected