Merge branch '2292-modify-find-server' into develop

PR #2314
This commit is contained in:
Mike Salvatore 2022-09-19 14:42:33 -04:00
commit e674f3ab24
3 changed files with 77 additions and 51 deletions

View File

@ -5,7 +5,7 @@ import subprocess
import sys import sys
from ipaddress import IPv4Address, IPv4Interface from ipaddress import IPv4Address, IPv4Interface
from pathlib import Path, WindowsPath from pathlib import Path, WindowsPath
from typing import List, Optional from typing import List, Mapping, Optional, Tuple
from pubsub.core import Publisher from pubsub.core import Publisher
@ -45,6 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter
from infection_monkey.exploit.wmiexec import WmiExploiter from infection_monkey.exploit.wmiexec import WmiExploiter
from infection_monkey.exploit.zerologon import ZerologonExploiter from infection_monkey.exploit.zerologon import ZerologonExploiter
from infection_monkey.i_puppet import IPuppet, PluginType from infection_monkey.i_puppet import IPuppet, PluginType
from infection_monkey.island_api_client import IIslandAPIClient
from infection_monkey.master import AutomatedMaster from infection_monkey.master import AutomatedMaster
from infection_monkey.master.control_channel import ControlChannel from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.model import VictimHostFactory from infection_monkey.model import VictimHostFactory
@ -52,7 +53,7 @@ from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.relay import TCPRelay from infection_monkey.network.relay import TCPRelay
from infection_monkey.network.relay.utils import ( from infection_monkey.network.relay.utils import (
find_server, find_available_island_apis,
notify_disconnect, notify_disconnect,
send_remove_from_waitlist_control_message_to_relays, send_remove_from_waitlist_control_message_to_relays,
) )
@ -110,7 +111,7 @@ class InfectionMonkey:
self._opts = self._get_arguments(args) self._opts = self._get_arguments(args)
# TODO: Revisit variable names # TODO: Revisit variable names
server = self._get_server() server, island_api_client = self._connect_to_island_api()
# TODO: `address_to_port()` should return the port as an integer. # TODO: `address_to_port()` should return the port as an integer.
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
self._cmd_island_port = int(self._cmd_island_port) self._cmd_island_port = int(self._cmd_island_port)
@ -123,7 +124,7 @@ class InfectionMonkey:
self._telemetry_messenger = LegacyTelemetryMessengerAdapter() self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth self._current_depth = self._opts.depth
self._master = None self._master = None
self._relay: Optional[TCPRelay] = None self._relay: TCPRelay
@staticmethod @staticmethod
def _get_arguments(args): def _get_arguments(args):
@ -136,10 +137,13 @@ class InfectionMonkey:
return opts return opts
def _get_server(self): # TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
servers_iterator = (s for s in self._opts.servers) server_clients = find_available_island_apis(self._opts.servers)
server = find_server(servers_iterator)
server, island_api_client = self._select_server(server_clients)
if server: if server:
logger.info(f"Successfully connected to the island via {server}") logger.info(f"Successfully connected to the island via {server}")
else: else:
@ -147,12 +151,22 @@ class InfectionMonkey:
f"Failed to connect to the island via any known servers: {self._opts.servers}" f"Failed to connect to the island via any known servers: {self._opts.servers}"
) )
# Note: Since we pass the address for each of our interfaces to the exploited # NOTE: Since we pass the address for each of our interfaces to the exploited
# machines, is it possible for a machine to unintentionally unregister itself from the # machines, is it possible for a machine to unintentionally unregister itself from the
# relay if it is able to connect to the relay over multiple interfaces? # relay if it is able to connect to the relay over multiple interfaces?
send_remove_from_waitlist_control_message_to_relays(servers_iterator) servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s])
send_remove_from_waitlist_control_message_to_relays(servers_to_close)
return server return server, island_api_client
def _select_server(
self, server_clients: Mapping[str, IIslandAPIClient]
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
for server in self._opts.servers:
if server_clients[server]:
return server, server_clients[server]
return None, None
@staticmethod @staticmethod
def _log_arguments(args): def _log_arguments(args):

View File

@ -2,12 +2,17 @@ import logging
import socket import socket
from contextlib import suppress from contextlib import suppress
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Dict, Iterable, Iterator, MutableMapping, Optional from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Tuple
import requests from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port from common.network.network_utils import address_to_ip_port
from infection_monkey.island_api_client import (
HTTPIslandAPIClient,
IIslandAPIClient,
IslandAPIConnectionError,
IslandAPIError,
IslandAPITimeoutError,
)
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import ( from infection_monkey.utils.threading import (
ThreadSafeIterator, ThreadSafeIterator,
@ -22,10 +27,10 @@ logger = logging.getLogger(__name__)
NUM_FIND_SERVER_WORKERS = 32 NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]: def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]:
server_list = list(servers) server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__()) server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {} server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
run_worker_threads( run_worker_threads(
_find_island_server, _find_island_server,
@ -34,47 +39,39 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
num_workers=NUM_FIND_SERVER_WORKERS, num_workers=NUM_FIND_SERVER_WORKERS,
) )
for server in server_list: return server_results
if server_results[server]:
return server
return None
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]): def _find_island_server(
servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]]
):
with suppress(StopIteration): with suppress(StopIteration):
server = next(servers) server = next(servers)
server_status[server] = _check_if_island_server(server) server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool: def _check_if_island_server(server: str) -> IIslandAPIClient:
logger.debug(f"Trying to connect to server: {server}") logger.debug(f"Trying to connect to server: {server}")
try: try:
requests.get( # noqa: DUO123 return HTTPIslandAPIClient(server)
f"https://{server}/api?action=is-up", except IslandAPIConnectionError as err:
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return True
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}") logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err: except IslandAPITimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}") logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err: except IslandAPIError as err:
logger.error( logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}" f"Exception encountered when trying to connect to server/relay {server}: {err}"
) )
return False return None
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
for server in servers: for i, server in enumerate(servers, start=1):
t = create_daemon_thread( t = create_daemon_thread(
target=_send_remove_from_waitlist_control_message_to_relay, target=_send_remove_from_waitlist_control_message_to_relay,
name="SendRemoveFromWaitlistControlMessageToRelaysThread", name=f"SendRemoveFromWaitlistControlMessageToRelaysThread-{i:02d}",
args=(server,), args=(server,),
) )
t.start() t.start()
@ -93,8 +90,10 @@ def notify_disconnect(server_ip: IPv4Address, server_port: int):
:param server_port: The port of the server to notify. :param server_port: The port of the server to notify.
""" """
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
d_socket.settimeout(LONG_REQUEST_TIMEOUT)
try: try:
d_socket.connect((server_ip, server_port)) d_socket.connect((str(server_ip), server_port))
d_socket.sendall(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST) d_socket.sendall(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST)
logger.info(f"Control message was sent to the server/relay {server_ip}:{server_port}") logger.info(f"Control message was sent to the server/relay {server_ip}:{server_port}")
except OSError as err: except OSError as err:

