diff --git a/monkey/infection_monkey/network/relay/relay_user_handler.py b/monkey/infection_monkey/network/relay/relay_user_handler.py index e9ee00d04..5755d2794 100644 --- a/monkey/infection_monkey/network/relay/relay_user_handler.py +++ b/monkey/infection_monkey/network/relay/relay_user_handler.py @@ -10,6 +10,7 @@ from common.utils.code_utils import del_key # Wait for potential new clients to connect DEFAULT_NEW_CLIENT_TIMEOUT = 2.5 * MEDIUM_REQUEST_TIMEOUT +DEFAULT_DISCONNECT_TIMEOUT = 60 * 10 # Wait up to 10 minutes for clients to disconnect @dataclass @@ -21,8 +22,13 @@ class RelayUser: class RelayUserHandler: """Manages membership to a network relay.""" - def __init__(self, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT): + def __init__( + self, + new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT, + client_disconnect_timeout: float = DEFAULT_DISCONNECT_TIMEOUT, + ): self._new_client_timeout = new_client_timeout + self._client_disconnect_timeout = client_disconnect_timeout self._relay_users: Dict[IPv4Address, RelayUser] = {} self._potential_users: Dict[IPv4Address, RelayUser] = {} @@ -41,6 +47,7 @@ class RelayUserHandler: del_key(self._potential_users, user_address) timer = EggTimer() + timer.set(self._client_disconnect_timeout) self._relay_users[user_address] = RelayUser(user_address, timer) def add_potential_user(self, user_address: IPv4Address): @@ -69,8 +76,22 @@ class RelayUserHandler: """ Return whether or not we have any potential users. """ - self._potential_users = dict( - filter(lambda ru: not ru[1].timer.is_expired(), self._potential_users.items()) - ) + with self._lock: + self._potential_users = RelayUserHandler._remove_expired_users(self._potential_users) - return len(self._potential_users) > 0 + return len(self._potential_users) > 0 + + def has_connected_users(self) -> bool: + """ + Return whether or not we have any relay users. + """ + with self._lock: + self._relay_users = RelayUserHandler._remove_expired_users(self._relay_users) + + return len(self._relay_users) > 0 + + @staticmethod + def _remove_expired_users( + user_list: Dict[IPv4Address, RelayUser] + ) -> Dict[IPv4Address, RelayUser]: + return dict(filter(lambda ru: not ru[1].timer.is_expired(), user_list.items())) diff --git a/monkey/infection_monkey/network/relay/tcp_relay.py b/monkey/infection_monkey/network/relay/tcp_relay.py index a12d8a0d9..f605cdc9d 100644 --- a/monkey/infection_monkey/network/relay/tcp_relay.py +++ b/monkey/infection_monkey/network/relay/tcp_relay.py @@ -36,7 +36,7 @@ class TCPRelay(Thread, InterruptableThreadMixin): """ Blocks until the users disconnect or the timeout has elapsed. """ - while self._user_handler.has_potential_users(): + while self._user_handler.has_potential_users() or self._user_handler.has_connected_users(): sleep(0.5) def _wait_for_pipes_to_close(self): diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py index ca0eb4103..6f3ecb8fa 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py @@ -33,3 +33,18 @@ def test_potential_users_time_out(): sleep(0.003) assert not handler.has_potential_users() + + +def test_relay_users_added(handler): + assert not handler.has_connected_users() + handler.add_relay_user(USER_ADDRESS) + assert handler.has_connected_users() + + +def test_relay_users_time_out(): + handler = RelayUserHandler(client_disconnect_timeout=0.001) + + handler.add_relay_user(USER_ADDRESS) + sleep(0.003) + + assert not handler.has_connected_users()