forked from p15670423/monkey
Merge pull request #2281 from guardicore/2216-prevent-port-collisions
Agent: Add TCPPortSelector
This commit is contained in:
commit
905fb3563e
|
@ -4,10 +4,12 @@ import struct
|
|||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Interface
|
||||
from random import shuffle # noqa: DUO102
|
||||
from typing import List
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import netifaces
|
||||
import psutil
|
||||
from egg_timer import EggTimer
|
||||
|
||||
from infection_monkey.utils.environment import is_windows_os
|
||||
|
||||
|
@ -120,20 +122,93 @@ else:
|
|||
return routes
|
||||
|
||||
|
||||
def get_free_tcp_port(min_range=1024, max_range=65535):
|
||||
class TCPPortSelector:
|
||||
"""
|
||||
Select an available TCP port that a new server can listen on
|
||||
|
||||
in_use = {conn.laddr[1] for conn in psutil.net_connections()}
|
||||
Examines the system to find which ports are not in use and makes an intelligent decision
|
||||
regarding what port can be used to host a server. In multithreaded applications, a race occurs
|
||||
between the time when the OS reports that a port is free and when the port is actually used. In
|
||||
other words, two threads which request a free port simultaneously may be handed the same port,
|
||||
as the OS will report that the port is not in use. To combat this, the TCPPortSelector will
|
||||
reserve a port for a period of time to give the requester ample time to start their server. Once
|
||||
the requester's server is listening on the port, the OS will report the port as "LISTEN".
|
||||
"""
|
||||
|
||||
for port in COMMON_PORTS:
|
||||
if port not in in_use:
|
||||
return port
|
||||
def __init__(self):
|
||||
self._leases: Dict[int, EggTimer] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
min_range = max(1, min_range)
|
||||
max_range = min(65535, max_range)
|
||||
ports = list(range(min_range, max_range))
|
||||
shuffle(ports)
|
||||
for port in ports:
|
||||
if port not in in_use:
|
||||
return port
|
||||
def get_free_tcp_port(
|
||||
self, min_range: int = 1024, max_range: int = 65535, lease_time_sec: float = 30
|
||||
):
|
||||
"""
|
||||
Get a free TCP port that a new server can listen on
|
||||
|
||||
return None
|
||||
This function will attempt to provide a well-known port that the caller can listen on. If no
|
||||
well-known ports are available, a random port will be selected.
|
||||
|
||||
:param min_range: The smallest port number a random port can be chosen from, defaults to
|
||||
1024
|
||||
:param max_range: The largest port number a random port can be chosen from, defaults to
|
||||
65535
|
||||
:param lease_time_sec: The amount of time a port should be reserved for if the OS does not report
|
||||
it as in use, defaults to 30 seconds
|
||||
:return: A TCP port number
|
||||
"""
|
||||
with self._lock:
|
||||
ports_in_use = {conn.laddr[1] for conn in psutil.net_connections()}
|
||||
|
||||
common_port = self._get_free_common_port(ports_in_use, lease_time_sec)
|
||||
if common_port is not None:
|
||||
return common_port
|
||||
|
||||
return self._get_free_random_port(ports_in_use, min_range, max_range, lease_time_sec)
|
||||
|
||||
def _get_free_common_port(self, ports_in_use: Set[int], lease_time_sec):
|
||||
for port in COMMON_PORTS:
|
||||
if self._port_is_available(port, ports_in_use):
|
||||
self._reserve_port(port, lease_time_sec)
|
||||
return port
|
||||
|
||||
return None
|
||||
|
||||
def _get_free_random_port(
|
||||
self, ports_in_use: Set[int], min_range: int, max_range: int, lease_time_sec: float
|
||||
):
|
||||
min_range = max(1, min_range)
|
||||
max_range = min(65535, max_range)
|
||||
ports = list(range(min_range, max_range))
|
||||
shuffle(ports)
|
||||
for port in ports:
|
||||
if self._port_is_available(port, ports_in_use):
|
||||
self._reserve_port(port, lease_time_sec)
|
||||
return port
|
||||
|
||||
return None
|
||||
|
||||
def _port_is_available(self, port: int, ports_in_use: Set[int]):
|
||||
if port in ports_in_use:
|
||||
return False
|
||||
|
||||
if port not in self._leases:
|
||||
return True
|
||||
|
||||
if self._leases[port].is_expired():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _reserve_port(self, port: int, lease_time_sec: float):
|
||||
timer = EggTimer()
|
||||
timer.set(lease_time_sec)
|
||||
self._leases[port] = timer
|
||||
|
||||
|
||||
# TODO: This function is here because existing components rely on it. Refactor these components to
|
||||
# accept a TCPPortSelector instance and use that instead.
|
||||
def get_free_tcp_port(min_range=1024, max_range=65535, lease_time_sec=30):
|
||||
return get_free_tcp_port.port_selector.get_free_tcp_port(min_range, max_range, lease_time_sec)
|
||||
|
||||
|
||||
get_free_tcp_port.port_selector = TCPPortSelector() # type: ignore[attr-defined]
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
|||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.network.info import get_free_tcp_port
|
||||
from infection_monkey.network.info import TCPPortSelector
|
||||
from infection_monkey.network.ports import COMMON_PORTS
|
||||
|
||||
|
||||
|
@ -13,28 +13,48 @@ class Connection:
|
|||
|
||||
|
||||
@pytest.mark.parametrize("port", COMMON_PORTS)
|
||||
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
|
||||
def test_tcp_port_selector__checks_common_ports(port: int, monkeypatch):
|
||||
tcp_port_selector = TCPPortSelector()
|
||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||
)
|
||||
assert get_free_tcp_port() is port
|
||||
assert tcp_port_selector.get_free_tcp_port() is port
|
||||
|
||||
|
||||
def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
|
||||
def test_tcp_port_selector__checks_other_ports_if_common_ports_unavailable(monkeypatch):
|
||||
tcp_port_selector = TCPPortSelector()
|
||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||
)
|
||||
|
||||
assert get_free_tcp_port() is not None
|
||||
assert tcp_port_selector.get_free_tcp_port() is not None
|
||||
|
||||
|
||||
def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
|
||||
def test_tcp_port_selector__none_if_no_available_ports(monkeypatch):
|
||||
tcp_port_selector = TCPPortSelector()
|
||||
unavailable_ports = [Connection(("", p)) for p in range(65535)]
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||
)
|
||||
|
||||
assert get_free_tcp_port() is None
|
||||
assert tcp_port_selector.get_free_tcp_port() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_port", COMMON_PORTS)
|
||||
def test_tcp_port_selector__checks_common_ports_leases(common_port: int, monkeypatch):
|
||||
tcp_port_selector = TCPPortSelector()
|
||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not common_port]
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||
)
|
||||
|
||||
free_port_1 = tcp_port_selector.get_free_tcp_port()
|
||||
free_port_2 = tcp_port_selector.get_free_tcp_port()
|
||||
|
||||
assert free_port_1 == common_port
|
||||
assert free_port_2 != common_port
|
||||
assert free_port_2 is not None
|
||||
assert free_port_2 not in COMMON_PORTS
|
||||
|
|
Loading…
Reference in New Issue