Merge branch '2216-fix-connection-issues' into 2216-tcp-relay

This commit is contained in:
Mike Salvatore 2022-09-14 07:06:03 -04:00
commit 1e3e6c9052
15 changed files with 325 additions and 82 deletions

View File

@ -110,6 +110,7 @@ class InfectionMonkey:
# TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object
ControlClient.control_client_object = self._control_client
self._control_channel = None
self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth
self._master = None
@ -136,6 +137,10 @@ class InfectionMonkey:
raise Exception(
f"Failed to connect to the island via any known servers: {self._opts.servers}"
)
# Note: Since we pass the address for each of our interfaces to the exploited
# machines, is it possible for a machine to unintentionally unregister itself from the
# relay if it is able to connect to the relay over multiple interfaces?
send_remove_from_waitlist_control_message_to_relays(servers_iterator)
return server
@ -177,10 +182,10 @@ class InfectionMonkey:
if firewall.is_enabled():
firewall.add_firewall_rule()
control_channel = ControlChannel(self._control_client.server_address, GUID)
control_channel.register_agent(self._opts.parent)
self._control_channel = ControlChannel(self._control_client.server_address, GUID)
self._control_channel.register_agent(self._opts.parent)
config = control_channel.get_config()
config = self._control_channel.get_config()
relay_port = get_free_tcp_port()
self._relay = TCPRelay(
@ -204,9 +209,8 @@ class InfectionMonkey:
local_network_interfaces = InfectionMonkey._get_local_network_interfaces()
# TODO control_channel and control_client have same responsibilities, merge them
control_channel = ControlChannel(self._control_client.server_address, GUID)
propagation_credentials_repository = AggregatingPropagationCredentialsRepository(
control_channel
self._control_channel
)
event_queue = PyPubSubAgentEventQueue(Publisher())
@ -226,7 +230,7 @@ class InfectionMonkey:
puppet,
telemetry_messenger,
victim_host_factory,
control_channel,
self._control_channel,
local_network_interfaces,
propagation_credentials_repository,
)
@ -393,10 +397,7 @@ class InfectionMonkey:
self._master.cleanup()
reset_signal_handlers()
if self._relay and self._relay.is_alive():
self._relay.stop()
self._relay.join(timeout=60)
self._stop_relay()
if firewall.is_enabled():
firewall.remove_firewall_rule()
@ -420,6 +421,16 @@ class InfectionMonkey:
logger.info("Monkey is shutting down")
def _stop_relay(self):
if self._relay and self._relay.is_alive():
self._relay.stop()
while self._relay.is_alive() and not self._control_channel.should_agent_stop():
self._relay.join(timeout=5)
if self._control_channel.should_agent_stop():
self._relay.join(timeout=60)
def _close_tunnel(self):
logger.info(f"Quitting tunnel {self._cmd_island_ip}")
notify_disconnect(self._cmd_island_ip, self._cmd_island_port)

View File

