forked from p15670423/monkey
Agent: Update TCPRelay to separate responsbilities
This commit is contained in:
parent
bbc9cf16e6
commit
72144faefc
|
@ -4,7 +4,7 @@ from threading import Lock
|
|||
from time import time
|
||||
from typing import Dict
|
||||
|
||||
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
|
||||
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -14,7 +14,8 @@ class RelayUser:
|
|||
|
||||
|
||||
class RelayUserHandler:
|
||||
def __init__(self):
|
||||
def __init__(self, new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT):
|
||||
self._new_client_timeout = new_client_timeout
|
||||
self._relay_users: Dict[IPv4Address, RelayUser] = {}
|
||||
self._potential_users: Dict[IPv4Address, RelayUser] = {}
|
||||
|
||||
|
@ -44,18 +45,6 @@ class RelayUserHandler:
|
|||
with self._lock:
|
||||
self._potential_users[user_address] = RelayUser(user_address, time())
|
||||
|
||||
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
|
||||
"""
|
||||
Disconnect a user with a specific starting data.
|
||||
|
||||
:param data: The data that a relay received
|
||||
:param user_address: An address defining RelayUser which received the data
|
||||
"""
|
||||
if data.startswith(RELAY_CONTROL_MESSAGE):
|
||||
self.disconnect_user(user_address)
|
||||
return False
|
||||
return True
|
||||
|
||||
def disconnect_user(self, user_address: IPv4Address):
|
||||
"""
|
||||
Handle when a user disconnects.
|
||||
|
@ -66,6 +55,16 @@ class RelayUserHandler:
|
|||
if user_address in self._relay_users:
|
||||
del self._relay_users[user_address]
|
||||
|
||||
def get_potential_users(self) -> Dict[IPv4Address, RelayUser]:
|
||||
with self._lock:
|
||||
return self._potential_users.copy()
|
||||
def has_potential_users(self) -> bool:
|
||||
"""
|
||||
Return whether or not we have any potential users.
|
||||
"""
|
||||
current_time = time()
|
||||
self._potential_users = dict(
|
||||
filter(
|
||||
lambda ru: (current_time - ru[1].last_update_time) < self._new_client_timeout,
|
||||
self._potential_users.items(),
|
||||
)
|
||||
)
|
||||
|
||||
return len(self._potential_users) > 0
|
||||
|
|
|
@ -11,10 +11,6 @@ SOCKET_READ_TIMEOUT = 10
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def _default_client_data_received(_: bytes, client=None) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SocketsPipe(Thread):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -22,7 +18,6 @@ class SocketsPipe(Thread):
|
|||
dest,
|
||||
timeout=SOCKET_READ_TIMEOUT,
|
||||
client_disconnected: Callable[[str], None] = None,
|
||||
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
|
||||
):
|
||||
Thread.__init__(self)
|
||||
self.source = source
|
||||
|
@ -32,13 +27,12 @@ class SocketsPipe(Thread):
|
|||
super(SocketsPipe, self).__init__()
|
||||
self.daemon = True
|
||||
self._client_disconnected = client_disconnected
|
||||
self._client_data_received = client_data_received
|
||||
|
||||
def run(self):
|
||||
sockets = [self.source, self.dest]
|
||||
while self._keep_connection:
|
||||
self._keep_connection = False
|
||||
rlist, wlist, xlist = select.select(sockets, [], sockets, self.timeout)
|
||||
rlist, _, xlist = select.select(sockets, [], sockets, self.timeout)
|
||||
if xlist:
|
||||
break
|
||||
for r in rlist:
|
||||
|
@ -47,7 +41,7 @@ class SocketsPipe(Thread):
|
|||
data = r.recv(READ_BUFFER_SIZE)
|
||||
except Exception:
|
||||
break
|
||||
if data and self._client_data_received(data):
|
||||
if data:
|
||||
try:
|
||||
other.sendall(data)
|
||||
update_last_serve_time()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import socket
|
||||
from ipaddress import IPv4Address
|
||||
from threading import Event, Thread
|
||||
from typing import Callable
|
||||
from typing import Callable, List
|
||||
|
||||
PROXY_TIMEOUT = 2.5
|
||||
|
||||
|
@ -13,7 +12,7 @@ class TCPConnectionHandler(Thread):
|
|||
self,
|
||||
local_port: int,
|
||||
local_host: str = "",
|
||||
client_connected: Callable[[socket.socket, IPv4Address], None] = None,
|
||||
client_connected: List[Callable[[socket.socket], None]] = [],
|
||||
):
|
||||
self.local_port = local_port
|
||||
self.local_host = local_host
|
||||
|
@ -29,22 +28,14 @@ class TCPConnectionHandler(Thread):
|
|||
|
||||
while not self._stopped:
|
||||
try:
|
||||
source, address = l_socket.accept()
|
||||
source, _ = l_socket.accept()
|
||||
except socket.timeout:
|
||||
continue
|
||||
|
||||
if self._client_connected:
|
||||
self._client_connected(source, IPv4Address(address[0]))
|
||||
for notify_client_connected in self._client_connected:
|
||||
notify_client_connected(source)
|
||||
|
||||
l_socket.close()
|
||||
|
||||
def stop(self):
|
||||
self._stopped.set()
|
||||
|
||||
def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]):
|
||||
"""
|
||||
Register to be notified when a client connects.
|
||||
|
||||
:param callback: Callable used to notify when a client connects.
|
||||
"""
|
||||
self._client_connected = callback
|
||||
|
|
|
@ -1,16 +1,21 @@
|
|||
import socket
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
|
||||
from .tcp import SocketsPipe
|
||||
|
||||
|
||||
class TCPPipeSpawner:
|
||||
"""
|
||||
Creates bi-directional pipes between the configured client and other clients.
|
||||
"""
|
||||
|
||||
def __init__(self, target_addr: IPv4Address, target_port: int):
|
||||
self._target_addr = target_addr
|
||||
self._target_port = target_port
|
||||
self._pipes: List[SocketsPipe] = []
|
||||
|
||||
def spawn_pipe(self, source: socket.socket) -> SocketsPipe:
|
||||
def spawn_pipe(self, source: socket.socket):
|
||||
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
dest.connect((self._target_addr, self._target_port))
|
||||
|
@ -19,7 +24,11 @@ class TCPPipeSpawner:
|
|||
dest.close()
|
||||
raise err
|
||||
|
||||
return SocketsPipe(source, dest, client_data_received=self._client_data_received)
|
||||
# TODO: have SocketsPipe notify TCPPipeSpawner when it's done
|
||||
pipe = SocketsPipe(source, dest)
|
||||
self._pipes.append(pipe)
|
||||
pipe.run()
|
||||
|
||||
def notify_client_data_received(self, callback: Callable[[bytes], bool]):
|
||||
self._client_data_received = callback
|
||||
def has_open_pipes(self) -> bool:
|
||||
self._pipes = [p for p in self._pipes if p.is_alive()]
|
||||
return len(self._pipes) > 0
|
||||
|
|
|
@ -1,18 +1,7 @@
|
|||
import socket
|
||||
from ipaddress import IPv4Address
|
||||
from threading import Event, Lock, Thread
|
||||
from time import sleep, time
|
||||
from typing import List
|
||||
from time import sleep
|
||||
|
||||
from infection_monkey.network.relay import (
|
||||
RelayUser,
|
||||
RelayUserHandler,
|
||||
SocketsPipe,
|
||||
TCPConnectionHandler,
|
||||
TCPPipeSpawner,
|
||||
)
|
||||
|
||||
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
|
||||
from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
|
||||
|
||||
|
||||
class TCPRelay(Thread):
|
||||
|
@ -25,20 +14,15 @@ class TCPRelay(Thread):
|
|||
relay_user_handler: RelayUserHandler,
|
||||
connection_handler: TCPConnectionHandler,
|
||||
pipe_spawner: TCPPipeSpawner,
|
||||
new_client_timeout: float = DEFAULT_NEW_CLIENT_TIMEOUT,
|
||||
):
|
||||
self._stopped = Event()
|
||||
|
||||
self._user_handler = relay_user_handler
|
||||
self._connection_handler = connection_handler
|
||||
self._connection_handler.notify_client_connected(self._user_connected)
|
||||
self._pipe_spawner = pipe_spawner
|
||||
self._pipe_spawner.notify_client_data_received(self._user_handler.on_user_data_received)
|
||||
self._new_client_timeout = new_client_timeout
|
||||
super().__init__(name="MonkeyTcpRelayThread")
|
||||
self.daemon = True
|
||||
self._lock = Lock()
|
||||
self._pipes: List[SocketsPipe] = []
|
||||
|
||||
def run(self):
|
||||
self._connection_handler.start()
|
||||
|
@ -48,32 +32,21 @@ class TCPRelay(Thread):
|
|||
|
||||
self._connection_handler.stop()
|
||||
self._connection_handler.join()
|
||||
|
||||
[pipe.join() for pipe in self._pipes]
|
||||
self._wait_for_pipes_to_close()
|
||||
|
||||
def stop(self):
|
||||
self._stopped.set()
|
||||
|
||||
def _user_connected(self, source: socket.socket, user_addr: IPv4Address):
|
||||
self._user_handler.add_relay_user(user_addr)
|
||||
self._spawn_pipe(source)
|
||||
|
||||
def _spawn_pipe(self, source: socket.socket):
|
||||
pipe = self._pipe_spawner.spawn_pipe(source)
|
||||
self._pipes.append(pipe)
|
||||
pipe.run()
|
||||
|
||||
def _wait_for_users_to_disconnect(self):
|
||||
stop = False
|
||||
while not stop:
|
||||
"""
|
||||
Blocks until the users disconnect or the timeout has elapsed.
|
||||
"""
|
||||
while self._user_handler.has_potential_users():
|
||||
sleep(0.01)
|
||||
current_time = time()
|
||||
potential_users = self._user_handler.get_potential_users()
|
||||
most_recent_potential_time = max(
|
||||
potential_users.values(),
|
||||
key=lambda ru: ru.last_update_time,
|
||||
default=RelayUser(IPv4Address(""), 0.0),
|
||||
).last_update_time
|
||||
potential_elapsed = current_time - most_recent_potential_time
|
||||
|
||||
stop = not potential_users or potential_elapsed > self._new_client_timeout
|
||||
def _wait_for_pipes_to_close(self):
|
||||
"""
|
||||
Blocks until the pipes have closed.
|
||||
"""
|
||||
while self._pipe_spawner.has_open_pipes():
|
||||
sleep(0.01)
|
||||
|
|
|
@ -1,23 +1,35 @@
|
|||
from ipaddress import IPv4Address
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
||||
from monkey.infection_monkey.network.relay import RelayUserHandler
|
||||
|
||||
|
||||
def test_potential_users_added():
|
||||
user_address = IPv4Address("0.0.0.0")
|
||||
handler = RelayUserHandler()
|
||||
|
||||
assert len(handler.get_potential_users()) == 0
|
||||
handler.add_potential_user(user_address)
|
||||
assert len(handler.get_potential_users()) == 1
|
||||
assert user_address in handler.get_potential_users()
|
||||
USER_ADDRESS = IPv4Address("0.0.0.0")
|
||||
|
||||
|
||||
def test_potential_user_removed_on_matching_user_added():
|
||||
user_address = IPv4Address("0.0.0.0")
|
||||
handler = RelayUserHandler()
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
return RelayUserHandler()
|
||||
|
||||
handler.add_potential_user(user_address)
|
||||
handler.add_relay_user(user_address)
|
||||
|
||||
assert len(handler.get_potential_users()) == 0
|
||||
def test_potential_users_added(handler):
|
||||
assert not handler.has_potential_users()
|
||||
handler.add_potential_user(USER_ADDRESS)
|
||||
assert handler.has_potential_users()
|
||||
|
||||
|
||||
def test_potential_user_removed_on_matching_user_added(handler):
|
||||
handler.add_potential_user(USER_ADDRESS)
|
||||
handler.add_relay_user(USER_ADDRESS)
|
||||
|
||||
assert not handler.has_potential_users()
|
||||
|
||||
|
||||
def test_potential_users_time_out():
|
||||
handler = RelayUserHandler(new_client_timeout=0.001)
|
||||
|
||||
handler.add_potential_user(USER_ADDRESS)
|
||||
sleep(0.003)
|
||||
|
||||
assert not handler.has_potential_users()
|
||||
|
|
|
@ -1,110 +0,0 @@
|
|||
import socket
|
||||
from ipaddress import IPv4Address
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from monkey.infection_monkey.network.relay.relay_user_handler import ( # RELAY_CONTROL_MESSAGE,
|
||||
RelayUserHandler,
|
||||
)
|
||||
from monkey.infection_monkey.tcp_relay import TCPRelay
|
||||
|
||||
NEW_USER_ADDRESS = IPv4Address("0.0.0.1")
|
||||
LOCAL_PORT = 9975
|
||||
TARGET_ADDRESS = "0.0.0.0"
|
||||
TARGET_PORT = 9976
|
||||
|
||||
|
||||
class FakeConnectionHandler:
|
||||
def notify_client_connected(self, callback: Callable[[socket.socket, IPv4Address], None]):
|
||||
self.cb = callback
|
||||
|
||||
def client_connected(self, socket: socket.socket, addr: IPv4Address):
|
||||
self.cb(socket, addr)
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def join(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakePipeSpawner:
|
||||
spawn_pipe = MagicMock()
|
||||
|
||||
def notify_client_data_received(self, callback: Callable[[bytes], bool]):
|
||||
self.cb = callback
|
||||
|
||||
def send_client_data(self, data: bytes):
|
||||
self.cb(data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def relay_user_handler() -> RelayUserHandler:
|
||||
return RelayUserHandler()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipe_spawner():
|
||||
return FakePipeSpawner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_handler():
|
||||
return FakeConnectionHandler()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tcp_relay(relay_user_handler, connection_handler, pipe_spawner) -> TCPRelay:
|
||||
return TCPRelay(relay_user_handler, connection_handler, pipe_spawner)
|
||||
|
||||
|
||||
def join_or_kill_thread(thread: Thread, timeout: float):
|
||||
"""Whether or not the thread joined in the given timeout period."""
|
||||
thread.join(timeout)
|
||||
if thread.is_alive():
|
||||
# Cannot set daemon status of active thread: thread.daemon = True
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_user_added_when_user_connected(connection_handler, relay_user_handler, tcp_relay):
|
||||
# sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# connection_handler.client_connected(sock, NEW_USER_ADDRESS)
|
||||
# assert len(relay_user_handler.get_relay_users()) == 1
|
||||
pass
|
||||
|
||||
|
||||
def test_pipe_created_when_user_connected(connection_handler, pipe_spawner, tcp_relay):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
connection_handler.client_connected(sock, NEW_USER_ADDRESS)
|
||||
assert pipe_spawner.spawn_pipe.called
|
||||
|
||||
|
||||
def test_user_removed_on_request(relay_user_handler, pipe_spawner, tcp_relay):
|
||||
relay_user_handler.add_relay_user(NEW_USER_ADDRESS)
|
||||
|
||||
# pipe_spawner.send_client_data(RELAY_CONTROL_MESSAGE, NEW_USER_ADDRESS)
|
||||
|
||||
# users = relay_user_handler.get_relay_users()
|
||||
# assert len(users) == 0
|
||||
pass
|
||||
|
||||
|
||||
# This will fail unless TcpProxy is updated to do non-blocking accepts
|
||||
# @pytest.mark.slow
|
||||
# def test_waits_for_exploited_machines():
|
||||
# relay = TCPRelay(9975, "0.0.0.0", 9976, new_client_timeout=0.2)
|
||||
# new_user = "0.0.0.1"
|
||||
# relay.start()
|
||||
|
||||
# relay.add_potential_user(new_user)
|
||||
# relay.stop()
|
||||
|
||||
# assert not join_or_kill_thread(relay, 0.1) # Should be waiting
|
||||
# assert join_or_kill_thread(relay, 1) # Should be done waiting
|
Loading…
Reference in New Issue