Agent: Revert old TcpProxy and create a new one

This commit is contained in:
Kekoa Kaaikala 2022-09-01 13:14:07 +00:00
parent c9b7f924a3
commit 2e7be823a1
3 changed files with 131 additions and 43 deletions

View File

@ -0,0 +1,127 @@
import select
import socket
from functools import partial
from logging import getLogger
from threading import Thread
from typing import Callable
from infection_monkey.transport.base import (
PROXY_TIMEOUT,
TransportProxyBase,
update_last_serve_time,
)
READ_BUFFER_SIZE = 8192
SOCKET_READ_TIMEOUT = 10
logger = getLogger(__name__)
def _default_client_data_received(_: bytes, client=None) -> bool:
return True
class SocketsPipe(Thread):
def __init__(
self,
source,
dest,
timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
):
Thread.__init__(self)
self.source = source
self.dest = dest
self.timeout = timeout
self._keep_connection = True
super(SocketsPipe, self).__init__()
self.daemon = True
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
sockets = [self.source, self.dest]
while self._keep_connection:
self._keep_connection = False
rlist, wlist, xlist = select.select(sockets, [], sockets, self.timeout)
if xlist:
break
for r in rlist:
other = self.dest if r is self.source else self.source
try:
data = r.recv(READ_BUFFER_SIZE)
except Exception:
break
if data and self._client_data_received(data):
try:
other.sendall(data)
update_last_serve_time()
except Exception:
break
self._keep_connection = True
self.source.close()
self.dest.close()
if self._client_disconnected:
self._client_disconnected()
class TcpProxy(TransportProxyBase):
def __init__(
self,
local_port,
dest_host=None,
dest_port=None,
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
# TODO: Rethink client_disconnected
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self):
pipes = []
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
l_socket.bind((self.local_host, self.local_port))
l_socket.settimeout(PROXY_TIMEOUT)
l_socket.listen(5)
while not self._stopped:
try:
source, address = l_socket.accept()
except socket.timeout:
continue
dest = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
dest.connect((self.dest_host, self.dest_port))
except socket.error:
source.close()
dest.close()
continue
on_disconnect = (
partial(self._client_connected, address[0]) if self._client_connected else None
)
on_data_received = partial(self._client_data_received, client=address[0])
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
pipes.append(pipe)
logger.debug(
"piping sockets %s:%s->%s:%s",
address[0],
address[1],
self.dest_host,
self.dest_port,
)
if self._client_connected:
self._client_connected(address[0])
pipe.start()
l_socket.close()
for pipe in pipes:
pipe.join()

View File

@ -4,7 +4,7 @@ from threading import Event, Lock, Thread
from time import sleep, time from time import sleep, time
from typing import Dict from typing import Dict
from infection_monkey.transport.tcp import TcpProxy from infection_monkey.network.relay.tcp import TcpProxy
DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect DEFAULT_NEW_CLIENT_TIMEOUT = 3 # Wait up to 3 seconds for potential new clients to connect
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -" RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"

View File

@ -1,9 +1,7 @@
import select import select
import socket import socket
from functools import partial
from logging import getLogger from logging import getLogger
from threading import Thread from threading import Thread
from typing import Callable
from infection_monkey.transport.base import ( from infection_monkey.transport.base import (
PROXY_TIMEOUT, PROXY_TIMEOUT,
@ -17,19 +15,8 @@ SOCKET_READ_TIMEOUT = 10
logger = getLogger(__name__) logger = getLogger(__name__)
def _default_client_data_received(_: bytes, client=None) -> bool:
return True
class SocketsPipe(Thread): class SocketsPipe(Thread):
def __init__( def __init__(self, source, dest, timeout=SOCKET_READ_TIMEOUT):
self,
source,
dest,
timeout=SOCKET_READ_TIMEOUT,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes], bool] = _default_client_data_received,
):
Thread.__init__(self) Thread.__init__(self)
self.source = source self.source = source
self.dest = dest self.dest = dest
@ -37,8 +24,6 @@ class SocketsPipe(Thread):
self._keep_connection = True self._keep_connection = True
super(SocketsPipe, self).__init__() super(SocketsPipe, self).__init__()
self.daemon = True self.daemon = True
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self): def run(self):
sockets = [self.source, self.dest] sockets = [self.source, self.dest]
@ -53,7 +38,7 @@ class SocketsPipe(Thread):
data = r.recv(READ_BUFFER_SIZE) data = r.recv(READ_BUFFER_SIZE)
except Exception: except Exception:
break break
if data and self._client_data_received(data): if data:
try: try:
other.sendall(data) other.sendall(data)
update_last_serve_time() update_last_serve_time()
@ -63,27 +48,9 @@ class SocketsPipe(Thread):
self.source.close() self.source.close()
self.dest.close() self.dest.close()
if self._client_disconnected:
self._client_disconnected()
class TcpProxy(TransportProxyBase): class TcpProxy(TransportProxyBase):
def __init__(
self,
local_port,
dest_host=None,
dest_port=None,
local_host="",
client_connected: Callable[[str], None] = None,
client_disconnected: Callable[[str], None] = None,
client_data_received: Callable[[bytes, str], bool] = _default_client_data_received,
):
super().__init__(local_port, dest_host, dest_port, local_host)
self._client_connected = client_connected
# TODO: Rethink client_disconnected
self._client_disconnected = client_disconnected
self._client_data_received = client_data_received
def run(self): def run(self):
pipes = [] pipes = []
l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) l_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -105,11 +72,7 @@ class TcpProxy(TransportProxyBase):
dest.close() dest.close()
continue continue
on_disconnect = ( pipe = SocketsPipe(source, dest)
partial(self._client_connected, address[0]) if self._client_connected else None
)
on_data_received = partial(self._client_data_received, client=address[0])
pipe = SocketsPipe(source, dest, on_disconnect, on_data_received)
pipes.append(pipe) pipes.append(pipe)
logger.debug( logger.debug(
"piping sockets %s:%s->%s:%s", "piping sockets %s:%s->%s:%s",
@ -118,8 +81,6 @@ class TcpProxy(TransportProxyBase):
self.dest_host, self.dest_host,
self.dest_port, self.dest_port,
) )
if self._client_connected:
self._client_connected(address[0])
pipe.start() pipe.start()
l_socket.close() l_socket.close()