Merge branch '2292-modify-find-server' into develop

PR #2314
This commit is contained in:
Mike Salvatore 2022-09-19 14:42:33 -04:00
commit e674f3ab24
3 changed files with 77 additions and 51 deletions

View File

@ -5,7 +5,7 @@ import subprocess
import sys
from ipaddress import IPv4Address, IPv4Interface
from pathlib import Path, WindowsPath
from typing import List, Optional
from typing import List, Mapping, Optional, Tuple
from pubsub.core import Publisher
@ -45,6 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter
from infection_monkey.exploit.wmiexec import WmiExploiter
from infection_monkey.exploit.zerologon import ZerologonExploiter
from infection_monkey.i_puppet import IPuppet, PluginType
from infection_monkey.island_api_client import IIslandAPIClient
from infection_monkey.master import AutomatedMaster
from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.model import VictimHostFactory
@ -52,7 +53,7 @@ from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.relay import TCPRelay
from infection_monkey.network.relay.utils import (
find_server,
find_available_island_apis,
notify_disconnect,
send_remove_from_waitlist_control_message_to_relays,
)
@ -110,7 +111,7 @@ class InfectionMonkey:
self._opts = self._get_arguments(args)
# TODO: Revisit variable names
server = self._get_server()
server, island_api_client = self._connect_to_island_api()
# TODO: `address_to_port()` should return the port as an integer.
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
self._cmd_island_port = int(self._cmd_island_port)
@ -123,7 +124,7 @@ class InfectionMonkey:
self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth
self._master = None
self._relay: Optional[TCPRelay] = None
self._relay: TCPRelay
@staticmethod
def _get_arguments(args):
@ -136,10 +137,13 @@ class InfectionMonkey:
return opts
def _get_server(self):
# TODO: By the time we finish 2292, _connect_to_island_api() may not need to return `server`
def _connect_to_island_api(self) -> Tuple[str, IIslandAPIClient]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
servers_iterator = (s for s in self._opts.servers)
server = find_server(servers_iterator)
server_clients = find_available_island_apis(self._opts.servers)
server, island_api_client = self._select_server(server_clients)
if server:
logger.info(f"Successfully connected to the island via {server}")
else:
@ -147,12 +151,22 @@ class InfectionMonkey:
f"Failed to connect to the island via any known servers: {self._opts.servers}"
)
# Note: Since we pass the address for each of our interfaces to the exploited
# NOTE: Since we pass the address for each of our interfaces to the exploited
# machines, is it possible for a machine to unintentionally unregister itself from the
# relay if it is able to connect to the relay over multiple interfaces?
send_remove_from_waitlist_control_message_to_relays(servers_iterator)
servers_to_close = (s for s in self._opts.servers if s != server and server_clients[s])
send_remove_from_waitlist_control_message_to_relays(servers_to_close)
return server
return server, island_api_client
def _select_server(
self, server_clients: Mapping[str, IIslandAPIClient]
) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
for server in self._opts.servers:
if server_clients[server]:
return server, server_clients[server]
return None, None
@staticmethod
def _log_arguments(args):

View File

@ -2,12 +2,17 @@ import logging
import socket
from contextlib import suppress
from ipaddress import IPv4Address
from typing import Dict, Iterable, Iterator, MutableMapping, Optional
from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Tuple
import requests
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
from infection_monkey.island_api_client import (
HTTPIslandAPIClient,
IIslandAPIClient,
IslandAPIConnectionError,
IslandAPIError,
IslandAPITimeoutError,
)
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import (
ThreadSafeIterator,
@ -22,10 +27,10 @@ logger = logging.getLogger(__name__)
NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]:
def find_available_island_apis(servers: Iterable[str]) -> Mapping[str, Optional[IIslandAPIClient]]:
server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {}
server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
run_worker_threads(
_find_island_server,
@ -34,47 +39,39 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
num_workers=NUM_FIND_SERVER_WORKERS,
)
for server in server_list:
if server_results[server]:
return server
return None
return server_results
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
def _find_island_server(
servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]]
):
with suppress(StopIteration):
server = next(servers)
server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool:
def _check_if_island_server(server: str) -> IIslandAPIClient:
logger.debug(f"Trying to connect to server: {server}")
try:
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
return True
except requests.exceptions.ConnectionError as err:
return HTTPIslandAPIClient(server)
except IslandAPIConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
except IslandAPITimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
except IslandAPIError as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return False
return None
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
for server in servers:
for i, server in enumerate(servers, start=1):
t = create_daemon_thread(
target=_send_remove_from_waitlist_control_message_to_relay,
name="SendRemoveFromWaitlistControlMessageToRelaysThread",
name=f"SendRemoveFromWaitlistControlMessageToRelaysThread-{i:02d}",
args=(server,),
)
t.start()
@ -93,8 +90,10 @@ def notify_disconnect(server_ip: IPv4Address, server_port: int):
:param server_port: The port of the server to notify.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
d_socket.settimeout(LONG_REQUEST_TIMEOUT)
try:
d_socket.connect((server_ip, server_port))
d_socket.connect((str(server_ip), server_port))
d_socket.sendall(RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST)
logger.info(f"Control message was sent to the server/relay {server_ip}:{server_port}")
except OSError as err:

View File

@ -1,8 +1,8 @@
import pytest
import requests
import requests_mock
from infection_monkey.network.relay.utils import find_server
from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError
from infection_monkey.network.relay.utils import find_available_island_apis
SERVER_1 = "1.1.1.1:12312"
SERVER_2 = "2.2.2.2:4321"
@ -14,29 +14,42 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@pytest.mark.parametrize(
"expected_server,server_response_pairs",
"expected_available_servers, server_response_pairs",
[
(None, [(server, {"exc": requests.exceptions.ConnectionError}) for server in servers]),
([], [(server, {"exc": IslandAPIConnectionError}) for server in servers]),
(
SERVER_2,
[(SERVER_1, {"exc": requests.exceptions.ConnectionError})]
servers[1:],
[(SERVER_1, {"exc": IslandAPIConnectionError})]
+ [(server, {"text": ""}) for server in servers[1:]], # type: ignore[dict-item]
),
],
)
def test_find_server(expected_server, server_response_pairs):
def test_find_available_island_apis(expected_available_servers, server_response_pairs):
with requests_mock.Mocker() as mock:
for server, response in server_response_pairs:
mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server
available_apis = find_available_island_apis(servers)
assert len(available_apis) == len(server_response_pairs)
for server, island_api_client in available_apis.items():
if server in expected_available_servers:
assert island_api_client is not None
else:
assert island_api_client is None
def test_find_server__multiple_successes():
def test_find_available_island_apis__multiple_successes():
available_servers = [SERVER_2, SERVER_3]
with requests_mock.Mocker() as mock:
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=requests.exceptions.ConnectionError)
mock.get(f"https://{SERVER_2}/api?action=is-up", text="")
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError)
for server in available_servers:
mock.get(f"https://{server}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2
available_apis = find_available_island_apis(servers)
assert available_apis[SERVER_1] is None
assert available_apis[SERVER_4] is None
for server in available_servers:
assert isinstance(available_apis[server], IIslandAPIClient)