@ -4,10 +4,12 @@ import struct
from dataclasses import dataclass
from ipaddress import IPv4Interface
from random import shuffle # noqa: DUO102
from typing import List
from threading import Lock
from typing import Dict, List, Set
import netifaces
import psutil
from egg_timer import EggTimer
from infection_monkey.utils.environment import is_windows_os
@ -120,20 +122,93 @@ else:
return routes
def get_free_tcp_port(min_range=1024, max_range=65535):
class TCPPortSelector:
"""
Select an available TCP port that a new server can listen on
in_use = {conn.laddr[1] for conn in psutil.net_connections()}
Examines the system to find which ports are not in use and makes an intelligent decision
regarding what port can be used to host a server. In multithreaded applications, a race occurs
between the time when the OS reports that a port is free and when the port is actually used. In
other words, two threads which request a free port simultaneously may be handed the same port,
as the OS will report that the port is not in use. To combat this, the TCPPortSelector will
reserve a port for a period of time to give the requester ample time to start their server. Once
the requester's server is listening on the port, the OS will report the port as "LISTEN".
"""
for port in COMMON_PORTS:
if port not in in_use:
return port
def __init__(self):
self._leases: Dict[int, EggTimer] = {}
self._lock = Lock()
min_range = max(1, min_range)
max_range = min(65535, max_range)
ports = list(range(min_range, max_range))
shuffle(ports)
for port in ports:
if port not in in_use:
return port
def get_free_tcp_port(
self, min_range: int = 1024, max_range: int = 65535, lease_time_sec: float = 30
):
"""
Get a free TCP port that a new server can listen on
return None
This function will attempt to provide a well-known port that the caller can listen on. If no
well-known ports are available, a random port will be selected.
:param min_range: The smallest port number a random port can be chosen from, defaults to
1024
:param max_range: The largest port number a random port can be chosen from, defaults to
65535
:param lease_time_sec: The amount of time a port should be reserved for if the OS does not report
it as in use, defaults to 30 seconds
:return: A TCP port number
"""
with self._lock:
ports_in_use = {conn.laddr[1] for conn in psutil.net_connections()}
common_port = self._get_free_common_port(ports_in_use, lease_time_sec)
if common_port is not None:
return common_port
return self._get_free_random_port(ports_in_use, min_range, max_range, lease_time_sec)
def _get_free_common_port(self, ports_in_use: Set[int], lease_time_sec):
for port in COMMON_PORTS:
if self._port_is_available(port, ports_in_use):
self._reserve_port(port, lease_time_sec)
return port
return None
def _get_free_random_port(
self, ports_in_use: Set[int], min_range: int, max_range: int, lease_time_sec: float
):
min_range = max(1, min_range)
max_range = min(65535, max_range)
ports = list(range(min_range, max_range))
shuffle(ports)
for port in ports:
if self._port_is_available(port, ports_in_use):
self._reserve_port(port, lease_time_sec)
return port
return None
def _port_is_available(self, port: int, ports_in_use: Set[int]):
if port in ports_in_use:
return False
if port not in self._leases:
return True
if self._leases[port].is_expired():
return True
return False
def _reserve_port(self, port: int, lease_time_sec: float):
timer = EggTimer()
timer.set(lease_time_sec)
self._leases[port] = timer
# TODO: This function is here because existing components rely on it. Refactor these components to
# accept a TCPPortSelector instance and use that instead.
def get_free_tcp_port(min_range=1024, max_range=65535, lease_time_sec=30):
return get_free_tcp_port.port_selector.get_free_tcp_port(min_range, max_range, lease_time_sec)
get_free_tcp_port.port_selector = TCPPortSelector() # type: ignore[attr-defined]

View File

@ -0,0 +1 @@
SOCKET_TIMEOUT = 10

View File

