diff --git a/monkey/infection_monkey/network/relay/tcp_connection_handler.py b/monkey/infection_monkey/network/relay/tcp_connection_handler.py index d8c7b8337..4515e364e 100644 --- a/monkey/infection_monkey/network/relay/tcp_connection_handler.py +++ b/monkey/infection_monkey/network/relay/tcp_connection_handler.py @@ -1,11 +1,13 @@ import socket -from threading import Event, Thread +from threading import Thread from typing import Callable, List +from infection_monkey.utils.threading import InterruptableThreadMixin + PROXY_TIMEOUT = 2.5 -class TCPConnectionHandler(Thread): +class TCPConnectionHandler(Thread, InterruptableThreadMixin): """Accepts connections on a TCP socket.""" def __init__( @@ -18,7 +20,6 @@ class TCPConnectionHandler(Thread): self.local_host = bind_host self._client_connected = client_connected super().__init__(name="TCPConnectionHandler", daemon=True) - self._stopped = Event() def run(self): l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -26,7 +27,7 @@ class TCPConnectionHandler(Thread): l_socket.settimeout(PROXY_TIMEOUT) l_socket.listen(5) - while not self._stopped.is_set(): + while not self._interrupted.is_set(): try: source, _ = l_socket.accept() except socket.timeout: @@ -36,6 +37,3 @@ class TCPConnectionHandler(Thread): notify_client_connected(source) l_socket.close() - - def stop(self): - self._stopped.set() diff --git a/monkey/infection_monkey/tcp_relay.py b/monkey/infection_monkey/tcp_relay.py index abe7ae56a..a12d8a0d9 100644 --- a/monkey/infection_monkey/tcp_relay.py +++ b/monkey/infection_monkey/tcp_relay.py @@ -1,10 +1,11 @@ -from threading import Event, Lock, Thread +from threading import Lock, Thread from time import sleep from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner +from infection_monkey.utils.threading import InterruptableThreadMixin -class TCPRelay(Thread): +class TCPRelay(Thread, InterruptableThreadMixin): """ Provides and manages a TCP proxy connection. """ @@ -15,8 +16,6 @@ class TCPRelay(Thread): connection_handler: TCPConnectionHandler, pipe_spawner: TCPPipeSpawner, ): - self._stopped = Event() - self._user_handler = relay_user_handler self._connection_handler = connection_handler self._pipe_spawner = pipe_spawner @@ -26,16 +25,13 @@ class TCPRelay(Thread): def run(self): self._connection_handler.start() - self._stopped.wait() + self._interrupted.wait() self._wait_for_users_to_disconnect() self._connection_handler.stop() self._connection_handler.join() self._wait_for_pipes_to_close() - def stop(self): - self._stopped.set() - def _wait_for_users_to_disconnect(self): """ Blocks until the users disconnect or the timeout has elapsed. diff --git a/monkey/infection_monkey/utils/threading.py b/monkey/infection_monkey/utils/threading.py index 0443978e6..be28aa0b1 100644 --- a/monkey/infection_monkey/utils/threading.py +++ b/monkey/infection_monkey/utils/threading.py @@ -107,3 +107,12 @@ def interruptible_function(*, msg: Optional[str] = None, default_return_value: A return _wrapper return _decorator + + +class InterruptableThreadMixin: + def __init__(self): + self._interrupted = Event() + + def stop(self): + """Stop a running thread.""" + self._interrupted.set()