diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index b7eed8b0a..b1a222af0 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -5,7 +5,7 @@ import subprocess import sys from ipaddress import IPv4Address, IPv4Interface from pathlib import Path, WindowsPath -from typing import List +from typing import List, Tuple from pubsub.core import Publisher @@ -45,6 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.exploit.wmiexec import WmiExploiter from infection_monkey.exploit.zerologon import ZerologonExploiter from infection_monkey.i_puppet import IPuppet, PluginType +from infection_monkey.island_api_client import IIslandAPIClient from infection_monkey.master import AutomatedMaster from infection_monkey.master.control_channel import ControlChannel from infection_monkey.model import VictimHostFactory @@ -110,7 +111,7 @@ class InfectionMonkey: self._opts = self._get_arguments(args) # TODO: Revisit variable names - server = self._get_server() + server, island_api_client = self._get_server() # TODO: `address_to_port()` should return the port as an integer. self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_port = int(self._cmd_island_port) @@ -136,9 +137,9 @@ class InfectionMonkey: return opts - def _get_server(self): + def _get_server(self) -> Tuple[str, IIslandAPIClient]: logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") - server = find_server(self._opts.servers) + server, island_api_client = find_server(self._opts.servers) if server: logger.info(f"Successfully connected to the island via {server}") else: @@ -152,7 +153,7 @@ class InfectionMonkey: servers_to_close = (s for s in self._opts.servers if s != server) send_remove_from_waitlist_control_message_to_relays(servers_to_close) - return server + return server, island_api_client @staticmethod def _log_arguments(args): diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index c2ffc64ba..acbd9cc46 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -2,11 +2,12 @@ import logging import socket from contextlib import suppress from ipaddress import IPv4Address -from typing import Dict, Iterable, Iterator, MutableMapping, Optional +from typing import Dict, Iterable, Iterator, MutableMapping, Optional, Tuple from common.network.network_utils import address_to_ip_port from infection_monkey.island_api_client import ( HTTPIslandAPIClient, + IIslandAPIClient, IslandAPIConnectionError, IslandAPIError, IslandAPITimeoutError, @@ -25,10 +26,10 @@ logger = logging.getLogger(__name__) NUM_FIND_SERVER_WORKERS = 32 -def find_server(servers: Iterable[str]) -> Optional[str]: +def find_server(servers: Iterable[str]) -> Tuple[Optional[str], Optional[IIslandAPIClient]]: server_list = list(servers) server_iterator = ThreadSafeIterator(server_list.__iter__()) - server_results: Dict[str, bool] = {} + server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {} run_worker_threads( _find_island_server, @@ -39,24 +40,27 @@ def find_server(servers: Iterable[str]) -> Optional[str]: for server in server_list: if server_results[server]: - return server + island_api_client = server_results[server] + return server, island_api_client - return None + return (None, None) -def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]): +def _find_island_server( + servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]] +): with suppress(StopIteration): server = next(servers) server_status[server] = _check_if_island_server(server) -def _check_if_island_server(server: str) -> bool: +def _check_if_island_server(server: str) -> IIslandAPIClient: logger.debug(f"Trying to connect to server: {server}") try: - HTTPIslandAPIClient(server) + island_api_client = HTTPIslandAPIClient(server) - return True + return island_api_client except IslandAPIConnectionError as err: logger.error(f"Unable to connect to server/relay {server}: {err}") except IslandAPITimeoutError as err: @@ -66,7 +70,7 @@ def _check_if_island_server(server: str) -> bool: f"Exception encountered when trying to connect to server/relay {server}: {err}" ) - return False + return None def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index c4e69bd1a..4979cce43 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -1,7 +1,7 @@ import pytest import requests_mock -from infection_monkey.island_api_client import IslandAPIConnectionError +from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError from infection_monkey.network.relay.utils import find_server SERVER_1 = "1.1.1.1:12312" @@ -14,7 +14,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] @pytest.mark.parametrize( - "expected_server,server_response_pairs", + "expected_server, server_response_pairs", [ (None, [(server, {"exc": IslandAPIConnectionError}) for server in servers]), ( @@ -29,7 +29,9 @@ def test_find_server(expected_server, server_response_pairs): for server, response in server_response_pairs: mock.get(f"https://{server}/api?action=is-up", **response) - assert find_server(servers) is expected_server + actual_server, _ = find_server(servers) + + assert actual_server is expected_server def test_find_server__multiple_successes(): @@ -39,4 +41,7 @@ def test_find_server__multiple_successes(): mock.get(f"https://{SERVER_3}/api?action=is-up", text="") mock.get(f"https://{SERVER_4}/api?action=is-up", text="") - assert find_server(servers) == SERVER_2 + actual_server, actual_island_api_client = find_server(servers) + + assert actual_server == SERVER_2 + assert isinstance(actual_island_api_client, IIslandAPIClient)