Merge pull request #2016 from guardicore/1996-agent-worm-config-decouple

1996 agent worm config decouple
This commit is contained in:
Mike Salvatore 2022-06-14 20:06:25 -04:00 committed by GitHub
commit ad1928db98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 127 additions and 156 deletions

View File

@ -8,7 +8,7 @@ SENSITIVE_FIELDS = [
"exploit_user_list", "exploit_user_list",
"exploit_ssh_keys", "exploit_ssh_keys",
] ]
LOCAL_CONFIG_VARS = ["name", "id", "current_server", "max_depth"] LOCAL_CONFIG_VARS = ["name", "id", "max_depth"]
HIDDEN_FIELD_REPLACEMENT_CONTENT = "hidden" HIDDEN_FIELD_REPLACEMENT_CONTENT = "hidden"
@ -62,10 +62,6 @@ class Configuration(object):
# depth of propagation # depth of propagation
depth = 2 depth = 2
max_depth = None max_depth = None
current_server = ""
# Configuration servers to try to connect to, in this order.
command_servers = []
keep_tunnel_open_time = 30 keep_tunnel_open_time = 30

View File

@ -3,6 +3,7 @@ import logging
import platform import platform
from pprint import pformat from pprint import pformat
from socket import gethostname from socket import gethostname
from typing import Mapping, Optional
import requests import requests
from requests.exceptions import ConnectionError from requests.exceptions import ConnectionError
@ -23,11 +24,17 @@ logger = logging.getLogger(__name__)
PBA_FILE_DOWNLOAD = "https://%s/api/pba/download/%s" PBA_FILE_DOWNLOAD = "https://%s/api/pba/download/%s"
class ControlClient(object): class ControlClient:
proxies = {} # TODO When we have mechanism that support telemetry messenger
# with control clients, then this needs to be removed
# https://github.com/guardicore/monkey/blob/133f7f5da131b481561141171827d1f9943f6aec/monkey/infection_monkey/telemetry/base_telem.py
control_client_object = None
@staticmethod def __init__(self, server_address: str, proxies: Optional[Mapping[str, str]] = None):
def wakeup(parent=None): self.proxies = {} if not proxies else proxies
self.server_address = server_address
def wakeup(self, parent=None):
if parent: if parent:
logger.debug("parent: %s" % (parent,)) logger.debug("parent: %s" % (parent,))
@ -45,67 +52,51 @@ class ControlClient(object):
"launch_time": agent_process.get_start_time(), "launch_time": agent_process.get_start_time(),
} }
if ControlClient.proxies: if self.proxies:
monkey["tunnel"] = ControlClient.proxies.get("https") monkey["tunnel"] = self.proxies.get("https")
requests.post( # noqa: DUO123 requests.post( # noqa: DUO123
f"https://{WormConfiguration.current_server}/api/agent", f"https://{self.server_address}/api/agent",
data=json.dumps(monkey), data=json.dumps(monkey),
headers={"content-type": "application/json"}, headers={"content-type": "application/json"},
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT, timeout=MEDIUM_REQUEST_TIMEOUT,
) )
@staticmethod def find_server(self, default_tunnel=None):
def find_server(default_tunnel=None): logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}")
logger.debug(
"Trying to wake up with Monkey Island servers list: %r"
% WormConfiguration.command_servers
)
if default_tunnel: if default_tunnel:
logger.debug("default_tunnel: %s" % (default_tunnel,)) logger.debug("default_tunnel: %s" % (default_tunnel,))
current_server = "" try:
debug_message = "Trying to connect to server: %s" % self.server_address
for server in WormConfiguration.command_servers: if self.proxies:
try: debug_message += " through proxies: %s" % self.proxies
current_server = server logger.debug(debug_message)
requests.get( # noqa: DUO123
debug_message = "Trying to connect to server: %s" % server f"https://{self.server_address}/api?action=is-up",
if ControlClient.proxies: verify=False,
debug_message += " through proxies: %s" % ControlClient.proxies proxies=self.proxies,
logger.debug(debug_message) timeout=MEDIUM_REQUEST_TIMEOUT,
requests.get( # noqa: DUO123 )
f"https://{server}/api?action=is-up",
verify=False,
proxies=ControlClient.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
WormConfiguration.current_server = current_server
break
except ConnectionError as exc:
current_server = ""
logger.warning("Error connecting to control server %s: %s", server, exc)
if current_server:
return True return True
else: except ConnectionError as exc:
if ControlClient.proxies: logger.warning("Error connecting to control server %s: %s", self.server_address, exc)
return False
else:
logger.info("Starting tunnel lookup...")
proxy_find = tunnel.find_tunnel(default=default_tunnel)
if proxy_find:
ControlClient.set_proxies(proxy_find)
return ControlClient.find_server()
else:
logger.info("No tunnel found")
return False
@staticmethod if self.proxies:
def set_proxies(proxy_find): return False
else:
logger.info("Starting tunnel lookup...")
proxy_find = tunnel.find_tunnel(default=default_tunnel)
if proxy_find:
self.set_proxies(proxy_find)
return self.find_server()
else:
logger.info("No tunnel found")
return False
def set_proxies(self, proxy_find):
""" """
Note: The proxy schema changes between different versions of requests and urllib3, Note: The proxy schema changes between different versions of requests and urllib3,
which causes the machine to not open a tunnel back. which causes the machine to not open a tunnel back.
@ -120,13 +111,12 @@ class ControlClient(object):
proxy_address, proxy_port = proxy_find proxy_address, proxy_port = proxy_find
logger.info("Found tunnel at %s:%s" % (proxy_address, proxy_port)) logger.info("Found tunnel at %s:%s" % (proxy_address, proxy_port))
if is_windows_os(): if is_windows_os():
ControlClient.proxies["https"] = f"http://{proxy_address}:{proxy_port}" self.proxies["https"] = f"http://{proxy_address}:{proxy_port}"
else: else:
ControlClient.proxies["https"] = f"{proxy_address}:{proxy_port}" self.proxies["https"] = f"{proxy_address}:{proxy_port}"
@staticmethod def send_telemetry(self, telem_category, json_data: str):
def send_telemetry(telem_category, json_data: str): if not self.server_address:
if not WormConfiguration.current_server:
logger.error( logger.error(
"Trying to send %s telemetry before current server is established, aborting." "Trying to send %s telemetry before current server is established, aborting."
% telem_category % telem_category
@ -135,53 +125,45 @@ class ControlClient(object):
try: try:
telemetry = {"monkey_guid": GUID, "telem_category": telem_category, "data": json_data} telemetry = {"monkey_guid": GUID, "telem_category": telem_category, "data": json_data}
requests.post( # noqa: DUO123 requests.post( # noqa: DUO123
"https://%s/api/telemetry" % (WormConfiguration.current_server,), "https://%s/api/telemetry" % (self.server_address,),
data=json.dumps(telemetry), data=json.dumps(telemetry),
headers={"content-type": "application/json"}, headers={"content-type": "application/json"},
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT, timeout=MEDIUM_REQUEST_TIMEOUT,
) )
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(f"Error connecting to control server {self.server_address}: {exc}")
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
)
@staticmethod def send_log(self, log):
def send_log(log): if not self.server_address:
if not WormConfiguration.current_server:
return return
try: try:
telemetry = {"monkey_guid": GUID, "log": json.dumps(log)} telemetry = {"monkey_guid": GUID, "log": json.dumps(log)}
requests.post( # noqa: DUO123 requests.post( # noqa: DUO123
"https://%s/api/log" % (WormConfiguration.current_server,), "https://%s/api/log" % (self.server_address,),
data=json.dumps(telemetry), data=json.dumps(telemetry),
headers={"content-type": "application/json"}, headers={"content-type": "application/json"},
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT, timeout=MEDIUM_REQUEST_TIMEOUT,
) )
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(f"Error connecting to control server {self.server_address}: {exc}")
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
)
@staticmethod def load_control_config(self):
def load_control_config(): if not self.server_address:
if not WormConfiguration.current_server:
return return
try: try:
reply = requests.get( # noqa: DUO123 reply = requests.get( # noqa: DUO123
f"https://{WormConfiguration.current_server}/api/agent/", f"https://{self.server_address}/api/agent/",
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=self.proxies,
timeout=MEDIUM_REQUEST_TIMEOUT, timeout=MEDIUM_REQUEST_TIMEOUT,
) )
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(f"Error connecting to control server {self.server_address}: {exc}")
"Error connecting to control server %s: %s", WormConfiguration.current_server, exc
)
return return
try: try:
@ -194,18 +176,17 @@ class ControlClient(object):
# we don't continue with default conf here because it might be dangerous # we don't continue with default conf here because it might be dangerous
logger.error( logger.error(
"Error parsing JSON reply from control server %s (%s): %s", "Error parsing JSON reply from control server %s (%s): %s",
WormConfiguration.current_server, self.server_address,
reply._content, reply._content,
exc, exc,
) )
raise Exception("Couldn't load from from server's configuration, aborting. %s" % exc) raise Exception("Couldn't load from from server's configuration, aborting. %s" % exc)
@staticmethod def create_control_tunnel(self):
def create_control_tunnel(): if not self.server_address:
if not WormConfiguration.current_server:
return None return None
my_proxy = ControlClient.proxies.get("https", "").replace("https://", "") my_proxy = self.proxies.get("https", "").replace("https://", "")
if my_proxy: if my_proxy:
proxy_class = TcpProxy proxy_class = TcpProxy
try: try:
@ -224,13 +205,12 @@ class ControlClient(object):
target_port=target_port, target_port=target_port,
) )
@staticmethod def get_pba_file(self, filename):
def get_pba_file(filename):
try: try:
return requests.get( # noqa: DUO123 return requests.get( # noqa: DUO123
PBA_FILE_DOWNLOAD % (WormConfiguration.current_server, filename), PBA_FILE_DOWNLOAD % (self.server_address, filename),
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=self.proxies,
timeout=LONG_REQUEST_TIMEOUT, timeout=LONG_REQUEST_TIMEOUT,
) )
except requests.exceptions.RequestException: except requests.exceptions.RequestException:

