Agent: Add TCPPortSelector

This commit is contained in:
Mike Salvatore 2022-09-13 12:17:45 -04:00
parent 69063de627
commit 04d79a0a35
2 changed files with 116 additions and 21 deletions

View File

@ -4,10 +4,12 @@ import struct
from dataclasses import dataclass
from ipaddress import IPv4Interface
from random import shuffle # noqa: DUO102
from typing import List
from threading import Lock
from typing import Dict, List, Set
import netifaces
import psutil
from egg_timer import EggTimer
from infection_monkey.utils.environment import is_windows_os
@ -120,20 +122,93 @@ else:
return routes
def get_free_tcp_port(min_range=1024, max_range=65535):
class TCPPortSelector:
"""
Select an available TCP port that a new server can listen on
in_use = {conn.laddr[1] for conn in psutil.net_connections()}
Examines the system to find which ports are not in use and makes an intelligent decision
regarding what port can be used to host a server. In multithreaded applications, a race occurs
between the time when the OS reports that a port is free and when the port is actually used. In
other words, two threads which request a free port simultaneously may be handed the same port,
as the OS will report that the port is not in use. To combat this, the TCPPortSelector will
reserve a port for a period of time to give the requester ample time to start their server. Once
the requester's server is listening on the port, the OS will report the port as "LISTEN".
"""
for port in COMMON_PORTS:
if port not in in_use:
return port
def __init__(self):
self._leases: Dict[int, EggTimer] = {}
self._lock = Lock()
min_range = max(1, min_range)
max_range = min(65535, max_range)
ports = list(range(min_range, max_range))
shuffle(ports)
for port in ports:
if port not in in_use:
return port
def get_free_tcp_port(
self, min_range: int = 1024, max_range: int = 65535, lease_time_sec: float = 30
):
"""
Get a free TCP port that a new server can listen on
return None
This function will attempt to provide a well-known port that the caller can listen on. If no
well-known ports are available, a random port will be selected.
:param min_range: The smallest port number a random port can be chosen from, defaults to
1024
:param max_range: The largest port number a random port can be chosen from, defaults to
65535
:lease_time_sec: The amount of time a port should be reserved for if the OS does not report
it as in use, defaults to 30 seconds
:return: A TCP port number
"""
with self._lock:
ports_in_use = {conn.laddr[1] for conn in psutil.net_connections()}
common_port = self._get_free_common_port(ports_in_use, lease_time_sec)
if common_port is not None:
return common_port
return self._get_free_random_port(ports_in_use, min_range, max_range, lease_time_sec)
def _get_free_common_port(self, ports_in_use: Set[int], lease_time_sec):
for port in COMMON_PORTS:
if self._port_is_available(port, ports_in_use):
self._reserve_port(port, lease_time_sec)
return port
return None
def _get_free_random_port(
self, ports_in_use: Set[int], min_range: int, max_range: int, lease_time_sec: float
):
min_range = max(1, min_range)
max_range = min(65535, max_range)
ports = list(range(min_range, max_range))
shuffle(ports)
for port in ports:
if self._port_is_available(port, ports_in_use):
self._reserve_port(port, lease_time_sec)
return port
return None
def _port_is_available(self, port: int, ports_in_use: Set[int]):
if port in ports_in_use:
return False
if port not in self._leases:
return True
if self._leases[port].is_expired():
return True
return False
def _reserve_port(self, port: int, lease_time_sec: float):
timer = EggTimer()
timer.set(lease_time_sec)
self._leases[port] = timer
# TODO: This function is here because existing components rely on it. Refactor these components to
# accept a TCPPortSelector instance and use that instead.
def get_free_tcp_port(min_range=1024, max_range=65535, lease_time_sec=30):
return get_free_tcp_port.port_selector.get_free_tcp_port(min_range, max_range, lease_time_sec)
get_free_tcp_port.port_selector = TCPPortSelector() # type: ignore[attr-defined]

View File

@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.info import TCPPortSelector
from infection_monkey.network.ports import COMMON_PORTS
@ -13,28 +13,48 @@ class Connection:
@pytest.mark.parametrize("port", COMMON_PORTS)
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
def test_tcp_port_selector__checks_common_ports(port: int, monkeypatch):
tcp_port_selector = TCPPortSelector()
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is port
assert tcp_port_selector.get_free_tcp_port() is port
def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
def test_tcp_port_selector__checks_other_ports_if_common_ports_unavailable(monkeypatch):
tcp_port_selector = TCPPortSelector()
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is not None
assert tcp_port_selector.get_free_tcp_port() is not None
def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
def test_tcp_port_selector__none_if_no_available_ports(monkeypatch):
tcp_port_selector = TCPPortSelector()
unavailable_ports = [Connection(("", p)) for p in range(65535)]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is None
assert tcp_port_selector.get_free_tcp_port() is None
@pytest.mark.parametrize("common_port", COMMON_PORTS)
def test_tcp_port_selector__checks_common_ports_leases(common_port: int, monkeypatch):
tcp_port_selector = TCPPortSelector()
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not common_port]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
free_port_1 = tcp_port_selector.get_free_tcp_port()
free_port_2 = tcp_port_selector.get_free_tcp_port()
assert free_port_1 == common_port
assert free_port_2 != common_port
assert free_port_2 is not None
assert free_port_2 not in COMMON_PORTS