Agent: Refactor TCPRelay
Integrate TCPConnectionHandler and RelayUserHandler into TCPRelay Remove TCPProxy
This commit is contained in:
parent
d6931a6414
commit
0e869462b5
|
@ -0,0 +1,3 @@
|
|||
from .relay_user_handler import RelayUser, RelayUserHandler
|
||||
from .tcp_connection_handler import TCPConnectionHandler
|
||||
from .tcp import SocketsPipe
|
|
@ -1,15 +1,9 @@
|
|||
import select
|
||||
import socket
|
||||
from functools import partial
|
||||
from logging import getLogger
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from infection_monkey.transport.base import (
|
||||
PROXY_TIMEOUT,
|
||||
TransportProxyBase,
|
||||
update_last_serve_time,
|
||||
)
|
||||
from infection_monkey.transport.base import update_last_serve_time
|
||||
|
||||
READ_BUFFER_SIZE = 8192
|
||||
SOCKET_READ_TIMEOUT = 10
|
||||
|
@ -65,63 +59,3 @@ class SocketsPipe(Thread):
|
|||
self.dest.close()
|
||||
if self._client_disconnected:
|
||||
self._client_disconnected()
|
||||
|
||||
|
||||
class TcpProxy(TransportProxyBase):
|
||||
def __init__(
|
||||
self,
|
||||
local_port,
|
||||
dest_host=None,
|
||||
dest_port=None,
|
||||
local_host="",
|
||||
client_connected: Callable[[str], None] = None,
|
||||
client_disconnected: Callable[[str], None] = None,
|
||||
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
|
||||
):
|
||||
super().__init__(local_port, dest_host, dest_port, local_host)
|
||||
self._client_connected = client_connected
|
||||
# TODO: Rethink client_disconnected
|
||||
self._client_disconnected = client_disconnected
|
||||
self._client_data_received = client_data_received
|
||||
|
||||
def run(self):
|
||||
pipes = []
|
||||
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
l_socket.bind((self.local_host, self.local_port))
|
||||
l_socket.settimeout(PROXY_TIMEOUT)
|
||||
l_socket.listen(5)
|
||||
|
||||
while not self._stopped:
|
||||
try:
|
||||
source, address = l_socket.accept()
|
||||
except socket.timeout:
|
||||
continue
|
||||
|
||||
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
dest.connect((self.dest_host, self.dest_port))
|
||||
except socket.error:
|
||||
source.close()
|
||||
dest.close()
|
||||
continue
|
||||
|
||||
on_disconnect = (
|
||||
partial(self._client_connected, address[0]) if self._client_connected else None
|
||||
)
|
||||
on_data_received = partial(self._client_data_received, client=address[0])
|
||||
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
|
||||
pipes.append(pipe)
|
||||
logger.debug(
|
||||
"piping sockets %s:%s->%s:%s",
|
||||
address[0],
|
||||
address[1],
|
||||
self.dest_host,
|
||||
self.dest_port,
|
||||
)
|
||||
if self._client_connected:
|
||||
self._client_connected(address[0])
|
||||
pipe.start()
|
||||
|
||||
l_socket.close()
|
||||
for pipe in pipes:
|
||||
pipe.join()
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
from dataclasses import dataclass
|
||||
import socket
|
||||
from ipaddress import IPv4Address
|
||||
from threading import Event, Lock, Thread
|
||||
from time import sleep, time
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from infection_monkey.network.relay.tcp import TcpProxy
|
||||
from infection_monkey.network.relay import (
|
||||
RelayUser,
|
||||
RelayUserHandler,
|
||||
SocketsPipe,
|
||||
TCPConnectionHandler,
|
||||
)
|
||||
|
||||
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
|
||||
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayUser:
|
||||
address: IPv4Address
|
||||
last_update_time: float
|
||||
|
||||
|
||||
class TCPRelay(Thread):
|
||||
|
@ -29,91 +27,63 @@ class TCPRelay(Thread):
|
|||
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
|
||||
):
|
||||
self._stopped = Event()
|
||||
|
||||
self._user_handler = RelayUserHandler()
|
||||
self._connection_handler = TCPConnectionHandler(
|
||||
local_port, client_connected=self._user_connected
|
||||
)
|
||||
self._local_port = local_port
|
||||
self._target_addr = target_addr
|
||||
self._target_port = target_port
|
||||
self._new_client_timeout = new_client_timeout
|
||||
super().__init__(name="MonkeyTcpRelayThread")
|
||||
self.daemon = True
|
||||
self._relay_users: Dict[IPv4Address, RelayUser] = {}
|
||||
self._potential_users: Dict[IPv4Address, RelayUser] = {}
|
||||
self._lock = Lock()
|
||||
self._pipes: List[SocketsPipe] = []
|
||||
|
||||
def run(self):
|
||||
proxy = TcpProxy(
|
||||
local_port=self._local_port,
|
||||
dest_host=self._target_addr,
|
||||
dest_port=self._target_port,
|
||||
client_connected=self.add_relay_user,
|
||||
client_data_received=self.on_user_data_received,
|
||||
)
|
||||
proxy.start()
|
||||
self._connection_handler.start()
|
||||
|
||||
self._stopped.wait()
|
||||
|
||||
self._wait_for_users_to_disconnect()
|
||||
|
||||
proxy.stop()
|
||||
proxy.join()
|
||||
self._connection_handler.stop()
|
||||
self._connection_handler.join()
|
||||
|
||||
[pipe.join() for pipe in self._pipes]
|
||||
|
||||
def stop(self):
|
||||
self._stopped.set()
|
||||
|
||||
def add_relay_user(self, user_address: IPv4Address):
|
||||
"""
|
||||
Handle new user connection.
|
||||
def _user_connected(self, source: socket.socket, user_addr: IPv4Address):
|
||||
self._user_handler.add_relay_user(user_addr)
|
||||
self._spawn_pipe(source)
|
||||
|
||||
:param user: A user which will be added to the relay
|
||||
"""
|
||||
with self._lock:
|
||||
if user_address in self._potential_users:
|
||||
del self._potential_users[user_address]
|
||||
def _spawn_pipe(self, source: socket.socket):
|
||||
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
dest.connect((self._target_addr, self._target_port))
|
||||
except socket.error:
|
||||
source.close()
|
||||
dest.close()
|
||||
|
||||
self._relay_users[user_address] = RelayUser(user_address, time())
|
||||
|
||||
def relay_users(self) -> Dict[IPv4Address, RelayUser]:
|
||||
"""
|
||||
Get the list of users connected to the relay.
|
||||
"""
|
||||
with self._lock:
|
||||
return self._relay_users.copy()
|
||||
|
||||
def add_potential_user(self, user_address: IPv4Address):
|
||||
"""
|
||||
Notify TCPRelay that a new user may try and connect.
|
||||
|
||||
:param user: A potential user that tries to connect to the relay
|
||||
"""
|
||||
with self._lock:
|
||||
self._potential_users[user_address] = RelayUser(user_address, time())
|
||||
|
||||
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
|
||||
"""
|
||||
Disconnect a user which a specific starting data.
|
||||
|
||||
:param data: The data that a relay recieved
|
||||
:param user: User which send the data
|
||||
"""
|
||||
if data.startswith(RELAY_CONTROL_MESSAGE):
|
||||
self._disconnect_user(user_address)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _disconnect_user(self, user_address: IPv4Address):
|
||||
with self._lock:
|
||||
if user_address in self._relay_users:
|
||||
del self._relay_users[user_address]
|
||||
pipe = SocketsPipe(
|
||||
source, dest, client_data_received=self._user_handler.on_user_data_received
|
||||
)
|
||||
self._pipes.append(pipe)
|
||||
pipe.run()
|
||||
|
||||
def _wait_for_users_to_disconnect(self):
|
||||
stop = False
|
||||
while not stop:
|
||||
sleep(0.01)
|
||||
current_time = time()
|
||||
potential_users = self._user_handler.get_potential_users()
|
||||
most_recent_potential_time = max(
|
||||
self._potential_users.values(),
|
||||
potential_users.values(),
|
||||
key=lambda ru: ru.last_update_time,
|
||||
default=RelayUser(IPv4Address(""), 0.0),
|
||||
).last_update_time
|
||||
potential_elapsed = current_time - most_recent_potential_time
|
||||
|
||||
stop = not self._potential_users or potential_elapsed > self._new_client_timeout
|
||||
stop = not potential_users or potential_elapsed > self._new_client_timeout
|
||||
|
|
Loading…
Reference in New Issue