Agent: Modify find_server to return tuple of server and IIslandAPIClient

This commit is contained in:
Ilija Lazoroski 2022-09-19 17:24:00 +02:00
parent f4b47f8238
commit bc19b5ea93
3 changed files with 29 additions and 19 deletions

View File

@ -5,7 +5,7 @@ import subprocess
import sys import sys
from ipaddress import IPv4Address, IPv4Interface from ipaddress import IPv4Address, IPv4Interface
from pathlib import Path, WindowsPath from pathlib import Path, WindowsPath
from typing import List from typing import List, Tuple
from pubsub.core import Publisher from pubsub.core import Publisher
@ -45,6 +45,7 @@ from infection_monkey.exploit.sshexec import SSHExploiter
from infection_monkey.exploit.wmiexec import WmiExploiter from infection_monkey.exploit.wmiexec import WmiExploiter
from infection_monkey.exploit.zerologon import ZerologonExploiter from infection_monkey.exploit.zerologon import ZerologonExploiter
from infection_monkey.i_puppet import IPuppet, PluginType from infection_monkey.i_puppet import IPuppet, PluginType
from infection_monkey.island_api_client import IIslandAPIClient
from infection_monkey.master import AutomatedMaster from infection_monkey.master import AutomatedMaster
from infection_monkey.master.control_channel import ControlChannel from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.model import VictimHostFactory from infection_monkey.model import VictimHostFactory
@ -110,7 +111,7 @@ class InfectionMonkey:
self._opts = self._get_arguments(args) self._opts = self._get_arguments(args)
# TODO: Revisit variable names # TODO: Revisit variable names
server = self._get_server() server, island_api_client = self._get_server()
# TODO: `address_to_port()` should return the port as an integer. # TODO: `address_to_port()` should return the port as an integer.
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
self._cmd_island_port = int(self._cmd_island_port) self._cmd_island_port = int(self._cmd_island_port)
@ -136,9 +137,9 @@ class InfectionMonkey:
return opts return opts
def _get_server(self): def _get_server(self) -> Tuple[str, IIslandAPIClient]:
logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}") logger.debug(f"Trying to wake up with servers: {', '.join(self._opts.servers)}")
server = find_server(self._opts.servers) server, island_api_client = find_server(self._opts.servers)
if server: if server:
logger.info(f"Successfully connected to the island via {server}") logger.info(f"Successfully connected to the island via {server}")
else: else:
@ -152,7 +153,7 @@ class InfectionMonkey:
servers_to_close = (s for s in self._opts.servers if s != server) servers_to_close = (s for s in self._opts.servers if s != server)
send_remove_from_waitlist_control_message_to_relays(servers_to_close) send_remove_from_waitlist_control_message_to_relays(servers_to_close)
return server return server, island_api_client
@staticmethod @staticmethod
def _log_arguments(args): def _log_arguments(args):

View File

@ -2,11 +2,12 @@ import logging
import socket import socket
from contextlib import suppress from contextlib import suppress
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Dict, Iterable, Iterator, MutableMapping, Optional from typing import Dict, Iterable, Iterator, MutableMapping, Optional, Tuple
from common.network.network_utils import address_to_ip_port from common.network.network_utils import address_to_ip_port
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
HTTPIslandAPIClient, HTTPIslandAPIClient,
IIslandAPIClient,
IslandAPIConnectionError, IslandAPIConnectionError,
IslandAPIError, IslandAPIError,
IslandAPITimeoutError, IslandAPITimeoutError,
@ -25,10 +26,10 @@ logger = logging.getLogger(__name__)
NUM_FIND_SERVER_WORKERS = 32 NUM_FIND_SERVER_WORKERS = 32
def find_server(servers: Iterable[str]) -> Optional[str]: def find_server(servers: Iterable[str]) -> Tuple[Optional[str], Optional[IIslandAPIClient]]:
server_list = list(servers) server_list = list(servers)
server_iterator = ThreadSafeIterator(server_list.__iter__()) server_iterator = ThreadSafeIterator(server_list.__iter__())
server_results: Dict[str, bool] = {} server_results: Dict[str, Tuple[bool, IIslandAPIClient]] = {}
run_worker_threads( run_worker_threads(
_find_island_server, _find_island_server,
@ -39,24 +40,27 @@ def find_server(servers: Iterable[str]) -> Optional[str]:
for server in server_list: for server in server_list:
if server_results[server]: if server_results[server]:
return server island_api_client = server_results[server]
return server, island_api_client
return None return (None, None)
def _find_island_server(servers: Iterator[str], server_status: MutableMapping[str, bool]): def _find_island_server(
servers: Iterator[str], server_status: MutableMapping[str, Optional[IIslandAPIClient]]
):
with suppress(StopIteration): with suppress(StopIteration):
server = next(servers) server = next(servers)
server_status[server] = _check_if_island_server(server) server_status[server] = _check_if_island_server(server)
def _check_if_island_server(server: str) -> bool: def _check_if_island_server(server: str) -> IIslandAPIClient:
logger.debug(f"Trying to connect to server: {server}") logger.debug(f"Trying to connect to server: {server}")
try: try:
HTTPIslandAPIClient(server) island_api_client = HTTPIslandAPIClient(server)
return True return island_api_client
except IslandAPIConnectionError as err: except IslandAPIConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}") logger.error(f"Unable to connect to server/relay {server}: {err}")
except IslandAPITimeoutError as err: except IslandAPITimeoutError as err:
@ -66,7 +70,7 @@ def _check_if_island_server(server: str) -> bool:
f"Exception encountered when trying to connect to server/relay {server}: {err}" f"Exception encountered when trying to connect to server/relay {server}: {err}"
) )
return False return None
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):

View File

@ -1,7 +1,7 @@
import pytest import pytest
import requests_mock import requests_mock
from infection_monkey.island_api_client import IslandAPIConnectionError from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIConnectionError
from infection_monkey.network.relay.utils import find_server from infection_monkey.network.relay.utils import find_server
SERVER_1 = "1.1.1.1:12312" SERVER_1 = "1.1.1.1:12312"
@ -14,7 +14,7 @@ servers = [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expected_server,server_response_pairs", "expected_server, server_response_pairs",
[ [
(None, [(server, {"exc": IslandAPIConnectionError}) for server in servers]), (None, [(server, {"exc": IslandAPIConnectionError}) for server in servers]),
( (
@ -29,7 +29,9 @@ def test_find_server(expected_server, server_response_pairs):
for server, response in server_response_pairs: for server, response in server_response_pairs:
mock.get(f"https://{server}/api?action=is-up", **response) mock.get(f"https://{server}/api?action=is-up", **response)
assert find_server(servers) is expected_server actual_server, _ = find_server(servers)
assert actual_server is expected_server
def test_find_server__multiple_successes(): def test_find_server__multiple_successes():
@ -39,4 +41,7 @@ def test_find_server__multiple_successes():
mock.get(f"https://{SERVER_3}/api?action=is-up", text="") mock.get(f"https://{SERVER_3}/api?action=is-up", text="")
mock.get(f"https://{SERVER_4}/api?action=is-up", text="") mock.get(f"https://{SERVER_4}/api?action=is-up", text="")
assert find_server(servers) == SERVER_2 actual_server, actual_island_api_client = find_server(servers)
assert actual_server == SERVER_2
assert isinstance(actual_island_api_client, IIslandAPIClient)