diff --git a/monkey/infection_monkey/control.py b/monkey/infection_monkey/control.py index 77b52fe3f..338a8183e 100644 --- a/monkey/infection_monkey/control.py +++ b/monkey/infection_monkey/control.py @@ -3,6 +3,7 @@ import logging import platform from pprint import pformat from socket import gethostname +from typing import Mapping, Optional import requests from requests.exceptions import ConnectionError @@ -23,11 +24,12 @@ logger = logging.getLogger(__name__) PBA_FILE_DOWNLOAD = "https://%s/api/pba/download/%s" -class ControlClient(object): - proxies = {} +class ControlClient: + def __init__(self, server_address: str, proxies: Optional[Mapping[str, str]] = None): + self.proxies = {} if not proxies else proxies + self.server_address = server_address - @staticmethod - def wakeup(parent=None): + def wakeup(self, parent=None): if parent: logger.debug("parent: %s" % (parent,)) @@ -45,20 +47,19 @@ class ControlClient(object): "launch_time": agent_process.get_start_time(), } - if ControlClient.proxies: - monkey["tunnel"] = ControlClient.proxies.get("https") + if self.proxies: + monkey["tunnel"] = self.proxies.get("https") requests.post( # noqa: DUO123 f"https://{WormConfiguration.current_server}/api/agent", data=json.dumps(monkey), headers={"content-type": "application/json"}, verify=False, - proxies=ControlClient.proxies, + proxies=self.proxies, timeout=MEDIUM_REQUEST_TIMEOUT, ) - @staticmethod - def find_server(default_tunnel=None): + def find_server(self, default_tunnel=None): logger.debug( "Trying to wake up with Monkey Island servers list: %r" % WormConfiguration.command_servers @@ -73,13 +74,13 @@ class ControlClient(object): current_server = server debug_message = "Trying to connect to server: %s" % server - if ControlClient.proxies: - debug_message += " through proxies: %s" % ControlClient.proxies + if self.proxies: + debug_message += " through proxies: %s" % self.proxies logger.debug(debug_message) requests.get( # noqa: DUO123 f"https://{server}/api?action=is-up", verify=False, - proxies=ControlClient.proxies, + proxies=self.proxies, timeout=MEDIUM_REQUEST_TIMEOUT, ) WormConfiguration.current_server = current_server @@ -92,20 +93,19 @@ class ControlClient(object): if current_server: return True else: - if ControlClient.proxies: + if self.proxies: return False else: logger.info("Starting tunnel lookup...") proxy_find = tunnel.find_tunnel(default=default_tunnel) if proxy_find: - ControlClient.set_proxies(proxy_find) - return ControlClient.find_server() + self.set_proxies(proxy_find) + return self.find_server() else: logger.info("No tunnel found") return False - @staticmethod - def set_proxies(proxy_find): + def set_proxies(self, proxy_find): """ Note: The proxy schema changes between different versions of requests and urllib3, which causes the machine to not open a tunnel back. @@ -120,12 +120,11 @@ class ControlClient(object): proxy_address, proxy_port = proxy_find logger.info("Found tunnel at %s:%s" % (proxy_address, proxy_port)) if is_windows_os(): - ControlClient.proxies["https"] = f"http://{proxy_address}:{proxy_port}" + self.proxies["https"] = f"http://{proxy_address}:{proxy_port}" else: - ControlClient.proxies["https"] = f"{proxy_address}:{proxy_port}" + self.proxies["https"] = f"{proxy_address}:{proxy_port}" - @staticmethod - def send_telemetry(telem_category, json_data: str): + def send_telemetry(self, telem_category, json_data: str): if not WormConfiguration.current_server: logger.error( "Trying to send %s telemetry before current server is established, aborting." @@ -139,7 +138,7 @@ class ControlClient(object): data=json.dumps(telemetry), headers={"content-type": "application/json"}, verify=False, - proxies=ControlClient.proxies, + proxies=self.proxies, timeout=MEDIUM_REQUEST_TIMEOUT, ) except Exception as exc: @@ -147,8 +146,7 @@ class ControlClient(object): "Error connecting to control server %s: %s", WormConfiguration.current_server, exc ) - @staticmethod - def send_log(log): + def send_log(self, log): if not WormConfiguration.current_server: return try: @@ -158,7 +156,7 @@ class ControlClient(object): data=json.dumps(telemetry), headers={"content-type": "application/json"}, verify=False, - proxies=ControlClient.proxies, + proxies=self.proxies, timeout=MEDIUM_REQUEST_TIMEOUT, ) except Exception as exc: @@ -166,15 +164,14 @@ class ControlClient(object): "Error connecting to control server %s: %s", WormConfiguration.current_server, exc ) - @staticmethod - def load_control_config(): + def load_control_config(self): if not WormConfiguration.current_server: return try: reply = requests.get( # noqa: DUO123 f"https://{WormConfiguration.current_server}/api/agent/", verify=False, - proxies=ControlClient.proxies, + proxies=self.proxies, timeout=MEDIUM_REQUEST_TIMEOUT, ) @@ -200,12 +197,11 @@ class ControlClient(object): ) raise Exception("Couldn't load from from server's configuration, aborting. %s" % exc) - @staticmethod - def create_control_tunnel(): + def create_control_tunnel(self): if not WormConfiguration.current_server: return None - my_proxy = ControlClient.proxies.get("https", "").replace("https://", "") + my_proxy = self.proxies.get("https", "").replace("https://", "") if my_proxy: proxy_class = TcpProxy try: @@ -224,13 +220,12 @@ class ControlClient(object): target_port=target_port, ) - @staticmethod - def get_pba_file(filename): + def get_pba_file(self, filename): try: return requests.get( # noqa: DUO123 PBA_FILE_DOWNLOAD % (WormConfiguration.current_server, filename), verify=False, - proxies=ControlClient.proxies, + proxies=self.proxies, timeout=LONG_REQUEST_TIMEOUT, ) except requests.exceptions.RequestException: diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 4d2807194..13930ecda 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -89,7 +89,7 @@ class InfectionMonkey: self._singleton = SystemSingleton() self._opts = self._get_arguments(args) self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.server) - self._default_server = self._opts.server + self.cc_client = ControlClient(self._opts.server) self._monkey_inbound_tunnel = None self._telemetry_messenger = LegacyTelemetryMessengerAdapter() self._current_depth = self._opts.depth @@ -120,7 +120,6 @@ class InfectionMonkey: logger.info("Agent is starting...") logger.info(f"Agent GUID: {GUID}") - self._add_default_server_to_config(self._opts.server) self._connect_to_island() # TODO: Reevaluate who is responsible to send this information @@ -137,27 +136,20 @@ class InfectionMonkey: self._setup() self._master.start() - @staticmethod - def _add_default_server_to_config(default_server: str): - if default_server: - logger.debug("Added default server: %s" % default_server) - WormConfiguration.command_servers.insert(0, default_server) - def _connect_to_island(self): # Sets island's IP and port for monkey to communicate to if self._current_server_is_set(): - self._default_server = WormConfiguration.current_server - logger.debug("Default server set to: %s" % self._default_server) + logger.debug("Default server set to: %s" % self.cc_client.server_address) else: raise Exception( "Monkey couldn't find server with {} default tunnel.".format(self._opts.tunnel) ) - ControlClient.wakeup(parent=self._opts.parent) - ControlClient.load_control_config() + self.cc_client.wakeup(parent=self._opts.parent) + self.cc_client.load_control_config() def _current_server_is_set(self) -> bool: - if ControlClient.find_server(default_tunnel=self._opts.tunnel): + if self.cc_client.find_server(default_tunnel=self._opts.tunnel): return True return False @@ -170,7 +162,7 @@ class InfectionMonkey: if firewall.is_enabled(): firewall.add_firewall_rule() - self._monkey_inbound_tunnel = ControlClient.create_control_tunnel() + self._monkey_inbound_tunnel = self.cc_client.create_control_tunnel() if self._monkey_inbound_tunnel and self._propagation_enabled(): self._monkey_inbound_tunnel.start() @@ -184,7 +176,9 @@ class InfectionMonkey: def _build_master(self): local_network_interfaces = InfectionMonkey._get_local_network_interfaces() - control_channel = ControlChannel(self._default_server, GUID) + # TODO control_channel and control_client have same responsibilities, merge them + control_channel = ControlChannel(self.cc_client.server_address, GUID) + control_client = self.cc_client credentials_store = AggregatingCredentialsStore(control_channel) puppet = self._build_puppet(credentials_store) @@ -206,6 +200,7 @@ class InfectionMonkey: control_channel, local_network_interfaces, credentials_store, + control_client, ) @staticmethod @@ -237,7 +232,7 @@ class InfectionMonkey: puppet.load_plugin("ssh", SSHFingerprinter(), PluginType.FINGERPRINTER) agent_repository = CachingAgentRepository( - f"https://{self._default_server}", ControlClient.proxies + f"https://{self.cc_client.server_address}", self.cc_client.proxies ) exploit_wrapper = ExploiterWrapper(self._telemetry_messenger, agent_repository) @@ -320,7 +315,9 @@ class InfectionMonkey: PluginType.POST_BREACH_ACTION, ) puppet.load_plugin( - "CustomPBA", CustomPBA(self._telemetry_messenger), PluginType.POST_BREACH_ACTION + "CustomPBA", + CustomPBA(self._telemetry_messenger, self.cc_client), + PluginType.POST_BREACH_ACTION, ) puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD) @@ -338,7 +335,7 @@ class InfectionMonkey: ) def _running_on_island(self, local_network_interfaces: List[NetworkInterface]) -> bool: - server_ip, _ = address_to_ip_port(self._default_server) + server_ip, _ = address_to_ip_port(self.cc_client.server_address) return server_ip in {interface.address for interface in local_network_interfaces} def _is_another_monkey_running(self): @@ -363,13 +360,13 @@ class InfectionMonkey: deleted = InfectionMonkey._self_delete() - InfectionMonkey._send_log() + self._send_log() StateTelem( is_done=True, version=get_version() ).send() # Signal the server (before closing the tunnel) - InfectionMonkey._close_tunnel() + self._close_tunnel() self._singleton.unlock() except Exception as e: logger.error(f"An error occurred while cleaning up the monkey agent: {e}") @@ -384,15 +381,15 @@ class InfectionMonkey: # maximum depth from the server return self._current_depth is None or self._current_depth > 0 - @staticmethod - def _close_tunnel(): - tunnel_address = ControlClient.proxies.get("https", "").replace("http://", "").split(":")[0] + def _close_tunnel(self): + tunnel_address = ( + self.cc_client.proxies.get("https", "").replace("http://", "").split(":")[0] + ) if tunnel_address: logger.info("Quitting tunnel %s", tunnel_address) tunnel.quit_tunnel(tunnel_address) - @staticmethod - def _send_log(): + def _send_log(self): monkey_log_path = get_agent_log_path() if monkey_log_path.is_file(): with open(monkey_log_path, "r") as f: @@ -400,7 +397,7 @@ class InfectionMonkey: else: log = "" - ControlClient.send_log(log) + self.cc_client.send_log(log) @staticmethod def _self_delete() -> bool: