Merge pull request #2251 from guardicore/2216-use-tcprelay-in-agent

2216 use tcprelay in agent
This commit is contained in:
Mike Salvatore 2022-09-07 15:51:32 -04:00
commit f3ff4176b2
13 changed files with 148 additions and 61 deletions

View File

@ -2,6 +2,13 @@ from typing import Optional, Tuple
def address_to_ip_port(address: str) -> Tuple[str, Optional[str]]: def address_to_ip_port(address: str) -> Tuple[str, Optional[str]]:
"""
Split a string containing an IP address (and optionally a port) into IP and Port components.
Currently only works for IPv4 addresses.
:param address: The address string.
:return: Tuple of IP and port strings. The port may be None if no port was in the address.
"""
if ":" in address: if ":" in address:
ip, port = address.split(":") ip, port = address.split(":")
return ip, port or None return ip, port or None

View File

@ -176,7 +176,7 @@ class AutomatedMaster(IMaster):
current_depth = self._current_depth if self._current_depth is not None else 0 current_depth = self._current_depth if self._current_depth is not None else 0
logger.info(f"Current depth is {current_depth}") logger.info(f"Current depth is {current_depth}")
if maximum_depth_reached(config.propagation.maximum_depth, current_depth): if not maximum_depth_reached(config.propagation.maximum_depth, current_depth):
self._propagator.propagate(config.propagation, current_depth, self._servers, self._stop) self._propagator.propagate(config.propagation, current_depth, self._servers, self._stop)
else: else:
logger.info("Skipping propagation: maximum depth reached") logger.info("Skipping propagation: maximum depth reached")

View File

