Agent: Modify find_server to accept list of servers

This commit is contained in:
Ilija Lazoroski 2022-09-07 13:29:37 +02:00 committed by Mike Salvatore
parent ac058c7788
commit 804bd4eadb
2 changed files with 117 additions and 30 deletions

View File

@ -1,20 +1,24 @@
import json
import logging
import platform
import socket
from socket import gethostname
from typing import Mapping, Optional
from typing import Mapping, Optional, Sequence
import requests
from requests.exceptions import ConnectionError
import infection_monkey.tunnel as tunnel
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
from infection_monkey.config import GUID
from infection_monkey.network.info import get_host_subnets, local_ips
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE
from infection_monkey.transport.http import HTTPConnectProxy
from infection_monkey.transport.tcp import TcpProxy
from infection_monkey.utils import agent_process
from infection_monkey.utils.environment import is_windows_os
from infection_monkey.utils.threading import create_daemon_thread
requests.packages.urllib3.disable_warnings()
@ -63,37 +67,66 @@ class ControlClient:
timeout=MEDIUM_REQUEST_TIMEOUT,
)
def find_server(self, default_tunnel=None):
logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}")
if default_tunnel:
logger.debug("default_tunnel: %s" % (default_tunnel,))
def find_server(self, servers: Sequence[str]):
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
try:
debug_message = "Trying to connect to server: %s" % self.server_address
if self.proxies:
debug_message += " through proxies: %s" % self.proxies
logger.debug(debug_message)
requests.get( # noqa: DUO123
f"https://{self.server_address}/api?action=is-up",
verify=False,
proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return True
except ConnectionError as exc:
logger.warning("Error connecting to control server %s: %s", self.server_address, exc)
while servers:
server = servers[0]
if self.proxies:
return False
else:
logger.info("Starting tunnel lookup...")
proxy_find = tunnel.find_tunnel(default=default_tunnel)
if proxy_find:
self.set_proxies(proxy_find)
return self.find_server()
else:
logger.info("No tunnel found")
return False
try:
debug_message = f"Trying to connect to server: {server}"
logger.debug(debug_message)
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
# We remove the server that has been succesfull
servers.remove(server)
# TODO: Check how we are going to set the server address that the ControlCLient
# is going to use
# self.server_address = server
# If we have any other server we send them RELAY_CONTROL_MESSAGE
if servers:
for ss in servers:
t = create_daemon_thread(
target=ControlClient._send_relay_control_message,
name="SendControlRelayMessageThread",
args=(ss,),
)
t.start()
t.join()
return True
except ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
servers.remove(server)
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
servers.remove(server)
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
servers.remove(server)
return False
@staticmethod
def _send_relay_control_message(server: str):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
try:
address, port = address_to_ip_port(server)
d_socket.connect((address, int(port)))
d_socket.send(RELAY_CONTROL_MESSAGE)
logger.info(f"Control message was sent to the server/relay {server}")
except OSError as err:
logger.error(f"Error connecting to socket {server}: {err}")
def set_proxies(self, proxy_find):
"""

View File

@ -1,7 +1,31 @@
from unittest.mock import MagicMock
import pytest
import requests
from monkey.infection_monkey.control import ControlClient
SERVER_1 = "1.1.1.1:12312"
SERVER_2 = "2.2.2.2:4321"
SERVER_3 = "3.3.3.3:3142"
SERVER_4 = "4.4.4.4:5000"
class MockConnectionError:
def __init__(self, *args, **kwargs):
raise requests.exceptions.ConnectionError
class RequestsGetArgument:
def __init__(self, *args, **kwargs):
if SERVER_1 in args[0]:
MockConnectionError()
@pytest.fixture
def servers():
return [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@pytest.mark.parametrize(
"is_windows_os,expected_proxy_string",
@ -14,3 +38,33 @@ def test_control_set_proxies(monkeypatch, is_windows_os, expected_proxy_string):
control_client.set_proxies(("8.8.8.8", "45455"))
assert control_client.proxies["https"] == expected_proxy_string
def test_control_find_server_any_exception(monkeypatch, servers):
monkeypatch.setattr("infection_monkey.control.requests.get", MockConnectionError)
cc = ControlClient(servers)
return_value = cc.find_server(servers)
assert return_value is False
assert servers == []
def test_control_find_server_socket(monkeypatch, servers):
mock_connect = MagicMock()
mock_send = MagicMock()
monkeypatch.setattr("infection_monkey.control.requests.get", RequestsGetArgument)
monkeypatch.setattr("infection_monkey.control.socket.socket.connect", mock_connect)
monkeypatch.setattr("infection_monkey.control.socket.socket.send", mock_send)
cc = ControlClient(servers)
return_value = cc.find_server(servers)
assert len(servers) == 2
assert return_value is True
assert mock_connect.call_count == 2
assert mock_send.call_count == 2
# TODO: be sure that connect is called with SERVER_3 and SERVER_4
# assert mock_connect.call_args == SERVER_3