diff --git a/monkey/infection_monkey/tcp_relay.py b/monkey/infection_monkey/tcp_relay.py index 23a6cb843..ed8340dd7 100644 --- a/monkey/infection_monkey/tcp_relay.py +++ b/monkey/infection_monkey/tcp_relay.py @@ -1,9 +1,16 @@ -from threading import Event, Thread +from dataclasses import dataclass +from threading import Event, Lock, Thread from time import sleep +from typing import List from infection_monkey.transport.tcp import TcpProxy +@dataclass +class RelayUser: + address: str + + class TCPRelay(Thread): """Provides and manages a TCP proxy connection.""" @@ -14,10 +21,16 @@ class TCPRelay(Thread): self._target_port = target_port super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread") self.daemon = True + self._relay_users: List[RelayUser] = [] + self._lock = Lock() def run(self): proxy = TcpProxy( - local_port=self._local_port, dest_host=self._target_addr, dest_port=self._target_port + local_port=self._local_port, + dest_host=self._target_addr, + dest_port=self._target_port, + client_connected=self.on_user_connected, + client_disconnected=self.on_user_disconnected, ) proxy.start() @@ -29,3 +42,15 @@ class TCPRelay(Thread): def stop(self): self._stopped.set() + + def on_user_connected(self, user: str): + with self._lock: + self._relay_users.append(RelayUser(user)) + + def on_user_disconnected(self, user: str): + with self._lock: + self._relay_users = [u for u in self._relay_users if u.address != user] + + def relay_users(self) -> List[RelayUser]: + with self._lock: + return self._relay_users.copy() diff --git a/monkey/infection_monkey/transport/tcp.py b/monkey/infection_monkey/transport/tcp.py index 83c631c3b..637d095d0 100644 --- a/monkey/infection_monkey/transport/tcp.py +++ b/monkey/infection_monkey/transport/tcp.py @@ -1,7 +1,9 @@ import select import socket +from functools import partial from logging import getLogger from threading import Thread +from typing import Callable from infection_monkey.transport.base import ( PROXY_TIMEOUT, @@ -16,7 +18,13 @@ logger = getLogger(__name__) class SocketsPipe(Thread): - def __init__(self, source, dest, timeout=SOCKET_READ_TIMEOUT): + def __init__( + self, + source, + dest, + timeout=SOCKET_READ_TIMEOUT, + client_disconnected: Callable[[str], None] = None, + ): Thread.__init__(self) self.source = source self.dest = dest @@ -24,6 +32,7 @@ class SocketsPipe(Thread): self._keep_connection = True super(SocketsPipe, self).__init__() self.daemon = True + self._client_disconnected = client_disconnected def run(self): sockets = [self.source, self.dest] @@ -48,9 +57,24 @@ class SocketsPipe(Thread): self.source.close() self.dest.close() + if self._client_disconnected: + self._client_disconnected() class TcpProxy(TransportProxyBase): + def __init__( + self, + local_port, + dest_host=None, + dest_port=None, + local_host="", + client_connected: Callable[[str], None] = None, + client_disconnected: Callable[[str], None] = None, + ): + super().__init__(local_port, dest_host, dest_port, local_host) + self._client_connected = client_connected + self._client_disconnected = client_disconnected + def run(self): pipes = [] l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -72,7 +96,10 @@ class TcpProxy(TransportProxyBase): dest.close() continue - pipe = SocketsPipe(source, dest) + on_disconnect = ( + partial(self._client_connected, address[0]) if self._client_connected else None + ) + pipe = SocketsPipe(source, dest, on_disconnect) pipes.append(pipe) logger.debug( "piping sockets %s:%s->%s:%s", @@ -81,6 +108,8 @@ class TcpProxy(TransportProxyBase): self.dest_host, self.dest_port, ) + if self._client_connected: + self._client_connected(address[0]) pipe.start() l_socket.close() diff --git a/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py b/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py new file mode 100644 index 000000000..4c0dc2bc9 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py @@ -0,0 +1,39 @@ +from threading import Thread + +from monkey.infection_monkey.tcp_relay import TCPRelay + + +def join_or_kill_thread(thread: Thread, timeout: float): + thread.join(timeout) + if thread.is_alive(): + thread.daemon = True + return False + return True + + +def test_stops(): + relay = TCPRelay(9975, "0.0.0.0", 9976) + relay.start() + relay.stop() + + assert join_or_kill_thread(relay, 0.1) + + +def test_user_added(): + relay = TCPRelay(9975, "0.0.0.0", 9976) + new_user = "0.0.0.1" + relay.on_user_connected(new_user) + + users = relay.relay_users() + assert len(users) == 1 + assert users[0].address == new_user + + +def test_user_removed(): + relay = TCPRelay(9975, "0.0.0.0", 9976) + new_user = "0.0.0.1" + relay.on_user_connected(new_user) + relay.on_user_disconnected(new_user) + + users = relay.relay_users() + assert len(users) == 0