Agent: Add InterruptableThreadMixin

This commit is contained in:
Kekoa Kaaikala 2022-09-06 18:26:45 +00:00
parent 066947c59f
commit 65f4edc625
3 changed files with 18 additions and 15 deletions

View File

@ -1,11 +1,13 @@
import socket
from threading import Event, Thread
from threading import Thread
from typing import Callable, List
from infection_monkey.utils.threading import InterruptableThreadMixin
PROXY_TIMEOUT = 2.5
class TCPConnectionHandler(Thread):
class TCPConnectionHandler(Thread, InterruptableThreadMixin):
"""Accepts connections on a TCP socket."""
def __init__(
@ -18,7 +20,6 @@ class TCPConnectionHandler(Thread):
self.local_host = bind_host
self._client_connected = client_connected
super().__init__(name="TCPConnectionHandler", daemon=True)
self._stopped = Event()
def run(self):
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -26,7 +27,7 @@ class TCPConnectionHandler(Thread):
l_socket.settimeout(PROXY_TIMEOUT)
l_socket.listen(5)
while not self._stopped.is_set():
while not self._interrupted.is_set():
try:
source, _ = l_socket.accept()
except socket.timeout:
@ -36,6 +37,3 @@ class TCPConnectionHandler(Thread):
notify_client_connected(source)
l_socket.close()
def stop(self):
self._stopped.set()

View File

@ -1,10 +1,11 @@
from threading import Event, Lock, Thread
from threading import Lock, Thread
from time import sleep
from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
from infection_monkey.utils.threading import InterruptableThreadMixin
class TCPRelay(Thread):
class TCPRelay(Thread, InterruptableThreadMixin):
"""
Provides and manages a TCP proxy connection.
"""
@ -15,8 +16,6 @@ class TCPRelay(Thread):
connection_handler: TCPConnectionHandler,
pipe_spawner: TCPPipeSpawner,
):
self._stopped = Event()
self._user_handler = relay_user_handler
self._connection_handler = connection_handler
self._pipe_spawner = pipe_spawner
@ -26,16 +25,13 @@ class TCPRelay(Thread):
def run(self):
self._connection_handler.start()
self._stopped.wait()
self._interrupted.wait()
self._wait_for_users_to_disconnect()
self._connection_handler.stop()
self._connection_handler.join()
self._wait_for_pipes_to_close()
def stop(self):
self._stopped.set()
def _wait_for_users_to_disconnect(self):
"""
Blocks until the users disconnect or the timeout has elapsed.

View File

@ -107,3 +107,12 @@ def interruptible_function(*, msg: Optional[str] = None, default_return_value: A
return _wrapper
return _decorator
class InterruptableThreadMixin:
def __init__(self):
self._interrupted = Event()
def stop(self):
"""Stop a running thread."""
self._interrupted.set()