forked from p15670423/monkey
Agent: Remove MonkeyTunnel
This commit is contained in:
parent
3516fa1fec
commit
dcb77d6285
|
@ -2,20 +2,17 @@ import json
|
|||
import logging
|
||||
import platform
|
||||
from socket import gethostname
|
||||
from typing import Mapping, Optional
|
||||
from typing import MutableMapping, Optional
|
||||
|
||||
import requests
|
||||
from urllib3 import disable_warnings
|
||||
|
||||
import infection_monkey.tunnel as tunnel
|
||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||
from infection_monkey.config import GUID
|
||||
from infection_monkey.network.info import get_host_subnets, local_ips
|
||||
from infection_monkey.transport.http import HTTPConnectProxy
|
||||
from infection_monkey.transport.tcp import TcpProxy
|
||||
from infection_monkey.utils import agent_process
|
||||
from infection_monkey.utils.environment import is_windows_os
|
||||
|
||||
requests.packages.urllib3.disable_warnings()
|
||||
disable_warnings() # noqa DUO131
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -28,7 +25,7 @@ class ControlClient:
|
|||
# https://github.com/guardicore/monkey/blob/133f7f5da131b481561141171827d1f9943f6aec/monkey/infection_monkey/telemetry/base_telem.py
|
||||
control_client_object = None
|
||||
|
||||
def __init__(self, server_address: str, proxies: Optional[Mapping[str, str]] = None):
|
||||
def __init__(self, server_address: str, proxies: Optional[MutableMapping[str, str]] = None):
|
||||
self.proxies = {} if not proxies else proxies
|
||||
self.server_address = server_address
|
||||
|
||||
|
@ -62,25 +59,6 @@ class ControlClient:
|
|||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
def set_proxies(self, proxy_find):
|
||||
"""
|
||||
Note: The proxy schema changes between different versions of requests and urllib3,
|
||||
which causes the machine to not open a tunnel back.
|
||||
If we get "ValueError: check_hostname requires server_hostname" or
|
||||
"Proxy URL had not schema, should start with http:// or https://" errors,
|
||||
the proxy schema needs to be changed.
|
||||
Keep this in mind when upgrading to newer python version or when urllib3 and
|
||||
requests are updated there is possibility that the proxy schema is changed.
|
||||
https://github.com/psf/requests/issues/5297
|
||||
https://github.com/psf/requests/issues/5855
|
||||
"""
|
||||
proxy_address, proxy_port = proxy_find
|
||||
logger.info("Found tunnel at %s:%s" % (proxy_address, proxy_port))
|
||||
if is_windows_os():
|
||||
self.proxies["https"] = f"http://{proxy_address}:{proxy_port}"
|
||||
else:
|
||||
self.proxies["https"] = f"{proxy_address}:{proxy_port}"
|
||||
|
||||
def send_telemetry(self, telem_category, json_data: str):
|
||||
if not self.server_address:
|
||||
logger.error(
|
||||
|
@ -117,29 +95,6 @@ class ControlClient:
|
|||
except Exception as exc:
|
||||
logger.warning(f"Error connecting to control server {self.server_address}: {exc}")
|
||||
|
||||
def create_control_tunnel(self, keep_tunnel_open_time: int):
|
||||
if not self.server_address:
|
||||
return None
|
||||
|
||||
my_proxy = self.proxies.get("https", "").replace("https://", "")
|
||||
if my_proxy:
|
||||
proxy_class = TcpProxy
|
||||
try:
|
||||
target_addr, target_port = my_proxy.split(":", 1)
|
||||
target_port = int(target_port)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
proxy_class = HTTPConnectProxy
|
||||
target_addr, target_port = None, None
|
||||
|
||||
return tunnel.MonkeyTunnel(
|
||||
proxy_class,
|
||||
keep_tunnel_open_time=keep_tunnel_open_time,
|
||||
target_addr=target_addr,
|
||||
target_port=target_port,
|
||||
)
|
||||
|
||||
def get_pba_file(self, filename):
|
||||
try:
|
||||
return requests.get( # noqa: DUO123
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Optional, Tuple
|
|||
from infection_monkey.model import VictimHost
|
||||
from infection_monkey.network import NetworkAddress
|
||||
from infection_monkey.network.tools import get_interface_to_target
|
||||
from infection_monkey.tunnel import MonkeyTunnel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -12,12 +11,10 @@ logger = logging.getLogger(__name__)
|
|||
class VictimHostFactory:
|
||||
def __init__(
|
||||
self,
|
||||
tunnel: Optional[MonkeyTunnel],
|
||||
island_ip: Optional[str],
|
||||
island_port: Optional[str],
|
||||
on_island: bool,
|
||||
):
|
||||
self.tunnel = tunnel
|
||||
self.island_ip = island_ip
|
||||
self.island_port = island_port
|
||||
self.on_island = on_island
|
||||
|
@ -26,19 +23,15 @@ class VictimHostFactory:
|
|||
domain = network_address.domain or ""
|
||||
victim_host = VictimHost(network_address.ip, domain)
|
||||
|
||||
if self.tunnel:
|
||||
victim_host.default_tunnel = self.tunnel.get_tunnel_for_ip(victim_host.ip_addr)
|
||||
|
||||
if self.island_ip:
|
||||
ip, port = self._choose_island_address(victim_host.ip_addr)
|
||||
victim_host.set_island_address(ip, port)
|
||||
|
||||
logger.debug(f"Default tunnel for {victim_host} set to {victim_host.default_tunnel}")
|
||||
logger.debug(f"Default server for {victim_host} set to {victim_host.default_server}")
|
||||
|
||||
return victim_host
|
||||
|
||||
def _choose_island_address(self, victim_ip: str) -> Tuple[str, Optional[str]]:
|
||||
def _choose_island_address(self, victim_ip: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
# Victims need to connect back to the interface they can reach
|
||||
# On island, choose the right interface to pass to children monkeys
|
||||
if self.on_island:
|
||||
|
|
|
@ -381,7 +381,7 @@ class InfectionMonkey:
|
|||
on_island = self._running_on_island(local_network_interfaces)
|
||||
logger.debug(f"This agent is running on the island: {on_island}")
|
||||
|
||||
return VictimHostFactory(None, self._cmd_island_ip, self._cmd_island_port, on_island)
|
||||
return VictimHostFactory(self._cmd_island_ip, self._cmd_island_port, on_island)
|
||||
|
||||
def _running_on_island(self, local_network_interfaces: List[IPv4Interface]) -> bool:
|
||||
server_ip, _ = address_to_ip_port(self._control_client.server_address)
|
||||
|
|
|
@ -1,230 +0,0 @@
|
|||
import logging
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from threading import Event, Thread
|
||||
|
||||
from common.utils import Timer
|
||||
from infection_monkey.network.firewall import app as firewall
|
||||
from infection_monkey.network.info import get_free_tcp_port, local_ips
|
||||
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
|
||||
from infection_monkey.transport.base import get_last_serve_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MCAST_GROUP = "224.1.1.1"
|
||||
MCAST_PORT = 5007
|
||||
BUFFER_READ = 1024
|
||||
DEFAULT_TIMEOUT = 10
|
||||
QUIT_TIMEOUT = 60 * 10 # 10 minutes
|
||||
|
||||
|
||||
def _set_multicast_socket(timeout=DEFAULT_TIMEOUT, adapter=""):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
sock.settimeout(timeout)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind((adapter, MCAST_PORT))
|
||||
sock.setsockopt(
|
||||
socket.IPPROTO_IP,
|
||||
socket.IP_ADD_MEMBERSHIP,
|
||||
struct.pack("4sl", socket.inet_aton(MCAST_GROUP), socket.INADDR_ANY),
|
||||
)
|
||||
return sock
|
||||
|
||||
|
||||
def _check_tunnel(address, port, existing_sock=None):
|
||||
if not existing_sock:
|
||||
sock = _set_multicast_socket()
|
||||
else:
|
||||
sock = existing_sock
|
||||
|
||||
logger.debug("Checking tunnel %s:%s", address, port)
|
||||
is_open, _ = check_tcp_port(address, int(port))
|
||||
if not is_open:
|
||||
logger.debug("Could not connect to %s:%s", address, port)
|
||||
if not existing_sock:
|
||||
sock.close()
|
||||
return False
|
||||
|
||||
try:
|
||||
sock.sendto(b"+", (address, MCAST_PORT))
|
||||
except Exception as exc:
|
||||
logger.debug("Caught exception in tunnel registration: %s", exc)
|
||||
|
||||
if not existing_sock:
|
||||
sock.close()
|
||||
return True
|
||||
|
||||
|
||||
def find_tunnel(default=None, attempts=3, timeout=DEFAULT_TIMEOUT):
|
||||
l_ips = local_ips()
|
||||
|
||||
if default:
|
||||
if default.find(":") != -1:
|
||||
address, port = default.split(":", 1)
|
||||
if _check_tunnel(address, port):
|
||||
return address, port
|
||||
|
||||
for adapter in l_ips:
|
||||
for attempt in range(0, attempts):
|
||||
try:
|
||||
logger.info("Trying to find using adapter %s", adapter)
|
||||
sock = _set_multicast_socket(timeout, adapter)
|
||||
sock.sendto(b"?", (MCAST_GROUP, MCAST_PORT))
|
||||
tunnels = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
answer, address = sock.recvfrom(BUFFER_READ)
|
||||
if answer not in [b"?", b"+", b"-"]:
|
||||
tunnels.append(answer)
|
||||
except socket.timeout:
|
||||
break
|
||||
|
||||
for tunnel in tunnels:
|
||||
if tunnel.find(":") != -1:
|
||||
address, port = tunnel.split(":", 1)
|
||||
if address in l_ips:
|
||||
continue
|
||||
|
||||
if _check_tunnel(address, port, sock):
|
||||
sock.close()
|
||||
return address, port
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug("Caught exception in tunnel lookup: %s", exc)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def quit_tunnel(address, timeout=DEFAULT_TIMEOUT):
|
||||
try:
|
||||
sock = _set_multicast_socket(timeout)
|
||||
sock.sendto(b"-", (address, MCAST_PORT))
|
||||
sock.close()
|
||||
logger.debug("Success quitting tunnel")
|
||||
except Exception as exc:
|
||||
logger.debug("Exception quitting tunnel: %s", exc)
|
||||
return
|
||||
|
||||
|
||||
class MonkeyTunnel(Thread):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_class,
|
||||
keep_tunnel_open_time,
|
||||
target_addr=None,
|
||||
target_port=None,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
):
|
||||
self._target_addr = target_addr
|
||||
self._target_port = target_port
|
||||
self._proxy_class = proxy_class
|
||||
self._keep_tunnel_open_time = keep_tunnel_open_time
|
||||
self._broad_sock = None
|
||||
self._timeout = timeout
|
||||
self._stopped = Event()
|
||||
self._clients = []
|
||||
self.local_port = None
|
||||
super(MonkeyTunnel, self).__init__(name="MonkeyTunnelThread")
|
||||
self.daemon = True
|
||||
self.l_ips = None
|
||||
self._wait_for_exploited_machines = Event()
|
||||
|
||||
def run(self):
|
||||
self._broad_sock = _set_multicast_socket(self._timeout)
|
||||
self.l_ips = local_ips()
|
||||
self.local_port = get_free_tcp_port()
|
||||
|
||||
if not self.local_port:
|
||||
return
|
||||
|
||||
if not firewall.listen_allowed(localport=self.local_port):
|
||||
logger.info("Machine firewalled, listen not allowed, not running tunnel.")
|
||||
return
|
||||
|
||||
proxy = self._proxy_class(
|
||||
local_port=self.local_port, dest_host=self._target_addr, dest_port=self._target_port
|
||||
)
|
||||
logger.info(
|
||||
"Running tunnel using proxy class: %s, listening on port %s, routing to: %s:%s",
|
||||
proxy.__class__.__name__,
|
||||
self.local_port,
|
||||
self._target_addr,
|
||||
self._target_port,
|
||||
)
|
||||
proxy.start()
|
||||
|
||||
while not self._stopped.is_set():
|
||||
try:
|
||||
search, address = self._broad_sock.recvfrom(BUFFER_READ)
|
||||
if b"?" == search:
|
||||
ip_match = get_interface_to_target(address[0])
|
||||
if ip_match:
|
||||
answer = "%s:%d" % (ip_match, self.local_port)
|
||||
logger.debug(
|
||||
"Got tunnel request from %s, answering with %s", address[0], answer
|
||||
)
|
||||
self._broad_sock.sendto(answer.encode(), (address[0], MCAST_PORT))
|
||||
elif b"+" == search:
|
||||
if not address[0] in self._clients:
|
||||
logger.debug("Tunnel control: Added %s to watchlist", address[0])
|
||||
self._clients.append(address[0])
|
||||
elif b"-" == search:
|
||||
logger.debug("Tunnel control: Removed %s from watchlist", address[0])
|
||||
self._clients = [client for client in self._clients if client != address[0]]
|
||||
|
||||
except socket.timeout:
|
||||
continue
|
||||
|
||||
logger.info("Stopping tunnel, waiting for clients: %s" % repr(self._clients))
|
||||
|
||||
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
|
||||
# QUIT_TIMEOUT seconds
|
||||
timer = Timer()
|
||||
timer.set(self._calculate_timeout())
|
||||
while self._clients and not timer.is_expired():
|
||||
try:
|
||||
search, address = self._broad_sock.recvfrom(BUFFER_READ)
|
||||
if b"-" == search:
|
||||
logger.debug("Tunnel control: Removed %s from watchlist", address[0])
|
||||
self._clients = [client for client in self._clients if client != address[0]]
|
||||
except socket.timeout:
|
||||
continue
|
||||
|
||||
timer.set(self._calculate_timeout())
|
||||
|
||||
logger.info("Closing tunnel")
|
||||
self._broad_sock.close()
|
||||
proxy.stop()
|
||||
proxy.join()
|
||||
|
||||
def _calculate_timeout(self) -> float:
|
||||
try:
|
||||
return QUIT_TIMEOUT - (time.time() - get_last_serve_time())
|
||||
except TypeError: # get_last_serve_time() may return None
|
||||
return 0.0
|
||||
|
||||
def get_tunnel_for_ip(self, ip: str):
|
||||
|
||||
if not self.local_port:
|
||||
return
|
||||
|
||||
ip_match = get_interface_to_target(ip)
|
||||
return "%s:%d" % (ip_match, self.local_port)
|
||||
|
||||
def set_wait_for_exploited_machines(self):
|
||||
self._wait_for_exploited_machines.set()
|
||||
|
||||
def stop(self):
|
||||
self._wait_for_exploited_machine_connection()
|
||||
self._stopped.set()
|
||||
|
||||
def _wait_for_exploited_machine_connection(self):
|
||||
if self._wait_for_exploited_machines.is_set():
|
||||
logger.info(
|
||||
f"Waiting {self._keep_tunnel_open_time} seconds for exploited machines to connect "
|
||||
"to the tunnel."
|
||||
)
|
||||
time.sleep(self._keep_tunnel_open_time)
|
|
@ -1,18 +1,9 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.model import VictimHostFactory
|
||||
from infection_monkey.network import NetworkAddress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tunnel():
|
||||
tunnel = MagicMock()
|
||||
tunnel.get_tunnel_for_ip = lambda _: "1.2.3.4:1234"
|
||||
return tunnel
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_get_interface_to_target(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
|
@ -21,9 +12,7 @@ def mock_get_interface_to_target(monkeypatch):
|
|||
|
||||
|
||||
def test_factory_no_tunnel():
|
||||
factory = VictimHostFactory(
|
||||
tunnel=None, island_ip="192.168.56.1", island_port="5000", on_island=False
|
||||
)
|
||||
factory = VictimHostFactory(island_ip="192.168.56.1", island_port="5000", on_island=False)
|
||||
network_address = NetworkAddress("192.168.56.2", None)
|
||||
|
||||
victim = factory.build_victim_host(network_address)
|
||||
|
@ -34,24 +23,8 @@ def test_factory_no_tunnel():
|
|||
assert victim.domain_name == ""
|
||||
|
||||
|
||||
def test_factory_with_tunnel(mock_tunnel):
|
||||
factory = VictimHostFactory(
|
||||
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port="5000", on_island=False
|
||||
)
|
||||
network_address = NetworkAddress("192.168.56.2", None)
|
||||
|
||||
victim = factory.build_victim_host(network_address)
|
||||
|
||||
assert victim.default_server == "192.168.56.1:5000"
|
||||
assert victim.ip_addr == "192.168.56.2"
|
||||
assert victim.default_tunnel == "1.2.3.4:1234"
|
||||
assert victim.domain_name == ""
|
||||
|
||||
|
||||
def test_factory_on_island(mock_tunnel):
|
||||
factory = VictimHostFactory(
|
||||
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port="99", on_island=True
|
||||
)
|
||||
def test_factory_on_island():
|
||||
factory = VictimHostFactory(island_ip="192.168.56.1", island_port="99", on_island=True)
|
||||
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
|
||||
|
||||
victim = factory.build_victim_host(network_address)
|
||||
|
@ -63,10 +36,8 @@ def test_factory_on_island(mock_tunnel):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("default_port", ["", None])
|
||||
def test_factory_no_port(mock_tunnel, default_port):
|
||||
factory = VictimHostFactory(
|
||||
tunnel=mock_tunnel, island_ip="192.168.56.1", island_port=default_port, on_island=True
|
||||
)
|
||||
def test_factory_no_port(default_port):
|
||||
factory = VictimHostFactory(island_ip="192.168.56.1", island_port=default_port, on_island=True)
|
||||
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
|
||||
|
||||
victim = factory.build_victim_host(network_address)
|
||||
|
@ -74,8 +45,8 @@ def test_factory_no_port(mock_tunnel, default_port):
|
|||
assert victim.default_server == "1.1.1.1"
|
||||
|
||||
|
||||
def test_factory_no_default_server(mock_tunnel):
|
||||
factory = VictimHostFactory(tunnel=mock_tunnel, island_ip=None, island_port="", on_island=True)
|
||||
def test_factory_no_default_server():
|
||||
factory = VictimHostFactory(island_ip=None, island_port="", on_island=True)
|
||||
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
|
||||
|
||||
victim = factory.build_victim_host(network_address)
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from monkey.infection_monkey.control import ControlClient
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_windows_os,expected_proxy_string",
|
||||
[(True, "http://8.8.8.8:45455"), (False, "8.8.8.8:45455")],
|
||||
)
|
||||
def test_control_set_proxies(monkeypatch, is_windows_os, expected_proxy_string):
|
||||
monkeypatch.setattr("monkey.infection_monkey.control.is_windows_os", lambda: is_windows_os)
|
||||
control_client = ControlClient("8.8.8.8:5000")
|
||||
|
||||
control_client.set_proxies(("8.8.8.8", "45455"))
|
||||
|
||||
assert control_client.proxies["https"] == expected_proxy_string
|
|
@ -299,9 +299,6 @@ event
|
|||
deserialize
|
||||
serialized_event
|
||||
|
||||
# TODO: Remove when removing Tunnel code
|
||||
create_control_tunnel
|
||||
set_wait_for_exploited_machines
|
||||
|
||||
# pydantic base models
|
||||
underscore_attrs_are_private
|
||||
|
|
Loading…
Reference in New Issue