@ -3,7 +3,7 @@ import logging
import os import os
import subprocess import subprocess
import sys import sys
from ipaddress import IPv4Interface from ipaddress import IPv4Address, IPv4Interface
from pathlib import Path, WindowsPath from pathlib import Path, WindowsPath
from typing import List from typing import List
@ -41,7 +41,8 @@ from infection_monkey.master import AutomatedMaster
from infection_monkey.master.control_channel import ControlChannel from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.model import VictimHostFactory from infection_monkey.model import VictimHostFactory
from infection_monkey.network.firewall import app as firewall from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_network_interfaces from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces
from infection_monkey.network.relay import TCPRelay
from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
@ -100,11 +101,10 @@ class InfectionMonkey:
# TODO Refactor the telemetry messengers to accept control client # TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object # and remove control_client_object
ControlClient.control_client_object = self._control_client ControlClient.control_client_object = self._control_client
self._monkey_inbound_tunnel = None
self._telemetry_messenger = LegacyTelemetryMessengerAdapter() self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth self._current_depth = self._opts.depth
self._master = None self._master = None
self._inbound_tunnel_opened = False self._relay: TCPRelay
@staticmethod @staticmethod
def _get_arguments(args): def _get_arguments(args):
@ -180,14 +180,17 @@ class InfectionMonkey:
control_channel.register_agent(self._opts.parent) control_channel.register_agent(self._opts.parent)
config = control_channel.get_config() config = control_channel.get_config()
self._monkey_inbound_tunnel = self._control_client.create_control_tunnel(
config.keep_tunnel_open_time relay_port = get_free_tcp_port()
self._relay = TCPRelay(
relay_port,
IPv4Address(self._cmd_island_ip),
self._cmd_island_port,
client_disconnect_timeout=config.keep_tunnel_open_time,
) )
if self._monkey_inbound_tunnel and maximum_depth_reached(
config.propagation.maximum_depth, self._current_depth if not maximum_depth_reached(config.propagation.maximum_depth, self._current_depth):
): self._relay.start()
self._inbound_tunnel_opened = True
self._monkey_inbound_tunnel.start()
StateTelem(is_done=False, version=get_version()).send() StateTelem(is_done=False, version=get_version()).send()
TunnelTelem(self._control_client.proxies).send() TunnelTelem(self._control_client.proxies).send()
@ -215,7 +218,7 @@ class InfectionMonkey:
victim_host_factory = self._build_victim_host_factory(local_network_interfaces) victim_host_factory = self._build_victim_host_factory(local_network_interfaces)
telemetry_messenger = ExploitInterceptingTelemetryMessenger( telemetry_messenger = ExploitInterceptingTelemetryMessenger(
self._telemetry_messenger, self._monkey_inbound_tunnel self._telemetry_messenger, self._relay
) )
self._master = AutomatedMaster( self._master = AutomatedMaster(
@ -374,9 +377,7 @@ class InfectionMonkey:
on_island = self._running_on_island(local_network_interfaces) on_island = self._running_on_island(local_network_interfaces)
logger.debug(f"This agent is running on the island: {on_island}") logger.debug(f"This agent is running on the island: {on_island}")
return VictimHostFactory( return VictimHostFactory(None, self._cmd_island_ip, self._cmd_island_port, on_island)
self._monkey_inbound_tunnel, self._cmd_island_ip, self._cmd_island_port, on_island
)
def _running_on_island(self, local_network_interfaces: List[IPv4Interface]) -> bool: def _running_on_island(self, local_network_interfaces: List[IPv4Interface]) -> bool:
server_ip, _ = address_to_ip_port(self._control_client.server_address) server_ip, _ = address_to_ip_port(self._control_client.server_address)
@ -394,9 +395,9 @@ class InfectionMonkey:
reset_signal_handlers() reset_signal_handlers()
if self._inbound_tunnel_opened: if self._relay and self._relay.is_alive():
self._monkey_inbound_tunnel.stop() self._relay.stop()
self._monkey_inbound_tunnel.join() self._relay.join(timeout=60)
if firewall.is_enabled(): if firewall.is_enabled():
firewall.remove_firewall_rule() firewall.remove_firewall_rule()

View File

@ -3,7 +3,7 @@ import socket
import struct import struct
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import IPv4Interface from ipaddress import IPv4Interface
from random import randint # noqa: DUO102 from random import shuffle # noqa: DUO102
from typing import List from typing import List
import netifaces import netifaces
@ -11,6 +11,8 @@ import psutil
from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.environment import is_windows_os
from .ports import COMMON_PORTS
# Timeout for monkey connections # Timeout for monkey connections
LOOPBACK_NAME = b"lo" LOOPBACK_NAME = b"lo"
SIOCGIFADDR = 0x8915 # get PA address SIOCGIFADDR = 0x8915 # get PA address
@ -119,14 +121,18 @@ else:
def get_free_tcp_port(min_range=1024, max_range=65535): def get_free_tcp_port(min_range=1024, max_range=65535):
in_use = {conn.laddr[1] for conn in psutil.net_connections()}
for port in COMMON_PORTS:
if port not in in_use:
return port
min_range = max(1, min_range) min_range = max(1, min_range)
max_range = min(65535, max_range) max_range = min(65535, max_range)
ports = list(range(min_range, max_range))
in_use = [conn.laddr[1] for conn in psutil.net_connections()] shuffle(ports)
for port in ports:
for i in range(min_range, max_range):
port = randint(min_range, max_range)
if port not in in_use: if port not in in_use:
return port return port

View File

@ -0,0 +1,15 @@
from typing import List
COMMON_PORTS: List[int] = [
1025, # NFS, IIS
1433, # Microsoft SQL Server
1434, # Microsoft SQL Monitor
1720, # h323q931
1723, # Microsoft PPTP VPN
3306, # mysql
3389, # Windows Terminal Server (RDP)
5900, # vnc
6001, # X11:1
8080, # http-proxy
8888, # sun-answerbook
]

View File

@ -3,3 +3,4 @@ from .relay_user_handler import RelayUser, RelayUserHandler
from .sockets_pipe import SocketsPipe from .sockets_pipe import SocketsPipe
from .tcp_connection_handler import TCPConnectionHandler from .tcp_connection_handler import TCPConnectionHandler
from .tcp_pipe_spawner import TCPPipeSpawner from .tcp_pipe_spawner import TCPPipeSpawner
from .tcp_relay import TCPRelay

View File

@ -1,7 +1,13 @@
from ipaddress import IPv4Address
from threading import Lock, Thread from threading import Lock, Thread
from time import sleep from time import sleep
from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner from infection_monkey.network.relay import (
RelayConnectionHandler,
RelayUserHandler,
TCPConnectionHandler,
TCPPipeSpawner,
)
from infection_monkey.utils.threading import InterruptableThreadMixin from infection_monkey.utils.threading import InterruptableThreadMixin
@ -12,13 +18,21 @@ class TCPRelay(Thread, InterruptableThreadMixin):
def __init__( def __init__(
self, self,
relay_user_handler: RelayUserHandler, relay_port: int,
connection_handler: TCPConnectionHandler, dest_addr: IPv4Address,
pipe_spawner: TCPPipeSpawner, dest_port: int,
client_disconnect_timeout: float,
): ):
self._user_handler = relay_user_handler self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout)
self._connection_handler = connection_handler self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port)
self._pipe_spawner = pipe_spawner relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler)
self._connection_handler = TCPConnectionHandler(
bind_host="",
bind_port=relay_port,
client_connected=[
relay_filter.handle_new_connection,
],
)
super().__init__(name="MonkeyTcpRelayThread", daemon=True) super().__init__(name="MonkeyTcpRelayThread", daemon=True)
self._lock = Lock() self._lock = Lock()
@ -32,6 +46,14 @@ class TCPRelay(Thread, InterruptableThreadMixin):
self._connection_handler.join() self._connection_handler.join()
self._wait_for_pipes_to_close() self._wait_for_pipes_to_close()
def add_potential_user(self, user_address: IPv4Address):
"""
Notify TCPRelay of a user that may try to connect.
:param user_address: The address of the potential new user.
"""
self._user_handler.add_potential_user(user_address)
def _wait_for_users_to_disconnect(self): def _wait_for_users_to_disconnect(self):
""" """
Blocks until the users disconnect or the timeout has elapsed. Blocks until the users disconnect or the timeout has elapsed.

