Merge pull request #2251 from guardicore/2216-use-tcprelay-in-agent

2216 use tcprelay in agent
This commit is contained in:
Mike Salvatore 2022-09-07 15:51:32 -04:00
commit f3ff4176b2
13 changed files with 148 additions and 61 deletions

View File

@ -2,6 +2,13 @@ from typing import Optional, Tuple
def address_to_ip_port(address: str) -> Tuple[str, Optional[str]]:
"""
Split a string containing an IP address (and optionally a port) into IP and Port components.
Currently only works for IPv4 addresses.
:param address: The address string.
:return: Tuple of IP and port strings. The port may be None if no port was in the address.
"""
if ":" in address:
ip, port = address.split(":")
return ip, port or None

View File

@ -176,7 +176,7 @@ class AutomatedMaster(IMaster):
current_depth = self._current_depth if self._current_depth is not None else 0
logger.info(f"Current depth is {current_depth}")
if maximum_depth_reached(config.propagation.maximum_depth, current_depth):
if not maximum_depth_reached(config.propagation.maximum_depth, current_depth):
self._propagator.propagate(config.propagation, current_depth, self._servers, self._stop)
else:
logger.info("Skipping propagation: maximum depth reached")

View File

@ -3,7 +3,7 @@ import logging
import os
import subprocess
import sys
from ipaddress import IPv4Interface
from ipaddress import IPv4Address, IPv4Interface
from pathlib import Path, WindowsPath
from typing import List
@ -41,7 +41,8 @@ from infection_monkey.master import AutomatedMaster
from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.model import VictimHostFactory
from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_network_interfaces
from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces
from infection_monkey.network.relay import TCPRelay
from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
@ -100,11 +101,10 @@ class InfectionMonkey:
# TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object
ControlClient.control_client_object = self._control_client
self._monkey_inbound_tunnel = None
self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth
self._master = None
self._inbound_tunnel_opened = False
self._relay: TCPRelay
@staticmethod
def _get_arguments(args):
@ -180,14 +180,17 @@ class InfectionMonkey:
control_channel.register_agent(self._opts.parent)
config = control_channel.get_config()
self._monkey_inbound_tunnel = self._control_client.create_control_tunnel(
config.keep_tunnel_open_time
relay_port = get_free_tcp_port()
self._relay = TCPRelay(
relay_port,
IPv4Address(self._cmd_island_ip),
self._cmd_island_port,
client_disconnect_timeout=config.keep_tunnel_open_time,
)
if self._monkey_inbound_tunnel and maximum_depth_reached(
config.propagation.maximum_depth, self._current_depth
):
self._inbound_tunnel_opened = True
self._monkey_inbound_tunnel.start()
if not maximum_depth_reached(config.propagation.maximum_depth, self._current_depth):
self._relay.start()
StateTelem(is_done=False, version=get_version()).send()
TunnelTelem(self._control_client.proxies).send()
@ -215,7 +218,7 @@ class InfectionMonkey:
victim_host_factory = self._build_victim_host_factory(local_network_interfaces)
telemetry_messenger = ExploitInterceptingTelemetryMessenger(
self._telemetry_messenger, self._monkey_inbound_tunnel
self._telemetry_messenger, self._relay
)
self._master = AutomatedMaster(
@ -374,9 +377,7 @@ class InfectionMonkey:
on_island = self._running_on_island(local_network_interfaces)
logger.debug(f"This agent is running on the island: {on_island}")
return VictimHostFactory(
self._monkey_inbound_tunnel, self._cmd_island_ip, self._cmd_island_port, on_island
)
return VictimHostFactory(None, self._cmd_island_ip, self._cmd_island_port, on_island)
def _running_on_island(self, local_network_interfaces: List[IPv4Interface]) -> bool:
server_ip, _ = address_to_ip_port(self._control_client.server_address)
@ -394,9 +395,9 @@ class InfectionMonkey:
reset_signal_handlers()
if self._inbound_tunnel_opened:
self._monkey_inbound_tunnel.stop()
self._monkey_inbound_tunnel.join()
if self._relay and self._relay.is_alive():
self._relay.stop()
self._relay.join(timeout=60)
if firewall.is_enabled():
firewall.remove_firewall_rule()

View File

@ -3,7 +3,7 @@ import socket
import struct
from dataclasses import dataclass
from ipaddress import IPv4Interface
from random import randint # noqa: DUO102
from random import shuffle # noqa: DUO102
from typing import List
import netifaces
@ -11,6 +11,8 @@ import psutil
from infection_monkey.utils.environment import is_windows_os
from .ports import COMMON_PORTS
# Timeout for monkey connections
LOOPBACK_NAME = b"lo"
SIOCGIFADDR = 0x8915 # get PA address
@ -119,14 +121,18 @@ else:
def get_free_tcp_port(min_range=1024, max_range=65535):
in_use = {conn.laddr[1] for conn in psutil.net_connections()}
for port in COMMON_PORTS:
if port not in in_use:
return port
min_range = max(1, min_range)
max_range = min(65535, max_range)
in_use = [conn.laddr[1] for conn in psutil.net_connections()]
for i in range(min_range, max_range):
port = randint(min_range, max_range)
ports = list(range(min_range, max_range))
shuffle(ports)
for port in ports:
if port not in in_use:
return port

View File

@ -0,0 +1,15 @@
from typing import List
COMMON_PORTS: List[int] = [
1025, # NFS, IIS
1433, # Microsoft SQL Server
1434, # Microsoft SQL Monitor
1720, # h323q931
1723, # Microsoft PPTP VPN
3306, # mysql
3389, # Windows Terminal Server (RDP)
5900, # vnc
6001, # X11:1
8080, # http-proxy
8888, # sun-answerbook
]

View File

@ -3,3 +3,4 @@ from .relay_user_handler import RelayUser, RelayUserHandler
from .sockets_pipe import SocketsPipe
from .tcp_connection_handler import TCPConnectionHandler
from .tcp_pipe_spawner import TCPPipeSpawner
from .tcp_relay import TCPRelay

View File

@ -1,7 +1,13 @@
from ipaddress import IPv4Address
from threading import Lock, Thread
from time import sleep
from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
from infection_monkey.network.relay import (
RelayConnectionHandler,
RelayUserHandler,
TCPConnectionHandler,
TCPPipeSpawner,
)
from infection_monkey.utils.threading import InterruptableThreadMixin
@ -12,13 +18,21 @@ class TCPRelay(Thread, InterruptableThreadMixin):
def __init__(
self,
relay_user_handler: RelayUserHandler,
connection_handler: TCPConnectionHandler,
pipe_spawner: TCPPipeSpawner,
relay_port: int,
dest_addr: IPv4Address,
dest_port: int,
client_disconnect_timeout: float,
):
self._user_handler = relay_user_handler
self._connection_handler = connection_handler
self._pipe_spawner = pipe_spawner
self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout)
self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port)
relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler)
self._connection_handler = TCPConnectionHandler(
bind_host="",
bind_port=relay_port,
client_connected=[
relay_filter.handle_new_connection,
],
)
super().__init__(name="MonkeyTcpRelayThread", daemon=True)
self._lock = Lock()
@ -32,6 +46,14 @@ class TCPRelay(Thread, InterruptableThreadMixin):
self._connection_handler.join()
self._wait_for_pipes_to_close()
def add_potential_user(self, user_address: IPv4Address):
"""
Notify TCPRelay of a user that may try to connect.
:param user_address: The address of the potential new user.
"""
self._user_handler.add_potential_user(user_address)
def _wait_for_users_to_disconnect(self):
"""
Blocks until the users disconnect or the timeout has elapsed.

View File

@ -1,23 +1,19 @@
from functools import singledispatch
from ipaddress import IPv4Address
from infection_monkey.network.relay.tcp_relay import TCPRelay
from infection_monkey.network.relay import TCPRelay
from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.telemetry.i_telem import ITelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.tunnel import MonkeyTunnel
class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger):
def __init__(
self, telemetry_messenger: ITelemetryMessenger, tunnel: MonkeyTunnel, relay: TCPRelay
):
def __init__(self, telemetry_messenger: ITelemetryMessenger, relay: TCPRelay):
self._telemetry_messenger = telemetry_messenger
self._tunnel = tunnel
self._relay = relay
def send_telemetry(self, telemetry: ITelem):
_send_telemetry(telemetry, self._telemetry_messenger, self._tunnel, self._relay)
_send_telemetry(telemetry, self._telemetry_messenger, self._relay)
# Note: We can use @singledispatchmethod instead of @singledispatch if we migrate to Python 3.8 or
@ -26,7 +22,6 @@ class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger):
def _send_telemetry(
telemetry: ITelem,
telemetry_messenger: ITelemetryMessenger,
tunnel: MonkeyTunnel,
relay: TCPRelay,
):
telemetry_messenger.send_telemetry(telemetry)
@ -36,12 +31,11 @@ def _send_telemetry(
def _(
telemetry: ExploitTelem,
telemetry_messenger: ITelemetryMessenger,
tunnel: MonkeyTunnel,
relay: TCPRelay,
):
if telemetry.propagation_result is True:
tunnel.set_wait_for_exploited_machines()
if relay:
relay.add_potential_user(IPv4Address(telemetry.host["ip_addr"]))
address = IPv4Address(str(telemetry.host["ip_addr"]))
relay.add_potential_user(address)
telemetry_messenger.send_telemetry(telemetry)

View File

@ -1,2 +1,10 @@
def maximum_depth_reached(maximum_depth: int, current_depth: int) -> bool:
return maximum_depth > current_depth
"""
Return whether or not the current depth has eclipsed the maximum depth.
Values are nonnegative. Depth should increase from zero.
:param maximum_depth: The maximum depth.
:param current_depth: The current depth.
:return: True if the current depth has reached the maximum depth, otherwise False.
"""
return current_depth >= maximum_depth

View File

@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Tuple
import pytest
from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.ports import COMMON_PORTS
@dataclass
class Connection:
laddr: Tuple[str, int]
@pytest.mark.parametrize("port", COMMON_PORTS)
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is port
def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is not None
def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
unavailable_ports = [Connection(("", p)) for p in range(65535)]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is None

View File

@ -20,49 +20,43 @@ class MockExploitTelem(ExploitTelem):
def test_generic_telemetry(TestTelem):
mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock()
telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay
mock_telemetry_messenger, mock_relay
)
telemetry_messenger.send_telemetry(TestTelem())
assert mock_telemetry_messenger.send_telemetry.called
assert not mock_tunnel.set_wait_for_exploited_machines.called
assert not mock_relay.add_potential_user.called
def test_propagation_successful_exploit_telemetry():
mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock()
mock_exploit_telem = MockExploitTelem(True)
telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay
mock_telemetry_messenger, mock_relay
)
telemetry_messenger.send_telemetry(mock_exploit_telem)
assert mock_telemetry_messenger.send_telemetry.called
assert mock_tunnel.set_wait_for_exploited_machines.called
assert mock_relay.add_potential_user.called
def test_propagation_failed_exploit_telemetry():
mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock()
mock_exploit_telem = MockExploitTelem(False)
telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay
mock_telemetry_messenger, mock_relay
)
telemetry_messenger.send_telemetry(mock_exploit_telem)
assert mock_telemetry_messenger.send_telemetry.called
assert not mock_tunnel.set_wait_for_exploited_machines.called
assert not mock_relay.add_potential_user.called

View File

@ -5,18 +5,18 @@ def test_maximum_depth_reached__current_less_than_max():
maximum_depth = 2
current_depth = 1
assert maximum_depth_reached(maximum_depth, current_depth) is True
assert maximum_depth_reached(maximum_depth, current_depth) is False
def test_maximum_depth_reached__current_greater_than_max():
maximum_depth = 2
current_depth = 3
assert maximum_depth_reached(maximum_depth, current_depth) is False
assert maximum_depth_reached(maximum_depth, current_depth) is True
def test_maximum_depth_reached__current_equal_to_max():
maximum_depth = 2
current_depth = maximum_depth
assert maximum_depth_reached(maximum_depth, current_depth) is False
assert maximum_depth_reached(maximum_depth, current_depth) is True

View File

@ -299,10 +299,9 @@ event
deserialize
serialized_event
# TODO: Remove after #2231 is closed
relay_users
last_update_time
add_relay_user
# TODO: Remove when removing Tunnel code
create_control_tunnel
set_wait_for_exploited_machines
# pydantic base models
underscore_attrs_are_private