View File

@ -1,8 +1,8 @@
import pytest import pytest
import requests
import requests_mock import requests_mock
from infection_monkey.network.relay.utils import find_server from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError
from infection_monkey.network.relay.utils import find_available_island_apis
SERVER_1 = "1.1.1.1:12312" SERVER_1 = "1.1.1.1:12312"
SERVER_2 = "2.2.2.2:4321" SERVER_2 = "2.2.2.2:4321"
@ -14,29 +14,42 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expected_server,server_response_pairs", "expected_available_servers, server_response_pairs",
[ [
(None, [(server, {"exc": requests.exceptions.ConnectionError}) for server in servers]), ([], [(server, {"exc": IslandAPIConnectionError}) for server in servers]),
( (
SERVER_2, servers[1:],
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})] [(SERVER_1, {"exc": IslandAPIConnectionError})]
+ [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item] + [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
), ),
], ],
) )
def test_find_server(expected_server, server_response_pairs): def test_find_available_island_apis(expected_available_servers, server_response_pairs):
with requests_mock.Mocker() as mock: with requests_mock.Mocker() as mock:
for server, response in server_response_pairs: for server, response in server_response_pairs:
mock.get(f"https://{server}/api?action=is-up", **response) mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server available_apis = find_available_island_apis(servers)
assert len(available_apis) == len(server_response_pairs)
for server, island_api_client in available_apis.items():
if server in expected_available_servers:
assert island_api_client is not None
else:
assert island_api_client is None
def test_find_server__multiple_successes(): def test_find_available_island_apis__multiple_successes():
available_servers = [SERVER_2, SERVER_3]
with requests_mock.Mocker() as mock: with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError) mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
mock.get(f"https://{SERVER_2}/api?action=is-up", text="") for server in available_servers:
mock.get(f"https://{SERVER_3}/api?action=is-up", text="") mock.get(f"https://{server}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2 available_apis = find_available_island_apis(servers)
assert available_apis[SERVER_1] is None
assert available_apis[SERVER_4] is None
for server in available_servers:
assert isinstance(available_apis[server], IIslandAPIClient)