forked from p15670423/monkey
Merge branch '2216-modify-controlclient-find-server' into 2216-tcp-relay
PR #2250
This commit is contained in:
commit
5366bba389
|
@ -5,7 +5,6 @@ from socket import gethostname
|
||||||
from typing import Mapping, Optional
|
from typing import Mapping, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from requests.exceptions import ConnectionError
|
|
||||||
|
|
||||||
import infection_monkey.tunnel as tunnel
|
import infection_monkey.tunnel as tunnel
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||||
|
@ -63,38 +62,6 @@ class ControlClient:
|
||||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_server(self, default_tunnel=None):
|
|
||||||
logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}")
|
|
||||||
if default_tunnel:
|
|
||||||
logger.debug("default_tunnel: %s" % (default_tunnel,))
|
|
||||||
|
|
||||||
try:
|
|
||||||
debug_message = "Trying to connect to server: %s" % self.server_address
|
|
||||||
if self.proxies:
|
|
||||||
debug_message += " through proxies: %s" % self.proxies
|
|
||||||
logger.debug(debug_message)
|
|
||||||
requests.get( # noqa: DUO123
|
|
||||||
f"https://{self.server_address}/api?action=is-up",
|
|
||||||
verify=False,
|
|
||||||
proxies=self.proxies,
|
|
||||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except ConnectionError as exc:
|
|
||||||
logger.warning("Error connecting to control server %s: %s", self.server_address, exc)
|
|
||||||
|
|
||||||
if self.proxies:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
logger.info("Starting tunnel lookup...")
|
|
||||||
proxy_find = tunnel.find_tunnel(default=default_tunnel)
|
|
||||||
if proxy_find:
|
|
||||||
self.set_proxies(proxy_find)
|
|
||||||
return self.find_server()
|
|
||||||
else:
|
|
||||||
logger.info("No tunnel found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def set_proxies(self, proxy_find):
|
def set_proxies(self, proxy_find):
|
||||||
"""
|
"""
|
||||||
Note: The proxy schema changes between different versions of requests and urllib3,
|
Note: The proxy schema changes between different versions of requests and urllib3,
|
||||||
|
|
|
@ -43,6 +43,10 @@ from infection_monkey.model import VictimHostFactory
|
||||||
from infection_monkey.network.firewall import app as firewall
|
from infection_monkey.network.firewall import app as firewall
|
||||||
from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces
|
from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces
|
||||||
from infection_monkey.network.relay import TCPRelay
|
from infection_monkey.network.relay import TCPRelay
|
||||||
|
from infection_monkey.network.relay.utils import (
|
||||||
|
find_server,
|
||||||
|
send_remove_from_waitlist_control_message_to_relays,
|
||||||
|
)
|
||||||
from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
|
from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
|
||||||
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
|
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
|
||||||
from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
|
from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
|
||||||
|
@ -96,8 +100,12 @@ class InfectionMonkey:
|
||||||
logger.info("Monkey is initializing...")
|
logger.info("Monkey is initializing...")
|
||||||
self._singleton = SystemSingleton()
|
self._singleton = SystemSingleton()
|
||||||
self._opts = self._get_arguments(args)
|
self._opts = self._get_arguments(args)
|
||||||
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.servers)
|
|
||||||
self._control_client = ControlClient(self._opts.servers)
|
# TODO: Revisit variable names
|
||||||
|
server = self._get_server()
|
||||||
|
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
|
||||||
|
self._control_client = ControlClient(server_address=server)
|
||||||
|
|
||||||
# TODO Refactor the telemetry messengers to accept control client
|
# TODO Refactor the telemetry messengers to accept control client
|
||||||
# and remove control_client_object
|
# and remove control_client_object
|
||||||
ControlClient.control_client_object = self._control_client
|
ControlClient.control_client_object = self._control_client
|
||||||
|
@ -117,6 +125,19 @@ class InfectionMonkey:
|
||||||
|
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
|
def _get_server(self):
|
||||||
|
servers_iterator = (s for s in self._opts.servers)
|
||||||
|
server = find_server(servers_iterator)
|
||||||
|
if server:
|
||||||
|
logger.info(f"Successfully connected to the island via {server}")
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to connect to the island via any known servers: {self._opts.servers}"
|
||||||
|
)
|
||||||
|
send_remove_from_waitlist_control_message_to_relays(servers_iterator)
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _log_arguments(args):
|
def _log_arguments(args):
|
||||||
arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()])
|
arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()])
|
||||||
|
@ -130,7 +151,7 @@ class InfectionMonkey:
|
||||||
logger.info("Agent is starting...")
|
logger.info("Agent is starting...")
|
||||||
logger.info(f"Agent GUID: {GUID}")
|
logger.info(f"Agent GUID: {GUID}")
|
||||||
|
|
||||||
self._connect_to_island()
|
self._control_client.wakeup(parent=self._opts.parent)
|
||||||
|
|
||||||
# TODO: Reevaluate who is responsible to send this information
|
# TODO: Reevaluate who is responsible to send this information
|
||||||
if is_windows_os():
|
if is_windows_os():
|
||||||
|
@ -148,24 +169,6 @@ class InfectionMonkey:
|
||||||
self._setup()
|
self._setup()
|
||||||
self._master.start()
|
self._master.start()
|
||||||
|
|
||||||
def _connect_to_island(self):
|
|
||||||
# Sets island's IP and port for monkey to communicate to
|
|
||||||
if self._current_server_is_set():
|
|
||||||
logger.debug(f"Default server set to: {self._control_client.server_address}")
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to connect to the island via "
|
|
||||||
f"any known server address: {self._opts.servers}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._control_client.wakeup(parent=self._opts.parent)
|
|
||||||
|
|
||||||
def _current_server_is_set(self) -> bool:
|
|
||||||
if self._control_client.find_server(default_tunnel=self._opts.servers):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _setup(self):
|
def _setup(self):
|
||||||
logger.debug("Starting the setup phase.")
|
logger.debug("Starting the setup phase.")
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
from .relay_connection_handler import RelayConnectionHandler, RELAY_CONTROL_MESSAGE
|
from .relay_connection_handler import (
|
||||||
|
RelayConnectionHandler,
|
||||||
|
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST,
|
||||||
|
)
|
||||||
from .relay_user_handler import RelayUser, RelayUserHandler
|
from .relay_user_handler import RelayUser, RelayUserHandler
|
||||||
from .sockets_pipe import SocketsPipe
|
from .sockets_pipe import SocketsPipe
|
||||||
from .tcp_connection_handler import TCPConnectionHandler
|
from .tcp_connection_handler import TCPConnectionHandler
|
||||||
|
|
|
@ -4,7 +4,7 @@ from ipaddress import IPv4Address
|
||||||
from .relay_user_handler import RelayUserHandler
|
from .relay_user_handler import RelayUserHandler
|
||||||
from .tcp_pipe_spawner import TCPPipeSpawner
|
from .tcp_pipe_spawner import TCPPipeSpawner
|
||||||
|
|
||||||
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
|
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -"
|
||||||
|
|
||||||
|
|
||||||
class RelayConnectionHandler:
|
class RelayConnectionHandler:
|
||||||
|
@ -25,7 +25,7 @@ class RelayConnectionHandler:
|
||||||
|
|
||||||
control_message = sock.recv(socket.MSG_PEEK)
|
control_message = sock.recv(socket.MSG_PEEK)
|
||||||
|
|
||||||
if control_message.startswith(RELAY_CONTROL_MESSAGE):
|
if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST):
|
||||||
self._relay_user_handler.disconnect_user(addr)
|
self._relay_user_handler.disconnect_user(addr)
|
||||||
else:
|
else:
|
||||||
self._relay_user_handler.add_relay_user(addr)
|
self._relay_user_handler.add_relay_user(addr)
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
|
||||||
|
from common.network.network_utils import address_to_ip_port
|
||||||
|
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
|
||||||
|
from infection_monkey.utils.threading import create_daemon_thread
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def find_server(servers: Iterable[str]) -> Optional[str]:
|
||||||
|
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
|
||||||
|
|
||||||
|
for server in servers:
|
||||||
|
logger.debug(f"Trying to connect to server: {server}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
requests.get( # noqa: DUO123
|
||||||
|
f"https://{server}/api?action=is-up",
|
||||||
|
verify=False,
|
||||||
|
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
return server
|
||||||
|
except requests.exceptions.ConnectionError as err:
|
||||||
|
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
||||||
|
except TimeoutError as err:
|
||||||
|
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
|
||||||
|
except Exception as err:
|
||||||
|
logger.error(
|
||||||
|
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
|
||||||
|
for server in servers:
|
||||||
|
t = create_daemon_thread(
|
||||||
|
target=_send_remove_from_waitlist_control_message_to_relay,
|
||||||
|
name="SendRemoveFromWaitlistControlMessageToRelaysThread",
|
||||||
|
args=(server,),
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
def _send_remove_from_waitlist_control_message_to_relay(server: str):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
|
||||||
|
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
|
||||||
|
|
||||||
|
ip, port = address_to_ip_port(server)
|
||||||
|
logger.info(f"Control message was sent to the server/relay {server}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
d_socket.connect((ip, int(port)))
|
||||||
|
d_socket.send(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST)
|
||||||
|
except OSError as err:
|
||||||
|
logger.error(f"Error connecting to socket {server}: {err}")
|
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from monkey.infection_monkey.network.relay import (
|
from monkey.infection_monkey.network.relay import (
|
||||||
RELAY_CONTROL_MESSAGE,
|
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST,
|
||||||
RelayConnectionHandler,
|
RelayConnectionHandler,
|
||||||
RelayUserHandler,
|
RelayUserHandler,
|
||||||
TCPPipeSpawner,
|
TCPPipeSpawner,
|
||||||
|
@ -27,7 +27,7 @@ def relay_user_handler():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def close_socket():
|
def close_socket():
|
||||||
sock = MagicMock(spec=socket.socket)
|
sock = MagicMock(spec=socket.socket)
|
||||||
sock.recv.return_value = RELAY_CONTROL_MESSAGE
|
sock.recv.return_value = RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
|
||||||
sock.getpeername.return_value = (USER_ADDRESS, 12345)
|
sock.getpeername.return_value = (USER_ADDRESS, 12345)
|
||||||
return sock
|
return sock
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import requests_mock
|
||||||
|
|
||||||
|
from infection_monkey.network.relay.utils import find_server
|
||||||
|
|
||||||
|
SERVER_1 = "1.1.1.1:12312"
|
||||||
|
SERVER_2 = "2.2.2.2:4321"
|
||||||
|
SERVER_3 = "3.3.3.3:3142"
|
||||||
|
SERVER_4 = "4.4.4.4:5000"
|
||||||
|
|
||||||
|
|
||||||
|
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"expected_server,server_response_pairs",
|
||||||
|
[
|
||||||
|
(None, [(server, {"exc": requests.exceptions.ConnectionError}) for server in servers]),
|
||||||
|
(
|
||||||
|
SERVER_2,
|
||||||
|
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
|
||||||
|
+ [(server, {"text": ""}) for server in servers[1:]],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_find_server(expected_server, server_response_pairs):
|
||||||
|
with requests_mock.Mocker() as mock:
|
||||||
|
for server, response in server_response_pairs:
|
||||||
|
mock.get(f"https://{server}/api?action=is-up", **response)
|
||||||
|
|
||||||
|
assert find_server(servers) is expected_server
|
Loading…
Reference in New Issue