View File

@ -1,23 +1,19 @@
from functools import singledispatch from functools import singledispatch
from ipaddress import IPv4Address from ipaddress import IPv4Address
from infection_monkey.network.relay.tcp_relay import TCPRelay from infection_monkey.network.relay import TCPRelay
from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.telemetry.i_telem import ITelem from infection_monkey.telemetry.i_telem import ITelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.tunnel import MonkeyTunnel
class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger): class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger):
def __init__( def __init__(self, telemetry_messenger: ITelemetryMessenger, relay: TCPRelay):
self, telemetry_messenger: ITelemetryMessenger, tunnel: MonkeyTunnel, relay: TCPRelay
):
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
self._tunnel = tunnel
self._relay = relay self._relay = relay
def send_telemetry(self, telemetry: ITelem): def send_telemetry(self, telemetry: ITelem):
_send_telemetry(telemetry, self._telemetry_messenger, self._tunnel, self._relay) _send_telemetry(telemetry, self._telemetry_messenger, self._relay)
# Note: We can use @singledispatchmethod instead of @singledispatch if we migrate to Python 3.8 or # Note: We can use @singledispatchmethod instead of @singledispatch if we migrate to Python 3.8 or
@ -26,7 +22,6 @@ class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger):
def _send_telemetry( def _send_telemetry(
telemetry: ITelem, telemetry: ITelem,
telemetry_messenger: ITelemetryMessenger, telemetry_messenger: ITelemetryMessenger,
tunnel: MonkeyTunnel,
relay: TCPRelay, relay: TCPRelay,
): ):
telemetry_messenger.send_telemetry(telemetry) telemetry_messenger.send_telemetry(telemetry)
@ -36,12 +31,11 @@ def _send_telemetry(
def _( def _(
telemetry: ExploitTelem, telemetry: ExploitTelem,
telemetry_messenger: ITelemetryMessenger, telemetry_messenger: ITelemetryMessenger,
tunnel: MonkeyTunnel,
relay: TCPRelay, relay: TCPRelay,
): ):
if telemetry.propagation_result is True: if telemetry.propagation_result is True:
tunnel.set_wait_for_exploited_machines()
if relay: if relay:
relay.add_potential_user(IPv4Address(telemetry.host["ip_addr"])) address = IPv4Address(str(telemetry.host["ip_addr"]))
relay.add_potential_user(address)
telemetry_messenger.send_telemetry(telemetry) telemetry_messenger.send_telemetry(telemetry)

View File

@ -1,2 +1,10 @@
def maximum_depth_reached(maximum_depth: int, current_depth: int) -> bool: def maximum_depth_reached(maximum_depth: int, current_depth: int) -> bool:
return maximum_depth > current_depth """
Return whether or not the current depth has eclipsed the maximum depth.
Values are nonnegative. Depth should increase from zero.
:param maximum_depth: The maximum depth.
:param current_depth: The current depth.
:return: True if the current depth has reached the maximum depth, otherwise False.
"""
return current_depth >= maximum_depth

View File

@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Tuple
import pytest
from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.ports import COMMON_PORTS
@dataclass
class Connection:
laddr: Tuple[str, int]
@pytest.mark.parametrize("port", COMMON_PORTS)
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is port
def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is not None
def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
unavailable_ports = [Connection(("", p)) for p in range(65535)]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is None

View File

@ -20,49 +20,43 @@ class MockExploitTelem(ExploitTelem):
def test_generic_telemetry(TestTelem): def test_generic_telemetry(TestTelem):
mock_telemetry_messenger = MagicMock() mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock() mock_relay = MagicMock()
telemetry_messenger = ExploitInterceptingTelemetryMessenger( telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay mock_telemetry_messenger, mock_relay
) )
telemetry_messenger.send_telemetry(TestTelem()) telemetry_messenger.send_telemetry(TestTelem())
assert mock_telemetry_messenger.send_telemetry.called assert mock_telemetry_messenger.send_telemetry.called
assert not mock_tunnel.set_wait_for_exploited_machines.called
assert not mock_relay.add_potential_user.called assert not mock_relay.add_potential_user.called
def test_propagation_successful_exploit_telemetry(): def test_propagation_successful_exploit_telemetry():
mock_telemetry_messenger = MagicMock() mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock() mock_relay = MagicMock()
mock_exploit_telem = MockExploitTelem(True) mock_exploit_telem = MockExploitTelem(True)
telemetry_messenger = ExploitInterceptingTelemetryMessenger( telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay mock_telemetry_messenger, mock_relay
) )
telemetry_messenger.send_telemetry(mock_exploit_telem) telemetry_messenger.send_telemetry(mock_exploit_telem)
assert mock_telemetry_messenger.send_telemetry.called assert mock_telemetry_messenger.send_telemetry.called
assert mock_tunnel.set_wait_for_exploited_machines.called
assert mock_relay.add_potential_user.called assert mock_relay.add_potential_user.called
def test_propagation_failed_exploit_telemetry(): def test_propagation_failed_exploit_telemetry():
mock_telemetry_messenger = MagicMock() mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock() mock_relay = MagicMock()
mock_exploit_telem = MockExploitTelem(False) mock_exploit_telem = MockExploitTelem(False)
telemetry_messenger = ExploitInterceptingTelemetryMessenger( telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay mock_telemetry_messenger, mock_relay
) )
telemetry_messenger.send_telemetry(mock_exploit_telem) telemetry_messenger.send_telemetry(mock_exploit_telem)
assert mock_telemetry_messenger.send_telemetry.called assert mock_telemetry_messenger.send_telemetry.called
assert not mock_tunnel.set_wait_for_exploited_machines.called
assert not mock_relay.add_potential_user.called assert not mock_relay.add_potential_user.called

