Agent: Refactor TCPRelay

Integrate TCPConnectionHandler and RelayUserHandler into TCPRelay
Remove TCPProxy
This commit is contained in:
Kekoa Kaaikala 2022-09-01 15:06:47 +00:00
parent d6931a6414
commit 0e869462b5
3 changed files with 41 additions and 134 deletions

View File

@ -0,0 +1,3 @@
from .relay_user_handler import RelayUser, RelayUserHandler
from .tcp_connection_handler import TCPConnectionHandler
from .tcp import SocketsPipe

View File

@ -1,15 +1,9 @@
import select import select
import socket
from functools import partial
from logging import getLogger from logging import getLogger
from threading import Thread from threading import Thread
from typing import Callable from typing import Callable
from infection_monkey.transport.base import ( from infection_monkey.transport.base import update_last_serve_time
PROXY_TIMEOUT,
TransportProxyBase,
update_last_serve_time,
)
READ_BUFFER_SIZE = 8192 READ_BUFFER_SIZE = 8192
SOCKET_READ_TIMEOUT = 10 SOCKET_READ_TIMEOUT = 10
@ -65,63 +59,3 @@ class SocketsPipe(Thread):
self.dest.close() self.dest.close()
if self._client_disconnected: if self._client_disconnected:
self._client_disconnected() self._client_disconnected()
class TcpProxy(TransportProxyBase):
def __init__(
self,
local_port,
dest_host=None,
dest_port=None,
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
# TODO: Rethink client_disconnected
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
pipes = []
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
l_socket.bind((self.local_host, self.local_port))
l_socket.settimeout(PROXY_TIMEOUT)
l_socket.listen(5)
while not self._stopped:
try:
source, address = l_socket.accept()
except socket.timeout:
continue
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
dest.connect((self.dest_host, self.dest_port))
except socket.error:
source.close()
dest.close()
continue
on_disconnect = (
partial(self._client_connected, address[0]) if self._client_connected else None
)
on_data_received = partial(self._client_data_received, client=address[0])
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
pipes.append(pipe)
logger.debug(
"piping sockets %s:%s->%s:%s",
address[0],
address[1],
self.dest_host,
self.dest_port,
)
if self._client_connected:
self._client_connected(address[0])
pipe.start()
l_socket.close()
for pipe in pipes:
pipe.join()

View File

@ -1,19 +1,17 @@
from dataclasses import dataclass import socket
from ipaddress import IPv4Address from ipaddress import IPv4Address
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from time import sleep, time from time import sleep, time
from typing import Dict from typing import List
from infection_monkey.network.relay.tcp import TcpProxy from infection_monkey.network.relay import (
RelayUser,
RelayUserHandler,
SocketsPipe,
TCPConnectionHandler,
)
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
@dataclass
class RelayUser:
address: IPv4Address
last_update_time: float
class TCPRelay(Thread): class TCPRelay(Thread):
@ -29,91 +27,63 @@ class TCPRelay(Thread):
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
): ):
self._stopped = Event() self._stopped = Event()
self._user_handler = RelayUserHandler()
self._connection_handler = TCPConnectionHandler(
local_port, client_connected=self._user_connected
)
self._local_port = local_port self._local_port = local_port
self._target_addr = target_addr self._target_addr = target_addr
self._target_port = target_port self._target_port = target_port
self._new_client_timeout = new_client_timeout self._new_client_timeout = new_client_timeout
super().__init__(name="MonkeyTcpRelayThread") super().__init__(name="MonkeyTcpRelayThread")
self.daemon = True self.daemon = True
self._relay_users: Dict[IPv4Address, RelayUser] = {}
self._potential_users: Dict[IPv4Address, RelayUser] = {}
self._lock = Lock() self._lock = Lock()
self._pipes: List[SocketsPipe] = []
def run(self): def run(self):
proxy = TcpProxy( self._connection_handler.start()
local_port=self._local_port,
dest_host=self._target_addr,
dest_port=self._target_port,
client_connected=self.add_relay_user,
client_data_received=self.on_user_data_received,
)
proxy.start()
self._stopped.wait() self._stopped.wait()
self._wait_for_users_to_disconnect() self._wait_for_users_to_disconnect()
proxy.stop() self._connection_handler.stop()
proxy.join() self._connection_handler.join()
[pipe.join() for pipe in self._pipes]
def stop(self): def stop(self):
self._stopped.set() self._stopped.set()
def add_relay_user(self, user_address: IPv4Address): def _user_connected(self, source: socket.socket, user_addr: IPv4Address):
""" self._user_handler.add_relay_user(user_addr)
Handle new user connection. self._spawn_pipe(source)
:param user: A user which will be added to the relay def _spawn_pipe(self, source: socket.socket):
""" dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
with self._lock: try:
if user_address in self._potential_users: dest.connect((self._target_addr, self._target_port))
del self._potential_users[user_address] except socket.error:
source.close()
dest.close()
self._relay_users[user_address] = RelayUser(user_address, time()) pipe = SocketsPipe(
source, dest, client_data_received=self._user_handler.on_user_data_received
def relay_users(self) -> Dict[IPv4Address, RelayUser]: )
""" self._pipes.append(pipe)
Get the list of users connected to the relay. pipe.run()
"""
with self._lock:
return self._relay_users.copy()
def add_potential_user(self, user_address: IPv4Address):
"""
Notify TCPRelay that a new user may try and connect.
:param user: A potential user that tries to connect to the relay
"""
with self._lock:
self._potential_users[user_address] = RelayUser(user_address, time())
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
"""
Disconnect a user which a specific starting data.
:param data: The data that a relay recieved
:param user: User which send the data
"""
if data.startswith(RELAY_CONTROL_MESSAGE):
self._disconnect_user(user_address)
return False
return True
def _disconnect_user(self, user_address: IPv4Address):
with self._lock:
if user_address in self._relay_users:
del self._relay_users[user_address]
def _wait_for_users_to_disconnect(self): def _wait_for_users_to_disconnect(self):
stop = False stop = False
while not stop: while not stop:
sleep(0.01) sleep(0.01)
current_time = time() current_time = time()
potential_users = self._user_handler.get_potential_users()
most_recent_potential_time = max( most_recent_potential_time = max(
self._potential_users.values(), potential_users.values(),
key=lambda ru: ru.last_update_time, key=lambda ru: ru.last_update_time,
default=RelayUser(IPv4Address(""), 0.0), default=RelayUser(IPv4Address(""), 0.0),
).last_update_time ).last_update_time
potential_elapsed = current_time - most_recent_potential_time potential_elapsed = current_time - most_recent_potential_time
stop = not self._potential_users or potential_elapsed > self._new_client_timeout stop = not potential_users or potential_elapsed > self._new_client_timeout