Agent: Remove MonkeyTunnel

This commit is contained in:
Kekoa Kaaikala 2022-09-07 19:10:08 +00:00 committed by Mike Salvatore
parent 3516fa1fec
commit dcb77d6285
7 changed files with 13 additions and 343 deletions

View File

@ -2,20 +2,17 @@ import json
import logging
import platform
from socket import gethostname
from typing import Mapping, Optional
from typing import MutableMapping, Optional
import requests
from urllib3 import disable_warnings
import infection_monkey.tunnel as tunnel
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from infection_monkey.config import GUID
from infection_monkey.network.info import get_host_subnets, local_ips
from infection_monkey.transport.http import HTTPConnectProxy
from infection_monkey.transport.tcp import TcpProxy
from infection_monkey.utils import agent_process
from infection_monkey.utils.environment import is_windows_os
requests.packages.urllib3.disable_warnings()
disable_warnings() # noqa DUO131
logger = logging.getLogger(__name__)
@ -28,7 +25,7 @@ class ControlClient:
# https://github.com/guardicore/monkey/blob/133f7f5da131b481561141171827d1f9943f6aec/monkey/infection_monkey/telemetry/base_telem.py
control_client_object = None
def __init__(self, server_address: str, proxies: Optional[Mapping[str, str]] = None):
def __init__(self, server_address: str, proxies: Optional[MutableMapping[str, str]] = None):
self.proxies = {} if not proxies else proxies
self.server_address = server_address
@ -62,25 +59,6 @@ class ControlClient:
timeout=MEDIUM_REQUEST_TIMEOUT,
)
def set_proxies(self, proxy_find):
"""
Note: The proxy schema changes between different versions of requests and urllib3,
which causes the machine to not open a tunnel back.
If we get "ValueError: check_hostname requires server_hostname" or
"Proxy URL had not schema, should start with http:// or https://" errors,
the proxy schema needs to be changed.
Keep this in mind when upgrading to newer python version or when urllib3 and
requests are updated there is possibility that the proxy schema is changed.
https://github.com/psf/requests/issues/5297
https://github.com/psf/requests/issues/5855
"""
proxy_address, proxy_port = proxy_find
logger.info("Found tunnel at %s:%s" % (proxy_address, proxy_port))
if is_windows_os():
self.proxies["https"] = f"http://{proxy_address}:{proxy_port}"
else:
self.proxies["https"] = f"{proxy_address}:{proxy_port}"
def send_telemetry(self, telem_category, json_data: str):
if not self.server_address:
logger.error(
@ -117,29 +95,6 @@ class ControlClient:
except Exception as exc:
logger.warning(f"Error connecting to control server {self.server_address}: {exc}")
def create_control_tunnel(self, keep_tunnel_open_time: int):
if not self.server_address:
return None
my_proxy = self.proxies.get("https", "").replace("https://", "")
if my_proxy:
proxy_class = TcpProxy
try:
target_addr, target_port = my_proxy.split(":", 1)
target_port = int(target_port)
except ValueError:
return None
else:
proxy_class = HTTPConnectProxy
target_addr, target_port = None, None
return tunnel.MonkeyTunnel(
proxy_class,
keep_tunnel_open_time=keep_tunnel_open_time,
target_addr=target_addr,
target_port=target_port,
)
def get_pba_file(self, filename):
try:
return requests.get( # noqa: DUO123

View File

@ -4,7 +4,6 @@ from typing import Optional, Tuple
from infection_monkey.model import VictimHost
from infection_monkey.network import NetworkAddress
from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.tunnel import MonkeyTunnel
logger = logging.getLogger(__name__)
@ -12,12 +11,10 @@ logger = logging.getLogger(__name__)
class VictimHostFactory:
def __init__(
self,
tunnel: Optional[MonkeyTunnel],
island_ip: Optional[str],
island_port: Optional[str],
on_island: bool,
):
self.tunnel = tunnel
self.island_ip = island_ip
self.island_port = island_port
self.on_island = on_island
@ -26,19 +23,15 @@ class VictimHostFactory:
domain = network_address.domain or ""
victim_host = VictimHost(network_address.ip, domain)
if self.tunnel:
victim_host.default_tunnel = self.tunnel.get_tunnel_for_ip(victim_host.ip_addr)
if self.island_ip:
ip, port = self._choose_island_address(victim_host.ip_addr)
victim_host.set_island_address(ip, port)
logger.debug(f"Default tunnel for {victim_host} set to {victim_host.default_tunnel}")
logger.debug(f"Default server for {victim_host} set to {victim_host.default_server}")
return victim_host
def _choose_island_address(self, victim_ip: str) -> Tuple[str, Optional[str]]:
def _choose_island_address(self, victim_ip: str) -> Tuple[Optional[str], Optional[str]]:
# Victims need to connect back to the interface they can reach
# On island, choose the right interface to pass to children monkeys
if self.on_island:

View File

@ -381,7 +381,7 @@ class InfectionMonkey:
on_island = self._running_on_island(local_network_interfaces)
logger.debug(f"This agent is running on the island: {on_island}")
return VictimHostFactory(None, self._cmd_island_ip, self._cmd_island_port, on_island)
return VictimHostFactory(self._cmd_island_ip, self._cmd_island_port, on_island)
def _running_on_island(self, local_network_interfaces: List[IPv4Interface]) -> bool:
server_ip, _ = address_to_ip_port(self._control_client.server_address)

View File

@ -1,230 +0,0 @@
import logging
import socket
import struct
import time
from threading import Event, Thread
from common.utils import Timer
from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port, local_ips
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
from infection_monkey.transport.base import get_last_serve_time
logger = logging.getLogger(__name__)
MCAST_GROUP = "224.1.1.1"
MCAST_PORT = 5007
BUFFER_READ = 1024
DEFAULT_TIMEOUT = 10
QUIT_TIMEOUT = 60 * 10 # 10 minutes
def _set_multicast_socket(timeout=DEFAULT_TIMEOUT, adapter=""):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((adapter, MCAST_PORT))
sock.setsockopt(
socket.IPPROTO_IP,
socket.IP_ADD_MEMBERSHIP,
struct.pack("4sl", socket.inet_aton(MCAST_GROUP), socket.INADDR_ANY),
)
return sock
def _check_tunnel(address, port, existing_sock=None):
if not existing_sock:
sock = _set_multicast_socket()
else:
sock = existing_sock
logger.debug("Checking tunnel %s:%s", address, port)
is_open, _ = check_tcp_port(address, int(port))
if not is_open:
logger.debug("Could not connect to %s:%s", address, port)
if not existing_sock:
sock.close()
return False
try:
sock.sendto(b"+", (address, MCAST_PORT))
except Exception as exc:
logger.debug("Caught exception in tunnel registration: %s", exc)
if not existing_sock:
sock.close()
return True
def find_tunnel(default=None, attempts=3, timeout=DEFAULT_TIMEOUT):
l_ips = local_ips()
if default:
if default.find(":") != -1:
address, port = default.split(":", 1)
if _check_tunnel(address, port):
return address, port
for adapter in l_ips:
for attempt in range(0, attempts):
try:
logger.info("Trying to find using adapter %s", adapter)
sock = _set_multicast_socket(timeout, adapter)
sock.sendto(b"?", (MCAST_GROUP, MCAST_PORT))
tunnels = []
while True:
try:
answer, address = sock.recvfrom(BUFFER_READ)
if answer not in [b"?", b"+", b"-"]:
tunnels.append(answer)
except socket.timeout:
break
for tunnel in tunnels:
if tunnel.find(":") != -1:
address, port = tunnel.split(":", 1)
if address in l_ips:
continue
if _check_tunnel(address, port, sock):
sock.close()
return address, port
except Exception as exc:
logger.debug("Caught exception in tunnel lookup: %s", exc)
continue
return None
def quit_tunnel(address, timeout=DEFAULT_TIMEOUT):
try:
sock = _set_multicast_socket(timeout)
sock.sendto(b"-", (address, MCAST_PORT))
sock.close()
logger.debug("Success quitting tunnel")
except Exception as exc:
logger.debug("Exception quitting tunnel: %s", exc)
return
class MonkeyTunnel(Thread):
def __init__(
self,
proxy_class,
keep_tunnel_open_time,
target_addr=None,
target_port=None,
timeout=DEFAULT_TIMEOUT,
):
self._target_addr = target_addr
self._target_port = target_port
self._proxy_class = proxy_class
self._keep_tunnel_open_time = keep_tunnel_open_time
self._broad_sock = None
self._timeout = timeout
self._stopped = Event()
self._clients = []
self.local_port = None
super(MonkeyTunnel, self).__init__(name="MonkeyTunnelThread")
self.daemon = True
self.l_ips = None
self._wait_for_exploited_machines = Event()
def run(self):
self._broad_sock = _set_multicast_socket(self._timeout)
self.l_ips = local_ips()
self.local_port = get_free_tcp_port()
if not self.local_port:
return
if not firewall.listen_allowed(localport=self.local_port):
logger.info("Machine firewalled, listen not allowed, not running tunnel.")
return
proxy = self._proxy_class(
local_port=self.local_port, dest_host=self._target_addr, dest_port=self._target_port
)
logger.info(
"Running tunnel using proxy class: %s, listening on port %s, routing to: %s:%s",
proxy.__class__.__name__,
self.local_port,
self._target_addr,
self._target_port,
)
proxy.start()
while not self._stopped.is_set():
try:
search, address = self._broad_sock.recvfrom(BUFFER_READ)
if b"?" == search:
ip_match = get_interface_to_target(address[0])
if ip_match:
answer = "%s:%d" % (ip_match, self.local_port)
logger.debug(
"Got tunnel request from %s, answering with %s", address[0], answer
)
self._broad_sock.sendto(answer.encode(), (address[0], MCAST_PORT))
elif b"+" == search:
if not address[0] in self._clients:
logger.debug("Tunnel control: Added %s to watchlist", address[0])
self._clients.append(address[0])
elif b"-" == search:
logger.debug("Tunnel control: Removed %s from watchlist", address[0])
self._clients = [client for client in self._clients if client != address[0]]
except socket.timeout:
continue
logger.info("Stopping tunnel, waiting for clients: %s" % repr(self._clients))
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
# QUIT_TIMEOUT seconds
timer = Timer()
timer.set(self._calculate_timeout())
while self._clients and not timer.is_expired():
try:
search, address = self._broad_sock.recvfrom(BUFFER_READ)
if b"-" == search:
logger.debug("Tunnel control: Removed %s from watchlist", address[0])
self._clients = [client for client in self._clients if client != address[0]]
except socket.timeout:
continue
timer.set(self._calculate_timeout())
logger.info("Closing tunnel")
self._broad_sock.close()
proxy.stop()
proxy.join()
def _calculate_timeout(self) -> float:
try:
return QUIT_TIMEOUT - (time.time() - get_last_serve_time())
except TypeError: # get_last_serve_time() may return None
return 0.0
def get_tunnel_for_ip(self, ip: str):
if not self.local_port:
return
ip_match = get_interface_to_target(ip)
return "%s:%d" % (ip_match, self.local_port)
def set_wait_for_exploited_machines(self):
self._wait_for_exploited_machines.set()
def stop(self):
self._wait_for_exploited_machine_connection()
self._stopped.set()
def _wait_for_exploited_machine_connection(self):
if self._wait_for_exploited_machines.is_set():
logger.info(
f"Waiting {self._keep_tunnel_open_time} seconds for exploited machines to connect "
"to the tunnel."
)
time.sleep(self._keep_tunnel_open_time)

View File

@ -1,18 +1,9 @@
from unittest.mock import MagicMock
import pytest
from infection_monkey.model import VictimHostFactory
from infection_monkey.network import NetworkAddress
@pytest.fixture
def mock_tunnel():
tunnel = MagicMock()
tunnel.get_tunnel_for_ip = lambda _: "1.2.3.4:1234"
return tunnel
@pytest.fixture(autouse=True)
def mock_get_interface_to_target(monkeypatch):
monkeypatch.setattr(
@ -21,9 +12,7 @@ def mock_get_interface_to_target(monkeypatch):
def test_factory_no_tunnel():
factory = VictimHostFactory(
tunnel=None, island_ip="192.168.56.1", island_port="5000", on_island=False
)
factory = VictimHostFactory(island_ip="192.168.56.1", island_port="5000", on_island=False)
network_address = NetworkAddress("192.168.56.2", None)
victim = factory.build_victim_host(network_address)
@ -34,24 +23,8 @@ def test_factory_no_tunnel():
assert victim.domain_name == ""
def test_factory_with_tunnel(mock_tunnel):
factory = VictimHostFactory(
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port="5000", on_island=False
)
network_address = NetworkAddress("192.168.56.2", None)
victim = factory.build_victim_host(network_address)
assert victim.default_server == "192.168.56.1:5000"
assert victim.ip_addr == "192.168.56.2"
assert victim.default_tunnel == "1.2.3.4:1234"
assert victim.domain_name == ""
def test_factory_on_island(mock_tunnel):
factory = VictimHostFactory(
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port="99", on_island=True
)
def test_factory_on_island():
factory = VictimHostFactory(island_ip="192.168.56.1", island_port="99", on_island=True)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)
@ -63,10 +36,8 @@ def test_factory_on_island(mock_tunnel):
@pytest.mark.parametrize("default_port", ["", None])
def test_factory_no_port(mock_tunnel, default_port):
factory = VictimHostFactory(
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port=default_port, on_island=True
)
def test_factory_no_port(default_port):
factory = VictimHostFactory(island_ip="192.168.56.1", island_port=default_port, on_island=True)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)
@ -74,8 +45,8 @@ def test_factory_no_port(mock_tunnel, default_port):
assert victim.default_server == "1.1.1.1"
def test_factory_no_default_server(mock_tunnel):
factory = VictimHostFactory(tunnel=mock_tunnel, island_ip=None, island_port="", on_island=True)
def test_factory_no_default_server():
factory = VictimHostFactory(island_ip=None, island_port="", on_island=True)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)

View File

@ -1,16 +0,0 @@
import pytest
from monkey.infection_monkey.control import ControlClient
@pytest.mark.parametrize(
"is_windows_os,expected_proxy_string",
[(True, "http://8.8.8.8:45455"), (False, "8.8.8.8:45455")],
)
def test_control_set_proxies(monkeypatch, is_windows_os, expected_proxy_string):
monkeypatch.setattr("monkey.infection_monkey.control.is_windows_os", lambda: is_windows_os)
control_client = ControlClient("8.8.8.8:5000")
control_client.set_proxies(("8.8.8.8", "45455"))
assert control_client.proxies["https"] == expected_proxy_string

View File

@ -299,9 +299,6 @@ event
deserialize
serialized_event
# TODO: Remove when removing Tunnel code
create_control_tunnel
set_wait_for_exploited_machines
# pydantic base models
underscore_attrs_are_private