Agent: Track relay users in TCPRelay

This commit is contained in:
Kekoa Kaaikala 2022-08-30 20:53:16 +00:00
parent 79d5b8bed1
commit 9425a9463a
3 changed files with 97 additions and 4 deletions

View File

@ -1,9 +1,16 @@
from threading import Event, Thread
from dataclasses import dataclass
from threading import Event, Lock, Thread
from time import sleep
from typing import List
from infection_monkey.transport.tcp import TcpProxy
@dataclass
class RelayUser:
address: str
class TCPRelay(Thread):
"""Provides and manages a TCP proxy connection."""
@ -14,10 +21,16 @@ class TCPRelay(Thread):
self._target_port = target_port
super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread")
self.daemon = True
self._relay_users: List[RelayUser] = []
self._lock = Lock()
def run(self):
proxy = TcpProxy(
local_port=self._local_port, dest_host=self._target_addr, dest_port=self._target_port
local_port=self._local_port,
dest_host=self._target_addr,
dest_port=self._target_port,
client_connected=self.on_user_connected,
client_disconnected=self.on_user_disconnected,
)
proxy.start()
@ -29,3 +42,15 @@ class TCPRelay(Thread):
def stop(self):
self._stopped.set()
def on_user_connected(self, user: str):
with self._lock:
self._relay_users.append(RelayUser(user))
def on_user_disconnected(self, user: str):
with self._lock:
self._relay_users = [u for u in self._relay_users if u.address != user]
def relay_users(self) -> List[RelayUser]:
with self._lock:
return self._relay_users.copy()

View File

@ -1,7 +1,9 @@
import select
import socket
from functools import partial
from logging import getLogger
from threading import Thread
from typing import Callable
from infection_monkey.transport.base import (
PROXY_TIMEOUT,
@ -16,7 +18,13 @@ logger = getLogger(__name__)
class SocketsPipe(Thread):
def __init__(self, source, dest, timeout=SOCKET_READ_TIMEOUT):
def __init__(
self,
source,
dest,
timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
):
Thread.__init__(self)
self.source = source
self.dest = dest
@ -24,6 +32,7 @@ class SocketsPipe(Thread):
self._keep_connection = True
super(SocketsPipe, self).__init__()
self.daemon = True
self._client_disconnected = client_disconnected
def run(self):
sockets = [self.source, self.dest]
@ -48,9 +57,24 @@ class SocketsPipe(Thread):
self.source.close()
self.dest.close()
if self._client_disconnected:
self._client_disconnected()
class TcpProxy(TransportProxyBase):
def __init__(
self,
local_port,
dest_host=None,
dest_port=None,
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
self._client_disconnected = client_disconnected
def run(self):
pipes = []
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -72,7 +96,10 @@ class TcpProxy(TransportProxyBase):
dest.close()
continue
pipe = SocketsPipe(source, dest)
on_disconnect = (
partial(self._client_connected, address[0]) if self._client_connected else None
)
pipe = SocketsPipe(source, dest, on_disconnect)
pipes.append(pipe)
logger.debug(
"piping sockets %s:%s->%s:%s",
@ -81,6 +108,8 @@ class TcpProxy(TransportProxyBase):
self.dest_host,
self.dest_port,
)
if self._client_connected:
self._client_connected(address[0])
pipe.start()
l_socket.close()

View File

@ -0,0 +1,39 @@
from threading import Thread
from monkey.infection_monkey.tcp_relay import TCPRelay
def join_or_kill_thread(thread: Thread, timeout: float):
thread.join(timeout)
if thread.is_alive():
thread.daemon = True
return False
return True
def test_stops():
relay = TCPRelay(9975, "0.0.0.0", 9976)
relay.start()
relay.stop()
assert join_or_kill_thread(relay, 0.1)
def test_user_added():
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
users = relay.relay_users()
assert len(users) == 1
assert users[0].address == new_user
def test_user_removed():
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
relay.on_user_disconnected(new_user)
users = relay.relay_users()
assert len(users) == 0