Agent: Remove closed pipes from TCPPipeSpawner

This commit is contained in:
Kekoa Kaaikala 2022-09-06 18:03:01 +00:00
parent 83cc5fc336
commit 066947c59f
2 changed files with 60 additions and 32 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import select import select
from logging import getLogger from logging import getLogger
from threading import Thread from threading import Thread
@ -14,37 +16,46 @@ class SocketsPipe(Thread):
self, self,
source, source,
dest, dest,
pipe_closed: Callable[[SocketsPipe], None],
timeout=SOCKET_READ_TIMEOUT, timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
): ):
self.source = source self.source = source
self.dest = dest self.dest = dest
self.timeout = timeout self.timeout = timeout
super().__init__(name=f"SocketsPipeThread-{self.ident}", daemon=True) super().__init__(name=f"SocketsPipeThread-{self.ident}", daemon=True)
self._client_disconnected = client_disconnected self._pipe_closed = pipe_closed
def _pipe(self):
sockets = [self.source, self.dest]
while True:
# TODO: Figure out how to capture when the socket times out.
read_list, _, except_list = select.select(sockets, [], sockets, self.timeout)
if except_list:
raise Exception("select() failed")
if not read_list:
raise TimeoutError("")
for r in read_list:
other = self.dest if r is self.source else self.source
data = r.recv(READ_BUFFER_SIZE)
if data:
other.sendall(data)
def run(self): def run(self):
sockets = [self.source, self.dest]
keep_connection = True
while keep_connection:
keep_connection = False
rlist, _, xlist = select.select(sockets, [], sockets, self.timeout)
if xlist:
break
for r in rlist:
other = self.dest if r is self.source else self.source
try: try:
data = r.recv(READ_BUFFER_SIZE) self._pipe()
except Exception: except Exception as err:
break logger.debug(err)
if data:
try:
other.sendall(data)
except Exception:
break
keep_connection = True
try:
self.source.close() self.source.close()
except Exception as err:
logger.debug(f"Error while closing source socket: {err}")
try:
self.dest.close() self.dest.close()
if self._client_disconnected: except Exception as err:
self._client_disconnected() logger.debug(f"Error while closing destination socket: {err}")
self._pipe_closed(self)

View File

@ -1,6 +1,7 @@
import socket import socket
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import List from threading import Lock
from typing import Set
from .sockets_pipe import SocketsPipe from .sockets_pipe import SocketsPipe
@ -13,9 +14,16 @@ class TCPPipeSpawner:
def __init__(self, target_addr: IPv4Address, target_port: int): def __init__(self, target_addr: IPv4Address, target_port: int):
self._target_addr = target_addr self._target_addr = target_addr
self._target_port = target_port self._target_port = target_port
self._pipes: List[SocketsPipe] = [] self._pipes: Set[SocketsPipe] = set()
self._lock = Lock()
def spawn_pipe(self, source: socket.socket): def spawn_pipe(self, source: socket.socket):
"""
Attempt to create a pipe on between the configured client and the provided socket
:param source: A socket to the connecting client.
:raises socket.error: If a socket to the configured client could not be created.
"""
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM) dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try: try:
dest.connect((self._target_addr, self._target_port)) dest.connect((self._target_addr, self._target_port))
@ -24,11 +32,20 @@ class TCPPipeSpawner:
dest.close() dest.close()
raise err raise err
# TODO: have SocketsPipe notify TCPPipeSpawner when it's done pipe = SocketsPipe(source, dest, self._handle_pipe_closed)
pipe = SocketsPipe(source, dest) with self._lock:
self._pipes.append(pipe) self._pipes.add(pipe)
pipe.run() pipe.run()
def has_open_pipes(self) -> bool: def has_open_pipes(self) -> bool:
self._pipes = [p for p in self._pipes if p.is_alive()] """Return whether or not the TCPPipeSpawner has any open pipes."""
return len(self._pipes) > 0 with self._lock:
for p in self._pipes:
if p.is_alive():
return True
return False
def _handle_pipe_closed(self, pipe: SocketsPipe):
with self._lock:
self._pipes.discard(pipe)