diff --git a/monkey/infection_monkey/network/info.py b/monkey/infection_monkey/network/info.py index 6fdef3597..fdcc45234 100644 --- a/monkey/infection_monkey/network/info.py +++ b/monkey/infection_monkey/network/info.py @@ -4,10 +4,12 @@ import struct from dataclasses import dataclass from ipaddress import IPv4Interface from random import shuffle # noqa: DUO102 -from typing import List +from threading import Lock +from typing import Dict, List, Set import netifaces import psutil +from egg_timer import EggTimer from infection_monkey.utils.environment import is_windows_os @@ -120,20 +122,93 @@ else: return routes -def get_free_tcp_port(min_range=1024, max_range=65535): +class TCPPortSelector: + """ + Select an available TCP port that a new server can listen on - in_use = {conn.laddr[1] for conn in psutil.net_connections()} + Examines the system to find which ports are not in use and makes an intelligent decision + regarding what port can be used to host a server. In multithreaded applications, a race occurs + between the time when the OS reports that a port is free and when the port is actually used. In + other words, two threads which request a free port simultaneously may be handed the same port, + as the OS will report that the port is not in use. To combat this, the TCPPortSelector will + reserve a port for a period of time to give the requester ample time to start their server. Once + the requester's server is listening on the port, the OS will report the port as "LISTEN". + """ - for port in COMMON_PORTS: - if port not in in_use: - return port + def __init__(self): + self._leases: Dict[int, EggTimer] = {} + self._lock = Lock() - min_range = max(1, min_range) - max_range = min(65535, max_range) - ports = list(range(min_range, max_range)) - shuffle(ports) - for port in ports: - if port not in in_use: - return port + def get_free_tcp_port( + self, min_range: int = 1024, max_range: int = 65535, lease_time_sec: float = 30 + ): + """ + Get a free TCP port that a new server can listen on - return None + This function will attempt to provide a well-known port that the caller can listen on. If no + well-known ports are available, a random port will be selected. + + :param min_range: The smallest port number a random port can be chosen from, defaults to + 1024 + :param max_range: The largest port number a random port can be chosen from, defaults to + 65535 + :lease_time_sec: The amount of time a port should be reserved for if the OS does not report + it as in use, defaults to 30 seconds + :return: A TCP port number + """ + with self._lock: + ports_in_use = {conn.laddr[1] for conn in psutil.net_connections()} + + common_port = self._get_free_common_port(ports_in_use, lease_time_sec) + if common_port is not None: + return common_port + + return self._get_free_random_port(ports_in_use, min_range, max_range, lease_time_sec) + + def _get_free_common_port(self, ports_in_use: Set[int], lease_time_sec): + for port in COMMON_PORTS: + if self._port_is_available(port, ports_in_use): + self._reserve_port(port, lease_time_sec) + return port + + return None + + def _get_free_random_port( + self, ports_in_use: Set[int], min_range: int, max_range: int, lease_time_sec: float + ): + min_range = max(1, min_range) + max_range = min(65535, max_range) + ports = list(range(min_range, max_range)) + shuffle(ports) + for port in ports: + if self._port_is_available(port, ports_in_use): + self._reserve_port(port, lease_time_sec) + return port + + return None + + def _port_is_available(self, port: int, ports_in_use: Set[int]): + if port in ports_in_use: + return False + + if port not in self._leases: + return True + + if self._leases[port].is_expired(): + return True + + return False + + def _reserve_port(self, port: int, lease_time_sec: float): + timer = EggTimer() + timer.set(lease_time_sec) + self._leases[port] = timer + + +# TODO: This function is here because existing components rely on it. Refactor these components to +# accept a TCPPortSelector instance and use that instead. +def get_free_tcp_port(min_range=1024, max_range=65535, lease_time_sec=30): + return get_free_tcp_port.port_selector.get_free_tcp_port(min_range, max_range, lease_time_sec) + + +get_free_tcp_port.port_selector = TCPPortSelector() # type: ignore[attr-defined] diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_info.py b/monkey/tests/unit_tests/infection_monkey/network/test_info.py index 8dab89e4b..f0508ce98 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/test_info.py +++ b/monkey/tests/unit_tests/infection_monkey/network/test_info.py @@ -3,7 +3,7 @@ from typing import Tuple import pytest -from infection_monkey.network.info import get_free_tcp_port +from infection_monkey.network.info import TCPPortSelector from infection_monkey.network.ports import COMMON_PORTS @@ -13,28 +13,48 @@ class Connection: @pytest.mark.parametrize("port", COMMON_PORTS) -def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch): +def test_tcp_port_selector__checks_common_ports(port: int, monkeypatch): + tcp_port_selector = TCPPortSelector() unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port] monkeypatch.setattr( "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports ) - assert get_free_tcp_port() is port + assert tcp_port_selector.get_free_tcp_port() is port -def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch): +def test_tcp_port_selector__checks_other_ports_if_common_ports_unavailable(monkeypatch): + tcp_port_selector = TCPPortSelector() unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS] monkeypatch.setattr( "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports ) - assert get_free_tcp_port() is not None + assert tcp_port_selector.get_free_tcp_port() is not None -def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch): +def test_tcp_port_selector__none_if_no_available_ports(monkeypatch): + tcp_port_selector = TCPPortSelector() unavailable_ports = [Connection(("", p)) for p in range(65535)] monkeypatch.setattr( "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports ) - assert get_free_tcp_port() is None + assert tcp_port_selector.get_free_tcp_port() is None + + +@pytest.mark.parametrize("common_port", COMMON_PORTS) +def test_tcp_port_selector__checks_common_ports_leases(common_port: int, monkeypatch): + tcp_port_selector = TCPPortSelector() + unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not common_port] + monkeypatch.setattr( + "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports + ) + + free_port_1 = tcp_port_selector.get_free_tcp_port() + free_port_2 = tcp_port_selector.get_free_tcp_port() + + assert free_port_1 == common_port + assert free_port_2 != common_port + assert free_port_2 is not None + assert free_port_2 not in COMMON_PORTS