Agent: Refactor TCPRelay

Integrate TCPConnectionHandler and RelayUserHandler into TCPRelay
Remove TCPProxy
This commit is contained in:
Kekoa Kaaikala 2022-09-01 15:06:47 +00:00
parent d6931a6414
commit 0e869462b5
3 changed files with 41 additions and 134 deletions

View File

@ -0,0 +1,3 @@
from .relay_user_handler import RelayUser, RelayUserHandler
from .tcp_connection_handler import TCPConnectionHandler
from .tcp import SocketsPipe

View File

@ -1,15 +1,9 @@
import select
import socket
from functools import partial
from logging import getLogger
from threading import Thread
from typing import Callable
from infection_monkey.transport.base import (
PROXY_TIMEOUT,
TransportProxyBase,
update_last_serve_time,
)
from infection_monkey.transport.base import update_last_serve_time
READ_BUFFER_SIZE = 8192
SOCKET_READ_TIMEOUT = 10
@ -65,63 +59,3 @@ class SocketsPipe(Thread):
self.dest.close()
if self._client_disconnected:
self._client_disconnected()
class TcpProxy(TransportProxyBase):
def __init__(
self,
local_port,
dest_host=None,
dest_port=None,
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
# TODO: Rethink client_disconnected
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
pipes = []
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
l_socket.bind((self.local_host, self.local_port))
l_socket.settimeout(PROXY_TIMEOUT)
l_socket.listen(5)
while not self._stopped:
try:
source, address = l_socket.accept()
except socket.timeout:
continue
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
dest.connect((self.dest_host, self.dest_port))
except socket.error:
source.close()
dest.close()
continue
on_disconnect = (
partial(self._client_connected, address[0]) if self._client_connected else None
)
on_data_received = partial(self._client_data_received, client=address[0])
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
pipes.append(pipe)
logger.debug(
"piping sockets %s:%s->%s:%s",
address[0],
address[1],
self.dest_host,
self.dest_port,
)
if self._client_connected:
self._client_connected(address[0])
pipe.start()
l_socket.close()
for pipe in pipes:
pipe.join()

View File

@ -1,19 +1,17 @@
from dataclasses import dataclass
import socket
from ipaddress import IPv4Address
from threading import Event, Lock, Thread
from time import sleep, time
from typing import Dict
from typing import List
from infection_monkey.network.relay.tcp import TcpProxy
from infection_monkey.network.relay import (
RelayUser,
RelayUserHandler,
SocketsPipe,
TCPConnectionHandler,
)
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
@dataclass
class RelayUser:
address: IPv4Address
last_update_time: float
class TCPRelay(Thread):
@ -29,91 +27,63 @@ class TCPRelay(Thread):
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
):
self._stopped = Event()
self._user_handler = RelayUserHandler()
self._connection_handler = TCPConnectionHandler(
local_port, client_connected=self._user_connected
)
self._local_port = local_port
self._target_addr = target_addr
self._target_port = target_port
self._new_client_timeout = new_client_timeout
super().__init__(name="MonkeyTcpRelayThread")
self.daemon = True
self._relay_users: Dict[IPv4Address, RelayUser] = {}
self._potential_users: Dict[IPv4Address, RelayUser] = {}
self._lock = Lock()
self._pipes: List[SocketsPipe] = []
def run(self):
proxy = TcpProxy(
local_port=self._local_port,
dest_host=self._target_addr,
dest_port=self._target_port,
client_connected=self.add_relay_user,
client_data_received=self.on_user_data_received,
)
proxy.start()
self._connection_handler.start()
self._stopped.wait()
self._wait_for_users_to_disconnect()
proxy.stop()
proxy.join()
self._connection_handler.stop()
self._connection_handler.join()
[pipe.join() for pipe in self._pipes]
def stop(self):
self._stopped.set()
def add_relay_user(self, user_address: IPv4Address):
"""
Handle new user connection.
def _user_connected(self, source: socket.socket, user_addr: IPv4Address):
self._user_handler.add_relay_user(user_addr)
self._spawn_pipe(source)
:param user: A user which will be added to the relay
"""
with self._lock:
if user_address in self._potential_users:
del self._potential_users[user_address]
def _spawn_pipe(self, source: socket.socket):
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
dest.connect((self._target_addr, self._target_port))
except socket.error:
source.close()
dest.close()
self._relay_users[user_address] = RelayUser(user_address, time())
def relay_users(self) -> Dict[IPv4Address, RelayUser]:
"""
Get the list of users connected to the relay.
"""
with self._lock:
return self._relay_users.copy()
def add_potential_user(self, user_address: IPv4Address):
"""
Notify TCPRelay that a new user may try and connect.
:param user: A potential user that tries to connect to the relay
"""
with self._lock:
self._potential_users[user_address] = RelayUser(user_address, time())
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
"""
Disconnect a user which a specific starting data.
:param data: The data that a relay recieved
:param user: User which send the data
"""
if data.startswith(RELAY_CONTROL_MESSAGE):
self._disconnect_user(user_address)
return False
return True
def _disconnect_user(self, user_address: IPv4Address):
with self._lock:
if user_address in self._relay_users:
del self._relay_users[user_address]
pipe = SocketsPipe(
source, dest, client_data_received=self._user_handler.on_user_data_received
)
self._pipes.append(pipe)
pipe.run()
def _wait_for_users_to_disconnect(self):
stop = False
while not stop:
sleep(0.01)
current_time = time()
potential_users = self._user_handler.get_potential_users()
most_recent_potential_time = max(
self._potential_users.values(),
potential_users.values(),
key=lambda ru: ru.last_update_time,
default=RelayUser(IPv4Address(""), 0.0),
).last_update_time
potential_elapsed = current_time - most_recent_potential_time
stop = not self._potential_users or potential_elapsed > self._new_client_timeout
stop = not potential_users or potential_elapsed > self._new_client_timeout