Merge pull request #2016 from guardicore/1996-agent-worm-config-decouple

1996 agent worm config decouple
This commit is contained in:
Mike Salvatore 2022-06-14 20:06:25 -04:00 committed by GitHub
commit ad1928db98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 127 additions and 156 deletions

View File

@ -8,7 +8,7 @@ SENSITIVE_FIELDS = [
"exploit_user_list",
"exploit_ssh_keys",
]
LOCAL_CONFIG_VARS = ["name", "id", "current_server", "max_depth"]
LOCAL_CONFIG_VARS = ["name", "id", "max_depth"]
HIDDEN_FIELD_REPLACEMENT_CONTENT = "hidden"
@ -62,10 +62,6 @@ class Configuration(object):
# depth of propagation
depth = 2
max_depth = None
current_server = ""
# Configuration servers to try to connect to, in this order.
command_servers = []
keep_tunnel_open_time = 30

View File

@ -3,6 +3,7 @@ import logging
import platform
from pprint import pformat
from socket import gethostname
from typing import Mapping, Optional
import requests
from requests.exceptions import ConnectionError
@ -23,11 +24,17 @@ logger = logging.getLogger(__name__)
PBA_FILE_DOWNLOAD = "https://%s/api/pba/download/%s"
class ControlClient(object):
proxies = {}
class ControlClient:
# TODO When we have mechanism that support telemetry messenger
# with control clients, then this needs to be removed
# https://github.com/guardicore/monkey/blob/133f7f5da131b481561141171827d1f9943f6aec/monkey/infection_monkey/telemetry/base_telem.py
control_client_object = None
@staticmethod
def wakeup(parent=None):
def __init__(self, server_address: str, proxies: Optional[Mapping[str, str]] = None):
self.proxies = {} if not proxies else proxies
self.server_address = server_address
def wakeup(self, parent=None):
if parent:
logger.debug("parent: %s" % (parent,))
@ -45,67 +52,51 @@ class ControlClient(object):
"launch_time": agent_process.get_start_time(),
}
if ControlClient.proxies:
monkey["tunnel"] = ControlClient.proxies.get("https")
if self.proxies:
monkey["tunnel"] = self.proxies.get("https")
requests.post( # noqa: DUO123
f"https://{WormConfiguration.current_server}/api/agent",
f"https://{self.server_address}/api/agent",
data=json.dumps(monkey),
headers={"content-type": "application/json"},
verify=False,
proxies=ControlClient.proxies,
proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
@staticmethod
def find_server(default_tunnel=None):
logger.debug(
"Trying to wake up with Monkey Island servers list: %r"
% WormConfiguration.command_servers
)
def find_server(self, default_tunnel=None):
logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}")
if default_tunnel:
logger.debug("default_tunnel: %s" % (default_tunnel,))
current_server = ""
for server in WormConfiguration.command_servers:
try:
current_server = server
debug_message = "Trying to connect to server: %s" % server
if ControlClient.proxies:
debug_message += " through proxies: %s" % ControlClient.proxies
logger.debug(debug_message)
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
proxies=ControlClient.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
WormConfiguration.current_server = current_server
break
except ConnectionError as exc:
current_server = ""
logger.warning("Error connecting to control server %s: %s", server, exc)
if current_server:
try:
debug_message = "Trying to connect to server: %s" % self.server_address
if self.proxies:
debug_message += " through proxies: %s" % self.proxies
logger.debug(debug_message)
requests.get( # noqa: DUO123
f"https://{self.server_address}/api?action=is-up",
verify=False,
proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return True
else:
if ControlClient.proxies:
return False
else:
logger.info("Starting tunnel lookup...")
proxy_find = tunnel.find_tunnel(default=default_tunnel)
if proxy_find:
ControlClient.set_proxies(proxy_find)
return ControlClient.find_server()
else:
logger.info("No tunnel found")
return False
except ConnectionError as exc:
logger.warning("Error connecting to control server %s: %s", self.server_address, exc)
@staticmethod
def set_proxies(proxy_find):
if self.proxies:
return False
else:
logger.info("Starting tunnel lookup...")
proxy_find = tunnel.find_tunnel(default=default_tunnel)
if proxy_find:
self.set_proxies(proxy_find)
return self.find_server()
else:
logger.info("No tunnel found")
return False
def set_proxies(self, proxy_find):
"""
Note: The proxy schema changes between different versions of requests and urllib3,
which causes the machine to not open a tunnel back.
@ -120,13 +111,12 @@ class ControlClient(object):
proxy_address, proxy_port = proxy_find
logger.info("Found tunnel at %s:%s" % (proxy_address, proxy_port))
if is_windows_os():
ControlClient.proxies["https"] = f"http://{proxy_address}:{proxy_port}"
self.proxies["https"] = f"http://{proxy_address}:{proxy_port}"
else:
ControlClient.proxies["https"] = f"{proxy_address}:{proxy_port}"
self.proxies["https"] = f"{proxy_address}:{proxy_port}"
@staticmethod
def send_telemetry(telem_category, json_data: str):
if not WormConfiguration.current_server:
def send_telemetry(self, telem_category, json_data: str):
if not self.server_address:
logger.error(
"Trying to send %s telemetry before current server is established, aborting."
% telem_category
@ -135,53 +125,45 @@ class ControlClient(object):
try:
telemetry = {"monkey_guid": GUID, "telem_category": telem_category, "data": json_data}
requests.post( # noqa: DUO123
"https://%s/api/telemetry" % (WormConfiguration.current_server,),
"https://%s/api/telemetry" % (self.server_address,),
data=json.dumps(telemetry),
headers={"content-type": "application/json"},
verify=False,
proxies=ControlClient.proxies,
proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
except Exception as exc:
logger.warning(
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
)
logger.warning(f"Error connecting to control server {self.server_address}: {exc}")
@staticmethod
def send_log(log):
if not WormConfiguration.current_server:
def send_log(self, log):
if not self.server_address:
return
try:
telemetry = {"monkey_guid": GUID, "log": json.dumps(log)}
requests.post( # noqa: DUO123
"https://%s/api/log" % (WormConfiguration.current_server,),
"https://%s/api/log" % (self.server_address,),
data=json.dumps(telemetry),
headers={"content-type": "application/json"},
verify=False,
proxies=ControlClient.proxies,
proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
except Exception as exc:
logger.warning(
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
)
logger.warning(f"Error connecting to control server {self.server_address}: {exc}")
@staticmethod
def load_control_config():
if not WormConfiguration.current_server:
def load_control_config(self):
if not self.server_address:
return
try:
reply = requests.get( # noqa: DUO123
f"https://{WormConfiguration.current_server}/api/agent/",
f"https://{self.server_address}/api/agent/",
verify=False,
proxies=ControlClient.proxies,
proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
except Exception as exc:
logger.warning(
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
)
logger.warning(f"Error connecting to control server {self.server_address}: {exc}")
return
try:
@ -194,18 +176,17 @@ class ControlClient(object):
# we don't continue with default conf here because it might be dangerous
logger.error(
"Error parsing JSON reply from control server %s (%s): %s",
WormConfiguration.current_server,
self.server_address,
reply._content,
exc,
)
raise Exception("Couldn't load from from server's configuration, aborting. %s" % exc)
@staticmethod
def create_control_tunnel():
if not WormConfiguration.current_server:
def create_control_tunnel(self):
if not self.server_address:
return None
my_proxy = ControlClient.proxies.get("https", "").replace("https://", "")
my_proxy = self.proxies.get("https", "").replace("https://", "")
if my_proxy:
proxy_class = TcpProxy
try:
@ -224,13 +205,12 @@ class ControlClient(object):
target_port=target_port,
)
@staticmethod
def get_pba_file(filename):
def get_pba_file(self, filename):
try:
return requests.get( # noqa: DUO123
PBA_FILE_DOWNLOAD % (WormConfiguration.current_server, filename),
PBA_FILE_DOWNLOAD % (self.server_address, filename),
verify=False,
proxies=ControlClient.proxies,
proxies=self.proxies,
timeout=LONG_REQUEST_TIMEOUT,
)
except requests.exceptions.RequestException:

View File

@ -1,10 +1,10 @@
import json
import logging
from typing import Mapping
import requests
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from infection_monkey.control import ControlClient
from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
@ -14,9 +14,10 @@ logger = logging.getLogger(__name__)
class ControlChannel(IControlChannel):
def __init__(self, server: str, agent_id: str):
def __init__(self, server: str, agent_id: str, proxies: Mapping[str, str]):
self._agent_id = agent_id
self._control_channel_server = server
self._proxies = proxies
def should_agent_stop(self) -> bool:
if not self._control_channel_server:
@ -30,7 +31,7 @@ class ControlChannel(IControlChannel):
response = requests.get( # noqa: DUO123
url,
verify=False,
proxies=ControlClient.proxies,
proxies=self._proxies,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
@ -51,7 +52,7 @@ class ControlChannel(IControlChannel):
response = requests.get( # noqa: DUO123
f"https://{self._control_channel_server}/api/agent",
verify=False,
proxies=ControlClient.proxies,
proxies=self._proxies,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
@ -74,7 +75,7 @@ class ControlChannel(IControlChannel):
response = requests.get( # noqa: DUO123
propagation_credentials_url,
verify=False,
proxies=ControlClient.proxies,
proxies=self._proxies,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()

View File

@ -10,7 +10,7 @@ import infection_monkey.tunnel as tunnel
from common.network.network_utils import address_to_ip_port
from common.utils.attack_utils import ScanStatus, UsageEnum
from common.version import get_version
from infection_monkey.config import GUID, WormConfiguration
from infection_monkey.config import GUID
from infection_monkey.control import ControlClient
from infection_monkey.credential_collectors import (
MimikatzCredentialCollector,
@ -89,7 +89,10 @@ class InfectionMonkey:
self._singleton = SystemSingleton()
self._opts = self._get_arguments(args)
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.server)
self._default_server = self._opts.server
self._control_client = ControlClient(self._opts.server)
# TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object
ControlClient.control_client_object = self._control_client
self._monkey_inbound_tunnel = None
self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth
@ -120,7 +123,6 @@ class InfectionMonkey:
logger.info("Agent is starting...")
logger.info(f"Agent GUID: {GUID}")
self._add_default_server_to_config(self._opts.server)
self._connect_to_island()
# TODO: Reevaluate who is responsible to send this information
@ -129,7 +131,9 @@ class InfectionMonkey:
run_aws_environment_check(self._telemetry_messenger)
should_stop = ControlChannel(WormConfiguration.current_server, GUID).should_agent_stop()
should_stop = ControlChannel(
self._control_client.server_address, GUID, self._control_client.proxies
).should_agent_stop()
if should_stop:
logger.info("The Monkey Island has instructed this agent to stop")
return
@ -137,27 +141,18 @@ class InfectionMonkey:
self._setup()
self._master.start()
@staticmethod
def _add_default_server_to_config(default_server: str):
if default_server:
logger.debug("Added default server: %s" % default_server)
WormConfiguration.command_servers.insert(0, default_server)
def _connect_to_island(self):
# Sets island's IP and port for monkey to communicate to
if self._current_server_is_set():
self._default_server = WormConfiguration.current_server
logger.debug("Default server set to: %s" % self._default_server)
logger.debug(f"Default server set to: {self._control_client.server_address}")
else:
raise Exception(
"Monkey couldn't find server with {} default tunnel.".format(self._opts.tunnel)
)
raise Exception(f"Monkey couldn't find server with {self._opts.tunnel} default tunnel.")
ControlClient.wakeup(parent=self._opts.parent)
ControlClient.load_control_config()
self._control_client.wakeup(parent=self._opts.parent)
self._control_client.load_control_config()
def _current_server_is_set(self) -> bool:
if ControlClient.find_server(default_tunnel=self._opts.tunnel):
if self._control_client.find_server(default_tunnel=self._opts.tunnel):
return True
return False
@ -170,12 +165,12 @@ class InfectionMonkey:
if firewall.is_enabled():
firewall.add_firewall_rule()
self._monkey_inbound_tunnel = ControlClient.create_control_tunnel()
self._monkey_inbound_tunnel = self._control_client.create_control_tunnel()
if self._monkey_inbound_tunnel and self._propagation_enabled():
self._monkey_inbound_tunnel.start()
StateTelem(is_done=False, version=get_version()).send()
TunnelTelem().send()
TunnelTelem(self._control_client.proxies).send()
self._build_master()
@ -184,7 +179,10 @@ class InfectionMonkey:
def _build_master(self):
local_network_interfaces = InfectionMonkey._get_local_network_interfaces()
control_channel = ControlChannel(self._default_server, GUID)
# TODO control_channel and control_client have same responsibilities, merge them
control_channel = ControlChannel(
self._control_client.server_address, GUID, self._control_client.proxies
)
credentials_store = AggregatingCredentialsStore(control_channel)
puppet = self._build_puppet(credentials_store)
@ -237,7 +235,7 @@ class InfectionMonkey:
puppet.load_plugin("ssh", SSHFingerprinter(), PluginType.FINGERPRINTER)
agent_repository = CachingAgentRepository(
f"https://{self._default_server}", ControlClient.proxies
f"https://{self._control_client.server_address}", self._control_client.proxies
)
exploit_wrapper = ExploiterWrapper(self._telemetry_messenger, agent_repository)
@ -320,7 +318,9 @@ class InfectionMonkey:
PluginType.POST_BREACH_ACTION,
)
puppet.load_plugin(
"CustomPBA", CustomPBA(self._telemetry_messenger), PluginType.POST_BREACH_ACTION
"CustomPBA",
CustomPBA(self._telemetry_messenger, self._control_client),
PluginType.POST_BREACH_ACTION,
)
puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD)
@ -338,7 +338,7 @@ class InfectionMonkey:
)
def _running_on_island(self, local_network_interfaces: List[NetworkInterface]) -> bool:
server_ip, _ = address_to_ip_port(self._default_server)
server_ip, _ = address_to_ip_port(self._control_client.server_address)
return server_ip in {interface.address for interface in local_network_interfaces}
def _is_another_monkey_running(self):
@ -363,13 +363,13 @@ class InfectionMonkey:
deleted = InfectionMonkey._self_delete()
InfectionMonkey._send_log()
self._send_log()
StateTelem(
is_done=True, version=get_version()
).send() # Signal the server (before closing the tunnel)
InfectionMonkey._close_tunnel()
self._close_tunnel()
self._singleton.unlock()
except Exception as e:
logger.error(f"An error occurred while cleaning up the monkey agent: {e}")
@ -384,15 +384,15 @@ class InfectionMonkey:
# maximum depth from the server
return self._current_depth is None or self._current_depth > 0
@staticmethod
def _close_tunnel():
tunnel_address = ControlClient.proxies.get("https", "").replace("http://", "").split(":")[0]
def _close_tunnel(self):
tunnel_address = (
self._control_client.proxies.get("https", "").replace("http://", "").split(":")[0]
)
if tunnel_address:
logger.info("Quitting tunnel %s", tunnel_address)
tunnel.quit_tunnel(tunnel_address)
@staticmethod
def _send_log():
def _send_log(self):
monkey_log_path = get_agent_log_path()
if monkey_log_path.is_file():
with open(monkey_log_path, "r") as f:
@ -400,7 +400,7 @@ class InfectionMonkey:
else:
log = ""
ControlClient.send_log(log)
self._control_client.send_log(log)
@staticmethod
def _self_delete() -> bool:

View File

@ -25,20 +25,18 @@ class CustomPBA(PBA):
Defines user's configured post breach action.
"""
def __init__(self, telemetry_messenger: ITelemetryMessenger):
def __init__(self, telemetry_messenger: ITelemetryMessenger, control_client: ControlClient):
super(CustomPBA, self).__init__(
telemetry_messenger, POST_BREACH_FILE_EXECUTION, timeout=None
)
self.filename = ""
self.control_client = control_client
def run(self, options: Dict) -> Iterable[PostBreachData]:
self._set_options(options)
return super().run(options)
def _set_options(self, options: Dict):
# Required for attack telemetry
self.current_server = options["current_server"]
if is_windows_os():
# Add windows commands to PBA's
if options["windows_filename"]:
@ -75,7 +73,7 @@ class CustomPBA(PBA):
:return: True if successful, false otherwise
"""
pba_file_contents = ControlClient.get_pba_file(filename)
pba_file_contents = self.control_client.get_pba_file(filename)
status = None
if not pba_file_contents or not pba_file_contents.content:
@ -88,8 +86,8 @@ class CustomPBA(PBA):
self.telemetry_messenger.send_telemetry(
T1105Telem(
status,
self.current_server.split(":")[0],
get_interface_to_target(self.current_server.split(":")[0]),
self.control_client.server_address.split(":")[0],
get_interface_to_target(self.control_client.server_address.split(":")[0]),
filename,
)
)

View File

@ -35,7 +35,7 @@ class BaseTelem(ITelem, metaclass=abc.ABCMeta):
data = self.get_data()
serialized_data = json.dumps(data, cls=self.json_encoder)
self._log_telem_sending(serialized_data, log_data)
ControlClient.send_telemetry(self.telem_category, serialized_data)
ControlClient.control_client_object.send_telemetry(self.telem_category, serialized_data)
@property
def json_encoder(self):

View File

@ -1,15 +1,16 @@
from typing import Mapping
from common.common_consts.telem_categories import TelemCategoryEnum
from infection_monkey.control import ControlClient
from infection_monkey.telemetry.base_telem import BaseTelem
class TunnelTelem(BaseTelem):
def __init__(self):
def __init__(self, proxy: Mapping[str, str]):
"""
Default tunnel telemetry constructor
"""
super(TunnelTelem, self).__init__()
self.proxy = ControlClient.proxies.get("https")
self.proxy = proxy.get("https")
telem_category = TelemCategoryEnum.TUNNEL

View File

@ -43,13 +43,11 @@ def fake_custom_pba_linux_options():
"linux_filename": CUSTOM_LINUX_FILENAME,
"windows_command": "",
"windows_filename": "",
# Current server is used for attack telemetry
"current_server": CUSTOM_SERVER,
}
def test_command_linux_custom_file_and_cmd(fake_custom_pba_linux_options, set_os_linux):
pba = CustomPBA(MagicMock())
pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_custom_pba_linux_options)
expected_command = f"cd {MONKEY_DIR_PATH} ; {CUSTOM_LINUX_CMD}"
assert pba.command == expected_command
@ -63,14 +61,12 @@ def fake_custom_pba_windows_options():
"linux_filename": "",
"windows_command": CUSTOM_WINDOWS_CMD,
"windows_filename": CUSTOM_WINDOWS_FILENAME,
# Current server is used for attack telemetry
"current_server": CUSTOM_SERVER,
}
def test_command_windows_custom_file_and_cmd(fake_custom_pba_windows_options, set_os_windows):
pba = CustomPBA(MagicMock())
pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_custom_pba_windows_options)
expected_command = f"cd {MONKEY_DIR_PATH} & {CUSTOM_WINDOWS_CMD}"
assert pba.command == expected_command
@ -84,14 +80,12 @@ def fake_options_files_only():
"linux_filename": CUSTOM_LINUX_FILENAME,
"windows_command": "",
"windows_filename": CUSTOM_WINDOWS_FILENAME,
# Current server is used for attack telemetry
"current_server": CUSTOM_SERVER,
}
@pytest.mark.parametrize("os", [set_os_linux, set_os_windows])
def test_files_only(fake_options_files_only, os):
pba = CustomPBA(MagicMock())
pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_options_files_only)
assert pba.command == ""
@ -103,20 +97,18 @@ def fake_options_commands_only():
"linux_filename": "",
"windows_command": CUSTOM_WINDOWS_CMD,
"windows_filename": "",
# Current server is used for attack telemetry
"current_server": CUSTOM_SERVER,
}
def test_commands_only(fake_options_commands_only, set_os_linux):
pba = CustomPBA(MagicMock())
pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_options_commands_only)
assert pba.command == CUSTOM_LINUX_CMD
assert pba.filename == ""
def test_commands_only_windows(fake_options_commands_only, set_os_windows):
pba = CustomPBA(MagicMock())
pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_options_commands_only)
assert pba.command == CUSTOM_WINDOWS_CMD
assert pba.filename == ""

View File

@ -1,3 +1,5 @@
from unittest.mock import MagicMock
import pytest
from infection_monkey.control import ControlClient
@ -11,5 +13,6 @@ def spy_send_telemetry(monkeypatch):
_spy_send_telemetry.telem_category = None
_spy_send_telemetry.data = None
monkeypatch.setattr(ControlClient, "send_telemetry", _spy_send_telemetry)
ControlClient.control_client_object = MagicMock()
ControlClient.control_client_object.send_telemetry = MagicMock(side_effect=_spy_send_telemetry)
return _spy_send_telemetry

View File

@ -7,7 +7,7 @@ from infection_monkey.telemetry.tunnel_telem import TunnelTelem
@pytest.fixture
def tunnel_telem_test_instance():
return TunnelTelem()
return TunnelTelem({})
def test_tunnel_telem_send(tunnel_telem_test_instance, spy_send_telemetry):

View File

@ -9,7 +9,7 @@ from monkey.infection_monkey.control import ControlClient
)
def test_control_set_proxies(monkeypatch, is_windows_os, expected_proxy_string):
monkeypatch.setattr("monkey.infection_monkey.control.is_windows_os", lambda: is_windows_os)
control_client = ControlClient()
control_client = ControlClient("8.8.8.8:5000")
control_client.set_proxies(("8.8.8.8", "45455"))