@ -1,11 +1,14 @@
import socket
from ipaddress import IPv4Address
from logging import getLogger
from .relay_user_handler import RelayUserHandler
from .tcp_pipe_spawner import TCPPipeSpawner
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -"
logger = getLogger(__name__)
class RelayConnectionHandler:
"""Handles new relay connections."""
@ -23,10 +26,15 @@ class RelayConnectionHandler:
addr, _ = sock.getpeername()
addr = IPv4Address(addr)
control_message = sock.recv(socket.MSG_PEEK)
control_message = sock.recv(
len(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST), socket.MSG_PEEK
)
if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST):
self._relay_user_handler.disconnect_user(addr)
else:
self._relay_user_handler.add_relay_user(addr)
self._pipe_spawner.spawn_pipe(sock)
try:
self._pipe_spawner.spawn_pipe(sock)
self._relay_user_handler.add_relay_user(addr)
except OSError as err:
logger.debug(f"Failed to spawn pipe: {err}")

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass
from ipaddress import IPv4Address
from logging import getLogger
from threading import Lock
from typing import Dict
@ -13,6 +14,9 @@ DEFAULT_NEW_CLIENT_TIMEOUT = 2.5 * MEDIUM_REQUEST_TIMEOUT
DEFAULT_DISCONNECT_TIMEOUT = 60 * 2 # Wait up to 2 minutes for clients to disconnect
logger = getLogger(__name__)
@dataclass
class RelayUser:
address: IPv4Address
@ -48,7 +52,9 @@ class RelayUserHandler:
timer = EggTimer()
timer.set(self._client_disconnect_timeout)
self._relay_users[user_address] = RelayUser(user_address, timer)
user = RelayUser(user_address, timer)
self._relay_users[user_address] = user
logger.debug(f"Added relay user {user}")
def add_potential_user(self, user_address: IPv4Address):
"""
@ -60,7 +66,9 @@ class RelayUserHandler:
with self._lock:
timer = EggTimer()
timer.set(self._new_client_timeout)
self._potential_users[user_address] = RelayUser(user_address, timer)
user = RelayUser(user_address, timer)
self._potential_users[user_address] = user
logger.debug(f"Added potential relay user {user}")
def disconnect_user(self, user_address: IPv4Address):
"""
@ -70,6 +78,7 @@ class RelayUserHandler:
"""
with self._lock:
if user_address in self._relay_users:
logger.debug(f"Disconnected user {user_address}")
del_key(self._relay_users, user_address)
def has_potential_users(self) -> bool:

View File

@ -5,8 +5,9 @@ from logging import getLogger
from threading import Thread
from typing import Callable
from .consts import SOCKET_TIMEOUT
READ_BUFFER_SIZE = 8192
SOCKET_READ_TIMEOUT = 10
logger = getLogger(__name__)
@ -14,25 +15,34 @@ logger = getLogger(__name__)
class SocketsPipe(Thread):
"""Manages a pipe between two sockets."""
_thread_count: int = 0
def __init__(
self,
source,
dest,
pipe_closed: Callable[[SocketsPipe], None],
timeout=SOCKET_READ_TIMEOUT,
timeout=SOCKET_TIMEOUT,
):
self.source = source
self.dest = dest
self.timeout = timeout
super().__init__(name=f"SocketsPipeThread-{self.ident}", daemon=True)
super().__init__(name=f"SocketsPipeThread-{self._next_thread_num()}", daemon=True)
self._pipe_closed = pipe_closed
@classmethod
def _next_thread_num(cls):
cls._thread_count += 1
return cls._thread_count
def _pipe(self):
sockets = [self.source, self.dest]
while True:
socket_closed = False
while not socket_closed:
read_list, _, except_list = select.select(sockets, [], sockets, self.timeout)
if except_list:
raise Exception("select() failed on sockets {except_list}")
raise OSError("select() failed on sockets {except_list}")
if not read_list:
raise TimeoutError("pipe did not receive data for {self.timeout} seconds")
@ -42,21 +52,24 @@ class SocketsPipe(Thread):
data = r.recv(READ_BUFFER_SIZE)
if data:
other.sendall(data)
else:
socket_closed = True
break
def run(self):
try:
self._pipe()
except Exception as err:
except OSError as err:
logger.debug(err)
try:
self.source.close()
except Exception as err:
except OSError as err:
logger.debug(f"Error while closing source socket: {err}")
try:
self.dest.close()
except Exception as err:
except OSError as err:
logger.debug(f"Error while closing destination socket: {err}")
self._pipe_closed(self)

View File

@ -1,3 +1,4 @@
import logging
import socket
from threading import Thread
from typing import Callable, List
@ -6,6 +7,8 @@ from infection_monkey.utils.threading import InterruptableThreadMixin
PROXY_TIMEOUT = 2.5
logger = logging.getLogger(__name__)
class TCPConnectionHandler(Thread, InterruptableThreadMixin):
"""Accepts connections on a TCP socket."""
@ -24,18 +27,24 @@ class TCPConnectionHandler(Thread, InterruptableThreadMixin):
InterruptableThreadMixin.__init__(self)
def run(self):
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
l_socket.bind((self.bind_host, self.bind_port))
l_socket.settimeout(PROXY_TIMEOUT)
l_socket.listen(5)
try:
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
l_socket.bind((self.bind_host, self.bind_port))
l_socket.settimeout(PROXY_TIMEOUT)
l_socket.listen(5)
while not self._interrupted.is_set():
try:
source, _ = l_socket.accept()
except socket.timeout:
continue
while not self._interrupted.is_set():
try:
source, _ = l_socket.accept()
except socket.timeout:
continue
for notify_client_connected in self._client_connected:
notify_client_connected(source)
logging.debug(f"New connection received from: {source.getpeername()}")
for notify_client_connected in self._client_connected:
notify_client_connected(source)
except OSError:
logging.exception("Uncaught error in TCPConnectionHandler thread")
finally:
l_socket.close()
l_socket.close()
logging.info("Exiting connection handler.")

View File

@ -1,10 +1,14 @@
import socket
from ipaddress import IPv4Address
from logging import getLogger
from threading import Lock
from typing import Set
from .consts import SOCKET_TIMEOUT
from .sockets_pipe import SocketsPipe
logger = getLogger(__name__)
class TCPPipeSpawner:
"""
@ -22,12 +26,13 @@ class TCPPipeSpawner:
Attempt to create a pipe on between the configured client and the provided socket
:param source: A socket to the connecting client.
:raises socket.error: If a socket to the configured client could not be created.
:raises OSError: If a socket to the configured client could not be created.
"""
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
dest.settimeout(SOCKET_TIMEOUT)
try:
dest.connect((self._target_addr, self._target_port))
except socket.error as err:
dest.connect((str(self._target_addr), self._target_port))
except OSError as err:
source.close()
dest.close()
raise err
@ -35,7 +40,8 @@ class TCPPipeSpawner:
pipe = SocketsPipe(source, dest, self._handle_pipe_closed)
with self._lock:
self._pipes.add(pipe)
pipe.run()
pipe.start()
def has_open_pipes(self) -> bool:
"""Return whether or not the TCPPipeSpawner has any open pipes."""
@ -48,4 +54,5 @@ class TCPPipeSpawner:
def _handle_pipe_closed(self, pipe: SocketsPipe):
with self._lock:
logger.debug(f"Closing pipe {pipe}")
self._pipes.discard(pipe)

View File

@ -1,4 +1,5 @@
from ipaddress import IPv4Address
from logging import getLogger
from threading import Lock, Thread
from time import sleep
@ -10,6 +11,8 @@ from infection_monkey.network.relay import (
)
from infection_monkey.utils.threading import InterruptableThreadMixin
logger = getLogger(__name__)
class TCPRelay(Thread, InterruptableThreadMixin):
"""
@ -23,7 +26,10 @@ class TCPRelay(Thread, InterruptableThreadMixin):
dest_port: int,
client_disconnect_timeout: float,
):
self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout)
self._user_handler = RelayUserHandler(
new_client_timeout=client_disconnect_timeout,
client_disconnect_timeout=client_disconnect_timeout,
)
self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port)
relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler)
self._connection_handler = TCPConnectionHandler(
@ -46,6 +52,7 @@ class TCPRelay(Thread, InterruptableThreadMixin):
self._connection_handler.stop()
self._connection_handler.join()
self._wait_for_pipes_to_close()
logger.info("TCP Relay closed.")
def add_potential_user(self, user_address: IPv4Address):
"""

View File

@ -1,42 +1,75 @@
import logging
import socket
from contextlib import suppress
from ipaddress import IPv4Address
from typing import Iterable, Optional
from typing import Dict, Iterable, Iterator, MutableMapping, Optional
import requests
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import create_daemon_thread
from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread,
run_worker_threads,
)
logger = logging.getLogger(__name__)
# The number of Island servers to test simultaneously. 32 threads seems large enough for all
# practical purposes. Revisit this if it's not.
NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]:
for server in servers:
logger.debug(f"Trying to connect to server: {server}")
server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {}
try:
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
run_worker_threads(
_find_island_server,
"FindIslandServer",
args=(server_iterator, server_results),
num_workers=NUM_FIND_SERVER_WORKERS,
)
for server in server_list:
if server_results[server]:
return server
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return None
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
with suppress(StopIteration):
server = next(servers)
server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool:
logger.debug(f"Trying to connect to server: {server}")
try:
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return True
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return False
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
for server in servers:
t = create_daemon_thread(

View File

@ -1,8 +1,8 @@
import logging
from functools import wraps
from itertools import count
from threading import Event, Thread
from typing import Any, Callable, Iterable, Optional, Tuple
from threading import Event, Lock, Thread
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar
logger = logging.getLogger(__name__)
@ -116,3 +116,19 @@ class InterruptableThreadMixin:
def stop(self):
"""Stop a running thread."""
self._interrupted.set()
T = TypeVar("T")
class ThreadSafeIterator(Iterator[T]):
"""Provides a thread-safe iterator that wraps another iterator"""
def __init__(self, iterator: Iterator[T]):
self._lock = Lock()
self._iterator = iterator
def __next__(self) -> T:
while True:
with self._lock:
return next(self._iterator)

View File

@ -0,0 +1,14 @@
from unittest.mock import MagicMock
from monkey.infection_monkey.network.relay import SocketsPipe
def test_sockets_pipe__name_increments():
sock_in = MagicMock()
sock_out = MagicMock()
pipe1 = SocketsPipe(sock_in, sock_out, None)
assert pipe1.name.endswith("1")
pipe2 = SocketsPipe(sock_in, sock_out, None)
assert pipe2.name.endswith("2")

View File

@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
(
SERVER_2,
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
+ [(server, {"text": ""}) for server in servers[1:]],
+ [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
),
],
)
@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs):
mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server
def test_find_server__multiple_successes():
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError)
mock.get(f"https://{SERVER_2}/api?action=is-up", text="")
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2

View File

@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.info import TCPPortSelector
from infection_monkey.network.ports import COMMON_PORTS
@ -13,28 +13,48 @@ class Connection:
@pytest.mark.parametrize("port", COMMON_PORTS)
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
def test_tcp_port_selector__checks_common_ports(port: int, monkeypatch):
tcp_port_selector = TCPPortSelector()
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is port
assert tcp_port_selector.get_free_tcp_port() is port
def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
def test_tcp_port_selector__checks_other_ports_if_common_ports_unavailable(monkeypatch):
tcp_port_selector = TCPPortSelector()
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is not None
assert tcp_port_selector.get_free_tcp_port() is not None
def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
def test_tcp_port_selector__none_if_no_available_ports(monkeypatch):
tcp_port_selector = TCPPortSelector()
unavailable_ports = [Connection(("", p)) for p in range(65535)]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is None
assert tcp_port_selector.get_free_tcp_port() is None
@pytest.mark.parametrize("common_port", COMMON_PORTS)
def test_tcp_port_selector__checks_common_ports_leases(common_port: int, monkeypatch):
tcp_port_selector = TCPPortSelector()
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not common_port]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
free_port_1 = tcp_port_selector.get_free_tcp_port()
free_port_2 = tcp_port_selector.get_free_tcp_port()
assert free_port_1 == common_port
assert free_port_2 != common_port
assert free_port_2 is not None
assert free_port_2 not in COMMON_PORTS

View File

@ -1,8 +1,10 @@
import logging
from itertools import zip_longest
from threading import Event, current_thread
from typing import Any
from infection_monkey.utils.threading import (
ThreadSafeIterator,
create_daemon_thread,
interruptible_function,
interruptible_iter,
@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt():
assert return_value == 777
assert fn.call_count == 0
def test_thread_safe_iterator():
test_list = [1, 2, 3, 4, 5]
tsi = ThreadSafeIterator(test_list.__iter__())
for actual, expected in zip_longest(tsi, test_list):
assert actual == expected