Agent: Update TCPRelay to separate responsbilities
This commit is contained in:
parent
bbc9cf16e6
commit
72144faefc
|
@ -4,7 +4,7 @@ from threading import Lock
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
|
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -14,7 +14,8 @@ class RelayUser:
|
||||||
|
|
||||||
|
|
||||||
class RelayUserHandler:
|
class RelayUserHandler:
|
||||||
def __init__(self):
|
def __init__(self, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT):
|
||||||
|
self._new_client_timeout = new_client_timeout
|
||||||
self._relay_users: Dict[IPv4Address, RelayUser] = {}
|
self._relay_users: Dict[IPv4Address, RelayUser] = {}
|
||||||
self._potential_users: Dict[IPv4Address, RelayUser] = {}
|
self._potential_users: Dict[IPv4Address, RelayUser] = {}
|
||||||
|
|
||||||
|
@ -44,18 +45,6 @@ class RelayUserHandler:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._potential_users[user_address] = RelayUser(user_address, time())
|
self._potential_users[user_address] = RelayUser(user_address, time())
|
||||||
|
|
||||||
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
|
|
||||||
"""
|
|
||||||
Disconnect a user with a specific starting data.
|
|
||||||
|
|
||||||
:param data: The data that a relay received
|
|
||||||
:param user_address: An address defining RelayUser which received the data
|
|
||||||
"""
|
|
||||||
if data.startswith(RELAY_CONTROL_MESSAGE):
|
|
||||||
self.disconnect_user(user_address)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def disconnect_user(self, user_address: IPv4Address):
|
def disconnect_user(self, user_address: IPv4Address):
|
||||||
"""
|
"""
|
||||||
Handle when a user disconnects.
|
Handle when a user disconnects.
|
||||||
|
@ -66,6 +55,16 @@ class RelayUserHandler:
|
||||||
if user_address in self._relay_users:
|
if user_address in self._relay_users:
|
||||||
del self._relay_users[user_address]
|
del self._relay_users[user_address]
|
||||||
|
|
||||||
def get_potential_users(self) -> Dict[IPv4Address, RelayUser]:
|
def has_potential_users(self) -> bool:
|
||||||
with self._lock:
|
"""
|
||||||
return self._potential_users.copy()
|
Return whether or not we have any potential users.
|
||||||
|
"""
|
||||||
|
current_time = time()
|
||||||
|
self._potential_users = dict(
|
||||||
|
filter(
|
||||||
|
lambda ru: (current_time - ru[1].last_update_time) < self._new_client_timeout,
|
||||||
|
self._potential_users.items(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(self._potential_users) > 0
|
||||||
|
|
|
@ -11,10 +11,6 @@ SOCKET_READ_TIMEOUT = 10
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _default_client_data_received(_: bytes, client=None) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class SocketsPipe(Thread):
|
class SocketsPipe(Thread):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -22,7 +18,6 @@ class SocketsPipe(Thread):
|
||||||
dest,
|
dest,
|
||||||
timeout=SOCKET_READ_TIMEOUT,
|
timeout=SOCKET_READ_TIMEOUT,
|
||||||
client_disconnected: Callable[[str], None] = None,
|
client_disconnected: Callable[[str], None] = None,
|
||||||
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
|
|
||||||
):
|
):
|
||||||
Thread.__init__(self)
|
Thread.__init__(self)
|
||||||
self.source = source
|
self.source = source
|
||||||
|
@ -32,13 +27,12 @@ class SocketsPipe(Thread):
|
||||||
super(SocketsPipe, self).__init__()
|
super(SocketsPipe, self).__init__()
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
self._client_disconnected = client_disconnected
|
self._client_disconnected = client_disconnected
|
||||||
self._client_data_received = client_data_received
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
sockets = [self.source, self.dest]
|
sockets = [self.source, self.dest]
|
||||||
while self._keep_connection:
|
while self._keep_connection:
|
||||||
self._keep_connection = False
|
self._keep_connection = False
|
||||||
rlist, wlist, xlist = select.select(sockets, [], sockets, self.timeout)
|
rlist, _, xlist = select.select(sockets, [], sockets, self.timeout)
|
||||||
if xlist:
|
if xlist:
|
||||||
break
|
break
|
||||||
for r in rlist:
|
for r in rlist:
|
||||||
|
@ -47,7 +41,7 @@ class SocketsPipe(Thread):
|
||||||
data = r.recv(READ_BUFFER_SIZE)
|
data = r.recv(READ_BUFFER_SIZE)
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
break
|
||||||
if data and self._client_data_received(data):
|
if data:
|
||||||
try:
|
try:
|
||||||
other.sendall(data)
|
other.sendall(data)
|
||||||
update_last_serve_time()
|
update_last_serve_time()
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import socket
|
import socket
|
||||||
from ipaddress import IPv4Address
|
|
||||||
from threading import Event, Thread
|
from threading import Event, Thread
|
||||||
from typing import Callable
|
from typing import Callable, List
|
||||||
|
|
||||||
PROXY_TIMEOUT = 2.5
|
PROXY_TIMEOUT = 2.5
|
||||||
|
|
||||||
|
@ -13,7 +12,7 @@ class TCPConnectionHandler(Thread):
|
||||||
self,
|
self,
|
||||||
local_port: int,
|
local_port: int,
|
||||||
local_host: str = "",
|
local_host: str = "",
|
||||||
client_connected: Callable[[socket.socket, IPv4Address], None] = None,
|
client_connected: List[Callable[[socket.socket], None]] = [],
|
||||||
):
|
):
|
||||||
self.local_port = local_port
|
self.local_port = local_port
|
||||||
self.local_host = local_host
|
self.local_host = local_host
|
||||||
|
@ -29,22 +28,14 @@ class TCPConnectionHandler(Thread):
|
||||||
|
|
||||||
while not self._stopped:
|
while not self._stopped:
|
||||||
try:
|
try:
|
||||||
source, address = l_socket.accept()
|
source, _ = l_socket.accept()
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._client_connected:
|
for notify_client_connected in self._client_connected:
|
||||||
self._client_connected(source, IPv4Address(address[0]))
|
notify_client_connected(source)
|
||||||
|
|
||||||
l_socket.close()
|
l_socket.close()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._stopped.set()
|
self._stopped.set()
|
||||||
|
|
||||||
def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]):
|
|
||||||
"""
|
|
||||||
Register to be notified when a client connects.
|
|
||||||
|
|
||||||
:param callback: Callable used to notify when a client connects.
|
|
||||||
"""
|
|
||||||
self._client_connected = callback
|
|
||||||
|
|
|
@ -1,16 +1,21 @@
|
||||||
import socket
|
import socket
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
from typing import Callable
|
from typing import List
|
||||||
|
|
||||||
from .tcp import SocketsPipe
|
from .tcp import SocketsPipe
|
||||||
|
|
||||||
|
|
||||||
class TCPPipeSpawner:
|
class TCPPipeSpawner:
|
||||||
|
"""
|
||||||
|
Creates bi-directional pipes between the configured client and other clients.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, target_addr: IPv4Address, target_port: int):
|
def __init__(self, target_addr: IPv4Address, target_port: int):
|
||||||
self._target_addr = target_addr
|
self._target_addr = target_addr
|
||||||
self._target_port = target_port
|
self._target_port = target_port
|
||||||
|
self._pipes: List[SocketsPipe] = []
|
||||||
|
|
||||||
def spawn_pipe(self, source: socket.socket) -> SocketsPipe:
|
def spawn_pipe(self, source: socket.socket):
|
||||||
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
try:
|
try:
|
||||||
dest.connect((self._target_addr, self._target_port))
|
dest.connect((self._target_addr, self._target_port))
|
||||||
|
@ -19,7 +24,11 @@ class TCPPipeSpawner:
|
||||||
dest.close()
|
dest.close()
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
return SocketsPipe(source, dest, client_data_received=self._client_data_received)
|
# TODO: have SocketsPipe notify TCPPipeSpawner when it's done
|
||||||
|
pipe = SocketsPipe(source, dest)
|
||||||
|
self._pipes.append(pipe)
|
||||||
|
pipe.run()
|
||||||
|
|
||||||
def notify_client_data_received(self, callback: Callable[[bytes], bool]):
|
def has_open_pipes(self) -> bool:
|
||||||
self._client_data_received = callback
|
self._pipes = [p for p in self._pipes if p.is_alive()]
|
||||||
|
return len(self._pipes) > 0
|
||||||
|
|
|
@ -1,18 +1,7 @@
|
||||||
import socket
|
|
||||||
from ipaddress import IPv4Address
|
|
||||||
from threading import Event, Lock, Thread
|
from threading import Event, Lock, Thread
|
||||||
from time import sleep, time
|
from time import sleep
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from infection_monkey.network.relay import (
|
from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
|
||||||
RelayUser,
|
|
||||||
RelayUserHandler,
|
|
||||||
SocketsPipe,
|
|
||||||
TCPConnectionHandler,
|
|
||||||
TCPPipeSpawner,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
|
|
||||||
|
|
||||||
|
|
||||||
class TCPRelay(Thread):
|
class TCPRelay(Thread):
|
||||||
|
@ -25,20 +14,15 @@ class TCPRelay(Thread):
|
||||||
relay_user_handler: RelayUserHandler,
|
relay_user_handler: RelayUserHandler,
|
||||||
connection_handler: TCPConnectionHandler,
|
connection_handler: TCPConnectionHandler,
|
||||||
pipe_spawner: TCPPipeSpawner,
|
pipe_spawner: TCPPipeSpawner,
|
||||||
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
|
|
||||||
):
|
):
|
||||||
self._stopped = Event()
|
self._stopped = Event()
|
||||||
|
|
||||||
self._user_handler = relay_user_handler
|
self._user_handler = relay_user_handler
|
||||||
self._connection_handler = connection_handler
|
self._connection_handler = connection_handler
|
||||||
self._connection_handler.notify_client_connected(self._user_connected)
|
|
||||||
self._pipe_spawner = pipe_spawner
|
self._pipe_spawner = pipe_spawner
|
||||||
self._pipe_spawner.notify_client_data_received(self._user_handler.on_user_data_received)
|
|
||||||
self._new_client_timeout = new_client_timeout
|
|
||||||
super().__init__(name="MonkeyTcpRelayThread")
|
super().__init__(name="MonkeyTcpRelayThread")
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
self._pipes: List[SocketsPipe] = []
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self._connection_handler.start()
|
self._connection_handler.start()
|
||||||
|
@ -48,32 +32,21 @@ class TCPRelay(Thread):
|
||||||
|
|
||||||
self._connection_handler.stop()
|
self._connection_handler.stop()
|
||||||
self._connection_handler.join()
|
self._connection_handler.join()
|
||||||
|
self._wait_for_pipes_to_close()
|
||||||
[pipe.join() for pipe in self._pipes]
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._stopped.set()
|
self._stopped.set()
|
||||||
|
|
||||||
def _user_connected(self, source: socket.socket, user_addr: IPv4Address):
|
|
||||||
self._user_handler.add_relay_user(user_addr)
|
|
||||||
self._spawn_pipe(source)
|
|
||||||
|
|
||||||
def _spawn_pipe(self, source: socket.socket):
|
|
||||||
pipe = self._pipe_spawner.spawn_pipe(source)
|
|
||||||
self._pipes.append(pipe)
|
|
||||||
pipe.run()
|
|
||||||
|
|
||||||
def _wait_for_users_to_disconnect(self):
|
def _wait_for_users_to_disconnect(self):
|
||||||
stop = False
|
"""
|
||||||
while not stop:
|
Blocks until the users disconnect or the timeout has elapsed.
|
||||||
|
"""
|
||||||
|
while self._user_handler.has_potential_users():
|
||||||
sleep(0.01)
|
sleep(0.01)
|
||||||
current_time = time()
|
|
||||||
potential_users = self._user_handler.get_potential_users()
|
|
||||||
most_recent_potential_time = max(
|
|
||||||
potential_users.values(),
|
|
||||||
key=lambda ru: ru.last_update_time,
|
|
||||||
default=RelayUser(IPv4Address(""), 0.0),
|
|
||||||
).last_update_time
|
|
||||||
potential_elapsed = current_time - most_recent_potential_time
|
|
||||||
|
|
||||||
stop = not potential_users or potential_elapsed > self._new_client_timeout
|
def _wait_for_pipes_to_close(self):
|
||||||
|
"""
|
||||||
|
Blocks until the pipes have closed.
|
||||||
|
"""
|
||||||
|
while self._pipe_spawner.has_open_pipes():
|
||||||
|
sleep(0.01)
|
||||||
|
|
|
@ -1,23 +1,35 @@
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from monkey.infection_monkey.network.relay import RelayUserHandler
|
from monkey.infection_monkey.network.relay import RelayUserHandler
|
||||||
|
|
||||||
|
USER_ADDRESS = IPv4Address("0.0.0.0")
|
||||||
def test_potential_users_added():
|
|
||||||
user_address = IPv4Address("0.0.0.0")
|
|
||||||
handler = RelayUserHandler()
|
|
||||||
|
|
||||||
assert len(handler.get_potential_users()) == 0
|
|
||||||
handler.add_potential_user(user_address)
|
|
||||||
assert len(handler.get_potential_users()) == 1
|
|
||||||
assert user_address in handler.get_potential_users()
|
|
||||||
|
|
||||||
|
|
||||||
def test_potential_user_removed_on_matching_user_added():
|
@pytest.fixture
|
||||||
user_address = IPv4Address("0.0.0.0")
|
def handler():
|
||||||
handler = RelayUserHandler()
|
return RelayUserHandler()
|
||||||
|
|
||||||
handler.add_potential_user(user_address)
|
|
||||||
handler.add_relay_user(user_address)
|
|
||||||
|
|
||||||
assert len(handler.get_potential_users()) == 0
|
def test_potential_users_added(handler):
|
||||||
|
assert not handler.has_potential_users()
|
||||||
|
handler.add_potential_user(USER_ADDRESS)
|
||||||
|
assert handler.has_potential_users()
|
||||||
|
|
||||||
|
|
||||||
|
def test_potential_user_removed_on_matching_user_added(handler):
|
||||||
|
handler.add_potential_user(USER_ADDRESS)
|
||||||
|
handler.add_relay_user(USER_ADDRESS)
|
||||||
|
|
||||||
|
assert not handler.has_potential_users()
|
||||||
|
|
||||||
|
|
||||||
|
def test_potential_users_time_out():
|
||||||
|
handler = RelayUserHandler(new_client_timeout=0.001)
|
||||||
|
|
||||||
|
handler.add_potential_user(USER_ADDRESS)
|
||||||
|
sleep(0.003)
|
||||||
|
|
||||||
|
assert not handler.has_potential_users()
|
||||||
|
|
|
@ -1,110 +0,0 @@
|
||||||
import socket
|
|
||||||
from ipaddress import IPv4Address
|
|
||||||
from threading import Thread
|
|
||||||
from typing import Callable
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from monkey.infection_monkey.network.relay.relay_user_handler import ( # RELAY_CONTROL_MESSAGE,
|
|
||||||
RelayUserHandler,
|
|
||||||
)
|
|
||||||
from monkey.infection_monkey.tcp_relay import TCPRelay
|
|
||||||
|
|
||||||
NEW_USER_ADDRESS = IPv4Address("0.0.0.1")
|
|
||||||
LOCAL_PORT = 9975
|
|
||||||
TARGET_ADDRESS = "0.0.0.0"
|
|
||||||
TARGET_PORT = 9976
|
|
||||||
|
|
||||||
|
|
||||||
class FakeConnectionHandler:
|
|
||||||
def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]):
|
|
||||||
self.cb = callback
|
|
||||||
|
|
||||||
def client_connected(self, socket: socket.socket, addr: IPv4Address):
|
|
||||||
self.cb(socket, addr)
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def join(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FakePipeSpawner:
|
|
||||||
spawn_pipe = MagicMock()
|
|
||||||
|
|
||||||
def notify_client_data_received(self, callback: Callable[[bytes], bool]):
|
|
||||||
self.cb = callback
|
|
||||||
|
|
||||||
def send_client_data(self, data: bytes):
|
|
||||||
self.cb(data)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def relay_user_handler() -> RelayUserHandler:
|
|
||||||
return RelayUserHandler()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def pipe_spawner():
|
|
||||||
return FakePipeSpawner()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def connection_handler():
|
|
||||||
return FakeConnectionHandler()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def tcp_relay(relay_user_handler, connection_handler, pipe_spawner) -> TCPRelay:
|
|
||||||
return TCPRelay(relay_user_handler, connection_handler, pipe_spawner)
|
|
||||||
|
|
||||||
|
|
||||||
def join_or_kill_thread(thread: Thread, timeout: float):
|
|
||||||
"""Whether or not the thread joined in the given timeout period."""
|
|
||||||
thread.join(timeout)
|
|
||||||
if thread.is_alive():
|
|
||||||
# Cannot set daemon status of active thread: thread.daemon = True
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def test_user_added_when_user_connected(connection_handler, relay_user_handler, tcp_relay):
|
|
||||||
# sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
# connection_handler.client_connected(sock, NEW_USER_ADDRESS)
|
|
||||||
# assert len(relay_user_handler.get_relay_users()) == 1
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_pipe_created_when_user_connected(connection_handler, pipe_spawner, tcp_relay):
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
connection_handler.client_connected(sock, NEW_USER_ADDRESS)
|
|
||||||
assert pipe_spawner.spawn_pipe.called
|
|
||||||
|
|
||||||
|
|
||||||
def test_user_removed_on_request(relay_user_handler, pipe_spawner, tcp_relay):
|
|
||||||
relay_user_handler.add_relay_user(NEW_USER_ADDRESS)
|
|
||||||
|
|
||||||
# pipe_spawner.send_client_data(RELAY_CONTROL_MESSAGE, NEW_USER_ADDRESS)
|
|
||||||
|
|
||||||
# users = relay_user_handler.get_relay_users()
|
|
||||||
# assert len(users) == 0
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# This will fail unless TcpProxy is updated to do non-blocking accepts
|
|
||||||
# @pytest.mark.slow
|
|
||||||
# def test_waits_for_exploited_machines():
|
|
||||||
# relay = TCPRelay(9975, "0.0.0.0", 9976, new_client_timeout=0.2)
|
|
||||||
# new_user = "0.0.0.1"
|
|
||||||
# relay.start()
|
|
||||||
|
|
||||||
# relay.add_potential_user(new_user)
|
|
||||||
# relay.stop()
|
|
||||||
|
|
||||||
# assert not join_or_kill_thread(relay, 0.1) # Should be waiting
|
|
||||||
# assert join_or_kill_thread(relay, 1) # Should be done waiting
|
|
Loading…
Reference in New Issue