Merge branch '2216-modify-controlclient-find-server' into 2216-tcp-relay

PR #2250
This commit is contained in:
Mike Salvatore 2022-09-08 07:52:52 -04:00
commit 5366bba389
7 changed files with 126 additions and 59 deletions

View File

@ -5,7 +5,6 @@ from socket import gethostname
from typing import Mapping, Optional
import requests
from requests.exceptions import ConnectionError
import infection_monkey.tunnel as tunnel
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
@ -63,38 +62,6 @@ class ControlClient:
timeout=MEDIUM_REQUEST_TIMEOUT,
)
def find_server(self, default_tunnel=None):
logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}")
if default_tunnel:
logger.debug("default_tunnel: %s" % (default_tunnel,))
try:
debug_message = "Trying to connect to server: %s" % self.server_address
if self.proxies:
debug_message += " through proxies: %s" % self.proxies
logger.debug(debug_message)
requests.get( # noqa: DUO123
f"https://{self.server_address}/api?action=is-up",
verify=False,
proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return True
except ConnectionError as exc:
logger.warning("Error connecting to control server %s: %s", self.server_address, exc)
if self.proxies:
return False
else:
logger.info("Starting tunnel lookup...")
proxy_find = tunnel.find_tunnel(default=default_tunnel)
if proxy_find:
self.set_proxies(proxy_find)
return self.find_server()
else:
logger.info("No tunnel found")
return False
def set_proxies(self, proxy_find):
"""
Note: The proxy schema changes between different versions of requests and urllib3,

View File

@ -43,6 +43,10 @@ from infection_monkey.model import VictimHostFactory
from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces
from infection_monkey.network.relay import TCPRelay
from infection_monkey.network.relay.utils import (
find_server,
send_remove_from_waitlist_control_message_to_relays,
)
from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
@ -96,8 +100,12 @@ class InfectionMonkey:
logger.info("Monkey is initializing...")
self._singleton = SystemSingleton()
self._opts = self._get_arguments(args)
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.servers)
self._control_client = ControlClient(self._opts.servers)
# TODO: Revisit variable names
server = self._get_server()
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
self._control_client = ControlClient(server_address=server)
# TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object
ControlClient.control_client_object = self._control_client
@ -117,6 +125,19 @@ class InfectionMonkey:
return opts
def _get_server(self):
servers_iterator = (s for s in self._opts.servers)
server = find_server(servers_iterator)
if server:
logger.info(f"Successfully connected to the island via {server}")
else:
raise Exception(
f"Failed to connect to the island via any known servers: {self._opts.servers}"
)
send_remove_from_waitlist_control_message_to_relays(servers_iterator)
return server
@staticmethod
def _log_arguments(args):
arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()])
@ -130,7 +151,7 @@ class InfectionMonkey:
logger.info("Agent is starting...")
logger.info(f"Agent GUID: {GUID}")
self._connect_to_island()
self._control_client.wakeup(parent=self._opts.parent)
# TODO: Reevaluate who is responsible to send this information
if is_windows_os():
@ -148,24 +169,6 @@ class InfectionMonkey:
self._setup()
self._master.start()
def _connect_to_island(self):
# Sets island's IP and port for monkey to communicate to
if self._current_server_is_set():
logger.debug(f"Default server set to: {self._control_client.server_address}")
else:
raise Exception(
f"Failed to connect to the island via "
f"any known server address: {self._opts.servers}"
)
self._control_client.wakeup(parent=self._opts.parent)
def _current_server_is_set(self) -> bool:
if self._control_client.find_server(default_tunnel=self._opts.servers):
return True
return False
def _setup(self):
logger.debug("Starting the setup phase.")

View File

@ -1,4 +1,7 @@
from .relay_connection_handler import RelayConnectionHandler, RELAY_CONTROL_MESSAGE
from .relay_connection_handler import (
RelayConnectionHandler,
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST,
)
from .relay_user_handler import RelayUser, RelayUserHandler
from .sockets_pipe import SocketsPipe
from .tcp_connection_handler import TCPConnectionHandler

View File

@ -4,7 +4,7 @@ from ipaddress import IPv4Address
from .relay_user_handler import RelayUserHandler
from .tcp_pipe_spawner import TCPPipeSpawner
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -"
class RelayConnectionHandler:
@ -25,7 +25,7 @@ class RelayConnectionHandler:
control_message = sock.recv(socket.MSG_PEEK)
if control_message.startswith(RELAY_CONTROL_MESSAGE):
if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST):
self._relay_user_handler.disconnect_user(addr)
else:
self._relay_user_handler.add_relay_user(addr)

View File

@ -0,0 +1,62 @@
import logging
import socket
from typing import Iterable, Optional
import requests
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import create_daemon_thread
logger = logging.getLogger(__name__)
def find_server(servers: Iterable[str]) -> Optional[str]:
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
for server in servers:
logger.debug(f"Trying to connect to server: {server}")
try:
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return server
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return None
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
for server in servers:
t = create_daemon_thread(
target=_send_remove_from_waitlist_control_message_to_relay,
name="SendRemoveFromWaitlistControlMessageToRelaysThread",
args=(server,),
)
t.start()
def _send_remove_from_waitlist_control_message_to_relay(server: str):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
ip, port = address_to_ip_port(server)
logger.info(f"Control message was sent to the server/relay {server}")
try:
d_socket.connect((ip, int(port)))
d_socket.send(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST)
except OSError as err:
logger.error(f"Error connecting to socket {server}: {err}")

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
import pytest
from monkey.infection_monkey.network.relay import (
RELAY_CONTROL_MESSAGE,
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST,
RelayConnectionHandler,
RelayUserHandler,
TCPPipeSpawner,
@ -27,7 +27,7 @@ def relay_user_handler():
@pytest.fixture
def close_socket():
sock = MagicMock(spec=socket.socket)
sock.recv.return_value = RELAY_CONTROL_MESSAGE
sock.recv.return_value = RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
sock.getpeername.return_value = (USER_ADDRESS, 12345)
return sock

View File

@ -0,0 +1,32 @@
import pytest
import requests
import requests_mock
from infection_monkey.network.relay.utils import find_server
SERVER_1 = "1.1.1.1:12312"
SERVER_2 = "2.2.2.2:4321"
SERVER_3 = "3.3.3.3:3142"
SERVER_4 = "4.4.4.4:5000"
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@pytest.mark.parametrize(
"expected_server,server_response_pairs",
[
(None, [(server, {"exc": requests.exceptions.ConnectionError}) for server in servers]),
(
SERVER_2,
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
+ [(server, {"text": ""}) for server in servers[1:]],
),
],
)
def test_find_server(expected_server, server_response_pairs):
with requests_mock.Mocker() as mock:
for server, response in server_response_pairs:
mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server