From 0e869462b54e759dd8a08dea7f6097cdfc435d00 Mon Sep 17 00:00:00 2001 From: Kekoa Kaaikala Date: Thu, 1 Sep 2022 15:06:47 +0000 Subject: [PATCH] Agent: Refactor TCPRelay Integrate TCPConnectionHandler and RelayUserHandler into TCPRelay Remove TCPProxy --- .../network/relay/__init__.py | 3 + monkey/infection_monkey/network/relay/tcp.py | 68 +----------- monkey/infection_monkey/tcp_relay.py | 104 +++++++----------- 3 files changed, 41 insertions(+), 134 deletions(-) create mode 100644 monkey/infection_monkey/network/relay/__init__.py diff --git a/monkey/infection_monkey/network/relay/__init__.py b/monkey/infection_monkey/network/relay/__init__.py new file mode 100644 index 000000000..a3a83d9a2 --- /dev/null +++ b/monkey/infection_monkey/network/relay/__init__.py @@ -0,0 +1,3 @@ +from .relay_user_handler import RelayUser, RelayUserHandler +from .tcp_connection_handler import TCPConnectionHandler +from .tcp import SocketsPipe diff --git a/monkey/infection_monkey/network/relay/tcp.py b/monkey/infection_monkey/network/relay/tcp.py index dc87c67d0..0ae6e630b 100644 --- a/monkey/infection_monkey/network/relay/tcp.py +++ b/monkey/infection_monkey/network/relay/tcp.py @@ -1,15 +1,9 @@ import select -import socket -from functools import partial from logging import getLogger from threading import Thread from typing import Callable -from infection_monkey.transport.base import ( - PROXY_TIMEOUT, - TransportProxyBase, - update_last_serve_time, -) +from infection_monkey.transport.base import update_last_serve_time READ_BUFFER_SIZE = 8192 SOCKET_READ_TIMEOUT = 10 @@ -65,63 +59,3 @@ class SocketsPipe(Thread): self.dest.close() if self._client_disconnected: self._client_disconnected() - - -class TcpProxy(TransportProxyBase): - def __init__( - self, - local_port, - dest_host=None, - dest_port=None, - local_host="", - client_connected: Callable[[str], None] = None, - client_disconnected: Callable[[str], None] = None, - client_data_received: Callable[[bytes, str], bool] = _default_client_data_received, - ): - super().__init__(local_port, dest_host, dest_port, local_host) - self._client_connected = client_connected - # TODO: Rethink client_disconnected - self._client_disconnected = client_disconnected - self._client_data_received = client_data_received - - def run(self): - pipes = [] - l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - l_socket.bind((self.local_host, self.local_port)) - l_socket.settimeout(PROXY_TIMEOUT) - l_socket.listen(5) - - while not self._stopped: - try: - source, address = l_socket.accept() - except socket.timeout: - continue - - dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - dest.connect((self.dest_host, self.dest_port)) - except socket.error: - source.close() - dest.close() - continue - - on_disconnect = ( - partial(self._client_connected, address[0]) if self._client_connected else None - ) - on_data_received = partial(self._client_data_received, client=address[0]) - pipe = SocketsPipe(source, dest, on_disconnect, on_data_received) - pipes.append(pipe) - logger.debug( - "piping sockets %s:%s->%s:%s", - address[0], - address[1], - self.dest_host, - self.dest_port, - ) - if self._client_connected: - self._client_connected(address[0]) - pipe.start() - - l_socket.close() - for pipe in pipes: - pipe.join() diff --git a/monkey/infection_monkey/tcp_relay.py b/monkey/infection_monkey/tcp_relay.py index c49bdad22..586663a0a 100644 --- a/monkey/infection_monkey/tcp_relay.py +++ b/monkey/infection_monkey/tcp_relay.py @@ -1,19 +1,17 @@ -from dataclasses import dataclass +import socket from ipaddress import IPv4Address from threading import Event, Lock, Thread from time import sleep, time -from typing import Dict +from typing import List -from infection_monkey.network.relay.tcp import TcpProxy +from infection_monkey.network.relay import ( + RelayUser, + RelayUserHandler, + SocketsPipe, + TCPConnectionHandler, +) DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect -RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -" - - -@dataclass -class RelayUser: - address: IPv4Address - last_update_time: float class TCPRelay(Thread): @@ -29,91 +27,63 @@ class TCPRelay(Thread): new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT, ): self._stopped = Event() + + self._user_handler = RelayUserHandler() + self._connection_handler = TCPConnectionHandler( + local_port, client_connected=self._user_connected + ) self._local_port = local_port self._target_addr = target_addr self._target_port = target_port self._new_client_timeout = new_client_timeout super().__init__(name="MonkeyTcpRelayThread") self.daemon = True - self._relay_users: Dict[IPv4Address, RelayUser] = {} - self._potential_users: Dict[IPv4Address, RelayUser] = {} self._lock = Lock() + self._pipes: List[SocketsPipe] = [] def run(self): - proxy = TcpProxy( - local_port=self._local_port, - dest_host=self._target_addr, - dest_port=self._target_port, - client_connected=self.add_relay_user, - client_data_received=self.on_user_data_received, - ) - proxy.start() + self._connection_handler.start() self._stopped.wait() - self._wait_for_users_to_disconnect() - proxy.stop() - proxy.join() + self._connection_handler.stop() + self._connection_handler.join() + + [pipe.join() for pipe in self._pipes] def stop(self): self._stopped.set() - def add_relay_user(self, user_address: IPv4Address): - """ - Handle new user connection. + def _user_connected(self, source: socket.socket, user_addr: IPv4Address): + self._user_handler.add_relay_user(user_addr) + self._spawn_pipe(source) - :param user: A user which will be added to the relay - """ - with self._lock: - if user_address in self._potential_users: - del self._potential_users[user_address] + def _spawn_pipe(self, source: socket.socket): + dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + dest.connect((self._target_addr, self._target_port)) + except socket.error: + source.close() + dest.close() - self._relay_users[user_address] = RelayUser(user_address, time()) - - def relay_users(self) -> Dict[IPv4Address, RelayUser]: - """ - Get the list of users connected to the relay. - """ - with self._lock: - return self._relay_users.copy() - - def add_potential_user(self, user_address: IPv4Address): - """ - Notify TCPRelay that a new user may try and connect. - - :param user: A potential user that tries to connect to the relay - """ - with self._lock: - self._potential_users[user_address] = RelayUser(user_address, time()) - - def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool: - """ - Disconnect a user which a specific starting data. - - :param data: The data that a relay recieved - :param user: User which send the data - """ - if data.startswith(RELAY_CONTROL_MESSAGE): - self._disconnect_user(user_address) - return False - return True - - def _disconnect_user(self, user_address: IPv4Address): - with self._lock: - if user_address in self._relay_users: - del self._relay_users[user_address] + pipe = SocketsPipe( + source, dest, client_data_received=self._user_handler.on_user_data_received + ) + self._pipes.append(pipe) + pipe.run() def _wait_for_users_to_disconnect(self): stop = False while not stop: sleep(0.01) current_time = time() + potential_users = self._user_handler.get_potential_users() most_recent_potential_time = max( - self._potential_users.values(), + potential_users.values(), key=lambda ru: ru.last_update_time, default=RelayUser(IPv4Address(""), 0.0), ).last_update_time potential_elapsed = current_time - most_recent_potential_time - stop = not self._potential_users or potential_elapsed > self._new_client_timeout + stop = not potential_users or potential_elapsed > self._new_client_timeout