Merge pull request #2249 from guardicore/2216-wait-for-relay-users-to-disconnect

Agent: Wait for relay users to disconnect
This commit is contained in:
Kekoa Kaaikala 2022-09-07 08:12:52 -04:00 committed by Mike Salvatore
commit 0b27e12b0f
3 changed files with 42 additions and 6 deletions

View File

@ -10,6 +10,7 @@ from common.utils.code_utils import del_key
# Wait for potential new clients to connect
DEFAULT_NEW_CLIENT_TIMEOUT = 2.5 * MEDIUM_REQUEST_TIMEOUT
DEFAULT_DISCONNECT_TIMEOUT = 60 * 10 # Wait up to 10 minutes for clients to disconnect
@dataclass
@ -21,8 +22,13 @@ class RelayUser:
class RelayUserHandler:
"""Manages membership to a network relay."""
def __init__(self, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT):
def __init__(
self,
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
client_disconnect_timeout: float = DEFAULT_DISCONNECT_TIMEOUT,
):
self._new_client_timeout = new_client_timeout
self._client_disconnect_timeout = client_disconnect_timeout
self._relay_users: Dict[IPv4Address, RelayUser] = {}
self._potential_users: Dict[IPv4Address, RelayUser] = {}
@ -41,6 +47,7 @@ class RelayUserHandler:
del_key(self._potential_users, user_address)
timer = EggTimer()
timer.set(self._client_disconnect_timeout)
self._relay_users[user_address] = RelayUser(user_address, timer)
def add_potential_user(self, user_address: IPv4Address):
@ -69,8 +76,22 @@ class RelayUserHandler:
"""
Return whether or not we have any potential users.
"""
self._potential_users = dict(
filter(lambda ru: not ru[1].timer.is_expired(), self._potential_users.items())
)
with self._lock:
self._potential_users = RelayUserHandler._remove_expired_users(self._potential_users)
return len(self._potential_users) > 0
return len(self._potential_users) > 0
def has_connected_users(self) -> bool:
"""
Return whether or not we have any relay users.
"""
with self._lock:
self._relay_users = RelayUserHandler._remove_expired_users(self._relay_users)
return len(self._relay_users) > 0
@staticmethod
def _remove_expired_users(
user_list: Dict[IPv4Address, RelayUser]
) -> Dict[IPv4Address, RelayUser]:
return dict(filter(lambda ru: not ru[1].timer.is_expired(), user_list.items()))

View File

@ -36,7 +36,7 @@ class TCPRelay(Thread, InterruptableThreadMixin):
"""
Blocks until the users disconnect or the timeout has elapsed.
"""
while self._user_handler.has_potential_users():
while self._user_handler.has_potential_users() or self._user_handler.has_connected_users():
sleep(0.5)
def _wait_for_pipes_to_close(self):

View File

@ -33,3 +33,18 @@ def test_potential_users_time_out():
sleep(0.003)
assert not handler.has_potential_users()
def test_relay_users_added(handler):
assert not handler.has_connected_users()
handler.add_relay_user(USER_ADDRESS)
assert handler.has_connected_users()
def test_relay_users_time_out():
handler = RelayUserHandler(client_disconnect_timeout=0.001)
handler.add_relay_user(USER_ADDRESS)
sleep(0.003)
assert not handler.has_connected_users()