From 72144faefcb6d3d6e2bbb9ea71f5901b80aad00a Mon Sep 17 00:00:00 2001 From: Kekoa Kaaikala Date: Fri, 2 Sep 2022 19:58:01 +0000 Subject: [PATCH] Agent: Update TCPRelay to separate responsbilities --- .../network/relay/relay_user_handler.py | 33 +++--- monkey/infection_monkey/network/relay/tcp.py | 10 +- .../network/relay/tcp_connection_handler.py | 19 +-- .../network/relay/tcp_pipe_spawner.py | 19 ++- monkey/infection_monkey/tcp_relay.py | 53 +++------ .../network/relay/test_relay_user_handler.py | 42 ++++--- .../infection_monkey/test_tcp_relay.py | 110 ------------------ 7 files changed, 77 insertions(+), 209 deletions(-) delete mode 100644 monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py diff --git a/monkey/infection_monkey/network/relay/relay_user_handler.py b/monkey/infection_monkey/network/relay/relay_user_handler.py index 0a32988b5..bce43cc4e 100644 --- a/monkey/infection_monkey/network/relay/relay_user_handler.py +++ b/monkey/infection_monkey/network/relay/relay_user_handler.py @@ -4,7 +4,7 @@ from threading import Lock from time import time from typing import Dict -RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -" +DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect @dataclass @@ -14,7 +14,8 @@ class RelayUser: class RelayUserHandler: - def __init__(self): + def __init__(self, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT): + self._new_client_timeout = new_client_timeout self._relay_users: Dict[IPv4Address, RelayUser] = {} self._potential_users: Dict[IPv4Address, RelayUser] = {} @@ -44,18 +45,6 @@ class RelayUserHandler: with self._lock: self._potential_users[user_address] = RelayUser(user_address, time()) - def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool: - """ - Disconnect a user with a specific starting data. - - :param data: The data that a relay received - :param user_address: An address defining RelayUser which received the data - """ - if data.startswith(RELAY_CONTROL_MESSAGE): - self.disconnect_user(user_address) - return False - return True - def disconnect_user(self, user_address: IPv4Address): """ Handle when a user disconnects. @@ -66,6 +55,16 @@ class RelayUserHandler: if user_address in self._relay_users: del self._relay_users[user_address] - def get_potential_users(self) -> Dict[IPv4Address, RelayUser]: - with self._lock: - return self._potential_users.copy() + def has_potential_users(self) -> bool: + """ + Return whether or not we have any potential users. + """ + current_time = time() + self._potential_users = dict( + filter( + lambda ru: (current_time - ru[1].last_update_time) < self._new_client_timeout, + self._potential_users.items(), + ) + ) + + return len(self._potential_users) > 0 diff --git a/monkey/infection_monkey/network/relay/tcp.py b/monkey/infection_monkey/network/relay/tcp.py index 0ae6e630b..d60f954ab 100644 --- a/monkey/infection_monkey/network/relay/tcp.py +++ b/monkey/infection_monkey/network/relay/tcp.py @@ -11,10 +11,6 @@ SOCKET_READ_TIMEOUT = 10 logger = getLogger(__name__) -def _default_client_data_received(_: bytes, client=None) -> bool: - return True - - class SocketsPipe(Thread): def __init__( self, @@ -22,7 +18,6 @@ class SocketsPipe(Thread): dest, timeout=SOCKET_READ_TIMEOUT, client_disconnected: Callable[[str], None] = None, - client_data_received: Callable[[bytes], bool] = _default_client_data_received, ): Thread.__init__(self) self.source = source @@ -32,13 +27,12 @@ class SocketsPipe(Thread): super(SocketsPipe, self).__init__() self.daemon = True self._client_disconnected = client_disconnected - self._client_data_received = client_data_received def run(self): sockets = [self.source, self.dest] while self._keep_connection: self._keep_connection = False - rlist, wlist, xlist = select.select(sockets, [], sockets, self.timeout) + rlist, _, xlist = select.select(sockets, [], sockets, self.timeout) if xlist: break for r in rlist: @@ -47,7 +41,7 @@ class SocketsPipe(Thread): data = r.recv(READ_BUFFER_SIZE) except Exception: break - if data and self._client_data_received(data): + if data: try: other.sendall(data) update_last_serve_time() diff --git a/monkey/infection_monkey/network/relay/tcp_connection_handler.py b/monkey/infection_monkey/network/relay/tcp_connection_handler.py index bcd44b3d5..dd8d315e3 100644 --- a/monkey/infection_monkey/network/relay/tcp_connection_handler.py +++ b/monkey/infection_monkey/network/relay/tcp_connection_handler.py @@ -1,7 +1,6 @@ import socket -from ipaddress import IPv4Address from threading import Event, Thread -from typing import Callable +from typing import Callable, List PROXY_TIMEOUT = 2.5 @@ -13,7 +12,7 @@ class TCPConnectionHandler(Thread): self, local_port: int, local_host: str = "", - client_connected: Callable[[socket.socket, IPv4Address], None] = None, + client_connected: List[Callable[[socket.socket], None]] = [], ): self.local_port = local_port self.local_host = local_host @@ -29,22 +28,14 @@ class TCPConnectionHandler(Thread): while not self._stopped: try: - source, address = l_socket.accept() + source, _ = l_socket.accept() except socket.timeout: continue - if self._client_connected: - self._client_connected(source, IPv4Address(address[0])) + for notify_client_connected in self._client_connected: + notify_client_connected(source) l_socket.close() def stop(self): self._stopped.set() - - def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]): - """ - Register to be notified when a client connects. - - :param callback: Callable used to notify when a client connects. - """ - self._client_connected = callback diff --git a/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py b/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py index 16523dfe3..651043380 100644 --- a/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py +++ b/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py @@ -1,16 +1,21 @@ import socket from ipaddress import IPv4Address -from typing import Callable +from typing import List from .tcp import SocketsPipe class TCPPipeSpawner: + """ + Creates bi-directional pipes between the configured client and other clients. + """ + def __init__(self, target_addr: IPv4Address, target_port: int): self._target_addr = target_addr self._target_port = target_port + self._pipes: List[SocketsPipe] = [] - def spawn_pipe(self, source: socket.socket) -> SocketsPipe: + def spawn_pipe(self, source: socket.socket): dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: dest.connect((self._target_addr, self._target_port)) @@ -19,7 +24,11 @@ class TCPPipeSpawner: dest.close() raise err - return SocketsPipe(source, dest, client_data_received=self._client_data_received) + # TODO: have SocketsPipe notify TCPPipeSpawner when it's done + pipe = SocketsPipe(source, dest) + self._pipes.append(pipe) + pipe.run() - def notify_client_data_received(self, callback: Callable[[bytes], bool]): - self._client_data_received = callback + def has_open_pipes(self) -> bool: + self._pipes = [p for p in self._pipes if p.is_alive()] + return len(self._pipes) > 0 diff --git a/monkey/infection_monkey/tcp_relay.py b/monkey/infection_monkey/tcp_relay.py index 66097b335..7f8a9abe5 100644 --- a/monkey/infection_monkey/tcp_relay.py +++ b/monkey/infection_monkey/tcp_relay.py @@ -1,18 +1,7 @@ -import socket -from ipaddress import IPv4Address from threading import Event, Lock, Thread -from time import sleep, time -from typing import List +from time import sleep -from infection_monkey.network.relay import ( - RelayUser, - RelayUserHandler, - SocketsPipe, - TCPConnectionHandler, - TCPPipeSpawner, -) - -DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect +from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner class TCPRelay(Thread): @@ -25,20 +14,15 @@ class TCPRelay(Thread): relay_user_handler: RelayUserHandler, connection_handler: TCPConnectionHandler, pipe_spawner: TCPPipeSpawner, - new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT, ): self._stopped = Event() self._user_handler = relay_user_handler self._connection_handler = connection_handler - self._connection_handler.notify_client_connected(self._user_connected) self._pipe_spawner = pipe_spawner - self._pipe_spawner.notify_client_data_received(self._user_handler.on_user_data_received) - self._new_client_timeout = new_client_timeout super().__init__(name="MonkeyTcpRelayThread") self.daemon = True self._lock = Lock() - self._pipes: List[SocketsPipe] = [] def run(self): self._connection_handler.start() @@ -48,32 +32,21 @@ class TCPRelay(Thread): self._connection_handler.stop() self._connection_handler.join() - - [pipe.join() for pipe in self._pipes] + self._wait_for_pipes_to_close() def stop(self): self._stopped.set() - def _user_connected(self, source: socket.socket, user_addr: IPv4Address): - self._user_handler.add_relay_user(user_addr) - self._spawn_pipe(source) - - def _spawn_pipe(self, source: socket.socket): - pipe = self._pipe_spawner.spawn_pipe(source) - self._pipes.append(pipe) - pipe.run() - def _wait_for_users_to_disconnect(self): - stop = False - while not stop: + """ + Blocks until the users disconnect or the timeout has elapsed. + """ + while self._user_handler.has_potential_users(): sleep(0.01) - current_time = time() - potential_users = self._user_handler.get_potential_users() - most_recent_potential_time = max( - potential_users.values(), - key=lambda ru: ru.last_update_time, - default=RelayUser(IPv4Address(""), 0.0), - ).last_update_time - potential_elapsed = current_time - most_recent_potential_time - stop = not potential_users or potential_elapsed > self._new_client_timeout + def _wait_for_pipes_to_close(self): + """ + Blocks until the pipes have closed. + """ + while self._pipe_spawner.has_open_pipes(): + sleep(0.01) diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py index 34aaf6bc7..ca0eb4103 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_user_handler.py @@ -1,23 +1,35 @@ from ipaddress import IPv4Address +from time import sleep + +import pytest from monkey.infection_monkey.network.relay import RelayUserHandler - -def test_potential_users_added(): - user_address = IPv4Address("0.0.0.0") - handler = RelayUserHandler() - - assert len(handler.get_potential_users()) == 0 - handler.add_potential_user(user_address) - assert len(handler.get_potential_users()) == 1 - assert user_address in handler.get_potential_users() +USER_ADDRESS = IPv4Address("0.0.0.0") -def test_potential_user_removed_on_matching_user_added(): - user_address = IPv4Address("0.0.0.0") - handler = RelayUserHandler() +@pytest.fixture +def handler(): + return RelayUserHandler() - handler.add_potential_user(user_address) - handler.add_relay_user(user_address) - assert len(handler.get_potential_users()) == 0 +def test_potential_users_added(handler): + assert not handler.has_potential_users() + handler.add_potential_user(USER_ADDRESS) + assert handler.has_potential_users() + + +def test_potential_user_removed_on_matching_user_added(handler): + handler.add_potential_user(USER_ADDRESS) + handler.add_relay_user(USER_ADDRESS) + + assert not handler.has_potential_users() + + +def test_potential_users_time_out(): + handler = RelayUserHandler(new_client_timeout=0.001) + + handler.add_potential_user(USER_ADDRESS) + sleep(0.003) + + assert not handler.has_potential_users() diff --git a/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py b/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py deleted file mode 100644 index c913e21ef..000000000 --- a/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py +++ /dev/null @@ -1,110 +0,0 @@ -import socket -from ipaddress import IPv4Address -from threading import Thread -from typing import Callable -from unittest.mock import MagicMock - -import pytest - -from monkey.infection_monkey.network.relay.relay_user_handler import ( # RELAY_CONTROL_MESSAGE, - RelayUserHandler, -) -from monkey.infection_monkey.tcp_relay import TCPRelay - -NEW_USER_ADDRESS = IPv4Address("0.0.0.1") -LOCAL_PORT = 9975 -TARGET_ADDRESS = "0.0.0.0" -TARGET_PORT = 9976 - - -class FakeConnectionHandler: - def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]): - self.cb = callback - - def client_connected(self, socket: socket.socket, addr: IPv4Address): - self.cb(socket, addr) - - def start(self): - pass - - def stop(self): - pass - - def join(self): - pass - - -class FakePipeSpawner: - spawn_pipe = MagicMock() - - def notify_client_data_received(self, callback: Callable[[bytes], bool]): - self.cb = callback - - def send_client_data(self, data: bytes): - self.cb(data) - - -@pytest.fixture -def relay_user_handler() -> RelayUserHandler: - return RelayUserHandler() - - -@pytest.fixture -def pipe_spawner(): - return FakePipeSpawner() - - -@pytest.fixture -def connection_handler(): - return FakeConnectionHandler() - - -@pytest.fixture -def tcp_relay(relay_user_handler, connection_handler, pipe_spawner) -> TCPRelay: - return TCPRelay(relay_user_handler, connection_handler, pipe_spawner) - - -def join_or_kill_thread(thread: Thread, timeout: float): - """Whether or not the thread joined in the given timeout period.""" - thread.join(timeout) - if thread.is_alive(): - # Cannot set daemon status of active thread: thread.daemon = True - return False - return True - - -def test_user_added_when_user_connected(connection_handler, relay_user_handler, tcp_relay): - # sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # connection_handler.client_connected(sock, NEW_USER_ADDRESS) - # assert len(relay_user_handler.get_relay_users()) == 1 - pass - - -def test_pipe_created_when_user_connected(connection_handler, pipe_spawner, tcp_relay): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - connection_handler.client_connected(sock, NEW_USER_ADDRESS) - assert pipe_spawner.spawn_pipe.called - - -def test_user_removed_on_request(relay_user_handler, pipe_spawner, tcp_relay): - relay_user_handler.add_relay_user(NEW_USER_ADDRESS) - - # pipe_spawner.send_client_data(RELAY_CONTROL_MESSAGE, NEW_USER_ADDRESS) - - # users = relay_user_handler.get_relay_users() - # assert len(users) == 0 - pass - - -# This will fail unless TcpProxy is updated to do non-blocking accepts -# @pytest.mark.slow -# def test_waits_for_exploited_machines(): -# relay = TCPRelay(9975, "0.0.0.0", 9976, new_client_timeout=0.2) -# new_user = "0.0.0.1" -# relay.start() - -# relay.add_potential_user(new_user) -# relay.stop() - -# assert not join_or_kill_thread(relay, 0.1) # Should be waiting -# assert join_or_kill_thread(relay, 1) # Should be done waiting