forked from p15670423/monkey
Agent: Track relay users in TCPRelay
This commit is contained in:
parent
79d5b8bed1
commit
9425a9463a
|
@ -1,9 +1,16 @@
|
|||
from threading import Event, Thread
|
||||
from dataclasses import dataclass
|
||||
from threading import Event, Lock, Thread
|
||||
from time import sleep
|
||||
from typing import List
|
||||
|
||||
from infection_monkey.transport.tcp import TcpProxy
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayUser:
|
||||
address: str
|
||||
|
||||
|
||||
class TCPRelay(Thread):
|
||||
"""Provides and manages a TCP proxy connection."""
|
||||
|
||||
|
@ -14,10 +21,16 @@ class TCPRelay(Thread):
|
|||
self._target_port = target_port
|
||||
super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread")
|
||||
self.daemon = True
|
||||
self._relay_users: List[RelayUser] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def run(self):
|
||||
proxy = TcpProxy(
|
||||
local_port=self._local_port, dest_host=self._target_addr, dest_port=self._target_port
|
||||
local_port=self._local_port,
|
||||
dest_host=self._target_addr,
|
||||
dest_port=self._target_port,
|
||||
client_connected=self.on_user_connected,
|
||||
client_disconnected=self.on_user_disconnected,
|
||||
)
|
||||
proxy.start()
|
||||
|
||||
|
@ -29,3 +42,15 @@ class TCPRelay(Thread):
|
|||
|
||||
def stop(self):
|
||||
self._stopped.set()
|
||||
|
||||
def on_user_connected(self, user: str):
|
||||
with self._lock:
|
||||
self._relay_users.append(RelayUser(user))
|
||||
|
||||
def on_user_disconnected(self, user: str):
|
||||
with self._lock:
|
||||
self._relay_users = [u for u in self._relay_users if u.address != user]
|
||||
|
||||
def relay_users(self) -> List[RelayUser]:
|
||||
with self._lock:
|
||||
return self._relay_users.copy()
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import select
|
||||
import socket
|
||||
from functools import partial
|
||||
from logging import getLogger
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from infection_monkey.transport.base import (
|
||||
PROXY_TIMEOUT,
|
||||
|
@ -16,7 +18,13 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class SocketsPipe(Thread):
|
||||
def __init__(self, source, dest, timeout=SOCKET_READ_TIMEOUT):
|
||||
def __init__(
|
||||
self,
|
||||
source,
|
||||
dest,
|
||||
timeout=SOCKET_READ_TIMEOUT,
|
||||
client_disconnected: Callable[[str], None] = None,
|
||||
):
|
||||
Thread.__init__(self)
|
||||
self.source = source
|
||||
self.dest = dest
|
||||
|
@ -24,6 +32,7 @@ class SocketsPipe(Thread):
|
|||
self._keep_connection = True
|
||||
super(SocketsPipe, self).__init__()
|
||||
self.daemon = True
|
||||
self._client_disconnected = client_disconnected
|
||||
|
||||
def run(self):
|
||||
sockets = [self.source, self.dest]
|
||||
|
@ -48,9 +57,24 @@ class SocketsPipe(Thread):
|
|||
|
||||
self.source.close()
|
||||
self.dest.close()
|
||||
if self._client_disconnected:
|
||||
self._client_disconnected()
|
||||
|
||||
|
||||
class TcpProxy(TransportProxyBase):
|
||||
def __init__(
|
||||
self,
|
||||
local_port,
|
||||
dest_host=None,
|
||||
dest_port=None,
|
||||
local_host="",
|
||||
client_connected: Callable[[str], None] = None,
|
||||
client_disconnected: Callable[[str], None] = None,
|
||||
):
|
||||
super().__init__(local_port, dest_host, dest_port, local_host)
|
||||
self._client_connected = client_connected
|
||||
self._client_disconnected = client_disconnected
|
||||
|
||||
def run(self):
|
||||
pipes = []
|
||||
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
@ -72,7 +96,10 @@ class TcpProxy(TransportProxyBase):
|
|||
dest.close()
|
||||
continue
|
||||
|
||||
pipe = SocketsPipe(source, dest)
|
||||
on_disconnect = (
|
||||
partial(self._client_connected, address[0]) if self._client_connected else None
|
||||
)
|
||||
pipe = SocketsPipe(source, dest, on_disconnect)
|
||||
pipes.append(pipe)
|
||||
logger.debug(
|
||||
"piping sockets %s:%s->%s:%s",
|
||||
|
@ -81,6 +108,8 @@ class TcpProxy(TransportProxyBase):
|
|||
self.dest_host,
|
||||
self.dest_port,
|
||||
)
|
||||
if self._client_connected:
|
||||
self._client_connected(address[0])
|
||||
pipe.start()
|
||||
|
||||
l_socket.close()
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
from threading import Thread
|
||||
|
||||
from monkey.infection_monkey.tcp_relay import TCPRelay
|
||||
|
||||
|
||||
def join_or_kill_thread(thread: Thread, timeout: float):
|
||||
thread.join(timeout)
|
||||
if thread.is_alive():
|
||||
thread.daemon = True
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_stops():
|
||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||
relay.start()
|
||||
relay.stop()
|
||||
|
||||
assert join_or_kill_thread(relay, 0.1)
|
||||
|
||||
|
||||
def test_user_added():
|
||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||
new_user = "0.0.0.1"
|
||||
relay.on_user_connected(new_user)
|
||||
|
||||
users = relay.relay_users()
|
||||
assert len(users) == 1
|
||||
assert users[0].address == new_user
|
||||
|
||||
|
||||
def test_user_removed():
|
||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||
new_user = "0.0.0.1"
|
||||
relay.on_user_connected(new_user)
|
||||
relay.on_user_disconnected(new_user)
|
||||
|
||||
users = relay.relay_users()
|
||||
assert len(users) == 0
|
Loading…
Reference in New Issue