diff --git a/monkey/infection_monkey/control.py b/monkey/infection_monkey/control.py index 90a7b6078..d11de18fc 100644 --- a/monkey/infection_monkey/control.py +++ b/monkey/infection_monkey/control.py @@ -292,7 +292,12 @@ class ControlClient(object): proxy_class = HTTPConnectProxy target_addr, target_port = None, None - return tunnel.MonkeyTunnel(proxy_class, target_addr=target_addr, target_port=target_port) + return tunnel.MonkeyTunnel( + proxy_class, + keep_tunnel_open_time=WormConfiguration.keep_tunnel_open_time, + target_addr=target_addr, + target_port=target_port, + ) @staticmethod def get_pba_file(filename): diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index db0ba58c7..cb1c47f6d 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -3,7 +3,6 @@ import logging import os import subprocess import sys -import time from typing import List import infection_monkey.tunnel as tunnel @@ -36,6 +35,9 @@ from infection_monkey.puppet.puppet import Puppet from infection_monkey.system_singleton import SystemSingleton from infection_monkey.telemetry.attack.t1106_telem import T1106Telem from infection_monkey.telemetry.attack.t1107_telem import T1107Telem +from infection_monkey.telemetry.messengers.exploit_intercepting_telemetry_messenger import ( + ExploitInterceptingTelemetryMessenger, +) from infection_monkey.telemetry.messengers.legacy_telemetry_messenger_adapter import ( LegacyTelemetryMessengerAdapter, ) @@ -164,9 +166,13 @@ class InfectionMonkey: victim_host_factory = self._build_victim_host_factory(local_network_interfaces) + telemetry_messenger = ExploitInterceptingTelemetryMessenger( + self.telemetry_messenger, self._monkey_inbound_tunnel + ) + self._master = AutomatedMaster( puppet, - self.telemetry_messenger, + telemetry_messenger, victim_host_factory, ControlChannel(self._default_server, GUID), local_network_interfaces, @@ -237,7 +243,6 @@ class InfectionMonkey: def cleanup(self): logger.info("Monkey cleanup started") - self._wait_for_exploited_machine_connection() try: if self._master: self._master.cleanup() @@ -270,19 +275,6 @@ class InfectionMonkey: logger.info("Monkey is shutting down") - def _wait_for_exploited_machine_connection(self): - # TODO check for actual exploitation - machines_exploited = False - # if host was exploited, before continue to closing the tunnel ensure the exploited - # host had its chance to - # connect to the tunnel - if machines_exploited: - time_to_sleep = WormConfiguration.keep_tunnel_open_time - logger.info( - "Sleeping %d seconds for exploited machines to connect to tunnel", time_to_sleep - ) - time.sleep(time_to_sleep) - @staticmethod def _close_tunnel(): tunnel_address = ( diff --git a/monkey/infection_monkey/telemetry/messengers/__init__.py b/monkey/infection_monkey/telemetry/messengers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/monkey/infection_monkey/telemetry/messengers/exploit_intercepting_telemetry_messenger.py b/monkey/infection_monkey/telemetry/messengers/exploit_intercepting_telemetry_messenger.py new file mode 100644 index 000000000..3b92235fb --- /dev/null +++ b/monkey/infection_monkey/telemetry/messengers/exploit_intercepting_telemetry_messenger.py @@ -0,0 +1,30 @@ +from functools import singledispatch + +from infection_monkey.telemetry.i_telem import ITelem +from infection_monkey.telemetry.exploit_telem import ExploitTelem +from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger +from infection_monkey.tunnel import MonkeyTunnel + + +class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger): + def __init__(self, telemetry_messenger: ITelemetryMessenger, tunnel: MonkeyTunnel): + self._telemetry_messenger = telemetry_messenger + self._tunnel = tunnel + + def send_telemetry(self, telemetry: ITelem): + _send_telemetry(telemetry, self._telemetry_messenger, self._tunnel) + + +# Note: We can use @singledispatchmethod instead of @singledispatch if we migrate to Python 3.8 or +# later. +@singledispatch +def _send_telemetry( + telemetry: ITelem, telemetry_messenger: ITelemetryMessenger, tunnel: MonkeyTunnel +): + telemetry_messenger.send_telemetry(telemetry) + + +@_send_telemetry.register +def _(telemetry: ExploitTelem, telemetry_messenger: ITelemetryMessenger, tunnel: MonkeyTunnel): + tunnel.set_wait_for_exploited_machines() + telemetry_messenger.send_telemetry(telemetry) diff --git a/monkey/infection_monkey/tunnel.py b/monkey/infection_monkey/tunnel.py index 4aa90e80f..769260d6b 100644 --- a/monkey/infection_monkey/tunnel.py +++ b/monkey/infection_monkey/tunnel.py @@ -2,7 +2,7 @@ import logging import socket import struct import time -from threading import Thread +from threading import Event, Thread from infection_monkey.network.firewall import app as firewall from infection_monkey.network.info import get_free_tcp_port, local_ips @@ -109,18 +109,27 @@ def quit_tunnel(address, timeout=DEFAULT_TIMEOUT): class MonkeyTunnel(Thread): - def __init__(self, proxy_class, target_addr=None, target_port=None, timeout=DEFAULT_TIMEOUT): + def __init__( + self, + proxy_class, + keep_tunnel_open_time, + target_addr=None, + target_port=None, + timeout=DEFAULT_TIMEOUT, + ): self._target_addr = target_addr self._target_port = target_port self._proxy_class = proxy_class + self._keep_tunnel_open_time = keep_tunnel_open_time self._broad_sock = None self._timeout = timeout - self._stopped = False + self._stopped = Event() self._clients = [] self.local_port = None super(MonkeyTunnel, self).__init__() self.daemon = True self.l_ips = None + self._wait_for_exploited_machines = Event() def run(self): self._broad_sock = _set_multicast_socket(self._timeout) @@ -146,7 +155,7 @@ class MonkeyTunnel(Thread): ) proxy.start() - while not self._stopped: + while not self._stopped.is_set(): try: search, address = self._broad_sock.recvfrom(BUFFER_READ) if b"?" == search: @@ -195,5 +204,17 @@ class MonkeyTunnel(Thread): ip_match = get_interface_to_target(ip) return "%s:%d" % (ip_match, self.local_port) + def set_wait_for_exploited_machines(self): + self._wait_for_exploited_machines.set() + def stop(self): - self._stopped = True + self._wait_for_exploited_machine_connection() + self._stopped.set() + + def _wait_for_exploited_machine_connection(self): + if self._wait_for_exploited_machines.is_set(): + logger.info( + f"Waiting {self._keep_tunnel_open_time} seconds for exploited machines to connect " + "to the tunnel." + ) + time.sleep(self._keep_tunnel_open_time) diff --git a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py new file mode 100644 index 000000000..c6b85df3e --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py @@ -0,0 +1,55 @@ +from unittest.mock import MagicMock + +from infection_monkey.telemetry.base_telem import BaseTelem +from infection_monkey.telemetry.exploit_telem import ExploitTelem +from infection_monkey.telemetry.i_telem import ITelem +from infection_monkey.telemetry.messengers.exploit_intercepting_telemetry_messenger import ( + ExploitInterceptingTelemetryMessenger, +) + + +class TestTelem(BaseTelem): + telem_category = None + + def __init__(self): + pass + + def get_data(self): + return {} + + +class MockExpliotTelem(ExploitTelem): + def __init__(self): + pass + + def get_data(self): + return {} + + +def test_generic_telemetry(): + mock_telemetry_messenger = MagicMock() + mock_tunnel = MagicMock() + + telemetry_messenger = ExploitInterceptingTelemetryMessenger( + mock_telemetry_messenger, mock_tunnel + ) + + telemetry_messenger.send_telemetry(TestTelem()) + + assert mock_telemetry_messenger.send_telemetry.called + assert not mock_tunnel.set_wait_for_exploited_machines.called + + +def test_expliot_telemetry(): + mock_telemetry_messenger = MagicMock() + mock_tunnel = MagicMock() + mock_expliot_telem = MockExpliotTelem() + + telemetry_messenger = ExploitInterceptingTelemetryMessenger( + mock_telemetry_messenger, mock_tunnel + ) + + telemetry_messenger.send_telemetry(mock_expliot_telem) + + assert mock_telemetry_messenger.send_telemetry.called + assert mock_tunnel.set_wait_for_exploited_machines.called