View File

@ -1,10 +1,10 @@
import json import json
import logging import logging
from typing import Mapping
import requests import requests
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from infection_monkey.control import ControlClient
from infection_monkey.custom_types import PropagationCredentials from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
@ -14,9 +14,10 @@ logger = logging.getLogger(__name__)
class ControlChannel(IControlChannel): class ControlChannel(IControlChannel):
def __init__(self, server: str, agent_id: str): def __init__(self, server: str, agent_id: str, proxies: Mapping[str, str]):
self._agent_id = agent_id self._agent_id = agent_id
self._control_channel_server = server self._control_channel_server = server
self._proxies = proxies
def should_agent_stop(self) -> bool: def should_agent_stop(self) -> bool:
if not self._control_channel_server: if not self._control_channel_server:
@ -30,7 +31,7 @@ class ControlChannel(IControlChannel):
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
url, url,
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=self._proxies,
timeout=SHORT_REQUEST_TIMEOUT, timeout=SHORT_REQUEST_TIMEOUT,
) )
response.raise_for_status() response.raise_for_status()
@ -51,7 +52,7 @@ class ControlChannel(IControlChannel):
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
f"https://{self._control_channel_server}/api/agent", f"https://{self._control_channel_server}/api/agent",
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=self._proxies,
timeout=SHORT_REQUEST_TIMEOUT, timeout=SHORT_REQUEST_TIMEOUT,
) )
response.raise_for_status() response.raise_for_status()
@ -74,7 +75,7 @@ class ControlChannel(IControlChannel):
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
propagation_credentials_url, propagation_credentials_url,
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=self._proxies,
timeout=SHORT_REQUEST_TIMEOUT, timeout=SHORT_REQUEST_TIMEOUT,
) )
response.raise_for_status() response.raise_for_status()

