Agent: Revert old TcpProxy and create a new one

This commit is contained in:
Kekoa Kaaikala 2022-09-01 13:14:07 +00:00
parent c9b7f924a3
commit 2e7be823a1
3 changed files with 131 additions and 43 deletions

View File

@ -0,0 +1,127 @@
import select
import socket
from functools import partial
from logging import getLogger
from threading import Thread
from typing import Callable
from infection_monkey.transport.base import (
PROXY_TIMEOUT,
TransportProxyBase,
update_last_serve_time,
)
READ_BUFFER_SIZE = 8192
SOCKET_READ_TIMEOUT = 10
logger = getLogger(__name__)
def _default_client_data_received(_: bytes, client=None) -> bool:
return True
class SocketsPipe(Thread):
def __init__(
self,
source,
dest,
timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
):
Thread.__init__(self)
self.source = source
self.dest = dest
self.timeout = timeout
self._keep_connection = True
super(SocketsPipe, self).__init__()
self.daemon = True
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
sockets = [self.source, self.dest]
while self._keep_connection:
self._keep_connection = False
rlist, wlist, xlist = select.select(sockets, [], sockets, self.timeout)
if xlist:
break
for r in rlist:
other = self.dest if r is self.source else self.source
try:
data = r.recv(READ_BUFFER_SIZE)
except Exception:
break
if data and self._client_data_received(data):
try:
other.sendall(data)
update_last_serve_time()
except Exception:
break
self._keep_connection = True
self.source.close()
self.dest.close()
if self._client_disconnected:
self._client_disconnected()
class TcpProxy(TransportProxyBase):
def __init__(
self,
local_port,
dest_host=None,
dest_port=None,
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
# TODO: Rethink client_disconnected
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
pipes = []
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
l_socket.bind((self.local_host, self.local_port))
l_socket.settimeout(PROXY_TIMEOUT)
l_socket.listen(5)
while not self._stopped:
try:
source, address = l_socket.accept()
except socket.timeout:
continue
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
dest.connect((self.dest_host, self.dest_port))
except socket.error:
source.close()
dest.close()
continue
on_disconnect = (
partial(self._client_connected, address[0]) if self._client_connected else None
)
on_data_received = partial(self._client_data_received, client=address[0])
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
pipes.append(pipe)
logger.debug(
"piping sockets %s:%s->%s:%s",
address[0],
address[1],
self.dest_host,
self.dest_port,
)
if self._client_connected:
self._client_connected(address[0])
pipe.start()
l_socket.close()
for pipe in pipes:
pipe.join()

View File

@ -4,7 +4,7 @@ from threading import Event, Lock, Thread
from time import sleep, time
from typing import Dict
from infection_monkey.transport.tcp import TcpProxy
from infection_monkey.network.relay.tcp import TcpProxy
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"

View File

@ -1,9 +1,7 @@
import select
import socket
from functools import partial
from logging import getLogger
from threading import Thread
from typing import Callable
from infection_monkey.transport.base import (
PROXY_TIMEOUT,
@ -17,19 +15,8 @@ SOCKET_READ_TIMEOUT = 10
logger = getLogger(__name__)
def _default_client_data_received(_: bytes, client=None) -> bool:
return True
class SocketsPipe(Thread):
def __init__(
self,
source,
dest,
timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
):
def __init__(self, source, dest, timeout=SOCKET_READ_TIMEOUT):
Thread.__init__(self)
self.source = source
self.dest = dest
@ -37,8 +24,6 @@ class SocketsPipe(Thread):
self._keep_connection = True
super(SocketsPipe, self).__init__()
self.daemon = True
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
sockets = [self.source, self.dest]
@ -53,7 +38,7 @@ class SocketsPipe(Thread):
data = r.recv(READ_BUFFER_SIZE)
except Exception:
break
if data and self._client_data_received(data):
if data:
try:
other.sendall(data)
update_last_serve_time()
@ -63,27 +48,9 @@ class SocketsPipe(Thread):
self.source.close()
self.dest.close()
if self._client_disconnected:
self._client_disconnected()
class TcpProxy(TransportProxyBase):
def __init__(
self,
local_port,
dest_host=None,
dest_port=None,
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
# TODO: Rethink client_disconnected
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
pipes = []
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -105,11 +72,7 @@ class TcpProxy(TransportProxyBase):
dest.close()
continue
on_disconnect = (
partial(self._client_connected, address[0]) if self._client_connected else None
)
on_data_received = partial(self._client_data_received, client=address[0])
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
pipe = SocketsPipe(source, dest)
pipes.append(pipe)
logger.debug(
"piping sockets %s:%s->%s:%s",
@ -118,8 +81,6 @@ class TcpProxy(TransportProxyBase):
self.dest_host,
self.dest_port,
)
if self._client_connected:
self._client_connected(address[0])
pipe.start()
l_socket.close()