forked from p15670423/monkey
Agent: Add InterruptableThreadMixin
This commit is contained in:
parent
066947c59f
commit
65f4edc625
|
@ -1,11 +1,13 @@
|
||||||
import socket
|
import socket
|
||||||
from threading import Event, Thread
|
from threading import Thread
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
|
from infection_monkey.utils.threading import InterruptableThreadMixin
|
||||||
|
|
||||||
PROXY_TIMEOUT = 2.5
|
PROXY_TIMEOUT = 2.5
|
||||||
|
|
||||||
|
|
||||||
class TCPConnectionHandler(Thread):
|
class TCPConnectionHandler(Thread, InterruptableThreadMixin):
|
||||||
"""Accepts connections on a TCP socket."""
|
"""Accepts connections on a TCP socket."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -18,7 +20,6 @@ class TCPConnectionHandler(Thread):
|
||||||
self.local_host = bind_host
|
self.local_host = bind_host
|
||||||
self._client_connected = client_connected
|
self._client_connected = client_connected
|
||||||
super().__init__(name="TCPConnectionHandler", daemon=True)
|
super().__init__(name="TCPConnectionHandler", daemon=True)
|
||||||
self._stopped = Event()
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
@ -26,7 +27,7 @@ class TCPConnectionHandler(Thread):
|
||||||
l_socket.settimeout(PROXY_TIMEOUT)
|
l_socket.settimeout(PROXY_TIMEOUT)
|
||||||
l_socket.listen(5)
|
l_socket.listen(5)
|
||||||
|
|
||||||
while not self._stopped.is_set():
|
while not self._interrupted.is_set():
|
||||||
try:
|
try:
|
||||||
source, _ = l_socket.accept()
|
source, _ = l_socket.accept()
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
|
@ -36,6 +37,3 @@ class TCPConnectionHandler(Thread):
|
||||||
notify_client_connected(source)
|
notify_client_connected(source)
|
||||||
|
|
||||||
l_socket.close()
|
l_socket.close()
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._stopped.set()
|
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from threading import Event, Lock, Thread
|
from threading import Lock, Thread
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
|
from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
|
||||||
|
from infection_monkey.utils.threading import InterruptableThreadMixin
|
||||||
|
|
||||||
|
|
||||||
class TCPRelay(Thread):
|
class TCPRelay(Thread, InterruptableThreadMixin):
|
||||||
"""
|
"""
|
||||||
Provides and manages a TCP proxy connection.
|
Provides and manages a TCP proxy connection.
|
||||||
"""
|
"""
|
||||||
|
@ -15,8 +16,6 @@ class TCPRelay(Thread):
|
||||||
connection_handler: TCPConnectionHandler,
|
connection_handler: TCPConnectionHandler,
|
||||||
pipe_spawner: TCPPipeSpawner,
|
pipe_spawner: TCPPipeSpawner,
|
||||||
):
|
):
|
||||||
self._stopped = Event()
|
|
||||||
|
|
||||||
self._user_handler = relay_user_handler
|
self._user_handler = relay_user_handler
|
||||||
self._connection_handler = connection_handler
|
self._connection_handler = connection_handler
|
||||||
self._pipe_spawner = pipe_spawner
|
self._pipe_spawner = pipe_spawner
|
||||||
|
@ -26,16 +25,13 @@ class TCPRelay(Thread):
|
||||||
def run(self):
|
def run(self):
|
||||||
self._connection_handler.start()
|
self._connection_handler.start()
|
||||||
|
|
||||||
self._stopped.wait()
|
self._interrupted.wait()
|
||||||
self._wait_for_users_to_disconnect()
|
self._wait_for_users_to_disconnect()
|
||||||
|
|
||||||
self._connection_handler.stop()
|
self._connection_handler.stop()
|
||||||
self._connection_handler.join()
|
self._connection_handler.join()
|
||||||
self._wait_for_pipes_to_close()
|
self._wait_for_pipes_to_close()
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._stopped.set()
|
|
||||||
|
|
||||||
def _wait_for_users_to_disconnect(self):
|
def _wait_for_users_to_disconnect(self):
|
||||||
"""
|
"""
|
||||||
Blocks until the users disconnect or the timeout has elapsed.
|
Blocks until the users disconnect or the timeout has elapsed.
|
||||||
|
|
|
@ -107,3 +107,12 @@ def interruptible_function(*, msg: Optional[str] = None, default_return_value: A
|
||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
return _decorator
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptableThreadMixin:
|
||||||
|
def __init__(self):
|
||||||
|
self._interrupted = Event()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop a running thread."""
|
||||||
|
self._interrupted.set()
|
||||||
|
|
Loading…
Reference in New Issue