diff --git a/monkey/infection_monkey/network/relay/relay_user_handler.py b/monkey/infection_monkey/network/relay/relay_user_handler.py index e9ee00d04..2c83f0b13 100644 --- a/monkey/infection_monkey/network/relay/relay_user_handler.py +++ b/monkey/infection_monkey/network/relay/relay_user_handler.py @@ -10,6 +10,7 @@ from common.utils.code_utils import del_key # Wait for potential new clients to connect DEFAULT_NEW_CLIENT_TIMEOUT = 2.5 * MEDIUM_REQUEST_TIMEOUT +DEFAULT_DISCONNECT_TIMEOUT = 60 * 10 # Wait up to 10 minutes for clients to disconnect @dataclass @@ -21,8 +22,13 @@ class RelayUser: class RelayUserHandler: """Manages membership to a network relay.""" - def __init__(self, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT): + def __init__( + self, + new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT, + client_disconnect_timeout: float = DEFAULT_DISCONNECT_TIMEOUT, + ): self._new_client_timeout = new_client_timeout + self._client_disconnect_timeout = client_disconnect_timeout self._relay_users: Dict[IPv4Address, RelayUser] = {} self._potential_users: Dict[IPv4Address, RelayUser] = {} @@ -41,6 +47,7 @@ class RelayUserHandler: del_key(self._potential_users, user_address) timer = EggTimer() + timer.set(self._client_disconnect_timeout) self._relay_users[user_address] = RelayUser(user_address, timer) def add_potential_user(self, user_address: IPv4Address): @@ -74,3 +81,13 @@ class RelayUserHandler: ) return len(self._potential_users) > 0 + + def has_connected_users(self) -> bool: + """ + Return whether or not we have any relay users. + """ + self._relay_users = dict( + filter(lambda ru: not ru[1].timer.is_expired(), self._relay_users.items()) + ) + + return len(self._relay_users) > 0 diff --git a/monkey/infection_monkey/network/relay/tcp_relay.py b/monkey/infection_monkey/network/relay/tcp_relay.py index a12d8a0d9..f605cdc9d 100644 --- a/monkey/infection_monkey/network/relay/tcp_relay.py +++ b/monkey/infection_monkey/network/relay/tcp_relay.py @@ -36,7 +36,7 @@ class TCPRelay(Thread, InterruptableThreadMixin): """ Blocks until the users disconnect or the timeout has elapsed. """ - while self._user_handler.has_potential_users(): + while self._user_handler.has_potential_users() or self._user_handler.has_connected_users(): sleep(0.5) def _wait_for_pipes_to_close(self): diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py index ca0eb4103..6f3ecb8fa 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py @@ -33,3 +33,18 @@ def test_potential_users_time_out(): sleep(0.003) assert not handler.has_potential_users() + + +def test_relay_users_added(handler): + assert not handler.has_connected_users() + handler.add_relay_user(USER_ADDRESS) + assert handler.has_connected_users() + + +def test_relay_users_time_out(): + handler = RelayUserHandler(client_disconnect_timeout=0.001) + + handler.add_relay_user(USER_ADDRESS) + sleep(0.003) + + assert not handler.has_connected_users()