Agent: Modify find_server to return tuple of server and IIslandAPIClient

This commit is contained in:
Ilija Lazoroski 2022-09-19 17:24:00 +02:00
parent f4b47f8238
commit bc19b5ea93
3 changed files with 29 additions and 19 deletions

View File

@ -5,7 +5,7 @@ import subprocess
import sys
from ipaddress import IPv4Address, IPv4Interface
from pathlib import Path, WindowsPath
from typing import List
from typing import List, Tuple
from pubsub.core import Publisher
@ -45,6 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter
from infection_monkey.exploit.wmiexec import WmiExploiter
from infection_monkey.exploit.zerologon import ZerologonExploiter
from infection_monkey.i_puppet import IPuppet, PluginType
from infection_monkey.island_api_client import IIslandAPIClient
from infection_monkey.master import AutomatedMaster
from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.model import VictimHostFactory
@ -110,7 +111,7 @@ class InfectionMonkey:
self._opts = self._get_arguments(args)
# TODO: Revisit variable names
server = self._get_server()
server, island_api_client = self._get_server()
# TODO: `address_to_port()` should return the port as an integer.
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
self._cmd_island_port = int(self._cmd_island_port)
@ -136,9 +137,9 @@ class InfectionMonkey:
return opts
def _get_server(self):
def _get_server(self) -> Tuple[str, IIslandAPIClient]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
server = find_server(self._opts.servers)
server, island_api_client = find_server(self._opts.servers)
if server:
logger.info(f"Successfully connected to the island via {server}")
else:
@ -152,7 +153,7 @@ class InfectionMonkey:
servers_to_close = (s for s in self._opts.servers if s != server)
send_remove_from_waitlist_control_message_to_relays(servers_to_close)
return server
return server, island_api_client
@staticmethod
def _log_arguments(args):

View File

@ -2,11 +2,12 @@ import logging
import socket
from contextlib import suppress
from ipaddress import IPv4Address
from typing import Dict, Iterable, Iterator, MutableMapping, Optional
from typing import Dict, Iterable, Iterator, MutableMapping, Optional, Tuple
from common.network.network_utils import address_to_ip_port
from infection_monkey.island_api_client import (
HTTPIslandAPIClient,
IIslandAPIClient,
IslandAPIConnectionError,
IslandAPIError,
IslandAPITimeoutError,
@ -25,10 +26,10 @@ logger = logging.getLogger(__name__)
NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]:
def find_server(servers: Iterable[str]) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {}
server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
run_worker_threads(
_find_island_server,
@ -39,24 +40,27 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
for server in server_list:
if server_results[server]:
return server
island_api_client = server_results[server]
return server, island_api_client
return None
return (None, None)
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]):
def _find_island_server(
servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]]
):
with suppress(StopIteration):
server = next(servers)
server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool:
def _check_if_island_server(server: str) -> IIslandAPIClient:
logger.debug(f"Trying to connect to server: {server}")
try:
HTTPIslandAPIClient(server)
island_api_client = HTTPIslandAPIClient(server)
return True
return island_api_client
except IslandAPIConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except IslandAPITimeoutError as err:
@ -66,7 +70,7 @@ def _check_if_island_server(server: str) -> bool:
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
return False
return None
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):

View File

@ -1,7 +1,7 @@
import pytest
import requests_mock
from infection_monkey.island_api_client import IslandAPIConnectionError
from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError
from infection_monkey.network.relay.utils import find_server
SERVER_1 = "1.1.1.1:12312"
@ -14,7 +14,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@pytest.mark.parametrize(
"expected_server,server_response_pairs",
"expected_server, server_response_pairs",
[
(None, [(server, {"exc": IslandAPIConnectionError}) for server in servers]),
(
@ -29,7 +29,9 @@ def test_find_server(expected_server, server_response_pairs):
for server, response in server_response_pairs:
mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server
actual_server, _ = find_server(servers)
assert actual_server is expected_server
def test_find_server__multiple_successes():
@ -39,4 +41,7 @@ def test_find_server__multiple_successes():
mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2
actual_server, actual_island_api_client = find_server(servers)
assert actual_server == SERVER_2
assert isinstance(actual_island_api_client, IIslandAPIClient)