Agent: Track relay users in TCPRelay

This commit is contained in:
Kekoa Kaaikala 2022-08-30 20:53:16 +00:00
parent 79d5b8bed1
commit 9425a9463a
3 changed files with 97 additions and 4 deletions

View File

@ -1,9 +1,16 @@
from threading import Event, Thread from dataclasses import dataclass
from threading import Event, Lock, Thread
from time import sleep from time import sleep
from typing import List
from infection_monkey.transport.tcp import TcpProxy from infection_monkey.transport.tcp import TcpProxy
@dataclass
class RelayUser:
address: str
class TCPRelay(Thread): class TCPRelay(Thread):
"""Provides and manages a TCP proxy connection.""" """Provides and manages a TCP proxy connection."""
@ -14,10 +21,16 @@ class TCPRelay(Thread):
self._target_port = target_port self._target_port = target_port
super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread") super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread")
self.daemon = True self.daemon = True
self._relay_users: List[RelayUser] = []
self._lock = Lock()
def run(self): def run(self):
proxy = TcpProxy( proxy = TcpProxy(
local_port=self._local_port, dest_host=self._target_addr, dest_port=self._target_port local_port=self._local_port,
dest_host=self._target_addr,
dest_port=self._target_port,
client_connected=self.on_user_connected,
client_disconnected=self.on_user_disconnected,
) )
proxy.start() proxy.start()
@ -29,3 +42,15 @@ class TCPRelay(Thread):
def stop(self): def stop(self):
self._stopped.set() self._stopped.set()
def on_user_connected(self, user: str):
with self._lock:
self._relay_users.append(RelayUser(user))
def on_user_disconnected(self, user: str):
with self._lock:
self._relay_users = [u for u in self._relay_users if u.address != user]
def relay_users(self) -> List[RelayUser]:
with self._lock:
return self._relay_users.copy()

View File

@ -1,7 +1,9 @@
import select import select
import socket import socket
from functools import partial
from logging import getLogger from logging import getLogger
from threading import Thread from threading import Thread
from typing import Callable
from infection_monkey.transport.base import ( from infection_monkey.transport.base import (
PROXY_TIMEOUT, PROXY_TIMEOUT,
@ -16,7 +18,13 @@ logger = getLogger(__name__)
class SocketsPipe(Thread): class SocketsPipe(Thread):
def __init__(self, source, dest, timeout=SOCKET_READ_TIMEOUT): def __init__(
self,
source,
dest,
timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
):
Thread.__init__(self) Thread.__init__(self)
self.source = source self.source = source
self.dest = dest self.dest = dest
@ -24,6 +32,7 @@ class SocketsPipe(Thread):
self._keep_connection = True self._keep_connection = True
super(SocketsPipe, self).__init__() super(SocketsPipe, self).__init__()
self.daemon = True self.daemon = True
self._client_disconnected = client_disconnected
def run(self): def run(self):
sockets = [self.source, self.dest] sockets = [self.source, self.dest]
@ -48,9 +57,24 @@ class SocketsPipe(Thread):
self.source.close() self.source.close()
self.dest.close() self.dest.close()
if self._client_disconnected:
self._client_disconnected()
class TcpProxy(TransportProxyBase): class TcpProxy(TransportProxyBase):
def __init__(
self,
local_port,
dest_host=None,
dest_port=None,
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
self._client_disconnected = client_disconnected
def run(self): def run(self):
pipes = [] pipes = []
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -72,7 +96,10 @@ class TcpProxy(TransportProxyBase):
dest.close() dest.close()
continue continue
pipe = SocketsPipe(source, dest) on_disconnect = (
partial(self._client_connected, address[0]) if self._client_connected else None
)
pipe = SocketsPipe(source, dest, on_disconnect)
pipes.append(pipe) pipes.append(pipe)
logger.debug( logger.debug(
"piping sockets %s:%s->%s:%s", "piping sockets %s:%s->%s:%s",
@ -81,6 +108,8 @@ class TcpProxy(TransportProxyBase):
self.dest_host, self.dest_host,
self.dest_port, self.dest_port,
) )
if self._client_connected:
self._client_connected(address[0])
pipe.start() pipe.start()
l_socket.close() l_socket.close()

View File

@ -0,0 +1,39 @@
from threading import Thread
from monkey.infection_monkey.tcp_relay import TCPRelay
def join_or_kill_thread(thread: Thread, timeout: float):
thread.join(timeout)
if thread.is_alive():
thread.daemon = True
return False
return True
def test_stops():
relay = TCPRelay(9975, "0.0.0.0", 9976)
relay.start()
relay.stop()
assert join_or_kill_thread(relay, 0.1)
def test_user_added():
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
users = relay.relay_users()
assert len(users) == 1
assert users[0].address == new_user
def test_user_removed():
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
relay.on_user_disconnected(new_user)
users = relay.relay_users()
assert len(users) == 0