forked from p15670423/monkey
Agent: Add disconnect protocol to TCPRelay
This commit is contained in:
parent
cd0b3077cf
commit
4b5d93beb0
|
@ -32,6 +32,7 @@ class TCPRelay(Thread):
|
|||
dest_port=self._target_port,
|
||||
client_connected=self.on_user_connected,
|
||||
client_disconnected=self.on_user_disconnected,
|
||||
client_data_received=self.on_user_data_received,
|
||||
)
|
||||
proxy.start()
|
||||
|
||||
|
@ -47,12 +48,12 @@ class TCPRelay(Thread):
|
|||
def on_user_connected(self, user: str):
|
||||
"""Handle new user connection."""
|
||||
with self._lock:
|
||||
self._potential_users = [u for u in self._potential_users if u.address != user]
|
||||
self._relay_users.append(RelayUser(user))
|
||||
|
||||
def on_user_disconnected(self, user: str):
|
||||
"""Handle user disconnection."""
|
||||
with self._lock:
|
||||
self._relay_users = [u for u in self._relay_users if u.address != user]
|
||||
pass
|
||||
|
||||
def relay_users(self) -> List[RelayUser]:
|
||||
"""Get the list of users connected to the relay."""
|
||||
|
@ -63,3 +64,13 @@ class TCPRelay(Thread):
|
|||
"""Notify TCPRelay that a new user may try and connect."""
|
||||
with self._lock:
|
||||
self._potential_users.append(RelayUser(user))
|
||||
|
||||
def on_user_data_received(self, data: bytes, user: str) -> bool:
|
||||
if data.startswith(b"-"):
|
||||
self._disconnect_user(user)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _disconnect_user(self, user: str):
|
||||
with self._lock:
|
||||
self._relay_users = [u for u in self._relay_users if u.address != user]
|
||||
|
|
|
@ -17,6 +17,10 @@ SOCKET_READ_TIMEOUT = 10
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def _default_client_data_received(_: bytes, client=None) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SocketsPipe(Thread):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -24,6 +28,7 @@ class SocketsPipe(Thread):
|
|||
dest,
|
||||
timeout=SOCKET_READ_TIMEOUT,
|
||||
client_disconnected: Callable[[str], None] = None,
|
||||
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
|
||||
):
|
||||
Thread.__init__(self)
|
||||
self.source = source
|
||||
|
@ -33,6 +38,7 @@ class SocketsPipe(Thread):
|
|||
super(SocketsPipe, self).__init__()
|
||||
self.daemon = True
|
||||
self._client_disconnected = client_disconnected
|
||||
self._client_data_received = client_data_received
|
||||
|
||||
def run(self):
|
||||
sockets = [self.source, self.dest]
|
||||
|
@ -47,7 +53,7 @@ class SocketsPipe(Thread):
|
|||
data = r.recv(READ_BUFFER_SIZE)
|
||||
except Exception:
|
||||
break
|
||||
if data:
|
||||
if data and self._client_data_received(data):
|
||||
try:
|
||||
other.sendall(data)
|
||||
update_last_serve_time()
|
||||
|
@ -70,10 +76,12 @@ class TcpProxy(TransportProxyBase):
|
|||
local_host="",
|
||||
client_connected: Callable[[str], None] = None,
|
||||
client_disconnected: Callable[[str], None] = None,
|
||||
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
|
||||
):
|
||||
super().__init__(local_port, dest_host, dest_port, local_host)
|
||||
self._client_connected = client_connected
|
||||
self._client_disconnected = client_disconnected
|
||||
self._client_data_received = client_data_received
|
||||
|
||||
def run(self):
|
||||
pipes = []
|
||||
|
@ -99,7 +107,8 @@ class TcpProxy(TransportProxyBase):
|
|||
on_disconnect = (
|
||||
partial(self._client_connected, address[0]) if self._client_connected else None
|
||||
)
|
||||
pipe = SocketsPipe(source, dest, on_disconnect)
|
||||
on_data_received = partial(self._client_data_received, client=address[0])
|
||||
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
|
||||
pipes.append(pipe)
|
||||
logger.debug(
|
||||
"piping sockets %s:%s->%s:%s",
|
||||
|
|
|
@ -29,11 +29,22 @@ def test_user_added():
|
|||
assert users[0].address == new_user
|
||||
|
||||
|
||||
def test_user_removed():
|
||||
def test_user_not_removed_on_disconnect():
|
||||
# A user should only be disconnected when they send a disconnect request
|
||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||
new_user = "0.0.0.1"
|
||||
relay.on_user_connected(new_user)
|
||||
relay.on_user_disconnected(new_user)
|
||||
|
||||
users = relay.relay_users()
|
||||
assert len(users) == 1
|
||||
|
||||
|
||||
def test_user_removed_on_request():
|
||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||
new_user = "0.0.0.1"
|
||||
relay.on_user_connected(new_user)
|
||||
relay.on_user_data_received(b"-", "0.0.0.1")
|
||||
|
||||
users = relay.relay_users()
|
||||
assert len(users) == 0
|
||||
|
|
Loading…
Reference in New Issue