Agent: Update TCPRelay to separate responsbilities

This commit is contained in:
Kekoa Kaaikala 2022-09-02 19:58:01 +00:00
parent bbc9cf16e6
commit 72144faefc
7 changed files with 77 additions and 209 deletions

View File

@ -4,7 +4,7 @@ from threading import Lock
from time import time
from typing import Dict
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
@dataclass
@ -14,7 +14,8 @@ class RelayUser:
class RelayUserHandler:
def __init__(self):
def __init__(self, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT):
self._new_client_timeout = new_client_timeout
self._relay_users: Dict[IPv4Address, RelayUser] = {}
self._potential_users: Dict[IPv4Address, RelayUser] = {}
@ -44,18 +45,6 @@ class RelayUserHandler:
with self._lock:
self._potential_users[user_address] = RelayUser(user_address, time())
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
"""
Disconnect a user with a specific starting data.
:param data: The data that a relay received
:param user_address: An address defining RelayUser which received the data
"""
if data.startswith(RELAY_CONTROL_MESSAGE):
self.disconnect_user(user_address)
return False
return True
def disconnect_user(self, user_address: IPv4Address):
"""
Handle when a user disconnects.
@ -66,6 +55,16 @@ class RelayUserHandler:
if user_address in self._relay_users:
del self._relay_users[user_address]
def get_potential_users(self) -> Dict[IPv4Address, RelayUser]:
with self._lock:
return self._potential_users.copy()
def has_potential_users(self) -> bool:
"""
Return whether or not we have any potential users.
"""
current_time = time()
self._potential_users = dict(
filter(
lambda ru: (current_time - ru[1].last_update_time) < self._new_client_timeout,
self._potential_users.items(),
)
)
return len(self._potential_users) > 0

View File

@ -11,10 +11,6 @@ SOCKET_READ_TIMEOUT = 10
logger = getLogger(__name__)
def _default_client_data_received(_: bytes, client=None) -> bool:
return True
class SocketsPipe(Thread):
def __init__(
self,
@ -22,7 +18,6 @@ class SocketsPipe(Thread):
dest,
timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
):
Thread.__init__(self)
self.source = source
@ -32,13 +27,12 @@ class SocketsPipe(Thread):
super(SocketsPipe, self).__init__()
self.daemon = True
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
sockets = [self.source, self.dest]
while self._keep_connection:
self._keep_connection = False
rlist, wlist, xlist = select.select(sockets, [], sockets, self.timeout)
rlist, _, xlist = select.select(sockets, [], sockets, self.timeout)
if xlist:
break
for r in rlist:
@ -47,7 +41,7 @@ class SocketsPipe(Thread):
data = r.recv(READ_BUFFER_SIZE)
except Exception:
break
if data and self._client_data_received(data):
if data:
try:
other.sendall(data)
update_last_serve_time()

View File

@ -1,7 +1,6 @@
import socket
from ipaddress import IPv4Address
from threading import Event, Thread
from typing import Callable
from typing import Callable, List
PROXY_TIMEOUT = 2.5
@ -13,7 +12,7 @@ class TCPConnectionHandler(Thread):
self,
local_port: int,
local_host: str = "",
client_connected: Callable[[socket.socket, IPv4Address], None] = None,
client_connected: List[Callable[[socket.socket], None]] = [],
):
self.local_port = local_port
self.local_host = local_host
@ -29,22 +28,14 @@ class TCPConnectionHandler(Thread):
while not self._stopped:
try:
source, address = l_socket.accept()
source, _ = l_socket.accept()
except socket.timeout:
continue
if self._client_connected:
self._client_connected(source, IPv4Address(address[0]))
for notify_client_connected in self._client_connected:
notify_client_connected(source)
l_socket.close()
def stop(self):
self._stopped.set()
def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]):
"""
Register to be notified when a client connects.
:param callback: Callable used to notify when a client connects.
"""
self._client_connected = callback

View File

@ -1,16 +1,21 @@
import socket
from ipaddress import IPv4Address
from typing import Callable
from typing import List
from .tcp import SocketsPipe
class TCPPipeSpawner:
"""
Creates bi-directional pipes between the configured client and other clients.
"""
def __init__(self, target_addr: IPv4Address, target_port: int):
self._target_addr = target_addr
self._target_port = target_port
self._pipes: List[SocketsPipe] = []
def spawn_pipe(self, source: socket.socket) -> SocketsPipe:
def spawn_pipe(self, source: socket.socket):
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
dest.connect((self._target_addr, self._target_port))
@ -19,7 +24,11 @@ class TCPPipeSpawner:
dest.close()
raise err
return SocketsPipe(source, dest, client_data_received=self._client_data_received)
# TODO: have SocketsPipe notify TCPPipeSpawner when it's done
pipe = SocketsPipe(source, dest)
self._pipes.append(pipe)
pipe.run()
def notify_client_data_received(self, callback: Callable[[bytes], bool]):
self._client_data_received = callback
def has_open_pipes(self) -> bool:
self._pipes = [p for p in self._pipes if p.is_alive()]
return len(self._pipes) > 0

View File

@ -1,18 +1,7 @@
import socket
from ipaddress import IPv4Address
from threading import Event, Lock, Thread
from time import sleep, time
from typing import List
from time import sleep
from infection_monkey.network.relay import (
RelayUser,
RelayUserHandler,
SocketsPipe,
TCPConnectionHandler,
TCPPipeSpawner,
)
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
class TCPRelay(Thread):
@ -25,20 +14,15 @@ class TCPRelay(Thread):
relay_user_handler: RelayUserHandler,
connection_handler: TCPConnectionHandler,
pipe_spawner: TCPPipeSpawner,
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
):
self._stopped = Event()
self._user_handler = relay_user_handler
self._connection_handler = connection_handler
self._connection_handler.notify_client_connected(self._user_connected)
self._pipe_spawner = pipe_spawner
self._pipe_spawner.notify_client_data_received(self._user_handler.on_user_data_received)
self._new_client_timeout = new_client_timeout
super().__init__(name="MonkeyTcpRelayThread")
self.daemon = True
self._lock = Lock()
self._pipes: List[SocketsPipe] = []
def run(self):
self._connection_handler.start()
@ -48,32 +32,21 @@ class TCPRelay(Thread):
self._connection_handler.stop()
self._connection_handler.join()
[pipe.join() for pipe in self._pipes]
self._wait_for_pipes_to_close()
def stop(self):
self._stopped.set()
def _user_connected(self, source: socket.socket, user_addr: IPv4Address):
self._user_handler.add_relay_user(user_addr)
self._spawn_pipe(source)
def _spawn_pipe(self, source: socket.socket):
pipe = self._pipe_spawner.spawn_pipe(source)
self._pipes.append(pipe)
pipe.run()
def _wait_for_users_to_disconnect(self):
stop = False
while not stop:
"""
Blocks until the users disconnect or the timeout has elapsed.
"""
while self._user_handler.has_potential_users():
sleep(0.01)
current_time = time()
potential_users = self._user_handler.get_potential_users()
most_recent_potential_time = max(
potential_users.values(),
key=lambda ru: ru.last_update_time,
default=RelayUser(IPv4Address(""), 0.0),
).last_update_time
potential_elapsed = current_time - most_recent_potential_time
stop = not potential_users or potential_elapsed > self._new_client_timeout
def _wait_for_pipes_to_close(self):
"""
Blocks until the pipes have closed.
"""
while self._pipe_spawner.has_open_pipes():
sleep(0.01)

View File

@ -1,23 +1,35 @@
from ipaddress import IPv4Address
from time import sleep
import pytest
from monkey.infection_monkey.network.relay import RelayUserHandler
def test_potential_users_added():
user_address = IPv4Address("0.0.0.0")
handler = RelayUserHandler()
assert len(handler.get_potential_users()) == 0
handler.add_potential_user(user_address)
assert len(handler.get_potential_users()) == 1
assert user_address in handler.get_potential_users()
USER_ADDRESS = IPv4Address("0.0.0.0")
def test_potential_user_removed_on_matching_user_added():
user_address = IPv4Address("0.0.0.0")
handler = RelayUserHandler()
@pytest.fixture
def handler():
return RelayUserHandler()
handler.add_potential_user(user_address)
handler.add_relay_user(user_address)
assert len(handler.get_potential_users()) == 0
def test_potential_users_added(handler):
assert not handler.has_potential_users()
handler.add_potential_user(USER_ADDRESS)
assert handler.has_potential_users()
def test_potential_user_removed_on_matching_user_added(handler):
handler.add_potential_user(USER_ADDRESS)
handler.add_relay_user(USER_ADDRESS)
assert not handler.has_potential_users()
def test_potential_users_time_out():
handler = RelayUserHandler(new_client_timeout=0.001)
handler.add_potential_user(USER_ADDRESS)
sleep(0.003)
assert not handler.has_potential_users()

View File

@ -1,110 +0,0 @@
import socket
from ipaddress import IPv4Address
from threading import Thread
from typing import Callable
from unittest.mock import MagicMock
import pytest
from monkey.infection_monkey.network.relay.relay_user_handler import ( # RELAY_CONTROL_MESSAGE,
RelayUserHandler,
)
from monkey.infection_monkey.tcp_relay import TCPRelay
NEW_USER_ADDRESS = IPv4Address("0.0.0.1")
LOCAL_PORT = 9975
TARGET_ADDRESS = "0.0.0.0"
TARGET_PORT = 9976
class FakeConnectionHandler:
def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]):
self.cb = callback
def client_connected(self, socket: socket.socket, addr: IPv4Address):
self.cb(socket, addr)
def start(self):
pass
def stop(self):
pass
def join(self):
pass
class FakePipeSpawner:
spawn_pipe = MagicMock()
def notify_client_data_received(self, callback: Callable[[bytes], bool]):
self.cb = callback
def send_client_data(self, data: bytes):
self.cb(data)
@pytest.fixture
def relay_user_handler() -> RelayUserHandler:
return RelayUserHandler()
@pytest.fixture
def pipe_spawner():
return FakePipeSpawner()
@pytest.fixture
def connection_handler():
return FakeConnectionHandler()
@pytest.fixture
def tcp_relay(relay_user_handler, connection_handler, pipe_spawner) -> TCPRelay:
return TCPRelay(relay_user_handler, connection_handler, pipe_spawner)
def join_or_kill_thread(thread: Thread, timeout: float):
"""Whether or not the thread joined in the given timeout period."""
thread.join(timeout)
if thread.is_alive():
# Cannot set daemon status of active thread: thread.daemon = True
return False
return True
def test_user_added_when_user_connected(connection_handler, relay_user_handler, tcp_relay):
# sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# connection_handler.client_connected(sock, NEW_USER_ADDRESS)
# assert len(relay_user_handler.get_relay_users()) == 1
pass
def test_pipe_created_when_user_connected(connection_handler, pipe_spawner, tcp_relay):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
connection_handler.client_connected(sock, NEW_USER_ADDRESS)
assert pipe_spawner.spawn_pipe.called
def test_user_removed_on_request(relay_user_handler, pipe_spawner, tcp_relay):
relay_user_handler.add_relay_user(NEW_USER_ADDRESS)
# pipe_spawner.send_client_data(RELAY_CONTROL_MESSAGE, NEW_USER_ADDRESS)
# users = relay_user_handler.get_relay_users()
# assert len(users) == 0
pass
# This will fail unless TcpProxy is updated to do non-blocking accepts
# @pytest.mark.slow
# def test_waits_for_exploited_machines():
# relay = TCPRelay(9975, "0.0.0.0", 9976, new_client_timeout=0.2)
# new_user = "0.0.0.1"
# relay.start()
# relay.add_potential_user(new_user)
# relay.stop()
# assert not join_or_kill_thread(relay, 0.1) # Should be waiting
# assert join_or_kill_thread(relay, 1) # Should be done waiting