From 804bd4eadbfe63ca49d94cbab0eb4ca03fee2696 Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Wed, 7 Sep 2022 13:29:37 +0200 Subject: [PATCH] Agent: Modify find_server to accept list of servers --- monkey/infection_monkey/control.py | 93 +++++++++++++------ .../infection_monkey/test_control.py | 54 +++++++++++ 2 files changed, 117 insertions(+), 30 deletions(-) diff --git a/monkey/infection_monkey/control.py b/monkey/infection_monkey/control.py index 8d1e48a22..5ed3f24a5 100644 --- a/monkey/infection_monkey/control.py +++ b/monkey/infection_monkey/control.py @@ -1,20 +1,24 @@ import json import logging import platform +import socket from socket import gethostname -from typing import Mapping, Optional +from typing import Mapping, Optional, Sequence import requests from requests.exceptions import ConnectionError import infection_monkey.tunnel as tunnel from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT +from common.network.network_utils import address_to_ip_port from infection_monkey.config import GUID from infection_monkey.network.info import get_host_subnets, local_ips +from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE from infection_monkey.transport.http import HTTPConnectProxy from infection_monkey.transport.tcp import TcpProxy from infection_monkey.utils import agent_process from infection_monkey.utils.environment import is_windows_os +from infection_monkey.utils.threading import create_daemon_thread requests.packages.urllib3.disable_warnings() @@ -63,37 +67,66 @@ class ControlClient: timeout=MEDIUM_REQUEST_TIMEOUT, ) - def find_server(self, default_tunnel=None): - logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}") - if default_tunnel: - logger.debug("default_tunnel: %s" % (default_tunnel,)) + def find_server(self, servers: Sequence[str]): + logger.debug(f"Trying to wake up with servers: {', '.join(servers)}") - try: - debug_message = "Trying to connect to server: %s" % self.server_address - if self.proxies: - debug_message += " through proxies: %s" % self.proxies - logger.debug(debug_message) - requests.get( # noqa: DUO123 - f"https://{self.server_address}/api?action=is-up", - verify=False, - proxies=self.proxies, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - return True - except ConnectionError as exc: - logger.warning("Error connecting to control server %s: %s", self.server_address, exc) + while servers: + server = servers[0] - if self.proxies: - return False - else: - logger.info("Starting tunnel lookup...") - proxy_find = tunnel.find_tunnel(default=default_tunnel) - if proxy_find: - self.set_proxies(proxy_find) - return self.find_server() - else: - logger.info("No tunnel found") - return False + try: + debug_message = f"Trying to connect to server: {server}" + logger.debug(debug_message) + requests.get( # noqa: DUO123 + f"https://{server}/api?action=is-up", + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + + # We remove the server that has been succesfull + servers.remove(server) + + # TODO: Check how we are going to set the server address that the ControlCLient + # is going to use + # self.server_address = server + + # If we have any other server we send them RELAY_CONTROL_MESSAGE + if servers: + for ss in servers: + t = create_daemon_thread( + target=ControlClient._send_relay_control_message, + name="SendControlRelayMessageThread", + args=(ss,), + ) + t.start() + t.join() + + return True + except ConnectionError as err: + logger.error(f"Unable to connect to server/relay {server}: {err}") + servers.remove(server) + except TimeoutError as err: + logger.error(f"Timed out while connecting to server/relay {server}: {err}") + servers.remove(server) + except Exception as err: + logger.error( + f"Exception encountered when trying to connect to server/relay {server}: {err}" + ) + servers.remove(server) + + return False + + @staticmethod + def _send_relay_control_message(server: str): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket: + d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT) + + try: + address, port = address_to_ip_port(server) + d_socket.connect((address, int(port))) + d_socket.send(RELAY_CONTROL_MESSAGE) + logger.info(f"Control message was sent to the server/relay {server}") + except OSError as err: + logger.error(f"Error connecting to socket {server}: {err}") def set_proxies(self, proxy_find): """ diff --git a/monkey/tests/unit_tests/infection_monkey/test_control.py b/monkey/tests/unit_tests/infection_monkey/test_control.py index b90087ebf..6cd3df844 100644 --- a/monkey/tests/unit_tests/infection_monkey/test_control.py +++ b/monkey/tests/unit_tests/infection_monkey/test_control.py @@ -1,7 +1,31 @@ +from unittest.mock import MagicMock + import pytest +import requests from monkey.infection_monkey.control import ControlClient +SERVER_1 = "1.1.1.1:12312" +SERVER_2 = "2.2.2.2:4321" +SERVER_3 = "3.3.3.3:3142" +SERVER_4 = "4.4.4.4:5000" + + +class MockConnectionError: + def __init__(self, *args, **kwargs): + raise requests.exceptions.ConnectionError + + +class RequestsGetArgument: + def __init__(self, *args, **kwargs): + if SERVER_1 in args[0]: + MockConnectionError() + + +@pytest.fixture +def servers(): + return [SERVER_1, SERVER_2, SERVER_3, SERVER_4] + @pytest.mark.parametrize( "is_windows_os,expected_proxy_string", @@ -14,3 +38,33 @@ def test_control_set_proxies(monkeypatch, is_windows_os, expected_proxy_string): control_client.set_proxies(("8.8.8.8", "45455")) assert control_client.proxies["https"] == expected_proxy_string + + +def test_control_find_server_any_exception(monkeypatch, servers): + monkeypatch.setattr("infection_monkey.control.requests.get", MockConnectionError) + + cc = ControlClient(servers) + + return_value = cc.find_server(servers) + + assert return_value is False + assert servers == [] + + +def test_control_find_server_socket(monkeypatch, servers): + mock_connect = MagicMock() + mock_send = MagicMock() + monkeypatch.setattr("infection_monkey.control.requests.get", RequestsGetArgument) + monkeypatch.setattr("infection_monkey.control.socket.socket.connect", mock_connect) + monkeypatch.setattr("infection_monkey.control.socket.socket.send", mock_send) + + cc = ControlClient(servers) + + return_value = cc.find_server(servers) + + assert len(servers) == 2 + assert return_value is True + assert mock_connect.call_count == 2 + assert mock_send.call_count == 2 + # TODO: be sure that connect is called with SERVER_3 and SERVER_4 + # assert mock_connect.call_args == SERVER_3