diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 26949ab8e..c6337148d 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -42,11 +42,7 @@ from infection_monkey.master.control_channel import ControlChannel from infection_monkey.model import VictimHostFactory from infection_monkey.network.firewall import app as firewall from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces -from infection_monkey.network.relay import ( - build_tcprelay_deps, - RelayUserHandler, - TCPRelay, -) +from infection_monkey.network.relay import TCPRelay from infection_monkey.network.tools import connect from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter @@ -109,7 +105,6 @@ class InfectionMonkey: self._telemetry_messenger = LegacyTelemetryMessengerAdapter() self._current_depth = self._opts.depth self._master = None - self._relay_user_handler: RelayUserHandler self._relay: TCPRelay @staticmethod @@ -190,18 +185,12 @@ class InfectionMonkey: local_port = get_free_tcp_port() sock, ip_str, port = connect([self._opts.server]) sock.close() - user_handler, connection_handler, pipe_spawner = build_tcprelay_deps( + self._relay = TCPRelay( local_port, IPv4Address(ip_str), port, client_disconnect_timeout=config.keep_tunnel_open_time, ) - self._relay_user_handler = user_handler - self._relay = TCPRelay( - self._relay_user_handler, - connection_handler, - pipe_spawner, - ) if self._relay and maximum_depth_reached( config.propagation.maximum_depth, self._current_depth @@ -234,7 +223,7 @@ class InfectionMonkey: victim_host_factory = self._build_victim_host_factory(local_network_interfaces) telemetry_messenger = ExploitInterceptingTelemetryMessenger( - self._telemetry_messenger, self._relay_user_handler + self._telemetry_messenger, self._relay ) self._master = AutomatedMaster( diff --git a/monkey/infection_monkey/network/relay/__init__.py b/monkey/infection_monkey/network/relay/__init__.py index 8a4ec4731..563b972c4 100644 --- a/monkey/infection_monkey/network/relay/__init__.py +++ b/monkey/infection_monkey/network/relay/__init__.py @@ -4,4 +4,3 @@ from .sockets_pipe import SocketsPipe from .tcp_connection_handler import TCPConnectionHandler from .tcp_pipe_spawner import TCPPipeSpawner from .tcp_relay import TCPRelay -from .utils import build_tcprelay_deps diff --git a/monkey/infection_monkey/network/relay/tcp_relay.py b/monkey/infection_monkey/network/relay/tcp_relay.py index f605cdc9d..a8793d91f 100644 --- a/monkey/infection_monkey/network/relay/tcp_relay.py +++ b/monkey/infection_monkey/network/relay/tcp_relay.py @@ -1,7 +1,13 @@ +from ipaddress import IPv4Address from threading import Lock, Thread from time import sleep -from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner +from infection_monkey.network.relay import ( + RelayConnectionHandler, + RelayUserHandler, + TCPConnectionHandler, + TCPPipeSpawner, +) from infection_monkey.utils.threading import InterruptableThreadMixin @@ -12,13 +18,21 @@ class TCPRelay(Thread, InterruptableThreadMixin): def __init__( self, - relay_user_handler: RelayUserHandler, - connection_handler: TCPConnectionHandler, - pipe_spawner: TCPPipeSpawner, + local_port: int, + dest_addr: IPv4Address, + dest_port: int, + client_disconnect_timeout: float, ): - self._user_handler = relay_user_handler - self._connection_handler = connection_handler - self._pipe_spawner = pipe_spawner + self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout) + self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port) + relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler) + self._connection_handler = TCPConnectionHandler( + bind_host="", + bind_port=local_port, + client_connected=[ + relay_filter.handle_new_connection, + ], + ) super().__init__(name="MonkeyTcpRelayThread", daemon=True) self._lock = Lock() @@ -32,6 +46,14 @@ class TCPRelay(Thread, InterruptableThreadMixin): self._connection_handler.join() self._wait_for_pipes_to_close() + def add_potential_user(self, user_address: IPv4Address): + """ + Notify TCPRelay of a user that may try to connect. + + :param user_address: The address of the potential new user. + """ + self._user_handler.add_potential_user(user_address) + def _wait_for_users_to_disconnect(self): """ Blocks until the users disconnect or the timeout has elapsed. diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py deleted file mode 100644 index 81bb00203..000000000 --- a/monkey/infection_monkey/network/relay/utils.py +++ /dev/null @@ -1,22 +0,0 @@ -from ipaddress import IPv4Address -from typing import Tuple - -from . import RelayConnectionHandler, RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner - - -def build_tcprelay_deps( - local_port: int, dest_addr: IPv4Address, dest_port: int, client_disconnect_timeout: float -) -> Tuple[RelayUserHandler, TCPPipeSpawner, TCPConnectionHandler]: - - relay_user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout) - pipe_spawner = TCPPipeSpawner(dest_addr, dest_port) - relay_filter = RelayConnectionHandler(pipe_spawner, relay_user_handler) - connection_handler = TCPConnectionHandler( - bind_host="", - bind_port=local_port, - client_connected=[ - relay_filter.handle_new_connection, - ], - ) - - return relay_user_handler, pipe_spawner, connection_handler diff --git a/monkey/infection_monkey/telemetry/messengers/exploit_intercepting_telemetry_messenger.py b/monkey/infection_monkey/telemetry/messengers/exploit_intercepting_telemetry_messenger.py index f151591f3..cb59aa0ae 100644 --- a/monkey/infection_monkey/telemetry/messengers/exploit_intercepting_telemetry_messenger.py +++ b/monkey/infection_monkey/telemetry/messengers/exploit_intercepting_telemetry_messenger.py @@ -1,21 +1,19 @@ from functools import singledispatch from ipaddress import IPv4Address -from infection_monkey.network.relay import RelayUserHandler +from infection_monkey.network.relay import TCPRelay from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.i_telem import ITelem from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger): - def __init__( - self, telemetry_messenger: ITelemetryMessenger, relay_user_handler: RelayUserHandler - ): + def __init__(self, telemetry_messenger: ITelemetryMessenger, relay: TCPRelay): self._telemetry_messenger = telemetry_messenger - self._relay_user_handler = relay_user_handler + self._relay = relay def send_telemetry(self, telemetry: ITelem): - _send_telemetry(telemetry, self._telemetry_messenger, self._relay_user_handler) + _send_telemetry(telemetry, self._telemetry_messenger, self._relay) # Note: We can use @singledispatchmethod instead of @singledispatch if we migrate to Python 3.8 or @@ -24,7 +22,7 @@ class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger): def _send_telemetry( telemetry: ITelem, telemetry_messenger: ITelemetryMessenger, - relay_user_handler: RelayUserHandler, + relay: TCPRelay, ): telemetry_messenger.send_telemetry(telemetry) @@ -33,11 +31,11 @@ def _send_telemetry( def _( telemetry: ExploitTelem, telemetry_messenger: ITelemetryMessenger, - relay_user_handler: RelayUserHandler, + relay: TCPRelay, ): if telemetry.propagation_result is True: - if relay_user_handler: + if relay: address = IPv4Address(str(telemetry.host["ip_addr"])) - relay_user_handler.add_potential_user(address) + relay.add_potential_user(address) telemetry_messenger.send_telemetry(telemetry) diff --git a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py index 50533fc75..61ca1b971 100644 --- a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py +++ b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py @@ -20,43 +20,43 @@ class MockExploitTelem(ExploitTelem): def test_generic_telemetry(TestTelem): mock_telemetry_messenger = MagicMock() - mock_relay_user_handler = MagicMock() + mock_relay = MagicMock() telemetry_messenger = ExploitInterceptingTelemetryMessenger( - mock_telemetry_messenger, mock_relay_user_handler + mock_telemetry_messenger, mock_relay ) telemetry_messenger.send_telemetry(TestTelem()) assert mock_telemetry_messenger.send_telemetry.called - assert not mock_relay_user_handler.add_potential_user.called + assert not mock_relay.add_potential_user.called def test_propagation_successful_exploit_telemetry(): mock_telemetry_messenger = MagicMock() - mock_relay_user_handler = MagicMock() + mock_relay = MagicMock() mock_exploit_telem = MockExploitTelem(True) telemetry_messenger = ExploitInterceptingTelemetryMessenger( - mock_telemetry_messenger, mock_relay_user_handler + mock_telemetry_messenger, mock_relay ) telemetry_messenger.send_telemetry(mock_exploit_telem) assert mock_telemetry_messenger.send_telemetry.called - assert mock_relay_user_handler.add_potential_user.called + assert mock_relay.add_potential_user.called def test_propagation_failed_exploit_telemetry(): mock_telemetry_messenger = MagicMock() - mock_relay_user_handler = MagicMock() + mock_relay = MagicMock() mock_exploit_telem = MockExploitTelem(False) telemetry_messenger = ExploitInterceptingTelemetryMessenger( - mock_telemetry_messenger, mock_relay_user_handler + mock_telemetry_messenger, mock_relay ) telemetry_messenger.send_telemetry(mock_exploit_telem) assert mock_telemetry_messenger.send_telemetry.called - assert not mock_relay_user_handler.add_potential_user.called + assert not mock_relay.add_potential_user.called