Agent: Update TCPRelay to separate responsbilities

This commit is contained in:
Kekoa Kaaikala 2022-09-02 19:58:01 +00:00
parent bbc9cf16e6
commit 72144faefc
7 changed files with 77 additions and 209 deletions

View File

@ -4,7 +4,7 @@ from threading import Lock
from time import time from time import time
from typing import Dict from typing import Dict
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -" DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
@dataclass @dataclass
@ -14,7 +14,8 @@ class RelayUser:
class RelayUserHandler: class RelayUserHandler:
def __init__(self): def __init__(self, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT):
self._new_client_timeout = new_client_timeout
self._relay_users: Dict[IPv4Address, RelayUser] = {} self._relay_users: Dict[IPv4Address, RelayUser] = {}
self._potential_users: Dict[IPv4Address, RelayUser] = {} self._potential_users: Dict[IPv4Address, RelayUser] = {}
@ -44,18 +45,6 @@ class RelayUserHandler:
with self._lock: with self._lock:
self._potential_users[user_address] = RelayUser(user_address, time()) self._potential_users[user_address] = RelayUser(user_address, time())
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
"""
Disconnect a user with a specific starting data.
:param data: The data that a relay received
:param user_address: An address defining RelayUser which received the data
"""
if data.startswith(RELAY_CONTROL_MESSAGE):
self.disconnect_user(user_address)
return False
return True
def disconnect_user(self, user_address: IPv4Address): def disconnect_user(self, user_address: IPv4Address):
""" """
Handle when a user disconnects. Handle when a user disconnects.
@ -66,6 +55,16 @@ class RelayUserHandler:
if user_address in self._relay_users: if user_address in self._relay_users:
del self._relay_users[user_address] del self._relay_users[user_address]
def get_potential_users(self) -> Dict[IPv4Address, RelayUser]: def has_potential_users(self) -> bool:
with self._lock: """
return self._potential_users.copy() Return whether or not we have any potential users.
"""
current_time = time()
self._potential_users = dict(
filter(
lambda ru: (current_time - ru[1].last_update_time) < self._new_client_timeout,
self._potential_users.items(),
)
)
return len(self._potential_users) > 0

View File

@ -11,10 +11,6 @@ SOCKET_READ_TIMEOUT = 10
logger = getLogger(__name__) logger = getLogger(__name__)
def _default_client_data_received(_: bytes, client=None) -> bool:
return True
class SocketsPipe(Thread): class SocketsPipe(Thread):
def __init__( def __init__(
self, self,
@ -22,7 +18,6 @@ class SocketsPipe(Thread):
dest, dest,
timeout=SOCKET_READ_TIMEOUT, timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None, client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
): ):
Thread.__init__(self) Thread.__init__(self)
self.source = source self.source = source
@ -32,13 +27,12 @@ class SocketsPipe(Thread):
super(SocketsPipe, self).__init__() super(SocketsPipe, self).__init__()
self.daemon = True self.daemon = True
self._client_disconnected = client_disconnected self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self): def run(self):
sockets = [self.source, self.dest] sockets = [self.source, self.dest]
while self._keep_connection: while self._keep_connection:
self._keep_connection = False self._keep_connection = False
rlist, wlist, xlist = select.select(sockets, [], sockets, self.timeout) rlist, _, xlist = select.select(sockets, [], sockets, self.timeout)
if xlist: if xlist:
break break
for r in rlist: for r in rlist:
@ -47,7 +41,7 @@ class SocketsPipe(Thread):
data = r.recv(READ_BUFFER_SIZE) data = r.recv(READ_BUFFER_SIZE)
except Exception: except Exception:
break break
if data and self._client_data_received(data): if data:
try: try:
other.sendall(data) other.sendall(data)
update_last_serve_time() update_last_serve_time()

View File

@ -1,7 +1,6 @@
import socket import socket
from ipaddress import IPv4Address
from threading import Event, Thread from threading import Event, Thread
from typing import Callable from typing import Callable, List
PROXY_TIMEOUT = 2.5 PROXY_TIMEOUT = 2.5
@ -13,7 +12,7 @@ class TCPConnectionHandler(Thread):
self, self,
local_port: int, local_port: int,
local_host: str = "", local_host: str = "",
client_connected: Callable[[socket.socket, IPv4Address], None] = None, client_connected: List[Callable[[socket.socket], None]] = [],
): ):
self.local_port = local_port self.local_port = local_port
self.local_host = local_host self.local_host = local_host
@ -29,22 +28,14 @@ class TCPConnectionHandler(Thread):
while not self._stopped: while not self._stopped:
try: try:
source, address = l_socket.accept() source, _ = l_socket.accept()
except socket.timeout: except socket.timeout:
continue continue
if self._client_connected: for notify_client_connected in self._client_connected:
self._client_connected(source, IPv4Address(address[0])) notify_client_connected(source)
l_socket.close() l_socket.close()
def stop(self): def stop(self):
self._stopped.set() self._stopped.set()
def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]):
"""
Register to be notified when a client connects.
:param callback: Callable used to notify when a client connects.
"""
self._client_connected = callback

