From fba5bea912a7a8081e98509b2317c871ca23ede2 Mon Sep 17 00:00:00 2001
From: itsikkes <itsikkes@users.noreply.github.com>
Date: Sat, 13 Aug 2016 18:38:31 +0300
Subject: [PATCH] Tunnel improvements - bugfix for using default tunnel,
 improvement in tunnel shutdown

1) Bugfix when searching for tunnel - registration packet might be sent
from wrong interface in case of the default tunnel
2) Tunnel shutdown now verifies that no one used the tunnel before
shutting it down (added code to allow tracing of last used time)
3) Timeouts increasments
---
 chaos_monkey/transport/base.py | 14 +++++++++
 chaos_monkey/transport/http.py | 10 +++---
 chaos_monkey/transport/tcp.py  |  5 +--
 chaos_monkey/tunnel.py         | 57 +++++++++++++++++++++++++---------
 4 files changed, 66 insertions(+), 20 deletions(-)

diff --git a/chaos_monkey/transport/base.py b/chaos_monkey/transport/base.py
index ecb6656a1..dae0ff072 100644
--- a/chaos_monkey/transport/base.py
+++ b/chaos_monkey/transport/base.py
@@ -1,8 +1,12 @@
+import time
 from threading import Thread
 
+g_last_served = None
 
 class TransportProxyBase(Thread):
     def __init__(self, local_port, dest_host=None, dest_port=None, local_host=''):
+        global g_last_served
+
         self.local_host = local_host
         self.local_port = local_port
         self.dest_host = dest_host
@@ -13,3 +17,13 @@ class TransportProxyBase(Thread):
 
     def stop(self):
         self._stopped = True
+
+
+def update_last_serve_time():
+    global g_last_served
+    g_last_served = time.time()
+
+
+def get_last_serve_time():
+    global g_last_served
+    return g_last_served
\ No newline at end of file
diff --git a/chaos_monkey/transport/http.py b/chaos_monkey/transport/http.py
index 6f10c1811..a3a2cae9c 100644
--- a/chaos_monkey/transport/http.py
+++ b/chaos_monkey/transport/http.py
@@ -1,7 +1,7 @@
 import urllib, BaseHTTPServer, threading, os.path
 import monkeyfs
 from logging import getLogger
-from base import TransportProxyBase
+from base import TransportProxyBase, update_last_serve_time
 from urlparse import urlsplit
 import select
 import socket
@@ -101,7 +101,7 @@ class FileServHTTPRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
 
 
 class HTTPConnectProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
-    timeout = 2               # timeout with clients, set to None not to make persistent connection
+    timeout = 30              # timeout with clients, set to None not to make persistent connection
     proxy_via = None          # pseudonym of the proxy in Via header, set to None not to modify original Via header
     protocol_version = "HTTP/1.1"    
 
@@ -118,7 +118,8 @@ class HTTPConnectProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
         address = (u.hostname, u.port or 443)
         try:
             conn = socket.create_connection(address)
-        except socket.error:
+        except socket.error, e:
+            LOG.debug("HTTPConnectProxyHandler: Got exception while trying to connect to %s: %s" % (repr(address), e))
             self.send_error(504)    # 504 Gateway Timeout
             return
         self.send_response(200, 'Connection Established')
@@ -138,6 +139,7 @@ class HTTPConnectProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
                 if data:
                     other.sendall(data)
                     keep_connection = True
+                    update_last_serve_time()
         conn.close()
 
     def log_message(self, format, *args):
@@ -191,6 +193,6 @@ class HTTPServer(threading.Thread):
 class HTTPConnectProxy(TransportProxyBase):
     def run(self):
         httpd = InternalHTTPServer((self.local_host, self.local_port), HTTPConnectProxyHandler)
-        httpd.timeout = 10
+        httpd.timeout = 30
         while not self._stopped:
             httpd.handle_request()
diff --git a/chaos_monkey/transport/tcp.py b/chaos_monkey/transport/tcp.py
index 0b5c94187..ee3a05442 100644
--- a/chaos_monkey/transport/tcp.py
+++ b/chaos_monkey/transport/tcp.py
@@ -1,11 +1,11 @@
 import socket
 import select
 from threading import Thread
-from base import TransportProxyBase
+from base import TransportProxyBase, update_last_serve_time
 from logging import getLogger
 
 READ_BUFFER_SIZE = 8192
-DEFAULT_TIMEOUT = 10
+DEFAULT_TIMEOUT = 30
 
 LOG = getLogger(__name__)
 
@@ -36,6 +36,7 @@ class SocketsPipe(Thread):
                 if data:
                     try:
                         other.sendall(data)