View File

@ -10,7 +10,7 @@ import infection_monkey.tunnel as tunnel
from common.network.network_utils import address_to_ip_port from common.network.network_utils import address_to_ip_port
from common.utils.attack_utils import ScanStatus, UsageEnum from common.utils.attack_utils import ScanStatus, UsageEnum
from common.version import get_version from common.version import get_version
from infection_monkey.config import GUID, WormConfiguration from infection_monkey.config import GUID
from infection_monkey.control import ControlClient from infection_monkey.control import ControlClient
from infection_monkey.credential_collectors import ( from infection_monkey.credential_collectors import (
MimikatzCredentialCollector, MimikatzCredentialCollector,
@ -89,7 +89,10 @@ class InfectionMonkey:
self._singleton = SystemSingleton() self._singleton = SystemSingleton()
self._opts = self._get_arguments(args) self._opts = self._get_arguments(args)
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.server) self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.server)
self._default_server = self._opts.server self._control_client = ControlClient(self._opts.server)
# TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object
ControlClient.control_client_object = self._control_client
self._monkey_inbound_tunnel = None self._monkey_inbound_tunnel = None
self._telemetry_messenger = LegacyTelemetryMessengerAdapter() self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth self._current_depth = self._opts.depth
@ -120,7 +123,6 @@ class InfectionMonkey:
logger.info("Agent is starting...") logger.info("Agent is starting...")
logger.info(f"Agent GUID: {GUID}") logger.info(f"Agent GUID: {GUID}")
self._add_default_server_to_config(self._opts.server)
self._connect_to_island() self._connect_to_island()
# TODO: Reevaluate who is responsible to send this information # TODO: Reevaluate who is responsible to send this information
@ -129,7 +131,9 @@ class InfectionMonkey:
run_aws_environment_check(self._telemetry_messenger) run_aws_environment_check(self._telemetry_messenger)
should_stop = ControlChannel(WormConfiguration.current_server, GUID).should_agent_stop() should_stop = ControlChannel(
self._control_client.server_address, GUID, self._control_client.proxies
).should_agent_stop()
if should_stop: if should_stop:
logger.info("The Monkey Island has instructed this agent to stop") logger.info("The Monkey Island has instructed this agent to stop")
return return
@ -137,27 +141,18 @@ class InfectionMonkey:
self._setup() self._setup()
self._master.start() self._master.start()
@staticmethod
def _add_default_server_to_config(default_server: str):
if default_server:
logger.debug("Added default server: %s" % default_server)
WormConfiguration.command_servers.insert(0, default_server)
def _connect_to_island(self): def _connect_to_island(self):
# Sets island's IP and port for monkey to communicate to # Sets island's IP and port for monkey to communicate to
if self._current_server_is_set(): if self._current_server_is_set():
self._default_server = WormConfiguration.current_server logger.debug(f"Default server set to: {self._control_client.server_address}")
logger.debug("Default server set to: %s" % self._default_server)
else: else:
raise Exception( raise Exception(f"Monkey couldn't find server with {self._opts.tunnel} default tunnel.")
"Monkey couldn't find server with {} default tunnel.".format(self._opts.tunnel)
)
ControlClient.wakeup(parent=self._opts.parent) self._control_client.wakeup(parent=self._opts.parent)
ControlClient.load_control_config() self._control_client.load_control_config()
def _current_server_is_set(self) -> bool: def _current_server_is_set(self) -> bool:
if ControlClient.find_server(default_tunnel=self._opts.tunnel): if self._control_client.find_server(default_tunnel=self._opts.tunnel):
return True return True
return False return False
@ -170,12 +165,12 @@ class InfectionMonkey:
if firewall.is_enabled(): if firewall.is_enabled():
firewall.add_firewall_rule() firewall.add_firewall_rule()
self._monkey_inbound_tunnel = ControlClient.create_control_tunnel() self._monkey_inbound_tunnel = self._control_client.create_control_tunnel()
if self._monkey_inbound_tunnel and self._propagation_enabled(): if self._monkey_inbound_tunnel and self._propagation_enabled():
self._monkey_inbound_tunnel.start() self._monkey_inbound_tunnel.start()
StateTelem(is_done=False, version=get_version()).send() StateTelem(is_done=False, version=get_version()).send()
TunnelTelem().send() TunnelTelem(self._control_client.proxies).send()
self._build_master() self._build_master()
@ -184,7 +179,10 @@ class InfectionMonkey:
def _build_master(self): def _build_master(self):
local_network_interfaces = InfectionMonkey._get_local_network_interfaces() local_network_interfaces = InfectionMonkey._get_local_network_interfaces()
control_channel = ControlChannel(self._default_server, GUID) # TODO control_channel and control_client have same responsibilities, merge them
control_channel = ControlChannel(
self._control_client.server_address, GUID, self._control_client.proxies
)
credentials_store = AggregatingCredentialsStore(control_channel) credentials_store = AggregatingCredentialsStore(control_channel)
puppet = self._build_puppet(credentials_store) puppet = self._build_puppet(credentials_store)
@ -237,7 +235,7 @@ class InfectionMonkey:
puppet.load_plugin("ssh", SSHFingerprinter(), PluginType.FINGERPRINTER) puppet.load_plugin("ssh", SSHFingerprinter(), PluginType.FINGERPRINTER)
agent_repository = CachingAgentRepository( agent_repository = CachingAgentRepository(
f"https://{self._default_server}", ControlClient.proxies f"https://{self._control_client.server_address}", self._control_client.proxies
) )
exploit_wrapper = ExploiterWrapper(self._telemetry_messenger, agent_repository) exploit_wrapper = ExploiterWrapper(self._telemetry_messenger, agent_repository)
@ -320,7 +318,9 @@ class InfectionMonkey:
PluginType.POST_BREACH_ACTION, PluginType.POST_BREACH_ACTION,
) )
puppet.load_plugin( puppet.load_plugin(
"CustomPBA", CustomPBA(self._telemetry_messenger), PluginType.POST_BREACH_ACTION "CustomPBA",
CustomPBA(self._telemetry_messenger, self._control_client),
PluginType.POST_BREACH_ACTION,
) )
puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD) puppet.load_plugin("ransomware", RansomwarePayload(), PluginType.PAYLOAD)
@ -338,7 +338,7 @@ class InfectionMonkey:
) )
def _running_on_island(self, local_network_interfaces: List[NetworkInterface]) -> bool: def _running_on_island(self, local_network_interfaces: List[NetworkInterface]) -> bool:
server_ip, _ = address_to_ip_port(self._default_server) server_ip, _ = address_to_ip_port(self._control_client.server_address)
return server_ip in {interface.address for interface in local_network_interfaces} return server_ip in {interface.address for interface in local_network_interfaces}
def _is_another_monkey_running(self): def _is_another_monkey_running(self):
@ -363,13 +363,13 @@ class InfectionMonkey:
deleted = InfectionMonkey._self_delete() deleted = InfectionMonkey._self_delete()
InfectionMonkey._send_log() self._send_log()
StateTelem( StateTelem(
is_done=True, version=get_version() is_done=True, version=get_version()
).send() # Signal the server (before closing the tunnel) ).send() # Signal the server (before closing the tunnel)
InfectionMonkey._close_tunnel() self._close_tunnel()
self._singleton.unlock() self._singleton.unlock()
except Exception as e: except Exception as e:
logger.error(f"An error occurred while cleaning up the monkey agent: {e}") logger.error(f"An error occurred while cleaning up the monkey agent: {e}")
@ -384,15 +384,15 @@ class InfectionMonkey:
# maximum depth from the server # maximum depth from the server
return self._current_depth is None or self._current_depth > 0 return self._current_depth is None or self._current_depth > 0
@staticmethod def _close_tunnel(self):
def _close_tunnel(): tunnel_address = (
tunnel_address = ControlClient.proxies.get("https", "").replace("http://", "").split(":")[0] self._control_client.proxies.get("https", "").replace("http://", "").split(":")[0]
)
if tunnel_address: if tunnel_address:
logger.info("Quitting tunnel %s", tunnel_address) logger.info("Quitting tunnel %s", tunnel_address)
tunnel.quit_tunnel(tunnel_address) tunnel.quit_tunnel(tunnel_address)
@staticmethod def _send_log(self):
def _send_log():
monkey_log_path = get_agent_log_path() monkey_log_path = get_agent_log_path()
if monkey_log_path.is_file(): if monkey_log_path.is_file():
with open(monkey_log_path, "r") as f: with open(monkey_log_path, "r") as f:
@ -400,7 +400,7 @@ class InfectionMonkey:
else: else:
log = "" log = ""
ControlClient.send_log(log) self._control_client.send_log(log)
@staticmethod @staticmethod
def _self_delete() -> bool: def _self_delete() -> bool:

View File

@ -25,20 +25,18 @@ class CustomPBA(PBA):
Defines user's configured post breach action. Defines user's configured post breach action.
""" """
def __init__(self, telemetry_messenger: ITelemetryMessenger): def __init__(self, telemetry_messenger: ITelemetryMessenger, control_client: ControlClient):
super(CustomPBA, self).__init__( super(CustomPBA, self).__init__(
telemetry_messenger, POST_BREACH_FILE_EXECUTION, timeout=None telemetry_messenger, POST_BREACH_FILE_EXECUTION, timeout=None
) )
self.filename = "" self.filename = ""
self.control_client = control_client
def run(self, options: Dict) -> Iterable[PostBreachData]: def run(self, options: Dict) -> Iterable[PostBreachData]:
self._set_options(options) self._set_options(options)
return super().run(options) return super().run(options)
def _set_options(self, options: Dict): def _set_options(self, options: Dict):
# Required for attack telemetry
self.current_server = options["current_server"]
if is_windows_os(): if is_windows_os():
# Add windows commands to PBA's # Add windows commands to PBA's
if options["windows_filename"]: if options["windows_filename"]:
@ -75,7 +73,7 @@ class CustomPBA(PBA):
:return: True if successful, false otherwise :return: True if successful, false otherwise
""" """
pba_file_contents = ControlClient.get_pba_file(filename) pba_file_contents = self.control_client.get_pba_file(filename)
status = None status = None
if not pba_file_contents or not pba_file_contents.content: if not pba_file_contents or not pba_file_contents.content:
@ -88,8 +86,8 @@ class CustomPBA(PBA):
self.telemetry_messenger.send_telemetry( self.telemetry_messenger.send_telemetry(
T1105Telem( T1105Telem(
status, status,
self.current_server.split(":")[0], self.control_client.server_address.split(":")[0],
get_interface_to_target(self.current_server.split(":")[0]), get_interface_to_target(self.control_client.server_address.split(":")[0]),
filename, filename,
) )
) )

View File

@ -35,7 +35,7 @@ class BaseTelem(ITelem, metaclass=abc.ABCMeta):
data = self.get_data() data = self.get_data()
serialized_data = json.dumps(data, cls=self.json_encoder) serialized_data = json.dumps(data, cls=self.json_encoder)
self._log_telem_sending(serialized_data, log_data) self._log_telem_sending(serialized_data, log_data)
ControlClient.send_telemetry(self.telem_category, serialized_data) ControlClient.control_client_object.send_telemetry(self.telem_category, serialized_data)
@property @property
def json_encoder(self): def json_encoder(self):

View File

@ -1,15 +1,16 @@
from typing import Mapping
from common.common_consts.telem_categories import TelemCategoryEnum from common.common_consts.telem_categories import TelemCategoryEnum
from infection_monkey.control import ControlClient
from infection_monkey.telemetry.base_telem import BaseTelem from infection_monkey.telemetry.base_telem import BaseTelem
class TunnelTelem(BaseTelem): class TunnelTelem(BaseTelem):
def __init__(self): def __init__(self, proxy: Mapping[str, str]):
""" """
Default tunnel telemetry constructor Default tunnel telemetry constructor
""" """
super(TunnelTelem, self).__init__() super(TunnelTelem, self).__init__()
self.proxy = ControlClient.proxies.get("https") self.proxy = proxy.get("https")
telem_category = TelemCategoryEnum.TUNNEL telem_category = TelemCategoryEnum.TUNNEL

View File

@ -43,13 +43,11 @@ def fake_custom_pba_linux_options():
"linux_filename": CUSTOM_LINUX_FILENAME, "linux_filename": CUSTOM_LINUX_FILENAME,
"windows_command": "", "windows_command": "",
"windows_filename": "", "windows_filename": "",
# Current server is used for attack telemetry
"current_server": CUSTOM_SERVER,
} }
def test_command_linux_custom_file_and_cmd(fake_custom_pba_linux_options, set_os_linux): def test_command_linux_custom_file_and_cmd(fake_custom_pba_linux_options, set_os_linux):
pba = CustomPBA(MagicMock()) pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_custom_pba_linux_options) pba._set_options(fake_custom_pba_linux_options)
expected_command = f"cd {MONKEY_DIR_PATH} ; {CUSTOM_LINUX_CMD}" expected_command = f"cd {MONKEY_DIR_PATH} ; {CUSTOM_LINUX_CMD}"
assert pba.command == expected_command assert pba.command == expected_command
@ -63,14 +61,12 @@ def fake_custom_pba_windows_options():
"linux_filename": "", "linux_filename": "",
"windows_command": CUSTOM_WINDOWS_CMD, "windows_command": CUSTOM_WINDOWS_CMD,
"windows_filename": CUSTOM_WINDOWS_FILENAME, "windows_filename": CUSTOM_WINDOWS_FILENAME,
# Current server is used for attack telemetry
"current_server": CUSTOM_SERVER,
} }
def test_command_windows_custom_file_and_cmd(fake_custom_pba_windows_options, set_os_windows): def test_command_windows_custom_file_and_cmd(fake_custom_pba_windows_options, set_os_windows):
pba = CustomPBA(MagicMock()) pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_custom_pba_windows_options) pba._set_options(fake_custom_pba_windows_options)
expected_command = f"cd {MONKEY_DIR_PATH} & {CUSTOM_WINDOWS_CMD}" expected_command = f"cd {MONKEY_DIR_PATH} & {CUSTOM_WINDOWS_CMD}"
assert pba.command == expected_command assert pba.command == expected_command
@ -84,14 +80,12 @@ def fake_options_files_only():
"linux_filename": CUSTOM_LINUX_FILENAME, "linux_filename": CUSTOM_LINUX_FILENAME,
"windows_command": "", "windows_command": "",
"windows_filename": CUSTOM_WINDOWS_FILENAME, "windows_filename": CUSTOM_WINDOWS_FILENAME,
# Current server is used for attack telemetry
"current_server": CUSTOM_SERVER,
} }
@pytest.mark.parametrize("os", [set_os_linux, set_os_windows]) @pytest.mark.parametrize("os", [set_os_linux, set_os_windows])
def test_files_only(fake_options_files_only, os): def test_files_only(fake_options_files_only, os):
pba = CustomPBA(MagicMock()) pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_options_files_only) pba._set_options(fake_options_files_only)
assert pba.command == "" assert pba.command == ""
@ -103,20 +97,18 @@ def fake_options_commands_only():
"linux_filename": "", "linux_filename": "",
"windows_command": CUSTOM_WINDOWS_CMD, "windows_command": CUSTOM_WINDOWS_CMD,
"windows_filename": "", "windows_filename": "",
# Current server is used for attack telemetry
"current_server": CUSTOM_SERVER,
} }
def test_commands_only(fake_options_commands_only, set_os_linux): def test_commands_only(fake_options_commands_only, set_os_linux):
pba = CustomPBA(MagicMock()) pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_options_commands_only) pba._set_options(fake_options_commands_only)
assert pba.command == CUSTOM_LINUX_CMD assert pba.command == CUSTOM_LINUX_CMD
assert pba.filename == "" assert pba.filename == ""
def test_commands_only_windows(fake_options_commands_only, set_os_windows): def test_commands_only_windows(fake_options_commands_only, set_os_windows):
pba = CustomPBA(MagicMock()) pba = CustomPBA(MagicMock(), MagicMock())
pba._set_options(fake_options_commands_only) pba._set_options(fake_options_commands_only)
assert pba.command == CUSTOM_WINDOWS_CMD assert pba.command == CUSTOM_WINDOWS_CMD
assert pba.filename == "" assert pba.filename == ""

