From a9edbb2874891cd0d1204ef1da9abce682e7c4b1 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Mon, 12 Sep 2022 09:55:04 -0400 Subject: [PATCH 1/3] Agent: Add ThreadSafeIterator --- monkey/infection_monkey/utils/threading.py | 20 +++++++++++++++++-- .../infection_monkey/utils/test_threading.py | 10 ++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/monkey/infection_monkey/utils/threading.py b/monkey/infection_monkey/utils/threading.py index be28aa0b1..6d8b28253 100644 --- a/monkey/infection_monkey/utils/threading.py +++ b/monkey/infection_monkey/utils/threading.py @@ -1,8 +1,8 @@ import logging from functools import wraps from itertools import count -from threading import Event, Thread -from typing import Any, Callable, Iterable, Optional, Tuple +from threading import Event, Lock, Thread +from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar logger = logging.getLogger(__name__) @@ -116,3 +116,19 @@ class InterruptableThreadMixin: def stop(self): """Stop a running thread.""" self._interrupted.set() + + +T = TypeVar("T") + + +class ThreadSafeIterator(Iterator[T]): + """Provides a thread-safe iterator that wraps another iterator""" + + def __init__(self, iterator: Iterator[T]): + self._lock = Lock() + self._iterator = iterator + + def __next__(self) -> T: + while True: + with self._lock: + return next(self._iterator) diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py index 96a289096..05b813b66 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_threading.py @@ -1,8 +1,10 @@ import logging +from itertools import zip_longest from threading import Event, current_thread from typing import Any from infection_monkey.utils.threading import ( + ThreadSafeIterator, create_daemon_thread, interruptible_function, interruptible_iter, @@ -127,3 +129,11 @@ def test_interruptible_decorator_returns_default_value_on_interrupt(): assert return_value == 777 assert fn.call_count == 0 + + +def test_thread_safe_iterator(): + test_list = [1, 2, 3, 4, 5] + tsi = ThreadSafeIterator(test_list.__iter__()) + + for actual, expected in zip_longest(tsi, test_list): + assert actual == expected From 67893b7825409783681e8dd2896bf849635ca537 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Mon, 12 Sep 2022 10:01:13 -0400 Subject: [PATCH 2/3] Agent: Find/check island servers concurrently --- .../infection_monkey/network/relay/utils.py | 69 ++++++++++++++----- .../network/relay/test_utils.py | 12 +++- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 4a5532c30..a19f316db 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -1,42 +1,75 @@ import logging import socket +from contextlib import suppress from ipaddress import IPv4Address -from typing import Iterable, Optional +from typing import Dict, Iterable, Iterator, MutableMapping, Optional import requests from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from common.network.network_utils import address_to_ip_port from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST -from infection_monkey.utils.threading import create_daemon_thread +from infection_monkey.utils.threading import ( + ThreadSafeIterator, + create_daemon_thread, + run_worker_threads, +) logger = logging.getLogger(__name__) +# The number of Island servers to test simultaneously. 32 threads seems large enough for all +# practical purposes. Revisit this if it's not. +NUM_FIND_SERVER_WORKERS = 32 + def find_server(servers: Iterable[str]) -> Optional[str]: - for server in servers: - logger.debug(f"Trying to connect to server: {server}") + server_list = list(servers) + server_iterator = ThreadSafeIterator(server_list.__iter__()) + server_results: Dict[str, bool] = {} - try: - requests.get( # noqa: DUO123 - f"https://{server}/api?action=is-up", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) + run_worker_threads( + _find_island_server, + "FindIslandServer", + args=(server_iterator, server_results), + num_workers=NUM_FIND_SERVER_WORKERS, + ) + for server in server_list: + if server_results[server]: return server - except requests.exceptions.ConnectionError as err: - logger.error(f"Unable to connect to server/relay {server}: {err}") - except TimeoutError as err: - logger.error(f"Timed out while connecting to server/relay {server}: {err}") - except Exception as err: - logger.error( - f"Exception encountered when trying to connect to server/relay {server}: {err}" - ) return None +def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]): + with suppress(StopIteration): + server = next(servers) + server_status[server] = _check_if_island_server(server) + + +def _check_if_island_server(server: str) -> bool: + logger.debug(f"Trying to connect to server: {server}") + + try: + requests.get( # noqa: DUO123 + f"https://{server}/api?action=is-up", + verify=False, + timeout=MEDIUM_REQUEST_TIMEOUT, + ) + + return True + except requests.exceptions.ConnectionError as err: + logger.error(f"Unable to connect to server/relay {server}: {err}") + except TimeoutError as err: + logger.error(f"Timed out while connecting to server/relay {server}: {err}") + except Exception as err: + logger.error( + f"Exception encountered when trying to connect to server/relay {server}: {err}" + ) + + return False + + def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): for server in servers: t = create_daemon_thread( diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index 1acd08012..ac7eb1b16 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -20,7 +20,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4] ( SERVER_2, [(SERVER_1, {"exc": requests.exceptions.ConnectionError})] - + [(server, {"text": ""}) for server in servers[1:]], + + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item] ), ], ) @@ -30,3 +30,13 @@ def test_find_server(expected_server, server_response_pairs): mock.get(f"https://{server}/api?action=is-up", **response) assert find_server(servers) is expected_server + + +def test_find_server__multiple_successes(): + with requests_mock.Mocker() as mock: + mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError) + mock.get(f"https://{SERVER_2}/api?action=is-up", text="") + mock.get(f"https://{SERVER_3}/api?action=is-up", text="") + mock.get(f"https://{SERVER_4}/api?action=is-up", text="") + + assert find_server(servers) == SERVER_2 From a01785838d5a9ec39a103371b08c45a3b96c8154 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Mon, 12 Sep 2022 10:24:26 -0400 Subject: [PATCH 3/3] Agent: Pass keep_tunnel_open timeout as new_client_timeout --- monkey/infection_monkey/network/relay/tcp_relay.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monkey/infection_monkey/network/relay/tcp_relay.py b/monkey/infection_monkey/network/relay/tcp_relay.py index 8e62a2169..5fac87ecc 100644 --- a/monkey/infection_monkey/network/relay/tcp_relay.py +++ b/monkey/infection_monkey/network/relay/tcp_relay.py @@ -23,7 +23,10 @@ class TCPRelay(Thread, InterruptableThreadMixin): dest_port: int, client_disconnect_timeout: float, ): - self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout) + self._user_handler = RelayUserHandler( + new_client_timeout=client_disconnect_timeout, + client_disconnect_timeout=client_disconnect_timeout, + ) self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port) relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler) self._connection_handler = TCPConnectionHandler(