diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index eaec7e6bf..680db9f6f 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -5,7 +5,7 @@ import subprocess import sys from ipaddress import IPv4Address, IPv4Interface from pathlib import Path, WindowsPath -from typing import List, Optional +from typing import List, Mapping, Optional, Tuple from pubsub.core import Publisher @@ -45,6 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.exploit.wmiexec import WmiExploiter from infection_monkey.exploit.zerologon import ZerologonExploiter from infection_monkey.i_puppet import IPuppet, PluginType +from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.master import AutomatedMaster from infection_monkey.master.control_channel import ControlChannel from infection_monkey.model import VictimHostFactory @@ -52,7 +53,7 @@ from infection_monkey.network.firewall import app as firewall from infection_monkey.network.info import get_free_tcp_port from infection_monkey.network.relay import TCPRelay from infection_monkey.network.relay.utils import ( - find_server, + find_available_island_apis, notify_disconnect, send_remove_from_waitlist_control_message_to_relays, ) @@ -110,7 +111,7 @@ class InfectionMonkey: self._opts = self._get_arguments(args) # TODO: Revisit variable names - server = self._get_server() + server, island_api_client = self._connect_to_island_api() # TODO: `address_to_port()` should return the port as an integer. self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_port = int(self._cmd_island_port) @@ -123,7 +124,7 @@ class InfectionMonkey: self._telemetry_messenger = LegacyTelemetryMessengerAdapter() self._current_depth = self._opts.depth self._master = None - self._relay: Optional[TCPRelay] = None + self._relay: TCPRelay @staticmethod def _get_arguments(args): @@ -136,10 +137,13 @@ class InfectionMonkey: return opts - def _get_server(self): + # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server` + def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]: logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") - servers_iterator = (s for s in self._opts.servers) - server = find_server(servers_iterator) + server_clients = find_available_island_apis(self._opts.servers) + + server, island_api_client = self._select_server(server_clients) + if server: logger.info(f"Successfully connected to the island via {server}") else: @@ -147,12 +151,22 @@ class InfectionMonkey: f"Failed to connect to the island via any known servers: {self._opts.servers}" ) - # Note: Since we pass the address for each of our interfaces to the exploited + # NOTE: Since we pass the address for each of our interfaces to the exploited # machines, is it possible for a machine to unintentionally unregister itself from the # relay if it is able to connect to the relay over multiple interfaces? - send_remove_from_waitlist_control_message_to_relays(servers_iterator) + servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s]) + send_remove_from_waitlist_control_message_to_relays(servers_to_close) - return server + return server, island_api_client + + def _select_server( + self, server_clients: Mapping[str, IIslandAPIClient] + ) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: + for server in self._opts.servers: + if server_clients[server]: + return server, server_clients[server] + + return None, None @staticmethod def _log_arguments(args): diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index a19f316db..451fd65e7 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -2,12 +2,17 @@ import logging import socket from contextlib import suppress from ipaddress import IPv4Address -from typing import Dict, Iterable, Iterator, MutableMapping, Optional +from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Tuple -import requests - -from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT +from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port +from infection_monkey.island_api_client import ( + HTTPIslandAPIClient, + IIslandAPIClient, + IslandAPIConnectionError, + IslandAPIError, + IslandAPITimeoutError, +) from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST from infection_monkey.utils.threading import ( ThreadSafeIterator, @@ -22,10 +27,10 @@ logger = logging.getLogger(__name__) NUM_FIND_SERVER_WORKERS = 32 -def find_server(servers: Iterable[str]) -> Optional[str]: +def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]: server_list = list(servers) server_iterator = ThreadSafeIterator(server_list.__iter__()) - server_results: Dict[str, bool] = {} + server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {} run_worker_threads( _find_island_server, @@ -34,47 +39,39 @@ def find_server(servers: Iterable[str]) -> Optional[str]: num_workers=NUM_FIND_SERVER_WORKERS, ) - for server in server_list: - if server_results[server]: - return server - - return None + return server_results -def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]): +def _find_island_server( + servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]] +): with suppress(StopIteration): server = next(servers) server_status[server] = _check_if_island_server(server) -def _check_if_island_server(server: str) -> bool: +def _check_if_island_server(server: str) -> IIslandAPIClient: logger.debug(f"Trying to connect to server: {server}") try: - requests.get( # noqa: DUO123 - f"https://{server}/api?action=is-up", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - - return True - except requests.exceptions.ConnectionError as err: + return HTTPIslandAPIClient(server) + except IslandAPIConnectionError as err: logger.error(f"Unable to connect to server/relay {server}: {err}") - except TimeoutError as err: + except IslandAPITimeoutError as err: logger.error(f"Timed out while connecting to server/relay {server}: {err}") - except Exception as err: + except IslandAPIError as err: logger.error( f"Exception encountered when trying to connect to server/relay {server}: {err}" ) - return False + return None def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): - for server in servers: + for i, server in enumerate(servers, start=1): t = create_daemon_thread( target=_send_remove_from_waitlist_control_message_to_relay, - name="SendRemoveFromWaitlistControlMessageToRelaysThread", + name=f"SendRemoveFromWaitlistControlMessageToRelaysThread-{i:02d}", args=(server,), ) t.start() @@ -93,8 +90,10 @@ def notify_disconnect(server_ip: IPv4Address, server_port: int): :param server_port: The port of the server to notify. """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket: + d_socket.settimeout(LONG_REQUEST_TIMEOUT) + try: - d_socket.connect((server_ip, server_port)) + d_socket.connect((str(server_ip), server_port)) d_socket.sendall(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST) logger.info(f"Control message was sent to the server/relay {server_ip}:{server_port}") except OSError as err: diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index ac7eb1b16..dd00ebc94 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -1,8 +1,8 @@ import pytest -import requests import requests_mock -from infection_monkey.network.relay.utils import find_server +from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError +from infection_monkey.network.relay.utils import find_available_island_apis SERVER_1 = "1.1.1.1:12312" SERVER_2 = "2.2.2.2:4321" @@ -14,29 +14,42 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] @pytest.mark.parametrize( - "expected_server,server_response_pairs", + "expected_available_servers, server_response_pairs", [ - (None, [(server, {"exc": requests.exceptions.ConnectionError}) for server in servers]), + ([], [(server, {"exc": IslandAPIConnectionError}) for server in servers]), ( - SERVER_2, - [(SERVER_1, {"exc": requests.exceptions.ConnectionError})] + servers[1:], + [(SERVER_1, {"exc": IslandAPIConnectionError})] + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item] ), ], ) -def test_find_server(expected_server, server_response_pairs): +def test_find_available_island_apis(expected_available_servers, server_response_pairs): with requests_mock.Mocker() as mock: for server, response in server_response_pairs: mock.get(f"https://{server}/api?action=is-up", **response) - assert find_server(servers) is expected_server + available_apis = find_available_island_apis(servers) + + assert len(available_apis) == len(server_response_pairs) + + for server, island_api_client in available_apis.items(): + if server in expected_available_servers: + assert island_api_client is not None + else: + assert island_api_client is None -def test_find_server__multiple_successes(): +def test_find_available_island_apis__multiple_successes(): + available_servers = [SERVER_2, SERVER_3] with requests_mock.Mocker() as mock: - mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError) - mock.get(f"https://{SERVER_2}/api?action=is-up", text="") - mock.get(f"https://{SERVER_3}/api?action=is-up", text="") - mock.get(f"https://{SERVER_4}/api?action=is-up", text="") + mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError) + for server in available_servers: + mock.get(f"https://{server}/api?action=is-up", text="") - assert find_server(servers) == SERVER_2 + available_apis = find_available_island_apis(servers) + + assert available_apis[SERVER_1] is None + assert available_apis[SERVER_4] is None + for server in available_servers: + assert isinstance(available_apis[server], IIslandAPIClient)