Merge branch '2216-find-servers-concurrently' into 2216-fix-connection-issues

This commit is contained in:
Mike Salvatore 2022-09-12 16:49:07 -04:00
commit 70978f9b30
4 changed files with 90 additions and 21 deletions

View File

@ -1,42 +1,75 @@
import logging import logging
import socket import socket
from contextlib import suppress
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Iterable, Optional from typing import Dict, Iterable, Iterator, MutableMapping, Optional
import requests import requests
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port from common.network.network_utils import address_to_ip_port
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import create_daemon_thread from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread,
run_worker_threads,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The number of Island servers to test simultaneously. 32 threads seems large enough for all
# practical purposes. Revisit this if it's not.
NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]: def find_server(servers: Iterable[str]) -> Optional[str]:
for server in servers: server_list = list(servers)
logger.debug(f"Trying to connect to server: {server}") server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {}
try: run_worker_threads(
requests.get( # noqa: DUO123 _find_island_server,
f"https://{server}/api?action=is-up", "FindIslandServer",
verify=False, args=(server_iterator, server_results),
timeout=MEDIUM_REQUEST_TIMEOUT, num_workers=NUM_FIND_SERVER_WORKERS,
) )
for server in server_list:
if server_results[server]:
return server return server
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return None return None
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
with suppress(StopIteration):
server = next(servers)
server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool:
logger.debug(f"Trying to connect to server: {server}")
try:
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return True
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return False
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
for server in servers: for server in servers:
t = create_daemon_thread( t = create_daemon_thread(

View File

@ -1,8 +1,8 @@
import logging import logging
from functools import wraps from functools import wraps
from itertools import count from itertools import count
from threading import Event, Thread from threading import Event, Lock, Thread
from typing import Any, Callable, Iterable, Optional, Tuple from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -116,3 +116,19 @@ class InterruptableThreadMixin:
def stop(self): def stop(self):
"""Stop a running thread.""" """Stop a running thread."""
self._interrupted.set() self._interrupted.set()
T = TypeVar("T")
class ThreadSafeIterator(Iterator[T]):
"""Provides a thread-safe iterator that wraps another iterator"""
def __init__(self, iterator: Iterator[T]):
self._lock = Lock()
self._iterator = iterator
def __next__(self) -> T:
while True:
with self._lock:
return next(self._iterator)

View File

@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
( (
SERVER_2, SERVER_2,
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})] [(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
+ [(server, {"text": ""}) for server in servers[1:]], + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
), ),
], ],
) )
@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs):
mock.get(f"https://{server}/api?action=is-up", **response) mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server assert find_server(servers) is expected_server
def test_find_server__multiple_successes():
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError)
mock.get(f"https://{SERVER_2}/api?action=is-up", text="")
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2

View File

@ -1,8 +1,10 @@
import logging import logging
from itertools import zip_longest
from threading import Event, current_thread from threading import Event, current_thread
from typing import Any from typing import Any
from infection_monkey.utils.threading import ( from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread, create_daemon_thread,
interruptible_function, interruptible_function,
interruptible_iter, interruptible_iter,
@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt():
assert return_value == 777 assert return_value == 777
assert fn.call_count == 0 assert fn.call_count == 0
def test_thread_safe_iterator():
test_list = [1, 2, 3, 4, 5]
tsi = ThreadSafeIterator(test_list.__iter__())
for actual, expected in zip_longest(tsi, test_list):
assert actual == expected