monkey/chaos_monkey/tunnel.py

193 lines
6.7 KiB
Python

import socket
import struct
import logging
from threading import Thread
from network.info import local_ips, get_free_tcp_port
from network.firewall import app as firewall
from transport.base import get_last_serve_time
from difflib import get_close_matches
from network.tools import check_port_tcp
from model import VictimHost
import time
__author__ = 'hoffer'
LOG = logging.getLogger(__name__)
MCAST_GROUP = '224.1.1.1'
MCAST_PORT = 5007
BUFFER_READ = 1024
DEFAULT_TIMEOUT = 10
QUIT_TIMEOUT = 60 * 10 # 10 minutes
def _set_multicast_socket(timeout=DEFAULT_TIMEOUT, adapter=''):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((adapter, MCAST_PORT))
sock.setsockopt(socket.IPPROTO_IP,
socket.IP_ADD_MEMBERSHIP,
struct.pack("4sl", socket.inet_aton(MCAST_GROUP), socket.INADDR_ANY))
return sock
def _check_tunnel(address, port, existing_sock=None):
if not existing_sock:
sock = _set_multicast_socket()
else:
sock = existing_sock
LOG.debug("Checking tunnel %s:%s", address, port)
is_open, _ = check_port_tcp(address, int(port))
if not is_open:
LOG.debug("Could not connect to %s:%s", address, port)
if not existing_sock:
sock.close()
return False
try:
sock.sendto("+", (address, MCAST_PORT))
except Exception, exc:
LOG.debug("Caught exception in tunnel registration: %s", exc)
if not existing_sock:
sock.close()
return True
def find_tunnel(default=None, attempts=3, timeout=DEFAULT_TIMEOUT):
l_ips = local_ips()
if default:
if default.find(':') != -1:
address, port = default.split(':', 1)
if _check_tunnel(address, port):
return address, port
for adapter in l_ips:
for attempt in range(0, attempts):
try:
LOG.info("Trying to find using adapter %s", adapter)
sock = _set_multicast_socket(timeout, adapter)
sock.sendto("?", (MCAST_GROUP, MCAST_PORT))
tunnels = []
while True:
try:
answer, address = sock.recvfrom(BUFFER_READ)
if answer not in ['?', '+', '-']:
tunnels.append(answer)
except socket.timeout:
break
for tunnel in tunnels:
if tunnel.find(':') != -1:
address, port = tunnel.split(':', 1)
if address in l_ips:
continue
if _check_tunnel(address, port, sock):
sock.close()
return address, port
except Exception, exc:
LOG.debug("Caught exception in tunnel lookup: %s", exc)
continue
return None
def quit_tunnel(address, timeout=DEFAULT_TIMEOUT):
try:
sock = _set_multicast_socket(timeout)
sock.sendto("-", (address, MCAST_PORT))
sock.close()
LOG.debug("Success quitting tunnel")
except Exception, exc:
LOG.debug("Exception quitting tunnel: %s", exc)
return
class MonkeyTunnel(Thread):
def __init__(self, proxy_class, target_addr=None, target_port=None, timeout=DEFAULT_TIMEOUT):
self._target_addr = target_addr
self._target_port = target_port
self._proxy_class = proxy_class
self._broad_sock = None
self._timeout = timeout
self._stopped = False
self._clients = []
self.local_port = None
super(MonkeyTunnel, self).__init__()
self.daemon = True
self.l_ips = None
def run(self):
self._broad_sock = _set_multicast_socket(self._timeout)
self.l_ips = local_ips()
self.local_port = get_free_tcp_port()
if not self.local_port:
return
if not firewall.listen_allowed(localport=self.local_port):
LOG.info("Machine firewalled, listen not allowed, not running tunnel.")
return
proxy = self._proxy_class(local_port=self.local_port, dest_host=self._target_addr, dest_port=self._target_port)
LOG.info("Running tunnel using proxy class: %s, listening on port %s, routing to: %s:%s",
proxy.__class__.__name__,
self.local_port,
self._target_addr,
self._target_port)
proxy.start()
while not self._stopped:
try:
search, address = self._broad_sock.recvfrom(BUFFER_READ)
if '?' == search:
ip_match = get_close_matches(address[0], self.l_ips) or self.l_ips
if ip_match:
answer = '%s:%d' % (ip_match[0], self.local_port)
LOG.debug("Got tunnel request from %s, answering with %s", address[0], answer)
self._broad_sock.sendto(answer, (address[0], MCAST_PORT))
elif '+' == search:
if not address[0] in self._clients:
LOG.debug("Tunnel control: Added %s to watchlist", address[0])
self._clients.append(address[0])
elif '-' == search:
LOG.debug("Tunnel control: Removed %s from watchlist", address[0])
self._clients = [client for client in self._clients if client != address[0]]
except socket.timeout:
continue
LOG.info("Stopping tunnel, waiting for clients: %s" % repr(self._clients))
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in QUIT_TIMEOUT seconds
while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT):
try:
search, address = self._broad_sock.recvfrom(BUFFER_READ)
if '-' == search:
LOG.debug("Tunnel control: Removed %s from watchlist", address[0])
self._clients = [client for client in self._clients if client != address[0]]
except socket.timeout:
continue
LOG.info("Closing tunnel")
self._broad_sock.close()
proxy.stop()
proxy.join()
def set_tunnel_for_host(self, host):
assert isinstance(host, VictimHost)
if not self.local_port:
return
ip_match = get_close_matches(host.ip_addr, local_ips()) or self.l_ips
host.default_tunnel = '%s:%d' % (ip_match[0], self.local_port)
def stop(self):
self._stopped = True