Merge branch '2216-fix-connection-issues' into 2216-tcp-relay
This commit is contained in:
commit
1e3e6c9052
|
@ -110,6 +110,7 @@ class InfectionMonkey:
|
||||||
# TODO Refactor the telemetry messengers to accept control client
|
# TODO Refactor the telemetry messengers to accept control client
|
||||||
# and remove control_client_object
|
# and remove control_client_object
|
||||||
ControlClient.control_client_object = self._control_client
|
ControlClient.control_client_object = self._control_client
|
||||||
|
self._control_channel = None
|
||||||
self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
|
self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
|
||||||
self._current_depth = self._opts.depth
|
self._current_depth = self._opts.depth
|
||||||
self._master = None
|
self._master = None
|
||||||
|
@ -136,6 +137,10 @@ class InfectionMonkey:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to connect to the island via any known servers: {self._opts.servers}"
|
f"Failed to connect to the island via any known servers: {self._opts.servers}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Note: Since we pass the address for each of our interfaces to the exploited
|
||||||
|
# machines, is it possible for a machine to unintentionally unregister itself from the
|
||||||
|
# relay if it is able to connect to the relay over multiple interfaces?
|
||||||
send_remove_from_waitlist_control_message_to_relays(servers_iterator)
|
send_remove_from_waitlist_control_message_to_relays(servers_iterator)
|
||||||
|
|
||||||
return server
|
return server
|
||||||
|
@ -177,10 +182,10 @@ class InfectionMonkey:
|
||||||
if firewall.is_enabled():
|
if firewall.is_enabled():
|
||||||
firewall.add_firewall_rule()
|
firewall.add_firewall_rule()
|
||||||
|
|
||||||
control_channel = ControlChannel(self._control_client.server_address, GUID)
|
self._control_channel = ControlChannel(self._control_client.server_address, GUID)
|
||||||
control_channel.register_agent(self._opts.parent)
|
self._control_channel.register_agent(self._opts.parent)
|
||||||
|
|
||||||
config = control_channel.get_config()
|
config = self._control_channel.get_config()
|
||||||
|
|
||||||
relay_port = get_free_tcp_port()
|
relay_port = get_free_tcp_port()
|
||||||
self._relay = TCPRelay(
|
self._relay = TCPRelay(
|
||||||
|
@ -204,9 +209,8 @@ class InfectionMonkey:
|
||||||
local_network_interfaces = InfectionMonkey._get_local_network_interfaces()
|
local_network_interfaces = InfectionMonkey._get_local_network_interfaces()
|
||||||
|
|
||||||
# TODO control_channel and control_client have same responsibilities, merge them
|
# TODO control_channel and control_client have same responsibilities, merge them
|
||||||
control_channel = ControlChannel(self._control_client.server_address, GUID)
|
|
||||||
propagation_credentials_repository = AggregatingPropagationCredentialsRepository(
|
propagation_credentials_repository = AggregatingPropagationCredentialsRepository(
|
||||||
control_channel
|
self._control_channel
|
||||||
)
|
)
|
||||||
|
|
||||||
event_queue = PyPubSubAgentEventQueue(Publisher())
|
event_queue = PyPubSubAgentEventQueue(Publisher())
|
||||||
|
@ -226,7 +230,7 @@ class InfectionMonkey:
|
||||||
puppet,
|
puppet,
|
||||||
telemetry_messenger,
|
telemetry_messenger,
|
||||||
victim_host_factory,
|
victim_host_factory,
|
||||||
control_channel,
|
self._control_channel,
|
||||||
local_network_interfaces,
|
local_network_interfaces,
|
||||||
propagation_credentials_repository,
|
propagation_credentials_repository,
|
||||||
)
|
)
|
||||||
|
@ -393,10 +397,7 @@ class InfectionMonkey:
|
||||||
self._master.cleanup()
|
self._master.cleanup()
|
||||||
|
|
||||||
reset_signal_handlers()
|
reset_signal_handlers()
|
||||||
|
self._stop_relay()
|
||||||
if self._relay and self._relay.is_alive():
|
|
||||||
self._relay.stop()
|
|
||||||
self._relay.join(timeout=60)
|
|
||||||
|
|
||||||
if firewall.is_enabled():
|
if firewall.is_enabled():
|
||||||
firewall.remove_firewall_rule()
|
firewall.remove_firewall_rule()
|
||||||
|
@ -420,6 +421,16 @@ class InfectionMonkey:
|
||||||
|
|
||||||
logger.info("Monkey is shutting down")
|
logger.info("Monkey is shutting down")
|
||||||
|
|
||||||
|
def _stop_relay(self):
|
||||||
|
if self._relay and self._relay.is_alive():
|
||||||
|
self._relay.stop()
|
||||||
|
|
||||||
|
while self._relay.is_alive() and not self._control_channel.should_agent_stop():
|
||||||
|
self._relay.join(timeout=5)
|
||||||
|
|
||||||
|
if self._control_channel.should_agent_stop():
|
||||||
|
self._relay.join(timeout=60)
|
||||||
|
|
||||||
def _close_tunnel(self):
|
def _close_tunnel(self):
|
||||||
logger.info(f"Quitting tunnel {self._cmd_island_ip}")
|
logger.info(f"Quitting tunnel {self._cmd_island_ip}")
|
||||||
notify_disconnect(self._cmd_island_ip, self._cmd_island_port)
|
notify_disconnect(self._cmd_island_ip, self._cmd_island_port)
|
||||||
|
|
|
@ -4,10 +4,12 @@ import struct
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from ipaddress import IPv4Interface
|
from ipaddress import IPv4Interface
|
||||||
from random import shuffle # noqa: DUO102
|
from random import shuffle # noqa: DUO102
|
||||||
from typing import List
|
from threading import Lock
|
||||||
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
import netifaces
|
import netifaces
|
||||||
import psutil
|
import psutil
|
||||||
|
from egg_timer import EggTimer
|
||||||
|
|
||||||
from infection_monkey.utils.environment import is_windows_os
|
from infection_monkey.utils.environment import is_windows_os
|
||||||
|
|
||||||
|
@ -120,20 +122,93 @@ else:
|
||||||
return routes
|
return routes
|
||||||
|
|
||||||
|
|
||||||
def get_free_tcp_port(min_range=1024, max_range=65535):
|
class TCPPortSelector:
|
||||||
|
"""
|
||||||
|
Select an available TCP port that a new server can listen on
|
||||||
|
|
||||||
in_use = {conn.laddr[1] for conn in psutil.net_connections()}
|
Examines the system to find which ports are not in use and makes an intelligent decision
|
||||||
|
regarding what port can be used to host a server. In multithreaded applications, a race occurs
|
||||||
|
between the time when the OS reports that a port is free and when the port is actually used. In
|
||||||
|
other words, two threads which request a free port simultaneously may be handed the same port,
|
||||||
|
as the OS will report that the port is not in use. To combat this, the TCPPortSelector will
|
||||||
|
reserve a port for a period of time to give the requester ample time to start their server. Once
|
||||||
|
the requester's server is listening on the port, the OS will report the port as "LISTEN".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._leases: Dict[int, EggTimer] = {}
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def get_free_tcp_port(
|
||||||
|
self, min_range: int = 1024, max_range: int = 65535, lease_time_sec: float = 30
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a free TCP port that a new server can listen on
|
||||||
|
|
||||||
|
This function will attempt to provide a well-known port that the caller can listen on. If no
|
||||||
|
well-known ports are available, a random port will be selected.
|
||||||
|
|
||||||
|
:param min_range: The smallest port number a random port can be chosen from, defaults to
|
||||||
|
1024
|
||||||
|
:param max_range: The largest port number a random port can be chosen from, defaults to
|
||||||
|
65535
|
||||||
|
:param lease_time_sec: The amount of time a port should be reserved for if the OS does not report
|
||||||
|
it as in use, defaults to 30 seconds
|
||||||
|
:return: A TCP port number
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
ports_in_use = {conn.laddr[1] for conn in psutil.net_connections()}
|
||||||
|
|
||||||
|
common_port = self._get_free_common_port(ports_in_use, lease_time_sec)
|
||||||
|
if common_port is not None:
|
||||||
|
return common_port
|
||||||
|
|
||||||
|
return self._get_free_random_port(ports_in_use, min_range, max_range, lease_time_sec)
|
||||||
|
|
||||||
|
def _get_free_common_port(self, ports_in_use: Set[int], lease_time_sec):
|
||||||
for port in COMMON_PORTS:
|
for port in COMMON_PORTS:
|
||||||
if port not in in_use:
|
if self._port_is_available(port, ports_in_use):
|
||||||
|
self._reserve_port(port, lease_time_sec)
|
||||||
return port
|
return port
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_free_random_port(
|
||||||
|
self, ports_in_use: Set[int], min_range: int, max_range: int, lease_time_sec: float
|
||||||
|
):
|
||||||
min_range = max(1, min_range)
|
min_range = max(1, min_range)
|
||||||
max_range = min(65535, max_range)
|
max_range = min(65535, max_range)
|
||||||
ports = list(range(min_range, max_range))
|
ports = list(range(min_range, max_range))
|
||||||
shuffle(ports)
|
shuffle(ports)
|
||||||
for port in ports:
|
for port in ports:
|
||||||
if port not in in_use:
|
if self._port_is_available(port, ports_in_use):
|
||||||
|
self._reserve_port(port, lease_time_sec)
|
||||||
return port
|
return port
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _port_is_available(self, port: int, ports_in_use: Set[int]):
|
||||||
|
if port in ports_in_use:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if port not in self._leases:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self._leases[port].is_expired():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _reserve_port(self, port: int, lease_time_sec: float):
|
||||||
|
timer = EggTimer()
|
||||||
|
timer.set(lease_time_sec)
|
||||||
|
self._leases[port] = timer
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This function is here because existing components rely on it. Refactor these components to
|
||||||
|
# accept a TCPPortSelector instance and use that instead.
|
||||||
|
def get_free_tcp_port(min_range=1024, max_range=65535, lease_time_sec=30):
|
||||||
|
return get_free_tcp_port.port_selector.get_free_tcp_port(min_range, max_range, lease_time_sec)
|
||||||
|
|
||||||
|
|
||||||
|
get_free_tcp_port.port_selector = TCPPortSelector() # type: ignore[attr-defined]
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
SOCKET_TIMEOUT = 10
|
|
@ -1,11 +1,14 @@
|
||||||
import socket
|
import socket
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
from .relay_user_handler import RelayUserHandler
|
from .relay_user_handler import RelayUserHandler
|
||||||
from .tcp_pipe_spawner import TCPPipeSpawner
|
from .tcp_pipe_spawner import TCPPipeSpawner
|
||||||
|
|
||||||
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -"
|
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -"
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RelayConnectionHandler:
|
class RelayConnectionHandler:
|
||||||
"""Handles new relay connections."""
|
"""Handles new relay connections."""
|
||||||
|
@ -23,10 +26,15 @@ class RelayConnectionHandler:
|
||||||
addr, _ = sock.getpeername()
|
addr, _ = sock.getpeername()
|
||||||
addr = IPv4Address(addr)
|
addr = IPv4Address(addr)
|
||||||
|
|
||||||
control_message = sock.recv(socket.MSG_PEEK)
|
control_message = sock.recv(
|
||||||
|
len(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST), socket.MSG_PEEK
|
||||||
|
)
|
||||||
|
|
||||||
if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST):
|
if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST):
|
||||||
self._relay_user_handler.disconnect_user(addr)
|
self._relay_user_handler.disconnect_user(addr)
|
||||||
else:
|
else:
|
||||||
self._relay_user_handler.add_relay_user(addr)
|
try:
|
||||||
self._pipe_spawner.spawn_pipe(sock)
|
self._pipe_spawner.spawn_pipe(sock)
|
||||||
|
self._relay_user_handler.add_relay_user(addr)
|
||||||
|
except OSError as err:
|
||||||
|
logger.debug(f"Failed to spawn pipe: {err}")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
from logging import getLogger
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
@ -13,6 +14,9 @@ DEFAULT_NEW_CLIENT_TIMEOUT = 2.5 * MEDIUM_REQUEST_TIMEOUT
|
||||||
DEFAULT_DISCONNECT_TIMEOUT = 60 * 2 # Wait up to 2 minutes for clients to disconnect
|
DEFAULT_DISCONNECT_TIMEOUT = 60 * 2 # Wait up to 2 minutes for clients to disconnect
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RelayUser:
|
class RelayUser:
|
||||||
address: IPv4Address
|
address: IPv4Address
|
||||||
|
@ -48,7 +52,9 @@ class RelayUserHandler:
|
||||||
|
|
||||||
timer = EggTimer()
|
timer = EggTimer()
|
||||||
timer.set(self._client_disconnect_timeout)
|
timer.set(self._client_disconnect_timeout)
|
||||||
self._relay_users[user_address] = RelayUser(user_address, timer)
|
user = RelayUser(user_address, timer)
|
||||||
|
self._relay_users[user_address] = user
|
||||||
|
logger.debug(f"Added relay user {user}")
|
||||||
|
|
||||||
def add_potential_user(self, user_address: IPv4Address):
|
def add_potential_user(self, user_address: IPv4Address):
|
||||||
"""
|
"""
|
||||||
|
@ -60,7 +66,9 @@ class RelayUserHandler:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
timer = EggTimer()
|
timer = EggTimer()
|
||||||
timer.set(self._new_client_timeout)
|
timer.set(self._new_client_timeout)
|
||||||
self._potential_users[user_address] = RelayUser(user_address, timer)
|
user = RelayUser(user_address, timer)
|
||||||
|
self._potential_users[user_address] = user
|
||||||
|
logger.debug(f"Added potential relay user {user}")
|
||||||
|
|
||||||
def disconnect_user(self, user_address: IPv4Address):
|
def disconnect_user(self, user_address: IPv4Address):
|
||||||
"""
|
"""
|
||||||
|
@ -70,6 +78,7 @@ class RelayUserHandler:
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if user_address in self._relay_users:
|
if user_address in self._relay_users:
|
||||||
|
logger.debug(f"Disconnected user {user_address}")
|
||||||
del_key(self._relay_users, user_address)
|
del_key(self._relay_users, user_address)
|
||||||
|
|
||||||
def has_potential_users(self) -> bool:
|
def has_potential_users(self) -> bool:
|
||||||
|
|
|
@ -5,8 +5,9 @@ from logging import getLogger
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
from .consts import SOCKET_TIMEOUT
|
||||||
|
|
||||||
READ_BUFFER_SIZE = 8192
|
READ_BUFFER_SIZE = 8192
|
||||||
SOCKET_READ_TIMEOUT = 10
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -14,25 +15,34 @@ logger = getLogger(__name__)
|
||||||
class SocketsPipe(Thread):
|
class SocketsPipe(Thread):
|
||||||
"""Manages a pipe between two sockets."""
|
"""Manages a pipe between two sockets."""
|
||||||
|
|
||||||
|
_thread_count: int = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
source,
|
source,
|
||||||
dest,
|
dest,
|
||||||
pipe_closed: Callable[[SocketsPipe], None],
|
pipe_closed: Callable[[SocketsPipe], None],
|
||||||
timeout=SOCKET_READ_TIMEOUT,
|
timeout=SOCKET_TIMEOUT,
|
||||||
):
|
):
|
||||||
self.source = source
|
self.source = source
|
||||||
self.dest = dest
|
self.dest = dest
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
super().__init__(name=f"SocketsPipeThread-{self.ident}", daemon=True)
|
super().__init__(name=f"SocketsPipeThread-{self._next_thread_num()}", daemon=True)
|
||||||
self._pipe_closed = pipe_closed
|
self._pipe_closed = pipe_closed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _next_thread_num(cls):
|
||||||
|
cls._thread_count += 1
|
||||||
|
return cls._thread_count
|
||||||
|
|
||||||
def _pipe(self):
|
def _pipe(self):
|
||||||
sockets = [self.source, self.dest]
|
sockets = [self.source, self.dest]
|
||||||
while True:
|
socket_closed = False
|
||||||
|
|
||||||
|
while not socket_closed:
|
||||||
read_list, _, except_list = select.select(sockets, [], sockets, self.timeout)
|
read_list, _, except_list = select.select(sockets, [], sockets, self.timeout)
|
||||||
if except_list:
|
if except_list:
|
||||||
raise Exception("select() failed on sockets {except_list}")
|
raise OSError("select() failed on sockets {except_list}")
|
||||||
|
|
||||||
if not read_list:
|
if not read_list:
|
||||||
raise TimeoutError("pipe did not receive data for {self.timeout} seconds")
|
raise TimeoutError("pipe did not receive data for {self.timeout} seconds")
|
||||||
|
@ -42,21 +52,24 @@ class SocketsPipe(Thread):
|
||||||
data = r.recv(READ_BUFFER_SIZE)
|
data = r.recv(READ_BUFFER_SIZE)
|
||||||
if data:
|
if data:
|
||||||
other.sendall(data)
|
other.sendall(data)
|
||||||
|
else:
|
||||||
|
socket_closed = True
|
||||||
|
break
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
try:
|
try:
|
||||||
self._pipe()
|
self._pipe()
|
||||||
except Exception as err:
|
except OSError as err:
|
||||||
logger.debug(err)
|
logger.debug(err)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.source.close()
|
self.source.close()
|
||||||
except Exception as err:
|
except OSError as err:
|
||||||
logger.debug(f"Error while closing source socket: {err}")
|
logger.debug(f"Error while closing source socket: {err}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.dest.close()
|
self.dest.close()
|
||||||
except Exception as err:
|
except OSError as err:
|
||||||
logger.debug(f"Error while closing destination socket: {err}")
|
logger.debug(f"Error while closing destination socket: {err}")
|
||||||
|
|
||||||
self._pipe_closed(self)
|
self._pipe_closed(self)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
@ -6,6 +7,8 @@ from infection_monkey.utils.threading import InterruptableThreadMixin
|
||||||
|
|
||||||
PROXY_TIMEOUT = 2.5
|
PROXY_TIMEOUT = 2.5
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TCPConnectionHandler(Thread, InterruptableThreadMixin):
|
class TCPConnectionHandler(Thread, InterruptableThreadMixin):
|
||||||
"""Accepts connections on a TCP socket."""
|
"""Accepts connections on a TCP socket."""
|
||||||
|
@ -24,6 +27,7 @@ class TCPConnectionHandler(Thread, InterruptableThreadMixin):
|
||||||
InterruptableThreadMixin.__init__(self)
|
InterruptableThreadMixin.__init__(self)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
try:
|
||||||
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
l_socket.bind((self.bind_host, self.bind_port))
|
l_socket.bind((self.bind_host, self.bind_port))
|
||||||
l_socket.settimeout(PROXY_TIMEOUT)
|
l_socket.settimeout(PROXY_TIMEOUT)
|
||||||
|
@ -35,7 +39,12 @@ class TCPConnectionHandler(Thread, InterruptableThreadMixin):
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
logging.debug(f"New connection received from: {source.getpeername()}")
|
||||||
for notify_client_connected in self._client_connected:
|
for notify_client_connected in self._client_connected:
|
||||||
notify_client_connected(source)
|
notify_client_connected(source)
|
||||||
|
except OSError:
|
||||||
|
logging.exception("Uncaught error in TCPConnectionHandler thread")
|
||||||
|
finally:
|
||||||
l_socket.close()
|
l_socket.close()
|
||||||
|
|
||||||
|
logging.info("Exiting connection handler.")
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
import socket
|
import socket
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
from logging import getLogger
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
|
from .consts import SOCKET_TIMEOUT
|
||||||
from .sockets_pipe import SocketsPipe
|
from .sockets_pipe import SocketsPipe
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TCPPipeSpawner:
|
class TCPPipeSpawner:
|
||||||
"""
|
"""
|
||||||
|
@ -22,12 +26,13 @@ class TCPPipeSpawner:
|
||||||
Attempt to create a pipe on between the configured client and the provided socket
|
Attempt to create a pipe on between the configured client and the provided socket
|
||||||
|
|
||||||
:param source: A socket to the connecting client.
|
:param source: A socket to the connecting client.
|
||||||
:raises socket.error: If a socket to the configured client could not be created.
|
:raises OSError: If a socket to the configured client could not be created.
|
||||||
"""
|
"""
|
||||||
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
dest.settimeout(SOCKET_TIMEOUT)
|
||||||
try:
|
try:
|
||||||
dest.connect((self._target_addr, self._target_port))
|
dest.connect((str(self._target_addr), self._target_port))
|
||||||
except socket.error as err:
|
except OSError as err:
|
||||||
source.close()
|
source.close()
|
||||||
dest.close()
|
dest.close()
|
||||||
raise err
|
raise err
|
||||||
|
@ -35,7 +40,8 @@ class TCPPipeSpawner:
|
||||||
pipe = SocketsPipe(source, dest, self._handle_pipe_closed)
|
pipe = SocketsPipe(source, dest, self._handle_pipe_closed)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._pipes.add(pipe)
|
self._pipes.add(pipe)
|
||||||
pipe.run()
|
|
||||||
|
pipe.start()
|
||||||
|
|
||||||
def has_open_pipes(self) -> bool:
|
def has_open_pipes(self) -> bool:
|
||||||
"""Return whether or not the TCPPipeSpawner has any open pipes."""
|
"""Return whether or not the TCPPipeSpawner has any open pipes."""
|
||||||
|
@ -48,4 +54,5 @@ class TCPPipeSpawner:
|
||||||
|
|
||||||
def _handle_pipe_closed(self, pipe: SocketsPipe):
|
def _handle_pipe_closed(self, pipe: SocketsPipe):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
logger.debug(f"Closing pipe {pipe}")
|
||||||
self._pipes.discard(pipe)
|
self._pipes.discard(pipe)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
from logging import getLogger
|
||||||
from threading import Lock, Thread
|
from threading import Lock, Thread
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
|
@ -10,6 +11,8 @@ from infection_monkey.network.relay import (
|
||||||
)
|
)
|
||||||
from infection_monkey.utils.threading import InterruptableThreadMixin
|
from infection_monkey.utils.threading import InterruptableThreadMixin
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TCPRelay(Thread, InterruptableThreadMixin):
|
class TCPRelay(Thread, InterruptableThreadMixin):
|
||||||
"""
|
"""
|
||||||
|
@ -23,7 +26,10 @@ class TCPRelay(Thread, InterruptableThreadMixin):
|
||||||
dest_port: int,
|
dest_port: int,
|
||||||
client_disconnect_timeout: float,
|
client_disconnect_timeout: float,
|
||||||
):
|
):
|
||||||
self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout)
|
self._user_handler = RelayUserHandler(
|
||||||
|
new_client_timeout=client_disconnect_timeout,
|
||||||
|
client_disconnect_timeout=client_disconnect_timeout,
|
||||||
|
)
|
||||||
self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port)
|
self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port)
|
||||||
relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler)
|
relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler)
|
||||||
self._connection_handler = TCPConnectionHandler(
|
self._connection_handler = TCPConnectionHandler(
|
||||||
|
@ -46,6 +52,7 @@ class TCPRelay(Thread, InterruptableThreadMixin):
|
||||||
self._connection_handler.stop()
|
self._connection_handler.stop()
|
||||||
self._connection_handler.join()
|
self._connection_handler.join()
|
||||||
self._wait_for_pipes_to_close()
|
self._wait_for_pipes_to_close()
|
||||||
|
logger.info("TCP Relay closed.")
|
||||||
|
|
||||||
def add_potential_user(self, user_address: IPv4Address):
|
def add_potential_user(self, user_address: IPv4Address):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,20 +1,53 @@
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
|
from contextlib import suppress
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
from typing import Iterable, Optional
|
from typing import Dict, Iterable, Iterator, MutableMapping, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
|
||||||
from common.network.network_utils import address_to_ip_port
|
from common.network.network_utils import address_to_ip_port
|
||||||
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
|
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
|
||||||
from infection_monkey.utils.threading import create_daemon_thread
|
from infection_monkey.utils.threading import (
|
||||||
|
ThreadSafeIterator,
|
||||||
|
create_daemon_thread,
|
||||||
|
run_worker_threads,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# The number of Island servers to test simultaneously. 32 threads seems large enough for all
|
||||||
|
# practical purposes. Revisit this if it's not.
|
||||||
|
NUM_FIND_SERVER_WORKERS = 32
|
||||||
|
|
||||||
|
|
||||||
def find_server(servers: Iterable[str]) -> Optional[str]:
|
def find_server(servers: Iterable[str]) -> Optional[str]:
|
||||||
for server in servers:
|
server_list = list(servers)
|
||||||
|
server_iterator = ThreadSafeIterator(server_list.__iter__())
|
||||||
|
server_results: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
run_worker_threads(
|
||||||
|
_find_island_server,
|
||||||
|
"FindIslandServer",
|
||||||
|
args=(server_iterator, server_results),
|
||||||
|
num_workers=NUM_FIND_SERVER_WORKERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
for server in server_list:
|
||||||
|
if server_results[server]:
|
||||||
|
return server
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
|
||||||
|
with suppress(StopIteration):
|
||||||
|
server = next(servers)
|
||||||
|
server_status[server] = _check_if_island_server(server)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_if_island_server(server: str) -> bool:
|
||||||
logger.debug(f"Trying to connect to server: {server}")
|
logger.debug(f"Trying to connect to server: {server}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -24,7 +57,7 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
|
||||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
return server
|
return True
|
||||||
except requests.exceptions.ConnectionError as err:
|
except requests.exceptions.ConnectionError as err:
|
||||||
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
||||||
except TimeoutError as err:
|
except TimeoutError as err:
|
||||||
|
@ -34,7 +67,7 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
|
||||||
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return False
|
||||||
|
|
||||||
|
|
||||||
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
|
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from threading import Event, Thread
|
from threading import Event, Lock, Thread
|
||||||
from typing import Any, Callable, Iterable, Optional, Tuple
|
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -116,3 +116,19 @@ class InterruptableThreadMixin:
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop a running thread."""
|
"""Stop a running thread."""
|
||||||
self._interrupted.set()
|
self._interrupted.set()
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSafeIterator(Iterator[T]):
|
||||||
|
"""Provides a thread-safe iterator that wraps another iterator"""
|
||||||
|
|
||||||
|
def __init__(self, iterator: Iterator[T]):
|
||||||
|
self._lock = Lock()
|
||||||
|
self._iterator = iterator
|
||||||
|
|
||||||
|
def __next__(self) -> T:
|
||||||
|
while True:
|
||||||
|
with self._lock:
|
||||||
|
return next(self._iterator)
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from monkey.infection_monkey.network.relay import SocketsPipe
|
||||||
|
|
||||||
|
|
||||||
|
def test_sockets_pipe__name_increments():
|
||||||
|
sock_in = MagicMock()
|
||||||
|
sock_out = MagicMock()
|
||||||
|
|
||||||
|
pipe1 = SocketsPipe(sock_in, sock_out, None)
|
||||||
|
assert pipe1.name.endswith("1")
|
||||||
|
|
||||||
|
pipe2 = SocketsPipe(sock_in, sock_out, None)
|
||||||
|
assert pipe2.name.endswith("2")
|
|
@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||||
(
|
(
|
||||||
SERVER_2,
|
SERVER_2,
|
||||||
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
|
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
|
||||||
+ [(server, {"text": ""}) for server in servers[1:]],
|
+ [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs):
|
||||||
mock.get(f"https://{server}/api?action=is-up", **response)
|
mock.get(f"https://{server}/api?action=is-up", **response)
|
||||||
|
|
||||||
assert find_server(servers) is expected_server
|
assert find_server(servers) is expected_server
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_server__multiple_successes():
|
||||||
|
with requests_mock.Mocker() as mock:
|
||||||
|
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError)
|
||||||
|
mock.get(f"https://{SERVER_2}/api?action=is-up", text="")
|
||||||
|
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
|
||||||
|
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
|
||||||
|
|
||||||
|
assert find_server(servers) == SERVER_2
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from infection_monkey.network.info import get_free_tcp_port
|
from infection_monkey.network.info import TCPPortSelector
|
||||||
from infection_monkey.network.ports import COMMON_PORTS
|
from infection_monkey.network.ports import COMMON_PORTS
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,28 +13,48 @@ class Connection:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("port", COMMON_PORTS)
|
@pytest.mark.parametrize("port", COMMON_PORTS)
|
||||||
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
|
def test_tcp_port_selector__checks_common_ports(port: int, monkeypatch):
|
||||||
|
tcp_port_selector = TCPPortSelector()
|
||||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
|
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||||
)
|
)
|
||||||
assert get_free_tcp_port() is port
|
assert tcp_port_selector.get_free_tcp_port() is port
|
||||||
|
|
||||||
|
|
||||||
def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
|
def test_tcp_port_selector__checks_other_ports_if_common_ports_unavailable(monkeypatch):
|
||||||
|
tcp_port_selector = TCPPortSelector()
|
||||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
|
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||||
)
|
)
|
||||||
|
|
||||||
assert get_free_tcp_port() is not None
|
assert tcp_port_selector.get_free_tcp_port() is not None
|
||||||
|
|
||||||
|
|
||||||
def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
|
def test_tcp_port_selector__none_if_no_available_ports(monkeypatch):
|
||||||
|
tcp_port_selector = TCPPortSelector()
|
||||||
unavailable_ports = [Connection(("", p)) for p in range(65535)]
|
unavailable_ports = [Connection(("", p)) for p in range(65535)]
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||||
)
|
)
|
||||||
|
|
||||||
assert get_free_tcp_port() is None
|
assert tcp_port_selector.get_free_tcp_port() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("common_port", COMMON_PORTS)
|
||||||
|
def test_tcp_port_selector__checks_common_ports_leases(common_port: int, monkeypatch):
|
||||||
|
tcp_port_selector = TCPPortSelector()
|
||||||
|
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not common_port]
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||||
|
)
|
||||||
|
|
||||||
|
free_port_1 = tcp_port_selector.get_free_tcp_port()
|
||||||
|
free_port_2 = tcp_port_selector.get_free_tcp_port()
|
||||||
|
|
||||||
|
assert free_port_1 == common_port
|
||||||
|
assert free_port_2 != common_port
|
||||||
|
assert free_port_2 is not None
|
||||||
|
assert free_port_2 not in COMMON_PORTS
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
|
from itertools import zip_longest
|
||||||
from threading import Event, current_thread
|
from threading import Event, current_thread
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from infection_monkey.utils.threading import (
|
from infection_monkey.utils.threading import (
|
||||||
|
ThreadSafeIterator,
|
||||||
create_daemon_thread,
|
create_daemon_thread,
|
||||||
interruptible_function,
|
interruptible_function,
|
||||||
interruptible_iter,
|
interruptible_iter,
|
||||||
|
@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt():
|
||||||
|
|
||||||
assert return_value == 777
|
assert return_value == 777
|
||||||
assert fn.call_count == 0
|
assert fn.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_safe_iterator():
|
||||||
|
test_list = [1, 2, 3, 4, 5]
|
||||||
|
tsi = ThreadSafeIterator(test_list.__iter__())
|
||||||
|
|
||||||
|
for actual, expected in zip_longest(tsi, test_list):
|
||||||
|
assert actual == expected
|
||||||
|
|
Loading…
Reference in New Issue