diff --git a/monkey/common/utils/code_utils.py b/monkey/common/utils/code_utils.py index db87903bc..2a2804d54 100644 --- a/monkey/common/utils/code_utils.py +++ b/monkey/common/utils/code_utils.py @@ -1,5 +1,6 @@ import queue -from typing import Any, Dict, List, MutableMapping, Type, TypeVar +from bisect import bisect_left +from typing import Any, Dict, List, MutableMapping, Sequence, Type, TypeVar T = TypeVar("T") @@ -48,3 +49,15 @@ def del_key(mapping: MutableMapping[T, Any], key: T): del mapping[key] except KeyError: pass + + +def in_sorted_sequence(item: Any, seq: Sequence[Any]) -> bool: + """ + Provides fast search in the case that the sequence is sorted. + + :param item: The item to search in the list. + :param seq: The sorted sequence in which to search the item. + :return: True if the item was found in the list, otherwise false. + """ + i = bisect_left(seq, item) + return i != len(seq) and seq[i] == item diff --git a/monkey/infection_monkey/network/info.py b/monkey/infection_monkey/network/info.py index 12748b8a0..9d3347e79 100644 --- a/monkey/infection_monkey/network/info.py +++ b/monkey/infection_monkey/network/info.py @@ -3,14 +3,17 @@ import socket import struct from dataclasses import dataclass from ipaddress import IPv4Interface -from random import randint # noqa: DUO102 +from random import shuffle # noqa: DUO102 from typing import List import netifaces import psutil +from common.utils.code_utils import in_sorted_sequence from infection_monkey.utils.environment import is_windows_os +from .ports import COMMON_PORTS + # Timeout for monkey connections LOOPBACK_NAME = b"lo" SIOCGIFADDR = 0x8915 # get PA address @@ -119,15 +122,19 @@ else: def get_free_tcp_port(min_range=1024, max_range=65535): + + in_use = sorted([conn.laddr[1] for conn in psutil.net_connections()]) + + for port in COMMON_PORTS: + if not in_sorted_sequence(port, in_use): + return port + min_range = max(1, min_range) max_range = min(65535, max_range) - - in_use = [conn.laddr[1] for conn in psutil.net_connections()] - - for i in range(min_range, max_range): - port = randint(min_range, max_range) - - if port not in in_use: + ports = list(range(min_range, max_range)) + shuffle(ports) + for port in ports: + if not in_sorted_sequence(port, in_use): return port return None diff --git a/monkey/infection_monkey/network/ports.py b/monkey/infection_monkey/network/ports.py new file mode 100644 index 000000000..e1b5e4e22 --- /dev/null +++ b/monkey/infection_monkey/network/ports.py @@ -0,0 +1,15 @@ +from typing import List + +COMMON_PORTS: List[int] = [ + 1025, # NFS, IIS + 1433, # Microsoft SQL Server + 1434, # Microsoft SQL Monitor + 1720, # h323q931 + 1723, # Microsoft PPTP VPN + 3306, # mysql + 3389, # Windows Terminal Server (RDP) + 5900, # vnc + 6001, # X11:1 + 8080, # http-proxy + 8888, # sun-answerbook +] diff --git a/monkey/tests/unit_tests/common/utils/test_code_utils.py b/monkey/tests/unit_tests/common/utils/test_code_utils.py index e5980723d..676b8cb88 100644 --- a/monkey/tests/unit_tests/common/utils/test_code_utils.py +++ b/monkey/tests/unit_tests/common/utils/test_code_utils.py @@ -1,6 +1,6 @@ from queue import Queue -from common.utils.code_utils import del_key, queue_to_list +from common.utils.code_utils import del_key, in_sorted_sequence, queue_to_list def test_empty_queue_to_empty_list(): @@ -40,3 +40,11 @@ def test_del_key__nonexistant_key(): # This test passes if the following call does not raise an error del_key(my_dict, key_to_delete) + + +def test_in_sorted_sequence__finds_item(): + assert in_sorted_sequence(99, range(100)) + + +def test_in_sorted_sequence__does_not_find_nonexistent_item(): + assert not in_sorted_sequence(101, range(100)) diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_info.py b/monkey/tests/unit_tests/infection_monkey/network/test_info.py new file mode 100644 index 000000000..8dab89e4b --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network/test_info.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import Tuple + +import pytest + +from infection_monkey.network.info import get_free_tcp_port +from infection_monkey.network.ports import COMMON_PORTS + + +@dataclass +class Connection: + laddr: Tuple[str, int] + + +@pytest.mark.parametrize("port", COMMON_PORTS) +def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch): + unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port] + + monkeypatch.setattr( + "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports + ) + assert get_free_tcp_port() is port + + +def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch): + unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS] + monkeypatch.setattr( + "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports + ) + + assert get_free_tcp_port() is not None + + +def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch): + unavailable_ports = [Connection(("", p)) for p in range(65535)] + monkeypatch.setattr( + "infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports + ) + + assert get_free_tcp_port() is None