Agent: Find/check island servers concurrently

This commit is contained in:
Mike Salvatore 2022-09-12 10:01:13 -04:00
parent a9edbb2874
commit 67893b7825
2 changed files with 62 additions and 19 deletions

View File

@ -1,42 +1,75 @@
import logging import logging
import socket import socket
from contextlib import suppress
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Iterable, Optional from typing import Dict, Iterable, Iterator, MutableMapping, Optional
import requests import requests
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port from common.network.network_utils import address_to_ip_port
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import create_daemon_thread from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread,
run_worker_threads,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The number of Island servers to test simultaneously. 32 threads seems large enough for all
# practical purposes. Revisit this if it's not.
NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]: def find_server(servers: Iterable[str]) -> Optional[str]:
for server in servers: server_list = list(servers)
logger.debug(f"Trying to connect to server: {server}") server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {}
try: run_worker_threads(
requests.get( # noqa: DUO123 _find_island_server,
f"https://{server}/api?action=is-up", "FindIslandServer",
verify=False, args=(server_iterator, server_results),
timeout=MEDIUM_REQUEST_TIMEOUT, num_workers=NUM_FIND_SERVER_WORKERS,
) )
for server in server_list:
if server_results[server]:
return server return server
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return None return None
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
with suppress(StopIteration):
server = next(servers)
server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool:
logger.debug(f"Trying to connect to server: {server}")
try:
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return True
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return False
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
for server in servers: for server in servers:
t = create_daemon_thread( t = create_daemon_thread(

View File

@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
( (
SERVER_2, SERVER_2,
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})] [(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
+ [(server, {"text": ""}) for server in servers[1:]], + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
), ),
], ],
) )
@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs):
mock.get(f"https://{server}/api?action=is-up", **response) mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server assert find_server(servers) is expected_server
def test_find_server__multiple_successes():
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError)
mock.get(f"https://{SERVER_2}/api?action=is-up", text="")
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2