From 066947c59f0bee20e028fc5729cf073e7a6d083a Mon Sep 17 00:00:00 2001 From: Kekoa Kaaikala Date: Tue, 6 Sep 2022 18:03:01 +0000 Subject: [PATCH] Agent: Remove closed pipes from TCPPipeSpawner --- .../network/relay/sockets_pipe.py | 61 +++++++++++-------- .../network/relay/tcp_pipe_spawner.py | 31 +++++++--- 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/monkey/infection_monkey/network/relay/sockets_pipe.py b/monkey/infection_monkey/network/relay/sockets_pipe.py index 0b33cb533..ff8e82c04 100644 --- a/monkey/infection_monkey/network/relay/sockets_pipe.py +++ b/monkey/infection_monkey/network/relay/sockets_pipe.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import select from logging import getLogger from threading import Thread @@ -14,37 +16,46 @@ class SocketsPipe(Thread): self, source, dest, + pipe_closed: Callable[[SocketsPipe], None], timeout=SOCKET_READ_TIMEOUT, - client_disconnected: Callable[[str], None] = None, ): self.source = source self.dest = dest self.timeout = timeout super().__init__(name=f"SocketsPipeThread-{self.ident}", daemon=True) - self._client_disconnected = client_disconnected + self._pipe_closed = pipe_closed + + def _pipe(self): + sockets = [self.source, self.dest] + while True: + # TODO: Figure out how to capture when the socket times out. + read_list, _, except_list = select.select(sockets, [], sockets, self.timeout) + if except_list: + raise Exception("select() failed") + + if not read_list: + raise TimeoutError("") + + for r in read_list: + other = self.dest if r is self.source else self.source + data = r.recv(READ_BUFFER_SIZE) + if data: + other.sendall(data) def run(self): - sockets = [self.source, self.dest] - keep_connection = True - while keep_connection: - keep_connection = False - rlist, _, xlist = select.select(sockets, [], sockets, self.timeout) - if xlist: - break - for r in rlist: - other = self.dest if r is self.source else self.source - try: - data = r.recv(READ_BUFFER_SIZE) - except Exception: - break - if data: - try: - other.sendall(data) - except Exception: - break - keep_connection = True + try: + self._pipe() + except Exception as err: + logger.debug(err) - self.source.close() - self.dest.close() - if self._client_disconnected: - self._client_disconnected() + try: + self.source.close() + except Exception as err: + logger.debug(f"Error while closing source socket: {err}") + + try: + self.dest.close() + except Exception as err: + logger.debug(f"Error while closing destination socket: {err}") + + self._pipe_closed(self) diff --git a/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py b/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py index 22183dd6c..5ffebaaea 100644 --- a/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py +++ b/monkey/infection_monkey/network/relay/tcp_pipe_spawner.py @@ -1,6 +1,7 @@ import socket from ipaddress import IPv4Address -from typing import List +from threading import Lock +from typing import Set from .sockets_pipe import SocketsPipe @@ -13,9 +14,16 @@ class TCPPipeSpawner: def __init__(self, target_addr: IPv4Address, target_port: int): self._target_addr = target_addr self._target_port = target_port - self._pipes: List[SocketsPipe] = [] + self._pipes: Set[SocketsPipe] = set() + self._lock = Lock() def spawn_pipe(self, source: socket.socket): + """ + Attempt to create a pipe on between the configured client and the provided socket + + :param source: A socket to the connecting client. + :raises socket.error: If a socket to the configured client could not be created. + """ dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: dest.connect((self._target_addr, self._target_port)) @@ -24,11 +32,20 @@ class TCPPipeSpawner: dest.close() raise err - # TODO: have SocketsPipe notify TCPPipeSpawner when it's done - pipe = SocketsPipe(source, dest) - self._pipes.append(pipe) + pipe = SocketsPipe(source, dest, self._handle_pipe_closed) + with self._lock: + self._pipes.add(pipe) pipe.run() def has_open_pipes(self) -> bool: - self._pipes = [p for p in self._pipes if p.is_alive()] - return len(self._pipes) > 0 + """Return whether or not the TCPPipeSpawner has any open pipes.""" + with self._lock: + for p in self._pipes: + if p.is_alive(): + return True + + return False + + def _handle_pipe_closed(self, pipe: SocketsPipe): + with self._lock: + self._pipes.discard(pipe)