Agent: Add disconnect protocol to TCPRelay

This commit is contained in:
Kekoa Kaaikala 2022-08-31 12:06:13 +00:00
parent cd0b3077cf
commit 4b5d93beb0
3 changed files with 36 additions and 5 deletions

View File

@ -32,6 +32,7 @@ class TCPRelay(Thread):
dest_port=self._target_port,
client_connected=self.on_user_connected,
client_disconnected=self.on_user_disconnected,
client_data_received=self.on_user_data_received,
)
proxy.start()
@ -47,12 +48,12 @@ class TCPRelay(Thread):
def on_user_connected(self, user: str):
"""Handle new user connection."""
with self._lock:
self._potential_users = [u for u in self._potential_users if u.address != user]
self._relay_users.append(RelayUser(user))
def on_user_disconnected(self, user: str):
"""Handle user disconnection."""
with self._lock:
self._relay_users = [u for u in self._relay_users if u.address != user]
pass
def relay_users(self) -> List[RelayUser]:
"""Get the list of users connected to the relay."""
@ -63,3 +64,13 @@ class TCPRelay(Thread):
"""Notify TCPRelay that a new user may try and connect."""
with self._lock:
self._potential_users.append(RelayUser(user))
def on_user_data_received(self, data: bytes, user: str) -> bool:
if data.startswith(b"-"):
self._disconnect_user(user)
return False
return True
def _disconnect_user(self, user: str):
with self._lock:
self._relay_users = [u for u in self._relay_users if u.address != user]

View File

@ -17,6 +17,10 @@ SOCKET_READ_TIMEOUT = 10
logger = getLogger(__name__)
def _default_client_data_received(_: bytes, client=None) -> bool:
return True
class SocketsPipe(Thread):
def __init__(
self,
@ -24,6 +28,7 @@ class SocketsPipe(Thread):
dest,
timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
):
Thread.__init__(self)
self.source = source
@ -33,6 +38,7 @@ class SocketsPipe(Thread):
super(SocketsPipe, self).__init__()
self.daemon = True
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
sockets = [self.source, self.dest]
@ -47,7 +53,7 @@ class SocketsPipe(Thread):
data = r.recv(READ_BUFFER_SIZE)
except Exception:
break
if data:
if data and self._client_data_received(data):
try:
other.sendall(data)
update_last_serve_time()
@ -70,10 +76,12 @@ class TcpProxy(TransportProxyBase):
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
pipes = []
@ -99,7 +107,8 @@ class TcpProxy(TransportProxyBase):
on_disconnect = (
partial(self._client_connected, address[0]) if self._client_connected else None
)
pipe = SocketsPipe(source, dest, on_disconnect)
on_data_received = partial(self._client_data_received, client=address[0])
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
pipes.append(pipe)
logger.debug(
"piping sockets %s:%s->%s:%s",

View File

@ -29,11 +29,22 @@ def test_user_added():
assert users[0].address == new_user
def test_user_removed():
def test_user_not_removed_on_disconnect():
# A user should only be disconnected when they send a disconnect request
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
relay.on_user_disconnected(new_user)
users = relay.relay_users()
assert len(users) == 1
def test_user_removed_on_request():
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
relay.on_user_data_received(b"-", "0.0.0.1")
users = relay.relay_users()
assert len(users) == 0