Merge branch '2216-fix-connection-issues' into 2216-tcp-relay
This commit is contained in:
commit
1e3e6c9052
|
@ -110,6 +110,7 @@ class InfectionMonkey:
|
|||
# TODO Refactor the telemetry messengers to accept control client
|
||||
# and remove control_client_object
|
||||
ControlClient.control_client_object = self._control_client
|
||||
self._control_channel = None
|
||||
self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
|
||||
self._current_depth = self._opts.depth
|
||||
self._master = None
|
||||
|
@ -136,6 +137,10 @@ class InfectionMonkey:
|
|||
raise Exception(
|
||||
f"Failed to connect to the island via any known servers: {self._opts.servers}"
|
||||
)
|
||||
|
||||
# Note: Since we pass the address for each of our interfaces to the exploited
|
||||
# machines, is it possible for a machine to unintentionally unregister itself from the
|
||||
# relay if it is able to connect to the relay over multiple interfaces?
|
||||
send_remove_from_waitlist_control_message_to_relays(servers_iterator)
|
||||
|
||||
return server
|
||||
|
@ -177,10 +182,10 @@ class InfectionMonkey:
|
|||
if firewall.is_enabled():
|
||||
firewall.add_firewall_rule()
|
||||
|
||||
control_channel = ControlChannel(self._control_client.server_address, GUID)
|
||||
control_channel.register_agent(self._opts.parent)
|
||||
self._control_channel = ControlChannel(self._control_client.server_address, GUID)
|
||||
self._control_channel.register_agent(self._opts.parent)
|
||||
|
||||
config = control_channel.get_config()
|
||||
config = self._control_channel.get_config()
|
||||
|
||||
relay_port = get_free_tcp_port()
|
||||
self._relay = TCPRelay(
|
||||
|
@ -204,9 +209,8 @@ class InfectionMonkey:
|
|||
local_network_interfaces = InfectionMonkey._get_local_network_interfaces()
|
||||
|
||||
# TODO control_channel and control_client have same responsibilities, merge them
|
||||
control_channel = ControlChannel(self._control_client.server_address, GUID)
|
||||
propagation_credentials_repository = AggregatingPropagationCredentialsRepository(
|
||||
control_channel
|
||||
self._control_channel
|
||||
)
|
||||
|
||||
event_queue = PyPubSubAgentEventQueue(Publisher())
|
||||
|
@ -226,7 +230,7 @@ class InfectionMonkey:
|
|||
puppet,
|
||||
telemetry_messenger,
|
||||
victim_host_factory,
|
||||
control_channel,
|
||||
self._control_channel,
|
||||
local_network_interfaces,
|
||||
propagation_credentials_repository,
|
||||
)
|
||||
|
@ -393,10 +397,7 @@ class InfectionMonkey:
|
|||
self._master.cleanup()
|
||||
|
||||
reset_signal_handlers()
|
||||
|
||||
if self._relay and self._relay.is_alive():
|
||||
self._relay.stop()
|
||||
self._relay.join(timeout=60)
|
||||
self._stop_relay()
|
||||
|
||||
if firewall.is_enabled():
|
||||
firewall.remove_firewall_rule()
|
||||
|
@ -420,6 +421,16 @@ class InfectionMonkey:
|
|||
|
||||
logger.info("Monkey is shutting down")
|
||||
|
||||
def _stop_relay(self):
|
||||
if self._relay and self._relay.is_alive():
|
||||
self._relay.stop()
|
||||
|
||||
while self._relay.is_alive() and not self._control_channel.should_agent_stop():
|
||||
self._relay.join(timeout=5)
|
||||
|
||||
if self._control_channel.should_agent_stop():
|
||||
self._relay.join(timeout=60)
|
||||
|
||||
def _close_tunnel(self):
|
||||
logger.info(f"Quitting tunnel {self._cmd_island_ip}")
|
||||
notify_disconnect(self._cmd_island_ip, self._cmd_island_port)
|
||||
|
|
|
@ -4,10 +4,12 @@ import struct
|
|||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Interface
|
||||
from random import shuffle # noqa: DUO102
|
||||
from typing import List
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import netifaces
|
||||
import psutil
|
||||
from egg_timer import EggTimer
|
||||
|
||||
from infection_monkey.utils.environment import is_windows_os
|
||||
|
||||
|
@ -120,20 +122,93 @@ else:
|
|||
return routes
|
||||
|
||||
|
||||
def get_free_tcp_port(min_range=1024, max_range=65535):
|
||||
class TCPPortSelector:
|
||||
"""
|
||||
Select an available TCP port that a new server can listen on
|
||||
|
||||
in_use = {conn.laddr[1] for conn in psutil.net_connections()}
|
||||
Examines the system to find which ports are not in use and makes an intelligent decision
|
||||
regarding what port can be used to host a server. In multithreaded applications, a race occurs
|
||||
between the time when the OS reports that a port is free and when the port is actually used. In
|
||||
other words, two threads which request a free port simultaneously may be handed the same port,
|
||||
as the OS will report that the port is not in use. To combat this, the TCPPortSelector will
|
||||
reserve a port for a period of time to give the requester ample time to start their server. Once
|
||||
the requester's server is listening on the port, the OS will report the port as "LISTEN".
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._leases: Dict[int, EggTimer] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get_free_tcp_port(
|
||||
self, min_range: int = 1024, max_range: int = 65535, lease_time_sec: float = 30
|
||||
):
|
||||
"""
|
||||
Get a free TCP port that a new server can listen on
|
||||
|
||||
This function will attempt to provide a well-known port that the caller can listen on. If no
|
||||
well-known ports are available, a random port will be selected.
|
||||
|
||||
:param min_range: The smallest port number a random port can be chosen from, defaults to
|
||||
1024
|
||||
:param max_range: The largest port number a random port can be chosen from, defaults to
|
||||
65535
|
||||
:param lease_time_sec: The amount of time a port should be reserved for if the OS does not report
|
||||
it as in use, defaults to 30 seconds
|
||||
:return: A TCP port number
|
||||
"""
|
||||
with self._lock:
|
||||
ports_in_use = {conn.laddr[1] for conn in psutil.net_connections()}
|
||||
|
||||
common_port = self._get_free_common_port(ports_in_use, lease_time_sec)
|
||||
if common_port is not None:
|
||||
return common_port
|
||||
|
||||
return self._get_free_random_port(ports_in_use, min_range, max_range, lease_time_sec)
|
||||
|
||||
def _get_free_common_port(self, ports_in_use: Set[int], lease_time_sec):
|
||||
for port in COMMON_PORTS:
|
||||
if port not in in_use:
|
||||
if self._port_is_available(port, ports_in_use):
|
||||
self._reserve_port(port, lease_time_sec)
|
||||
return port
|
||||
|
||||
return None
|
||||
|
||||
def _get_free_random_port(
|
||||
self, ports_in_use: Set[int], min_range: int, max_range: int, lease_time_sec: float
|
||||
):
|
||||
min_range = max(1, min_range)
|
||||
max_range = min(65535, max_range)
|
||||
ports = list(range(min_range, max_range))
|
||||
shuffle(ports)
|
||||
for port in ports:
|
||||
if port not in in_use:
|
||||
if self._port_is_available(port, ports_in_use):
|
||||
self._reserve_port(port, lease_time_sec)
|
||||
return port
|
||||
|
||||
return None
|
||||
|
||||
def _port_is_available(self, port: int, ports_in_use: Set[int]):
|
||||
if port in ports_in_use:
|
||||
return False
|
||||
|
||||
if port not in self._leases:
|
||||
return True
|
||||
|
||||
if self._leases[port].is_expired():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _reserve_port(self, port: int, lease_time_sec: float):
|
||||
timer = EggTimer()
|
||||
timer.set(lease_time_sec)
|
||||
self._leases[port] = timer
|
||||
|
||||
|
||||
# TODO: This function is here because existing components rely on it. Refactor these components to
|
||||
# accept a TCPPortSelector instance and use that instead.
|
||||
def get_free_tcp_port(min_range=1024, max_range=65535, lease_time_sec=30):
|
||||
return get_free_tcp_port.port_selector.get_free_tcp_port(min_range, max_range, lease_time_sec)
|
||||
|
||||
|
||||
get_free_tcp_port.port_selector = TCPPortSelector() # type: ignore[attr-defined]
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
SOCKET_TIMEOUT = 10
|
|
@ -1,11 +1,14 @@
|
|||
import socket
|
||||
from ipaddress import IPv4Address
|
||||
from logging import getLogger
|
||||
|
||||
from .relay_user_handler import RelayUserHandler
|
||||
from .tcp_pipe_spawner import TCPPipeSpawner
|
||||
|
||||
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -"
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class RelayConnectionHandler:
|
||||
"""Handles new relay connections."""
|
||||
|
@ -23,10 +26,15 @@ class RelayConnectionHandler:
|
|||
addr, _ = sock.getpeername()
|
||||
addr = IPv4Address(addr)
|
||||
|
||||
control_message = sock.recv(socket.MSG_PEEK)
|
||||
control_message = sock.recv(
|
||||
len(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST), socket.MSG_PEEK
|
||||
)
|
||||
|
||||
if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST):
|
||||
self._relay_user_handler.disconnect_user(addr)
|
||||
else:
|
||||
self._relay_user_handler.add_relay_user(addr)
|
||||
try:
|
||||
self._pipe_spawner.spawn_pipe(sock)
|
||||
self._relay_user_handler.add_relay_user(addr)
|
||||
except OSError as err:
|
||||
logger.debug(f"Failed to spawn pipe: {err}")
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address
|
||||
from logging import getLogger
|
||||
from threading import Lock
|
||||
from typing import Dict
|
||||
|
||||
|
@ -13,6 +14,9 @@ DEFAULT_NEW_CLIENT_TIMEOUT = 2.5 * MEDIUM_REQUEST_TIMEOUT
|
|||
DEFAULT_DISCONNECT_TIMEOUT = 60 * 2 # Wait up to 2 minutes for clients to disconnect
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayUser:
|
||||
address: IPv4Address
|
||||
|
@ -48,7 +52,9 @@ class RelayUserHandler:
|
|||
|
||||
timer = EggTimer()
|
||||
timer.set(self._client_disconnect_timeout)
|
||||
self._relay_users[user_address] = RelayUser(user_address, timer)
|
||||
user = RelayUser(user_address, timer)
|
||||
self._relay_users[user_address] = user
|
||||
logger.debug(f"Added relay user {user}")
|
||||
|
||||
def add_potential_user(self, user_address: IPv4Address):
|
||||
"""
|
||||
|
@ -60,7 +66,9 @@ class RelayUserHandler:
|
|||
with self._lock:
|
||||
timer = EggTimer()
|
||||
timer.set(self._new_client_timeout)
|
||||
self._potential_users[user_address] = RelayUser(user_address, timer)
|
||||
user = RelayUser(user_address, timer)
|
||||
self._potential_users[user_address] = user
|
||||
logger.debug(f"Added potential relay user {user}")
|
||||
|
||||
def disconnect_user(self, user_address: IPv4Address):
|
||||
"""
|
||||
|
@ -70,6 +78,7 @@ class RelayUserHandler:
|
|||
"""
|
||||
with self._lock:
|
||||
if user_address in self._relay_users:
|
||||
logger.debug(f"Disconnected user {user_address}")
|
||||
del_key(self._relay_users, user_address)
|
||||
|
||||
def has_potential_users(self) -> bool:
|
||||
|
|
|
@ -5,8 +5,9 @@ from logging import getLogger
|
|||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from .consts import SOCKET_TIMEOUT
|
||||
|
||||
READ_BUFFER_SIZE = 8192
|
||||
SOCKET_READ_TIMEOUT = 10
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -14,25 +15,34 @@ logger = getLogger(__name__)
|
|||
class SocketsPipe(Thread):
|
||||
"""Manages a pipe between two sockets."""
|
||||
|
||||
_thread_count: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source,
|
||||
dest,
|
||||
pipe_closed: Callable[[SocketsPipe], None],
|
||||
timeout=SOCKET_READ_TIMEOUT,
|
||||
timeout=SOCKET_TIMEOUT,
|
||||
):
|
||||
self.source = source
|
||||
self.dest = dest
|
||||
self.timeout = timeout
|
||||
super().__init__(name=f"SocketsPipeThread-{self.ident}", daemon=True)
|
||||
super().__init__(name=f"SocketsPipeThread-{self._next_thread_num()}", daemon=True)
|
||||
self._pipe_closed = pipe_closed
|
||||
|
||||
@classmethod
|
||||
def _next_thread_num(cls):
|
||||
cls._thread_count += 1
|
||||
return cls._thread_count
|
||||
|
||||
def _pipe(self):
|
||||
sockets = [self.source, self.dest]
|
||||
while True:
|
||||
socket_closed = False
|
||||
|
||||
while not socket_closed:
|
||||
read_list, _, except_list = select.select(sockets, [], sockets, self.timeout)
|
||||
if except_list:
|
||||
raise Exception("select() failed on sockets {except_list}")
|
||||
raise OSError("select() failed on sockets {except_list}")
|
||||
|
||||
if not read_list:
|
||||
raise TimeoutError("pipe did not receive data for {self.timeout} seconds")
|
||||
|
@ -42,21 +52,24 @@ class SocketsPipe(Thread):
|
|||
data = r.recv(READ_BUFFER_SIZE)
|
||||
if data:
|
||||
other.sendall(data)
|
||||
else:
|
||||
socket_closed = True
|
||||
break
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self._pipe()
|
||||
except Exception as err:
|
||||
except OSError as err:
|
||||
logger.debug(err)
|
||||
|
||||
try:
|
||||
self.source.close()
|
||||
except Exception as err:
|
||||
except OSError as err:
|
||||
logger.debug(f"Error while closing source socket: {err}")
|
||||
|
||||
try:
|
||||
self.dest.close()
|
||||
except Exception as err:
|
||||
except OSError as err:
|
||||
logger.debug(f"Error while closing destination socket: {err}")
|
||||
|
||||
self._pipe_closed(self)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import socket
|
||||
from threading import Thread
|
||||
from typing import Callable, List
|
||||
|
@ -6,6 +7,8 @@ from infection_monkey.utils.threading import InterruptableThreadMixin
|
|||
|
||||
PROXY_TIMEOUT = 2.5
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TCPConnectionHandler(Thread, InterruptableThreadMixin):
|
||||
"""Accepts connections on a TCP socket."""
|
||||
|
@ -24,6 +27,7 @@ class TCPConnectionHandler(Thread, InterruptableThreadMixin):
|
|||
InterruptableThreadMixin.__init__(self)
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
l_socket.bind((self.bind_host, self.bind_port))
|
||||
l_socket.settimeout(PROXY_TIMEOUT)
|
||||
|
@ -35,7 +39,12 @@ class TCPConnectionHandler(Thread, InterruptableThreadMixin):
|
|||
except socket.timeout:
|
||||
continue
|
||||
|
||||
logging.debug(f"New connection received from: {source.getpeername()}")
|
||||
for notify_client_connected in self._client_connected:
|
||||
notify_client_connected(source)
|
||||
|
||||
except OSError:
|
||||
logging.exception("Uncaught error in TCPConnectionHandler thread")
|
||||
finally:
|
||||
l_socket.close()
|
||||
|
||||
logging.info("Exiting connection handler.")
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
import socket
|
||||
from ipaddress import IPv4Address
|
||||
from logging import getLogger
|
||||
from threading import Lock
|
||||
from typing import Set
|
||||
|
||||
from .consts import SOCKET_TIMEOUT
|
||||
from .sockets_pipe import SocketsPipe
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class TCPPipeSpawner:
|
||||
"""
|
||||
|
@ -22,12 +26,13 @@ class TCPPipeSpawner:
|
|||
Attempt to create a pipe on between the configured client and the provided socket
|
||||
|
||||
:param source: A socket to the connecting client.
|
||||
:raises socket.error: If a socket to the configured client could not be created.
|
||||
:raises OSError: If a socket to the configured client could not be created.
|
||||
"""
|
||||
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
dest.settimeout(SOCKET_TIMEOUT)
|
||||
try:
|
||||
dest.connect((self._target_addr, self._target_port))
|
||||
except socket.error as err:
|
||||
dest.connect((str(self._target_addr), self._target_port))
|
||||
except OSError as err:
|
||||
source.close()
|
||||
dest.close()
|
||||
raise err
|
||||
|
@ -35,7 +40,8 @@ class TCPPipeSpawner:
|
|||
pipe = SocketsPipe(source, dest, self._handle_pipe_closed)
|
||||
with self._lock:
|
||||
self._pipes.add(pipe)
|
||||
pipe.run()
|
||||
|
||||
pipe.start()
|
||||
|
||||
def has_open_pipes(self) -> bool:
|
||||
"""Return whether or not the TCPPipeSpawner has any open pipes."""
|
||||
|
@ -48,4 +54,5 @@ class TCPPipeSpawner:
|
|||
|
||||
def _handle_pipe_closed(self, pipe: SocketsPipe):
|
||||
with self._lock:
|
||||
logger.debug(f"Closing pipe {pipe}")
|
||||
self._pipes.discard(pipe)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from ipaddress import IPv4Address
|
||||
from logging import getLogger
|
||||
from threading import Lock, Thread
|
||||
from time import sleep
|
||||
|
||||
|
@ -10,6 +11,8 @@ from infection_monkey.network.relay import (
|
|||
)
|
||||
from infection_monkey.utils.threading import InterruptableThreadMixin
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class TCPRelay(Thread, InterruptableThreadMixin):
|
||||
"""
|
||||
|
@ -23,7 +26,10 @@ class TCPRelay(Thread, InterruptableThreadMixin):
|
|||
dest_port: int,
|
||||
client_disconnect_timeout: float,
|
||||
):
|
||||
self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout)
|
||||
self._user_handler = RelayUserHandler(
|
||||
new_client_timeout=client_disconnect_timeout,
|
||||
client_disconnect_timeout=client_disconnect_timeout,
|
||||
)
|
||||
self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port)
|
||||
relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler)
|
||||
self._connection_handler = TCPConnectionHandler(
|
||||
|
@ -46,6 +52,7 @@ class TCPRelay(Thread, InterruptableThreadMixin):
|
|||
self._connection_handler.stop()
|
||||
self._connection_handler.join()
|
||||
self._wait_for_pipes_to_close()
|
||||
logger.info("TCP Relay closed.")
|
||||
|
||||
def add_potential_user(self, user_address: IPv4Address):
|
||||
"""
|
||||
|
|
|
@ -1,20 +1,53 @@
|
|||
import logging
|
||||
import socket
|
||||
from contextlib import suppress
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Iterable, Optional
|
||||
from typing import Dict, Iterable, Iterator, MutableMapping, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
|
||||
from common.network.network_utils import address_to_ip_port
|
||||
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
|
||||
from infection_monkey.utils.threading import create_daemon_thread
|
||||
from infection_monkey.utils.threading import (
|
||||
ThreadSafeIterator,
|
||||
create_daemon_thread,
|
||||
run_worker_threads,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The number of Island servers to test simultaneously. 32 threads seems large enough for all
|
||||
# practical purposes. Revisit this if it's not.
|
||||
NUM_FIND_SERVER_WORKERS = 32
|
||||
|
||||
|
||||
def find_server(servers: Iterable[str]) -> Optional[str]:
|
||||
for server in servers:
|
||||
server_list = list(servers)
|
||||
server_iterator = ThreadSafeIterator(server_list.__iter__())
|
||||
server_results: Dict[str, bool] = {}
|
||||
|
||||
run_worker_threads(
|
||||
_find_island_server,
|
||||
"FindIslandServer",
|
||||
args=(server_iterator, server_results),
|
||||
num_workers=NUM_FIND_SERVER_WORKERS,
|
||||
)
|
||||
|
||||
for server in server_list:
|
||||
if server_results[server]:
|
||||
return server
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
|
||||
with suppress(StopIteration):
|
||||
server = next(servers)
|
||||
server_status[server] = _check_if_island_server(server)
|
||||
|
||||
|
||||
def _check_if_island_server(server: str) -> bool:
|
||||
logger.debug(f"Trying to connect to server: {server}")
|
||||
|
||||
try:
|
||||
|
@ -24,7 +57,7 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
|
|||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
return server
|
||||
return True
|
||||
except requests.exceptions.ConnectionError as err:
|
||||
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
||||
except TimeoutError as err:
|
||||
|
@ -34,7 +67,7 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
|
|||
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
||||
)
|
||||
|
||||
return None
|
||||
return False
|
||||
|
||||
|
||||
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
from functools import wraps
|
||||
from itertools import count
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Callable, Iterable, Optional, Tuple
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -116,3 +116,19 @@ class InterruptableThreadMixin:
|
|||
def stop(self):
|
||||
"""Stop a running thread."""
|
||||
self._interrupted.set()
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ThreadSafeIterator(Iterator[T]):
|
||||
"""Provides a thread-safe iterator that wraps another iterator"""
|
||||
|
||||
def __init__(self, iterator: Iterator[T]):
|
||||
self._lock = Lock()
|
||||
self._iterator = iterator
|
||||
|
||||
def __next__(self) -> T:
|
||||
while True:
|
||||
with self._lock:
|
||||
return next(self._iterator)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from monkey.infection_monkey.network.relay import SocketsPipe
|
||||
|
||||
|
||||
def test_sockets_pipe__name_increments():
|
||||
sock_in = MagicMock()
|
||||
sock_out = MagicMock()
|
||||
|
||||
pipe1 = SocketsPipe(sock_in, sock_out, None)
|
||||
assert pipe1.name.endswith("1")
|
||||
|
||||
pipe2 = SocketsPipe(sock_in, sock_out, None)
|
||||
assert pipe2.name.endswith("2")
|
|
@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
|||
(
|
||||
SERVER_2,
|
||||
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
|
||||
+ [(server, {"text": ""}) for server in servers[1:]],
|
||||
+ [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs):
|
|||
mock.get(f"https://{server}/api?action=is-up", **response)
|
||||
|
||||
assert find_server(servers) is expected_server
|
||||
|
||||
|
||||
def test_find_server__multiple_successes():
|
||||
with requests_mock.Mocker() as mock:
|
||||
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError)
|
||||
mock.get(f"https://{SERVER_2}/api?action=is-up", text="")
|
||||
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
|
||||
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
|
||||
|
||||
assert find_server(servers) == SERVER_2
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
|||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.network.info import get_free_tcp_port
|
||||
from infection_monkey.network.info import TCPPortSelector
|
||||
from infection_monkey.network.ports import COMMON_PORTS
|
||||
|
||||
|
||||
|
@ -13,28 +13,48 @@ class Connection:
|
|||
|
||||
|
||||
@pytest.mark.parametrize("port", COMMON_PORTS)
|
||||
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
|
||||
def test_tcp_port_selector__checks_common_ports(port: int, monkeypatch):
|
||||
tcp_port_selector = TCPPortSelector()
|
||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||
)
|
||||
assert get_free_tcp_port() is port
|
||||
assert tcp_port_selector.get_free_tcp_port() is port
|
||||
|
||||
|
||||
def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
|
||||
def test_tcp_port_selector__checks_other_ports_if_common_ports_unavailable(monkeypatch):
|
||||
tcp_port_selector = TCPPortSelector()
|
||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||
)
|
||||
|
||||
assert get_free_tcp_port() is not None
|
||||
assert tcp_port_selector.get_free_tcp_port() is not None
|
||||
|
||||
|
||||
def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
|
||||
def test_tcp_port_selector__none_if_no_available_ports(monkeypatch):
|
||||
tcp_port_selector = TCPPortSelector()
|
||||
unavailable_ports = [Connection(("", p)) for p in range(65535)]
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||
)
|
||||
|
||||
assert get_free_tcp_port() is None
|
||||
assert tcp_port_selector.get_free_tcp_port() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_port", COMMON_PORTS)
|
||||
def test_tcp_port_selector__checks_common_ports_leases(common_port: int, monkeypatch):
|
||||
tcp_port_selector = TCPPortSelector()
|
||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not common_port]
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||
)
|
||||
|
||||
free_port_1 = tcp_port_selector.get_free_tcp_port()
|
||||
free_port_2 = tcp_port_selector.get_free_tcp_port()
|
||||
|
||||
assert free_port_1 == common_port
|
||||
assert free_port_2 != common_port
|
||||
assert free_port_2 is not None
|
||||
assert free_port_2 not in COMMON_PORTS
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from itertools import zip_longest
|
||||
from threading import Event, current_thread
|
||||
from typing import Any
|
||||
|
||||
from infection_monkey.utils.threading import (
|
||||
ThreadSafeIterator,
|
||||
create_daemon_thread,
|
||||
interruptible_function,
|
||||
interruptible_iter,
|
||||
|
@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt():
|
|||
|
||||
assert return_value == 777
|
||||
assert fn.call_count == 0
|
||||
|
||||
|
||||
def test_thread_safe_iterator():
|
||||
test_list = [1, 2, 3, 4, 5]
|
||||
tsi = ThreadSafeIterator(test_list.__iter__())
|
||||
|
||||
for actual, expected in zip_longest(tsi, test_list):
|
||||
assert actual == expected
|
||||
|
|
Loading…
Reference in New Issue