diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 4ddc8655c..440ab7213 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -110,6 +110,7 @@ class InfectionMonkey: # TODO Refactor the telemetry messengers to accept control client # and remove control_client_object ControlClient.control_client_object = self._control_client + self._control_channel = None self._telemetry_messenger = LegacyTelemetryMessengerAdapter() self._current_depth = self._opts.depth self._master = None @@ -136,6 +137,10 @@ class InfectionMonkey: raise Exception( f"Failed to connect to the island via any known servers: {self._opts.servers}" ) + + # Note: Since we pass the address for each of our interfaces to the exploited + # machines, is it possible for a machine to unintentionally unregister itself from the + # relay if it is able to connect to the relay over multiple interfaces? send_remove_from_waitlist_control_message_to_relays(servers_iterator) return server @@ -177,10 +182,10 @@ class InfectionMonkey: if firewall.is_enabled(): firewall.add_firewall_rule() - control_channel = ControlChannel(self._control_client.server_address, GUID) - control_channel.register_agent(self._opts.parent) + self._control_channel = ControlChannel(self._control_client.server_address, GUID) + self._control_channel.register_agent(self._opts.parent) - config = control_channel.get_config() + config = self._control_channel.get_config() relay_port = get_free_tcp_port() self._relay = TCPRelay( @@ -204,9 +209,8 @@ class InfectionMonkey: local_network_interfaces = InfectionMonkey._get_local_network_interfaces() # TODO control_channel and control_client have same responsibilities, merge them - control_channel = ControlChannel(self._control_client.server_address, GUID) propagation_credentials_repository = AggregatingPropagationCredentialsRepository( - control_channel + self._control_channel ) event_queue = PyPubSubAgentEventQueue(Publisher()) @@ -226,7 +230,7 @@ class InfectionMonkey: puppet, telemetry_messenger, victim_host_factory, - control_channel, + self._control_channel, local_network_interfaces, propagation_credentials_repository, ) @@ -393,10 +397,7 @@ class InfectionMonkey: self._master.cleanup() reset_signal_handlers() - - if self._relay and self._relay.is_alive(): - self._relay.stop() - self._relay.join(timeout=60) + self._stop_relay() if firewall.is_enabled(): firewall.remove_firewall_rule() @@ -420,6 +421,16 @@ class InfectionMonkey: logger.info("Monkey is shutting down") + def _stop_relay(self): + if self._relay and self._relay.is_alive(): + self._relay.stop() + + while self._relay.is_alive() and not self._control_channel.should_agent_stop(): + self._relay.join(timeout=5) + + if self._control_channel.should_agent_stop(): + self._relay.join(timeout=60) + def _close_tunnel(self): logger.info(f"Quitting tunnel {self._cmd_island_ip}") notify_disconnect(self._cmd_island_ip, self._cmd_island_port) diff --git a/monkey/infection_monkey/network/info.py b/monkey/infection_monkey/network/info.py index 6fdef3597..6efcf3cd2 100644 --- a/monkey/infection_monkey/network/info.py +++ b/monkey/infection_monkey/network/info.py @@ -4,10 +4,12 @@ import struct from dataclasses import dataclass from ipaddress import IPv4Interface from random import shuffle # noqa: DUO102 -from typing import List +from threading import Lock +from typing import Dict, List, Set import netifaces import psutil +from egg_timer import EggTimer from infection_monkey.utils.environment import is_windows_os @@ -120,20 +122,93 @@ else: return routes -def get_free_tcp_port(min_range=1024, max_range=65535): +class TCPPortSelector: + """ + Select an available TCP port that a new server can listen on - in_use = {conn.laddr[1] for conn in psutil.net_connections()} + Examines the system to find which ports are not in use and makes an intelligent decision + regarding what port can be used to host a server. In multithreaded applications, a race occurs + between the time when the OS reports that a port is free and when the port is actually used. In + other words, two threads which request a free port simultaneously may be handed the same port, + as the OS will report that the port is not in use. To combat this, the TCPPortSelector will + reserve a port for a period of time to give the requester ample time to start their server. Once + the requester's server is listening on the port, the OS will report the port as "LISTEN". + """ - for port in COMMON_PORTS: - if port not in in_use: - return port + def __init__(self): + self._leases: Dict[int, EggTimer] = {} + self._lock = Lock() - min_range = max(1, min_range) - max_range = min(65535, max_range) - ports = list(range(min_range, max_range)) - shuffle(ports) - for port in ports: - if port not in in_use: - return port + def get_free_tcp_port( + self, min_range: int = 1024, max_range: int = 65535, lease_time_sec: float = 30 + ): + """ + Get a free TCP port that a new server can listen on - return None + This function will attempt to provide a well-known port that the caller can listen on. If no + well-known ports are available, a random port will be selected. + + :param min_range: The smallest port number a random port can be chosen from, defaults to + 1024 + :param max_range: The largest port number a random port can be chosen from, defaults to + 65535 + :param lease_time_sec: The amount of time a port should be reserved for if the OS does not report + it as in use, defaults to 30 seconds + :return: A TCP port number + """ + with self._lock: + ports_in_use = {conn.laddr[1] for conn in psutil.net_connections()} + + common_port = self._get_free_common_port(ports_in_use, lease_time_sec) + if common_port is not None: + return common_port + + return self._get_free_random_port(ports_in_use, min_range, max_range, lease_time_sec) + + def _get_free_common_port(self, ports_in_use: Set[int], lease_time_sec): + for port in COMMON_PORTS: + if self._port_is_available(port, ports_in_use): + self._reserve_port(port, lease_time_sec) + return port + + return None + + def _get_free_random_port( + self, ports_in_use: Set[int], min_range: int, max_range: int, lease_time_sec: float + ): + min_range = max(1, min_range) + max_range = min(65535, max_range) + ports = list(range(min_range, max_range)) + shuffle(ports) + for port in ports: + if self._port_is_available(port, ports_in_use): + self._reserve_port(port, lease_time_sec) + return port + + return None + + def _port_is_available(self, port: int, ports_in_use: Set[int]): + if port in ports_in_use: + return False + + if port not in self._leases: + return True + + if self._leases[port].is_expired(): + return True + + return False + + def _reserve_port(self, port: int, lease_time_sec: float): + timer = EggTimer() + timer.set(lease_time_sec) + self._leases[port] = timer + + +# TODO: This function is here because existing components rely on it. Refactor these components to +# accept a TCPPortSelector instance and use that instead. +def get_free_tcp_port(min_range=1024, max_range=65535, lease_time_sec=30): + return get_free_tcp_port.port_selector.get_free_tcp_port(min_range, max_range, lease_time_sec) + + +get_free_tcp_port.port_selector = TCPPortSelector() # type: ignore[attr-defined] diff --git a/monkey/infection_monkey/network/relay/consts.py b/monkey/infection_monkey/network/relay/consts.py new file mode 100644 index 000000000..eb5d3f75a --- /dev/null +++ b/monkey/infection_monkey/network/relay/consts.py @@ -0,0 +1 @@ +SOCKET_TIMEOUT = 10 diff --git a/monkey/infection_monkey/network/relay/relay_connection_handler.py b/monkey/infection_monkey/network/relay/relay_connection_handler.py index 4b4475e52..91f37b520 100644 --- a/monkey/infection_monkey/network/relay/relay_connection_handler.py +++ b/monkey/infection_monkey/network/relay/relay_connection_handler.py @@ -1,11 +1,14 @@ import socket from ipaddress import IPv4Address +from logging import getLogger from .relay_user_handler import RelayUserHandler from .tcp_pipe_spawner import TCPPipeSpawner RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -" +logger = getLogger(__name__) + class RelayConnectionHandler: """Handles new relay connections.""" @@ -23,10 +26,15 @@ class RelayConnectionHandler: addr, _ = sock.getpeername() addr = IPv4Address(addr) - control_message = sock.recv(socket.MSG_PEEK) + control_message = sock.recv( + len(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST), socket.MSG_PEEK + ) if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST): self._relay_user_handler.disconnect_user(addr) else: - self._relay_user_handler.add_relay_user(addr) - self._pipe_spawner.spawn_pipe(sock) + try: + self._pipe_spawner.spawn_pipe(sock) + self._relay_user_handler.add_relay_user(addr) + except OSError as err: + logger.debug(f"Failed to spawn pipe: {err}") diff --git a/monkey/infection_monkey/network/relay/relay_user_handler.py b/monkey/infection_monkey/network/relay/relay_user_handler.py index 63c4000e0..390f1d31f 100644 --- a/monkey/infection_monkey/network/relay/relay_user_handler.py +++ b/monkey/infection_monkey/network/relay/relay_user_handler.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from ipaddress import IPv4Address +from logging import getLogger from threading import Lock from typing import Dict @@ -13,6 +14,9 @@ DEFAULT_NEW_CLIENT_TIMEOUT = 2.5 * MEDIUM_REQUEST_TIMEOUT DEFAULT_DISCONNECT_TIMEOUT = 60 * 2 # Wait up to 2 minutes for clients to disconnect +logger = getLogger(__name__) + + @dataclass class RelayUser: address: IPv4Address @@ -48,7 +52,9 @@ class RelayUserHandler: timer = EggTimer() timer.set(self._client_disconnect_timeout) - self._relay_users[user_address] = RelayUser(user_address, timer) + user = RelayUser(user_address, timer) + self._relay_users[user_address] = user + logger.debug(f"Added relay user {user}") def add_potential_user(self, user_address: IPv4Address): """ @@ -60,7 +66,9 @@ class RelayUserHandler: with self._lock: timer = EggTimer() timer.set(self._new_client_timeout) - self._potential_users[user_address] = RelayUser(user_address, timer) + user = RelayUser(user_address, timer) + self._potential_users[user_address] = user + logger.debug(f"Added potential relay user {user}") def disconnect_user(self, user_address: IPv4Address): """ @@ -70,6 +78,7 @@ class RelayUserHandler: """ with self._lock: if user_address in self._relay_users: + logger.debug(f"Disconnected user {user_address}") del_key(self._relay_users, user_address) def has_potential_users(self) -> bool: diff --git a/monkey/infection_monkey/network/relay/sockets_pipe.py b/monkey/infection_monkey/network/relay/sockets_pipe.py index e31dfaaf4..b4d59416a 100644 --- a/monkey/infection_monkey/network/relay/sockets_pipe.py +++ b/monkey/infection_monkey/network/relay/sockets_pipe.py @@ -5,8 +5,9 @@ from logging import getLogger from threading import Thread from typing import Callable +from .consts import SOCKET_TIMEOUT + READ_BUFFER_SIZE = 8192 -SOCKET_READ_TIMEOUT = 10 logger = getLogger(__name__) @@ -14,25 +15,34 @@ logger = getLogger(__name__) class SocketsPipe(Thread): """Manages a pipe between two sockets.""" + _thread_count: int = 0 + def __init__( self, source, dest, pipe_closed: Callable[[SocketsPipe], None], - timeout=SOCKET_READ_TIMEOUT, + timeout=SOCKET_TIMEOUT, ): self.source = source self.dest = dest self.timeout = timeout - super().__init__(name=f"SocketsPipeThread-{self.ident}", daemon=True) + super().__init__(name=f"SocketsPipeThread-{self._next_thread_num()}", daemon=True) self._pipe_closed = pipe_closed + @classmethod + def _next_thread_num(cls): + cls._thread_count += 1 + return cls._thread_count + def _pipe(self): sockets = [self.source, self.dest] - while True: + socket_closed = False + + while not socket_closed: read_list, _, except_list = select.select(sockets, [], sockets, self.timeout) if except_list: - raise Exception("select() failed on sockets {except_list}") + raise OSError("select() failed on sockets {except_list}") if not read_list: raise TimeoutError("pipe did not receive data for {self.timeout} seconds") @@ -42,21 +52,24 @@ class SocketsPipe(Thread): data = r.recv(READ_BUFFER_SIZE) if data: other.sendall(data) + else: + socket_closed = True + break def run(self): try: self._pipe() - except Exception as err: + except OSError as err: logger.debug(err) try: self.source.close() - except Exception as err: + except OSError as err: logger.debug(f"Error while closing source socket: {err}") try: self.dest.close() - except Exception as err: + except OSError as err: logger.debug(f"Error while closing destination socket: {err}") self._pipe_closed(self) diff --git a/monkey/infection_monkey/network/relay/tcp_connection_handler.py b/monkey/infection_monkey/network/relay/tcp_connection_handler.py index b70a133bf..f61bc2d3a 100644 --- a/monkey/infection_monkey/network/relay/tcp_connection_handler.py +++ b/monkey/infection_monkey/network/relay/tcp_connection_handler.py @@ -1,3 +1,4 @@ +import logging import socket from threading import Thread from typing import Callable, List @@ -6,6 +7,8 @@ from infection_monkey.utils.threading import InterruptableThreadMixin PROXY_TIMEOUT = 2.5 +logger = logging.getLogger(__name__) + class TCPConnectionHandler(Thread, InterruptableThreadMixin): """Accepts connections on a TCP socket.""" @@ -24,18 +27,24 @@ class TCPConnectionHandler(Thread, InterruptableThreadMixin): InterruptableThreadMixin.__init__(self) def run(self): - l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - l_socket.bind((self.bind_host, self.bind_port)) - l_socket.settimeout(PROXY_TIMEOUT) - l_socket.listen(5) + try: + l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + l_socket.bind((self.bind_host, self.bind_port)) + l_socket.settimeout(PROXY_TIMEOUT) + l_socket.listen(5) - while not self._interrupted.is_set(): - try: - source, _ = l_socket.accept() - except socket.timeout: - continue + while not self._interrupted.is_set(): + try: + source, _ = l_socket.accept() + except socket.timeout: + continue - for notify_client_connected in self._client_connected: - notify_client_connected(source) + logging.debug(f"New connection received from: {source.getpeername()}") + for notify_client_connected in self._client_connected: + notify_client_connected(source) + except OSError: + logging.exception("Uncaught error in TCPConnectionHandler thread") + finally: + l_socket.close() - l_socket.close() + logging.info("Exiting connection handler.") diff --git a/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py b/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py index 5ffebaaea..0c23f0380 100644 --- a/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py +++ b/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py @@ -1,10 +1,14 @@ import socket from ipaddress import IPv4Address +from logging import getLogger from threading import Lock from typing import Set +from .consts import SOCKET_TIMEOUT from .sockets_pipe import SocketsPipe +logger = getLogger(__name__) + class TCPPipeSpawner: """ @@ -22,12 +26,13 @@ class TCPPipeSpawner: Attempt to create a pipe on between the configured client and the provided socket :param source: A socket to the connecting client. - :raises socket.error: If a socket to the configured client could not be created. + :raises OSError: If a socket to the configured client could not be created. """ dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + dest.settimeout(SOCKET_TIMEOUT) try: - dest.connect((self._target_addr, self._target_port)) - except socket.error as err: + dest.connect((str(self._target_addr), self._target_port)) + except OSError as err: source.close() dest.close() raise err @@ -35,7 +40,8 @@ class TCPPipeSpawner: pipe = SocketsPipe(source, dest, self._handle_pipe_closed) with self._lock: self._pipes.add(pipe) - pipe.run() + + pipe.start() def has_open_pipes(self) -> bool: """Return whether or not the TCPPipeSpawner has any open pipes.""" @@ -48,4 +54,5 @@ class TCPPipeSpawner: def _handle_pipe_closed(self, pipe: SocketsPipe): with self._lock: + logger.debug(f"Closing pipe {pipe}") self._pipes.discard(pipe) diff --git a/monkey/infection_monkey/network/relay/tcp_relay.py b/monkey/infection_monkey/network/relay/tcp_relay.py index 8e62a2169..c2663cb1d 100644 --- a/monkey/infection_monkey/network/relay/tcp_relay.py +++ b/monkey/infection_monkey/network/relay/tcp_relay.py @@ -1,4 +1,5 @@ from ipaddress import IPv4Address +from logging import getLogger from threading import Lock, Thread from time import sleep @@ -10,6 +11,8 @@ from infection_monkey.network.relay import ( ) from infection_monkey.utils.threading import InterruptableThreadMixin +logger = getLogger(__name__) + class TCPRelay(Thread, InterruptableThreadMixin): """ @@ -23,7 +26,10 @@ class TCPRelay(Thread, InterruptableThreadMixin): dest_port: int, client_disconnect_timeout: float, ): - self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout) + self._user_handler = RelayUserHandler( + new_client_timeout=client_disconnect_timeout, + client_disconnect_timeout=client_disconnect_timeout, + ) self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port) relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler) self._connection_handler = TCPConnectionHandler( @@ -46,6 +52,7 @@ class TCPRelay(Thread, InterruptableThreadMixin): self._connection_handler.stop() self._connection_handler.join() self._wait_for_pipes_to_close() + logger.info("TCP Relay closed.") def add_potential_user(self, user_address: IPv4Address): """ diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 4a5532c30..a19f316db 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -1,42 +1,75 @@ import logging import socket +from contextlib import suppress from ipaddress import IPv4Address -from typing import Iterable, Optional +from typing import Dict, Iterable, Iterator, MutableMapping, Optional import requests from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST -from infection_monkey.utils.threading import create_daemon_thread +from infection_monkey.utils.threading import ( + ThreadSafeIterator, + create_daemon_thread, + run_worker_threads, +) logger = logging.getLogger(__name__) +# The number of Island servers to test simultaneously. 32 threads seems large enough for all +# practical purposes. Revisit this if it's not. +NUM_FIND_SERVER_WORKERS = 32 + def find_server(servers: Iterable[str]) -> Optional[str]: - for server in servers: - logger.debug(f"Trying to connect to server: {server}") + server_list = list(servers) + server_iterator = ThreadSafeIterator(server_list.__iter__()) + server_results: Dict[str, bool] = {} - try: - requests.get( # noqa: DUO123 - f"https://{server}/api?action=is-up", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) + run_worker_threads( + _find_island_server, + "FindIslandServer", + args=(server_iterator, server_results), + num_workers=NUM_FIND_SERVER_WORKERS, + ) + for server in server_list: + if server_results[server]: return server - except requests.exceptions.ConnectionError as err: - logger.error(f"Unable to connect to server/relay {server}: {err}") - except TimeoutError as err: - logger.error(f"Timed out while connecting to server/relay {server}: {err}") - except Exception as err: - logger.error( - f"Exception encountered when trying to connect to server/relay {server}: {err}" - ) return None +def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]): + with suppress(StopIteration): + server = next(servers) + server_status[server] = _check_if_island_server(server) + + +def _check_if_island_server(server: str) -> bool: + logger.debug(f"Trying to connect to server: {server}") + + try: + requests.get( # noqa: DUO123 + f"https://{server}/api?action=is-up", + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + + return True + except requests.exceptions.ConnectionError as err: + logger.error(f"Unable to connect to server/relay {server}: {err}") + except TimeoutError as err: + logger.error(f"Timed out while connecting to server/relay {server}: {err}") + except Exception as err: + logger.error( + f"Exception encountered when trying to connect to server/relay {server}: {err}" + ) + + return False + + def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): for server in servers: t = create_daemon_thread( diff --git a/monkey/infection_monkey/utils/threading.py b/monkey/infection_monkey/utils/threading.py index be28aa0b1..6d8b28253 100644 --- a/monkey/infection_monkey/utils/threading.py +++ b/monkey/infection_monkey/utils/threading.py @@ -1,8 +1,8 @@ import logging from functools import wraps from itertools import count -from threading import Event, Thread -from typing import Any, Callable, Iterable, Optional, Tuple +from threading import Event, Lock, Thread +from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar logger = logging.getLogger(__name__) @@ -116,3 +116,19 @@ class InterruptableThreadMixin: def stop(self): """Stop a running thread.""" self._interrupted.set() + + +T = TypeVar("T") + + +class ThreadSafeIterator(Iterator[T]): + """Provides a thread-safe iterator that wraps another iterator""" + + def __init__(self, iterator: Iterator[T]): + self._lock = Lock() + self._iterator = iterator + + def __next__(self) -> T: + while True: + with self._lock: + return next(self._iterator) diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_sockets_pipe.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_sockets_pipe.py new file mode 100644 index 000000000..0a98c0247 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_sockets_pipe.py @@ -0,0 +1,14 @@ +from unittest.mock import MagicMock + +from monkey.infection_monkey.network.relay import SocketsPipe + + +def test_sockets_pipe__name_increments(): + sock_in = MagicMock() + sock_out = MagicMock() + + pipe1 = SocketsPipe(sock_in, sock_out, None) + assert pipe1.name.endswith("1") + + pipe2 = SocketsPipe(sock_in, sock_out, None) + assert pipe2.name.endswith("2") diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index 1acd08012..ac7eb1b16 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] ( SERVER_2, [(SERVER_1, {"exc": requests.exceptions.ConnectionError})] - + [(server, {"text": ""}) for server in servers[1:]], + + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item] ), ], ) @@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs): mock.get(f"https://{server}/api?action=is-up", **response) assert find_server(servers) is expected_server + + +def test_find_server__multiple_successes(): + with requests_mock.Mocker() as mock: + mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError) + mock.get(f"https://{SERVER_2}/api?action=is-up", text="") + mock.get(f"https://{SERVER_3}/api?action=is-up", text="") + mock.get(f"https://{SERVER_4}/api?action=is-up", text="") + + assert find_server(servers) == SERVER_2 diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_info.py b/monkey/tests/unit_tests/infection_monkey/network/test_info.py index 8dab89e4b..f0508ce98 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/test_info.py +++ b/monkey/tests/unit_tests/infection_monkey/network/test_info.py @@ -3,7 +3,7 @@ from typing import Tuple import pytest -from infection_monkey.network.info import get_free_tcp_port +from infection_monkey.network.info import TCPPortSelector from infection_monkey.network.ports import COMMON_PORTS @@ -13,28 +13,48 @@ class Connection: @pytest.mark.parametrize("port", COMMON_PORTS) -def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch): +def test_tcp_port_selector__checks_common_ports(port: int, monkeypatch): + tcp_port_selector = TCPPortSelector() unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port] monkeypatch.setattr( "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports ) - assert get_free_tcp_port() is port + assert tcp_port_selector.get_free_tcp_port() is port -def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch): +def test_tcp_port_selector__checks_other_ports_if_common_ports_unavailable(monkeypatch): + tcp_port_selector = TCPPortSelector() unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS] monkeypatch.setattr( "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports ) - assert get_free_tcp_port() is not None + assert tcp_port_selector.get_free_tcp_port() is not None -def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch): +def test_tcp_port_selector__none_if_no_available_ports(monkeypatch): + tcp_port_selector = TCPPortSelector() unavailable_ports = [Connection(("", p)) for p in range(65535)] monkeypatch.setattr( "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports ) - assert get_free_tcp_port() is None + assert tcp_port_selector.get_free_tcp_port() is None + + +@pytest.mark.parametrize("common_port", COMMON_PORTS) +def test_tcp_port_selector__checks_common_ports_leases(common_port: int, monkeypatch): + tcp_port_selector = TCPPortSelector() + unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not common_port] + monkeypatch.setattr( + "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports + ) + + free_port_1 = tcp_port_selector.get_free_tcp_port() + free_port_2 = tcp_port_selector.get_free_tcp_port() + + assert free_port_1 == common_port + assert free_port_2 != common_port + assert free_port_2 is not None + assert free_port_2 not in COMMON_PORTS diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py index 96a289096..05b813b66 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py @@ -1,8 +1,10 @@ import logging +from itertools import zip_longest from threading import Event, current_thread from typing import Any from infection_monkey.utils.threading import ( + ThreadSafeIterator, create_daemon_thread, interruptible_function, interruptible_iter, @@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt(): assert return_value == 777 assert fn.call_count == 0 + + +def test_thread_safe_iterator(): + test_list = [1, 2, 3, 4, 5] + tsi = ThreadSafeIterator(test_list.__iter__()) + + for actual, expected in zip_longest(tsi, test_list): + assert actual == expected