From 9425a9463a8ddfeab2248a938850dfcac0df1fdd Mon Sep 17 00:00:00 2001
From: Kekoa Kaaikala <kekoa.kaaikala@gmail.com>
Date: Tue, 30 Aug 2022 20:53:16 +0000
Subject: [PATCH] Agent: Track relay users in TCPRelay

---
 monkey/infection_monkey/tcp_relay.py          | 29 +++++++++++++-
 monkey/infection_monkey/transport/tcp.py      | 33 +++++++++++++++-
 .../infection_monkey/test_tcp_relay.py        | 39 +++++++++++++++++++
 3 files changed, 97 insertions(+), 4 deletions(-)
 create mode 100644 monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py

diff --git a/monkey/infection_monkey/tcp_relay.py b/monkey/infection_monkey/tcp_relay.py
index 23a6cb843..ed8340dd7 100644
--- a/monkey/infection_monkey/tcp_relay.py
+++ b/monkey/infection_monkey/tcp_relay.py
@@ -1,9 +1,16 @@
-from threading import Event, Thread
+from dataclasses import dataclass
+from threading import Event, Lock, Thread
 from time import sleep
+from typing import List
 
 from infection_monkey.transport.tcp import TcpProxy
 
 
+@dataclass
+class RelayUser:
+    address: str
+
+
 class TCPRelay(Thread):
     """Provides and manages a TCP proxy connection."""
 
@@ -14,10 +21,16 @@ class TCPRelay(Thread):
         self._target_port = target_port
         super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread")
         self.daemon = True
+        self._relay_users: List[RelayUser] = []
+        self._lock = Lock()
 
     def run(self):
         proxy = TcpProxy(
-            local_port=self._local_port, dest_host=self._target_addr, dest_port=self._target_port
+            local_port=self._local_port,
+            dest_host=self._target_addr,
+            dest_port=self._target_port,
+            client_connected=self.on_user_connected,
+            client_disconnected=self.on_user_disconnected,
         )
         proxy.start()
 
@@ -29,3 +42,15 @@ class TCPRelay(Thread):
 
     def stop(self):
         self._stopped.set()
+
+    def on_user_connected(self, user: str):
+        with self._lock:
+            self._relay_users.append(RelayUser(user))
+
+    def on_user_disconnected(self, user: str):
+        with self._lock:
+            self._relay_users = [u for u in self._relay_users if u.address != user]
+
+    def relay_users(self) -> List[RelayUser]:
+        with self._lock:
+            return self._relay_users.copy()
diff --git a/monkey/infection_monkey/transport/tcp.py b/monkey/infection_monkey/transport/tcp.py
index 83c631c3b..637d095d0 100644
--- a/monkey/infection_monkey/transport/tcp.py
+++ b/monkey/infection_monkey/transport/tcp.py
@@ -1,7 +1,9 @@
 import select
 import socket
+from functools import partial
 from logging import getLogger
 from threading import Thread
+from typing import Callable
 
 from infection_monkey.transport.base import (
     PROXY_TIMEOUT,
@@ -16,7 +18,13 @@ logger = getLogger(__name__)
 
 
 class SocketsPipe(Thread):
-    def __init__(self, source, dest, timeout=SOCKET_READ_TIMEOUT):
+    def __init__(
+        self,
+        source,
+        dest,
+        timeout=SOCKET_READ_TIMEOUT,
+        client_disconnected: Callable[[str], None] = None,
+    ):
         Thread.__init__(self)
         self.source = source
         self.dest = dest
@@ -24,6 +32,7 @@ class SocketsPipe(Thread):
         self._keep_connection = True
         super(SocketsPipe, self).__init__()
         self.daemon = True
+        self._client_disconnected = client_disconnected
 
     def run(self):
         sockets = [self.source, self.dest]
@@ -48,9 +57,24 @@ class SocketsPipe(Thread):
 
         self.source.close()
         self.dest.close()
+        if self._client_disconnected:
+            self._client_disconnected()
 
 
 class TcpProxy(TransportProxyBase):
+    def __init__(
+        self,
+        local_port,
+        dest_host=None,
+        dest_port=None,
+        local_host="",
+        client_connected: Callable[[str], None] = None,
+        client_disconnected: Callable[[str], None] = None,
+    ):
+        super().__init__(local_port, dest_host, dest_port, local_host)
+        self._client_connected = client_connected
+        self._client_disconnected = client_disconnected
+
     def run(self):
         pipes = []
         l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -72,7 +96,10 @@ class TcpProxy(TransportProxyBase):
                 dest.close()
                 continue
 
-            pipe = SocketsPipe(source, dest)
+            on_disconnect = (
+                partial(self._client_connected, address[0]) if self._client_connected else None
+            )
+            pipe = SocketsPipe(source, dest, on_disconnect)
             pipes.append(pipe)
             logger.debug(
                 "piping sockets %s:%s->%s:%s",
@@ -81,6 +108,8 @@ class TcpProxy(TransportProxyBase):
                 self.dest_host,
                 self.dest_port,
             )
+            if self._client_connected:
+                self._client_connected(address[0])
             pipe.start()
 
         l_socket.close()
diff --git a/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py b/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py
new file mode 100644
index 000000000..4c0dc2bc9
--- /dev/null
+++ b/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py
@@ -0,0 +1,39 @@
+from threading import Thread
+
+from monkey.infection_monkey.tcp_relay import TCPRelay
+
+
+def join_or_kill_thread(thread: Thread, timeout: float):
+    thread.join(timeout)
+    if thread.is_alive():
+        thread.daemon = True
+        return False
+    return True
+
+
+def test_stops():
+    relay = TCPRelay(9975, "0.0.0.0", 9976)
+    relay.start()
+    relay.stop()
+
+    assert join_or_kill_thread(relay, 0.1)
+
+
+def test_user_added():
+    relay = TCPRelay(9975, "0.0.0.0", 9976)
+    new_user = "0.0.0.1"
+    relay.on_user_connected(new_user)
+
+    users = relay.relay_users()
+    assert len(users) == 1
+    assert users[0].address == new_user
+
+
+def test_user_removed():
+    relay = TCPRelay(9975, "0.0.0.0", 9976)
+    new_user = "0.0.0.1"
+    relay.on_user_connected(new_user)
+    relay.on_user_disconnected(new_user)
+
+    users = relay.relay_users()
+    assert len(users) == 0