View File

@ -1,16 +1,21 @@
import socket import socket
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Callable from typing import List
from .tcp import SocketsPipe from .tcp import SocketsPipe
class TCPPipeSpawner: class TCPPipeSpawner:
"""
Creates bi-directional pipes between the configured client and other clients.
"""
def __init__(self, target_addr: IPv4Address, target_port: int): def __init__(self, target_addr: IPv4Address, target_port: int):
self._target_addr = target_addr self._target_addr = target_addr
self._target_port = target_port self._target_port = target_port
self._pipes: List[SocketsPipe] = []
def spawn_pipe(self, source: socket.socket) -> SocketsPipe: def spawn_pipe(self, source: socket.socket):
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM) dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try: try:
dest.connect((self._target_addr, self._target_port)) dest.connect((self._target_addr, self._target_port))
@ -19,7 +24,11 @@ class TCPPipeSpawner:
dest.close() dest.close()
raise err raise err
return SocketsPipe(source, dest, client_data_received=self._client_data_received) # TODO: have SocketsPipe notify TCPPipeSpawner when it's done
pipe = SocketsPipe(source, dest)
self._pipes.append(pipe)
pipe.run()
def notify_client_data_received(self, callback: Callable[[bytes], bool]): def has_open_pipes(self) -> bool:
self._client_data_received = callback self._pipes = [p for p in self._pipes if p.is_alive()]
return len(self._pipes) > 0

View File

@ -1,18 +1,7 @@
import socket
from ipaddress import IPv4Address
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from time import sleep, time from time import sleep
from typing import List
from infection_monkey.network.relay import ( from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
RelayUser,
RelayUserHandler,
SocketsPipe,
TCPConnectionHandler,
TCPPipeSpawner,
)
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
class TCPRelay(Thread): class TCPRelay(Thread):
@ -25,20 +14,15 @@ class TCPRelay(Thread):
relay_user_handler: RelayUserHandler, relay_user_handler: RelayUserHandler,
connection_handler: TCPConnectionHandler, connection_handler: TCPConnectionHandler,
pipe_spawner: TCPPipeSpawner, pipe_spawner: TCPPipeSpawner,
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
): ):
self._stopped = Event() self._stopped = Event()
self._user_handler = relay_user_handler self._user_handler = relay_user_handler
self._connection_handler = connection_handler self._connection_handler = connection_handler
self._connection_handler.notify_client_connected(self._user_connected)
self._pipe_spawner = pipe_spawner self._pipe_spawner = pipe_spawner
self._pipe_spawner.notify_client_data_received(self._user_handler.on_user_data_received)
self._new_client_timeout = new_client_timeout
super().__init__(name="MonkeyTcpRelayThread") super().__init__(name="MonkeyTcpRelayThread")
self.daemon = True self.daemon = True
self._lock = Lock() self._lock = Lock()
self._pipes: List[SocketsPipe] = []
def run(self): def run(self):
self._connection_handler.start() self._connection_handler.start()
@ -48,32 +32,21 @@ class TCPRelay(Thread):
self._connection_handler.stop() self._connection_handler.stop()
self._connection_handler.join() self._connection_handler.join()
self._wait_for_pipes_to_close()
[pipe.join() for pipe in self._pipes]
def stop(self): def stop(self):
self._stopped.set() self._stopped.set()
def _user_connected(self, source: socket.socket, user_addr: IPv4Address):
self._user_handler.add_relay_user(user_addr)
self._spawn_pipe(source)
def _spawn_pipe(self, source: socket.socket):
pipe = self._pipe_spawner.spawn_pipe(source)
self._pipes.append(pipe)
pipe.run()
def _wait_for_users_to_disconnect(self): def _wait_for_users_to_disconnect(self):
stop = False """
while not stop: Blocks until the users disconnect or the timeout has elapsed.
"""
while self._user_handler.has_potential_users():
sleep(0.01) sleep(0.01)
current_time = time()
potential_users = self._user_handler.get_potential_users()
most_recent_potential_time = max(
potential_users.values(),
key=lambda ru: ru.last_update_time,
default=RelayUser(IPv4Address(""), 0.0),
).last_update_time
potential_elapsed = current_time - most_recent_potential_time
stop = not potential_users or potential_elapsed > self._new_client_timeout def _wait_for_pipes_to_close(self):
"""
Blocks until the pipes have closed.
"""
while self._pipe_spawner.has_open_pipes():
sleep(0.01)

