Agent: Find/check island servers concurrently

This commit is contained in:
Mike Salvatore 2022-09-12 10:01:13 -04:00
parent a9edbb2874
commit 67893b7825
2 changed files with 62 additions and 19 deletions

View File

@ -1,20 +1,53 @@
import logging import logging
import socket import socket
from contextlib import suppress
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Iterable, Optional from typing import Dict, Iterable, Iterator, MutableMapping, Optional
import requests import requests
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port from common.network.network_utils import address_to_ip_port
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import create_daemon_thread from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread,
run_worker_threads,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The number of Island servers to test simultaneously. 32 threads seems large enough for all
# practical purposes. Revisit this if it's not.
NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]: def find_server(servers: Iterable[str]) -> Optional[str]:
for server in servers: server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {}
run_worker_threads(
_find_island_server,
"FindIslandServer",
args=(server_iterator, server_results),
num_workers=NUM_FIND_SERVER_WORKERS,
)
for server in server_list:
if server_results[server]:
return server
return None
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
with suppress(StopIteration):
server = next(servers)
server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool:
logger.debug(f"Trying to connect to server: {server}") logger.debug(f"Trying to connect to server: {server}")
try: try:
@ -24,7 +57,7 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
timeout=MEDIUM_REQUEST_TIMEOUT, timeout=MEDIUM_REQUEST_TIMEOUT,
) )
return server return True
except requests.exceptions.ConnectionError as err: except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}") logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err: except TimeoutError as err:
@ -34,7 +67,7 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
f"Exception encountered when trying to connect to server/relay {server}: {err}" f"Exception encountered when trying to connect to server/relay {server}: {err}"
) )
return None return False
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):

View File

@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
( (
SERVER_2, SERVER_2,
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})] [(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
+ [(server, {"text": ""}) for server in servers[1:]], + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
), ),
], ],
) )
@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs):
mock.get(f"https://{server}/api?action=is-up", **response) mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server assert find_server(servers) is expected_server
def test_find_server__multiple_successes():
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError)
mock.get(f"https://{SERVER_2}/api?action=is-up", text="")
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2