Agent: Refactor TCPRelay
Integrate TCPConnectionHandler and RelayUserHandler into TCPRelay Remove TCPProxy
This commit is contained in:
parent
d6931a6414
commit
0e869462b5
|
@ -0,0 +1,3 @@
|
||||||
|
from .relay_user_handler import RelayUser, RelayUserHandler
|
||||||
|
from .tcp_connection_handler import TCPConnectionHandler
|
||||||
|
from .tcp import SocketsPipe
|
|
@ -1,15 +1,9 @@
|
||||||
import select
|
import select
|
||||||
import socket
|
|
||||||
from functools import partial
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from infection_monkey.transport.base import (
|
from infection_monkey.transport.base import update_last_serve_time
|
||||||
PROXY_TIMEOUT,
|
|
||||||
TransportProxyBase,
|
|
||||||
update_last_serve_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
READ_BUFFER_SIZE = 8192
|
READ_BUFFER_SIZE = 8192
|
||||||
SOCKET_READ_TIMEOUT = 10
|
SOCKET_READ_TIMEOUT = 10
|
||||||
|
@ -65,63 +59,3 @@ class SocketsPipe(Thread):
|
||||||
self.dest.close()
|
self.dest.close()
|
||||||
if self._client_disconnected:
|
if self._client_disconnected:
|
||||||
self._client_disconnected()
|
self._client_disconnected()
|
||||||
|
|
||||||
|
|
||||||
class TcpProxy(TransportProxyBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
local_port,
|
|
||||||
dest_host=None,
|
|
||||||
dest_port=None,
|
|
||||||
local_host="",
|
|
||||||
client_connected: Callable[[str], None] = None,
|
|
||||||
client_disconnected: Callable[[str], None] = None,
|
|
||||||
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
|
|
||||||
):
|
|
||||||
super().__init__(local_port, dest_host, dest_port, local_host)
|
|
||||||
self._client_connected = client_connected
|
|
||||||
# TODO: Rethink client_disconnected
|
|
||||||
self._client_disconnected = client_disconnected
|
|
||||||
self._client_data_received = client_data_received
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
pipes = []
|
|
||||||
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
l_socket.bind((self.local_host, self.local_port))
|
|
||||||
l_socket.settimeout(PROXY_TIMEOUT)
|
|
||||||
l_socket.listen(5)
|
|
||||||
|
|
||||||
while not self._stopped:
|
|
||||||
try:
|
|
||||||
source, address = l_socket.accept()
|
|
||||||
except socket.timeout:
|
|
||||||
continue
|
|
||||||
|
|
||||||
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
try:
|
|
||||||
dest.connect((self.dest_host, self.dest_port))
|
|
||||||
except socket.error:
|
|
||||||
source.close()
|
|
||||||
dest.close()
|
|
||||||
continue
|
|
||||||
|
|
||||||
on_disconnect = (
|
|
||||||
partial(self._client_connected, address[0]) if self._client_connected else None
|
|
||||||
)
|
|
||||||
on_data_received = partial(self._client_data_received, client=address[0])
|
|
||||||
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
|
|
||||||
pipes.append(pipe)
|
|
||||||
logger.debug(
|
|
||||||
"piping sockets %s:%s->%s:%s",
|
|
||||||
address[0],
|
|
||||||
address[1],
|
|
||||||
self.dest_host,
|
|
||||||
self.dest_port,
|
|
||||||
)
|
|
||||||
if self._client_connected:
|
|
||||||
self._client_connected(address[0])
|
|
||||||
pipe.start()
|
|
||||||
|
|
||||||
l_socket.close()
|
|
||||||
for pipe in pipes:
|
|
||||||
pipe.join()
|
|
||||||
|
|
|
@ -1,19 +1,17 @@
|
||||||
from dataclasses import dataclass
|
import socket
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
from threading import Event, Lock, Thread
|
from threading import Event, Lock, Thread
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
from typing import Dict
|
from typing import List
|
||||||
|
|
||||||
from infection_monkey.network.relay.tcp import TcpProxy
|
from infection_monkey.network.relay import (
|
||||||
|
RelayUser,
|
||||||
|
RelayUserHandler,
|
||||||
|
SocketsPipe,
|
||||||
|
TCPConnectionHandler,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
|
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
|
||||||
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RelayUser:
|
|
||||||
address: IPv4Address
|
|
||||||
last_update_time: float
|
|
||||||
|
|
||||||
|
|
||||||
class TCPRelay(Thread):
|
class TCPRelay(Thread):
|
||||||
|
@ -29,91 +27,63 @@ class TCPRelay(Thread):
|
||||||
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
|
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
|
||||||
):
|
):
|
||||||
self._stopped = Event()
|
self._stopped = Event()
|
||||||
|
|
||||||
|
self._user_handler = RelayUserHandler()
|
||||||
|
self._connection_handler = TCPConnectionHandler(
|
||||||
|
local_port, client_connected=self._user_connected
|
||||||
|
)
|
||||||
self._local_port = local_port
|
self._local_port = local_port
|
||||||
self._target_addr = target_addr
|
self._target_addr = target_addr
|
||||||
self._target_port = target_port
|
self._target_port = target_port
|
||||||
self._new_client_timeout = new_client_timeout
|
self._new_client_timeout = new_client_timeout
|
||||||
super().__init__(name="MonkeyTcpRelayThread")
|
super().__init__(name="MonkeyTcpRelayThread")
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
self._relay_users: Dict[IPv4Address, RelayUser] = {}
|
|
||||||
self._potential_users: Dict[IPv4Address, RelayUser] = {}
|
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
self._pipes: List[SocketsPipe] = []
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
proxy = TcpProxy(
|
self._connection_handler.start()
|
||||||
local_port=self._local_port,
|
|
||||||
dest_host=self._target_addr,
|
|
||||||
dest_port=self._target_port,
|
|
||||||
client_connected=self.add_relay_user,
|
|
||||||
client_data_received=self.on_user_data_received,
|
|
||||||
)
|
|
||||||
proxy.start()
|
|
||||||
|
|
||||||
self._stopped.wait()
|
self._stopped.wait()
|
||||||
|
|
||||||
self._wait_for_users_to_disconnect()
|
self._wait_for_users_to_disconnect()
|
||||||
|
|
||||||
proxy.stop()
|
self._connection_handler.stop()
|
||||||
proxy.join()
|
self._connection_handler.join()
|
||||||
|
|
||||||
|
[pipe.join() for pipe in self._pipes]
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._stopped.set()
|
self._stopped.set()
|
||||||
|
|
||||||
def add_relay_user(self, user_address: IPv4Address):
|
def _user_connected(self, source: socket.socket, user_addr: IPv4Address):
|
||||||
"""
|
self._user_handler.add_relay_user(user_addr)
|
||||||
Handle new user connection.
|
self._spawn_pipe(source)
|
||||||
|
|
||||||
:param user: A user which will be added to the relay
|
def _spawn_pipe(self, source: socket.socket):
|
||||||
"""
|
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
with self._lock:
|
try:
|
||||||
if user_address in self._potential_users:
|
dest.connect((self._target_addr, self._target_port))
|
||||||
del self._potential_users[user_address]
|
except socket.error:
|
||||||
|
source.close()
|
||||||
|
dest.close()
|
||||||
|
|
||||||
self._relay_users[user_address] = RelayUser(user_address, time())
|
pipe = SocketsPipe(
|
||||||
|
source, dest, client_data_received=self._user_handler.on_user_data_received
|
||||||
def relay_users(self) -> Dict[IPv4Address, RelayUser]:
|
)
|
||||||
"""
|
self._pipes.append(pipe)
|
||||||
Get the list of users connected to the relay.
|
pipe.run()
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
return self._relay_users.copy()
|
|
||||||
|
|
||||||
def add_potential_user(self, user_address: IPv4Address):
|
|
||||||
"""
|
|
||||||
Notify TCPRelay that a new user may try and connect.
|
|
||||||
|
|
||||||
:param user: A potential user that tries to connect to the relay
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
self._potential_users[user_address] = RelayUser(user_address, time())
|
|
||||||
|
|
||||||
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
|
|
||||||
"""
|
|
||||||
Disconnect a user which a specific starting data.
|
|
||||||
|
|
||||||
:param data: The data that a relay recieved
|
|
||||||
:param user: User which send the data
|
|
||||||
"""
|
|
||||||
if data.startswith(RELAY_CONTROL_MESSAGE):
|
|
||||||
self._disconnect_user(user_address)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _disconnect_user(self, user_address: IPv4Address):
|
|
||||||
with self._lock:
|
|
||||||
if user_address in self._relay_users:
|
|
||||||
del self._relay_users[user_address]
|
|
||||||
|
|
||||||
def _wait_for_users_to_disconnect(self):
|
def _wait_for_users_to_disconnect(self):
|
||||||
stop = False
|
stop = False
|
||||||
while not stop:
|
while not stop:
|
||||||
sleep(0.01)
|
sleep(0.01)
|
||||||
current_time = time()
|
current_time = time()
|
||||||
|
potential_users = self._user_handler.get_potential_users()
|
||||||
most_recent_potential_time = max(
|
most_recent_potential_time = max(
|
||||||
self._potential_users.values(),
|
potential_users.values(),
|
||||||
key=lambda ru: ru.last_update_time,
|
key=lambda ru: ru.last_update_time,
|
||||||
default=RelayUser(IPv4Address(""), 0.0),
|
default=RelayUser(IPv4Address(""), 0.0),
|
||||||
).last_update_time
|
).last_update_time
|
||||||
potential_elapsed = current_time - most_recent_potential_time
|
potential_elapsed = current_time - most_recent_potential_time
|
||||||
|
|
||||||
stop = not self._potential_users or potential_elapsed > self._new_client_timeout
|
stop = not potential_users or potential_elapsed > self._new_client_timeout
|
||||||
|
|
Loading…
Reference in New Issue