Merge pull request #2281 from guardicore/2216-prevent-port-collisions
Agent: Add TCPPortSelector
This commit is contained in:
commit
905fb3563e
|
@ -4,10 +4,12 @@ import struct
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from ipaddress import IPv4Interface
|
from ipaddress import IPv4Interface
|
||||||
from random import shuffle # noqa: DUO102
|
from random import shuffle # noqa: DUO102
|
||||||
from typing import List
|
from threading import Lock
|
||||||
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
import netifaces
|
import netifaces
|
||||||
import psutil
|
import psutil
|
||||||
|
from egg_timer import EggTimer
|
||||||
|
|
||||||
from infection_monkey.utils.environment import is_windows_os
|
from infection_monkey.utils.environment import is_windows_os
|
||||||
|
|
||||||
|
@ -120,20 +122,93 @@ else:
|
||||||
return routes
|
return routes
|
||||||
|
|
||||||
|
|
||||||
def get_free_tcp_port(min_range=1024, max_range=65535):
|
class TCPPortSelector:
|
||||||
|
"""
|
||||||
|
Select an available TCP port that a new server can listen on
|
||||||
|
|
||||||
in_use = {conn.laddr[1] for conn in psutil.net_connections()}
|
Examines the system to find which ports are not in use and makes an intelligent decision
|
||||||
|
regarding what port can be used to host a server. In multithreaded applications, a race occurs
|
||||||
|
between the time when the OS reports that a port is free and when the port is actually used. In
|
||||||
|
other words, two threads which request a free port simultaneously may be handed the same port,
|
||||||
|
as the OS will report that the port is not in use. To combat this, the TCPPortSelector will
|
||||||
|
reserve a port for a period of time to give the requester ample time to start their server. Once
|
||||||
|
the requester's server is listening on the port, the OS will report the port as "LISTEN".
|
||||||
|
"""
|
||||||
|
|
||||||
for port in COMMON_PORTS:
|
def __init__(self):
|
||||||
if port not in in_use:
|
self._leases: Dict[int, EggTimer] = {}
|
||||||
return port
|
self._lock = Lock()
|
||||||
|
|
||||||
min_range = max(1, min_range)
|
def get_free_tcp_port(
|
||||||
max_range = min(65535, max_range)
|
self, min_range: int = 1024, max_range: int = 65535, lease_time_sec: float = 30
|
||||||
ports = list(range(min_range, max_range))
|
):
|
||||||
shuffle(ports)
|
"""
|
||||||
for port in ports:
|
Get a free TCP port that a new server can listen on
|
||||||
if port not in in_use:
|
|
||||||
return port
|
|
||||||
|
|
||||||
return None
|
This function will attempt to provide a well-known port that the caller can listen on. If no
|
||||||
|
well-known ports are available, a random port will be selected.
|
||||||
|
|
||||||
|
:param min_range: The smallest port number a random port can be chosen from, defaults to
|
||||||
|
1024
|
||||||
|
:param max_range: The largest port number a random port can be chosen from, defaults to
|
||||||
|
65535
|
||||||
|
:param lease_time_sec: The amount of time a port should be reserved for if the OS does not report
|
||||||
|
it as in use, defaults to 30 seconds
|
||||||
|
:return: A TCP port number
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
ports_in_use = {conn.laddr[1] for conn in psutil.net_connections()}
|
||||||
|
|
||||||
|
common_port = self._get_free_common_port(ports_in_use, lease_time_sec)
|
||||||
|
if common_port is not None:
|
||||||
|
return common_port
|
||||||
|
|
||||||
|
return self._get_free_random_port(ports_in_use, min_range, max_range, lease_time_sec)
|
||||||
|
|
||||||
|
def _get_free_common_port(self, ports_in_use: Set[int], lease_time_sec):
|
||||||
|
for port in COMMON_PORTS:
|
||||||
|
if self._port_is_available(port, ports_in_use):
|
||||||
|
self._reserve_port(port, lease_time_sec)
|
||||||
|
return port
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_free_random_port(
|
||||||
|
self, ports_in_use: Set[int], min_range: int, max_range: int, lease_time_sec: float
|
||||||
|
):
|
||||||
|
min_range = max(1, min_range)
|
||||||
|
max_range = min(65535, max_range)
|
||||||
|
ports = list(range(min_range, max_range))
|
||||||
|
shuffle(ports)
|
||||||
|
for port in ports:
|
||||||
|
if self._port_is_available(port, ports_in_use):
|
||||||
|
self._reserve_port(port, lease_time_sec)
|
||||||
|
return port
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _port_is_available(self, port: int, ports_in_use: Set[int]):
|
||||||
|
if port in ports_in_use:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if port not in self._leases:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self._leases[port].is_expired():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _reserve_port(self, port: int, lease_time_sec: float):
|
||||||
|
timer = EggTimer()
|
||||||
|
timer.set(lease_time_sec)
|
||||||
|
self._leases[port] = timer
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This function is here because existing components rely on it. Refactor these components to
|
||||||
|
# accept a TCPPortSelector instance and use that instead.
|
||||||
|
def get_free_tcp_port(min_range=1024, max_range=65535, lease_time_sec=30):
|
||||||
|
return get_free_tcp_port.port_selector.get_free_tcp_port(min_range, max_range, lease_time_sec)
|
||||||
|
|
||||||
|
|
||||||
|
get_free_tcp_port.port_selector = TCPPortSelector() # type: ignore[attr-defined]
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from infection_monkey.network.info import get_free_tcp_port
|
from infection_monkey.network.info import TCPPortSelector
|
||||||
from infection_monkey.network.ports import COMMON_PORTS
|
from infection_monkey.network.ports import COMMON_PORTS
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,28 +13,48 @@ class Connection:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("port", COMMON_PORTS)
|
@pytest.mark.parametrize("port", COMMON_PORTS)
|
||||||
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
|
def test_tcp_port_selector__checks_common_ports(port: int, monkeypatch):
|
||||||
|
tcp_port_selector = TCPPortSelector()
|
||||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
|
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||||
)
|
)
|
||||||
assert get_free_tcp_port() is port
|
assert tcp_port_selector.get_free_tcp_port() is port
|
||||||
|
|
||||||
|
|
||||||
def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
|
def test_tcp_port_selector__checks_other_ports_if_common_ports_unavailable(monkeypatch):
|
||||||
|
tcp_port_selector = TCPPortSelector()
|
||||||
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
|
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||||
)
|
)
|
||||||
|
|
||||||
assert get_free_tcp_port() is not None
|
assert tcp_port_selector.get_free_tcp_port() is not None
|
||||||
|
|
||||||
|
|
||||||
def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
|
def test_tcp_port_selector__none_if_no_available_ports(monkeypatch):
|
||||||
|
tcp_port_selector = TCPPortSelector()
|
||||||
unavailable_ports = [Connection(("", p)) for p in range(65535)]
|
unavailable_ports = [Connection(("", p)) for p in range(65535)]
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||||
)
|
)
|
||||||
|
|
||||||
assert get_free_tcp_port() is None
|
assert tcp_port_selector.get_free_tcp_port() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("common_port", COMMON_PORTS)
|
||||||
|
def test_tcp_port_selector__checks_common_ports_leases(common_port: int, monkeypatch):
|
||||||
|
tcp_port_selector = TCPPortSelector()
|
||||||
|
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not common_port]
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
|
||||||
|
)
|
||||||
|
|
||||||
|
free_port_1 = tcp_port_selector.get_free_tcp_port()
|
||||||
|
free_port_2 = tcp_port_selector.get_free_tcp_port()
|
||||||
|
|
||||||
|
assert free_port_1 == common_port
|
||||||
|
assert free_port_2 != common_port
|
||||||
|
assert free_port_2 is not None
|
||||||
|
assert free_port_2 not in COMMON_PORTS
|
||||||
|
|
Loading…
Reference in New Issue