+                        update_last_serve_time()
                     except:
                         break
                     self._keep_connection = True
diff --git a/chaos_monkey/tunnel.py b/chaos_monkey/tunnel.py
index e8ba530a6..f6f258a84 100644
--- a/chaos_monkey/tunnel.py
+++ b/chaos_monkey/tunnel.py
@@ -4,6 +4,7 @@ import logging
 from threading import Thread
 from network.info import local_ips, get_free_tcp_port
 from network.firewall import app as firewall
+from transport.base import get_last_serve_time
 from difflib import get_close_matches
 from network.tools import check_port_tcp
 from model import VictimHost
@@ -31,9 +32,39 @@ def _set_multicast_socket(timeout=DEFAULT_TIMEOUT, adapter=''):
     return sock
 
 
+def _check_tunnel(address, port, existing_sock=None):
+    if not existing_sock:
+        sock = _set_multicast_socket()
+    else:
+        sock = existing_sock
+
+    LOG.debug("Checking tunnel %s:%s", address, port)
+    is_open, _ = check_port_tcp(address, int(port))
+    if not is_open:
+        LOG.debug("Could not connect to %s:%s", address, port)
+        if not existing_sock:
+            sock.close()
+        return False
+
+    try:
+        sock.sendto("+", (address, MCAST_PORT))
+    except Exception, exc:
+        LOG.debug("Caught exception in tunnel registration: %s", exc)
+
+    if not existing_sock:
+        sock.close()
+    return True
+
+
 def find_tunnel(default=None, attempts=3, timeout=DEFAULT_TIMEOUT):
     l_ips = local_ips()
 
+    if default:
+        if default.find(':') != -1:
+            address, port = default.split(':', 1)
+            if _check_tunnel(address, port):
+                return address, port
+
     for adapter in l_ips:
         for attempt in range(0, attempts):
             try:
@@ -41,8 +72,6 @@ def find_tunnel(default=None, attempts=3, timeout=DEFAULT_TIMEOUT):
                 sock = _set_multicast_socket(timeout, adapter)
                 sock.sendto("?", (MCAST_GROUP, MCAST_PORT))
                 tunnels = []
-                if default:
-                    tunnels.append(default)
 
                 while True:
                     try:
@@ -58,15 +87,10 @@ def find_tunnel(default=None, attempts=3, timeout=DEFAULT_TIMEOUT):
                         if address in l_ips:
                             continue
 
-                        LOG.debug("Checking tunnel %s:%s", address, port)
-                        is_open, _ = check_port_tcp(address, int(port))
-                        if not is_open:
-                            LOG.debug("Could not connect to %s:%s", address, port)
-                            continue
+                        if _check_tunnel(address, port, sock):
+                            sock.close()
+                            return address, port
 
-                        sock.sendto("+", (address, MCAST_PORT))
-                        sock.close()
-                        return address, port
             except Exception, exc:
                 LOG.debug("Caught exception in tunnel lookup: %s", exc)
                 continue
@@ -130,22 +154,27 @@ class MonkeyTunnel(Thread):
                         self._broad_sock.sendto(answer, (address[0], MCAST_PORT))
                 elif '+' == search:
                     if not address[0] in self._clients:
+                        LOG.debug("Tunnel control: Added %s to watchlist", address[0])
                         self._clients.append(address[0])
                 elif '-' == search:
+                        LOG.debug("Tunnel control: Removed %s from watchlist", address[0])
                         self._clients = [client for client in self._clients if client != address[0]]
             
             except socket.timeout:
                 continue
 
-        LOG.info("Stopping tunnel, waiting for clients")
-        stop_time = time.time()
-        while self._clients and (time.time() - stop_time < QUIT_TIMEOUT):
+        LOG.info("Stopping tunnel, waiting for clients: %s" % repr(self._clients))
+
+        # wait till all of the tunnel clients has been disconnected, or no one used the tunnel in QUIT_TIMEOUT seconds
+        while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT):
             try:
                 search, address = self._broad_sock.recvfrom(BUFFER_READ)
                 if '-' == search:
+                    LOG.debug("Tunnel control: Removed %s from watchlist", address[0])
                     self._clients = [client for client in self._clients if client != address[0]]
             except socket.timeout:
                 continue
+
         LOG.info("Closing tunnel")
         self._broad_sock.close()
         proxy.stop()
@@ -161,4 +190,4 @@ class MonkeyTunnel(Thread):
         host.default_tunnel = '%s:%d' % (ip_match[0], self.local_port)
 
     def stop(self):
-        self._stopped = True
+        self._stopped = True
\ No newline at end of file