diff --git a/monkey/infection_monkey/control.py b/monkey/infection_monkey/control.py index 8d1e48a22..3aedeb272 100644 --- a/monkey/infection_monkey/control.py +++ b/monkey/infection_monkey/control.py @@ -5,7 +5,6 @@ from socket import gethostname from typing import Mapping, Optional import requests -from requests.exceptions import ConnectionError import infection_monkey.tunnel as tunnel from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT @@ -63,38 +62,6 @@ class ControlClient: timeout=MEDIUM_REQUEST_TIMEOUT, ) - def find_server(self, default_tunnel=None): - logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}") - if default_tunnel: - logger.debug("default_tunnel: %s" % (default_tunnel,)) - - try: - debug_message = "Trying to connect to server: %s" % self.server_address - if self.proxies: - debug_message += " through proxies: %s" % self.proxies - logger.debug(debug_message) - requests.get( # noqa: DUO123 - f"https://{self.server_address}/api?action=is-up", - verify=False, - proxies=self.proxies, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - return True - except ConnectionError as exc: - logger.warning("Error connecting to control server %s: %s", self.server_address, exc) - - if self.proxies: - return False - else: - logger.info("Starting tunnel lookup...") - proxy_find = tunnel.find_tunnel(default=default_tunnel) - if proxy_find: - self.set_proxies(proxy_find) - return self.find_server() - else: - logger.info("No tunnel found") - return False - def set_proxies(self, proxy_find): """ Note: The proxy schema changes between different versions of requests and urllib3, diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 63e0b71d1..918f5240c 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -43,6 +43,10 @@ from infection_monkey.model import VictimHostFactory from infection_monkey.network.firewall import app as firewall from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces from infection_monkey.network.relay import TCPRelay +from infection_monkey.network.relay.utils import ( + find_server, + send_remove_from_waitlist_control_message_to_relays, +) from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter @@ -96,8 +100,12 @@ class InfectionMonkey: logger.info("Monkey is initializing...") self._singleton = SystemSingleton() self._opts = self._get_arguments(args) - self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.servers) - self._control_client = ControlClient(self._opts.servers) + + # TODO: Revisit variable names + server = self._get_server() + self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) + self._control_client = ControlClient(server_address=server) + # TODO Refactor the telemetry messengers to accept control client # and remove control_client_object ControlClient.control_client_object = self._control_client @@ -117,6 +125,19 @@ class InfectionMonkey: return opts + def _get_server(self): + servers_iterator = (s for s in self._opts.servers) + server = find_server(servers_iterator) + if server: + logger.info(f"Successfully connected to the island via {server}") + else: + raise Exception( + f"Failed to connect to the island via any known servers: {self._opts.servers}" + ) + send_remove_from_waitlist_control_message_to_relays(servers_iterator) + + return server + @staticmethod def _log_arguments(args): arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()]) @@ -130,7 +151,7 @@ class InfectionMonkey: logger.info("Agent is starting...") logger.info(f"Agent GUID: {GUID}") - self._connect_to_island() + self._control_client.wakeup(parent=self._opts.parent) # TODO: Reevaluate who is responsible to send this information if is_windows_os(): @@ -148,24 +169,6 @@ class InfectionMonkey: self._setup() self._master.start() - def _connect_to_island(self): - # Sets island's IP and port for monkey to communicate to - if self._current_server_is_set(): - logger.debug(f"Default server set to: {self._control_client.server_address}") - else: - raise Exception( - f"Failed to connect to the island via " - f"any known server address: {self._opts.servers}" - ) - - self._control_client.wakeup(parent=self._opts.parent) - - def _current_server_is_set(self) -> bool: - if self._control_client.find_server(default_tunnel=self._opts.servers): - return True - - return False - def _setup(self): logger.debug("Starting the setup phase.") diff --git a/monkey/infection_monkey/network/relay/__init__.py b/monkey/infection_monkey/network/relay/__init__.py index 563b972c4..b9eb8a009 100644 --- a/monkey/infection_monkey/network/relay/__init__.py +++ b/monkey/infection_monkey/network/relay/__init__.py @@ -1,4 +1,7 @@ -from .relay_connection_handler import RelayConnectionHandler, RELAY_CONTROL_MESSAGE +from .relay_connection_handler import ( + RelayConnectionHandler, + RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST, +) from .relay_user_handler import RelayUser, RelayUserHandler from .sockets_pipe import SocketsPipe from .tcp_connection_handler import TCPConnectionHandler diff --git a/monkey/infection_monkey/network/relay/relay_connection_handler.py b/monkey/infection_monkey/network/relay/relay_connection_handler.py index 3d7755fb9..4b4475e52 100644 --- a/monkey/infection_monkey/network/relay/relay_connection_handler.py +++ b/monkey/infection_monkey/network/relay/relay_connection_handler.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address from .relay_user_handler import RelayUserHandler from .tcp_pipe_spawner import TCPPipeSpawner -RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -" +RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -" class RelayConnectionHandler: @@ -25,7 +25,7 @@ class RelayConnectionHandler: control_message = sock.recv(socket.MSG_PEEK) - if control_message.startswith(RELAY_CONTROL_MESSAGE): + if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST): self._relay_user_handler.disconnect_user(addr) else: self._relay_user_handler.add_relay_user(addr) diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py new file mode 100644 index 000000000..096500c5b --- /dev/null +++ b/monkey/infection_monkey/network/relay/utils.py @@ -0,0 +1,62 @@ +import logging +import socket +from typing import Iterable, Optional + +import requests + +from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT +from common.network.network_utils import address_to_ip_port +from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST +from infection_monkey.utils.threading import create_daemon_thread + +logger = logging.getLogger(__name__) + + +def find_server(servers: Iterable[str]) -> Optional[str]: + logger.debug(f"Trying to wake up with servers: {', '.join(servers)}") + + for server in servers: + logger.debug(f"Trying to connect to server: {server}") + + try: + requests.get( # noqa: DUO123 + f"https://{server}/api?action=is-up", + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + + return server + except requests.exceptions.ConnectionError as err: + logger.error(f"Unable to connect to server/relay {server}: {err}") + except TimeoutError as err: + logger.error(f"Timed out while connecting to server/relay {server}: {err}") + except Exception as err: + logger.error( + f"Exception encountered when trying to connect to server/relay {server}: {err}" + ) + + return None + + +def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): + for server in servers: + t = create_daemon_thread( + target=_send_remove_from_waitlist_control_message_to_relay, + name="SendRemoveFromWaitlistControlMessageToRelaysThread", + args=(server,), + ) + t.start() + + +def _send_remove_from_waitlist_control_message_to_relay(server: str): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket: + d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT) + + ip, port = address_to_ip_port(server) + logger.info(f"Control message was sent to the server/relay {server}") + + try: + d_socket.connect((ip, int(port))) + d_socket.send(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST) + except OSError as err: + logger.error(f"Error connecting to socket {server}: {err}") diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_connection_handler.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_connection_handler.py index 92aff10fc..48f8b6bf5 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_connection_handler.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_connection_handler.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock import pytest from monkey.infection_monkey.network.relay import ( - RELAY_CONTROL_MESSAGE, + RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST, RelayConnectionHandler, RelayUserHandler, TCPPipeSpawner, @@ -27,7 +27,7 @@ def relay_user_handler(): @pytest.fixture def close_socket(): sock = MagicMock(spec=socket.socket) - sock.recv.return_value = RELAY_CONTROL_MESSAGE + sock.recv.return_value = RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST sock.getpeername.return_value = (USER_ADDRESS, 12345) return sock diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py new file mode 100644 index 000000000..1acd08012 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -0,0 +1,32 @@ +import pytest +import requests +import requests_mock + +from infection_monkey.network.relay.utils import find_server + +SERVER_1 = "1.1.1.1:12312" +SERVER_2 = "2.2.2.2:4321" +SERVER_3 = "3.3.3.3:3142" +SERVER_4 = "4.4.4.4:5000" + + +servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] + + +@pytest.mark.parametrize( + "expected_server,server_response_pairs", + [ + (None, [(server, {"exc": requests.exceptions.ConnectionError}) for server in servers]), + ( + SERVER_2, + [(SERVER_1, {"exc": requests.exceptions.ConnectionError})] + + [(server, {"text": ""}) for server in servers[1:]], + ), + ], +) +def test_find_server(expected_server, server_response_pairs): + with requests_mock.Mocker() as mock: + for server, response in server_response_pairs: + mock.get(f"https://{server}/api?action=is-up", **response) + + assert find_server(servers) is expected_server