Merge branch '2216-find-servers-concurrently' into 2216-fix-connection-issues

This commit is contained in:
Mike Salvatore 2022-09-12 16:49:07 -04:00
commit 70978f9b30
4 changed files with 90 additions and 21 deletions

View File

@ -1,20 +1,53 @@
import logging
import socket
from contextlib import suppress
from ipaddress import IPv4Address
from typing import Iterable, Optional
from typing import Dict, Iterable, Iterator, MutableMapping, Optional
import requests
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import create_daemon_thread
from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread,
run_worker_threads,
)
logger = logging.getLogger(__name__)
# The number of Island servers to test simultaneously. 32 threads seems large enough for all
# practical purposes. Revisit this if it's not.
NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]:
for server in servers:
server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {}
run_worker_threads(
_find_island_server,
"FindIslandServer",
args=(server_iterator, server_results),
num_workers=NUM_FIND_SERVER_WORKERS,
)
for server in server_list:
if server_results[server]:
return server
return None
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
with suppress(StopIteration):
server = next(servers)
server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool:
logger.debug(f"Trying to connect to server: {server}")
try:
@ -24,7 +57,7 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return server
return True
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
@ -34,7 +67,7 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return None
return False
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):

View File

@ -1,8 +1,8 @@
import logging
from functools import wraps
from itertools import count
from threading import Event, Thread
from typing import Any, Callable, Iterable, Optional, Tuple
from threading import Event, Lock, Thread
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar
logger = logging.getLogger(__name__)
@ -116,3 +116,19 @@ class InterruptableThreadMixin:
def stop(self):
"""Stop a running thread."""
self._interrupted.set()
T = TypeVar("T")
class ThreadSafeIterator(Iterator[T]):
"""Provides a thread-safe iterator that wraps another iterator"""
def __init__(self, iterator: Iterator[T]):
self._lock = Lock()
self._iterator = iterator
def __next__(self) -> T:
while True:
with self._lock:
return next(self._iterator)

View File

@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
(
SERVER_2,
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
+ [(server, {"text": ""}) for server in servers[1:]],
+ [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
),
],
)
@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs):
mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server
def test_find_server__multiple_successes():
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError)
mock.get(f"https://{SERVER_2}/api?action=is-up", text="")
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2

View File

@ -1,8 +1,10 @@
import logging
from itertools import zip_longest
from threading import Event, current_thread
from typing import Any
from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread,
interruptible_function,
interruptible_iter,
@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt():
assert return_value == 777
assert fn.call_count == 0
def test_thread_safe_iterator():
test_list = [1, 2, 3, 4, 5]
tsi = ThreadSafeIterator(test_list.__iter__())
for actual, expected in zip_longest(tsi, test_list):
assert actual == expected