View File

@ -1,23 +1,35 @@
from ipaddress import IPv4Address from ipaddress import IPv4Address
from time import sleep
import pytest
from monkey.infection_monkey.network.relay import RelayUserHandler from monkey.infection_monkey.network.relay import RelayUserHandler
USER_ADDRESS = IPv4Address("0.0.0.0")
def test_potential_users_added():
user_address = IPv4Address("0.0.0.0")
handler = RelayUserHandler()
assert len(handler.get_potential_users()) == 0
handler.add_potential_user(user_address)
assert len(handler.get_potential_users()) == 1
assert user_address in handler.get_potential_users()
def test_potential_user_removed_on_matching_user_added(): @pytest.fixture
user_address = IPv4Address("0.0.0.0") def handler():
handler = RelayUserHandler() return RelayUserHandler()
handler.add_potential_user(user_address)
handler.add_relay_user(user_address)
assert len(handler.get_potential_users()) == 0 def test_potential_users_added(handler):
assert not handler.has_potential_users()
handler.add_potential_user(USER_ADDRESS)
assert handler.has_potential_users()
def test_potential_user_removed_on_matching_user_added(handler):
handler.add_potential_user(USER_ADDRESS)
handler.add_relay_user(USER_ADDRESS)
assert not handler.has_potential_users()
def test_potential_users_time_out():
handler = RelayUserHandler(new_client_timeout=0.001)
handler.add_potential_user(USER_ADDRESS)
sleep(0.003)
assert not handler.has_potential_users()

View File

@ -1,110 +0,0 @@
import socket
from ipaddress import IPv4Address
from threading import Thread
from typing import Callable
from unittest.mock import MagicMock
import pytest
from monkey.infection_monkey.network.relay.relay_user_handler import ( # RELAY_CONTROL_MESSAGE,
RelayUserHandler,
)
from monkey.infection_monkey.tcp_relay import TCPRelay
NEW_USER_ADDRESS = IPv4Address("0.0.0.1")
LOCAL_PORT = 9975
TARGET_ADDRESS = "0.0.0.0"
TARGET_PORT = 9976
class FakeConnectionHandler:
def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]):
self.cb = callback
def client_connected(self, socket: socket.socket, addr: IPv4Address):
self.cb(socket, addr)
def start(self):
pass
def stop(self):
pass
def join(self):
pass
class FakePipeSpawner:
spawn_pipe = MagicMock()
def notify_client_data_received(self, callback: Callable[[bytes], bool]):
self.cb = callback
def send_client_data(self, data: bytes):
self.cb(data)
@pytest.fixture
def relay_user_handler() -> RelayUserHandler:
return RelayUserHandler()
@pytest.fixture
def pipe_spawner():
return FakePipeSpawner()
@pytest.fixture
def connection_handler():
return FakeConnectionHandler()
@pytest.fixture
def tcp_relay(relay_user_handler, connection_handler, pipe_spawner) -> TCPRelay:
return TCPRelay(relay_user_handler, connection_handler, pipe_spawner)
def join_or_kill_thread(thread: Thread, timeout: float):
"""Whether or not the thread joined in the given timeout period."""
thread.join(timeout)
if thread.is_alive():
# Cannot set daemon status of active thread: thread.daemon = True
return False
return True
def test_user_added_when_user_connected(connection_handler, relay_user_handler, tcp_relay):
# sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# connection_handler.client_connected(sock, NEW_USER_ADDRESS)
# assert len(relay_user_handler.get_relay_users()) == 1
pass
def test_pipe_created_when_user_connected(connection_handler, pipe_spawner, tcp_relay):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
connection_handler.client_connected(sock, NEW_USER_ADDRESS)
assert pipe_spawner.spawn_pipe.called
def test_user_removed_on_request(relay_user_handler, pipe_spawner, tcp_relay):
relay_user_handler.add_relay_user(NEW_USER_ADDRESS)
# pipe_spawner.send_client_data(RELAY_CONTROL_MESSAGE, NEW_USER_ADDRESS)
# users = relay_user_handler.get_relay_users()
# assert len(users) == 0
pass
# This will fail unless TcpProxy is updated to do non-blocking accepts
# @pytest.mark.slow
# def test_waits_for_exploited_machines():
# relay = TCPRelay(9975, "0.0.0.0", 9976, new_client_timeout=0.2)
# new_user = "0.0.0.1"
# relay.start()
# relay.add_potential_user(new_user)
# relay.stop()
# assert not join_or_kill_thread(relay, 0.1) # Should be waiting
# assert join_or_kill_thread(relay, 1) # Should be done waiting