View File

@ -1,3 +1,5 @@
from unittest.mock import MagicMock
import pytest import pytest
from infection_monkey.control import ControlClient from infection_monkey.control import ControlClient
@ -11,5 +13,6 @@ def spy_send_telemetry(monkeypatch):
_spy_send_telemetry.telem_category = None _spy_send_telemetry.telem_category = None
_spy_send_telemetry.data = None _spy_send_telemetry.data = None
monkeypatch.setattr(ControlClient, "send_telemetry", _spy_send_telemetry) ControlClient.control_client_object = MagicMock()
ControlClient.control_client_object.send_telemetry = MagicMock(side_effect=_spy_send_telemetry)
return _spy_send_telemetry return _spy_send_telemetry

View File

@ -7,7 +7,7 @@ from infection_monkey.telemetry.tunnel_telem import TunnelTelem
@pytest.fixture @pytest.fixture
def tunnel_telem_test_instance(): def tunnel_telem_test_instance():
return TunnelTelem() return TunnelTelem({})
def test_tunnel_telem_send(tunnel_telem_test_instance, spy_send_telemetry): def test_tunnel_telem_send(tunnel_telem_test_instance, spy_send_telemetry):

View File

@ -9,7 +9,7 @@ from monkey.infection_monkey.control import ControlClient
) )
def test_control_set_proxies(monkeypatch, is_windows_os, expected_proxy_string): def test_control_set_proxies(monkeypatch, is_windows_os, expected_proxy_string):
monkeypatch.setattr("monkey.infection_monkey.control.is_windows_os", lambda: is_windows_os) monkeypatch.setattr("monkey.infection_monkey.control.is_windows_os", lambda: is_windows_os)
control_client = ControlClient() control_client = ControlClient("8.8.8.8:5000")
control_client.set_proxies(("8.8.8.8", "45455")) control_client.set_proxies(("8.8.8.8", "45455"))