diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index a19f316db..2d2bcc3d7 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -4,11 +4,14 @@ from contextlib import suppress from ipaddress import IPv4Address from typing import Dict, Iterable, Iterator, MutableMapping, Optional -import requests - -from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST +from infection_monkey.transport import IslandApiClient +from infection_monkey.transport.island_api_client_errors import ( + IslandAPIConnectionError, + IslandAPIError, + IslandAPITimeoutError, +) from infection_monkey.utils.threading import ( ThreadSafeIterator, create_daemon_thread, @@ -51,18 +54,14 @@ def _check_if_island_server(server: str) -> bool: logger.debug(f"Trying to connect to server: {server}") try: - requests.get( # noqa: DUO123 - f"https://{server}/api?action=is-up", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) + _ = IslandApiClient(server) return True - except requests.exceptions.ConnectionError as err: + except IslandAPIConnectionError as err: logger.error(f"Unable to connect to server/relay {server}: {err}") - except TimeoutError as err: + except IslandAPITimeoutError as err: logger.error(f"Timed out while connecting to server/relay {server}: {err}") - except Exception as err: + except IslandAPIError as err: logger.error( f"Exception encountered when trying to connect to server/relay {server}: {err}" ) diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index ac7eb1b16..58f2f1be7 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -1,8 +1,8 @@ import pytest -import requests import requests_mock from infection_monkey.network.relay.utils import find_server +from infection_monkey.transport.island_api_client_errors import IslandAPIConnectionError SERVER_1 = "1.1.1.1:12312" SERVER_2 = "2.2.2.2:4321" @@ -16,10 +16,10 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] @pytest.mark.parametrize( "expected_server,server_response_pairs", [ - (None, [(server, {"exc": requests.exceptions.ConnectionError}) for server in servers]), + (None, [(server, {"exc": IslandAPIConnectionError}) for server in servers]), ( SERVER_2, - [(SERVER_1, {"exc": requests.exceptions.ConnectionError})] + [(SERVER_1, {"exc": IslandAPIConnectionError})] + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item] ), ], @@ -34,7 +34,7 @@ def test_find_server(expected_server, server_response_pairs): def test_find_server__multiple_successes(): with requests_mock.Mocker() as mock: - mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError) + mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError) mock.get(f"https://{SERVER_2}/api?action=is-up", text="") mock.get(f"https://{SERVER_3}/api?action=is-up", text="") mock.get(f"https://{SERVER_4}/api?action=is-up", text="")