View File

@ -5,18 +5,18 @@ def test_maximum_depth_reached__current_less_than_max():
maximum_depth = 2 maximum_depth = 2
current_depth = 1 current_depth = 1
assert maximum_depth_reached(maximum_depth, current_depth) is True assert maximum_depth_reached(maximum_depth, current_depth) is False
def test_maximum_depth_reached__current_greater_than_max(): def test_maximum_depth_reached__current_greater_than_max():
maximum_depth = 2 maximum_depth = 2
current_depth = 3 current_depth = 3
assert maximum_depth_reached(maximum_depth, current_depth) is False assert maximum_depth_reached(maximum_depth, current_depth) is True
def test_maximum_depth_reached__current_equal_to_max(): def test_maximum_depth_reached__current_equal_to_max():
maximum_depth = 2 maximum_depth = 2
current_depth = maximum_depth current_depth = maximum_depth
assert maximum_depth_reached(maximum_depth, current_depth) is False assert maximum_depth_reached(maximum_depth, current_depth) is True

View File

@ -299,10 +299,9 @@ event
deserialize deserialize
serialized_event serialized_event
# TODO: Remove after #2231 is closed # TODO: Remove when removing Tunnel code
relay_users create_control_tunnel
last_update_time set_wait_for_exploited_machines
add_relay_user
# pydantic base models # pydantic base models
underscore_attrs_are_private underscore_attrs_are_private