forked from p15670423/monkey
Agent: Add disconnect protocol to TCPRelay
This commit is contained in:
parent
cd0b3077cf
commit
4b5d93beb0
|
@ -32,6 +32,7 @@ class TCPRelay(Thread):
|
||||||
dest_port=self._target_port,
|
dest_port=self._target_port,
|
||||||
client_connected=self.on_user_connected,
|
client_connected=self.on_user_connected,
|
||||||
client_disconnected=self.on_user_disconnected,
|
client_disconnected=self.on_user_disconnected,
|
||||||
|
client_data_received=self.on_user_data_received,
|
||||||
)
|
)
|
||||||
proxy.start()
|
proxy.start()
|
||||||
|
|
||||||
|
@ -47,12 +48,12 @@ class TCPRelay(Thread):
|
||||||
def on_user_connected(self, user: str):
|
def on_user_connected(self, user: str):
|
||||||
"""Handle new user connection."""
|
"""Handle new user connection."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
self._potential_users = [u for u in self._potential_users if u.address != user]
|
||||||
self._relay_users.append(RelayUser(user))
|
self._relay_users.append(RelayUser(user))
|
||||||
|
|
||||||
def on_user_disconnected(self, user: str):
|
def on_user_disconnected(self, user: str):
|
||||||
"""Handle user disconnection."""
|
"""Handle user disconnection."""
|
||||||
with self._lock:
|
pass
|
||||||
self._relay_users = [u for u in self._relay_users if u.address != user]
|
|
||||||
|
|
||||||
def relay_users(self) -> List[RelayUser]:
|
def relay_users(self) -> List[RelayUser]:
|
||||||
"""Get the list of users connected to the relay."""
|
"""Get the list of users connected to the relay."""
|
||||||
|
@ -63,3 +64,13 @@ class TCPRelay(Thread):
|
||||||
"""Notify TCPRelay that a new user may try and connect."""
|
"""Notify TCPRelay that a new user may try and connect."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._potential_users.append(RelayUser(user))
|
self._potential_users.append(RelayUser(user))
|
||||||
|
|
||||||
|
def on_user_data_received(self, data: bytes, user: str) -> bool:
|
||||||
|
if data.startswith(b"-"):
|
||||||
|
self._disconnect_user(user)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _disconnect_user(self, user: str):
|
||||||
|
with self._lock:
|
||||||
|
self._relay_users = [u for u in self._relay_users if u.address != user]
|
||||||
|
|
|
@ -17,6 +17,10 @@ SOCKET_READ_TIMEOUT = 10
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_client_data_received(_: bytes, client=None) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SocketsPipe(Thread):
|
class SocketsPipe(Thread):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -24,6 +28,7 @@ class SocketsPipe(Thread):
|
||||||
dest,
|
dest,
|
||||||
timeout=SOCKET_READ_TIMEOUT,
|
timeout=SOCKET_READ_TIMEOUT,
|
||||||
client_disconnected: Callable[[str], None] = None,
|
client_disconnected: Callable[[str], None] = None,
|
||||||
|
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
|
||||||
):
|
):
|
||||||
Thread.__init__(self)
|
Thread.__init__(self)
|
||||||
self.source = source
|
self.source = source
|
||||||
|
@ -33,6 +38,7 @@ class SocketsPipe(Thread):
|
||||||
super(SocketsPipe, self).__init__()
|
super(SocketsPipe, self).__init__()
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
self._client_disconnected = client_disconnected
|
self._client_disconnected = client_disconnected
|
||||||
|
self._client_data_received = client_data_received
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
sockets = [self.source, self.dest]
|
sockets = [self.source, self.dest]
|
||||||
|
@ -47,7 +53,7 @@ class SocketsPipe(Thread):
|
||||||
data = r.recv(READ_BUFFER_SIZE)
|
data = r.recv(READ_BUFFER_SIZE)
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
break
|
||||||
if data:
|
if data and self._client_data_received(data):
|
||||||
try:
|
try:
|
||||||
other.sendall(data)
|
other.sendall(data)
|
||||||
update_last_serve_time()
|
update_last_serve_time()
|
||||||
|
@ -70,10 +76,12 @@ class TcpProxy(TransportProxyBase):
|
||||||
local_host="",
|
local_host="",
|
||||||
client_connected: Callable[[str], None] = None,
|
client_connected: Callable[[str], None] = None,
|
||||||
client_disconnected: Callable[[str], None] = None,
|
client_disconnected: Callable[[str], None] = None,
|
||||||
|
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
|
||||||
):
|
):
|
||||||
super().__init__(local_port, dest_host, dest_port, local_host)
|
super().__init__(local_port, dest_host, dest_port, local_host)
|
||||||
self._client_connected = client_connected
|
self._client_connected = client_connected
|
||||||
self._client_disconnected = client_disconnected
|
self._client_disconnected = client_disconnected
|
||||||
|
self._client_data_received = client_data_received
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
pipes = []
|
pipes = []
|
||||||
|
@ -99,7 +107,8 @@ class TcpProxy(TransportProxyBase):
|
||||||
on_disconnect = (
|
on_disconnect = (
|
||||||
partial(self._client_connected, address[0]) if self._client_connected else None
|
partial(self._client_connected, address[0]) if self._client_connected else None
|
||||||
)
|
)
|
||||||
pipe = SocketsPipe(source, dest, on_disconnect)
|
on_data_received = partial(self._client_data_received, client=address[0])
|
||||||
|
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
|
||||||
pipes.append(pipe)
|
pipes.append(pipe)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"piping sockets %s:%s->%s:%s",
|
"piping sockets %s:%s->%s:%s",
|
||||||
|
|
|
@ -29,11 +29,22 @@ def test_user_added():
|
||||||
assert users[0].address == new_user
|
assert users[0].address == new_user
|
||||||
|
|
||||||
|
|
||||||
def test_user_removed():
|
def test_user_not_removed_on_disconnect():
|
||||||
|
# A user should only be disconnected when they send a disconnect request
|
||||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||||
new_user = "0.0.0.1"
|
new_user = "0.0.0.1"
|
||||||
relay.on_user_connected(new_user)
|
relay.on_user_connected(new_user)
|
||||||
relay.on_user_disconnected(new_user)
|
relay.on_user_disconnected(new_user)
|
||||||
|
|
||||||
|
users = relay.relay_users()
|
||||||
|
assert len(users) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_removed_on_request():
|
||||||
|
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||||
|
new_user = "0.0.0.1"
|
||||||
|
relay.on_user_connected(new_user)
|
||||||
|
relay.on_user_data_received(b"-", "0.0.0.1")
|
||||||
|
|
||||||
users = relay.relay_users()
|
users = relay.relay_users()
|
||||||
assert len(users) == 0
|
assert len(users) == 0
|
||||||
|
|
Loading…
Reference in New Issue