Merge branch '2216-modify-controlclient-find-server' into 2216-tcp-relay
PR #2250
This commit is contained in:
commit
5366bba389
|
@ -5,7 +5,6 @@ from socket import gethostname
|
|||
from typing import Mapping, Optional
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
import infection_monkey.tunnel as tunnel
|
||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||
|
@ -63,38 +62,6 @@ class ControlClient:
|
|||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
def find_server(self, default_tunnel=None):
|
||||
logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}")
|
||||
if default_tunnel:
|
||||
logger.debug("default_tunnel: %s" % (default_tunnel,))
|
||||
|
||||
try:
|
||||
debug_message = "Trying to connect to server: %s" % self.server_address
|
||||
if self.proxies:
|
||||
debug_message += " through proxies: %s" % self.proxies
|
||||
logger.debug(debug_message)
|
||||
requests.get( # noqa: DUO123
|
||||
f"https://{self.server_address}/api?action=is-up",
|
||||
verify=False,
|
||||
proxies=self.proxies,
|
||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||
)
|
||||
return True
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Error connecting to control server %s: %s", self.server_address, exc)
|
||||
|
||||
if self.proxies:
|
||||
return False
|
||||
else:
|
||||
logger.info("Starting tunnel lookup...")
|
||||
proxy_find = tunnel.find_tunnel(default=default_tunnel)
|
||||
if proxy_find:
|
||||
self.set_proxies(proxy_find)
|
||||
return self.find_server()
|
||||
else:
|
||||
logger.info("No tunnel found")
|
||||
return False
|
||||
|
||||
def set_proxies(self, proxy_find):
|
||||
"""
|
||||
Note: The proxy schema changes between different versions of requests and urllib3,
|
||||
|
|
|
@ -43,6 +43,10 @@ from infection_monkey.model import VictimHostFactory
|
|||
from infection_monkey.network.firewall import app as firewall
|
||||
from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces
|
||||
from infection_monkey.network.relay import TCPRelay
|
||||
from infection_monkey.network.relay.utils import (
|
||||
find_server,
|
||||
send_remove_from_waitlist_control_message_to_relays,
|
||||
)
|
||||
from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
|
||||
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
|
||||
from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
|
||||
|
@ -96,8 +100,12 @@ class InfectionMonkey:
|
|||
logger.info("Monkey is initializing...")
|
||||
self._singleton = SystemSingleton()
|
||||
self._opts = self._get_arguments(args)
|
||||
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(self._opts.servers)
|
||||
self._control_client = ControlClient(self._opts.servers)
|
||||
|
||||
# TODO: Revisit variable names
|
||||
server = self._get_server()
|
||||
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
|
||||
self._control_client = ControlClient(server_address=server)
|
||||
|
||||
# TODO Refactor the telemetry messengers to accept control client
|
||||
# and remove control_client_object
|
||||
ControlClient.control_client_object = self._control_client
|
||||
|
@ -117,6 +125,19 @@ class InfectionMonkey:
|
|||
|
||||
return opts
|
||||
|
||||
def _get_server(self):
|
||||
servers_iterator = (s for s in self._opts.servers)
|
||||
server = find_server(servers_iterator)
|
||||
if server:
|
||||
logger.info(f"Successfully connected to the island via {server}")
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to connect to the island via any known servers: {self._opts.servers}"
|
||||
)
|
||||
send_remove_from_waitlist_control_message_to_relays(servers_iterator)
|
||||
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def _log_arguments(args):
|
||||
arg_string = " ".join([f"{key}: {value}" for key, value in vars(args).items()])
|
||||
|
@ -130,7 +151,7 @@ class InfectionMonkey:
|
|||
logger.info("Agent is starting...")
|
||||
logger.info(f"Agent GUID: {GUID}")
|
||||
|
||||
self._connect_to_island()
|
||||
self._control_client.wakeup(parent=self._opts.parent)
|
||||
|
||||
# TODO: Reevaluate who is responsible to send this information
|
||||
if is_windows_os():
|
||||
|
@ -148,24 +169,6 @@ class InfectionMonkey:
|
|||
self._setup()
|
||||
self._master.start()
|
||||
|
||||
def _connect_to_island(self):
|
||||
# Sets island's IP and port for monkey to communicate to
|
||||
if self._current_server_is_set():
|
||||
logger.debug(f"Default server set to: {self._control_client.server_address}")
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to connect to the island via "
|
||||
f"any known server address: {self._opts.servers}"
|
||||
)
|
||||
|
||||
self._control_client.wakeup(parent=self._opts.parent)
|
||||
|
||||
def _current_server_is_set(self) -> bool:
|
||||
if self._control_client.find_server(default_tunnel=self._opts.servers):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _setup(self):
|
||||
logger.debug("Starting the setup phase.")
|
||||
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from .relay_connection_handler import RelayConnectionHandler, RELAY_CONTROL_MESSAGE
|
||||
from .relay_connection_handler import (
|
||||
RelayConnectionHandler,
|
||||
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST,
|
||||
)
|
||||
from .relay_user_handler import RelayUser, RelayUserHandler
|
||||
from .sockets_pipe import SocketsPipe
|
||||
from .tcp_connection_handler import TCPConnectionHandler
|
||||
|
|
|
@ -4,7 +4,7 @@ from ipaddress import IPv4Address
|
|||
from .relay_user_handler import RelayUserHandler
|
||||
from .tcp_pipe_spawner import TCPPipeSpawner
|
||||
|
||||
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
|
||||
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST = b"infection-monkey-relay-control-message: -"
|
||||
|
||||
|
||||
class RelayConnectionHandler:
|
||||
|
@ -25,7 +25,7 @@ class RelayConnectionHandler:
|
|||
|
||||
control_message = sock.recv(socket.MSG_PEEK)
|
||||
|
||||
if control_message.startswith(RELAY_CONTROL_MESSAGE):
|
||||
if control_message.startswith(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST):
|
||||
self._relay_user_handler.disconnect_user(addr)
|
||||
else:
|
||||
self._relay_user_handler.add_relay_user(addr)
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
import logging
|
||||
import socket
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
|
||||
from common.network.network_utils import address_to_ip_port
|
||||
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
|
||||
from infection_monkey.utils.threading import create_daemon_thread
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_server(servers: Iterable[str]) -> Optional[str]:
|
||||
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
|
||||
|
||||
for server in servers:
|
||||
logger.debug(f"Trying to connect to server: {server}")
|
||||
|
||||
try:
|
||||
requests.get( # noqa: DUO123
|
||||
f"https://{server}/api?action=is-up",
|
||||
verify=False,
|
||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
return server
|
||||
except requests.exceptions.ConnectionError as err:
|
||||
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
||||
except TimeoutError as err:
|
||||
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
|
||||
for server in servers:
|
||||
t = create_daemon_thread(
|
||||
target=_send_remove_from_waitlist_control_message_to_relay,
|
||||
name="SendRemoveFromWaitlistControlMessageToRelaysThread",
|
||||
args=(server,),
|
||||
)
|
||||
t.start()
|
||||
|
||||
|
||||
def _send_remove_from_waitlist_control_message_to_relay(server: str):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
|
||||
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
|
||||
|
||||
ip, port = address_to_ip_port(server)
|
||||
logger.info(f"Control message was sent to the server/relay {server}")
|
||||
|
||||
try:
|
||||
d_socket.connect((ip, int(port)))
|
||||
d_socket.send(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST)
|
||||
except OSError as err:
|
||||
logger.error(f"Error connecting to socket {server}: {err}")
|
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
|
||||
from monkey.infection_monkey.network.relay import (
|
||||
RELAY_CONTROL_MESSAGE,
|
||||
RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST,
|
||||
RelayConnectionHandler,
|
||||
RelayUserHandler,
|
||||
TCPPipeSpawner,
|
||||
|
@ -27,7 +27,7 @@ def relay_user_handler():
|
|||
@pytest.fixture
|
||||
def close_socket():
|
||||
sock = MagicMock(spec=socket.socket)
|
||||
sock.recv.return_value = RELAY_CONTROL_MESSAGE
|
||||
sock.recv.return_value = RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
|
||||
sock.getpeername.return_value = (USER_ADDRESS, 12345)
|
||||
return sock
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
import pytest
|
||||
import requests
|
||||
import requests_mock
|
||||
|
||||
from infection_monkey.network.relay.utils import find_server
|
||||
|
||||
SERVER_1 = "1.1.1.1:12312"
|
||||
SERVER_2 = "2.2.2.2:4321"
|
||||
SERVER_3 = "3.3.3.3:3142"
|
||||
SERVER_4 = "4.4.4.4:5000"
|
||||
|
||||
|
||||
servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_server,server_response_pairs",
|
||||
[
|
||||
(None, [(server, {"exc": requests.exceptions.ConnectionError}) for server in servers]),
|
||||
(
|
||||
SERVER_2,
|
||||
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
|
||||
+ [(server, {"text": ""}) for server in servers[1:]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_server(expected_server, server_response_pairs):
|
||||
with requests_mock.Mocker() as mock:
|
||||
for server, response in server_response_pairs:
|
||||
mock.get(f"https://{server}/api?action=is-up", **response)
|
||||
|
||||
assert find_server(servers) is expected_server
|
Loading…
Reference in New Issue